From 98edf5f8a67f36396059d8145459efbcb9764431 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Thu, 15 Jan 2026 20:45:58 +0100 Subject: [PATCH 01/45] Add QRE code owners for PR review assignment. (#2853) --- .github/CODEOWNERS | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3c5d5ded4d..6747cbc280 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -17,7 +17,9 @@ /library @DmitryVasilevsky @swernli /source/npm @billti @DmitryVasilevsky @minestarks /source/pip @billti @idavis @minestarks +/source/pip/qsharp/qre @msoeken @brad-lackey @jwhogabo /source/playground @billti @minestarks +/source/qre @msoeken @brad-lackey @jwhogabo /source/resource_estimator @billti @ivanbasov @swernli /samples @DmitryVasilevsky @minestarks @swernli /source/vscode @billti @idavis @minestarks From 413cba3b1924f004481ca36e34d94e74d0f6a679 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Tue, 27 Jan 2026 10:29:04 +0100 Subject: [PATCH 02/45] Types to define ISAs and ISA requirements (#2856) This pull request sets up the integration of QRE into the code base. In the first PR it defines classes and functions to create ISAs and ISA requirements. --- Cargo.lock | 9 + Cargo.toml | 1 + source/pip/Cargo.toml | 1 + source/pip/qsharp/qre/__init__.py | 35 ++ source/pip/qsharp/qre/_instruction.py | 127 +++++++ source/pip/qsharp/qre/_qre.py | 17 + source/pip/qsharp/qre/_qre.pyi | 416 ++++++++++++++++++++++ source/pip/qsharp/qre/instruction_ids.py | 91 +++++ source/pip/src/interpreter.rs | 2 + source/pip/src/lib.rs | 1 + source/pip/src/qre.rs | 337 ++++++++++++++++++ source/pip/tests/test_qre.py | 115 ++++++ source/qre/Cargo.toml | 18 + source/qre/src/isa.rs | 429 +++++++++++++++++++++++ source/qre/src/isa/tests.rs | 136 +++++++ source/qre/src/lib.rs | 8 + 16 files changed, 1743 insertions(+) create mode 100644 source/pip/qsharp/qre/__init__.py create mode 100644 source/pip/qsharp/qre/_instruction.py create mode 100644 source/pip/qsharp/qre/_qre.py create mode 100644 source/pip/qsharp/qre/_qre.pyi create mode 100644 source/pip/qsharp/qre/instruction_ids.py create mode 100644 source/pip/src/qre.rs create mode 100644 source/pip/tests/test_qre.py create mode 100644 source/qre/Cargo.toml create mode 100644 source/qre/src/isa.rs create mode 100644 source/qre/src/isa/tests.rs create mode 100644 source/qre/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 069dcabf19..bfe0d95443 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1863,6 +1863,14 @@ dependencies = [ "wgpu", ] +[[package]] +name = "qre" +version = "0.0.0" +dependencies = [ + "num-traits", + "rustc-hash", +] + [[package]] name = "qsc" version = "0.0.0" @@ -2269,6 +2277,7 @@ dependencies = [ "num-traits", "pyo3", "qdk_simulators", + "qre", "qsc", "rand 0.8.5", "rayon", diff --git a/Cargo.toml b/Cargo.toml index 23b0187950..6e01cd85de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ "source/language_service", "source/simulators", "source/pip", + "source/qre", "source/resource_estimator", "source/samples_test", "source/wasm", diff --git a/source/pip/Cargo.toml b/source/pip/Cargo.toml index a5bee271a3..95db22fce7 100644 --- a/source/pip/Cargo.toml +++ b/source/pip/Cargo.toml @@ -17,6 +17,7 @@ num-traits = { workspace = true } qsc = { path = "../compiler/qsc" } qdk_simulators = { path = "../simulators" } resource_estimator = { path = "../resource_estimator" } +qre = { path = "../qre" } miette = { workspace = true, features = ["fancy"] } rustc-hash = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py new file mode 100644 index 0000000000..e3de51a171 --- /dev/null +++ b/source/pip/qsharp/qre/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._instruction import ( + LOGICAL, + PHYSICAL, + Encoding, + constraint, + instruction, +) +from ._qre import ( + ISA, + Constraint, + ConstraintBound, + ISARequirements, + block_linear_function, + constant_function, + linear_function, +) + +__all__ = [ + "block_linear_function", + "constant_function", + "constraint", + "instruction", + "isa_constraints", + "linear_function", + "Constraint", + "ConstraintBound", + "Encoding", + "ISA", + "ISARequirements", + "LOGICAL", + "PHYSICAL", +] diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py new file mode 100644 index 0000000000..af4782b3db --- /dev/null +++ b/source/pip/qsharp/qre/_instruction.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Optional, overload, cast +from enum import IntEnum + +from ._qre import ( + Instruction, + Constraint, + FloatFunction, + IntFunction, + constant_function, + ConstraintBound, +) + + +class Encoding(IntEnum): + PHYSICAL = 0 + LOGICAL = 1 + + +PHYSICAL = Encoding.PHYSICAL +LOGICAL = Encoding.LOGICAL + + +def constraint( + id: int, + encoding: Encoding = PHYSICAL, + *, + arity: Optional[int] = 1, + error_rate: Optional[ConstraintBound] = None +) -> Constraint: + """ + Creates an instruction constraint. + + Args: + id (int): The instruction ID. + encoding (Encoding): The instruction encoding. PHYSICAL (0) or LOGICAL (1). + arity (Optional[int]): The instruction arity. If None, instruction is + assumed to have variable arity. Default is 1. + error_rate (Optional[ConstraintBound]): The constraint on the error rate. + + Returns: + Constraint: The instruction constraint. + """ + return Constraint(id, encoding, arity, error_rate) + + +@overload +def instruction( + id: int, + encoding: Encoding = PHYSICAL, + *, + time: int, + arity: int = 1, + space: Optional[int] = None, + length: Optional[int] = None, + error_rate: float +) -> Instruction: ... +@overload +def instruction( + id: int, + encoding: Encoding = PHYSICAL, + *, + time: int | IntFunction, + arity: None = ..., + space: Optional[IntFunction] = None, + length: Optional[IntFunction] = None, + error_rate: FloatFunction +) -> Instruction: ... +def instruction( + id: int, + encoding: Encoding = PHYSICAL, + *, + time: int | IntFunction, + arity: Optional[int] = 1, + space: Optional[int] | IntFunction = None, + length: Optional[int | IntFunction] = None, + error_rate: float | FloatFunction +) -> Instruction: + """ + Creates an instruction. + + Args: + id (int): The instruction ID. + encoding (Encoding): The instruction encoding. PHYSICAL (0) or LOGICAL (1). + time (int | IntFunction): The instruction time in ns. + arity (Optional[int]): The instruction arity. If None, instruction is + assumed to have variable arity. Default is 1. One can use variable arity + functions for time, space, length, and error_rate in this case. + space (Optional[int] | IntFunction): The instruction space in number of + physical qubits. If None, length is used. + length (Optional[int | IntFunction]): The arity including ancilla + qubits. If None, arity is used. + error_rate (float | FloatFunction): The instruction error rate. + + Returns: + Instruction: The instruction. + """ + if arity is not None: + return Instruction.fixed_arity( + id, + encoding, + arity, + cast(int, time), + cast(int | None, space), + cast(int | None, length), + cast(float, error_rate), + ) + else: + if isinstance(time, int): + time = constant_function(time) + if isinstance(space, int): + space = constant_function(space) + if isinstance(length, int): + length = constant_function(length) + if isinstance(error_rate, float): + error_rate = constant_function(error_rate) + + return Instruction.variable_arity( + id, + encoding, + time, + cast(IntFunction, space), + cast(FloatFunction, error_rate), + length, + ) diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py new file mode 100644 index 0000000000..c01b87587b --- /dev/null +++ b/source/pip/qsharp/qre/_qre.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# flake8: noqa E402 + +from .._native import ( + ISA, + Constraint, + ConstraintBound, + Instruction, + ISARequirements, + FloatFunction, + IntFunction, + block_linear_function, + constant_function, + linear_function, +) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi new file mode 100644 index 0000000000..c442aea117 --- /dev/null +++ b/source/pip/qsharp/qre/_qre.pyi @@ -0,0 +1,416 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +from typing import Iterator, Optional, overload + +class ISA: + @overload + def __new__(cls, *instructions: Instruction) -> ISA: ... + @overload + def __new__(cls, instructions: list[Instruction], /) -> ISA: ... + def __new__(cls, *instructions: Instruction | list[Instruction]) -> ISA: + """ + Creates an ISA from a list of instructions. + + Args: + instructions (list[Instruction] | *Instruction): The list of instructions. + """ + ... + + def satisfies(self, requirements: ISARequirements) -> bool: + """ + Checks if the ISA satisfies the given ISA requirements. + """ + ... + + def __getitem__(self, id: int) -> Instruction: + """ + Gets an instruction by its ID. + + Args: + id (int): The instruction ID. + + Returns: + Instruction: The instruction. + """ + ... + + def __len__(self) -> int: + """ + Returns the number of instructions in the ISA. + + Returns: + int: The number of instructions. + """ + ... + + def __iter__(self) -> Iterator[Instruction]: + """ + Returns an iterator over the instructions. + + Note: + The order of instructions is not guaranteed. + + Returns: + Iterator[Instruction]: The instruction iterator. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the ISA. + + Note: + The order of instructions in the output is not guaranteed. + + Returns: + str: A string representation of the ISA. + """ + ... + +class ISARequirements: + @overload + def __new__(cls, *constraints: Constraint) -> ISARequirements: ... + @overload + def __new__(cls, constraints: list[Constraint], /) -> ISARequirements: ... + def __new__(cls, *constraints: Constraint | list[Constraint]) -> ISARequirements: + """ + Creates an ISA requirements specification from a list of instructions + constraints. + + Args: + constraints (list[InstructionConstraint] | *InstructionConstraint): The list of instruction + constraints. + """ + ... + +class Instruction: + @staticmethod + def fixed_arity( + id: int, + encoding: int, + arity: int, + time: int, + space: Optional[int], + length: Optional[int], + error_rate: float, + ) -> Instruction: + """ + Creates an instruction with a fixed arity. + + Note: + This function is not intended to be called directly by the user, use qre.instruction instead. + + Args: + id (int): The instruction ID. + encoding (int): The instruction encoding. 0 = Physical, 1 = Logical. + arity (int): The instruction arity. + time (int): The instruction time in ns. + space (Optional[int]): The instruction space in number of physical + qubits. If None, length is used. + length (Optional[int]): The arity including ancilla qubits. If None, + arity is used. + error_rate (float): The instruction error rate. + + Returns: + Instruction: The instruction. + """ + ... + + @staticmethod + def variable_arity( + id: int, + encoding: int, + time_fn: IntFunction, + space_fn: IntFunction, + error_rate_fn: FloatFunction, + length_fn: Optional[IntFunction], + ) -> Instruction: + """ + Creates an instruction with variable arity. + + Note: + This function is not intended to be called directly by the user, use qre.instruction instead. + + Args: + id (int): The instruction ID. + encoding (int): The instruction encoding. 0 = Physical, 1 = Logical. + time_fn (IntFunction): The time function. + space_fn (IntFunction): The space function. + error_rate_fn (FloatFunction): The error rate function. + length_fn (Optional[IntFunction]): The length function. + If None, space_fn is used. + + Returns: + Instruction: The instruction. + """ + ... + + @property + def id(self) -> int: + """ + The instruction ID. + + Returns: + int: The instruction ID. + """ + ... + + @property + def encoding(self) -> int: + """ + The instruction encoding. 0 = Physical, 1 = Logical. + + Returns: + int: The instruction encoding. + """ + ... + + @property + def arity(self) -> Optional[int]: + """ + The instruction arity. + + Returns: + Optional[int]: The instruction arity. + """ + ... + + def space(self, arity: Optional[int] = None) -> Optional[int]: + """ + The instruction space in number of physical qubits. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + Optional[int]: The instruction space in number of physical qubits. + """ + ... + + def time(self, arity: Optional[int] = None) -> Optional[int]: + """ + The instruction time in ns. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + Optional[int]: The instruction time in ns. + """ + ... + + def error_rate(self, arity: Optional[int] = None) -> Optional[float]: + """ + The instruction error rate. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + Optional[float]: The instruction error rate. + """ + ... + + def expect_space(self, arity: Optional[int] = None) -> int: + """ + The instruction space in number of physical qubits. Raises an error if not found. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + int: The instruction space in number of physical qubits. + """ + ... + + def expect_time(self, arity: Optional[int] = None) -> int: + """ + The instruction time in ns. Raises an error if not found. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + int: The instruction time in ns. + """ + ... + + def expect_error_rate(self, arity: Optional[int] = None) -> float: + """ + The instruction error rate. Raises an error if not found. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + float: The instruction error rate. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the instruction. + + Returns: + str: A string representation of the instruction. + """ + ... + +class ConstraintBound: + """ + A bound for a constraint. + """ + + @staticmethod + def lt(value: float) -> ConstraintBound: + """ + Creates a less than constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + + @staticmethod + def le(value: float) -> ConstraintBound: + """ + Creates a less equal constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + + @staticmethod + def eq(value: float) -> ConstraintBound: + """ + Creates an equal constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + + @staticmethod + def gt(value: float) -> ConstraintBound: + """ + Creates a greater than constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + + @staticmethod + def ge(value: float) -> ConstraintBound: + """ + Creates a greater equal constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + +class Constraint: + """ + An instruction constraint that can be used to describe ISA requirements + for ISA transformations. + """ + + def __new__( + cls, + id: int, + encoding: int, + arity: Optional[int], + error_rate: Optional[ConstraintBound], + ) -> Constraint: + """ + Note: + This function is not intended to be called directly by the user, use qre.constraint instead. + + Args: + id (int): The instruction ID. + encoding (int): The instruction encoding. 0 = Physical, 1 = Logical. + arity (Optional[int]): The instruction arity. If None, instruction is + assumed to have variable arity. + error_rate (Optional[ConstraintBound]): The constraint on the error rate. + + Returns: + InstructionConstraint: The instruction constraint. + """ + ... + +class IntFunction: ... +class FloatFunction: ... + +@overload +def constant_function(value: int) -> IntFunction: ... +@overload +def constant_function(value: float) -> FloatFunction: ... +def constant_function( + value: int | float, +) -> IntFunction | FloatFunction: + """ + Creates a constant function. + + Args: + value (int | float): The constant value. + + Returns: + IntFunction | FloatFunction: The constant function. + """ + ... + +@overload +def linear_function(slope: int) -> IntFunction: ... +@overload +def linear_function(slope: float) -> FloatFunction: ... +def linear_function( + slope: int | float, +) -> IntFunction | FloatFunction: + """ + Creates a linear function. + + Args: + slope (int | float): The slope. + + Returns: + IntFunction | FloatFunction: The linear function. + """ + ... + +@overload +def block_linear_function(block_size: int, slope: int) -> IntFunction: ... +@overload +def block_linear_function(block_size: int, slope: float) -> FloatFunction: ... +def block_linear_function( + block_size: int, slope: int | float +) -> IntFunction | FloatFunction: + """ + Creates a block linear function. + + Args: + block_size (int): The block size. + slope (int | float): The slope. + + Returns: + IntFunction | FloatFunction: The block linear function. + """ + ... diff --git a/source/pip/qsharp/qre/instruction_ids.py b/source/pip/qsharp/qre/instruction_ids.py new file mode 100644 index 0000000000..f89bcc6c5b --- /dev/null +++ b/source/pip/qsharp/qre/instruction_ids.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +################### +# Instruction IDs # +################### + +# Paulis +PAULI_I = 0x0 +PAULI_X = 0x1 +PAULI_Y = 0x2 +PAULI_Z = 0x3 + +# Clifford gates +H = H_XZ = 0x10 +H_XY = 0x11 +H_YZ = 0x12 +SQRT_X = 0x13 +SQRT_X_DAG = 0x14 +SQRT_Y = 0x15 +SQRT_Y_DAG = 0x16 +S = SQRT_Z = 0x17 +S_DAG = SQRT_Z_DAG = 0x18 +CNOT = CX = 0x19 +CY = 0x1A +CZ = 0x1B +SWAP = 0x1C + +# State preparation +PREP_X = 0x30 +PREP_Y = 0x31 +PREP_Z = 0x32 + +# Generic Cliffords +ONE_QUBIT_CLIFFORD = 0x50 +TWO_QUBIT_CLIFFORD = 0x51 +N_QUBIT_CLIFFORD = 0x52 + +# Measurements +MEAS_X = 0x100 +MEAS_Y = 0x101 +MEAS_Z = 0x102 +MEAS_RESET_X = 0x103 +MEAS_RESET_Y = 0x104 +MEAS_RESET_Z = 0x105 +MEAS_XX = 0x106 +MEAS_YY = 0x107 +MEAS_ZZ = 0x108 +MEAS_XZ = 0x109 +MEAS_XY = 0x10A +MEAS_YZ = 0x10B + +# Non-Clifford gates +SQRT_SQRT_X = 0x400 +SQRT_SQRT_X_DAG = 0x401 +SQRT_SQRT_Y = 0x402 +SQRT_SQRT_Y_DAG = 0x403 +SQRT_SQRT_Z = T = 0x404 +SQRT_SQRT_Z_DAG = T_DAG = 0x405 +CCX = 0x406 +CCY = 0x407 +CCZ = 0x408 +CSWAP = 0x409 +AND = 0x40A +AND_DAG = 0x40B +RX = 0x40C +RY = 0x40D +RZ = 0x40E +CRX = 0x40F +CRY = 0x410 +CRZ = 0x411 +RXX = 0x412 +RYY = 0x413 +RZZ = 0x414 + +# Multi-qubit Pauli measurement +MULTI_PAULI_MEAS = 0x1000 + +# Some generic logical instructions +LATTICE_SURGERY = 0x1100 + +# Memory/compute operations (used in compute parts of memory-compute layouts) +READ_FROM_MEMORY = 0x1200 +WRITE_TO_MEMORY = 0x1201 + +# Some special hardware physical instructions +CYCLIC_SHIFT = 0x1300 + +# Generic operation (for unified RE) +GENERIC = 0xFFFF diff --git a/source/pip/src/interpreter.rs b/source/pip/src/interpreter.rs index f135370924..99cb74ca32 100644 --- a/source/pip/src/interpreter.rs +++ b/source/pip/src/interpreter.rs @@ -27,6 +27,7 @@ use crate::{ cpu_simulators::{run_clifford, run_cpu_full_state}, gpu_full_state::{GpuContext, run_parallel_shots, try_create_gpu_adapter}, }, + qre::register_qre_submodule, }; use miette::{Diagnostic, Report}; use num_bigint::BigUint; @@ -131,6 +132,7 @@ fn _native<'a>(py: Python<'a>, m: &Bound<'a, PyModule>) -> PyResult<()> { m.add("QSharpError", py.get_type::())?; register_noisy_simulator_submodule(py, m)?; register_generic_estimator_submodule(m)?; + register_qre_submodule(m)?; // QASM interop m.add("QasmError", py.get_type::())?; m.add_function(wrap_pyfunction!(resource_estimate_qasm_program, m)?)?; diff --git a/source/pip/src/lib.rs b/source/pip/src/lib.rs index 72edafb7ca..b962d97df5 100644 --- a/source/pip/src/lib.rs +++ b/source/pip/src/lib.rs @@ -10,3 +10,4 @@ mod interop; mod interpreter; mod noisy_simulator; mod qir_simulation; +mod qre; diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs new file mode 100644 index 0000000000..93e9680e15 --- /dev/null +++ b/source/pip/src/qre.rs @@ -0,0 +1,337 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use pyo3::{IntoPyObjectExt, prelude::*, types::PyTuple}; + +pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(constant_function, m)?)?; + m.add_function(wrap_pyfunction!(linear_function, m)?)?; + m.add_function(wrap_pyfunction!(block_linear_function, m)?)?; + Ok(()) +} + +#[allow(clippy::upper_case_acronyms)] +#[pyclass] +pub struct ISA(qre::ISA); + +#[pymethods] +impl ISA { + #[new] + #[pyo3(signature = (*instructions))] + pub fn new(instructions: &Bound<'_, PyTuple>) -> PyResult { + if instructions.len() == 1 { + let item = instructions.get_item(0)?; + if let Ok(seq) = item.cast_into::() { + let mut instrs = Vec::with_capacity(seq.len()); + for item in seq.iter() { + let instr = item.cast_into::()?; + instrs.push(instr.borrow().0.clone()); + } + return Ok(ISA(instrs.into_iter().collect())); + } + } + + instructions + .into_iter() + .map(|instr| { + let instr = instr.cast_into::()?; + Ok(instr.borrow().0.clone()) + }) + .collect::>() + .map(ISA) + } + + pub fn satisfies(&self, requirements: &ISARequirements) -> PyResult { + Ok(self.0.satisfies(&requirements.0)) + } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + pub fn __getitem__(&self, id: u64) -> PyResult { + match self.0.get(&id) { + Some(instr) => Ok(Instruction(instr.clone())), + None => Err(PyErr::new::(format!( + "Instruction with id {id} not found" + ))), + } + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let iter = ISAIterator { + iter: slf.0.clone().into_iter(), + }; + Py::new(slf.py(), iter) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyclass] +pub struct ISAIterator { + iter: std::collections::hash_map::IntoIter, +} + +#[pymethods] +impl ISAIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next().map(|(_, instr)| Instruction(instr)) + } +} + +#[pyclass] +pub struct ISARequirements(qre::ISARequirements); + +#[pymethods] +impl ISARequirements { + #[new] + #[pyo3(signature = (*constraints))] + pub fn new(constraints: &Bound<'_, PyTuple>) -> PyResult { + if constraints.len() == 1 { + let item = constraints.get_item(0)?; + if let Ok(seq) = item.cast::() { + let mut instrs = Vec::with_capacity(seq.len()); + for item in seq.iter() { + let instr = item.cast_into::()?; + instrs.push(instr.borrow().0.clone()); + } + return Ok(ISARequirements(instrs.into_iter().collect())); + } + } + + constraints + .into_iter() + .map(|instr| { + let instr = instr.cast_into::()?; + Ok(instr.borrow().0.clone()) + }) + .collect::>() + .map(ISARequirements) + } +} + +#[pyclass] +pub struct Instruction(qre::Instruction); + +#[pymethods] +impl Instruction { + #[staticmethod] + pub fn fixed_arity( + id: u64, + encoding: u64, + arity: u64, + time: u64, + space: Option, + length: Option, + error_rate: f64, + ) -> PyResult { + Ok(Instruction(qre::Instruction::fixed_arity( + id, + convert_encoding(encoding)?, + arity, + time, + space, + length, + error_rate, + ))) + } + + #[staticmethod] + pub fn variable_arity( + id: u64, + encoding: u64, + time_fn: &IntFunction, + space_fn: &IntFunction, + error_rate_fn: &FloatFunction, + length_fn: Option<&IntFunction>, + ) -> PyResult { + Ok(Instruction(qre::Instruction::variable_arity( + id, + convert_encoding(encoding)?, + time_fn.0.clone(), + space_fn.0.clone(), + length_fn.map(|f| f.0.clone()), + error_rate_fn.0.clone(), + ))) + } + + #[getter] + pub fn id(&self) -> u64 { + self.0.id() + } + + #[getter] + pub fn encoding(&self) -> u64 { + match self.0.encoding() { + qre::Encoding::Physical => 0, + qre::Encoding::Logical => 1, + } + } + + #[getter] + pub fn arity(&self) -> Option { + self.0.arity() + } + + #[pyo3(signature = (arity=None))] + pub fn space(&self, arity: Option) -> Option { + self.0.space(arity) + } + + #[pyo3(signature = (arity=None))] + pub fn time(&self, arity: Option) -> Option { + self.0.time(arity) + } + + #[pyo3(signature = (arity=None))] + pub fn error_rate(&self, arity: Option) -> Option { + self.0.error_rate(arity) + } + + #[pyo3(signature = (arity=None))] + pub fn expect_space(&self, arity: Option) -> PyResult { + Ok(self.0.expect_space(arity)) + } + + #[pyo3(signature = (arity=None))] + pub fn expect_time(&self, arity: Option) -> PyResult { + Ok(self.0.expect_time(arity)) + } + + #[pyo3(signature = (arity=None))] + pub fn expect_error_rate(&self, arity: Option) -> PyResult { + Ok(self.0.expect_error_rate(arity)) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyclass] +pub struct Constraint(qre::InstructionConstraint); + +#[pymethods] +impl Constraint { + #[new] + pub fn new( + id: u64, + encoding: u64, + arity: Option, + error_rate: Option<&ConstraintBound>, + ) -> PyResult { + Ok(Constraint(qre::InstructionConstraint::new( + id, + convert_encoding(encoding)?, + arity, + error_rate.map(|error_rate| error_rate.0), + ))) + } +} + +fn convert_encoding(encoding: u64) -> PyResult { + match encoding { + 0 => Ok(qre::Encoding::Physical), + 1 => Ok(qre::Encoding::Logical), + _ => Err(PyErr::new::( + "Invalid encoding value", + )), + } +} + +#[pyclass] +pub struct ConstraintBound(qre::ConstraintBound); + +#[pymethods] +impl ConstraintBound { + #[staticmethod] + pub fn lt(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::less_than(value)) + } + + #[staticmethod] + pub fn le(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::less_equal(value)) + } + + #[staticmethod] + pub fn eq(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::equal(value)) + } + + #[staticmethod] + pub fn gt(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::greater_than(value)) + } + + #[staticmethod] + pub fn ge(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::greater_equal(value)) + } +} + +#[pyclass] +pub struct IntFunction(qre::VariableArityFunction); + +#[pyclass] +pub struct FloatFunction(qre::VariableArityFunction); + +#[pyfunction] +pub fn constant_function<'py>(value: &Bound<'py, PyAny>) -> PyResult> { + if let Ok(v) = value.extract::() { + IntFunction(qre::VariableArityFunction::Constant { value: v }).into_bound_py_any(value.py()) + } else if let Ok(v) = value.extract::() { + FloatFunction(qre::VariableArityFunction::Constant { value: v }) + .into_bound_py_any(value.py()) + } else { + Err(PyErr::new::( + "Value must be either an integer or a float", + )) + } +} + +#[pyfunction] +pub fn linear_function<'py>(slope: &Bound<'py, PyAny>) -> PyResult> { + if let Ok(s) = slope.extract::() { + IntFunction(qre::VariableArityFunction::linear(s)).into_bound_py_any(slope.py()) + } else if let Ok(s) = slope.extract::() { + FloatFunction(qre::VariableArityFunction::linear(s)).into_bound_py_any(slope.py()) + } else { + Err(PyErr::new::( + "Slope must be either an integer or a float", + )) + } +} + +#[pyfunction] +pub fn block_linear_function<'py>( + block_size: u64, + slope: &Bound<'py, PyAny>, +) -> PyResult> { + if let Ok(s) = slope.extract::() { + IntFunction(qre::VariableArityFunction::block_linear(block_size, s)) + .into_bound_py_any(slope.py()) + } else if let Ok(s) = slope.extract::() { + FloatFunction(qre::VariableArityFunction::block_linear(block_size, s)) + .into_bound_py_any(slope.py()) + } else { + Err(PyErr::new::( + "Slope must be either an integer or a float", + )) + } +} diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py new file mode 100644 index 0000000000..b3449dfbff --- /dev/null +++ b/source/pip/tests/test_qre.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass, KW_ONLY +from typing import Generator + + +from qsharp.qre import ( + constraint, + ConstraintBound, + ISA, + ISARequirements, + instruction, + linear_function, + LOGICAL, +) +from qsharp.qre.instruction_ids import ( + T, + TWO_QUBIT_CLIFFORD, + H, + CNOT, + MEAS_Z, + GENERIC, + LATTICE_SURGERY, +) + + +# NOTE These classes will be generalized as part of the QRE API in the following +# pull requests and then moved out of the tests. + + +class Architecture: + @property + def provided_isa(self) -> ISA: + return ISA( + instruction(H, time=50, error_rate=1e-3), + instruction(CNOT, arity=2, time=50, error_rate=1e-3), + instruction(MEAS_Z, time=100, error_rate=1e-3), + instruction(T, time=40, error_rate=1e-4), + instruction(TWO_QUBIT_CLIFFORD, arity=2, time=50, error_rate=1e-3), + ) + + +@dataclass +class SurfaceCode: + _: KW_ONLY + distance: int = 7 + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(H, error_rate=ConstraintBound.lt(0.01)), + constraint(CNOT, arity=2, error_rate=ConstraintBound.lt(0.01)), + constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), + ) + + def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + crossing_prefactor: float = 0.03 + error_correction_threshold: float = 0.01 + + cnot_time = impl_isa[CNOT].expect_time() + h_time = impl_isa[H].expect_time() + meas_time = impl_isa[MEAS_Z].expect_time() + + physical_error_rate = max( + impl_isa[CNOT].expect_error_rate(), + impl_isa[H].expect_error_rate(), + impl_isa[MEAS_Z].expect_error_rate(), + ) + + space_formula = linear_function(2 * self.distance**2) + + time_value = (h_time + meas_time + cnot_time * 4) * self.distance + + error_formula = linear_function( + crossing_prefactor + * ( + (physical_error_rate / error_correction_threshold) + ** ((self.distance + 1) // 2) + ) + ) + + yield ISA( + instruction( + GENERIC, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + ), + instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + ), + ) + + +def test_isa_from_architecture(): + arch = Architecture() + code = SurfaceCode() + + # Verify that the architecture satisfies the code requirements + assert arch.provided_isa.satisfies(SurfaceCode.required_isa()) + + # Generate logical ISAs + isas = list(code.provided_isa(arch.provided_isa)) + + # There is one ISA with two instructions + assert len(isas) == 1 + assert len(isas[0]) == 2 diff --git a/source/qre/Cargo.toml b/source/qre/Cargo.toml new file mode 100644 index 0000000000..e7e7c80783 --- /dev/null +++ b/source/qre/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "qre" + +version.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rustc-hash = { workspace = true } +num-traits = { workspace = true } + +[dev-dependencies] + +[lints] +workspace = true diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs new file mode 100644 index 0000000000..9668995e4d --- /dev/null +++ b/source/qre/src/isa.rs @@ -0,0 +1,429 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + fmt::Display, + ops::{Add, Deref, Index}, +}; + +use num_traits::FromPrimitive; +use rustc_hash::FxHashMap; + +#[cfg(test)] +mod tests; + +#[derive(Default)] +pub struct ISA { + instructions: FxHashMap, +} + +impl ISA { + #[must_use] + pub fn new() -> Self { + ISA { + instructions: FxHashMap::default(), + } + } + + pub fn add_instruction(&mut self, instruction: Instruction) { + self.instructions.insert(instruction.id, instruction); + } + + #[must_use] + pub fn get(&self, id: &u64) -> Option<&Instruction> { + self.instructions.get(id) + } + + #[must_use] + pub fn satisfies(&self, requirements: &ISARequirements) -> bool { + for constraint in requirements.constraints.values() { + let Some(instruction) = self.instructions.get(&constraint.id) else { + return false; + }; + + if instruction.encoding != constraint.encoding { + return false; + } + + match &instruction.metrics { + Metrics::FixedArity { + arity, error_rate, .. + } => { + // Constraint requires variable arity for this instruction + let Some(constraint_arity) = constraint.arity else { + return false; + }; + + // Arity must match + if *arity != constraint_arity { + return false; + } + + // Error rate constraint must be satisfied + if let Some(ref bound) = constraint.error_rate_fn + && !bound.evaluate(error_rate) + { + return false; + } + } + + Metrics::VariableArity { error_rate_fn, .. } => { + // If an arity and error rate constraint is specified, it + // must be satisfied + if let (Some(constraint_arity), Some(ref bound)) = + (constraint.arity, constraint.error_rate_fn) + && !bound.evaluate(&error_rate_fn.evaluate(constraint_arity)) + { + return false; + } + } + } + } + true + } +} + +impl Deref for ISA { + type Target = FxHashMap; + + fn deref(&self) -> &Self::Target { + &self.instructions + } +} + +impl FromIterator for ISA { + fn from_iter>(iter: T) -> Self { + let mut isa = ISA::new(); + for instruction in iter { + isa.add_instruction(instruction); + } + isa + } +} + +impl Display for ISA { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for instruction in self.instructions.values() { + writeln!(f, "{instruction}")?; + } + Ok(()) + } +} + +impl Index for ISA { + type Output = Instruction; + + fn index(&self, index: u64) -> &Self::Output { + &self.instructions[&index] + } +} + +impl Add for ISA { + type Output = ISA; + + fn add(self, other: ISA) -> ISA { + let mut combined = self; + for instruction in other.instructions.into_values() { + combined.add_instruction(instruction); + } + combined + } +} + +#[derive(Default)] +pub struct ISARequirements { + constraints: FxHashMap, +} + +impl ISARequirements { + #[must_use] + pub fn new() -> Self { + ISARequirements { + constraints: FxHashMap::default(), + } + } + + pub fn add_constraint(&mut self, constraint: InstructionConstraint) { + self.constraints.insert(constraint.id, constraint); + } +} + +impl FromIterator for ISARequirements { + fn from_iter>(iter: T) -> Self { + let mut reqs = ISARequirements::new(); + for constraint in iter { + reqs.add_constraint(constraint); + } + reqs + } +} + +#[derive(Clone)] +pub struct Instruction { + id: u64, + encoding: Encoding, + metrics: Metrics, +} + +impl Instruction { + #[must_use] + pub fn fixed_arity( + id: u64, + encoding: Encoding, + arity: u64, + time: u64, + space: Option, + length: Option, + error_rate: f64, + ) -> Self { + let length = length.unwrap_or(arity); + let space = space.unwrap_or(length); + + Instruction { + id, + encoding, + metrics: Metrics::FixedArity { + arity, + length, + space, + time, + error_rate, + }, + } + } + + #[must_use] + pub fn variable_arity( + id: u64, + encoding: Encoding, + time_fn: VariableArityFunction, + space_fn: VariableArityFunction, + length_fn: Option>, + error_rate_fn: VariableArityFunction, + ) -> Self { + let length_fn = length_fn.unwrap_or_else(|| space_fn.clone()); + + Instruction { + id, + encoding, + metrics: Metrics::VariableArity { + length_fn, + space_fn, + time_fn, + error_rate_fn, + }, + } + } + + #[must_use] + pub fn id(&self) -> u64 { + self.id + } + + #[must_use] + pub fn encoding(&self) -> Encoding { + self.encoding + } + + #[must_use] + pub fn arity(&self) -> Option { + match &self.metrics { + Metrics::FixedArity { arity, .. } => Some(*arity), + Metrics::VariableArity { .. } => None, + } + } + + #[must_use] + pub fn space(&self, arity: Option) -> Option { + match &self.metrics { + Metrics::FixedArity { space, .. } => Some(*space), + Metrics::VariableArity { space_fn, .. } => arity.map(|a| space_fn.evaluate(a)), + } + } + + #[must_use] + pub fn length(&self, arity: Option) -> Option { + match &self.metrics { + Metrics::FixedArity { length, .. } => Some(*length), + Metrics::VariableArity { length_fn, .. } => arity.map(|a| length_fn.evaluate(a)), + } + } + + #[must_use] + pub fn time(&self, arity: Option) -> Option { + match &self.metrics { + Metrics::FixedArity { time, .. } => Some(*time), + Metrics::VariableArity { time_fn, .. } => arity.map(|a| time_fn.evaluate(a)), + } + } + + #[must_use] + pub fn error_rate(&self, arity: Option) -> Option { + match &self.metrics { + Metrics::FixedArity { error_rate, .. } => Some(*error_rate), + Metrics::VariableArity { error_rate_fn, .. } => { + arity.map(|a| error_rate_fn.evaluate(a)) + } + } + } + + #[must_use] + pub fn expect_space(&self, arity: Option) -> u64 { + self.space(arity) + .expect("Instruction does not support variable arity") + } + + #[must_use] + pub fn expect_length(&self, arity: Option) -> u64 { + self.length(arity) + .expect("Instruction does not support variable arity") + } + + #[must_use] + pub fn expect_time(&self, arity: Option) -> u64 { + self.time(arity) + .expect("Instruction does not support variable arity") + } + + #[must_use] + pub fn expect_error_rate(&self, arity: Option) -> f64 { + self.error_rate(arity) + .expect("Instruction does not support variable arity") + } +} + +impl Display for Instruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.metrics { + Metrics::FixedArity { arity, .. } => { + write!(f, "{} |{:?}| arity: {arity}", self.id, self.encoding) + } + Metrics::VariableArity { .. } => write!(f, "{} |{:?}|", self.id, self.encoding), + } + } +} + +#[derive(Clone)] +pub struct InstructionConstraint { + id: u64, + encoding: Encoding, + arity: Option, + error_rate_fn: Option>, +} + +impl InstructionConstraint { + #[must_use] + pub fn new( + id: u64, + encoding: Encoding, + arity: Option, + error_rate_fn: Option>, + ) -> Self { + InstructionConstraint { + id, + encoding, + arity, + error_rate_fn, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Encoding { + Physical, + Logical, +} + +#[derive(Clone)] +pub enum Metrics { + FixedArity { + arity: u64, + length: u64, + space: u64, + time: u64, + error_rate: f64, + }, + VariableArity { + length_fn: VariableArityFunction, + space_fn: VariableArityFunction, + time_fn: VariableArityFunction, + error_rate_fn: VariableArityFunction, + }, +} + +#[derive(Clone)] +pub enum VariableArityFunction { + Constant { value: T }, + Linear { slope: T }, + BlockLinear { block_size: u64, slope: T }, +} + +impl + std::ops::Mul + Copy + FromPrimitive> + VariableArityFunction +{ + pub fn constant(value: T) -> Self { + VariableArityFunction::Constant { value } + } + + pub fn linear(slope: T) -> Self { + VariableArityFunction::Linear { slope } + } + + pub fn block_linear(block_size: u64, slope: T) -> Self { + VariableArityFunction::BlockLinear { block_size, slope } + } + + pub fn evaluate(&self, arity: u64) -> T { + match self { + VariableArityFunction::Constant { value } => *value, + VariableArityFunction::Linear { slope } => { + *slope * T::from_u64(arity).expect("Failed to convert u64 to target type") + } + VariableArityFunction::BlockLinear { block_size, slope } => { + let blocks = arity.div_ceil(*block_size); + *slope * T::from_u64(blocks).expect("Failed to convert u64 to target type") + } + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ConstraintBound { + LessThan(T), + LessEqual(T), + Equal(T), + GreaterThan(T), + GreaterEqual(T), +} + +impl ConstraintBound { + pub fn less_than(value: T) -> Self { + ConstraintBound::LessThan(value) + } + + pub fn less_equal(value: T) -> Self { + ConstraintBound::LessEqual(value) + } + + pub fn equal(value: T) -> Self { + ConstraintBound::Equal(value) + } + + pub fn greater_than(value: T) -> Self { + ConstraintBound::GreaterThan(value) + } + + pub fn greater_equal(value: T) -> Self { + ConstraintBound::GreaterEqual(value) + } + + pub fn evaluate(&self, other: &T) -> bool { + match self { + ConstraintBound::LessThan(v) => other < v, + ConstraintBound::LessEqual(v) => other <= v, + ConstraintBound::Equal(v) => other == v, + ConstraintBound::GreaterThan(v) => other > v, + ConstraintBound::GreaterEqual(v) => other >= v, + } + } +} diff --git a/source/qre/src/isa/tests.rs b/source/qre/src/isa/tests.rs new file mode 100644 index 0000000000..d71ae1b902 --- /dev/null +++ b/source/qre/src/isa/tests.rs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn test_fixed_arity_instruction() { + let instr = Instruction::fixed_arity(1, Encoding::Physical, 2, 100, Some(10), Some(5), 0.01); + + assert_eq!(instr.id(), 1); + assert_eq!(instr.encoding(), Encoding::Physical); + assert_eq!(instr.arity(), Some(2)); + assert_eq!(instr.time(None), Some(100)); + assert_eq!(instr.space(None), Some(10)); + assert_eq!(instr.length(None), Some(5)); + assert_eq!(instr.error_rate(None), Some(0.01)); +} + +#[test] +fn test_variable_arity_instruction() { + let time_fn = VariableArityFunction::linear(10); + let space_fn = VariableArityFunction::constant(5); + let error_rate_fn = VariableArityFunction::constant(0.001); + + let instr = + Instruction::variable_arity(2, Encoding::Logical, time_fn, space_fn, None, error_rate_fn); + + assert_eq!(instr.id(), 2); + assert_eq!(instr.encoding(), Encoding::Logical); + assert_eq!(instr.arity(), None); + + // Check evaluation at specific arity + assert_eq!(instr.time(Some(3)), Some(30)); // 3 * 10 + assert_eq!(instr.space(Some(3)), Some(5)); + assert_eq!(instr.length(Some(3)), Some(5)); // Defaulted to space_fn + assert_eq!(instr.error_rate(Some(3)), Some(0.001)); + + // Check None arity returns None for variable metrics + assert_eq!(instr.time(None), None); +} + +#[test] +fn test_isa_satisfies() { + let mut isa = ISA::new(); + let instr1 = Instruction::fixed_arity(1, Encoding::Physical, 2, 100, None, None, 0.01); + isa.add_instruction(instr1); + + let mut reqs = ISARequirements::new(); + + // Test exact match + reqs.add_constraint(InstructionConstraint::new( + 1, + Encoding::Physical, + Some(2), + Some(ConstraintBound::less_than(0.02)), + )); + assert!(isa.satisfies(&reqs)); + + // Test failing error rate + let mut reqs_fail = ISARequirements::new(); + reqs_fail.add_constraint(InstructionConstraint::new( + 1, + Encoding::Physical, + Some(2), + Some(ConstraintBound::less_than(0.005)), + )); + assert!(!isa.satisfies(&reqs_fail)); + + // Test failing arity + let mut reqs_fail_arity = ISARequirements::new(); + reqs_fail_arity.add_constraint(InstructionConstraint::new( + 1, + Encoding::Physical, + Some(3), + Some(ConstraintBound::less_than(0.02)), + )); + assert!(!isa.satisfies(&reqs_fail_arity)); + + // Test failing encoding + let mut reqs_fail_enc = ISARequirements::new(); + reqs_fail_enc.add_constraint(InstructionConstraint::new( + 1, + Encoding::Logical, + Some(2), + None, + )); + assert!(!isa.satisfies(&reqs_fail_enc)); + + // Test missing instruction + let mut reqs_missing = ISARequirements::new(); + reqs_missing.add_constraint(InstructionConstraint::new( + 99, + Encoding::Physical, + None, + None, + )); + assert!(!isa.satisfies(&reqs_missing)); +} + +#[test] +fn test_variable_arity_satisfies() { + let mut isa = ISA::new(); + let time_fn = VariableArityFunction::linear(10); + let space_fn = VariableArityFunction::constant(5); + let error_rate_fn = VariableArityFunction::linear(0.001); // 0.001 * arity + + let instr = Instruction::variable_arity( + 10, + Encoding::Logical, + time_fn, + space_fn, + None, + error_rate_fn, + ); + isa.add_instruction(instr); + + let mut reqs = ISARequirements::new(); + // Check for arity 5, error rate should be 0.005 + reqs.add_constraint(InstructionConstraint::new( + 10, + Encoding::Logical, + Some(5), + Some(ConstraintBound::less_than(0.01)), + )); + assert!(isa.satisfies(&reqs)); // 0.005 < 0.01 + + let mut reqs_fail = ISARequirements::new(); + // Check for arity 20, error rate should be 0.02 + reqs_fail.add_constraint(InstructionConstraint::new( + 10, + Encoding::Logical, + Some(20), + Some(ConstraintBound::less_than(0.01)), + )); + assert!(!isa.satisfies(&reqs_fail)); // 0.02 not < 0.01 +} diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs new file mode 100644 index 0000000000..00d70a2a53 --- /dev/null +++ b/source/qre/src/lib.rs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +mod isa; +pub use isa::{ + ConstraintBound, Encoding, ISA, ISARequirements, Instruction, InstructionConstraint, + VariableArityFunction, +}; From dcc53ed27627b08e4b0495b67e1893566b1a5cc1 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Fri, 30 Jan 2026 10:32:46 +0100 Subject: [PATCH 03/45] Enumerate ISAs with transforms and queries (#2894) This implements ISA transforms that can provide new ISAs from existing ones. It also builds data structures to define ISA queries that can combine transformations and their parameters to define the ISA exploration space. --- source/pip/benchmarks/bench_qre.py | 79 ++++ source/pip/qsharp/qre/__init__.py | 5 +- source/pip/qsharp/qre/_architecture.py | 12 + source/pip/qsharp/qre/_enumeration.py | 105 +++++ source/pip/qsharp/qre/_instruction.py | 113 ++++- source/pip/qsharp/qre/_isa_enumeration.py | 354 +++++++++++++++ source/pip/qsharp/qre/_qre.pyi | 8 + source/pip/src/qre.rs | 6 +- source/pip/tests/test_qre.py | 497 +++++++++++++++++++++- source/qre/src/isa.rs | 2 +- 10 files changed, 1156 insertions(+), 25 deletions(-) create mode 100644 source/pip/benchmarks/bench_qre.py create mode 100644 source/pip/qsharp/qre/_architecture.py create mode 100644 source/pip/qsharp/qre/_enumeration.py create mode 100644 source/pip/qsharp/qre/_isa_enumeration.py diff --git a/source/pip/benchmarks/bench_qre.py b/source/pip/benchmarks/bench_qre.py new file mode 100644 index 0000000000..6dd54f2af9 --- /dev/null +++ b/source/pip/benchmarks/bench_qre.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import timeit +from dataclasses import dataclass, KW_ONLY, field +from qsharp.qre._enumeration import _enumerate_instances + + +def bench_enumerate_instances(): + # Measure performance of enumerating instances with a large domain + @dataclass + class LargeDomain: + _: KW_ONLY + param1: int = field(default=0, metadata={"domain": range(1000)}) + param2: bool + + number = 100 + + duration = timeit.timeit( + "list(_enumerate_instances(LargeDomain))", + globals={ + "_enumerate_instances": _enumerate_instances, + "LargeDomain": LargeDomain, + }, + number=number, + ) + + print(f"Enumerating instances took {duration / number:.6f} seconds on average.") + + +def bench_enumerate_isas(): + import os + import sys + + # Add the tests directory to sys.path to import test_qre + # TODO: Remove this once the models in test_qre are moved to a proper module + sys.path.append(os.path.join(os.path.dirname(__file__), "../tests")) + import test_qre # type: ignore + + from qsharp.qre._isa_enumeration import ( + Context, + ISAQuery, + ProductNode, + ) + + ctx = Context(architecture=test_qre.ExampleArchitecture()) + + # Hierarchical factory using from_components + query = ProductNode( + sources=[ + ISAQuery(test_qre.SurfaceCode), + ISAQuery( + test_qre.ExampleLogicalFactory, + source=ProductNode( + sources=[ + ISAQuery(test_qre.SurfaceCode), + ISAQuery(test_qre.ExampleFactory), + ] + ), + ), + ] + ) + + number = 100 + duration = timeit.timeit( + "list(query.enumerate(ctx))", + globals={ + "query": query, + "ctx": ctx, + }, + number=number, + ) + + print(f"Enumerating ISAs took {duration / number:.6f} seconds on average.") + + +if __name__ == "__main__": + bench_enumerate_instances() + bench_enumerate_isas() diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index e3de51a171..771a23ea14 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -7,6 +7,7 @@ Encoding, constraint, instruction, + ISATransform, ) from ._qre import ( ISA, @@ -17,19 +18,21 @@ constant_function, linear_function, ) +from ._architecture import Architecture __all__ = [ "block_linear_function", "constant_function", "constraint", "instruction", - "isa_constraints", "linear_function", + "Architecture", "Constraint", "ConstraintBound", "Encoding", "ISA", "ISARequirements", + "ISATransform", "LOGICAL", "PHYSICAL", ] diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py new file mode 100644 index 0000000000..0d95bb0a93 --- /dev/null +++ b/source/pip/qsharp/qre/_architecture.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from abc import ABC, abstractmethod + +from ._qre import ISA + + +class Architecture(ABC): + @property + @abstractmethod + def provided_isa(self) -> ISA: ... diff --git a/source/pip/qsharp/qre/_enumeration.py b/source/pip/qsharp/qre/_enumeration.py new file mode 100644 index 0000000000..59eb1a9582 --- /dev/null +++ b/source/pip/qsharp/qre/_enumeration.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Generator, Type, TypeVar, Literal, get_args, get_origin +from dataclasses import MISSING +from itertools import product +from enum import Enum + + +T = TypeVar("T") + + +def _enumerate_instances(cls: Type[T], **kwargs) -> Generator[T, None, None]: + """ + Yields all instances of a dataclass given its class. + + The enumeration logic supports defining domains for fields using the `domain` + metadata key. This allows fields to specify their valid range of values for + enumeration directly in the definition. Additionally, boolean fields are + automatically enumerated with `[True, False]`. Enum fields are enumerated + with all their members, and Literal types with their defined values. + + Args: + cls (Type[T]): The dataclass type to enumerate. + **kwargs: Fixed values or domains for fields. If a value is a list + and the corresponding field is kw_only, it is treated as a domain + to enumerate over. + + Returns: + Generator[T, None, None]: A generator yielding instances of the dataclass. + + Raises: + ValueError: If a field cannot be enumerated (no domain found). + + Example: + + .. code-block:: python + from dataclasses import dataclass, field, KW_ONLY + @dataclass + class MyConfig: + # Not part of enumeration + name: str + _ : KW_ONLY + # Part of enumeration with implicit domain [True, False] + enable_logging: bool = field(kw_only=True) + # Explicit domain in metadata + retry_count: int = field(metadata={"domain": [1, 3, 5]}, kw_only=True) + """ + + names = [] + values = [] + fixed_kwargs = {} + + if (fields := getattr(cls, "__dataclass_fields__", None)) is None: + # There are no fields defined for this class, so just yield a single + # instance + yield cls(**kwargs) + return + + for field in fields.values(): + name = field.name + + if name in kwargs: + val = kwargs[name] + # If kw_only and list, it's a domain to enumerate + if field.kw_only and isinstance(val, list): + names.append(name) + values.append(val) + else: + # Otherwise, it's a fixed value + fixed_kwargs[name] = val + continue + + if not field.kw_only: + # We don't enumerate non-kw-only fields that aren't in kwargs + continue + + # Derived domain logic + names.append(name) + + domain = field.metadata.get("domain", None) + if domain is not None: + values.append(domain) + continue + + if field.type is bool: + values.append([True, False]) + continue + + if isinstance(field.type, type) and issubclass(field.type, Enum): + values.append(list(field.type)) + continue + + if get_origin(field.type) is Literal: + values.append(list(get_args(field.type))) + continue + + if field.default is not MISSING: + values.append([field.default]) + continue + + raise ValueError(f"Cannot enumerate field {name}.") + + for instance_values in product(*values): + yield cls(**fixed_kwargs, **dict(zip(names, instance_values))) diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index af4782b3db..a74c97376b 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -1,16 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Optional, overload, cast +from abc import ABC, abstractmethod +from typing import Generator, Iterable, Optional, overload, cast from enum import IntEnum +from ._enumeration import _enumerate_instances +from ._isa_enumeration import ISA_ROOT, BindingNode, ISAQuery, Node from ._qre import ( - Instruction, + ISA, Constraint, + ConstraintBound, FloatFunction, + Instruction, IntFunction, + ISARequirements, constant_function, - ConstraintBound, ) @@ -28,7 +33,7 @@ def constraint( encoding: Encoding = PHYSICAL, *, arity: Optional[int] = 1, - error_rate: Optional[ConstraintBound] = None + error_rate: Optional[ConstraintBound] = None, ) -> Constraint: """ Creates an instruction constraint. @@ -55,7 +60,7 @@ def instruction( arity: int = 1, space: Optional[int] = None, length: Optional[int] = None, - error_rate: float + error_rate: float, ) -> Instruction: ... @overload def instruction( @@ -66,7 +71,7 @@ def instruction( arity: None = ..., space: Optional[IntFunction] = None, length: Optional[IntFunction] = None, - error_rate: FloatFunction + error_rate: FloatFunction, ) -> Instruction: ... def instruction( id: int, @@ -76,7 +81,7 @@ def instruction( arity: Optional[int] = 1, space: Optional[int] | IntFunction = None, length: Optional[int | IntFunction] = None, - error_rate: float | FloatFunction + error_rate: float | FloatFunction, ) -> Instruction: """ Creates an instruction. @@ -125,3 +130,97 @@ def instruction( cast(FloatFunction, error_rate), length, ) + + +class ISATransform(ABC): + """ + Abstract base class for transformations between ISAs (e.g., QEC schemes). + + An ISA transform defines a mapping from a required input ISA (e.g., + architecture constraints) to a provided output ISA (logical instructions). + It supports enumeration of configuration parameters. + """ + + @staticmethod + @abstractmethod + def required_isa() -> ISARequirements: + """ + Returns the requirements that an implementation ISA must satisfy. + + Returns: + ISARequirements: The requirements for the underlying ISA. + """ + ... + + @abstractmethod + def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + """ + Yields ISAs provided by this transform given an implementation ISA. + + Args: + impl_isa (ISA): The implementation ISA that satisfies requirements. + + Yields: + ISA: A provided logical ISA. + """ + ... + + @classmethod + def enumerate_isas( + cls, + impl_isa: ISA | Iterable[ISA], + **kwargs, + ) -> Generator[ISA, None, None]: + """ + Enumerates all valid ISAs for this transform given implementation ISAs. + + This method iterates over all instances of the transform class (enumerating + hypterparameters) and filters implementation ISAs against requirements. + + Args: + impl_isa (ISA | Iterable[ISA]): One or more implementation ISAs. + **kwargs: Arguments passed to parameter enumeration. + + Yields: + ISA: Valid provided ISAs. + """ + isas = [impl_isa] if isinstance(impl_isa, ISA) else impl_isa + for isa in isas: + if not isa.satisfies(cls.required_isa()): + continue + + for component in _enumerate_instances(cls, **kwargs): + yield from component.provided_isa(isa) + + @classmethod + def q(cls, *, source: Node | None = None, **kwargs) -> ISAQuery: + """ + Creates an ISAQuery node for this transform. + + Args: + source (Node | None): The source node providing implementation ISAs. + Defaults to ISA_ROOT. + **kwargs: Additional arguments for parameter enumeration. + + Returns: + ISAQuery: An enumeration node representing this transform. + """ + return ISAQuery( + cls, source=source if source is not None else ISA_ROOT, kwargs=kwargs + ) + + @classmethod + def bind(cls, name: str, node: Node) -> BindingNode: + """ + Creates a BindingNode for this transform. + + This is a convenience method equivalent to `cls.q().bind(name, node)`. + + Args: + name (str): The name to bind the transform's output to. + node (Node): The child node that can reference this binding. + + Returns: + BindingNode: A binding node enclosing this transform. + """ + return cls.q().bind(name, node) diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py new file mode 100644 index 0000000000..54908aa9a6 --- /dev/null +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -0,0 +1,354 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import functools +import itertools +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Generator + +from ._architecture import Architecture +from ._qre import ISA + + +class Node(ABC): + """ + Abstract base class for all nodes in the ISA enumeration tree. + + Enumeration nodes define the structure of the search space for ISAs starting + from architectures and mofied by ISA transforms such as error correction + schemes. They can be composed using operators like `+` (sum) and `*` + (product) to build complex enumeration strategies. + """ + + @abstractmethod + def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + """ + Yields all ISA instances represented by this enumeration node. + + Args: + ctx (Context): The enumeration context containing shared state, + e.g., access to the underlying architecture. + + Yields: + ISA: A possible ISA that can be generated from this node. + """ + pass + + def __add__(self, other: Node) -> SumNode: + """ + Performs a union of two enumeration nodes. + + Enumerating the sum node yields all ISAs from this node, followed by all + ISAs from the other node. Duplicate ISAs may be produced if both nodes + yield the same ISA. + + Args: + other (Node): The other enumeration node. + + Returns: + SumNode: A node representing the union of both enumerations. + + Example: + + The following enumerates ISAs from both SurfaceCode and ColorCode: + + .. code-block:: python + for isa in SurfaceCode.q() + ColorCode.q(): + ... + """ + if isinstance(self, SumNode) and isinstance(other, SumNode): + sources = self.sources + other.sources + return SumNode(sources) + elif isinstance(self, SumNode): + sources = self.sources + [other] + return SumNode(sources) + elif isinstance(other, SumNode): + sources = [self] + other.sources + return SumNode(sources) + else: + return SumNode([self, other]) + + def __mul__(self, other: Node) -> ProductNode: + """ + Performs the cross product of two enumeration nodes. + + Enumerating the product node yields ISAs resulting from the Cartesian + product of ISAs from both nodes. The ISAs are combined using + concatenation (logical union). This means that instructions in the + other enumeration node with the same ID as an instruction in this + enumeration node will overwrite the instruction from this node. + + Args: + other (Node): The other enumeration node. + + Returns: + ProductNode: A node representing the product of both enumerations. + + Example: + + The following enumerates ISAs formed by combining ISAs from a + surface code and a factory: + + .. code-block:: python + + for isa in SurfaceCode.q() * Factory.q(): + ... + """ + if isinstance(self, ProductNode) and isinstance(other, ProductNode): + sources = self.sources + other.sources + return ProductNode(sources) + elif isinstance(self, ProductNode): + sources = self.sources + [other] + return ProductNode(sources) + elif isinstance(other, ProductNode): + sources = [self] + other.sources + return ProductNode(sources) + else: + return ProductNode([self, other]) + + def bind(self, name: str, node: Node) -> "BindingNode": + """Create a BindingNode with this node as the component. + + Args: + name: The name to bind the component to. + node: The child enumeration node that may contain ISARefNodes. + + Returns: + A BindingNode with self as the component. + + Example: + + .. code-block:: python + ExampleErrorCorrection.q().bind("c", ISARefNode("c") * ISARefNode("c")) + """ + return BindingNode(name=name, component=self, node=node) + + +@dataclass +class Context: + """ + Context passed through enumeration, holding shared state. + + Attributes: + architecture: The base architecture for enumeration. + """ + + architecture: Architecture + _bindings: dict[str, ISA] = field(default_factory=dict, repr=False) + + @property + def root_isa(self) -> ISA: + """The architecture's provided ISA.""" + return self.architecture.provided_isa + + def _with_binding(self, name: str, isa: ISA) -> "Context": + """Return a new context with an additional binding (internal use).""" + new_bindings = {**self._bindings, name: isa} + return Context(self.architecture, new_bindings) + + +@dataclass +class RootNode(Node): + """ + Represents the architecture's base ISA. + Reads from the context instead of holding a reference. + """ + + def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + """ + Yields the architecture ISA from the context. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: The architecture's provided ISA, called root. + """ + yield ctx.root_isa + + +# Singleton instance for convenience +ISA_ROOT = RootNode() + + +@dataclass +class ISAQuery(Node): + """ + Query node that enumerates ISAs based on a component type and source. + + This node takes a component type (which must have an `enumerate_isas` class + method) and a source node. It enumerates the source node to get base ISAs, + and then calls `enumerate_isas` on the component type for each base ISA + to generate derived ISAs. + + Attributes: + component: The component type to query (e.g., a QEC code class). + source: The source node providing input ISAs (default: ISA_ROOT). + kwargs: Additional keyword arguments passed to `enumerate_isas`. + """ + + component: type + source: Node = field(default_factory=lambda: ISA_ROOT) + kwargs: dict = field(default_factory=dict) + + def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + """ + Yields ISAs generated by the component from source ISAs. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: A generated ISA instance. + """ + for isa in self.source.enumerate(ctx): + yield from self.component.enumerate_isas(isa, **self.kwargs) + + +@dataclass +class ProductNode(Node): + """ + Node representing the Cartesian product of multiple source nodes. + + Attributes: + sources: A list of source nodes to combine. + """ + + sources: list[Node] + + def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + """ + Yields ISAs formed by combining ISAs from all source nodes. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: A combined ISA instance. + """ + source_generators = [source.enumerate(ctx) for source in self.sources] + yield from ( + functools.reduce(lambda a, b: a + b, isa_tuple) + for isa_tuple in itertools.product(*source_generators) + ) + + +@dataclass +class SumNode(Node): + """ + Node representing the union of multiple source nodes. + + Attributes: + sources: A list of source nodes to enumerate sequentially. + """ + + sources: list[Node] + + def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + """ + Yields ISAs from each source node in sequence. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: An ISA instance from one of the sources. + """ + for source in self.sources: + yield from source.enumerate(ctx) + + +@dataclass +class ISARefNode(Node): + """ + A reference to a bound ISA in the enumeration context. + + This node looks up the binding from the context and yields the bound ISA. + + Args: + name: The name of the bound ISA to reference. + """ + + name: str + + def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + """ + Yields the bound ISA from the context. + + Args: + ctx (Context): The enumeration context containing bindings. + + Yields: + ISA: The bound ISA. + + Raises: + ValueError: If the name is not bound in the context. + """ + if self.name not in ctx._bindings: + raise ValueError(f"Undefined component reference: '{self.name}'") + yield ctx._bindings[self.name] + + +@dataclass +class BindingNode(Node): + """ + Enumeration node that binds a component to a name. + + This node enables the as_/ref pattern where multiple positions in the + enumeration tree share the same component instance. The bound component + is enumerated once, and its value is shared across all ISARefNodes with + the same name via the context. + + For multiple bindings, nest BindingNode instances. + + Args: + name: The name to bind the component to. + component: An EnumerationNode (e.g., ISAQuery) that produces the bound ISAs. + node: The child enumeration node that may contain ISARefNodes. + + Example: + + .. code-block:: python + ctx = EnumerationContext(architecture=arch) + + # Bind a code and reference it multiple times + BindingNode( + name="c", + component=ExampleErrorCorrection.q(), + node=ISARefNode("c") * ISARefNode("c"), + ).enumerate(ctx) + + # Multiple bindings via nesting + BindingNode( + name="c", + component=ExampleErrorCorrection.q(), + node=BindingNode( + name="f", + component=ExampleFactory.q(source=ISARefNode("c")), + node=ISARefNode("c") * ISARefNode("f"), + ), + ).enumerate(ctx) + """ + + name: str + component: Node + node: Node + + def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + """ + Enumerates child nodes with the bound component in context. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: An ISA instance from the child node. + """ + # Enumerate all ISAs from the component node + for isa in self.component.enumerate(ctx): + # Add binding to context and enumerate child node + new_ctx = ctx._with_binding(self.name, isa) + yield from self.node.enumerate(new_ctx) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index c442aea117..01d999b49e 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -18,6 +18,14 @@ class ISA: """ ... + def __add__(self, other: ISA) -> ISA: + """ + Concatenates two ISAs (logical union). Instructions in the second + operand overwrite instructions in the first operand if they have the + same ID. + """ + ... + def satisfies(self, requirements: ISARequirements) -> bool: """ Checks if the ISA satisfies the given ISA requirements. diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 93e9680e15..fd8e80a5cd 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -48,6 +48,10 @@ impl ISA { .map(ISA) } + pub fn __add__(&self, other: &ISA) -> PyResult { + Ok(ISA(self.0.clone() + other.0.clone())) + } + pub fn satisfies(&self, requirements: &ISARequirements) -> PyResult { Ok(self.0.satisfies(&requirements.0)) } @@ -68,7 +72,7 @@ impl ISA { #[allow(clippy::needless_pass_by_value)] pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { let iter = ISAIterator { - iter: slf.0.clone().into_iter(), + iter: (*slf.0).clone().into_iter(), }; Py::new(slf.py(), iter) } diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index b3449dfbff..90430f5167 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -1,50 +1,61 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from dataclasses import dataclass, KW_ONLY +from dataclasses import KW_ONLY, dataclass, field +from enum import Enum from typing import Generator - from qsharp.qre import ( - constraint, - ConstraintBound, ISA, + LOGICAL, + Architecture, + ConstraintBound, ISARequirements, + ISATransform, + constraint, instruction, linear_function, - LOGICAL, +) +from qsharp.qre._enumeration import _enumerate_instances +from qsharp.qre._isa_enumeration import ( + BindingNode, + Context, + ISAQuery, + ISARefNode, + ProductNode, + SumNode, ) from qsharp.qre.instruction_ids import ( - T, - TWO_QUBIT_CLIFFORD, - H, CNOT, - MEAS_Z, GENERIC, LATTICE_SURGERY, + MEAS_Z, + TWO_QUBIT_CLIFFORD, + H, + T, ) - # NOTE These classes will be generalized as part of the QRE API in the following # pull requests and then moved out of the tests. -class Architecture: +class ExampleArchitecture(Architecture): @property def provided_isa(self) -> ISA: return ISA( instruction(H, time=50, error_rate=1e-3), instruction(CNOT, arity=2, time=50, error_rate=1e-3), instruction(MEAS_Z, time=100, error_rate=1e-3), - instruction(T, time=40, error_rate=1e-4), instruction(TWO_QUBIT_CLIFFORD, arity=2, time=50, error_rate=1e-3), + instruction(GENERIC, time=50, error_rate=1e-4), + instruction(T, time=50, error_rate=1e-4), ) @dataclass -class SurfaceCode: +class SurfaceCode(ISATransform): _: KW_ONLY - distance: int = 7 + distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) @staticmethod def required_isa() -> ISARequirements: @@ -100,8 +111,43 @@ def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: ) +@dataclass +class ExampleFactory(ISATransform): + _: KW_ONLY + level: int = field(default=1, metadata={"domain": range(1, 4)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(T), + ) + + def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + yield ISA( + instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + ) + + +@dataclass +class ExampleLogicalFactory(ISATransform): + _: KW_ONLY + level: int = field(default=1, metadata={"domain": range(1, 4)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(GENERIC, encoding=LOGICAL), + constraint(T, encoding=LOGICAL), + ) + + def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + yield ISA( + instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), + ) + + def test_isa_from_architecture(): - arch = Architecture() + arch = ExampleArchitecture() code = SurfaceCode() # Verify that the architecture satisfies the code requirements @@ -113,3 +159,424 @@ def test_isa_from_architecture(): # There is one ISA with two instructions assert len(isas) == 1 assert len(isas[0]) == 2 + + +def test_enumerate_instances(): + instances = list(_enumerate_instances(SurfaceCode)) + + # There are 12 instances with distances from 3 to 25 + assert len(instances) == 12 + expected_distances = list(range(3, 26, 2)) + for instance, expected_distance in zip(instances, expected_distances): + assert instance.distance == expected_distance + + # Test with specific distances + instances = list(_enumerate_instances(SurfaceCode, distance=[3, 5, 7])) + assert len(instances) == 3 + expected_distances = [3, 5, 7] + for instance, expected_distance in zip(instances, expected_distances): + assert instance.distance == expected_distance + + # Test with fixed distance + instances = list(_enumerate_instances(SurfaceCode, distance=9)) + assert len(instances) == 1 + assert instances[0].distance == 9 + + +def test_enumerate_instances_bool(): + @dataclass + class BoolConfig: + _: KW_ONLY + flag: bool + + instances = list(_enumerate_instances(BoolConfig)) + assert len(instances) == 2 + assert instances[0].flag is True + assert instances[1].flag is False + + +def test_enumerate_instances_enum(): + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + @dataclass + class EnumConfig: + _: KW_ONLY + color: Color + + instances = list(_enumerate_instances(EnumConfig)) + assert len(instances) == 3 + assert instances[0].color == Color.RED + assert instances[1].color == Color.GREEN + assert instances[2].color == Color.BLUE + + +def test_enumerate_instances_failure(): + import pytest + + @dataclass + class InvalidConfig: + _: KW_ONLY + # This field has no domain, is not bool/enum, and has no default + value: int + + with pytest.raises(ValueError, match="Cannot enumerate field value"): + list(_enumerate_instances(InvalidConfig)) + + +def test_enumerate_instances_single(): + @dataclass + class SingleConfig: + value: int = 42 + + instances = list(_enumerate_instances(SingleConfig)) + assert len(instances) == 1 + assert instances[0].value == 42 + + +def test_enumerate_instances_literal(): + from typing import Literal + + @dataclass + class LiteralConfig: + _: KW_ONLY + mode: Literal["fast", "slow"] + + instances = list(_enumerate_instances(LiteralConfig)) + assert len(instances) == 2 + assert instances[0].mode == "fast" + assert instances[1].mode == "slow" + + +def test_enumerate_isas(): + ctx = Context(architecture=ExampleArchitecture()) + + # This will enumerate the 4 ISAs for the error correction code + count = sum(1 for _ in ISAQuery(SurfaceCode).enumerate(ctx)) + assert count == 12 + + # This will enumerate the 2 ISAs for the error correction code when + # restricting the domain + count = sum( + 1 for _ in ISAQuery(SurfaceCode, kwargs={"distance": [3, 5]}).enumerate(ctx) + ) + assert count == 2 + + # This will enumerate the 3 ISAs for the factory + count = sum(1 for _ in ISAQuery(ExampleFactory).enumerate(ctx)) + assert count == 3 + + # This will enumerate 36 ISAs for all products between the 12 error + # correction code ISAs and the 3 factory ISAs + count = sum( + 1 + for _ in ProductNode( + sources=[ + ISAQuery(SurfaceCode), + ISAQuery(ExampleFactory), + ] + ).enumerate(ctx) + ) + assert count == 36 + + # When providing a list, components are chained (OR operation). This + # enumerates ISAs from first factory instance OR second factory instance + count = sum( + 1 + for _ in ProductNode( + sources=[ + ISAQuery(SurfaceCode), + SumNode( + sources=[ + ISAQuery(ExampleFactory), + ISAQuery(ExampleFactory), + ] + ), + ] + ).enumerate(ctx) + ) + assert count == 72 + + # When providing separate arguments, components are combined via product + # (AND). This enumerates ISAs from first factory instance AND second + # factory instance + count = sum( + 1 + for _ in ProductNode( + sources=[ + ISAQuery(SurfaceCode), + ISAQuery(ExampleFactory), + ISAQuery(ExampleFactory), + ] + ).enumerate(ctx) + ) + assert count == 108 + + # Hierarchical factory using from_components: the component receives ISAs + # from the product of other components as its source + count = sum( + 1 + for _ in ProductNode( + sources=[ + ISAQuery(SurfaceCode), + ISAQuery( + ExampleLogicalFactory, + source=ProductNode( + sources=[ + ISAQuery(SurfaceCode), + ISAQuery(ExampleFactory), + ] + ), + ), + ] + ).enumerate(ctx) + ) + assert count == 1296 + + +def test_binding_node(): + """Test BindingNode with ISARefNode for component bindings""" + ctx = Context(architecture=ExampleArchitecture()) + + # Test basic binding: same code used twice + # Without binding: 12 codes × 12 codes = 144 combinations + count_without = sum( + 1 + for _ in ProductNode( + sources=[ + ISAQuery(SurfaceCode), + ISAQuery(SurfaceCode), + ] + ).enumerate(ctx) + ) + assert count_without == 144 + + # With binding: 12 codes (same instance used twice) + count_with = sum( + 1 + for _ in BindingNode( + name="c", + component=ISAQuery(SurfaceCode), + node=ProductNode( + sources=[ISARefNode("c"), ISARefNode("c")], + ), + ).enumerate(ctx) + ) + assert count_with == 12 + + # Verify the binding works: with binding, both should use same params + for isa in BindingNode( + name="c", + component=ISAQuery(SurfaceCode), + node=ProductNode( + sources=[ISARefNode("c"), ISARefNode("c")], + ), + ).enumerate(ctx): + logical_gates = [g for g in isa if g.encoding == LOGICAL] + # Should have 2 logical gates (GENERIC and LATTICE_SURGERY) + assert len(logical_gates) == 2 + + # Test binding with factories (nested bindings) + count_without = sum( + 1 + for _ in ProductNode( + sources=[ + ISAQuery(SurfaceCode), + ISAQuery(ExampleFactory), + ISAQuery(SurfaceCode), + ISAQuery(ExampleFactory), + ] + ).enumerate(ctx) + ) + assert count_without == 1296 # 12 * 3 * 12 * 3 + + count_with = sum( + 1 + for _ in BindingNode( + name="c", + component=ISAQuery(SurfaceCode), + node=BindingNode( + name="f", + component=ISAQuery(ExampleFactory), + node=ProductNode( + sources=[ + ISARefNode("c"), + ISARefNode("f"), + ISARefNode("c"), + ISARefNode("f"), + ], + ), + ), + ).enumerate(ctx) + ) + assert count_with == 36 # 12 * 3 + + # Test binding with from_components equivalent (hierarchical) + # Without binding: 4 outer codes × (4 inner codes × 3 factories × 3 levels) + count_without = sum( + 1 + for _ in ProductNode( + sources=[ + ISAQuery(SurfaceCode), + ISAQuery( + ExampleLogicalFactory, + source=ProductNode( + sources=[ + ISAQuery(SurfaceCode), + ISAQuery(ExampleFactory), + ] + ), + ), + ] + ).enumerate(ctx) + ) + assert count_without == 1296 # 12 * 12 * 3 * 3 + + # With binding: 4 codes (same used twice) × 3 factories × 3 levels + count_with = sum( + 1 + for _ in BindingNode( + name="c", + component=ISAQuery(SurfaceCode), + node=ProductNode( + sources=[ + ISARefNode("c"), + ISAQuery( + ExampleLogicalFactory, + source=ProductNode( + sources=[ + ISARefNode("c"), + ISAQuery(ExampleFactory), + ] + ), + ), + ] + ), + ).enumerate(ctx) + ) + assert count_with == 108 # 12 * 3 * 3 + + # Test binding with kwargs + count_with_kwargs = sum( + 1 + for _ in BindingNode( + name="c", + component=ISAQuery(SurfaceCode, kwargs={"distance": 5}), + node=ProductNode( + sources=[ISARefNode("c"), ISARefNode("c")], + ), + ).enumerate(ctx) + ) + assert count_with_kwargs == 1 # Only distance=5 + + # Verify kwargs are applied + for isa in BindingNode( + name="c", + component=ISAQuery(SurfaceCode, kwargs={"distance": 5}), + node=ProductNode( + sources=[ISARefNode("c"), ISARefNode("c")], + ), + ).enumerate(ctx): + logical_gates = [g for g in isa if g.encoding == LOGICAL] + assert all(g.space(1) == 50 for g in logical_gates) + + # Test multiple independent bindings (nested) + count = sum( + 1 + for _ in BindingNode( + name="c1", + component=ISAQuery(SurfaceCode), + node=BindingNode( + name="c2", + component=ISAQuery(ExampleFactory), + node=ProductNode( + sources=[ + ISARefNode("c1"), + ISARefNode("c1"), + ISARefNode("c2"), + ISARefNode("c2"), + ], + ), + ), + ).enumerate(ctx) + ) + # 12 codes for c1 × 3 factories for c2 + assert count == 36 + + +def test_binding_node_errors(): + """Test error handling for BindingNode""" + ctx = Context(architecture=ExampleArchitecture()) + + # Test ISARefNode enumerate with undefined binding raises ValueError + try: + list(ISARefNode("test").enumerate(ctx)) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Undefined component reference: 'test'" in str(e) + + +def test_product_isa_enumeration_nodes(): + terminal = ISAQuery(SurfaceCode) + query = terminal * terminal + + # Multiplication should create ProductNode + assert isinstance(query, ProductNode) + assert len(query.sources) == 2 + for source in query.sources: + assert isinstance(source, ISAQuery) + + # Multiplying again should extend the sources + query = query * terminal + assert isinstance(query, ProductNode) + assert len(query.sources) == 3 + for source in query.sources: + assert isinstance(source, ISAQuery) + + # Also from the other side + query = terminal * query + assert isinstance(query, ProductNode) + assert len(query.sources) == 4 + for source in query.sources: + assert isinstance(source, ISAQuery) + + # Also for two ProductNodes + query = query * query + assert isinstance(query, ProductNode) + assert len(query.sources) == 8 + for source in query.sources: + assert isinstance(source, ISAQuery) + + +def test_sum_isa_enumeration_nodes(): + terminal = ISAQuery(SurfaceCode) + query = terminal + terminal + + # Multiplication should create SumNode + assert isinstance(query, SumNode) + assert len(query.sources) == 2 + for source in query.sources: + assert isinstance(source, ISAQuery) + + # Multiplying again should extend the sources + query = query + terminal + assert isinstance(query, SumNode) + assert len(query.sources) == 3 + for source in query.sources: + assert isinstance(source, ISAQuery) + + # Also from the other side + query = terminal + query + assert isinstance(query, SumNode) + assert len(query.sources) == 4 + for source in query.sources: + assert isinstance(source, ISAQuery) + + # Also for two SumNodes + query = query + query + assert isinstance(query, SumNode) + assert len(query.sources) == 8 + for source in query.sources: + assert isinstance(source, ISAQuery) diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index 9668995e4d..c608b91084 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap; #[cfg(test)] mod tests; -#[derive(Default)] +#[derive(Default, Clone)] pub struct ISA { instructions: FxHashMap, } From 0fae825169e3292a8dac3c3c8e8aa63abe5f9177 Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Mon, 2 Feb 2026 10:59:48 -0800 Subject: [PATCH 04/45] Initial commit; hypergraph module added (#2909) Co-authored-by: Mathias Soeken --- .../pip/qsharp/magnets/geometry/__init__.py | 13 ++ .../pip/qsharp/magnets/geometry/hypergraph.py | 119 +++++++++++++ source/pip/tests/magnets/__init__.py | 4 + source/pip/tests/magnets/test_hypergraph.py | 160 ++++++++++++++++++ 4 files changed, 296 insertions(+) create mode 100644 source/pip/qsharp/magnets/geometry/__init__.py create mode 100644 source/pip/qsharp/magnets/geometry/hypergraph.py create mode 100644 source/pip/tests/magnets/__init__.py create mode 100755 source/pip/tests/magnets/test_hypergraph.py diff --git a/source/pip/qsharp/magnets/geometry/__init__.py b/source/pip/qsharp/magnets/geometry/__init__.py new file mode 100644 index 0000000000..649b2a37b2 --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Geometry module for representing quantum system topologies. + +This module provides hypergraph data structures for representing the +geometric structure of quantum systems, including lattice topologies +and interaction graphs. +""" + +from .hypergraph import Hyperedge, Hypergraph + +__all__ = ["Hyperedge", "Hypergraph"] diff --git a/source/pip/qsharp/magnets/geometry/hypergraph.py b/source/pip/qsharp/magnets/geometry/hypergraph.py new file mode 100644 index 0000000000..fc5cd38447 --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/hypergraph.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Hypergraph data structures for representing quantum system geometries. + +This module provides classes for representing hypergraphs, which generalize +graphs by allowing edges (hyperedges) to connect any number of vertices. +Hypergraphs are useful for representing interaction terms in quantum +Hamiltonians, where multi-body interactions can involve more than two sites. +""" + +from typing import Iterator + + +class Hyperedge: + """A hyperedge connecting one or more vertices in a hypergraph. + + A hyperedge generalizes the concept of an edge in a graph. While a + traditional edge connects exactly two vertices, a hyperedge can connect + any number of vertices. This is useful for representing: + - Single-site terms (self-loops): 1 vertex + - Two-body interactions: 2 vertices + - Multi-body interactions: 3+ vertices + Each hyperedge is defined by a set of unique vertex indices, which are + stored in sorted order for consistency. + + Attributes: + vertices: Sorted list of vertex indices connected by this hyperedge. + + Example: + + .. code-block:: python + >>> edge = Hyperedge([2, 0, 1]) + >>> edge.vertices + [0, 1, 2] + """ + + def __init__(self, vertices: list[int]) -> None: + """Initialize a hyperedge with the given vertices. + + Args: + vertices: List of vertex indices. Will be sorted internally. + """ + self.vertices: list[int] = sorted(set(vertices)) + + def __repr__(self) -> str: + return f"Hyperedge({self.vertices})" + + +class Hypergraph: + """A hypergraph consisting of vertices connected by hyperedges. + + A hypergraph is a generalization of a graph where edges (hyperedges) can + connect any number of vertices. This class serves as the base class for + various lattice geometries used in quantum simulations. + + Attributes: + _edges: List of hyperedges in the order they were added. + _vertex_set: Set of all unique vertex indices in the hypergraph. + _edge_list: Set of hyperedges for efficient membership testing. + + Example: + + .. code-block:: python + >>> edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] + >>> graph = Hypergraph(edges) + >>> graph.nvertices() + 3 + >>> graph.nedges() + 3 + """ + + def __init__(self, edges: list[Hyperedge]) -> None: + """Initialize a hypergraph with the given edges. + + Args: + edges: List of hyperedges defining the hypergraph structure. + """ + self._edges = edges + self._vertex_set = set() + self._edge_list = set(edges) + for edge in edges: + self._vertex_set.update(edge.vertices) + + @property + def nedges(self) -> int: + """Return the number of hyperedges in the hypergraph.""" + return len(self._edges) + + @property + def nvertices(self) -> int: + """Return the number of vertices in the hypergraph.""" + return len(self._vertex_set) + + def vertices(self) -> Iterator[int]: + """Return an iterator over vertices in sorted order. + + Returns: + Iterator yielding vertex indices in ascending order. + """ + return iter(sorted(self._vertex_set)) + + def edges(self, part: int = 0) -> Iterator[Hyperedge]: + """Return an iterator over hyperedges in the hypergraph. + + Args: + part: Partition index (reserved for subclass implementations + that support edge partitioning for parallel updates). + + Returns: + Iterator over all hyperedges in the hypergraph. + """ + return iter(self._edge_list) + + def __str__(self) -> str: + return f"Hypergraph with {self.nvertices()} vertices and {self.nedges()} edges." + + def __repr__(self) -> str: + return f"Hypergraph({list(self._edges)})" diff --git a/source/pip/tests/magnets/__init__.py b/source/pip/tests/magnets/__init__.py new file mode 100644 index 0000000000..686737dba3 --- /dev/null +++ b/source/pip/tests/magnets/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for the magnets library.""" diff --git a/source/pip/tests/magnets/test_hypergraph.py b/source/pip/tests/magnets/test_hypergraph.py new file mode 100755 index 0000000000..79f071f47b --- /dev/null +++ b/source/pip/tests/magnets/test_hypergraph.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for hypergraph data structures.""" + +from qsharp.magnets.geometry.hypergraph import Hyperedge, Hypergraph + + +# Hyperedge tests + + +def test_hyperedge_init_basic(): + """Test basic Hyperedge initialization.""" + edge = Hyperedge([0, 1]) + assert edge.vertices == [0, 1] + + +def test_hyperedge_vertices_sorted(): + """Test that vertices are automatically sorted.""" + edge = Hyperedge([3, 1, 2]) + assert edge.vertices == [1, 2, 3] + + +def test_hyperedge_single_vertex(): + """Test hyperedge with single vertex (self-loop).""" + edge = Hyperedge([5]) + assert edge.vertices == [5] + assert len(edge.vertices) == 1 + + +def test_hyperedge_multiple_vertices(): + """Test hyperedge with multiple vertices (multi-body interaction).""" + edge = Hyperedge([0, 1, 2, 3]) + assert edge.vertices == [0, 1, 2, 3] + assert len(edge.vertices) == 4 + + +def test_hyperedge_repr(): + """Test string representation.""" + edge = Hyperedge([1, 0]) + assert repr(edge) == "Hyperedge([0, 1])" + + +def test_hyperedge_empty_vertices(): + """Test hyperedge with empty vertex list.""" + edge = Hyperedge([]) + assert edge.vertices == [] + assert len(edge.vertices) == 0 + + +# Hypergraph tests + + +def test_hypergraph_init_basic(): + """Test basic Hypergraph initialization.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] + graph = Hypergraph(edges) + assert graph.nedges() == 2 + assert graph.nvertices() == 3 + + +def test_hypergraph_empty_graph(): + """Test hypergraph with no edges.""" + graph = Hypergraph([]) + assert graph.nedges() == 0 + assert graph.nvertices() == 0 + + +def test_hypergraph_nedges(): + """Test edge count.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + assert graph.nedges() == 3 + + +def test_hypergraph_nvertices(): + """Test vertex count with unique vertices.""" + edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + assert graph.nvertices() == 4 + + +def test_hypergraph_nvertices_with_shared_vertices(): + """Test vertex count when edges share vertices.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] + graph = Hypergraph(edges) + assert graph.nvertices() == 3 + + +def test_hypergraph_vertices_iterator(): + """Test vertices iterator returns sorted vertices.""" + edges = [Hyperedge([3, 1]), Hyperedge([0, 2])] + graph = Hypergraph(edges) + vertices = list(graph.vertices()) + assert vertices == [0, 1, 2, 3] + + +def test_hypergraph_edges_iterator(): + """Test edges iterator returns all edges.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] + graph = Hypergraph(edges) + edge_list = list(graph.edges()) + assert len(edge_list) == 2 + + +def test_hypergraph_edges_with_part_parameter(): + """Test edges iterator with part parameter (base class ignores it).""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] + graph = Hypergraph(edges) + # Base class returns all edges regardless of part parameter + edge_list_0 = list(graph.edges(part=0)) + edge_list_1 = list(graph.edges(part=1)) + assert len(edge_list_0) == 2 + assert len(edge_list_1) == 2 + + +def test_hypergraph_str(): + """Test string representation.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + expected = "Hypergraph with 4 vertices and 3 edges." + assert str(graph) == expected + + +def test_hypergraph_repr(): + """Test repr representation.""" + edges = [Hyperedge([0, 1])] + graph = Hypergraph(edges) + result = repr(graph) + assert "Hypergraph" in result + assert "Hyperedge" in result + + +def test_hypergraph_single_vertex_edges(): + """Test hypergraph with self-loop edges.""" + edges = [Hyperedge([0]), Hyperedge([1]), Hyperedge([2])] + graph = Hypergraph(edges) + assert graph.nedges() == 3 + assert graph.nvertices() == 3 + + +def test_hypergraph_mixed_edge_sizes(): + """Test hypergraph with edges of different sizes.""" + edges = [ + Hyperedge([0]), # 1 vertex (self-loop) + Hyperedge([1, 2]), # 2 vertices (pair) + Hyperedge([3, 4, 5]), # 3 vertices (triple) + ] + graph = Hypergraph(edges) + assert graph.nedges() == 3 + assert graph.nvertices() == 6 + + +def test_hypergraph_non_contiguous_vertices(): + """Test hypergraph with non-contiguous vertex indices.""" + edges = [Hyperedge([0, 10]), Hyperedge([5, 20])] + graph = Hypergraph(edges) + assert graph.nvertices() == 4 + vertices = list(graph.vertices()) + assert vertices == [0, 5, 10, 20] From c0308e98c68d8788030231fc6b50d8c37268645c Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Tue, 3 Feb 2026 09:13:20 +0100 Subject: [PATCH 05/45] Data types for traces, transforms, Pareto frontiers (#2907) These are data structures for traces, transform on traces, and Pareto frontiers to store estimation results. The `Trace` struct also implements an estimate function to estimate a trace with respect to an ISA. The two current trace transformations are PSSPC and LatticeSurgery. The trace API is considered final for now, whereas the implementation might still change. Instruction IDs are now provided by the Rust crate, the updates for the corresponding IDs in the Python package happen in an upcoming PR to keep this PR in a reasonable size. --- Cargo.lock | 2 + .../pip/qsharp/magnets/geometry/hypergraph.py | 6 +- source/pip/tests/magnets/test_hypergraph.py | 24 +- source/qre/Cargo.toml | 5 +- source/qre/src/isa.rs | 10 +- source/qre/src/lib.rs | 43 ++ source/qre/src/pareto.rs | 259 ++++++++ source/qre/src/pareto/tests.rs | 243 ++++++++ source/qre/src/result.rs | 180 ++++++ source/qre/src/trace.rs | 590 ++++++++++++++++++ source/qre/src/trace/instruction_ids.rs | 145 +++++ source/qre/src/trace/tests.rs | 240 +++++++ source/qre/src/trace/transforms.rs | 14 + .../src/trace/transforms/lattice_surgery.rs | 30 + source/qre/src/trace/transforms/psspc.rs | 218 +++++++ 15 files changed, 1989 insertions(+), 20 deletions(-) create mode 100644 source/qre/src/pareto.rs create mode 100644 source/qre/src/pareto/tests.rs create mode 100644 source/qre/src/result.rs create mode 100644 source/qre/src/trace.rs create mode 100644 source/qre/src/trace/instruction_ids.rs create mode 100644 source/qre/src/trace/tests.rs create mode 100644 source/qre/src/trace/transforms.rs create mode 100644 source/qre/src/trace/transforms/lattice_surgery.rs create mode 100644 source/qre/src/trace/transforms/psspc.rs diff --git a/Cargo.lock b/Cargo.lock index bfe0d95443..b7ab884a4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1869,6 +1869,8 @@ version = "0.0.0" dependencies = [ "num-traits", "rustc-hash", + "serde", + "thiserror 1.0.63", ] [[package]] diff --git a/source/pip/qsharp/magnets/geometry/hypergraph.py b/source/pip/qsharp/magnets/geometry/hypergraph.py index fc5cd38447..dd55ebf408 100644 --- a/source/pip/qsharp/magnets/geometry/hypergraph.py +++ b/source/pip/qsharp/magnets/geometry/hypergraph.py @@ -28,7 +28,7 @@ class Hyperedge: vertices: Sorted list of vertex indices connected by this hyperedge. Example: - + .. code-block:: python >>> edge = Hyperedge([2, 0, 1]) >>> edge.vertices @@ -60,7 +60,7 @@ class Hypergraph: _edge_list: Set of hyperedges for efficient membership testing. Example: - + .. code-block:: python >>> edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] >>> graph = Hypergraph(edges) @@ -113,7 +113,7 @@ def edges(self, part: int = 0) -> Iterator[Hyperedge]: return iter(self._edge_list) def __str__(self) -> str: - return f"Hypergraph with {self.nvertices()} vertices and {self.nedges()} edges." + return f"Hypergraph with {self.nvertices} vertices and {self.nedges} edges." def __repr__(self) -> str: return f"Hypergraph({list(self._edges)})" diff --git a/source/pip/tests/magnets/test_hypergraph.py b/source/pip/tests/magnets/test_hypergraph.py index 79f071f47b..5a050993c9 100755 --- a/source/pip/tests/magnets/test_hypergraph.py +++ b/source/pip/tests/magnets/test_hypergraph.py @@ -55,36 +55,36 @@ def test_hypergraph_init_basic(): """Test basic Hypergraph initialization.""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] graph = Hypergraph(edges) - assert graph.nedges() == 2 - assert graph.nvertices() == 3 + assert graph.nedges == 2 + assert graph.nvertices == 3 def test_hypergraph_empty_graph(): """Test hypergraph with no edges.""" graph = Hypergraph([]) - assert graph.nedges() == 0 - assert graph.nvertices() == 0 + assert graph.nedges == 0 + assert graph.nvertices == 0 def test_hypergraph_nedges(): """Test edge count.""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] graph = Hypergraph(edges) - assert graph.nedges() == 3 + assert graph.nedges == 3 def test_hypergraph_nvertices(): """Test vertex count with unique vertices.""" edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] graph = Hypergraph(edges) - assert graph.nvertices() == 4 + assert graph.nvertices == 4 def test_hypergraph_nvertices_with_shared_vertices(): """Test vertex count when edges share vertices.""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] graph = Hypergraph(edges) - assert graph.nvertices() == 3 + assert graph.nvertices == 3 def test_hypergraph_vertices_iterator(): @@ -135,8 +135,8 @@ def test_hypergraph_single_vertex_edges(): """Test hypergraph with self-loop edges.""" edges = [Hyperedge([0]), Hyperedge([1]), Hyperedge([2])] graph = Hypergraph(edges) - assert graph.nedges() == 3 - assert graph.nvertices() == 3 + assert graph.nedges == 3 + assert graph.nvertices == 3 def test_hypergraph_mixed_edge_sizes(): @@ -147,14 +147,14 @@ def test_hypergraph_mixed_edge_sizes(): Hyperedge([3, 4, 5]), # 3 vertices (triple) ] graph = Hypergraph(edges) - assert graph.nedges() == 3 - assert graph.nvertices() == 6 + assert graph.nedges == 3 + assert graph.nvertices == 6 def test_hypergraph_non_contiguous_vertices(): """Test hypergraph with non-contiguous vertex indices.""" edges = [Hyperedge([0, 10]), Hyperedge([5, 20])] graph = Hypergraph(edges) - assert graph.nvertices() == 4 + assert graph.nvertices == 4 vertices = list(graph.vertices()) assert vertices == [0, 5, 10, 20] diff --git a/source/qre/Cargo.toml b/source/qre/Cargo.toml index e7e7c80783..88148dca7c 100644 --- a/source/qre/Cargo.toml +++ b/source/qre/Cargo.toml @@ -9,8 +9,11 @@ edition.workspace = true license.workspace = true [dependencies] -rustc-hash = { workspace = true } num-traits = { workspace = true } +rustc-hash = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } + [dev-dependencies] diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index c608b91084..310f375c56 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -8,6 +8,7 @@ use std::{ use num_traits::FromPrimitive; use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; #[cfg(test)] mod tests; @@ -158,7 +159,7 @@ impl FromIterator for ISARequirements { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct Instruction { id: u64, encoding: Encoding, @@ -328,13 +329,14 @@ impl InstructionConstraint { } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum Encoding { + #[default] Physical, Logical, } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub enum Metrics { FixedArity { arity: u64, @@ -351,7 +353,7 @@ pub enum Metrics { }, } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub enum VariableArityFunction { Constant { value: T }, Linear { slope: T }, diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs index 00d70a2a53..4baa1f9e13 100644 --- a/source/qre/src/lib.rs +++ b/source/qre/src/lib.rs @@ -1,8 +1,51 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +use thiserror::Error; + mod isa; +mod pareto; +pub use pareto::{ + ParetoFrontier as ParetoFrontier2D, ParetoFrontier3D, ParetoItem2D, ParetoItem3D, +}; +mod result; +pub use result::{EstimationCollection, EstimationResult, FactoryResult}; +mod trace; pub use isa::{ ConstraintBound, Encoding, ISA, ISARequirements, Instruction, InstructionConstraint, VariableArityFunction, }; +pub use trace::instruction_ids; +pub use trace::{Block, LatticeSurgery, PSSPC, Property, Trace, TraceTransform, estimate_parallel}; + +/// A resourc estimation error. +#[derive(Clone, Debug, Error, PartialEq)] +pub enum Error { + /// The resource estimation exceeded the maximum allowed error. + #[error("resource estimation exceeded the maximum allowed error: {actual_error} > {max_error}")] + MaximumErrorExceeded { actual_error: f64, max_error: f64 }, + /// Missing instruction in the ISA. + #[error("requested instruction {0} not present in ISA")] + InstructionNotFound(u64), + /// Cannot extract space from instruction. + #[error("cannot extract space from instruction {0} for fixed arity")] + CannotExtractSpace(u64), + /// Cannot extract time from instruction. + #[error("cannot extract time from instruction {0} for fixed arity")] + CannotExtractTime(u64), + /// Cannot extract error rate from instruction. + #[error("cannot extract error rate from instruction {0} for fixed arity")] + CannotExtractErrorRate(u64), + /// Factory time exceeds algorithm runtime + #[error( + "factory instruction {id} time {factory_time} exceeds algorithm runtime {algorithm_runtime}" + )] + FactoryTimeExceedsAlgorithmRuntime { + id: u64, + factory_time: u64, + algorithm_runtime: u64, + }, + /// Unsupported instruction in trace transformation + #[error("unsupported instruction {id} in trace transformation '{name}'")] + UnsupportedInstruction { id: u64, name: &'static str }, +} diff --git a/source/qre/src/pareto.rs b/source/qre/src/pareto.rs new file mode 100644 index 0000000000..96f842a0a4 --- /dev/null +++ b/source/qre/src/pareto.rs @@ -0,0 +1,259 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use serde::{Deserialize, Serialize}; + +#[cfg(test)] +mod tests; + +pub trait ParetoItem2D { + type Objective1: PartialOrd + Copy; + type Objective2: PartialOrd + Copy; + + fn objective1(&self) -> Self::Objective1; + fn objective2(&self) -> Self::Objective2; +} + +pub trait ParetoItem3D { + type Objective1: PartialOrd + Copy; + type Objective2: PartialOrd + Copy; + type Objective3: PartialOrd + Copy; + + fn objective1(&self) -> Self::Objective1; + fn objective2(&self) -> Self::Objective2; + fn objective3(&self) -> Self::Objective3; +} + +/// A Pareto frontier for 2-dimensional objectives. +/// +/// The implementation maintains the frontier sorted by the first objective. +/// This allows for efficient updates based on the geometric property that +/// a point is dominated if and only if it is dominated by its immediate +/// predecessor in the sorted list (when sorted by the first objective). +/// +/// This approach is related to the algorithms described in: +/// H. T. Kung, F. Luccio, and F. P. Preparata, "On Finding the Maxima of a Set of Vectors," +/// Journal of the ACM, vol. 22, no. 4, pp. 469-476, 1975. +#[derive(Default, Debug, Clone)] +pub struct ParetoFrontier(pub Vec); + +impl ParetoFrontier { + #[must_use] + pub fn new() -> Self { + Self(Vec::new()) + } + + pub fn insert(&mut self, p: I) { + // If any objective is incomparable (e.g. NaN), we silently ignore the item + // to maintain the frontier's sorting invariant. + if p.objective1().partial_cmp(&p.objective1()).is_none() + || p.objective2().partial_cmp(&p.objective2()).is_none() + { + return; + } + + let frontier = &mut self.0; + let search = frontier.binary_search_by(|q| { + q.objective1() + .partial_cmp(&p.objective1()) + .expect("objectives must be comparable") + }); + + let pos = match search { + Ok(i) => { + if frontier[i].objective2() <= p.objective2() { + return; + } + i + } + Err(i) => { + if i > 0 { + let left = &frontier[i - 1]; + if left.objective2() <= p.objective2() { + return; + } + } + i + } + }; + let i = pos; + while i < frontier.len() && frontier[i].objective2() >= p.objective2() { + frontier.remove(i); + } + frontier.insert(pos, p); + } + + pub fn iter(&self) -> std::slice::Iter<'_, I> { + self.0.iter() + } + + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Extend for ParetoFrontier { + fn extend>(&mut self, iter: T) { + for p in iter { + self.insert(p); + } + } +} + +impl FromIterator for ParetoFrontier { + fn from_iter>(iter: T) -> Self { + let mut frontier = Self::new(); + frontier.extend(iter); + frontier + } +} + +impl IntoIterator for ParetoFrontier { + type Item = I; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a, I: ParetoItem2D> IntoIterator for &'a ParetoFrontier { + type Item = &'a I; + type IntoIter = std::slice::Iter<'a, I>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +/// A Pareto frontier for 3-dimensional objectives. +/// +/// The implementation maintains the frontier sorted lexicographically. +/// Unlike the 2D case where dominance checks are O(1) given the sorted order, +/// the 3D case requires checking the prefix or suffix to establish dominance, +/// though maintaining sorted order significantly reduces the search space. +/// +/// The theoretical O(N log N) bound for constructing the 3D frontier is established in: +/// H. T. Kung, F. Luccio, and F. P. Preparata, "On Finding the Maxima of a Set of Vectors," +/// Journal of the ACM, vol. 22, no. 4, pp. 469-476, 1975. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ParetoFrontier3D(pub Vec); + +impl ParetoFrontier3D { + #[must_use] + pub fn new() -> Self { + Self(Vec::new()) + } + + pub fn insert(&mut self, p: I) { + // If any objective is incomparable (e.g. NaN), we silently ignore the item. + if p.objective1().partial_cmp(&p.objective1()).is_none() + || p.objective2().partial_cmp(&p.objective2()).is_none() + || p.objective3().partial_cmp(&p.objective3()).is_none() + { + return; + } + + let frontier = &mut self.0; + + // Use lexicographical sort covering all objectives. + // This makes the binary search deterministic and ensures specific properties for prefix/suffix. + let Err(pos) = frontier.binary_search_by(|q| { + q.objective1() + .partial_cmp(&p.objective1()) + .expect("objectives must be comparable") + .then_with(|| { + q.objective2() + .partial_cmp(&p.objective2()) + .expect("objectives must be comparable") + }) + .then_with(|| { + q.objective3() + .partial_cmp(&p.objective3()) + .expect("objectives must be comparable") + }) + }) else { + return; + }; + + // 1. Check if dominated by any existing point in the prefix [0..pos]. + // Because the list is sorted lexicographically, any point `q` before `pos` + // satisfies `q.obj1 <= p.obj1` (often strictly less). + // Therefore, we only need to check if `q` is also better in obj2 and obj3. + for q in &frontier[..pos] { + if q.objective2() <= p.objective2() && q.objective3() <= p.objective3() { + return; + } + } + + // 2. Remove points dominated by the new point in the suffix [pos..]. + // Any point `q` at or after `pos` satisfies `p.obj1 <= q.obj1`. + // So `p` can only dominate `q` if `p` is better in obj2 and obj3. + let mut i = pos; + while i < frontier.len() { + let q = &frontier[i]; + if p.objective2() <= q.objective2() && p.objective3() <= q.objective3() { + frontier.remove(i); + } else { + i += 1; + } + } + + frontier.insert(pos, p); + } + + pub fn iter(&self) -> std::slice::Iter<'_, I> { + self.0.iter() + } + + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Extend for ParetoFrontier3D { + fn extend>(&mut self, iter: T) { + for p in iter { + self.insert(p); + } + } +} + +impl FromIterator for ParetoFrontier3D { + fn from_iter>(iter: T) -> Self { + let mut frontier = Self::new(); + frontier.extend(iter); + frontier + } +} + +impl IntoIterator for ParetoFrontier3D { + type Item = I; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a, I: ParetoItem3D> IntoIterator for &'a ParetoFrontier3D { + type Item = &'a I; + type IntoIter = std::slice::Iter<'a, I>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} diff --git a/source/qre/src/pareto/tests.rs b/source/qre/src/pareto/tests.rs new file mode 100644 index 0000000000..eaf2539c33 --- /dev/null +++ b/source/qre/src/pareto/tests.rs @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::{ + EstimationCollection, EstimationResult, + pareto::{ParetoFrontier, ParetoFrontier3D, ParetoItem2D, ParetoItem3D}, +}; + +struct Point2D { + x: f64, + y: f64, +} + +impl ParetoItem2D for Point2D { + type Objective1 = f64; + type Objective2 = f64; + + fn objective1(&self) -> Self::Objective1 { + self.x + } + + fn objective2(&self) -> Self::Objective2 { + self.y + } +} + +#[test] +fn test_update_frontier() { + let mut frontier: ParetoFrontier = ParetoFrontier::new(); + let p1 = Point2D { x: 1.0, y: 5.0 }; + frontier.insert(p1); + assert_eq!(frontier.0.len(), 1); + let p2 = Point2D { x: 2.0, y: 4.0 }; + frontier.insert(p2); + assert_eq!(frontier.0.len(), 2); + let p3 = Point2D { x: 1.5, y: 6.0 }; + frontier.insert(p3); + assert_eq!(frontier.0.len(), 2); + let p4 = Point2D { x: 3.0, y: 3.0 }; + frontier.insert(p4); + assert_eq!(frontier.0.len(), 3); + let p5 = Point2D { x: 2.5, y: 2.0 }; + frontier.insert(p5); + assert_eq!(frontier.0.len(), 3); +} + +#[test] +fn test_iter_frontier() { + let mut frontier: ParetoFrontier = ParetoFrontier::new(); + frontier.insert(Point2D { x: 1.0, y: 5.0 }); + frontier.insert(Point2D { x: 2.0, y: 4.0 }); + + let mut iter = frontier.iter(); + let p = iter.next().expect("Has element"); + assert!((p.x - 1.0).abs() <= f64::EPSILON); + assert!((p.y - 5.0).abs() <= f64::EPSILON); + + let p = iter.next().expect("Has element"); + assert!((p.x - 2.0).abs() <= f64::EPSILON); + assert!((p.y - 4.0).abs() <= f64::EPSILON); + + assert!(iter.next().is_none()); + + // Test IntoIterator for &ParetoFrontier + for p in &frontier { + assert!(p.x > 0.0); + } +} + +#[derive(Clone, Copy, Debug)] +struct Point3D { + x: f64, + y: f64, + z: f64, +} + +impl ParetoItem3D for Point3D { + type Objective1 = f64; + type Objective2 = f64; + type Objective3 = f64; + + fn objective1(&self) -> Self::Objective1 { + self.x + } + + fn objective2(&self) -> Self::Objective2 { + self.y + } + + fn objective3(&self) -> Self::Objective3 { + self.z + } +} + +#[test] +fn test_update_frontier_3d() { + let mut frontier: ParetoFrontier3D = ParetoFrontier3D::new(); + + // p1: 1, 5, 5 + let p1 = Point3D { + x: 1.0, + y: 5.0, + z: 5.0, + }; + frontier.insert(p1); + assert_eq!(frontier.0.len(), 1); + + // p2: 2, 6, 6 (dominated by p1) + let p2 = Point3D { + x: 2.0, + y: 6.0, + z: 6.0, + }; + frontier.insert(p2); + assert_eq!(frontier.0.len(), 1); + + // p3: 0.5, 6, 6 (not dominated, x makes it unique) + let p3 = Point3D { + x: 0.5, + y: 6.0, + z: 6.0, + }; + frontier.insert(p3); + assert_eq!(frontier.0.len(), 2); + + // p4: 1, 4, 4 (dominates p1, should remove p1 and add p4) + // p1 (1,5,5). p4 (1,4,4). p4 <= p1? 1<=1, 4<=5, 4<=5. Yes. + // p3 (0.5,6,6). p4 (1,4,4). p4 <= p3? 1<=0.5 False. + // Result: p1 removed, p4 added. p3 remains. + let p4 = Point3D { + x: 1.0, + y: 4.0, + z: 4.0, + }; + frontier.insert(p4); + assert_eq!(frontier.0.len(), 2); + + // Verify content (generic check, not order specific) + let points: Vec<(f64, f64, f64)> = frontier.iter().map(|p| (p.x, p.y, p.z)).collect(); + + // Should contain p3 and p4 + assert!( + points + .iter() + .any(|p| (p.0 - 0.5).abs() < 1e-9 && (p.1 - 6.0).abs() < 1e-9) + ); + assert!( + points + .iter() + .any(|p| (p.0 - 1.0).abs() < 1e-9 && (p.1 - 4.0).abs() < 1e-9) + ); +} + +#[test] +fn test_estimation_results() { + let mut result_worst = EstimationResult::new(); + result_worst.add_qubits(994_570); + result_worst.add_runtime(346_196_523_750); + + let mut result_mid = EstimationResult::new(); + result_mid.add_qubits(994_570); + result_mid.add_runtime(346_191_476_400); + + let mut result_best = EstimationResult::new(); + result_best.add_qubits(994_570); + result_best.add_runtime(346_181_381_700); + + let results = [result_worst, result_mid, result_best]; + let permutations = [ + [0, 1, 2], + [0, 2, 1], + [1, 0, 2], + [1, 2, 0], + [2, 0, 1], + [2, 1, 0], + ]; + + for p in permutations { + let mut frontier = EstimationCollection::new(); + frontier.insert(results[p[0]].clone()); + frontier.insert(results[p[1]].clone()); + frontier.insert(results[p[2]].clone()); + assert_eq!(frontier.len(), 1, "Failed for permutation {p:?}"); + + // Verify the retained item is the best one (index 2) + let item = frontier.iter().next().expect("has item"); + assert_eq!( + item.runtime(), + 346_181_381_700, + "Wrong item retained for permutation {p:?}", + ); + } +} + +#[test] +fn test_estimation_results_3d_permutations() { + // Check that 3D frontier handles strictly dominating points correctly + // even when first dimension is equal. + + // p_worst: (10, 100, 1000) + let p_worst = Point3D { + x: 10.0, + y: 100.0, + z: 1000.0, + }; + // p_mid: (10, 90, 1000) -> Dominates p_worst + let p_mid = Point3D { + x: 10.0, + y: 90.0, + z: 1000.0, + }; + // p_best: (10, 80, 1000) -> Dominates p_mid and p_worst + let p_best = Point3D { + x: 10.0, + y: 80.0, + z: 1000.0, + }; + + let results = [p_worst, p_mid, p_best]; + + let permutations = [ + [0, 1, 2], + [0, 2, 1], + [1, 0, 2], + [1, 2, 0], + [2, 0, 1], + [2, 1, 0], + ]; + + for p in permutations { + let mut frontier = ParetoFrontier3D::new(); + frontier.insert(results[p[0]]); + frontier.insert(results[p[1]]); + frontier.insert(results[p[2]]); + assert_eq!(frontier.len(), 1, "Failed for 3D permutation {p:?}"); + + let item = frontier.iter().next().expect("has item"); + assert!( + (item.y - 80.0).abs() < f64::EPSILON, + "Wrong item retained for 3D permutation {p:?}", + ); + } +} diff --git a/source/qre/src/result.rs b/source/qre/src/result.rs new file mode 100644 index 0000000000..fec8bf2135 --- /dev/null +++ b/source/qre/src/result.rs @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + fmt::Display, + ops::{Deref, DerefMut}, +}; + +use rustc_hash::FxHashMap; + +use crate::{ParetoFrontier2D, ParetoItem2D}; + +#[derive(Clone, Default)] +pub struct EstimationResult { + qubits: u64, + runtime: u64, + error: f64, + factories: FxHashMap, +} + +impl EstimationResult { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn qubits(&self) -> u64 { + self.qubits + } + + #[must_use] + pub fn runtime(&self) -> u64 { + self.runtime + } + + #[must_use] + pub fn error(&self) -> f64 { + self.error + } + + #[must_use] + pub fn factories(&self) -> &FxHashMap { + &self.factories + } + + pub fn set_qubits(&mut self, qubits: u64) { + self.qubits = qubits; + } + + pub fn set_runtime(&mut self, runtime: u64) { + self.runtime = runtime; + } + + pub fn set_error(&mut self, error: f64) { + self.error = error; + } + + /// Adds to the current qubit count and returns the new value. + pub fn add_qubits(&mut self, qubits: u64) -> u64 { + self.qubits += qubits; + self.qubits + } + + /// Adds to the current runtime and returns the new value. + pub fn add_runtime(&mut self, runtime: u64) -> u64 { + self.runtime += runtime; + self.runtime + } + + /// Adds to the current error and returns the new value. + pub fn add_error(&mut self, error: f64) -> f64 { + self.error += error; + self.error + } + + pub fn add_factory_result(&mut self, id: u64, result: FactoryResult) { + self.factories.insert(id, result); + } +} + +impl Display for EstimationResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Qubits: {}, Runtime: {}, Error: {}", + self.qubits, self.runtime, self.error + )?; + + if !self.factories.is_empty() { + for (id, factory) in &self.factories { + write!( + f, + ", {id}: {} runs x {} copies", + factory.runs(), + factory.copies() + )?; + } + } + + Ok(()) + } +} + +impl ParetoItem2D for EstimationResult { + type Objective1 = u64; // qubits + type Objective2 = u64; // runtime + + fn objective1(&self) -> Self::Objective1 { + self.qubits + } + + fn objective2(&self) -> Self::Objective2 { + self.runtime + } +} + +#[derive(Default)] +pub struct EstimationCollection(ParetoFrontier2D); + +impl EstimationCollection { + #[must_use] + pub fn new() -> Self { + Self::default() + } +} + +impl Deref for EstimationCollection { + type Target = ParetoFrontier2D; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for EstimationCollection { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[derive(Clone)] +pub struct FactoryResult { + copies: u64, + runs: u64, + states: u64, + error_rate: f64, +} + +impl FactoryResult { + #[must_use] + pub fn new(copies: u64, runs: u64, states: u64, error_rate: f64) -> Self { + Self { + copies, + runs, + states, + error_rate, + } + } + + #[must_use] + pub fn copies(&self) -> u64 { + self.copies + } + + #[must_use] + pub fn runs(&self) -> u64 { + self.runs + } + + #[must_use] + pub fn states(&self) -> u64 { + self.states + } + + #[must_use] + pub fn error_rate(&self) -> f64 { + self.error_rate + } +} diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs new file mode 100644 index 0000000000..0193b3c9db --- /dev/null +++ b/source/qre/src/trace.rs @@ -0,0 +1,590 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::fmt::{Display, Formatter}; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction}; + +pub mod instruction_ids; +#[cfg(test)] +mod tests; + +mod transforms; +pub use transforms::{LatticeSurgery, PSSPC, TraceTransform}; + +#[derive(Clone, Default)] +pub struct Trace { + block: Block, + base_error: f64, + compute_qubits: u64, + memory_qubits: Option, + resource_states: Option>, + properties: FxHashMap, +} + +impl Trace { + #[must_use] + pub fn new(compute_qubits: u64) -> Self { + Self { + compute_qubits, + ..Default::default() + } + } + + #[must_use] + pub fn clone_empty(&self, compute_qubits: Option) -> Self { + Self { + block: Block::default(), + base_error: self.base_error, + compute_qubits: compute_qubits.unwrap_or(self.compute_qubits), + memory_qubits: self.memory_qubits, + resource_states: self.resource_states.clone(), + properties: self.properties.clone(), + } + } + + #[must_use] + pub fn compute_qubits(&self) -> u64 { + self.compute_qubits + } + + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + self.block.add_operation(id, qubits, params); + } + + pub fn add_block(&mut self, repetitions: u64) -> &mut Block { + self.block.add_block(repetitions) + } + + #[must_use] + pub fn base_error(&self) -> f64 { + self.base_error + } + + pub fn increment_base_error(&mut self, amount: f64) { + self.base_error += amount; + } + + pub fn increment_resource_state(&mut self, resource_id: u64, amount: u64) { + if amount == 0 { + return; + } + let states = self.resource_states.get_or_insert_with(FxHashMap::default); + *states.entry(resource_id).or_default() += amount; + } + + #[must_use] + pub fn get_resource_states(&self) -> Option<&FxHashMap> { + self.resource_states.as_ref() + } + + #[must_use] + pub fn get_resource_state_count(&self, resource_id: u64) -> u64 { + if let Some(states) = &self.resource_states + && let Some(count) = states.get(&resource_id) + { + return *count; + } + 0 + } + + pub fn set_property(&mut self, key: String, value: Property) { + self.properties.insert(key, value); + } + + #[must_use] + pub fn get_property(&self, key: &str) -> Option<&Property> { + self.properties.get(key) + } + + #[must_use] + pub fn deep_iter(&self) -> TraceIterator<'_> { + TraceIterator::new(&self.block) + } + + #[must_use] + pub fn depth(&self) -> u64 { + self.block.depth() + } + + #[allow( + clippy::cast_precision_loss, + clippy::cast_possible_truncation, + clippy::cast_sign_loss + )] + pub fn estimate(&self, isa: &ISA, max_error: Option) -> Result { + let max_error = max_error.unwrap_or(1.0); + + if self.base_error > max_error { + return Err(Error::MaximumErrorExceeded { + actual_error: self.base_error, + max_error, + }); + } + + let mut result = EstimationResult::new(); + + // base error starts with the error already present in the trace + result.add_error(self.base_error); + + // Counts how many magic state factories are needed per resource state ID + let mut factories: FxHashMap = FxHashMap::default(); + + // This will track the number of physical qubits per logical qubit while + // processing all the instructions. Normally, we assume that the number + // is always the same. + let mut qubit_counts: Vec = vec![]; + + // ------------------------------------------------------------------ + // Add errors from resource states. Allow callable error rates. + // ------------------------------------------------------------------ + if let Some(resource_states) = &self.resource_states { + for (state_id, count) in resource_states { + let rate = get_error_rate_by_id(isa, *state_id)?; + let actual_error = result.add_error(rate * (*count as f64)); + if actual_error > max_error { + return Err(Error::MaximumErrorExceeded { + actual_error, + max_error, + }); + } + factories.insert(*state_id, *count); + } + } + + // ------------------------------------------------------------------ + // Gate error accumulation using recursion over block structure. + // Each block contributes repetitions * internal_gate_errors. + // Missing instructions raise an error. Callable rates use arity. + // ------------------------------------------------------------------ + for (gate, mult) in self.deep_iter() { + let instr = get_instruction(isa, gate.id)?; + + let arity = gate.qubits.len() as u64; + + let rate = instr.expect_error_rate(Some(arity)); + + let qubit_count = instr.expect_space(Some(arity)) as f64 / arity as f64; + + if let Err(i) = qubit_counts.binary_search_by(|qc| qc.total_cmp(&qubit_count)) { + qubit_counts.insert(i, qubit_count); + } + + let actual_error = result.add_error(rate * (mult as f64)); + if actual_error > max_error { + return Err(Error::MaximumErrorExceeded { + actual_error, + max_error, + }); + } + } + + let total_compute_qubits = (self.compute_qubits() as f64 + * qubit_counts.last().copied().unwrap_or(1.0)) + .ceil() as u64; + result.add_qubits(total_compute_qubits); + + result.add_runtime( + self.block + .depth_and_used(Some(&|op: &Gate| { + let instr = get_instruction(isa, op.id)?; + Ok(instr.expect_time(Some(op.qubits.len() as u64))) + }))? + .0, + ); + + // ------------------------------------------------------------------ + // Factory overhead estimation. Each factory produces states at + // a certain rate, so we need enough copies to meet the demand. + // ------------------------------------------------------------------ + for (factory, count) in &factories { + let instr = get_instruction(isa, *factory)?; + let factory_time = get_time(instr)?; + let factory_space = get_space(instr)?; + let factory_error_rate = get_error_rate(instr)?; + let runs = result.runtime() / factory_time; + + if runs == 0 { + return Err(Error::FactoryTimeExceedsAlgorithmRuntime { + id: *factory, + factory_time, + algorithm_runtime: result.runtime(), + }); + } + + let copies = count.div_ceil(runs); + + result.add_qubits(copies * factory_space); + result.add_factory_result( + *factory, + FactoryResult::new(copies, runs, *count, factory_error_rate), + ); + } + + Ok(result) + } +} + +impl Display for Trace { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "@compute_qubits({})", self.compute_qubits())?; + + if let Some(memory_qubits) = self.memory_qubits { + writeln!(f, "@memory_qubits({memory_qubits})")?; + } + if self.base_error > 0.0 { + writeln!(f, "@base_error({})", self.base_error)?; + } + if let Some(resource_states) = &self.resource_states { + for (res_id, amount) in resource_states { + writeln!(f, "@resource_state({res_id}, {amount})")?; + } + } + write!(f, "{}", self.block) + } +} + +#[derive(Clone, Debug)] +pub enum Operation { + GateOperation(Gate), + BlockOperation(Block), +} + +#[derive(Clone, Debug)] +pub struct Gate { + id: u64, + qubits: Vec, + params: Vec, +} + +#[derive(Clone, Debug)] +pub struct Block { + operations: Vec, + repetitions: u64, +} + +impl Default for Block { + fn default() -> Self { + Self { + operations: Vec::new(), + repetitions: 1, + } + } +} + +impl Block { + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + self.operations + .push(Operation::gate_operation(id, qubits, params)); + } + + pub fn add_block(&mut self, repetitions: u64) -> &mut Block { + self.operations + .push(Operation::block_operation(repetitions)); + + match self.operations.last_mut() { + Some(Operation::BlockOperation(b)) => b, + _ => unreachable!("Last operation must be a block operation"), + } + } + + pub fn write(&self, f: &mut Formatter<'_>, indent: usize) -> std::fmt::Result { + let indent_str = " ".repeat(indent); + if self.repetitions == 1 { + writeln!(f, "{indent_str}{{")?; + } else { + writeln!(f, "{indent_str}repeat {} {{", self.repetitions)?; + } + + for op in &self.operations { + match op { + Operation::GateOperation(Gate { id, qubits, params }) => { + writeln!(f, "{indent_str} {id}({params:?})({qubits:?})")?; + } + Operation::BlockOperation(b) => { + b.write(f, indent + 2)?; + } + } + } + writeln!(f, "{indent_str}}}") + } + + fn depth_and_used Result>( + &self, + duration_fn: Option<&FnDuration>, + ) -> Result<(u64, FxHashSet), Error> { + let mut qubit_depths: FxHashMap = FxHashMap::default(); + let mut all_used = FxHashSet::default(); + + for op in &self.operations { + match op { + Operation::GateOperation(gate) => { + let start_time = gate + .qubits + .iter() + .filter_map(|q| qubit_depths.get(q)) + .max() + .copied() + .unwrap_or(0); + + let duration = match duration_fn { + Some(f) => f(gate)?, + None => 1, + }; + + let end_time = start_time + duration; + for q in &gate.qubits { + qubit_depths.insert(*q, end_time); + all_used.insert(*q); + } + } + Operation::BlockOperation(block) => { + let (duration, used) = block.depth_and_used(duration_fn)?; + if used.is_empty() { + continue; + } + + let start_time = used + .iter() + .filter_map(|q| qubit_depths.get(q)) + .max() + .copied() + .unwrap_or(0); + + let end_time = start_time + duration; + for q in &used { + qubit_depths.insert(*q, end_time); + } + all_used.extend(used); + } + } + } + + let max_depth = qubit_depths.values().max().copied().unwrap_or(0); + Ok((max_depth * self.repetitions, all_used)) + } + + #[must_use] + pub fn depth(&self) -> u64 { + self.depth_and_used:: Result>(None) + .expect("Duration function is None") + .0 + } +} + +impl Display for Block { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.write(f, 0) + } +} + +impl Operation { + fn gate_operation(id: u64, qubits: Vec, params: Vec) -> Self { + Operation::GateOperation(Gate { id, qubits, params }) + } + + fn block_operation(repetitions: u64) -> Self { + Operation::BlockOperation(Block { + operations: Vec::new(), + repetitions, + }) + } +} + +pub struct TraceIterator<'a> { + stack: Vec<(std::slice::Iter<'a, Operation>, u64)>, +} + +impl<'a> TraceIterator<'a> { + fn new(block: &'a Block) -> Self { + Self { + stack: vec![(block.operations.iter(), 1)], + } + } +} + +impl<'a> Iterator for TraceIterator<'a> { + type Item = (&'a Gate, u64); + + fn next(&mut self) -> Option { + loop { + let (iter, multiplier) = self.stack.last_mut()?; + match iter.next() { + Some(op) => match op { + Operation::GateOperation(g) => return Some((g, *multiplier)), + Operation::BlockOperation(block) => { + let new_multiplier = *multiplier * block.repetitions; + self.stack.push((block.operations.iter(), new_multiplier)); + } + }, + None => { + self.stack.pop(); + } + } + } + } +} + +#[derive(Clone)] +pub enum Property { + Bool(bool), + Int(i64), + Float(f64), + Str(String), +} + +impl Property { + #[must_use] + pub fn new_bool(b: bool) -> Self { + Property::Bool(b) + } + + #[must_use] + pub fn new_int(i: i64) -> Self { + Property::Int(i) + } + + #[must_use] + pub fn new_float(f: f64) -> Self { + Property::Float(f) + } + + #[must_use] + pub fn new_str(s: String) -> Self { + Property::Str(s) + } + + #[must_use] + pub fn as_bool(&self) -> Option { + match self { + Property::Bool(b) => Some(*b), + _ => None, + } + } + + #[must_use] + pub fn as_int(&self) -> Option { + match self { + Property::Int(i) => Some(*i), + _ => None, + } + } + + #[must_use] + pub fn as_float(&self) -> Option { + match self { + Property::Float(f) => Some(*f), + _ => None, + } + } + + #[must_use] + pub fn as_str(&self) -> Option<&str> { + match self { + Property::Str(s) => Some(s), + _ => None, + } + } + + #[must_use] + pub fn is_bool(&self) -> bool { + matches!(self, Property::Bool(_)) + } + + #[must_use] + pub fn is_int(&self) -> bool { + matches!(self, Property::Int(_)) + } + + #[must_use] + pub fn is_float(&self) -> bool { + matches!(self, Property::Float(_)) + } + + #[must_use] + pub fn is_str(&self) -> bool { + matches!(self, Property::Str(_)) + } +} + +impl Display for Property { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Property::Bool(b) => write!(f, "{b}"), + Property::Int(i) => write!(f, "{i}"), + Property::Float(fl) => write!(f, "{fl}"), + Property::Str(s) => write!(f, "{s}"), + } + } +} + +// Some helper functions to extract instructions and their metrics together with +// error handling + +fn get_instruction(isa: &ISA, id: u64) -> Result<&Instruction, Error> { + isa.get(&id).ok_or(Error::InstructionNotFound(id)) +} + +fn get_space(instruction: &Instruction) -> Result { + instruction + .space(None) + .ok_or(Error::CannotExtractSpace(instruction.id())) +} + +fn get_time(instruction: &Instruction) -> Result { + instruction + .time(None) + .ok_or(Error::CannotExtractTime(instruction.id())) +} + +fn get_error_rate(instruction: &Instruction) -> Result { + instruction + .error_rate(None) + .ok_or(Error::CannotExtractErrorRate(instruction.id())) +} + +fn get_error_rate_by_id(isa: &ISA, id: u64) -> Result { + let instr = get_instruction(isa, id)?; + instr + .error_rate(None) + .ok_or(Error::CannotExtractErrorRate(id)) +} + +fn estimate_chunks<'a>(traces: &[&'a Trace], isas: &[&'a ISA]) -> Vec { + let mut local_collection = Vec::new(); + for trace in traces { + for isa in isas { + if let Ok(estimation) = trace.estimate(isa, None) { + local_collection.push(estimation); + } + } + } + local_collection +} + +#[must_use] +pub fn estimate_parallel<'a>(traces: &[&'a Trace], isas: &[&'a ISA]) -> EstimationCollection { + let mut collection = EstimationCollection::new(); + std::thread::scope(|scope| { + let num_threads = std::thread::available_parallelism() + .map(std::num::NonZero::get) + .unwrap_or(1); + let chunk_size = traces.len().div_ceil(num_threads); + + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); + + for chunk in traces.chunks(chunk_size) { + let tx = tx.clone(); + scope.spawn(move || tx.send(estimate_chunks(chunk, isas))); + } + drop(tx); + + for local_collection in rx.iter().take(num_threads) { + collection.extend(local_collection.into_iter()); + } + }); + + collection +} diff --git a/source/qre/src/trace/instruction_ids.rs b/source/qre/src/trace/instruction_ids.rs new file mode 100644 index 0000000000..f8f78bc958 --- /dev/null +++ b/source/qre/src/trace/instruction_ids.rs @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// NOTE: Define new instruction ids here. Then: +// - add them to `add_instruction_ids` in qre.rs +// - add them to instruction_ids.pyi + +pub const PAULI_I: u64 = 0x0; +pub const PAULI_X: u64 = 0x1; +pub const PAULI_Y: u64 = 0x2; +pub const PAULI_Z: u64 = 0x3; +pub const H: u64 = 0x10; +pub const H_XZ: u64 = 0x10; +pub const H_XY: u64 = 0x11; +pub const H_YZ: u64 = 0x12; +pub const SQRT_X: u64 = 0x13; +pub const SQRT_X_DAG: u64 = 0x14; +pub const SQRT_Y: u64 = 0x15; +pub const SQRT_Y_DAG: u64 = 0x16; +pub const S: u64 = 0x17; +pub const SQRT_Z: u64 = 0x17; +pub const S_DAG: u64 = 0x18; +pub const SQRT_Z_DAG: u64 = 0x18; +pub const CNOT: u64 = 0x19; +pub const CX: u64 = 0x19; +pub const CY: u64 = 0x1A; +pub const CZ: u64 = 0x1B; +pub const SWAP: u64 = 0x1C; +pub const PREP_X: u64 = 0x30; +pub const PREP_Y: u64 = 0x31; +pub const PREP_Z: u64 = 0x32; +pub const ONE_QUBIT_CLIFFORD: u64 = 0x50; +pub const TWO_QUBIT_CLIFFORD: u64 = 0x51; +pub const N_QUBIT_CLIFFORD: u64 = 0x52; +pub const MEAS_X: u64 = 0x100; +pub const MEAS_Y: u64 = 0x101; +pub const MEAS_Z: u64 = 0x102; +pub const MEAS_RESET_X: u64 = 0x103; +pub const MEAS_RESET_Y: u64 = 0x104; +pub const MEAS_RESET_Z: u64 = 0x105; +pub const MEAS_XX: u64 = 0x106; +pub const MEAS_YY: u64 = 0x107; +pub const MEAS_ZZ: u64 = 0x108; +pub const MEAS_XZ: u64 = 0x109; +pub const MEAS_XY: u64 = 0x10A; +pub const MEAS_YZ: u64 = 0x10B; +pub const SQRT_SQRT_X: u64 = 0x400; +pub const SQRT_SQRT_X_DAG: u64 = 0x401; +pub const SQRT_SQRT_Y: u64 = 0x402; +pub const SQRT_SQRT_Y_DAG: u64 = 0x403; +pub const SQRT_SQRT_Z: u64 = 0x404; +pub const T: u64 = 0x404; +pub const SQRT_SQRT_Z_DAG: u64 = 0x405; +pub const T_DAG: u64 = 0x405; +pub const CCX: u64 = 0x406; +pub const CCY: u64 = 0x407; +pub const CCZ: u64 = 0x408; +pub const CSWAP: u64 = 0x409; +pub const AND: u64 = 0x40A; +pub const AND_DAG: u64 = 0x40B; +pub const RX: u64 = 0x40C; +pub const RY: u64 = 0x40D; +pub const RZ: u64 = 0x40E; +pub const CRX: u64 = 0x40F; +pub const CRY: u64 = 0x410; +pub const CRZ: u64 = 0x411; +pub const RXX: u64 = 0x412; +pub const RYY: u64 = 0x413; +pub const RZZ: u64 = 0x414; +pub const MULTI_PAULI_MEAS: u64 = 0x1000; +pub const LATTICE_SURGERY: u64 = 0x1100; +pub const READ_FROM_MEMORY: u64 = 0x1200; +pub const WRITE_TO_MEMORY: u64 = 0x1201; +pub const CYCLIC_SHIFT: u64 = 0x1300; +pub const GENERIC: u64 = 0xFFFF; + +#[must_use] +pub fn is_pauli_measurement(id: u64) -> bool { + matches!( + id, + MEAS_X + | MEAS_Y + | MEAS_Z + | MEAS_XX + | MEAS_YY + | MEAS_ZZ + | MEAS_XZ + | MEAS_XY + | MEAS_YZ + | MULTI_PAULI_MEAS + ) +} + +#[must_use] +pub fn is_t_like(id: u64) -> bool { + matches!( + id, + SQRT_SQRT_X + | SQRT_SQRT_X_DAG + | SQRT_SQRT_Y + | SQRT_SQRT_Y_DAG + | SQRT_SQRT_Z + | SQRT_SQRT_Z_DAG + ) +} + +#[must_use] +pub fn is_ccx_like(id: u64) -> bool { + matches!(id, CCX | CCY | CCZ | CSWAP | AND | AND_DAG) +} + +#[must_use] +pub fn is_rotation_like(id: u64) -> bool { + matches!(id, RX | RY | RZ | RXX | RYY | RZZ) +} + +#[must_use] +pub fn is_clifford(id: u64) -> bool { + matches!( + id, + PAULI_I + | PAULI_X + | PAULI_Y + | PAULI_Z + | H_XZ + | H_XY + | H_YZ + | SQRT_X + | SQRT_X_DAG + | SQRT_Y + | SQRT_Y_DAG + | SQRT_Z + | SQRT_Z_DAG + | CX + | CY + | CZ + | SWAP + | PREP_X + | PREP_Y + | PREP_Z + | ONE_QUBIT_CLIFFORD + | TWO_QUBIT_CLIFFORD + | N_QUBIT_CLIFFORD + ) +} diff --git a/source/qre/src/trace/tests.rs b/source/qre/src/trace/tests.rs new file mode 100644 index 0000000000..6509b30048 --- /dev/null +++ b/source/qre/src/trace/tests.rs @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[test] +fn test_trace_iteration() { + use crate::trace::Trace; + + let mut trace = Trace::new(2); + trace.add_operation(1, vec![0], vec![]); + trace.add_operation(2, vec![1], vec![]); + + assert_eq!(trace.deep_iter().count(), 2); +} + +#[test] +fn test_nested_blocks() { + use crate::trace::Trace; + + let mut trace = Trace::new(3); + trace.add_operation(1, vec![0], vec![]); + let block = trace.add_block(2); + block.add_operation(2, vec![1], vec![]); + let block = block.add_block(3); + block.add_operation(3, vec![2], vec![]); + trace.add_operation(1, vec![0], vec![]); + + let repetitions = trace.deep_iter().map(|(_, rep)| rep).collect::>(); + assert_eq!(repetitions.len(), 4); + assert_eq!(repetitions, vec![1, 2, 6, 1]); +} + +#[test] +fn test_depth_simple() { + use crate::trace::Trace; + + let mut trace = Trace::new(2); + trace.add_operation(1, vec![0], vec![]); + trace.add_operation(2, vec![1], vec![]); + + // Operations are parallel + assert_eq!(trace.depth(), 1); + + trace.add_operation(3, vec![0], vec![]); + // Operation on qubit 0 is sequential to first one + assert_eq!(trace.depth(), 2); +} + +#[test] +fn test_depth_with_blocks() { + use crate::trace::Trace; + + let mut trace = Trace::new(2); + trace.add_operation(1, vec![0], vec![]); // Depth 1 on q0 + + let block = trace.add_block(2); + block.add_operation(2, vec![1], vec![]); // Depth 1 on q1 * 2 reps = 2 + + // Block acts as barrier *only on qubits it touches*. + // q1 is touched. q0 is not. + // q0 stays at depth 1. + // q1 ends at depth 2. + + trace.add_operation(3, vec![0], vec![]); + // Next op starts at depth 1 (after op 1). Ends at 2. + + assert_eq!(trace.depth(), 2); +} + +#[test] +fn test_depth_parallel_blocks() { + use crate::trace::Trace; + + let mut trace = Trace::new(4); + + let block1 = trace.add_block(1); + block1.add_operation(1, vec![0], vec![]); // q0: 1 + + let block2 = trace.add_block(1); + block2.add_operation(2, vec![1], vec![]); // q1: 1 + + // Blocks are parallel + assert_eq!(trace.depth(), 1); + + trace.add_operation(3, vec![0, 1], vec![]); + // Dependent on q0 (1) and q1 (1). Start at 1. End at 2. + + assert_eq!(trace.depth(), 2); +} + +#[test] +fn test_depth_entangled() { + use crate::trace::Trace; + + let mut trace = Trace::new(2); + trace.add_operation(1, vec![0], vec![]); // q0: 1 + trace.add_operation(2, vec![1], vec![]); // q1: 1 + + trace.add_operation(3, vec![0, 1], vec![]); // q0, q1 synced at 1 -> end at 2 + + assert_eq!(trace.depth(), 2); + + trace.add_operation(4, vec![0], vec![]); // q0: 3 + assert_eq!(trace.depth(), 3); +} + +#[test] +fn test_psspc_transform() { + use crate::trace::{PSSPC, Trace, TraceTransform, instruction_ids::*}; + + let mut trace = Trace::new(3); + + trace.add_operation(T, vec![0], vec![]); + trace.add_operation(CCX, vec![0, 1, 2], vec![]); + trace.add_operation(RZ, vec![0], vec![0.1]); + trace.add_operation(CX, vec![0, 1], vec![]); + trace.add_operation(RZ, vec![1], vec![0.2]); + trace.add_operation(MEAS_Z, vec![0], vec![]); + + // Configure PSSPC with 20 T states per rotation, include CCX magic states + let psspc = PSSPC::new(20, true); + + let transformed = psspc.transform(&trace).expect("Transformation failed"); + + assert_eq!(transformed.compute_qubits(), 12); + assert_eq!(transformed.depth(), 47); + + assert_eq!(transformed.get_resource_state_count(T), 41); + assert_eq!(transformed.get_resource_state_count(CCX), 1); + + assert!(transformed.base_error() > 0.0); + // Error is roughly 5e-9 for 20 Ts + assert!(transformed.base_error() < 1e-8); +} + +#[test] +fn test_lattice_surgery_transform() { + use crate::trace::{LatticeSurgery, Trace, TraceTransform, instruction_ids::*}; + + let mut trace = Trace::new(3); + + trace.add_operation(T, vec![0], vec![]); + trace.add_operation(CX, vec![1, 2], vec![]); + trace.add_operation(T, vec![0], vec![]); + + assert_eq!(trace.depth(), 2); + + let ls = LatticeSurgery::new(); + let transformed = ls.transform(&trace).expect("Transformation failed"); + + assert_eq!(transformed.compute_qubits(), 3); + assert_eq!(transformed.depth(), 2); + + // Check that we have a LATTICE_SURGERY operation + // TraceIterator visits the operation definition once, but with a multiplier. + let ls_ops: Vec<_> = transformed + .deep_iter() + .filter(|(gate, _)| gate.id == LATTICE_SURGERY) + .collect(); + + assert_eq!(ls_ops.len(), 1); + + let (gate, mult) = ls_ops[0]; + assert_eq!(gate.id, LATTICE_SURGERY); + assert_eq!(mult, 2); // Multiplier should carry the repetition count (depth) +} + +#[test] +fn test_estimate_simple() { + use crate::isa::{Encoding, ISA, Instruction}; + use crate::trace::{Trace, instruction_ids::*}; + + let mut trace = Trace::new(1); + trace.add_operation(T, vec![0], vec![]); + + // Create ISA + let mut isa = ISA::new(); + isa.add_instruction(Instruction::fixed_arity( + T, + Encoding::Logical, + 1, // arity + 100, // time + Some(50), // space + None, // length (defaults to arity) + 0.001, // error_rate + )); + + let result = trace.estimate(&isa, None).expect("Estimation failed"); + + assert!((result.error() - 0.001).abs() <= f64::EPSILON); + assert_eq!(result.runtime(), 100); + assert_eq!(result.qubits(), 50); +} + +#[test] +fn test_estimate_with_factory() { + use crate::isa::{Encoding, ISA, Instruction}; + use crate::trace::{Trace, instruction_ids::*}; + + let mut trace = Trace::new(1); + // Algorithm needs 1000 T states + trace.increment_resource_state(T, 1000); + + // Some compute runtime to allow factories to run + trace.add_operation(GENERIC, vec![0], vec![]); + + let mut isa = ISA::new(); + + // T factory instruction + // Produces 1 T state + isa.add_instruction(Instruction::fixed_arity( + T, + Encoding::Logical, + 1, // arity + 10, // time to produce 1 state + Some(50), // space for factory + None, + 0.0001, // error rate of produced state + )); + + isa.add_instruction(Instruction::fixed_arity( + GENERIC, + Encoding::Logical, + 1, + 1000, // runtime 1000 + Some(200), + None, + 0.0, + )); + + let result = trace.estimate(&isa, None).expect("Estimation failed"); + + assert_eq!(result.runtime(), 1000); + assert_eq!(result.qubits(), 700); + + // Check factory result + let factory_res = result.factories().get(&T).expect("Factory missing"); + assert_eq!(factory_res.copies(), 10); + assert_eq!(factory_res.runs(), 100); + assert_eq!(result.factories().len(), 1); +} diff --git a/source/qre/src/trace/transforms.rs b/source/qre/src/trace/transforms.rs new file mode 100644 index 0000000000..d232ba5b9f --- /dev/null +++ b/source/qre/src/trace/transforms.rs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +mod lattice_surgery; +mod psspc; + +pub use lattice_surgery::LatticeSurgery; +pub use psspc::PSSPC; + +use crate::{Error, Trace}; + +pub trait TraceTransform { + fn transform(&self, trace: &Trace) -> Result; +} diff --git a/source/qre/src/trace/transforms/lattice_surgery.rs b/source/qre/src/trace/transforms/lattice_surgery.rs new file mode 100644 index 0000000000..425606b99d --- /dev/null +++ b/source/qre/src/trace/transforms/lattice_surgery.rs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::trace::TraceTransform; +use crate::{Error, Trace, instruction_ids}; + +#[derive(Default)] +pub struct LatticeSurgery; + +impl LatticeSurgery { + #[must_use] + pub fn new() -> Self { + Self + } +} + +impl TraceTransform for LatticeSurgery { + fn transform(&self, trace: &Trace) -> Result { + let mut transformed = trace.clone_empty(None); + + let block = transformed.add_block(trace.depth()); + block.add_operation( + instruction_ids::LATTICE_SURGERY, + (0..trace.compute_qubits()).collect(), + vec![], + ); + + Ok(transformed) + } +} diff --git a/source/qre/src/trace/transforms/psspc.rs b/source/qre/src/trace/transforms/psspc.rs new file mode 100644 index 0000000000..287e6c0aa1 --- /dev/null +++ b/source/qre/src/trace/transforms/psspc.rs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::trace::{Gate, TraceTransform}; +use crate::{Error, Trace, instruction_ids}; + +/// Implements the Parellel Synthesis Sequential Pauli Computation (PSSPC) +/// layout algorithm described in Appendix D in +/// [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629). This scheme combines +/// sequential Pauli-based computation (SPC) as described in +/// [arXiv:1808.02892](https://arxiv.org/pdf/1808.02892) and +/// [arXiv:2109.02746](https://arxiv.org/pdf/2109.02746) with an approach to +/// synthesize sets of diagonal non-Clifford unitaries in parallel as done in +/// [arXiv:2110.11493](https://arxiv.org/pdf/2110.11493). +/// +/// References: +/// - Michael E. Beverland, Prakash Murali, Matthias Troyer, Krysta M. Svore, +/// Torsten Hoefler, Vadym Kliuchnikov, Guang Hao Low, Mathias Soeken, Aarthi +/// Sundaram, Alexander Vaschillo: Assessing requirements to scale to +/// practical quantum advantage, +/// [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) +/// - Daniel Litinski: A Game of Surface Codes: Large-Scale Quantum Computing +/// with Lattice Surgery, [arXiv:1808.02892](https://arxiv.org/pdf/1808.02892) +/// - Christopher Chamberland, Earl T. Campbell: Universal quantum computing +/// with twist-free and temporally encoded lattice surgery, +/// [arXiv:2109.02746](https://arxiv.org/pdf/2109.02746) +/// - Michael Beverland, Vadym Kliuchnikov, Eddie Schoute: Surface code +/// compilation via edge-disjoint paths, +/// [arXiv:2110.11493](https://arxiv.org/pdf/2110.11493). +#[derive(Clone)] +pub struct PSSPC { + /// Number of multi-qubit Pauli measurements to inject a synthesized + /// rotation, defaults to 1, see [arXiv:2211.07629, (D3)] + num_measurements_per_r: u64, + /// Number of multi-qubit Pauli measurements to apply a Toffoli gate, + /// defaults to 3, see [arXiv:2211.07629, (D3)] + num_measurements_per_ccx: u64, + /// Number of Pauli measurements to write to memory, defaults to 2, see + /// [arXiv:2109.02746, Fig. 16a] + num_measurements_per_wtm: u64, + /// Number of Pauli measurements to read from memory, defaults to 1, see + /// [arXiv:2109.02746, Fig. 16b] + num_measurements_per_rfm: u64, + + /// Number of Ts per rotation synthesis + num_ts_per_rotation: u64, + /// Perform Toffoli gates using CCX magic states, if false, T gates are used + ccx_magic_states: bool, +} + +impl PSSPC { + #[must_use] + pub fn new(num_ts_per_rotation: u64, ccx_magic_states: bool) -> Self { + Self { + num_measurements_per_r: 1, + num_measurements_per_ccx: 3, + num_measurements_per_wtm: 2, + num_measurements_per_rfm: 1, + num_ts_per_rotation, + ccx_magic_states, + } + } +} + +impl PSSPC { + #[allow(clippy::cast_possible_truncation)] + fn psspc_counts(trace: &Trace) -> Result { + let mut counter = PSSPCCounts::default(); + + let mut max_rotation_depth = vec![0; trace.compute_qubits() as usize]; + + for (Gate { id, qubits, .. }, mult) in trace.deep_iter() { + if instruction_ids::is_pauli_measurement(*id) { + counter.measurements += mult; + } else if instruction_ids::is_t_like(*id) { + counter.t_like += mult; + } else if instruction_ids::is_ccx_like(*id) { + counter.ccx_like += mult; + } else if instruction_ids::is_rotation_like(*id) { + counter.rotation_like += mult; + + // Track rotation depth + let mut current_depth = 0; + for q in qubits { + if max_rotation_depth[*q as usize] > current_depth { + current_depth = max_rotation_depth[*q as usize]; + } + } + let new_depth = current_depth + mult; + for q in qubits { + max_rotation_depth[*q as usize] = new_depth; + } + if new_depth > counter.rotation_depth { + counter.rotation_depth = new_depth; + } + } else if *id == instruction_ids::READ_FROM_MEMORY { + counter.read_from_memory += mult; + } else if *id == instruction_ids::WRITE_TO_MEMORY { + counter.write_to_memory += mult; + } else if !instruction_ids::is_clifford(*id) { + // Unsupported non-Clifford gate + return Err(Error::UnsupportedInstruction { + id: *id, + name: "PSSPC", + }); + } else { + // For Clifford gates, synchronize depths across qubits + if !qubits.is_empty() { + let mut max_depth = 0; + for q in qubits { + if max_rotation_depth[*q as usize] > max_depth { + max_depth = max_rotation_depth[*q as usize]; + } + } + for q in qubits { + max_rotation_depth[*q as usize] = max_depth; + } + } + } + } + + Ok(counter) + } + + #[allow(clippy::cast_precision_loss)] + fn compute_only_trace(&self, trace: &Trace, counts: &PSSPCCounts) -> Trace { + let num_qubits = trace.compute_qubits(); + let logical_qubits = Self::logical_qubit_overhead(num_qubits); + + let mut transformed = trace.clone_empty(Some(logical_qubits)); + + let logical_depth = self.logical_depth_overhead(counts); + let (t_states, ccx_states) = self.num_magic_states(counts); + + transformed.increment_resource_state(instruction_ids::T, t_states); + transformed.increment_resource_state(instruction_ids::CCX, ccx_states); + + let block = transformed.add_block(logical_depth); + block.add_operation( + instruction_ids::MULTI_PAULI_MEAS, + (0..logical_qubits).collect(), + vec![], + ); + + // Add error due to rotation synthesis + transformed.increment_base_error(counts.rotation_like as f64 * self.synthesis_error()); + + transformed + } + + /// Calculates the number of logical qubits required for the PSSPC layout + /// according to Eq. (D1) in + /// [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) + #[allow( + clippy::cast_precision_loss, + clippy::cast_possible_truncation, + clippy::cast_sign_loss + )] + fn logical_qubit_overhead(algorithm_qubits: u64) -> u64 { + let qubit_padding = ((8 * algorithm_qubits) as f64).sqrt().ceil() as u64 + 1; + 2 * algorithm_qubits + qubit_padding + } + + /// Calculates the number of multi-qubit Pauli measurements executed in + /// sequence according to Eq. (D3) in + /// [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) + fn logical_depth_overhead(&self, counter: &PSSPCCounts) -> u64 { + (counter.measurements + counter.t_like + counter.rotation_like) + * self.num_measurements_per_r + + counter.ccx_like * self.num_measurements_per_ccx + + counter.read_from_memory * self.num_measurements_per_rfm + + counter.write_to_memory * self.num_measurements_per_wtm + + (self.num_ts_per_rotation * counter.rotation_depth) * self.num_measurements_per_r + } + + /// Calculates the number of T and CCX magic states that are consumed by + /// multi-qubit Pauli measurements executed by PSSPC according to Eq. (D4) + /// in [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) + /// + /// CCX magic states are only counted if the hyper parameter + /// `ccx_magic_states` is set to true. + fn num_magic_states(&self, counter: &PSSPCCounts) -> (u64, u64) { + let t_states = counter.t_like + self.num_ts_per_rotation * counter.rotation_like; + + if self.ccx_magic_states { + (t_states, counter.ccx_like) + } else { + (t_states + 4 * counter.ccx_like, 0) + } + } + + /// Calculates the synthesis error from the formula provided in Table 1 in + /// [arXiv:2203.10064](https://arxiv.org/pdf/2203.10064) for Clifford+T in + /// the mixed fallback approximation protocol. + #[allow(clippy::cast_precision_loss)] + fn synthesis_error(&self) -> f64 { + 2f64.powf((4.86 - self.num_ts_per_rotation as f64) / 0.53) + } +} + +impl TraceTransform for PSSPC { + fn transform(&self, trace: &Trace) -> Result { + let counts = Self::psspc_counts(trace)?; + + Ok(self.compute_only_trace(trace, &counts)) + } +} + +#[derive(Default)] +struct PSSPCCounts { + measurements: u64, + t_like: u64, + ccx_like: u64, + rotation_like: u64, + write_to_memory: u64, + read_from_memory: u64, + rotation_depth: u64, +} From 4c706d402ed3d3dbbaaa0957f1075435d53f6ecb Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Tue, 3 Feb 2026 01:00:12 -0800 Subject: [PATCH 06/45] Magnets: added greedy edge coloring method, and two simple 1d lattices (#2912) Updated the Hypergraph class to support edge partitioning (coloring) * Added a function to perform (multiple trials) of greedy edge coloring * Added two simple 1d lattices: a open string (chain) and a ring Added/updated tests. --------- Co-authored-by: Mathias Soeken --- .../pip/qsharp/magnets/geometry/__init__.py | 11 +- .../pip/qsharp/magnets/geometry/hypergraph.py | 150 +++++++++-- .../pip/qsharp/magnets/geometry/lattice1d.py | 119 +++++++++ source/pip/tests/magnets/test_hypergraph.py | 188 +++++++++++++- source/pip/tests/magnets/test_lattice1d.py | 236 ++++++++++++++++++ 5 files changed, 677 insertions(+), 27 deletions(-) create mode 100644 source/pip/qsharp/magnets/geometry/lattice1d.py create mode 100644 source/pip/tests/magnets/test_lattice1d.py diff --git a/source/pip/qsharp/magnets/geometry/__init__.py b/source/pip/qsharp/magnets/geometry/__init__.py index 649b2a37b2..beecd639f2 100644 --- a/source/pip/qsharp/magnets/geometry/__init__.py +++ b/source/pip/qsharp/magnets/geometry/__init__.py @@ -8,6 +8,13 @@ and interaction graphs. """ -from .hypergraph import Hyperedge, Hypergraph +from .hypergraph import Hyperedge, Hypergraph, greedy_edge_coloring +from .lattice1d import Chain1D, Ring1D -__all__ = ["Hyperedge", "Hypergraph"] +__all__ = [ + "Hyperedge", + "Hypergraph", + "greedy_edge_coloring", + "Chain1D", + "Ring1D", +] diff --git a/source/pip/qsharp/magnets/geometry/hypergraph.py b/source/pip/qsharp/magnets/geometry/hypergraph.py index dd55ebf408..f64dc79e63 100644 --- a/source/pip/qsharp/magnets/geometry/hypergraph.py +++ b/source/pip/qsharp/magnets/geometry/hypergraph.py @@ -9,7 +9,9 @@ Hamiltonians, where multi-body interactions can involve more than two sites. """ -from typing import Iterator +from copy import deepcopy +import random +from typing import Iterator, Optional class Hyperedge: @@ -55,18 +57,20 @@ class Hypergraph: various lattice geometries used in quantum simulations. Attributes: - _edges: List of hyperedges in the order they were added. + _edge_list: List of hyperedges in the order they were added. _vertex_set: Set of all unique vertex indices in the hypergraph. - _edge_list: Set of hyperedges for efficient membership testing. + parts: List of lists, where each sublist contains indices of edges + belonging to a specific part of an edge partitioning. This is useful + for parallelism in certain architectures. Example: .. code-block:: python >>> edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] >>> graph = Hypergraph(edges) - >>> graph.nvertices() + >>> graph.nvertices 3 - >>> graph.nedges() + >>> graph.nedges 3 """ @@ -76,44 +80,156 @@ def __init__(self, edges: list[Hyperedge]) -> None: Args: edges: List of hyperedges defining the hypergraph structure. """ - self._edges = edges self._vertex_set = set() - self._edge_list = set(edges) + self._edge_list = edges + self.parts = [list(range(len(edges)))] # Single partition by default for edge in edges: self._vertex_set.update(edge.vertices) @property def nedges(self) -> int: """Return the number of hyperedges in the hypergraph.""" - return len(self._edges) + return len(self._edge_list) @property def nvertices(self) -> int: """Return the number of vertices in the hypergraph.""" return len(self._vertex_set) + def add_edge(self, edge: Hyperedge, part: int = 0) -> None: + """Add a hyperedge to the hypergraph. + + Args: + edge: The Hyperedge instance to add. + part: Partition index, used for implementations + with edge partitioning for parallel updates. By + default, all edges are added to the single part + with index 0. + """ + self._edge_list.append(edge) + self._vertex_set.update(edge.vertices) + self.parts[part].append(len(self._edge_list) - 1) # Add to specified partition + def vertices(self) -> Iterator[int]: - """Return an iterator over vertices in sorted order. + """Iterate over all vertex indices in the hypergraph. Returns: - Iterator yielding vertex indices in ascending order. + Iterator of vertex indices in ascending order. """ return iter(sorted(self._vertex_set)) - def edges(self, part: int = 0) -> Iterator[Hyperedge]: - """Return an iterator over hyperedges in the hypergraph. + def edges(self) -> Iterator[Hyperedge]: + """Iterate over all hyperedges in the hypergraph. + + Returns: + Iterator of all hyperedges in the hypergraph. + """ + return iter(self._edge_list) + + def edges_by_part(self, part: int) -> Iterator[Hyperedge]: + """Iterate over hyperedges in a specific partition of the hypergraph. Args: - part: Partition index (reserved for subclass implementations - that support edge partitioning for parallel updates). + part: Partition index, used for implementations + with edge partitioning for parallel updates. By + default, all edges are in a single part with + index 0. Returns: - Iterator over all hyperedges in the hypergraph. + Iterator of hyperedges in the specified partition. """ - return iter(self._edge_list) + return iter([self._edge_list[i] for i in self.parts[part]]) def __str__(self) -> str: return f"Hypergraph with {self.nvertices} vertices and {self.nedges} edges." def __repr__(self) -> str: - return f"Hypergraph({list(self._edges)})" + return f"Hypergraph({list(self._edge_list)})" + + +def greedy_edge_coloring( + hypergraph: Hypergraph, # The hypergraph to color. + seed: Optional[int] = None, # Random seed for reproducibility. + trials: int = 1, # Number of trials to perform. +) -> Hypergraph: + """Perform a (nondeterministic) greedy edge coloring of the hypergraph. + Args: + hypergraph: The Hypergraph instance to color. + seed: Optional random seed for reproducibility. + trials: Number of trials to perform. The coloring with the fewest colors + will be returned. Default is 1. + + Returns: + A Hypergraph where each (hyper)edge is assigned a color + such that no two (hyper)edges sharing a vertex have the + same color. + """ + + best = Hypergraph(hypergraph._edge_list) # Placeholder for best coloring found + + if seed is not None: + random.seed(seed) + + # Shuffle edge indices to randomize insertion order + edge_indexes = list(range(hypergraph.nedges)) + random.shuffle(edge_indexes) + + best.parts = [[]] # Initialize with one empty color part + used_vertices = [set()] # Vertices used by each color + + for i in range(len(edge_indexes)): + edge = hypergraph._edge_list[edge_indexes[i]] + for j in range(len(best.parts) + 1): + + # If we've reached a new color, add it + if j == len(best.parts): + best.parts.append([]) + used_vertices.append(set()) + + # Check if this edge can be added to color j + # Note that we always match on the last color if it was added + # if so, add it and break + if not any(v in used_vertices[j] for v in edge.vertices): + best.parts[j].append(edge_indexes[i]) + used_vertices[j].update(edge.vertices) + break + + least_colors = len(best.parts) + + # To do: parallelize over trials + for trial in range(1, trials): + + # Set random seed for reproducibility + # Designed to work with parallel trials + if seed is not None: + random.seed(seed + trial) + + # Shuffle edge indices to randomize insertion order + edge_indexes = list(range(hypergraph.nedges)) + random.shuffle(edge_indexes) + + parts = [[]] # Initialize with one empty color part + used_vertices = [set()] # Vertices used by each color + + for i in range(len(edge_indexes)): + edge = hypergraph._edge_list[edge_indexes[i]] + for j in range(len(parts) + 1): + + # If we've reached a new color, add it + if j == len(parts): + parts.append([]) + used_vertices.append(set()) + + # Check if this edge can be added to color j + # if so, add it and break + if not any(v in used_vertices[j] for v in edge.vertices): + parts[j].append(edge_indexes[i]) + used_vertices[j].update(edge.vertices) + break + + # If this trial used fewer colors, update best + if len(parts) < least_colors: + least_colors = len(parts) + best.parts = deepcopy(parts) + + return best diff --git a/source/pip/qsharp/magnets/geometry/lattice1d.py b/source/pip/qsharp/magnets/geometry/lattice1d.py new file mode 100644 index 0000000000..a5a892fff4 --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/lattice1d.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""One-dimensional lattice geometries for quantum simulations. + +This module provides classes for representing 1D lattice structures as +hypergraphs. These lattices are commonly used in quantum spin chain +simulations and other one-dimensional quantum systems. +""" + +from qsharp.magnets.geometry.hypergraph import Hyperedge, Hypergraph + + +class Chain1D(Hypergraph): + """A one-dimensional open chain lattice. + + Represents a linear chain of vertices with nearest-neighbor edges. + The chain has open boundary conditions, meaning the first and last + vertices are not connected. + + The edges are partitioned into two parts for parallel updates: + - Part 0 (if self_loops): Self-loop edges on each vertex + - Part 1: Even-indexed nearest-neighbor edges (0-1, 2-3, ...) + - Part 2: Odd-indexed nearest-neighbor edges (1-2, 3-4, ...) + + Attributes: + length: Number of vertices in the chain. + + Example: + + .. code-block:: python + >>> chain = Chain1D(4) + >>> chain.nvertices + 4 + >>> chain.nedges + 3 + """ + + def __init__(self, length: int, self_loops: bool = False) -> None: + """Initialize a 1D chain lattice. + + Args: + length: Number of vertices in the chain. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + if self_loops: + _edges = [Hyperedge([i]) for i in range(length)] + else: + _edges = [] + + for i in range(length - 1): + _edges.append(Hyperedge([i, i + 1])) + super().__init__(_edges) + + # Set up edge partitions for parallel updates + if self_loops: + self.parts = [list(range(length - 1))] + else: + self.parts = [] + + self.parts.append(list(range(0, length - 1, 2))) + self.parts.append(list(range(1, length - 1, 2))) + + self.length = length + + +class Ring1D(Hypergraph): + """A one-dimensional ring (periodic chain) lattice. + + Represents a circular chain of vertices with nearest-neighbor edges. + The ring has periodic boundary conditions, meaning the first and last + vertices are connected. + + The edges are partitioned into two parts for parallel updates: + - Part 0 (if self_loops): Self-loop edges on each vertex + - Part 1: Even-indexed nearest-neighbor edges (0-1, 2-3, ...) + - Part 2: Odd-indexed nearest-neighbor edges (1-2, 3-4, ...) + + Attributes: + length: Number of vertices in the ring. + + Example: + + .. code-block:: python + >>> ring = Ring1D(4) + >>> ring.nvertices + 4 + >>> ring.nedges + 4 + """ + + def __init__(self, length: int, self_loops: bool = False) -> None: + """Initialize a 1D ring lattice. + + Args: + length: Number of vertices in the ring. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + if self_loops: + _edges = [Hyperedge([i]) for i in range(length)] + else: + _edges = [] + + for i in range(length): + _edges.append(Hyperedge([i, (i + 1) % length])) + super().__init__(_edges) + + # Set up edge partitions for parallel updates + if self_loops: + self.parts = [list(range(length))] + else: + self.parts = [] + + self.parts.append(list(range(0, length, 2))) + self.parts.append(list(range(1, length, 2))) + + self.length = length diff --git a/source/pip/tests/magnets/test_hypergraph.py b/source/pip/tests/magnets/test_hypergraph.py index 5a050993c9..3063fcb727 100755 --- a/source/pip/tests/magnets/test_hypergraph.py +++ b/source/pip/tests/magnets/test_hypergraph.py @@ -3,7 +3,11 @@ """Unit tests for hypergraph data structures.""" -from qsharp.magnets.geometry.hypergraph import Hyperedge, Hypergraph +from qsharp.magnets.geometry.hypergraph import ( + Hyperedge, + Hypergraph, + greedy_edge_coloring, +) # Hyperedge tests @@ -48,6 +52,12 @@ def test_hyperedge_empty_vertices(): assert len(edge.vertices) == 0 +def test_hyperedge_duplicate_vertices(): + """Test that duplicate vertices are removed.""" + edge = Hyperedge([1, 2, 2, 1, 3]) + assert edge.vertices == [1, 2, 3] + + # Hypergraph tests @@ -103,15 +113,39 @@ def test_hypergraph_edges_iterator(): assert len(edge_list) == 2 -def test_hypergraph_edges_with_part_parameter(): - """Test edges iterator with part parameter (base class ignores it).""" +def test_hypergraph_edges_by_part(): + """Test edgesByPart returns edges in a specific partition.""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] graph = Hypergraph(edges) - # Base class returns all edges regardless of part parameter - edge_list_0 = list(graph.edges(part=0)) - edge_list_1 = list(graph.edges(part=1)) - assert len(edge_list_0) == 2 - assert len(edge_list_1) == 2 + # Default: all edges in part 0 + edge_list = list(graph.edges_by_part(0)) + assert len(edge_list) == 2 + + +def test_hypergraph_add_edge(): + """Test adding an edge to the hypergraph.""" + graph = Hypergraph([]) + graph.add_edge(Hyperedge([0, 1])) + assert graph.nedges == 1 + assert graph.nvertices == 2 + + +def test_hypergraph_add_edge_to_part(): + """Test adding edges to different partitions.""" + graph = Hypergraph([Hyperedge([0, 1])]) + graph.parts.append([]) # Add a second partition + graph.add_edge(Hyperedge([2, 3]), part=1) + assert graph.nedges == 2 + assert len(graph.parts[0]) == 1 + assert len(graph.parts[1]) == 1 + + +def test_hypergraph_parts_default(): + """Test that default parts contain all edge indices.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + assert len(graph.parts) == 1 + assert graph.parts[0] == [0, 1, 2] def test_hypergraph_str(): @@ -158,3 +192,141 @@ def test_hypergraph_non_contiguous_vertices(): assert graph.nvertices == 4 vertices = list(graph.vertices()) assert vertices == [0, 5, 10, 20] + + +# greedyEdgeColoring tests + + +def test_greedy_edge_coloring_empty(): + """Test greedy edge coloring on empty hypergraph.""" + graph = Hypergraph([]) + colored = greedy_edge_coloring(graph) + assert colored.nedges == 0 + assert len(colored.parts) == 1 + assert colored.parts[0] == [] + + +def test_greedy_edge_coloring_single_edge(): + """Test greedy edge coloring with a single edge.""" + graph = Hypergraph([Hyperedge([0, 1])]) + colored = greedy_edge_coloring(graph, seed=42) + assert colored.nedges == 1 + assert len(colored.parts) == 1 + + +def test_greedy_edge_coloring_non_overlapping(): + """Test coloring of non-overlapping edges (can share color).""" + edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + colored = greedy_edge_coloring(graph, seed=42) + # Non-overlapping edges can be in the same color + assert colored.nedges == 2 + assert len(colored.parts) == 1 + + +def test_greedy_edge_coloring_overlapping(): + """Test coloring of overlapping edges (need different colors).""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] + graph = Hypergraph(edges) + colored = greedy_edge_coloring(graph, seed=42) + # Overlapping edges need different colors + assert colored.nedges == 2 + assert len(colored.parts) == 2 + + +def test_greedy_edge_coloring_triangle(): + """Test coloring of a triangle (3 edges, all pairwise overlapping).""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] + graph = Hypergraph(edges) + colored = greedy_edge_coloring(graph, seed=42) + # All edges share vertices pairwise, so need 3 colors + assert colored.nedges == 3 + assert len(colored.parts) == 3 + + +def test_greedy_edge_coloring_validity(): + """Test that coloring is valid (no two edges in same part share a vertex).""" + edges = [ + Hyperedge([0, 1]), + Hyperedge([1, 2]), + Hyperedge([2, 3]), + Hyperedge([3, 4]), + Hyperedge([0, 4]), + ] + graph = Hypergraph(edges) + colored = greedy_edge_coloring(graph, seed=42) + + # Verify each part has no overlapping edges + for part in colored.parts: + used_vertices = set() + for edge_idx in part: + edge = colored._edge_list[edge_idx] + # No vertex should already be used in this part + assert not any(v in used_vertices for v in edge.vertices) + used_vertices.update(edge.vertices) + + +def test_greedy_edge_coloring_all_edges_colored(): + """Test that all edges are assigned to exactly one part.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + colored = greedy_edge_coloring(graph, seed=42) + + # Collect all edge indices from all parts + all_colored = [] + for part in colored.parts: + all_colored.extend(part) + + # Should have exactly 3 edges colored, each once + assert sorted(all_colored) == [0, 1, 2] + + +def test_greedy_edge_coloring_reproducible_with_seed(): + """Test that coloring is reproducible with the same seed.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3]), Hyperedge([0, 3])] + graph = Hypergraph(edges) + + colored1 = greedy_edge_coloring(graph, seed=123) + colored2 = greedy_edge_coloring(graph, seed=123) + + assert colored1.parts == colored2.parts + + +def test_greedy_edge_coloring_multiple_trials(): + """Test that multiple trials can find better colorings.""" + edges = [ + Hyperedge([0, 1]), + Hyperedge([1, 2]), + Hyperedge([2, 3]), + Hyperedge([3, 0]), + ] + graph = Hypergraph(edges) + colored = greedy_edge_coloring(graph, seed=42, trials=10) + # A cycle of 4 edges can be 2-colored + assert len(colored.parts) <= 3 # Greedy may not always find optimal + + +def test_greedy_edge_coloring_hyperedges(): + """Test coloring with multi-vertex hyperedges.""" + edges = [ + Hyperedge([0, 1, 2]), + Hyperedge([2, 3, 4]), + Hyperedge([5, 6, 7]), + ] + graph = Hypergraph(edges) + colored = greedy_edge_coloring(graph, seed=42) + + # First two share vertex 2, third is independent + assert colored.nedges == 3 + assert len(colored.parts) >= 2 + + +def test_greedy_edge_coloring_self_loops(): + """Test coloring with self-loop edges.""" + edges = [Hyperedge([0]), Hyperedge([1]), Hyperedge([2])] + graph = Hypergraph(edges) + colored = greedy_edge_coloring(graph, seed=42) + + # Self-loops don't share vertices, can all be same color + assert colored.nedges == 3 + assert len(colored.parts) == 1 diff --git a/source/pip/tests/magnets/test_lattice1d.py b/source/pip/tests/magnets/test_lattice1d.py new file mode 100644 index 0000000000..f940506f36 --- /dev/null +++ b/source/pip/tests/magnets/test_lattice1d.py @@ -0,0 +1,236 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for 1D lattice data structures.""" + +from qsharp.magnets.geometry.lattice1d import Chain1D, Ring1D + + +# Chain1D tests + + +def test_chain1d_init_basic(): + """Test basic Chain1D initialization.""" + chain = Chain1D(4) + assert chain.nvertices == 4 + assert chain.nedges == 3 + assert chain.length == 4 + + +def test_chain1d_single_vertex(): + """Test Chain1D with a single vertex (no edges).""" + chain = Chain1D(1) + assert chain.nvertices == 0 + assert chain.nedges == 0 + assert chain.length == 1 + + +def test_chain1d_two_vertices(): + """Test Chain1D with two vertices (one edge).""" + chain = Chain1D(2) + assert chain.nvertices == 2 + assert chain.nedges == 1 + + +def test_chain1d_edges(): + """Test that Chain1D creates correct nearest-neighbor edges.""" + chain = Chain1D(4) + edges = list(chain.edges()) + assert len(edges) == 3 + # Check edges are [0,1], [1,2], [2,3] + assert edges[0].vertices == [0, 1] + assert edges[1].vertices == [1, 2] + assert edges[2].vertices == [2, 3] + + +def test_chain1d_vertices(): + """Test that Chain1D vertices are correct.""" + chain = Chain1D(5) + vertices = list(chain.vertices()) + assert vertices == [0, 1, 2, 3, 4] + + +def test_chain1d_with_self_loops(): + """Test Chain1D with self-loops enabled.""" + chain = Chain1D(4, self_loops=True) + assert chain.nvertices == 4 + # 4 self-loops + 3 nearest-neighbor edges = 7 + assert chain.nedges == 7 + + +def test_chain1d_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + chain = Chain1D(3, self_loops=True) + edges = list(chain.edges()) + # First 3 edges should be self-loops + assert edges[0].vertices == [0] + assert edges[1].vertices == [1] + assert edges[2].vertices == [2] + # Next 2 edges should be nearest-neighbor + assert edges[3].vertices == [0, 1] + assert edges[4].vertices == [1, 2] + + +def test_chain1d_parts_without_self_loops(): + """Test edge partitioning without self-loops.""" + chain = Chain1D(5) + # Should have 2 parts: even edges [0,2] and odd edges [1,3] + assert len(chain.parts) == 2 + assert chain.parts[0] == [0, 2] # edges 0-1, 2-3 + assert chain.parts[1] == [1, 3] # edges 1-2, 3-4 + + +def test_chain1d_parts_with_self_loops(): + """Test edge partitioning with self-loops.""" + chain = Chain1D(4, self_loops=True) + # Should have 3 parts: self-loops, even edges, odd edges + assert len(chain.parts) == 3 + + +def test_chain1d_parts_non_overlapping(): + """Test that edges in the same part don't share vertices.""" + chain = Chain1D(6) + for part_indices in chain.parts: + used_vertices = set() + for idx in part_indices: + edge = chain._edge_list[idx] + assert not any(v in used_vertices for v in edge.vertices) + used_vertices.update(edge.vertices) + + +def test_chain1d_str(): + """Test string representation.""" + chain = Chain1D(4) + assert "4 vertices" in str(chain) + assert "3 edges" in str(chain) + + +# Ring1D tests + + +def test_ring1d_init_basic(): + """Test basic Ring1D initialization.""" + ring = Ring1D(4) + assert ring.nvertices == 4 + assert ring.nedges == 4 + assert ring.length == 4 + + +def test_ring1d_two_vertices(): + """Test Ring1D with two vertices (two edges, same pair).""" + ring = Ring1D(2) + assert ring.nvertices == 2 + # Edge 0-1 and edge 1-0 (wrapping), but both are [0,1] after sorting + assert ring.nedges == 2 + + +def test_ring1d_three_vertices(): + """Test Ring1D with three vertices (triangle).""" + ring = Ring1D(3) + assert ring.nvertices == 3 + assert ring.nedges == 3 + + +def test_ring1d_edges(): + """Test that Ring1D creates correct edges including wrap-around.""" + ring = Ring1D(4) + edges = list(ring.edges()) + assert len(edges) == 4 + # Check edges are [0,1], [1,2], [2,3], [0,3] (sorted) + assert edges[0].vertices == [0, 1] + assert edges[1].vertices == [1, 2] + assert edges[2].vertices == [2, 3] + assert edges[3].vertices == [0, 3] # Wrap-around edge + + +def test_ring1d_vertices(): + """Test that Ring1D vertices are correct.""" + ring = Ring1D(5) + vertices = list(ring.vertices()) + assert vertices == [0, 1, 2, 3, 4] + + +def test_ring1d_with_self_loops(): + """Test Ring1D with self-loops enabled.""" + ring = Ring1D(4, self_loops=True) + assert ring.nvertices == 4 + # 4 self-loops + 4 nearest-neighbor edges = 8 + assert ring.nedges == 8 + + +def test_ring1d_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + ring = Ring1D(3, self_loops=True) + edges = list(ring.edges()) + # First 3 edges should be self-loops + assert edges[0].vertices == [0] + assert edges[1].vertices == [1] + assert edges[2].vertices == [2] + # Next 3 edges should be nearest-neighbor (including wrap) + assert edges[3].vertices == [0, 1] + assert edges[4].vertices == [1, 2] + assert edges[5].vertices == [0, 2] # Wrap-around + + +def test_ring1d_parts_without_self_loops(): + """Test edge partitioning without self-loops.""" + ring = Ring1D(4) + # Should have 2 parts for parallel updates + assert len(ring.parts) == 2 + + +def test_ring1d_parts_with_self_loops(): + """Test edge partitioning with self-loops.""" + ring = Ring1D(4, self_loops=True) + # Should have 3 parts: self-loops, even edges, odd edges + assert len(ring.parts) == 3 + + +def test_ring1d_parts_non_overlapping(): + """Test that edges in the same part don't share vertices.""" + ring = Ring1D(6) + for part_indices in ring.parts: + used_vertices = set() + for idx in part_indices: + edge = ring._edge_list[idx] + assert not any(v in used_vertices for v in edge.vertices) + used_vertices.update(edge.vertices) + + +def test_ring1d_str(): + """Test string representation.""" + ring = Ring1D(4) + assert "4 vertices" in str(ring) + assert "4 edges" in str(ring) + + +def test_ring1d_vs_chain1d_edge_count(): + """Test that ring has one more edge than chain of same length.""" + for length in range(2, 10): + chain = Chain1D(length) + ring = Ring1D(length) + assert ring.nedges == chain.nedges + 1 + + +def test_chain1d_inherits_hypergraph(): + """Test that Chain1D is a Hypergraph subclass with all methods.""" + from qsharp.magnets.geometry.hypergraph import Hypergraph + + chain = Chain1D(4) + assert isinstance(chain, Hypergraph) + # Test inherited methods work + assert hasattr(chain, "edges") + assert hasattr(chain, "vertices") + assert hasattr(chain, "edges_by_part") + + +def test_ring1d_inherits_hypergraph(): + """Test that Ring1D is a Hypergraph subclass with all methods.""" + from qsharp.magnets.geometry.hypergraph import Hypergraph + + ring = Ring1D(4) + assert isinstance(ring, Hypergraph) + # Test inherited methods work + assert hasattr(ring, "edges") + assert hasattr(ring, "vertices") + assert hasattr(ring, "edges_by_part") From 50d9bbc21b46a26a1b29b35ba9a9c06de4bda7f3 Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Tue, 3 Feb 2026 02:06:39 -0800 Subject: [PATCH 07/45] Magnets: Base classes for Trotter-Suzuki expansions (#2913) This is the skeleton implementations for Trotter-Suzuki expansions * TrotterStep class wraps the basic functionality of a Trotter step * StrangStep class subclasses TrotterStep for the (second-order) symmetric Trotter step * TrotterExpansion wraps the basic functionality of a full Trotter expansion * Test file included. This PR is disjoint from #2912. --------- Co-authored-by: Mathias Soeken --- source/pip/qsharp/magnets/trotter/__init__.py | 12 + source/pip/qsharp/magnets/trotter/trotter.py | 156 ++++++++++++ source/pip/tests/magnets/test_trotter.py | 241 ++++++++++++++++++ 3 files changed, 409 insertions(+) create mode 100644 source/pip/qsharp/magnets/trotter/__init__.py create mode 100644 source/pip/qsharp/magnets/trotter/trotter.py create mode 100644 source/pip/tests/magnets/test_trotter.py diff --git a/source/pip/qsharp/magnets/trotter/__init__.py b/source/pip/qsharp/magnets/trotter/__init__.py new file mode 100644 index 0000000000..f3107d526a --- /dev/null +++ b/source/pip/qsharp/magnets/trotter/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Trotter-Suzuki methods for time evolution.""" + +from .trotter import TrotterStep, StrangStep, TrotterExpansion + +__all__ = [ + "TrotterStep", + "StrangStep", + "TrotterExpansion", +] diff --git a/source/pip/qsharp/magnets/trotter/trotter.py b/source/pip/qsharp/magnets/trotter/trotter.py new file mode 100644 index 0000000000..b598fe5abf --- /dev/null +++ b/source/pip/qsharp/magnets/trotter/trotter.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Base Trotter class for first- and second-order Trotter-Suzuki decomposition.""" + + +class TrotterStep: + """ + Base class for Trotter decompositions. Essentially, this is a wrapper around + a list of (time, term_index) tuples, which specify which term to apply for + how long. + + As a default, the base class implements the first-order Trotter-Suzuki formula + for approximating time evolution under a Hamiltonian represented as a sum of + terms H = ∑_k H_k by sequentially applying each term for the full time + + e^{-i H t} ≈ ∏_k e^{-i H_k t}. + + This base class is designed for lazy evaluation: the list of (time, term_index) + tuples is only generated when the get() method is called. + + Example: + + .. code-block:: python + >>> trotter = TrotterStep(num_terms=3, time=0.5) + >>> trotter.get() + [(0.5, 0), (0.5, 1), (0.5, 2)] + """ + + def __init__(self, num_terms: int, time: float): + """ + Initialize the Trotter decomposition. + + Args: + num_terms: Number of terms in the Hamiltonian + time: Total time for the evolution + """ + self._num_terms = num_terms + self._time_step = time + + def get(self) -> list[tuple[float, int]]: + """ + Get the Trotter decomposition as a list of (time, term_index) tuples. + + Returns: + List of tuples where each tuple contains the time duration and the + index of the term to be applied. + """ + return [(self._time_step, term_index) for term_index in range(self._num_terms)] + + def __str__(self) -> str: + """String representation of the Trotter decomposition.""" + return f"Trotter(time_step={self._time_step}, num_terms={self._num_terms})" + + def __repr__(self) -> str: + """String representation of the Trotter decomposition.""" + return self.__str__() + + +class StrangStep(TrotterStep): + """ + Strang splitting (second-order Trotter-Suzuki decomposition). + + The second-order Trotter formula uses symmetric splitting: + e^{-i H t} ≈ ∏_{k=1}^{n} e^{-i H_k t/2} ∏_{k=n}^{1} e^{-i H_k t/2} + + This provides second-order accuracy in the time step, compared to + first-order for the basic Trotter decomposition. + + Example: + + .. code-block:: python + >>> strang = StrangStep(num_terms=3, time=0.5) + >>> strang.get() + [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] + """ + + def __init__(self, num_terms: int, time: float): + """ + Initialize the Strang splitting. + + Args: + num_terms: Number of terms in the Hamiltonian + time: Total time for the evolution + """ + super().__init__(num_terms, time) + + def get(self) -> list[tuple[float, int]]: + """ + Get the Strang splitting as a list of (time, term_index) tuples. + + Returns: + List of tuples where each tuple contains the time duration and the + index of the term to be applied. The sequence is symmetric for + second-order accuracy. + """ + terms = [] + # Forward sweep with half time steps + for term_index in range(self._num_terms - 1): + terms.append((self._time_step / 2.0, term_index)) + + # Combine the two middle terms + terms.append((self._time_step, self._num_terms - 1)) + + # Backward sweep with half time steps + for term_index in range(self._num_terms - 2, -1, -1): + terms.append((self._time_step / 2.0, term_index)) + + return terms + + def __str__(self) -> str: + """String representation of the Strang splitting.""" + return f"Strang(time_step={self._time_step}, num_terms={self._num_terms})" + + +class TrotterExpansion: + """ + Trotter expansion class for multiple Trotter steps. This class wraps around + a TrotterStep instance and specifies how many times to repeat this Trotter + step. The expansion can be used to represent the full time evolution + as a sequence of Trotter steps + + e^{-i H t} ≈ (∏_k e^{-i H_k t/n})^n. + + where n is the number of Trotter steps. + + Example: + + .. code-block:: python + >>> n = 4 # Number of Trotter steps + >>> total_time = 1.0 # Total time + >>> trotter_expansion = TrotterExpansion(TrotterStep(2, total_time/n), n) + >>> trotter_expansion.get() + [([(0.25, 0), (0.25, 1)], 4)] + """ + + def __init__(self, trotter_step: TrotterStep, num_steps: int): + """ + Initialize the Trotter expansion. + + Args: + trotter_step: An instance of TrotterStep representing a single Trotter step + num_steps: Number of Trotter steps + """ + self._trotter_step = trotter_step + self._num_steps = num_steps + + def get(self) -> list[tuple[list[tuple[float, int]], int]]: + """ + Get the Trotter expansion as a list of (terms, step_index) tuples. + + Returns: + List of tuples where each tuple contains the list of (time, term_index) + for that step and the number of times that step is executed. + """ + return [(self._trotter_step.get(), self._num_steps)] diff --git a/source/pip/tests/magnets/test_trotter.py b/source/pip/tests/magnets/test_trotter.py new file mode 100644 index 0000000000..bd26ed8f72 --- /dev/null +++ b/source/pip/tests/magnets/test_trotter.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for Trotter-Suzuki decomposition classes.""" + +from qsharp.magnets.trotter import TrotterStep, StrangStep, TrotterExpansion + + +# TrotterStep tests + + +def test_trotter_step_init_basic(): + """Test basic TrotterStep initialization.""" + trotter = TrotterStep(num_terms=3, time=0.5) + assert trotter._num_terms == 3 + assert trotter._time_step == 0.5 + + +def test_trotter_step_get_single_term(): + """Test TrotterStep with a single term.""" + trotter = TrotterStep(num_terms=1, time=1.0) + result = trotter.get() + assert result == [(1.0, 0)] + + +def test_trotter_step_get_multiple_terms(): + """Test TrotterStep with multiple terms.""" + trotter = TrotterStep(num_terms=3, time=0.5) + result = trotter.get() + assert result == [(0.5, 0), (0.5, 1), (0.5, 2)] + + +def test_trotter_step_get_zero_time(): + """Test TrotterStep with zero time.""" + trotter = TrotterStep(num_terms=2, time=0.0) + result = trotter.get() + assert result == [(0.0, 0), (0.0, 1)] + + +def test_trotter_step_get_returns_all_terms(): + """Test that TrotterStep returns all term indices.""" + num_terms = 5 + trotter = TrotterStep(num_terms=num_terms, time=1.0) + result = trotter.get() + assert len(result) == num_terms + term_indices = [idx for _, idx in result] + assert term_indices == list(range(num_terms)) + + +def test_trotter_step_get_uniform_time(): + """Test that all terms have the same time in TrotterStep.""" + time = 0.25 + trotter = TrotterStep(num_terms=4, time=time) + result = trotter.get() + for t, _ in result: + assert t == time + + +def test_trotter_step_str(): + """Test string representation of TrotterStep.""" + trotter = TrotterStep(num_terms=3, time=0.5) + result = str(trotter) + assert "Trotter" in result + assert "0.5" in result + assert "3" in result + + +def test_trotter_step_repr(): + """Test repr representation of TrotterStep.""" + trotter = TrotterStep(num_terms=3, time=0.5) + assert repr(trotter) == str(trotter) + + +# StrangStep tests + + +def test_strang_step_init_basic(): + """Test basic StrangStep initialization.""" + strang = StrangStep(num_terms=3, time=0.5) + assert strang._num_terms == 3 + assert strang._time_step == 0.5 + + +def test_strang_step_inherits_trotter(): + """Test that StrangStep inherits from TrotterStep.""" + strang = StrangStep(num_terms=3, time=0.5) + assert isinstance(strang, TrotterStep) + + +def test_strang_step_get_single_term(): + """Test StrangStep with a single term.""" + strang = StrangStep(num_terms=1, time=1.0) + result = strang.get() + # Single term: just full time on term 0 + assert result == [(1.0, 0)] + + +def test_strang_step_get_two_terms(): + """Test StrangStep with two terms.""" + strang = StrangStep(num_terms=2, time=1.0) + result = strang.get() + # Forward: half on term 0, full on term 1, backward: half on term 0 + assert result == [(0.5, 0), (1.0, 1), (0.5, 0)] + + +def test_strang_step_get_three_terms(): + """Test StrangStep with three terms (example from docstring).""" + strang = StrangStep(num_terms=3, time=0.5) + result = strang.get() + expected = [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] + assert result == expected + + +def test_strang_step_symmetric(): + """Test that StrangStep produces symmetric sequence.""" + strang = StrangStep(num_terms=4, time=1.0) + result = strang.get() + # Check symmetry: term indices should be palindromic + term_indices = [idx for _, idx in result] + assert term_indices == term_indices[::-1] + + +def test_strang_step_time_sum(): + """Test that total time in StrangStep equals expected value.""" + time = 1.0 + num_terms = 3 + strang = StrangStep(num_terms=num_terms, time=time) + result = strang.get() + total_time = sum(t for t, _ in result) + # Each term appears once with full time equivalent + # (half + half for outer terms, full for middle) + assert abs(total_time - time * num_terms) < 1e-10 + + +def test_strang_step_middle_term_full_time(): + """Test that the middle term gets full time step.""" + strang = StrangStep(num_terms=5, time=2.0) + result = strang.get() + # Middle term (index 4, the last term) should have full time + middle_entries = [(t, idx) for t, idx in result if idx == 4] + assert len(middle_entries) == 1 + assert middle_entries[0][0] == 2.0 + + +def test_strang_step_outer_terms_half_time(): + """Test that outer terms get half time steps.""" + strang = StrangStep(num_terms=4, time=2.0) + result = strang.get() + # Term 0 should appear twice with half time each + term_0_entries = [(t, idx) for t, idx in result if idx == 0] + assert len(term_0_entries) == 2 + for t, _ in term_0_entries: + assert t == 1.0 + + +def test_strang_step_str(): + """Test string representation of StrangStep.""" + strang = StrangStep(num_terms=3, time=0.5) + result = str(strang) + assert "Strang" in result + assert "0.5" in result + assert "3" in result + + +# TrotterExpansion tests + + +def test_trotter_expansion_init_basic(): + """Test basic TrotterExpansion initialization.""" + step = TrotterStep(num_terms=2, time=0.25) + expansion = TrotterExpansion(step, num_steps=4) + assert expansion._trotter_step is step + assert expansion._num_steps == 4 + + +def test_trotter_expansion_get_single_step(): + """Test TrotterExpansion with a single step.""" + step = TrotterStep(num_terms=2, time=1.0) + expansion = TrotterExpansion(step, num_steps=1) + result = expansion.get() + assert len(result) == 1 + terms, count = result[0] + assert count == 1 + assert terms == [(1.0, 0), (1.0, 1)] + + +def test_trotter_expansion_get_multiple_steps(): + """Test TrotterExpansion with multiple steps.""" + step = TrotterStep(num_terms=2, time=0.25) + expansion = TrotterExpansion(step, num_steps=4) + result = expansion.get() + assert len(result) == 1 + terms, count = result[0] + assert count == 4 + assert terms == [(0.25, 0), (0.25, 1)] + + +def test_trotter_expansion_with_strang_step(): + """Test TrotterExpansion using StrangStep.""" + step = StrangStep(num_terms=2, time=0.5) + expansion = TrotterExpansion(step, num_steps=2) + result = expansion.get() + assert len(result) == 1 + terms, count = result[0] + assert count == 2 + # StrangStep with 2 terms: [(0.25, 0), (0.5, 1), (0.25, 0)] + assert terms == [(0.25, 0), (0.5, 1), (0.25, 0)] + + +def test_trotter_expansion_total_time(): + """Test that total evolution time is correct.""" + total_time = 1.0 + num_steps = 4 + step = TrotterStep(num_terms=3, time=total_time / num_steps) + expansion = TrotterExpansion(step, num_steps=num_steps) + result = expansion.get() + terms, count = result[0] + # Total time = sum of times in one step * count + step_time = sum(t for t, _ in terms) + total = step_time * count + # For first-order Trotter, step_time = time * num_terms + assert abs(total - total_time * 3) < 1e-10 + + +def test_trotter_expansion_preserves_step(): + """Test that expansion preserves the original step.""" + step = TrotterStep(num_terms=3, time=0.5) + expansion = TrotterExpansion(step, num_steps=10) + result = expansion.get() + terms, _ = result[0] + assert terms == step.get() + + +def test_trotter_expansion_docstring_example(): + """Test the example from the TrotterExpansion docstring.""" + n = 4 # Number of Trotter steps + total_time = 1.0 # Total time + trotter_expansion = TrotterExpansion(TrotterStep(2, total_time / n), n) + result = trotter_expansion.get() + expected = [([(0.25, 0), (0.25, 1)], 4)] + assert result == expected From c8b973ab586262df0a8f16233bd9e1f5d2ffc18b Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Tue, 3 Feb 2026 03:38:42 -0800 Subject: [PATCH 08/45] Base class for magnet models (#2914) This is the base class for a magnet model. * Draws on the geometry from the Hypergraph class (base class for geometries) * Only implements simple modification of the Hamiltonian * Test file included This PR is disjoint from #2912 and #2913 --------- Co-authored-by: Mathias Soeken --- source/pip/pyproject.toml | 1 + source/pip/qsharp/magnets/models/__init__.py | 12 + source/pip/qsharp/magnets/models/model.py | 141 ++++++++++ source/pip/tests/magnets/__init__.py | 11 + source/pip/tests/magnets/test_model.py | 263 +++++++++++++++++++ 5 files changed, 428 insertions(+) create mode 100644 source/pip/qsharp/magnets/models/__init__.py create mode 100644 source/pip/qsharp/magnets/models/model.py create mode 100644 source/pip/tests/magnets/test_model.py diff --git a/source/pip/pyproject.toml b/source/pip/pyproject.toml index 94695b8009..d8cc02a94d 100644 --- a/source/pip/pyproject.toml +++ b/source/pip/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ jupyterlab = ["qsharp-jupyterlab"] widgets = ["qsharp-widgets"] qiskit = ["qiskit>=1.2.2,<3.0.0"] +cirq = ["cirq-core>=1.3.0,<=1.4.1"] [build-system] requires = ["maturin ~= 1.10.2"] diff --git a/source/pip/qsharp/magnets/models/__init__.py b/source/pip/qsharp/magnets/models/__init__.py new file mode 100644 index 0000000000..1f815fb5e9 --- /dev/null +++ b/source/pip/qsharp/magnets/models/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Models module for quantum spin models. + +This module provides classes for representing quantum spin models +as Hamiltonians built from Pauli operators. +""" + +from .model import Model + +__all__ = ["Model"] diff --git a/source/pip/qsharp/magnets/models/model.py b/source/pip/qsharp/magnets/models/model.py new file mode 100644 index 0000000000..e6f1eb7449 --- /dev/null +++ b/source/pip/qsharp/magnets/models/model.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pyright: reportPrivateImportUsage=false + +"""Base Model class for quantum spin models. + +This module provides the base class for representing quantum spin models +as Hamiltonians built from Pauli operators. The Model class integrates +with hypergraph geometries to define interaction topologies and uses +Cirq's PauliString and PauliSum for representing quantum operators. +""" + +from typing import Iterator +from qsharp.magnets.geometry import Hypergraph + +try: + from cirq import LineQubit, PauliSum, PauliString +except Exception as ex: + raise ImportError( + "qsharp.magnets.models requires the cirq extras. Install with 'pip install \"qsharp[cirq]\"'." + ) from ex + + +class Model: + """Base class for quantum spin models. + + This class wraps a list of cirq.PauliSum objects that define the Hamiltonian + of a quantum system. Each element of the list represents a partition of + the Hamiltonian into different terms, which is useful for: + + - Trotterization: Grouping commuting terms for efficient simulation + - Parallel execution: Terms in the same partition can be applied simultaneously + - Resource estimation: Analyzing different parts of the Hamiltonian separately + + The model is built on a hypergraph geometry that defines which qubits + interact with each other. Subclasses should populate the `terms` list + with appropriate PauliSum operators based on the geometry. + + Attributes: + geometry: The Hypergraph defining the interaction topology. + terms: List of PauliSum objects representing partitioned Hamiltonian terms. + + Example: + + .. code-block:: python + >>> from qsharp.magnets.geometry import Chain1D + >>> geometry = Chain1D(4) + >>> model = Model(geometry) + >>> model.add_term() # Add an empty term + >>> len(model.terms) + 1 + """ + + def __init__(self, geometry: Hypergraph): + """Initialize the Model. + + Creates a quantum spin model on the given geometry. The model starts + with no Hamiltonian terms; subclasses or callers should add terms + using `add_term()` and `add_to_term()`. + + Args: + geometry: Hypergraph defining the interaction topology. The number + of vertices determines the number of qubits in the model. + """ + self.geometry: Hypergraph = geometry + self._qubits: list[LineQubit] = [ + LineQubit(i) for i in range(geometry.nvertices) + ] + self.terms: list[PauliSum] = [] + + def add_term(self, term: PauliSum = None) -> None: + """Add a term to the Hamiltonian. + + Appends a new PauliSum to the list of Hamiltonian terms. This is + typically used to create partitions for Trotterization, where each + partition contains operators that can be applied together. + + Args: + term: The PauliSum to add. If None, an empty PauliSum is added, + which can be populated later using `add_to_term()`. + """ + if term is None: + term = PauliSum() + self.terms.append(term) + + def add_to_term(self, index: int, pauli_string: PauliString) -> None: + """Add a PauliString to a specific term in the Hamiltonian. + + Appends a Pauli operator (with coefficient) to an existing term. + This is used to build up the Hamiltonian incrementally. + + Args: + index: Index of the term to add to (0-indexed). + pauli_string: The PauliString to add to the term. This can + include a coefficient, e.g., `0.5 * cirq.Z(q0) * cirq.Z(q1)`. + + Raises: + IndexError: If index is out of range of the terms list. + """ + self.terms[index] += pauli_string + + def q(self, i: int) -> LineQubit: + """Return the qubit at index i. + + Provides convenient access to qubits by their vertex index in + the underlying geometry. + + Args: + i: Index of the qubit (0-indexed, corresponds to vertex index). + + Returns: + The LineQubit at the specified index. + """ + return self._qubits[i] + + def qubit_list(self) -> list[LineQubit]: + """Return the list of qubits in the model. + + Returns: + A list of all LineQubit objects in the model, ordered by index. + """ + return self._qubits + + def qubits(self) -> Iterator[LineQubit]: + """Return an iterator over the qubits in the model. + + Returns: + An iterator yielding LineQubit objects in index order. + """ + return iter(self._qubits) + + def __str__(self) -> str: + """String representation of the model.""" + return "Generic model with {} terms on {} qubits.".format( + len(self.terms), len(self._qubits) + ) + + def __repr__(self) -> str: + """String representation of the model.""" + return self.__str__() diff --git a/source/pip/tests/magnets/__init__.py b/source/pip/tests/magnets/__init__.py index 686737dba3..4540e70bc2 100644 --- a/source/pip/tests/magnets/__init__.py +++ b/source/pip/tests/magnets/__init__.py @@ -2,3 +2,14 @@ # Licensed under the MIT License. """Unit tests for the magnets library.""" + +try: + # pylint: disable=unused-import + # flake8: noqa E401 + import cirq + + CIRQ_AVAILABLE = True +except ImportError: + CIRQ_AVAILABLE = False + +SKIP_REASON = "cirq is not available" diff --git a/source/pip/tests/magnets/test_model.py b/source/pip/tests/magnets/test_model.py new file mode 100644 index 0000000000..0028cbcbab --- /dev/null +++ b/source/pip/tests/magnets/test_model.py @@ -0,0 +1,263 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pyright: reportPrivateImportUsage=false, reportOperatorIssue=false + +"""Unit tests for the Model class.""" + +# To be updated after additional geometries are implemented + +from __future__ import annotations + +import pytest +from . import CIRQ_AVAILABLE, SKIP_REASON + +if CIRQ_AVAILABLE: + import cirq + from cirq import LineQubit + + from qsharp.magnets.geometry import Hyperedge, Hypergraph + from qsharp.magnets.models import Model + + +def make_chain(length: int) -> Hypergraph: + """Create a simple chain hypergraph for testing.""" + edges = [Hyperedge([i, i + 1]) for i in range(length - 1)] + return Hypergraph(edges) + + +# Model initialization tests + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_init_basic(): + """Test basic Model initialization.""" + geometry = Hypergraph([Hyperedge([0, 1]), Hyperedge([1, 2])]) + model = Model(geometry) + assert model.geometry is geometry + assert len(model.terms) == 0 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_init_creates_qubits(): + """Test that Model creates correct number of qubits.""" + geometry = Hypergraph([Hyperedge([0, 1]), Hyperedge([2, 3])]) + model = Model(geometry) + assert len(model.qubit_list()) == 4 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_init_with_chain(): + """Test Model initialization with chain geometry.""" + geometry = make_chain(5) + model = Model(geometry) + assert len(model.qubit_list()) == 5 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_init_empty_geometry(): + """Test Model with empty geometry.""" + geometry = Hypergraph([]) + model = Model(geometry) + assert len(model.qubit_list()) == 0 + assert len(model.terms) == 0 + + +# Qubit access tests + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_q_returns_line_qubit(): + """Test that q() returns LineQubit instances.""" + geometry = make_chain(3) + model = Model(geometry) + qubit = model.q(0) + assert isinstance(qubit, LineQubit) + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_q_returns_correct_qubit(): + """Test that q() returns qubit with correct index.""" + geometry = make_chain(4) + model = Model(geometry) + for i in range(4): + assert model.q(i) == LineQubit(i) + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_qubit_list(): + """Test qubit_list() returns all qubits.""" + geometry = make_chain(3) + model = Model(geometry) + qubits = model.qubit_list() + assert len(qubits) == 3 + assert qubits == [LineQubit(0), LineQubit(1), LineQubit(2)] + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_qubits_iterator(): + """Test qubits() returns an iterator.""" + geometry = make_chain(3) + model = Model(geometry) + qubit_iter = model.qubits() + qubits = list(qubit_iter) + assert len(qubits) == 3 + assert qubits == [LineQubit(0), LineQubit(1), LineQubit(2)] + + +# Term management tests + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_add_term_empty(): + """Test adding an empty term.""" + geometry = make_chain(2) + model = Model(geometry) + model.add_term() + assert len(model.terms) == 1 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_add_term_with_pauli_sum(): + """Test adding a PauliSum term.""" + geometry = make_chain(2) + model = Model(geometry) + q0, q1 = model.q(0), model.q(1) + term = cirq.Z(q0) * cirq.Z(q1) + model.add_term(cirq.PauliSum.from_pauli_strings([term])) + assert len(model.terms) == 1 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_add_multiple_terms(): + """Test adding multiple terms.""" + geometry = make_chain(3) + model = Model(geometry) + model.add_term() + model.add_term() + model.add_term() + assert len(model.terms) == 3 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_add_to_term(): + """Test adding a PauliString to an existing term.""" + geometry = make_chain(2) + model = Model(geometry) + model.add_term() + q0, q1 = model.q(0), model.q(1) + pauli_string = cirq.Z(q0) * cirq.Z(q1) + model.add_to_term(0, pauli_string) + # Term should now contain the Pauli string + assert len(model.terms[0]) == 1 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_add_to_term_multiple_strings(): + """Test adding multiple PauliStrings to the same term.""" + geometry = make_chain(3) + model = Model(geometry) + model.add_term() + q0, q1, q2 = model.q(0), model.q(1), model.q(2) + model.add_to_term(0, cirq.Z(q0) * cirq.Z(q1)) + model.add_to_term(0, cirq.Z(q1) * cirq.Z(q2)) + assert len(model.terms[0]) == 2 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_add_to_different_terms(): + """Test adding PauliStrings to different terms.""" + geometry = make_chain(3) + model = Model(geometry) + model.add_term() + model.add_term() + q0, q1, q2 = model.q(0), model.q(1), model.q(2) + model.add_to_term(0, cirq.Z(q0) * cirq.Z(q1)) + model.add_to_term(1, cirq.Z(q1) * cirq.Z(q2)) + assert len(model.terms[0]) == 1 + assert len(model.terms[1]) == 1 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_add_to_term_with_coefficient(): + """Test adding a PauliString with a coefficient.""" + geometry = make_chain(2) + model = Model(geometry) + model.add_term() + q0, q1 = model.q(0), model.q(1) + pauli_string = 0.5 * cirq.Z(q0) * cirq.Z(q1) + model.add_to_term(0, pauli_string) + assert len(model.terms[0]) == 1 + + +# String representation tests + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_str(): + """Test string representation.""" + geometry = make_chain(4) + model = Model(geometry) + model.add_term() + model.add_term() + result = str(model) + assert "2 terms" in result + assert "4 qubits" in result + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_str_empty(): + """Test string representation with no terms.""" + geometry = make_chain(3) + model = Model(geometry) + result = str(model) + assert "0 terms" in result + assert "3 qubits" in result + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_repr(): + """Test repr representation.""" + geometry = make_chain(2) + model = Model(geometry) + assert repr(model) == str(model) + + +# Integration tests + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_build_simple_hamiltonian(): + """Test building a simple ZZ Hamiltonian on a chain.""" + geometry = make_chain(3) + model = Model(geometry) + model.add_term() # Single term for all interactions + + for edge in geometry.edges(): + i, j = edge.vertices + model.add_to_term(0, cirq.Z(model.q(i)) * cirq.Z(model.q(j))) + + # Should have 2 ZZ interactions: (0,1) and (1,2) + assert len(model.terms[0]) == 2 + + +@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) +def test_model_with_partitioned_terms(): + """Test building a model with partitioned terms for Trotterization.""" + geometry = make_chain(4) + model = Model(geometry) + + # Add two terms for even/odd partitioning + model.add_term() # Even edges: (0,1), (2,3) + model.add_term() # Odd edges: (1,2) + + # Add even edges to term 0 + model.add_to_term(0, cirq.Z(model.q(0)) * cirq.Z(model.q(1))) + model.add_to_term(0, cirq.Z(model.q(2)) * cirq.Z(model.q(3))) + + # Add odd edge to term 1 + model.add_to_term(1, cirq.Z(model.q(1)) * cirq.Z(model.q(2))) + + assert len(model.terms) == 2 + assert len(model.terms[0]) == 2 + assert len(model.terms[1]) == 1 From 29c03ba35137672f85494b579fbe1719888b033d Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Fri, 6 Feb 2026 09:32:17 -0800 Subject: [PATCH 09/45] Magnets: more basic geometries (#2919) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added more basic hypergraphs to the geometies library. * Patch2D and Torus2D: two dimensional analogues of chains and rings * Compete graphs: the base geometry for the Sherrington–Kirkpatrick model * Complete bipartite graphs: useful to create models with hidden frustration The complete graph does not have edge partitioning implemented (this is a bit tricky) * All others have edge partitioning * In the case of complete bipartite graphs it still needs to be tested, and comments added Basic tests added --- .../pip/qsharp/magnets/geometry/__init__.py | 6 + .../pip/qsharp/magnets/geometry/complete.py | 126 ++++++++ .../pip/qsharp/magnets/geometry/lattice2d.py | 188 ++++++++++++ source/pip/tests/magnets/test_complete.py | 247 ++++++++++++++++ source/pip/tests/magnets/test_lattice2d.py | 277 ++++++++++++++++++ 5 files changed, 844 insertions(+) create mode 100644 source/pip/qsharp/magnets/geometry/complete.py create mode 100644 source/pip/qsharp/magnets/geometry/lattice2d.py create mode 100644 source/pip/tests/magnets/test_complete.py create mode 100644 source/pip/tests/magnets/test_lattice2d.py diff --git a/source/pip/qsharp/magnets/geometry/__init__.py b/source/pip/qsharp/magnets/geometry/__init__.py index beecd639f2..3d2ac6c1fb 100644 --- a/source/pip/qsharp/magnets/geometry/__init__.py +++ b/source/pip/qsharp/magnets/geometry/__init__.py @@ -8,13 +8,19 @@ and interaction graphs. """ +from .complete import CompleteBipartiteGraph, CompleteGraph from .hypergraph import Hyperedge, Hypergraph, greedy_edge_coloring from .lattice1d import Chain1D, Ring1D +from .lattice2d import Patch2D, Torus2D __all__ = [ + "CompleteBipartiteGraph", + "CompleteGraph", "Hyperedge", "Hypergraph", "greedy_edge_coloring", "Chain1D", "Ring1D", + "Patch2D", + "Torus2D", ] diff --git a/source/pip/qsharp/magnets/geometry/complete.py b/source/pip/qsharp/magnets/geometry/complete.py new file mode 100644 index 0000000000..595b5ff162 --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/complete.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Complete graph geometries for quantum simulations. + +This module provides classes for representing complete graphs and complete +bipartite graphs as hypergraphs. These structures are useful for quantum +systems with all-to-all or bipartite all-to-all interactions. +""" + +from qsharp.magnets.geometry.hypergraph import ( + Hyperedge, + Hypergraph, + greedy_edge_coloring, +) + + +class CompleteGraph(Hypergraph): + """A complete graph where every vertex is connected to every other vertex. + + In a complete graph K_n, there are n vertices and n(n-1)/2 edges, + with each pair of distinct vertices connected by exactly one edge. + + To do: edge partitioning for parallel updates. + + Attributes: + n: Number of vertices in the graph. + + Example: + + .. code-block:: python + >>> graph = CompleteGraph(4) + >>> graph.nvertices + 4 + >>> graph.nedges + 6 + """ + + def __init__(self, n: int, self_loops: bool = False) -> None: + """Initialize a complete graph. + + Args: + n: Number of vertices in the graph. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + if self_loops: + _edges = [Hyperedge([i]) for i in range(n)] + else: + _edges = [] + + # Add all pairs of vertices + for i in range(n): + for j in range(i + 1, n): + _edges.append(Hyperedge([i, j])) + + super().__init__(_edges) + + # To do: set up edge partitions + + self.n = n + + +class CompleteBipartiteGraph(Hypergraph): + """A complete bipartite graph with two vertex sets. + + In a complete bipartite graph K_{m,n} (m <= n), there are m + n + vertices partitioned into two sets of sizes m and n. Every vertex + in the first set is connected to every vertex in the second set, + giving m * n edges total. + + Vertices 0 to m-1 form the first set, and vertices m to m+n-1 + form the second set. + + To do: edge partitioning for parallel updates. + + Attributes: + m: Number of vertices in the first set. + n: Number of vertices in the second set. + + Requires: + m <= n + + Example: + + .. code-block:: python + >>> graph = CompleteBipartiteGraph(2, 3) + >>> graph.nvertices + 5 + >>> graph.nedges + 6 + """ + + def __init__(self, m: int, n: int, self_loops: bool = False) -> None: + """Initialize a complete bipartite graph. + + Args: + m: Number of vertices in the first set (vertices 0 to m-1). + n: Number of vertices in the second set (vertices m to m+n-1). + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + assert m <= n, "Require m <= n for CompleteBipartiteGraph." + total_vertices = m + n + + if self_loops: + _edges = [Hyperedge([i]) for i in range(total_vertices)] + self.parts = [list(range(total_vertices))] + else: + _edges = [] + self.parts = [] + + colors = [[] for _ in range(n)] # n colors for bipartite edges + + # Connect every vertex in first set to every vertex in second set + for i in range(m): + for j in range(m, m + n): + edge_idx = len(_edges) + _edges.append(Hyperedge([i, j])) + colors[(i + j - m) % n].append(edge_idx) # Do to: explain this coloring + + super().__init__(_edges) + self.parts.extend(colors) + + self.m = m + self.n = n diff --git a/source/pip/qsharp/magnets/geometry/lattice2d.py b/source/pip/qsharp/magnets/geometry/lattice2d.py new file mode 100644 index 0000000000..fc98f9de9d --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/lattice2d.py @@ -0,0 +1,188 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Two-dimensional lattice geometries for quantum simulations. + +This module provides classes for representing 2D lattice structures as +hypergraphs. These lattices are commonly used in quantum spin system +simulations and other two-dimensional quantum systems. +""" + +from qsharp.magnets.geometry.hypergraph import Hyperedge, Hypergraph + + +class Patch2D(Hypergraph): + """A two-dimensional open rectangular lattice. + + Represents a rectangular grid of vertices with nearest-neighbor edges. + The patch has open boundary conditions, meaning edges do not wrap around. + + Vertices are indexed in row-major order: vertex (x, y) has index y * width + x. + + The edges are partitioned into parts for parallel updates: + - Part 0 (if self_loops): Self-loop edges on each vertex + - Part 1: Even-column horizontal edges + - Part 2: Odd-column horizontal edges + - Part 3: Even-row vertical edges + - Part 4: Odd-row vertical edges + + Attributes: + width: Number of vertices in the horizontal direction. + height: Number of vertices in the vertical direction. + + Example: + + .. code-block:: python + >>> patch = Patch2D(3, 2) + >>> patch.nvertices + 6 + >>> patch.nedges + 7 + """ + + def __init__(self, width: int, height: int, self_loops: bool = False) -> None: + """Initialize a 2D patch lattice. + + Args: + width: Number of vertices in the horizontal direction. + height: Number of vertices in the vertical direction. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + + def index(x: int, y: int) -> int: + return y * width + x + + if self_loops: + _edges = [Hyperedge([i]) for i in range(width * height)] + else: + _edges = [] + + # Horizontal edges (connecting (x, y) to (x+1, y)) + horizontal_even = [] + horizontal_odd = [] + for y in range(height): + for x in range(width - 1): + edge_idx = len(_edges) + _edges.append(Hyperedge([index(x, y), index(x + 1, y)])) + if x % 2 == 0: + horizontal_even.append(edge_idx) + else: + horizontal_odd.append(edge_idx) + + # Vertical edges (connecting (x, y) to (x, y+1)) + vertical_even = [] + vertical_odd = [] + for y in range(height - 1): + for x in range(width): + edge_idx = len(_edges) + _edges.append(Hyperedge([index(x, y), index(x, y + 1)])) + if y % 2 == 0: + vertical_even.append(edge_idx) + else: + vertical_odd.append(edge_idx) + + super().__init__(_edges) + + # Set up edge partitions for parallel updates + if self_loops: + self.parts = [list(range(width * height))] + else: + self.parts = [] + + self.parts.append(horizontal_even) + self.parts.append(horizontal_odd) + self.parts.append(vertical_even) + self.parts.append(vertical_odd) + + self.width = width + self.height = height + + +class Torus2D(Hypergraph): + """A two-dimensional toroidal (periodic) lattice. + + Represents a rectangular grid of vertices with nearest-neighbor edges + and periodic boundary conditions in both directions. The topology is + that of a torus. + + Vertices are indexed in row-major order: vertex (x, y) has index y * width + x. + + The edges are partitioned into parts for parallel updates: + - Part 0 (if self_loops): Self-loop edges on each vertex + - Part 1: Even-column horizontal edges + - Part 2: Odd-column horizontal edges + - Part 3: Even-row vertical edges + - Part 4: Odd-row vertical edges + + Attributes: + width: Number of vertices in the horizontal direction. + height: Number of vertices in the vertical direction. + + Example: + + .. code-block:: python + >>> torus = Torus2D(3, 2) + >>> torus.nvertices + 6 + >>> torus.nedges + 12 + """ + + def __init__(self, width: int, height: int, self_loops: bool = False) -> None: + """Initialize a 2D torus lattice. + + Args: + width: Number of vertices in the horizontal direction. + height: Number of vertices in the vertical direction. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + + def index(x: int, y: int) -> int: + return y * width + x + + if self_loops: + _edges = [Hyperedge([i]) for i in range(width * height)] + else: + _edges = [] + + # Horizontal edges (connecting (x, y) to ((x+1) % width, y)) + horizontal_even = [] + horizontal_odd = [] + for y in range(height): + for x in range(width): + edge_idx = len(_edges) + _edges.append(Hyperedge([index(x, y), index((x + 1) % width, y)])) + if x % 2 == 0: + horizontal_even.append(edge_idx) + else: + horizontal_odd.append(edge_idx) + + # Vertical edges (connecting (x, y) to (x, (y+1) % height)) + vertical_even = [] + vertical_odd = [] + for y in range(height): + for x in range(width): + edge_idx = len(_edges) + _edges.append(Hyperedge([index(x, y), index(x, (y + 1) % height)])) + if y % 2 == 0: + vertical_even.append(edge_idx) + else: + vertical_odd.append(edge_idx) + + super().__init__(_edges) + + # Set up edge partitions for parallel updates + if self_loops: + self.parts = [list(range(width * height))] + else: + self.parts = [] + + self.parts.append(horizontal_even) + self.parts.append(horizontal_odd) + self.parts.append(vertical_even) + self.parts.append(vertical_odd) + + self.width = width + self.height = height diff --git a/source/pip/tests/magnets/test_complete.py b/source/pip/tests/magnets/test_complete.py new file mode 100644 index 0000000000..ad1bc28769 --- /dev/null +++ b/source/pip/tests/magnets/test_complete.py @@ -0,0 +1,247 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for complete graph data structures.""" + +from qsharp.magnets.geometry.complete import CompleteBipartiteGraph, CompleteGraph + + +# CompleteGraph tests + + +def test_complete_graph_init_basic(): + """Test basic CompleteGraph initialization.""" + graph = CompleteGraph(4) + assert graph.nvertices == 4 + assert graph.nedges == 6 # 4 * 3 / 2 = 6 + assert graph.n == 4 + + +def test_complete_graph_single_vertex(): + """Test CompleteGraph with a single vertex (no edges).""" + graph = CompleteGraph(1) + assert graph.nvertices == 0 + assert graph.nedges == 0 + assert graph.n == 1 + + +def test_complete_graph_two_vertices(): + """Test CompleteGraph with two vertices (one edge).""" + graph = CompleteGraph(2) + assert graph.nvertices == 2 + assert graph.nedges == 1 + + +def test_complete_graph_three_vertices(): + """Test CompleteGraph with three vertices (triangle).""" + graph = CompleteGraph(3) + assert graph.nvertices == 3 + assert graph.nedges == 3 + + +def test_complete_graph_five_vertices(): + """Test CompleteGraph with five vertices.""" + graph = CompleteGraph(5) + assert graph.nvertices == 5 + assert graph.nedges == 10 # 5 * 4 / 2 = 10 + + +def test_complete_graph_edges(): + """Test that CompleteGraph creates correct edges.""" + graph = CompleteGraph(4) + edges = list(graph.edges()) + assert len(edges) == 6 + # All pairs should be present + edge_sets = [set(e.vertices) for e in edges] + assert {0, 1} in edge_sets + assert {0, 2} in edge_sets + assert {0, 3} in edge_sets + assert {1, 2} in edge_sets + assert {1, 3} in edge_sets + assert {2, 3} in edge_sets + + +def test_complete_graph_vertices(): + """Test that CompleteGraph vertices are correct.""" + graph = CompleteGraph(5) + vertices = list(graph.vertices()) + assert vertices == [0, 1, 2, 3, 4] + + +def test_complete_graph_with_self_loops(): + """Test CompleteGraph with self-loops enabled.""" + graph = CompleteGraph(4, self_loops=True) + assert graph.nvertices == 4 + # 4 self-loops + 6 edges = 10 + assert graph.nedges == 10 + + +def test_complete_graph_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + graph = CompleteGraph(3, self_loops=True) + edges = list(graph.edges()) + # First 3 edges should be self-loops + assert edges[0].vertices == [0] + assert edges[1].vertices == [1] + assert edges[2].vertices == [2] + + +def test_complete_graph_edge_count_formula(): + """Test that edge count follows n(n-1)/2 formula.""" + for n in range(1, 10): + graph = CompleteGraph(n) + expected_edges = n * (n - 1) // 2 + assert graph.nedges == expected_edges + + +def test_complete_graph_str(): + """Test string representation.""" + graph = CompleteGraph(4) + assert "4 vertices" in str(graph) + assert "6 edges" in str(graph) + + +def test_complete_graph_inherits_hypergraph(): + """Test that CompleteGraph is a Hypergraph subclass with all methods.""" + from qsharp.magnets.geometry.hypergraph import Hypergraph + + graph = CompleteGraph(4) + assert isinstance(graph, Hypergraph) + assert hasattr(graph, "edges") + assert hasattr(graph, "vertices") + + +# CompleteBipartiteGraph tests + + +def test_complete_bipartite_graph_init_basic(): + """Test basic CompleteBipartiteGraph initialization.""" + graph = CompleteBipartiteGraph(2, 3) + assert graph.nvertices == 5 + assert graph.nedges == 6 # 2 * 3 = 6 + assert graph.m == 2 + assert graph.n == 3 + + +def test_complete_bipartite_graph_single_each(): + """Test CompleteBipartiteGraph with one vertex in each set.""" + graph = CompleteBipartiteGraph(1, 1) + assert graph.nvertices == 2 + assert graph.nedges == 1 + + +def test_complete_bipartite_graph_one_and_many(): + """Test CompleteBipartiteGraph with one vertex in first set.""" + graph = CompleteBipartiteGraph(1, 5) + assert graph.nvertices == 6 + assert graph.nedges == 5 # 1 * 5 = 5 + + +def test_complete_bipartite_graph_square(): + """Test CompleteBipartiteGraph with equal set sizes.""" + graph = CompleteBipartiteGraph(3, 3) + assert graph.nvertices == 6 + assert graph.nedges == 9 # 3 * 3 = 9 + + +def test_complete_bipartite_graph_edges(): + """Test that CompleteBipartiteGraph creates correct edges.""" + graph = CompleteBipartiteGraph(2, 3) + edges = list(graph.edges()) + assert len(edges) == 6 + # Vertices 0, 1 in first set; 2, 3, 4 in second set + edge_sets = [set(e.vertices) for e in edges] + # All pairs between sets should be present + assert {0, 2} in edge_sets + assert {0, 3} in edge_sets + assert {0, 4} in edge_sets + assert {1, 2} in edge_sets + assert {1, 3} in edge_sets + assert {1, 4} in edge_sets + # No edges within sets + assert {0, 1} not in edge_sets + assert {2, 3} not in edge_sets + assert {2, 4} not in edge_sets + assert {3, 4} not in edge_sets + + +def test_complete_bipartite_graph_vertices(): + """Test that CompleteBipartiteGraph vertices are correct.""" + graph = CompleteBipartiteGraph(2, 3) + vertices = list(graph.vertices()) + assert vertices == [0, 1, 2, 3, 4] + + +def test_complete_bipartite_graph_with_self_loops(): + """Test CompleteBipartiteGraph with self-loops enabled.""" + graph = CompleteBipartiteGraph(2, 3, self_loops=True) + assert graph.nvertices == 5 + # 5 self-loops + 6 edges = 11 + assert graph.nedges == 11 + + +def test_complete_bipartite_graph_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + graph = CompleteBipartiteGraph(2, 2, self_loops=True) + edges = list(graph.edges()) + # First 4 edges should be self-loops + assert edges[0].vertices == [0] + assert edges[1].vertices == [1] + assert edges[2].vertices == [2] + assert edges[3].vertices == [3] + + +def test_complete_bipartite_graph_edge_count_formula(): + """Test that edge count follows m * n formula.""" + for m in range(1, 6): + for n in range(m, 6): + graph = CompleteBipartiteGraph(m, n) + expected_edges = m * n + assert graph.nedges == expected_edges + + +def test_complete_bipartite_graph_parts_without_self_loops(): + """Test edge partitioning without self-loops.""" + graph = CompleteBipartiteGraph(3, 4) + # Should have at least n parts for bipartite coloring + assert len(graph.parts) >= 4 + + +def test_complete_bipartite_graph_parts_with_self_loops(): + """Test edge partitioning with self-loops.""" + graph = CompleteBipartiteGraph(3, 4, self_loops=True) + # Should have n + 1 parts: self-loops + n color groups + assert len(graph.parts) == 5 + + +def test_complete_bipartite_graph_parts_non_overlapping(): + """Test that edges in the same part don't share vertices.""" + graph = CompleteBipartiteGraph(3, 4) + # Skip the first part if it contains all edges (default from Hypergraph) + parts_to_check = graph.parts + if len(parts_to_check) > 0 and len(parts_to_check[0]) == graph.nedges: + parts_to_check = parts_to_check[1:] + for part_indices in parts_to_check: + used_vertices = set() + for idx in part_indices: + edge = graph._edge_list[idx] + assert not any(v in used_vertices for v in edge.vertices) + used_vertices.update(edge.vertices) + + +def test_complete_bipartite_graph_str(): + """Test string representation.""" + graph = CompleteBipartiteGraph(2, 3) + assert "5 vertices" in str(graph) + assert "6 edges" in str(graph) + + +def test_complete_bipartite_graph_inherits_hypergraph(): + """Test that CompleteBipartiteGraph is a Hypergraph subclass with all methods.""" + from qsharp.magnets.geometry.hypergraph import Hypergraph + + graph = CompleteBipartiteGraph(2, 3) + assert isinstance(graph, Hypergraph) + assert hasattr(graph, "edges") + assert hasattr(graph, "vertices") + assert hasattr(graph, "edges_by_part") diff --git a/source/pip/tests/magnets/test_lattice2d.py b/source/pip/tests/magnets/test_lattice2d.py new file mode 100644 index 0000000000..8be8e816f6 --- /dev/null +++ b/source/pip/tests/magnets/test_lattice2d.py @@ -0,0 +1,277 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for 2D lattice data structures.""" + +from qsharp.magnets.geometry.lattice2d import Patch2D, Torus2D + + +# Patch2D tests + + +def test_patch2d_init_basic(): + """Test basic Patch2D initialization.""" + patch = Patch2D(3, 2) + assert patch.nvertices == 6 + # 2 horizontal edges per row * 2 rows + 3 vertical edges per column * 1 = 7 + assert patch.nedges == 7 + assert patch.width == 3 + assert patch.height == 2 + + +def test_patch2d_single_vertex(): + """Test Patch2D with a single vertex (no edges).""" + patch = Patch2D(1, 1) + assert patch.nvertices == 0 + assert patch.nedges == 0 + assert patch.width == 1 + assert patch.height == 1 + + +def test_patch2d_single_row(): + """Test Patch2D with a single row (like Chain1D).""" + patch = Patch2D(4, 1) + assert patch.nvertices == 4 + assert patch.nedges == 3 # Only horizontal edges + + +def test_patch2d_single_column(): + """Test Patch2D with a single column.""" + patch = Patch2D(1, 4) + assert patch.nvertices == 4 + assert patch.nedges == 3 # Only vertical edges + + +def test_patch2d_square(): + """Test Patch2D with a square lattice.""" + patch = Patch2D(3, 3) + assert patch.nvertices == 9 + # 2 horizontal * 3 rows + 3 vertical * 2 = 12 + assert patch.nedges == 12 + + +def test_patch2d_edges(): + """Test that Patch2D creates correct nearest-neighbor edges.""" + patch = Patch2D(2, 2) + edges = list(patch.edges()) + # Should have 4 edges: 2 horizontal + 2 vertical + assert len(edges) == 4 + # Vertices: 0=(0,0), 1=(1,0), 2=(0,1), 3=(1,1) + # Horizontal: [0,1], [2,3] + # Vertical: [0,2], [1,3] + edge_sets = [set(e.vertices) for e in edges] + assert {0, 1} in edge_sets + assert {2, 3} in edge_sets + assert {0, 2} in edge_sets + assert {1, 3} in edge_sets + + +def test_patch2d_vertices(): + """Test that Patch2D vertices are correct.""" + patch = Patch2D(3, 2) + vertices = list(patch.vertices()) + assert vertices == [0, 1, 2, 3, 4, 5] + + +def test_patch2d_with_self_loops(): + """Test Patch2D with self-loops enabled.""" + patch = Patch2D(3, 2, self_loops=True) + assert patch.nvertices == 6 + # 6 self-loops + 7 nearest-neighbor edges = 13 + assert patch.nedges == 13 + + +def test_patch2d_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + patch = Patch2D(2, 2, self_loops=True) + edges = list(patch.edges()) + # First 4 edges should be self-loops + assert edges[0].vertices == [0] + assert edges[1].vertices == [1] + assert edges[2].vertices == [2] + assert edges[3].vertices == [3] + + +def test_patch2d_parts_without_self_loops(): + """Test edge partitioning without self-loops.""" + patch = Patch2D(4, 4) + # Should have 4 parts: horizontal even/odd, vertical even/odd + assert len(patch.parts) == 4 + + +def test_patch2d_parts_with_self_loops(): + """Test edge partitioning with self-loops.""" + patch = Patch2D(3, 3, self_loops=True) + # Should have 5 parts: self-loops + 4 edge groups + assert len(patch.parts) == 5 + + +def test_patch2d_parts_non_overlapping(): + """Test that edges in the same part don't share vertices.""" + patch = Patch2D(4, 4) + for part_indices in patch.parts: + used_vertices = set() + for idx in part_indices: + edge = patch._edge_list[idx] + assert not any(v in used_vertices for v in edge.vertices) + used_vertices.update(edge.vertices) + + +def test_patch2d_str(): + """Test string representation.""" + patch = Patch2D(3, 2) + assert "6 vertices" in str(patch) + assert "7 edges" in str(patch) + + +# Torus2D tests + + +def test_torus2d_init_basic(): + """Test basic Torus2D initialization.""" + torus = Torus2D(3, 2) + assert torus.nvertices == 6 + # 3 horizontal edges per row * 2 rows + 3 vertical edges per column * 2 = 12 + assert torus.nedges == 12 + assert torus.width == 3 + assert torus.height == 2 + + +def test_torus2d_single_vertex(): + """Test Torus2D with a single vertex (self-edge in both directions).""" + torus = Torus2D(1, 1) + assert torus.nvertices == 1 + # One horizontal wrap + one vertical wrap, both connect vertex 0 to itself + assert torus.nedges == 2 + + +def test_torus2d_single_row(): + """Test Torus2D with a single row (like Ring1D + vertical wraps).""" + torus = Torus2D(4, 1) + assert torus.nvertices == 4 + # 4 horizontal + 4 vertical wraps + assert torus.nedges == 8 + + +def test_torus2d_single_column(): + """Test Torus2D with a single column.""" + torus = Torus2D(1, 4) + assert torus.nvertices == 4 + # 4 horizontal wraps + 4 vertical + assert torus.nedges == 8 + + +def test_torus2d_square(): + """Test Torus2D with a square lattice.""" + torus = Torus2D(3, 3) + assert torus.nvertices == 9 + # 3 horizontal * 3 rows + 3 vertical * 3 columns = 18 + assert torus.nedges == 18 + + +def test_torus2d_edges(): + """Test that Torus2D creates correct edges including wrap-around.""" + torus = Torus2D(2, 2) + edges = list(torus.edges()) + # Should have 8 edges: 4 horizontal + 4 vertical + assert len(edges) == 8 + # Vertices: 0=(0,0), 1=(1,0), 2=(0,1), 3=(1,1) + edge_sets = [set(e.vertices) for e in edges] + # Horizontal edges (including wraps) + assert {0, 1} in edge_sets # (0,0)-(1,0) + assert {2, 3} in edge_sets # (0,1)-(1,1) + # Vertical edges (including wraps) + assert {0, 2} in edge_sets # (0,0)-(0,1) + assert {1, 3} in edge_sets # (1,0)-(1,1) + + +def test_torus2d_vertices(): + """Test that Torus2D vertices are correct.""" + torus = Torus2D(3, 2) + vertices = list(torus.vertices()) + assert vertices == [0, 1, 2, 3, 4, 5] + + +def test_torus2d_with_self_loops(): + """Test Torus2D with self-loops enabled.""" + torus = Torus2D(3, 2, self_loops=True) + assert torus.nvertices == 6 + # 6 self-loops + 12 nearest-neighbor edges = 18 + assert torus.nedges == 18 + + +def test_torus2d_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + torus = Torus2D(2, 2, self_loops=True) + edges = list(torus.edges()) + # First 4 edges should be self-loops + assert edges[0].vertices == [0] + assert edges[1].vertices == [1] + assert edges[2].vertices == [2] + assert edges[3].vertices == [3] + + +def test_torus2d_parts_without_self_loops(): + """Test edge partitioning without self-loops.""" + torus = Torus2D(4, 4) + # Should have 4 parts: horizontal even/odd, vertical even/odd + assert len(torus.parts) == 4 + + +def test_torus2d_parts_with_self_loops(): + """Test edge partitioning with self-loops.""" + torus = Torus2D(3, 3, self_loops=True) + # Should have 5 parts: self-loops + 4 edge groups + assert len(torus.parts) == 5 + + +def test_torus2d_parts_non_overlapping(): + """Test that edges in the same part don't share vertices.""" + torus = Torus2D(4, 4) + for part_indices in torus.parts: + used_vertices = set() + for idx in part_indices: + edge = torus._edge_list[idx] + assert not any(v in used_vertices for v in edge.vertices) + used_vertices.update(edge.vertices) + + +def test_torus2d_str(): + """Test string representation.""" + torus = Torus2D(3, 2) + assert "6 vertices" in str(torus) + assert "12 edges" in str(torus) + + +def test_torus2d_vs_patch2d_edge_count(): + """Test that torus has more edges than patch of same dimensions.""" + for width in range(2, 5): + for height in range(2, 5): + patch = Patch2D(width, height) + torus = Torus2D(width, height) + # Torus has width + height extra edges (wrapping) + assert torus.nedges == patch.nedges + width + height + + +def test_patch2d_inherits_hypergraph(): + """Test that Patch2D is a Hypergraph subclass with all methods.""" + from qsharp.magnets.geometry.hypergraph import Hypergraph + + patch = Patch2D(3, 3) + assert isinstance(patch, Hypergraph) + # Test inherited methods work + assert hasattr(patch, "edges") + assert hasattr(patch, "vertices") + assert hasattr(patch, "edges_by_part") + + +def test_torus2d_inherits_hypergraph(): + """Test that Torus2D is a Hypergraph subclass with all methods.""" + from qsharp.magnets.geometry.hypergraph import Hypergraph + + torus = Torus2D(3, 3) + assert isinstance(torus, Hypergraph) + # Test inherited methods work + assert hasattr(torus, "edges") + assert hasattr(torus, "vertices") + assert hasattr(torus, "edges_by_part") From ecd7e237e01d689ec01b815382f8d48ec3f30789 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Sat, 7 Feb 2026 09:10:24 +0100 Subject: [PATCH 10/45] ISA queries, trace queries, and enumeration. (#2916) This introduces ISA queries and trace queries in the Python API. With that one can implement a preliminary implementation for the central `estimate` function which can enumerate traces and ISA using these queries and perform estimation. It also moves some of the preliminary models from the test code to proper places in the API (`qsharp.qre.models`) with `AQREGateBased` as first architecture and `SurfaceCode` as first QEC model. Further it defines the application base class with a `QSharpApplication` as first implementation of it. --- source/pip/benchmarks/bench_qre.py | 26 +- source/pip/qsharp/qre/__init__.py | 27 +- source/pip/qsharp/qre/_application.py | 139 +++++ source/pip/qsharp/qre/_architecture.py | 25 + source/pip/qsharp/qre/_enumeration.py | 27 +- source/pip/qsharp/qre/_estimation.py | 52 ++ source/pip/qsharp/qre/_instruction.py | 8 +- source/pip/qsharp/qre/_isa_enumeration.py | 101 ++-- source/pip/qsharp/qre/_qre.py | 21 +- source/pip/qsharp/qre/_qre.pyi | 540 ++++++++++++++++- source/pip/qsharp/qre/_trace.py | 89 +++ source/pip/qsharp/qre/instruction_ids.py | 89 +-- source/pip/qsharp/qre/instruction_ids.pyi | 92 +++ source/pip/qsharp/qre/models/__init__.py | 7 + source/pip/qsharp/qre/models/qec/__init__.py | 6 + .../qsharp/qre/models/qec/_surface_code.py | 93 +++ .../pip/qsharp/qre/models/qubits/__init__.py | 6 + source/pip/qsharp/qre/models/qubits/_aqre.py | 65 +++ source/pip/src/qre.rs | 551 +++++++++++++++++- source/pip/tests/test_qre.py | 490 ++++++++-------- source/qre/src/trace.rs | 30 +- source/qre/src/trace/tests.rs | 2 +- .../src/trace/transforms/lattice_surgery.rs | 25 +- 23 files changed, 2046 insertions(+), 465 deletions(-) create mode 100644 source/pip/qsharp/qre/_application.py create mode 100644 source/pip/qsharp/qre/_estimation.py create mode 100644 source/pip/qsharp/qre/_trace.py create mode 100644 source/pip/qsharp/qre/instruction_ids.pyi create mode 100644 source/pip/qsharp/qre/models/__init__.py create mode 100644 source/pip/qsharp/qre/models/qec/__init__.py create mode 100644 source/pip/qsharp/qre/models/qec/_surface_code.py create mode 100644 source/pip/qsharp/qre/models/qubits/__init__.py create mode 100644 source/pip/qsharp/qre/models/qubits/_aqre.py diff --git a/source/pip/benchmarks/bench_qre.py b/source/pip/benchmarks/bench_qre.py index 6dd54f2af9..561aa2c0b4 100644 --- a/source/pip/benchmarks/bench_qre.py +++ b/source/pip/benchmarks/bench_qre.py @@ -3,6 +3,7 @@ import timeit from dataclasses import dataclass, KW_ONLY, field +from qsharp.qre.models import AQREGateBased, SurfaceCode from qsharp.qre._enumeration import _enumerate_instances @@ -35,30 +36,13 @@ def bench_enumerate_isas(): # Add the tests directory to sys.path to import test_qre # TODO: Remove this once the models in test_qre are moved to a proper module sys.path.append(os.path.join(os.path.dirname(__file__), "../tests")) - import test_qre # type: ignore + from test_qre import ExampleLogicalFactory, ExampleFactory # type: ignore - from qsharp.qre._isa_enumeration import ( - Context, - ISAQuery, - ProductNode, - ) - - ctx = Context(architecture=test_qre.ExampleArchitecture()) + ctx = AQREGateBased().context() # Hierarchical factory using from_components - query = ProductNode( - sources=[ - ISAQuery(test_qre.SurfaceCode), - ISAQuery( - test_qre.ExampleLogicalFactory, - source=ProductNode( - sources=[ - ISAQuery(test_qre.SurfaceCode), - ISAQuery(test_qre.ExampleFactory), - ] - ), - ), - ] + query = SurfaceCode.q() * ExampleLogicalFactory.q( + source=SurfaceCode.q() * ExampleFactory.q() ) number = 100 diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index 771a23ea14..d6dbb24e29 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -1,38 +1,61 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from ._application import Application, QSharpApplication +from ._architecture import Architecture +from ._estimation import estimate from ._instruction import ( LOGICAL, PHYSICAL, Encoding, + ISATransform, constraint, instruction, - ISATransform, ) +from ._isa_enumeration import ISAQuery, ISARefNode, ISA_ROOT from ._qre import ( ISA, + InstructionFrontier, Constraint, ConstraintBound, + EstimationResult, + FactoryResult, ISARequirements, + Block, + Trace, block_linear_function, constant_function, linear_function, ) -from ._architecture import Architecture +from ._trace import LatticeSurgery, PSSPC, TraceQuery __all__ = [ "block_linear_function", "constant_function", "constraint", + "estimate", "instruction", "linear_function", + "Application", "Architecture", + "Block", "Constraint", "ConstraintBound", "Encoding", + "EstimationResult", + "FactoryResult", + "InstructionFrontier", "ISA", + "ISA_ROOT", + "ISAQuery", + "ISARefNode", "ISARequirements", "ISATransform", + "LatticeSurgery", + "PSSPC", + "QSharpApplication", + "Trace", + "TraceQuery", "LOGICAL", "PHYSICAL", ] diff --git a/source/pip/qsharp/qre/_application.py b/source/pip/qsharp/qre/_application.py new file mode 100644 index 0000000000..43e81ea4eb --- /dev/null +++ b/source/pip/qsharp/qre/_application.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import types +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import ( + Any, + Callable, + ClassVar, + Generic, + Protocol, + TypeVar, + Generator, + get_type_hints, + cast, +) + +from .._qsharp import logical_counts +from ..estimator import LogicalCounts +from ._enumeration import _enumerate_instances +from ._qre import Trace +from .instruction_ids import CCX, MEAS_Z, RZ, T + + +class DataclassProtocol(Protocol): + __dataclass_fields__: ClassVar[dict] + + +TraceParameters = TypeVar("TraceParameters", DataclassProtocol, types.NoneType) + + +class Application(ABC, Generic[TraceParameters]): + """ + An application defines a class of quantum computation problems along with a + method to generate traces for specific problem instances. + + We distinguish between application and trace parameters. The application + parameters define which particular instance of the application we want to + consider. The trace parameters define how to generate a trace. They + change the specific way in which we solve the problem, but not the problem + itself. + + For example, in quantum cryptography, the application parameters could + define the key size for an RSA prime product, while the trace parameters + define which algorithm to use to break the cryptography, as well as + parameters therein. + """ + + @abstractmethod + def get_trace(self, parameters: TraceParameters) -> Trace: + """Return the trace corresponding to this application.""" + + def context(self, **kwargs) -> _Context: + """Create a new enumeration context for this application.""" + return _Context(self, **kwargs) + + def enumerate_traces( + self, + **kwargs, + ) -> Generator[Trace, None, None]: + """Yields all traces of an application given its dataclass parameters.""" + + param_type = get_type_hints(self.__class__.get_trace).get("parameters") + if param_type is types.NoneType: + yield self.get_trace(None) # type: ignore + return + + if isinstance(param_type, TypeVar): + for c in param_type.__constraints__: + if c is not types.NoneType: + param_type = c + break + for parameters in _enumerate_instances(cast(type, param_type), **kwargs): + yield self.get_trace(parameters) + + +class _Context: + application: Application + kwargs: dict[str, Any] + + def __init__(self, application: Application, **kwargs): + self.application = application + self.kwargs = kwargs + + +@dataclass +class QSharpApplication(Application[None]): + def __init__(self, entry_expr: str | Callable | LogicalCounts): + self._entry_expr = entry_expr + + def get_trace(self, parameters: None = None) -> Trace: + if not isinstance(self._entry_expr, LogicalCounts): + self._counts = logical_counts(self._entry_expr) + else: + self._counts = self._entry_expr + return self._trace_from_logical_counts(self._counts) + + def _trace_from_logical_counts(self, counts: LogicalCounts) -> Trace: + ccx_count = counts.get("cczCount", 0) + counts.get("ccixCount", 0) + + trace = Trace(counts.get("numQubits", 0)) + + rotation_count = counts.get("rotationCount", 0) + rotation_depth = counts.get("rotationDepth", rotation_count) + + if rotation_count != 0: + if rotation_depth > 1: + rotations_per_layer = rotation_count // (rotation_depth - 1) + else: + rotations_per_layer = 0 + + last_layer = rotation_count - (rotations_per_layer * (rotation_depth - 1)) + + if rotations_per_layer != 0: + block = trace.add_block(repetitions=rotation_depth - 1) + for i in range(rotations_per_layer): + block.add_operation(RZ, [i]) + block = trace.add_block() + for i in range(last_layer): + block.add_operation(RZ, [i]) + + if t_count := counts.get("tCount", 0): + block = trace.add_block(repetitions=t_count) + block.add_operation(T, [0]) + + if ccx_count: + block = trace.add_block(repetitions=ccx_count) + block.add_operation(CCX, [0, 1, 2]) + + if meas_count := counts.get("measurementCount", 0): + block = trace.add_block(repetitions=meas_count) + block.add_operation(MEAS_Z, [0]) + + # TODO: handle memory qubits + + return trace diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index 0d95bb0a93..fe991aff42 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + from abc import ABC, abstractmethod +from dataclasses import dataclass, field from ._qre import ISA @@ -10,3 +13,25 @@ class Architecture(ABC): @property @abstractmethod def provided_isa(self) -> ISA: ... + + def context(self) -> _Context: + """Create a new enumeration context for this architecture.""" + return _Context(self.provided_isa) + + +@dataclass(slots=True, frozen=True) +class _Context: + """ + Context passed through enumeration, holding shared state. + + Attributes: + root_isa: The root ISA for enumeration. + """ + + root_isa: ISA + _bindings: dict[str, ISA] = field(default_factory=dict, repr=False) + + def _with_binding(self, name: str, isa: ISA) -> _Context: + """Return a new context with an additional binding (internal use).""" + new_bindings = {**self._bindings, name: isa} + return _Context(self.root_isa, new_bindings) diff --git a/source/pip/qsharp/qre/_enumeration.py b/source/pip/qsharp/qre/_enumeration.py index 59eb1a9582..d41b279d0c 100644 --- a/source/pip/qsharp/qre/_enumeration.py +++ b/source/pip/qsharp/qre/_enumeration.py @@ -1,7 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Generator, Type, TypeVar, Literal, get_args, get_origin +from typing import ( + Generator, + Type, + TypeVar, + Literal, + get_args, + get_origin, + get_type_hints, +) from dataclasses import MISSING from itertools import product from enum import Enum @@ -57,8 +65,13 @@ class MyConfig: yield cls(**kwargs) return - for field in fields.values(): + # Resolve type hints to handle stringified types from __future__.annotations + type_hints = get_type_hints(cls) + + for field in fields.values(): # type: ignore name = field.name + # Get resolved type or fallback to field.type + current_type = type_hints.get(name, field.type) if name in kwargs: val = kwargs[name] @@ -83,16 +96,16 @@ class MyConfig: values.append(domain) continue - if field.type is bool: + if current_type is bool: values.append([True, False]) continue - if isinstance(field.type, type) and issubclass(field.type, Enum): - values.append(list(field.type)) + if isinstance(current_type, type) and issubclass(current_type, Enum): + values.append(list(current_type)) continue - if get_origin(field.type) is Literal: - values.append(list(get_args(field.type))) + if get_origin(current_type) is Literal: + values.append(list(get_args(current_type))) continue if field.default is not MISSING: diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py new file mode 100644 index 0000000000..79b11b9eb7 --- /dev/null +++ b/source/pip/qsharp/qre/_estimation.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._application import Application +from ._architecture import Architecture +from ._qre import EstimationCollection, estimate_parallel +from ._trace import TraceQuery +from ._isa_enumeration import ISAQuery + + +def estimate( + application: Application, + architecture: Architecture, + trace_query: TraceQuery, + isa_query: ISAQuery, + *, + max_error: float = 1.0, +) -> EstimationCollection: + """ + Estimate the resource requirements for a given application instance and + architecture. + + The application instance might return multiple traces. Each of the traces + is transformed by the trace query, which applies several trace transforms in + sequence. Each transform may return multiple traces. Similarly, the + architecture's ISA is transformed by the ISA query, which applies several + ISA transforms in sequence, each of which may return multiple ISAs. The + estimation is performed for each combination of transformed trace and ISA. + The results are collected into an EstimationCollection and returned. + + The collection only contains the results that are optimal with respect to + the total number of qubits and the total runtime. + + Args: + application (Application): The quantum application to be estimated. + architecture (Architecture): The target quantum architecture. + trace_query (TraceQuery): The trace query to enumerate traces from the + application. + isa_query (ISAQuery): The ISA query to enumerate ISAs from the architecture. + + Returns: + EstimationCollection: A collection of estimation results. + """ + + app_ctx = application.context() + arch_ctx = architecture.context() + + return estimate_parallel( + list(trace_query.enumerate(app_ctx)), + list(isa_query.enumerate(arch_ctx)), + max_error, + ) diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index a74c97376b..9c4b24260e 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -6,7 +6,7 @@ from enum import IntEnum from ._enumeration import _enumerate_instances -from ._isa_enumeration import ISA_ROOT, BindingNode, ISAQuery, Node +from ._isa_enumeration import ISA_ROOT, _BindingNode, _ComponentQuery, ISAQuery from ._qre import ( ISA, Constraint, @@ -193,7 +193,7 @@ def enumerate_isas( yield from component.provided_isa(isa) @classmethod - def q(cls, *, source: Node | None = None, **kwargs) -> ISAQuery: + def q(cls, *, source: ISAQuery | None = None, **kwargs) -> ISAQuery: """ Creates an ISAQuery node for this transform. @@ -205,12 +205,12 @@ def q(cls, *, source: Node | None = None, **kwargs) -> ISAQuery: Returns: ISAQuery: An enumeration node representing this transform. """ - return ISAQuery( + return _ComponentQuery( cls, source=source if source is not None else ISA_ROOT, kwargs=kwargs ) @classmethod - def bind(cls, name: str, node: Node) -> BindingNode: + def bind(cls, name: str, node: ISAQuery) -> _BindingNode: """ Creates a BindingNode for this transform. diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py index 54908aa9a6..0cfe5e5940 100644 --- a/source/pip/qsharp/qre/_isa_enumeration.py +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -9,11 +9,11 @@ from dataclasses import dataclass, field from typing import Generator -from ._architecture import Architecture +from ._architecture import _Context from ._qre import ISA -class Node(ABC): +class ISAQuery(ABC): """ Abstract base class for all nodes in the ISA enumeration tree. @@ -24,7 +24,7 @@ class Node(ABC): """ @abstractmethod - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields all ISA instances represented by this enumeration node. @@ -37,7 +37,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: """ pass - def __add__(self, other: Node) -> SumNode: + def __add__(self, other: ISAQuery) -> _SumNode: """ Performs a union of two enumeration nodes. @@ -59,19 +59,19 @@ def __add__(self, other: Node) -> SumNode: for isa in SurfaceCode.q() + ColorCode.q(): ... """ - if isinstance(self, SumNode) and isinstance(other, SumNode): + if isinstance(self, _SumNode) and isinstance(other, _SumNode): sources = self.sources + other.sources - return SumNode(sources) - elif isinstance(self, SumNode): + return _SumNode(sources) + elif isinstance(self, _SumNode): sources = self.sources + [other] - return SumNode(sources) - elif isinstance(other, SumNode): + return _SumNode(sources) + elif isinstance(other, _SumNode): sources = [self] + other.sources - return SumNode(sources) + return _SumNode(sources) else: - return SumNode([self, other]) + return _SumNode([self, other]) - def __mul__(self, other: Node) -> ProductNode: + def __mul__(self, other: ISAQuery) -> _ProductNode: """ Performs the cross product of two enumeration nodes. @@ -97,19 +97,19 @@ def __mul__(self, other: Node) -> ProductNode: for isa in SurfaceCode.q() * Factory.q(): ... """ - if isinstance(self, ProductNode) and isinstance(other, ProductNode): + if isinstance(self, _ProductNode) and isinstance(other, _ProductNode): sources = self.sources + other.sources - return ProductNode(sources) - elif isinstance(self, ProductNode): + return _ProductNode(sources) + elif isinstance(self, _ProductNode): sources = self.sources + [other] - return ProductNode(sources) - elif isinstance(other, ProductNode): + return _ProductNode(sources) + elif isinstance(other, _ProductNode): sources = [self] + other.sources - return ProductNode(sources) + return _ProductNode(sources) else: - return ProductNode([self, other]) + return _ProductNode([self, other]) - def bind(self, name: str, node: Node) -> "BindingNode": + def bind(self, name: str, node: ISAQuery) -> "_BindingNode": """Create a BindingNode with this node as the component. Args: @@ -124,40 +124,17 @@ def bind(self, name: str, node: Node) -> "BindingNode": .. code-block:: python ExampleErrorCorrection.q().bind("c", ISARefNode("c") * ISARefNode("c")) """ - return BindingNode(name=name, component=self, node=node) + return _BindingNode(name=name, component=self, node=node) @dataclass -class Context: - """ - Context passed through enumeration, holding shared state. - - Attributes: - architecture: The base architecture for enumeration. - """ - - architecture: Architecture - _bindings: dict[str, ISA] = field(default_factory=dict, repr=False) - - @property - def root_isa(self) -> ISA: - """The architecture's provided ISA.""" - return self.architecture.provided_isa - - def _with_binding(self, name: str, isa: ISA) -> "Context": - """Return a new context with an additional binding (internal use).""" - new_bindings = {**self._bindings, name: isa} - return Context(self.architecture, new_bindings) - - -@dataclass -class RootNode(Node): +class RootNode(ISAQuery): """ Represents the architecture's base ISA. Reads from the context instead of holding a reference. """ - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields the architecture ISA from the context. @@ -175,7 +152,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class ISAQuery(Node): +class _ComponentQuery(ISAQuery): """ Query node that enumerates ISAs based on a component type and source. @@ -191,10 +168,10 @@ class ISAQuery(Node): """ component: type - source: Node = field(default_factory=lambda: ISA_ROOT) + source: ISAQuery = field(default_factory=lambda: ISA_ROOT) kwargs: dict = field(default_factory=dict) - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields ISAs generated by the component from source ISAs. @@ -209,7 +186,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class ProductNode(Node): +class _ProductNode(ISAQuery): """ Node representing the Cartesian product of multiple source nodes. @@ -217,9 +194,9 @@ class ProductNode(Node): sources: A list of source nodes to combine. """ - sources: list[Node] + sources: list[ISAQuery] - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields ISAs formed by combining ISAs from all source nodes. @@ -237,7 +214,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class SumNode(Node): +class _SumNode(ISAQuery): """ Node representing the union of multiple source nodes. @@ -245,9 +222,9 @@ class SumNode(Node): sources: A list of source nodes to enumerate sequentially. """ - sources: list[Node] + sources: list[ISAQuery] - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields ISAs from each source node in sequence. @@ -262,7 +239,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class ISARefNode(Node): +class ISARefNode(ISAQuery): """ A reference to a bound ISA in the enumeration context. @@ -274,7 +251,7 @@ class ISARefNode(Node): name: str - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields the bound ISA from the context. @@ -293,7 +270,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class BindingNode(Node): +class _BindingNode(ISAQuery): """ Enumeration node that binds a component to a name. @@ -306,7 +283,7 @@ class BindingNode(Node): Args: name: The name to bind the component to. - component: An EnumerationNode (e.g., ISAQuery) that produces the bound ISAs. + component: An EnumerationNode (e.g., _ComponentQuery) that produces the bound ISAs. node: The child enumeration node that may contain ISARefNodes. Example: @@ -334,10 +311,10 @@ class BindingNode(Node): """ name: str - component: Node - node: Node + component: ISAQuery + node: ISAQuery - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Enumerates child nodes with the bound component in context. diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index c01b87587b..3fdd913414 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -2,16 +2,27 @@ # Licensed under the MIT License. # flake8: noqa E402 +# pyright: reportAttributeAccessIssue=false from .._native import ( - ISA, + block_linear_function, + Block, + constant_function, Constraint, ConstraintBound, - Instruction, - ISARequirements, + estimate_parallel, + EstimationCollection, + EstimationResult, + FactoryResult, FloatFunction, + Instruction, + InstructionFrontier, IntFunction, - block_linear_function, - constant_function, + ISA, + ISARequirements, + Property, linear_function, + LatticeSurgery, + PSSPC, + Trace, ) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 01d999b49e..85be2b136e 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -2,7 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Iterator, Optional, overload +from typing import Any, Iterator, Optional, overload class ISA: @overload @@ -44,6 +44,22 @@ class ISA: """ ... + def get( + self, id: int, default: Optional[Instruction] = None + ) -> Optional[Instruction]: + """ + Gets an instruction by its ID, or returns a default value if not found. + + Args: + id (int): The instruction ID. + default (Optional[Instruction]): The default value to return if the + instruction is not found. + + Returns: + Optional[Instruction]: The instruction, or the default value if not found. + """ + ... + def __len__(self) -> int: """ Returns the number of instructions in the ISA. @@ -422,3 +438,525 @@ def block_linear_function( IntFunction | FloatFunction: The block linear function. """ ... + +class Property: + def __new__(cls, value: Any) -> Property: + """ + Creates a property from a value. + + Args: + value (Any): The value. + """ + ... + + def as_bool(self) -> Optional[bool]: + """ + Returns the value as a boolean. + + Returns: + Optional[bool]: The value as a boolean, or None if it is not a boolean. + """ + ... + + def as_int(self) -> Optional[int]: + """ + Returns the value as an integer. + + Returns: + Optional[int]: The value as an integer, or None if it is not an integer. + """ + ... + + def as_float(self) -> Optional[float]: + """ + Returns the value as a float. + + Returns: + Optional[float]: The value as a float, or None if it is not a float. + """ + ... + + def as_str(self) -> Optional[str]: + """ + Returns the value as a string. + + Returns: + Optional[str]: The value as a string, or None if it is not a string. + """ + ... + + def is_bool(self) -> bool: + """ + Checks if the value is a boolean. + + Returns: + bool: True if the value is a boolean, False otherwise. + """ + ... + + def is_int(self) -> bool: + """ + Checks if the value is an integer. + + Returns: + bool: True if the value is an integer, False otherwise. + """ + ... + + def is_float(self) -> bool: + """ + Checks if the value is a float. + + Returns: + bool: True if the value is a float, False otherwise. + """ + ... + + def is_str(self) -> bool: + """ + Checks if the value is a string. + + Returns: + bool: True if the value is a string, False otherwise. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the property. + + Returns: + str: A string representation of the property. + """ + ... + +class EstimationResult: + """ + Represents the result of a resource estimation. + """ + + @property + def qubits(self) -> int: + """ + The number of logical qubits. + + Returns: + int: The number of logical qubits. + """ + ... + + @property + def runtime(self) -> int: + """ + The runtime in nanoseconds. + + Returns: + int: The runtime in nanoseconds. + """ + ... + + @property + def error(self) -> float: + """ + The error probability of the computation. + + Returns: + float: The error probability of the computation. + """ + ... + + @property + def factories(self) -> dict[int, FactoryResult]: + """ + The factory results. + + Returns: + dict[int, FactoryResult]: A dictionary mapping factory IDs to their results. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the estimation result. + + Returns: + str: A string representation of the estimation result. + """ + ... + +class EstimationCollection: + """ + Represents a collection of estimation results. Results are stored as a 2D + Pareto frontier with physical qubits and runtime as objectives. + """ + + def __new__(cls) -> EstimationCollection: + """ + Creates a new estimation collection. + + Returns: + EstimationCollection: The estimation collection. + """ + ... + + def insert(self, result: EstimationResult) -> None: + """ + Inserts an estimation result into the collection. + + Args: + result (EstimationResult): The estimation result to insert. + """ + ... + + def __len__(self) -> int: + """ + Returns the number of estimation results in the collection. + + Returns: + int: The number of estimation results. + """ + ... + + def __iter__(self) -> Iterator[EstimationResult]: + """ + Returns an iterator over the estimation results. + + Returns: + Iterator[EstimationResult]: The estimation result iterator. + """ + ... + +class FactoryResult: + """ + Represents the result of a factory used in resource estimation. + """ + + @property + def copies(self) -> int: + """ + The number of factory copies. + + Returns: + int: The number of factory copies. + """ + ... + + @property + def runs(self) -> int: + """ + The number of factory runs. + + Returns: + int: The number of factory runs. + """ + ... + + @property + def error_rate(self) -> float: + """ + The error rate of the factory. + + Returns: + float: The error rate of the factory. + """ + ... + + @property + def states(self) -> int: + """ + The number of states produced by the factory. + + Returns: + int: The number of states produced by the factory. + """ + ... + +class Trace: + """ + Represents a quantum program optimized for resource estimation. + + A trace originates from a quantum application and can be modified via trace + transformations. It consists of blocks of operations. + """ + + def __new__(cls, compute_qubits: int) -> Trace: + """ + Creates a new trace. + + Returns: + Trace: The trace. + """ + ... + + def clone_empty(self, compute_qubits: Optional[int] = None) -> Trace: + """ + Creates a new trace with the same metadata but empty block. + + Args: + compute_qubits (Optional[int]): The number of compute qubits. If None, + the number of compute qubits of the original trace is used. + + Returns: + Trace: The new trace. + """ + ... + + @property + def compute_qubits(self) -> int: + """ + The number of compute qubits. + + Returns: + int: The number of compute qubits. + """ + ... + + @property + def base_error(self) -> float: + """ + The base error of the trace. + + Returns: + float: The base error of the trace. + """ + ... + + def increment_base_error(self, amount: float) -> None: + """ + Increments the base error. + + Args: + amount (float): The amount to increment. + """ + ... + + def increment_resource_state(self, resource_id: int, amount: int) -> None: + """ + Increments a resource state count. + + Args: + resource_id (int): The resource state ID. + amount (int): The amount to increment. + """ + ... + + def set_property(self, key: str, value: Property) -> None: + """ + Sets a property. + + Args: + key (str): The property key. + value (Property): The property value. + """ + ... + + def get_property(self, key: str) -> Optional[Property]: + """ + Gets a property. + + Args: + key (str): The property key. + + Returns: + Optional[Property]: The property value, or None if not found. + """ + ... + + @property + def depth(self) -> int: + """ + The trace depth. + + Returns: + int: The trace depth. + """ + ... + + def estimate( + self, isa: ISA, max_error: Optional[float] = None + ) -> Optional[EstimationResult]: + """ + Estimates resources for the trace given a logical ISA. + + Args: + isa (ISA): The logical ISA. + max_error (Optional[float]): The maximum allowed error. If None, + Pareto points are computed. + + Returns: + Optional[EstimationResult]: The estimation result if max_error is + provided, otherwise valid Pareto points. + """ + ... # The implementation in Rust returns Option, so it fits + + @property + def resource_states(self) -> dict[int, int]: + """ + The resource states used in the trace. + + Returns: + dict[int, int]: A dictionary mapping instruction IDs to their counts. + """ + ... + + def add_operation( + self, id: int, qubits: list[int], params: list[float] = [] + ) -> None: + """ + Adds an operation to the trace. + + Args: + id (int): The operation ID. + qubits (list[int]): The qubits involved in the operation. + params (list[float]): The operation parameters. + """ + ... + + def add_block(self, repetitions: int = 1) -> Block: + """ + Adds a block to the trace. + + Args: + repetitions (int): The number of times the block is repeated. + + Returns: + Block: The block. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the trace. + + Returns: + str: A string representation of the trace. + """ + ... + +class Block: + """ + Represents a block of operations in a trace. + + An operation in a block can either refer to an instruction applied to some + qubits or can be another block to create a hierarchical structure. Blocks + can be repeated. + """ + + def add_operation( + self, id: int, qubits: list[int], params: list[float] = [] + ) -> None: + """ + Adds an operation to the block. + + Args: + id (int): The operation ID. + qubits (list[int]): The qubits involved in the operation. + params (list[float]): The operation parameters. + """ + ... + + def add_block(self, repetitions: int = 1) -> Block: + """ + Adds a nested block to the block. + + Args: + repetitions (int): The number of times the block is repeated. + + Returns: + Block: The block. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the block. + + Returns: + str: A string representation of the block. + """ + ... + +class PSSPC: + def __new__(cls, num_ts_per_rotation: int, ccx_magic_states: bool) -> PSSPC: ... + def transform(self, trace: Trace) -> Optional[Trace]: ... + +class LatticeSurgery: + def __new__(cls, slow_down_factor: float) -> LatticeSurgery: ... + def transform(self, trace: Trace) -> Optional[Trace]: ... + +class InstructionFrontier: + """ + Represents a Pareto frontier of instructions with space, time, and error + rates as objectives. + """ + + def __new__(cls) -> InstructionFrontier: + """ + Creates a new instruction frontier. + """ + ... + + def insert(self, point: Instruction): + """ + Inserts an instruction to the frontier. + + Args: + point (Instruction): The instruction to insert. + """ + ... + + def __len__(self) -> int: + """ + Returns the number of instructions in the frontier. + + Returns: + int: The number of instructions. + """ + ... + + def __iter__(self) -> Iterator[Instruction]: + """ + Returns an iterator over the instructions in the frontier. + + Returns: + Iterator[Instruction]: The iterator. + """ + ... + + @staticmethod + def load(filename: str) -> InstructionFrontier: + """ + Loads an instruction frontier from a file. + + Args: + filename (str): The file name. + + Returns: + InstructionFrontier: The loaded instruction frontier. + """ + ... + + def dump(self, filename: str) -> None: + """ + Dumps the instruction frontier to a file. + + Args: + filename (str): The file name. + """ + ... + +def estimate_parallel( + traces: list[Trace], isas: list[ISA], max_error: float = 1.0 +) -> EstimationCollection: + """ + Estimates resources for multiple traces and ISAs in parallel. + + Args: + traces (list[Trace]): The list of traces. + isas (list[ISA]): The list of ISAs. + max_error (float): The maximum allowed error. The default is 1.0. + + Returns: + EstimationCollection: The estimation collection. + """ + ... diff --git a/source/pip/qsharp/qre/_trace.py b/source/pip/qsharp/qre/_trace.py new file mode 100644 index 0000000000..ab1d49f6ce --- /dev/null +++ b/source/pip/qsharp/qre/_trace.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +from abc import ABC, abstractmethod +from dataclasses import dataclass, KW_ONLY, field +from itertools import product +from typing import Any, Optional, Generator, Type +from ._application import _Context +from ._enumeration import _enumerate_instances +from ._qre import PSSPC as _PSSPC, LatticeSurgery as _LatticeSurgery, Trace + + +class TraceTransform(ABC): + @abstractmethod + def transform(self, trace: Trace) -> Optional[Trace]: ... + + @classmethod + def q(cls, **kwargs) -> TraceQuery: + return TraceQuery(cls, **kwargs) + + +@dataclass +class PSSPC(TraceTransform): + _: KW_ONLY + num_ts_per_rotation: int = field( + default=10, metadata={"domain": list(range(1, 21))} + ) + ccx_magic_states: bool = field(default=False) + + def __post_init__(self): + self._psspc = _PSSPC(self.num_ts_per_rotation, self.ccx_magic_states) + + def transform(self, trace: Trace) -> Optional[Trace]: + return self._psspc.transform(trace) + + +@dataclass +class LatticeSurgery(TraceTransform): + _: KW_ONLY + slow_down_factor: float = field(default=1.0, metadata={"domain": [1.0]}) + + def __post_init__(self): + self._lattice_surgery = _LatticeSurgery(self.slow_down_factor) + + def transform(self, trace: Trace) -> Optional[Trace]: + return self._lattice_surgery.transform(trace) + + +class _Node(ABC): + @abstractmethod + def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: ... + + +class RootNode(_Node): + # NOTE: this might be redundant with TransformationNode with an empty sequence + def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: + yield from ctx.application.enumerate_traces(**ctx.kwargs) + + +class TraceQuery(_Node): + sequence: list[tuple[Type, dict[str, Any]]] + + def __init__(self, t: Type, **kwargs): + self.sequence = [(t, kwargs)] + + def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: + for trace in ctx.application.enumerate_traces(**ctx.kwargs): + if not self.sequence: + yield trace + continue + + transformer_instances = [] + + for t, transformer_kwargs in self.sequence: + instances = _enumerate_instances(t, **transformer_kwargs) + transformer_instances.append(instances) + + # TODO: make parallel + for sequence in product(*transformer_instances): + transformed = trace + for transformer in sequence: + transformed = transformer.transform(transformed) + yield transformed + + def __mul__(self, other: TraceQuery) -> TraceQuery: + new_query = TraceQuery.__new__(TraceQuery) + new_query.sequence = self.sequence + other.sequence + return new_query diff --git a/source/pip/qsharp/qre/instruction_ids.py b/source/pip/qsharp/qre/instruction_ids.py index f89bcc6c5b..cec4a9c070 100644 --- a/source/pip/qsharp/qre/instruction_ids.py +++ b/source/pip/qsharp/qre/instruction_ids.py @@ -1,91 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pyright: reportAttributeAccessIssue=false -################### -# Instruction IDs # -################### -# Paulis -PAULI_I = 0x0 -PAULI_X = 0x1 -PAULI_Y = 0x2 -PAULI_Z = 0x3 +from .._native import instruction_ids -# Clifford gates -H = H_XZ = 0x10 -H_XY = 0x11 -H_YZ = 0x12 -SQRT_X = 0x13 -SQRT_X_DAG = 0x14 -SQRT_Y = 0x15 -SQRT_Y_DAG = 0x16 -S = SQRT_Z = 0x17 -S_DAG = SQRT_Z_DAG = 0x18 -CNOT = CX = 0x19 -CY = 0x1A -CZ = 0x1B -SWAP = 0x1C - -# State preparation -PREP_X = 0x30 -PREP_Y = 0x31 -PREP_Z = 0x32 - -# Generic Cliffords -ONE_QUBIT_CLIFFORD = 0x50 -TWO_QUBIT_CLIFFORD = 0x51 -N_QUBIT_CLIFFORD = 0x52 - -# Measurements -MEAS_X = 0x100 -MEAS_Y = 0x101 -MEAS_Z = 0x102 -MEAS_RESET_X = 0x103 -MEAS_RESET_Y = 0x104 -MEAS_RESET_Z = 0x105 -MEAS_XX = 0x106 -MEAS_YY = 0x107 -MEAS_ZZ = 0x108 -MEAS_XZ = 0x109 -MEAS_XY = 0x10A -MEAS_YZ = 0x10B - -# Non-Clifford gates -SQRT_SQRT_X = 0x400 -SQRT_SQRT_X_DAG = 0x401 -SQRT_SQRT_Y = 0x402 -SQRT_SQRT_Y_DAG = 0x403 -SQRT_SQRT_Z = T = 0x404 -SQRT_SQRT_Z_DAG = T_DAG = 0x405 -CCX = 0x406 -CCY = 0x407 -CCZ = 0x408 -CSWAP = 0x409 -AND = 0x40A -AND_DAG = 0x40B -RX = 0x40C -RY = 0x40D -RZ = 0x40E -CRX = 0x40F -CRY = 0x410 -CRZ = 0x411 -RXX = 0x412 -RYY = 0x413 -RZZ = 0x414 - -# Multi-qubit Pauli measurement -MULTI_PAULI_MEAS = 0x1000 - -# Some generic logical instructions -LATTICE_SURGERY = 0x1100 - -# Memory/compute operations (used in compute parts of memory-compute layouts) -READ_FROM_MEMORY = 0x1200 -WRITE_TO_MEMORY = 0x1201 - -# Some special hardware physical instructions -CYCLIC_SHIFT = 0x1300 - -# Generic operation (for unified RE) -GENERIC = 0xFFFF +for name in instruction_ids.__all__: + globals()[name] = getattr(instruction_ids, name) diff --git a/source/pip/qsharp/qre/instruction_ids.pyi b/source/pip/qsharp/qre/instruction_ids.pyi new file mode 100644 index 0000000000..72934487f8 --- /dev/null +++ b/source/pip/qsharp/qre/instruction_ids.pyi @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Paulis +PAULI_I: int +PAULI_X: int +PAULI_Y: int +PAULI_Z: int + +# Clifford gates +H: int +H_XZ: int +H_XY: int +H_YZ: int +SQRT_X: int +SQRT_X_DAG: int +SQRT_Y: int +SQRT_Y_DAG: int +S: int +SQRT_Z: int +S_DAG: int +SQRT_Z_DAG: int +CNOT: int +CX: int +CY: int +CZ: int +SWAP: int + +# State preparation +PREP_X: int +PREP_Y: int +PREP_Z: int + +# Generic Cliffords +ONE_QUBIT_CLIFFORD: int +TWO_QUBIT_CLIFFORD: int +N_QUBIT_CLIFFORD: int + +# Measurements +MEAS_X: int +MEAS_Y: int +MEAS_Z: int +MEAS_RESET_X: int +MEAS_RESET_Y: int +MEAS_RESET_Z: int +MEAS_XX: int +MEAS_YY: int +MEAS_ZZ: int +MEAS_XZ: int +MEAS_XY: int +MEAS_YZ: int + +# Non-Clifford gates +SQRT_SQRT_X: int +SQRT_SQRT_X_DAG: int +SQRT_SQRT_Y: int +SQRT_SQRT_Y_DAG: int +SQRT_SQRT_Z: int +T: int +SQRT_SQRT_Z_DAG: int +T_DAG: int +CCX: int +CCY: int +CCZ: int +CSWAP: int +AND: int +AND_DAG: int +RX: int +RY: int +RZ: int +CRX: int +CRY: int +CRZ: int +RXX: int +RYY: int +RZZ: int + +# Multi-qubit Pauli measurement +MULTI_PAULI_MEAS: int + +# Some generic logical instructions +LATTICE_SURGERY: int + +# Memory/compute operations (used in compute parts of memory-compute layouts) +READ_FROM_MEMORY: int +WRITE_TO_MEMORY: int + +# Some special hardware physical instructions +CYCLIC_SHIFT: int + +# Generic operation (for unified RE) +GENERIC: int diff --git a/source/pip/qsharp/qre/models/__init__.py b/source/pip/qsharp/qre/models/__init__.py new file mode 100644 index 0000000000..10a82c977e --- /dev/null +++ b/source/pip/qsharp/qre/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .qec import SurfaceCode +from .qubits import AQREGateBased + +__all__ = ["SurfaceCode", "AQREGateBased"] diff --git a/source/pip/qsharp/qre/models/qec/__init__.py b/source/pip/qsharp/qre/models/qec/__init__.py new file mode 100644 index 0000000000..c813df0dc4 --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._surface_code import SurfaceCode + +__all__ = ["SurfaceCode"] diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py new file mode 100644 index 0000000000..52bf94439f --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +from dataclasses import KW_ONLY, dataclass, field +from typing import Generator +from ..._instruction import ( + ISA, + ISARequirements, + ISATransform, + instruction, + constraint, + ConstraintBound, + LOGICAL, +) +from ..._qre import linear_function +from ...instruction_ids import CNOT, GENERIC, H, LATTICE_SURGERY, MEAS_Z + + +@dataclass +class SurfaceCode(ISATransform): + """ + Attributes: + crossing_prefactor: float + The prefactor for logical error rate due to error correction + crossings. (Default is 0.03, see Eq. (11) in arXiv:1208.0928) + error_correction_threshold: float + The error correction threshold for the surface code. Default is + 0.01 (1%), see arXiv:1009.3686. + + Hyper parameters: + distance: int + The code distance of the surface code. + + References: + - [arXiv:1208.0928](https://arxiv.org/abs/1208.0928) + - [arXiv:1009.3686](https://arxiv.org/abs/1009.3686) + """ + + crossing_prefactor: float = 0.03 + error_correction_threshold: float = 0.01 + _: KW_ONLY + distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(H, error_rate=ConstraintBound.lt(0.01)), + constraint(CNOT, arity=2, error_rate=ConstraintBound.lt(0.01)), + constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), + ) + + def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + cnot_time = impl_isa[CNOT].expect_time() + h_time = impl_isa[H].expect_time() + meas_time = impl_isa[MEAS_Z].expect_time() + + physical_error_rate = max( + impl_isa[CNOT].expect_error_rate(), + impl_isa[H].expect_error_rate(), + impl_isa[MEAS_Z].expect_error_rate(), + ) + + space_formula = linear_function(2 * self.distance**2) + + time_value = (h_time + meas_time + cnot_time * 4) * self.distance + + error_formula = linear_function( + self.crossing_prefactor + * ( + (physical_error_rate / self.error_correction_threshold) + ** ((self.distance + 1) // 2) + ) + ) + + yield ISA( + instruction( + GENERIC, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + ), + instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + ), + ) diff --git a/source/pip/qsharp/qre/models/qubits/__init__.py b/source/pip/qsharp/qre/models/qubits/__init__.py new file mode 100644 index 0000000000..f9907adbc3 --- /dev/null +++ b/source/pip/qsharp/qre/models/qubits/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._aqre import AQREGateBased + +__all__ = ["AQREGateBased"] diff --git a/source/pip/qsharp/qre/models/qubits/_aqre.py b/source/pip/qsharp/qre/models/qubits/_aqre.py new file mode 100644 index 0000000000..b6add8ae2d --- /dev/null +++ b/source/pip/qsharp/qre/models/qubits/_aqre.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field + +from ..._architecture import Architecture +from ...instruction_ids import CNOT, CZ, MEAS_Z, PAULI_I, H, T +from ..._instruction import ISA, Encoding, instruction + + +@dataclass +class AQREGateBased(Architecture): + """ + References: + - [arXiv:2211.07629](https://arxiv.org/abs/2211.07629) + """ + + _: KW_ONLY + error_rate: float = field(default=1e-4) + + @property + def provided_isa(self) -> ISA: + return ISA( + instruction( + PAULI_I, + encoding=Encoding.PHYSICAL, + arity=1, + time=50, + error_rate=self.error_rate, + ), + instruction( + CNOT, + encoding=Encoding.PHYSICAL, + arity=2, + time=50, + error_rate=self.error_rate, + ), + instruction( + CZ, + encoding=Encoding.PHYSICAL, + arity=2, + time=50, + error_rate=self.error_rate, + ), + instruction( + H, + encoding=Encoding.PHYSICAL, + arity=1, + time=50, + error_rate=self.error_rate, + ), + instruction( + MEAS_Z, + encoding=Encoding.PHYSICAL, + arity=1, + time=100, + error_rate=self.error_rate, + ), + instruction( + T, + encoding=Encoding.PHYSICAL, + time=50, + error_rate=self.error_rate, + ), + ) diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index fd8e80a5cd..d9e870990c 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -1,22 +1,48 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use pyo3::{IntoPyObjectExt, prelude::*, types::PyTuple}; +use std::ptr::NonNull; + +use pyo3::{ + IntoPyObjectExt, + exceptions::{PyException, PyKeyError, PyTypeError}, + prelude::*, + types::{PyDict, PyTuple}, +}; +use qre::TraceTransform; +use serde::{Deserialize, Serialize}; pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(constant_function, m)?)?; m.add_function(wrap_pyfunction!(linear_function, m)?)?; m.add_function(wrap_pyfunction!(block_linear_function, m)?)?; + m.add_function(wrap_pyfunction!(estimate_parallel, m)?)?; + + m.add("EstimationError", m.py().get_type::())?; + + add_instruction_ids(m)?; + Ok(()) } +pyo3::create_exception!(qsharp.qre, EstimationError, PyException); + #[allow(clippy::upper_case_acronyms)] #[pyclass] pub struct ISA(qre::ISA); @@ -63,12 +89,20 @@ impl ISA { pub fn __getitem__(&self, id: u64) -> PyResult { match self.0.get(&id) { Some(instr) => Ok(Instruction(instr.clone())), - None => Err(PyErr::new::(format!( + None => Err(PyKeyError::new_err(format!( "Instruction with id {id} not found" ))), } } + #[pyo3(signature = (id, default=None))] + pub fn get(&self, id: u64, default: Option<&Instruction>) -> Option { + match self.0.get(&id) { + Some(instr) => Some(Instruction(instr.clone())), + None => default.cloned(), + } + } + #[allow(clippy::needless_pass_by_value)] pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { let iter = ISAIterator { @@ -129,7 +163,10 @@ impl ISARequirements { } } +#[allow(clippy::unsafe_derive_deserialize)] #[pyclass] +#[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct Instruction(qre::Instruction); #[pymethods] @@ -227,6 +264,24 @@ impl Instruction { } } +impl qre::ParetoItem3D for Instruction { + type Objective1 = u64; + type Objective2 = u64; + type Objective3 = f64; + + fn objective1(&self) -> Self::Objective1 { + self.0.expect_space(None) + } + + fn objective2(&self) -> Self::Objective2 { + self.0.expect_time(None) + } + + fn objective3(&self) -> Self::Objective3 { + self.0.expect_error_rate(None) + } +} + #[pyclass] pub struct Constraint(qre::InstructionConstraint); @@ -252,9 +307,7 @@ fn convert_encoding(encoding: u64) -> PyResult { match encoding { 0 => Ok(qre::Encoding::Physical), 1 => Ok(qre::Encoding::Logical), - _ => Err(PyErr::new::( - "Invalid encoding value", - )), + _ => Err(EstimationError::new_err("Invalid encoding value")), } } @@ -289,6 +342,61 @@ impl ConstraintBound { } } +#[pyclass] +pub struct Property(qre::Property); + +#[pymethods] +impl Property { + #[new] + pub fn new(value: &Bound<'_, PyAny>) -> PyResult { + if value.is_instance_of::() { + Ok(Property(qre::Property::new_bool(value.extract()?))) + } else if let Ok(i) = value.extract::() { + Ok(Property(qre::Property::new_int(i))) + } else if let Ok(f) = value.extract::() { + Ok(Property(qre::Property::new_float(f))) + } else { + Ok(Property(qre::Property::new_str(value.to_string()))) + } + } + + fn as_bool(&self) -> Option { + self.0.as_bool() + } + + fn as_int(&self) -> Option { + self.0.as_int() + } + + fn as_float(&self) -> Option { + self.0.as_float() + } + + fn as_str(&self) -> Option { + self.0.as_str().map(String::from) + } + + fn is_bool(&self) -> bool { + self.0.is_bool() + } + + fn is_int(&self) -> bool { + self.0.is_int() + } + + fn is_float(&self) -> bool { + self.0.is_float() + } + + fn is_str(&self) -> bool { + self.0.is_str() + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + #[pyclass] pub struct IntFunction(qre::VariableArityFunction); @@ -303,7 +411,7 @@ pub fn constant_function<'py>(value: &Bound<'py, PyAny>) -> PyResult( + Err(PyTypeError::new_err( "Value must be either an integer or a float", )) } @@ -316,7 +424,7 @@ pub fn linear_function<'py>(slope: &Bound<'py, PyAny>) -> PyResult() { FloatFunction(qre::VariableArityFunction::linear(s)).into_bound_py_any(slope.py()) } else { - Err(PyErr::new::( + Err(PyTypeError::new_err( "Slope must be either an integer or a float", )) } @@ -334,8 +442,435 @@ pub fn block_linear_function<'py>( FloatFunction(qre::VariableArityFunction::block_linear(block_size, s)) .into_bound_py_any(slope.py()) } else { - Err(PyErr::new::( + Err(PyTypeError::new_err( "Slope must be either an integer or a float", )) } } + +#[derive(Default)] +#[pyclass] +pub struct EstimationCollection(qre::EstimationCollection); + +#[pymethods] +impl EstimationCollection { + #[new] + pub fn new() -> Self { + Self::default() + } + + pub fn insert(&mut self, result: &EstimationResult) { + self.0.insert(result.0.clone()); + } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let iter = EstimationCollectionIterator { + iter: slf.0.iter().cloned().collect::>().into_iter(), + }; + Py::new(slf.py(), iter) + } +} + +#[pyclass] +pub struct EstimationCollectionIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl EstimationCollectionIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next().map(EstimationResult) + } +} + +#[pyclass] +pub struct EstimationResult(qre::EstimationResult); + +#[pymethods] +impl EstimationResult { + #[getter] + pub fn qubits(&self) -> u64 { + self.0.qubits() + } + + #[getter] + pub fn runtime(&self) -> u64 { + self.0.runtime() + } + + #[getter] + pub fn error(&self) -> f64 { + self.0.error() + } + + #[allow(clippy::needless_pass_by_value)] + #[getter] + pub fn factories(self_: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(self_.py()); + + for (id, factory) in self_.0.factories() { + dict.set_item(id, FactoryResult(factory.clone()))?; + } + + Ok(dict) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyclass] +pub struct FactoryResult(qre::FactoryResult); + +#[pymethods] +impl FactoryResult { + #[getter] + pub fn copies(&self) -> u64 { + self.0.copies() + } + + #[getter] + pub fn runs(&self) -> u64 { + self.0.runs() + } + + #[getter] + pub fn states(&self) -> u64 { + self.0.states() + } + + #[getter] + pub fn error_rate(&self) -> f64 { + self.0.error_rate() + } +} + +#[pyclass] +pub struct Trace(qre::Trace); + +#[pymethods] +impl Trace { + #[new] + pub fn new(compute_qubits: u64) -> Self { + Trace(qre::Trace::new(compute_qubits)) + } + + #[pyo3(signature = (compute_qubits = None))] + pub fn clone_empty(&self, compute_qubits: Option) -> Self { + Trace(self.0.clone_empty(compute_qubits)) + } + + #[getter] + pub fn compute_qubits(&self) -> u64 { + self.0.compute_qubits() + } + + #[getter] + pub fn base_error(&self) -> f64 { + self.0.base_error() + } + + pub fn increment_base_error(&mut self, amount: f64) { + self.0.increment_base_error(amount); + } + + pub fn set_property(&mut self, key: String, value: &Property) { + self.0.set_property(key, value.0.clone()); + } + + pub fn get_property(&self, key: &str) -> Option { + self.0.get_property(key).map(|p| Property(p.clone())) + } + + #[allow(clippy::needless_pass_by_value)] + #[getter] + pub fn resource_states(self_: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(self_.py()); + if let Some(resource_states) = self_.0.get_resource_states() { + for (resource_id, count) in resource_states { + if *count != 0 { + dict.set_item(resource_id, *count)?; + } + } + } + Ok(dict) + } + + #[getter] + pub fn depth(&self) -> u64 { + self.0.depth() + } + + #[pyo3(signature = (isa, max_error = None))] + pub fn estimate(&self, isa: &ISA, max_error: Option) -> Option { + self.0 + .estimate(&isa.0, max_error) + .map(EstimationResult) + .ok() + } + + #[pyo3(signature = (id, qubits, params = vec![]))] + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + self.0.add_operation(id, qubits, params); + } + + #[pyo3(signature = (repetitions = 1))] + pub fn add_block(mut slf: PyRefMut<'_, Self>, repetitions: u64) -> PyResult { + let block = slf.0.add_block(repetitions); + let ptr = NonNull::from(block); + Ok(Block { + ptr, + parent: slf.into(), + }) + } + + pub fn increment_resource_state(&mut self, resource_id: u64, amount: u64) { + self.0.increment_resource_state(resource_id, amount); + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyclass(unsendable)] +pub struct Block { + ptr: NonNull, + #[allow(dead_code)] + parent: Py, +} + +#[pymethods] +impl Block { + #[pyo3(signature = (id, qubits, params = vec![]))] + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + unsafe { self.ptr.as_mut() }.add_operation(id, qubits, params); + } + + #[pyo3(signature = (repetitions = 1))] + pub fn add_block(&mut self, py: Python<'_>, repetitions: u64) -> PyResult { + let block = unsafe { self.ptr.as_mut() }.add_block(repetitions); + let ptr = NonNull::from(block); + Ok(Block { + ptr, + parent: self.parent.clone_ref(py), + }) + } + + fn __str__(&self) -> String { + format!("{}", unsafe { self.ptr.as_ref() }) + } +} + +#[allow(clippy::upper_case_acronyms)] +#[pyclass] +pub struct PSSPC(qre::PSSPC); + +#[pymethods] +impl PSSPC { + #[new] + pub fn new(num_ts_per_rotation: u64, ccx_magic_states: bool) -> Self { + PSSPC(qre::PSSPC::new(num_ts_per_rotation, ccx_magic_states)) + } + + pub fn transform(&self, trace: &Trace) -> PyResult { + self.0 + .transform(&trace.0) + .map(Trace) + .map_err(|e| EstimationError::new_err(format!("{e}"))) + } +} + +#[derive(Default)] +#[pyclass] +pub struct LatticeSurgery(qre::LatticeSurgery); + +#[pymethods] +impl LatticeSurgery { + #[new] + pub fn new(slow_down_factor: f64) -> Self { + Self(qre::LatticeSurgery::new(slow_down_factor)) + } + + pub fn transform(&self, trace: &Trace) -> PyResult { + self.0 + .transform(&trace.0) + .map(Trace) + .map_err(|e| EstimationError::new_err(format!("{e}"))) + } +} + +#[pyclass] +pub struct InstructionFrontier(qre::ParetoFrontier3D); + +impl Default for InstructionFrontier { + fn default() -> Self { + InstructionFrontier(qre::ParetoFrontier3D::new()) + } +} + +#[pymethods] +impl InstructionFrontier { + #[new] + pub fn new() -> Self { + Self::default() + } + + pub fn insert(&mut self, point: &Instruction) { + self.0.insert(point.clone()); + } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let iter = InstructionFrontierIterator { + iter: slf.0.iter().cloned().collect::>().into_iter(), + }; + Py::new(slf.py(), iter) + } + + #[staticmethod] + pub fn load(filename: &str) -> PyResult { + let content = std::fs::read_to_string(filename)?; + let frontier = + serde_json::from_str(&content).map_err(|e| EstimationError::new_err(format!("{e}")))?; + Ok(InstructionFrontier(frontier)) + } + + pub fn dump(&self, filename: &str) -> PyResult<()> { + let content = + serde_json::to_string(&self.0).map_err(|e| EstimationError::new_err(format!("{e}")))?; + Ok(std::fs::write(filename, content)?) + } +} + +#[pyclass] +pub struct InstructionFrontierIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl InstructionFrontierIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next() + } +} + +#[allow(clippy::needless_pass_by_value)] +#[pyfunction(signature = (traces, isas, max_error = 1.0))] +pub fn estimate_parallel( + traces: Vec>, + isas: Vec>, + max_error: f64, +) -> EstimationCollection { + let traces: Vec<_> = traces.iter().map(|t| &t.0).collect(); + let isas: Vec<_> = isas.iter().map(|i| &i.0).collect(); + + let collection = qre::estimate_parallel(&traces, &isas, Some(max_error)); + EstimationCollection(collection) +} + +fn add_instruction_ids(m: &Bound<'_, PyModule>) -> PyResult<()> { + #[allow(clippy::wildcard_imports)] + use qre::instruction_ids::*; + + let instruction_ids = PyModule::new(m.py(), "instruction_ids")?; + + macro_rules! add_ids { + ($($name:ident),* $(,)?) => { + $(instruction_ids.add(stringify!($name), $name)?;)* + }; + } + + add_ids!( + PAULI_I, + PAULI_X, + PAULI_Y, + PAULI_Z, + H, + H_XZ, + H_XY, + H_YZ, + SQRT_X, + SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + S, + SQRT_Z, + S_DAG, + SQRT_Z_DAG, + CNOT, + CX, + CY, + CZ, + SWAP, + PREP_X, + PREP_Y, + PREP_Z, + ONE_QUBIT_CLIFFORD, + TWO_QUBIT_CLIFFORD, + N_QUBIT_CLIFFORD, + MEAS_X, + MEAS_Y, + MEAS_Z, + MEAS_RESET_X, + MEAS_RESET_Y, + MEAS_RESET_Z, + MEAS_XX, + MEAS_YY, + MEAS_ZZ, + MEAS_XZ, + MEAS_XY, + MEAS_YZ, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + T, + SQRT_SQRT_Z_DAG, + T_DAG, + CCX, + CCY, + CCZ, + CSWAP, + AND, + AND_DAG, + RX, + RY, + RZ, + CRX, + CRY, + CRZ, + RXX, + RYY, + RZZ, + MULTI_PAULI_MEAS, + LATTICE_SURGERY, + READ_FROM_MEMORY, + WRITE_TO_MEMORY, + CYCLIC_SHIFT, + GENERIC + ); + + m.add_submodule(&instruction_ids)?; + + Ok(()) +} diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 90430f5167..98e1c9de59 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -5,33 +5,30 @@ from enum import Enum from typing import Generator +import qsharp from qsharp.qre import ( ISA, LOGICAL, - Architecture, - ConstraintBound, + PSSPC, + EstimationResult, ISARequirements, ISATransform, + LatticeSurgery, + QSharpApplication, + Trace, constraint, + estimate, instruction, linear_function, ) -from qsharp.qre._enumeration import _enumerate_instances +from qsharp.qre.models import SurfaceCode, AQREGateBased from qsharp.qre._isa_enumeration import ( - BindingNode, - Context, - ISAQuery, ISARefNode, - ProductNode, - SumNode, ) from qsharp.qre.instruction_ids import ( - CNOT, + CCX, GENERIC, LATTICE_SURGERY, - MEAS_Z, - TWO_QUBIT_CLIFFORD, - H, T, ) @@ -39,78 +36,6 @@ # pull requests and then moved out of the tests. -class ExampleArchitecture(Architecture): - @property - def provided_isa(self) -> ISA: - return ISA( - instruction(H, time=50, error_rate=1e-3), - instruction(CNOT, arity=2, time=50, error_rate=1e-3), - instruction(MEAS_Z, time=100, error_rate=1e-3), - instruction(TWO_QUBIT_CLIFFORD, arity=2, time=50, error_rate=1e-3), - instruction(GENERIC, time=50, error_rate=1e-4), - instruction(T, time=50, error_rate=1e-4), - ) - - -@dataclass -class SurfaceCode(ISATransform): - _: KW_ONLY - distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) - - @staticmethod - def required_isa() -> ISARequirements: - return ISARequirements( - constraint(H, error_rate=ConstraintBound.lt(0.01)), - constraint(CNOT, arity=2, error_rate=ConstraintBound.lt(0.01)), - constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), - ) - - def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: - crossing_prefactor: float = 0.03 - error_correction_threshold: float = 0.01 - - cnot_time = impl_isa[CNOT].expect_time() - h_time = impl_isa[H].expect_time() - meas_time = impl_isa[MEAS_Z].expect_time() - - physical_error_rate = max( - impl_isa[CNOT].expect_error_rate(), - impl_isa[H].expect_error_rate(), - impl_isa[MEAS_Z].expect_error_rate(), - ) - - space_formula = linear_function(2 * self.distance**2) - - time_value = (h_time + meas_time + cnot_time * 4) * self.distance - - error_formula = linear_function( - crossing_prefactor - * ( - (physical_error_rate / error_correction_threshold) - ** ((self.distance + 1) // 2) - ) - ) - - yield ISA( - instruction( - GENERIC, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - ), - instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - ), - ) - - @dataclass class ExampleFactory(ISATransform): _: KW_ONLY @@ -147,7 +72,7 @@ def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: def test_isa_from_architecture(): - arch = ExampleArchitecture() + arch = AQREGateBased() code = SurfaceCode() # Verify that the architecture satisfies the code requirements @@ -162,6 +87,8 @@ def test_isa_from_architecture(): def test_enumerate_instances(): + from qsharp.qre._enumeration import _enumerate_instances + instances = list(_enumerate_instances(SurfaceCode)) # There are 12 instances with distances from 3 to 25 @@ -184,6 +111,8 @@ def test_enumerate_instances(): def test_enumerate_instances_bool(): + from qsharp.qre._enumeration import _enumerate_instances + @dataclass class BoolConfig: _: KW_ONLY @@ -196,6 +125,8 @@ class BoolConfig: def test_enumerate_instances_enum(): + from qsharp.qre._enumeration import _enumerate_instances + class Color(Enum): RED = 1 GREEN = 2 @@ -214,6 +145,8 @@ class EnumConfig: def test_enumerate_instances_failure(): + from qsharp.qre._enumeration import _enumerate_instances + import pytest @dataclass @@ -227,6 +160,8 @@ class InvalidConfig: def test_enumerate_instances_single(): + from qsharp.qre._enumeration import _enumerate_instances + @dataclass class SingleConfig: value: int = 42 @@ -237,6 +172,8 @@ class SingleConfig: def test_enumerate_instances_literal(): + from qsharp.qre._enumeration import _enumerate_instances + from typing import Literal @dataclass @@ -251,50 +188,32 @@ class LiteralConfig: def test_enumerate_isas(): - ctx = Context(architecture=ExampleArchitecture()) + ctx = AQREGateBased().context() # This will enumerate the 4 ISAs for the error correction code - count = sum(1 for _ in ISAQuery(SurfaceCode).enumerate(ctx)) + count = sum(1 for _ in SurfaceCode.q().enumerate(ctx)) assert count == 12 # This will enumerate the 2 ISAs for the error correction code when # restricting the domain - count = sum( - 1 for _ in ISAQuery(SurfaceCode, kwargs={"distance": [3, 5]}).enumerate(ctx) - ) + count = sum(1 for _ in SurfaceCode.q(distance=[3, 4]).enumerate(ctx)) assert count == 2 # This will enumerate the 3 ISAs for the factory - count = sum(1 for _ in ISAQuery(ExampleFactory).enumerate(ctx)) + count = sum(1 for _ in ExampleFactory.q().enumerate(ctx)) assert count == 3 # This will enumerate 36 ISAs for all products between the 12 error # correction code ISAs and the 3 factory ISAs - count = sum( - 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ] - ).enumerate(ctx) - ) + count = sum(1 for _ in (SurfaceCode.q() * ExampleFactory.q()).enumerate(ctx)) assert count == 36 # When providing a list, components are chained (OR operation). This # enumerates ISAs from first factory instance OR second factory instance count = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - SumNode( - sources=[ - ISAQuery(ExampleFactory), - ISAQuery(ExampleFactory), - ] - ), - ] + for _ in ( + SurfaceCode.q() * (ExampleFactory.q() + ExampleFactory.q()) ).enumerate(ctx) ) assert count == 72 @@ -304,13 +223,9 @@ def test_enumerate_isas(): # factory instance count = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ISAQuery(ExampleFactory), - ] - ).enumerate(ctx) + for _ in (SurfaceCode.q() * ExampleFactory.q() * ExampleFactory.q()).enumerate( + ctx + ) ) assert count == 108 @@ -318,62 +233,32 @@ def test_enumerate_isas(): # from the product of other components as its source count = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery( - ExampleLogicalFactory, - source=ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ] - ), - ), - ] + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q(source=(SurfaceCode.q() * ExampleFactory.q())) ).enumerate(ctx) ) assert count == 1296 def test_binding_node(): - """Test BindingNode with ISARefNode for component bindings""" - ctx = Context(architecture=ExampleArchitecture()) + """Test binding nodes with ISARefNode for component bindings""" + ctx = AQREGateBased().context() # Test basic binding: same code used twice # Without binding: 12 codes × 12 codes = 144 combinations - count_without = sum( - 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(SurfaceCode), - ] - ).enumerate(ctx) - ) + count_without = sum(1 for _ in (SurfaceCode.q() * SurfaceCode.q()).enumerate(ctx)) assert count_without == 144 # With binding: 12 codes (same instance used twice) count_with = sum( 1 - for _ in BindingNode( - name="c", - component=ISAQuery(SurfaceCode), - node=ProductNode( - sources=[ISARefNode("c"), ISARefNode("c")], - ), - ).enumerate(ctx) + for _ in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx) ) assert count_with == 12 # Verify the binding works: with binding, both should use same params - for isa in BindingNode( - name="c", - component=ISAQuery(SurfaceCode), - node=ProductNode( - sources=[ISARefNode("c"), ISARefNode("c")], - ), - ).enumerate(ctx): + for isa in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx): logical_gates = [g for g in isa if g.encoding == LOGICAL] # Should have 2 logical gates (GENERIC and LATTICE_SURGERY) assert len(logical_gates) == 2 @@ -381,33 +266,19 @@ def test_binding_node(): # Test binding with factories (nested bindings) count_without = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ] + for _ in ( + SurfaceCode.q() * ExampleFactory.q() * SurfaceCode.q() * ExampleFactory.q() ).enumerate(ctx) ) assert count_without == 1296 # 12 * 3 * 12 * 3 count_with = sum( 1 - for _ in BindingNode( - name="c", - component=ISAQuery(SurfaceCode), - node=BindingNode( - name="f", - component=ISAQuery(ExampleFactory), - node=ProductNode( - sources=[ - ISARefNode("c"), - ISARefNode("f"), - ISARefNode("c"), - ISARefNode("f"), - ], - ), + for _ in SurfaceCode.bind( + "c", + ExampleFactory.bind( + "f", + ISARefNode("c") * ISARefNode("f") * ISARefNode("c") * ISARefNode("f"), ), ).enumerate(ctx) ) @@ -417,19 +288,11 @@ def test_binding_node(): # Without binding: 4 outer codes × (4 inner codes × 3 factories × 3 levels) count_without = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery( - ExampleLogicalFactory, - source=ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ] - ), - ), - ] + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q( + source=(SurfaceCode.q() * ExampleFactory.q()), + ) ).enumerate(ctx) ) assert count_without == 1296 # 12 * 12 * 3 * 3 @@ -437,22 +300,11 @@ def test_binding_node(): # With binding: 4 codes (same used twice) × 3 factories × 3 levels count_with = sum( 1 - for _ in BindingNode( - name="c", - component=ISAQuery(SurfaceCode), - node=ProductNode( - sources=[ - ISARefNode("c"), - ISAQuery( - ExampleLogicalFactory, - source=ProductNode( - sources=[ - ISARefNode("c"), - ISAQuery(ExampleFactory), - ] - ), - ), - ] + for _ in SurfaceCode.bind( + "c", + ISARefNode("c") + * ExampleLogicalFactory.q( + source=(ISARefNode("c") * ExampleFactory.q()), ), ).enumerate(ctx) ) @@ -461,44 +313,32 @@ def test_binding_node(): # Test binding with kwargs count_with_kwargs = sum( 1 - for _ in BindingNode( - name="c", - component=ISAQuery(SurfaceCode, kwargs={"distance": 5}), - node=ProductNode( - sources=[ISARefNode("c"), ISARefNode("c")], - ), - ).enumerate(ctx) + for _ in SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) ) assert count_with_kwargs == 1 # Only distance=5 # Verify kwargs are applied - for isa in BindingNode( - name="c", - component=ISAQuery(SurfaceCode, kwargs={"distance": 5}), - node=ProductNode( - sources=[ISARefNode("c"), ISARefNode("c")], - ), - ).enumerate(ctx): + for isa in ( + SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) + ): logical_gates = [g for g in isa if g.encoding == LOGICAL] assert all(g.space(1) == 50 for g in logical_gates) # Test multiple independent bindings (nested) count = sum( 1 - for _ in BindingNode( - name="c1", - component=ISAQuery(SurfaceCode), - node=BindingNode( - name="c2", - component=ISAQuery(ExampleFactory), - node=ProductNode( - sources=[ - ISARefNode("c1"), - ISARefNode("c1"), - ISARefNode("c2"), - ISARefNode("c2"), - ], - ), + for _ in SurfaceCode.bind( + "c1", + ExampleFactory.bind( + "c2", + ISARefNode("c1") + * ISARefNode("c1") + * ISARefNode("c2") + * ISARefNode("c2"), ), ).enumerate(ctx) ) @@ -507,8 +347,8 @@ def test_binding_node(): def test_binding_node_errors(): - """Test error handling for BindingNode""" - ctx = Context(architecture=ExampleArchitecture()) + """Test error handling for binding nodes""" + ctx = AQREGateBased().context() # Test ISARefNode enumerate with undefined binding raises ValueError try: @@ -519,64 +359,208 @@ def test_binding_node_errors(): def test_product_isa_enumeration_nodes(): - terminal = ISAQuery(SurfaceCode) + from qsharp.qre._isa_enumeration import _ComponentQuery, _ProductNode + + terminal = SurfaceCode.q() query = terminal * terminal # Multiplication should create ProductNode - assert isinstance(query, ProductNode) + assert isinstance(query, _ProductNode) assert len(query.sources) == 2 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Multiplying again should extend the sources query = query * terminal - assert isinstance(query, ProductNode) + assert isinstance(query, _ProductNode) assert len(query.sources) == 3 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Also from the other side query = terminal * query - assert isinstance(query, ProductNode) + assert isinstance(query, _ProductNode) assert len(query.sources) == 4 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Also for two ProductNodes query = query * query - assert isinstance(query, ProductNode) + assert isinstance(query, _ProductNode) assert len(query.sources) == 8 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) def test_sum_isa_enumeration_nodes(): - terminal = ISAQuery(SurfaceCode) + from qsharp.qre._isa_enumeration import _ComponentQuery, _SumNode + + terminal = SurfaceCode.q() query = terminal + terminal # Multiplication should create SumNode - assert isinstance(query, SumNode) + assert isinstance(query, _SumNode) assert len(query.sources) == 2 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Multiplying again should extend the sources query = query + terminal - assert isinstance(query, SumNode) + assert isinstance(query, _SumNode) assert len(query.sources) == 3 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Also from the other side query = terminal + query - assert isinstance(query, SumNode) + assert isinstance(query, _SumNode) assert len(query.sources) == 4 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Also for two SumNodes query = query + query - assert isinstance(query, SumNode) + assert isinstance(query, _SumNode) assert len(query.sources) == 8 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) + + +def test_qsharp_application(): + from qsharp.qre._enumeration import _enumerate_instances + + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + trace = app.get_trace() + + assert trace.compute_qubits == 3 + assert trace.depth == 3 + assert trace.resource_states == {} + + isa = ISA( + instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + time=1000, + error_rate=linear_function(1e-6), + space=linear_function(50), + ), + instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8, space=400), + instruction(CCX, encoding=LOGICAL, time=2000, error_rate=1e-10, space=800), + ) + + # Properties from the program + counts = qsharp.logical_counts(code) + num_ts = counts["tCount"] + num_ccx = counts["cczCount"] + num_rotations = counts["rotationCount"] + rotation_depth = counts["rotationDepth"] + + lattice_surgery = LatticeSurgery() + + counter = 0 + for psspc in _enumerate_instances(PSSPC): + counter += 1 + trace2 = psspc.transform(trace) + assert trace2 is not None + trace2 = lattice_surgery.transform(trace2) + assert trace2 is not None + assert trace2.compute_qubits == 12 + assert ( + trace2.depth + == num_ts + + num_ccx * 3 + + num_rotations + + rotation_depth * psspc.num_ts_per_rotation + ) + if psspc.ccx_magic_states: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations, + CCX: num_ccx, + } + else: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx + } + result = trace2.estimate(isa, max_error=float("inf")) + assert result is not None + _assert_estimation_result(trace2, result, isa) + assert counter == 40 + + +def test_trace_enumeration(): + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + + from qsharp.qre._trace import RootNode + + ctx = app.context() + root = RootNode() + assert sum(1 for _ in root.enumerate(ctx)) == 1 + + assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 40 + + assert sum(1 for _ in LatticeSurgery.q().enumerate(ctx)) == 1 + + q = PSSPC.q() * LatticeSurgery.q() + assert sum(1 for _ in q.enumerate(ctx)) == 40 + + +def test_estimation_max_error(): + from qsharp.estimator import LogicalCounts + + app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) + arch = AQREGateBased() + + for max_error in [1e-1, 1e-2, 1e-3, 1e-4]: + results = estimate( + app, + arch, + PSSPC.q() * LatticeSurgery.q(), + SurfaceCode.q() * ExampleFactory.q(), + max_error=max_error, + ) + + assert len(results) == 1 + assert next(iter(results)).error <= max_error + + +def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): + actual_qubits = ( + isa[LATTICE_SURGERY].expect_space(trace.compute_qubits) + + isa[T].expect_space() * result.factories[T].copies + ) + if CCX in trace.resource_states: + actual_qubits += isa[CCX].expect_space() * result.factories[CCX].copies + assert result.qubits == actual_qubits + + assert ( + result.runtime + == isa[LATTICE_SURGERY].expect_time(trace.compute_qubits) * trace.depth + ) + + actual_error = ( + trace.base_error + + isa[LATTICE_SURGERY].expect_error_rate(trace.compute_qubits) * trace.depth + + isa[T].expect_error_rate() * result.factories[T].states + ) + if CCX in trace.resource_states: + actual_error += isa[CCX].expect_error_rate() * result.factories[CCX].states + assert abs(result.error - actual_error) <= 1e-8 diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 0193b3c9db..c557251534 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -552,20 +552,28 @@ fn get_error_rate_by_id(isa: &ISA, id: u64) -> Result { .ok_or(Error::CannotExtractErrorRate(id)) } -fn estimate_chunks<'a>(traces: &[&'a Trace], isas: &[&'a ISA]) -> Vec { - let mut local_collection = Vec::new(); - for trace in traces { - for isa in isas { - if let Ok(estimation) = trace.estimate(isa, None) { - local_collection.push(estimation); +#[must_use] +pub fn estimate_parallel<'a>( + traces: &[&'a Trace], + isas: &[&'a ISA], + max_error: Option, +) -> EstimationCollection { + fn estimate_chunks<'a>( + traces: &[&'a Trace], + isas: &[&'a ISA], + max_error: Option, + ) -> Vec { + let mut local_collection = Vec::new(); + for trace in traces { + for isa in isas { + if let Ok(estimation) = trace.estimate(isa, max_error) { + local_collection.push(estimation); + } } } + local_collection } - local_collection -} -#[must_use] -pub fn estimate_parallel<'a>(traces: &[&'a Trace], isas: &[&'a ISA]) -> EstimationCollection { let mut collection = EstimationCollection::new(); std::thread::scope(|scope| { let num_threads = std::thread::available_parallelism() @@ -577,7 +585,7 @@ pub fn estimate_parallel<'a>(traces: &[&'a Trace], isas: &[&'a ISA]) -> Estimati for chunk in traces.chunks(chunk_size) { let tx = tx.clone(); - scope.spawn(move || tx.send(estimate_chunks(chunk, isas))); + scope.spawn(move || tx.send(estimate_chunks(chunk, isas, max_error))); } drop(tx); diff --git a/source/qre/src/trace/tests.rs b/source/qre/src/trace/tests.rs index 6509b30048..57c422c8a4 100644 --- a/source/qre/src/trace/tests.rs +++ b/source/qre/src/trace/tests.rs @@ -144,7 +144,7 @@ fn test_lattice_surgery_transform() { assert_eq!(trace.depth(), 2); - let ls = LatticeSurgery::new(); + let ls = LatticeSurgery::default(); let transformed = ls.transform(&trace).expect("Transformation failed"); assert_eq!(transformed.compute_qubits(), 3); diff --git a/source/qre/src/trace/transforms/lattice_surgery.rs b/source/qre/src/trace/transforms/lattice_surgery.rs index 425606b99d..fd3ff45f72 100644 --- a/source/qre/src/trace/transforms/lattice_surgery.rs +++ b/source/qre/src/trace/transforms/lattice_surgery.rs @@ -4,21 +4,36 @@ use crate::trace::TraceTransform; use crate::{Error, Trace, instruction_ids}; -#[derive(Default)] -pub struct LatticeSurgery; +pub struct LatticeSurgery { + slow_down_factor: f64, +} + +impl Default for LatticeSurgery { + fn default() -> Self { + Self { + slow_down_factor: 1.0, + } + } +} impl LatticeSurgery { #[must_use] - pub fn new() -> Self { - Self + pub fn new(slow_down_factor: f64) -> Self { + Self { slow_down_factor } } } impl TraceTransform for LatticeSurgery { + #[allow( + clippy::cast_possible_truncation, + clippy::cast_precision_loss, + clippy::cast_sign_loss + )] fn transform(&self, trace: &Trace) -> Result { let mut transformed = trace.clone_empty(None); - let block = transformed.add_block(trace.depth()); + let block = + transformed.add_block((trace.depth() as f64 * self.slow_down_factor).ceil() as u64); block.add_operation( instruction_ids::LATTICE_SURGERY, (0..trace.compute_qubits()).collect(), From aee65fa88ca6542da910039fcf24740b57cb856c Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Mon, 9 Feb 2026 00:40:07 -0800 Subject: [PATCH 11/45] Magnets: changed implementation of edge coloring (#2925) Changed how edge coloring is implemented: - Old: self.part was a list of lists indicating the partitions of the edges. - New: self.color is a dictionary keyed on the vertices of the edge with values indicating the "color" or which part the edge belongs. - Property ncolors gives the number of parts (colors) in the partition Convention: self loops are all called "color" -1. Remaining colors are 0, 1, .... - Automatic indexing (creating lists of parts) then places the self loops at the end (index -1). Updated all the examples of graphs to using this coloring implementation - Chain1D, Ring1D, Patch2D, Torus2D, and CompleteBipartiteGraph have only minor changes. - CompleteGraph has a new implementation of edge coloring. All test files updated. --- .../pip/qsharp/magnets/geometry/complete.py | 56 +++++-- .../pip/qsharp/magnets/geometry/hypergraph.py | 81 +++++----- .../pip/qsharp/magnets/geometry/lattice1d.py | 42 ++--- .../pip/qsharp/magnets/geometry/lattice2d.py | 133 ++++++++-------- source/pip/tests/magnets/test_complete.py | 57 +++---- source/pip/tests/magnets/test_hypergraph.py | 95 +++++------ source/pip/tests/magnets/test_lattice1d.py | 148 +++++++++++------- source/pip/tests/magnets/test_lattice2d.py | 92 ++++++----- 8 files changed, 395 insertions(+), 309 deletions(-) diff --git a/source/pip/qsharp/magnets/geometry/complete.py b/source/pip/qsharp/magnets/geometry/complete.py index 595b5ff162..c38cdfc2b5 100644 --- a/source/pip/qsharp/magnets/geometry/complete.py +++ b/source/pip/qsharp/magnets/geometry/complete.py @@ -53,10 +53,42 @@ def __init__(self, n: int, self_loops: bool = False) -> None: for i in range(n): for j in range(i + 1, n): _edges.append(Hyperedge([i, j])) - super().__init__(_edges) - # To do: set up edge partitions + # Set colors for self-loop edges if enabled + if self_loops: + for i in range(n): + self.color[(i,)] = -1 # Self-loop edges get color -1 + + # Edge coloring for parallel updates + # The even case: n-1 colors are needed + if n % 2 == 0: + m = n - 1 + for i in range(m): + self.color[(i, n - 1)] = ( + i # Connect vertex n-1 to all others with unique colors + ) + for j in range(1, (m - 1) // 2 + 1): + a = (i + j) % m + b = (i - j) % m + if a < b: + self.color[(a, b)] = i + else: + self.color[(b, a)] = i + + # The odd case: n colors are needed + # This is the round-robin tournament scheduling algorithm for odd n + # Set m = n for ease of reading + else: + m = n + for i in range(m): + for j in range(1, (m - 1) // 2 + 1): + a = (i + j) % m + b = (i - j) % m + if a < b: + self.color[(a, b)] = i + else: + self.color[(b, a)] = i self.n = n @@ -105,22 +137,28 @@ def __init__(self, m: int, n: int, self_loops: bool = False) -> None: if self_loops: _edges = [Hyperedge([i]) for i in range(total_vertices)] - self.parts = [list(range(total_vertices))] + else: _edges = [] - self.parts = [] - - colors = [[] for _ in range(n)] # n colors for bipartite edges # Connect every vertex in first set to every vertex in second set for i in range(m): for j in range(m, m + n): edge_idx = len(_edges) _edges.append(Hyperedge([i, j])) - colors[(i + j - m) % n].append(edge_idx) # Do to: explain this coloring - super().__init__(_edges) - self.parts.extend(colors) + + # Set colors for self-loop edges if enabled + if self_loops: + for i in range(total_vertices): + self.color[(i,)] = -1 # Self-loop edges get color -1 + + # Color edges based on the second vertex index to create n parallel partitions + for i in range(m): + for j in range(m, m + n): + self.color[(i, j)] = ( + i + j - m + ) % n # Color edges based on second vertex index self.m = m self.n = n diff --git a/source/pip/qsharp/magnets/geometry/hypergraph.py b/source/pip/qsharp/magnets/geometry/hypergraph.py index f64dc79e63..706ef9a1b5 100644 --- a/source/pip/qsharp/magnets/geometry/hypergraph.py +++ b/source/pip/qsharp/magnets/geometry/hypergraph.py @@ -24,17 +24,17 @@ class Hyperedge: - Two-body interactions: 2 vertices - Multi-body interactions: 3+ vertices Each hyperedge is defined by a set of unique vertex indices, which are - stored in sorted order for consistency. + stored as a sorted tuple for consistency and hashability. Attributes: - vertices: Sorted list of vertex indices connected by this hyperedge. + vertices: Sorted tuple of vertex indices connected by this hyperedge. Example: .. code-block:: python >>> edge = Hyperedge([2, 0, 1]) >>> edge.vertices - [0, 1, 2] + (0, 1, 2) """ def __init__(self, vertices: list[int]) -> None: @@ -43,10 +43,10 @@ def __init__(self, vertices: list[int]) -> None: Args: vertices: List of vertex indices. Will be sorted internally. """ - self.vertices: list[int] = sorted(set(vertices)) + self.vertices: tuple[int, ...] = tuple(sorted(set(vertices))) def __repr__(self) -> str: - return f"Hyperedge({self.vertices})" + return f"Hyperedge({list(self.vertices)})" class Hypergraph: @@ -59,9 +59,9 @@ class Hypergraph: Attributes: _edge_list: List of hyperedges in the order they were added. _vertex_set: Set of all unique vertex indices in the hypergraph. - parts: List of lists, where each sublist contains indices of edges - belonging to a specific part of an edge partitioning. This is useful - for parallelism in certain architectures. + color: Dictionary mapping edge vertex tuples to color indices. Initially + all edges have color index 0. This is useful for parallelism in + certain architectures. Example: @@ -82,9 +82,15 @@ def __init__(self, edges: list[Hyperedge]) -> None: """ self._vertex_set = set() self._edge_list = edges - self.parts = [list(range(len(edges)))] # Single partition by default + self.color: dict[tuple[int, ...], int] = {} # All edges start with color 0 for edge in edges: self._vertex_set.update(edge.vertices) + self.color[edge.vertices] = 0 + + @property + def ncolors(self) -> int: + """Return the number of distinct colors used in the edge coloring.""" + return len(set(self.color.values())) @property def nedges(self) -> int: @@ -96,19 +102,18 @@ def nvertices(self) -> int: """Return the number of vertices in the hypergraph.""" return len(self._vertex_set) - def add_edge(self, edge: Hyperedge, part: int = 0) -> None: + def add_edge(self, edge: Hyperedge, color: int = 0) -> None: """Add a hyperedge to the hypergraph. Args: edge: The Hyperedge instance to add. - part: Partition index, used for implementations - with edge partitioning for parallel updates. By - default, all edges are added to the single part - with index 0. + color: Color index for the edge, used for implementations + with edge coloring for parallel updates. By + default, all edges are assigned color 0. """ self._edge_list.append(edge) self._vertex_set.update(edge.vertices) - self.parts[part].append(len(self._edge_list) - 1) # Add to specified partition + self.color[edge.vertices] = color def vertices(self) -> Iterator[int]: """Iterate over all vertex indices in the hypergraph. @@ -126,19 +131,18 @@ def edges(self) -> Iterator[Hyperedge]: """ return iter(self._edge_list) - def edges_by_part(self, part: int) -> Iterator[Hyperedge]: - """Iterate over hyperedges in a specific partition of the hypergraph. + def edges_by_color(self, color: int) -> Iterator[Hyperedge]: + """Iterate over hyperedges with a specific color. Args: - part: Partition index, used for implementations - with edge partitioning for parallel updates. By - default, all edges are in a single part with - index 0. + color: Color index for filtering edges. Returns: - Iterator of hyperedges in the specified partition. + Iterator of hyperedges with the specified color. """ - return iter([self._edge_list[i] for i in self.parts[part]]) + return iter( + [edge for edge in self._edge_list if self.color[edge.vertices] == color] + ) def __str__(self) -> str: return f"Hypergraph with {self.nvertices} vertices and {self.nedges} edges." @@ -174,27 +178,27 @@ def greedy_edge_coloring( edge_indexes = list(range(hypergraph.nedges)) random.shuffle(edge_indexes) - best.parts = [[]] # Initialize with one empty color part - used_vertices = [set()] # Vertices used by each color + used_vertices: list[set[int]] = [set()] # Vertices used by each color + num_colors = 1 for i in range(len(edge_indexes)): edge = hypergraph._edge_list[edge_indexes[i]] - for j in range(len(best.parts) + 1): + for j in range(num_colors + 1): # If we've reached a new color, add it - if j == len(best.parts): - best.parts.append([]) + if j == num_colors: used_vertices.append(set()) + num_colors += 1 # Check if this edge can be added to color j # Note that we always match on the last color if it was added # if so, add it and break if not any(v in used_vertices[j] for v in edge.vertices): - best.parts[j].append(edge_indexes[i]) + best.color[edge.vertices] = j used_vertices[j].update(edge.vertices) break - least_colors = len(best.parts) + least_colors = num_colors # To do: parallelize over trials for trial in range(1, trials): @@ -208,28 +212,29 @@ def greedy_edge_coloring( edge_indexes = list(range(hypergraph.nedges)) random.shuffle(edge_indexes) - parts = [[]] # Initialize with one empty color part + edge_colors: dict[tuple[int, ...], int] = {} # Edge to color mapping used_vertices = [set()] # Vertices used by each color + num_colors = 1 for i in range(len(edge_indexes)): edge = hypergraph._edge_list[edge_indexes[i]] - for j in range(len(parts) + 1): + for j in range(num_colors + 1): # If we've reached a new color, add it - if j == len(parts): - parts.append([]) + if j == num_colors: used_vertices.append(set()) + num_colors += 1 # Check if this edge can be added to color j # if so, add it and break if not any(v in used_vertices[j] for v in edge.vertices): - parts[j].append(edge_indexes[i]) + edge_colors[edge.vertices] = j used_vertices[j].update(edge.vertices) break # If this trial used fewer colors, update best - if len(parts) < least_colors: - least_colors = len(parts) - best.parts = deepcopy(parts) + if num_colors < least_colors: + least_colors = num_colors + best.color = deepcopy(edge_colors) return best diff --git a/source/pip/qsharp/magnets/geometry/lattice1d.py b/source/pip/qsharp/magnets/geometry/lattice1d.py index a5a892fff4..9fffeecbe2 100644 --- a/source/pip/qsharp/magnets/geometry/lattice1d.py +++ b/source/pip/qsharp/magnets/geometry/lattice1d.py @@ -18,10 +18,10 @@ class Chain1D(Hypergraph): The chain has open boundary conditions, meaning the first and last vertices are not connected. - The edges are partitioned into two parts for parallel updates: - - Part 0 (if self_loops): Self-loop edges on each vertex - - Part 1: Even-indexed nearest-neighbor edges (0-1, 2-3, ...) - - Part 2: Odd-indexed nearest-neighbor edges (1-2, 3-4, ...) + Edges are colored for parallel updates: + - Color -1 (if self_loops): Self-loop edges on each vertex + - Color 0: Even-indexed nearest-neighbor edges (0-1, 2-3, ...) + - Color 1: Odd-indexed nearest-neighbor edges (1-2, 3-4, ...) Attributes: length: Number of vertices in the chain. @@ -46,6 +46,7 @@ def __init__(self, length: int, self_loops: bool = False) -> None: """ if self_loops: _edges = [Hyperedge([i]) for i in range(length)] + else: _edges = [] @@ -53,14 +54,14 @@ def __init__(self, length: int, self_loops: bool = False) -> None: _edges.append(Hyperedge([i, i + 1])) super().__init__(_edges) - # Set up edge partitions for parallel updates + # Update color for self-loop edges if self_loops: - self.parts = [list(range(length - 1))] - else: - self.parts = [] + for i in range(length): + self.color[(i,)] = -1 - self.parts.append(list(range(0, length - 1, 2))) - self.parts.append(list(range(1, length - 1, 2))) + for i in range(length - 1): + color = i % 2 + self.color[(i, i + 1)] = color self.length = length @@ -72,10 +73,10 @@ class Ring1D(Hypergraph): The ring has periodic boundary conditions, meaning the first and last vertices are connected. - The edges are partitioned into two parts for parallel updates: - - Part 0 (if self_loops): Self-loop edges on each vertex - - Part 1: Even-indexed nearest-neighbor edges (0-1, 2-3, ...) - - Part 2: Odd-indexed nearest-neighbor edges (1-2, 3-4, ...) + Edges are colored for parallel updates: + - Color -1 (if self_loops): Self-loop edges on each vertex + - Color 0: Even-indexed nearest-neighbor edges (0-1, 2-3, ...) + - Color 1: Odd-indexed nearest-neighbor edges (1-2, 3-4, ...) Attributes: length: Number of vertices in the ring. @@ -107,13 +108,14 @@ def __init__(self, length: int, self_loops: bool = False) -> None: _edges.append(Hyperedge([i, (i + 1) % length])) super().__init__(_edges) - # Set up edge partitions for parallel updates + # Update color for self-loop edges if self_loops: - self.parts = [list(range(length))] - else: - self.parts = [] + for i in range(length): + self.color[(i,)] = -1 - self.parts.append(list(range(0, length, 2))) - self.parts.append(list(range(1, length, 2))) + for i in range(length): + j = (i + 1) % length + color = i % 2 + self.color[tuple(sorted([i, j]))] = color self.length = length diff --git a/source/pip/qsharp/magnets/geometry/lattice2d.py b/source/pip/qsharp/magnets/geometry/lattice2d.py index fc98f9de9d..e04817ef92 100644 --- a/source/pip/qsharp/magnets/geometry/lattice2d.py +++ b/source/pip/qsharp/magnets/geometry/lattice2d.py @@ -19,12 +19,12 @@ class Patch2D(Hypergraph): Vertices are indexed in row-major order: vertex (x, y) has index y * width + x. - The edges are partitioned into parts for parallel updates: - - Part 0 (if self_loops): Self-loop edges on each vertex - - Part 1: Even-column horizontal edges - - Part 2: Odd-column horizontal edges - - Part 3: Even-row vertical edges - - Part 4: Odd-row vertical edges + Edges are colored for parallel updates: + - Color -1 (if self_loops): Self-loop edges on each vertex + - Color 0: Even-column horizontal edges + - Color 1: Odd-column horizontal edges + - Color 2: Even-row vertical edges + - Color 3: Odd-row vertical edges Attributes: width: Number of vertices in the horizontal direction. @@ -49,9 +49,8 @@ def __init__(self, width: int, height: int, self_loops: bool = False) -> None: self_loops: If True, include self-loop edges on each vertex for single-site terms. """ - - def index(x: int, y: int) -> int: - return y * width + x + self.width = width + self.height = height if self_loops: _edges = [Hyperedge([i]) for i in range(width * height)] @@ -59,44 +58,38 @@ def index(x: int, y: int) -> int: _edges = [] # Horizontal edges (connecting (x, y) to (x+1, y)) - horizontal_even = [] - horizontal_odd = [] for y in range(height): for x in range(width - 1): - edge_idx = len(_edges) - _edges.append(Hyperedge([index(x, y), index(x + 1, y)])) - if x % 2 == 0: - horizontal_even.append(edge_idx) - else: - horizontal_odd.append(edge_idx) + _edges.append(Hyperedge([self._index(x, y), self._index(x + 1, y)])) # Vertical edges (connecting (x, y) to (x, y+1)) - vertical_even = [] - vertical_odd = [] for y in range(height - 1): for x in range(width): - edge_idx = len(_edges) - _edges.append(Hyperedge([index(x, y), index(x, y + 1)])) - if y % 2 == 0: - vertical_even.append(edge_idx) - else: - vertical_odd.append(edge_idx) - + _edges.append(Hyperedge([self._index(x, y), self._index(x, y + 1)])) super().__init__(_edges) - # Set up edge partitions for parallel updates + # Set up edge colors for parallel updates if self_loops: - self.parts = [list(range(width * height))] - else: - self.parts = [] + for i in range(width * height): + self.color[(i,)] = -1 + + # Color horizontal edges + for y in range(height): + for x in range(width - 1): + v1, v2 = self._index(x, y), self._index(x + 1, y) + color = 0 if x % 2 == 0 else 1 + self.color[tuple(sorted([v1, v2]))] = color - self.parts.append(horizontal_even) - self.parts.append(horizontal_odd) - self.parts.append(vertical_even) - self.parts.append(vertical_odd) + # Color vertical edges + for y in range(height - 1): + for x in range(width): + v1, v2 = self._index(x, y), self._index(x, y + 1) + color = 2 if y % 2 == 0 else 3 + self.color[tuple(sorted([v1, v2]))] = color - self.width = width - self.height = height + def _index(self, x: int, y: int) -> int: + """Convert (x, y) coordinates to vertex index.""" + return y * self.width + x class Torus2D(Hypergraph): @@ -108,12 +101,12 @@ class Torus2D(Hypergraph): Vertices are indexed in row-major order: vertex (x, y) has index y * width + x. - The edges are partitioned into parts for parallel updates: - - Part 0 (if self_loops): Self-loop edges on each vertex - - Part 1: Even-column horizontal edges - - Part 2: Odd-column horizontal edges - - Part 3: Even-row vertical edges - - Part 4: Odd-row vertical edges + Edges are colored for parallel updates: + - Color -1 (if self_loops): Self-loop edges on each vertex + - Color 0: Even-column horizontal edges + - Color 1: Odd-column horizontal edges + - Color 2: Even-row vertical edges + - Color 3: Odd-row vertical edges Attributes: width: Number of vertices in the horizontal direction. @@ -138,9 +131,8 @@ def __init__(self, width: int, height: int, self_loops: bool = False) -> None: self_loops: If True, include self-loop edges on each vertex for single-site terms. """ - - def index(x: int, y: int) -> int: - return y * width + x + self.width = width + self.height = height if self_loops: _edges = [Hyperedge([i]) for i in range(width * height)] @@ -148,41 +140,40 @@ def index(x: int, y: int) -> int: _edges = [] # Horizontal edges (connecting (x, y) to ((x+1) % width, y)) - horizontal_even = [] - horizontal_odd = [] for y in range(height): for x in range(width): - edge_idx = len(_edges) - _edges.append(Hyperedge([index(x, y), index((x + 1) % width, y)])) - if x % 2 == 0: - horizontal_even.append(edge_idx) - else: - horizontal_odd.append(edge_idx) + _edges.append( + Hyperedge([self._index(x, y), self._index((x + 1) % width, y)]) + ) # Vertical edges (connecting (x, y) to (x, (y+1) % height)) - vertical_even = [] - vertical_odd = [] for y in range(height): for x in range(width): - edge_idx = len(_edges) - _edges.append(Hyperedge([index(x, y), index(x, (y + 1) % height)])) - if y % 2 == 0: - vertical_even.append(edge_idx) - else: - vertical_odd.append(edge_idx) + _edges.append( + Hyperedge([self._index(x, y), self._index(x, (y + 1) % height)]) + ) super().__init__(_edges) - # Set up edge partitions for parallel updates + # Set up edge colors for parallel updates if self_loops: - self.parts = [list(range(width * height))] - else: - self.parts = [] + for i in range(width * height): + self.color[(i,)] = -1 - self.parts.append(horizontal_even) - self.parts.append(horizontal_odd) - self.parts.append(vertical_even) - self.parts.append(vertical_odd) + # Color horizontal edges + for y in range(height): + for x in range(width): + v1, v2 = self._index(x, y), self._index((x + 1) % width, y) + color = 0 if x % 2 == 0 else 1 + self.color[tuple(sorted([v1, v2]))] = color - self.width = width - self.height = height + # Color vertical edges + for y in range(height): + for x in range(width): + v1, v2 = self._index(x, y), self._index(x, (y + 1) % height) + color = 2 if y % 2 == 0 else 3 + self.color[tuple(sorted([v1, v2]))] = color + + def _index(self, x: int, y: int) -> int: + """Convert (x, y) coordinates to vertex index.""" + return y * self.width + x diff --git a/source/pip/tests/magnets/test_complete.py b/source/pip/tests/magnets/test_complete.py index ad1bc28769..93237e7ced 100644 --- a/source/pip/tests/magnets/test_complete.py +++ b/source/pip/tests/magnets/test_complete.py @@ -81,9 +81,9 @@ def test_complete_graph_self_loops_edges(): graph = CompleteGraph(3, self_loops=True) edges = list(graph.edges()) # First 3 edges should be self-loops - assert edges[0].vertices == [0] - assert edges[1].vertices == [1] - assert edges[2].vertices == [2] + assert edges[0].vertices == (0,) + assert edges[1].vertices == (1,) + assert edges[2].vertices == (2,) def test_complete_graph_edge_count_formula(): @@ -185,10 +185,10 @@ def test_complete_bipartite_graph_self_loops_edges(): graph = CompleteBipartiteGraph(2, 2, self_loops=True) edges = list(graph.edges()) # First 4 edges should be self-loops - assert edges[0].vertices == [0] - assert edges[1].vertices == [1] - assert edges[2].vertices == [2] - assert edges[3].vertices == [3] + assert edges[0].vertices == (0,) + assert edges[1].vertices == (1,) + assert edges[2].vertices == (2,) + assert edges[3].vertices == (3,) def test_complete_bipartite_graph_edge_count_formula(): @@ -200,33 +200,36 @@ def test_complete_bipartite_graph_edge_count_formula(): assert graph.nedges == expected_edges -def test_complete_bipartite_graph_parts_without_self_loops(): - """Test edge partitioning without self-loops.""" +def test_complete_bipartite_graph_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" graph = CompleteBipartiteGraph(3, 4) - # Should have at least n parts for bipartite coloring - assert len(graph.parts) >= 4 + # Should have n colors for bipartite coloring + assert graph.ncolors == 4 -def test_complete_bipartite_graph_parts_with_self_loops(): - """Test edge partitioning with self-loops.""" +def test_complete_bipartite_graph_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" graph = CompleteBipartiteGraph(3, 4, self_loops=True) - # Should have n + 1 parts: self-loops + n color groups - assert len(graph.parts) == 5 + # Self-loops get color -1, bipartite edges get n colors (0 to n-1) + # So total distinct colors = n + 1 (including -1) + assert graph.ncolors == 5 -def test_complete_bipartite_graph_parts_non_overlapping(): - """Test that edges in the same part don't share vertices.""" +def test_complete_bipartite_graph_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" graph = CompleteBipartiteGraph(3, 4) - # Skip the first part if it contains all edges (default from Hypergraph) - parts_to_check = graph.parts - if len(parts_to_check) > 0 and len(parts_to_check[0]) == graph.nedges: - parts_to_check = parts_to_check[1:] - for part_indices in parts_to_check: + # Group edges by color + colors = {} + for edge_vertices, color in graph.color.items(): + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): used_vertices = set() - for idx in part_indices: - edge = graph._edge_list[idx] - assert not any(v in used_vertices for v in edge.vertices) - used_vertices.update(edge.vertices) + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) def test_complete_bipartite_graph_str(): @@ -244,4 +247,4 @@ def test_complete_bipartite_graph_inherits_hypergraph(): assert isinstance(graph, Hypergraph) assert hasattr(graph, "edges") assert hasattr(graph, "vertices") - assert hasattr(graph, "edges_by_part") + assert hasattr(graph, "edges_by_color") diff --git a/source/pip/tests/magnets/test_hypergraph.py b/source/pip/tests/magnets/test_hypergraph.py index 3063fcb727..c158d9589e 100755 --- a/source/pip/tests/magnets/test_hypergraph.py +++ b/source/pip/tests/magnets/test_hypergraph.py @@ -16,26 +16,26 @@ def test_hyperedge_init_basic(): """Test basic Hyperedge initialization.""" edge = Hyperedge([0, 1]) - assert edge.vertices == [0, 1] + assert edge.vertices == (0, 1) def test_hyperedge_vertices_sorted(): """Test that vertices are automatically sorted.""" edge = Hyperedge([3, 1, 2]) - assert edge.vertices == [1, 2, 3] + assert edge.vertices == (1, 2, 3) def test_hyperedge_single_vertex(): """Test hyperedge with single vertex (self-loop).""" edge = Hyperedge([5]) - assert edge.vertices == [5] + assert edge.vertices == (5,) assert len(edge.vertices) == 1 def test_hyperedge_multiple_vertices(): """Test hyperedge with multiple vertices (multi-body interaction).""" edge = Hyperedge([0, 1, 2, 3]) - assert edge.vertices == [0, 1, 2, 3] + assert edge.vertices == (0, 1, 2, 3) assert len(edge.vertices) == 4 @@ -48,14 +48,14 @@ def test_hyperedge_repr(): def test_hyperedge_empty_vertices(): """Test hyperedge with empty vertex list.""" edge = Hyperedge([]) - assert edge.vertices == [] + assert edge.vertices == () assert len(edge.vertices) == 0 def test_hyperedge_duplicate_vertices(): """Test that duplicate vertices are removed.""" edge = Hyperedge([1, 2, 2, 1, 3]) - assert edge.vertices == [1, 2, 3] + assert edge.vertices == (1, 2, 3) # Hypergraph tests @@ -113,12 +113,12 @@ def test_hypergraph_edges_iterator(): assert len(edge_list) == 2 -def test_hypergraph_edges_by_part(): - """Test edgesByPart returns edges in a specific partition.""" +def test_hypergraph_edges_by_color(): + """Test edges_by_color returns edges with a specific color.""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] graph = Hypergraph(edges) - # Default: all edges in part 0 - edge_list = list(graph.edges_by_part(0)) + # Default: all edges have color 0 + edge_list = list(graph.edges_by_color(0)) assert len(edge_list) == 2 @@ -130,22 +130,22 @@ def test_hypergraph_add_edge(): assert graph.nvertices == 2 -def test_hypergraph_add_edge_to_part(): - """Test adding edges to different partitions.""" +def test_hypergraph_add_edge_with_color(): + """Test adding edges with different colors.""" graph = Hypergraph([Hyperedge([0, 1])]) - graph.parts.append([]) # Add a second partition - graph.add_edge(Hyperedge([2, 3]), part=1) + graph.add_edge(Hyperedge([2, 3]), color=1) assert graph.nedges == 2 - assert len(graph.parts[0]) == 1 - assert len(graph.parts[1]) == 1 + assert graph.color[(0, 1)] == 0 + assert graph.color[(2, 3)] == 1 -def test_hypergraph_parts_default(): - """Test that default parts contain all edge indices.""" +def test_hypergraph_color_default(): + """Test that default colors are all 0.""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] graph = Hypergraph(edges) - assert len(graph.parts) == 1 - assert graph.parts[0] == [0, 1, 2] + assert graph.color[(0, 1)] == 0 + assert graph.color[(1, 2)] == 0 + assert graph.color[(2, 3)] == 0 def test_hypergraph_str(): @@ -202,8 +202,7 @@ def test_greedy_edge_coloring_empty(): graph = Hypergraph([]) colored = greedy_edge_coloring(graph) assert colored.nedges == 0 - assert len(colored.parts) == 1 - assert colored.parts[0] == [] + assert colored.ncolors == 0 def test_greedy_edge_coloring_single_edge(): @@ -211,7 +210,7 @@ def test_greedy_edge_coloring_single_edge(): graph = Hypergraph([Hyperedge([0, 1])]) colored = greedy_edge_coloring(graph, seed=42) assert colored.nedges == 1 - assert len(colored.parts) == 1 + assert colored.ncolors == 1 def test_greedy_edge_coloring_non_overlapping(): @@ -221,7 +220,7 @@ def test_greedy_edge_coloring_non_overlapping(): colored = greedy_edge_coloring(graph, seed=42) # Non-overlapping edges can be in the same color assert colored.nedges == 2 - assert len(colored.parts) == 1 + assert colored.ncolors == 1 def test_greedy_edge_coloring_overlapping(): @@ -231,7 +230,7 @@ def test_greedy_edge_coloring_overlapping(): colored = greedy_edge_coloring(graph, seed=42) # Overlapping edges need different colors assert colored.nedges == 2 - assert len(colored.parts) == 2 + assert colored.ncolors == 2 def test_greedy_edge_coloring_triangle(): @@ -241,11 +240,11 @@ def test_greedy_edge_coloring_triangle(): colored = greedy_edge_coloring(graph, seed=42) # All edges share vertices pairwise, so need 3 colors assert colored.nedges == 3 - assert len(colored.parts) == 3 + assert colored.ncolors == 3 def test_greedy_edge_coloring_validity(): - """Test that coloring is valid (no two edges in same part share a vertex).""" + """Test that coloring is valid (no two edges with same color share a vertex).""" edges = [ Hyperedge([0, 1]), Hyperedge([1, 2]), @@ -256,29 +255,33 @@ def test_greedy_edge_coloring_validity(): graph = Hypergraph(edges) colored = greedy_edge_coloring(graph, seed=42) - # Verify each part has no overlapping edges - for part in colored.parts: + # Group edges by color + colors = {} + for edge_vertices, color in colored.color.items(): + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + + # Verify each color group has no overlapping edges + for color, edge_list in colors.items(): used_vertices = set() - for edge_idx in part: - edge = colored._edge_list[edge_idx] - # No vertex should already be used in this part - assert not any(v in used_vertices for v in edge.vertices) - used_vertices.update(edge.vertices) + for vertices in edge_list: + # No vertex should already be used in this color + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) def test_greedy_edge_coloring_all_edges_colored(): - """Test that all edges are assigned to exactly one part.""" + """Test that all edges are assigned a color.""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] graph = Hypergraph(edges) colored = greedy_edge_coloring(graph, seed=42) - # Collect all edge indices from all parts - all_colored = [] - for part in colored.parts: - all_colored.extend(part) - - # Should have exactly 3 edges colored, each once - assert sorted(all_colored) == [0, 1, 2] + # All edges should have a color assigned + assert len(colored.color) == 3 + assert (0, 1) in colored.color + assert (1, 2) in colored.color + assert (2, 3) in colored.color def test_greedy_edge_coloring_reproducible_with_seed(): @@ -289,7 +292,7 @@ def test_greedy_edge_coloring_reproducible_with_seed(): colored1 = greedy_edge_coloring(graph, seed=123) colored2 = greedy_edge_coloring(graph, seed=123) - assert colored1.parts == colored2.parts + assert colored1.color == colored2.color def test_greedy_edge_coloring_multiple_trials(): @@ -303,7 +306,7 @@ def test_greedy_edge_coloring_multiple_trials(): graph = Hypergraph(edges) colored = greedy_edge_coloring(graph, seed=42, trials=10) # A cycle of 4 edges can be 2-colored - assert len(colored.parts) <= 3 # Greedy may not always find optimal + assert colored.ncolors <= 3 # Greedy may not always find optimal def test_greedy_edge_coloring_hyperedges(): @@ -318,7 +321,7 @@ def test_greedy_edge_coloring_hyperedges(): # First two share vertex 2, third is independent assert colored.nedges == 3 - assert len(colored.parts) >= 2 + assert colored.ncolors >= 2 def test_greedy_edge_coloring_self_loops(): @@ -329,4 +332,4 @@ def test_greedy_edge_coloring_self_loops(): # Self-loops don't share vertices, can all be same color assert colored.nedges == 3 - assert len(colored.parts) == 1 + assert colored.ncolors == 1 diff --git a/source/pip/tests/magnets/test_lattice1d.py b/source/pip/tests/magnets/test_lattice1d.py index f940506f36..e9ccacd519 100644 --- a/source/pip/tests/magnets/test_lattice1d.py +++ b/source/pip/tests/magnets/test_lattice1d.py @@ -37,10 +37,10 @@ def test_chain1d_edges(): chain = Chain1D(4) edges = list(chain.edges()) assert len(edges) == 3 - # Check edges are [0,1], [1,2], [2,3] - assert edges[0].vertices == [0, 1] - assert edges[1].vertices == [1, 2] - assert edges[2].vertices == [2, 3] + # Check edges are (0,1), (1,2), (2,3) + assert edges[0].vertices == (0, 1) + assert edges[1].vertices == (1, 2) + assert edges[2].vertices == (2, 3) def test_chain1d_vertices(): @@ -63,39 +63,54 @@ def test_chain1d_self_loops_edges(): chain = Chain1D(3, self_loops=True) edges = list(chain.edges()) # First 3 edges should be self-loops - assert edges[0].vertices == [0] - assert edges[1].vertices == [1] - assert edges[2].vertices == [2] + assert edges[0].vertices == (0,) + assert edges[1].vertices == (1,) + assert edges[2].vertices == (2,) # Next 2 edges should be nearest-neighbor - assert edges[3].vertices == [0, 1] - assert edges[4].vertices == [1, 2] + assert edges[3].vertices == (0, 1) + assert edges[4].vertices == (1, 2) -def test_chain1d_parts_without_self_loops(): - """Test edge partitioning without self-loops.""" +def test_chain1d_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" chain = Chain1D(5) - # Should have 2 parts: even edges [0,2] and odd edges [1,3] - assert len(chain.parts) == 2 - assert chain.parts[0] == [0, 2] # edges 0-1, 2-3 - assert chain.parts[1] == [1, 3] # edges 1-2, 3-4 + # Even edges (0-1, 2-3) should have color 0 + assert chain.color[(0, 1)] == 0 + assert chain.color[(2, 3)] == 0 + # Odd edges (1-2, 3-4) should have color 1 + assert chain.color[(1, 2)] == 1 + assert chain.color[(3, 4)] == 1 -def test_chain1d_parts_with_self_loops(): - """Test edge partitioning with self-loops.""" +def test_chain1d_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" chain = Chain1D(4, self_loops=True) - # Should have 3 parts: self-loops, even edges, odd edges - assert len(chain.parts) == 3 - - -def test_chain1d_parts_non_overlapping(): - """Test that edges in the same part don't share vertices.""" + # Self-loops should have color -1 + assert chain.color[(0,)] == -1 + assert chain.color[(1,)] == -1 + assert chain.color[(2,)] == -1 + assert chain.color[(3,)] == -1 + # Even edges should have color 0, odd edges color 1 + assert chain.color[(0, 1)] == 0 + assert chain.color[(1, 2)] == 1 + assert chain.color[(2, 3)] == 0 + + +def test_chain1d_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" chain = Chain1D(6) - for part_indices in chain.parts: + # Group edges by color + colors = {} + for edge_vertices, color in chain.color.items(): + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): used_vertices = set() - for idx in part_indices: - edge = chain._edge_list[idx] - assert not any(v in used_vertices for v in edge.vertices) - used_vertices.update(edge.vertices) + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) def test_chain1d_str(): @@ -136,11 +151,11 @@ def test_ring1d_edges(): ring = Ring1D(4) edges = list(ring.edges()) assert len(edges) == 4 - # Check edges are [0,1], [1,2], [2,3], [0,3] (sorted) - assert edges[0].vertices == [0, 1] - assert edges[1].vertices == [1, 2] - assert edges[2].vertices == [2, 3] - assert edges[3].vertices == [0, 3] # Wrap-around edge + # Check edges are (0,1), (1,2), (2,3), (0,3) (sorted) + assert edges[0].vertices == (0, 1) + assert edges[1].vertices == (1, 2) + assert edges[2].vertices == (2, 3) + assert edges[3].vertices == (0, 3) # Wrap-around edge def test_ring1d_vertices(): @@ -163,38 +178,55 @@ def test_ring1d_self_loops_edges(): ring = Ring1D(3, self_loops=True) edges = list(ring.edges()) # First 3 edges should be self-loops - assert edges[0].vertices == [0] - assert edges[1].vertices == [1] - assert edges[2].vertices == [2] + assert edges[0].vertices == (0,) + assert edges[1].vertices == (1,) + assert edges[2].vertices == (2,) # Next 3 edges should be nearest-neighbor (including wrap) - assert edges[3].vertices == [0, 1] - assert edges[4].vertices == [1, 2] - assert edges[5].vertices == [0, 2] # Wrap-around + assert edges[3].vertices == (0, 1) + assert edges[4].vertices == (1, 2) + assert edges[5].vertices == (0, 2) # Wrap-around -def test_ring1d_parts_without_self_loops(): - """Test edge partitioning without self-loops.""" +def test_ring1d_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" ring = Ring1D(4) - # Should have 2 parts for parallel updates - assert len(ring.parts) == 2 + # Even edges should have color 0, odd edges color 1 + assert ring.color[(0, 1)] == 0 + assert ring.color[(1, 2)] == 1 + assert ring.color[(2, 3)] == 0 + assert ring.color[(0, 3)] == 1 # Wrap-around edge (index 3) -def test_ring1d_parts_with_self_loops(): - """Test edge partitioning with self-loops.""" +def test_ring1d_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" ring = Ring1D(4, self_loops=True) - # Should have 3 parts: self-loops, even edges, odd edges - assert len(ring.parts) == 3 - - -def test_ring1d_parts_non_overlapping(): - """Test that edges in the same part don't share vertices.""" + # Self-loops should have color -1 + assert ring.color[(0,)] == -1 + assert ring.color[(1,)] == -1 + assert ring.color[(2,)] == -1 + assert ring.color[(3,)] == -1 + # Even edges should have color 0, odd edges color 1 + assert ring.color[(0, 1)] == 0 + assert ring.color[(1, 2)] == 1 + assert ring.color[(2, 3)] == 0 + assert ring.color[(0, 3)] == 1 + + +def test_ring1d_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" ring = Ring1D(6) - for part_indices in ring.parts: + # Group edges by color + colors = {} + for edge_vertices, color in ring.color.items(): + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): used_vertices = set() - for idx in part_indices: - edge = ring._edge_list[idx] - assert not any(v in used_vertices for v in edge.vertices) - used_vertices.update(edge.vertices) + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) def test_ring1d_str(): @@ -221,7 +253,7 @@ def test_chain1d_inherits_hypergraph(): # Test inherited methods work assert hasattr(chain, "edges") assert hasattr(chain, "vertices") - assert hasattr(chain, "edges_by_part") + assert hasattr(chain, "edges_by_color") def test_ring1d_inherits_hypergraph(): @@ -233,4 +265,4 @@ def test_ring1d_inherits_hypergraph(): # Test inherited methods work assert hasattr(ring, "edges") assert hasattr(ring, "vertices") - assert hasattr(ring, "edges_by_part") + assert hasattr(ring, "edges_by_color") diff --git a/source/pip/tests/magnets/test_lattice2d.py b/source/pip/tests/magnets/test_lattice2d.py index 8be8e816f6..ccf95c313a 100644 --- a/source/pip/tests/magnets/test_lattice2d.py +++ b/source/pip/tests/magnets/test_lattice2d.py @@ -86,35 +86,41 @@ def test_patch2d_self_loops_edges(): patch = Patch2D(2, 2, self_loops=True) edges = list(patch.edges()) # First 4 edges should be self-loops - assert edges[0].vertices == [0] - assert edges[1].vertices == [1] - assert edges[2].vertices == [2] - assert edges[3].vertices == [3] + assert edges[0].vertices == (0,) + assert edges[1].vertices == (1,) + assert edges[2].vertices == (2,) + assert edges[3].vertices == (3,) -def test_patch2d_parts_without_self_loops(): - """Test edge partitioning without self-loops.""" +def test_patch2d_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" patch = Patch2D(4, 4) - # Should have 4 parts: horizontal even/odd, vertical even/odd - assert len(patch.parts) == 4 + # Should have 4 colors: horizontal even/odd (0,1), vertical even/odd (2,3) + assert patch.ncolors == 4 -def test_patch2d_parts_with_self_loops(): - """Test edge partitioning with self-loops.""" +def test_patch2d_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" patch = Patch2D(3, 3, self_loops=True) - # Should have 5 parts: self-loops + 4 edge groups - assert len(patch.parts) == 5 + # Should have 5 colors: self-loops (-1) + 4 edge groups (0-3) + assert patch.ncolors == 5 -def test_patch2d_parts_non_overlapping(): - """Test that edges in the same part don't share vertices.""" +def test_patch2d_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" patch = Patch2D(4, 4) - for part_indices in patch.parts: + # Group edges by color + colors = {} + for edge_vertices, color in patch.color.items(): + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): used_vertices = set() - for idx in part_indices: - edge = patch._edge_list[idx] - assert not any(v in used_vertices for v in edge.vertices) - used_vertices.update(edge.vertices) + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) def test_patch2d_str(): @@ -205,35 +211,41 @@ def test_torus2d_self_loops_edges(): torus = Torus2D(2, 2, self_loops=True) edges = list(torus.edges()) # First 4 edges should be self-loops - assert edges[0].vertices == [0] - assert edges[1].vertices == [1] - assert edges[2].vertices == [2] - assert edges[3].vertices == [3] + assert edges[0].vertices == (0,) + assert edges[1].vertices == (1,) + assert edges[2].vertices == (2,) + assert edges[3].vertices == (3,) -def test_torus2d_parts_without_self_loops(): - """Test edge partitioning without self-loops.""" +def test_torus2d_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" torus = Torus2D(4, 4) - # Should have 4 parts: horizontal even/odd, vertical even/odd - assert len(torus.parts) == 4 + # Should have 4 colors: horizontal even/odd (0,1), vertical even/odd (2,3) + assert torus.ncolors == 4 -def test_torus2d_parts_with_self_loops(): - """Test edge partitioning with self-loops.""" +def test_torus2d_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" torus = Torus2D(3, 3, self_loops=True) - # Should have 5 parts: self-loops + 4 edge groups - assert len(torus.parts) == 5 + # Should have 5 colors: self-loops (-1) + 4 edge groups (0-3) + assert torus.ncolors == 5 -def test_torus2d_parts_non_overlapping(): - """Test that edges in the same part don't share vertices.""" +def test_torus2d_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" torus = Torus2D(4, 4) - for part_indices in torus.parts: + # Group edges by color + colors = {} + for edge_vertices, color in torus.color.items(): + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): used_vertices = set() - for idx in part_indices: - edge = torus._edge_list[idx] - assert not any(v in used_vertices for v in edge.vertices) - used_vertices.update(edge.vertices) + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) def test_torus2d_str(): @@ -262,7 +274,7 @@ def test_patch2d_inherits_hypergraph(): # Test inherited methods work assert hasattr(patch, "edges") assert hasattr(patch, "vertices") - assert hasattr(patch, "edges_by_part") + assert hasattr(patch, "edges_by_color") def test_torus2d_inherits_hypergraph(): @@ -274,4 +286,4 @@ def test_torus2d_inherits_hypergraph(): # Test inherited methods work assert hasattr(torus, "edges") assert hasattr(torus, "vertices") - assert hasattr(torus, "edges_by_part") + assert hasattr(torus, "edges_by_color") From b3f519830c650b7749f148e7c41851e29fe68015 Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Mon, 9 Feb 2026 00:40:42 -0800 Subject: [PATCH 12/45] Magnets: refactored Trotter step classes and added Suzuki and Yoshida recursion (#2926) Changed some basic functionality: - The TrotterStep class is now just a wrapper class with some information functions and an iterator - Factory functions instantiate TrotterStep classes at desired orders (trotter_decomposition, strang_splitting) Added recursion - Both the Suzuki fractal recursion and Yoshide triple recursion are implemented (suzuki_recursion, yoshide_recursion) - fourth_order_trotter_suzuki is a convenience function for strang_splitting composed with suzuki_recursion Tests added. This PR is independent from #2925. --------- Co-authored-by: Mathias Soeken --- source/pip/qsharp/magnets/trotter/__init__.py | 16 +- source/pip/qsharp/magnets/trotter/trotter.py | 374 +++++++++++---- source/pip/tests/magnets/test_trotter.py | 440 ++++++++++++++---- 3 files changed, 644 insertions(+), 186 deletions(-) diff --git a/source/pip/qsharp/magnets/trotter/__init__.py b/source/pip/qsharp/magnets/trotter/__init__.py index f3107d526a..95dc485fa7 100644 --- a/source/pip/qsharp/magnets/trotter/__init__.py +++ b/source/pip/qsharp/magnets/trotter/__init__.py @@ -3,10 +3,22 @@ """Trotter-Suzuki methods for time evolution.""" -from .trotter import TrotterStep, StrangStep, TrotterExpansion +from .trotter import ( + TrotterStep, + TrotterExpansion, + trotter_decomposition, + strang_splitting, + suzuki_recursion, + yoshida_recursion, + fourth_order_trotter_suzuki, +) __all__ = [ "TrotterStep", - "StrangStep", "TrotterExpansion", + "trotter_decomposition", + "strang_splitting", + "suzuki_recursion", + "yoshida_recursion", + "fourth_order_trotter_suzuki", ] diff --git a/source/pip/qsharp/magnets/trotter/trotter.py b/source/pip/qsharp/magnets/trotter/trotter.py index b598fe5abf..0568db61fc 100644 --- a/source/pip/qsharp/magnets/trotter/trotter.py +++ b/source/pip/qsharp/magnets/trotter/trotter.py @@ -4,134 +4,296 @@ """Base Trotter class for first- and second-order Trotter-Suzuki decomposition.""" -class TrotterStep: - """ - Base class for Trotter decompositions. Essentially, this is a wrapper around - a list of (time, term_index) tuples, which specify which term to apply for - how long. +from typing import Iterator - As a default, the base class implements the first-order Trotter-Suzuki formula - for approximating time evolution under a Hamiltonian represented as a sum of - terms H = ∑_k H_k by sequentially applying each term for the full time - e^{-i H t} ≈ ∏_k e^{-i H_k t}. +class TrotterStep: + """ + Base class for Trotter decompositions. Essentially, this is a wrapper around a + list of (time, term_index) tuples, which specify which term to apply for how long. + + The TrotterStep class provides a common interface for different Trotter decompositions, + such as first-order Trotter and Strang splitting. It also serves as the base class for + higher-order Trotter steps that can be constructed via Suzuki or Yoshida recursion. Each + Trotter step is defined by the sequence of terms to apply and their corresponding time + durations, as well as the overall order of the decomposition and the time step for each term. + """ - This base class is designed for lazy evaluation: the list of (time, term_index) - tuples is only generated when the get() method is called. + def __init__(self): + """ + Creates an empty Trotter decomposition. - Example: + """ + self.terms: list[tuple[float, int]] = [] + self._nterms = 0 + self._time_step = 0.0 + self._order = 0 + self._repr_string = "TrotterStep()" + + @property + def order(self) -> int: + """Get the order of the Trotter decomposition.""" + return self._order + + @property + def nterms(self) -> int: + """Get the number of terms in the Hamiltonian.""" + return self._nterms + + @property + def time_step(self) -> float: + """Get the time step for each term in the Trotter decomposition.""" + return self._time_step + + def reduce(self) -> None: + """ + Reduce the Trotter step in place by combining consecutive terms that are the same. - .. code-block:: python - >>> trotter = TrotterStep(num_terms=3, time=0.5) - >>> trotter.get() - [(0.5, 0), (0.5, 1), (0.5, 2)] - """ + This can be useful for optimizing the Trotter sequence by merging adjacent + applications of the same term into a single application with a longer time step. - def __init__(self, num_terms: int, time: float): + Example: + >>> trotter = TrotterStep() + >>> trotter.terms = [(0.5, 0), (0.5, 0), (0.5, 1)] + >>> trotter.reduce() + >>> list(trotter.step()) + [(1.0, 0), (0.5, 1)] """ - Initialize the Trotter decomposition. + if len(self.terms) > 1: + reduced_terms: list[tuple[float, int]] = [] + current_time, current_term = self.terms[0] - Args: - num_terms: Number of terms in the Hamiltonian - time: Total time for the evolution - """ - self._num_terms = num_terms - self._time_step = time + for time, term in self.terms[1:]: + if term == current_term: + current_time += time + else: + reduced_terms.append((current_time, current_term)) + current_time, current_term = time, term - def get(self) -> list[tuple[float, int]]: + reduced_terms.append((current_time, current_term)) + self.terms = reduced_terms + + def step(self) -> Iterator[tuple[float, int]]: """ - Get the Trotter decomposition as a list of (time, term_index) tuples. + Iterate over the Trotter decomposition as a list of (time, term_index) tuples. Returns: - List of tuples where each tuple contains the time duration and the + Iterator of tuples where each tuple contains the time duration and the index of the term to be applied. """ - return [(self._time_step, term_index) for term_index in range(self._num_terms)] + return iter(self.terms) def __str__(self) -> str: """String representation of the Trotter decomposition.""" - return f"Trotter(time_step={self._time_step}, num_terms={self._num_terms})" + return f"Trotter expansion of order {self._order}: time_step={self._time_step}, num_terms={self._nterms}" def __repr__(self) -> str: """String representation of the Trotter decomposition.""" - return self.__str__() + return self._repr_string -class StrangStep(TrotterStep): +def suzuki_recursion(trotter: TrotterStep) -> TrotterStep: """ - Strang splitting (second-order Trotter-Suzuki decomposition). + Apply one level of Suzuki recursion to double the order of a Trotter step. - The second-order Trotter formula uses symmetric splitting: - e^{-i H t} ≈ ∏_{k=1}^{n} e^{-i H_k t/2} ∏_{k=n}^{1} e^{-i H_k t/2} + Given a k-th order Trotter step S_k(t), this function constructs a (k+2)-nd order + step using the Suzuki fractal decomposition: - This provides second-order accuracy in the time step, compared to - first-order for the basic Trotter decomposition. + S_{k+2}(t) = S_{k}(p t) S_{k}(p t) S_{k}((1 - 4p) t) S_{k}(p t) S_{k}(p t) + + where p = 1 / (4 - 4^{1/(2k+1)}). + + The resulting step has improved accuracy: the error scales as O(t^{k+3}) instead + of O(t^{k+1}), at the cost of 5x more exponential applications per step. + + Args: + trotter: A TrotterStep of order k to be promoted to order k+2. + + Returns: + A new TrotterStep of order k+2 constructed via Suzuki recursion. + + References: + M. Suzuki, Phys. Lett. A 146, 319 (1990). + """ + + suzuki = TrotterStep() + suzuki._nterms = trotter._nterms + suzuki._time_step = trotter._time_step + suzuki._order = trotter._order + 2 + suzuki._repr_string = f"SuzukiRecursion(order={suzuki._order}, time_step={suzuki._time_step}, num_terms={suzuki._nterms})" + + p = 1 / (4 - 4 ** (1 / (2 * trotter.order + 1))) + + suzuki.terms = [(p * time, term_index) for time, term_index in trotter.step()] + suzuki.terms += [(p * time, term_index) for time, term_index in trotter.step()] + suzuki.terms += [ + ((1 - 4 * p) * time, term_index) for time, term_index in trotter.step() + ] + suzuki.terms += [(p * time, term_index) for time, term_index in trotter.step()] + suzuki.terms += [(p * time, term_index) for time, term_index in trotter.step()] + suzuki.reduce() # Combine consecutive terms that are the same + + return suzuki + + +def yoshida_recursion(trotter: TrotterStep) -> TrotterStep: + """ + Apply one level of Yoshida recursion to increase the order of a Trotter step by 2. + + Given a k-th order Trotter step S_k(t), this function constructs a (k+2)-nd order + step using Yoshida's symmetric triple-jump composition: + + S_{k+2}(t) = S_{k}(w_1 t) S_{k}(w_0 t) S_{k}(w_1 t) + + where: + w_1 = 1 / (2 - 2^{1/(2k+1)}) + w_0 = -2^{1/(2k+1)} / (2 - 2^{1/(2k+1)}) = 1 - 2 w_1 + + The resulting step has improved accuracy: the error scales as O(t^{k+3}) instead + of O(t^{k+1}), at the cost of 3x more exponential applications per step. + + Args: + trotter: A TrotterStep of order k to be promoted to order k+2. + + Returns: + A new TrotterStep of order k+2 constructed via Yoshida recursion. + + References: + H. Yoshida, Phys. Lett. A 150, 262 (1990). + """ + + yoshida = TrotterStep() + yoshida._nterms = trotter._nterms + yoshida._time_step = trotter._time_step + yoshida._order = trotter._order + 2 + yoshida._repr_string = f"YoshidaRecursion(order={yoshida._order}, time_step={yoshida._time_step}, num_terms={yoshida._nterms})" + + cube_root_2 = 2 ** (1 / (2 * trotter.order + 1)) + w1 = 1 / (2 - cube_root_2) + w0 = 1 - 2 * w1 # equivalent to -cube_root_2 / (2 - cube_root_2) + + yoshida.terms = [(w1 * time, term_index) for time, term_index in trotter.step()] + yoshida.terms += [(w0 * time, term_index) for time, term_index in trotter.step()] + yoshida.terms += [(w1 * time, term_index) for time, term_index in trotter.step()] + yoshida.reduce() # Combine consecutive terms that are the same + + return yoshida + + +def trotter_decomposition(num_terms: int, time: float) -> TrotterStep: + """ + Factory function for creating a first-order Trotter decomposition. + + The first-order Trotter-Suzuki formula for approximating time evolution + under a Hamiltonian represented as a sum of terms + + H = ∑_k H_k + + is obtained by sequentially applying each term for the full time + + e^{-i H t} ≈ ∏_k e^{-i H_k t}. Example: .. code-block:: python - >>> strang = StrangStep(num_terms=3, time=0.5) - >>> strang.get() - [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] + >>> trotter = first_order_trotter(num_terms=3, time=0.5) + >>> list(trotter.step()) + [(0.5, 0), (0.5, 1), (0.5, 2)] + + References: + H. F. Trotter, Proc. Amer. Math. Soc. 10, 545 (1959). """ + trotter = TrotterStep() + trotter.terms = [(time, term_index) for term_index in range(num_terms)] + trotter._nterms = num_terms + trotter._time_step = time + trotter._order = 1 + trotter._repr_string = f"FirstOrderTrotter(time_step={time}, num_terms={num_terms})" + return trotter - def __init__(self, num_terms: int, time: float): - """ - Initialize the Strang splitting. - Args: - num_terms: Number of terms in the Hamiltonian - time: Total time for the evolution - """ - super().__init__(num_terms, time) +def strang_splitting(num_terms: int, time: float) -> TrotterStep: + """ + Factory function for creating a Strang splitting (second-order + Trotter-Suzuki decomposition). - def get(self) -> list[tuple[float, int]]: - """ - Get the Strang splitting as a list of (time, term_index) tuples. + The second-order Trotter formula uses symmetric splitting: - Returns: - List of tuples where each tuple contains the time duration and the - index of the term to be applied. The sequence is symmetric for - second-order accuracy. - """ - terms = [] - # Forward sweep with half time steps - for term_index in range(self._num_terms - 1): - terms.append((self._time_step / 2.0, term_index)) + e^{-i H t} ≈ ∏_{k=1}^{n} e^{-i H_k t/2} ∏_{k=n}^{1} e^{-i H_k t/2} - # Combine the two middle terms - terms.append((self._time_step, self._num_terms - 1)) + This provides second-order accuracy in the time step, compared to + first-order for the basic Trotter decomposition. - # Backward sweep with half time steps - for term_index in range(self._num_terms - 2, -1, -1): - terms.append((self._time_step / 2.0, term_index)) + Example: - return terms + .. code-block:: python + >>> strang = strang_splitting(num_terms=3, time=0.5) + >>> list(strang.step()) + [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] - def __str__(self) -> str: - """String representation of the Strang splitting.""" - return f"Strang(time_step={self._time_step}, num_terms={self._num_terms})" + References: + G. Strang, SIAM J. Numer. Anal. 5, 506 (1968). + """ + strang = TrotterStep() + strang._nterms = num_terms + strang._time_step = time + strang._order = 2 + strang._repr_string = f"StrangSplitting(time_step={time}, num_terms={num_terms})" + strang.terms = [] + for term_index in range(num_terms - 1): + strang.terms.append((time / 2, term_index)) + strang.terms.append((time, num_terms - 1)) + for term_index in reversed(range(num_terms - 1)): + strang.terms.append((time / 2, term_index)) + return strang + + +def fourth_order_trotter_suzuki(num_terms: int, time: float) -> TrotterStep: + """ + Factory function for creating a fourth-order Trotter-Suzuki decomposition + using Suzuki recursion. + + This is obtained by applying one level of Suzuki recursion to the second-order + Strang splitting. The resulting fourth-order decomposition has improved accuracy + compared to the second-order Strang splitting, at the cost of more exponential + applications per step. + + Example: + + .. code-block:: python + >>> fourth_order = fourth_order_trotter_suzuki(num_terms=3, time=0.5) + >>> list(fourth_order.step()) + [(0.1767766952966369, 0), (0.1767766952966369, 1), (0.1767766952966369, 2), (0.3535533905932738, 1), (0.3535533905932738, 0), (0.1767766952966369, 1), (0.1767766952966369, 2), (0.1767766952966369, 1), (0.1767766952966369, 0)] + """ + return suzuki_recursion(strang_splitting(num_terms, time)) class TrotterExpansion: """ - Trotter expansion class for multiple Trotter steps. This class wraps around - a TrotterStep instance and specifies how many times to repeat this Trotter - step. The expansion can be used to represent the full time evolution - as a sequence of Trotter steps + Trotter expansion for repeated application of a Trotter step. + + This class wraps a TrotterStep instance and specifies how many times to repeat + the step. The expansion represents full time evolution as a sequence of + Trotter steps: - e^{-i H t} ≈ (∏_k e^{-i H_k t/n})^n. + e^{-i H T} ≈ (S(T/n))^n - where n is the number of Trotter steps. + where S is the Trotter step formula, T is the total time, and n is the number + of steps. Example: .. code-block:: python >>> n = 4 # Number of Trotter steps >>> total_time = 1.0 # Total time - >>> trotter_expansion = TrotterExpansion(TrotterStep(2, total_time/n), n) - >>> trotter_expansion.get() - [([(0.25, 0), (0.25, 1)], 4)] + >>> step = trotter_decomposition(num_terms=2, time=total_time/n) + >>> expansion = TrotterExpansion(step, n) + >>> expansion.order + 1 + >>> expansion.total_time + 1.0 + >>> list(expansion.step())[:4] + [(0.25, 0), (0.25, 1), (0.25, 0), (0.25, 1)] """ def __init__(self, trotter_step: TrotterStep, num_steps: int): @@ -139,18 +301,62 @@ def __init__(self, trotter_step: TrotterStep, num_steps: int): Initialize the Trotter expansion. Args: - trotter_step: An instance of TrotterStep representing a single Trotter step - num_steps: Number of Trotter steps + trotter_step: An instance of TrotterStep representing a single Trotter step. + num_steps: Number of times to repeat the Trotter step. """ self._trotter_step = trotter_step self._num_steps = num_steps + @property + def order(self) -> int: + """Get the order of the underlying Trotter step.""" + return self._trotter_step.order + + @property + def nterms(self) -> int: + """Get the number of Hamiltonian terms.""" + return self._trotter_step.nterms + + @property + def num_steps(self) -> int: + """Get the number of Trotter steps.""" + return self._num_steps + + @property + def total_time(self) -> float: + """Get the total evolution time (time_step * num_steps).""" + return self._trotter_step.time_step * self._num_steps + + def step(self) -> Iterator[tuple[float, int]]: + """ + Iterate over the full Trotter expansion. + + Yields all (time, term_index) tuples for the complete expansion, + repeating the Trotter step sequence num_steps times. + + Returns: + Iterator of (time, term_index) tuples for the full evolution. + """ + for _ in range(self._num_steps): + yield from self._trotter_step.step() + def get(self) -> list[tuple[list[tuple[float, int]], int]]: """ - Get the Trotter expansion as a list of (terms, step_index) tuples. + Get the Trotter expansion as a compact representation. Returns: - List of tuples where each tuple contains the list of (time, term_index) - for that step and the number of times that step is executed. + List containing a single tuple of (terms, num_steps) where terms + is the list of (time, term_index) for one step. """ - return [(self._trotter_step.get(), self._num_steps)] + return [(list(self._trotter_step.step()), self._num_steps)] + + def __str__(self) -> str: + """String representation of the Trotter expansion.""" + return ( + f"TrotterExpansion(order={self.order}, num_steps={self._num_steps}, " + f"total_time={self.total_time}, num_terms={self.nterms})" + ) + + def __repr__(self) -> str: + """Repr representation of the Trotter expansion.""" + return f"TrotterExpansion({self._trotter_step!r}, num_steps={self._num_steps})" diff --git a/source/pip/tests/magnets/test_trotter.py b/source/pip/tests/magnets/test_trotter.py index bd26ed8f72..db981a9c16 100644 --- a/source/pip/tests/magnets/test_trotter.py +++ b/source/pip/tests/magnets/test_trotter.py @@ -1,151 +1,188 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Unit tests for Trotter-Suzuki decomposition classes.""" +"""Unit tests for Trotter-Suzuki decomposition classes and factory functions.""" -from qsharp.magnets.trotter import TrotterStep, StrangStep, TrotterExpansion +from qsharp.magnets.trotter import ( + TrotterStep, + TrotterExpansion, + trotter_decomposition, + strang_splitting, + suzuki_recursion, + yoshida_recursion, + fourth_order_trotter_suzuki, +) -# TrotterStep tests +# TrotterStep base class tests -def test_trotter_step_init_basic(): - """Test basic TrotterStep initialization.""" - trotter = TrotterStep(num_terms=3, time=0.5) - assert trotter._num_terms == 3 - assert trotter._time_step == 0.5 +def test_trotter_step_empty_init(): + """Test that TrotterStep initializes as empty.""" + trotter = TrotterStep() + assert trotter.nterms == 0 + assert trotter.time_step == 0.0 + assert trotter.order == 0 + assert list(trotter.step()) == [] -def test_trotter_step_get_single_term(): - """Test TrotterStep with a single term.""" - trotter = TrotterStep(num_terms=1, time=1.0) - result = trotter.get() +def test_trotter_step_reduce_combines_consecutive(): + """Test that reduce combines consecutive same-term entries.""" + trotter = TrotterStep() + trotter.terms = [(0.5, 0), (0.5, 0), (0.5, 1)] + trotter.reduce() + assert list(trotter.step()) == [(1.0, 0), (0.5, 1)] + + +def test_trotter_step_reduce_no_change_when_different(): + """Test that reduce does not change non-consecutive same terms.""" + trotter = TrotterStep() + trotter.terms = [(0.5, 0), (0.5, 1), (0.5, 0)] + trotter.reduce() + assert list(trotter.step()) == [(0.5, 0), (0.5, 1), (0.5, 0)] + + +def test_trotter_step_reduce_empty(): + """Test that reduce handles empty terms.""" + trotter = TrotterStep() + trotter.reduce() + assert list(trotter.step()) == [] + + +# trotter_decomposition factory tests + + +def test_trotter_decomposition_basic(): + """Test basic trotter_decomposition creation.""" + trotter = trotter_decomposition(num_terms=3, time=0.5) + assert trotter.nterms == 3 + assert trotter.time_step == 0.5 + assert trotter.order == 1 + + +def test_trotter_decomposition_single_term(): + """Test trotter_decomposition with a single term.""" + trotter = trotter_decomposition(num_terms=1, time=1.0) + result = list(trotter.step()) assert result == [(1.0, 0)] -def test_trotter_step_get_multiple_terms(): - """Test TrotterStep with multiple terms.""" - trotter = TrotterStep(num_terms=3, time=0.5) - result = trotter.get() +def test_trotter_decomposition_multiple_terms(): + """Test trotter_decomposition with multiple terms.""" + trotter = trotter_decomposition(num_terms=3, time=0.5) + result = list(trotter.step()) assert result == [(0.5, 0), (0.5, 1), (0.5, 2)] -def test_trotter_step_get_zero_time(): - """Test TrotterStep with zero time.""" - trotter = TrotterStep(num_terms=2, time=0.0) - result = trotter.get() +def test_trotter_decomposition_zero_time(): + """Test trotter_decomposition with zero time.""" + trotter = trotter_decomposition(num_terms=2, time=0.0) + result = list(trotter.step()) assert result == [(0.0, 0), (0.0, 1)] -def test_trotter_step_get_returns_all_terms(): - """Test that TrotterStep returns all term indices.""" +def test_trotter_decomposition_returns_all_terms(): + """Test that trotter_decomposition returns all term indices.""" num_terms = 5 - trotter = TrotterStep(num_terms=num_terms, time=1.0) - result = trotter.get() + trotter = trotter_decomposition(num_terms=num_terms, time=1.0) + result = list(trotter.step()) assert len(result) == num_terms term_indices = [idx for _, idx in result] assert term_indices == list(range(num_terms)) -def test_trotter_step_get_uniform_time(): - """Test that all terms have the same time in TrotterStep.""" +def test_trotter_decomposition_uniform_time(): + """Test that all terms have the same time in trotter_decomposition.""" time = 0.25 - trotter = TrotterStep(num_terms=4, time=time) - result = trotter.get() + trotter = trotter_decomposition(num_terms=4, time=time) + result = list(trotter.step()) for t, _ in result: assert t == time -def test_trotter_step_str(): - """Test string representation of TrotterStep.""" - trotter = TrotterStep(num_terms=3, time=0.5) +def test_trotter_decomposition_str(): + """Test string representation of trotter_decomposition result.""" + trotter = trotter_decomposition(num_terms=3, time=0.5) result = str(trotter) - assert "Trotter" in result - assert "0.5" in result - assert "3" in result + assert "order" in result.lower() or "1" in result -def test_trotter_step_repr(): - """Test repr representation of TrotterStep.""" - trotter = TrotterStep(num_terms=3, time=0.5) - assert repr(trotter) == str(trotter) +def test_trotter_decomposition_repr(): + """Test repr representation of trotter_decomposition result.""" + trotter = trotter_decomposition(num_terms=3, time=0.5) + assert "FirstOrderTrotter" in repr(trotter) -# StrangStep tests +# strang_splitting factory tests -def test_strang_step_init_basic(): - """Test basic StrangStep initialization.""" - strang = StrangStep(num_terms=3, time=0.5) - assert strang._num_terms == 3 - assert strang._time_step == 0.5 +def test_strang_splitting_basic(): + """Test basic strang_splitting creation.""" + strang = strang_splitting(num_terms=3, time=0.5) + assert strang.nterms == 3 + assert strang.time_step == 0.5 + assert strang.order == 2 -def test_strang_step_inherits_trotter(): - """Test that StrangStep inherits from TrotterStep.""" - strang = StrangStep(num_terms=3, time=0.5) - assert isinstance(strang, TrotterStep) - - -def test_strang_step_get_single_term(): - """Test StrangStep with a single term.""" - strang = StrangStep(num_terms=1, time=1.0) - result = strang.get() +def test_strang_splitting_single_term(): + """Test strang_splitting with a single term.""" + strang = strang_splitting(num_terms=1, time=1.0) + result = list(strang.step()) # Single term: just full time on term 0 assert result == [(1.0, 0)] -def test_strang_step_get_two_terms(): - """Test StrangStep with two terms.""" - strang = StrangStep(num_terms=2, time=1.0) - result = strang.get() +def test_strang_splitting_two_terms(): + """Test strang_splitting with two terms.""" + strang = strang_splitting(num_terms=2, time=1.0) + result = list(strang.step()) # Forward: half on term 0, full on term 1, backward: half on term 0 assert result == [(0.5, 0), (1.0, 1), (0.5, 0)] -def test_strang_step_get_three_terms(): - """Test StrangStep with three terms (example from docstring).""" - strang = StrangStep(num_terms=3, time=0.5) - result = strang.get() +def test_strang_splitting_three_terms(): + """Test strang_splitting with three terms (example from docstring).""" + strang = strang_splitting(num_terms=3, time=0.5) + result = list(strang.step()) expected = [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] assert result == expected -def test_strang_step_symmetric(): - """Test that StrangStep produces symmetric sequence.""" - strang = StrangStep(num_terms=4, time=1.0) - result = strang.get() +def test_strang_splitting_symmetric(): + """Test that strang_splitting produces symmetric sequence.""" + strang = strang_splitting(num_terms=4, time=1.0) + result = list(strang.step()) # Check symmetry: term indices should be palindromic term_indices = [idx for _, idx in result] assert term_indices == term_indices[::-1] -def test_strang_step_time_sum(): - """Test that total time in StrangStep equals expected value.""" +def test_strang_splitting_time_sum(): + """Test that total time in strang_splitting equals expected value.""" time = 1.0 num_terms = 3 - strang = StrangStep(num_terms=num_terms, time=time) - result = strang.get() + strang = strang_splitting(num_terms=num_terms, time=time) + result = list(strang.step()) total_time = sum(t for t, _ in result) # Each term appears once with full time equivalent # (half + half for outer terms, full for middle) assert abs(total_time - time * num_terms) < 1e-10 -def test_strang_step_middle_term_full_time(): +def test_strang_splitting_middle_term_full_time(): """Test that the middle term gets full time step.""" - strang = StrangStep(num_terms=5, time=2.0) - result = strang.get() + strang = strang_splitting(num_terms=5, time=2.0) + result = list(strang.step()) # Middle term (index 4, the last term) should have full time middle_entries = [(t, idx) for t, idx in result if idx == 4] assert len(middle_entries) == 1 assert middle_entries[0][0] == 2.0 -def test_strang_step_outer_terms_half_time(): +def test_strang_splitting_outer_terms_half_time(): """Test that outer terms get half time steps.""" - strang = StrangStep(num_terms=4, time=2.0) - result = strang.get() + strang = strang_splitting(num_terms=4, time=2.0) + result = list(strang.step()) # Term 0 should appear twice with half time each term_0_entries = [(t, idx) for t, idx in result if idx == 0] assert len(term_0_entries) == 2 @@ -153,13 +190,147 @@ def test_strang_step_outer_terms_half_time(): assert t == 1.0 -def test_strang_step_str(): - """Test string representation of StrangStep.""" - strang = StrangStep(num_terms=3, time=0.5) - result = str(strang) - assert "Strang" in result - assert "0.5" in result - assert "3" in result +def test_strang_splitting_repr(): + """Test repr representation of strang_splitting result.""" + strang = strang_splitting(num_terms=3, time=0.5) + assert "StrangSplitting" in repr(strang) + + +# suzuki_recursion tests + + +def test_suzuki_recursion_from_strang(): + """Test Suzuki recursion applied to Strang splitting produces 4th order.""" + strang = strang_splitting(num_terms=2, time=1.0) + suzuki = suzuki_recursion(strang) + assert suzuki.order == 4 + assert suzuki.nterms == 2 + assert suzuki.time_step == 1.0 + + +def test_suzuki_recursion_from_first_order(): + """Test Suzuki recursion applied to first-order Trotter produces 3rd order.""" + trotter = trotter_decomposition(num_terms=2, time=1.0) + suzuki = suzuki_recursion(trotter) + assert suzuki.order == 3 + assert suzuki.nterms == 2 + + +def test_suzuki_recursion_preserves_nterms(): + """Test that Suzuki recursion preserves number of terms.""" + base = strang_splitting(num_terms=5, time=0.5) + suzuki = suzuki_recursion(base) + assert suzuki.nterms == base.nterms + + +def test_suzuki_recursion_preserves_time_step(): + """Test that Suzuki recursion preserves time step.""" + base = strang_splitting(num_terms=3, time=0.75) + suzuki = suzuki_recursion(base) + assert suzuki.time_step == base.time_step + + +def test_suzuki_recursion_repr(): + """Test repr of Suzuki recursion result.""" + base = strang_splitting(num_terms=2, time=1.0) + suzuki = suzuki_recursion(base) + assert "SuzukiRecursion" in repr(suzuki) + + +def test_suzuki_recursion_time_weights_sum(): + """Test that time weights in Suzuki recursion sum correctly.""" + base = trotter_decomposition(num_terms=2, time=1.0) + suzuki = suzuki_recursion(base) + # The total scaled time should equal the original total time * nterms + # because we're scaling times, not adding them + result = list(suzuki.step()) + total_time = sum(t for t, _ in result) + # For Suzuki: 5 copies scaled by p, p, (1-4p), p, p + # where weights sum to 4p + (1-4p) = 1, so total = base total + base_total = sum(t for t, _ in base.step()) + assert abs(total_time - base_total) < 1e-10 + + +# yoshida_recursion tests + + +def test_yoshida_recursion_from_strang(): + """Test Yoshida recursion applied to Strang splitting produces 4th order.""" + strang = strang_splitting(num_terms=2, time=1.0) + yoshida = yoshida_recursion(strang) + assert yoshida.order == 4 + assert yoshida.nterms == 2 + assert yoshida.time_step == 1.0 + + +def test_yoshida_recursion_from_first_order(): + """Test Yoshida recursion applied to first-order Trotter produces 3rd order.""" + trotter = trotter_decomposition(num_terms=2, time=1.0) + yoshida = yoshida_recursion(trotter) + assert yoshida.order == 3 + assert yoshida.nterms == 2 + + +def test_yoshida_recursion_preserves_nterms(): + """Test that Yoshida recursion preserves number of terms.""" + base = strang_splitting(num_terms=5, time=0.5) + yoshida = yoshida_recursion(base) + assert yoshida.nterms == base.nterms + + +def test_yoshida_recursion_preserves_time_step(): + """Test that Yoshida recursion preserves time step.""" + base = strang_splitting(num_terms=3, time=0.75) + yoshida = yoshida_recursion(base) + assert yoshida.time_step == base.time_step + + +def test_yoshida_recursion_repr(): + """Test repr of Yoshida recursion result.""" + base = strang_splitting(num_terms=2, time=1.0) + yoshida = yoshida_recursion(base) + assert "YoshidaRecursion" in repr(yoshida) + + +def test_yoshida_recursion_time_weights_sum(): + """Test that time weights in Yoshida recursion sum correctly.""" + base = trotter_decomposition(num_terms=2, time=1.0) + yoshida = yoshida_recursion(base) + # The total scaled time should equal the original total time * nterms + # because weights w1 + w0 + w1 = 2*w1 + w0 = 2*w1 + (1 - 2*w1) = 1 + result = list(yoshida.step()) + total_time = sum(t for t, _ in result) + base_total = sum(t for t, _ in base.step()) + assert abs(total_time - base_total) < 1e-10 + + +def test_yoshida_fewer_terms_than_suzuki(): + """Test that Yoshida produces fewer terms than Suzuki (3x vs 5x).""" + base = strang_splitting(num_terms=3, time=1.0) + suzuki = suzuki_recursion(base) + yoshida = yoshida_recursion(base) + # Yoshida uses 3 copies, Suzuki uses 5 copies + # After reduction, Yoshida should generally have fewer terms + assert len(list(yoshida.step())) <= len(list(suzuki.step())) + + +# fourth_order_trotter_suzuki tests + + +def test_fourth_order_trotter_suzuki_basic(): + """Test fourth_order_trotter_suzuki factory function.""" + fourth = fourth_order_trotter_suzuki(num_terms=2, time=1.0) + assert fourth.order == 4 + assert fourth.nterms == 2 + assert fourth.time_step == 1.0 + + +def test_fourth_order_trotter_suzuki_equals_suzuki_of_strang(): + """Test that fourth_order_trotter_suzuki equals suzuki_recursion(strang_splitting).""" + fourth = fourth_order_trotter_suzuki(num_terms=3, time=0.5) + manual = suzuki_recursion(strang_splitting(num_terms=3, time=0.5)) + assert list(fourth.step()) == list(manual.step()) + assert fourth.order == manual.order # TrotterExpansion tests @@ -167,7 +338,7 @@ def test_strang_step_str(): def test_trotter_expansion_init_basic(): """Test basic TrotterExpansion initialization.""" - step = TrotterStep(num_terms=2, time=0.25) + step = trotter_decomposition(num_terms=2, time=0.25) expansion = TrotterExpansion(step, num_steps=4) assert expansion._trotter_step is step assert expansion._num_steps == 4 @@ -175,7 +346,7 @@ def test_trotter_expansion_init_basic(): def test_trotter_expansion_get_single_step(): """Test TrotterExpansion with a single step.""" - step = TrotterStep(num_terms=2, time=1.0) + step = trotter_decomposition(num_terms=2, time=1.0) expansion = TrotterExpansion(step, num_steps=1) result = expansion.get() assert len(result) == 1 @@ -186,7 +357,7 @@ def test_trotter_expansion_get_single_step(): def test_trotter_expansion_get_multiple_steps(): """Test TrotterExpansion with multiple steps.""" - step = TrotterStep(num_terms=2, time=0.25) + step = trotter_decomposition(num_terms=2, time=0.25) expansion = TrotterExpansion(step, num_steps=4) result = expansion.get() assert len(result) == 1 @@ -195,15 +366,15 @@ def test_trotter_expansion_get_multiple_steps(): assert terms == [(0.25, 0), (0.25, 1)] -def test_trotter_expansion_with_strang_step(): - """Test TrotterExpansion using StrangStep.""" - step = StrangStep(num_terms=2, time=0.5) +def test_trotter_expansion_with_strang(): + """Test TrotterExpansion using strang_splitting.""" + step = strang_splitting(num_terms=2, time=0.5) expansion = TrotterExpansion(step, num_steps=2) result = expansion.get() assert len(result) == 1 terms, count = result[0] assert count == 2 - # StrangStep with 2 terms: [(0.25, 0), (0.5, 1), (0.25, 0)] + # strang_splitting with 2 terms: [(0.25, 0), (0.5, 1), (0.25, 0)] assert terms == [(0.25, 0), (0.5, 1), (0.25, 0)] @@ -211,7 +382,7 @@ def test_trotter_expansion_total_time(): """Test that total evolution time is correct.""" total_time = 1.0 num_steps = 4 - step = TrotterStep(num_terms=3, time=total_time / num_steps) + step = trotter_decomposition(num_terms=3, time=total_time / num_steps) expansion = TrotterExpansion(step, num_steps=num_steps) result = expansion.get() terms, count = result[0] @@ -224,18 +395,87 @@ def test_trotter_expansion_total_time(): def test_trotter_expansion_preserves_step(): """Test that expansion preserves the original step.""" - step = TrotterStep(num_terms=3, time=0.5) + step = trotter_decomposition(num_terms=3, time=0.5) expansion = TrotterExpansion(step, num_steps=10) result = expansion.get() terms, _ = result[0] - assert terms == step.get() + assert terms == list(step.step()) + + +def test_trotter_expansion_with_fourth_order(): + """Test TrotterExpansion with fourth-order Trotter-Suzuki.""" + step = fourth_order_trotter_suzuki(num_terms=2, time=0.25) + expansion = TrotterExpansion(step, num_steps=4) + result = expansion.get() + terms, count = result[0] + assert count == 4 + assert step.order == 4 + + +def test_trotter_expansion_order_property(): + """Test TrotterExpansion order property.""" + step = strang_splitting(num_terms=3, time=0.5) + expansion = TrotterExpansion(step, num_steps=4) + assert expansion.order == 2 + + +def test_trotter_expansion_nterms_property(): + """Test TrotterExpansion nterms property.""" + step = trotter_decomposition(num_terms=5, time=0.5) + expansion = TrotterExpansion(step, num_steps=4) + assert expansion.nterms == 5 + +def test_trotter_expansion_num_steps_property(): + """Test TrotterExpansion num_steps property.""" + step = trotter_decomposition(num_terms=2, time=0.25) + expansion = TrotterExpansion(step, num_steps=8) + assert expansion.num_steps == 8 -def test_trotter_expansion_docstring_example(): - """Test the example from the TrotterExpansion docstring.""" - n = 4 # Number of Trotter steps - total_time = 1.0 # Total time - trotter_expansion = TrotterExpansion(TrotterStep(2, total_time / n), n) - result = trotter_expansion.get() - expected = [([(0.25, 0), (0.25, 1)], 4)] + +def test_trotter_expansion_total_time_property(): + """Test TrotterExpansion total_time property.""" + step = trotter_decomposition(num_terms=2, time=0.25) + expansion = TrotterExpansion(step, num_steps=4) + assert expansion.total_time == 1.0 + + +def test_trotter_expansion_step_iterator(): + """Test TrotterExpansion step() iterator yields full expansion.""" + step = trotter_decomposition(num_terms=2, time=0.5) + expansion = TrotterExpansion(step, num_steps=3) + result = list(expansion.step()) + # Should yield 3 repetitions of [(0.5, 0), (0.5, 1)] + expected = [(0.5, 0), (0.5, 1), (0.5, 0), (0.5, 1), (0.5, 0), (0.5, 1)] assert result == expected + + +def test_trotter_expansion_step_iterator_with_strang(): + """Test TrotterExpansion step() with Strang splitting.""" + step = strang_splitting(num_terms=2, time=1.0) + expansion = TrotterExpansion(step, num_steps=2) + result = list(expansion.step()) + # Strang with 2 terms: [(0.5, 0), (1.0, 1), (0.5, 0)] + # Repeated twice + expected = [(0.5, 0), (1.0, 1), (0.5, 0), (0.5, 0), (1.0, 1), (0.5, 0)] + assert result == expected + + +def test_trotter_expansion_str(): + """Test TrotterExpansion string representation.""" + step = strang_splitting(num_terms=3, time=0.25) + expansion = TrotterExpansion(step, num_steps=4) + result = str(expansion) + assert "order=2" in result + assert "num_steps=4" in result + assert "total_time=1.0" in result + assert "num_terms=3" in result + + +def test_trotter_expansion_repr(): + """Test TrotterExpansion repr representation.""" + step = trotter_decomposition(num_terms=2, time=0.5) + expansion = TrotterExpansion(step, num_steps=4) + result = repr(expansion) + assert "TrotterExpansion" in result + assert "num_steps=4" in result From b0261828cf0a503dc0a1bef922cfa8eab26bcbdb Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Tue, 10 Feb 2026 09:38:26 +0100 Subject: [PATCH 13/45] New ISA and ISARequirements function (#2923) This PR provides some support for upcoming models by providing new features to ISAs and the Rust bindings in general - ISA instructions can have properties (e.g., the `LATTICE_SURGERY` instruction can have a distance property if it was generated by a surface code) - ISA requirement constraints can check if some properties are set, e.g., one can require that the `LATTICE_SURGERY` instruction also provides a `distance` property. - Variable arity instructions can provide space, time, and error_rate functions as generic Python functions. This can be slower than the built-in functions `const`, `linear`, and `block_linear`, but useful for cases that do not have an easy structure. The new function is called `generic` (`generic_function` in the API) - We expose a `binom_ppf` functions that works like `scipy.binom.ppf` and uses the `probability` crate that is already a dependency. This is not only faster but also does not require an additional dependency. It will be used to model round-based distillation factories. I missed to address some comments in the last PR. This is now done here. --- Cargo.lock | 1 + source/pip/benchmarks/bench_qre.py | 42 +++++++ source/pip/qsharp/qre/__init__.py | 4 + source/pip/qsharp/qre/_application.py | 7 +- source/pip/qsharp/qre/_instruction.py | 53 +++++++- source/pip/qsharp/qre/_qre.py | 2 + source/pip/qsharp/qre/_qre.pyi | 144 +++++++++++++++++++++- source/pip/qsharp/qre/_trace.py | 2 +- source/pip/qsharp/qre/instruction_ids.pyi | 4 + source/pip/src/qre.rs | 103 +++++++++++++++- source/pip/tests/test_qre.py | 126 ++++++++++++++++++- source/qre/Cargo.toml | 1 + source/qre/src/isa.rs | 99 ++++++++++++++- source/qre/src/isa/tests.rs | 16 +++ source/qre/src/lib.rs | 2 + source/qre/src/trace/instruction_ids.rs | 25 ++++ source/qre/src/utils.rs | 12 ++ 17 files changed, 625 insertions(+), 18 deletions(-) create mode 100644 source/qre/src/utils.rs diff --git a/Cargo.lock b/Cargo.lock index b7ab884a4f..05f76494e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1868,6 +1868,7 @@ name = "qre" version = "0.0.0" dependencies = [ "num-traits", + "probability", "rustc-hash", "serde", "thiserror 1.0.63", diff --git a/source/pip/benchmarks/bench_qre.py b/source/pip/benchmarks/bench_qre.py index 561aa2c0b4..f273cf5d0d 100644 --- a/source/pip/benchmarks/bench_qre.py +++ b/source/pip/benchmarks/bench_qre.py @@ -3,6 +3,7 @@ import timeit from dataclasses import dataclass, KW_ONLY, field +from qsharp.qre import linear_function, generic_function, instruction from qsharp.qre.models import AQREGateBased, SurfaceCode from qsharp.qre._enumeration import _enumerate_instances @@ -58,6 +59,47 @@ def bench_enumerate_isas(): print(f"Enumerating ISAs took {duration / number:.6f} seconds on average.") +def bench_function_evaluation_linear(): + fl = linear_function(12) + + inst = instruction(42, arity=None, space=fl, time=1, error_rate=1.0) + number = 1000 + duration = timeit.timeit( + "inst.space(5)", + globals={ + "inst": inst, + }, + number=number, + ) + + print( + f"Evaluating linear function took {duration / number:.6f} seconds on average." + ) + + +def bench_function_evaluation_generic(): + def func(arity: int) -> int: + return 12 * arity + + fg = generic_function(func) + + inst = instruction(42, arity=None, space=fg, time=1, error_rate=1.0) + number = 1000 + duration = timeit.timeit( + "inst.space(5)", + globals={ + "inst": inst, + }, + number=number, + ) + + print( + f"Evaluating linear function took {duration / number:.6f} seconds on average." + ) + + if __name__ == "__main__": bench_enumerate_instances() bench_enumerate_isas() + bench_function_evaluation_linear() + bench_function_evaluation_generic() diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index d6dbb24e29..2177e5ea9f 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -9,6 +9,7 @@ PHYSICAL, Encoding, ISATransform, + PropertyKey, constraint, instruction, ) @@ -25,6 +26,7 @@ Trace, block_linear_function, constant_function, + generic_function, linear_function, ) from ._trace import LatticeSurgery, PSSPC, TraceQuery @@ -44,6 +46,7 @@ "Encoding", "EstimationResult", "FactoryResult", + "generic_function", "InstructionFrontier", "ISA", "ISA_ROOT", @@ -52,6 +55,7 @@ "ISARequirements", "ISATransform", "LatticeSurgery", + "PropertyKey", "PSSPC", "QSharpApplication", "Trace", diff --git a/source/pip/qsharp/qre/_application.py b/source/pip/qsharp/qre/_application.py index 43e81ea4eb..61300fcf38 100644 --- a/source/pip/qsharp/qre/_application.py +++ b/source/pip/qsharp/qre/_application.py @@ -39,11 +39,10 @@ class Application(ABC, Generic[TraceParameters]): We distinguish between application and trace parameters. The application parameters define which particular instance of the application we want to - consider. The trace parameters define how to generate a trace. They - change the specific way in which we solve the problem, but not the problem - itself. + consider. The trace parameters define how to generate a trace. They change + the specific way in which we solve the problem, but not the problem itself. - For example, in quantum cryptography, the application parameters could + For example, in quantum cryptanalysis, the application parameters could define the key size for an RSA prime product, while the trace parameters define which algorithm to use to break the cryptography, as well as parameters therein. diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index 9c4b24260e..17b1dc6c39 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -24,6 +24,10 @@ class Encoding(IntEnum): LOGICAL = 1 +class PropertyKey(IntEnum): + DISTANCE = 0 + + PHYSICAL = Encoding.PHYSICAL LOGICAL = Encoding.LOGICAL @@ -34,6 +38,7 @@ def constraint( *, arity: Optional[int] = 1, error_rate: Optional[ConstraintBound] = None, + **kwargs: bool, ) -> Constraint: """ Creates an instruction constraint. @@ -44,11 +49,28 @@ def constraint( arity (Optional[int]): The instruction arity. If None, instruction is assumed to have variable arity. Default is 1. error_rate (Optional[ConstraintBound]): The constraint on the error rate. + **kwargs (bool): Required properties that matching instructions must have. + Valid property names: distance. Set to True to require the property. Returns: Constraint: The instruction constraint. + + Raises: + ValueError: If an unknown property name is provided in kwargs. """ - return Constraint(id, encoding, arity, error_rate) + c = Constraint(id, encoding, arity, error_rate) + + for key, value in kwargs.items(): + if value: + try: + prop_key = PropertyKey[key.upper()] + except KeyError: + raise ValueError( + f"Unknown property '{key}'. Valid properties: {[k.name.lower() for k in PropertyKey]}" + ) + c.add_property(prop_key) + + return c @overload @@ -61,6 +83,7 @@ def instruction( space: Optional[int] = None, length: Optional[int] = None, error_rate: float, + **kwargs: int, ) -> Instruction: ... @overload def instruction( @@ -69,9 +92,10 @@ def instruction( *, time: int | IntFunction, arity: None = ..., - space: Optional[IntFunction] = None, - length: Optional[IntFunction] = None, - error_rate: FloatFunction, + space: int | IntFunction, + length: Optional[int | IntFunction] = None, + error_rate: float | FloatFunction, + **kwargs: int, ) -> Instruction: ... def instruction( id: int, @@ -82,6 +106,7 @@ def instruction( space: Optional[int] | IntFunction = None, length: Optional[int | IntFunction] = None, error_rate: float | FloatFunction, + **kwargs: int, ) -> Instruction: """ Creates an instruction. @@ -98,12 +123,17 @@ def instruction( length (Optional[int | IntFunction]): The arity including ancilla qubits. If None, arity is used. error_rate (float | FloatFunction): The instruction error rate. + **kwargs (int): Additional properties to set on the instruction. + Valid property names: distance. Returns: Instruction: The instruction. + + Raises: + ValueError: If an unknown property name is provided in kwargs. """ if arity is not None: - return Instruction.fixed_arity( + instr = Instruction.fixed_arity( id, encoding, arity, @@ -122,7 +152,7 @@ def instruction( if isinstance(error_rate, float): error_rate = constant_function(error_rate) - return Instruction.variable_arity( + instr = Instruction.variable_arity( id, encoding, time, @@ -131,6 +161,17 @@ def instruction( length, ) + for key, value in kwargs.items(): + try: + prop_key = PropertyKey[key.upper()] + except KeyError: + raise ValueError( + f"Unknown property '{key}'. Valid properties: {[k.name.lower() for k in PropertyKey]}" + ) + instr.set_property(prop_key, value) + + return instr + class ISATransform(ABC): """ diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index 3fdd913414..dd721b344e 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -5,6 +5,7 @@ # pyright: reportAttributeAccessIssue=false from .._native import ( + binom_ppf, block_linear_function, Block, constant_function, @@ -15,6 +16,7 @@ EstimationResult, FactoryResult, FloatFunction, + generic_function, Instruction, InstructionFrontier, IntFunction, diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 85be2b136e..20424d52b6 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -2,7 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Any, Iterator, Optional, overload +from typing import Any, Callable, Iterator, Optional, overload class ISA: @overload @@ -18,6 +18,15 @@ class ISA: """ ... + def append(self, instruction: Instruction) -> None: + """ + Appends an instruction to the ISA. + + Args: + instruction (Instruction): The instruction to append. + """ + ... + def __add__(self, other: ISA) -> ISA: """ Concatenates two ISAs (logical union). Instructions in the second @@ -26,6 +35,18 @@ class ISA: """ ... + def __contains__(self, id: int) -> bool: + """ + Checks if the ISA contains an instruction with the given ID. + + Args: + id (int): The instruction ID. + + Returns: + bool: True if the ISA contains an instruction with the given ID, False otherwise. + """ + ... + def satisfies(self, requirements: ISARequirements) -> bool: """ Checks if the ISA satisfies the given ISA requirements. @@ -171,6 +192,18 @@ class Instruction: """ ... + def with_id(self, id: int) -> Instruction: + """ + Returns a copy of the instruction with the given ID. + + Args: + id (int): The instruction ID. + + Returns: + Instruction: A copy of the instruction with the given ID. + """ + ... + @property def id(self) -> int: """ @@ -273,6 +306,53 @@ class Instruction: """ ... + def set_property(self, key: int, value: int) -> None: + """ + Sets a property on the instruction. + + Args: + key (int): The property key. + value (int): The property value. + """ + ... + + def get_property(self, key: int) -> Optional[int]: + """ + Gets a property by its key. + + Args: + key (int): The property key. + + Returns: + Optional[int]: The property value, or None if not found. + """ + ... + + def has_property(self, key: int) -> bool: + """ + Checks if the instruction has a property with the given key. + + Args: + key (int): The property key. + + Returns: + bool: True if the instruction has the property, False otherwise. + """ + ... + + def get_property_or(self, key: int, default: int) -> int: + """ + Gets a property by its key, or returns a default value if not found. + + Args: + key (int): The property key. + default (int): The default value to return if the property is not found. + + Returns: + int: The property value, or the default value if not found. + """ + ... + def __str__(self) -> str: """ Returns a string representation of the instruction. @@ -381,6 +461,27 @@ class Constraint: """ ... + def add_property(self, property: int) -> None: + """ + Adds a property requirement to the constraint. + + Args: + property (int): The property key that must be present in matching instructions. + """ + ... + + def has_property(self, property: int) -> bool: + """ + Checks if the constraint requires a specific property. + + Args: + property (int): The property key to check. + + Returns: + bool: True if the constraint requires this property, False otherwise. + """ + ... + class IntFunction: ... class FloatFunction: ... @@ -439,6 +540,31 @@ def block_linear_function( """ ... +@overload +def generic_function(func: Callable[[int], int]) -> IntFunction: ... +@overload +def generic_function(func: Callable[[int], float]) -> FloatFunction: ... +def generic_function( + func: Callable[[int], int | float], +) -> IntFunction | FloatFunction: + """ + Creates a generic function from a Python callable. + + Note: + Only use this function if the other function constructors + (constant_function, linear_function, and block_linear_function) do not + meet your needs, as using a Python callable can have performance + implications. If using this function, keep the logic in the callable as + simple as possible to minimize overhead. + + Args: + func (Callable[[int], int | float]): The Python callable. + + Returns: + IntFunction | FloatFunction: The generic function. + """ + ... + class Property: def __new__(cls, value: Any) -> Property: """ @@ -905,6 +1031,15 @@ class InstructionFrontier: """ ... + def extend(self, points: list[Instruction]) -> None: + """ + Extends the frontier with a list of instructions. + + Args: + points (list[Instruction]): The instructions to insert. + """ + ... + def __len__(self) -> int: """ Returns the number of instructions in the frontier. @@ -960,3 +1095,10 @@ def estimate_parallel( EstimationCollection: The estimation collection. """ ... + +def binom_ppf(q: float, n: int, p: float) -> int: + """ + A replacement for SciPy's binom.ppf that is faster and does not require + SciPy as a dependency. + """ + ... diff --git a/source/pip/qsharp/qre/_trace.py b/source/pip/qsharp/qre/_trace.py index ab1d49f6ce..d57b30db76 100644 --- a/source/pip/qsharp/qre/_trace.py +++ b/source/pip/qsharp/qre/_trace.py @@ -24,7 +24,7 @@ def q(cls, **kwargs) -> TraceQuery: class PSSPC(TraceTransform): _: KW_ONLY num_ts_per_rotation: int = field( - default=10, metadata={"domain": list(range(1, 21))} + default=20, metadata={"domain": list(range(1, 21))} ) ccx_magic_states: bool = field(default=False) diff --git a/source/pip/qsharp/qre/instruction_ids.pyi b/source/pip/qsharp/qre/instruction_ids.pyi index 72934487f8..164e8d431c 100644 --- a/source/pip/qsharp/qre/instruction_ids.pyi +++ b/source/pip/qsharp/qre/instruction_ids.pyi @@ -75,6 +75,10 @@ RXX: int RYY: int RZZ: int +# Generic unitary gates +ONE_QUBIT_UNITARY: int +TWO_QUBIT_UNITARY: int + # Multi-qubit Pauli measurement MULTI_PAULI_MEAS: int diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index d9e870990c..54b2133f38 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use std::ptr::NonNull; +use std::{ptr::NonNull, sync::Arc}; use pyo3::{ IntoPyObjectExt, @@ -32,7 +32,9 @@ pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(constant_function, m)?)?; m.add_function(wrap_pyfunction!(linear_function, m)?)?; m.add_function(wrap_pyfunction!(block_linear_function, m)?)?; + m.add_function(wrap_pyfunction!(generic_function, m)?)?; m.add_function(wrap_pyfunction!(estimate_parallel, m)?)?; + m.add_function(wrap_pyfunction!(binom_ppf, m)?)?; m.add("EstimationError", m.py().get_type::())?; @@ -74,10 +76,18 @@ impl ISA { .map(ISA) } + pub fn append(&mut self, instruction: &Instruction) { + self.0.add_instruction(instruction.0.clone()); + } + pub fn __add__(&self, other: &ISA) -> PyResult { Ok(ISA(self.0.clone() + other.0.clone())) } + pub fn __contains__(&self, id: u64) -> bool { + self.0.contains(&id) + } + pub fn satisfies(&self, requirements: &ISARequirements) -> PyResult { Ok(self.0.satisfies(&requirements.0)) } @@ -211,6 +221,10 @@ impl Instruction { ))) } + pub fn with_id(&self, id: u64) -> Self { + Instruction(self.0.with_id(id)) + } + #[getter] pub fn id(&self) -> u64 { self.0.id() @@ -259,6 +273,23 @@ impl Instruction { Ok(self.0.expect_error_rate(arity)) } + pub fn set_property(&mut self, key: u64, value: u64) { + self.0.set_property(key, value); + } + + pub fn get_property(&self, key: u64) -> Option { + self.0.get_property(&key) + } + + pub fn has_property(&self, key: u64) -> bool { + self.0.has_property(&key) + } + + #[pyo3(signature = (key, default))] + pub fn get_property_or(&self, key: u64, default: u64) -> u64 { + self.0.get_property_or(&key, default) + } + fn __str__(&self) -> String { format!("{}", self.0) } @@ -301,6 +332,14 @@ impl Constraint { error_rate.map(|error_rate| error_rate.0), ))) } + + pub fn add_property(&mut self, property: u64) { + self.0.add_property(property); + } + + pub fn has_property(&self, property: u64) -> bool { + self.0.has_property(&property) + } } fn convert_encoding(encoding: u64) -> PyResult { @@ -448,6 +487,55 @@ pub fn block_linear_function<'py>( } } +#[pyfunction] +pub fn generic_function<'py>( + py: Python<'py>, + func: Bound<'py, PyAny>, +) -> PyResult> { + // Try to get return type annotation from the function + let is_int = if let Ok(annotations) = func.getattr("__annotations__") { + if let Ok(return_type) = annotations.get_item("return") { + // Check if return type is float + let float_type = py.get_type::(); + return_type.eq(float_type).unwrap_or(false) + } else { + false + } + } else { + false + }; + + let func: Py = func.unbind(); + + if is_int { + let closure = move |arity: u64| -> u64 { + Python::attach(|py| { + let result = func.call1(py, (arity,)); + match result { + Ok(value) => value.extract::(py).unwrap_or(0), + Err(_) => 0, + } + }) + }; + + let arc: Arc u64 + Send + Sync> = Arc::new(closure); + IntFunction(qre::VariableArityFunction::generic_from_arc(arc)).into_bound_py_any(py) + } else { + let closure = move |arity: u64| -> f64 { + Python::attach(|py| { + let result = func.call1(py, (arity,)); + match result { + Ok(value) => value.extract::(py).unwrap_or(0.0), + Err(_) => 0.0, + } + }) + }; + + let arc: Arc f64 + Send + Sync> = Arc::new(closure); + FloatFunction(qre::VariableArityFunction::generic_from_arc(arc)).into_bound_py_any(py) + } +} + #[derive(Default)] #[pyclass] pub struct EstimationCollection(qre::EstimationCollection); @@ -730,6 +818,12 @@ impl InstructionFrontier { self.0.insert(point.clone()); } + #[allow(clippy::needless_pass_by_value)] + pub fn extend(&mut self, points: Vec>) { + self.0 + .extend(points.iter().map(|p| Instruction(p.0.clone()))); + } + pub fn __len__(&self) -> usize { self.0.len() } @@ -787,6 +881,11 @@ pub fn estimate_parallel( EstimationCollection(collection) } +#[pyfunction] +pub fn binom_ppf(q: f64, n: usize, p: f64) -> usize { + qre::binom_ppf(q, n, p) +} + fn add_instruction_ids(m: &Bound<'_, PyModule>) -> PyResult<()> { #[allow(clippy::wildcard_imports)] use qre::instruction_ids::*; @@ -862,6 +961,8 @@ fn add_instruction_ids(m: &Bound<'_, PyModule>) -> PyResult<()> { RXX, RYY, RZZ, + ONE_QUBIT_UNITARY, + TWO_QUBIT_UNITARY, MULTI_PAULI_MEAS, LATTICE_SURGERY, READ_FROM_MEMORY, diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 98e1c9de59..05a3ffb63e 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -4,6 +4,7 @@ from dataclasses import KW_ONLY, dataclass, field from enum import Enum from typing import Generator +import pytest import qsharp from qsharp.qre import ( @@ -14,19 +15,25 @@ ISARequirements, ISATransform, LatticeSurgery, + PropertyKey, QSharpApplication, Trace, constraint, estimate, instruction, linear_function, + generic_function, +) +from qsharp.qre.models import ( + SurfaceCode, + AQREGateBased, ) -from qsharp.qre.models import SurfaceCode, AQREGateBased from qsharp.qre._isa_enumeration import ( ISARefNode, ) from qsharp.qre.instruction_ids import ( CCX, + CCZ, GENERIC, LATTICE_SURGERY, T, @@ -71,6 +78,123 @@ def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: ) +def test_isa(): + isa = ISA( + instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8, space=400), + instruction( + CCX, arity=3, encoding=LOGICAL, time=2000, error_rate=1e-10, space=800 + ), + ) + + assert T in isa + assert CCX in isa + assert LATTICE_SURGERY not in isa + + t_instr = isa[T] + assert t_instr.time() == 1000 + assert t_instr.error_rate() == 1e-8 + assert t_instr.space() == 400 + + assert len(isa) == 2 + ccz_instr = isa[CCX].with_id(CCZ) + assert ccz_instr.arity == 3 + assert ccz_instr.time() == 2000 + assert ccz_instr.error_rate() == 1e-10 + assert ccz_instr.space() == 800 + + isa.append(ccz_instr) + assert CCZ in isa + assert len(isa) == 3 + + isa.append(ccz_instr) + assert ( + len(isa) == 3 + ) # Appending the same instruction should not increase the number of instructions + + +def test_instruction_properties(): + # Test instruction with no properties + instr_no_props = instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8) + assert instr_no_props.get_property(PropertyKey.DISTANCE) is None + assert instr_no_props.has_property(PropertyKey.DISTANCE) is False + assert instr_no_props.get_property_or(PropertyKey.DISTANCE, 5) == 5 + + # Test instruction with valid property (distance) + instr_with_distance = instruction( + T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 + ) + assert instr_with_distance.get_property(PropertyKey.DISTANCE) == 9 + assert instr_with_distance.has_property(PropertyKey.DISTANCE) is True + assert instr_with_distance.get_property_or(PropertyKey.DISTANCE, 5) == 9 + + # Test instruction with invalid property name + with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): + instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8, invalid_prop=42) + + +def test_instruction_constraints(): + # Test constraint without properties + c_no_props = constraint(T, encoding=LOGICAL) + assert c_no_props.has_property(PropertyKey.DISTANCE) is False + + # Test constraint with valid property (distance=True) + c_with_distance = constraint(T, encoding=LOGICAL, distance=True) + assert c_with_distance.has_property(PropertyKey.DISTANCE) is True + + # Test constraint with distance=False (should not add the property) + c_distance_false = constraint(T, encoding=LOGICAL, distance=False) + assert c_distance_false.has_property(PropertyKey.DISTANCE) is False + + # Test constraint with invalid property name + with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): + constraint(T, encoding=LOGICAL, invalid_prop=True) + + # Test ISA.satisfies with property constraints + instr_no_dist = instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8) + instr_with_dist = instruction( + T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 + ) + + isa_no_dist = ISA(instr_no_dist) + isa_with_dist = ISA(instr_with_dist) + + reqs_no_prop = ISARequirements(constraint(T, encoding=LOGICAL)) + reqs_with_prop = ISARequirements(constraint(T, encoding=LOGICAL, distance=True)) + + # ISA without distance property + assert isa_no_dist.satisfies(reqs_no_prop) is True + assert isa_no_dist.satisfies(reqs_with_prop) is False + + # ISA with distance property + assert isa_with_dist.satisfies(reqs_no_prop) is True + assert isa_with_dist.satisfies(reqs_with_prop) is True + + +def test_generic_function(): + from qsharp.qre._qre import IntFunction, FloatFunction + + def time(x: int) -> int: + return x * x + + time_fn = generic_function(time) + assert isinstance(time_fn, IntFunction) + + def error_rate(x: int) -> float: + return x / 2.0 + + error_rate_fn = generic_function(error_rate) + assert isinstance(error_rate_fn, FloatFunction) + + # Without annotations, defaults to FloatFunction + space_fn = generic_function(lambda x: 12) + assert isinstance(space_fn, FloatFunction) + + i = instruction(42, arity=None, space=12, time=time_fn, error_rate=error_rate_fn) + assert i.space(5) == 12 + assert i.time(5) == 25 + assert i.error_rate(5) == 2.5 + + def test_isa_from_architecture(): arch = AQREGateBased() code = SurfaceCode() diff --git a/source/qre/Cargo.toml b/source/qre/Cargo.toml index 88148dca7c..37f28e2ffc 100644 --- a/source/qre/Cargo.toml +++ b/source/qre/Cargo.toml @@ -11,6 +11,7 @@ license.workspace = true [dependencies] num-traits = { workspace = true } rustc-hash = { workspace = true } +probability = { workspace = true } serde = { workspace = true } thiserror = { workspace = true } diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index 310f375c56..a63de6bd74 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -4,10 +4,11 @@ use std::{ fmt::Display, ops::{Add, Deref, Index}, + sync::Arc, }; use num_traits::FromPrimitive; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; #[cfg(test)] @@ -35,6 +36,11 @@ impl ISA { self.instructions.get(id) } + #[must_use] + pub fn contains(&self, id: &u64) -> bool { + self.instructions.contains_key(id) + } + #[must_use] pub fn satisfies(&self, requirements: &ISARequirements) -> bool { for constraint in requirements.constraints.values() { @@ -79,6 +85,13 @@ impl ISA { } } } + + // Check that all required properties are present in the instruction + for prop in &constraint.properties { + if !instruction.has_property(prop) { + return false; + } + } } true } @@ -164,6 +177,7 @@ pub struct Instruction { id: u64, encoding: Encoding, metrics: Metrics, + properties: Option>, } impl Instruction { @@ -190,6 +204,7 @@ impl Instruction { time, error_rate, }, + properties: None, } } @@ -213,9 +228,17 @@ impl Instruction { time_fn, error_rate_fn, }, + properties: None, } } + #[must_use] + pub fn with_id(&self, id: u64) -> Self { + let mut new_instruction = self.clone(); + new_instruction.id = id; + new_instruction + } + #[must_use] pub fn id(&self) -> u64 { self.id @@ -291,6 +314,33 @@ impl Instruction { self.error_rate(arity) .expect("Instruction does not support variable arity") } + + pub fn set_property(&mut self, key: u64, value: u64) { + if let Some(ref mut properties) = self.properties { + properties.insert(key, value); + } else { + let mut properties = FxHashMap::default(); + properties.insert(key, value); + self.properties = Some(properties); + } + } + + #[must_use] + pub fn get_property(&self, key: &u64) -> Option { + self.properties.as_ref()?.get(key).copied() + } + + #[must_use] + pub fn has_property(&self, key: &u64) -> bool { + self.properties + .as_ref() + .is_some_and(|props| props.contains_key(key)) + } + + #[must_use] + pub fn get_property_or(&self, key: &u64, default: u64) -> u64 { + self.get_property(key).unwrap_or(default) + } } impl Display for Instruction { @@ -310,6 +360,7 @@ pub struct InstructionConstraint { encoding: Encoding, arity: Option, error_rate_fn: Option>, + properties: FxHashSet, } impl InstructionConstraint { @@ -325,8 +376,26 @@ impl InstructionConstraint { encoding, arity, error_rate_fn, + properties: FxHashSet::default(), } } + + /// Adds a property requirement to the constraint. + pub fn add_property(&mut self, property: u64) { + self.properties.insert(property); + } + + /// Checks if the constraint requires a specific property. + #[must_use] + pub fn has_property(&self, property: &u64) -> bool { + self.properties.contains(property) + } + + /// Returns the set of required properties. + #[must_use] + pub fn properties(&self) -> &FxHashSet { + &self.properties + } } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -355,9 +424,20 @@ pub enum Metrics { #[derive(Clone, Serialize, Deserialize)] pub enum VariableArityFunction { - Constant { value: T }, - Linear { slope: T }, - BlockLinear { block_size: u64, slope: T }, + Constant { + value: T, + }, + Linear { + slope: T, + }, + BlockLinear { + block_size: u64, + slope: T, + }, + #[serde(skip)] + Generic { + func: Arc T + Send + Sync>, + }, } impl + std::ops::Mul + Copy + FromPrimitive> @@ -375,6 +455,16 @@ impl + std::ops::Mul + Copy + FromPrimitive> VariableArityFunction::BlockLinear { block_size, slope } } + pub fn generic(func: impl Fn(u64) -> T + Send + Sync + 'static) -> Self { + VariableArityFunction::Generic { + func: Arc::new(func), + } + } + + pub fn generic_from_arc(func: Arc T + Send + Sync>) -> Self { + VariableArityFunction::Generic { func } + } + pub fn evaluate(&self, arity: u64) -> T { match self { VariableArityFunction::Constant { value } => *value, @@ -385,6 +475,7 @@ impl + std::ops::Mul + Copy + FromPrimitive> let blocks = arity.div_ceil(*block_size); *slope * T::from_u64(blocks).expect("Failed to convert u64 to target type") } + VariableArityFunction::Generic { func } => func(arity), } } } diff --git a/source/qre/src/isa/tests.rs b/source/qre/src/isa/tests.rs index d71ae1b902..802a76f30c 100644 --- a/source/qre/src/isa/tests.rs +++ b/source/qre/src/isa/tests.rs @@ -134,3 +134,19 @@ fn test_variable_arity_satisfies() { )); assert!(!isa.satisfies(&reqs_fail)); // 0.02 not < 0.01 } + +#[test] +fn test_variable_arity_function() { + let linear_fn = VariableArityFunction::linear(10); + assert_eq!(linear_fn.evaluate(3), 30); + assert_eq!(linear_fn.evaluate(0), 0); + + let constant_fn = VariableArityFunction::constant(5); + assert_eq!(constant_fn.evaluate(3), 5); + assert_eq!(constant_fn.evaluate(0), 5); + + // Test with a custom function + let custom_fn = VariableArityFunction::generic(|arity| arity * arity); // Quadratic + assert_eq!(custom_fn.evaluate(3), 9); + assert_eq!(custom_fn.evaluate(4), 16); +} diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs index 4baa1f9e13..b8bd2db55a 100644 --- a/source/qre/src/lib.rs +++ b/source/qre/src/lib.rs @@ -17,6 +17,8 @@ pub use isa::{ }; pub use trace::instruction_ids; pub use trace::{Block, LatticeSurgery, PSSPC, Property, Trace, TraceTransform, estimate_parallel}; +mod utils; +pub use utils::binom_ppf; /// A resourc estimation error. #[derive(Clone, Debug, Error, PartialEq)] diff --git a/source/qre/src/trace/instruction_ids.rs b/source/qre/src/trace/instruction_ids.rs index f8f78bc958..49457b3c79 100644 --- a/source/qre/src/trace/instruction_ids.rs +++ b/source/qre/src/trace/instruction_ids.rs @@ -5,10 +5,13 @@ // - add them to `add_instruction_ids` in qre.rs // - add them to instruction_ids.pyi +// Paulis pub const PAULI_I: u64 = 0x0; pub const PAULI_X: u64 = 0x1; pub const PAULI_Y: u64 = 0x2; pub const PAULI_Z: u64 = 0x3; + +// Clifford gates pub const H: u64 = 0x10; pub const H_XZ: u64 = 0x10; pub const H_XY: u64 = 0x11; @@ -26,12 +29,18 @@ pub const CX: u64 = 0x19; pub const CY: u64 = 0x1A; pub const CZ: u64 = 0x1B; pub const SWAP: u64 = 0x1C; + +// State preparation pub const PREP_X: u64 = 0x30; pub const PREP_Y: u64 = 0x31; pub const PREP_Z: u64 = 0x32; + +// Generic Cliffords pub const ONE_QUBIT_CLIFFORD: u64 = 0x50; pub const TWO_QUBIT_CLIFFORD: u64 = 0x51; pub const N_QUBIT_CLIFFORD: u64 = 0x52; + +// Measurements pub const MEAS_X: u64 = 0x100; pub const MEAS_Y: u64 = 0x101; pub const MEAS_Z: u64 = 0x102; @@ -44,6 +53,8 @@ pub const MEAS_ZZ: u64 = 0x108; pub const MEAS_XZ: u64 = 0x109; pub const MEAS_XY: u64 = 0x10A; pub const MEAS_YZ: u64 = 0x10B; + +// Non-Clifford gates pub const SQRT_SQRT_X: u64 = 0x400; pub const SQRT_SQRT_X_DAG: u64 = 0x401; pub const SQRT_SQRT_Y: u64 = 0x402; @@ -67,11 +78,25 @@ pub const CRZ: u64 = 0x411; pub const RXX: u64 = 0x412; pub const RYY: u64 = 0x413; pub const RZZ: u64 = 0x414; + +// Generic unitaries +pub const ONE_QUBIT_UNITARY: u64 = 0x500; +pub const TWO_QUBIT_UNITARY: u64 = 0x501; + +// Multi-qubit Pauli measurement pub const MULTI_PAULI_MEAS: u64 = 0x1000; + +// Some generic logical instructions pub const LATTICE_SURGERY: u64 = 0x1100; + +// Memory/compute operations (used in compute parts of memory-compute layouts) pub const READ_FROM_MEMORY: u64 = 0x1200; pub const WRITE_TO_MEMORY: u64 = 0x1201; + +// Some special hardware physical instructions pub const CYCLIC_SHIFT: u64 = 0x1300; + +// Generic operation (for unified RE) pub const GENERIC: u64 = 0xFFFF; #[must_use] diff --git a/source/qre/src/utils.rs b/source/qre/src/utils.rs new file mode 100644 index 0000000000..ea6dde8623 --- /dev/null +++ b/source/qre/src/utils.rs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use probability::prelude::{Binomial, Inverse}; + +#[allow(clippy::doc_markdown)] +/// Faster implementation of SciPy's binom.ppf +#[must_use] +pub fn binom_ppf(q: f64, n: usize, p: f64) -> usize { + let dist = Binomial::with_failure(n, 1.0 - p); + dist.inverse(q) +} From 289a181e19efe898d90667b4b5f25654f356be31 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Fri, 13 Feb 2026 18:08:46 +0100 Subject: [PATCH 14/45] Track instruction origin (#2933) The main change of this PR is the introduction of a provenance graph that keeps track of how instructions in ISAs are build from other instructions from other ISAs down to the architecture's ISA. We can then build an `InstructionSource` (also a graph) that describes the dependencies with respect to optimal resource estimation results such that we can track which instructions were used to estimate the trace and what properties they have. Other changes include: - A function `instruction_name` to turn an instruction ID into a name, e.g. `instruction_name(CNOT) == "CNOT"`, as well as using them in `__str__` (`Display`) implementations for ISAs and traces. - Prefixing some Rust bindings with an underscore when exported to the Python package to emphasize that they are private and are not exported by the `qsharp.qre` module. --- source/pip/qsharp/qre/__init__.py | 2 + source/pip/qsharp/qre/_architecture.py | 60 ++++- source/pip/qsharp/qre/_estimation.py | 52 +++- source/pip/qsharp/qre/_instruction.py | 154 +++++++++-- source/pip/qsharp/qre/_isa_enumeration.py | 4 +- source/pip/qsharp/qre/_qre.py | 16 +- source/pip/qsharp/qre/_qre.pyi | 255 +++++++++++++----- source/pip/qsharp/qre/instruction_ids.pyi | 1 + .../qsharp/qre/models/qec/_surface_code.py | 55 ++-- source/pip/src/qre.rs | 73 ++++- source/pip/tests/test_qre.py | 15 +- source/qre/src/isa.rs | 101 ++++++- source/qre/src/isa/tests.rs | 55 ++++ source/qre/src/lib.rs | 3 +- source/qre/src/result.rs | 12 +- source/qre/src/trace.rs | 6 +- source/qre/src/trace/instruction_ids.rs | 242 ++++++++++------- source/qre/src/trace/instruction_ids/tests.rs | 26 ++ source/qre/src/trace/tests.rs | 59 ++++ 19 files changed, 935 insertions(+), 256 deletions(-) create mode 100644 source/qre/src/trace/instruction_ids/tests.rs diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index 2177e5ea9f..90bf5bda00 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -28,6 +28,7 @@ constant_function, generic_function, linear_function, + instruction_name, ) from ._trace import LatticeSurgery, PSSPC, TraceQuery @@ -47,6 +48,7 @@ "EstimationResult", "FactoryResult", "generic_function", + "instruction_name", "InstructionFrontier", "ISA", "ISA_ROOT", diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index fe991aff42..ce69e7d7c1 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -2,11 +2,15 @@ # Licensed under the MIT License. from __future__ import annotations +import copy +from typing import TYPE_CHECKING from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from ._qre import ISA +from ._qre import ISA, _ProvenanceGraph, _Instruction + +if TYPE_CHECKING: + from ._instruction import ISATransform class Architecture(ABC): @@ -16,22 +20,56 @@ def provided_isa(self) -> ISA: ... def context(self) -> _Context: """Create a new enumeration context for this architecture.""" - return _Context(self.provided_isa) + return _Context(self) -@dataclass(slots=True, frozen=True) class _Context: """ Context passed through enumeration, holding shared state. - - Attributes: - root_isa: The root ISA for enumeration. """ - root_isa: ISA - _bindings: dict[str, ISA] = field(default_factory=dict, repr=False) + def __init__(self, arch: Architecture): + self._provenance: _ProvenanceGraph = _ProvenanceGraph() + + def _mark_instruction(inst: _Instruction) -> _Instruction: + node = self._provenance.add_node(inst.id, 0, []) + inst.set_source(node) + return inst + + self._isa = ISA([_mark_instruction(instr) for instr in arch.provided_isa]) + + self._bindings: dict[str, ISA] = {} + self._transforms: dict[int, Architecture | ISATransform] = {0: arch} def _with_binding(self, name: str, isa: ISA) -> _Context: """Return a new context with an additional binding (internal use).""" - new_bindings = {**self._bindings, name: isa} - return _Context(self.root_isa, new_bindings) + ctx = copy.copy(self) + ctx._bindings = {**self._bindings, name: isa} + return ctx + + def set_source( + self, + transform: ISATransform, + instruction: _Instruction, + source_instructions: list[_Instruction], + ) -> _Instruction: + """ + Record the provenance of an instruction generated by a transform, and + return the instruction with its source set. + + Args: + transform: The transform that generated the instruction. + instruction: The instruction whose provenance is being recorded. + source_instructions: The instructions that were used as input to the + transform to generate this instruction. + + Returns: + The input instruction with its source set to the provenance node. + """ + + source = self._provenance.add_node( + instruction.id, id(transform), [inst.source for inst in source_instructions] + ) + + instruction.set_source(source) + return instruction diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 79b11b9eb7..17a8330237 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -1,10 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + +from dataclasses import dataclass, field + from ._application import Application from ._architecture import Architecture -from ._qre import EstimationCollection, estimate_parallel +from ._qre import _estimate_parallel from ._trace import TraceQuery +from ._instruction import InstructionSource from ._isa_enumeration import ISAQuery @@ -15,7 +20,7 @@ def estimate( isa_query: ISAQuery, *, max_error: float = 1.0, -) -> EstimationCollection: +) -> EstimationTable: """ Estimate the resource requirements for a given application instance and architecture. @@ -26,7 +31,7 @@ def estimate( architecture's ISA is transformed by the ISA query, which applies several ISA transforms in sequence, each of which may return multiple ISAs. The estimation is performed for each combination of transformed trace and ISA. - The results are collected into an EstimationCollection and returned. + The results are collected into an EstimationTable and returned. The collection only contains the results that are optimal with respect to the total number of qubits and the total runtime. @@ -39,14 +44,51 @@ def estimate( isa_query (ISAQuery): The ISA query to enumerate ISAs from the architecture. Returns: - EstimationCollection: A collection of estimation results. + EstimationTable: A table containing the optimal estimation results """ app_ctx = application.context() arch_ctx = architecture.context() - return estimate_parallel( + # Obtain all results + results = _estimate_parallel( list(trace_query.enumerate(app_ctx)), list(isa_query.enumerate(arch_ctx)), max_error, ) + + # Post-process the results and add them to a results table + table = EstimationTable() + + for result in results: + entry = EstimationTableEntry( + qubits=result.qubits, + runtime=result.runtime, + error=result.error, + source=InstructionSource.from_estimation_result(arch_ctx, result), + ) + table.append(entry) + + return table + + +@dataclass(frozen=True, slots=True) +class EstimationTable: + entries: list[EstimationTableEntry] = field(default_factory=list, init=False) + + def append(self, entry: EstimationTableEntry) -> None: + self.entries.append(entry) + + def __len__(self) -> int: + return len(self.entries) + + def __iter__(self): + return iter(self.entries) + + +@dataclass(frozen=True, slots=True) +class EstimationTableEntry: + qubits: int + runtime: int + error: float + source: InstructionSource diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index 17b1dc6c39..9517a04eb4 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -1,21 +1,32 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Generator, Iterable, Optional, overload, cast from enum import IntEnum +from ._architecture import _Context, Architecture from ._enumeration import _enumerate_instances -from ._isa_enumeration import ISA_ROOT, _BindingNode, _ComponentQuery, ISAQuery +from ._isa_enumeration import ( + ISA_ROOT, + _BindingNode, + _ComponentQuery, + ISAQuery, +) from ._qre import ( ISA, Constraint, ConstraintBound, - FloatFunction, - Instruction, - IntFunction, + EstimationResult, + _FloatFunction, + _Instruction, + _IntFunction, ISARequirements, constant_function, + instruction_name, ) @@ -84,56 +95,56 @@ def instruction( length: Optional[int] = None, error_rate: float, **kwargs: int, -) -> Instruction: ... +) -> _Instruction: ... @overload def instruction( id: int, encoding: Encoding = PHYSICAL, *, - time: int | IntFunction, + time: int | _IntFunction, arity: None = ..., - space: int | IntFunction, - length: Optional[int | IntFunction] = None, - error_rate: float | FloatFunction, + space: int | _IntFunction, + length: Optional[int | _IntFunction] = None, + error_rate: float | _FloatFunction, **kwargs: int, -) -> Instruction: ... +) -> _Instruction: ... def instruction( id: int, encoding: Encoding = PHYSICAL, *, - time: int | IntFunction, + time: int | _IntFunction, arity: Optional[int] = 1, - space: Optional[int] | IntFunction = None, - length: Optional[int | IntFunction] = None, - error_rate: float | FloatFunction, + space: Optional[int] | _IntFunction = None, + length: Optional[int | _IntFunction] = None, + error_rate: float | _FloatFunction, **kwargs: int, -) -> Instruction: +) -> _Instruction: """ Creates an instruction. Args: id (int): The instruction ID. encoding (Encoding): The instruction encoding. PHYSICAL (0) or LOGICAL (1). - time (int | IntFunction): The instruction time in ns. + time (int | _IntFunction): The instruction time in ns. arity (Optional[int]): The instruction arity. If None, instruction is assumed to have variable arity. Default is 1. One can use variable arity functions for time, space, length, and error_rate in this case. - space (Optional[int] | IntFunction): The instruction space in number of + space (Optional[int] | _IntFunction): The instruction space in number of physical qubits. If None, length is used. - length (Optional[int | IntFunction]): The arity including ancilla + length (Optional[int | _IntFunction]): The arity including ancilla qubits. If None, arity is used. - error_rate (float | FloatFunction): The instruction error rate. + error_rate (float | _FloatFunction): The instruction error rate. **kwargs (int): Additional properties to set on the instruction. Valid property names: distance. Returns: - Instruction: The instruction. + _Instruction: The instruction. Raises: ValueError: If an unknown property name is provided in kwargs. """ if arity is not None: - instr = Instruction.fixed_arity( + instr = _Instruction.fixed_arity( id, encoding, arity, @@ -152,12 +163,12 @@ def instruction( if isinstance(error_rate, float): error_rate = constant_function(error_rate) - instr = Instruction.variable_arity( + instr = _Instruction.variable_arity( id, encoding, time, - cast(IntFunction, space), - cast(FloatFunction, error_rate), + cast(_IntFunction, space), + cast(_FloatFunction, error_rate), length, ) @@ -194,7 +205,7 @@ def required_isa() -> ISARequirements: ... @abstractmethod - def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: """ Yields ISAs provided by this transform given an implementation ISA. @@ -210,6 +221,7 @@ def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: def enumerate_isas( cls, impl_isa: ISA | Iterable[ISA], + ctx: _Context, **kwargs, ) -> Generator[ISA, None, None]: """ @@ -231,7 +243,8 @@ def enumerate_isas( continue for component in _enumerate_instances(cls, **kwargs): - yield from component.provided_isa(isa) + ctx._transforms[id(component)] = component + yield from component.provided_isa(isa, ctx) @classmethod def q(cls, *, source: ISAQuery | None = None, **kwargs) -> ISAQuery: @@ -265,3 +278,92 @@ def bind(cls, name: str, node: ISAQuery) -> _BindingNode: BindingNode: A binding node enclosing this transform. """ return cls.q().bind(name, node) + + +@dataclass(frozen=True, slots=True) +class InstructionSource: + nodes: list[_InstructionSourceNode] = field(default_factory=list, init=False) + roots: list[int] = field(default_factory=list, init=False) + + @classmethod + def from_estimation_result( + cls, ctx: _Context, result: EstimationResult + ) -> InstructionSource: + """ + Constructs an InstructionSource graph from an EstimationResult. + + The instruction source graph contains more information than the + provenance graph in the context, as it connects the instructions to the + transforms and architectures that generated them. + + Args: + ctx (_Context): The enumeration context containing the provenance graph. + result (EstimationResult): The estimation result containing the ISA and instruction sources. + + Returns: + InstructionSource: The instruction source graph for the estimation result. + """ + + def _make_node( + graph: InstructionSource, source_table: dict[int, int], source: int + ) -> int: + if source in source_table: + return source_table[source] + + children = [ + _make_node(graph, source_table, child) + for child in ctx._provenance.children(source) + if child != 0 + ] + + node = graph.add_node( + ctx._provenance.instruction_id(source), + ctx._transforms.get(ctx._provenance.transform_id(source)), + children, + ) + + source_table[source] = node + return node + + graph = cls() + source_table: dict[int, int] = {} + + for inst in result.isa: + if inst.source != 0: + node = _make_node(graph, source_table, inst.source) + graph.add_root(node) + + return graph + + def add_root(self, node_id: int) -> None: + self.roots.append(node_id) + + def add_node( + self, + id: int, + transform: Optional[ISATransform | Architecture], + children: list[int], + ) -> int: + node_id = self.nodes.__len__() + self.nodes.append(_InstructionSourceNode(id, transform, children)) + return node_id + + def __str__(self) -> str: + def _format_node(node: _InstructionSourceNode, indent: int = 0) -> str: + result = " " * indent + f"{instruction_name(node.id) or '??'}" + if node.transform is not None: + result += f" @ {node.transform}" + for child_index in node.children: + result += "\n" + _format_node(self.nodes[child_index], indent + 2) + return result + + return "\n".join( + _format_node(self.nodes[root_index]) for root_index in self.roots + ) + + +@dataclass(frozen=True, slots=True) +class _InstructionSourceNode: + id: int + transform: Optional[ISATransform | Architecture] + children: list[int] diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py index 0cfe5e5940..6298ffd847 100644 --- a/source/pip/qsharp/qre/_isa_enumeration.py +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -144,7 +144,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: Yields: ISA: The architecture's provided ISA, called root. """ - yield ctx.root_isa + yield ctx._isa # Singleton instance for convenience @@ -182,7 +182,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: ISA: A generated ISA instance. """ for isa in self.source.enumerate(ctx): - yield from self.component.enumerate_isas(isa, **self.kwargs) + yield from self.component.enumerate_isas(isa, ctx, **self.kwargs) @dataclass diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index dd721b344e..32b15b45f6 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -5,24 +5,26 @@ # pyright: reportAttributeAccessIssue=false from .._native import ( - binom_ppf, + _binom_ppf, block_linear_function, Block, constant_function, Constraint, ConstraintBound, - estimate_parallel, - EstimationCollection, + _estimate_parallel, + _EstimationCollection, EstimationResult, FactoryResult, - FloatFunction, + _FloatFunction, generic_function, - Instruction, + instruction_name, + _Instruction, InstructionFrontier, - IntFunction, + _IntFunction, ISA, ISARequirements, - Property, + _Property, + _ProvenanceGraph, linear_function, LatticeSurgery, PSSPC, diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 20424d52b6..93dc6750be 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -6,24 +6,24 @@ from typing import Any, Callable, Iterator, Optional, overload class ISA: @overload - def __new__(cls, *instructions: Instruction) -> ISA: ... + def __new__(cls, *instructions: _Instruction) -> ISA: ... @overload - def __new__(cls, instructions: list[Instruction], /) -> ISA: ... - def __new__(cls, *instructions: Instruction | list[Instruction]) -> ISA: + def __new__(cls, instructions: list[_Instruction], /) -> ISA: ... + def __new__(cls, *instructions: _Instruction | list[_Instruction]) -> ISA: """ Creates an ISA from a list of instructions. Args: - instructions (list[Instruction] | *Instruction): The list of instructions. + instructions (list[_Instruction] | *_Instruction): The list of instructions. """ ... - def append(self, instruction: Instruction) -> None: + def append(self, instruction: _Instruction) -> None: """ Appends an instruction to the ISA. Args: - instruction (Instruction): The instruction to append. + instruction (_Instruction): The instruction to append. """ ... @@ -53,7 +53,7 @@ class ISA: """ ... - def __getitem__(self, id: int) -> Instruction: + def __getitem__(self, id: int) -> _Instruction: """ Gets an instruction by its ID. @@ -61,23 +61,23 @@ class ISA: id (int): The instruction ID. Returns: - Instruction: The instruction. + _Instruction: The instruction. """ ... def get( - self, id: int, default: Optional[Instruction] = None - ) -> Optional[Instruction]: + self, id: int, default: Optional[_Instruction] = None + ) -> Optional[_Instruction]: """ Gets an instruction by its ID, or returns a default value if not found. Args: id (int): The instruction ID. - default (Optional[Instruction]): The default value to return if the + default (Optional[_Instruction]): The default value to return if the instruction is not found. Returns: - Optional[Instruction]: The instruction, or the default value if not found. + Optional[_Instruction]: The instruction, or the default value if not found. """ ... @@ -90,7 +90,7 @@ class ISA: """ ... - def __iter__(self) -> Iterator[Instruction]: + def __iter__(self) -> Iterator[_Instruction]: """ Returns an iterator over the instructions. @@ -98,7 +98,7 @@ class ISA: The order of instructions is not guaranteed. Returns: - Iterator[Instruction]: The instruction iterator. + Iterator[_Instruction]: The instruction iterator. """ ... @@ -130,7 +130,7 @@ class ISARequirements: """ ... -class Instruction: +class _Instruction: @staticmethod def fixed_arity( id: int, @@ -140,7 +140,7 @@ class Instruction: space: Optional[int], length: Optional[int], error_rate: float, - ) -> Instruction: + ) -> _Instruction: """ Creates an instruction with a fixed arity. @@ -159,7 +159,7 @@ class Instruction: error_rate (float): The instruction error rate. Returns: - Instruction: The instruction. + _Instruction: The instruction. """ ... @@ -167,11 +167,11 @@ class Instruction: def variable_arity( id: int, encoding: int, - time_fn: IntFunction, - space_fn: IntFunction, - error_rate_fn: FloatFunction, - length_fn: Optional[IntFunction], - ) -> Instruction: + time_fn: _IntFunction, + space_fn: _IntFunction, + error_rate_fn: _FloatFunction, + length_fn: Optional[_IntFunction], + ) -> _Instruction: """ Creates an instruction with variable arity. @@ -181,26 +181,30 @@ class Instruction: Args: id (int): The instruction ID. encoding (int): The instruction encoding. 0 = Physical, 1 = Logical. - time_fn (IntFunction): The time function. - space_fn (IntFunction): The space function. - error_rate_fn (FloatFunction): The error rate function. - length_fn (Optional[IntFunction]): The length function. + time_fn (_IntFunction): The time function. + space_fn (_IntFunction): The space function. + error_rate_fn (_FloatFunction): The error rate function. + length_fn (Optional[_IntFunction]): The length function. If None, space_fn is used. Returns: - Instruction: The instruction. + _Instruction: The instruction. """ ... - def with_id(self, id: int) -> Instruction: + def with_id(self, id: int) -> _Instruction: """ Returns a copy of the instruction with the given ID. + Note: + The created instruction will not inherit the source property of the + original instruction and must be set by the user if intended. + Args: id (int): The instruction ID. Returns: - Instruction: A copy of the instruction with the given ID. + _Instruction: A copy of the instruction with the given ID. """ ... @@ -306,6 +310,25 @@ class Instruction: """ ... + def set_source(self, index: int) -> None: + """ + Sets the source index for the instruction. + + Args: + index (int): The source index to set. + """ + ... + + @property + def source(self) -> int: + """ + Gets the source index for the instruction. + + Returns: + int: The source index for the instruction. + """ + ... + def set_property(self, key: int, value: int) -> None: """ Sets a property on the instruction. @@ -482,16 +505,16 @@ class Constraint: """ ... -class IntFunction: ... -class FloatFunction: ... +class _IntFunction: ... +class _FloatFunction: ... @overload -def constant_function(value: int) -> IntFunction: ... +def constant_function(value: int) -> _IntFunction: ... @overload -def constant_function(value: float) -> FloatFunction: ... +def constant_function(value: float) -> _FloatFunction: ... def constant_function( value: int | float, -) -> IntFunction | FloatFunction: +) -> _IntFunction | _FloatFunction: """ Creates a constant function. @@ -499,17 +522,17 @@ def constant_function( value (int | float): The constant value. Returns: - IntFunction | FloatFunction: The constant function. + _IntFunction | _FloatFunction: The constant function. """ ... @overload -def linear_function(slope: int) -> IntFunction: ... +def linear_function(slope: int) -> _IntFunction: ... @overload -def linear_function(slope: float) -> FloatFunction: ... +def linear_function(slope: float) -> _FloatFunction: ... def linear_function( slope: int | float, -) -> IntFunction | FloatFunction: +) -> _IntFunction | _FloatFunction: """ Creates a linear function. @@ -517,17 +540,17 @@ def linear_function( slope (int | float): The slope. Returns: - IntFunction | FloatFunction: The linear function. + _IntFunction | _FloatFunction: The linear function. """ ... @overload -def block_linear_function(block_size: int, slope: int) -> IntFunction: ... +def block_linear_function(block_size: int, slope: int) -> _IntFunction: ... @overload -def block_linear_function(block_size: int, slope: float) -> FloatFunction: ... +def block_linear_function(block_size: int, slope: float) -> _FloatFunction: ... def block_linear_function( block_size: int, slope: int | float -) -> IntFunction | FloatFunction: +) -> _IntFunction | _FloatFunction: """ Creates a block linear function. @@ -536,17 +559,17 @@ def block_linear_function( slope (int | float): The slope. Returns: - IntFunction | FloatFunction: The block linear function. + _IntFunction | _FloatFunction: The block linear function. """ ... @overload -def generic_function(func: Callable[[int], int]) -> IntFunction: ... +def generic_function(func: Callable[[int], int]) -> _IntFunction: ... @overload -def generic_function(func: Callable[[int], float]) -> FloatFunction: ... +def generic_function(func: Callable[[int], float]) -> _FloatFunction: ... def generic_function( func: Callable[[int], int | float], -) -> IntFunction | FloatFunction: +) -> _IntFunction | _FloatFunction: """ Creates a generic function from a Python callable. @@ -561,12 +584,90 @@ def generic_function( func (Callable[[int], int | float]): The Python callable. Returns: - IntFunction | FloatFunction: The generic function. + _IntFunction | _FloatFunction: The generic function. """ ... -class Property: - def __new__(cls, value: Any) -> Property: +class _ProvenanceGraph: + """ + Represents the provenance graph of instructions in a trace. Each node in + the graph corresponds to an instruction and the transform from which it was + produced, and edges represent transformations applied to instructions during + enumeration. + """ + + def add_node( + self, instruction_id: int, transform_id: int, children: list[int] + ) -> int: + """ + Adds a node to the provenance graph. + + Args: + instruction_id (int): The instruction ID corresponding to the node. + transform_id (int): The transform ID corresponding to the node. + children (list[int]): The list of child node indices in the provenance graph. + + Returns: + int: The index of the added node in the provenance graph. + """ + ... + + def instruction_id(self, node_index: int) -> int: + """ + Returns the instruction ID for a given node index. + + Args: + node_index (int): The index of the node in the provenance graph. + + Returns: + int: The instruction ID corresponding to the node. + """ + ... + + def transform_id(self, node_index: int) -> int: + """ + Returns the transform ID for a given node index. + + Args: + node_index (int): The index of the node in the provenance graph. + + Returns: + int: The transform ID corresponding to the node. + """ + ... + + def children(self, node_index: int) -> list[int]: + """ + Returns the list of child node indices for a given node index. + + Args: + node_index (int): The index of the node in the provenance graph. + + Returns: + list[int]: The list of child node indices. + """ + ... + + def num_nodes(self) -> int: + """ + Returns the number of nodes in the provenance graph. + + Returns: + int: The number of nodes in the provenance graph. + """ + ... + + def num_edges(self) -> int: + """ + Returns the number of edges in the provenance graph. + + Returns: + int: The number of edges in the provenance graph. + """ + ... + +class _Property: + def __new__(cls, value: Any) -> _Property: """ Creates a property from a value. @@ -701,6 +802,16 @@ class EstimationResult: """ ... + @property + def isa(self) -> ISA: + """ + The ISA used for the estimation. + + Returns: + ISA: The ISA used for the estimation. + """ + ... + def __str__(self) -> str: """ Returns a string representation of the estimation result. @@ -710,18 +821,18 @@ class EstimationResult: """ ... -class EstimationCollection: +class _EstimationCollection: """ Represents a collection of estimation results. Results are stored as a 2D Pareto frontier with physical qubits and runtime as objectives. """ - def __new__(cls) -> EstimationCollection: + def __new__(cls) -> _EstimationCollection: """ Creates a new estimation collection. Returns: - EstimationCollection: The estimation collection. + _EstimationCollection: The estimation collection. """ ... @@ -866,17 +977,17 @@ class Trace: """ ... - def set_property(self, key: str, value: Property) -> None: + def set_property(self, key: str, value: _Property) -> None: """ Sets a property. Args: key (str): The property key. - value (Property): The property value. + value (_Property): The property value. """ ... - def get_property(self, key: str) -> Optional[Property]: + def get_property(self, key: str) -> Optional[_Property]: """ Gets a property. @@ -884,7 +995,7 @@ class Trace: key (str): The property key. Returns: - Optional[Property]: The property value, or None if not found. + Optional[_Property]: The property value, or None if not found. """ ... @@ -1022,21 +1133,21 @@ class InstructionFrontier: """ ... - def insert(self, point: Instruction): + def insert(self, point: _Instruction): """ Inserts an instruction to the frontier. Args: - point (Instruction): The instruction to insert. + point (_Instruction): The instruction to insert. """ ... - def extend(self, points: list[Instruction]) -> None: + def extend(self, points: list[_Instruction]) -> None: """ Extends the frontier with a list of instructions. Args: - points (list[Instruction]): The instructions to insert. + points (list[_Instruction]): The instructions to insert. """ ... @@ -1049,12 +1160,12 @@ class InstructionFrontier: """ ... - def __iter__(self) -> Iterator[Instruction]: + def __iter__(self) -> Iterator[_Instruction]: """ Returns an iterator over the instructions in the frontier. Returns: - Iterator[Instruction]: The iterator. + Iterator[_Instruction]: The iterator. """ ... @@ -1080,9 +1191,9 @@ class InstructionFrontier: """ ... -def estimate_parallel( +def _estimate_parallel( traces: list[Trace], isas: list[ISA], max_error: float = 1.0 -) -> EstimationCollection: +) -> _EstimationCollection: """ Estimates resources for multiple traces and ISAs in parallel. @@ -1092,13 +1203,25 @@ def estimate_parallel( max_error (float): The maximum allowed error. The default is 1.0. Returns: - EstimationCollection: The estimation collection. + _EstimationCollection: The estimation collection. """ ... -def binom_ppf(q: float, n: int, p: float) -> int: +def _binom_ppf(q: float, n: int, p: float) -> int: """ A replacement for SciPy's binom.ppf that is faster and does not require SciPy as a dependency. """ ... + +def instruction_name(id: int) -> Optional[str]: + """ + Returns the name of an instruction given its ID, if known. + + Args: + id (int): The instruction ID. + + Returns: + Optional[str]: The name of the instruction, or None if the ID is not recognized. + """ + ... diff --git a/source/pip/qsharp/qre/instruction_ids.pyi b/source/pip/qsharp/qre/instruction_ids.pyi index 164e8d431c..e8b8c0e739 100644 --- a/source/pip/qsharp/qre/instruction_ids.pyi +++ b/source/pip/qsharp/qre/instruction_ids.pyi @@ -88,6 +88,7 @@ LATTICE_SURGERY: int # Memory/compute operations (used in compute parts of memory-compute layouts) READ_FROM_MEMORY: int WRITE_TO_MEMORY: int +MEMORY: int # Some special hardware physical instructions CYCLIC_SHIFT: int diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py index 52bf94439f..d619b07a7c 100644 --- a/source/pip/qsharp/qre/models/qec/_surface_code.py +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -13,6 +13,7 @@ ConstraintBound, LOGICAL, ) +from ..._isa_enumeration import _Context from ..._qre import linear_function from ...instruction_ids import CNOT, GENERIC, H, LATTICE_SURGERY, MEAS_Z @@ -50,15 +51,19 @@ def required_isa() -> ISARequirements: constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), ) - def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: - cnot_time = impl_isa[CNOT].expect_time() - h_time = impl_isa[H].expect_time() - meas_time = impl_isa[MEAS_Z].expect_time() + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + cnot = impl_isa[CNOT] + h = impl_isa[H] + meas_z = impl_isa[MEAS_Z] + + cnot_time = cnot.expect_time() + h_time = h.expect_time() + meas_time = meas_z.expect_time() physical_error_rate = max( - impl_isa[CNOT].expect_error_rate(), - impl_isa[H].expect_error_rate(), - impl_isa[MEAS_Z].expect_error_rate(), + cnot.expect_error_rate(), + h.expect_error_rate(), + meas_z.expect_error_rate(), ) space_formula = linear_function(2 * self.distance**2) @@ -73,21 +78,25 @@ def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: ) ) + generic = instruction( + GENERIC, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + ) + + lattice_surgery = instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + ) + yield ISA( - instruction( - GENERIC, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - ), - instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - ), + ctx.set_source(self, generic, [cnot, h, meas_z]), + ctx.set_source(self, lattice_surgery, [cnot, h, meas_z]), ) diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 54b2133f38..b5f79a02e5 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -21,6 +21,7 @@ pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -35,6 +36,7 @@ pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(generic_function, m)?)?; m.add_function(wrap_pyfunction!(estimate_parallel, m)?)?; m.add_function(wrap_pyfunction!(binom_ppf, m)?)?; + m.add_function(wrap_pyfunction!(instruction_name, m)?)?; m.add("EstimationError", m.py().get_type::())?; @@ -174,7 +176,7 @@ impl ISARequirements { } #[allow(clippy::unsafe_derive_deserialize)] -#[pyclass] +#[pyclass(name = "_Instruction")] #[derive(Clone, Serialize, Deserialize)] #[serde(transparent)] pub struct Instruction(qre::Instruction); @@ -273,6 +275,15 @@ impl Instruction { Ok(self.0.expect_error_rate(arity)) } + pub fn set_source(&mut self, index: usize) { + self.0.set_source(index); + } + + #[getter] + pub fn source(&self) -> usize { + self.0.source() + } + pub fn set_property(&mut self, key: u64, value: u64) { self.0.set_property(key, value); } @@ -381,7 +392,44 @@ impl ConstraintBound { } } -#[pyclass] +#[derive(Default)] +#[pyclass(name = "_ProvenanceGraph")] +pub struct ProvenanceGraph(qre::ProvenanceGraph); + +#[pymethods] +impl ProvenanceGraph { + #[new] + pub fn new() -> Self { + Self(qre::ProvenanceGraph::new()) + } + + #[allow(clippy::needless_pass_by_value)] + pub fn add_node(&mut self, id: u64, transform: u64, children: Vec) -> usize { + self.0.add_node(id, transform, &children) + } + + pub fn instruction_id(&self, node_index: usize) -> u64 { + self.0.instruction_id(node_index) + } + + pub fn transform_id(&self, node_index: usize) -> u64 { + self.0.transform_id(node_index) + } + + pub fn children(&self, node_index: usize) -> Vec { + self.0.children(node_index).to_vec() + } + + pub fn num_nodes(&self) -> usize { + self.0.num_nodes() + } + + pub fn num_edges(&self) -> usize { + self.0.num_edges() + } +} + +#[pyclass(name = "_Property")] pub struct Property(qre::Property); #[pymethods] @@ -436,10 +484,10 @@ impl Property { } } -#[pyclass] +#[pyclass(name = "_IntFunction")] pub struct IntFunction(qre::VariableArityFunction); -#[pyclass] +#[pyclass(name = "_FloatFunction")] pub struct FloatFunction(qre::VariableArityFunction); #[pyfunction] @@ -537,7 +585,7 @@ pub fn generic_function<'py>( } #[derive(Default)] -#[pyclass] +#[pyclass(name = "_EstimationCollection")] pub struct EstimationCollection(qre::EstimationCollection); #[pymethods] @@ -612,6 +660,11 @@ impl EstimationResult { Ok(dict) } + #[getter] + pub fn isa(&self) -> ISA { + ISA(self.0.isa().clone()) + } + fn __str__(&self) -> String { format!("{}", self.0) } @@ -868,7 +921,7 @@ impl InstructionFrontierIterator { } #[allow(clippy::needless_pass_by_value)] -#[pyfunction(signature = (traces, isas, max_error = 1.0))] +#[pyfunction(name = "_estimate_parallel", signature = (traces, isas, max_error = 1.0))] pub fn estimate_parallel( traces: Vec>, isas: Vec>, @@ -881,11 +934,16 @@ pub fn estimate_parallel( EstimationCollection(collection) } -#[pyfunction] +#[pyfunction(name = "_binom_ppf")] pub fn binom_ppf(q: f64, n: usize, p: f64) -> usize { qre::binom_ppf(q, n, p) } +#[pyfunction] +pub fn instruction_name(id: u64) -> Option { + qre::instruction_name(id).map(String::from) +} + fn add_instruction_ids(m: &Bound<'_, PyModule>) -> PyResult<()> { #[allow(clippy::wildcard_imports)] use qre::instruction_ids::*; @@ -967,6 +1025,7 @@ fn add_instruction_ids(m: &Bound<'_, PyModule>) -> PyResult<()> { LATTICE_SURGERY, READ_FROM_MEMORY, WRITE_TO_MEMORY, + MEMORY, CYCLIC_SHIFT, GENERIC ); diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 05a3ffb63e..e097b983f6 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -28,6 +28,7 @@ SurfaceCode, AQREGateBased, ) +from qsharp.qre._architecture import _Context from qsharp.qre._isa_enumeration import ( ISARefNode, ) @@ -54,7 +55,7 @@ def required_isa() -> ISARequirements: constraint(T), ) - def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: yield ISA( instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), ) @@ -72,7 +73,7 @@ def required_isa() -> ISARequirements: constraint(T, encoding=LOGICAL), ) - def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: yield ISA( instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), ) @@ -171,23 +172,23 @@ def test_instruction_constraints(): def test_generic_function(): - from qsharp.qre._qre import IntFunction, FloatFunction + from qsharp.qre._qre import _IntFunction, _FloatFunction def time(x: int) -> int: return x * x time_fn = generic_function(time) - assert isinstance(time_fn, IntFunction) + assert isinstance(time_fn, _IntFunction) def error_rate(x: int) -> float: return x / 2.0 error_rate_fn = generic_function(error_rate) - assert isinstance(error_rate_fn, FloatFunction) + assert isinstance(error_rate_fn, _FloatFunction) # Without annotations, defaults to FloatFunction space_fn = generic_function(lambda x: 12) - assert isinstance(space_fn, FloatFunction) + assert isinstance(space_fn, _FloatFunction) i = instruction(42, arity=None, space=12, time=time_fn, error_rate=error_rate_fn) assert i.space(5) == 12 @@ -203,7 +204,7 @@ def test_isa_from_architecture(): assert arch.provided_isa.satisfies(SurfaceCode.required_isa()) # Generate logical ISAs - isas = list(code.provided_isa(arch.provided_isa)) + isas = list(code.provided_isa(arch.provided_isa, arch.context())) # There is one ISA with two instructions assert len(isas) == 1 diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index a63de6bd74..2ca511747a 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -11,6 +11,8 @@ use num_traits::FromPrimitive; use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; +use crate::trace::instruction_ids::instruction_name; + #[cfg(test)] mod tests; @@ -177,6 +179,7 @@ pub struct Instruction { id: u64, encoding: Encoding, metrics: Metrics, + source: usize, properties: Option>, } @@ -204,6 +207,7 @@ impl Instruction { time, error_rate, }, + source: 0, properties: None, } } @@ -228,6 +232,7 @@ impl Instruction { time_fn, error_rate_fn, }, + source: 0, properties: None, } } @@ -235,6 +240,8 @@ impl Instruction { #[must_use] pub fn with_id(&self, id: u64) -> Self { let mut new_instruction = self.clone(); + // reset source for new instruction + new_instruction.source = 0; new_instruction.id = id; new_instruction } @@ -315,6 +322,15 @@ impl Instruction { .expect("Instruction does not support variable arity") } + pub fn set_source(&mut self, provenance: usize) { + self.source = provenance; + } + + #[must_use] + pub fn source(&self) -> usize { + self.source + } + pub fn set_property(&mut self, key: u64, value: u64) { if let Some(ref mut properties) = self.properties { properties.insert(key, value); @@ -345,11 +361,12 @@ impl Instruction { impl Display for Instruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let name = instruction_name(self.id).unwrap_or("??"); match self.metrics { Metrics::FixedArity { arity, .. } => { - write!(f, "{} |{:?}| arity: {arity}", self.id, self.encoding) + write!(f, "{name} |{:?}| arity: {arity}", self.encoding) } - Metrics::VariableArity { .. } => write!(f, "{} |{:?}|", self.id, self.encoding), + Metrics::VariableArity { .. } => write!(f, "{name} |{:?}|", self.encoding), } } } @@ -520,3 +537,83 @@ impl ConstraintBound { } } } + +pub struct ProvenanceGraph { + nodes: Vec, + // A consecutive list of child node indices for each node, where the + // children of node i are located at children[offset..offset+num_children] + // in the children vector. + children: Vec, +} + +impl Default for ProvenanceGraph { + fn default() -> Self { + // Initialize with a dummy node at index 0 to simplify indexing logic + // (so that 0 can be used as a "null" provenance) + let empty = ProvenanceNode::default(); + ProvenanceGraph { + nodes: vec![empty], + children: Vec::new(), + } + } +} + +impl ProvenanceGraph { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn add_node( + &mut self, + instruction_id: u64, + transform_id: u64, + children: &[usize], + ) -> usize { + let node_index = self.nodes.len(); + let offset = self.children.len(); + let num_children = children.len(); + self.children.extend_from_slice(children); + self.nodes.push(ProvenanceNode { + instruction_id, + transform_id, + offset, + num_children, + }); + node_index + } + + #[must_use] + pub fn instruction_id(&self, node_index: usize) -> u64 { + self.nodes[node_index].instruction_id + } + + #[must_use] + pub fn transform_id(&self, node_index: usize) -> u64 { + self.nodes[node_index].transform_id + } + + #[must_use] + pub fn children(&self, node_index: usize) -> &[usize] { + let node = &self.nodes[node_index]; + &self.children[node.offset..node.offset + node.num_children] + } + + #[must_use] + pub fn num_nodes(&self) -> usize { + self.nodes.len() - 1 + } + + #[must_use] + pub fn num_edges(&self) -> usize { + self.children.len() + } +} + +#[derive(Default)] +struct ProvenanceNode { + instruction_id: u64, + transform_id: u64, + offset: usize, + num_children: usize, +} diff --git a/source/qre/src/isa/tests.rs b/source/qre/src/isa/tests.rs index 802a76f30c..b847a6b049 100644 --- a/source/qre/src/isa/tests.rs +++ b/source/qre/src/isa/tests.rs @@ -150,3 +150,58 @@ fn test_variable_arity_function() { assert_eq!(custom_fn.evaluate(3), 9); assert_eq!(custom_fn.evaluate(4), 16); } + +#[test] +fn test_instruction_display_known_id() { + use crate::trace::instruction_ids::H; + + let instr = Instruction::fixed_arity(H, Encoding::Physical, 1, 100, None, None, 0.01); + let display = format!("{instr}"); + + assert!(display.contains('H'), "Expected 'H' in '{display}'"); + assert!( + display.contains("arity: 1"), + "Expected 'arity: 1' in '{display}'" + ); +} + +#[test] +fn test_instruction_display_unknown_id() { + let unknown_id = 0x9999; + let instr = Instruction::fixed_arity(unknown_id, Encoding::Logical, 2, 50, None, None, 0.001); + let display = format!("{instr}"); + + assert!( + display.contains("??"), + "Expected '??' for unknown ID in '{display}'" + ); +} + +#[test] +fn test_instruction_display_variable_arity() { + use crate::trace::instruction_ids::MULTI_PAULI_MEAS; + + let time_fn = VariableArityFunction::linear(10); + let space_fn = VariableArityFunction::constant(5); + let error_rate_fn = VariableArityFunction::constant(0.001); + + let instr = Instruction::variable_arity( + MULTI_PAULI_MEAS, + Encoding::Logical, + time_fn, + space_fn, + None, + error_rate_fn, + ); + let display = format!("{instr}"); + + assert!( + display.contains("MULTI_PAULI_MEAS"), + "Expected 'MULTI_PAULI_MEAS' in '{display}'" + ); + // Variable arity instructions don't show arity + assert!( + !display.contains("arity:"), + "Variable arity should not show arity in '{display}'" + ); +} diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs index b8bd2db55a..bf73b47ffd 100644 --- a/source/qre/src/lib.rs +++ b/source/qre/src/lib.rs @@ -13,9 +13,10 @@ pub use result::{EstimationCollection, EstimationResult, FactoryResult}; mod trace; pub use isa::{ ConstraintBound, Encoding, ISA, ISARequirements, Instruction, InstructionConstraint, - VariableArityFunction, + ProvenanceGraph, VariableArityFunction, }; pub use trace::instruction_ids; +pub use trace::instruction_ids::instruction_name; pub use trace::{Block, LatticeSurgery, PSSPC, Property, Trace, TraceTransform, estimate_parallel}; mod utils; pub use utils::binom_ppf; diff --git a/source/qre/src/result.rs b/source/qre/src/result.rs index fec8bf2135..36c0e8aaec 100644 --- a/source/qre/src/result.rs +++ b/source/qre/src/result.rs @@ -8,7 +8,7 @@ use std::{ use rustc_hash::FxHashMap; -use crate::{ParetoFrontier2D, ParetoItem2D}; +use crate::{ISA, ParetoFrontier2D, ParetoItem2D}; #[derive(Clone, Default)] pub struct EstimationResult { @@ -16,6 +16,7 @@ pub struct EstimationResult { runtime: u64, error: f64, factories: FxHashMap, + isa: ISA, } impl EstimationResult { @@ -77,6 +78,15 @@ impl EstimationResult { pub fn add_factory_result(&mut self, id: u64, result: FactoryResult) { self.factories.insert(id, result); } + + pub fn set_isa(&mut self, isa: ISA) { + self.isa = isa; + } + + #[must_use] + pub fn isa(&self) -> &ISA { + &self.isa + } } impl Display for EstimationResult { diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index c557251534..a6c74a6abe 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -8,6 +8,7 @@ use rustc_hash::{FxHashMap, FxHashSet}; use crate::{Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction}; pub mod instruction_ids; +use instruction_ids::instruction_name; #[cfg(test)] mod tests; @@ -223,6 +224,8 @@ impl Trace { ); } + result.set_isa(isa.clone()); + Ok(result) } } @@ -301,7 +304,8 @@ impl Block { for op in &self.operations { match op { Operation::GateOperation(Gate { id, qubits, params }) => { - writeln!(f, "{indent_str} {id}({params:?})({qubits:?})")?; + let name = instruction_name(*id).unwrap_or("??"); + writeln!(f, "{indent_str} {name}({params:?})({qubits:?})")?; } Operation::BlockOperation(b) => { b.write(f, indent + 2)?; diff --git a/source/qre/src/trace/instruction_ids.rs b/source/qre/src/trace/instruction_ids.rs index 49457b3c79..f47259ca4a 100644 --- a/source/qre/src/trace/instruction_ids.rs +++ b/source/qre/src/trace/instruction_ids.rs @@ -1,103 +1,151 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -// NOTE: Define new instruction ids here. Then: -// - add them to `add_instruction_ids` in qre.rs -// - add them to instruction_ids.pyi - -// Paulis -pub const PAULI_I: u64 = 0x0; -pub const PAULI_X: u64 = 0x1; -pub const PAULI_Y: u64 = 0x2; -pub const PAULI_Z: u64 = 0x3; - -// Clifford gates -pub const H: u64 = 0x10; -pub const H_XZ: u64 = 0x10; -pub const H_XY: u64 = 0x11; -pub const H_YZ: u64 = 0x12; -pub const SQRT_X: u64 = 0x13; -pub const SQRT_X_DAG: u64 = 0x14; -pub const SQRT_Y: u64 = 0x15; -pub const SQRT_Y_DAG: u64 = 0x16; -pub const S: u64 = 0x17; -pub const SQRT_Z: u64 = 0x17; -pub const S_DAG: u64 = 0x18; -pub const SQRT_Z_DAG: u64 = 0x18; -pub const CNOT: u64 = 0x19; -pub const CX: u64 = 0x19; -pub const CY: u64 = 0x1A; -pub const CZ: u64 = 0x1B; -pub const SWAP: u64 = 0x1C; - -// State preparation -pub const PREP_X: u64 = 0x30; -pub const PREP_Y: u64 = 0x31; -pub const PREP_Z: u64 = 0x32; - -// Generic Cliffords -pub const ONE_QUBIT_CLIFFORD: u64 = 0x50; -pub const TWO_QUBIT_CLIFFORD: u64 = 0x51; -pub const N_QUBIT_CLIFFORD: u64 = 0x52; - -// Measurements -pub const MEAS_X: u64 = 0x100; -pub const MEAS_Y: u64 = 0x101; -pub const MEAS_Z: u64 = 0x102; -pub const MEAS_RESET_X: u64 = 0x103; -pub const MEAS_RESET_Y: u64 = 0x104; -pub const MEAS_RESET_Z: u64 = 0x105; -pub const MEAS_XX: u64 = 0x106; -pub const MEAS_YY: u64 = 0x107; -pub const MEAS_ZZ: u64 = 0x108; -pub const MEAS_XZ: u64 = 0x109; -pub const MEAS_XY: u64 = 0x10A; -pub const MEAS_YZ: u64 = 0x10B; - -// Non-Clifford gates -pub const SQRT_SQRT_X: u64 = 0x400; -pub const SQRT_SQRT_X_DAG: u64 = 0x401; -pub const SQRT_SQRT_Y: u64 = 0x402; -pub const SQRT_SQRT_Y_DAG: u64 = 0x403; -pub const SQRT_SQRT_Z: u64 = 0x404; -pub const T: u64 = 0x404; -pub const SQRT_SQRT_Z_DAG: u64 = 0x405; -pub const T_DAG: u64 = 0x405; -pub const CCX: u64 = 0x406; -pub const CCY: u64 = 0x407; -pub const CCZ: u64 = 0x408; -pub const CSWAP: u64 = 0x409; -pub const AND: u64 = 0x40A; -pub const AND_DAG: u64 = 0x40B; -pub const RX: u64 = 0x40C; -pub const RY: u64 = 0x40D; -pub const RZ: u64 = 0x40E; -pub const CRX: u64 = 0x40F; -pub const CRY: u64 = 0x410; -pub const CRZ: u64 = 0x411; -pub const RXX: u64 = 0x412; -pub const RYY: u64 = 0x413; -pub const RZZ: u64 = 0x414; - -// Generic unitaries -pub const ONE_QUBIT_UNITARY: u64 = 0x500; -pub const TWO_QUBIT_UNITARY: u64 = 0x501; - -// Multi-qubit Pauli measurement -pub const MULTI_PAULI_MEAS: u64 = 0x1000; - -// Some generic logical instructions -pub const LATTICE_SURGERY: u64 = 0x1100; - -// Memory/compute operations (used in compute parts of memory-compute layouts) -pub const READ_FROM_MEMORY: u64 = 0x1200; -pub const WRITE_TO_MEMORY: u64 = 0x1201; - -// Some special hardware physical instructions -pub const CYCLIC_SHIFT: u64 = 0x1300; - -// Generic operation (for unified RE) -pub const GENERIC: u64 = 0xFFFF; +// NOTE: To add a new instruction ID: +// 1. Add it to the `define_instructions!` macro below (primary or alias section) +// 2. Add it to `add_instruction_ids` in qre.rs +// 3. Add it to instruction_ids.pyi +// +// The `instruction_name` function is auto-generated from the primary entries. + +#[cfg(test)] +mod tests; + +/// Macro that defines instruction ID constants and generates the `instruction_name` function. +/// Primary entries are the canonical names returned by `instruction_name`. +/// Aliases are alternative names for the same value. +macro_rules! define_instructions { + ( + primary: [ $( ($name:ident, $value:expr) ),* $(,)? ], + aliases: [ $( ($alias:ident, $avalue:expr) ),* $(,)? ] + ) => { + // Define primary constants + $( + pub const $name: u64 = $value; + )* + + // Define alias constants + $( + pub const $alias: u64 = $avalue; + )* + + /// Returns the canonical name for an instruction ID. + /// For IDs with aliases, returns the primary name. + #[must_use] + pub fn instruction_name(id: u64) -> Option<&'static str> { + match id { + $( + $name => Some(stringify!($name)), + )* + _ => None, + } + } + }; +} + +define_instructions! { + primary: [ + // Paulis + (PAULI_I, 0x0), + (PAULI_X, 0x1), + (PAULI_Y, 0x2), + (PAULI_Z, 0x3), + + // Clifford gates + (H, 0x10), + (H_XY, 0x11), + (H_YZ, 0x12), + (SQRT_X, 0x13), + (SQRT_X_DAG, 0x14), + (SQRT_Y, 0x15), + (SQRT_Y_DAG, 0x16), + (S, 0x17), + (S_DAG, 0x18), + (CNOT, 0x19), + (CY, 0x1A), + (CZ, 0x1B), + (SWAP, 0x1C), + + // State preparation + (PREP_X, 0x30), + (PREP_Y, 0x31), + (PREP_Z, 0x32), + + // Generic Cliffords + (ONE_QUBIT_CLIFFORD, 0x50), + (TWO_QUBIT_CLIFFORD, 0x51), + (N_QUBIT_CLIFFORD, 0x52), + + // Measurements + (MEAS_X, 0x100), + (MEAS_Y, 0x101), + (MEAS_Z, 0x102), + (MEAS_RESET_X, 0x103), + (MEAS_RESET_Y, 0x104), + (MEAS_RESET_Z, 0x105), + (MEAS_XX, 0x106), + (MEAS_YY, 0x107), + (MEAS_ZZ, 0x108), + (MEAS_XZ, 0x109), + (MEAS_XY, 0x10A), + (MEAS_YZ, 0x10B), + + // Non-Clifford gates + (SQRT_SQRT_X, 0x400), + (SQRT_SQRT_X_DAG, 0x401), + (SQRT_SQRT_Y, 0x402), + (SQRT_SQRT_Y_DAG, 0x403), + (T, 0x404), + (T_DAG, 0x405), + (CCX, 0x406), + (CCY, 0x407), + (CCZ, 0x408), + (CSWAP, 0x409), + (AND, 0x40A), + (AND_DAG, 0x40B), + (RX, 0x40C), + (RY, 0x40D), + (RZ, 0x40E), + (CRX, 0x40F), + (CRY, 0x410), + (CRZ, 0x411), + (RXX, 0x412), + (RYY, 0x413), + (RZZ, 0x414), + + // Generic unitaries + (ONE_QUBIT_UNITARY, 0x500), + (TWO_QUBIT_UNITARY, 0x501), + + // Multi-qubit Pauli measurement + (MULTI_PAULI_MEAS, 0x1000), + + // Some generic logical instructions + (LATTICE_SURGERY, 0x1100), + + // Memory/compute operations (used in compute parts of memory-compute layouts) + (READ_FROM_MEMORY, 0x1200), + (WRITE_TO_MEMORY, 0x1201), + (MEMORY, 0x1210), + + // Some special hardware physical instructions + (CYCLIC_SHIFT, 0x1300), + + // Generic operation (for unified RE) + (GENERIC, 0xFFFF), + ], + aliases: [ + // Clifford gate aliases + (H_XZ, 0x10), // alias for H + (SQRT_Z, 0x17), // alias for S + (SQRT_Z_DAG, 0x18), // alias for S_DAG + (CX, 0x19), // alias for CNOT + + // Non-Clifford aliases + (SQRT_SQRT_Z, 0x404), // alias for T + (SQRT_SQRT_Z_DAG, 0x405), // alias for T_DAG + ] +} #[must_use] pub fn is_pauli_measurement(id: u64) -> bool { diff --git a/source/qre/src/trace/instruction_ids/tests.rs b/source/qre/src/trace/instruction_ids/tests.rs new file mode 100644 index 0000000000..c9d745ad9a --- /dev/null +++ b/source/qre/src/trace/instruction_ids/tests.rs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn test_instruction_name_primary() { + assert_eq!(instruction_name(H), Some("H")); + assert_eq!(instruction_name(CNOT), Some("CNOT")); + assert_eq!(instruction_name(T), Some("T")); + assert_eq!(instruction_name(MEAS_Z), Some("MEAS_Z")); +} + +#[test] +fn test_instruction_name_aliases_return_primary() { + // Aliases should return the primary name + assert_eq!(instruction_name(H_XZ), Some("H")); + assert_eq!(instruction_name(CX), Some("CNOT")); + assert_eq!(instruction_name(SQRT_Z), Some("S")); + assert_eq!(instruction_name(SQRT_SQRT_Z), Some("T")); +} + +#[test] +fn test_instruction_name_unknown() { + assert_eq!(instruction_name(0x9999), None); +} diff --git a/source/qre/src/trace/tests.rs b/source/qre/src/trace/tests.rs index 57c422c8a4..8b31717969 100644 --- a/source/qre/src/trace/tests.rs +++ b/source/qre/src/trace/tests.rs @@ -238,3 +238,62 @@ fn test_estimate_with_factory() { assert_eq!(factory_res.runs(), 100); assert_eq!(result.factories().len(), 1); } + +#[test] +fn test_trace_display_uses_instruction_names() { + use crate::trace::Trace; + use crate::trace::instruction_ids::{CNOT, H, MEAS_Z}; + + let mut trace = Trace::new(2); + trace.add_operation(H, vec![0], vec![]); + trace.add_operation(CNOT, vec![0, 1], vec![]); + trace.add_operation(MEAS_Z, vec![0], vec![]); + + let display = format!("{trace}"); + + assert!( + display.contains('H'), + "Expected 'H' in trace output: {display}" + ); + assert!( + display.contains("CNOT"), + "Expected 'CNOT' in trace output: {display}" + ); + assert!( + display.contains("MEAS_Z"), + "Expected 'MEAS_Z' in trace output: {display}" + ); +} + +#[test] +fn test_trace_display_unknown_instruction() { + use crate::trace::Trace; + + let mut trace = Trace::new(1); + trace.add_operation(0x9999, vec![0], vec![]); + + let display = format!("{trace}"); + + assert!( + display.contains("??"), + "Expected '??' for unknown instruction in: {display}" + ); +} + +#[test] +fn test_block_display_with_repetitions() { + use crate::trace::Trace; + use crate::trace::instruction_ids::H; + + let mut trace = Trace::new(1); + let block = trace.add_block(10); + block.add_operation(H, vec![0], vec![]); + + let display = format!("{trace}"); + + assert!( + display.contains("repeat 10"), + "Expected 'repeat 10' in: {display}" + ); + assert!(display.contains('H'), "Expected 'H' in block: {display}"); +} From 80ec7869095c42b9800ad9b09725b94a990a8116 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Tue, 17 Feb 2026 19:51:00 +0100 Subject: [PATCH 15/45] Refactor trace and result properties (#2952) This PR cleans up the way properties are handled for traces and results. Trace properties can be assigned by application generators and trace transforms and are stored in optimal resource estimation results. The Python API automatically detects whether the property value is `bool`, `int`, `float`, or `str` and does not require a dedicated property class exposed from Rust. Some other changes include - Fixes a deadlock due to the Python GIL - Allows parallel trace generation from application generators - Reorganizes code around Q# application generators - Sets up resource estimation for memory qubits --- source/pip/qsharp/qre/__init__.py | 5 +- source/pip/qsharp/qre/_application.py | 76 ++------ source/pip/qsharp/qre/_estimation.py | 48 +++-- source/pip/qsharp/qre/_instruction.py | 11 +- source/pip/qsharp/qre/_qre.py | 1 - source/pip/qsharp/qre/_qre.pyi | 163 +++++++--------- source/pip/qsharp/qre/application/__init__.py | 6 + source/pip/qsharp/qre/application/_qsharp.py | 22 +++ source/pip/qsharp/qre/interop/__init__.py | 6 + source/pip/qsharp/qre/interop/_qsharp.py | 78 ++++++++ source/pip/src/qre.rs | 176 ++++++++++++------ source/pip/tests/test_qre.py | 24 ++- source/qre/src/result.rs | 22 ++- source/qre/src/trace.rs | 136 ++++++++++++-- source/qre/src/trace/transforms/psspc.rs | 6 +- 15 files changed, 511 insertions(+), 269 deletions(-) create mode 100644 source/pip/qsharp/qre/application/__init__.py create mode 100644 source/pip/qsharp/qre/application/_qsharp.py create mode 100644 source/pip/qsharp/qre/interop/__init__.py create mode 100644 source/pip/qsharp/qre/interop/_qsharp.py diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index 90bf5bda00..15c3477cb7 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from ._application import Application, QSharpApplication +from ._application import Application from ._architecture import Architecture from ._estimation import estimate from ._instruction import ( @@ -12,6 +12,7 @@ PropertyKey, constraint, instruction, + InstructionSource, ) from ._isa_enumeration import ISAQuery, ISARefNode, ISA_ROOT from ._qre import ( @@ -50,6 +51,7 @@ "generic_function", "instruction_name", "InstructionFrontier", + "InstructionSource", "ISA", "ISA_ROOT", "ISAQuery", @@ -59,7 +61,6 @@ "LatticeSurgery", "PropertyKey", "PSSPC", - "QSharpApplication", "Trace", "TraceQuery", "LOGICAL", diff --git a/source/pip/qsharp/qre/_application.py b/source/pip/qsharp/qre/_application.py index 61300fcf38..4c9a1829e1 100644 --- a/source/pip/qsharp/qre/_application.py +++ b/source/pip/qsharp/qre/_application.py @@ -5,10 +5,9 @@ import types from abc import ABC, abstractmethod -from dataclasses import dataclass +from concurrent.futures import ThreadPoolExecutor from typing import ( Any, - Callable, ClassVar, Generic, Protocol, @@ -18,11 +17,8 @@ cast, ) -from .._qsharp import logical_counts -from ..estimator import LogicalCounts from ._enumeration import _enumerate_instances from ._qre import Trace -from .instruction_ids import CCX, MEAS_Z, RZ, T class DataclassProtocol(Protocol): @@ -48,6 +44,8 @@ class Application(ABC, Generic[TraceParameters]): parameters therein. """ + _parallel_traces: bool = True + @abstractmethod def get_trace(self, parameters: TraceParameters) -> Trace: """Return the trace corresponding to this application.""" @@ -72,8 +70,19 @@ def enumerate_traces( if c is not types.NoneType: param_type = c break - for parameters in _enumerate_instances(cast(type, param_type), **kwargs): - yield self.get_trace(parameters) + + if self._parallel_traces: + instances = list(_enumerate_instances(cast(type, param_type), **kwargs)) + with ThreadPoolExecutor() as executor: + for trace in executor.map(self.get_trace, instances): + yield trace + else: + for instances in _enumerate_instances(cast(type, param_type), **kwargs): + yield self.get_trace(instances) + + def disable_parallel_traces(self): + """Disable parallel trace generation for this application.""" + self._parallel_traces = False class _Context: @@ -83,56 +92,3 @@ class _Context: def __init__(self, application: Application, **kwargs): self.application = application self.kwargs = kwargs - - -@dataclass -class QSharpApplication(Application[None]): - def __init__(self, entry_expr: str | Callable | LogicalCounts): - self._entry_expr = entry_expr - - def get_trace(self, parameters: None = None) -> Trace: - if not isinstance(self._entry_expr, LogicalCounts): - self._counts = logical_counts(self._entry_expr) - else: - self._counts = self._entry_expr - return self._trace_from_logical_counts(self._counts) - - def _trace_from_logical_counts(self, counts: LogicalCounts) -> Trace: - ccx_count = counts.get("cczCount", 0) + counts.get("ccixCount", 0) - - trace = Trace(counts.get("numQubits", 0)) - - rotation_count = counts.get("rotationCount", 0) - rotation_depth = counts.get("rotationDepth", rotation_count) - - if rotation_count != 0: - if rotation_depth > 1: - rotations_per_layer = rotation_count // (rotation_depth - 1) - else: - rotations_per_layer = 0 - - last_layer = rotation_count - (rotations_per_layer * (rotation_depth - 1)) - - if rotations_per_layer != 0: - block = trace.add_block(repetitions=rotation_depth - 1) - for i in range(rotations_per_layer): - block.add_operation(RZ, [i]) - block = trace.add_block() - for i in range(last_layer): - block.add_operation(RZ, [i]) - - if t_count := counts.get("tCount", 0): - block = trace.add_block(repetitions=t_count) - block.add_operation(T, [0]) - - if ccx_count: - block = trace.add_block(repetitions=ccx_count) - block.add_operation(CCX, [0, 1, 2]) - - if meas_count := counts.get("measurementCount", 0): - block = trace.add_block(repetitions=meas_count) - block.add_operation(MEAS_Z, [0]) - - # TODO: handle memory qubits - - return trace diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 17a8330237..c542b4c597 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -4,11 +4,12 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Optional from ._application import Application from ._architecture import Architecture from ._qre import _estimate_parallel -from ._trace import TraceQuery +from ._trace import TraceQuery, PSSPC, LatticeSurgery from ._instruction import InstructionSource from ._isa_enumeration import ISAQuery @@ -16,8 +17,8 @@ def estimate( application: Application, architecture: Architecture, - trace_query: TraceQuery, isa_query: ISAQuery, + trace_query: Optional[TraceQuery] = None, *, max_error: float = 1.0, ) -> EstimationTable: @@ -50,6 +51,9 @@ def estimate( app_ctx = application.context() arch_ctx = architecture.context() + if trace_query is None: + trace_query = PSSPC.q() * LatticeSurgery.q() + # Obtain all results results = _estimate_parallel( list(trace_query.enumerate(app_ctx)), @@ -65,25 +69,38 @@ def estimate( qubits=result.qubits, runtime=result.runtime, error=result.error, - source=InstructionSource.from_estimation_result(arch_ctx, result), + source=InstructionSource.from_isa(arch_ctx, result.isa), + properties=result.properties.copy(), ) + table.append(entry) return table -@dataclass(frozen=True, slots=True) -class EstimationTable: - entries: list[EstimationTableEntry] = field(default_factory=list, init=False) - - def append(self, entry: EstimationTableEntry) -> None: - self.entries.append(entry) - - def __len__(self) -> int: - return len(self.entries) - - def __iter__(self): - return iter(self.entries) +class EstimationTable(list["EstimationTableEntry"]): + def __init__(self): + super().__init__() + + def as_frame(self): + try: + import pandas as pd + except ImportError: + raise ImportError( + "Missing optional 'pandas' dependency. To install run: " + "pip install pandas" + ) + + return pd.DataFrame( + [ + { + "qubits": entry.qubits, + "runtime": entry.runtime, + "error": entry.error, + } + for entry in self + ] + ) @dataclass(frozen=True, slots=True) @@ -92,3 +109,4 @@ class EstimationTableEntry: runtime: int error: float source: InstructionSource + properties: dict[str, int | float | bool | str] = field(default_factory=dict) diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index 9517a04eb4..a907133fd7 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -20,7 +20,6 @@ ISA, Constraint, ConstraintBound, - EstimationResult, _FloatFunction, _Instruction, _IntFunction, @@ -286,11 +285,9 @@ class InstructionSource: roots: list[int] = field(default_factory=list, init=False) @classmethod - def from_estimation_result( - cls, ctx: _Context, result: EstimationResult - ) -> InstructionSource: + def from_isa(cls, ctx: _Context, isa: ISA) -> InstructionSource: """ - Constructs an InstructionSource graph from an EstimationResult. + Constructs an InstructionSource graph from an ISA. The instruction source graph contains more information than the provenance graph in the context, as it connects the instructions to the @@ -298,7 +295,7 @@ def from_estimation_result( Args: ctx (_Context): The enumeration context containing the provenance graph. - result (EstimationResult): The estimation result containing the ISA and instruction sources. + isa (ISA): Instructions in the ISA will serve as root nodes in the source graph. Returns: InstructionSource: The instruction source graph for the estimation result. @@ -328,7 +325,7 @@ def _make_node( graph = cls() source_table: dict[int, int] = {} - for inst in result.isa: + for inst in isa: if inst.source != 0: node = _make_node(graph, source_table, inst.source) graph.add_root(node) diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index 32b15b45f6..e7d8fe29a0 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -23,7 +23,6 @@ _IntFunction, ISA, ISARequirements, - _Property, _ProvenanceGraph, linear_function, LatticeSurgery, diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 93dc6750be..4cac0a4894 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -666,97 +666,6 @@ class _ProvenanceGraph: """ ... -class _Property: - def __new__(cls, value: Any) -> _Property: - """ - Creates a property from a value. - - Args: - value (Any): The value. - """ - ... - - def as_bool(self) -> Optional[bool]: - """ - Returns the value as a boolean. - - Returns: - Optional[bool]: The value as a boolean, or None if it is not a boolean. - """ - ... - - def as_int(self) -> Optional[int]: - """ - Returns the value as an integer. - - Returns: - Optional[int]: The value as an integer, or None if it is not an integer. - """ - ... - - def as_float(self) -> Optional[float]: - """ - Returns the value as a float. - - Returns: - Optional[float]: The value as a float, or None if it is not a float. - """ - ... - - def as_str(self) -> Optional[str]: - """ - Returns the value as a string. - - Returns: - Optional[str]: The value as a string, or None if it is not a string. - """ - ... - - def is_bool(self) -> bool: - """ - Checks if the value is a boolean. - - Returns: - bool: True if the value is a boolean, False otherwise. - """ - ... - - def is_int(self) -> bool: - """ - Checks if the value is an integer. - - Returns: - bool: True if the value is an integer, False otherwise. - """ - ... - - def is_float(self) -> bool: - """ - Checks if the value is a float. - - Returns: - bool: True if the value is a float, False otherwise. - """ - ... - - def is_str(self) -> bool: - """ - Checks if the value is a string. - - Returns: - bool: True if the value is a string, False otherwise. - """ - ... - - def __str__(self) -> str: - """ - Returns a string representation of the property. - - Returns: - str: A string representation of the property. - """ - ... - class EstimationResult: """ Represents the result of a resource estimation. @@ -812,6 +721,16 @@ class EstimationResult: """ ... + @property + def properties(self) -> dict[str, bool | int | float | str]: + """ + Custom properties from application generation and trace transform. + + Returns: + dict[str, bool | int | float | str]: A dictionary mapping property keys to their values. + """ + ... + def __str__(self) -> str: """ Returns a string representation of the estimation result. @@ -967,6 +886,44 @@ class Trace: """ ... + @property + def memory_qubits(self) -> Optional[int]: + """ + The number of memory qubits, if set. + + Returns: + Optional[int]: The number of memory qubits, or None if not set. + """ + ... + + def has_memory_qubits(self) -> bool: + """ + Checks if the trace has memory qubits set. + + Returns: + bool: True if memory qubits are set, False otherwise. + """ + ... + + def set_memory_qubits(self, qubits: int) -> None: + """ + Sets the number of memory qubits. + + Args: + qubits (int): The number of memory qubits. + """ + ... + + def increment_memory_qubits(self, amount: int) -> None: + """ + Increments the number of memory qubits. If memory qubits have not been + set, initializes them to 0 before incrementing. + + Args: + amount (int): The amount to increment. + """ + ... + def increment_resource_state(self, resource_id: int, amount: int) -> None: """ Increments a resource state count. @@ -977,17 +934,19 @@ class Trace: """ ... - def set_property(self, key: str, value: _Property) -> None: + def set_property(self, key: str, value: Any) -> None: """ - Sets a property. + Sets a property. All values of type `int`, `float`, `bool`, and `str` + are supported. Any other value is converted to a string using its + `__str__` method. Args: key (str): The property key. - value (_Property): The property value. + value (Any): The property value. """ ... - def get_property(self, key: str) -> Optional[_Property]: + def get_property(self, key: str) -> Optional[int | float | bool | str]: """ Gets a property. @@ -995,7 +954,19 @@ class Trace: key (str): The property key. Returns: - Optional[_Property]: The property value, or None if not found. + Optional[int | float | bool | str]: The property value, or None if not found. + """ + ... + + def has_property(self, key: str) -> bool: + """ + Checks if a property with the given key exists. + + Args: + key (str): The property key. + + Returns: + bool: True if the property exists, False otherwise. """ ... diff --git a/source/pip/qsharp/qre/application/__init__.py b/source/pip/qsharp/qre/application/__init__.py new file mode 100644 index 0000000000..9f36049425 --- /dev/null +++ b/source/pip/qsharp/qre/application/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._qsharp import QSharpApplication + +__all__ = ["QSharpApplication"] diff --git a/source/pip/qsharp/qre/application/_qsharp.py b/source/pip/qsharp/qre/application/_qsharp.py new file mode 100644 index 0000000000..b01a8d329c --- /dev/null +++ b/source/pip/qsharp/qre/application/_qsharp.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +from ...estimator import LogicalCounts +from .._qre import Trace +from .._application import Application +from ..interop import trace_from_entry_expr + + +@dataclass +class QSharpApplication(Application[None]): + def __init__(self, entry_expr: str | Callable | LogicalCounts): + self._entry_expr = entry_expr + + def get_trace(self, parameters: None = None) -> Trace: + return trace_from_entry_expr(self._entry_expr) diff --git a/source/pip/qsharp/qre/interop/__init__.py b/source/pip/qsharp/qre/interop/__init__.py new file mode 100644 index 0000000000..01a5234a5b --- /dev/null +++ b/source/pip/qsharp/qre/interop/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._qsharp import trace_from_entry_expr + +__all__ = ["trace_from_entry_expr"] diff --git a/source/pip/qsharp/qre/interop/_qsharp.py b/source/pip/qsharp/qre/interop/_qsharp.py new file mode 100644 index 0000000000..1c7d041430 --- /dev/null +++ b/source/pip/qsharp/qre/interop/_qsharp.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import time +from typing import Callable + +from ..._qsharp import logical_counts +from ...estimator import LogicalCounts +from .._qre import Trace +from ..instruction_ids import CCX, MEAS_Z, RZ, T, READ_FROM_MEMORY, WRITE_TO_MEMORY + + +def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: + + start = time.time_ns() + counts = ( + logical_counts(entry_expr) + if not isinstance(entry_expr, LogicalCounts) + else entry_expr + ) + evaluation_time = time.time_ns() - start + + ccx_count = counts.get("cczCount", 0) + counts.get("ccixCount", 0) + + # Q# logical counts report total number of qubits (compute + memory) + num_qubits = counts.get("numQubits", 0) + # Compute qubits may be reported separately + compute_qubits = counts.get("numComputeQubits", num_qubits) + memory_qubits = num_qubits - compute_qubits + + trace = Trace(compute_qubits) + + rotation_count = counts.get("rotationCount", 0) + rotation_depth = counts.get("rotationDepth", rotation_count) + + if rotation_count != 0: + if rotation_depth > 1: + rotations_per_layer = rotation_count // (rotation_depth - 1) + else: + rotations_per_layer = 0 + + last_layer = rotation_count - (rotations_per_layer * (rotation_depth - 1)) + + if rotations_per_layer != 0: + block = trace.add_block(repetitions=rotation_depth - 1) + for i in range(rotations_per_layer): + block.add_operation(RZ, [i]) + block = trace.add_block() + for i in range(last_layer): + block.add_operation(RZ, [i]) + + if t_count := counts.get("tCount", 0): + block = trace.add_block(repetitions=t_count) + block.add_operation(T, [0]) + + if ccx_count: + block = trace.add_block(repetitions=ccx_count) + block.add_operation(CCX, [0, 1, 2]) + + if meas_count := counts.get("measurementCount", 0): + block = trace.add_block(repetitions=meas_count) + block.add_operation(MEAS_Z, [0]) + + if memory_qubits != 0: + trace.set_memory_qubits(memory_qubits) + + if rfm_count := counts.get("readFromMemoryCount", 0): + block = trace.add_block(repetitions=rfm_count) + block.add_operation(READ_FROM_MEMORY, [0, compute_qubits]) + + if wtm_count := counts.get("writeToMemoryCount", 0): + block = trace.add_block(repetitions=wtm_count) + block.add_operation(WRITE_TO_MEMORY, [0, compute_qubits]) + + trace.set_property("evaluation_time", evaluation_time) + return trace diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index b5f79a02e5..0f45f76bda 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -7,7 +7,7 @@ use pyo3::{ IntoPyObjectExt, exceptions::{PyException, PyKeyError, PyTypeError}, prelude::*, - types::{PyDict, PyTuple}, + types::{PyBool, PyDict, PyFloat, PyInt, PyString, PyTuple}, }; use qre::TraceTransform; use serde::{Deserialize, Serialize}; @@ -17,7 +17,6 @@ pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -429,61 +428,6 @@ impl ProvenanceGraph { } } -#[pyclass(name = "_Property")] -pub struct Property(qre::Property); - -#[pymethods] -impl Property { - #[new] - pub fn new(value: &Bound<'_, PyAny>) -> PyResult { - if value.is_instance_of::() { - Ok(Property(qre::Property::new_bool(value.extract()?))) - } else if let Ok(i) = value.extract::() { - Ok(Property(qre::Property::new_int(i))) - } else if let Ok(f) = value.extract::() { - Ok(Property(qre::Property::new_float(f))) - } else { - Ok(Property(qre::Property::new_str(value.to_string()))) - } - } - - fn as_bool(&self) -> Option { - self.0.as_bool() - } - - fn as_int(&self) -> Option { - self.0.as_int() - } - - fn as_float(&self) -> Option { - self.0.as_float() - } - - fn as_str(&self) -> Option { - self.0.as_str().map(String::from) - } - - fn is_bool(&self) -> bool { - self.0.is_bool() - } - - fn is_int(&self) -> bool { - self.0.is_int() - } - - fn is_float(&self) -> bool { - self.0.is_float() - } - - fn is_str(&self) -> bool { - self.0.is_str() - } - - fn __str__(&self) -> String { - format!("{}", self.0) - } -} - #[pyclass(name = "_IntFunction")] pub struct IntFunction(qre::VariableArityFunction); @@ -665,6 +609,23 @@ impl EstimationResult { ISA(self.0.isa().clone()) } + #[allow(clippy::needless_pass_by_value)] + #[getter] + pub fn properties(self_: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(self_.py()); + + for (key, value) in self_.0.properties() { + match value { + qre::Property::Bool(b) => dict.set_item(key, *b)?, + qre::Property::Int(i) => dict.set_item(key, *i)?, + qre::Property::Float(f) => dict.set_item(key, *f)?, + qre::Property::Str(s) => dict.set_item(key, s.clone())?, + } + } + + Ok(dict) + } + fn __str__(&self) -> String { format!("{}", self.0) } @@ -725,12 +686,46 @@ impl Trace { self.0.increment_base_error(amount); } - pub fn set_property(&mut self, key: String, value: &Property) { - self.0.set_property(key, value.0.clone()); + pub fn set_property(&mut self, key: String, value: &Bound<'_, PyAny>) -> PyResult<()> { + let property = if value.is_instance_of::() { + qre::Property::new_bool(value.extract()?) + } else if let Ok(i) = value.extract::() { + qre::Property::new_int(i) + } else if let Ok(f) = value.extract::() { + qre::Property::new_float(f) + } else { + qre::Property::new_str(value.to_string()) + }; + + self.0.set_property(key, property); + + Ok(()) } - pub fn get_property(&self, key: &str) -> Option { - self.0.get_property(key).map(|p| Property(p.clone())) + #[allow(clippy::needless_pass_by_value)] + pub fn get_property<'py>(self_: PyRef<'py, Self>, key: &str) -> Option> { + if let Some(value) = self_.0.get_property(key) { + match value { + qre::Property::Bool(b) => PyBool::new(self_.py(), *b) + .into_bound_py_any(self_.py()) + .ok(), + qre::Property::Int(i) => PyInt::new(self_.py(), *i) + .into_bound_py_any(self_.py()) + .ok(), + qre::Property::Float(f) => PyFloat::new(self_.py(), *f) + .into_bound_py_any(self_.py()) + .ok(), + qre::Property::Str(s) => PyString::new(self_.py(), s) + .into_bound_py_any(self_.py()) + .ok(), + } + } else { + None + } + } + + pub fn has_property(&self, key: &str) -> bool { + self.0.has_property(key) } #[allow(clippy::needless_pass_by_value)] @@ -775,6 +770,23 @@ impl Trace { }) } + #[getter] + pub fn memory_qubits(&self) -> Option { + self.0.memory_qubits() + } + + pub fn has_memory_qubits(&self) -> bool { + self.0.has_memory_qubits() + } + + pub fn set_memory_qubits(&mut self, qubits: u64) { + self.0.set_memory_qubits(qubits); + } + + pub fn increment_memory_qubits(&mut self, amount: u64) { + self.0.increment_memory_qubits(amount); + } + pub fn increment_resource_state(&mut self, resource_id: u64, amount: u64) { self.0.increment_resource_state(resource_id, amount); } @@ -923,6 +935,7 @@ impl InstructionFrontierIterator { #[allow(clippy::needless_pass_by_value)] #[pyfunction(name = "_estimate_parallel", signature = (traces, isas, max_error = 1.0))] pub fn estimate_parallel( + py: Python<'_>, traces: Vec>, isas: Vec>, max_error: f64, @@ -930,10 +943,49 @@ pub fn estimate_parallel( let traces: Vec<_> = traces.iter().map(|t| &t.0).collect(); let isas: Vec<_> = isas.iter().map(|i| &i.0).collect(); - let collection = qre::estimate_parallel(&traces, &isas, Some(max_error)); + // Release the GIL before entering the parallel section. + // Worker threads spawned by qre::estimate_parallel may need to acquire + // the GIL to evaluate Python callbacks (via generic_function closures). + // If the calling thread holds the GIL while blocked in + // std::thread::scope, the worker threads deadlock. + let collection = release_gil(py, || { + qre::estimate_parallel(&traces, &isas, Some(max_error)) + }); EstimationCollection(collection) } +/// Releases the GIL for the duration of the closure `f`, allowing other +/// threads to acquire it. A RAII guard ensures the thread state is restored +/// even if `f` panics. +/// +/// # Safety +/// +/// The caller must ensure that no `Bound<'_, _>` or `Python<'_>` references +/// are used inside `f`. GIL-independent `Py` handles are fine because +/// they re-acquire the GIL via `Python::attach` when needed. +/// +/// We cannot use `py.allow_threads` here because the captured data +/// (`&qre::ISA`) transitively contains `Arc` whose +/// trait object does not carry the `Ungil` auto-trait bound. +fn release_gil(_py: Python<'_>, f: F) -> R +where + F: FnOnce() -> R, +{ + struct RestoreGuard(*mut pyo3::ffi::PyThreadState); + + impl Drop for RestoreGuard { + fn drop(&mut self) { + // SAFETY: called on the same thread that saved the state. + unsafe { pyo3::ffi::PyEval_RestoreThread(self.0) }; + } + } + + // SAFETY: we hold the GIL (proven by the `_py` token) and release it + // here so that worker threads can acquire it for Python callbacks. + let _guard = RestoreGuard(unsafe { pyo3::ffi::PyEval_SaveThread() }); + f() +} + #[pyfunction(name = "_binom_ppf")] pub fn binom_ppf(q: f64, n: usize, p: f64) -> usize { qre::binom_ppf(q, n, p) diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index e097b983f6..40b3e8ffd4 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -16,7 +16,6 @@ ISATransform, LatticeSurgery, PropertyKey, - QSharpApplication, Trace, constraint, estimate, @@ -24,6 +23,7 @@ linear_function, generic_function, ) +from qsharp.qre.application import QSharpApplication from qsharp.qre.models import ( SurfaceCode, AQREGateBased, @@ -551,6 +551,26 @@ def test_sum_isa_enumeration_nodes(): assert isinstance(source, _ComponentQuery) +def test_trace_properties(): + trace = Trace(42) + + trace.set_property("int", 42) + assert trace.get_property("int") == 42 + assert isinstance(trace.get_property("int"), int) + + trace.set_property("float", 3.14) + assert trace.get_property("float") == 3.14 + assert isinstance(trace.get_property("float"), float) + + trace.set_property("bool", True) + assert trace.get_property("bool") is True + assert isinstance(trace.get_property("bool"), bool) + + trace.set_property("str", "hello") + assert trace.get_property("str") == "hello" + assert isinstance(trace.get_property("str"), str) + + def test_qsharp_application(): from qsharp.qre._enumeration import _enumerate_instances @@ -658,8 +678,8 @@ def test_estimation_max_error(): results = estimate( app, arch, - PSSPC.q() * LatticeSurgery.q(), SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), max_error=max_error, ) diff --git a/source/qre/src/result.rs b/source/qre/src/result.rs index 36c0e8aaec..485b6c1135 100644 --- a/source/qre/src/result.rs +++ b/source/qre/src/result.rs @@ -8,7 +8,7 @@ use std::{ use rustc_hash::FxHashMap; -use crate::{ISA, ParetoFrontier2D, ParetoItem2D}; +use crate::{ISA, ParetoFrontier2D, ParetoItem2D, Property}; #[derive(Clone, Default)] pub struct EstimationResult { @@ -17,6 +17,7 @@ pub struct EstimationResult { error: f64, factories: FxHashMap, isa: ISA, + properties: FxHashMap, } impl EstimationResult { @@ -87,6 +88,25 @@ impl EstimationResult { pub fn isa(&self) -> &ISA { &self.isa } + + pub fn set_property(&mut self, key: String, value: Property) { + self.properties.insert(key, value); + } + + #[must_use] + pub fn get_property(&self, key: &str) -> Option<&Property> { + self.properties.get(key) + } + + #[must_use] + pub fn has_property(&self, key: &str) -> bool { + self.properties.contains_key(key) + } + + #[must_use] + pub fn properties(&self) -> &FxHashMap { + &self.properties + } } impl Display for EstimationResult { diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index a6c74a6abe..1d2d7081e2 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -68,6 +68,33 @@ impl Trace { self.base_error += amount; } + #[must_use] + pub fn memory_qubits(&self) -> Option { + self.memory_qubits + } + + #[must_use] + pub fn has_memory_qubits(&self) -> bool { + self.memory_qubits.is_some() + } + + pub fn set_memory_qubits(&mut self, qubits: u64) { + self.memory_qubits = Some(qubits); + } + + pub fn increment_memory_qubits(&mut self, amount: u64) { + if amount == 0 { + return; + } + let current = self.memory_qubits.get_or_insert(0); + *current += amount; + } + + #[must_use] + pub fn total_qubits(&self) -> u64 { + self.compute_qubits + self.memory_qubits.unwrap_or(0) + } + pub fn increment_resource_state(&mut self, resource_id: u64, amount: u64) { if amount == 0 { return; @@ -100,6 +127,11 @@ impl Trace { self.properties.get(key) } + #[must_use] + pub fn has_property(&self, key: &str) -> bool { + self.properties.contains_key(key) + } + #[must_use] pub fn deep_iter(&self) -> TraceIterator<'_> { TraceIterator::new(&self.block) @@ -224,8 +256,38 @@ impl Trace { ); } + // Memory qubits + if let Some(memory_qubits) = self.memory_qubits { + // We need a MEMORY instruction in our ISA + let memory = isa + .get(&instruction_ids::MEMORY) + .ok_or(Error::InstructionNotFound(instruction_ids::MEMORY))?; + + result.add_qubits(memory.expect_space(Some(memory_qubits))); + + // The number of rounds for the memory qubits to stay alive with + // respect to the total runtime of the algorithm. + let rounds = result + .runtime() + .div_ceil(memory.expect_time(Some(memory_qubits))); + + let actual_error = + result.add_error(rounds as f64 * memory.expect_error_rate(Some(memory_qubits))); + if actual_error > max_error { + return Err(Error::MaximumErrorExceeded { + actual_error, + max_error, + }); + } + } + result.set_isa(isa.clone()); + // Copy properties from the trace to the result + for (key, value) in &self.properties { + result.set_property(key.clone(), value.clone()); + } + Ok(result) } } @@ -556,45 +618,79 @@ fn get_error_rate_by_id(isa: &ISA, id: u64) -> Result { .ok_or(Error::CannotExtractErrorRate(id)) } +/// Estimates all (trace, ISA) combinations in parallel, returning only the +/// successful results collected into an [`EstimationCollection`]. +/// +/// This uses a shared atomic counter as a lock-free work queue. Each worker +/// thread atomically claims the next job index, maps it to a `(trace, isa)` +/// pair, and runs the estimation. This keeps all available cores busy until +/// the last job completes. +/// +/// # Work distribution +/// +/// Jobs are numbered `0 .. traces.len() * isas.len()`. For job index `j`: +/// - `trace_idx = j / isas.len()` +/// - `isa_idx = j % isas.len()` +/// +/// Each worker accumulates results locally and sends them back over a bounded +/// channel once it runs out of work, avoiding contention on the shared +/// collection. #[must_use] pub fn estimate_parallel<'a>( traces: &[&'a Trace], isas: &[&'a ISA], max_error: Option, ) -> EstimationCollection { - fn estimate_chunks<'a>( - traces: &[&'a Trace], - isas: &[&'a ISA], - max_error: Option, - ) -> Vec { - let mut local_collection = Vec::new(); - for trace in traces { - for isa in isas { - if let Ok(estimation) = trace.estimate(isa, max_error) { - local_collection.push(estimation); - } - } - } - local_collection - } + let total_jobs = traces.len() * isas.len(); + let num_isas = isas.len(); + + // Shared atomic counter acts as a lock-free work queue. Workers call + // fetch_add to claim the next job index. + let next_job = std::sync::atomic::AtomicUsize::new(0); let mut collection = EstimationCollection::new(); std::thread::scope(|scope| { let num_threads = std::thread::available_parallelism() .map(std::num::NonZero::get) .unwrap_or(1); - let chunk_size = traces.len().div_ceil(num_threads); + // Bounded channel so each worker can send its batch of results back + // to the main thread without unbounded buffering. let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); - for chunk in traces.chunks(chunk_size) { + for _ in 0..num_threads { let tx = tx.clone(); - scope.spawn(move || tx.send(estimate_chunks(chunk, isas, max_error))); + let next_job = &next_job; + scope.spawn(move || { + let mut local_results = Vec::new(); + loop { + // Atomically claim the next job. Relaxed ordering is + // sufficient because there is no dependent data between + // jobs — each (trace, isa) pair is independent. + let job = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if job >= total_jobs { + break; + } + + // Map the flat job index to a (trace, ISA) pair. + let trace_idx = job / num_isas; + let isa_idx = job % num_isas; + + if let Ok(estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) { + local_results.push(estimation); + } + } + // Send all results from this worker in one batch. + let _ = tx.send(local_results); + }); } + // Drop the cloned sender so the receiver iterator terminates once all + // workers have finished. drop(tx); - for local_collection in rx.iter().take(num_threads) { - collection.extend(local_collection.into_iter()); + // Collect results from all workers into the shared collection. + for local_results in rx { + collection.extend(local_results.into_iter()); } }); diff --git a/source/qre/src/trace/transforms/psspc.rs b/source/qre/src/trace/transforms/psspc.rs index 287e6c0aa1..309cd1fb81 100644 --- a/source/qre/src/trace/transforms/psspc.rs +++ b/source/qre/src/trace/transforms/psspc.rs @@ -67,7 +67,7 @@ impl PSSPC { fn psspc_counts(trace: &Trace) -> Result { let mut counter = PSSPCCounts::default(); - let mut max_rotation_depth = vec![0; trace.compute_qubits() as usize]; + let mut max_rotation_depth = vec![0; trace.total_qubits() as usize]; for (Gate { id, qubits, .. }, mult) in trace.deep_iter() { if instruction_ids::is_pauli_measurement(*id) { @@ -123,7 +123,7 @@ impl PSSPC { } #[allow(clippy::cast_precision_loss)] - fn compute_only_trace(&self, trace: &Trace, counts: &PSSPCCounts) -> Trace { + fn get_trace(&self, trace: &Trace, counts: &PSSPCCounts) -> Trace { let num_qubits = trace.compute_qubits(); let logical_qubits = Self::logical_qubit_overhead(num_qubits); @@ -202,7 +202,7 @@ impl TraceTransform for PSSPC { fn transform(&self, trace: &Trace) -> Result { let counts = Self::psspc_counts(trace)?; - Ok(self.compute_only_trace(trace, &counts)) + Ok(self.get_trace(trace, &counts)) } } From bcc31d0a034bc42995fb6a429ba1c33ae9613a7d Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Wed, 18 Feb 2026 05:13:19 -0500 Subject: [PATCH 16/45] Brlackey/magnets ising model (#2955) Major modifications to the internal structure of the Model class. - Coefficients of the model are now a dictionary keyed by the vertex tuple of the (hyper)edges. - API more closely matches QREv3 (still needs some work). - Standard translation-invariant Ising model is implemented as a factory function that builds Model instances. - Tests included. --- source/pip/qsharp/magnets/models/__init__.py | 4 +- source/pip/qsharp/magnets/models/model.py | 172 ++++++---- source/pip/tests/magnets/test_model.py | 327 +++++++++++-------- 3 files changed, 293 insertions(+), 210 deletions(-) diff --git a/source/pip/qsharp/magnets/models/__init__.py b/source/pip/qsharp/magnets/models/__init__.py index 1f815fb5e9..58f47bd721 100644 --- a/source/pip/qsharp/magnets/models/__init__.py +++ b/source/pip/qsharp/magnets/models/__init__.py @@ -7,6 +7,6 @@ as Hamiltonians built from Pauli operators. """ -from .model import Model +from .model import Model, translation_invariant_ising_model -__all__ = ["Model"] +__all__ = ["Model", "translation_invariant_ising_model"] diff --git a/source/pip/qsharp/magnets/models/model.py b/source/pip/qsharp/magnets/models/model.py index e6f1eb7449..6f69c35adf 100644 --- a/source/pip/qsharp/magnets/models/model.py +++ b/source/pip/qsharp/magnets/models/model.py @@ -6,40 +6,27 @@ """Base Model class for quantum spin models. This module provides the base class for representing quantum spin models -as Hamiltonians built from Pauli operators. The Model class integrates -with hypergraph geometries to define interaction topologies and uses -Cirq's PauliString and PauliSum for representing quantum operators. +as Hamiltonians. The Model class integrates with hypergraph geometries +to define interaction topologies and stores coefficients for each edge. """ -from typing import Iterator -from qsharp.magnets.geometry import Hypergraph - -try: - from cirq import LineQubit, PauliSum, PauliString -except Exception as ex: - raise ImportError( - "qsharp.magnets.models requires the cirq extras. Install with 'pip install \"qsharp[cirq]\"'." - ) from ex +from qsharp.magnets.geometry import Hyperedge, Hypergraph class Model: """Base class for quantum spin models. - This class wraps a list of cirq.PauliSum objects that define the Hamiltonian - of a quantum system. Each element of the list represents a partition of - the Hamiltonian into different terms, which is useful for: + This class represents a quantum spin Hamiltonian defined on a hypergraph + geometry. The Hamiltonian is characterized by: - - Trotterization: Grouping commuting terms for efficient simulation - - Parallel execution: Terms in the same partition can be applied simultaneously - - Resource estimation: Analyzing different parts of the Hamiltonian separately + - Coefficients: A mapping from edge vertex tuples to float coefficients + - Terms: Groupings of hyperedges for Trotterization or parallel execution The model is built on a hypergraph geometry that defines which qubits - interact with each other. Subclasses should populate the `terms` list - with appropriate PauliSum operators based on the geometry. + interact with each other. Attributes: geometry: The Hypergraph defining the interaction topology. - terms: List of PauliSum objects representing partitioned Hamiltonian terms. Example: @@ -47,95 +34,136 @@ class Model: >>> from qsharp.magnets.geometry import Chain1D >>> geometry = Chain1D(4) >>> model = Model(geometry) - >>> model.add_term() # Add an empty term - >>> len(model.terms) - 1 + >>> model.set_coefficient((0, 1), 1.5) + >>> model.get_coefficient((0, 1)) + 1.5 """ def __init__(self, geometry: Hypergraph): """Initialize the Model. Creates a quantum spin model on the given geometry. The model starts - with no Hamiltonian terms; subclasses or callers should add terms - using `add_term()` and `add_to_term()`. + with all coefficients set to zero and no term groupings. Args: geometry: Hypergraph defining the interaction topology. The number of vertices determines the number of qubits in the model. """ self.geometry: Hypergraph = geometry - self._qubits: list[LineQubit] = [ - LineQubit(i) for i in range(geometry.nvertices) - ] - self.terms: list[PauliSum] = [] - - def add_term(self, term: PauliSum = None) -> None: - """Add a term to the Hamiltonian. + self._qubits: set[int] = set() + self._coefficients: dict[tuple[int, ...], float] = dict() + for edge in geometry.edges(): + self._qubits.update(edge.vertices) + self._coefficients[edge.vertices] = 0.0 + self._terms: list[list[Hyperedge]] = [] - Appends a new PauliSum to the list of Hamiltonian terms. This is - typically used to create partitions for Trotterization, where each - partition contains operators that can be applied together. + def get_coefficient(self, vertices: tuple[int, ...]) -> float: + """Get the coefficient for an edge in the Hamiltonian. Args: - term: The PauliSum to add. If None, an empty PauliSum is added, - which can be populated later using `add_to_term()`. - """ - if term is None: - term = PauliSum() - self.terms.append(term) + vertices: Tuple of vertex indices identifying the edge. - def add_to_term(self, index: int, pauli_string: PauliString) -> None: - """Add a PauliString to a specific term in the Hamiltonian. - - Appends a Pauli operator (with coefficient) to an existing term. - This is used to build up the Hamiltonian incrementally. - - Args: - index: Index of the term to add to (0-indexed). - pauli_string: The PauliString to add to the term. This can - include a coefficient, e.g., `0.5 * cirq.Z(q0) * cirq.Z(q1)`. + Returns: + The coefficient value for the specified edge. Raises: - IndexError: If index is out of range of the terms list. + KeyError: If the vertex tuple does not correspond to an edge + in the geometry. """ - self.terms[index] += pauli_string + vertices = tuple(sorted(vertices)) + if vertices not in self._coefficients: + raise KeyError(f"No edge with vertices {vertices} in geometry") + return self._coefficients[vertices] + + def has_coefficient(self, vertices: tuple[int, ...]) -> bool: + """Check if a coefficient exists for the given edge vertices. - def q(self, i: int) -> LineQubit: - """Return the qubit at index i. + Args: + vertices: Tuple of vertex indices identifying the edge. + Returns: + True if a coefficient exists for the edge, False otherwise. + """ + return tuple(sorted(vertices)) in self._coefficients - Provides convenient access to qubits by their vertex index in - the underlying geometry. + def set_coefficient(self, vertices: tuple[int, ...], value: float) -> None: + """Set the coefficient for an edge in the Hamiltonian. Args: - i: Index of the qubit (0-indexed, corresponds to vertex index). + vertices: Tuple of vertex indices identifying the edge. + value: The coefficient value to set. - Returns: - The LineQubit at the specified index. + Raises: + KeyError: If the vertex tuple does not correspond to an edge + in the geometry. """ - return self._qubits[i] + vertices = tuple(sorted(vertices)) + if vertices not in self._coefficients: + raise KeyError(f"No edge with vertices {vertices} in geometry") + self._coefficients[vertices] = value - def qubit_list(self) -> list[LineQubit]: - """Return the list of qubits in the model. + def add_term(self, edges: list[Hyperedge]) -> None: + """Add a term grouping to the model. - Returns: - A list of all LineQubit objects in the model, ordered by index. + Appends a list of hyperedges as a term. Terms are used for + grouping edges for Trotterization or parallel execution. + + Args: + edges: List of Hyperedge objects to group as a term. """ - return self._qubits + self._terms.append(list(edges)) - def qubits(self) -> Iterator[LineQubit]: - """Return an iterator over the qubits in the model. + def terms(self) -> list[list[Hyperedge]]: + """Return the list of term groupings. Returns: - An iterator yielding LineQubit objects in index order. + List of lists of Hyperedges representing term groupings. """ - return iter(self._qubits) + return self._terms def __str__(self) -> str: """String representation of the model.""" return "Generic model with {} terms on {} qubits.".format( - len(self.terms), len(self._qubits) + len(self._terms), len(self._qubits) ) def __repr__(self) -> str: """String representation of the model.""" return self.__str__() + + +def translation_invariant_ising_model( + geometry: Hypergraph, h: float, J: float +) -> Model: + """Create a translation-invariant Ising model on the given geometry. + + The Hamiltonian is: + H = -J * Σ_{} Z_i Z_j - h * Σ_i X_i + + Two-body edges (len=2) in the geometry represent ZZ interactions with + coefficient -J. Single-vertex edges (len=1) represent X field terms + with coefficient -h. Edges are grouped into terms by their color + for parallel execution. + + Args: + geometry: The Hypergraph defining the interaction topology. + Should include single-vertex edges for field terms. + h: The transverse field strength (coefficient for X terms). + J: The coupling strength (coefficient for ZZ interaction terms). + + Returns: + A Model instance representing the Ising Hamiltonian. + """ + model = Model(geometry) + model._terms = [ + [] for _ in range(geometry.ncolors + 1) + ] # Initialize term groupings based on edge colors + for edge in geometry.edges(): + vertices = edge.vertices + if len(vertices) == 1: + model.set_coefficient(vertices, -h) # Set X field coefficient + elif len(vertices) == 2: + model.set_coefficient(vertices, -J) # Set ZZ interaction coefficient + color = geometry.color[vertices] + model._terms[color].append(edge) # Group edges by color for parallel execution + + return model diff --git a/source/pip/tests/magnets/test_model.py b/source/pip/tests/magnets/test_model.py index 0028cbcbab..246b11fb15 100644 --- a/source/pip/tests/magnets/test_model.py +++ b/source/pip/tests/magnets/test_model.py @@ -1,23 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pyright: reportPrivateImportUsage=false, reportOperatorIssue=false +# pyright: reportPrivateImportUsage=false """Unit tests for the Model class.""" -# To be updated after additional geometries are implemented - from __future__ import annotations import pytest -from . import CIRQ_AVAILABLE, SKIP_REASON - -if CIRQ_AVAILABLE: - import cirq - from cirq import LineQubit - from qsharp.magnets.geometry import Hyperedge, Hypergraph - from qsharp.magnets.models import Model +from qsharp.magnets.geometry import Hyperedge, Hypergraph +from qsharp.magnets.models import Model def make_chain(length: int) -> Hypergraph: @@ -26,186 +19,163 @@ def make_chain(length: int) -> Hypergraph: return Hypergraph(edges) +def make_chain_with_vertices(length: int) -> Hypergraph: + """Create a chain hypergraph with single-vertex (field) edges for testing.""" + edges = [Hyperedge([i, i + 1]) for i in range(length - 1)] + # Add single-vertex edges for field terms + edges.extend([Hyperedge([i]) for i in range(length)]) + return Hypergraph(edges) + + # Model initialization tests -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) def test_model_init_basic(): """Test basic Model initialization.""" geometry = Hypergraph([Hyperedge([0, 1]), Hyperedge([1, 2])]) model = Model(geometry) assert model.geometry is geometry - assert len(model.terms) == 0 - - -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_init_creates_qubits(): - """Test that Model creates correct number of qubits.""" - geometry = Hypergraph([Hyperedge([0, 1]), Hyperedge([2, 3])]) - model = Model(geometry) - assert len(model.qubit_list()) == 4 + assert len(model.terms()) == 0 -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) def test_model_init_with_chain(): """Test Model initialization with chain geometry.""" geometry = make_chain(5) model = Model(geometry) - assert len(model.qubit_list()) == 5 + assert len(model._qubits) == 5 -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) def test_model_init_empty_geometry(): """Test Model with empty geometry.""" geometry = Hypergraph([]) model = Model(geometry) - assert len(model.qubit_list()) == 0 - assert len(model.terms) == 0 + assert len(model._qubits) == 0 + assert len(model.terms()) == 0 -# Qubit access tests +def test_model_init_coefficients_zero(): + """Test that coefficients are initialized to zero.""" + geometry = make_chain(3) # edges: (0,1), (1,2) + model = Model(geometry) + assert model.get_coefficient((0, 1)) == 0.0 + assert model.get_coefficient((1, 2)) == 0.0 -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_q_returns_line_qubit(): - """Test that q() returns LineQubit instances.""" - geometry = make_chain(3) - model = Model(geometry) - qubit = model.q(0) - assert isinstance(qubit, LineQubit) +# Coefficient tests -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_q_returns_correct_qubit(): - """Test that q() returns qubit with correct index.""" - geometry = make_chain(4) +def test_model_set_coefficient(): + """Test setting coefficient for an edge.""" + geometry = make_chain(2) model = Model(geometry) - for i in range(4): - assert model.q(i) == LineQubit(i) + model.set_coefficient((0, 1), 1.5) + assert model.get_coefficient((0, 1)) == 1.5 -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_qubit_list(): - """Test qubit_list() returns all qubits.""" - geometry = make_chain(3) +def test_model_set_coefficient_overwrite(): + """Test overwriting an existing coefficient.""" + geometry = make_chain(2) model = Model(geometry) - qubits = model.qubit_list() - assert len(qubits) == 3 - assert qubits == [LineQubit(0), LineQubit(1), LineQubit(2)] + model.set_coefficient((0, 1), 1.5) + model.set_coefficient((0, 1), 2.5) + assert model.get_coefficient((0, 1)) == 2.5 -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_qubits_iterator(): - """Test qubits() returns an iterator.""" - geometry = make_chain(3) +def test_model_set_coefficient_invalid_edge(): + """Test setting coefficient for non-existent edge raises error.""" + geometry = make_chain(2) model = Model(geometry) - qubit_iter = model.qubits() - qubits = list(qubit_iter) - assert len(qubits) == 3 - assert qubits == [LineQubit(0), LineQubit(1), LineQubit(2)] + with pytest.raises(KeyError): + model.set_coefficient((0, 2), 1.0) -# Term management tests +def test_model_get_coefficient_invalid_edge(): + """Test getting coefficient for non-existent edge raises error.""" + geometry = make_chain(2) + model = Model(geometry) + with pytest.raises(KeyError): + model.get_coefficient((0, 2)) -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_add_term_empty(): - """Test adding an empty term.""" +def test_model_get_coefficient_sorted(): + """Test that get_coefficient sorts vertices so order doesn't matter.""" geometry = make_chain(2) model = Model(geometry) - model.add_term() - assert len(model.terms) == 1 + model.set_coefficient((0, 1), 3.0) + assert model.get_coefficient((1, 0)) == 3.0 -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_add_term_with_pauli_sum(): - """Test adding a PauliSum term.""" +def test_model_set_coefficient_sorted(): + """Test that set_coefficient sorts vertices so order doesn't matter.""" geometry = make_chain(2) model = Model(geometry) - q0, q1 = model.q(0), model.q(1) - term = cirq.Z(q0) * cirq.Z(q1) - model.add_term(cirq.PauliSum.from_pauli_strings([term])) - assert len(model.terms) == 1 + model.set_coefficient((1, 0), 4.0) + assert model.get_coefficient((0, 1)) == 4.0 -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_add_multiple_terms(): - """Test adding multiple terms.""" +# has_coefficient tests + + +def test_model_has_coefficient_true(): + """Test has_coefficient returns True for existing edge.""" geometry = make_chain(3) model = Model(geometry) - model.add_term() - model.add_term() - model.add_term() - assert len(model.terms) == 3 + assert model.has_coefficient((0, 1)) is True + assert model.has_coefficient((1, 2)) is True -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_add_to_term(): - """Test adding a PauliString to an existing term.""" - geometry = make_chain(2) +def test_model_has_coefficient_false(): + """Test has_coefficient returns False for non-existent edge.""" + geometry = make_chain(3) model = Model(geometry) - model.add_term() - q0, q1 = model.q(0), model.q(1) - pauli_string = cirq.Z(q0) * cirq.Z(q1) - model.add_to_term(0, pauli_string) - # Term should now contain the Pauli string - assert len(model.terms[0]) == 1 + assert model.has_coefficient((0, 2)) is False + assert model.has_coefficient((5, 6)) is False -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_add_to_term_multiple_strings(): - """Test adding multiple PauliStrings to the same term.""" - geometry = make_chain(3) +def test_model_has_coefficient_sorted(): + """Test has_coefficient sorts vertices so order doesn't matter.""" + geometry = make_chain(2) model = Model(geometry) - model.add_term() - q0, q1, q2 = model.q(0), model.q(1), model.q(2) - model.add_to_term(0, cirq.Z(q0) * cirq.Z(q1)) - model.add_to_term(0, cirq.Z(q1) * cirq.Z(q2)) - assert len(model.terms[0]) == 2 + assert model.has_coefficient((1, 0)) is True + + +# Term management tests -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_add_to_different_terms(): - """Test adding PauliStrings to different terms.""" +def test_model_add_term(): + """Test adding a term with edges.""" geometry = make_chain(3) model = Model(geometry) - model.add_term() - model.add_term() - q0, q1, q2 = model.q(0), model.q(1), model.q(2) - model.add_to_term(0, cirq.Z(q0) * cirq.Z(q1)) - model.add_to_term(1, cirq.Z(q1) * cirq.Z(q2)) - assert len(model.terms[0]) == 1 - assert len(model.terms[1]) == 1 - - -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) -def test_model_add_to_term_with_coefficient(): - """Test adding a PauliString with a coefficient.""" - geometry = make_chain(2) + edge1 = Hyperedge([0, 1]) + edge2 = Hyperedge([1, 2]) + model.add_term([edge1, edge2]) + assert len(model.terms()) == 1 + assert len(model.terms()[0]) == 2 + + +def test_model_add_multiple_terms(): + """Test adding multiple terms.""" + geometry = make_chain(4) model = Model(geometry) - model.add_term() - q0, q1 = model.q(0), model.q(1) - pauli_string = 0.5 * cirq.Z(q0) * cirq.Z(q1) - model.add_to_term(0, pauli_string) - assert len(model.terms[0]) == 1 + model.add_term([Hyperedge([0, 1])]) + model.add_term([Hyperedge([1, 2]), Hyperedge([2, 3])]) + assert len(model.terms()) == 2 # String representation tests -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) def test_model_str(): """Test string representation.""" geometry = make_chain(4) model = Model(geometry) - model.add_term() - model.add_term() + model.add_term([Hyperedge([0, 1])]) + model.add_term([Hyperedge([1, 2])]) result = str(model) assert "2 terms" in result assert "4 qubits" in result -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) def test_model_str_empty(): """Test string representation with no terms.""" geometry = make_chain(3) @@ -215,7 +185,6 @@ def test_model_str_empty(): assert "3 qubits" in result -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) def test_model_repr(): """Test repr representation.""" geometry = make_chain(2) @@ -226,38 +195,124 @@ def test_model_repr(): # Integration tests -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) def test_model_build_simple_hamiltonian(): """Test building a simple ZZ Hamiltonian on a chain.""" geometry = make_chain(3) model = Model(geometry) - model.add_term() # Single term for all interactions + # Set coefficients for all edges for edge in geometry.edges(): - i, j = edge.vertices - model.add_to_term(0, cirq.Z(model.q(i)) * cirq.Z(model.q(j))) + model.set_coefficient(edge.vertices, 1.0) - # Should have 2 ZZ interactions: (0,1) and (1,2) - assert len(model.terms[0]) == 2 + # Verify coefficients + assert model.get_coefficient((0, 1)) == 1.0 + assert model.get_coefficient((1, 2)) == 1.0 -@pytest.mark.skipif(not CIRQ_AVAILABLE, reason=SKIP_REASON) def test_model_with_partitioned_terms(): """Test building a model with partitioned terms for Trotterization.""" geometry = make_chain(4) model = Model(geometry) # Add two terms for even/odd partitioning - model.add_term() # Even edges: (0,1), (2,3) - model.add_term() # Odd edges: (1,2) + even_edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] + odd_edges = [Hyperedge([1, 2])] + model.add_term(even_edges) + model.add_term(odd_edges) + + assert len(model.terms()) == 2 + assert len(model.terms()[0]) == 2 + assert len(model.terms()[1]) == 1 + + +# translation_invariant_ising_model tests + + +def test_translation_invariant_ising_model_basic(): + """Test basic creation of Ising model.""" + from qsharp.magnets.models import translation_invariant_ising_model + + geometry = make_chain_with_vertices(3) + model = translation_invariant_ising_model(geometry, h=1.0, J=1.0) + + assert isinstance(model, Model) + assert model.geometry is geometry + + +def test_translation_invariant_ising_model_zz_coefficients(): + """Test that ZZ interaction coefficients are correctly set.""" + from qsharp.magnets.models import translation_invariant_ising_model + + geometry = make_chain_with_vertices(4) # 3 two-body edges: (0,1), (1,2), (2,3) + J = 2.0 + model = translation_invariant_ising_model(geometry, h=0.5, J=J) + + # All two-body edge coefficients should be -J + assert model.get_coefficient((0, 1)) == -J + assert model.get_coefficient((1, 2)) == -J + assert model.get_coefficient((2, 3)) == -J + + +def test_translation_invariant_ising_model_x_coefficients(): + """Test that X field coefficients are correctly set.""" + from qsharp.magnets.models import translation_invariant_ising_model + + geometry = make_chain_with_vertices(4) # 4 single-vertex edges + h = 0.5 + model = translation_invariant_ising_model(geometry, h=h, J=2.0) + + # All single-vertex edge coefficients should be -h + for v in range(4): + assert model.get_coefficient((v,)) == -h + + +def test_translation_invariant_ising_model_coefficients(): + """Test that coefficients are correctly applied.""" + from qsharp.magnets.models import translation_invariant_ising_model + + # Geometry with one two-body edge and two single-vertex edges + geometry = Hypergraph([Hyperedge([0, 1]), Hyperedge([0]), Hyperedge([1])]) + h, J = 0.3, 0.7 + model = translation_invariant_ising_model(geometry, h=h, J=J) + + # Check ZZ coefficient is -J + assert model.get_coefficient((0, 1)) == -J + + # Check X coefficients are -h + assert model.get_coefficient((0,)) == -h + assert model.get_coefficient((1,)) == -h + + +def test_translation_invariant_ising_model_zero_field(): + """Test Ising model with zero transverse field.""" + from qsharp.magnets.models import translation_invariant_ising_model + + geometry = make_chain_with_vertices(3) + model = translation_invariant_ising_model(geometry, h=0.0, J=1.0) + + # X coefficients (single-vertex edges) should all be zero + for v in range(3): + assert model.get_coefficient((v,)) == 0.0 + + +def test_translation_invariant_ising_model_zero_coupling(): + """Test Ising model with zero coupling.""" + from qsharp.magnets.models import translation_invariant_ising_model + + geometry = make_chain_with_vertices(3) + model = translation_invariant_ising_model(geometry, h=1.0, J=0.0) + + # ZZ coefficients (two-body edges) should all be zero + assert model.get_coefficient((0, 1)) == 0.0 + assert model.get_coefficient((1, 2)) == 0.0 + - # Add even edges to term 0 - model.add_to_term(0, cirq.Z(model.q(0)) * cirq.Z(model.q(1))) - model.add_to_term(0, cirq.Z(model.q(2)) * cirq.Z(model.q(3))) +def test_translation_invariant_ising_model_term_grouping(): + """Test that Ising model has correct term grouping by color.""" + from qsharp.magnets.models import translation_invariant_ising_model - # Add odd edge to term 1 - model.add_to_term(1, cirq.Z(model.q(1)) * cirq.Z(model.q(2))) + geometry = make_chain_with_vertices(4) + model = translation_invariant_ising_model(geometry, h=1.0, J=1.0) - assert len(model.terms) == 2 - assert len(model.terms[0]) == 2 - assert len(model.terms[1]) == 1 + # Number of terms should be ncolors + 1 + assert len(model.terms()) == geometry.ncolors + 1 From e19e6a139ceb8a647d8b90647936855ec3c43c07 Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Fri, 20 Feb 2026 09:27:56 -0500 Subject: [PATCH 17/45] Magnets: added utilities (#2965) Added a utilities sublibrary. * Moved Hypergraph and Hyperedge to utilities * Added Pauli and PauliString classes (primitives around Cirq classes) * Modified Model class to use PauliStrings * Modified tests --- .../pip/qsharp/magnets/geometry/__init__.py | 4 - .../pip/qsharp/magnets/geometry/complete.py | 2 +- .../pip/qsharp/magnets/geometry/lattice1d.py | 2 +- .../pip/qsharp/magnets/geometry/lattice2d.py | 2 +- source/pip/qsharp/magnets/models/model.py | 104 ++++++-- .../pip/qsharp/magnets/utilities/__init__.py | 22 ++ .../{geometry => utilities}/hypergraph.py | 0 source/pip/qsharp/magnets/utilities/pauli.py | 235 ++++++++++++++++++ source/pip/tests/magnets/test_complete.py | 4 +- source/pip/tests/magnets/test_hypergraph.py | 2 +- source/pip/tests/magnets/test_lattice1d.py | 4 +- source/pip/tests/magnets/test_lattice2d.py | 4 +- source/pip/tests/magnets/test_model.py | 112 ++++++++- 13 files changed, 454 insertions(+), 43 deletions(-) create mode 100644 source/pip/qsharp/magnets/utilities/__init__.py rename source/pip/qsharp/magnets/{geometry => utilities}/hypergraph.py (100%) create mode 100644 source/pip/qsharp/magnets/utilities/pauli.py diff --git a/source/pip/qsharp/magnets/geometry/__init__.py b/source/pip/qsharp/magnets/geometry/__init__.py index 3d2ac6c1fb..4a7a380f86 100644 --- a/source/pip/qsharp/magnets/geometry/__init__.py +++ b/source/pip/qsharp/magnets/geometry/__init__.py @@ -9,16 +9,12 @@ """ from .complete import CompleteBipartiteGraph, CompleteGraph -from .hypergraph import Hyperedge, Hypergraph, greedy_edge_coloring from .lattice1d import Chain1D, Ring1D from .lattice2d import Patch2D, Torus2D __all__ = [ "CompleteBipartiteGraph", "CompleteGraph", - "Hyperedge", - "Hypergraph", - "greedy_edge_coloring", "Chain1D", "Ring1D", "Patch2D", diff --git a/source/pip/qsharp/magnets/geometry/complete.py b/source/pip/qsharp/magnets/geometry/complete.py index c38cdfc2b5..aee8f35014 100644 --- a/source/pip/qsharp/magnets/geometry/complete.py +++ b/source/pip/qsharp/magnets/geometry/complete.py @@ -8,7 +8,7 @@ systems with all-to-all or bipartite all-to-all interactions. """ -from qsharp.magnets.geometry.hypergraph import ( +from qsharp.magnets.utilities import ( Hyperedge, Hypergraph, greedy_edge_coloring, diff --git a/source/pip/qsharp/magnets/geometry/lattice1d.py b/source/pip/qsharp/magnets/geometry/lattice1d.py index 9fffeecbe2..ff091fd28a 100644 --- a/source/pip/qsharp/magnets/geometry/lattice1d.py +++ b/source/pip/qsharp/magnets/geometry/lattice1d.py @@ -8,7 +8,7 @@ simulations and other one-dimensional quantum systems. """ -from qsharp.magnets.geometry.hypergraph import Hyperedge, Hypergraph +from qsharp.magnets.utilities import Hyperedge, Hypergraph class Chain1D(Hypergraph): diff --git a/source/pip/qsharp/magnets/geometry/lattice2d.py b/source/pip/qsharp/magnets/geometry/lattice2d.py index e04817ef92..4821c5eaeb 100644 --- a/source/pip/qsharp/magnets/geometry/lattice2d.py +++ b/source/pip/qsharp/magnets/geometry/lattice2d.py @@ -8,7 +8,7 @@ simulations and other two-dimensional quantum systems. """ -from qsharp.magnets.geometry.hypergraph import Hyperedge, Hypergraph +from qsharp.magnets.utilities import Hyperedge, Hypergraph class Patch2D(Hypergraph): diff --git a/source/pip/qsharp/magnets/models/model.py b/source/pip/qsharp/magnets/models/model.py index 6f69c35adf..197078f7ab 100644 --- a/source/pip/qsharp/magnets/models/model.py +++ b/source/pip/qsharp/magnets/models/model.py @@ -10,7 +10,7 @@ to define interaction topologies and stores coefficients for each edge. """ -from qsharp.magnets.geometry import Hyperedge, Hypergraph +from qsharp.magnets.utilities import Hyperedge, Hypergraph, PauliString class Model: @@ -19,7 +19,7 @@ class Model: This class represents a quantum spin Hamiltonian defined on a hypergraph geometry. The Hamiltonian is characterized by: - - Coefficients: A mapping from edge vertex tuples to float coefficients + - Ops: A mapping from edge vertex tuples to (coefficient, PauliString) pairs - Terms: Groupings of hyperedges for Trotterization or parallel execution The model is built on a hypergraph geometry that defines which qubits @@ -35,6 +35,7 @@ class Model: >>> geometry = Chain1D(4) >>> model = Model(geometry) >>> model.set_coefficient((0, 1), 1.5) + >>> model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) >>> model.get_coefficient((0, 1)) 1.5 """ @@ -43,7 +44,8 @@ def __init__(self, geometry: Hypergraph): """Initialize the Model. Creates a quantum spin model on the given geometry. The model starts - with all coefficients set to zero and no term groupings. + with all coefficients set to zero (with identity PauliStrings) and + no term groupings. Args: geometry: Hypergraph defining the interaction topology. The number @@ -51,10 +53,13 @@ def __init__(self, geometry: Hypergraph): """ self.geometry: Hypergraph = geometry self._qubits: set[int] = set() - self._coefficients: dict[tuple[int, ...], float] = dict() + self._ops: dict[tuple[int, ...], tuple[float, PauliString]] = dict() for edge in geometry.edges(): self._qubits.update(edge.vertices) - self._coefficients[edge.vertices] = 0.0 + self._ops[edge.vertices] = ( + 0.0, + PauliString.from_qubits(edge.vertices, [0] * len(edge.vertices)), + ) self._terms: list[list[Hyperedge]] = [] def get_coefficient(self, vertices: tuple[int, ...]) -> float: @@ -71,21 +76,43 @@ def get_coefficient(self, vertices: tuple[int, ...]) -> float: in the geometry. """ vertices = tuple(sorted(vertices)) - if vertices not in self._coefficients: + if vertices not in self._ops: raise KeyError(f"No edge with vertices {vertices} in geometry") - return self._coefficients[vertices] + return self._ops[vertices][0] - def has_coefficient(self, vertices: tuple[int, ...]) -> bool: - """Check if a coefficient exists for the given edge vertices. + def get_pauli_string(self, vertices: tuple[int, ...]) -> PauliString: + """Get the PauliString for an edge in the Hamiltonian. Args: vertices: Tuple of vertex indices identifying the edge. + Returns: - True if a coefficient exists for the edge, False otherwise. + The PauliString for the specified edge. + + Raises: + KeyError: If the vertex tuple does not correspond to an edge + in the geometry. """ - return tuple(sorted(vertices)) in self._coefficients + vertices = tuple(sorted(vertices)) + if vertices not in self._ops: + raise KeyError(f"No edge with vertices {vertices} in geometry") + return self._ops[vertices][1] + + def has_interaction_term(self, vertices: tuple[int, ...]) -> bool: + """Check if an interaction term exists for the given edge vertices. - def set_coefficient(self, vertices: tuple[int, ...], value: float) -> None: + Args: + vertices: Tuple of vertex indices identifying the edge. + Returns: + True if an interaction term exists for the edge, False otherwise. + """ + return tuple(sorted(vertices)) in self._ops + + def set_coefficient( + self, + vertices: tuple[int, ...], + value: float, + ) -> None: """Set the coefficient for an edge in the Hamiltonian. Args: @@ -97,9 +124,54 @@ def set_coefficient(self, vertices: tuple[int, ...], value: float) -> None: in the geometry. """ vertices = tuple(sorted(vertices)) - if vertices not in self._coefficients: + if vertices not in self._ops: + raise KeyError(f"No edge with vertices {vertices} in geometry") + self._ops[vertices] = (value, self._ops[vertices][1]) + + def set_pauli_string( + self, + vertices: tuple[int, ...], + pauli_string: PauliString, + ) -> None: + """Set the PauliString for an edge in the Hamiltonian. + + Args: + vertices: Tuple of vertex indices identifying the edge. + pauli_string: The PauliString to associate with this edge. + + Raises: + KeyError: If the vertex tuple does not correspond to an edge + in the geometry. + """ + vertices = tuple(sorted(vertices)) + if vertices not in self._ops: + raise KeyError(f"No edge with vertices {vertices} in geometry") + self._ops[vertices] = (self._ops[vertices][0], pauli_string) + + def set_operator( + self, + vertices: tuple[int, ...], + value: float, + pauli_string: PauliString, + ) -> None: + """Set both the coefficient and PauliString for an edge. + + Convenience method that combines :meth:`set_coefficient` and + :meth:`set_pauli_string` in a single call. + + Args: + vertices: Tuple of vertex indices identifying the edge. + value: The coefficient value to set. + pauli_string: The PauliString to associate with this edge. + + Raises: + KeyError: If the vertex tuple does not correspond to an edge + in the geometry. + """ + vertices = tuple(sorted(vertices)) + if vertices not in self._ops: raise KeyError(f"No edge with vertices {vertices} in geometry") - self._coefficients[vertices] = value + self._ops[vertices] = (value, pauli_string) def add_term(self, edges: list[Hyperedge]) -> None: """Add a term grouping to the model. @@ -160,9 +232,9 @@ def translation_invariant_ising_model( for edge in geometry.edges(): vertices = edge.vertices if len(vertices) == 1: - model.set_coefficient(vertices, -h) # Set X field coefficient + model.set_operator(vertices, -h, PauliString.from_qubits(vertices, "X")) elif len(vertices) == 2: - model.set_coefficient(vertices, -J) # Set ZZ interaction coefficient + model.set_operator(vertices, -J, PauliString.from_qubits(vertices, "ZZ")) color = geometry.color[vertices] model._terms[color].append(edge) # Group edges by color for parallel execution diff --git a/source/pip/qsharp/magnets/utilities/__init__.py b/source/pip/qsharp/magnets/utilities/__init__.py new file mode 100644 index 0000000000..10c2ebdf69 --- /dev/null +++ b/source/pip/qsharp/magnets/utilities/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Utilities module for magnets package. + +This module provides utility data structures and algorithms used across +the magnets package, including hypergraph representations. +""" + +from .hypergraph import Hyperedge, Hypergraph, greedy_edge_coloring +from .pauli import Pauli, PauliString, PauliX, PauliY, PauliZ + +__all__ = [ + "Hyperedge", + "Hypergraph", + "Pauli", + "PauliString", + "PauliX", + "PauliY", + "PauliZ", + "greedy_edge_coloring", +] diff --git a/source/pip/qsharp/magnets/geometry/hypergraph.py b/source/pip/qsharp/magnets/utilities/hypergraph.py similarity index 100% rename from source/pip/qsharp/magnets/geometry/hypergraph.py rename to source/pip/qsharp/magnets/utilities/hypergraph.py diff --git a/source/pip/qsharp/magnets/utilities/pauli.py b/source/pip/qsharp/magnets/utilities/pauli.py new file mode 100644 index 0000000000..82c936778f --- /dev/null +++ b/source/pip/qsharp/magnets/utilities/pauli.py @@ -0,0 +1,235 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Pauli operator representation for quantum spin systems.""" + +from collections.abc import Sequence + +try: + import cirq +except Exception as ex: + raise ImportError( + "qsharp.magnets.models requires the cirq extras. Install with 'pip install \"qsharp[cirq]\"'." + ) from ex + + +class Pauli: + """A single-qubit Pauli operator (I, X, Y, or Z) acting on a specific qubit. + + Can be constructed from an integer (0–3) or a string ('I', 'X', 'Y', 'Z'), + along with the index of the qubit it acts on. + + Mapping: + 0 / 'I' → Identity + 1 / 'X' → Pauli-X + 2 / 'Z' → Pauli-Z + 3 / 'Y' → Pauli-Y + + Attributes: + qubit: The qubit index this operator acts on. + + Example: + + .. code-block:: python + >>> p = Pauli('X', 0) + >>> p.op + 1 + >>> p.qubit + 0 + """ + + _VALID_INTS = {0, 1, 2, 3} + _STR_TO_INT = {"I": 0, "X": 1, "Z": 2, "Y": 3} + + def __init__(self, value: int | str, qubit: int = 0) -> None: + """Initialize a Pauli operator. + + Args: + value: An integer 0–3 or one of 'I', 'X', 'Y', 'Z' (case-insensitive). + qubit: The index of the qubit this operator acts on. Defaults to 0. + + Raises: + ValueError: If the value is not a recognized Pauli identifier. + """ + if isinstance(value, int): + if value not in self._VALID_INTS: + raise ValueError(f"Integer value must be 0–3, got {value}.") + self._op = value + elif isinstance(value, str): + key = value.upper() + if key not in self._STR_TO_INT: + raise ValueError( + f"String value must be one of 'I', 'X', 'Y', 'Z', got '{value}'." + ) + self._op = self._STR_TO_INT[key] + else: + raise ValueError(f"Expected int or str, got {type(value).__name__}.") + self.qubit: int = qubit + + @property + def op(self) -> int: + """Return the integer representation of this Pauli operator. + + Returns: + 0 for I, 1 for X, 2 for Z, 3 for Y. + """ + return self._op + + def __repr__(self) -> str: + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + return f"Pauli('{labels[self._op]}', qubit={self.qubit})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Pauli): + return NotImplemented + return self._op == other._op and self.qubit == other.qubit + + def __hash__(self) -> int: + return hash((self._op, self.qubit)) + + @property + def cirq(self): + """Return the corresponding Cirq Pauli operator. + + Returns: + ``cirq.I``, ``cirq.X``, ``cirq.Z``, or ``cirq.Y``. + """ + _INT_TO_CIRQ = (cirq.I, cirq.X, cirq.Z, cirq.Y) + return _INT_TO_CIRQ[self._op].on(cirq.LineQubit(self.qubit)) + + +def PauliX(qubit: int) -> Pauli: + """Create a Pauli-X operator on the given qubit.""" + return Pauli("X", qubit) + + +def PauliY(qubit: int) -> Pauli: + """Create a Pauli-Y operator on the given qubit.""" + return Pauli("Y", qubit) + + +def PauliZ(qubit: int) -> Pauli: + """Create a Pauli-Z operator on the given qubit.""" + return Pauli("Z", qubit) + + +class PauliString: + """A multi-qubit Pauli operator acting on specific qubits. + + Stores a tuple of :class:`Pauli` objects, each carrying its own qubit index. + Can be constructed from a sequence of ``Pauli`` instances (default), or via + the :meth:`from_qubits` class method which takes qubit indices and Pauli + labels separately. + + Attributes: + _paulis: Tuple of Pauli objects defining the operator on each qubit. + + Example: + + .. code-block:: python + >>> ps = PauliString([PauliX(0), PauliZ(1)]) + >>> ps.qubits + (0, 1) + >>> list(ps) + [Pauli(X, qubit=0), Pauli(Z, qubit=1)] + >>> ps2 = PauliString.from_qubits((0, 1), "XZ") + >>> ps == ps2 + True + """ + + def __init__(self, paulis: Sequence[Pauli]) -> None: + """Initialize a PauliString from a sequence of Pauli operators. + + Args: + paulis: A sequence of :class:`Pauli` instances, each with its + own qubit index. + + Raises: + TypeError: If any element is not a Pauli instance. + """ + for p in paulis: + if not isinstance(p, Pauli): + raise TypeError( + f"Expected Pauli instance, got {type(p).__name__}. " + "Use PauliString.from_qubits() for int/str values." + ) + self._paulis: tuple[Pauli, ...] = tuple(paulis) + + @classmethod + def from_qubits( + cls, + qubits: tuple[int, ...], + values: Sequence[int | str] | str, + ) -> "PauliString": + """Create a PauliString from qubit indices and Pauli labels. + + Args: + qubits: Tuple of qubit indices. + values: Sequence of Pauli identifiers (integers 0–3 or strings + 'I', 'X', 'Y', 'Z'). A plain string like ``"XZI"`` is also + accepted and treated as individual characters. + + Returns: + A new PauliString instance. + + Raises: + ValueError: If qubits and values have different lengths, or if + any value is not a valid Pauli identifier. + """ + if len(qubits) != len(values): + raise ValueError( + f"Length mismatch: {len(qubits)} qubits vs {len(values)} values." + ) + paulis = [Pauli(v, q) for q, v in zip(qubits, values)] + return cls(paulis) + + @property + def qubits(self) -> tuple[int, ...]: + """Return the tuple of qubit indices. + + Returns: + Tuple of qubit indices, one per Pauli operator. + """ + return tuple(p.qubit for p in self._paulis) + + def __iter__(self): + """Iterate over the Pauli operators in this PauliString. + + Yields: + :class:`Pauli` instances in order. + """ + return iter(self._paulis) + + def __len__(self) -> int: + return len(self._paulis) + + def __getitem__(self, index: int) -> Pauli: + return self._paulis[index] + + def __repr__(self) -> str: + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + s = "".join(labels[p.op] for p in self._paulis) + return f"PauliString(qubits={self.qubits}, ops='{s}')" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PauliString): + return NotImplemented + return self._paulis == other._paulis + + def __hash__(self) -> int: + return hash(self._paulis) + + @property + def cirq(self): + """Return the corresponding Cirq ``PauliString``. + + Constructs a ``cirq.PauliString`` by applying each single-qubit + Pauli to its corresponding ``cirq.LineQubit``. + + Returns: + A ``cirq.PauliString`` acting on ``cirq.LineQubit`` instances. + """ + _INT_TO_CIRQ = (cirq.I, cirq.X, cirq.Z, cirq.Y) + return cirq.PauliString( + {cirq.LineQubit(p.qubit): _INT_TO_CIRQ[p.op] for p in self._paulis} + ) diff --git a/source/pip/tests/magnets/test_complete.py b/source/pip/tests/magnets/test_complete.py index 93237e7ced..614d030c50 100644 --- a/source/pip/tests/magnets/test_complete.py +++ b/source/pip/tests/magnets/test_complete.py @@ -103,7 +103,7 @@ def test_complete_graph_str(): def test_complete_graph_inherits_hypergraph(): """Test that CompleteGraph is a Hypergraph subclass with all methods.""" - from qsharp.magnets.geometry.hypergraph import Hypergraph + from qsharp.magnets.utilities import Hypergraph graph = CompleteGraph(4) assert isinstance(graph, Hypergraph) @@ -241,7 +241,7 @@ def test_complete_bipartite_graph_str(): def test_complete_bipartite_graph_inherits_hypergraph(): """Test that CompleteBipartiteGraph is a Hypergraph subclass with all methods.""" - from qsharp.magnets.geometry.hypergraph import Hypergraph + from qsharp.magnets.utilities import Hypergraph graph = CompleteBipartiteGraph(2, 3) assert isinstance(graph, Hypergraph) diff --git a/source/pip/tests/magnets/test_hypergraph.py b/source/pip/tests/magnets/test_hypergraph.py index c158d9589e..6df404787c 100755 --- a/source/pip/tests/magnets/test_hypergraph.py +++ b/source/pip/tests/magnets/test_hypergraph.py @@ -3,7 +3,7 @@ """Unit tests for hypergraph data structures.""" -from qsharp.magnets.geometry.hypergraph import ( +from qsharp.magnets.utilities import ( Hyperedge, Hypergraph, greedy_edge_coloring, diff --git a/source/pip/tests/magnets/test_lattice1d.py b/source/pip/tests/magnets/test_lattice1d.py index e9ccacd519..b4bbf152c4 100644 --- a/source/pip/tests/magnets/test_lattice1d.py +++ b/source/pip/tests/magnets/test_lattice1d.py @@ -246,7 +246,7 @@ def test_ring1d_vs_chain1d_edge_count(): def test_chain1d_inherits_hypergraph(): """Test that Chain1D is a Hypergraph subclass with all methods.""" - from qsharp.magnets.geometry.hypergraph import Hypergraph + from qsharp.magnets.utilities import Hypergraph chain = Chain1D(4) assert isinstance(chain, Hypergraph) @@ -258,7 +258,7 @@ def test_chain1d_inherits_hypergraph(): def test_ring1d_inherits_hypergraph(): """Test that Ring1D is a Hypergraph subclass with all methods.""" - from qsharp.magnets.geometry.hypergraph import Hypergraph + from qsharp.magnets.utilities import Hypergraph ring = Ring1D(4) assert isinstance(ring, Hypergraph) diff --git a/source/pip/tests/magnets/test_lattice2d.py b/source/pip/tests/magnets/test_lattice2d.py index ccf95c313a..d629975227 100644 --- a/source/pip/tests/magnets/test_lattice2d.py +++ b/source/pip/tests/magnets/test_lattice2d.py @@ -267,7 +267,7 @@ def test_torus2d_vs_patch2d_edge_count(): def test_patch2d_inherits_hypergraph(): """Test that Patch2D is a Hypergraph subclass with all methods.""" - from qsharp.magnets.geometry.hypergraph import Hypergraph + from qsharp.magnets.utilities import Hypergraph patch = Patch2D(3, 3) assert isinstance(patch, Hypergraph) @@ -279,7 +279,7 @@ def test_patch2d_inherits_hypergraph(): def test_torus2d_inherits_hypergraph(): """Test that Torus2D is a Hypergraph subclass with all methods.""" - from qsharp.magnets.geometry.hypergraph import Hypergraph + from qsharp.magnets.utilities import Hypergraph torus = Torus2D(3, 3) assert isinstance(torus, Hypergraph) diff --git a/source/pip/tests/magnets/test_model.py b/source/pip/tests/magnets/test_model.py index 246b11fb15..d86d3924c0 100644 --- a/source/pip/tests/magnets/test_model.py +++ b/source/pip/tests/magnets/test_model.py @@ -9,8 +9,8 @@ import pytest -from qsharp.magnets.geometry import Hyperedge, Hypergraph from qsharp.magnets.models import Model +from qsharp.magnets.utilities import Hyperedge, Hypergraph, PauliString def make_chain(length: int) -> Hypergraph: @@ -61,6 +61,14 @@ def test_model_init_coefficients_zero(): assert model.get_coefficient((1, 2)) == 0.0 +def test_model_init_pauli_strings_identity(): + """Test that PauliStrings are initialized to identity.""" + geometry = make_chain(3) # edges: (0,1), (1,2) + model = Model(geometry) + assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "II") + assert model.get_pauli_string((1, 2)) == PauliString.from_qubits((1, 2), "II") + + # Coefficient tests @@ -113,30 +121,92 @@ def test_model_set_coefficient_sorted(): assert model.get_coefficient((0, 1)) == 4.0 -# has_coefficient tests +def test_model_set_coefficient_preserves_pauli_string(): + """Test that set_coefficient does not change the PauliString.""" + geometry = make_chain(2) + model = Model(geometry) + model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) + model.set_coefficient((0, 1), 3.0) + assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "ZZ") + + +# PauliString tests + + +def test_model_set_pauli_string(): + """Test setting PauliString for an edge.""" + geometry = make_chain(2) + model = Model(geometry) + model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) + assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "ZZ") + + +def test_model_set_pauli_string_overwrite(): + """Test overwriting an existing PauliString.""" + geometry = make_chain(2) + model = Model(geometry) + model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) + model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "XX")) + assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "XX") + + +def test_model_set_pauli_string_invalid_edge(): + """Test setting PauliString for non-existent edge raises error.""" + geometry = make_chain(2) + model = Model(geometry) + with pytest.raises(KeyError): + model.set_pauli_string((0, 2), PauliString.from_qubits((0, 2), "ZZ")) + + +def test_model_get_pauli_string_invalid_edge(): + """Test getting PauliString for non-existent edge raises error.""" + geometry = make_chain(2) + model = Model(geometry) + with pytest.raises(KeyError): + model.get_pauli_string((0, 2)) + + +def test_model_set_pauli_string_preserves_coefficient(): + """Test that set_pauli_string does not change the coefficient.""" + geometry = make_chain(2) + model = Model(geometry) + model.set_coefficient((0, 1), 5.0) + model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) + assert model.get_coefficient((0, 1)) == 5.0 + + +def test_model_set_pauli_string_sorted(): + """Test that set_pauli_string sorts vertices so order doesn't matter.""" + geometry = make_chain(2) + model = Model(geometry) + model.set_pauli_string((1, 0), PauliString.from_qubits((1, 0), "XZ")) + assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((1, 0), "XZ") -def test_model_has_coefficient_true(): - """Test has_coefficient returns True for existing edge.""" +# has_interaction_term tests + + +def test_model_has_interaction_term_true(): + """Test has_interaction_term returns True for existing edge.""" geometry = make_chain(3) model = Model(geometry) - assert model.has_coefficient((0, 1)) is True - assert model.has_coefficient((1, 2)) is True + assert model.has_interaction_term((0, 1)) is True + assert model.has_interaction_term((1, 2)) is True -def test_model_has_coefficient_false(): - """Test has_coefficient returns False for non-existent edge.""" +def test_model_has_interaction_term_false(): + """Test has_interaction_term returns False for non-existent edge.""" geometry = make_chain(3) model = Model(geometry) - assert model.has_coefficient((0, 2)) is False - assert model.has_coefficient((5, 6)) is False + assert model.has_interaction_term((0, 2)) is False + assert model.has_interaction_term((5, 6)) is False -def test_model_has_coefficient_sorted(): - """Test has_coefficient sorts vertices so order doesn't matter.""" +def test_model_has_interaction_term_sorted(): + """Test has_interaction_term sorts vertices so order doesn't matter.""" geometry = make_chain(2) model = Model(geometry) - assert model.has_coefficient((1, 0)) is True + assert model.has_interaction_term((1, 0)) is True # Term management tests @@ -316,3 +386,19 @@ def test_translation_invariant_ising_model_term_grouping(): # Number of terms should be ncolors + 1 assert len(model.terms()) == geometry.ncolors + 1 + + +def test_translation_invariant_ising_model_pauli_strings(): + """Test that Ising model sets correct PauliStrings.""" + from qsharp.magnets.models import translation_invariant_ising_model + + geometry = make_chain_with_vertices(3) + model = translation_invariant_ising_model(geometry, h=1.0, J=1.0) + + # Two-body edges should have ZZ PauliString + assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "ZZ") + assert model.get_pauli_string((1, 2)) == PauliString.from_qubits((1, 2), "ZZ") + + # Single-vertex edges should have X PauliString + for v in range(3): + assert model.get_pauli_string((v,)) == PauliString.from_qubits((v,), "X") From 511c953d3ceeb4a01397b1de31a784eba4055ac0 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Wed, 25 Feb 2026 14:13:28 +0100 Subject: [PATCH 18/45] Various RE models and some API enhancements (#2963) This PR adds various models: - Qubits - Majorana - Error correcting codes - ThreeAux - Yoked surface code - Factories - Round-based factory - Litinski19 factory It also makes some updates to the surface code model and improves documentation. Further, it adds a large test suite for all the models. Some of these tests are very detailed but they will help us to spot if we introduce inconsistencies with future code changes. Besides that there are some fixes to the Python API: - Adds caching to Q# to Trace transformation - InstructionFrontier can be 2D or 3D (without and with error rates) - Instance enumeration can also consider nested types and union types now - EstimationResult can be used from the Python API to create new results --- source/pip/qsharp/qre/_enumeration.py | 19 + source/pip/qsharp/qre/_estimation.py | 2 +- source/pip/qsharp/qre/_qre.pyi | 79 +- source/pip/qsharp/qre/_trace.py | 2 +- source/pip/qsharp/qre/application/_qsharp.py | 18 +- source/pip/qsharp/qre/interop/__init__.py | 4 +- source/pip/qsharp/qre/interop/_qsharp.py | 18 +- source/pip/qsharp/qre/models/__init__.py | 16 +- .../qsharp/qre/models/factories/__init__.py | 8 + .../qsharp/qre/models/factories/_litinski.py | 359 ++++++ .../qre/models/factories/_round_based.py | 409 +++++++ .../pip/qsharp/qre/models/factories/_utils.py | 89 ++ source/pip/qsharp/qre/models/qec/__init__.py | 4 +- .../qsharp/qre/models/qec/_surface_code.py | 45 +- .../pip/qsharp/qre/models/qec/_three_aux.py | 118 ++ source/pip/qsharp/qre/models/qec/_yoked.py | 158 +++ .../pip/qsharp/qre/models/qubits/__init__.py | 3 +- source/pip/qsharp/qre/models/qubits/_aqre.py | 19 +- source/pip/qsharp/qre/models/qubits/_msft.py | 97 ++ source/pip/src/qre.rs | 137 ++- source/pip/test_requirements.txt | 1 + source/pip/tests/test_qre.py | 89 +- source/pip/tests/test_qre_models.py | 1019 +++++++++++++++++ source/qre/src/pareto.rs | 2 +- source/qre/src/trace.rs | 11 +- 25 files changed, 2651 insertions(+), 75 deletions(-) create mode 100644 source/pip/qsharp/qre/models/factories/__init__.py create mode 100644 source/pip/qsharp/qre/models/factories/_litinski.py create mode 100644 source/pip/qsharp/qre/models/factories/_round_based.py create mode 100644 source/pip/qsharp/qre/models/factories/_utils.py create mode 100644 source/pip/qsharp/qre/models/qec/_three_aux.py create mode 100644 source/pip/qsharp/qre/models/qec/_yoked.py create mode 100644 source/pip/qsharp/qre/models/qubits/_msft.py create mode 100644 source/pip/tests/test_qre_models.py diff --git a/source/pip/qsharp/qre/_enumeration.py b/source/pip/qsharp/qre/_enumeration.py index d41b279d0c..07d4b81466 100644 --- a/source/pip/qsharp/qre/_enumeration.py +++ b/source/pip/qsharp/qre/_enumeration.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import types from typing import ( Generator, Type, TypeVar, Literal, + Union, get_args, get_origin, get_type_hints, @@ -108,6 +110,23 @@ class MyConfig: values.append(list(get_args(current_type))) continue + # Union types (e.g., OptionA | OptionB or Union[OptionA, OptionB]) + if get_origin(current_type) is Union or isinstance( + current_type, types.UnionType + ): + union_domain = [] + for member_type in get_args(current_type): + union_domain.extend(_enumerate_instances(member_type)) + values.append(union_domain) + continue + + # Nested dataclass types + if isinstance(current_type, type) and hasattr( + current_type, "__dataclass_fields__" + ): + values.append(list(_enumerate_instances(current_type))) + continue + if field.default is not MISSING: values.append([field.default]) continue diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index c542b4c597..1d1fd170c3 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -95,7 +95,7 @@ def as_frame(self): [ { "qubits": entry.qubits, - "runtime": entry.runtime, + "runtime": pd.Timedelta(entry.runtime, unit="ns"), "error": entry.error, } for entry in self diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 4cac0a4894..c3301a448d 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -671,6 +671,22 @@ class EstimationResult: Represents the result of a resource estimation. """ + def __new__( + cls, *, qubits: int = 0, runtime: int = 0, error: float = 0.0 + ) -> EstimationResult: + """ + Creates a new estimation result. + + Args: + qubits (int): The number of logical qubits. + runtime (int): The runtime in nanoseconds. + error (float): The error probability of the computation. + + Returns: + EstimationResult: The estimation result. + """ + ... + @property def qubits(self) -> int: """ @@ -681,6 +697,15 @@ class EstimationResult: """ ... + def add_qubits(self, qubits: int) -> None: + """ + Adds to the number of logical qubits. + + Args: + qubits (int): The number of logical qubits to add. + """ + ... + @property def runtime(self) -> int: """ @@ -691,6 +716,15 @@ class EstimationResult: """ ... + def add_runtime(self, runtime: int) -> None: + """ + Adds to the runtime. + + Args: + runtime (int): The amount of runtime in nanoseconds to add. + """ + ... + @property def error(self) -> float: """ @@ -701,6 +735,15 @@ class EstimationResult: """ ... + def add_error(self, error: float) -> None: + """ + Adds to the error probability. + + Args: + error (float): The amount to add to the error probability. + """ + ... + @property def factories(self) -> dict[int, FactoryResult]: """ @@ -857,6 +900,28 @@ class Trace: """ ... + @classmethod + def from_json(cls, json: str) -> Trace: + """ + Creates a trace from a JSON string. + + Args: + json (str): The JSON string. + + Returns: + Trace: The trace. + """ + ... + + def to_json(self) -> str: + """ + Serializes the trace to a JSON string. + + Returns: + str: The JSON string representation of the trace. + """ + ... + @property def compute_qubits(self) -> int: """ @@ -1098,9 +1163,14 @@ class InstructionFrontier: rates as objectives. """ - def __new__(cls) -> InstructionFrontier: + def __new__(cls, *, with_error_objective: bool = True) -> InstructionFrontier: """ Creates a new instruction frontier. + + Args: + with_error_objective (bool): If True (default), the frontier uses + three objectives (space, time, error rate). If False, it uses + two objectives (space, time). """ ... @@ -1141,12 +1211,17 @@ class InstructionFrontier: ... @staticmethod - def load(filename: str) -> InstructionFrontier: + def load( + filename: str, *, with_error_objective: bool = True + ) -> InstructionFrontier: """ Loads an instruction frontier from a file. Args: filename (str): The file name. + with_error_objective (bool): If True (default), the frontier uses + three objectives (space, time, error rate). If False, it uses + two objectives (space, time). Returns: InstructionFrontier: The loaded instruction frontier. diff --git a/source/pip/qsharp/qre/_trace.py b/source/pip/qsharp/qre/_trace.py index d57b30db76..efe7985390 100644 --- a/source/pip/qsharp/qre/_trace.py +++ b/source/pip/qsharp/qre/_trace.py @@ -24,7 +24,7 @@ def q(cls, **kwargs) -> TraceQuery: class PSSPC(TraceTransform): _: KW_ONLY num_ts_per_rotation: int = field( - default=20, metadata={"domain": list(range(1, 21))} + default=20, metadata={"domain": list(range(5, 21))} ) ccx_magic_states: bool = field(default=False) diff --git a/source/pip/qsharp/qre/application/_qsharp.py b/source/pip/qsharp/qre/application/_qsharp.py index b01a8d329c..0ed642d826 100644 --- a/source/pip/qsharp/qre/application/_qsharp.py +++ b/source/pip/qsharp/qre/application/_qsharp.py @@ -4,19 +4,31 @@ from __future__ import annotations -from dataclasses import dataclass +from pathlib import Path +from dataclasses import dataclass, field from typing import Callable from ...estimator import LogicalCounts from .._qre import Trace from .._application import Application -from ..interop import trace_from_entry_expr +from ..interop import trace_from_entry_expr_cached @dataclass class QSharpApplication(Application[None]): + cache_dir: Path = field( + default=Path.home() / ".cache" / "re3" / "qsharp", repr=False + ) + use_cache: bool = field(default=False, repr=False) + def __init__(self, entry_expr: str | Callable | LogicalCounts): self._entry_expr = entry_expr def get_trace(self, parameters: None = None) -> Trace: - return trace_from_entry_expr(self._entry_expr) + # TODO: make caching work for `Callable` as well + if self.use_cache and isinstance(self._entry_expr, str): + cache_path = self.cache_dir / f"{self._entry_expr}.json" + else: + cache_path = None + + return trace_from_entry_expr_cached(self._entry_expr, cache_path) diff --git a/source/pip/qsharp/qre/interop/__init__.py b/source/pip/qsharp/qre/interop/__init__.py index 01a5234a5b..3b6c04c922 100644 --- a/source/pip/qsharp/qre/interop/__init__.py +++ b/source/pip/qsharp/qre/interop/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from ._qsharp import trace_from_entry_expr +from ._qsharp import trace_from_entry_expr, trace_from_entry_expr_cached -__all__ = ["trace_from_entry_expr"] +__all__ = ["trace_from_entry_expr", "trace_from_entry_expr_cached"] diff --git a/source/pip/qsharp/qre/interop/_qsharp.py b/source/pip/qsharp/qre/interop/_qsharp.py index 1c7d041430..d2f534fa4d 100644 --- a/source/pip/qsharp/qre/interop/_qsharp.py +++ b/source/pip/qsharp/qre/interop/_qsharp.py @@ -3,8 +3,9 @@ from __future__ import annotations +from pathlib import Path import time -from typing import Callable +from typing import Callable, Optional from ..._qsharp import logical_counts from ...estimator import LogicalCounts @@ -76,3 +77,18 @@ def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: trace.set_property("evaluation_time", evaluation_time) return trace + + +def trace_from_entry_expr_cached( + entry_expr: str | Callable | LogicalCounts, cache_path: Optional[Path] +) -> Trace: + if cache_path and cache_path.exists(): + return Trace.from_json(cache_path.read_text()) + + trace = trace_from_entry_expr(entry_expr) + + if cache_path: + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.write_text(trace.to_json()) + + return trace diff --git a/source/pip/qsharp/qre/models/__init__.py b/source/pip/qsharp/qre/models/__init__.py index 10a82c977e..5f22ddb1a6 100644 --- a/source/pip/qsharp/qre/models/__init__.py +++ b/source/pip/qsharp/qre/models/__init__.py @@ -1,7 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .qec import SurfaceCode -from .qubits import AQREGateBased +from .factories import Litinski19Factory, MagicUpToClifford, RoundBasedFactory +from .qec import SurfaceCode, ThreeAux, YokedSurfaceCode +from .qubits import AQREGateBased, Majorana -__all__ = ["SurfaceCode", "AQREGateBased"] +__all__ = [ + "AQREGateBased", + "Litinski19Factory", + "Majorana", + "MagicUpToClifford", + "RoundBasedFactory", + "SurfaceCode", + "ThreeAux", + "YokedSurfaceCode", +] diff --git a/source/pip/qsharp/qre/models/factories/__init__.py b/source/pip/qsharp/qre/models/factories/__init__.py new file mode 100644 index 0000000000..e652dfc983 --- /dev/null +++ b/source/pip/qsharp/qre/models/factories/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._litinski import Litinski19Factory +from ._round_based import RoundBasedFactory +from ._utils import MagicUpToClifford + +__all__ = ["Litinski19Factory", "MagicUpToClifford", "RoundBasedFactory"] diff --git a/source/pip/qsharp/qre/models/factories/_litinski.py b/source/pip/qsharp/qre/models/factories/_litinski.py new file mode 100644 index 0000000000..3ce98c31f7 --- /dev/null +++ b/source/pip/qsharp/qre/models/factories/_litinski.py @@ -0,0 +1,359 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import dataclass +from math import ceil +from typing import Generator + +from ..._architecture import _Context +from ..._qre import ISA, ISARequirements, ConstraintBound, _Instruction +from ..._instruction import ISATransform, constraint, instruction, LOGICAL +from ...instruction_ids import T, CNOT, H, MEAS_Z, CCZ + + +@dataclass +class Litinski19Factory(ISATransform): + """ + T and CCZ factories based on the paper + [arXiv:1905.06903](https://arxiv.org/abs/1905.06903). + + It contains two categories of estimates. If the input T error rate is + similar to the Clifford error, it produces magic state instructions based on + Table 1 in the paper. If the input T error rate is at most 10 times higher + than the Clifford error rate, it produces magic state instructions based on + Table 2 in the paper. + + It requires Clifford error rates of at most 0.1% for CNOT, H, and MEAS_Z + instructions. If these instructions have different error rates, the maximum + error rate is assumed. + + References: + + - Daniel Litinski: Magic state distillation: not as costly as you think, + [arXiv:1905.06903](https://arxiv.org/abs/1905.06903) + """ + + def __post_init__(self): + self._initialize_entries() + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + # T error rate may be at least 10x higher than Clifford error rates + constraint(T, error_rate=ConstraintBound.le(1e-2)), + constraint(H, error_rate=ConstraintBound.le(1e-3)), + constraint(CNOT, arity=2, error_rate=ConstraintBound.le(1e-3)), + constraint(MEAS_Z, error_rate=ConstraintBound.le(1e-3)), + ) + + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + h = impl_isa[H] + cnot = impl_isa[CNOT] + meas_z = impl_isa[MEAS_Z] + t = impl_isa[T] + + clifford_error_rate = max( + h.expect_error_rate(), + cnot.expect_error_rate(), + meas_z.expect_error_rate(), + ) + + t_error_rate = t.expect_error_rate() + + entries_by_state = None + + if clifford_error_rate <= 1e-4: + if t_error_rate <= 1e-4: + entries_by_state = self._entries[1e-4][0] + elif t_error_rate <= 1e-3: + entries_by_state = self._entries[1e-4][1] + else: + # NOTE: This assertion is valid due to the constraint bound in the + # required_isa method + assert clifford_error_rate <= 1e-3 + if t_error_rate <= 1e-3: + entries_by_state = self._entries[1e-3][0] + elif t_error_rate <= 1e-2: + entries_by_state = self._entries[1e-3][1] + + if entries_by_state is None: + return + + t_entries = entries_by_state.get(T, []) + ccz_entries = entries_by_state.get(CCZ, []) + + syndrome_extraction_time = ( + 4 * impl_isa[CNOT].expect_time() + + impl_isa[H].expect_time() + + impl_isa[MEAS_Z].expect_time() + ) + + def make_instruction(entry: _Entry) -> _Instruction: + # Convert cycles (number of syndrome extraction cycles) to time + # based on fast surface code + time = ceil(syndrome_extraction_time * entry.cycles) + + # NOTE: If the protocol outputs multiple states, we assume that the + # space cost is divided by the number of output states. This is a + # simplification that allows us to fit all protocols in the ISA, but + # it may not be accurate for all protocols. + inst = instruction( + entry.state, + arity=3 if entry.state == CCZ else 1, + encoding=LOGICAL, + space=ceil(entry.space / entry.output_states), + time=time, + error_rate=entry.error_rate, + ) + return ctx.set_source(self, inst, [cnot, h, meas_z, t]) + + # Yield combinations of T and CCZ entries + if ccz_entries: + for t_entry in t_entries: + for ccz_entry in ccz_entries: + yield ISA( + make_instruction(t_entry), + make_instruction(ccz_entry), + ) + else: + # Table 2 scenarios: only T gates available + for t_entry in t_entries: + yield ISA(make_instruction(t_entry)) + + def _initialize_entries(self): + self._entries = { + # Assuming a Clifford error rate of at most 1e-4: + 1e-4: ( + # Assuming a T error rate of at most 1e-4 (Table 1): + { + T: [ + _Entry(_Protocol(15, 1, 7, 3, 3), 4.4e-8, 810, 18.1), + _Entry(_Protocol(15, 1, 9, 3, 3), 9.3e-10, 1_150, 18.1), + _Entry(_Protocol(15, 1, 11, 5, 5), 1.9e-11, 2_070, 30.0), + _Entry( + [ + (_Protocol(15, 1, 9, 3, 3), 4), + (_Protocol(20, 4, 15, 7, 9), 1), + ], + 2.4e-15, + 16_400, + 90.3, + ), + _Entry( + [ + (_Protocol(15, 1, 9, 3, 3), 4), + (_Protocol(15, 1, 25, 9, 9), 1), + ], + 6.3e-25, + 18_600, + 67.8, + ), + _Entry(_Protocol(15, 1, 9, 3, 3), 1.5e-9, 762, 36.2), + ], + CCZ: [ + _Entry( + [ + (_Protocol(15, 1, 7, 3, 3), 4), + (_Protocol(8, 1, 15, 7, 9, CCZ), 1), + ], + 7.2e-14, + 12_400, + 36.1, + ), + ], + }, + # Assuming a T error rate of at most 1e-3 (10x higher than Clifford, Table 2): + { + T: [ + _Entry(_Protocol(15, 1, 9, 3, 3), 2.1e-8, 1_150, 18.2), + _Entry( + [ + (_Protocol(15, 1, 7, 3, 3), 6), + (_Protocol(20, 4, 13, 5, 7), 1), + ], + 1.4e-12, + 13_200, + 70, + ), + _Entry( + [ + (_Protocol(15, 1, 9, 3, 3), 4), + (_Protocol(20, 4, 15, 7, 9), 1), + ], + 6.6e-15, + 16_400, + 91.2, + ), + _Entry( + [ + (_Protocol(15, 1, 9, 3, 3), 4), + (_Protocol(15, 1, 25, 9, 9), 1), + ], + 4.2e-22, + 18_600, + 68.4, + ), + ], + CCZ: [], + }, + ), + # Assuming a Clifford error rate of at most 1e-3: + 1e-3: ( + # Assuming a T error rate of at most 1e-3 (Table 1): + { + T: [ + _Entry(_Protocol(15, 1, 17, 7, 7), 4.5e-8, 4_620, 42.6), + _Entry( + [ + (_Protocol(15, 1, 13, 5, 5), 6), + (_Protocol(20, 4, 23, 11, 13), 1), + ], + 1.4e-10, + 43_300, + 130, + ), + _Entry( + [ + (_Protocol(15, 1, 13, 5, 5), 4), + (_Protocol(20, 4, 27, 13, 15), 1), + ], + 2.6e-11, + 46_800, + 157, + ), + _Entry( + [ + (_Protocol(15, 1, 11, 5, 5), 6), + (_Protocol(15, 1, 25, 11, 11), 1), + ], + 2.7e-12, + 30_700, + 82.5, + ), + _Entry( + [ + (_Protocol(15, 1, 13, 5, 5), 6), + (_Protocol(15, 1, 29, 11, 13), 1), + ], + 3.3e-14, + 39_100, + 97.5, + ), + _Entry( + [ + (_Protocol(15, 1, 15, 7, 7), 6), + (_Protocol(15, 1, 41, 17, 17), 1), + ], + 4.5e-20, + 73_400, + 128, + ), + ], + CCZ: [ + _Entry( + [ + (_Protocol(15, 1, 13, 7, 7), 6), + (_Protocol(8, 1, 25, 15, 15, CCZ), 1), + ], + 5.2e-11, + 47_000, + 60, + ), + ], + }, + # Assuming a T error rate of at most 1e-2 (10x higher than Clifford, Table 2): + { + T: [ + _Entry( + [ + (_Protocol(15, 1, 13, 5, 5), 6), + (_Protocol(20, 4, 21, 11, 13), 1), + ], + 5.7e-9, + 40_700, + 130, + ), + _Entry( + [ + (_Protocol(15, 1, 11, 5, 5), 6), + (_Protocol(15, 1, 21, 9, 11), 1), + ], + 2.1e-10, + 27_400, + 85.7, + ), + _Entry( + [ + (_Protocol(15, 1, 11, 5, 5), 6), + (_Protocol(15, 1, 23, 11, 11), 1), + ], + 2.5e-11, + 29_500, + 85.7, + ), + _Entry( + [ + (_Protocol(15, 1, 11, 5, 5), 6), + (_Protocol(15, 1, 25, 11, 11), 1), + ], + 6.4e-12, + 30_700, + 85.7, + ), + _Entry( + [ + (_Protocol(15, 1, 13, 7, 7), 8), + (_Protocol(15, 1, 29, 13, 13), 1), + ], + 1.5e-13, + 52_400, + 97.5, + ), + ], + CCZ: [], + }, + ), + } + + +@dataclass(frozen=True, slots=True) +class _Entry: + protocol: list[tuple[_Protocol, int]] | _Protocol + error_rate: float + # Space estimation in number of physical qubits + space: int + # Number of code cycles to estimate time; a code cycle corresponds to + # measuring all surface-code check operators exactly once. + cycles: float + + @property + def output_states(self) -> int: + if isinstance(self.protocol, list): + return self.protocol[-1][0].output_states + else: + return self.protocol.output_states + + @property + def state(self) -> int: + if isinstance(self.protocol, list): + return self.protocol[-1][0].state + else: + return self.protocol.state + + +@dataclass(frozen=True, slots=True) +class _Protocol: + # Number of input T states in protocol + input_states: int + # Number of output T states in protocol + output_states: int + # Spatial X distance (arXiv:1905.06903, Section 2) + d_x: int + # Spatial Z distance (arXiv:1905.06903, Section 2) + d_z: int + # Temporal distance (arXiv:1905.06903, Section 2) + d_m: int + # Magic state + state: int = T diff --git a/source/pip/qsharp/qre/models/factories/_round_based.py b/source/pip/qsharp/qre/models/factories/_round_based.py new file mode 100644 index 0000000000..bdde7cfcf9 --- /dev/null +++ b/source/pip/qsharp/qre/models/factories/_round_based.py @@ -0,0 +1,409 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass, field +from itertools import combinations_with_replacement +from math import ceil +from pathlib import Path +from typing import Callable, Generator, Iterable, Optional, Sequence + +from ..._qre import ISA, InstructionFrontier, ISARequirements, _Instruction, _binom_ppf +from ..._instruction import ( + LOGICAL, + PHYSICAL, + ISAQuery, + ISATransform, + constraint, + instruction, +) +from ..._architecture import _Context +from ...instruction_ids import CNOT, LATTICE_SURGERY, T, MEAS_ZZ +from ..qec import SurfaceCode + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RoundBasedFactory(ISATransform): + """ + A magic state factory that produces T gate instructions using round-based + distillation pipelines. + + This factory explores combinations of distillation units (such as "15-to-1 + RM prep" and "15-to-1 space efficient") to find optimal configurations that + minimize time and space while achieving target error rates. It supports + both physical-level distillation (when the input T gate is physically + encoded) and logical-level distillation (using lattice surgery via surface + codes). + + In order to account for the success probability of distillation rounds, the + factory models the pipeline using a failure probability requirement + (defaulting to 1%) that each round must meet. The number of distillation + units per round is adjusted to meet this requirement, which in turn affects + the overall space requirements. + + Space requirements are calculated using a user-provided function that + aggregates per-round space (e.g., sum or max). The `sum` function models + the case in which qubits are not reused across rounds, while the `max` + function models the case in which qubits are reused across rounds. + + For the enumeration of logical-level distillation units, the factory relies + on a user-provided `ISAQuery` (defaulting to `SurfaceCode.q()`) to explore + different surface code configurations and their corresponding lattice + surgery instructions. These need to be provided by the user and cannot + automatically be derived from the provided implementation ISA, as they can + only contain a subset of the required instructions. The user needs to + ensure that the provided query matches the architecture for which this + factory is being used. + + Results are cached to disk for efficiency. + + Attributes: + code_query: ISAQuery + Query to enumerate QEC codes for logical distillation units. + Defaults to SurfaceCode.q(). + physical_qubit_calculation: Callable[[Iterable], int] + Function to calculate total physical qubits from per-round space + requirements, e.g., sum or max. Defaults to sum. + cache_dir: Path + Directory for caching computed factory configurations. Defaults to + ~/.cache/re3/round_based. + use_cache: bool + Whether to use cached results. Defaults to True. + + References: + + - Sergei Bravyi, Alexei Kitaev: Universal Quantum Computation with ideal + Clifford gates and noisy ancillas, + [arXiv:quant-ph/0403025](https://arxiv.org/abs/quant-ph/0403025) + - Michael E. Beverland, Prakash Murali, Matthias Troyer, Krysta M. Svore, + Torsten Hoefler, Vadym Kliuchnikov, Guang Hao Low, Mathias Soeken, Aarthi + Sundaram, Alexander Vaschillo: Assessing requirements to scale to + practical quantum advantage, + [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) + """ + + code_query: ISAQuery = field(default_factory=lambda: SurfaceCode.q()) + physical_qubit_calculation: Callable[[Iterable], int] = field(default=sum) + # optional: make cache directory configurable + cache_dir: Path = field( + default=Path.home() / ".cache" / "re3" / "round_based", repr=False + ) + use_cache: bool = field(default=True, repr=False) + + @staticmethod + def required_isa() -> ISARequirements: + # NOTE: A T gate is required, but a CNOT is only required to explore + # physical units. + return ISARequirements( + constraint(T), + ) + + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + cache_path = self._cache_path(impl_isa) + + # 1) Try to load from cache + if self.use_cache and cache_path.exists(): + cached_states = InstructionFrontier.load(str(cache_path)) + for state in cached_states: + yield ISA(state) + return + + # 2) Compute as before + t_gate_error = impl_isa[T].expect_error_rate() + + units: list[_DistillationUnit] = [] + initial_unit = [] + + # Physical units? + if impl_isa[T].encoding == PHYSICAL: + clifford_gate = impl_isa.get(CNOT) or impl_isa.get(MEAS_ZZ) + if clifford_gate is None: + raise ValueError( + "CNOT or MEAS_ZZ instruction is required for physical units" + ) + + gate_time = clifford_gate.expect_time() + clifford_error = clifford_gate.expect_error_rate() + units.extend(self._physical_units(gate_time, clifford_error)) + else: + initial_unit.append( + _DistillationUnit( + 1, + impl_isa[T].expect_time(), + impl_isa[T].expect_space(), + [1, 0], + [0], + ) + ) + + for code_isa in self.code_query.enumerate(ctx): + units.extend(self._logical_units(code_isa[LATTICE_SURGERY])) + + optimal_states = InstructionFrontier() + + for r in range(1, 4 - len(initial_unit)): + for k in combinations_with_replacement(units, r): + pipeline = _Pipeline.try_create( + initial_unit + list(k), + t_gate_error, + physical_qubit_calculation=self.physical_qubit_calculation, + ) + if pipeline is not None: + state = self._state_from_pipeline(pipeline) + optimal_states.insert(state) + logger.debug(f"Optimal states after {r} rounds: {len(optimal_states)}") + + # 3) Save to cache, then yield + if self.use_cache: + optimal_states.dump(str(cache_path)) + + for state in optimal_states: + yield ISA(ctx.set_source(self, state, [impl_isa[T]])) + + def _physical_units(self, gate_time, clifford_error) -> list[_DistillationUnit]: + return [ + _DistillationUnit( + num_input_states=15, + time=24 * gate_time, + space=31, + error_rate_coeffs=[35, 0.0, 0.0, 7.1 * clifford_error], + failure_probability_coeffs=[15, 356 * clifford_error], + name="15-to-1 RM prep", + ), + _DistillationUnit( + num_input_states=15, + time=45 * gate_time, + space=12, + error_rate_coeffs=[35, 0.0, 0.0, 7.1 * clifford_error], + failure_probability_coeffs=[15, 356 * clifford_error], + name="15-to-1 space efficient", + ), + ] + + def _logical_units( + self, lattice_surgery_instruction: _Instruction + ) -> list[_DistillationUnit]: + logical_cycle_time = lattice_surgery_instruction.expect_time(1) + logical_error = lattice_surgery_instruction.expect_error_rate(1) + + return [ + _DistillationUnit( + num_input_states=15, + time=11 * logical_cycle_time, + space=lattice_surgery_instruction.expect_space(31), + error_rate_coeffs=[35, 0.0, 0.0, 7.1 * logical_error], + failure_probability_coeffs=[15, 356 * logical_error], + name="15-to-1 RM prep", + ), + _DistillationUnit( + num_input_states=15, + time=13 * logical_cycle_time, + space=lattice_surgery_instruction.expect_space(20), + error_rate_coeffs=[35, 0.0, 0.0, 7.1 * logical_error], + failure_probability_coeffs=[15, 356 * logical_error], + name="15-to-1 space efficient", + ), + ] + + def _state_from_pipeline(self, pipeline: _Pipeline) -> _Instruction: + return instruction( + T, + encoding=LOGICAL, + time=pipeline.time, + error_rate=pipeline.error_rate, + space=pipeline.space, + ) + + def _cache_key(self, impl_isa: ISA) -> str: + """Build a deterministic key from factory configuration and impl_isa.""" + # You can refine this if ISA has a better serialization method. + payload = { + "factory": type(self).__qualname__, + "code_query": getattr( + self.code_query, "__qualname__", repr(self.code_query) + ), + "impl_isa": str(impl_isa), + } + data = repr(payload).encode("utf-8") + return hashlib.sha256(data).hexdigest() + + def _cache_path(self, impl_isa: ISA) -> Path: + self.cache_dir.mkdir(parents=True, exist_ok=True) + return self.cache_dir / f"{self._cache_key(impl_isa)}.json" + + +class _Pipeline: + def __init__( + self, + units: Sequence[_DistillationUnit], + initial_input_error_rate: float, + *, + failure_probability_requirement: float = 0.01, + physical_qubit_calculation: Callable[[Iterable], int] = sum, + ): + self.failure_probability_requirement = failure_probability_requirement + self.rounds: list["_DistillationRound"] = [] + self.output_error_rate: float = initial_input_error_rate + self.physical_qubit_calculation = physical_qubit_calculation + + self._add_rounds(units) + + @classmethod + def try_create( + cls, + units: Sequence[_DistillationUnit], + initial_input_error_rate: float, + *, + failure_probability_requirement: float = 0.01, + physical_qubit_calculation: Callable[[Iterable], int] = sum, + ) -> Optional[_Pipeline]: + pipeline = cls( + units, + initial_input_error_rate, + failure_probability_requirement=failure_probability_requirement, + physical_qubit_calculation=physical_qubit_calculation, + ) + if not pipeline._compute_units_per_round(): + return None + return pipeline + + def _compute_units_per_round(self) -> bool: + if len(self.rounds) > 0: + states_needed_next = self.rounds[-1].unit.num_output_states + + for dist_round in reversed(self.rounds): + if not dist_round.adjust_num_units_to(states_needed_next): + return False + states_needed_next = dist_round.num_input_states + + return True + + def _add_rounds(self, units: Sequence[_DistillationUnit]): + per_round_failure_prob_req = self.failure_probability_requirement / len(units) + + for unit in units: + self.rounds.append( + _DistillationRound( + unit, + per_round_failure_prob_req, + self.output_error_rate, + ) + ) + # TODO: handle case when output_error_rate is larger than input_error_rate + self.output_error_rate = unit.error_rate(self.output_error_rate) + + @property + def space(self) -> int: + return self.physical_qubit_calculation(round.space for round in self.rounds) + + @property + def time(self) -> int: + return sum(round.unit.time for round in self.rounds) + + @property + def error_rate(self) -> float: + return self.output_error_rate + + @property + def num_output_states(self) -> int: + return self.rounds[-1].compute_num_output_states() + + +@dataclass(slots=True) +class _DistillationUnit: + num_input_states: int + time: int + space: int + error_rate_coeffs: Sequence[float] + failure_probability_coeffs: Sequence[float] + name: Optional[str] = None + num_output_states: int = 1 + + def error_rate(self, input_error_rate: float) -> float: + result = 0.0 + for c in self.error_rate_coeffs: + result = result * input_error_rate + c + return result + + def failure_probability(self, input_error_rate: float) -> float: + result = 0.0 + for c in self.failure_probability_coeffs: + result = result * input_error_rate + c + return result + + +@dataclass(slots=True) +class _DistillationRound: + unit: _DistillationUnit + failure_probability_requirement: float + input_error_rate: float + num_units: int = 1 + failure_probability: float = field(init=False) + + def __post_init__(self): + self.failure_probability = self.unit.failure_probability(self.input_error_rate) + + def adjust_num_units_to(self, output_states_needed_next: int) -> bool: + if self.failure_probability == 0.0: + self.num_units = output_states_needed_next + return True + + # Binary search to find the minimal number of units needed + self.num_units = ceil(output_states_needed_next / self.max_num_output_states) + + while True: + num_output_states = self.compute_num_output_states() + if num_output_states < output_states_needed_next: + self.num_units *= 2 + + # Distillation round requires unreasonably high number of units + if self.num_units >= 1_000_000_000_000_000: + return False + else: + break + + upper = self.num_units + lower = self.num_units // 2 + while lower < upper: + self.num_units = (lower + upper) // 2 + num_output_states = self.compute_num_output_states() + if num_output_states >= output_states_needed_next: + upper = self.num_units + else: + lower = self.num_units + 1 + self.num_units = upper + + return True + + @property + def space(self) -> int: + return self.num_units * self.unit.space + + @property + def num_input_states(self) -> int: + return self.num_units * self.unit.num_input_states + + @property + def max_num_output_states(self) -> int: + return self.num_units * self.unit.num_output_states + + def compute_num_output_states(self) -> int: + failure_prob = self.failure_probability + + if failure_prob <= 1e-8: + return self.num_units * self.unit.num_output_states + + # A replacement for SciPy's binom.ppf that is faster + k = _binom_ppf( + self.failure_probability_requirement, + self.num_units, + 1.0 - failure_prob, + ) + + return int(k) * self.unit.num_output_states diff --git a/source/pip/qsharp/qre/models/factories/_utils.py b/source/pip/qsharp/qre/models/factories/_utils.py new file mode 100644 index 0000000000..c52b3583e4 --- /dev/null +++ b/source/pip/qsharp/qre/models/factories/_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Generator + +from ..._architecture import _Context +from ..._qre import ISARequirements, ISA +from ..._instruction import ISATransform +from ...instruction_ids import ( + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + SQRT_SQRT_Z_DAG, + CCX, + CCY, + CCZ, +) + + +class MagicUpToClifford(ISATransform): + """ + An ISA transform that adds Clifford equivalent representations of magic + states. For example, if the input ISA contains a T gate, the provided ISA + will also contain `SQRT_SQRT_X`, `SQRT_SQRT_X_DAG`, `SQRT_SQRT_Y`, + `SQRT_SQRT_Y_DAG`, and `T_DAG`. The same is applied for `CCZ` gates and + their Clifford equivalents. + + Example: + + .. code-block:: python + app = SomeApplication() + arch = SomeArchitecture() + + # This will contain CCX states + trace_query = PSSPC.q(ccx_magic_states=True) * LatticeSurgery.q() + + # This will contain CCZ states + isa_query = SurfaceCode.q() * Litinski19Factory.q() + + # There will be no results from the estimation because there is no + # instruction to support CCX magic states in the query + results = estimate(app, arch, isa_query, trace_query) + assert len(results) == 0 + + # We solve this by wrapping the Litinski19Factory with the + # MagicUpToClifford transform, which transforms the CCZ states in the + # provided ISA into CCX states. + isa_query = SurfaceCode.q() * MagicUpToClifford.q(source=Litinski19Factory.q()) + + # Now we will get results + results = estimate(app, arch, isa_query, trace_query) + assert len(results) != 0 + """ + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements() + + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + # Families of equivalent gates under Clifford conjugation. + families = [ + [ + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + SQRT_SQRT_Z_DAG, + ], + [CCX, CCY, CCZ], + ] + + # For each family, if any member of the family is present in the input ISA, add all members of the family to the provided ISA. + for family in families: + for id in family: + if id in impl_isa: + instr = impl_isa[id] + for equivalent_id in family: + if equivalent_id != id: + impl_isa.append( + ctx.set_source( + self, instr.with_id(equivalent_id), [instr] + ) + ) + break # Check next family + + yield impl_isa diff --git a/source/pip/qsharp/qre/models/qec/__init__.py b/source/pip/qsharp/qre/models/qec/__init__.py index c813df0dc4..588544fb3a 100644 --- a/source/pip/qsharp/qre/models/qec/__init__.py +++ b/source/pip/qsharp/qre/models/qec/__init__.py @@ -2,5 +2,7 @@ # Licensed under the MIT License. from ._surface_code import SurfaceCode +from ._three_aux import ThreeAux +from ._yoked import YokedSurfaceCode -__all__ = ["SurfaceCode"] +__all__ = ["SurfaceCode", "ThreeAux", "YokedSurfaceCode"] diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py index d619b07a7c..6758d5796d 100644 --- a/source/pip/qsharp/qre/models/qec/_surface_code.py +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -15,27 +15,38 @@ ) from ..._isa_enumeration import _Context from ..._qre import linear_function -from ...instruction_ids import CNOT, GENERIC, H, LATTICE_SURGERY, MEAS_Z +from ...instruction_ids import CNOT, H, LATTICE_SURGERY, MEAS_Z @dataclass class SurfaceCode(ISATransform): """ + This class models the gate-based rotated surface code. + Attributes: crossing_prefactor: float The prefactor for logical error rate due to error correction - crossings. (Default is 0.03, see Eq. (11) in arXiv:1208.0928) + crossings. (Default is 0.03, see Eq. (11) in + [arXiv:1208.0928](https://arxiv.org/abs/1208.0928)) error_correction_threshold: float - The error correction threshold for the surface code. Default is - 0.01 (1%), see arXiv:1009.3686. + The error correction threshold for the surface code. (Default is + 0.01 (1%), see [arXiv:1009.3686](https://arxiv.org/abs/1009.3686)) Hyper parameters: distance: int The code distance of the surface code. References: - - [arXiv:1208.0928](https://arxiv.org/abs/1208.0928) - - [arXiv:1009.3686](https://arxiv.org/abs/1009.3686) + + - Dominic Horsman, Austin G. Fowler, Simon Devitt, Rodney Van Meter: Surface + code quantum computing by lattice surgery, + [arXiv:1111.4022](https://arxiv.org/abs/1111.4022) + - Austin G. Fowler, Matteo Mariantoni, John M. Martinis, Andrew N. Cleland: + Surface codes: Towards practical large-scale quantum computation, + [arXiv:1208.0928](https://arxiv.org/abs/1208.0928) + - David S. Wang, Austin G. Fowler, Lloyd C. L. Hollenberg: Quantum computing + with nearest neighbor interactions and error rates over 1%, + [arXiv:1009.3686](https://arxiv.org/abs/1009.3686) """ crossing_prefactor: float = 0.03 @@ -66,10 +77,15 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non meas_z.expect_error_rate(), ) - space_formula = linear_function(2 * self.distance**2) + # There are d^2 data qubits and (d^2 - 1) ancilla qubits in the rotated + # surface code. (See Section 7.1 in arXiv:1111.4022) + space_formula = linear_function(2 * self.distance**2 - 1) - time_value = (h_time + meas_time + cnot_time * 4) * self.distance + # Each syndrome extraction cycle consists of ancilla preparation, 4 + # rounds of CNOTs, and measurement. (See Fig. 2 in arXiv:1009.3686) + time_value = (h_time + 4 * cnot_time + meas_time) * self.distance + # See Eqs. (10) and (11) in arXiv:1208.0928 error_formula = linear_function( self.crossing_prefactor * ( @@ -78,15 +94,8 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non ) ) - generic = instruction( - GENERIC, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - ) - + # We provide a generic lattice surgery instruction (See Section 3 in + # arXiv:1111.4022) lattice_surgery = instruction( LATTICE_SURGERY, encoding=LOGICAL, @@ -94,9 +103,9 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non space=space_formula, time=time_value, error_rate=error_formula, + distance=self.distance, ) yield ISA( - ctx.set_source(self, generic, [cnot, h, meas_z]), ctx.set_source(self, lattice_surgery, [cnot, h, meas_z]), ) diff --git a/source/pip/qsharp/qre/models/qec/_three_aux.py b/source/pip/qsharp/qre/models/qec/_three_aux.py new file mode 100644 index 0000000000..f276061c73 --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/_three_aux.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import KW_ONLY, dataclass, field +from typing import Generator + +from ..._architecture import _Context +from ..._instruction import ( + LOGICAL, + ISATransform, + constraint, + instruction, +) +from ..._qre import ( + ISA, + ISARequirements, + linear_function, +) +from ...instruction_ids import ( + LATTICE_SURGERY, + MEAS_X, + MEAS_XX, + MEAS_Z, + MEAS_ZZ, +) + + +@dataclass +class ThreeAux(ISATransform): + """ + This class models the pairwise measurement-based surface code with three + auxiliary qubits per stabilizer measurement. + + Hyper parameters: + distance: int + The code distance of the surface code. + single_rail: bool + Whether to use single-rail encoding. + + References: + + - Linnea Grans-Samuelsson, Ryan V. Mishmash, David Aasen, Christina Knapp, + Bela Bauer, Brad Lackey, Marcus P. da Silva, Parsa Bonderson: Improved + Pairwise Measurement-Based Surface Code, + [arXiv:2310.12981](https://arxiv.org/abs/2310.12981) + """ + + _: KW_ONLY + distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) + single_rail: bool = field(default=False) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(MEAS_X), + constraint(MEAS_Z), + constraint(MEAS_XX, arity=2), + constraint(MEAS_ZZ, arity=2), + ) + + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + meas_x = impl_isa[MEAS_X] + meas_z = impl_isa[MEAS_Z] + meas_xx = impl_isa[MEAS_XX] + meas_zz = impl_isa[MEAS_ZZ] + + gate_time = max(meas_xx.expect_time(), meas_zz.expect_time()) + + physical_error_rate = max( + meas_x.expect_error_rate(), + meas_z.expect_error_rate(), + meas_xx.expect_error_rate(), + meas_zz.expect_error_rate(), + ) + + # See arXiv:2310.12981, Table 1 and Figs. 2, 3, 4, 6, and 7 + depth = 5 if self.single_rail else 4 + + # See arXiv:2310.12981, Table 1 + error_correction_threshold = 0.0051 if self.single_rail else 0.0066 + + # See arXiv:2310.12981, Fig. 23 + crossing_prefactor = 0.05 + + # d^2 data qubits and 3 qubits for each of the d^2 - 1 stabilizer + # measurements + space_formula = linear_function(4 * self.distance**2 - 3) + + # The measurement circuits do not overlap perfectly, so there is an + # additional 4 steps that need to be accounted for independent of the + # distance (see Section 2 between Eqs. (2) and (3) in arXiv:2310.12981) + time_value = gate_time * (depth * self.distance + 4) + + # Typical fitting curve for surface code logical error (see + # arXiv:1208.0928) + error_formula = linear_function( + crossing_prefactor + * ( + (physical_error_rate / error_correction_threshold) + ** ((self.distance + 1) // 2) + ) + ) + + lattice_surgery = instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + distance=self.distance, + ) + + yield ISA( + ctx.set_source(self, lattice_surgery, [meas_x, meas_z, meas_xx, meas_zz]) + ) diff --git a/source/pip/qsharp/qre/models/qec/_yoked.py b/source/pip/qsharp/qre/models/qec/_yoked.py new file mode 100644 index 0000000000..280ee4bb24 --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/_yoked.py @@ -0,0 +1,158 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass, KW_ONLY, field +from enum import IntEnum +from math import ceil +from typing import Generator + +from ..._instruction import ISATransform, constraint, LOGICAL, PropertyKey, instruction +from ..._qre import ISA, ISARequirements, generic_function +from ..._architecture import _Context +from ...instruction_ids import LATTICE_SURGERY, MEMORY + + +class ShapeHeuristic(IntEnum): + """ + The heuristic to determine the shape of the memory qubits with respect to + the number of required rows and columns. + + Attributes: + MIN_AREA: The shape that minimizes the total number of qubits. + SQUARE: The shape that minimizes the difference between the number of rows + and columns. + """ + + MIN_AREA = 0 + SQUARE = 1 + + +@dataclass +class YokedSurfaceCode(ISATransform): + """ + This class models the Yoked surface code to provide a generic memory + instruction based on lattice surgery instructions from a surface code like + error correction code. + + Attributes: + crossing_prefactor: float + The prefactor for logical error rate (Default is 0.016) + error_correction_threshold: float + The error correction threshold for the surface code (Default is + 0.064) + + Hyper parameters: + shape_heuristic: ShapeHeuristic + The heuristic to determine the shape of the surface code patch for a + given number of logical qubits. (Default is ShapeHeuristic.MIN_AREA) + + References: + + - Craig Gidney, Michael Newman, Peter Brooks, Cody Jones: Yoked surface + codes, [arXiv:2312.04522](https://arxiv.org/abs/2312.04522) + """ + + crossing_prefactor: float = 0.016 + error_correction_threshold: float = 0.064 + _: KW_ONLY + shape_heuristic: ShapeHeuristic = field( + default=ShapeHeuristic.MIN_AREA, metadata={"domain": list(ShapeHeuristic)} + ) + + @staticmethod + def required_isa() -> ISARequirements: + # We require a lattice surgery instruction that also provides the code + # distance as a property. This is necessary to compute the time + # and error rate formulas for the provided memory instruction. + return ISARequirements( + constraint(LATTICE_SURGERY, LOGICAL, arity=None, distance=True), + ) + + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + lattice_surgery = impl_isa[LATTICE_SURGERY] + distance = lattice_surgery.get_property(PropertyKey.DISTANCE) + assert distance is not None + + shape_fn = [self._min_area_shape, self._square_shape][self.shape_heuristic] + + def space(arity: int) -> int: + a, b = shape_fn(arity) + return lattice_surgery.expect_space(a * b) + + space_fn = generic_function(space) + + def time(arity: int) -> int: + a, b = shape_fn(arity) + s = lattice_surgery.expect_time(a * b) + return s * (8 * distance * (a - 1) + 2 * distance) + + time_fn = generic_function(time) + + def error_rate(arity: int) -> float: + a, b = shape_fn(arity) + rounds = 2 * (a - 2) + # logical error rate on a single surface code patch + p = lattice_surgery.expect_error_rate(1) + return ( + rounds**2 + * (a * b) ** 2 + * self.crossing_prefactor + * (p / self.error_correction_threshold) ** ((distance + 1) // 2) + ) + + error_rate_fn = generic_function(error_rate) + + yield ISA( + ctx.set_source( + self, + instruction( + MEMORY, + arity=None, + encoding=LOGICAL, + space=space_fn, + time=time_fn, + error_rate=error_rate_fn, + distance=distance, + ), + [lattice_surgery], + ) + ) + + @staticmethod + def _square_shape(num_qubits: int) -> tuple[int, int]: + """ + Given a number of qubits num_qubits, returns numbers (a + 1) and (b + 2) + such that a * b >= num_qubits and a and b are as close as possible. + """ + + a = int(num_qubits**0.5) + while num_qubits % a != 0: + a -= 1 + b = num_qubits // a + return a + 1, b + 2 + + @staticmethod + def _min_area_shape(num_qubits: int) -> tuple[int, int]: + """ + Given a number of qubits num_qubits, returns numbers (a + 1) and (b + 2) + such that a * b >= num_qubits and a * b is as small as possible. + """ + + best_a = None + best_b = None + best_qubits = num_qubits**2 + + for a in range(1, num_qubits): + # Compute required number of columns to reach the required number + # of logical qubits + b = ceil(num_qubits / a) + + qubits = (a + 1) * (b + 2) + if qubits < best_qubits: + best_qubits = qubits + best_a = a + best_b = b + + assert best_a is not None + assert best_b is not None + return best_a + 1, best_b + 2 diff --git a/source/pip/qsharp/qre/models/qubits/__init__.py b/source/pip/qsharp/qre/models/qubits/__init__.py index f9907adbc3..99c9e1c156 100644 --- a/source/pip/qsharp/qre/models/qubits/__init__.py +++ b/source/pip/qsharp/qre/models/qubits/__init__.py @@ -2,5 +2,6 @@ # Licensed under the MIT License. from ._aqre import AQREGateBased +from ._msft import Majorana -__all__ = ["AQREGateBased"] +__all__ = ["AQREGateBased", "Majorana"] diff --git a/source/pip/qsharp/qre/models/qubits/_aqre.py b/source/pip/qsharp/qre/models/qubits/_aqre.py index b6add8ae2d..981c6223d7 100644 --- a/source/pip/qsharp/qre/models/qubits/_aqre.py +++ b/source/pip/qsharp/qre/models/qubits/_aqre.py @@ -11,8 +11,25 @@ @dataclass class AQREGateBased(Architecture): """ + A generic gate-based architecture based on the qubit parameters in Azure + Quantum Resource Estimator (AQRE, + [arXiv:2211.07629](https://arxiv.org/abs/2211.07629)). The error rate can + be set arbitrarily and is either 1e-3 or 1e-4 in the reference. Gate times + are set to 50ns and measurement times are set to 100ns, which are typical + for superconducting transmon qubits + [arXiv:cond-mat/0703002](https://arxiv.org/abs/cond-mat/0703002). + References: - - [arXiv:2211.07629](https://arxiv.org/abs/2211.07629) + + - Michael E. Beverland, Prakash Murali, Matthias Troyer, Krysta M. Svore, + Torsten Hoefler, Vadym Kliuchnikov, Guang Hao Low, Mathias Soeken, Aarthi + Sundaram, Alexander Vaschillo: Assessing requirements to scale to + practical quantum advantage, + [arXiv:2211.07629](https://arxiv.org/abs/2211.07629) + - Jens Koch, Terri M. Yu, Jay Gambetta, A. A. Houck, D. I. Schuster, J. + Majer, Alexandre Blais, M. H. Devoret, S. M. Girvin, R. J. Schoelkopf: + Charge insensitive qubit design derived from the Cooper pair box, + [arXiv:cond-mat/0703002](https://arxiv.org/abs/cond-mat/0703002) """ _: KW_ONLY diff --git a/source/pip/qsharp/qre/models/qubits/_msft.py b/source/pip/qsharp/qre/models/qubits/_msft.py new file mode 100644 index 0000000000..9ce6fcb3c9 --- /dev/null +++ b/source/pip/qsharp/qre/models/qubits/_msft.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field + +from ..._architecture import Architecture +from ...instruction_ids import ( + T, + PREP_X, + PREP_Z, + MEAS_XX, + MEAS_ZZ, + MEAS_X, + MEAS_Z, +) +from ..._instruction import ISA, instruction + + +@dataclass +class Majorana(Architecture): + """ + This class models physical instructions that may be relevant for future + Majorana qubits. For these qubits, we assume that measurements + and the physical T gate each take 1 µs. Owing to topological protection in + the hardware, we assume single and two-qubit measurement error rates + (Clifford error rates) in $10^{-4}$, $10^{-5}$, and $10^{-6}$ as a range + between realistic and optimistic targets. Non-Clifford operations in this + architecture do not have topological protection, so we assume a 5%, 1.5%, + and 1% error rate for non-Clifford physical T gates for the three cases, + respectively. + + References: + + - Torsten Karzig, Christina Knapp, Roman M. Lutchyn, Parsa Bonderson, + Matthew B. Hastings, Chetan Nayak, Jason Alicea, Karsten Flensberg, + Stephan Plugge, Yuval Oreg, Charles M. Marcus, Michael H. Freedman: + Scalable Designs for Quasiparticle-Poisoning-Protected Topological Quantum + Computation with Majorana Zero Modes, + [arXiv:1610.05289](https://arxiv.org/abs/1610.05289) + - Alexei Kitaev: Unpaired Majorana fermions in quantum wires, + [arXiv:cond-mat/0010440](https://arxiv.org/abs/cond-mat/0010440) + - Sankar Das Sarma, Michael Freedman, Chetan Nayak: Majorana Zero Modes and + Topological Quantum Computation, + [arXiv:1501.02813](https://arxiv.org/abs/1501.02813) + """ + + _: KW_ONLY + error_rate: float = field(default=1e-5, metadata={"domain": [1e-4, 1e-5, 1e-6]}) + + @property + def provided_isa(self) -> ISA: + if abs(self.error_rate - 1e-4) <= 1e-8: + t_error_rate = 0.05 + elif abs(self.error_rate - 1e-5) <= 1e-8: + t_error_rate = 0.015 + elif abs(self.error_rate - 1e-6) <= 1e-8: + t_error_rate = 0.01 + + return ISA( + instruction( + PREP_X, + time=1000, + error_rate=self.error_rate, + ), + instruction( + PREP_Z, + time=1000, + error_rate=self.error_rate, + ), + instruction( + MEAS_XX, + arity=2, + time=1000, + error_rate=self.error_rate, + ), + instruction( + MEAS_ZZ, + arity=2, + time=1000, + error_rate=self.error_rate, + ), + instruction( + MEAS_X, + time=1000, + error_rate=self.error_rate, + ), + instruction( + MEAS_Z, + time=1000, + error_rate=self.error_rate, + ), + instruction( + T, + time=1000, + error_rate=t_error_rate, + ), + ) diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 0f45f76bda..37e9f57ba6 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -7,7 +7,7 @@ use pyo3::{ IntoPyObjectExt, exceptions::{PyException, PyKeyError, PyTypeError}, prelude::*, - types::{PyBool, PyDict, PyFloat, PyInt, PyString, PyTuple}, + types::{PyBool, PyDict, PyFloat, PyInt, PyString, PyTuple, PyType}, }; use qre::TraceTransform; use serde::{Deserialize, Serialize}; @@ -305,6 +305,19 @@ impl Instruction { } } +impl qre::ParetoItem2D for Instruction { + type Objective1 = u64; + type Objective2 = u64; + + fn objective1(&self) -> Self::Objective1 { + self.0.expect_space(None) + } + + fn objective2(&self) -> Self::Objective2 { + self.0.expect_time(None) + } +} + impl qre::ParetoItem3D for Instruction { type Objective1 = u64; type Objective2 = u64; @@ -577,21 +590,44 @@ pub struct EstimationResult(qre::EstimationResult); #[pymethods] impl EstimationResult { + #[new] + #[pyo3(signature = (*, qubits = 0, runtime = 0, error = 0.0))] + pub fn new(qubits: u64, runtime: u64, error: f64) -> Self { + let mut result = qre::EstimationResult::new(); + result.add_qubits(qubits); + result.add_runtime(runtime); + result.add_error(error); + + EstimationResult(result) + } + #[getter] pub fn qubits(&self) -> u64 { self.0.qubits() } + pub fn add_qubits(&mut self, amount: u64) { + self.0.add_qubits(amount); + } + #[getter] pub fn runtime(&self) -> u64 { self.0.runtime() } + pub fn add_runtime(&mut self, amount: u64) { + self.0.add_runtime(amount); + } + #[getter] pub fn error(&self) -> f64 { self.0.error() } + pub fn add_error(&mut self, amount: f64) { + self.0.add_error(amount); + } + #[allow(clippy::needless_pass_by_value)] #[getter] pub fn factories(self_: PyRef<'_, Self>) -> PyResult> { @@ -672,6 +708,21 @@ impl Trace { Trace(self.0.clone_empty(compute_qubits)) } + #[classmethod] + pub fn from_json(_cls: &Bound<'_, PyType>, json: &str) -> PyResult { + let trace: qre::Trace = serde_json::from_str(json).map_err(|e| { + EstimationError::new_err(format!("Failed to parse trace from JSON: {e}")) + })?; + + Ok(Self(trace)) + } + + pub fn to_json(&self) -> PyResult { + serde_json::to_string(&self.0).map_err(|e| { + EstimationError::new_err(format!("Failed to serialize trace to JSON: {e}")) + }) + } + #[getter] pub fn compute_qubits(&self) -> u64 { self.0.compute_qubits() @@ -863,55 +914,103 @@ impl LatticeSurgery { } } +/// Dispatches a method call to the inner frontier variant, avoiding +/// repetitive match arms. Use as: +/// +/// ```ignore +/// dispatch_frontier!(self, f => f.len()) +/// dispatch_frontier!(mut self, f => f.insert(point.clone())) +/// ``` +macro_rules! dispatch_frontier { + ($self:ident, $f:ident => $body:expr) => { + match &$self.0 { + InstructionFrontierInner::Frontier2D($f) => $body, + InstructionFrontierInner::Frontier3D($f) => $body, + } + }; + (mut $self:ident, $f:ident => $body:expr) => { + match &mut $self.0 { + InstructionFrontierInner::Frontier2D($f) => $body, + InstructionFrontierInner::Frontier3D($f) => $body, + } + }; +} + +#[derive(Clone)] +enum InstructionFrontierInner { + Frontier2D(qre::ParetoFrontier2D), + Frontier3D(qre::ParetoFrontier3D), +} + #[pyclass] -pub struct InstructionFrontier(qre::ParetoFrontier3D); +pub struct InstructionFrontier(InstructionFrontierInner); impl Default for InstructionFrontier { fn default() -> Self { - InstructionFrontier(qre::ParetoFrontier3D::new()) + Self(InstructionFrontierInner::Frontier3D( + qre::ParetoFrontier3D::new(), + )) } } #[pymethods] impl InstructionFrontier { #[new] - pub fn new() -> Self { - Self::default() + #[pyo3(signature = (*, with_error_objective = true))] + pub fn new(with_error_objective: bool) -> Self { + if with_error_objective { + Self(InstructionFrontierInner::Frontier3D( + qre::ParetoFrontier3D::new(), + )) + } else { + Self(InstructionFrontierInner::Frontier2D( + qre::ParetoFrontier2D::new(), + )) + } } pub fn insert(&mut self, point: &Instruction) { - self.0.insert(point.clone()); + dispatch_frontier!(mut self, f => f.insert(point.clone())); } #[allow(clippy::needless_pass_by_value)] pub fn extend(&mut self, points: Vec>) { - self.0 - .extend(points.iter().map(|p| Instruction(p.0.clone()))); + dispatch_frontier!(mut self, f => f.extend(points.iter().map(|p| Instruction(p.0.clone())))); } pub fn __len__(&self) -> usize { - self.0.len() + dispatch_frontier!(self, f => f.len()) } #[allow(clippy::needless_pass_by_value)] pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { - let iter = InstructionFrontierIterator { - iter: slf.0.iter().cloned().collect::>().into_iter(), - }; - Py::new(slf.py(), iter) + let items: Vec = dispatch_frontier!(slf, f => f.iter().cloned().collect()); + Py::new( + slf.py(), + InstructionFrontierIterator { + iter: items.into_iter(), + }, + ) } #[staticmethod] - pub fn load(filename: &str) -> PyResult { + #[pyo3(signature = (filename, *, with_error_objective = true))] + pub fn load(filename: &str, with_error_objective: bool) -> PyResult { let content = std::fs::read_to_string(filename)?; - let frontier = - serde_json::from_str(&content).map_err(|e| EstimationError::new_err(format!("{e}")))?; - Ok(InstructionFrontier(frontier)) + let err = |e: serde_json::Error| EstimationError::new_err(format!("{e}")); + + let inner = if with_error_objective { + InstructionFrontierInner::Frontier3D(serde_json::from_str(&content).map_err(err)?) + } else { + InstructionFrontierInner::Frontier2D(serde_json::from_str(&content).map_err(err)?) + }; + Ok(InstructionFrontier(inner)) } pub fn dump(&self, filename: &str) -> PyResult<()> { - let content = - serde_json::to_string(&self.0).map_err(|e| EstimationError::new_err(format!("{e}")))?; + let content = dispatch_frontier!(self, f => + serde_json::to_string(f).map_err(|e| EstimationError::new_err(format!("{e}")))? + ); Ok(std::fs::write(filename, content)?) } } diff --git a/source/pip/test_requirements.txt b/source/pip/test_requirements.txt index b12a203d61..3e0b184b99 100644 --- a/source/pip/test_requirements.txt +++ b/source/pip/test_requirements.txt @@ -1,2 +1,3 @@ pytest pyqir<0.12 +cirq==1.6.1 diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 40b3e8ffd4..c741538b75 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -32,13 +32,7 @@ from qsharp.qre._isa_enumeration import ( ISARefNode, ) -from qsharp.qre.instruction_ids import ( - CCX, - CCZ, - GENERIC, - LATTICE_SURGERY, - T, -) +from qsharp.qre.instruction_ids import CCX, CCZ, LATTICE_SURGERY, T, RZ # NOTE These classes will be generalized as part of the QRE API in the following # pull requests and then moved out of the tests. @@ -69,7 +63,7 @@ class ExampleLogicalFactory(ISATransform): @staticmethod def required_isa() -> ISARequirements: return ISARequirements( - constraint(GENERIC, encoding=LOGICAL), + constraint(LATTICE_SURGERY, encoding=LOGICAL), constraint(T, encoding=LOGICAL), ) @@ -206,9 +200,9 @@ def test_isa_from_architecture(): # Generate logical ISAs isas = list(code.provided_isa(arch.provided_isa, arch.context())) - # There is one ISA with two instructions + # There is one ISA with one instructions assert len(isas) == 1 - assert len(isas[0]) == 2 + assert len(isas[0]) == 1 def test_enumerate_instances(): @@ -312,6 +306,51 @@ class LiteralConfig: assert instances[1].mode == "slow" +def test_enumerate_instances_nested(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InnerConfig: + _: KW_ONLY + option: bool + + @dataclass + class OuterConfig: + _: KW_ONLY + inner: InnerConfig + + instances = list(_enumerate_instances(OuterConfig)) + assert len(instances) == 2 + assert instances[0].inner.option is True + assert instances[1].inner.option is False + + +def test_enumerate_instances_union(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + instances = list(_enumerate_instances(UnionConfig)) + assert len(instances) == 5 + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + assert isinstance(instances[2].option, OptionB) + assert instances[2].option.number == 1 + + def test_enumerate_isas(): ctx = AQREGateBased().context() @@ -385,8 +424,8 @@ def test_binding_node(): # Verify the binding works: with binding, both should use same params for isa in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx): logical_gates = [g for g in isa if g.encoding == LOGICAL] - # Should have 2 logical gates (GENERIC and LATTICE_SURGERY) - assert len(logical_gates) == 2 + # Should have 1 logical gate (LATTICE_SURGERY) + assert len(logical_gates) == 1 # Test binding with factories (nested bindings) count_without = sum( @@ -451,7 +490,7 @@ def test_binding_node(): .enumerate(ctx) ): logical_gates = [g for g in isa if g.encoding == LOGICAL] - assert all(g.space(1) == 50 for g in logical_gates) + assert all(g.space(1) == 49 for g in logical_gates) # Test multiple independent bindings (nested) count = sum( @@ -639,7 +678,7 @@ def test_qsharp_application(): result = trace2.estimate(isa, max_error=float("inf")) assert result is not None _assert_estimation_result(trace2, result, isa) - assert counter == 40 + assert counter == 32 def test_trace_enumeration(): @@ -660,12 +699,30 @@ def test_trace_enumeration(): root = RootNode() assert sum(1 for _ in root.enumerate(ctx)) == 1 - assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 40 + assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 32 assert sum(1 for _ in LatticeSurgery.q().enumerate(ctx)) == 1 q = PSSPC.q() * LatticeSurgery.q() - assert sum(1 for _ in q.enumerate(ctx)) == 40 + assert sum(1 for _ in q.enumerate(ctx)) == 32 + + +def test_rotation_error_psspc(): + from qsharp.qre._enumeration import _enumerate_instances + + # This test helps to bound the variables for the number of rotations in PSSPC + + # Create a trace with a single rotation gate and ensure that the base error + # after PSSPC transformation is less than 1. + trace = Trace(1) + trace.add_operation(RZ, [0]) + + for psspc in _enumerate_instances(PSSPC, ccx_magic_states=False): + transformed = psspc.transform(trace) + assert transformed is not None + assert ( + transformed.base_error < 1.0 + ), f"Base error too high: {transformed.base_error} for {psspc.num_ts_per_rotation} T states per rotation" def test_estimation_max_error(): diff --git a/source/pip/tests/test_qre_models.py b/source/pip/tests/test_qre_models.py new file mode 100644 index 0000000000..85c0643a92 --- /dev/null +++ b/source/pip/tests/test_qre_models.py @@ -0,0 +1,1019 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from qsharp.qre import LOGICAL, PHYSICAL, Encoding, PropertyKey, instruction +from qsharp.qre.instruction_ids import ( + T, + CCZ, + CCX, + CCY, + CNOT, + CZ, + H, + MEAS_Z, + MEAS_X, + MEAS_XX, + MEAS_ZZ, + PAULI_I, + PREP_X, + PREP_Z, + LATTICE_SURGERY, + MEMORY, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + SQRT_SQRT_Z_DAG, +) +from qsharp.qre.models import ( + AQREGateBased, + Majorana, + RoundBasedFactory, + MagicUpToClifford, + Litinski19Factory, + SurfaceCode, + ThreeAux, + YokedSurfaceCode, +) + + +# --------------------------------------------------------------------------- +# AQREGateBased architecture tests +# --------------------------------------------------------------------------- + + +class TestAQREGateBased: + def test_default_error_rate(self): + arch = AQREGateBased() + assert arch.error_rate == 1e-4 + + def test_custom_error_rate(self): + arch = AQREGateBased(error_rate=1e-3) + assert arch.error_rate == 1e-3 + + def test_provided_isa_contains_expected_instructions(self): + arch = AQREGateBased() + isa = arch.provided_isa + + for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: + assert instr_id in isa + + def test_instruction_encodings_are_physical(self): + arch = AQREGateBased() + isa = arch.provided_isa + + for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: + assert isa[instr_id].encoding == PHYSICAL + + def test_instruction_error_rates_match(self): + rate = 1e-3 + arch = AQREGateBased(error_rate=rate) + isa = arch.provided_isa + + for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: + assert isa[instr_id].expect_error_rate() == rate + + def test_gate_times(self): + arch = AQREGateBased() + isa = arch.provided_isa + + # Single-qubit gates: 50ns + for instr_id in [PAULI_I, H, T]: + assert isa[instr_id].expect_time() == 50 + + # Two-qubit gates: 50ns + for instr_id in [CNOT, CZ]: + assert isa[instr_id].expect_time() == 50 + + # Measurement: 100ns + assert isa[MEAS_Z].expect_time() == 100 + + def test_arities(self): + arch = AQREGateBased() + isa = arch.provided_isa + + assert isa[PAULI_I].arity == 1 + assert isa[CNOT].arity == 2 + assert isa[CZ].arity == 2 + assert isa[H].arity == 1 + assert isa[MEAS_Z].arity == 1 + + def test_context_creation(self): + arch = AQREGateBased() + ctx = arch.context() + assert ctx is not None + + +# --------------------------------------------------------------------------- +# Majorana architecture tests +# --------------------------------------------------------------------------- + + +class TestMajorana: + def test_default_error_rate(self): + arch = Majorana() + assert arch.error_rate == 1e-5 + + def test_provided_isa_contains_expected_instructions(self): + arch = Majorana() + isa = arch.provided_isa + + for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z, T]: + assert instr_id in isa + + def test_all_times_are_1us(self): + arch = Majorana() + isa = arch.provided_isa + + for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z, T]: + assert isa[instr_id].expect_time() == 1000 + + def test_clifford_error_rates_match_qubit_error(self): + for rate in [1e-4, 1e-5, 1e-6]: + arch = Majorana(error_rate=rate) + isa = arch.provided_isa + + for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z]: + assert isa[instr_id].expect_error_rate() == rate + + def test_t_error_rate_mapping(self): + """T error rate maps: 1e-4 -> 5%, 1e-5 -> 1.5%, 1e-6 -> 1%.""" + expected = {1e-4: 0.05, 1e-5: 0.015, 1e-6: 0.01} + + for qubit_rate, t_rate in expected.items(): + arch = Majorana(error_rate=qubit_rate) + isa = arch.provided_isa + assert isa[T].expect_error_rate() == t_rate + + def test_two_qubit_measurement_arities(self): + arch = Majorana() + isa = arch.provided_isa + + assert isa[MEAS_XX].arity == 2 + assert isa[MEAS_ZZ].arity == 2 + + +# --------------------------------------------------------------------------- +# SurfaceCode QEC tests +# --------------------------------------------------------------------------- + + +class TestSurfaceCode: + def test_required_isa(self): + reqs = SurfaceCode.required_isa() + assert reqs is not None + + def test_default_distance(self): + sc = SurfaceCode(distance=3) + assert sc.distance == 3 + + def test_provides_lattice_surgery(self): + arch = AQREGateBased() + ctx = arch.context() + sc = SurfaceCode(distance=3) + + isas = list(sc.provided_isa(arch.provided_isa, ctx)) + assert len(isas) == 1 + + isa = isas[0] + assert LATTICE_SURGERY in isa + + ls = isa[LATTICE_SURGERY] + assert ls.encoding == LOGICAL + + def test_space_scales_with_distance(self): + """Space = 2*d^2 - 1 physical qubits per logical qubit.""" + arch = AQREGateBased() + + for d in [3, 5, 7, 9]: + ctx = arch.context() + sc = SurfaceCode(distance=d) + isas = list(sc.provided_isa(arch.provided_isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + expected_space = 2 * d**2 - 1 + assert ls.expect_space(1) == expected_space + + def test_time_scales_with_distance(self): + """Time = (h_time + 4*cnot_time + meas_time) * d.""" + arch = AQREGateBased() + # h=50, cnot=50, meas=100 for AQREGateBased + syndrome_time = 50 + 4 * 50 + 100 # = 350 + + for d in [3, 5, 7]: + ctx = arch.context() + sc = SurfaceCode(distance=d) + isas = list(sc.provided_isa(arch.provided_isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + assert ls.expect_time(1) == syndrome_time * d + + def test_error_rate_decreases_with_distance(self): + arch = AQREGateBased() + + errors = [] + for d in [3, 5, 7, 9, 11]: + ctx = arch.context() + sc = SurfaceCode(distance=d) + isas = list(sc.provided_isa(arch.provided_isa, ctx)) + errors.append(isas[0][LATTICE_SURGERY].expect_error_rate(1)) + + # Each successive distance should have a lower error rate + for i in range(len(errors) - 1): + assert errors[i] > errors[i + 1] + + def test_enumeration_via_query(self): + """Enumerating SurfaceCode.q() should yield multiple distances.""" + arch = AQREGateBased() + ctx = arch.context() + + count = 0 + for isa in SurfaceCode.q().enumerate(ctx): + assert LATTICE_SURGERY in isa + count += 1 + + # domain is range(3, 26, 2) = 12 distances + assert count == 12 + + def test_custom_crossing_prefactor(self): + arch = AQREGateBased() + ctx = arch.context() + + sc_default = SurfaceCode(distance=5) + sc_custom = SurfaceCode(crossing_prefactor=0.06, distance=5) + + default_error = list(sc_default.provided_isa(arch.provided_isa, ctx))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + ctx2 = arch.context() + custom_error = list(sc_custom.provided_isa(arch.provided_isa, ctx2))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + # Doubling prefactor should double the error rate + assert abs(custom_error - 2 * default_error) < 1e-20 + + def test_custom_error_correction_threshold(self): + arch = AQREGateBased() + + ctx1 = arch.context() + sc_low_threshold = SurfaceCode(error_correction_threshold=0.005, distance=5) + error_low = list(sc_low_threshold.provided_isa(arch.provided_isa, ctx1))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + ctx2 = arch.context() + sc_high_threshold = SurfaceCode(error_correction_threshold=0.02, distance=5) + error_high = list(sc_high_threshold.provided_isa(arch.provided_isa, ctx2))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + # Lower threshold means worse ratio => higher logical error + assert error_low > error_high + + +# --------------------------------------------------------------------------- +# ThreeAux QEC tests +# --------------------------------------------------------------------------- + + +class TestThreeAux: + def test_required_isa(self): + reqs = ThreeAux.required_isa() + assert reqs is not None + + def test_provides_lattice_surgery(self): + arch = Majorana() + ctx = arch.context() + ta = ThreeAux(distance=3) + + isas = list(ta.provided_isa(arch.provided_isa, ctx)) + assert len(isas) == 1 + assert LATTICE_SURGERY in isas[0] + + def test_space_formula(self): + """Space = 4*d^2 - 3 per logical qubit.""" + arch = Majorana() + + for d in [3, 5, 7]: + ctx = arch.context() + ta = ThreeAux(distance=d) + isas = list(ta.provided_isa(arch.provided_isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + expected = 4 * d**2 - 3 + assert ls.expect_space(1) == expected + + def test_time_formula_double_rail(self): + """Time = gate_time * (4*d + 4) for double-rail (default).""" + arch = Majorana() + + for d in [3, 5, 7]: + ctx = arch.context() + ta = ThreeAux(distance=d, single_rail=False) + isas = list(ta.provided_isa(arch.provided_isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + # MEAS_XX and MEAS_ZZ have time=1000 each; max is 1000 + expected_time = 1000 * (4 * d + 4) + assert ls.expect_time(1) == expected_time + + def test_time_formula_single_rail(self): + """Time = gate_time * (5*d + 4) for single-rail.""" + arch = Majorana() + + for d in [3, 5, 7]: + ctx = arch.context() + ta = ThreeAux(distance=d, single_rail=True) + isas = list(ta.provided_isa(arch.provided_isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + expected_time = 1000 * (5 * d + 4) + assert ls.expect_time(1) == expected_time + + def test_error_rate_decreases_with_distance(self): + arch = Majorana() + + errors = [] + for d in [3, 5, 7, 9]: + ctx = arch.context() + ta = ThreeAux(distance=d) + isas = list(ta.provided_isa(arch.provided_isa, ctx)) + errors.append(isas[0][LATTICE_SURGERY].expect_error_rate(1)) + + for i in range(len(errors) - 1): + assert errors[i] > errors[i + 1] + + def test_single_rail_has_different_error_threshold(self): + """Single-rail has threshold 0.0051, double-rail 0.0066.""" + arch = Majorana() + + ctx1 = arch.context() + double = ThreeAux(distance=5, single_rail=False) + error_double = list(double.provided_isa(arch.provided_isa, ctx1))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + ctx2 = arch.context() + single = ThreeAux(distance=5, single_rail=True) + error_single = list(single.provided_isa(arch.provided_isa, ctx2))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + # Both should be positive but differ + assert error_double > 0 + assert error_single > 0 + assert error_double != error_single + + def test_enumeration_via_query(self): + arch = Majorana() + ctx = arch.context() + + count = 0 + for isa in ThreeAux.q().enumerate(ctx): + assert LATTICE_SURGERY in isa + count += 1 + + # domain: range(3, 26, 2) × {True, False} for single_rail + # = 12 distances × 2 = 24 + assert count == 24 + + +# --------------------------------------------------------------------------- +# YokedSurfaceCode tests +# --------------------------------------------------------------------------- + + +class TestYokedSurfaceCode: + def _get_lattice_surgery_isa(self, distance=5): + """Helper to get a lattice surgery ISA from SurfaceCode.""" + arch = AQREGateBased() + ctx = arch.context() + sc = SurfaceCode(distance=distance) + isas = list(sc.provided_isa(arch.provided_isa, ctx)) + return isas[0], ctx + + def test_provides_memory_instruction(self): + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = YokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + assert len(isas) == 1 + assert MEMORY in isas[0] + + def test_memory_is_logical(self): + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = YokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + assert mem.encoding == LOGICAL + + def test_memory_arity_is_variable(self): + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = YokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + # arity=None means variable arity + assert mem.arity is None + + def test_space_increases_with_arity(self): + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = YokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + + spaces = [mem.expect_space(n) for n in [4, 16, 64]] + for i in range(len(spaces) - 1): + assert spaces[i] < spaces[i + 1] + + def test_time_increases_with_arity(self): + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = YokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + + times = [mem.expect_time(n) for n in [4, 16, 64]] + for i in range(len(times) - 1): + assert times[i] < times[i + 1] + + def test_error_rate_increases_with_arity(self): + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = YokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + + errors = [mem.expect_error_rate(n) for n in [4, 16, 64]] + for i in range(len(errors) - 1): + assert errors[i] < errors[i + 1] + + def test_distance_property_propagated(self): + d = 7 + ls_isa, ctx = self._get_lattice_surgery_isa(distance=d) + ysc = YokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + assert mem.get_property(PropertyKey.DISTANCE) == d + + +# --------------------------------------------------------------------------- +# Litinski19Factory tests +# --------------------------------------------------------------------------- + + +class TestLitinski19Factory: + def test_required_isa(self): + reqs = Litinski19Factory.required_isa() + assert reqs is not None + + def test_table1_aqre_yields_t_and_ccz(self): + """AQREGateBased (error 1e-4) matches Table 1 scenario: T & CCZ.""" + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + + isas = list(factory.provided_isa(arch.provided_isa, ctx)) + + # 6 T entries × 1 CCZ entry = 6 combinations + assert len(isas) == 6 + + for isa in isas: + assert T in isa + assert CCZ in isa + assert len(isa) == 2 + + def test_table1_instruction_properties(self): + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + + for isa in factory.provided_isa(arch.provided_isa, ctx): + t_instr = isa[T] + ccz_instr = isa[CCZ] + + assert t_instr.arity == 1 + assert t_instr.encoding == LOGICAL + assert t_instr.expect_error_rate() > 0 + assert t_instr.expect_time() > 0 + assert t_instr.expect_space() > 0 + + assert ccz_instr.arity == 3 + assert ccz_instr.encoding == LOGICAL + assert ccz_instr.expect_error_rate() > 0 + + def test_table1_t_error_rates_are_diverse(self): + """T entries in Table 1 should span a range of error rates.""" + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + + isas = list(factory.provided_isa(arch.provided_isa, ctx)) + t_errors = [isa[T].expect_error_rate() for isa in isas] + + # Should have multiple distinct T error rates + unique_errors = set(t_errors) + assert len(unique_errors) > 1 + + # All error rates should be positive and very small + for err in t_errors: + assert 0 < err < 1e-5 + + def test_table1_1e3_clifford_yields_6_isas(self): + """AQREGateBased with 1e-3 error matches Table 1 at 1e-3 Clifford.""" + arch = AQREGateBased(error_rate=1e-3) + ctx = arch.context() + factory = Litinski19Factory() + + isas = list(factory.provided_isa(arch.provided_isa, ctx)) + + # 6 T entries × 1 CCZ entry = 6 combinations + assert len(isas) == 6 + + for isa in isas: + assert T in isa + assert CCZ in isa + + def test_table2_scenario_no_ccz(self): + """Table 2 scenario: T error ~10x higher than Clifford, no CCZ.""" + from qsharp.qre import ISA as ISAType + + arch = AQREGateBased() + ctx = arch.context() + + # Manually create ISA with T error rate 10x Clifford + isa_input = ISAType( + instruction( + CNOT, encoding=Encoding.PHYSICAL, arity=2, time=50, error_rate=1e-4 + ), + instruction( + H, encoding=Encoding.PHYSICAL, arity=1, time=50, error_rate=1e-4 + ), + instruction( + MEAS_Z, encoding=Encoding.PHYSICAL, arity=1, time=100, error_rate=1e-4 + ), + instruction(T, encoding=Encoding.PHYSICAL, time=50, error_rate=1e-3), + ) + + factory = Litinski19Factory() + isas = list(factory.provided_isa(isa_input, ctx)) + + # Table 2 at 1e-4 Clifford: 4 T entries, no CCZ + assert len(isas) == 4 + + for isa in isas: + assert T in isa + assert CCZ not in isa + + def test_no_yield_when_error_too_high(self): + """If T error > 10x Clifford, no entries match.""" + from qsharp.qre import ISA as ISAType + + arch = AQREGateBased() + ctx = arch.context() + + isa_input = ISAType( + instruction( + CNOT, encoding=Encoding.PHYSICAL, arity=2, time=50, error_rate=1e-4 + ), + instruction( + H, encoding=Encoding.PHYSICAL, arity=1, time=50, error_rate=1e-4 + ), + instruction( + MEAS_Z, encoding=Encoding.PHYSICAL, arity=1, time=100, error_rate=1e-4 + ), + instruction(T, encoding=Encoding.PHYSICAL, time=50, error_rate=0.05), + ) + + factory = Litinski19Factory() + isas = list(factory.provided_isa(isa_input, ctx)) + assert len(isas) == 0 + + def test_time_based_on_syndrome_extraction(self): + """Time should be based on syndrome extraction time × cycles.""" + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + + # For AQREGateBased: syndrome_extraction_time = 4*50 + 50 + 100 = 350 + syndrome_time = 4 * 50 + 50 + 100 # 350 ns + + isas = list(factory.provided_isa(arch.provided_isa, ctx)) + for isa in isas: + t_time = isa[T].expect_time() + assert t_time > 0 + # Time should be ceil(syndrome_time * cycles), so it must be at + # least syndrome_time (cycles >= 1) + assert t_time >= syndrome_time + + +# --------------------------------------------------------------------------- +# MagicUpToClifford tests +# --------------------------------------------------------------------------- + + +class TestMagicUpToClifford: + def test_required_isa_is_empty(self): + reqs = MagicUpToClifford.required_isa() + assert reqs is not None + + def test_adds_clifford_equivalent_t_gates(self): + """Given T gate, should add SQRT_SQRT_X/Y/Z and dagger variants.""" + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + for isa in factory.provided_isa(arch.provided_isa, ctx): + modified_isas = list(modifier.provided_isa(isa, ctx)) + assert len(modified_isas) == 1 + modified_isa = modified_isas[0] + + # T family equivalents + for equiv_id in [ + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + SQRT_SQRT_Z_DAG, + ]: + assert equiv_id in modified_isa + + break # Just test the first one + + def test_adds_clifford_equivalent_ccz(self): + """Given CCZ, should add CCX and CCY.""" + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + for isa in factory.provided_isa(arch.provided_isa, ctx): + modified_isas = list(modifier.provided_isa(isa, ctx)) + modified_isa = modified_isas[0] + + assert CCX in modified_isa + assert CCY in modified_isa + assert CCZ in modified_isa + break + + def test_full_count_of_instructions(self): + """T gate (1) + 5 equivalents (SQRT_SQRT_*) + CCZ (1) + 2 equivalents (CCX, CCY) = 9.""" + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + for isa in factory.provided_isa(arch.provided_isa, ctx): + modified_isas = list(modifier.provided_isa(isa, ctx)) + assert len(modified_isas[0]) == 9 + break + + def test_equivalent_instructions_share_properties(self): + """Clifford equivalents should have same time, space, error rate.""" + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + for isa in factory.provided_isa(arch.provided_isa, ctx): + modified_isas = list(modifier.provided_isa(isa, ctx)) + modified_isa = modified_isas[0] + + t_instr = modified_isa[T] + for equiv_id in [ + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z_DAG, + ]: + equiv = modified_isa[equiv_id] + assert equiv.expect_error_rate() == t_instr.expect_error_rate() + assert equiv.expect_time() == t_instr.expect_time() + assert equiv.expect_space() == t_instr.expect_space() + + ccz_instr = modified_isa[CCZ] + for equiv_id in [CCX, CCY]: + equiv = modified_isa[equiv_id] + assert equiv.expect_error_rate() == ccz_instr.expect_error_rate() + + break + + def test_modification_count_matches_factory_output(self): + """MagicUpToClifford should produce one modified ISA per input ISA.""" + arch = AQREGateBased() + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + modified_count = 0 + for isa in factory.provided_isa(arch.provided_isa, ctx): + for _ in modifier.provided_isa(isa, ctx): + modified_count += 1 + + assert modified_count == 6 + + def test_no_family_present_passes_through(self): + """If no family member is present, ISA passes through unchanged.""" + from qsharp.qre import ISA as ISAType + + arch = AQREGateBased() + ctx = arch.context() + modifier = MagicUpToClifford() + + # ISA with only a LATTICE_SURGERY instruction (no T or CCZ family) + from qsharp.qre import linear_function + + ls = instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=linear_function(17), + time=1000, + error_rate=linear_function(1e-10), + ) + isa_input = ISAType(ls) + + results = list(modifier.provided_isa(isa_input, ctx)) + assert len(results) == 1 + # Should only contain the original instruction + assert len(results[0]) == 1 + + +# --------------------------------------------------------------------------- +# Litinski19Factory + MagicUpToClifford integration (from original test) +# --------------------------------------------------------------------------- + + +def test_isa_manipulation(): + arch = AQREGateBased() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + ctx = arch.context() + + # Table 1 scenario: should yield ISAs with both T and CCZ instructions + isas = list(factory.provided_isa(arch.provided_isa, ctx)) + + # 6 T entries × 1 CCZ entry = 6 combinations + assert len(isas) == 6 + + for isa in isas: + # Each ISA should contain both T and CCZ instructions + assert T in isa + assert CCZ in isa + assert len(isa) == 2 + + t_instr = isa[T] + ccz_instr = isa[CCZ] + + # Verify instruction properties + assert t_instr.arity == 1 + assert t_instr.encoding == LOGICAL + assert t_instr.expect_error_rate() > 0 + + assert ccz_instr.arity == 3 + assert ccz_instr.encoding == LOGICAL + assert ccz_instr.expect_error_rate() > 0 + + # After MagicUpToClifford modifier + modified_count = 0 + for isa in factory.provided_isa(arch.provided_isa, ctx): + for modified_isa in modifier.provided_isa(isa, ctx): + modified_count += 1 + # MagicUpToClifford should add derived instructions + assert T in modified_isa + assert CCZ in modified_isa + assert CCX in modified_isa + assert len(modified_isa) == 9 + + assert modified_count == 6 + + +# --------------------------------------------------------------------------- +# RoundBasedFactory tests +# --------------------------------------------------------------------------- + + +class TestRoundBasedFactory: + def test_required_isa(self): + reqs = RoundBasedFactory.required_isa() + assert reqs is not None + + def test_produces_logical_t_gates(self): + arch = AQREGateBased() + + for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): + t = isa[T] + assert t.encoding == LOGICAL + assert t.arity == 1 + assert t.expect_error_rate() > 0 + assert t.expect_time() > 0 + assert t.expect_space() > 0 + break # Just check the first + + def test_error_rates_are_bounded(self): + """Distilled T error rates should be bounded and mostly small.""" + arch = AQREGateBased() # T error rate is 1e-4 + + errors = [] + for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): + errors.append(isa[T].expect_error_rate()) + + # All should be positive + assert all(e > 0 for e in errors) + # Most distilled error rates should be much lower than 1 + assert min(errors) < 1e-4 + # Median should be well below raw physical error + sorted_errors = sorted(errors) + median = sorted_errors[len(sorted_errors) // 2] + assert median < 1e-3 + + def test_max_produces_fewer_or_equal_results_than_sum(self): + """Using max for physical_qubit_calculation may filter differently.""" + arch = AQREGateBased() + + sum_count = sum( + 1 for _ in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()) + ) + max_count = sum( + 1 + for _ in RoundBasedFactory.q( + use_cache=False, physical_qubit_calculation=max + ).enumerate(arch.context()) + ) + + assert max_count <= sum_count + + def test_max_space_less_than_or_equal_sum_space(self): + """max-aggregated space should be <= sum-aggregated space for each.""" + arch = AQREGateBased() + + sum_spaces = sorted( + isa[T].expect_space() + for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()) + ) + + max_spaces = sorted( + isa[T].expect_space() + for isa in RoundBasedFactory.q( + use_cache=False, physical_qubit_calculation=max + ).enumerate(arch.context()) + ) + + # The minimum space with max should be <= minimum space with sum + assert max_spaces[0] <= sum_spaces[0] + + def test_with_three_aux_code_query(self): + """RoundBasedFactory with ThreeAux code query should produce results.""" + arch = Majorana() + + count = 0 + for isa in RoundBasedFactory.q( + use_cache=False, code_query=ThreeAux.q() + ).enumerate(arch.context()): + assert T in isa + assert isa[T].encoding == LOGICAL + count += 1 + + assert count > 0 + + def test_round_based_aqre_sum(self): + arch = AQREGateBased() + + total_space = 0 + total_time = 0 + total_error = 0.0 + count = 0 + + for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): + count += 1 + total_space += isa[T].expect_space() + total_time += isa[T].expect_time() + total_error += isa[T].expect_error_rate() + + assert total_space == 12_946_488 + assert total_time == 12_032_250 + assert abs(total_error - 0.001_463_030_863_973_197_8) < 1e-8 + assert count == 107 + + def test_round_based_aqre_max(self): + arch = AQREGateBased() + + total_space = 0 + total_time = 0 + total_error = 0.0 + count = 0 + + for isa in RoundBasedFactory.q( + use_cache=False, physical_qubit_calculation=max + ).enumerate(arch.context()): + count += 1 + total_space += isa[T].expect_space() + total_time += isa[T].expect_time() + total_error += isa[T].expect_error_rate() + + assert total_space == 4_651_617 + assert total_time == 7_785_000 + assert abs(total_error - 0.001_463_030_863_973_197_8) < 1e-8 + assert count == 77 + + def test_round_based_msft_sum(self): + arch = Majorana() + + total_space = 0 + total_time = 0 + total_error = 0.0 + count = 0 + + for isa in RoundBasedFactory.q( + use_cache=False, code_query=ThreeAux.q() + ).enumerate(arch.context()): + count += 1 + total_space += isa[T].expect_space() + total_time += isa[T].expect_time() + total_error += isa[T].expect_error_rate() + + assert total_space == 255_952_723 + assert total_time == 478_235_000 + assert abs(total_error - 0.000_880_967_766_732_897_4) < 1e-8 + assert count == 301 + + +# --------------------------------------------------------------------------- +# Cross-model integration tests +# --------------------------------------------------------------------------- + + +class TestCrossModelIntegration: + def test_surface_code_feeds_into_litinski(self): + """SurfaceCode -> Litinski19Factory pipeline works end to end.""" + arch = AQREGateBased() + ctx = arch.context() + + # SurfaceCode takes AQRE physical ISA -> LATTICE_SURGERY + sc = SurfaceCode(distance=5) + sc_isas = list(sc.provided_isa(arch.provided_isa, ctx)) + assert len(sc_isas) == 1 + + # Litinski takes H, CNOT, MEAS_Z, T from the physical ISA + factory = Litinski19Factory() + factory_isas = list(factory.provided_isa(arch.provided_isa, ctx)) + assert len(factory_isas) > 0 + + def test_three_aux_feeds_into_round_based(self): + """ThreeAux -> RoundBasedFactory pipeline works.""" + arch = Majorana() + ctx = arch.context() + + count = 0 + for isa in RoundBasedFactory.q( + use_cache=False, code_query=ThreeAux.q() + ).enumerate(ctx): + assert T in isa + count += 1 + + assert count > 0 + + def test_litinski_with_magic_up_to_clifford_query(self): + """Full query chain: Litinski19Factory -> MagicUpToClifford.""" + arch = AQREGateBased() + ctx = arch.context() + + count = 0 + for isa in MagicUpToClifford.q(source=Litinski19Factory.q()).enumerate(ctx): + assert T in isa + assert CCX in isa + assert CCY in isa + assert CCZ in isa + count += 1 + + assert count == 6 + + def test_surface_code_with_yoked_surface_code(self): + """SurfaceCode -> YokedSurfaceCode pipeline provides MEMORY.""" + arch = AQREGateBased() + ctx = arch.context() + + count = 0 + for isa in YokedSurfaceCode.q(source=SurfaceCode.q()).enumerate(ctx): + assert MEMORY in isa + count += 1 + + # 12 distances × 2 shape heuristics = 24 + assert count == 24 + + def test_majorana_three_aux_yoked(self): + """Majorana -> ThreeAux -> YokedSurfaceCode pipeline.""" + arch = Majorana() + ctx = arch.context() + + count = 0 + for isa in YokedSurfaceCode.q(source=ThreeAux.q()).enumerate(ctx): + assert MEMORY in isa + count += 1 + + assert count > 0 diff --git a/source/qre/src/pareto.rs b/source/qre/src/pareto.rs index 96f842a0a4..ace12ca172 100644 --- a/source/qre/src/pareto.rs +++ b/source/qre/src/pareto.rs @@ -34,7 +34,7 @@ pub trait ParetoItem3D { /// This approach is related to the algorithms described in: /// H. T. Kung, F. Luccio, and F. P. Preparata, "On Finding the Maxima of a Set of Vectors," /// Journal of the ACM, vol. 22, no. 4, pp. 469-476, 1975. -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct ParetoFrontier(pub Vec); impl ParetoFrontier { diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 1d2d7081e2..d9a4f12162 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -4,6 +4,7 @@ use std::fmt::{Display, Formatter}; use rustc_hash::{FxHashMap, FxHashSet}; +use serde::{Deserialize, Serialize}; use crate::{Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction}; @@ -15,7 +16,7 @@ mod tests; mod transforms; pub use transforms::{LatticeSurgery, PSSPC, TraceTransform}; -#[derive(Clone, Default)] +#[derive(Clone, Default, Serialize, Deserialize)] pub struct Trace { block: Block, base_error: f64, @@ -311,20 +312,20 @@ impl Display for Trace { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum Operation { GateOperation(Gate), BlockOperation(Block), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Gate { id: u64, qubits: Vec, params: Vec, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Block { operations: Vec, repetitions: u64, @@ -493,7 +494,7 @@ impl<'a> Iterator for TraceIterator<'a> { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub enum Property { Bool(bool), Int(i64), From 41bf7b59b91692d6db374a9080484c22bdd7f866 Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Thu, 26 Feb 2026 09:59:52 -0800 Subject: [PATCH 19/45] Magnets: revisions to (hypergraph) coloring (#2972) Issue: different terms in a Trotter expansion may need different coloring of the edges. So the coloring is not an intrinsic property of the graph. We remove (edge) coloring of a graph to be a seperate class, and shift all the coloring functionality to this class. Hence a single graph can now support multiple different colorings. - Major revisions to the Hypergraph class and introduction of the HypergraphEdgeColoring class - Introduction of the the edge_coloring() function that wraps greedy_edge_coloring() for the base class - Revisions of each of the graph subclasses to remove coloring and create an overloaded version of edge_coloring() - Revisions of the docstrings across all these classes - Minor changes to some string representations - Update to all the unit tests --- .../pip/qsharp/magnets/geometry/complete.py | 127 +++--- .../pip/qsharp/magnets/geometry/lattice1d.py | 64 +-- .../pip/qsharp/magnets/geometry/lattice2d.py | 130 +++--- source/pip/qsharp/magnets/models/__init__.py | 4 +- source/pip/qsharp/magnets/models/model.py | 228 +++------- .../pip/qsharp/magnets/utilities/__init__.py | 8 +- .../qsharp/magnets/utilities/hypergraph.py | 299 ++++++++----- source/pip/qsharp/magnets/utilities/pauli.py | 98 ++-- source/pip/tests/magnets/test_complete.py | 39 +- source/pip/tests/magnets/test_hypergraph.py | 189 ++++++-- source/pip/tests/magnets/test_lattice1d.py | 114 ++--- source/pip/tests/magnets/test_lattice2d.py | 64 +-- source/pip/tests/magnets/test_model.py | 420 ++++-------------- source/pip/tests/magnets/test_pauli.py | 113 +++++ 14 files changed, 946 insertions(+), 951 deletions(-) create mode 100644 source/pip/tests/magnets/test_pauli.py diff --git a/source/pip/qsharp/magnets/geometry/complete.py b/source/pip/qsharp/magnets/geometry/complete.py index aee8f35014..057abb950b 100644 --- a/source/pip/qsharp/magnets/geometry/complete.py +++ b/source/pip/qsharp/magnets/geometry/complete.py @@ -11,7 +11,7 @@ from qsharp.magnets.utilities import ( Hyperedge, Hypergraph, - greedy_edge_coloring, + HypergraphEdgeColoring, ) @@ -21,8 +21,6 @@ class CompleteGraph(Hypergraph): In a complete graph K_n, there are n vertices and n(n-1)/2 edges, with each pair of distinct vertices connected by exactly one edge. - To do: edge partitioning for parallel updates. - Attributes: n: Number of vertices in the graph. @@ -55,42 +53,32 @@ def __init__(self, n: int, self_loops: bool = False) -> None: _edges.append(Hyperedge([i, j])) super().__init__(_edges) - # Set colors for self-loop edges if enabled - if self_loops: - for i in range(n): - self.color[(i,)] = -1 # Self-loop edges get color -1 - - # Edge coloring for parallel updates - # The even case: n-1 colors are needed - if n % 2 == 0: - m = n - 1 - for i in range(m): - self.color[(i, n - 1)] = ( - i # Connect vertex n-1 to all others with unique colors - ) - for j in range(1, (m - 1) // 2 + 1): - a = (i + j) % m - b = (i - j) % m - if a < b: - self.color[(a, b)] = i - else: - self.color[(b, a)] = i + self.n = n - # The odd case: n colors are needed - # This is the round-robin tournament scheduling algorithm for odd n - # Set m = n for ease of reading - else: - m = n - for i in range(m): - for j in range(1, (m - 1) // 2 + 1): - a = (i + j) % m - b = (i - j) % m - if a < b: - self.color[(a, b)] = i + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute edge coloring for this complete graph.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + else: + if self.n % 2 == 0: + i, j = edge.vertices + m = self.n - 1 + if j == m: + coloring.add_edge(edge, i) + elif (j - i) % 2 == 0: + coloring.add_edge(edge, (j - i) // 2) else: - self.color[(b, a)] = i - - self.n = n + coloring.add_edge(edge, (j - i + m) // 2) + else: + m = self.n + i, j = edge.vertices + if (j - i) % 2 == 0: + coloring.add_edge(edge, (j - i) // 2) + else: + coloring.add_edge(edge, (j - i + m) // 2) + return coloring class CompleteBipartiteGraph(Hypergraph): @@ -104,8 +92,6 @@ class CompleteBipartiteGraph(Hypergraph): Vertices 0 to m-1 form the first set, and vertices m to m+n-1 form the second set. - To do: edge partitioning for parallel updates. - Attributes: m: Number of vertices in the first set. n: Number of vertices in the second set. @@ -144,21 +130,58 @@ def __init__(self, m: int, n: int, self_loops: bool = False) -> None: # Connect every vertex in first set to every vertex in second set for i in range(m): for j in range(m, m + n): - edge_idx = len(_edges) _edges.append(Hyperedge([i, j])) super().__init__(_edges) - # Set colors for self-loop edges if enabled - if self_loops: - for i in range(total_vertices): - self.color[(i,)] = -1 # Self-loop edges get color -1 - - # Color edges based on the second vertex index to create n parallel partitions - for i in range(m): - for j in range(m, m + n): - self.color[(i, j)] = ( - i + j - m - ) % n # Color edges based on second vertex index - self.m = m self.n = n + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute edge coloring for this complete bipartite graph.""" + coloring = HypergraphEdgeColoring(self) + m = self.m + n = self.n + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + else: + i, j = edge.vertices + coloring.add_edge(edge, (i + j - m) % n) + return coloring + + # Color edges based on the second vertex index to create n parallel partitions + # for i in range(m): + # for j in range(m, m + n): + # self.color[(i, j)] = ( + # i + j - m + # ) % n # Color edges based on second vertex index + + # Edge coloring for parallel updates + # The even case: n-1 colors are needed + # if n % 2 == 0: + # m = n - 1 + # for i in range(m): + # self.color[(i, n - 1)] = ( + # i # Connect vertex n-1 to all others with unique colors + # ) + # for j in range(1, (m - 1) // 2 + 1): + # a = (i + j) % m + # b = (i - j) % m + # if a < b: + # self.color[(a, b)] = i + # else: + # self.color[(b, a)] = i + + # The odd case: n colors are needed + # This is the round-robin tournament scheduling algorithm for odd n + # Set m = n for ease of reading + # else: + # m = n + # for i in range(m): + # for j in range(1, (m - 1) // 2 + 1): + # a = (i + j) % m + # b = (i - j) % m + # if a < b: + # self.color[(a, b)] = i + # else: + # self.color[(b, a)] = i diff --git a/source/pip/qsharp/magnets/geometry/lattice1d.py b/source/pip/qsharp/magnets/geometry/lattice1d.py index ff091fd28a..9586167276 100644 --- a/source/pip/qsharp/magnets/geometry/lattice1d.py +++ b/source/pip/qsharp/magnets/geometry/lattice1d.py @@ -8,7 +8,11 @@ simulations and other one-dimensional quantum systems. """ -from qsharp.magnets.utilities import Hyperedge, Hypergraph +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, +) class Chain1D(Hypergraph): @@ -18,11 +22,6 @@ class Chain1D(Hypergraph): The chain has open boundary conditions, meaning the first and last vertices are not connected. - Edges are colored for parallel updates: - - Color -1 (if self_loops): Self-loop edges on each vertex - - Color 0: Even-indexed nearest-neighbor edges (0-1, 2-3, ...) - - Color 1: Odd-indexed nearest-neighbor edges (1-2, 3-4, ...) - Attributes: length: Number of vertices in the chain. @@ -52,19 +51,22 @@ def __init__(self, length: int, self_loops: bool = False) -> None: for i in range(length - 1): _edges.append(Hyperedge([i, i + 1])) - super().__init__(_edges) - - # Update color for self-loop edges - if self_loops: - for i in range(length): - self.color[(i,)] = -1 - - for i in range(length - 1): - color = i % 2 - self.color[(i, i + 1)] = color + super().__init__(_edges) self.length = length + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute a valid edge coloring for this chain.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + else: + i, j = edge.vertices + color = min(i, j) % 2 + coloring.add_edge(edge, color) + return coloring + class Ring1D(Hypergraph): """A one-dimensional ring (periodic chain) lattice. @@ -73,11 +75,6 @@ class Ring1D(Hypergraph): The ring has periodic boundary conditions, meaning the first and last vertices are connected. - Edges are colored for parallel updates: - - Color -1 (if self_loops): Self-loop edges on each vertex - - Color 0: Even-indexed nearest-neighbor edges (0-1, 2-3, ...) - - Color 1: Odd-indexed nearest-neighbor edges (1-2, 3-4, ...) - Attributes: length: Number of vertices in the ring. @@ -108,14 +105,19 @@ def __init__(self, length: int, self_loops: bool = False) -> None: _edges.append(Hyperedge([i, (i + 1) % length])) super().__init__(_edges) - # Update color for self-loop edges - if self_loops: - for i in range(length): - self.color[(i,)] = -1 - - for i in range(length): - j = (i + 1) % length - color = i % 2 - self.color[tuple(sorted([i, j]))] = color - self.length = length + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute a valid edge coloring for this ring.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + else: + i, j = edge.vertices + if {i, j} == {0, self.length - 1}: + color = (self.length % 2) + 1 + else: + color = min(i, j) % 2 + coloring.add_edge(edge, color) + return coloring diff --git a/source/pip/qsharp/magnets/geometry/lattice2d.py b/source/pip/qsharp/magnets/geometry/lattice2d.py index 4821c5eaeb..a69a8c7644 100644 --- a/source/pip/qsharp/magnets/geometry/lattice2d.py +++ b/source/pip/qsharp/magnets/geometry/lattice2d.py @@ -8,7 +8,11 @@ simulations and other two-dimensional quantum systems. """ -from qsharp.magnets.utilities import Hyperedge, Hypergraph +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, +) class Patch2D(Hypergraph): @@ -19,13 +23,6 @@ class Patch2D(Hypergraph): Vertices are indexed in row-major order: vertex (x, y) has index y * width + x. - Edges are colored for parallel updates: - - Color -1 (if self_loops): Self-loop edges on each vertex - - Color 0: Even-column horizontal edges - - Color 1: Odd-column horizontal edges - - Color 2: Even-row vertical edges - - Color 3: Odd-row vertical edges - Attributes: width: Number of vertices in the horizontal direction. height: Number of vertices in the vertical direction. @@ -34,10 +31,8 @@ class Patch2D(Hypergraph): .. code-block:: python >>> patch = Patch2D(3, 2) - >>> patch.nvertices - 6 - >>> patch.nedges - 7 + >>> str(patch) + '3x2 lattice patch with 6 vertices and 7 edges' """ def __init__(self, width: int, height: int, self_loops: bool = False) -> None: @@ -68,29 +63,37 @@ def __init__(self, width: int, height: int, self_loops: bool = False) -> None: _edges.append(Hyperedge([self._index(x, y), self._index(x, y + 1)])) super().__init__(_edges) - # Set up edge colors for parallel updates - if self_loops: - for i in range(width * height): - self.color[(i,)] = -1 - - # Color horizontal edges - for y in range(height): - for x in range(width - 1): - v1, v2 = self._index(x, y), self._index(x + 1, y) - color = 0 if x % 2 == 0 else 1 - self.color[tuple(sorted([v1, v2]))] = color - - # Color vertical edges - for y in range(height - 1): - for x in range(width): - v1, v2 = self._index(x, y), self._index(x, y + 1) - color = 2 if y % 2 == 0 else 3 - self.color[tuple(sorted([v1, v2]))] = color - def _index(self, x: int, y: int) -> int: """Convert (x, y) coordinates to vertex index.""" return y * self.width + x + def __str__(self) -> str: + """Return the summary string ``"{width}x{height} lattice patch with {nvertices} vertices and {nedges} edges"``.""" + return f"{self.width}x{self.height} lattice patch with {self.nvertices} vertices and {self.nedges} edges" + + def __repr__(self) -> str: + """Return a string representation of the Patch2D geometry.""" + return f"Patch2D(width={self.width}, height={self.height})" + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute edge coloring for this 2D patch.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + continue + + u, v = edge.vertices + x_u, y_u = u % self.width, u // self.width + x_v, y_v = v % self.width, v // self.width + + if y_u == y_v: + color = 0 if min(x_u, x_v) % 2 == 0 else 1 + else: + color = 2 if min(y_u, y_v) % 2 == 0 else 3 + coloring.add_edge(edge, color) + return coloring + class Torus2D(Hypergraph): """A two-dimensional toroidal (periodic) lattice. @@ -101,13 +104,6 @@ class Torus2D(Hypergraph): Vertices are indexed in row-major order: vertex (x, y) has index y * width + x. - Edges are colored for parallel updates: - - Color -1 (if self_loops): Self-loop edges on each vertex - - Color 0: Even-column horizontal edges - - Color 1: Odd-column horizontal edges - - Color 2: Even-row vertical edges - - Color 3: Odd-row vertical edges - Attributes: width: Number of vertices in the horizontal direction. height: Number of vertices in the vertical direction. @@ -116,10 +112,8 @@ class Torus2D(Hypergraph): .. code-block:: python >>> torus = Torus2D(3, 2) - >>> torus.nvertices - 6 - >>> torus.nedges - 12 + >>> str(torus) + '3x2 lattice torus with 6 vertices and 12 edges' """ def __init__(self, width: int, height: int, self_loops: bool = False) -> None: @@ -155,25 +149,39 @@ def __init__(self, width: int, height: int, self_loops: bool = False) -> None: super().__init__(_edges) - # Set up edge colors for parallel updates - if self_loops: - for i in range(width * height): - self.color[(i,)] = -1 - - # Color horizontal edges - for y in range(height): - for x in range(width): - v1, v2 = self._index(x, y), self._index((x + 1) % width, y) - color = 0 if x % 2 == 0 else 1 - self.color[tuple(sorted([v1, v2]))] = color - - # Color vertical edges - for y in range(height): - for x in range(width): - v1, v2 = self._index(x, y), self._index(x, (y + 1) % height) - color = 2 if y % 2 == 0 else 3 - self.color[tuple(sorted([v1, v2]))] = color - def _index(self, x: int, y: int) -> int: """Convert (x, y) coordinates to vertex index.""" return y * self.width + x + + def __str__(self) -> str: + """Return the summary string ``"{width}x{height} lattice torus with {nvertices} vertices and {nedges} edges"``.""" + return f"{self.width}x{self.height} lattice torus with {self.nvertices} vertices and {self.nedges} edges" + + def __repr__(self) -> str: + """Return a string representation of the Torus2D geometry.""" + return f"Torus2D(width={self.width}, height={self.height})" + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute edge coloring for this 2D torus.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + continue + + u, v = edge.vertices + x_u, y_u = u % self.width, u // self.width + x_v, y_v = v % self.width, v // self.width + + if y_u == y_v: + if {x_u, x_v} == {0, self.width - 1}: + color = 1 if self.width % 2 == 0 else 4 + else: + color = 0 if min(x_u, x_v) % 2 == 0 else 1 + else: + if {y_u, y_v} == {0, self.height - 1}: + color = 3 if self.height % 2 == 0 else 5 + else: + color = 2 if min(y_u, y_v) % 2 == 0 else 3 + coloring.add_edge(edge, color) + return coloring diff --git a/source/pip/qsharp/magnets/models/__init__.py b/source/pip/qsharp/magnets/models/__init__.py index 58f47bd721..224270e17e 100644 --- a/source/pip/qsharp/magnets/models/__init__.py +++ b/source/pip/qsharp/magnets/models/__init__.py @@ -7,6 +7,6 @@ as Hamiltonians built from Pauli operators. """ -from .model import Model, translation_invariant_ising_model +from .model import IsingModel, Model -__all__ = ["Model", "translation_invariant_ising_model"] +__all__ = ["Model", "IsingModel"] diff --git a/source/pip/qsharp/magnets/models/model.py b/source/pip/qsharp/magnets/models/model.py index 197078f7ab..f262b113a5 100644 --- a/source/pip/qsharp/magnets/models/model.py +++ b/source/pip/qsharp/magnets/models/model.py @@ -3,6 +3,10 @@ # pyright: reportPrivateImportUsage=false +from collections.abc import Sequence +from typing import Optional + + """Base Model class for quantum spin models. This module provides the base class for representing quantum spin models @@ -10,7 +14,12 @@ to define interaction topologies and stores coefficients for each edge. """ -from qsharp.magnets.utilities import Hyperedge, Hypergraph, PauliString +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, + PauliString, +) class Model: @@ -19,8 +28,8 @@ class Model: This class represents a quantum spin Hamiltonian defined on a hypergraph geometry. The Hamiltonian is characterized by: - - Ops: A mapping from edge vertex tuples to (coefficient, PauliString) pairs - - Terms: Groupings of hyperedges for Trotterization or parallel execution + - Ops: A list of PauliStrings (one entry per interaction term) + - Terms: Groupings of operator indices for Trotterization or parallel execution The model is built on a hypergraph geometry that defines which qubits interact with each other. @@ -43,9 +52,10 @@ class Model: def __init__(self, geometry: Hypergraph): """Initialize the Model. - Creates a quantum spin model on the given geometry. The model starts - with all coefficients set to zero (with identity PauliStrings) and - no term groupings. + Creates a quantum spin model on the given geometry. + + The model stores operators lazily in ``_ops`` as terms are defined. + ``_terms`` is initialized with one empty term group. Args: geometry: Hypergraph defining the interaction topology. The number @@ -53,144 +63,43 @@ def __init__(self, geometry: Hypergraph): """ self.geometry: Hypergraph = geometry self._qubits: set[int] = set() - self._ops: dict[tuple[int, ...], tuple[float, PauliString]] = dict() + self._ops: list[PauliString] = [] for edge in geometry.edges(): self._qubits.update(edge.vertices) - self._ops[edge.vertices] = ( - 0.0, - PauliString.from_qubits(edge.vertices, [0] * len(edge.vertices)), - ) - self._terms: list[list[Hyperedge]] = [] - - def get_coefficient(self, vertices: tuple[int, ...]) -> float: - """Get the coefficient for an edge in the Hamiltonian. - - Args: - vertices: Tuple of vertex indices identifying the edge. - - Returns: - The coefficient value for the specified edge. - - Raises: - KeyError: If the vertex tuple does not correspond to an edge - in the geometry. - """ - vertices = tuple(sorted(vertices)) - if vertices not in self._ops: - raise KeyError(f"No edge with vertices {vertices} in geometry") - return self._ops[vertices][0] - - def get_pauli_string(self, vertices: tuple[int, ...]) -> PauliString: - """Get the PauliString for an edge in the Hamiltonian. - - Args: - vertices: Tuple of vertex indices identifying the edge. + self._terms: dict[int, list[int]] = {} - Returns: - The PauliString for the specified edge. - - Raises: - KeyError: If the vertex tuple does not correspond to an edge - in the geometry. - """ - vertices = tuple(sorted(vertices)) - if vertices not in self._ops: - raise KeyError(f"No edge with vertices {vertices} in geometry") - return self._ops[vertices][1] - - def has_interaction_term(self, vertices: tuple[int, ...]) -> bool: - """Check if an interaction term exists for the given edge vertices. - - Args: - vertices: Tuple of vertex indices identifying the edge. - Returns: - True if an interaction term exists for the edge, False otherwise. - """ - return tuple(sorted(vertices)) in self._ops - - def set_coefficient( + def add_interaction( self, - vertices: tuple[int, ...], - value: float, + edge: Hyperedge, + pauli_string: Sequence[int | str] | str, + coefficient: complex = 1.0, + term: Optional[int] = None, ) -> None: - """Set the coefficient for an edge in the Hamiltonian. + """Add an interaction term to the model. Args: - vertices: Tuple of vertex indices identifying the edge. - value: The coefficient value to set. - - Raises: - KeyError: If the vertex tuple does not correspond to an edge - in the geometry. + edge: The Hyperedge representing the qubits involved in the interaction. + pauli_string: The PauliString operator for this interaction. + coefficient: The complex coefficient multiplying this term (default 1.0). """ - vertices = tuple(sorted(vertices)) - if vertices not in self._ops: - raise KeyError(f"No edge with vertices {vertices} in geometry") - self._ops[vertices] = (value, self._ops[vertices][1]) - - def set_pauli_string( - self, - vertices: tuple[int, ...], - pauli_string: PauliString, - ) -> None: - """Set the PauliString for an edge in the Hamiltonian. - - Args: - vertices: Tuple of vertex indices identifying the edge. - pauli_string: The PauliString to associate with this edge. - - Raises: - KeyError: If the vertex tuple does not correspond to an edge - in the geometry. - """ - vertices = tuple(sorted(vertices)) - if vertices not in self._ops: - raise KeyError(f"No edge with vertices {vertices} in geometry") - self._ops[vertices] = (self._ops[vertices][0], pauli_string) - - def set_operator( - self, - vertices: tuple[int, ...], - value: float, - pauli_string: PauliString, - ) -> None: - """Set both the coefficient and PauliString for an edge. - - Convenience method that combines :meth:`set_coefficient` and - :meth:`set_pauli_string` in a single call. - - Args: - vertices: Tuple of vertex indices identifying the edge. - value: The coefficient value to set. - pauli_string: The PauliString to associate with this edge. - - Raises: - KeyError: If the vertex tuple does not correspond to an edge - in the geometry. - """ - vertices = tuple(sorted(vertices)) - if vertices not in self._ops: - raise KeyError(f"No edge with vertices {vertices} in geometry") - self._ops[vertices] = (value, pauli_string) - - def add_term(self, edges: list[Hyperedge]) -> None: - """Add a term grouping to the model. - - Appends a list of hyperedges as a term. Terms are used for - grouping edges for Trotterization or parallel execution. - - Args: - edges: List of Hyperedge objects to group as a term. - """ - self._terms.append(list(edges)) - - def terms(self) -> list[list[Hyperedge]]: - """Return the list of term groupings. - - Returns: - List of lists of Hyperedges representing term groupings. - """ - return self._terms + if edge not in self.geometry.edges(): + raise ValueError("Edge is not part of the model geometry.") + s = PauliString.from_qubits(edge.vertices, pauli_string, coefficient) + self._ops.append(s) + if term is not None: + if term not in self._terms: + self._terms[term] = [] + self._terms[term].append(len(self._ops) - 1) + + @property + def nqubits(self) -> int: + """Return the number of qubits in the model.""" + return len(self._qubits) + + @property + def nterms(self) -> int: + """Return the number of term groups in the model.""" + return len(self._terms) def __str__(self) -> str: """String representation of the model.""" @@ -203,39 +112,26 @@ def __repr__(self) -> str: return self.__str__() -def translation_invariant_ising_model( - geometry: Hypergraph, h: float, J: float -) -> Model: - """Create a translation-invariant Ising model on the given geometry. +class IsingModel(Model): + """Translation-invariant Ising model on a hypergraph geometry. The Hamiltonian is: H = -J * Σ_{} Z_i Z_j - h * Σ_i X_i - Two-body edges (len=2) in the geometry represent ZZ interactions with - coefficient -J. Single-vertex edges (len=1) represent X field terms - with coefficient -h. Edges are grouped into terms by their color - for parallel execution. + - Single-vertex edges define X-field terms with coefficient ``-h``. + - Two-vertex edges define ZZ-coupling terms with coefficient ``-J``. + - Terms are grouped into two groups: ``0`` for field terms and ``1`` for + coupling terms. + """ - Args: - geometry: The Hypergraph defining the interaction topology. - Should include single-vertex edges for field terms. - h: The transverse field strength (coefficient for X terms). - J: The coupling strength (coefficient for ZZ interaction terms). + def __init__(self, geometry: Hypergraph, h: float, J: float): + super().__init__(geometry) + self.coloring: HypergraphEdgeColoring = geometry.edge_coloring() + self._terms = {0: [], 1: []} - Returns: - A Model instance representing the Ising Hamiltonian. - """ - model = Model(geometry) - model._terms = [ - [] for _ in range(geometry.ncolors + 1) - ] # Initialize term groupings based on edge colors - for edge in geometry.edges(): - vertices = edge.vertices - if len(vertices) == 1: - model.set_operator(vertices, -h, PauliString.from_qubits(vertices, "X")) - elif len(vertices) == 2: - model.set_operator(vertices, -J, PauliString.from_qubits(vertices, "ZZ")) - color = geometry.color[vertices] - model._terms[color].append(edge) # Group edges by color for parallel execution - - return model + for edge in geometry.edges(): + vertices = edge.vertices + if len(vertices) == 1: + self.add_interaction(edge, "X", -h, term=0) + elif len(vertices) == 2: + self.add_interaction(edge, "ZZ", -J, term=1) diff --git a/source/pip/qsharp/magnets/utilities/__init__.py b/source/pip/qsharp/magnets/utilities/__init__.py index 10c2ebdf69..b350f7da40 100644 --- a/source/pip/qsharp/magnets/utilities/__init__.py +++ b/source/pip/qsharp/magnets/utilities/__init__.py @@ -7,16 +7,20 @@ the magnets package, including hypergraph representations. """ -from .hypergraph import Hyperedge, Hypergraph, greedy_edge_coloring +from .hypergraph import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, +) from .pauli import Pauli, PauliString, PauliX, PauliY, PauliZ __all__ = [ "Hyperedge", "Hypergraph", + "HypergraphEdgeColoring", "Pauli", "PauliString", "PauliX", "PauliY", "PauliZ", - "greedy_edge_coloring", ] diff --git a/source/pip/qsharp/magnets/utilities/hypergraph.py b/source/pip/qsharp/magnets/utilities/hypergraph.py index 706ef9a1b5..efbe7c6ad3 100644 --- a/source/pip/qsharp/magnets/utilities/hypergraph.py +++ b/source/pip/qsharp/magnets/utilities/hypergraph.py @@ -9,7 +9,6 @@ Hamiltonians, where multi-body interactions can involve more than two sites. """ -from copy import deepcopy import random from typing import Iterator, Optional @@ -45,6 +44,9 @@ def __init__(self, vertices: list[int]) -> None: """ self.vertices: tuple[int, ...] = tuple(sorted(set(vertices))) + def __str__(self) -> str: + return str(self.vertices) + def __repr__(self) -> str: return f"Hyperedge({list(self.vertices)})" @@ -57,11 +59,12 @@ class Hypergraph: various lattice geometries used in quantum simulations. Attributes: - _edge_list: List of hyperedges in the order they were added. + _edge_set: Set of hyperedges in the hypergraph. _vertex_set: Set of all unique vertex indices in the hypergraph. - color: Dictionary mapping edge vertex tuples to color indices. Initially - all edges have color index 0. This is useful for parallelism in - certain architectures. + + Note: + Edge colors are managed separately by :class:`HypergraphEdgeColoring`. + Use :meth:`edge_coloring` to generate a coloring for this hypergraph. Example: @@ -81,40 +84,15 @@ def __init__(self, edges: list[Hyperedge]) -> None: edges: List of hyperedges defining the hypergraph structure. """ self._vertex_set = set() - self._edge_list = edges - self.color: dict[tuple[int, ...], int] = {} # All edges start with color 0 + self._edge_set = set(edges) for edge in edges: self._vertex_set.update(edge.vertices) - self.color[edge.vertices] = 0 - - @property - def ncolors(self) -> int: - """Return the number of distinct colors used in the edge coloring.""" - return len(set(self.color.values())) - - @property - def nedges(self) -> int: - """Return the number of hyperedges in the hypergraph.""" - return len(self._edge_list) @property def nvertices(self) -> int: """Return the number of vertices in the hypergraph.""" return len(self._vertex_set) - def add_edge(self, edge: Hyperedge, color: int = 0) -> None: - """Add a hyperedge to the hypergraph. - - Args: - edge: The Hyperedge instance to add. - color: Color index for the edge, used for implementations - with edge coloring for parallel updates. By - default, all edges are assigned color 0. - """ - self._edge_list.append(edge) - self._vertex_set.update(edge.vertices) - self.color[edge.vertices] = color - def vertices(self) -> Iterator[int]: """Iterate over all vertex indices in the hypergraph. @@ -123,118 +101,215 @@ def vertices(self) -> Iterator[int]: """ return iter(sorted(self._vertex_set)) + @property + def nedges(self) -> int: + """Return the number of hyperedges in the hypergraph.""" + return len(self._edge_set) + def edges(self) -> Iterator[Hyperedge]: """Iterate over all hyperedges in the hypergraph. Returns: Iterator of all hyperedges in the hypergraph. """ - return iter(self._edge_list) + return iter(self._edge_set) - def edges_by_color(self, color: int) -> Iterator[Hyperedge]: - """Iterate over hyperedges with a specific color. + def add_edge(self, edge: Hyperedge) -> None: + """Add a hyperedge to the hypergraph. Args: - color: Color index for filtering edges. + edge: The Hyperedge instance to add. + """ + self._edge_set.add(edge) + self._vertex_set.update(edge.vertices) + + def edge_coloring( + self, seed: Optional[int] = 0, trials: int = 1 + ) -> "HypergraphEdgeColoring": + """Compute a (nondeterministic) greedy edge coloring of this hypergraph. + + Args: + seed: Optional random seed for reproducibility. + trials: Number of randomized trials to attempt. The best coloring + (fewest colors) is returned. Returns: - Iterator of hyperedges with the specified color. + A :class:`HypergraphEdgeColoring` for this hypergraph. """ - return iter( - [edge for edge in self._edge_list if self.color[edge.vertices] == color] - ) + all_edges = sorted(self.edges(), key=lambda edge: edge.vertices) + + if not all_edges: + return HypergraphEdgeColoring(self) + + num_trials = max(trials, 1) + best_coloring: Optional[HypergraphEdgeColoring] = None + least_colors: Optional[int] = None + + for trial in range(num_trials): + trial_seed = None if seed is None else seed + trial + rng = random.Random(trial_seed) + + edge_order = list(all_edges) + rng.shuffle(edge_order) + + coloring = HypergraphEdgeColoring(self) + num_colors = 0 + + for edge in edge_order: + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + continue + + assigned = False + for color in range(num_colors): + used_vertices = set().union( + *( + candidate.vertices + for candidate in coloring.edges_of_color(color) + ) + ) + if not any(vertex in used_vertices for vertex in edge.vertices): + coloring.add_edge(edge, color) + assigned = True + break + + if not assigned: + coloring.add_edge(edge, num_colors) + num_colors += 1 + + if least_colors is None or coloring.ncolors < least_colors: + least_colors = coloring.ncolors + best_coloring = coloring + + assert best_coloring is not None + return best_coloring def __str__(self) -> str: return f"Hypergraph with {self.nvertices} vertices and {self.nedges} edges." def __repr__(self) -> str: - return f"Hypergraph({list(self._edge_list)})" - - -def greedy_edge_coloring( - hypergraph: Hypergraph, # The hypergraph to color. - seed: Optional[int] = None, # Random seed for reproducibility. - trials: int = 1, # Number of trials to perform. -) -> Hypergraph: - """Perform a (nondeterministic) greedy edge coloring of the hypergraph. - Args: - hypergraph: The Hypergraph instance to color. - seed: Optional random seed for reproducibility. - trials: Number of trials to perform. The coloring with the fewest colors - will be returned. Default is 1. - - Returns: - A Hypergraph where each (hyper)edge is assigned a color - such that no two (hyper)edges sharing a vertex have the - same color. - """ + return f"Hypergraph({list(self._edge_set)})" - best = Hypergraph(hypergraph._edge_list) # Placeholder for best coloring found - if seed is not None: - random.seed(seed) +class HypergraphEdgeColoring: + """Edge-color assignment for a :class:`Hypergraph`. - # Shuffle edge indices to randomize insertion order - edge_indexes = list(range(hypergraph.nedges)) - random.shuffle(edge_indexes) + This class stores colors separately from :class:`Hypergraph` and enforces + the rule that multi-vertex edges sharing a color do not share any vertices. - used_vertices: list[set[int]] = [set()] # Vertices used by each color - num_colors = 1 + Conventions: - for i in range(len(edge_indexes)): - edge = hypergraph._edge_list[edge_indexes[i]] - for j in range(num_colors + 1): + - Colors for nontrivial edges must be nonnegative integers. + - Single-vertex edges may use a special color (for example ``-1``). + - Only nonnegative colors contribute to :attr:`ncolors`. - # If we've reached a new color, add it - if j == num_colors: - used_vertices.append(set()) - num_colors += 1 + Note: + Colors are keyed by edge vertex tuples (``edge.vertices``), not by + ``Hyperedge`` object identity. As a result, :meth:`color` accepts any + ``Hyperedge`` with matching vertices, while :meth:`add_edge` still + requires an edge instance that belongs to :attr:`hypergraph`. - # Check if this edge can be added to color j - # Note that we always match on the last color if it was added - # if so, add it and break - if not any(v in used_vertices[j] for v in edge.vertices): - best.color[edge.vertices] = j - used_vertices[j].update(edge.vertices) - break + Attributes: + hypergraph: The supporting :class:`Hypergraph` whose edges can be + colored by this instance. + """ - least_colors = num_colors + def __init__(self, hypergraph: Hypergraph) -> None: + self.hypergraph = hypergraph + self._colors: dict[tuple[int, ...], int] = {} # Vertices-to-color mapping + self._used_vertices: dict[int, set[int]] = ( + {} + ) # Set of vertices used by each color - # To do: parallelize over trials - for trial in range(1, trials): + @property + def ncolors(self) -> int: + """Return the number of distinct nonnegative colors in the coloring.""" + return len(self._used_vertices) + + def color(self, edge: Hyperedge) -> Optional[int]: + """Return the color assigned to a specific edge. + + Args: + edge: Hyperedge to query. Any ``Hyperedge`` with the same + ``vertices`` tuple resolves to the same stored color. - # Set random seed for reproducibility - # Designed to work with parallel trials - if seed is not None: - random.seed(seed + trial) + Returns: + The color assigned to ``edge``, or ``None`` if the edge has not + been added to this coloring. + """ + if not isinstance(edge, Hyperedge): + raise TypeError(f"edge must be Hyperedge, got {type(edge).__name__}") + return self._colors.get(edge.vertices) - # Shuffle edge indices to randomize insertion order - edge_indexes = list(range(hypergraph.nedges)) - random.shuffle(edge_indexes) + def colors(self) -> Iterator[int]: + """Iterate over distinct nonnegative colors present in the coloring. - edge_colors: dict[tuple[int, ...], int] = {} # Edge to color mapping - used_vertices = [set()] # Vertices used by each color - num_colors = 1 + Returns: + Iterator of distinct nonnegative color indices. + """ + return iter(self._used_vertices.keys()) - for i in range(len(edge_indexes)): - edge = hypergraph._edge_list[edge_indexes[i]] - for j in range(num_colors + 1): + def add_edge(self, edge: Hyperedge, color: int) -> None: + """Add ``edge`` to this coloring with the specified ``color``. - # If we've reached a new color, add it - if j == num_colors: - used_vertices.append(set()) - num_colors += 1 + For multi-vertex edges, this enforces that no previously added edge + with the same color shares a vertex with ``edge``. - # Check if this edge can be added to color j - # if so, add it and break - if not any(v in used_vertices[j] for v in edge.vertices): - edge_colors[edge.vertices] = j - used_vertices[j].update(edge.vertices) - break + Args: + edge: The Hyperedge instance to add. This must be an edge present + in :attr:`hypergraph` (typically one returned by + ``hypergraph.edges()``). + color: Color index for the edge. + + Raises: + TypeError: If ``edge`` is not a :class:`Hyperedge`. + ValueError: If ``edge`` is not part of :attr:`hypergraph`. + ValueError: If ``color`` is negative for a nontrivial edge. + RuntimeError: If adding ``edge`` would create a same-color vertex + conflict. + """ + if not isinstance(edge, Hyperedge): + raise TypeError(f"edge must be Hyperedge, got {type(edge).__name__}") + + if edge not in self.hypergraph.edges(): + raise ValueError("edge must belong to the supporting Hypergraph") + + vertices = edge.vertices + + if len(vertices) == 1: + # Single-vertex edges can be colored with a special color (e.g., -1) + self._colors[vertices] = color + else: + if color < 0: + raise ValueError( + "Color index must be nonnegative for multi-vertex edges." + ) + if color not in self._used_vertices: + self._colors[vertices] = color + self._used_vertices[color] = set(vertices) + else: + if any(v in self._used_vertices[color] for v in vertices): + raise RuntimeError( + "Edge conflicts with existing edge of same color." + ) + self._colors[vertices] = color + self._used_vertices[color].update(vertices) + + self._colors[vertices] = color + + def edges_of_color(self, color: int) -> Iterator[Hyperedge]: + """Iterate over hyperedges with a specific color. - # If this trial used fewer colors, update best - if num_colors < least_colors: - least_colors = num_colors - best.color = deepcopy(edge_colors) + Args: + color: Color index for filtering edges. - return best + Returns: + Iterator of edges currently assigned to ``color``. + """ + return iter( + [ + edge + for edge in self.hypergraph.edges() + if self._colors.get(edge.vertices) == color + ] + ) diff --git a/source/pip/qsharp/magnets/utilities/pauli.py b/source/pip/qsharp/magnets/utilities/pauli.py index 82c936778f..4708cb67d4 100644 --- a/source/pip/qsharp/magnets/utilities/pauli.py +++ b/source/pip/qsharp/magnets/utilities/pauli.py @@ -14,28 +14,28 @@ class Pauli: - """A single-qubit Pauli operator (I, X, Y, or Z) acting on a specific qubit. + """Single-qubit Pauli term tied to an explicit qubit index. - Can be constructed from an integer (0–3) or a string ('I', 'X', 'Y', 'Z'), - along with the index of the qubit it acts on. + ``Pauli`` stores a Pauli identifier and the qubit it acts on. The Pauli + identifier can be provided either as an integer code or a label: - Mapping: - 0 / 'I' → Identity - 1 / 'X' → Pauli-X - 2 / 'Z' → Pauli-Z - 3 / 'Y' → Pauli-Y + - ``0`` / ``"I"`` + - ``1`` / ``"X"`` + - ``2`` / ``"Z"`` + - ``3`` / ``"Y"`` - Attributes: - qubit: The qubit index this operator acts on. + Note: + The integer mapping follows the internal QDK convention where ``2`` is + ``Z`` and ``3`` is ``Y``. Example: .. code-block:: python - >>> p = Pauli('X', 0) + >>> p = Pauli("Y", qubit=2) >>> p.op - 1 + 3 >>> p.qubit - 0 + 2 """ _VALID_INTS = {0, 1, 2, 3} @@ -45,15 +45,15 @@ def __init__(self, value: int | str, qubit: int = 0) -> None: """Initialize a Pauli operator. Args: - value: An integer 0–3 or one of 'I', 'X', 'Y', 'Z' (case-insensitive). + value: An integer 0-3 or one of 'I', 'X', 'Y', 'Z' (case-insensitive). qubit: The index of the qubit this operator acts on. Defaults to 0. Raises: - ValueError: If the value is not a recognized Pauli identifier. + ValueError: If ``value`` is not a valid integer/string Pauli identifier. """ if isinstance(value, int): if value not in self._VALID_INTS: - raise ValueError(f"Integer value must be 0–3, got {value}.") + raise ValueError(f"Integer value must be 0-3, got {value}.") self._op = value elif isinstance(value, str): key = value.upper() @@ -68,13 +68,17 @@ def __init__(self, value: int | str, qubit: int = 0) -> None: @property def op(self) -> int: - """Return the integer representation of this Pauli operator. + """Integer encoding of this Pauli term. Returns: - 0 for I, 1 for X, 2 for Z, 3 for Y. + ``0`` for ``I``, ``1`` for ``X``, ``2`` for ``Z``, ``3`` for ``Y``. """ return self._op + def __str__(self) -> str: + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + return f"{labels[self._op]}({self.qubit})" + def __repr__(self) -> str: labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} return f"Pauli('{labels[self._op]}', qubit={self.qubit})" @@ -89,10 +93,11 @@ def __hash__(self) -> int: @property def cirq(self): - """Return the corresponding Cirq Pauli operator. + """Return this Pauli term as a Cirq gate operation on ``LineQubit``. Returns: - ``cirq.I``, ``cirq.X``, ``cirq.Z``, or ``cirq.Y``. + A Cirq operation equivalent to + ``cirq.{I|X|Z|Y}.on(cirq.LineQubit(self.qubit))``. """ _INT_TO_CIRQ = (cirq.I, cirq.X, cirq.Z, cirq.Y) return _INT_TO_CIRQ[self._op].on(cirq.LineQubit(self.qubit)) @@ -114,35 +119,36 @@ def PauliZ(qubit: int) -> Pauli: class PauliString: - """A multi-qubit Pauli operator acting on specific qubits. + """Ordered tensor product of single-qubit ``Pauli`` terms with a coefficient. + + ``PauliString`` stores: - Stores a tuple of :class:`Pauli` objects, each carrying its own qubit index. - Can be constructed from a sequence of ``Pauli`` instances (default), or via - the :meth:`from_qubits` class method which takes qubit indices and Pauli - labels separately. + - an ordered tuple of :class:`Pauli` objects (including each term's qubit), and + - a complex scalar coefficient. - Attributes: - _paulis: Tuple of Pauli objects defining the operator on each qubit. + Construction options: + + - pass a sequence of :class:`Pauli` objects to ``PauliString(...)`` + - use :meth:`from_qubits` to pair qubit indices with Pauli labels/codes Example: .. code-block:: python - >>> ps = PauliString([PauliX(0), PauliZ(1)]) + >>> ps = PauliString([PauliX(0), PauliZ(1)], coefficient=-1j) >>> ps.qubits (0, 1) - >>> list(ps) - [Pauli(X, qubit=0), Pauli(Z, qubit=1)] - >>> ps2 = PauliString.from_qubits((0, 1), "XZ") + >>> ps2 = PauliString.from_qubits((0, 1), "XZ", coefficient=-1j) >>> ps == ps2 True """ - def __init__(self, paulis: Sequence[Pauli]) -> None: + def __init__(self, paulis: Sequence[Pauli], coefficient: complex = 1.0) -> None: """Initialize a PauliString from a sequence of Pauli operators. Args: paulis: A sequence of :class:`Pauli` instances, each with its own qubit index. + coefficient: Complex coefficient multiplying the Pauli string (default 1.0). Raises: TypeError: If any element is not a Pauli instance. @@ -154,20 +160,23 @@ def __init__(self, paulis: Sequence[Pauli]) -> None: "Use PauliString.from_qubits() for int/str values." ) self._paulis: tuple[Pauli, ...] = tuple(paulis) + self._coefficient: complex = coefficient @classmethod def from_qubits( cls, qubits: tuple[int, ...], values: Sequence[int | str] | str, + coefficient: complex = 1.0, ) -> "PauliString": """Create a PauliString from qubit indices and Pauli labels. Args: qubits: Tuple of qubit indices. - values: Sequence of Pauli identifiers (integers 0–3 or strings + values: Sequence of Pauli identifiers (integers 0-3 or strings 'I', 'X', 'Y', 'Z'). A plain string like ``"XZI"`` is also accepted and treated as individual characters. + coefficient: Complex coefficient multiplying the Pauli string. Returns: A new PauliString instance. @@ -181,11 +190,11 @@ def from_qubits( f"Length mismatch: {len(qubits)} qubits vs {len(values)} values." ) paulis = [Pauli(v, q) for q, v in zip(qubits, values)] - return cls(paulis) + return cls(paulis, coefficient=coefficient) @property def qubits(self) -> tuple[int, ...]: - """Return the tuple of qubit indices. + """Tuple of qubit indices in the same order as the stored Pauli terms. Returns: Tuple of qubit indices, one per Pauli operator. @@ -193,7 +202,7 @@ def qubits(self) -> tuple[int, ...]: return tuple(p.qubit for p in self._paulis) def __iter__(self): - """Iterate over the Pauli operators in this PauliString. + """Iterate over Pauli terms in stored order. Yields: :class:`Pauli` instances in order. @@ -206,18 +215,23 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> Pauli: return self._paulis[index] + def __str__(self) -> str: + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + s = "".join(map(str, self._paulis)) + return f"{self._coefficient} * {s}" + def __repr__(self) -> str: labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} s = "".join(labels[p.op] for p in self._paulis) - return f"PauliString(qubits={self.qubits}, ops='{s}')" + return f"PauliString(qubits={self.qubits}, ops='{s}', coefficient={self._coefficient})" def __eq__(self, other: object) -> bool: if not isinstance(other, PauliString): return NotImplemented - return self._paulis == other._paulis + return self._paulis == other._paulis and self._coefficient == other._coefficient def __hash__(self) -> int: - return hash(self._paulis) + return hash((self._paulis, self._coefficient)) @property def cirq(self): @@ -227,9 +241,11 @@ def cirq(self): Pauli to its corresponding ``cirq.LineQubit``. Returns: - A ``cirq.PauliString`` acting on ``cirq.LineQubit`` instances. + A ``cirq.PauliString`` on ``cirq.LineQubit`` instances with + ``self._coefficient`` as its coefficient. """ _INT_TO_CIRQ = (cirq.I, cirq.X, cirq.Z, cirq.Y) return cirq.PauliString( - {cirq.LineQubit(p.qubit): _INT_TO_CIRQ[p.op] for p in self._paulis} + {cirq.LineQubit(p.qubit): _INT_TO_CIRQ[p.op] for p in self._paulis}, + coefficient=self._coefficient, ) diff --git a/source/pip/tests/magnets/test_complete.py b/source/pip/tests/magnets/test_complete.py index 614d030c50..d49c668e63 100644 --- a/source/pip/tests/magnets/test_complete.py +++ b/source/pip/tests/magnets/test_complete.py @@ -3,7 +3,11 @@ """Unit tests for complete graph data structures.""" -from qsharp.magnets.geometry.complete import CompleteBipartiteGraph, CompleteGraph +from qsharp.magnets.geometry.complete import ( + CompleteBipartiteGraph, + CompleteGraph, +) +from qsharp.magnets.utilities import HypergraphEdgeColoring # CompleteGraph tests @@ -79,11 +83,8 @@ def test_complete_graph_with_self_loops(): def test_complete_graph_self_loops_edges(): """Test that self-loop edges are created correctly.""" graph = CompleteGraph(3, self_loops=True) - edges = list(graph.edges()) - # First 3 edges should be self-loops - assert edges[0].vertices == (0,) - assert edges[1].vertices == (1,) - assert edges[2].vertices == (2,) + edge_vertices = {edge.vertices for edge in graph.edges()} + assert {(0,), (1,), (2,)}.issubset(edge_vertices) def test_complete_graph_edge_count_formula(): @@ -183,12 +184,8 @@ def test_complete_bipartite_graph_with_self_loops(): def test_complete_bipartite_graph_self_loops_edges(): """Test that self-loop edges are created correctly.""" graph = CompleteBipartiteGraph(2, 2, self_loops=True) - edges = list(graph.edges()) - # First 4 edges should be self-loops - assert edges[0].vertices == (0,) - assert edges[1].vertices == (1,) - assert edges[2].vertices == (2,) - assert edges[3].vertices == (3,) + edge_vertices = {edge.vertices for edge in graph.edges()} + assert {(0,), (1,), (2,), (3,)}.issubset(edge_vertices) def test_complete_bipartite_graph_edge_count_formula(): @@ -203,24 +200,30 @@ def test_complete_bipartite_graph_edge_count_formula(): def test_complete_bipartite_graph_coloring_without_self_loops(): """Test edge coloring without self-loops.""" graph = CompleteBipartiteGraph(3, 4) + coloring = graph.edge_coloring() # Should have n colors for bipartite coloring - assert graph.ncolors == 4 + assert coloring.ncolors == 4 def test_complete_bipartite_graph_coloring_with_self_loops(): """Test edge coloring with self-loops.""" graph = CompleteBipartiteGraph(3, 4, self_loops=True) + coloring = graph.edge_coloring() # Self-loops get color -1, bipartite edges get n colors (0 to n-1) - # So total distinct colors = n + 1 (including -1) - assert graph.ncolors == 5 + # ncolors counts nonnegative colors only. + assert coloring.ncolors == 4 def test_complete_bipartite_graph_coloring_non_overlapping(): """Test that edges with the same color don't share vertices.""" graph = CompleteBipartiteGraph(3, 4) + coloring = graph.edge_coloring() # Group edges by color colors = {} - for edge_vertices, color in graph.color.items(): + for edge in graph.edges(): + color = coloring.color(edge) + assert color is not None + edge_vertices = edge.vertices if color not in colors: colors[color] = [] colors[color].append(edge_vertices) @@ -247,4 +250,6 @@ def test_complete_bipartite_graph_inherits_hypergraph(): assert isinstance(graph, Hypergraph) assert hasattr(graph, "edges") assert hasattr(graph, "vertices") - assert hasattr(graph, "edges_by_color") + coloring = graph.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") diff --git a/source/pip/tests/magnets/test_hypergraph.py b/source/pip/tests/magnets/test_hypergraph.py index 6df404787c..adf539407a 100755 --- a/source/pip/tests/magnets/test_hypergraph.py +++ b/source/pip/tests/magnets/test_hypergraph.py @@ -3,10 +3,12 @@ """Unit tests for hypergraph data structures.""" +import pytest + from qsharp.magnets.utilities import ( Hyperedge, Hypergraph, - greedy_edge_coloring, + HypergraphEdgeColoring, ) @@ -113,15 +115,100 @@ def test_hypergraph_edges_iterator(): assert len(edge_list) == 2 -def test_hypergraph_edges_by_color(): - """Test edges_by_color returns edges with a specific color.""" - edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] +def test_hypergraph_edges_of_color(): + """Test HypergraphEdgeColoring returns edges with a specific color.""" + edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] graph = Hypergraph(edges) - # Default: all edges have color 0 - edge_list = list(graph.edges_by_color(0)) + coloring = HypergraphEdgeColoring(graph) + coloring.add_edge(edges[0], 0) + coloring.add_edge(edges[1], 0) + edge_list = list(coloring.edges_of_color(0)) assert len(edge_list) == 2 +def test_hypergraph_edge_coloring_method_returns_coloring(): + """Test Hypergraph.edge_coloring returns a HypergraphEdgeColoring.""" + graph = Hypergraph([Hyperedge([0, 1]), Hyperedge([2, 3])]) + coloring = graph.edge_coloring(seed=42) + assert isinstance(coloring, HypergraphEdgeColoring) + + +def test_hypergraph_edge_coloring_stores_supporting_hypergraph(): + """Test HypergraphEdgeColoring keeps a reference to its Hypergraph.""" + graph = Hypergraph([Hyperedge([0, 1])]) + coloring = HypergraphEdgeColoring(graph) + assert coloring.hypergraph is graph + + +def test_hypergraph_edge_coloring_rejects_non_hyperedge(): + """Test add_edge rejects non-Hyperedge values.""" + graph = Hypergraph([Hyperedge([0, 1])]) + coloring = HypergraphEdgeColoring(graph) + + with pytest.raises(TypeError, match="edge must be Hyperedge"): + coloring.add_edge((0, 1), 0) + + +def test_hypergraph_edge_coloring_rejects_edge_not_in_hypergraph(): + """Test add_edge rejects Hyperedge values not in the supporting Hypergraph.""" + graph = Hypergraph([Hyperedge([0, 1])]) + coloring = HypergraphEdgeColoring(graph) + + with pytest.raises( + ValueError, match="edge must belong to the supporting Hypergraph" + ): + coloring.add_edge(Hyperedge([1, 2]), 0) + + +def test_hypergraph_edge_coloring_rejects_equivalent_edge_not_in_hypergraph(): + """Test add_edge requires an edge instance from the supporting Hypergraph.""" + edge = Hyperedge([0, 1]) + graph = Hypergraph([edge]) + coloring = HypergraphEdgeColoring(graph) + + with pytest.raises( + ValueError, match="edge must belong to the supporting Hypergraph" + ): + coloring.add_edge(Hyperedge([0, 1]), 0) + + +def test_hypergraph_edge_coloring_color_matches_equivalent_vertices(): + """Test color lookup uses edge vertices, not Hyperedge object identity.""" + edge = Hyperedge([0, 1]) + graph = Hypergraph([edge]) + coloring = HypergraphEdgeColoring(graph) + + coloring.add_edge(edge, 3) + assert coloring.color(Hyperedge([1, 0])) == 3 + + +def test_hypergraph_edge_coloring_rejects_negative_color_for_nontrivial_edge(): + """Test add_edge raises ValueError for negative color on nontrivial edges.""" + edge = Hyperedge([0, 1]) + graph = Hypergraph([edge]) + coloring = HypergraphEdgeColoring(graph) + + with pytest.raises( + ValueError, match="Color index must be nonnegative for multi-vertex edges" + ): + coloring.add_edge(edge, -1) + + +def test_hypergraph_edge_coloring_rejects_conflicting_edge(): + """Test add_edge raises RuntimeError when same-color edges share a vertex.""" + edge1 = Hyperedge([0, 1]) + edge2 = Hyperedge([1, 2]) + graph = Hypergraph([edge1, edge2]) + coloring = HypergraphEdgeColoring(graph) + + coloring.add_edge(edge1, 0) + + with pytest.raises( + RuntimeError, match="Edge conflicts with existing edge of same color" + ): + coloring.add_edge(edge2, 0) + + def test_hypergraph_add_edge(): """Test adding an edge to the hypergraph.""" graph = Hypergraph([]) @@ -131,21 +218,20 @@ def test_hypergraph_add_edge(): def test_hypergraph_add_edge_with_color(): - """Test adding edges with different colors.""" + """Test assigning colors via HypergraphEdgeColoring.""" graph = Hypergraph([Hyperedge([0, 1])]) - graph.add_edge(Hyperedge([2, 3]), color=1) + edge = Hyperedge([2, 3]) + graph.add_edge(edge) + coloring = HypergraphEdgeColoring(graph) + coloring.add_edge(edge, color=1) assert graph.nedges == 2 - assert graph.color[(0, 1)] == 0 - assert graph.color[(2, 3)] == 1 + assert coloring.color(edge) == 1 def test_hypergraph_color_default(): - """Test that default colors are all 0.""" - edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] - graph = Hypergraph(edges) - assert graph.color[(0, 1)] == 0 - assert graph.color[(1, 2)] == 0 - assert graph.color[(2, 3)] == 0 + """Test that Hypergraph has no built-in color mapping.""" + graph = Hypergraph([Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])]) + assert not hasattr(graph, "color") def test_hypergraph_str(): @@ -200,16 +286,18 @@ def test_hypergraph_non_contiguous_vertices(): def test_greedy_edge_coloring_empty(): """Test greedy edge coloring on empty hypergraph.""" graph = Hypergraph([]) - colored = greedy_edge_coloring(graph) - assert colored.nedges == 0 + colored = graph.edge_coloring() + assert isinstance(colored, HypergraphEdgeColoring) + assert len(list(colored.colors())) == 0 assert colored.ncolors == 0 def test_greedy_edge_coloring_single_edge(): """Test greedy edge coloring with a single edge.""" - graph = Hypergraph([Hyperedge([0, 1])]) - colored = greedy_edge_coloring(graph, seed=42) - assert colored.nedges == 1 + edge = Hyperedge([0, 1]) + graph = Hypergraph([edge]) + colored = graph.edge_coloring(seed=42) + assert colored.color(edge) == 0 assert colored.ncolors == 1 @@ -217,9 +305,10 @@ def test_greedy_edge_coloring_non_overlapping(): """Test coloring of non-overlapping edges (can share color).""" edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] graph = Hypergraph(edges) - colored = greedy_edge_coloring(graph, seed=42) + colored = graph.edge_coloring(seed=42) # Non-overlapping edges can be in the same color - assert colored.nedges == 2 + assert colored.color(edges[0]) is not None + assert colored.color(edges[1]) is not None assert colored.ncolors == 1 @@ -227,9 +316,10 @@ def test_greedy_edge_coloring_overlapping(): """Test coloring of overlapping edges (need different colors).""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] graph = Hypergraph(edges) - colored = greedy_edge_coloring(graph, seed=42) + colored = graph.edge_coloring(seed=42) # Overlapping edges need different colors - assert colored.nedges == 2 + assert colored.color(edges[0]) is not None + assert colored.color(edges[1]) is not None assert colored.ncolors == 2 @@ -237,9 +327,11 @@ def test_greedy_edge_coloring_triangle(): """Test coloring of a triangle (3 edges, all pairwise overlapping).""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] graph = Hypergraph(edges) - colored = greedy_edge_coloring(graph, seed=42) + colored = graph.edge_coloring(seed=42) # All edges share vertices pairwise, so need 3 colors - assert colored.nedges == 3 + assert colored.color(edges[0]) is not None + assert colored.color(edges[1]) is not None + assert colored.color(edges[2]) is not None assert colored.ncolors == 3 @@ -253,14 +345,16 @@ def test_greedy_edge_coloring_validity(): Hyperedge([0, 4]), ] graph = Hypergraph(edges) - colored = greedy_edge_coloring(graph, seed=42) + colored = graph.edge_coloring(seed=42) # Group edges by color colors = {} - for edge_vertices, color in colored.color.items(): + for edge in edges: + color = colored.color(edge) + assert color is not None if color not in colors: colors[color] = [] - colors[color].append(edge_vertices) + colors[color].append(edge.vertices) # Verify each color group has no overlapping edges for color, edge_list in colors.items(): @@ -275,13 +369,12 @@ def test_greedy_edge_coloring_all_edges_colored(): """Test that all edges are assigned a color.""" edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] graph = Hypergraph(edges) - colored = greedy_edge_coloring(graph, seed=42) + colored = graph.edge_coloring(seed=42) # All edges should have a color assigned - assert len(colored.color) == 3 - assert (0, 1) in colored.color - assert (1, 2) in colored.color - assert (2, 3) in colored.color + assert colored.color(edges[0]) is not None + assert colored.color(edges[1]) is not None + assert colored.color(edges[2]) is not None def test_greedy_edge_coloring_reproducible_with_seed(): @@ -289,10 +382,12 @@ def test_greedy_edge_coloring_reproducible_with_seed(): edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3]), Hyperedge([0, 3])] graph = Hypergraph(edges) - colored1 = greedy_edge_coloring(graph, seed=123) - colored2 = greedy_edge_coloring(graph, seed=123) + colored1 = graph.edge_coloring(seed=123) + colored2 = graph.edge_coloring(seed=123) - assert colored1.color == colored2.color + color_map_1 = {edge.vertices: colored1.color(edge) for edge in edges} + color_map_2 = {edge.vertices: colored2.color(edge) for edge in edges} + assert color_map_1 == color_map_2 def test_greedy_edge_coloring_multiple_trials(): @@ -304,7 +399,7 @@ def test_greedy_edge_coloring_multiple_trials(): Hyperedge([3, 0]), ] graph = Hypergraph(edges) - colored = greedy_edge_coloring(graph, seed=42, trials=10) + colored = graph.edge_coloring(seed=42, trials=10) # A cycle of 4 edges can be 2-colored assert colored.ncolors <= 3 # Greedy may not always find optimal @@ -317,10 +412,12 @@ def test_greedy_edge_coloring_hyperedges(): Hyperedge([5, 6, 7]), ] graph = Hypergraph(edges) - colored = greedy_edge_coloring(graph, seed=42) + colored = graph.edge_coloring(seed=42) # First two share vertex 2, third is independent - assert colored.nedges == 3 + assert colored.color(edges[0]) is not None + assert colored.color(edges[1]) is not None + assert colored.color(edges[2]) is not None assert colored.ncolors >= 2 @@ -328,8 +425,10 @@ def test_greedy_edge_coloring_self_loops(): """Test coloring with self-loop edges.""" edges = [Hyperedge([0]), Hyperedge([1]), Hyperedge([2])] graph = Hypergraph(edges) - colored = greedy_edge_coloring(graph, seed=42) + colored = graph.edge_coloring(seed=42) - # Self-loops don't share vertices, can all be same color - assert colored.nedges == 3 - assert colored.ncolors == 1 + # Self-loops use the special -1 color and do not contribute to ncolors. + assert colored.color(edges[0]) == -1 + assert colored.color(edges[1]) == -1 + assert colored.color(edges[2]) == -1 + assert colored.ncolors == 0 diff --git a/source/pip/tests/magnets/test_lattice1d.py b/source/pip/tests/magnets/test_lattice1d.py index b4bbf152c4..d8553d2b99 100644 --- a/source/pip/tests/magnets/test_lattice1d.py +++ b/source/pip/tests/magnets/test_lattice1d.py @@ -4,6 +4,12 @@ """Unit tests for 1D lattice data structures.""" from qsharp.magnets.geometry.lattice1d import Chain1D, Ring1D +from qsharp.magnets.utilities import Hyperedge, HypergraphEdgeColoring + + +def _vertex_color_map(graph) -> dict[tuple[int, ...], int | None]: + coloring = graph.edge_coloring() + return {edge.vertices: coloring.color(edge) for edge in graph.edges()} # Chain1D tests @@ -35,12 +41,8 @@ def test_chain1d_two_vertices(): def test_chain1d_edges(): """Test that Chain1D creates correct nearest-neighbor edges.""" chain = Chain1D(4) - edges = list(chain.edges()) - assert len(edges) == 3 - # Check edges are (0,1), (1,2), (2,3) - assert edges[0].vertices == (0, 1) - assert edges[1].vertices == (1, 2) - assert edges[2].vertices == (2, 3) + edge_vertices = {edge.vertices for edge in chain.edges()} + assert edge_vertices == {(0, 1), (1, 2), (2, 3)} def test_chain1d_vertices(): @@ -61,47 +63,47 @@ def test_chain1d_with_self_loops(): def test_chain1d_self_loops_edges(): """Test that self-loop edges are created correctly.""" chain = Chain1D(3, self_loops=True) - edges = list(chain.edges()) - # First 3 edges should be self-loops - assert edges[0].vertices == (0,) - assert edges[1].vertices == (1,) - assert edges[2].vertices == (2,) - # Next 2 edges should be nearest-neighbor - assert edges[3].vertices == (0, 1) - assert edges[4].vertices == (1, 2) + edge_vertices = {edge.vertices for edge in chain.edges()} + assert edge_vertices == {(0,), (1,), (2,), (0, 1), (1, 2)} def test_chain1d_coloring_without_self_loops(): """Test edge coloring without self-loops.""" chain = Chain1D(5) + color = _vertex_color_map(chain) # Even edges (0-1, 2-3) should have color 0 - assert chain.color[(0, 1)] == 0 - assert chain.color[(2, 3)] == 0 + assert color[(0, 1)] == 0 + assert color[(2, 3)] == 0 # Odd edges (1-2, 3-4) should have color 1 - assert chain.color[(1, 2)] == 1 - assert chain.color[(3, 4)] == 1 + assert color[(1, 2)] == 1 + assert color[(3, 4)] == 1 def test_chain1d_coloring_with_self_loops(): """Test edge coloring with self-loops.""" chain = Chain1D(4, self_loops=True) + color = _vertex_color_map(chain) # Self-loops should have color -1 - assert chain.color[(0,)] == -1 - assert chain.color[(1,)] == -1 - assert chain.color[(2,)] == -1 - assert chain.color[(3,)] == -1 + assert color[(0,)] == -1 + assert color[(1,)] == -1 + assert color[(2,)] == -1 + assert color[(3,)] == -1 # Even edges should have color 0, odd edges color 1 - assert chain.color[(0, 1)] == 0 - assert chain.color[(1, 2)] == 1 - assert chain.color[(2, 3)] == 0 + assert color[(0, 1)] == 0 + assert color[(1, 2)] == 1 + assert color[(2, 3)] == 0 def test_chain1d_coloring_non_overlapping(): """Test that edges with the same color don't share vertices.""" chain = Chain1D(6) + coloring = chain.edge_coloring() # Group edges by color colors = {} - for edge_vertices, color in chain.color.items(): + for edge in chain.edges(): + color = coloring.color(edge) + assert color is not None + edge_vertices = edge.vertices if color not in colors: colors[color] = [] colors[color].append(edge_vertices) @@ -149,13 +151,8 @@ def test_ring1d_three_vertices(): def test_ring1d_edges(): """Test that Ring1D creates correct edges including wrap-around.""" ring = Ring1D(4) - edges = list(ring.edges()) - assert len(edges) == 4 - # Check edges are (0,1), (1,2), (2,3), (0,3) (sorted) - assert edges[0].vertices == (0, 1) - assert edges[1].vertices == (1, 2) - assert edges[2].vertices == (2, 3) - assert edges[3].vertices == (0, 3) # Wrap-around edge + edge_vertices = {edge.vertices for edge in ring.edges()} + assert edge_vertices == {(0, 1), (1, 2), (2, 3), (0, 3)} def test_ring1d_vertices(): @@ -176,48 +173,47 @@ def test_ring1d_with_self_loops(): def test_ring1d_self_loops_edges(): """Test that self-loop edges are created correctly.""" ring = Ring1D(3, self_loops=True) - edges = list(ring.edges()) - # First 3 edges should be self-loops - assert edges[0].vertices == (0,) - assert edges[1].vertices == (1,) - assert edges[2].vertices == (2,) - # Next 3 edges should be nearest-neighbor (including wrap) - assert edges[3].vertices == (0, 1) - assert edges[4].vertices == (1, 2) - assert edges[5].vertices == (0, 2) # Wrap-around + edge_vertices = {edge.vertices for edge in ring.edges()} + assert edge_vertices == {(0,), (1,), (2,), (0, 1), (1, 2), (0, 2)} def test_ring1d_coloring_without_self_loops(): """Test edge coloring without self-loops.""" ring = Ring1D(4) + color = _vertex_color_map(ring) # Even edges should have color 0, odd edges color 1 - assert ring.color[(0, 1)] == 0 - assert ring.color[(1, 2)] == 1 - assert ring.color[(2, 3)] == 0 - assert ring.color[(0, 3)] == 1 # Wrap-around edge (index 3) + assert color[(0, 1)] == 0 + assert color[(1, 2)] == 1 + assert color[(2, 3)] == 0 + assert color[(0, 3)] == 1 # Wrap-around edge def test_ring1d_coloring_with_self_loops(): """Test edge coloring with self-loops.""" ring = Ring1D(4, self_loops=True) + color = _vertex_color_map(ring) # Self-loops should have color -1 - assert ring.color[(0,)] == -1 - assert ring.color[(1,)] == -1 - assert ring.color[(2,)] == -1 - assert ring.color[(3,)] == -1 + assert color[(0,)] == -1 + assert color[(1,)] == -1 + assert color[(2,)] == -1 + assert color[(3,)] == -1 # Even edges should have color 0, odd edges color 1 - assert ring.color[(0, 1)] == 0 - assert ring.color[(1, 2)] == 1 - assert ring.color[(2, 3)] == 0 - assert ring.color[(0, 3)] == 1 + assert color[(0, 1)] == 0 + assert color[(1, 2)] == 1 + assert color[(2, 3)] == 0 + assert color[(0, 3)] == 1 def test_ring1d_coloring_non_overlapping(): """Test that edges with the same color don't share vertices.""" ring = Ring1D(6) + coloring = ring.edge_coloring() # Group edges by color colors = {} - for edge_vertices, color in ring.color.items(): + for edge in ring.edges(): + color = coloring.color(edge) + assert color is not None + edge_vertices = edge.vertices if color not in colors: colors[color] = [] colors[color].append(edge_vertices) @@ -253,7 +249,9 @@ def test_chain1d_inherits_hypergraph(): # Test inherited methods work assert hasattr(chain, "edges") assert hasattr(chain, "vertices") - assert hasattr(chain, "edges_by_color") + coloring = chain.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") def test_ring1d_inherits_hypergraph(): @@ -265,4 +263,6 @@ def test_ring1d_inherits_hypergraph(): # Test inherited methods work assert hasattr(ring, "edges") assert hasattr(ring, "vertices") - assert hasattr(ring, "edges_by_color") + coloring = ring.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") diff --git a/source/pip/tests/magnets/test_lattice2d.py b/source/pip/tests/magnets/test_lattice2d.py index d629975227..5f85fda913 100644 --- a/source/pip/tests/magnets/test_lattice2d.py +++ b/source/pip/tests/magnets/test_lattice2d.py @@ -4,6 +4,12 @@ """Unit tests for 2D lattice data structures.""" from qsharp.magnets.geometry.lattice2d import Patch2D, Torus2D +from qsharp.magnets.utilities import HypergraphEdgeColoring + + +def _vertex_color_map(graph) -> dict[tuple[int, ...], int | None]: + coloring = graph.edge_coloring() + return {edge.vertices: coloring.color(edge) for edge in graph.edges()} # Patch2D tests @@ -84,34 +90,36 @@ def test_patch2d_with_self_loops(): def test_patch2d_self_loops_edges(): """Test that self-loop edges are created correctly.""" patch = Patch2D(2, 2, self_loops=True) - edges = list(patch.edges()) - # First 4 edges should be self-loops - assert edges[0].vertices == (0,) - assert edges[1].vertices == (1,) - assert edges[2].vertices == (2,) - assert edges[3].vertices == (3,) + edge_vertices = {edge.vertices for edge in patch.edges()} + assert {(0,), (1,), (2,), (3,)}.issubset(edge_vertices) def test_patch2d_coloring_without_self_loops(): """Test edge coloring without self-loops.""" patch = Patch2D(4, 4) + coloring = patch.edge_coloring() # Should have 4 colors: horizontal even/odd (0,1), vertical even/odd (2,3) - assert patch.ncolors == 4 + assert coloring.ncolors == 4 def test_patch2d_coloring_with_self_loops(): """Test edge coloring with self-loops.""" patch = Patch2D(3, 3, self_loops=True) - # Should have 5 colors: self-loops (-1) + 4 edge groups (0-3) - assert patch.ncolors == 5 + coloring = patch.edge_coloring() + # Self-loops are -1 and do not contribute to ncolors. + assert coloring.ncolors == 4 def test_patch2d_coloring_non_overlapping(): """Test that edges with the same color don't share vertices.""" patch = Patch2D(4, 4) + coloring = patch.edge_coloring() # Group edges by color colors = {} - for edge_vertices, color in patch.color.items(): + for edge in patch.edges(): + color = coloring.color(edge) + assert color is not None + edge_vertices = edge.vertices if color not in colors: colors[color] = [] colors[color].append(edge_vertices) @@ -126,8 +134,7 @@ def test_patch2d_coloring_non_overlapping(): def test_patch2d_str(): """Test string representation.""" patch = Patch2D(3, 2) - assert "6 vertices" in str(patch) - assert "7 edges" in str(patch) + assert str(patch) == "3x2 lattice patch with 6 vertices and 7 edges" # Torus2D tests @@ -209,34 +216,36 @@ def test_torus2d_with_self_loops(): def test_torus2d_self_loops_edges(): """Test that self-loop edges are created correctly.""" torus = Torus2D(2, 2, self_loops=True) - edges = list(torus.edges()) - # First 4 edges should be self-loops - assert edges[0].vertices == (0,) - assert edges[1].vertices == (1,) - assert edges[2].vertices == (2,) - assert edges[3].vertices == (3,) + edge_vertices = {edge.vertices for edge in torus.edges()} + assert {(0,), (1,), (2,), (3,)}.issubset(edge_vertices) def test_torus2d_coloring_without_self_loops(): """Test edge coloring without self-loops.""" torus = Torus2D(4, 4) + coloring = torus.edge_coloring() # Should have 4 colors: horizontal even/odd (0,1), vertical even/odd (2,3) - assert torus.ncolors == 4 + assert coloring.ncolors == 4 def test_torus2d_coloring_with_self_loops(): """Test edge coloring with self-loops.""" torus = Torus2D(3, 3, self_loops=True) - # Should have 5 colors: self-loops (-1) + 4 edge groups (0-3) - assert torus.ncolors == 5 + coloring = torus.edge_coloring() + # Odd periodic dimensions require dedicated wrap colors. + assert coloring.ncolors == 6 def test_torus2d_coloring_non_overlapping(): """Test that edges with the same color don't share vertices.""" torus = Torus2D(4, 4) + coloring = torus.edge_coloring() # Group edges by color colors = {} - for edge_vertices, color in torus.color.items(): + for edge in torus.edges(): + color = coloring.color(edge) + assert color is not None + edge_vertices = edge.vertices if color not in colors: colors[color] = [] colors[color].append(edge_vertices) @@ -251,8 +260,7 @@ def test_torus2d_coloring_non_overlapping(): def test_torus2d_str(): """Test string representation.""" torus = Torus2D(3, 2) - assert "6 vertices" in str(torus) - assert "12 edges" in str(torus) + assert str(torus) == "3x2 lattice torus with 6 vertices and 12 edges" def test_torus2d_vs_patch2d_edge_count(): @@ -274,7 +282,9 @@ def test_patch2d_inherits_hypergraph(): # Test inherited methods work assert hasattr(patch, "edges") assert hasattr(patch, "vertices") - assert hasattr(patch, "edges_by_color") + coloring = patch.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") def test_torus2d_inherits_hypergraph(): @@ -286,4 +296,6 @@ def test_torus2d_inherits_hypergraph(): # Test inherited methods work assert hasattr(torus, "edges") assert hasattr(torus, "vertices") - assert hasattr(torus, "edges_by_color") + coloring = torus.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") diff --git a/source/pip/tests/magnets/test_model.py b/source/pip/tests/magnets/test_model.py index d86d3924c0..5b9b48c521 100644 --- a/source/pip/tests/magnets/test_model.py +++ b/source/pip/tests/magnets/test_model.py @@ -3,402 +3,144 @@ # pyright: reportPrivateImportUsage=false -"""Unit tests for the Model class.""" +"""Unit tests for the Model classes.""" from __future__ import annotations import pytest -from qsharp.magnets.models import Model -from qsharp.magnets.utilities import Hyperedge, Hypergraph, PauliString +from qsharp.magnets.models import IsingModel, Model +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, + PauliString, +) def make_chain(length: int) -> Hypergraph: - """Create a simple chain hypergraph for testing.""" edges = [Hyperedge([i, i + 1]) for i in range(length - 1)] return Hypergraph(edges) def make_chain_with_vertices(length: int) -> Hypergraph: - """Create a chain hypergraph with single-vertex (field) edges for testing.""" edges = [Hyperedge([i, i + 1]) for i in range(length - 1)] - # Add single-vertex edges for field terms edges.extend([Hyperedge([i]) for i in range(length)]) return Hypergraph(edges) -# Model initialization tests +class CountingColoringHypergraph(Hypergraph): + def __init__(self, edges: list[Hyperedge]): + super().__init__(edges) + self.edge_coloring_calls = 0 + + def edge_coloring(self): + self.edge_coloring_calls += 1 + return super().edge_coloring() def test_model_init_basic(): - """Test basic Model initialization.""" geometry = Hypergraph([Hyperedge([0, 1]), Hyperedge([1, 2])]) model = Model(geometry) assert model.geometry is geometry - assert len(model.terms()) == 0 - - -def test_model_init_with_chain(): - """Test Model initialization with chain geometry.""" - geometry = make_chain(5) - model = Model(geometry) - assert len(model._qubits) == 5 + assert model.nqubits == 3 + assert model.nterms == 0 + assert model._ops == [] + assert model._terms == {} def test_model_init_empty_geometry(): - """Test Model with empty geometry.""" - geometry = Hypergraph([]) - model = Model(geometry) - assert len(model._qubits) == 0 - assert len(model.terms()) == 0 - - -def test_model_init_coefficients_zero(): - """Test that coefficients are initialized to zero.""" - geometry = make_chain(3) # edges: (0,1), (1,2) - model = Model(geometry) - assert model.get_coefficient((0, 1)) == 0.0 - assert model.get_coefficient((1, 2)) == 0.0 - - -def test_model_init_pauli_strings_identity(): - """Test that PauliStrings are initialized to identity.""" - geometry = make_chain(3) # edges: (0,1), (1,2) - model = Model(geometry) - assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "II") - assert model.get_pauli_string((1, 2)) == PauliString.from_qubits((1, 2), "II") - - -# Coefficient tests - - -def test_model_set_coefficient(): - """Test setting coefficient for an edge.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_coefficient((0, 1), 1.5) - assert model.get_coefficient((0, 1)) == 1.5 - - -def test_model_set_coefficient_overwrite(): - """Test overwriting an existing coefficient.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_coefficient((0, 1), 1.5) - model.set_coefficient((0, 1), 2.5) - assert model.get_coefficient((0, 1)) == 2.5 - - -def test_model_set_coefficient_invalid_edge(): - """Test setting coefficient for non-existent edge raises error.""" - geometry = make_chain(2) - model = Model(geometry) - with pytest.raises(KeyError): - model.set_coefficient((0, 2), 1.0) - - -def test_model_get_coefficient_invalid_edge(): - """Test getting coefficient for non-existent edge raises error.""" - geometry = make_chain(2) - model = Model(geometry) - with pytest.raises(KeyError): - model.get_coefficient((0, 2)) - - -def test_model_get_coefficient_sorted(): - """Test that get_coefficient sorts vertices so order doesn't matter.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_coefficient((0, 1), 3.0) - assert model.get_coefficient((1, 0)) == 3.0 - - -def test_model_set_coefficient_sorted(): - """Test that set_coefficient sorts vertices so order doesn't matter.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_coefficient((1, 0), 4.0) - assert model.get_coefficient((0, 1)) == 4.0 - - -def test_model_set_coefficient_preserves_pauli_string(): - """Test that set_coefficient does not change the PauliString.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) - model.set_coefficient((0, 1), 3.0) - assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "ZZ") - - -# PauliString tests - - -def test_model_set_pauli_string(): - """Test setting PauliString for an edge.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) - assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "ZZ") - - -def test_model_set_pauli_string_overwrite(): - """Test overwriting an existing PauliString.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) - model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "XX")) - assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "XX") - - -def test_model_set_pauli_string_invalid_edge(): - """Test setting PauliString for non-existent edge raises error.""" - geometry = make_chain(2) - model = Model(geometry) - with pytest.raises(KeyError): - model.set_pauli_string((0, 2), PauliString.from_qubits((0, 2), "ZZ")) - - -def test_model_get_pauli_string_invalid_edge(): - """Test getting PauliString for non-existent edge raises error.""" - geometry = make_chain(2) - model = Model(geometry) - with pytest.raises(KeyError): - model.get_pauli_string((0, 2)) - - -def test_model_set_pauli_string_preserves_coefficient(): - """Test that set_pauli_string does not change the coefficient.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_coefficient((0, 1), 5.0) - model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) - assert model.get_coefficient((0, 1)) == 5.0 - - -def test_model_set_pauli_string_sorted(): - """Test that set_pauli_string sorts vertices so order doesn't matter.""" - geometry = make_chain(2) - model = Model(geometry) - model.set_pauli_string((1, 0), PauliString.from_qubits((1, 0), "XZ")) - assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((1, 0), "XZ") - + model = Model(Hypergraph([])) + assert model.nqubits == 0 + assert model.nterms == 0 -# has_interaction_term tests +def test_model_add_interaction_basic(): + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -1.5) -def test_model_has_interaction_term_true(): - """Test has_interaction_term returns True for existing edge.""" - geometry = make_chain(3) - model = Model(geometry) - assert model.has_interaction_term((0, 1)) is True - assert model.has_interaction_term((1, 2)) is True - - -def test_model_has_interaction_term_false(): - """Test has_interaction_term returns False for non-existent edge.""" - geometry = make_chain(3) - model = Model(geometry) - assert model.has_interaction_term((0, 2)) is False - assert model.has_interaction_term((5, 6)) is False - - -def test_model_has_interaction_term_sorted(): - """Test has_interaction_term sorts vertices so order doesn't matter.""" - geometry = make_chain(2) - model = Model(geometry) - assert model.has_interaction_term((1, 0)) is True - - -# Term management tests - - -def test_model_add_term(): - """Test adding a term with edges.""" - geometry = make_chain(3) - model = Model(geometry) - edge1 = Hyperedge([0, 1]) - edge2 = Hyperedge([1, 2]) - model.add_term([edge1, edge2]) - assert len(model.terms()) == 1 - assert len(model.terms()[0]) == 2 - - -def test_model_add_multiple_terms(): - """Test adding multiple terms.""" - geometry = make_chain(4) - model = Model(geometry) - model.add_term([Hyperedge([0, 1])]) - model.add_term([Hyperedge([1, 2]), Hyperedge([2, 3])]) - assert len(model.terms()) == 2 - + assert len(model._ops) == 1 + assert model._ops[0] == PauliString.from_qubits((0, 1), "ZZ", -1.5) + assert model.nterms == 0 -# String representation tests +def test_model_add_interaction_with_term(): + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -2.0, term=3) -def test_model_str(): - """Test string representation.""" - geometry = make_chain(4) - model = Model(geometry) - model.add_term([Hyperedge([0, 1])]) - model.add_term([Hyperedge([1, 2])]) - result = str(model) - assert "2 terms" in result - assert "4 qubits" in result + assert model.nterms == 1 + assert 3 in model._terms + assert model._terms[3] == [0] -def test_model_str_empty(): - """Test string representation with no terms.""" - geometry = make_chain(3) - model = Model(geometry) - result = str(model) - assert "0 terms" in result - assert "3 qubits" in result +def test_model_add_interaction_rejects_edge_not_in_geometry(): + model = Model(Hypergraph([Hyperedge([0, 1])])) + with pytest.raises(ValueError, match="Edge is not part of the model geometry"): + model.add_interaction(Hyperedge([1, 2]), "ZZ", -1.0) -def test_model_repr(): - """Test repr representation.""" - geometry = make_chain(2) - model = Model(geometry) +def test_model_str_and_repr(): + model = Model(make_chain(3)) + assert "0 terms" in str(model) + assert "3 qubits" in str(model) assert repr(model) == str(model) -# Integration tests - - -def test_model_build_simple_hamiltonian(): - """Test building a simple ZZ Hamiltonian on a chain.""" - geometry = make_chain(3) - model = Model(geometry) - - # Set coefficients for all edges - for edge in geometry.edges(): - model.set_coefficient(edge.vertices, 1.0) - - # Verify coefficients - assert model.get_coefficient((0, 1)) == 1.0 - assert model.get_coefficient((1, 2)) == 1.0 - - -def test_model_with_partitioned_terms(): - """Test building a model with partitioned terms for Trotterization.""" - geometry = make_chain(4) - model = Model(geometry) - - # Add two terms for even/odd partitioning - even_edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] - odd_edges = [Hyperedge([1, 2])] - model.add_term(even_edges) - model.add_term(odd_edges) - - assert len(model.terms()) == 2 - assert len(model.terms()[0]) == 2 - assert len(model.terms()[1]) == 1 - - -# translation_invariant_ising_model tests - - -def test_translation_invariant_ising_model_basic(): - """Test basic creation of Ising model.""" - from qsharp.magnets.models import translation_invariant_ising_model - +def test_ising_model_basic(): geometry = make_chain_with_vertices(3) - model = translation_invariant_ising_model(geometry, h=1.0, J=1.0) + model = IsingModel(geometry, h=1.0, J=1.0) assert isinstance(model, Model) assert model.geometry is geometry + assert model.nterms == 2 + assert isinstance(model.coloring, HypergraphEdgeColoring) -def test_translation_invariant_ising_model_zz_coefficients(): - """Test that ZZ interaction coefficients are correctly set.""" - from qsharp.magnets.models import translation_invariant_ising_model - - geometry = make_chain_with_vertices(4) # 3 two-body edges: (0,1), (1,2), (2,3) - J = 2.0 - model = translation_invariant_ising_model(geometry, h=0.5, J=J) - - # All two-body edge coefficients should be -J - assert model.get_coefficient((0, 1)) == -J - assert model.get_coefficient((1, 2)) == -J - assert model.get_coefficient((2, 3)) == -J - - -def test_translation_invariant_ising_model_x_coefficients(): - """Test that X field coefficients are correctly set.""" - from qsharp.magnets.models import translation_invariant_ising_model - - geometry = make_chain_with_vertices(4) # 4 single-vertex edges - h = 0.5 - model = translation_invariant_ising_model(geometry, h=h, J=2.0) - - # All single-vertex edge coefficients should be -h - for v in range(4): - assert model.get_coefficient((v,)) == -h - - -def test_translation_invariant_ising_model_coefficients(): - """Test that coefficients are correctly applied.""" - from qsharp.magnets.models import translation_invariant_ising_model - - # Geometry with one two-body edge and two single-vertex edges - geometry = Hypergraph([Hyperedge([0, 1]), Hyperedge([0]), Hyperedge([1])]) - h, J = 0.3, 0.7 - model = translation_invariant_ising_model(geometry, h=h, J=J) - - # Check ZZ coefficient is -J - assert model.get_coefficient((0, 1)) == -J - - # Check X coefficients are -h - assert model.get_coefficient((0,)) == -h - assert model.get_coefficient((1,)) == -h +def test_ising_model_coloring_matches_geometry_coloring(): + geometry = make_chain_with_vertices(4) + model = IsingModel(geometry, h=1.0, J=1.0) + geometry_coloring = geometry.edge_coloring() + for edge in geometry.edges(): + assert model.coloring.color(edge) == geometry_coloring.color(edge) -def test_translation_invariant_ising_model_zero_field(): - """Test Ising model with zero transverse field.""" - from qsharp.magnets.models import translation_invariant_ising_model - geometry = make_chain_with_vertices(3) - model = translation_invariant_ising_model(geometry, h=0.0, J=1.0) +def test_ising_model_initialization_calls_geometry_edge_coloring_once(): + geometry = CountingColoringHypergraph( + [ + Hyperedge([0, 1]), + Hyperedge([1, 2]), + Hyperedge([0]), + Hyperedge([1]), + Hyperedge([2]), + ] + ) - # X coefficients (single-vertex edges) should all be zero - for v in range(3): - assert model.get_coefficient((v,)) == 0.0 + model = IsingModel(geometry, h=1.0, J=1.0) + assert isinstance(model.coloring, HypergraphEdgeColoring) + assert geometry.edge_coloring_calls == 1 -def test_translation_invariant_ising_model_zero_coupling(): - """Test Ising model with zero coupling.""" - from qsharp.magnets.models import translation_invariant_ising_model +def test_ising_model_coefficients_and_paulis(): geometry = make_chain_with_vertices(3) - model = translation_invariant_ising_model(geometry, h=1.0, J=0.0) + model = IsingModel(geometry, h=0.5, J=2.0) - # ZZ coefficients (two-body edges) should all be zero - assert model.get_coefficient((0, 1)) == 0.0 - assert model.get_coefficient((1, 2)) == 0.0 + ops_by_qubits = {tuple(sorted(op.qubits)): op for op in model._ops} + assert ops_by_qubits[(0, 1)] == PauliString.from_qubits((0, 1), "ZZ", -2.0) + assert ops_by_qubits[(1, 2)] == PauliString.from_qubits((1, 2), "ZZ", -2.0) + assert ops_by_qubits[(0,)] == PauliString.from_qubits((0,), "X", -0.5) + assert ops_by_qubits[(1,)] == PauliString.from_qubits((1,), "X", -0.5) + assert ops_by_qubits[(2,)] == PauliString.from_qubits((2,), "X", -0.5) -def test_translation_invariant_ising_model_term_grouping(): - """Test that Ising model has correct term grouping by color.""" - from qsharp.magnets.models import translation_invariant_ising_model +def test_ising_model_term_grouping_indices(): geometry = make_chain_with_vertices(4) - model = translation_invariant_ising_model(geometry, h=1.0, J=1.0) - - # Number of terms should be ncolors + 1 - assert len(model.terms()) == geometry.ncolors + 1 - - -def test_translation_invariant_ising_model_pauli_strings(): - """Test that Ising model sets correct PauliStrings.""" - from qsharp.magnets.models import translation_invariant_ising_model - - geometry = make_chain_with_vertices(3) - model = translation_invariant_ising_model(geometry, h=1.0, J=1.0) - - # Two-body edges should have ZZ PauliString - assert model.get_pauli_string((0, 1)) == PauliString.from_qubits((0, 1), "ZZ") - assert model.get_pauli_string((1, 2)) == PauliString.from_qubits((1, 2), "ZZ") + model = IsingModel(geometry, h=1.0, J=1.0) - # Single-vertex edges should have X PauliString - for v in range(3): - assert model.get_pauli_string((v,)) == PauliString.from_qubits((v,), "X") + assert set(model._terms.keys()) == {0, 1} + assert all(len(model._ops[index].qubits) == 1 for index in model._terms[0]) + assert all(len(model._ops[index].qubits) == 2 for index in model._terms[1]) diff --git a/source/pip/tests/magnets/test_pauli.py b/source/pip/tests/magnets/test_pauli.py new file mode 100644 index 0000000000..7ca82e7c6f --- /dev/null +++ b/source/pip/tests/magnets/test_pauli.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for Pauli and PauliString utilities.""" + +import pytest + +cirq = pytest.importorskip("cirq") + +from qsharp.magnets.utilities import Pauli, PauliString, PauliX, PauliY, PauliZ + + +def test_pauli_init_from_int_and_string(): + """Test Pauli initialization from int and case-insensitive string labels.""" + p_i = Pauli(0, qubit=1) + p_x = Pauli("x", qubit=2) + p_z = Pauli(2, qubit=3) + p_y = Pauli("Y", qubit=4) + + assert p_i.op == 0 and p_i.qubit == 1 + assert p_x.op == 1 and p_x.qubit == 2 + assert p_z.op == 2 and p_z.qubit == 3 + assert p_y.op == 3 and p_y.qubit == 4 + + +@pytest.mark.parametrize("value", [-1, 4, 42]) +def test_pauli_invalid_int_raises(value: int): + """Test invalid integer Pauli identifiers raise ValueError.""" + with pytest.raises(ValueError, match="Integer value must be 0-3"): + Pauli(value) + + +def test_pauli_invalid_string_raises(): + """Test invalid string Pauli identifiers raise ValueError.""" + with pytest.raises(ValueError, match="String value must be one of"): + Pauli("A") + + +def test_pauli_invalid_type_raises(): + """Test non-int/non-str Pauli identifiers raise ValueError.""" + with pytest.raises(ValueError, match="Expected int or str"): + Pauli(1.5) + + +def test_pauli_helpers_create_expected_operator(): + """Test PauliX/PauliY/PauliZ helper constructors.""" + assert PauliX(0) == Pauli("X", 0) + assert PauliY(1) == Pauli("Y", 1) + assert PauliZ(2) == Pauli("Z", 2) + + +def test_pauli_cirq_property_returns_operation_on_line_qubit(): + """Test Pauli.cirq returns a Cirq operation on the target qubit.""" + q = cirq.LineQubit(3) + assert Pauli("I", 3).cirq == cirq.I.on(q) + assert Pauli("X", 3).cirq == cirq.X.on(q) + assert Pauli("Y", 3).cirq == cirq.Y.on(q) + assert Pauli("Z", 3).cirq == cirq.Z.on(q) + + +def test_pauli_string_init_requires_pauli_instances(): + """Test PauliString initializer validates element types.""" + with pytest.raises(TypeError, match="Expected Pauli instance"): + PauliString([PauliX(0), "Z"]) + + +def test_pauli_string_from_qubits_accepts_string_and_int_values(): + """Test PauliString.from_qubits accepts both string and int identifiers.""" + from_string = PauliString.from_qubits((0, 1, 2), "XZY", coefficient=-1j) + from_ints = PauliString.from_qubits((0, 1, 2), [1, 2, 3], coefficient=-1j) + + assert from_string == from_ints + assert len(from_string) == 3 + assert from_string.qubits == (0, 1, 2) + + +def test_pauli_string_from_qubits_length_mismatch_raises(): + """Test from_qubits raises when qubit/value lengths differ.""" + with pytest.raises(ValueError, match="Length mismatch"): + PauliString.from_qubits((0, 1), "XYZ") + + +def test_pauli_string_sequence_protocol_and_indexing(): + """Test iteration, len, and indexing behavior.""" + ps = PauliString([PauliX(0), PauliZ(2)], coefficient=2.0) + + assert ps.qubits == (0, 2) + assert len(ps) == 2 + assert ps[0] == PauliX(0) + assert list(ps) == [PauliX(0), PauliZ(2)] + + +def test_pauli_string_equality_and_hash_include_coefficient(): + """Test equality/hash depend on Pauli terms and coefficient.""" + p1 = PauliString.from_qubits((0, 1), "XZ", coefficient=1.0) + p2 = PauliString.from_qubits((0, 1), "XZ", coefficient=1.0) + p3 = PauliString.from_qubits((0, 1), "XZ", coefficient=-1.0) + + assert p1 == p2 + assert hash(p1) == hash(p2) + assert p1 != p3 + + +def test_pauli_string_cirq_property_preserves_terms_and_coefficient(): + """Test PauliString.cirq conversion with coefficient.""" + ps = PauliString.from_qubits((0, 2), "XZ", coefficient=-0.5j) + + expected = cirq.PauliString( + {cirq.LineQubit(0): cirq.X, cirq.LineQubit(2): cirq.Z}, + coefficient=-0.5j, + ) + + assert ps.cirq == expected From 674124f3e9f592c6ac0773bdbc3d4bdcf2df77ce Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Thu, 26 Feb 2026 23:43:27 -0800 Subject: [PATCH 20/45] =?UTF-8?q?Minor=20changes=20to=20colorings=20and=20?= =?UTF-8?q?term=20representation;=20added=20Heisenberg=20=E2=80=A6=20(#297?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A small PR around the change in functionality of terms in the Model class. - Minor revisions on how edge coloring is presented - Terms are now dict[int, dict[int, int]], where the first index is the term and the second is the color, the value is the index of the interaction operator - Fixed IsingModel to use colorings properly - Added HeisenbergModel - Added tests --- source/pip/qsharp/magnets/models/model.py | 87 ++++++++++++++++--- .../qsharp/magnets/utilities/hypergraph.py | 26 +++--- source/pip/tests/magnets/test_complete.py | 2 +- source/pip/tests/magnets/test_hypergraph.py | 46 +++++----- source/pip/tests/magnets/test_lattice1d.py | 6 +- source/pip/tests/magnets/test_lattice2d.py | 6 +- source/pip/tests/magnets/test_model.py | 85 ++++++++++++++++-- 7 files changed, 198 insertions(+), 60 deletions(-) diff --git a/source/pip/qsharp/magnets/models/model.py b/source/pip/qsharp/magnets/models/model.py index f262b113a5..af6373fcd9 100644 --- a/source/pip/qsharp/magnets/models/model.py +++ b/source/pip/qsharp/magnets/models/model.py @@ -4,7 +4,7 @@ # pyright: reportPrivateImportUsage=false from collections.abc import Sequence -from typing import Optional +from typing import Iterator, Optional """Base Model class for quantum spin models. @@ -54,8 +54,11 @@ def __init__(self, geometry: Hypergraph): Creates a quantum spin model on the given geometry. - The model stores operators lazily in ``_ops`` as terms are defined. - ``_terms`` is initialized with one empty term group. + The model stores operators lazily in ``_ops`` as interaction operators + are defined. Noncommuting collections of operators are collected in + ``_terms`` that stores the indices of its interaction operators. This + list of arrays seperate terms into parallizable groups by color. It is + initialized as one empty term group. Args: geometry: Hypergraph defining the interaction topology. The number @@ -66,7 +69,7 @@ def __init__(self, geometry: Hypergraph): self._ops: list[PauliString] = [] for edge in geometry.edges(): self._qubits.update(edge.vertices) - self._terms: dict[int, list[int]] = {} + self._terms: dict[int, dict[int, list[int]]] = {} def add_interaction( self, @@ -74,6 +77,7 @@ def add_interaction( pauli_string: Sequence[int | str] | str, coefficient: complex = 1.0, term: Optional[int] = None, + color: int = 0, ) -> None: """Add an interaction term to the model. @@ -88,8 +92,16 @@ def add_interaction( self._ops.append(s) if term is not None: if term not in self._terms: - self._terms[term] = [] - self._terms[term].append(len(self._ops) - 1) + self._terms[term] = {} + if color not in self._terms[term]: + self._terms[term][color] = [] + self._terms[term][color].append(len(self._ops) - 1) + + def terms(self, t: int) -> Iterator[PauliString]: + """Get the list of PauliStrings corresponding to a term group.""" + if t not in self._terms: + raise ValueError("Term group does not exist.") + return iter([self._ops[i] for i in self._terms[t]]) @property def nqubits(self) -> int: @@ -126,12 +138,67 @@ class IsingModel(Model): def __init__(self, geometry: Hypergraph, h: float, J: float): super().__init__(geometry) - self.coloring: HypergraphEdgeColoring = geometry.edge_coloring() - self._terms = {0: [], 1: []} + self.h = h + self.J = J + self._terms = {0: {}, 1: {}} + coloring: HypergraphEdgeColoring = geometry.edge_coloring() for edge in geometry.edges(): vertices = edge.vertices if len(vertices) == 1: - self.add_interaction(edge, "X", -h, term=0) + self.add_interaction(edge, "X", -h, term=0, color=0) elif len(vertices) == 2: - self.add_interaction(edge, "ZZ", -J, term=1) + color = coloring.color(edge.vertices) + if color is None: + raise ValueError("Geometry edge coloring failed to assign a color.") + self.add_interaction(edge, "ZZ", -J, term=1, color=color) + + def __str__(self) -> str: + return ( + f"Ising model with {self.nterms} terms on {self.nqubits} qubits " + f"(h={self.h}, J={self.J})." + ) + + def __repr__(self) -> str: + return ( + f"IsingModel(nqubits={self.nqubits}, nterms={self.nterms}, " + f"h={self.h}, J={self.J})" + ) + + +class HeisenbergModel(Model): + """Translation-invariant Heisenberg model on a hypergraph geometry. + + The Hamiltonian is: + H = -J * Σ_{} (X_i X_j + Y_i Y_j + Z_i Z_j) + + - Two-vertex edges define XX, YY, and ZZ coupling terms with coefficient ``-J``. + - Terms are grouped into three parts: ``0`` for XX, ``1`` for YY, and ``2`` for ZZ. + """ + + def __init__(self, geometry: Hypergraph, J: float): + super().__init__(geometry) + self.J = J + self.coloring: HypergraphEdgeColoring = geometry.edge_coloring() + self._terms = {0: {}, 1: {}, 2: {}} + for edge in geometry.edges(): + vertices = edge.vertices + if len(vertices) == 2: + color = self.coloring.color(edge.vertices) + if color is None: + raise ValueError("Geometry edge coloring failed to assign a color.") + self.add_interaction(edge, "XX", -J, term=0, color=color) + self.add_interaction(edge, "YY", -J, term=1, color=color) + self.add_interaction(edge, "ZZ", -J, term=2, color=color) + + def __str__(self) -> str: + return ( + f"Heisenberg model with {self.nterms} terms on {self.nqubits} qubits " + f"(J={self.J})." + ) + + def __repr__(self) -> str: + return ( + f"HeisenbergModel(nqubits={self.nqubits}, nterms={self.nterms}, " + f"J={self.J})" + ) diff --git a/source/pip/qsharp/magnets/utilities/hypergraph.py b/source/pip/qsharp/magnets/utilities/hypergraph.py index efbe7c6ad3..b7caffbd99 100644 --- a/source/pip/qsharp/magnets/utilities/hypergraph.py +++ b/source/pip/qsharp/magnets/utilities/hypergraph.py @@ -205,9 +205,9 @@ class HypergraphEdgeColoring: Note: Colors are keyed by edge vertex tuples (``edge.vertices``), not by - ``Hyperedge`` object identity. As a result, :meth:`color` accepts any - ``Hyperedge`` with matching vertices, while :meth:`add_edge` still - requires an edge instance that belongs to :attr:`hypergraph`. + ``Hyperedge`` object identity. As a result, :meth:`color` accepts edge + vertex tuples directly, while :meth:`add_edge` still requires an edge + instance that belongs to :attr:`hypergraph`. Attributes: hypergraph: The supporting :class:`Hypergraph` whose edges can be @@ -226,20 +226,22 @@ def ncolors(self) -> int: """Return the number of distinct nonnegative colors in the coloring.""" return len(self._used_vertices) - def color(self, edge: Hyperedge) -> Optional[int]: - """Return the color assigned to a specific edge. + def color(self, vertices: tuple[int, ...]) -> Optional[int]: + """Return the color assigned to edge vertices. Args: - edge: Hyperedge to query. Any ``Hyperedge`` with the same - ``vertices`` tuple resolves to the same stored color. + vertices: Canonical vertex tuple for the edge to query (typically + ``edge.vertices``). Returns: - The color assigned to ``edge``, or ``None`` if the edge has not - been added to this coloring. + The color assigned to ``vertices``, or ``None`` if the edge has + not been added to this coloring. """ - if not isinstance(edge, Hyperedge): - raise TypeError(f"edge must be Hyperedge, got {type(edge).__name__}") - return self._colors.get(edge.vertices) + if not isinstance(vertices, tuple) or not all( + isinstance(vertex, int) for vertex in vertices + ): + raise TypeError("vertices must be tuple[int, ...]") + return self._colors.get(vertices) def colors(self) -> Iterator[int]: """Iterate over distinct nonnegative colors present in the coloring. diff --git a/source/pip/tests/magnets/test_complete.py b/source/pip/tests/magnets/test_complete.py index d49c668e63..38052dc668 100644 --- a/source/pip/tests/magnets/test_complete.py +++ b/source/pip/tests/magnets/test_complete.py @@ -221,7 +221,7 @@ def test_complete_bipartite_graph_coloring_non_overlapping(): # Group edges by color colors = {} for edge in graph.edges(): - color = coloring.color(edge) + color = coloring.color(edge.vertices) assert color is not None edge_vertices = edge.vertices if color not in colors: diff --git a/source/pip/tests/magnets/test_hypergraph.py b/source/pip/tests/magnets/test_hypergraph.py index adf539407a..2c28289824 100755 --- a/source/pip/tests/magnets/test_hypergraph.py +++ b/source/pip/tests/magnets/test_hypergraph.py @@ -173,13 +173,13 @@ def test_hypergraph_edge_coloring_rejects_equivalent_edge_not_in_hypergraph(): def test_hypergraph_edge_coloring_color_matches_equivalent_vertices(): - """Test color lookup uses edge vertices, not Hyperedge object identity.""" + """Test color lookup uses edge vertex tuples as keys.""" edge = Hyperedge([0, 1]) graph = Hypergraph([edge]) coloring = HypergraphEdgeColoring(graph) coloring.add_edge(edge, 3) - assert coloring.color(Hyperedge([1, 0])) == 3 + assert coloring.color((0, 1)) == 3 def test_hypergraph_edge_coloring_rejects_negative_color_for_nontrivial_edge(): @@ -225,7 +225,7 @@ def test_hypergraph_add_edge_with_color(): coloring = HypergraphEdgeColoring(graph) coloring.add_edge(edge, color=1) assert graph.nedges == 2 - assert coloring.color(edge) == 1 + assert coloring.color(edge.vertices) == 1 def test_hypergraph_color_default(): @@ -297,7 +297,7 @@ def test_greedy_edge_coloring_single_edge(): edge = Hyperedge([0, 1]) graph = Hypergraph([edge]) colored = graph.edge_coloring(seed=42) - assert colored.color(edge) == 0 + assert colored.color(edge.vertices) == 0 assert colored.ncolors == 1 @@ -307,8 +307,8 @@ def test_greedy_edge_coloring_non_overlapping(): graph = Hypergraph(edges) colored = graph.edge_coloring(seed=42) # Non-overlapping edges can be in the same color - assert colored.color(edges[0]) is not None - assert colored.color(edges[1]) is not None + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None assert colored.ncolors == 1 @@ -318,8 +318,8 @@ def test_greedy_edge_coloring_overlapping(): graph = Hypergraph(edges) colored = graph.edge_coloring(seed=42) # Overlapping edges need different colors - assert colored.color(edges[0]) is not None - assert colored.color(edges[1]) is not None + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None assert colored.ncolors == 2 @@ -329,9 +329,9 @@ def test_greedy_edge_coloring_triangle(): graph = Hypergraph(edges) colored = graph.edge_coloring(seed=42) # All edges share vertices pairwise, so need 3 colors - assert colored.color(edges[0]) is not None - assert colored.color(edges[1]) is not None - assert colored.color(edges[2]) is not None + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None + assert colored.color(edges[2].vertices) is not None assert colored.ncolors == 3 @@ -350,7 +350,7 @@ def test_greedy_edge_coloring_validity(): # Group edges by color colors = {} for edge in edges: - color = colored.color(edge) + color = colored.color(edge.vertices) assert color is not None if color not in colors: colors[color] = [] @@ -372,9 +372,9 @@ def test_greedy_edge_coloring_all_edges_colored(): colored = graph.edge_coloring(seed=42) # All edges should have a color assigned - assert colored.color(edges[0]) is not None - assert colored.color(edges[1]) is not None - assert colored.color(edges[2]) is not None + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None + assert colored.color(edges[2].vertices) is not None def test_greedy_edge_coloring_reproducible_with_seed(): @@ -385,8 +385,8 @@ def test_greedy_edge_coloring_reproducible_with_seed(): colored1 = graph.edge_coloring(seed=123) colored2 = graph.edge_coloring(seed=123) - color_map_1 = {edge.vertices: colored1.color(edge) for edge in edges} - color_map_2 = {edge.vertices: colored2.color(edge) for edge in edges} + color_map_1 = {edge.vertices: colored1.color(edge.vertices) for edge in edges} + color_map_2 = {edge.vertices: colored2.color(edge.vertices) for edge in edges} assert color_map_1 == color_map_2 @@ -415,9 +415,9 @@ def test_greedy_edge_coloring_hyperedges(): colored = graph.edge_coloring(seed=42) # First two share vertex 2, third is independent - assert colored.color(edges[0]) is not None - assert colored.color(edges[1]) is not None - assert colored.color(edges[2]) is not None + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None + assert colored.color(edges[2].vertices) is not None assert colored.ncolors >= 2 @@ -428,7 +428,7 @@ def test_greedy_edge_coloring_self_loops(): colored = graph.edge_coloring(seed=42) # Self-loops use the special -1 color and do not contribute to ncolors. - assert colored.color(edges[0]) == -1 - assert colored.color(edges[1]) == -1 - assert colored.color(edges[2]) == -1 + assert colored.color(edges[0].vertices) == -1 + assert colored.color(edges[1].vertices) == -1 + assert colored.color(edges[2].vertices) == -1 assert colored.ncolors == 0 diff --git a/source/pip/tests/magnets/test_lattice1d.py b/source/pip/tests/magnets/test_lattice1d.py index d8553d2b99..8117ee3617 100644 --- a/source/pip/tests/magnets/test_lattice1d.py +++ b/source/pip/tests/magnets/test_lattice1d.py @@ -9,7 +9,7 @@ def _vertex_color_map(graph) -> dict[tuple[int, ...], int | None]: coloring = graph.edge_coloring() - return {edge.vertices: coloring.color(edge) for edge in graph.edges()} + return {edge.vertices: coloring.color(edge.vertices) for edge in graph.edges()} # Chain1D tests @@ -101,7 +101,7 @@ def test_chain1d_coloring_non_overlapping(): # Group edges by color colors = {} for edge in chain.edges(): - color = coloring.color(edge) + color = coloring.color(edge.vertices) assert color is not None edge_vertices = edge.vertices if color not in colors: @@ -211,7 +211,7 @@ def test_ring1d_coloring_non_overlapping(): # Group edges by color colors = {} for edge in ring.edges(): - color = coloring.color(edge) + color = coloring.color(edge.vertices) assert color is not None edge_vertices = edge.vertices if color not in colors: diff --git a/source/pip/tests/magnets/test_lattice2d.py b/source/pip/tests/magnets/test_lattice2d.py index 5f85fda913..6a1291e9b4 100644 --- a/source/pip/tests/magnets/test_lattice2d.py +++ b/source/pip/tests/magnets/test_lattice2d.py @@ -9,7 +9,7 @@ def _vertex_color_map(graph) -> dict[tuple[int, ...], int | None]: coloring = graph.edge_coloring() - return {edge.vertices: coloring.color(edge) for edge in graph.edges()} + return {edge.vertices: coloring.color(edge.vertices) for edge in graph.edges()} # Patch2D tests @@ -117,7 +117,7 @@ def test_patch2d_coloring_non_overlapping(): # Group edges by color colors = {} for edge in patch.edges(): - color = coloring.color(edge) + color = coloring.color(edge.vertices) assert color is not None edge_vertices = edge.vertices if color not in colors: @@ -243,7 +243,7 @@ def test_torus2d_coloring_non_overlapping(): # Group edges by color colors = {} for edge in torus.edges(): - color = coloring.color(edge) + color = coloring.color(edge.vertices) assert color is not None edge_vertices = edge.vertices if color not in colors: diff --git a/source/pip/tests/magnets/test_model.py b/source/pip/tests/magnets/test_model.py index 5b9b48c521..00222ec4fe 100644 --- a/source/pip/tests/magnets/test_model.py +++ b/source/pip/tests/magnets/test_model.py @@ -10,10 +10,10 @@ import pytest from qsharp.magnets.models import IsingModel, Model +from qsharp.magnets.models.model import HeisenbergModel from qsharp.magnets.utilities import ( Hyperedge, Hypergraph, - HypergraphEdgeColoring, PauliString, ) @@ -72,7 +72,7 @@ def test_model_add_interaction_with_term(): assert model.nterms == 1 assert 3 in model._terms - assert model._terms[3] == [0] + assert model._terms[3] == {0: [0]} def test_model_add_interaction_rejects_edge_not_in_geometry(): @@ -95,7 +95,23 @@ def test_ising_model_basic(): assert isinstance(model, Model) assert model.geometry is geometry assert model.nterms == 2 - assert isinstance(model.coloring, HypergraphEdgeColoring) + assert set(model._terms.keys()) == {0, 1} + + +def test_ising_model_str_and_repr(): + geometry = make_chain_with_vertices(3) + model = IsingModel(geometry, h=0.5, J=2.0) + + assert str(model) == "Ising model with 2 terms on 3 qubits (h=0.5, J=2.0)." + assert repr(model) == "IsingModel(nqubits=3, nterms=2, h=0.5, J=2.0)" + + +def test_heisenberg_model_str_and_repr(): + geometry = make_chain(3) + model = HeisenbergModel(geometry, J=1.5) + + assert str(model) == "Heisenberg model with 3 terms on 3 qubits (J=1.5)." + assert repr(model) == "HeisenbergModel(nqubits=3, nterms=3, J=1.5)" def test_ising_model_coloring_matches_geometry_coloring(): @@ -103,8 +119,11 @@ def test_ising_model_coloring_matches_geometry_coloring(): model = IsingModel(geometry, h=1.0, J=1.0) geometry_coloring = geometry.edge_coloring() - for edge in geometry.edges(): - assert model.coloring.color(edge) == geometry_coloring.color(edge) + for color, indices in model._terms[1].items(): + for index in indices: + op = model._ops[index] + edge_vertices = tuple(sorted(op.qubits)) + assert geometry_coloring.color(edge_vertices) == color def test_ising_model_initialization_calls_geometry_edge_coloring_once(): @@ -120,7 +139,7 @@ def test_ising_model_initialization_calls_geometry_edge_coloring_once(): model = IsingModel(geometry, h=1.0, J=1.0) - assert isinstance(model.coloring, HypergraphEdgeColoring) + assert isinstance(model, IsingModel) assert geometry.edge_coloring_calls == 1 @@ -142,5 +161,55 @@ def test_ising_model_term_grouping_indices(): model = IsingModel(geometry, h=1.0, J=1.0) assert set(model._terms.keys()) == {0, 1} - assert all(len(model._ops[index].qubits) == 1 for index in model._terms[0]) - assert all(len(model._ops[index].qubits) == 2 for index in model._terms[1]) + assert all( + len(model._ops[index].qubits) == 1 + for indices in model._terms[0].values() + for index in indices + ) + assert all( + len(model._ops[index].qubits) == 2 + for indices in model._terms[1].values() + for index in indices + ) + + +def test_heisenberg_model_basic(): + geometry = make_chain(3) + model = HeisenbergModel(geometry, J=1.0) + + assert isinstance(model, Model) + assert model.geometry is geometry + assert model.nterms == 3 + assert set(model._terms.keys()) == {0, 1, 2} + + +def test_heisenberg_model_coefficients_and_paulis(): + geometry = make_chain(3) + model = HeisenbergModel(geometry, J=2.5) + + expected = [ + PauliString.from_qubits((0, 1), "XX", -2.5), + PauliString.from_qubits((1, 2), "XX", -2.5), + PauliString.from_qubits((0, 1), "YY", -2.5), + PauliString.from_qubits((1, 2), "YY", -2.5), + PauliString.from_qubits((0, 1), "ZZ", -2.5), + PauliString.from_qubits((1, 2), "ZZ", -2.5), + ] + for pauli in expected: + assert pauli in model._ops + + +def test_heisenberg_model_term_grouping_colors_and_paulis(): + geometry = make_chain(4) + model = HeisenbergModel(geometry, J=1.0) + + paulis_by_term = {0: "XX", 1: "YY", 2: "ZZ"} + for term, pauli in paulis_by_term.items(): + for color, indices in model._terms[term].items(): + for index in indices: + op = model._ops[index] + expected = PauliString.from_qubits( + tuple(sorted(op.qubits)), pauli, -1.0 + ) + assert op == expected + assert model.coloring.color(tuple(sorted(op.qubits))) == color From 104001657a83bc120acb7876b139a5d41eef89b3 Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Thu, 5 Mar 2026 09:39:31 -0800 Subject: [PATCH 21/45] More edits to functionality --- source/pip/qsharp/magnets/models/model.py | 0 source/pip/qsharp/magnets/trotter/__init__.py | 2 - source/pip/qsharp/magnets/trotter/trotter.py | 79 +++++++++--------- source/pip/qsharp/magnets/utilities/pauli.py | 15 ++++ source/pip/tests/magnets/test_model.py | 0 source/pip/tests/magnets/test_trotter.py | 81 +++++++++---------- 6 files changed, 91 insertions(+), 86 deletions(-) mode change 100644 => 100755 source/pip/qsharp/magnets/models/model.py mode change 100644 => 100755 source/pip/tests/magnets/test_model.py diff --git a/source/pip/qsharp/magnets/models/model.py b/source/pip/qsharp/magnets/models/model.py old mode 100644 new mode 100755 diff --git a/source/pip/qsharp/magnets/trotter/__init__.py b/source/pip/qsharp/magnets/trotter/__init__.py index 95dc485fa7..d4beaa68c5 100644 --- a/source/pip/qsharp/magnets/trotter/__init__.py +++ b/source/pip/qsharp/magnets/trotter/__init__.py @@ -6,7 +6,6 @@ from .trotter import ( TrotterStep, TrotterExpansion, - trotter_decomposition, strang_splitting, suzuki_recursion, yoshida_recursion, @@ -16,7 +15,6 @@ __all__ = [ "TrotterStep", "TrotterExpansion", - "trotter_decomposition", "strang_splitting", "suzuki_recursion", "yoshida_recursion", diff --git a/source/pip/qsharp/magnets/trotter/trotter.py b/source/pip/qsharp/magnets/trotter/trotter.py index 0568db61fc..20ef8cb845 100644 --- a/source/pip/qsharp/magnets/trotter/trotter.py +++ b/source/pip/qsharp/magnets/trotter/trotter.py @@ -9,26 +9,51 @@ class TrotterStep: """ - Base class for Trotter decompositions. Essentially, this is a wrapper around a - list of (time, term_index) tuples, which specify which term to apply for how long. + Base class for Trotter decompositions. + + Essentially, this is a wrapper around a list of ``(time, term_index)`` tuples, + which specify which term to apply for how long, independent of the specific + Trotter decomposition or model being used. The TrotterStep class provides a common interface for different Trotter decompositions, such as first-order Trotter and Strang splitting. It also serves as the base class for higher-order Trotter steps that can be constructed via Suzuki or Yoshida recursion. Each Trotter step is defined by the sequence of terms to apply and their corresponding time durations, as well as the overall order of the decomposition and the time step for each term. + + The constructor creates an empty Trotter step (when ``num_terms = 0``), or a + first-order Trotter step: + + .. math:: + + e^{-i H t} \\approx \\prod_k e^{-i H_k t}, \\quad H = \\sum_k H_k. + + In the first-order case, each term index from ``0`` to ``num_terms - 1`` appears + once, each with duration ``time_step``. + + Example: + + .. code-block:: python + + >>> trotter = TrotterStep(num_terms=3, time_step=0.5) + >>> list(trotter.step()) + [(0.5, 0), (0.5, 1), (0.5, 2)] + + References: + H. F. Trotter, Proc. Amer. Math. Soc. 10, 545 (1959). + + TODO: Initializer offers randomized order of terms. """ - def __init__(self): + def __init__(self, num_terms: int = 0, time_step: float = 0.0): """ Creates an empty Trotter decomposition. """ - self.terms: list[tuple[float, int]] = [] - self._nterms = 0 - self._time_step = 0.0 - self._order = 0 - self._repr_string = "TrotterStep()" + self._nterms = num_terms + self._time_step = time_step + self._order = 1 if num_terms > 0 else 0 + self.terms: list[tuple[float, int]] = [(time_step, j) for j in range(num_terms)] @property def order(self) -> int: @@ -89,7 +114,7 @@ def __str__(self) -> str: def __repr__(self) -> str: """String representation of the Trotter decomposition.""" - return self._repr_string + return f"TrotterStep(num_terms={self._nterms}, time_step={self._time_step})" def suzuki_recursion(trotter: TrotterStep) -> TrotterStep: @@ -180,38 +205,6 @@ def yoshida_recursion(trotter: TrotterStep) -> TrotterStep: return yoshida -def trotter_decomposition(num_terms: int, time: float) -> TrotterStep: - """ - Factory function for creating a first-order Trotter decomposition. - - The first-order Trotter-Suzuki formula for approximating time evolution - under a Hamiltonian represented as a sum of terms - - H = ∑_k H_k - - is obtained by sequentially applying each term for the full time - - e^{-i H t} ≈ ∏_k e^{-i H_k t}. - - Example: - - .. code-block:: python - >>> trotter = first_order_trotter(num_terms=3, time=0.5) - >>> list(trotter.step()) - [(0.5, 0), (0.5, 1), (0.5, 2)] - - References: - H. F. Trotter, Proc. Amer. Math. Soc. 10, 545 (1959). - """ - trotter = TrotterStep() - trotter.terms = [(time, term_index) for term_index in range(num_terms)] - trotter._nterms = num_terms - trotter._time_step = time - trotter._order = 1 - trotter._repr_string = f"FirstOrderTrotter(time_step={time}, num_terms={num_terms})" - return trotter - - def strang_splitting(num_terms: int, time: float) -> TrotterStep: """ Factory function for creating a Strang splitting (second-order @@ -286,7 +279,7 @@ class TrotterExpansion: .. code-block:: python >>> n = 4 # Number of Trotter steps >>> total_time = 1.0 # Total time - >>> step = trotter_decomposition(num_terms=2, time=total_time/n) + >>> step = TrotterStep(num_terms=2, time_step=total_time/n) >>> expansion = TrotterExpansion(step, n) >>> expansion.order 1 @@ -318,7 +311,7 @@ def nterms(self) -> int: return self._trotter_step.nterms @property - def num_steps(self) -> int: + def nsteps(self) -> int: """Get the number of Trotter steps.""" return self._num_steps diff --git a/source/pip/qsharp/magnets/utilities/pauli.py b/source/pip/qsharp/magnets/utilities/pauli.py index 4708cb67d4..c681aa2987 100644 --- a/source/pip/qsharp/magnets/utilities/pauli.py +++ b/source/pip/qsharp/magnets/utilities/pauli.py @@ -201,6 +201,21 @@ def qubits(self) -> tuple[int, ...]: """ return tuple(p.qubit for p in self._paulis) + @property + def coefficient(self) -> complex: + """Complex coefficient multiplying this Pauli string.""" + return self._coefficient + + @property + def paulis(self) -> str: + """String of Pauli labels in the same order as the stored Pauli terms. + + Returns: + String of Pauli labels ('I', 'X', 'Z', 'Y'), one per Pauli operator. + """ + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + return "".join(labels[p.op] for p in self._paulis) + def __iter__(self): """Iterate over Pauli terms in stored order. diff --git a/source/pip/tests/magnets/test_model.py b/source/pip/tests/magnets/test_model.py old mode 100644 new mode 100755 diff --git a/source/pip/tests/magnets/test_trotter.py b/source/pip/tests/magnets/test_trotter.py index db981a9c16..fb7c72046d 100644 --- a/source/pip/tests/magnets/test_trotter.py +++ b/source/pip/tests/magnets/test_trotter.py @@ -6,7 +6,6 @@ from qsharp.magnets.trotter import ( TrotterStep, TrotterExpansion, - trotter_decomposition, strang_splitting, suzuki_recursion, yoshida_recursion, @@ -49,68 +48,68 @@ def test_trotter_step_reduce_empty(): assert list(trotter.step()) == [] -# trotter_decomposition factory tests +# first-order TrotterStep constructor tests -def test_trotter_decomposition_basic(): - """Test basic trotter_decomposition creation.""" - trotter = trotter_decomposition(num_terms=3, time=0.5) +def test_trotter_step_first_order_basic(): + """Test basic first-order TrotterStep creation.""" + trotter = TrotterStep(num_terms=3, time_step=0.5) assert trotter.nterms == 3 assert trotter.time_step == 0.5 assert trotter.order == 1 -def test_trotter_decomposition_single_term(): - """Test trotter_decomposition with a single term.""" - trotter = trotter_decomposition(num_terms=1, time=1.0) +def test_trotter_step_first_order_single_term(): + """Test first-order TrotterStep with a single term.""" + trotter = TrotterStep(num_terms=1, time_step=1.0) result = list(trotter.step()) assert result == [(1.0, 0)] -def test_trotter_decomposition_multiple_terms(): - """Test trotter_decomposition with multiple terms.""" - trotter = trotter_decomposition(num_terms=3, time=0.5) +def test_trotter_step_first_order_multiple_terms(): + """Test first-order TrotterStep with multiple terms.""" + trotter = TrotterStep(num_terms=3, time_step=0.5) result = list(trotter.step()) assert result == [(0.5, 0), (0.5, 1), (0.5, 2)] -def test_trotter_decomposition_zero_time(): - """Test trotter_decomposition with zero time.""" - trotter = trotter_decomposition(num_terms=2, time=0.0) +def test_trotter_step_first_order_zero_time(): + """Test first-order TrotterStep with zero time.""" + trotter = TrotterStep(num_terms=2, time_step=0.0) result = list(trotter.step()) assert result == [(0.0, 0), (0.0, 1)] -def test_trotter_decomposition_returns_all_terms(): - """Test that trotter_decomposition returns all term indices.""" +def test_trotter_step_first_order_returns_all_terms(): + """Test that first-order TrotterStep returns all term indices.""" num_terms = 5 - trotter = trotter_decomposition(num_terms=num_terms, time=1.0) + trotter = TrotterStep(num_terms=num_terms, time_step=1.0) result = list(trotter.step()) assert len(result) == num_terms term_indices = [idx for _, idx in result] assert term_indices == list(range(num_terms)) -def test_trotter_decomposition_uniform_time(): - """Test that all terms have the same time in trotter_decomposition.""" +def test_trotter_step_first_order_uniform_time(): + """Test that all terms have the same time in first-order TrotterStep.""" time = 0.25 - trotter = trotter_decomposition(num_terms=4, time=time) + trotter = TrotterStep(num_terms=4, time_step=time) result = list(trotter.step()) for t, _ in result: assert t == time -def test_trotter_decomposition_str(): - """Test string representation of trotter_decomposition result.""" - trotter = trotter_decomposition(num_terms=3, time=0.5) +def test_trotter_step_first_order_str(): + """Test string representation of first-order TrotterStep.""" + trotter = TrotterStep(num_terms=3, time_step=0.5) result = str(trotter) assert "order" in result.lower() or "1" in result -def test_trotter_decomposition_repr(): - """Test repr representation of trotter_decomposition result.""" - trotter = trotter_decomposition(num_terms=3, time=0.5) - assert "FirstOrderTrotter" in repr(trotter) +def test_trotter_step_first_order_repr(): + """Test repr representation of first-order TrotterStep.""" + trotter = TrotterStep(num_terms=3, time_step=0.5) + assert "TrotterStep" in repr(trotter) # strang_splitting factory tests @@ -210,7 +209,7 @@ def test_suzuki_recursion_from_strang(): def test_suzuki_recursion_from_first_order(): """Test Suzuki recursion applied to first-order Trotter produces 3rd order.""" - trotter = trotter_decomposition(num_terms=2, time=1.0) + trotter = TrotterStep(num_terms=2, time_step=1.0) suzuki = suzuki_recursion(trotter) assert suzuki.order == 3 assert suzuki.nterms == 2 @@ -239,7 +238,7 @@ def test_suzuki_recursion_repr(): def test_suzuki_recursion_time_weights_sum(): """Test that time weights in Suzuki recursion sum correctly.""" - base = trotter_decomposition(num_terms=2, time=1.0) + base = TrotterStep(num_terms=2, time_step=1.0) suzuki = suzuki_recursion(base) # The total scaled time should equal the original total time * nterms # because we're scaling times, not adding them @@ -265,7 +264,7 @@ def test_yoshida_recursion_from_strang(): def test_yoshida_recursion_from_first_order(): """Test Yoshida recursion applied to first-order Trotter produces 3rd order.""" - trotter = trotter_decomposition(num_terms=2, time=1.0) + trotter = TrotterStep(num_terms=2, time_step=1.0) yoshida = yoshida_recursion(trotter) assert yoshida.order == 3 assert yoshida.nterms == 2 @@ -294,7 +293,7 @@ def test_yoshida_recursion_repr(): def test_yoshida_recursion_time_weights_sum(): """Test that time weights in Yoshida recursion sum correctly.""" - base = trotter_decomposition(num_terms=2, time=1.0) + base = TrotterStep(num_terms=2, time_step=1.0) yoshida = yoshida_recursion(base) # The total scaled time should equal the original total time * nterms # because weights w1 + w0 + w1 = 2*w1 + w0 = 2*w1 + (1 - 2*w1) = 1 @@ -338,7 +337,7 @@ def test_fourth_order_trotter_suzuki_equals_suzuki_of_strang(): def test_trotter_expansion_init_basic(): """Test basic TrotterExpansion initialization.""" - step = trotter_decomposition(num_terms=2, time=0.25) + step = TrotterStep(num_terms=2, time_step=0.25) expansion = TrotterExpansion(step, num_steps=4) assert expansion._trotter_step is step assert expansion._num_steps == 4 @@ -346,7 +345,7 @@ def test_trotter_expansion_init_basic(): def test_trotter_expansion_get_single_step(): """Test TrotterExpansion with a single step.""" - step = trotter_decomposition(num_terms=2, time=1.0) + step = TrotterStep(num_terms=2, time_step=1.0) expansion = TrotterExpansion(step, num_steps=1) result = expansion.get() assert len(result) == 1 @@ -357,7 +356,7 @@ def test_trotter_expansion_get_single_step(): def test_trotter_expansion_get_multiple_steps(): """Test TrotterExpansion with multiple steps.""" - step = trotter_decomposition(num_terms=2, time=0.25) + step = TrotterStep(num_terms=2, time_step=0.25) expansion = TrotterExpansion(step, num_steps=4) result = expansion.get() assert len(result) == 1 @@ -382,7 +381,7 @@ def test_trotter_expansion_total_time(): """Test that total evolution time is correct.""" total_time = 1.0 num_steps = 4 - step = trotter_decomposition(num_terms=3, time=total_time / num_steps) + step = TrotterStep(num_terms=3, time_step=total_time / num_steps) expansion = TrotterExpansion(step, num_steps=num_steps) result = expansion.get() terms, count = result[0] @@ -395,7 +394,7 @@ def test_trotter_expansion_total_time(): def test_trotter_expansion_preserves_step(): """Test that expansion preserves the original step.""" - step = trotter_decomposition(num_terms=3, time=0.5) + step = TrotterStep(num_terms=3, time_step=0.5) expansion = TrotterExpansion(step, num_steps=10) result = expansion.get() terms, _ = result[0] @@ -421,28 +420,28 @@ def test_trotter_expansion_order_property(): def test_trotter_expansion_nterms_property(): """Test TrotterExpansion nterms property.""" - step = trotter_decomposition(num_terms=5, time=0.5) + step = TrotterStep(num_terms=5, time_step=0.5) expansion = TrotterExpansion(step, num_steps=4) assert expansion.nterms == 5 def test_trotter_expansion_num_steps_property(): """Test TrotterExpansion num_steps property.""" - step = trotter_decomposition(num_terms=2, time=0.25) + step = TrotterStep(num_terms=2, time_step=0.25) expansion = TrotterExpansion(step, num_steps=8) assert expansion.num_steps == 8 def test_trotter_expansion_total_time_property(): """Test TrotterExpansion total_time property.""" - step = trotter_decomposition(num_terms=2, time=0.25) + step = TrotterStep(num_terms=2, time_step=0.25) expansion = TrotterExpansion(step, num_steps=4) assert expansion.total_time == 1.0 def test_trotter_expansion_step_iterator(): """Test TrotterExpansion step() iterator yields full expansion.""" - step = trotter_decomposition(num_terms=2, time=0.5) + step = TrotterStep(num_terms=2, time_step=0.5) expansion = TrotterExpansion(step, num_steps=3) result = list(expansion.step()) # Should yield 3 repetitions of [(0.5, 0), (0.5, 1)] @@ -474,7 +473,7 @@ def test_trotter_expansion_str(): def test_trotter_expansion_repr(): """Test TrotterExpansion repr representation.""" - step = trotter_decomposition(num_terms=2, time=0.5) + step = TrotterStep(num_terms=2, time_step=0.5) expansion = TrotterExpansion(step, num_steps=4) result = repr(expansion) assert "TrotterExpansion" in result From df1f2c2274ab1777d565fa9dc27c92729646d031 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Tue, 10 Mar 2026 18:41:57 +0100 Subject: [PATCH 22/45] Fix magnet tests (#3006) This fixes some tests that caused CI errors. The faulty tests were checked in with a previous PR. Fixes #2997 --- source/pip/qsharp/magnets/trotter/trotter.py | 8 ++++++-- source/pip/tests/magnets/test_trotter.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/source/pip/qsharp/magnets/trotter/trotter.py b/source/pip/qsharp/magnets/trotter/trotter.py index 20ef8cb845..f7ac8b18f0 100644 --- a/source/pip/qsharp/magnets/trotter/trotter.py +++ b/source/pip/qsharp/magnets/trotter/trotter.py @@ -4,7 +4,7 @@ """Base Trotter class for first- and second-order Trotter-Suzuki decomposition.""" -from typing import Iterator +from typing import Iterator, Optional class TrotterStep: @@ -53,6 +53,7 @@ def __init__(self, num_terms: int = 0, time_step: float = 0.0): self._nterms = num_terms self._time_step = time_step self._order = 1 if num_terms > 0 else 0 + self._repr_string: Optional[str] = None self.terms: list[tuple[float, int]] = [(time_step, j) for j in range(num_terms)] @property @@ -114,7 +115,10 @@ def __str__(self) -> str: def __repr__(self) -> str: """String representation of the Trotter decomposition.""" - return f"TrotterStep(num_terms={self._nterms}, time_step={self._time_step})" + if self._repr_string is not None: + return self._repr_string + else: + return f"TrotterStep(num_terms={self._nterms}, time_step={self._time_step})" def suzuki_recursion(trotter: TrotterStep) -> TrotterStep: diff --git a/source/pip/tests/magnets/test_trotter.py b/source/pip/tests/magnets/test_trotter.py index fb7c72046d..2b2fe22572 100644 --- a/source/pip/tests/magnets/test_trotter.py +++ b/source/pip/tests/magnets/test_trotter.py @@ -429,7 +429,7 @@ def test_trotter_expansion_num_steps_property(): """Test TrotterExpansion num_steps property.""" step = TrotterStep(num_terms=2, time_step=0.25) expansion = TrotterExpansion(step, num_steps=8) - assert expansion.num_steps == 8 + assert expansion.nsteps == 8 def test_trotter_expansion_total_time_property(): From bdddb704a4462c1ae6bc93b076f9f00fec5477a9 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Wed, 11 Mar 2026 17:16:49 +0100 Subject: [PATCH 23/45] Resource estimation results table (#2973) The main feature from this PR is a resource estimation result table that is returned from the `estimate` call. It can be configured to have additional columns which are considered in the creation of a pandas DataFrame. Besides that, the PR - refactors how instructions are build: they are now all build inside the provenance graph that is part of the architecture context - enhances enumeration capabilities: one can now enumerate over nested attributes and union attributes, and also restrict domains, further one can restrict the domain of the application inside the trace query - adds post processing to applications: applications can post-process estimation results _before_ they are inserted in the estimation table. This post processing step comes currently at a cost in runtime, because estimation is parallelized in Python and not in Rust --- source/pip/benchmarks/bench_qre.py | 9 +- source/pip/qsharp/qre/__init__.py | 15 +- source/pip/qsharp/qre/_application.py | 60 +- source/pip/qsharp/qre/_architecture.py | 218 ++++++- source/pip/qsharp/qre/_enumeration.py | 163 ++++- source/pip/qsharp/qre/_estimation.py | 199 +++++- source/pip/qsharp/qre/_instruction.py | 222 ++++--- source/pip/qsharp/qre/_qre.pyi | 183 ++++-- source/pip/qsharp/qre/_trace.py | 50 +- .../qsharp/qre/models/factories/_litinski.py | 19 +- .../qre/models/factories/_round_based.py | 19 +- .../pip/qsharp/qre/models/factories/_utils.py | 11 +- .../qsharp/qre/models/qec/_surface_code.py | 25 +- .../pip/qsharp/qre/models/qec/_three_aux.py | 25 +- source/pip/qsharp/qre/models/qec/_yoked.py | 26 +- source/pip/qsharp/qre/models/qubits/_aqre.py | 57 +- source/pip/qsharp/qre/models/qubits/_msft.py | 53 +- source/pip/src/qre.rs | 324 ++++++++-- source/pip/test_requirements.txt | 1 + source/pip/tests/test_qre.py | 598 ++++++++++++++++-- source/pip/tests/test_qre_models.py | 240 +++---- source/qre/src/isa.rs | 175 +++-- source/qre/src/trace.rs | 8 +- source/qre/src/trace/transforms/psspc.rs | 10 +- 24 files changed, 2056 insertions(+), 654 deletions(-) diff --git a/source/pip/benchmarks/bench_qre.py b/source/pip/benchmarks/bench_qre.py index f273cf5d0d..e236594921 100644 --- a/source/pip/benchmarks/bench_qre.py +++ b/source/pip/benchmarks/bench_qre.py @@ -3,7 +3,8 @@ import timeit from dataclasses import dataclass, KW_ONLY, field -from qsharp.qre import linear_function, generic_function, instruction +from qsharp.qre import linear_function, generic_function +from qsharp.qre._architecture import _make_instruction from qsharp.qre.models import AQREGateBased, SurfaceCode from qsharp.qre._enumeration import _enumerate_instances @@ -39,7 +40,7 @@ def bench_enumerate_isas(): sys.path.append(os.path.join(os.path.dirname(__file__), "../tests")) from test_qre import ExampleLogicalFactory, ExampleFactory # type: ignore - ctx = AQREGateBased().context() + ctx = AQREGateBased(gate_time=50, measurement_time=100).context() # Hierarchical factory using from_components query = SurfaceCode.q() * ExampleLogicalFactory.q( @@ -62,7 +63,7 @@ def bench_enumerate_isas(): def bench_function_evaluation_linear(): fl = linear_function(12) - inst = instruction(42, arity=None, space=fl, time=1, error_rate=1.0) + inst = _make_instruction(42, 0, None, 1, fl, None, 1.0, {}) number = 1000 duration = timeit.timeit( "inst.space(5)", @@ -83,7 +84,7 @@ def func(arity: int) -> int: fg = generic_function(func) - inst = instruction(42, arity=None, space=fg, time=1, error_rate=1.0) + inst = _make_instruction(42, 0, None, 1, fg, None, 1.0, {}) number = 1000 duration = timeit.timeit( "inst.space(5)", diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index 15c3477cb7..86aba4790d 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -3,7 +3,12 @@ from ._application import Application from ._architecture import Architecture -from ._estimation import estimate +from ._estimation import ( + estimate, + EstimationTable, + EstimationTableColumn, + EstimationTableEntry, +) from ._instruction import ( LOGICAL, PHYSICAL, @@ -11,7 +16,6 @@ ISATransform, PropertyKey, constraint, - instruction, InstructionSource, ) from ._isa_enumeration import ISAQuery, ISARefNode, ISA_ROOT @@ -31,14 +35,13 @@ linear_function, instruction_name, ) -from ._trace import LatticeSurgery, PSSPC, TraceQuery +from ._trace import LatticeSurgery, PSSPC, TraceQuery, TraceTransform __all__ = [ "block_linear_function", "constant_function", "constraint", "estimate", - "instruction", "linear_function", "Application", "Architecture", @@ -47,6 +50,9 @@ "ConstraintBound", "Encoding", "EstimationResult", + "EstimationTable", + "EstimationTableColumn", + "EstimationTableEntry", "FactoryResult", "generic_function", "instruction_name", @@ -63,6 +69,7 @@ "PSSPC", "Trace", "TraceQuery", + "TraceTransform", "LOGICAL", "PHYSICAL", ] diff --git a/source/pip/qsharp/qre/_application.py b/source/pip/qsharp/qre/_application.py index 4c9a1829e1..8f2e1d33ed 100644 --- a/source/pip/qsharp/qre/_application.py +++ b/source/pip/qsharp/qre/_application.py @@ -6,8 +6,8 @@ import types from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor +from types import NoneType from typing import ( - Any, ClassVar, Generic, Protocol, @@ -18,7 +18,8 @@ ) from ._enumeration import _enumerate_instances -from ._qre import Trace +from ._qre import Trace, EstimationResult +from ._trace import TraceQuery class DataclassProtocol(Protocol): @@ -50,9 +51,19 @@ class Application(ABC, Generic[TraceParameters]): def get_trace(self, parameters: TraceParameters) -> Trace: """Return the trace corresponding to this application.""" - def context(self, **kwargs) -> _Context: + @staticmethod + def q(**kwargs) -> TraceQuery: + return TraceQuery(NoneType, **kwargs) + + def context(self) -> _Context: """Create a new enumeration context for this application.""" - return _Context(self, **kwargs) + return _Context(self) + + def post_process( + self, parameters: TraceParameters, estimation: EstimationResult + ) -> EstimationResult: + """Post-process an estimation result for a given set of trace parameters.""" + return estimation def enumerate_traces( self, @@ -80,6 +91,45 @@ def enumerate_traces( for instances in _enumerate_instances(cast(type, param_type), **kwargs): yield self.get_trace(instances) + def enumerate_traces_with_parameters( + self, + **kwargs, + ) -> Generator[tuple[TraceParameters, Trace], None, None]: + """Yields (parameters, trace) pairs for an application. + + Like ``enumerate_traces``, but each yielded trace is accompanied by the + trace parameters that were used to generate it. + + Args: + **kwargs: Domain overrides forwarded to ``_enumerate_instances``. + + Returns: + Generator[tuple[TraceParameters, Trace], None, None]: A generator + of (parameters, trace) pairs. + """ + + param_type = get_type_hints(self.__class__.get_trace).get("parameters") + if param_type is types.NoneType: + yield None, self.get_trace(None) # type: ignore + return + + if isinstance(param_type, TypeVar): + for c in param_type.__constraints__: + if c is not types.NoneType: + param_type = c + break + + if self._parallel_traces: + instances = list(_enumerate_instances(cast(type, param_type), **kwargs)) + with ThreadPoolExecutor() as executor: + for instance, trace in zip( + instances, executor.map(self.get_trace, instances) + ): + yield instance, trace + else: + for instance in _enumerate_instances(cast(type, param_type), **kwargs): + yield instance, self.get_trace(instance) + def disable_parallel_traces(self): """Disable parallel trace generation for this application.""" self._parallel_traces = False @@ -87,8 +137,6 @@ def disable_parallel_traces(self): class _Context: application: Application - kwargs: dict[str, Any] def __init__(self, application: Application, **kwargs): self.application = application - self.kwargs = kwargs diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index ce69e7d7c1..15bdd3afdb 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -3,20 +3,44 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING +from typing import cast, TYPE_CHECKING from abc import ABC, abstractmethod -from ._qre import ISA, _ProvenanceGraph, _Instruction +from ._qre import ( + ISA, + _ProvenanceGraph, + _Instruction, + _IntFunction, + _FloatFunction, + constant_function, +) if TYPE_CHECKING: - from ._instruction import ISATransform + from typing import Optional + + from ._instruction import ISATransform, Encoding + + +# Valid property names for instructions, mapped to their integer keys. +_PROPERTY_KEYS: dict[str, int] = {"distance": 0} class Architecture(ABC): - @property @abstractmethod - def provided_isa(self) -> ISA: ... + def provided_isa(self, ctx: _Context) -> ISA: + """ + Creates the ISA provided by this architecture, adding instructions + directly to the context's provenance graph. + + Args: + ctx: The enumeration context whose provenance graph stores + the instructions. + + Returns: + ISA: The ISA backed by the context's provenance graph. + """ + ... def context(self) -> _Context: """Create a new enumeration context for this architecture.""" @@ -31,12 +55,8 @@ class _Context: def __init__(self, arch: Architecture): self._provenance: _ProvenanceGraph = _ProvenanceGraph() - def _mark_instruction(inst: _Instruction) -> _Instruction: - node = self._provenance.add_node(inst.id, 0, []) - inst.set_source(node) - return inst - - self._isa = ISA([_mark_instruction(instr) for instr in arch.provided_isa]) + # Let the architecture create instructions directly in the graph. + self._isa = arch.provided_isa(self) self._bindings: dict[str, ISA] = {} self._transforms: dict[int, Architecture | ISATransform] = {0: arch} @@ -47,29 +67,173 @@ def _with_binding(self, name: str, isa: ISA) -> _Context: ctx._bindings = {**self._bindings, name: isa} return ctx - def set_source( + @property + def isa(self) -> ISA: + """The ISA provided by the architecture for this context.""" + return self._isa + + def add_instruction( self, - transform: ISATransform, - instruction: _Instruction, - source_instructions: list[_Instruction], - ) -> _Instruction: + id_or_instruction: int | _Instruction, + encoding: Encoding = 0, # type: ignore + *, + arity: Optional[int] = 1, + time: int | _IntFunction = 0, + space: Optional[int] | _IntFunction = None, + length: Optional[int | _IntFunction] = None, + error_rate: float | _FloatFunction = 0.0, + transform: ISATransform | None = None, + source: list[_Instruction] | None = None, + **kwargs: int, + ) -> int: """ - Record the provenance of an instruction generated by a transform, and - return the instruction with its source set. + Create an instruction and add it to the provenance graph. + + Can be called in two ways: + + 1. With keyword args to create a new instruction:: + + ctx.add_instruction(T, encoding=LOGICAL, time=1000, + error_rate=1e-8) + + 2. With a pre-existing ``_Instruction`` object (e.g. from + ``with_id()``):: + + ctx.add_instruction(existing_instruction) + + Provenance is recorded when *transform* and/or *source* are + supplied: + + - **transform** — the ``ISATransform`` that produced the + instruction. + - **source** — input instructions consumed by the transform. Args: - transform: The transform that generated the instruction. - instruction: The instruction whose provenance is being recorded. - source_instructions: The instructions that were used as input to the - transform to generate this instruction. + id_or_instruction: Either an instruction ID (int) for creating + a new instruction, or an existing ``_Instruction`` object. + encoding: The instruction encoding (0 = Physical, 1 = Logical). + Ignored when passing an existing ``_Instruction``. + arity: The instruction arity. ``None`` for variable arity. + Ignored when passing an existing ``_Instruction``. + time: Instruction time in ns (or ``_IntFunction`` for variable + arity). Ignored when passing an existing ``_Instruction``. + space: Instruction space in physical qubits (or ``_IntFunction`` + for variable arity). Ignored when passing an existing + ``_Instruction``. + length: Arity including ancilla qubits. Ignored when passing an + existing ``_Instruction``. + error_rate: Instruction error rate (or ``_FloatFunction`` for + variable arity). Ignored when passing an existing + ``_Instruction``. + transform: The ``ISATransform`` that produced the instruction. + source: List of source ``_Instruction`` objects consumed by the + transform. + **kwargs: Additional properties (e.g. ``distance=9``). Ignored + when passing an existing ``_Instruction``. Returns: - The input instruction with its source set to the provenance node. + The node index in the provenance graph. + + Raises: + ValueError: If an unknown property name is provided in kwargs. """ + if transform is None and source is None: + return self._provenance.add_instruction( + cast(int, id_or_instruction), + encoding, + arity=arity, + time=time, + space=space, + length=length, + error_rate=error_rate, + **kwargs, + ) + + if isinstance(id_or_instruction, _Instruction): + instr = id_or_instruction + else: + instr = _make_instruction( + id_or_instruction, + int(encoding), + arity, + time, + space, + length, + error_rate, + kwargs, + ) + + transform_id = id(transform) if transform is not None else 0 + children = [inst.source for inst in source] if source else [] + + node_index = self._provenance.add_node(instr, transform_id, children) + + if transform is not None: + self._transforms[transform_id] = transform + + return node_index + + def make_isa(self, *node_indices: int) -> ISA: + """ + Creates an ISA backed by this context's provenance graph from the + given node indices. + + Args: + *node_indices: Node indices in the provenance graph. - source = self._provenance.add_node( - instruction.id, id(transform), [inst.source for inst in source_instructions] + Returns: + ISA: An ISA referencing the provenance graph. + """ + return self._provenance.make_isa(list(node_indices)) + + +def _make_instruction( + id: int, + encoding: int, + arity: int | None, + time: int | _IntFunction, + space: int | _IntFunction | None, + length: int | _IntFunction | None, + error_rate: float | _FloatFunction, + properties: dict[str, int], +) -> _Instruction: + """Build an ``_Instruction`` from keyword arguments.""" + if arity is not None: + instr = _Instruction.fixed_arity( + id, + encoding, + arity, + cast(int, time), + cast(int | None, space), + cast(int | None, length), + cast(float, error_rate), + ) + else: + if isinstance(time, int): + time = constant_function(time) + if isinstance(space, int): + space = constant_function(space) + if isinstance(length, int): + length = constant_function(length) + if isinstance(error_rate, (int, float)): + error_rate = constant_function(float(error_rate)) + + instr = _Instruction.variable_arity( + id, + encoding, + time, + cast(_IntFunction, space), + error_rate, + length, ) - instruction.set_source(source) - return instruction + for key, value in properties.items(): + prop_key = _PROPERTY_KEYS.get(key) + if prop_key is None: + raise ValueError( + f"Unknown property '{key}'. " + f"Valid properties: {list(_PROPERTY_KEYS)}" + ) + instr.set_property(prop_key, value) + + return instr diff --git a/source/pip/qsharp/qre/_enumeration.py b/source/pip/qsharp/qre/_enumeration.py index 07d4b81466..fef8b85314 100644 --- a/source/pip/qsharp/qre/_enumeration.py +++ b/source/pip/qsharp/qre/_enumeration.py @@ -8,6 +8,7 @@ TypeVar, Literal, Union, + cast, get_args, get_origin, get_type_hints, @@ -20,41 +21,127 @@ T = TypeVar("T") +def _is_union_type(tp) -> bool: + """Check if a type is a Union or Python 3.10+ union (X | Y).""" + return get_origin(tp) is Union or isinstance(tp, types.UnionType) + + +def _is_type_filter(val, union_members: tuple) -> bool: + """ + Check if *val* is a union member type or a list of union member types, + i.e. a type filter for a union field (as opposed to a fixed value or + instance domain). + """ + member_set = set(union_members) + if isinstance(val, type) and val in member_set: + return True + if isinstance(val, list) and all( + isinstance(v, type) and v in member_set for v in val + ): + return True + return False + + +def _is_union_constraint_dict(val) -> bool: + """ + Check if *val* is a dict whose keys are all types, i.e. a per-member + constraint mapping for a union field. + + Example: ``{OptionA: {"number": [2, 3]}, OptionB: {}}`` + """ + return isinstance(val, dict) and all(isinstance(k, type) for k in val) + + +def _enumerate_union_members( + union_members: tuple, + val=None, +) -> list: + """ + Enumerate instances for a union-typed field. + + *val* controls which members are enumerated and how: + + - ``None`` - enumerate all members with their default domains. + - A single type (e.g. ``OptionB``) - enumerate only that member. + - A list of types (e.g. ``[OptionA, OptionB]``) - enumerate those members. + - A dict mapping types to constraint dicts + (e.g. ``{OptionA: {"number": [2, 3]}, OptionB: {}}``) - + enumerate only the listed members, forwarding the constraint dicts. + """ + # No override - enumerate all members with defaults + if val is None: + domain: list = [] + for member_type in union_members: + domain.extend(_enumerate_instances(member_type)) + return domain + + # Single type + if isinstance(val, type): + return list(_enumerate_instances(val)) + + # List of types + if isinstance(val, list) and all(isinstance(v, type) for v in val): + domain = [] + for member_type in val: + domain.extend(_enumerate_instances(member_type)) + return domain + + # Dict of type → constraint dict + if _is_union_constraint_dict(val): + domain = [] + for member_type, member_kwargs in cast(dict, val).items(): + domain.extend(_enumerate_instances(member_type, **member_kwargs)) + return domain + + raise ValueError( + f"Invalid value for union field: {val!r}. " + "Expected a union member type, a list of types, or a dict mapping " + "types to constraint dicts." + ) + + def _enumerate_instances(cls: Type[T], **kwargs) -> Generator[T, None, None]: """ Yields all instances of a dataclass given its class. - The enumeration logic supports defining domains for fields using the `domain` - metadata key. This allows fields to specify their valid range of values for - enumeration directly in the definition. Additionally, boolean fields are - automatically enumerated with `[True, False]`. Enum fields are enumerated - with all their members, and Literal types with their defined values. + The enumeration logic supports defining domains for fields using the + ``domain`` metadata key. Additionally, boolean fields are automatically + enumerated with ``[True, False]``, Enum fields with all their members, + and Literal types with their defined values. + + **Nested dataclass fields** can be constrained by passing a dict:: + + _enumerate_instances(Outer, inner={"option": True}) + + **Union-typed fields** support several override forms: + + - A single type to select one member:: + + _enumerate_instances(Config, option=OptionB) + + - A list of types to select a subset:: + + _enumerate_instances(Config, option=[OptionA, OptionB]) + + - A dict mapping types to constraint dicts:: + + _enumerate_instances(Config, option={OptionA: {"number": [2, 3]}, OptionB: {}}) Args: cls (Type[T]): The dataclass type to enumerate. - **kwargs: Fixed values or domains for fields. If a value is a list + **kwargs: Fixed values or domains for fields. If a value is a list and the corresponding field is kw_only, it is treated as a domain - to enumerate over. + to enumerate over. For nested dataclass fields a ``dict`` value + is forwarded as keyword arguments. For union-typed fields a type, + list of types, or ``dict[type, dict]`` controls member selection + and constraints. Returns: - Generator[T, None, None]: A generator yielding instances of the dataclass. + Generator[T, None, None]: A generator yielding instances of the + dataclass. Raises: ValueError: If a field cannot be enumerated (no domain found). - - Example: - - .. code-block:: python - from dataclasses import dataclass, field, KW_ONLY - @dataclass - class MyConfig: - # Not part of enumeration - name: str - _ : KW_ONLY - # Part of enumeration with implicit domain [True, False] - enable_logging: bool = field(kw_only=True) - # Explicit domain in metadata - retry_count: int = field(metadata={"domain": [1, 3, 5]}, kw_only=True) """ names = [] @@ -77,6 +164,29 @@ class MyConfig: if name in kwargs: val = kwargs[name] + + is_union = _is_union_type(current_type) + union_members = get_args(current_type) if is_union else () + + # Union field with a type filter or constraint dict + if is_union and ( + _is_type_filter(val, union_members) or _is_union_constraint_dict(val) + ): + names.append(name) + values.append(_enumerate_union_members(union_members, val)) + continue + + # Nested dataclass field with a dict of constraints + if ( + isinstance(val, dict) + and not is_union + and isinstance(current_type, type) + and hasattr(current_type, "__dataclass_fields__") + ): + names.append(name) + values.append(list(_enumerate_instances(current_type, **val))) + continue + # If kw_only and list, it's a domain to enumerate if field.kw_only and isinstance(val, list): names.append(name) @@ -111,13 +221,8 @@ class MyConfig: continue # Union types (e.g., OptionA | OptionB or Union[OptionA, OptionB]) - if get_origin(current_type) is Union or isinstance( - current_type, types.UnionType - ): - union_domain = [] - for member_type in get_args(current_type): - union_domain.extend(_enumerate_instances(member_type)) - values.append(union_domain) + if _is_union_type(current_type): + values.append(_enumerate_union_members(get_args(current_type), None)) continue # Nested dataclass types diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 1d1fd170c3..dde6b23d35 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -3,12 +3,21 @@ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Optional +from typing import cast, Optional, Callable, Any + +import pandas as pd from ._application import Application from ._architecture import Architecture -from ._qre import _estimate_parallel +from ._qre import ( + _estimate_parallel, + _EstimationCollection, + Trace, + FactoryResult, + instruction_name, +) from ._trace import TraceQuery, PSSPC, LatticeSurgery from ._instruction import InstructionSource from ._isa_enumeration import ISAQuery @@ -21,6 +30,8 @@ def estimate( trace_query: Optional[TraceQuery] = None, *, max_error: float = 1.0, + post_process: bool = False, + name: Optional[str] = None, ) -> EstimationTable: """ Estimate the resource requirements for a given application instance and @@ -40,9 +51,15 @@ def estimate( Args: application (Application): The quantum application to be estimated. architecture (Architecture): The target quantum architecture. + isa_query (ISAQuery): The ISA query to enumerate ISAs from the architecture. trace_query (TraceQuery): The trace query to enumerate traces from the application. - isa_query (ISAQuery): The ISA query to enumerate ISAs from the architecture. + max_error (float): The maximum allowed error for the estimation results. + post_process (bool): If True, use the Python-threaded estimation path + (intended for future post-processing logic). If False (default), + use the Rust parallel estimation path. + name (Optional[str]): An optional name for the estimation. If give, this + will be added as a first column to the results table for all entries. Returns: EstimationTable: A table containing the optimal estimation results @@ -54,22 +71,50 @@ def estimate( if trace_query is None: trace_query = PSSPC.q() * LatticeSurgery.q() - # Obtain all results - results = _estimate_parallel( - list(trace_query.enumerate(app_ctx)), - list(isa_query.enumerate(arch_ctx)), - max_error, - ) + if post_process: + # Enumerate traces with their parameters so we can post-process later + params_and_traces = list(trace_query.enumerate(app_ctx, track_parameters=True)) + isas = list(isa_query.enumerate(arch_ctx)) + + # Estimate all trace × ISA combinations using Python threads + collection = _EstimationCollection() + + def _estimate_one(params, trace, isa): + result = trace.estimate(isa, max_error) + if result is not None: + result = app_ctx.application.post_process(params, result) + return result + + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(_estimate_one, params, trace, isa) + for params, trace in cast(list[tuple[Any, Trace]], params_and_traces) + for isa in isas + ] + for future in futures: + result = future.result() + if result is not None: + collection.insert(result) + else: + traces = list(trace_query.enumerate(app_ctx)) + isas = list(isa_query.enumerate(arch_ctx)) + + # Use the Rust parallel estimation path + collection = _estimate_parallel(cast(list[Trace], traces), isas, max_error) # Post-process the results and add them to a results table table = EstimationTable() - for result in results: + if name is not None: + table.insert_column(0, "name", lambda entry: name) + + for result in collection: entry = EstimationTableEntry( qubits=result.qubits, runtime=result.runtime, error=result.error, source=InstructionSource.from_isa(arch_ctx, result.isa), + factories=result.factories.copy(), properties=result.properties.copy(), ) @@ -79,34 +124,148 @@ def estimate( class EstimationTable(list["EstimationTableEntry"]): + """A table of quantum resource estimation results. + + Extends ``list[EstimationTableEntry]`` and provides configurable columns for + displaying estimation data. By default the table includes *qubits*, + *runtime* (displayed as a ``pandas.Timedelta``), and *error* columns. + Additional columns can be added or inserted with :meth:`add_column` and + :meth:`insert_column`. + """ + def __init__(self): + """Initialize an empty estimation table with default columns.""" super().__init__() - def as_frame(self): - try: - import pandas as pd - except ImportError: - raise ImportError( - "Missing optional 'pandas' dependency. To install run: " - "pip install pandas" + self._columns: list[tuple[str, EstimationTableColumn]] = [ + ("qubits", EstimationTableColumn(lambda entry: entry.qubits)), + ( + "runtime", + EstimationTableColumn( + lambda entry: entry.runtime, + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ), + ), + ("error", EstimationTableColumn(lambda entry: entry.error)), + ] + + def add_column( + self, + name: str, + function: Callable[[EstimationTableEntry], Any], + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: + """Adds a column to the estimation table. + + Args: + name (str): The name of the column. + function (Callable[[EstimationTableEntry], Any]): A function that + takes an EstimationTableEntry and returns the value for this + column. + formatter (Optional[Callable[[Any], Any]]): An optional function + that formats the output of `function` for display purposes. + """ + self._columns.append((name, EstimationTableColumn(function, formatter))) + + def insert_column( + self, + index: int, + name: str, + function: Callable[[EstimationTableEntry], Any], + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: + """Inserts a column at the specified index in the estimation table. + + Args: + index (int): The index at which to insert the column. + name (str): The name of the column. + function (Callable[[EstimationTableEntry], Any]): A function that + takes an EstimationTableEntry and returns the value for this + column. + formatter (Optional[Callable[[Any], Any]]): An optional function + that formats the output of `function` for display purposes. + """ + self._columns.insert(index, (name, EstimationTableColumn(function, formatter))) + + def add_factory_summary_column(self) -> None: + """Adds a column to the estimation table that summarizes the factories used in the estimation.""" + + def summarize_factories(entry: EstimationTableEntry) -> str: + if not entry.factories: + return "None" + return ", ".join( + f"{factory_result.copies}×{instruction_name(id)}" + for id, factory_result in entry.factories.items() ) + self.add_column("factories", summarize_factories) + + def as_frame(self): + """Convert the estimation table to a :class:`pandas.DataFrame`. + + Each row corresponds to an :class:`EstimationTableEntry` and each + column is determined by the columns registered on this table. Column + formatters, when present, are applied to the values before they are + placed in the frame. + + Returns: + pandas.DataFrame: A DataFrame representation of the estimation + results. + """ return pd.DataFrame( [ { - "qubits": entry.qubits, - "runtime": pd.Timedelta(entry.runtime, unit="ns"), - "error": entry.error, + column_name: ( + column.formatter(column.function(entry)) + if column.formatter is not None + else column.function(entry) + ) + for column_name, column in self._columns } for entry in self ] ) +@dataclass(frozen=True, slots=True) +class EstimationTableColumn: + """Definition of a single column in an :class:`EstimationTable`. + + Attributes: + function: A callable that extracts the raw column value from an + :class:`EstimationTableEntry`. + formatter: An optional callable that transforms the raw value for + display purposes (e.g. converting nanoseconds to a + ``pandas.Timedelta``). + """ + + function: Callable[[EstimationTableEntry], Any] + formatter: Optional[Callable[[Any], Any]] = None + + @dataclass(frozen=True, slots=True) class EstimationTableEntry: + """A single row in an :class:`EstimationTable`. + + Each entry represents one Pareto-optimal estimation result for a + particular combination of application trace and architecture ISA. + + Attributes: + qubits: Total number of physical qubits required. + runtime: Total runtime of the algorithm in nanoseconds. + error: Total estimated error probability. + source: The instruction source derived from the architecture ISA used + for this estimation. + factories: A mapping from instruction id to the + :class:`FactoryResult` describing the magic-state factory used + and the number of copies required. + properties: Additional key-value properties attached to the + estimation result. + """ + qubits: int runtime: int error: float source: InstructionSource + factories: dict[int, FactoryResult] = field(default_factory=dict) properties: dict[str, int | float | bool | str] = field(default_factory=dict) diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index a907133fd7..1dc5b0a135 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Generator, Iterable, Optional, overload, cast +from typing import Generator, Iterable, Optional from enum import IntEnum from ._architecture import _Context, Architecture @@ -20,11 +20,8 @@ ISA, Constraint, ConstraintBound, - _FloatFunction, _Instruction, - _IntFunction, ISARequirements, - constant_function, instruction_name, ) @@ -83,105 +80,6 @@ def constraint( return c -@overload -def instruction( - id: int, - encoding: Encoding = PHYSICAL, - *, - time: int, - arity: int = 1, - space: Optional[int] = None, - length: Optional[int] = None, - error_rate: float, - **kwargs: int, -) -> _Instruction: ... -@overload -def instruction( - id: int, - encoding: Encoding = PHYSICAL, - *, - time: int | _IntFunction, - arity: None = ..., - space: int | _IntFunction, - length: Optional[int | _IntFunction] = None, - error_rate: float | _FloatFunction, - **kwargs: int, -) -> _Instruction: ... -def instruction( - id: int, - encoding: Encoding = PHYSICAL, - *, - time: int | _IntFunction, - arity: Optional[int] = 1, - space: Optional[int] | _IntFunction = None, - length: Optional[int | _IntFunction] = None, - error_rate: float | _FloatFunction, - **kwargs: int, -) -> _Instruction: - """ - Creates an instruction. - - Args: - id (int): The instruction ID. - encoding (Encoding): The instruction encoding. PHYSICAL (0) or LOGICAL (1). - time (int | _IntFunction): The instruction time in ns. - arity (Optional[int]): The instruction arity. If None, instruction is - assumed to have variable arity. Default is 1. One can use variable arity - functions for time, space, length, and error_rate in this case. - space (Optional[int] | _IntFunction): The instruction space in number of - physical qubits. If None, length is used. - length (Optional[int | _IntFunction]): The arity including ancilla - qubits. If None, arity is used. - error_rate (float | _FloatFunction): The instruction error rate. - **kwargs (int): Additional properties to set on the instruction. - Valid property names: distance. - - Returns: - _Instruction: The instruction. - - Raises: - ValueError: If an unknown property name is provided in kwargs. - """ - if arity is not None: - instr = _Instruction.fixed_arity( - id, - encoding, - arity, - cast(int, time), - cast(int | None, space), - cast(int | None, length), - cast(float, error_rate), - ) - else: - if isinstance(time, int): - time = constant_function(time) - if isinstance(space, int): - space = constant_function(space) - if isinstance(length, int): - length = constant_function(length) - if isinstance(error_rate, float): - error_rate = constant_function(error_rate) - - instr = _Instruction.variable_arity( - id, - encoding, - time, - cast(_IntFunction, space), - cast(_FloatFunction, error_rate), - length, - ) - - for key, value in kwargs.items(): - try: - prop_key = PropertyKey[key.upper()] - except KeyError: - raise ValueError( - f"Unknown property '{key}'. Valid properties: {[k.name.lower() for k in PropertyKey]}" - ) - instr.set_property(prop_key, value) - - return instr - class ISATransform(ABC): """ @@ -279,7 +177,7 @@ def bind(cls, name: str, node: ISAQuery) -> _BindingNode: return cls.q().bind(name, node) -@dataclass(frozen=True, slots=True) +@dataclass(slots=True) class InstructionSource: nodes: list[_InstructionSourceNode] = field(default_factory=list, init=False) roots: list[int] = field(default_factory=list, init=False) @@ -314,7 +212,7 @@ def _make_node( ] node = graph.add_node( - ctx._provenance.instruction_id(source), + ctx._provenance.instruction(source), ctx._transforms.get(ctx._provenance.transform_id(source)), children, ) @@ -326,8 +224,9 @@ def _make_node( source_table: dict[int, int] = {} for inst in isa: - if inst.source != 0: - node = _make_node(graph, source_table, inst.source) + node_idx = isa.node_index(inst.id) + if node_idx is not None and node_idx != 0: + node = _make_node(graph, source_table, node_idx) graph.add_root(node) return graph @@ -337,17 +236,17 @@ def add_root(self, node_id: int) -> None: def add_node( self, - id: int, + instruction: _Instruction, transform: Optional[ISATransform | Architecture], children: list[int], ) -> int: - node_id = self.nodes.__len__() - self.nodes.append(_InstructionSourceNode(id, transform, children)) + node_id = len(self.nodes) + self.nodes.append(_InstructionSourceNode(instruction, transform, children)) return node_id def __str__(self) -> str: def _format_node(node: _InstructionSourceNode, indent: int = 0) -> str: - result = " " * indent + f"{instruction_name(node.id) or '??'}" + result = " " * indent + f"{instruction_name(node.instruction.id) or '??'}" if node.transform is not None: result += f" @ {node.transform}" for child_index in node.children: @@ -358,9 +257,108 @@ def _format_node(node: _InstructionSourceNode, indent: int = 0) -> str: _format_node(self.nodes[root_index]) for root_index in self.roots ) + def __getitem__(self, id: int) -> _InstructionSourceNodeReference: + """ + Retrieves the first instruction source root node with the given + instruction ID. Raises KeyError if no such node exists. + + Args: + id (int): The instruction ID to search for. + + Returns: + _InstructionSourceNodeReference: The first instruction source node with the + given instruction ID. + """ + if (node := self.get(id)) is not None: + return node + + raise KeyError(f"Instruction ID {id} not found in instruction source graph.") + + def get( + self, id: int, default: Optional[_InstructionSourceNodeReference] = None + ) -> Optional[_InstructionSourceNodeReference]: + """ + Retrieves the first instruction source root node with the given + instruction ID. Returns default if no such node exists. + + Args: + id (int): The instruction ID to search for. + default (Optional[_InstructionSourceNodeReference]): The value to return if no + node with the given ID is found. Default is None. + + Returns: + Optional[_InstructionSourceNodeReference]: The first instruction source node with the + given instruction ID, or default if no such node exists. + """ + for root in self.roots: + if self.nodes[root].instruction.id == id: + return _InstructionSourceNodeReference(self, root) + + return default + @dataclass(frozen=True, slots=True) class _InstructionSourceNode: - id: int + instruction: _Instruction transform: Optional[ISATransform | Architecture] children: list[int] + + +class _InstructionSourceNodeReference: + def __init__(self, graph: InstructionSource, node_id: int): + self.graph = graph + self.node_id = node_id + + @property + def instruction(self) -> _Instruction: + return self.graph.nodes[self.node_id].instruction + + @property + def transform(self) -> Optional[ISATransform | Architecture]: + return self.graph.nodes[self.node_id].transform + + def __str__(self) -> str: + return str(self.graph.nodes[self.node_id]) + + def __getitem__(self, id: int) -> _InstructionSourceNodeReference: + """ + Retrieves the first child instruction source node with the given + instruction ID. Raises KeyError if no such node exists. + + Args: + id (int): The instruction ID to search for. + + Returns: + _InstructionSourceNodeReference: The first child instruction source node with the + given instruction ID. + """ + if (node := self.get(id)) is not None: + return node + + raise KeyError( + f"Instruction ID {id} not found in children of instruction {instruction_name(self.instruction.id) or '??'}." + ) + + def get( + self, id: int, default: Optional[_InstructionSourceNodeReference] = None + ) -> Optional[_InstructionSourceNodeReference]: + """ + Retrieves the first child instruction source node with the given + instruction ID. Returns default if no such node exists. + + Args: + id (int): The instruction ID to search for. + default (Optional[_InstructionSourceNodeReference]): The value to return if no + node with the given ID is found. Default is None. + + Returns: + Optional[_InstructionSourceNodeReference]: The first child instruction source + node with the given instruction ID, or default if no such node + exists. + """ + + for child_id in self.graph.nodes[self.node_id].children: + if self.graph.nodes[child_id].instruction.id == id: + return _InstructionSourceNodeReference(self.graph, child_id) + + return default diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index c3301a448d..458bf1e842 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -5,28 +5,6 @@ from __future__ import annotations from typing import Any, Callable, Iterator, Optional, overload class ISA: - @overload - def __new__(cls, *instructions: _Instruction) -> ISA: ... - @overload - def __new__(cls, instructions: list[_Instruction], /) -> ISA: ... - def __new__(cls, *instructions: _Instruction | list[_Instruction]) -> ISA: - """ - Creates an ISA from a list of instructions. - - Args: - instructions (list[_Instruction] | *_Instruction): The list of instructions. - """ - ... - - def append(self, instruction: _Instruction) -> None: - """ - Appends an instruction to the ISA. - - Args: - instruction (_Instruction): The instruction to append. - """ - ... - def __add__(self, other: ISA) -> ISA: """ Concatenates two ISAs (logical union). Instructions in the second @@ -90,6 +68,28 @@ class ISA: """ ... + def node_index(self, id: int) -> Optional[int]: + """ + Returns the provenance graph node index for the given instruction ID. + + Args: + id (int): The instruction ID. + + Returns: + Optional[int]: The node index, or None if not found. + """ + ... + + def add_node(self, instruction_id: int, node_index: int) -> None: + """ + Adds a pre-existing provenance graph node to the ISA. + + Args: + instruction_id (int): The instruction ID. + node_index (int): The node index in the provenance graph. + """ + ... + def __iter__(self) -> Iterator[_Instruction]: """ Returns an iterator over the instructions. @@ -376,6 +376,18 @@ class _Instruction: """ ... + def __getitem__(self, key: int) -> int: + """ + Gets a property by its key, or raises an error if not found. + + Args: + key (int): The property key. + + Returns: + int: The property value. + """ + ... + def __str__(self) -> str: """ Returns a string representation of the instruction. @@ -545,18 +557,23 @@ def linear_function( ... @overload -def block_linear_function(block_size: int, slope: int) -> _IntFunction: ... +def block_linear_function(block_size: int, slope: int, offset: int) -> _IntFunction: ... @overload -def block_linear_function(block_size: int, slope: float) -> _FloatFunction: ... def block_linear_function( - block_size: int, slope: int | float + block_size: int, slope: float, offset: float +) -> _FloatFunction: ... +def block_linear_function( + block_size: int, slope: int | float, offset: int | float ) -> _IntFunction | _FloatFunction: """ - Creates a block linear function. + Creates a block linear function that takes an arity (number of qubits) as + input. Given an arity, it will compute the number of blocks `num_blocks` by + computing `ceil(arity / block_size)` and then return `slope * num_blocks + + offset`. Args: - block_size (int): The block size. - slope (int | float): The slope. + block_size (int): The block size. slope (int | float): The slope. offset + (int | float): The offset Returns: _IntFunction | _FloatFunction: The block linear function. @@ -597,13 +614,13 @@ class _ProvenanceGraph: """ def add_node( - self, instruction_id: int, transform_id: int, children: list[int] + self, instruction: _Instruction, transform_id: int, children: list[int] ) -> int: """ Adds a node to the provenance graph. Args: - instruction_id (int): The instruction ID corresponding to the node. + instruction (int): The instruction corresponding to the node. transform_id (int): The transform ID corresponding to the node. children (list[int]): The list of child node indices in the provenance graph. @@ -612,15 +629,15 @@ class _ProvenanceGraph: """ ... - def instruction_id(self, node_index: int) -> int: + def instruction(self, node_index: int) -> _Instruction: """ - Returns the instruction ID for a given node index. + Returns the instruction for a given node index. Args: node_index (int): The index of the node in the provenance graph. Returns: - int: The instruction ID corresponding to the node. + int: The instruction corresponding to the node. """ ... @@ -666,6 +683,74 @@ class _ProvenanceGraph: """ ... + @overload + def add_instruction( + self, + instruction: _Instruction, + ) -> int: ... + @overload + def add_instruction( + self, + id: int, + encoding: int = 0, + *, + arity: Optional[int] = 1, + time: int | _IntFunction = ..., + space: Optional[int | _IntFunction] = None, + length: Optional[int | _IntFunction] = None, + error_rate: float | _FloatFunction = ..., + **kwargs: int, + ) -> int: ... + def add_instruction( + self, + id_or_instruction: int | _Instruction, + encoding: int = 0, + *, + arity: Optional[int] = 1, + time: int | _IntFunction = ..., + space: Optional[int | _IntFunction] = None, + length: Optional[int | _IntFunction] = None, + error_rate: float | _FloatFunction = ..., + **kwargs: int, + ) -> int: + """ + Adds an instruction to the provenance graph with no transform or + children. + + Can be called with a pre-existing ``_Instruction`` or with keyword + args to create one inline. + + Args: + id_or_instruction: An instruction ID (int) or ``_Instruction``. + encoding: 0 = Physical, 1 = Logical. Ignored for ``_Instruction``. + arity: Instruction arity, ``None`` for variable. Ignored for + ``_Instruction``. + time: Time in ns (or ``_IntFunction``). Ignored for ``_Instruction``. + space: Space in physical qubits (or ``_IntFunction``). Ignored for + ``_Instruction``. + length: Arity including ancillas. Ignored for ``_Instruction``. + error_rate: Error rate (or ``_FloatFunction``). Ignored for + ``_Instruction``. + **kwargs: Additional properties (e.g. ``distance=9``). + + Returns: + int: The node index of the added instruction. + """ + ... + + def make_isa(self, node_indices: list[int]) -> ISA: + """ + Creates an ISA backed by this provenance graph from the given node + indices. + + Args: + node_indices: A list of node indices in the provenance graph. + + Returns: + ISA: An ISA referencing this provenance graph. + """ + ... + class EstimationResult: """ Represents the result of a resource estimation. @@ -697,12 +782,13 @@ class EstimationResult: """ ... - def add_qubits(self, qubits: int) -> None: + @qubits.setter + def qubits(self, qubits: int) -> None: """ - Adds to the number of logical qubits. + Sets the number of logical qubits. Args: - qubits (int): The number of logical qubits to add. + qubits (int): The number of logical qubits to set. """ ... @@ -716,12 +802,13 @@ class EstimationResult: """ ... - def add_runtime(self, runtime: int) -> None: + @runtime.setter + def runtime(self, runtime: int) -> None: """ - Adds to the runtime. + Sets the runtime. Args: - runtime (int): The amount of runtime in nanoseconds to add. + runtime (int): The runtime in nanoseconds to set. """ ... @@ -735,12 +822,13 @@ class EstimationResult: """ ... - def add_error(self, error: float) -> None: + @error.setter + def error(self, error: float) -> None: """ - Adds to the error probability. + Sets the error probability. Args: - error (float): The amount to add to the error probability. + error (float): The error probability to set. """ ... @@ -774,6 +862,17 @@ class EstimationResult: """ ... + def set_property(self, key: str, value: bool | int | float | str) -> None: + """ + Sets a custom property. + + Args: + key (str): The property key. + value (bool | int | float | str): The property value. All values of type `int`, `float`, `bool`, and `str` + are supported. Any other value is converted to a string using its `__str__` method. + """ + ... + def __str__(self) -> str: """ Returns a string representation of the estimation result. diff --git a/source/pip/qsharp/qre/_trace.py b/source/pip/qsharp/qre/_trace.py index efe7985390..15873ebf15 100644 --- a/source/pip/qsharp/qre/_trace.py +++ b/source/pip/qsharp/qre/_trace.py @@ -5,8 +5,11 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, KW_ONLY, field from itertools import product -from typing import Any, Optional, Generator, Type -from ._application import _Context +from types import NoneType +from typing import Any, Optional, Generator, Type, TYPE_CHECKING + +if TYPE_CHECKING: + from ._application import _Context from ._enumeration import _enumerate_instances from ._qre import PSSPC as _PSSPC, LatticeSurgery as _LatticeSurgery, Trace @@ -52,38 +55,53 @@ class _Node(ABC): def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: ... -class RootNode(_Node): - # NOTE: this might be redundant with TransformationNode with an empty sequence - def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: - yield from ctx.application.enumerate_traces(**ctx.kwargs) - - class TraceQuery(_Node): + # This is a sequence of trace transforms together with possible kwargs to + # override their default domains. The first element might be sequence: list[tuple[Type, dict[str, Any]]] def __init__(self, t: Type, **kwargs): self.sequence = [(t, kwargs)] - def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: - for trace in ctx.application.enumerate_traces(**ctx.kwargs): - if not self.sequence: - yield trace + def enumerate( + self, ctx: _Context, track_parameters: bool = False + ) -> Generator[Trace | tuple[Any, Trace], None, None]: + sequence = self.sequence + kwargs = {} + if len(sequence) > 0 and sequence[0][0] is NoneType: + kwargs = sequence[0][1] + sequence = sequence[1:] + + if track_parameters: + source = ctx.application.enumerate_traces_with_parameters(**kwargs) + else: + source = ((None, t) for t in ctx.application.enumerate_traces(**kwargs)) + + for params, trace in source: + if not sequence: + yield (params, trace) if track_parameters else trace continue transformer_instances = [] - for t, transformer_kwargs in self.sequence: + for t, transformer_kwargs in sequence: instances = _enumerate_instances(t, **transformer_kwargs) transformer_instances.append(instances) # TODO: make parallel - for sequence in product(*transformer_instances): + for combination in product(*transformer_instances): transformed = trace - for transformer in sequence: + for transformer in combination: transformed = transformer.transform(transformed) - yield transformed + yield (params, transformed) if track_parameters else transformed def __mul__(self, other: TraceQuery) -> TraceQuery: new_query = TraceQuery.__new__(TraceQuery) + + if len(other.sequence) > 0 and other.sequence[0][0] is NoneType: + raise ValueError( + "Cannot multiply with a TraceQuery that has a None transform at the beginning of its sequence." + ) + new_query.sequence = self.sequence + other.sequence return new_query diff --git a/source/pip/qsharp/qre/models/factories/_litinski.py b/source/pip/qsharp/qre/models/factories/_litinski.py index 3ce98c31f7..d4f35117e4 100644 --- a/source/pip/qsharp/qre/models/factories/_litinski.py +++ b/source/pip/qsharp/qre/models/factories/_litinski.py @@ -8,8 +8,8 @@ from typing import Generator from ..._architecture import _Context -from ..._qre import ISA, ISARequirements, ConstraintBound, _Instruction -from ..._instruction import ISATransform, constraint, instruction, LOGICAL +from ..._qre import ISARequirements, ConstraintBound, ISA +from ..._instruction import ISATransform, constraint, LOGICAL from ...instruction_ids import T, CNOT, H, MEAS_Z, CCZ @@ -90,7 +90,7 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non + impl_isa[MEAS_Z].expect_time() ) - def make_instruction(entry: _Entry) -> _Instruction: + def make_node(entry: _Entry) -> int: # Convert cycles (number of syndrome extraction cycles) to time # based on fast surface code time = ceil(syndrome_extraction_time * entry.cycles) @@ -99,28 +99,29 @@ def make_instruction(entry: _Entry) -> _Instruction: # space cost is divided by the number of output states. This is a # simplification that allows us to fit all protocols in the ISA, but # it may not be accurate for all protocols. - inst = instruction( + return ctx.add_instruction( entry.state, arity=3 if entry.state == CCZ else 1, encoding=LOGICAL, space=ceil(entry.space / entry.output_states), time=time, error_rate=entry.error_rate, + transform=self, + source=[cnot, h, meas_z, t], ) - return ctx.set_source(self, inst, [cnot, h, meas_z, t]) # Yield combinations of T and CCZ entries if ccz_entries: for t_entry in t_entries: for ccz_entry in ccz_entries: - yield ISA( - make_instruction(t_entry), - make_instruction(ccz_entry), + yield ctx.make_isa( + make_node(t_entry), + make_node(ccz_entry), ) else: # Table 2 scenarios: only T gates available for t_entry in t_entries: - yield ISA(make_instruction(t_entry)) + yield ctx.make_isa(make_node(t_entry)) def _initialize_entries(self): self._entries = { diff --git a/source/pip/qsharp/qre/models/factories/_round_based.py b/source/pip/qsharp/qre/models/factories/_round_based.py index bdde7cfcf9..b14a0336fe 100644 --- a/source/pip/qsharp/qre/models/factories/_round_based.py +++ b/source/pip/qsharp/qre/models/factories/_round_based.py @@ -18,7 +18,6 @@ ISAQuery, ISATransform, constraint, - instruction, ) from ..._architecture import _Context from ...instruction_ids import CNOT, LATTICE_SURGERY, T, MEAS_ZZ @@ -111,7 +110,7 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non if self.use_cache and cache_path.exists(): cached_states = InstructionFrontier.load(str(cache_path)) for state in cached_states: - yield ISA(state) + yield ctx.make_isa(ctx.add_instruction(state)) return # 2) Compute as before @@ -164,7 +163,9 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non optimal_states.dump(str(cache_path)) for state in optimal_states: - yield ISA(ctx.set_source(self, state, [impl_isa[T]])) + yield ctx.make_isa( + ctx.add_instruction(state, transform=self, source=[impl_isa[T]]) + ) def _physical_units(self, gate_time, clifford_error) -> list[_DistillationUnit]: return [ @@ -212,12 +213,14 @@ def _logical_units( ] def _state_from_pipeline(self, pipeline: _Pipeline) -> _Instruction: - return instruction( + return _Instruction.fixed_arity( T, - encoding=LOGICAL, - time=pipeline.time, - error_rate=pipeline.error_rate, - space=pipeline.space, + int(LOGICAL), + 1, + pipeline.time, + pipeline.space, + None, + pipeline.error_rate, ) def _cache_key(self, impl_isa: ISA) -> str: diff --git a/source/pip/qsharp/qre/models/factories/_utils.py b/source/pip/qsharp/qre/models/factories/_utils.py index c52b3583e4..dcd72c6afe 100644 --- a/source/pip/qsharp/qre/models/factories/_utils.py +++ b/source/pip/qsharp/qre/models/factories/_utils.py @@ -58,7 +58,7 @@ class MagicUpToClifford(ISATransform): def required_isa() -> ISARequirements: return ISARequirements() - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa(self, impl_isa, ctx: _Context) -> Generator[ISA, None, None]: # Families of equivalent gates under Clifford conjugation. families = [ [ @@ -79,11 +79,12 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non instr = impl_isa[id] for equivalent_id in family: if equivalent_id != id: - impl_isa.append( - ctx.set_source( - self, instr.with_id(equivalent_id), [instr] - ) + node_idx = ctx.add_instruction( + instr.with_id(equivalent_id), + transform=self, + source=[instr], ) + impl_isa.add_node(equivalent_id, node_idx) break # Check next family yield impl_isa diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py index 6758d5796d..ce80a23506 100644 --- a/source/pip/qsharp/qre/models/qec/_surface_code.py +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -8,7 +8,6 @@ ISA, ISARequirements, ISATransform, - instruction, constraint, ConstraintBound, LOGICAL, @@ -96,16 +95,16 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non # We provide a generic lattice surgery instruction (See Section 3 in # arXiv:1111.4022) - lattice_surgery = instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - distance=self.distance, - ) - - yield ISA( - ctx.set_source(self, lattice_surgery, [cnot, h, meas_z]), + yield ctx.make_isa( + ctx.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + transform=self, + source=[cnot, h, meas_z], + distance=self.distance, + ), ) diff --git a/source/pip/qsharp/qre/models/qec/_three_aux.py b/source/pip/qsharp/qre/models/qec/_three_aux.py index f276061c73..2af1879205 100644 --- a/source/pip/qsharp/qre/models/qec/_three_aux.py +++ b/source/pip/qsharp/qre/models/qec/_three_aux.py @@ -11,7 +11,6 @@ LOGICAL, ISATransform, constraint, - instruction, ) from ..._qre import ( ISA, @@ -103,16 +102,16 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non ) ) - lattice_surgery = instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - distance=self.distance, - ) - - yield ISA( - ctx.set_source(self, lattice_surgery, [meas_x, meas_z, meas_xx, meas_zz]) + yield ctx.make_isa( + ctx.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + transform=self, + source=[meas_x, meas_z, meas_xx, meas_zz], + distance=self.distance, + ) ) diff --git a/source/pip/qsharp/qre/models/qec/_yoked.py b/source/pip/qsharp/qre/models/qec/_yoked.py index 280ee4bb24..9b24069aa2 100644 --- a/source/pip/qsharp/qre/models/qec/_yoked.py +++ b/source/pip/qsharp/qre/models/qec/_yoked.py @@ -6,7 +6,7 @@ from math import ceil from typing import Generator -from ..._instruction import ISATransform, constraint, LOGICAL, PropertyKey, instruction +from ..._instruction import ISATransform, constraint, LOGICAL, PropertyKey from ..._qre import ISA, ISARequirements, generic_function from ..._architecture import _Context from ...instruction_ids import LATTICE_SURGERY, MEMORY @@ -102,19 +102,17 @@ def error_rate(arity: int) -> float: error_rate_fn = generic_function(error_rate) - yield ISA( - ctx.set_source( - self, - instruction( - MEMORY, - arity=None, - encoding=LOGICAL, - space=space_fn, - time=time_fn, - error_rate=error_rate_fn, - distance=distance, - ), - [lattice_surgery], + yield ctx.make_isa( + ctx.add_instruction( + MEMORY, + arity=None, + encoding=LOGICAL, + space=space_fn, + time=time_fn, + error_rate=error_rate_fn, + transform=self, + source=[lattice_surgery], + distance=distance, ) ) diff --git a/source/pip/qsharp/qre/models/qubits/_aqre.py b/source/pip/qsharp/qre/models/qubits/_aqre.py index 981c6223d7..2726daf219 100644 --- a/source/pip/qsharp/qre/models/qubits/_aqre.py +++ b/source/pip/qsharp/qre/models/qubits/_aqre.py @@ -2,10 +2,11 @@ # Licensed under the MIT License. from dataclasses import KW_ONLY, dataclass, field +from typing import Optional -from ..._architecture import Architecture +from ..._architecture import Architecture, _Context from ...instruction_ids import CNOT, CZ, MEAS_Z, PAULI_I, H, T -from ..._instruction import ISA, Encoding, instruction +from ..._instruction import ISA, Encoding @dataclass @@ -14,11 +15,18 @@ class AQREGateBased(Architecture): A generic gate-based architecture based on the qubit parameters in Azure Quantum Resource Estimator (AQRE, [arXiv:2211.07629](https://arxiv.org/abs/2211.07629)). The error rate can - be set arbitrarily and is either 1e-3 or 1e-4 in the reference. Gate times - are set to 50ns and measurement times are set to 100ns, which are typical - for superconducting transmon qubits + be set arbitrarily and is either 1e-3 or 1e-4 in the reference. Typical + gate times are 50ns and measurement times are 100ns for superconducting + transmon qubits [arXiv:cond-mat/0703002](https://arxiv.org/abs/cond-mat/0703002). + Args: + error_rate: The error rate for all gates. Defaults to 1e-4. + gate_time: The time (in ns) for single-qubit gates. + measurement_time: The time (in ns) for measurement operations. + two_qubit_gate_time: The time (in ns) for two-qubit gates (CNOT, CZ). + If not provided, defaults to the value of ``gate_time``. + References: - Michael E. Beverland, Prakash Murali, Matthias Troyer, Krysta M. Svore, @@ -34,49 +42,58 @@ class AQREGateBased(Architecture): _: KW_ONLY error_rate: float = field(default=1e-4) + gate_time: int + measurement_time: int + two_qubit_gate_time: Optional[int] = field(default=None) + + def __post_init__(self): + if self.two_qubit_gate_time is None: + self.two_qubit_gate_time = self.gate_time + + def provided_isa(self, ctx: _Context) -> ISA: + # Value is initialized in __post_init__ + assert self.two_qubit_gate_time is not None - @property - def provided_isa(self) -> ISA: - return ISA( - instruction( + return ctx.make_isa( + ctx.add_instruction( PAULI_I, encoding=Encoding.PHYSICAL, arity=1, - time=50, + time=self.gate_time, error_rate=self.error_rate, ), - instruction( + ctx.add_instruction( CNOT, encoding=Encoding.PHYSICAL, arity=2, - time=50, + time=self.two_qubit_gate_time, error_rate=self.error_rate, ), - instruction( + ctx.add_instruction( CZ, encoding=Encoding.PHYSICAL, arity=2, - time=50, + time=self.two_qubit_gate_time, error_rate=self.error_rate, ), - instruction( + ctx.add_instruction( H, encoding=Encoding.PHYSICAL, arity=1, - time=50, + time=self.gate_time, error_rate=self.error_rate, ), - instruction( + ctx.add_instruction( MEAS_Z, encoding=Encoding.PHYSICAL, arity=1, - time=100, + time=self.measurement_time, error_rate=self.error_rate, ), - instruction( + ctx.add_instruction( T, encoding=Encoding.PHYSICAL, - time=50, + time=self.gate_time, error_rate=self.error_rate, ), ) diff --git a/source/pip/qsharp/qre/models/qubits/_msft.py b/source/pip/qsharp/qre/models/qubits/_msft.py index 9ce6fcb3c9..022157c1d4 100644 --- a/source/pip/qsharp/qre/models/qubits/_msft.py +++ b/source/pip/qsharp/qre/models/qubits/_msft.py @@ -3,7 +3,7 @@ from dataclasses import KW_ONLY, dataclass, field -from ..._architecture import Architecture +from ..._architecture import Architecture, _Context from ...instruction_ids import ( T, PREP_X, @@ -13,7 +13,7 @@ MEAS_X, MEAS_Z, ) -from ..._instruction import ISA, instruction +from ..._instruction import ISA @dataclass @@ -47,8 +47,7 @@ class Majorana(Architecture): _: KW_ONLY error_rate: float = field(default=1e-5, metadata={"domain": [1e-4, 1e-5, 1e-6]}) - @property - def provided_isa(self) -> ISA: + def provided_isa(self, ctx: _Context) -> ISA: if abs(self.error_rate - 1e-4) <= 1e-8: t_error_rate = 0.05 elif abs(self.error_rate - 1e-5) <= 1e-8: @@ -56,42 +55,16 @@ def provided_isa(self) -> ISA: elif abs(self.error_rate - 1e-6) <= 1e-8: t_error_rate = 0.01 - return ISA( - instruction( - PREP_X, - time=1000, - error_rate=self.error_rate, + return ctx.make_isa( + ctx.add_instruction(PREP_X, time=1000, error_rate=self.error_rate), + ctx.add_instruction(PREP_Z, time=1000, error_rate=self.error_rate), + ctx.add_instruction( + MEAS_XX, arity=2, time=1000, error_rate=self.error_rate ), - instruction( - PREP_Z, - time=1000, - error_rate=self.error_rate, - ), - instruction( - MEAS_XX, - arity=2, - time=1000, - error_rate=self.error_rate, - ), - instruction( - MEAS_ZZ, - arity=2, - time=1000, - error_rate=self.error_rate, - ), - instruction( - MEAS_X, - time=1000, - error_rate=self.error_rate, - ), - instruction( - MEAS_Z, - time=1000, - error_rate=self.error_rate, - ), - instruction( - T, - time=1000, - error_rate=t_error_rate, + ctx.add_instruction( + MEAS_ZZ, arity=2, time=1000, error_rate=self.error_rate ), + ctx.add_instruction(MEAS_X, time=1000, error_rate=self.error_rate), + ctx.add_instruction(MEAS_Z, time=1000, error_rate=self.error_rate), + ctx.add_instruction(T, time=1000, error_rate=t_error_rate), ) diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 37e9f57ba6..8c016d646a 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -1,11 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use std::{ptr::NonNull, sync::Arc}; +use std::{ + ptr::NonNull, + sync::{Arc, RwLock}, +}; use pyo3::{ IntoPyObjectExt, - exceptions::{PyException, PyKeyError, PyTypeError}, + exceptions::{PyException, PyKeyError, PyRuntimeError, PyTypeError}, prelude::*, types::{PyBool, PyDict, PyFloat, PyInt, PyString, PyTuple, PyType}, }; @@ -46,41 +49,22 @@ pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { pyo3::create_exception!(qsharp.qre, EstimationError, PyException); +fn poisoned_lock_err(_: std::sync::PoisonError) -> PyErr { + PyRuntimeError::new_err("provenance graph lock poisoned") +} + #[allow(clippy::upper_case_acronyms)] #[pyclass] pub struct ISA(qre::ISA); -#[pymethods] impl ISA { - #[new] - #[pyo3(signature = (*instructions))] - pub fn new(instructions: &Bound<'_, PyTuple>) -> PyResult { - if instructions.len() == 1 { - let item = instructions.get_item(0)?; - if let Ok(seq) = item.cast_into::() { - let mut instrs = Vec::with_capacity(seq.len()); - for item in seq.iter() { - let instr = item.cast_into::()?; - instrs.push(instr.borrow().0.clone()); - } - return Ok(ISA(instrs.into_iter().collect())); - } - } - - instructions - .into_iter() - .map(|instr| { - let instr = instr.cast_into::()?; - Ok(instr.borrow().0.clone()) - }) - .collect::>() - .map(ISA) - } - - pub fn append(&mut self, instruction: &Instruction) { - self.0.add_instruction(instruction.0.clone()); + pub fn inner(&self) -> &qre::ISA { + &self.0 } +} +#[pymethods] +impl ISA { pub fn __add__(&self, other: &ISA) -> PyResult { Ok(ISA(self.0.clone() + other.0.clone())) } @@ -99,7 +83,7 @@ impl ISA { pub fn __getitem__(&self, id: u64) -> PyResult { match self.0.get(&id) { - Some(instr) => Ok(Instruction(instr.clone())), + Some(instr) => Ok(Instruction(instr)), None => Err(PyKeyError::new_err(format!( "Instruction with id {id} not found" ))), @@ -109,15 +93,26 @@ impl ISA { #[pyo3(signature = (id, default=None))] pub fn get(&self, id: u64, default: Option<&Instruction>) -> Option { match self.0.get(&id) { - Some(instr) => Some(Instruction(instr.clone())), + Some(instr) => Some(Instruction(instr)), None => default.cloned(), } } + /// Returns the provenance graph node index for the given instruction ID. + pub fn node_index(&self, id: u64) -> Option { + self.0.node_index(&id) + } + + /// Adds a node (by instruction ID and node index) that already exists in the graph. + pub fn add_node(&mut self, instruction_id: u64, node_index: usize) { + self.0.add_node(instruction_id, node_index); + } + #[allow(clippy::needless_pass_by_value)] pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let instructions = slf.0.instructions(); let iter = ISAIterator { - iter: (*slf.0).clone().into_iter(), + iter: instructions.into_iter(), }; Py::new(slf.py(), iter) } @@ -129,7 +124,7 @@ impl ISA { #[pyclass] pub struct ISAIterator { - iter: std::collections::hash_map::IntoIter, + iter: std::vec::IntoIter, } #[pymethods] @@ -139,7 +134,7 @@ impl ISAIterator { } fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { - slf.iter.next().map(|(_, instr)| Instruction(instr)) + slf.iter.next().map(Instruction) } } @@ -300,6 +295,15 @@ impl Instruction { self.0.get_property_or(&key, default) } + pub fn __getitem__(&self, key: u64) -> PyResult { + match self.0.get_property(&key) { + Some(value) => Ok(value), + None => Err(PyKeyError::new_err(format!( + "Property with key {key} not found" + ))), + } + } + fn __str__(&self) -> String { format!("{}", self.0) } @@ -373,6 +377,102 @@ fn convert_encoding(encoding: u64) -> PyResult { } } +/// Property name → integer key mapping (must match Python `_PROPERTY_KEYS`). +fn property_name_to_key(name: &str) -> PyResult { + match name { + "distance" => Ok(0), + other => Err(PyTypeError::new_err(format!( + "Unknown property '{other}'. Valid properties: [\"distance\"]" + ))), + } +} + +/// Build a `qre::Instruction` from either an existing `Instruction` Python +/// object or from keyword arguments (id + encoding + arity + …). +#[allow(clippy::too_many_arguments)] +fn build_instruction( + id_or_instruction: &Bound<'_, PyAny>, + encoding: u64, + arity: Option, + time: Option<&Bound<'_, PyAny>>, + space: Option<&Bound<'_, PyAny>>, + length: Option<&Bound<'_, PyAny>>, + error_rate: Option<&Bound<'_, PyAny>>, + kwargs: Option<&Bound<'_, PyDict>>, +) -> PyResult { + // If the first argument is already an Instruction, return its inner clone. + if let Ok(instr) = id_or_instruction.cast::() { + return Ok(instr.borrow().0.clone()); + } + + // Otherwise treat the first arg as an instruction ID (int). + let id: u64 = id_or_instruction.extract()?; + let enc = convert_encoding(encoding)?; + + let mut instr = if let Some(arity_val) = arity { + // Fixed-arity path + let time_val: u64 = time + .ok_or_else(|| PyTypeError::new_err("'time' is required"))? + .extract()?; + let space_val: Option = space.map(PyAnyMethods::extract).transpose()?; + let length_val: Option = length.map(PyAnyMethods::extract).transpose()?; + let error_rate_val: f64 = error_rate + .ok_or_else(|| PyTypeError::new_err("'error_rate' is required"))? + .extract()?; + qre::Instruction::fixed_arity( + id, + enc, + arity_val, + time_val, + space_val, + length_val, + error_rate_val, + ) + } else { + // Variable-arity path: time/space/error_rate may be functions + let time_fn = + extract_int_function(time.ok_or_else(|| PyTypeError::new_err("'time' is required"))?)?; + let space_fn = extract_int_function( + space.ok_or_else(|| PyTypeError::new_err("'space' is required for variable arity"))?, + )?; + let length_fn = length.map(extract_int_function).transpose()?; + let error_rate_fn = extract_float_function( + error_rate.ok_or_else(|| PyTypeError::new_err("'error_rate' is required"))?, + )?; + qre::Instruction::variable_arity(id, enc, time_fn, space_fn, length_fn, error_rate_fn) + }; + + // Apply additional properties from kwargs + if let Some(kw) = kwargs { + for (key, value) in kw { + let key_str: String = key.extract()?; + let prop_key = property_name_to_key(&key_str)?; + let prop_value: u64 = value.extract()?; + instr.set_property(prop_key, prop_value); + } + } + + Ok(instr) +} + +/// Extract an `_IntFunction` or convert a plain int into a constant function. +fn extract_int_function(obj: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(f) = obj.cast::() { + return Ok(f.borrow().0.clone()); + } + let val: u64 = obj.extract()?; + Ok(qre::VariableArityFunction::constant(val)) +} + +/// Extract a `_FloatFunction` or convert a plain float into a constant function. +fn extract_float_function(obj: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(f) = obj.cast::() { + return Ok(f.borrow().0.clone()); + } + let val: f64 = obj.extract()?; + Ok(qre::VariableArityFunction::constant(val)) +} + #[pyclass] pub struct ConstraintBound(qre::ConstraintBound); @@ -404,40 +504,114 @@ impl ConstraintBound { } } -#[derive(Default)] +#[derive(Clone)] #[pyclass(name = "_ProvenanceGraph")] -pub struct ProvenanceGraph(qre::ProvenanceGraph); +pub struct ProvenanceGraph(Arc>); + +impl Default for ProvenanceGraph { + fn default() -> Self { + Self(Arc::new(RwLock::new(qre::ProvenanceGraph::new()))) + } +} #[pymethods] impl ProvenanceGraph { #[new] pub fn new() -> Self { - Self(qre::ProvenanceGraph::new()) + Self::default() } #[allow(clippy::needless_pass_by_value)] - pub fn add_node(&mut self, id: u64, transform: u64, children: Vec) -> usize { - self.0.add_node(id, transform, &children) + pub fn add_node( + &mut self, + instruction: &Instruction, + transform: u64, + children: Vec, + ) -> PyResult { + Ok(self.0.write().map_err(poisoned_lock_err)?.add_node( + instruction.0.clone(), + transform, + &children, + )) } - pub fn instruction_id(&self, node_index: usize) -> u64 { - self.0.instruction_id(node_index) + pub fn instruction(&self, node_index: usize) -> PyResult { + Ok(Instruction( + self.0 + .read() + .map_err(poisoned_lock_err)? + .instruction(node_index) + .clone(), + )) } - pub fn transform_id(&self, node_index: usize) -> u64 { - self.0.transform_id(node_index) + pub fn transform_id(&self, node_index: usize) -> PyResult { + Ok(self + .0 + .read() + .map_err(poisoned_lock_err)? + .transform_id(node_index)) } - pub fn children(&self, node_index: usize) -> Vec { - self.0.children(node_index).to_vec() + pub fn children(&self, node_index: usize) -> PyResult> { + Ok(self + .0 + .read() + .map_err(poisoned_lock_err)? + .children(node_index) + .to_vec()) } - pub fn num_nodes(&self) -> usize { - self.0.num_nodes() + pub fn num_nodes(&self) -> PyResult { + Ok(self.0.read().map_err(poisoned_lock_err)?.num_nodes()) } - pub fn num_edges(&self) -> usize { - self.0.num_edges() + pub fn num_edges(&self) -> PyResult { + Ok(self.0.read().map_err(poisoned_lock_err)?.num_edges()) + } + + /// Adds an instruction to the provenance graph with no transform or children. + /// Accepts either a pre-existing Instruction or keyword args to create one. + /// Returns the node index. + #[pyo3(signature = (id_or_instruction, encoding=0, *, arity=Some(1), time=None, space=None, length=None, error_rate=None, **kwargs))] + #[allow(clippy::too_many_arguments)] + pub fn add_instruction( + &mut self, + id_or_instruction: &Bound<'_, PyAny>, + encoding: u64, + arity: Option, + time: Option<&Bound<'_, PyAny>>, + space: Option<&Bound<'_, PyAny>>, + length: Option<&Bound<'_, PyAny>>, + error_rate: Option<&Bound<'_, PyAny>>, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult { + let instr = build_instruction( + id_or_instruction, + encoding, + arity, + time, + space, + length, + error_rate, + kwargs, + )?; + Ok(self + .0 + .write() + .map_err(poisoned_lock_err)? + .add_node(instr, 0, &[])) + } + + /// Creates an ISA backed by this provenance graph from the given node indices. + pub fn make_isa(&self, node_indices: Vec) -> PyResult { + let graph = self.0.read().map_err(poisoned_lock_err)?; + let mut isa = qre::ISA::with_graph(self.0.clone()); + for idx in node_indices { + let id = graph.instruction(idx).id(); + isa.add_node(id, idx); + } + Ok(ISA(isa)) } } @@ -474,17 +648,26 @@ pub fn linear_function<'py>(slope: &Bound<'py, PyAny>) -> PyResult( block_size: u64, slope: &Bound<'py, PyAny>, + offset: &Bound<'py, PyAny>, ) -> PyResult> { - if let Ok(s) = slope.extract::() { - IntFunction(qre::VariableArityFunction::block_linear(block_size, s)) - .into_bound_py_any(slope.py()) + if let Ok(s) = slope.extract() { + let offset = offset.extract::()?; + IntFunction(qre::VariableArityFunction::block_linear( + block_size, s, offset, + )) + .into_bound_py_any(slope.py()) } else if let Ok(s) = slope.extract::() { - FloatFunction(qre::VariableArityFunction::block_linear(block_size, s)) - .into_bound_py_any(slope.py()) + let offset = offset.extract()?; + FloatFunction(qre::VariableArityFunction::block_linear( + block_size, s, offset, + )) + .into_bound_py_any(slope.py()) } else { Err(PyTypeError::new_err( "Slope must be either an integer or a float", @@ -606,8 +789,9 @@ impl EstimationResult { self.0.qubits() } - pub fn add_qubits(&mut self, amount: u64) { - self.0.add_qubits(amount); + #[setter] + pub fn set_qubits(&mut self, qubits: u64) { + self.0.set_qubits(qubits); } #[getter] @@ -615,8 +799,9 @@ impl EstimationResult { self.0.runtime() } - pub fn add_runtime(&mut self, amount: u64) { - self.0.add_runtime(amount); + #[setter] + pub fn set_runtime(&mut self, runtime: u64) { + self.0.set_runtime(runtime); } #[getter] @@ -624,8 +809,9 @@ impl EstimationResult { self.0.error() } - pub fn add_error(&mut self, amount: f64) { - self.0.add_error(amount); + #[setter] + pub fn set_error(&mut self, error: f64) { + self.0.set_error(error); } #[allow(clippy::needless_pass_by_value)] @@ -662,6 +848,22 @@ impl EstimationResult { Ok(dict) } + pub fn set_property(&mut self, key: String, value: &Bound<'_, PyAny>) -> PyResult<()> { + let property = if value.is_instance_of::() { + qre::Property::new_bool(value.extract()?) + } else if let Ok(i) = value.extract::() { + qre::Property::new_int(i) + } else if let Ok(f) = value.extract::() { + qre::Property::new_float(f) + } else { + qre::Property::new_str(value.to_string()) + }; + + self.0.set_property(key, property); + + Ok(()) + } + fn __str__(&self) -> String { format!("{}", self.0) } diff --git a/source/pip/test_requirements.txt b/source/pip/test_requirements.txt index 3fc61b1768..5526d0eed5 100644 --- a/source/pip/test_requirements.txt +++ b/source/pip/test_requirements.txt @@ -2,3 +2,4 @@ pytest expecttest==0.3.0 pyqir<0.12 cirq==1.6.1 +pandas>=2.1 diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index c741538b75..ecd5b7e673 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -3,11 +3,13 @@ from dataclasses import KW_ONLY, dataclass, field from enum import Enum -from typing import Generator +from typing import cast, Generator import pytest +import pandas as pd import qsharp from qsharp.qre import ( + Application, ISA, LOGICAL, PSSPC, @@ -19,16 +21,21 @@ Trace, constraint, estimate, - instruction, linear_function, generic_function, ) +from qsharp.qre._qre import _ProvenanceGraph from qsharp.qre.application import QSharpApplication from qsharp.qre.models import ( SurfaceCode, AQREGateBased, ) -from qsharp.qre._architecture import _Context +from qsharp.qre._architecture import _Context, _make_instruction +from qsharp.qre._estimation import ( + EstimationTable, + EstimationTableEntry, +) +from qsharp.qre._instruction import InstructionSource from qsharp.qre._isa_enumeration import ( ISARefNode, ) @@ -50,8 +57,8 @@ def required_isa() -> ISARequirements: ) def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: - yield ISA( - instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + yield ctx.make_isa( + ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), ) @@ -68,17 +75,22 @@ def required_isa() -> ISARequirements: ) def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: - yield ISA( - instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), + yield ctx.make_isa( + ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), ) def test_isa(): - isa = ISA( - instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8, space=400), - instruction( - CCX, arity=3, encoding=LOGICAL, time=2000, error_rate=1e-10, space=800 - ), + graph = _ProvenanceGraph() + isa = graph.make_isa( + [ + graph.add_instruction( + T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 + ), + graph.add_instruction( + CCX, encoding=LOGICAL, arity=3, time=2000, space=800, error_rate=1e-10 + ), + ] ) assert T in isa @@ -97,26 +109,27 @@ def test_isa(): assert ccz_instr.error_rate() == 1e-10 assert ccz_instr.space() == 800 - isa.append(ccz_instr) + # Add another instruction to the graph and register it in the ISA + ccz_node = graph.add_instruction(ccz_instr) + isa.add_node(CCZ, ccz_node) assert CCZ in isa assert len(isa) == 3 - isa.append(ccz_instr) - assert ( - len(isa) == 3 - ) # Appending the same instruction should not increase the number of instructions + # Adding the same instruction ID should not increase the count + isa.add_node(CCZ, ccz_node) + assert len(isa) == 3 def test_instruction_properties(): # Test instruction with no properties - instr_no_props = instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8) + instr_no_props = _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {}) assert instr_no_props.get_property(PropertyKey.DISTANCE) is None assert instr_no_props.has_property(PropertyKey.DISTANCE) is False assert instr_no_props.get_property_or(PropertyKey.DISTANCE, 5) == 5 # Test instruction with valid property (distance) - instr_with_distance = instruction( - T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 + instr_with_distance = _make_instruction( + T, 1, 1, 1000, None, None, 1e-8, {"distance": 9} ) assert instr_with_distance.get_property(PropertyKey.DISTANCE) == 9 assert instr_with_distance.has_property(PropertyKey.DISTANCE) is True @@ -124,7 +137,7 @@ def test_instruction_properties(): # Test instruction with invalid property name with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): - instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8, invalid_prop=42) + _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {"invalid_prop": 42}) def test_instruction_constraints(): @@ -145,13 +158,19 @@ def test_instruction_constraints(): constraint(T, encoding=LOGICAL, invalid_prop=True) # Test ISA.satisfies with property constraints - instr_no_dist = instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8) - instr_with_dist = instruction( - T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 + graph = _ProvenanceGraph() + isa_no_dist = graph.make_isa( + [ + graph.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + ] + ) + isa_with_dist = graph.make_isa( + [ + graph.add_instruction( + T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 + ), + ] ) - - isa_no_dist = ISA(instr_no_dist) - isa_with_dist = ISA(instr_with_dist) reqs_no_prop = ISARequirements(constraint(T, encoding=LOGICAL)) reqs_with_prop = ISARequirements(constraint(T, encoding=LOGICAL, distance=True)) @@ -184,21 +203,22 @@ def error_rate(x: int) -> float: space_fn = generic_function(lambda x: 12) assert isinstance(space_fn, _FloatFunction) - i = instruction(42, arity=None, space=12, time=time_fn, error_rate=error_rate_fn) + i = _make_instruction(42, 0, None, time_fn, 12, None, error_rate_fn, {}) assert i.space(5) == 12 assert i.time(5) == 25 assert i.error_rate(5) == 2.5 def test_isa_from_architecture(): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) code = SurfaceCode() + ctx = arch.context() # Verify that the architecture satisfies the code requirements - assert arch.provided_isa.satisfies(SurfaceCode.required_isa()) + assert ctx.isa.satisfies(SurfaceCode.required_isa()) # Generate logical ISAs - isas = list(code.provided_isa(arch.provided_isa, arch.context())) + isas = list(code.provided_isa(ctx.isa, ctx)) # There is one ISA with one instructions assert len(isas) == 1 @@ -351,8 +371,137 @@ class UnionConfig: assert instances[2].option.number == 1 +def test_enumerate_instances_nested_with_constraints(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InnerConfig: + _: KW_ONLY + option: bool + + @dataclass + class OuterConfig: + _: KW_ONLY + inner: InnerConfig + + # Constrain nested field via dict + instances = list(_enumerate_instances(OuterConfig, inner={"option": True})) + assert len(instances) == 1 + assert instances[0].inner.option is True + + +def test_enumerate_instances_union_single_type(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + # Restrict to OptionB only - uses its default domain + instances = list(_enumerate_instances(UnionConfig, option=OptionB)) + assert len(instances) == 3 + assert all(isinstance(i.option, OptionB) for i in instances) + assert [cast(OptionB, i.option).number for i in instances] == [1, 2, 3] + + # Restrict to OptionA only + instances = list(_enumerate_instances(UnionConfig, option=OptionA)) + assert len(instances) == 2 + assert all(isinstance(i.option, OptionA) for i in instances) + assert cast(OptionA, instances[0].option).value is True + assert cast(OptionA, instances[1].option).value is False + + +def test_enumerate_instances_union_list_of_types(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class OptionC: + _: KW_ONLY + flag: bool + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB | OptionC + + # Select a subset: only OptionA and OptionB + instances = list(_enumerate_instances(UnionConfig, option=[OptionA, OptionB])) + assert len(instances) == 5 # 2 from OptionA + 3 from OptionB + assert all(isinstance(i.option, (OptionA, OptionB)) for i in instances) + + +def test_enumerate_instances_union_constraint_dict(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + # Constrain OptionA, enumerate only that member + instances = list( + _enumerate_instances(UnionConfig, option={OptionA: {"value": True}}) + ) + assert len(instances) == 1 + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + + # Constrain OptionB with a domain, enumerate only that member + instances = list( + _enumerate_instances(UnionConfig, option={OptionB: {"number": [2, 3]}}) + ) + assert len(instances) == 2 + assert all(isinstance(i.option, OptionB) for i in instances) + assert cast(OptionB, instances[0].option).number == 2 + assert cast(OptionB, instances[1].option).number == 3 + + # Constrain one member and keep another with defaults + instances = list( + _enumerate_instances( + UnionConfig, + option={OptionA: {"value": True}, OptionB: {}}, + ) + ) + assert len(instances) == 4 # 1 from OptionA + 3 from OptionB + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + assert all(isinstance(i.option, OptionB) for i in instances[1:]) + assert [cast(OptionB, i.option).number for i in instances[1:]] == [1, 2, 3] + + def test_enumerate_isas(): - ctx = AQREGateBased().context() + ctx = AQREGateBased(gate_time=50, measurement_time=100).context() # This will enumerate the 4 ISAs for the error correction code count = sum(1 for _ in SurfaceCode.q().enumerate(ctx)) @@ -407,7 +556,7 @@ def test_enumerate_isas(): def test_binding_node(): """Test binding nodes with ISARefNode for component bindings""" - ctx = AQREGateBased().context() + ctx = AQREGateBased(gate_time=50, measurement_time=100).context() # Test basic binding: same code used twice # Without binding: 12 codes × 12 codes = 144 combinations @@ -512,7 +661,7 @@ def test_binding_node(): def test_binding_node_errors(): """Test error handling for binding nodes""" - ctx = AQREGateBased().context() + ctx = AQREGateBased(gate_time=50, measurement_time=100).context() # Test ISARefNode enumerate with undefined binding raises ValueError try: @@ -629,17 +778,24 @@ def test_qsharp_application(): assert trace.depth == 3 assert trace.resource_states == {} - isa = ISA( - instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - time=1000, - error_rate=linear_function(1e-6), - space=linear_function(50), - ), - instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8, space=400), - instruction(CCX, encoding=LOGICAL, time=2000, error_rate=1e-10, space=800), + graph = _ProvenanceGraph() + isa = graph.make_isa( + [ + graph.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + time=1000, + space=linear_function(50), + error_rate=linear_function(1e-6), + ), + graph.add_instruction( + T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 + ), + graph.add_instruction( + CCX, encoding=LOGICAL, time=2000, space=800, error_rate=1e-10 + ), + ] ) # Properties from the program @@ -681,6 +837,21 @@ def test_qsharp_application(): assert counter == 32 +def test_application_enumeration(): + @dataclass(kw_only=True) + class _Params: + size: int = field(default=1, metadata={"domain": range(1, 4)}) + + class TestApp(Application[_Params]): + def get_trace(self, parameters: _Params) -> Trace: + return Trace(parameters.size) + + app = TestApp() + assert sum(1 for _ in TestApp.q().enumerate(app.context())) == 3 + assert sum(1 for _ in TestApp.q(size=1).enumerate(app.context())) == 1 + assert sum(1 for _ in TestApp.q(size=[4, 5]).enumerate(app.context())) == 2 + + def test_trace_enumeration(): code = """ {{ @@ -693,11 +864,8 @@ def test_trace_enumeration(): app = QSharpApplication(code) - from qsharp.qre._trace import RootNode - ctx = app.context() - root = RootNode() - assert sum(1 for _ in root.enumerate(ctx)) == 1 + assert sum(1 for _ in QSharpApplication.q().enumerate(ctx)) == 1 assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 32 @@ -729,7 +897,7 @@ def test_estimation_max_error(): from qsharp.estimator import LogicalCounts app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) for max_error in [1e-1, 1e-2, 1e-3, 1e-4]: results = estimate( @@ -766,3 +934,333 @@ def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): if CCX in trace.resource_states: actual_error += isa[CCX].expect_error_rate() * result.factories[CCX].states assert abs(result.error - actual_error) <= 1e-8 + + +# --- EstimationTable tests --- + + +def _make_entry(qubits, runtime, error, properties=None): + """Helper to create an EstimationTableEntry with a dummy InstructionSource.""" + return EstimationTableEntry( + qubits=qubits, + runtime=runtime, + error=error, + source=InstructionSource(), + properties=properties or {}, + ) + + +def test_estimation_table_default_columns(): + """Test that a new EstimationTable has the three default columns.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error"] + assert frame["qubits"][0] == 100 + assert frame["runtime"][0] == pd.Timedelta(5000, unit="ns") + assert frame["error"][0] == 0.01 + + +def test_estimation_table_multiple_rows(): + """Test as_frame with multiple entries.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + table.append(_make_entry(200, 10000, 0.02)) + table.append(_make_entry(300, 15000, 0.03)) + + frame = table.as_frame() + assert len(frame) == 3 + assert list(frame["qubits"]) == [100, 200, 300] + assert list(frame["error"]) == [0.01, 0.02, 0.03] + + +def test_estimation_table_empty(): + """Test as_frame with no entries produces an empty DataFrame.""" + table = EstimationTable() + frame = table.as_frame() + assert len(frame) == 0 + + +def test_estimation_table_add_column(): + """Test adding a column to the table.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={"val": 42})) + table.append(_make_entry(200, 10000, 0.02, properties={"val": 84})) + + table.add_column("val", lambda e: e.properties["val"]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "val"] + assert list(frame["val"]) == [42, 84] + + +def test_estimation_table_add_column_with_formatter(): + """Test adding a column with a formatter.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={"ns": 1000})) + + table.add_column( + "duration", + lambda e: e.properties["ns"], + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ) + + frame = table.as_frame() + assert frame["duration"][0] == pd.Timedelta(1000, unit="ns") + + +def test_estimation_table_add_multiple_columns(): + """Test adding multiple columns preserves order.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={"a": 1, "b": 2, "c": 3})) + + table.add_column("a", lambda e: e.properties["a"]) + table.add_column("b", lambda e: e.properties["b"]) + table.add_column("c", lambda e: e.properties["c"]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "a", "b", "c"] + assert frame["a"][0] == 1 + assert frame["b"][0] == 2 + assert frame["c"][0] == 3 + + +def test_estimation_table_insert_column_at_beginning(): + """Test inserting a column at index 0.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={"name": "test"})) + + table.insert_column(0, "name", lambda e: e.properties["name"]) + + frame = table.as_frame() + assert list(frame.columns) == ["name", "qubits", "runtime", "error"] + assert frame["name"][0] == "test" + + +def test_estimation_table_insert_column_in_middle(): + """Test inserting a column between existing default columns.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={"extra": 99})) + + # Insert between qubits and runtime (index 1) + table.insert_column(1, "extra", lambda e: e.properties["extra"]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "extra", "runtime", "error"] + assert frame["extra"][0] == 99 + + +def test_estimation_table_insert_column_at_end(): + """Test inserting a column at the end (same effect as add_column).""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={"last": True})) + + # 3 default columns, inserting at index 3 = end + table.insert_column(3, "last", lambda e: e.properties["last"]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "last"] + assert frame["last"][0] + + +def test_estimation_table_insert_column_with_formatter(): + """Test inserting a column with a formatter.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={"ns": 2000})) + + table.insert_column( + 0, + "custom_time", + lambda e: e.properties["ns"], + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ) + + frame = table.as_frame() + assert frame["custom_time"][0] == pd.Timedelta(2000, unit="ns") + assert list(frame.columns)[0] == "custom_time" + + +def test_estimation_table_insert_and_add_columns(): + """Test combining insert_column and add_column.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={"a": 1, "b": 2})) + + table.add_column("b", lambda e: e.properties["b"]) + table.insert_column(0, "a", lambda e: e.properties["a"]) + + frame = table.as_frame() + assert list(frame.columns) == ["a", "qubits", "runtime", "error", "b"] + + +def test_estimation_table_factory_summary_no_factories(): + """Test factory summary column when entries have no factories.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + + table.add_factory_summary_column() + + frame = table.as_frame() + assert "factories" in frame.columns + assert frame["factories"][0] == "None" + + +def test_estimation_table_factory_summary_with_estimation(): + """Test factory summary column with real estimation results.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = AQREGateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_factory_summary_column() + frame = results.as_frame() + + assert "factories" in frame.columns + # Each result should mention T in the factory summary + for val in frame["factories"]: + assert "T" in val + + +def test_estimation_table_add_column_from_source(): + """Test adding a column that accesses the InstructionSource (like distance).""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = AQREGateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_column( + "compute_distance", + lambda entry: entry.source[LATTICE_SURGERY].instruction[PropertyKey.DISTANCE], + ) + + frame = results.as_frame() + assert "compute_distance" in frame.columns + for d in frame["compute_distance"]: + assert isinstance(d, int) + assert d >= 3 + + +def test_estimation_table_add_column_from_properties(): + """Test adding columns that access trace properties from estimation.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = AQREGateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_column( + "num_ts_per_rotation", + lambda entry: entry.properties["num_ts_per_rotation"], + ) + + frame = results.as_frame() + assert "num_ts_per_rotation" in frame.columns + for val in frame["num_ts_per_rotation"]: + assert isinstance(val, int) + assert val >= 1 + + +def test_estimation_table_insert_column_before_defaults(): + """Test inserting a name column before all default columns, similar to the factoring notebook.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = AQREGateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + name="test_experiment", + ) + + assert len(results) >= 1 + + # Insert a name column at the beginning and add factory summary at the end + results.insert_column(0, "name", lambda entry: entry.properties.get("name", "")) + results.add_factory_summary_column() + + frame = results.as_frame() + assert frame.columns[0] == "name" + assert frame.columns[-1] == "factories" + # Default columns should still be in order + assert list(frame.columns[1:4]) == ["qubits", "runtime", "error"] + + +def test_estimation_table_as_frame_sortable(): + """Test that the DataFrame from as_frame can be sorted, as done in the factoring tests.""" + table = EstimationTable() + table.append(_make_entry(300, 15000, 0.03)) + table.append(_make_entry(100, 5000, 0.01)) + table.append(_make_entry(200, 10000, 0.02)) + + frame = table.as_frame() + sorted_frame = frame.sort_values(by=["qubits", "runtime"]).reset_index(drop=True) + + assert list(sorted_frame["qubits"]) == [100, 200, 300] + assert list(sorted_frame["error"]) == [0.01, 0.02, 0.03] + + +def test_estimation_table_computed_column(): + """Test adding a column that computes a derived value from the entry.""" + table = EstimationTable() + table.append(_make_entry(100, 5_000_000, 0.01)) + table.append(_make_entry(200, 10_000_000, 0.02)) + + # Compute qubits * error as a derived metric + table.add_column("qubit_error_product", lambda e: e.qubits * e.error) + + frame = table.as_frame() + assert frame["qubit_error_product"][0] == pytest.approx(1.0) + assert frame["qubit_error_product"][1] == pytest.approx(4.0) diff --git a/source/pip/tests/test_qre_models.py b/source/pip/tests/test_qre_models.py index 85c0643a92..a8c8f8462a 100644 --- a/source/pip/tests/test_qre_models.py +++ b/source/pip/tests/test_qre_models.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qsharp.qre import LOGICAL, PHYSICAL, Encoding, PropertyKey, instruction +from qsharp.qre import LOGICAL, PHYSICAL, PropertyKey from qsharp.qre.instruction_ids import ( T, CCZ, @@ -45,38 +45,42 @@ class TestAQREGateBased: def test_default_error_rate(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) assert arch.error_rate == 1e-4 def test_custom_error_rate(self): - arch = AQREGateBased(error_rate=1e-3) + arch = AQREGateBased(error_rate=1e-3, gate_time=50, measurement_time=100) assert arch.error_rate == 1e-3 def test_provided_isa_contains_expected_instructions(self): - arch = AQREGateBased() - isa = arch.provided_isa + arch = AQREGateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: assert instr_id in isa def test_instruction_encodings_are_physical(self): - arch = AQREGateBased() - isa = arch.provided_isa + arch = AQREGateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: assert isa[instr_id].encoding == PHYSICAL def test_instruction_error_rates_match(self): rate = 1e-3 - arch = AQREGateBased(error_rate=rate) - isa = arch.provided_isa + arch = AQREGateBased(error_rate=rate, gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: assert isa[instr_id].expect_error_rate() == rate def test_gate_times(self): - arch = AQREGateBased() - isa = arch.provided_isa + arch = AQREGateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa # Single-qubit gates: 50ns for instr_id in [PAULI_I, H, T]: @@ -90,8 +94,9 @@ def test_gate_times(self): assert isa[MEAS_Z].expect_time() == 100 def test_arities(self): - arch = AQREGateBased() - isa = arch.provided_isa + arch = AQREGateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa assert isa[PAULI_I].arity == 1 assert isa[CNOT].arity == 2 @@ -100,7 +105,7 @@ def test_arities(self): assert isa[MEAS_Z].arity == 1 def test_context_creation(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() assert ctx is not None @@ -117,14 +122,16 @@ def test_default_error_rate(self): def test_provided_isa_contains_expected_instructions(self): arch = Majorana() - isa = arch.provided_isa + ctx = arch.context() + isa = ctx.isa for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z, T]: assert instr_id in isa def test_all_times_are_1us(self): arch = Majorana() - isa = arch.provided_isa + ctx = arch.context() + isa = ctx.isa for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z, T]: assert isa[instr_id].expect_time() == 1000 @@ -132,7 +139,8 @@ def test_all_times_are_1us(self): def test_clifford_error_rates_match_qubit_error(self): for rate in [1e-4, 1e-5, 1e-6]: arch = Majorana(error_rate=rate) - isa = arch.provided_isa + ctx = arch.context() + isa = ctx.isa for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z]: assert isa[instr_id].expect_error_rate() == rate @@ -143,12 +151,14 @@ def test_t_error_rate_mapping(self): for qubit_rate, t_rate in expected.items(): arch = Majorana(error_rate=qubit_rate) - isa = arch.provided_isa + ctx = arch.context() + isa = ctx.isa assert isa[T].expect_error_rate() == t_rate def test_two_qubit_measurement_arities(self): arch = Majorana() - isa = arch.provided_isa + ctx = arch.context() + isa = ctx.isa assert isa[MEAS_XX].arity == 2 assert isa[MEAS_ZZ].arity == 2 @@ -169,11 +179,11 @@ def test_default_distance(self): assert sc.distance == 3 def test_provides_lattice_surgery(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc = SurfaceCode(distance=3) - isas = list(sc.provided_isa(arch.provided_isa, ctx)) + isas = list(sc.provided_isa(ctx.isa, ctx)) assert len(isas) == 1 isa = isas[0] @@ -184,37 +194,37 @@ def test_provides_lattice_surgery(self): def test_space_scales_with_distance(self): """Space = 2*d^2 - 1 physical qubits per logical qubit.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) for d in [3, 5, 7, 9]: ctx = arch.context() sc = SurfaceCode(distance=d) - isas = list(sc.provided_isa(arch.provided_isa, ctx)) + isas = list(sc.provided_isa(ctx.isa, ctx)) ls = isas[0][LATTICE_SURGERY] expected_space = 2 * d**2 - 1 assert ls.expect_space(1) == expected_space def test_time_scales_with_distance(self): """Time = (h_time + 4*cnot_time + meas_time) * d.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) # h=50, cnot=50, meas=100 for AQREGateBased syndrome_time = 50 + 4 * 50 + 100 # = 350 for d in [3, 5, 7]: ctx = arch.context() sc = SurfaceCode(distance=d) - isas = list(sc.provided_isa(arch.provided_isa, ctx)) + isas = list(sc.provided_isa(ctx.isa, ctx)) ls = isas[0][LATTICE_SURGERY] assert ls.expect_time(1) == syndrome_time * d def test_error_rate_decreases_with_distance(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) errors = [] for d in [3, 5, 7, 9, 11]: ctx = arch.context() sc = SurfaceCode(distance=d) - isas = list(sc.provided_isa(arch.provided_isa, ctx)) + isas = list(sc.provided_isa(ctx.isa, ctx)) errors.append(isas[0][LATTICE_SURGERY].expect_error_rate(1)) # Each successive distance should have a lower error rate @@ -223,7 +233,7 @@ def test_error_rate_decreases_with_distance(self): def test_enumeration_via_query(self): """Enumerating SurfaceCode.q() should yield multiple distances.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 @@ -235,18 +245,18 @@ def test_enumeration_via_query(self): assert count == 12 def test_custom_crossing_prefactor(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc_default = SurfaceCode(distance=5) sc_custom = SurfaceCode(crossing_prefactor=0.06, distance=5) - default_error = list(sc_default.provided_isa(arch.provided_isa, ctx))[0][ + default_error = list(sc_default.provided_isa(ctx.isa, ctx))[0][ LATTICE_SURGERY ].expect_error_rate(1) ctx2 = arch.context() - custom_error = list(sc_custom.provided_isa(arch.provided_isa, ctx2))[0][ + custom_error = list(sc_custom.provided_isa(ctx2.isa, ctx2))[0][ LATTICE_SURGERY ].expect_error_rate(1) @@ -254,17 +264,17 @@ def test_custom_crossing_prefactor(self): assert abs(custom_error - 2 * default_error) < 1e-20 def test_custom_error_correction_threshold(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx1 = arch.context() sc_low_threshold = SurfaceCode(error_correction_threshold=0.005, distance=5) - error_low = list(sc_low_threshold.provided_isa(arch.provided_isa, ctx1))[0][ + error_low = list(sc_low_threshold.provided_isa(ctx1.isa, ctx1))[0][ LATTICE_SURGERY ].expect_error_rate(1) ctx2 = arch.context() sc_high_threshold = SurfaceCode(error_correction_threshold=0.02, distance=5) - error_high = list(sc_high_threshold.provided_isa(arch.provided_isa, ctx2))[0][ + error_high = list(sc_high_threshold.provided_isa(ctx2.isa, ctx2))[0][ LATTICE_SURGERY ].expect_error_rate(1) @@ -287,7 +297,7 @@ def test_provides_lattice_surgery(self): ctx = arch.context() ta = ThreeAux(distance=3) - isas = list(ta.provided_isa(arch.provided_isa, ctx)) + isas = list(ta.provided_isa(ctx.isa, ctx)) assert len(isas) == 1 assert LATTICE_SURGERY in isas[0] @@ -298,7 +308,7 @@ def test_space_formula(self): for d in [3, 5, 7]: ctx = arch.context() ta = ThreeAux(distance=d) - isas = list(ta.provided_isa(arch.provided_isa, ctx)) + isas = list(ta.provided_isa(ctx.isa, ctx)) ls = isas[0][LATTICE_SURGERY] expected = 4 * d**2 - 3 assert ls.expect_space(1) == expected @@ -310,7 +320,7 @@ def test_time_formula_double_rail(self): for d in [3, 5, 7]: ctx = arch.context() ta = ThreeAux(distance=d, single_rail=False) - isas = list(ta.provided_isa(arch.provided_isa, ctx)) + isas = list(ta.provided_isa(ctx.isa, ctx)) ls = isas[0][LATTICE_SURGERY] # MEAS_XX and MEAS_ZZ have time=1000 each; max is 1000 expected_time = 1000 * (4 * d + 4) @@ -323,7 +333,7 @@ def test_time_formula_single_rail(self): for d in [3, 5, 7]: ctx = arch.context() ta = ThreeAux(distance=d, single_rail=True) - isas = list(ta.provided_isa(arch.provided_isa, ctx)) + isas = list(ta.provided_isa(ctx.isa, ctx)) ls = isas[0][LATTICE_SURGERY] expected_time = 1000 * (5 * d + 4) assert ls.expect_time(1) == expected_time @@ -335,7 +345,7 @@ def test_error_rate_decreases_with_distance(self): for d in [3, 5, 7, 9]: ctx = arch.context() ta = ThreeAux(distance=d) - isas = list(ta.provided_isa(arch.provided_isa, ctx)) + isas = list(ta.provided_isa(ctx.isa, ctx)) errors.append(isas[0][LATTICE_SURGERY].expect_error_rate(1)) for i in range(len(errors) - 1): @@ -347,13 +357,13 @@ def test_single_rail_has_different_error_threshold(self): ctx1 = arch.context() double = ThreeAux(distance=5, single_rail=False) - error_double = list(double.provided_isa(arch.provided_isa, ctx1))[0][ + error_double = list(double.provided_isa(ctx1.isa, ctx1))[0][ LATTICE_SURGERY ].expect_error_rate(1) ctx2 = arch.context() single = ThreeAux(distance=5, single_rail=True) - error_single = list(single.provided_isa(arch.provided_isa, ctx2))[0][ + error_single = list(single.provided_isa(ctx2.isa, ctx2))[0][ LATTICE_SURGERY ].expect_error_rate(1) @@ -384,10 +394,10 @@ def test_enumeration_via_query(self): class TestYokedSurfaceCode: def _get_lattice_surgery_isa(self, distance=5): """Helper to get a lattice surgery ISA from SurfaceCode.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc = SurfaceCode(distance=distance) - isas = list(sc.provided_isa(arch.provided_isa, ctx)) + isas = list(sc.provided_isa(ctx.isa, ctx)) return isas[0], ctx def test_provides_memory_instruction(self): @@ -470,11 +480,11 @@ def test_required_isa(self): def test_table1_aqre_yields_t_and_ccz(self): """AQREGateBased (error 1e-4) matches Table 1 scenario: T & CCZ.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() - isas = list(factory.provided_isa(arch.provided_isa, ctx)) + isas = list(factory.provided_isa(ctx.isa, ctx)) # 6 T entries × 1 CCZ entry = 6 combinations assert len(isas) == 6 @@ -485,11 +495,11 @@ def test_table1_aqre_yields_t_and_ccz(self): assert len(isa) == 2 def test_table1_instruction_properties(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() - for isa in factory.provided_isa(arch.provided_isa, ctx): + for isa in factory.provided_isa(ctx.isa, ctx): t_instr = isa[T] ccz_instr = isa[CCZ] @@ -505,11 +515,11 @@ def test_table1_instruction_properties(self): def test_table1_t_error_rates_are_diverse(self): """T entries in Table 1 should span a range of error rates.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() - isas = list(factory.provided_isa(arch.provided_isa, ctx)) + isas = list(factory.provided_isa(ctx.isa, ctx)) t_errors = [isa[T].expect_error_rate() for isa in isas] # Should have multiple distinct T error rates @@ -522,11 +532,11 @@ def test_table1_t_error_rates_are_diverse(self): def test_table1_1e3_clifford_yields_6_isas(self): """AQREGateBased with 1e-3 error matches Table 1 at 1e-3 Clifford.""" - arch = AQREGateBased(error_rate=1e-3) + arch = AQREGateBased(error_rate=1e-3, gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() - isas = list(factory.provided_isa(arch.provided_isa, ctx)) + isas = list(factory.provided_isa(ctx.isa, ctx)) # 6 T entries × 1 CCZ entry = 6 combinations assert len(isas) == 6 @@ -537,23 +547,20 @@ def test_table1_1e3_clifford_yields_6_isas(self): def test_table2_scenario_no_ccz(self): """Table 2 scenario: T error ~10x higher than Clifford, no CCZ.""" - from qsharp.qre import ISA as ISAType + from qsharp.qre._qre import _ProvenanceGraph - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() # Manually create ISA with T error rate 10x Clifford - isa_input = ISAType( - instruction( - CNOT, encoding=Encoding.PHYSICAL, arity=2, time=50, error_rate=1e-4 - ), - instruction( - H, encoding=Encoding.PHYSICAL, arity=1, time=50, error_rate=1e-4 - ), - instruction( - MEAS_Z, encoding=Encoding.PHYSICAL, arity=1, time=100, error_rate=1e-4 - ), - instruction(T, encoding=Encoding.PHYSICAL, time=50, error_rate=1e-3), + graph = _ProvenanceGraph() + isa_input = graph.make_isa( + [ + graph.add_instruction(CNOT, arity=2, time=50, error_rate=1e-4), + graph.add_instruction(H, time=50, error_rate=1e-4), + graph.add_instruction(MEAS_Z, time=100, error_rate=1e-4), + graph.add_instruction(T, time=50, error_rate=1e-3), + ] ) factory = Litinski19Factory() @@ -568,22 +575,19 @@ def test_table2_scenario_no_ccz(self): def test_no_yield_when_error_too_high(self): """If T error > 10x Clifford, no entries match.""" - from qsharp.qre import ISA as ISAType - - arch = AQREGateBased() - ctx = arch.context() - - isa_input = ISAType( - instruction( - CNOT, encoding=Encoding.PHYSICAL, arity=2, time=50, error_rate=1e-4 - ), - instruction( - H, encoding=Encoding.PHYSICAL, arity=1, time=50, error_rate=1e-4 - ), - instruction( - MEAS_Z, encoding=Encoding.PHYSICAL, arity=1, time=100, error_rate=1e-4 - ), - instruction(T, encoding=Encoding.PHYSICAL, time=50, error_rate=0.05), + from qsharp.qre._qre import _ProvenanceGraph + + arch = AQREGateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + + graph = _ProvenanceGraph() + isa_input = graph.make_isa( + [ + graph.add_instruction(CNOT, arity=2, time=50, error_rate=1e-4), + graph.add_instruction(H, time=50, error_rate=1e-4), + graph.add_instruction(MEAS_Z, time=100, error_rate=1e-4), + graph.add_instruction(T, time=50, error_rate=0.05), + ] ) factory = Litinski19Factory() @@ -592,14 +596,14 @@ def test_no_yield_when_error_too_high(self): def test_time_based_on_syndrome_extraction(self): """Time should be based on syndrome extraction time × cycles.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() # For AQREGateBased: syndrome_extraction_time = 4*50 + 50 + 100 = 350 syndrome_time = 4 * 50 + 50 + 100 # 350 ns - isas = list(factory.provided_isa(arch.provided_isa, ctx)) + isas = list(factory.provided_isa(ctx.isa, ctx)) for isa in isas: t_time = isa[T].expect_time() assert t_time > 0 @@ -620,12 +624,12 @@ def test_required_isa_is_empty(self): def test_adds_clifford_equivalent_t_gates(self): """Given T gate, should add SQRT_SQRT_X/Y/Z and dagger variants.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() - for isa in factory.provided_isa(arch.provided_isa, ctx): + for isa in factory.provided_isa(ctx.isa, ctx): modified_isas = list(modifier.provided_isa(isa, ctx)) assert len(modified_isas) == 1 modified_isa = modified_isas[0] @@ -645,12 +649,12 @@ def test_adds_clifford_equivalent_t_gates(self): def test_adds_clifford_equivalent_ccz(self): """Given CCZ, should add CCX and CCY.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() - for isa in factory.provided_isa(arch.provided_isa, ctx): + for isa in factory.provided_isa(ctx.isa, ctx): modified_isas = list(modifier.provided_isa(isa, ctx)) modified_isa = modified_isas[0] @@ -661,24 +665,24 @@ def test_adds_clifford_equivalent_ccz(self): def test_full_count_of_instructions(self): """T gate (1) + 5 equivalents (SQRT_SQRT_*) + CCZ (1) + 2 equivalents (CCX, CCY) = 9.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() - for isa in factory.provided_isa(arch.provided_isa, ctx): + for isa in factory.provided_isa(ctx.isa, ctx): modified_isas = list(modifier.provided_isa(isa, ctx)) assert len(modified_isas[0]) == 9 break def test_equivalent_instructions_share_properties(self): """Clifford equivalents should have same time, space, error rate.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() - for isa in factory.provided_isa(arch.provided_isa, ctx): + for isa in factory.provided_isa(ctx.isa, ctx): modified_isas = list(modifier.provided_isa(isa, ctx)) modified_isa = modified_isas[0] @@ -704,13 +708,13 @@ def test_equivalent_instructions_share_properties(self): def test_modification_count_matches_factory_output(self): """MagicUpToClifford should produce one modified ISA per input ISA.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() modified_count = 0 - for isa in factory.provided_isa(arch.provided_isa, ctx): + for isa in factory.provided_isa(ctx.isa, ctx): for _ in modifier.provided_isa(isa, ctx): modified_count += 1 @@ -718,24 +722,28 @@ def test_modification_count_matches_factory_output(self): def test_no_family_present_passes_through(self): """If no family member is present, ISA passes through unchanged.""" - from qsharp.qre import ISA as ISAType + from qsharp.qre._qre import _ProvenanceGraph - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() modifier = MagicUpToClifford() # ISA with only a LATTICE_SURGERY instruction (no T or CCZ family) from qsharp.qre import linear_function - ls = instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - space=linear_function(17), - time=1000, - error_rate=linear_function(1e-10), + graph = _ProvenanceGraph() + isa_input = graph.make_isa( + [ + graph.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + time=1000, + space=linear_function(17), + error_rate=linear_function(1e-10), + ) + ] ) - isa_input = ISAType(ls) results = list(modifier.provided_isa(isa_input, ctx)) assert len(results) == 1 @@ -749,14 +757,14 @@ def test_no_family_present_passes_through(self): def test_isa_manipulation(): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) factory = Litinski19Factory() modifier = MagicUpToClifford() ctx = arch.context() # Table 1 scenario: should yield ISAs with both T and CCZ instructions - isas = list(factory.provided_isa(arch.provided_isa, ctx)) + isas = list(factory.provided_isa(ctx.isa, ctx)) # 6 T entries × 1 CCZ entry = 6 combinations assert len(isas) == 6 @@ -781,7 +789,7 @@ def test_isa_manipulation(): # After MagicUpToClifford modifier modified_count = 0 - for isa in factory.provided_isa(arch.provided_isa, ctx): + for isa in factory.provided_isa(ctx.isa, ctx): for modified_isa in modifier.provided_isa(isa, ctx): modified_count += 1 # MagicUpToClifford should add derived instructions @@ -804,7 +812,7 @@ def test_required_isa(self): assert reqs is not None def test_produces_logical_t_gates(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): t = isa[T] @@ -817,7 +825,7 @@ def test_produces_logical_t_gates(self): def test_error_rates_are_bounded(self): """Distilled T error rates should be bounded and mostly small.""" - arch = AQREGateBased() # T error rate is 1e-4 + arch = AQREGateBased(gate_time=50, measurement_time=100) # T error rate is 1e-4 errors = [] for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): @@ -834,7 +842,7 @@ def test_error_rates_are_bounded(self): def test_max_produces_fewer_or_equal_results_than_sum(self): """Using max for physical_qubit_calculation may filter differently.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) sum_count = sum( 1 for _ in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()) @@ -850,7 +858,7 @@ def test_max_produces_fewer_or_equal_results_than_sum(self): def test_max_space_less_than_or_equal_sum_space(self): """max-aggregated space should be <= sum-aggregated space for each.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) sum_spaces = sorted( isa[T].expect_space() @@ -882,7 +890,7 @@ def test_with_three_aux_code_query(self): assert count > 0 def test_round_based_aqre_sum(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) total_space = 0 total_time = 0 @@ -901,7 +909,7 @@ def test_round_based_aqre_sum(self): assert count == 107 def test_round_based_aqre_max(self): - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) total_space = 0 total_time = 0 @@ -951,17 +959,17 @@ def test_round_based_msft_sum(self): class TestCrossModelIntegration: def test_surface_code_feeds_into_litinski(self): """SurfaceCode -> Litinski19Factory pipeline works end to end.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() # SurfaceCode takes AQRE physical ISA -> LATTICE_SURGERY sc = SurfaceCode(distance=5) - sc_isas = list(sc.provided_isa(arch.provided_isa, ctx)) + sc_isas = list(sc.provided_isa(ctx.isa, ctx)) assert len(sc_isas) == 1 # Litinski takes H, CNOT, MEAS_Z, T from the physical ISA factory = Litinski19Factory() - factory_isas = list(factory.provided_isa(arch.provided_isa, ctx)) + factory_isas = list(factory.provided_isa(ctx.isa, ctx)) assert len(factory_isas) > 0 def test_three_aux_feeds_into_round_based(self): @@ -980,7 +988,7 @@ def test_three_aux_feeds_into_round_based(self): def test_litinski_with_magic_up_to_clifford_query(self): """Full query chain: Litinski19Factory -> MagicUpToClifford.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 @@ -995,7 +1003,7 @@ def test_litinski_with_magic_up_to_clifford_query(self): def test_surface_code_with_yoked_surface_code(self): """SurfaceCode -> YokedSurfaceCode pipeline provides MEMORY.""" - arch = AQREGateBased() + arch = AQREGateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index 2ca511747a..01dac1762c 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -3,8 +3,8 @@ use std::{ fmt::Display, - ops::{Add, Deref, Index}, - sync::Arc, + ops::Add, + sync::{Arc, RwLock}, }; use num_traits::FromPrimitive; @@ -16,40 +16,115 @@ use crate::trace::instruction_ids::instruction_name; #[cfg(test)] mod tests; -#[derive(Default, Clone)] +#[derive(Clone)] pub struct ISA { - instructions: FxHashMap, + graph: Arc>, + nodes: FxHashMap, +} + +impl Default for ISA { + fn default() -> Self { + ISA { + graph: Arc::new(RwLock::new(ProvenanceGraph::new())), + nodes: FxHashMap::default(), + } + } } impl ISA { #[must_use] pub fn new() -> Self { + Self::default() + } + + /// Creates an ISA backed by the given shared provenance graph. + #[must_use] + pub fn with_graph(graph: Arc>) -> Self { ISA { - instructions: FxHashMap::default(), + graph, + nodes: FxHashMap::default(), } } - pub fn add_instruction(&mut self, instruction: Instruction) { - self.instructions.insert(instruction.id, instruction); + /// Returns a reference to the shared provenance graph. + #[must_use] + pub fn graph(&self) -> &Arc> { + &self.graph + } + + /// Adds an instruction to the provenance graph and records its node index. + /// Returns the node index in the graph. + pub fn add_instruction(&mut self, instruction: Instruction) -> usize { + let id = instruction.id; + let mut graph = self.graph.write().expect("provenance graph lock poisoned"); + let node_idx = graph.add_node(instruction, 0, &[]); + self.nodes.insert(id, node_idx); + node_idx + } + + /// Records an existing provenance graph node in this ISA. + pub fn add_node(&mut self, instruction_id: u64, node_index: usize) { + self.nodes.insert(instruction_id, node_index); } + /// Returns the node index for an instruction ID, if present. #[must_use] - pub fn get(&self, id: &u64) -> Option<&Instruction> { - self.instructions.get(id) + pub fn node_index(&self, id: &u64) -> Option { + self.nodes.get(id).copied() + } + + /// Returns a clone of the instruction with the given ID, if present. + #[must_use] + pub fn get(&self, id: &u64) -> Option { + let &node_idx = self.nodes.get(id)?; + let graph = self.read_graph(); + Some(graph.instruction(node_idx).clone()) } #[must_use] pub fn contains(&self, id: &u64) -> bool { - self.instructions.contains_key(id) + self.nodes.contains_key(id) + } + + #[must_use] + pub fn len(&self) -> usize { + self.nodes.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + fn read_graph(&self) -> std::sync::RwLockReadGuard<'_, ProvenanceGraph> { + self.graph.read().expect("provenance graph lock poisoned") + } + + /// Returns an iterator over pairs of instruction IDs and node IDs. + pub fn node_entries(&self) -> impl Iterator { + self.nodes.iter() + } + + /// Returns all instructions as owned clones. + #[must_use] + pub fn instructions(&self) -> Vec { + let graph = self.read_graph(); + self.nodes + .values() + .map(|&idx| graph.instruction(idx).clone()) + .collect() } #[must_use] pub fn satisfies(&self, requirements: &ISARequirements) -> bool { + let graph = self.read_graph(); for constraint in requirements.constraints.values() { - let Some(instruction) = self.instructions.get(&constraint.id) else { + let Some(&node_idx) = self.nodes.get(&constraint.id) else { return false; }; + let instruction = graph.instruction(node_idx); + if instruction.encoding != constraint.encoding { return false; } @@ -99,14 +174,6 @@ impl ISA { } } -impl Deref for ISA { - type Target = FxHashMap; - - fn deref(&self) -> &Self::Target { - &self.instructions - } -} - impl FromIterator for ISA { fn from_iter>(iter: T) -> Self { let mut isa = ISA::new(); @@ -119,28 +186,37 @@ impl FromIterator for ISA { impl Display for ISA { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for instruction in self.instructions.values() { + let graph = self.read_graph(); + for &node_idx in self.nodes.values() { + let instruction = graph.instruction(node_idx); writeln!(f, "{instruction}")?; } Ok(()) } } -impl Index for ISA { - type Output = Instruction; - - fn index(&self, index: u64) -> &Self::Output { - &self.instructions[&index] - } -} - impl Add for ISA { type Output = ISA; fn add(self, other: ISA) -> ISA { let mut combined = self; - for instruction in other.instructions.into_values() { - combined.add_instruction(instruction); + if Arc::ptr_eq(&combined.graph, &other.graph) { + // Same graph: just merge node maps + for (id, node_idx) in other.nodes { + combined.nodes.insert(id, node_idx); + } + } else { + // Different graphs: copy instructions into combined's graph + let other_graph = other.read_graph(); + let mut self_graph = combined + .graph + .write() + .expect("provenance graph lock poisoned"); + for (id, node_idx) in &other.nodes { + let instruction = other_graph.instruction(*node_idx).clone(); + let new_idx = self_graph.add_node(instruction, 0, &[]); + combined.nodes.insert(*id, new_idx); + } } combined } @@ -450,6 +526,7 @@ pub enum VariableArityFunction { BlockLinear { block_size: u64, slope: T, + offset: T, }, #[serde(skip)] Generic { @@ -468,8 +545,12 @@ impl + std::ops::Mul + Copy + FromPrimitive> VariableArityFunction::Linear { slope } } - pub fn block_linear(block_size: u64, slope: T) -> Self { - VariableArityFunction::BlockLinear { block_size, slope } + pub fn block_linear(block_size: u64, slope: T, offset: T) -> Self { + VariableArityFunction::BlockLinear { + block_size, + slope, + offset, + } } pub fn generic(func: impl Fn(u64) -> T + Send + Sync + 'static) -> Self { @@ -488,9 +569,14 @@ impl + std::ops::Mul + Copy + FromPrimitive> VariableArityFunction::Linear { slope } => { *slope * T::from_u64(arity).expect("Failed to convert u64 to target type") } - VariableArityFunction::BlockLinear { block_size, slope } => { + VariableArityFunction::BlockLinear { + block_size, + slope, + offset, + } => { let blocks = arity.div_ceil(*block_size); *slope * T::from_u64(blocks).expect("Failed to convert u64 to target type") + + *offset } VariableArityFunction::Generic { func } => func(arity), } @@ -566,16 +652,17 @@ impl ProvenanceGraph { pub fn add_node( &mut self, - instruction_id: u64, + mut instruction: Instruction, transform_id: u64, children: &[usize], ) -> usize { let node_index = self.nodes.len(); + instruction.source = node_index; let offset = self.children.len(); let num_children = children.len(); self.children.extend_from_slice(children); self.nodes.push(ProvenanceNode { - instruction_id, + instruction, transform_id, offset, num_children, @@ -584,8 +671,8 @@ impl ProvenanceGraph { } #[must_use] - pub fn instruction_id(&self, node_index: usize) -> u64 { - self.nodes[node_index].instruction_id + pub fn instruction(&self, node_index: usize) -> &Instruction { + &self.nodes[node_index].instruction } #[must_use] @@ -610,10 +697,20 @@ impl ProvenanceGraph { } } -#[derive(Default)] struct ProvenanceNode { - instruction_id: u64, + instruction: Instruction, transform_id: u64, offset: usize, num_children: usize, } + +impl Default for ProvenanceNode { + fn default() -> Self { + ProvenanceNode { + instruction: Instruction::fixed_arity(0, Encoding::Physical, 0, 0, None, None, 0.0), + transform_id: 0, + offset: 0, + num_children: 0, + } + } +} diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index d9a4f12162..60f1764a4e 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -235,9 +235,9 @@ impl Trace { // ------------------------------------------------------------------ for (factory, count) in &factories { let instr = get_instruction(isa, *factory)?; - let factory_time = get_time(instr)?; - let factory_space = get_space(instr)?; - let factory_error_rate = get_error_rate(instr)?; + let factory_time = get_time(&instr)?; + let factory_space = get_space(&instr)?; + let factory_error_rate = get_error_rate(&instr)?; let runs = result.runtime() / factory_time; if runs == 0 { @@ -590,7 +590,7 @@ impl Display for Property { // Some helper functions to extract instructions and their metrics together with // error handling -fn get_instruction(isa: &ISA, id: u64) -> Result<&Instruction, Error> { +fn get_instruction(isa: &ISA, id: u64) -> Result { isa.get(&id).ok_or(Error::InstructionNotFound(id)) } diff --git a/source/qre/src/trace/transforms/psspc.rs b/source/qre/src/trace/transforms/psspc.rs index 309cd1fb81..0f671b14f3 100644 --- a/source/qre/src/trace/transforms/psspc.rs +++ b/source/qre/src/trace/transforms/psspc.rs @@ -2,7 +2,7 @@ // Licensed under the MIT License. use crate::trace::{Gate, TraceTransform}; -use crate::{Error, Trace, instruction_ids}; +use crate::{Error, Property, Trace, instruction_ids}; /// Implements the Parellel Synthesis Sequential Pauli Computation (PSSPC) /// layout algorithm described in Appendix D in @@ -122,7 +122,7 @@ impl PSSPC { Ok(counter) } - #[allow(clippy::cast_precision_loss)] + #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap)] fn get_trace(&self, trace: &Trace, counts: &PSSPCCounts) -> Trace { let num_qubits = trace.compute_qubits(); let logical_qubits = Self::logical_qubit_overhead(num_qubits); @@ -145,6 +145,12 @@ impl PSSPC { // Add error due to rotation synthesis transformed.increment_base_error(counts.rotation_like as f64 * self.synthesis_error()); + // Track some properties + transformed.set_property( + String::from("num_ts_per_rotation"), + Property::Int(self.num_ts_per_rotation as i64), + ); + transformed } From a7d15c9196acf1a97287659d02263a9aa1cdc227 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 16 Mar 2026 17:53:56 +0100 Subject: [PATCH 24/45] Refactor property keys from string to integer and add estimation improvements (#3013) This PR replaces string-keyed properties with integer keys defined via a Rust macro, and includes several estimation performance and API improvements. **Property keys** - New `define_properties!` macro in isa/property_keys.rs auto-assigns u64 values and generates `property_name_to_key` - Properties on `Trace`, `EstimationResult`, and `Instruction` now use `u64` keys instead of `String` - Python `property_keys` submodule exposes all keys as constants; `PropertyKey` enum is removed **Estimation improvements** - `LockedISA` holds a single read lock for the duration of `estimate()`, avoiding repeated lock acquisitions - ISA cloning is deferred to Pareto-surviving results only, reducing allocation in `estimate_parallel` - `EstimationTableStats` tracks job counts (total, successful, Pareto-surviving) - `add_qubit_partition_column()` breaks down physical qubits into compute, factory, and memory **Model updates** - `SurfaceCode`: configurable `one_qubit_gate_depth` / `two_qubit_gate_depth` with per-instruction time factor overrides - `RoundBasedFactory`: deterministic cache key using full instruction serialization Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- source/pip/qsharp/qre/__init__.py | 2 - source/pip/qsharp/qre/_architecture.py | 13 +- source/pip/qsharp/qre/_estimation.py | 48 ++++++- source/pip/qsharp/qre/_instruction.py | 15 +-- source/pip/qsharp/qre/_qre.py | 3 + source/pip/qsharp/qre/_qre.pyi | 65 ++++++++-- source/pip/qsharp/qre/interop/_qsharp.py | 3 +- .../qre/models/factories/_round_based.py | 24 ++-- .../qsharp/qre/models/qec/_surface_code.py | 30 ++++- source/pip/qsharp/qre/models/qec/_yoked.py | 5 +- source/pip/qsharp/qre/property_keys.py | 10 ++ source/pip/qsharp/qre/property_keys.pyi | 14 +++ source/pip/src/qre.rs | 110 ++++++++++++---- source/pip/tests/test_qre.py | 117 +++++++++++------- source/pip/tests/test_qre_models.py | 5 +- source/qre/src/isa.rs | 32 ++++- source/qre/src/isa/property_keys.rs | 49 ++++++++ source/qre/src/lib.rs | 10 +- source/qre/src/pareto.rs | 13 ++ source/qre/src/result.rs | 52 ++++++-- source/qre/src/trace.rs | 81 ++++++++---- source/qre/src/trace/transforms/psspc.rs | 3 +- source/qre/src/utils.rs | 10 ++ 23 files changed, 556 insertions(+), 158 deletions(-) create mode 100644 source/pip/qsharp/qre/property_keys.py create mode 100644 source/pip/qsharp/qre/property_keys.pyi create mode 100644 source/qre/src/isa/property_keys.rs diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index 86aba4790d..2cdaf8dfc1 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -14,7 +14,6 @@ PHYSICAL, Encoding, ISATransform, - PropertyKey, constraint, InstructionSource, ) @@ -65,7 +64,6 @@ "ISARequirements", "ISATransform", "LatticeSurgery", - "PropertyKey", "PSSPC", "Trace", "TraceQuery", diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index 15bdd3afdb..2260ac44e8 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -14,6 +14,8 @@ _IntFunction, _FloatFunction, constant_function, + instruction_name, + property_name_to_key, ) if TYPE_CHECKING: @@ -22,10 +24,6 @@ from ._instruction import ISATransform, Encoding -# Valid property names for instructions, mapped to their integer keys. -_PROPERTY_KEYS: dict[str, int] = {"distance": 0} - - class Architecture(ABC): @abstractmethod def provided_isa(self, ctx: _Context) -> ISA: @@ -228,12 +226,9 @@ def _make_instruction( ) for key, value in properties.items(): - prop_key = _PROPERTY_KEYS.get(key) + prop_key = property_name_to_key(key) if prop_key is None: - raise ValueError( - f"Unknown property '{key}'. " - f"Valid properties: {list(_PROPERTY_KEYS)}" - ) + raise ValueError(f"Unknown property '{key}'.") instr.set_property(prop_key, value) return instr diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index dde6b23d35..1a78225a94 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -21,6 +21,11 @@ from ._trace import TraceQuery, PSSPC, LatticeSurgery from ._instruction import InstructionSource from ._isa_enumeration import ISAQuery +from .property_keys import ( + PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_MEMORY_QUBITS, + PHYSICAL_FACTORY_QUBITS, +) def estimate( @@ -76,6 +81,9 @@ def estimate( params_and_traces = list(trace_query.enumerate(app_ctx, track_parameters=True)) isas = list(isa_query.enumerate(arch_ctx)) + num_traces = len(params_and_traces) + num_isas = len(isas) + # Estimate all trace × ISA combinations using Python threads collection = _EstimationCollection() @@ -85,6 +93,7 @@ def _estimate_one(params, trace, isa): result = app_ctx.application.post_process(params, result) return result + successful = 0 with ThreadPoolExecutor() as executor: futures = [ executor.submit(_estimate_one, params, trace, isa) @@ -94,13 +103,18 @@ def _estimate_one(params, trace, isa): for future in futures: result = future.result() if result is not None: + successful += 1 collection.insert(result) else: traces = list(trace_query.enumerate(app_ctx)) isas = list(isa_query.enumerate(arch_ctx)) + num_traces = len(traces) + num_isas = len(isas) + # Use the Rust parallel estimation path collection = _estimate_parallel(cast(list[Trace], traces), isas, max_error) + successful = collection.successful_estimates # Post-process the results and add them to a results table table = EstimationTable() @@ -120,6 +134,13 @@ def _estimate_one(params, trace, isa): table.append(entry) + # Fill in the stats for this estimation run + table.stats.num_traces = num_traces + table.stats.num_isas = num_isas + table.stats.total_jobs = num_traces * num_isas + table.stats.successful_estimates = successful + table.stats.pareto_results = len(collection) + return table @@ -137,6 +158,8 @@ def __init__(self): """Initialize an empty estimation table with default columns.""" super().__init__() + self.stats = EstimationTableStats() + self._columns: list[tuple[str, EstimationTableColumn]] = [ ("qubits", EstimationTableColumn(lambda entry: entry.qubits)), ( @@ -187,6 +210,20 @@ def insert_column( """ self._columns.insert(index, (name, EstimationTableColumn(function, formatter))) + def add_qubit_partition_column(self) -> None: + self.add_column( + "physical_compute_qubits", + lambda entry: entry.properties.get(PHYSICAL_COMPUTE_QUBITS, 0), + ) + self.add_column( + "physical_factory_qubits", + lambda entry: entry.properties.get(PHYSICAL_FACTORY_QUBITS, 0), + ) + self.add_column( + "physical_memory_qubits", + lambda entry: entry.properties.get(PHYSICAL_MEMORY_QUBITS, 0), + ) + def add_factory_summary_column(self) -> None: """Adds a column to the estimation table that summarizes the factories used in the estimation.""" @@ -268,4 +305,13 @@ class EstimationTableEntry: error: float source: InstructionSource factories: dict[int, FactoryResult] = field(default_factory=dict) - properties: dict[str, int | float | bool | str] = field(default_factory=dict) + properties: dict[int, int | float | bool | str] = field(default_factory=dict) + + +@dataclass(slots=True) +class EstimationTableStats: + num_traces: int = 0 + num_isas: int = 0 + total_jobs: int = 0 + successful_estimates: int = 0 + pareto_results: int = 0 diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index 1dc5b0a135..d6c25b704f 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -23,6 +23,7 @@ _Instruction, ISARequirements, instruction_name, + property_name_to_key, ) @@ -31,10 +32,6 @@ class Encoding(IntEnum): LOGICAL = 1 -class PropertyKey(IntEnum): - DISTANCE = 0 - - PHYSICAL = Encoding.PHYSICAL LOGICAL = Encoding.LOGICAL @@ -69,18 +66,14 @@ def constraint( for key, value in kwargs.items(): if value: - try: - prop_key = PropertyKey[key.upper()] - except KeyError: - raise ValueError( - f"Unknown property '{key}'. Valid properties: {[k.name.lower() for k in PropertyKey]}" - ) + if (prop_key := property_name_to_key(key)) is None: + raise ValueError(f"Unknown property '{key}'") + c.add_property(prop_key) return c - class ISATransform(ABC): """ Abstract base class for transformations between ISAs (e.g., QEC schemes). diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index e7d8fe29a0..46870b4aae 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -28,4 +28,7 @@ LatticeSurgery, PSSPC, Trace, + property_name_to_key, + _float_to_bits, + _float_from_bits, ) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 458bf1e842..3ad89d9ff4 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -853,21 +853,21 @@ class EstimationResult: ... @property - def properties(self) -> dict[str, bool | int | float | str]: + def properties(self) -> dict[int, bool | int | float | str]: """ Custom properties from application generation and trace transform. Returns: - dict[str, bool | int | float | str]: A dictionary mapping property keys to their values. + dict[int, bool | int | float | str]: A dictionary mapping property keys to their values. """ ... - def set_property(self, key: str, value: bool | int | float | str) -> None: + def set_property(self, key: int, value: bool | int | float | str) -> None: """ Sets a custom property. Args: - key (str): The property key. + key (int) The property key. value (bool | int | float | str): The property value. All values of type `int`, `float`, `bool`, and `str` are supported. Any other value is converted to a string using its `__str__` method. """ @@ -924,6 +924,27 @@ class _EstimationCollection: """ ... + @property + def total_jobs(self) -> int: + """ + Returns the total number of (trace, ISA) estimation jobs. + + Returns: + int: The total number of jobs. + """ + ... + + @property + def successful_estimates(self) -> int: + """ + Returns the number of estimation jobs that completed successfully + (before Pareto filtering). + + Returns: + int: The number of successful estimates. + """ + ... + class FactoryResult: """ Represents the result of a factory used in resource estimation. @@ -1098,36 +1119,36 @@ class Trace: """ ... - def set_property(self, key: str, value: Any) -> None: + def set_property(self, key: int, value: Any) -> None: """ Sets a property. All values of type `int`, `float`, `bool`, and `str` are supported. Any other value is converted to a string using its `__str__` method. Args: - key (str): The property key. + key (int): The property key. value (Any): The property value. """ ... - def get_property(self, key: str) -> Optional[int | float | bool | str]: + def get_property(self, key: int) -> Optional[int | float | bool | str]: """ Gets a property. Args: - key (str): The property key. + key (int): The property key. Returns: Optional[int | float | bool | str]: The property value, or None if not found. """ ... - def has_property(self, key: str) -> bool: + def has_property(self, key: int) -> bool: """ Checks if a property with the given key exists. Args: - key (str): The property key. + key (int): The property key. Returns: bool: True if the property exists, False otherwise. @@ -1359,6 +1380,18 @@ def _binom_ppf(q: float, n: int, p: float) -> int: """ ... +def _float_to_bits(f: float) -> int: + """ + Converts a float to its bit representation as an integer. + """ + ... + +def _float_from_bits(b: int) -> float: + """ + Converts a float from its bit representation as an integer. + """ + ... + def instruction_name(id: int) -> Optional[str]: """ Returns the name of an instruction given its ID, if known. @@ -1370,3 +1403,15 @@ def instruction_name(id: int) -> Optional[str]: Optional[str]: The name of the instruction, or None if the ID is not recognized. """ ... + +def property_name_to_key(name: str) -> Optional[int]: + """ + Converts a property name to its corresponding key, if known. + + Args: + name (str): The property name. + + Returns: + Optional[int]: The property key, or None if the name is not recognized. + """ + ... diff --git a/source/pip/qsharp/qre/interop/_qsharp.py b/source/pip/qsharp/qre/interop/_qsharp.py index d2f534fa4d..e14b372a9d 100644 --- a/source/pip/qsharp/qre/interop/_qsharp.py +++ b/source/pip/qsharp/qre/interop/_qsharp.py @@ -11,6 +11,7 @@ from ...estimator import LogicalCounts from .._qre import Trace from ..instruction_ids import CCX, MEAS_Z, RZ, T, READ_FROM_MEMORY, WRITE_TO_MEMORY +from ..property_keys import EVALUATION_TIME def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: @@ -75,7 +76,7 @@ def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: block = trace.add_block(repetitions=wtm_count) block.add_operation(WRITE_TO_MEMORY, [0, compute_qubits]) - trace.set_property("evaluation_time", evaluation_time) + trace.set_property(EVALUATION_TIME, evaluation_time) return trace diff --git a/source/pip/qsharp/qre/models/factories/_round_based.py b/source/pip/qsharp/qre/models/factories/_round_based.py index b14a0336fe..2371d0b9d2 100644 --- a/source/pip/qsharp/qre/models/factories/_round_based.py +++ b/source/pip/qsharp/qre/models/factories/_round_based.py @@ -225,15 +225,21 @@ def _state_from_pipeline(self, pipeline: _Pipeline) -> _Instruction: def _cache_key(self, impl_isa: ISA) -> str: """Build a deterministic key from factory configuration and impl_isa.""" - # You can refine this if ISA has a better serialization method. - payload = { - "factory": type(self).__qualname__, - "code_query": getattr( - self.code_query, "__qualname__", repr(self.code_query) - ), - "impl_isa": str(impl_isa), - } - data = repr(payload).encode("utf-8") + parts = [ + f"factory={type(self).__qualname__}", + f"code_query={repr(self.code_query)}", + f"physical_qubit_calculation={self.physical_qubit_calculation.__name__}", + ] + + # Include full instruction details, sorted by id for determinism + for instr in sorted(impl_isa, key=lambda i: i.id): + parts.append( + f"id={instr.id}|encoding={instr.encoding}|arity={instr.arity}" + f"|time={instr.time()}|space={instr.space()}" + f"|error_rate={instr.error_rate()}" + ) + + data = "\n".join(parts).encode("utf-8") return hashlib.sha256(data).hexdigest() def _cache_path(self, impl_isa: ISA) -> Path: diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py index ce80a23506..e402ea9c41 100644 --- a/source/pip/qsharp/qre/models/qec/_surface_code.py +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -15,6 +15,10 @@ from ..._isa_enumeration import _Context from ..._qre import linear_function from ...instruction_ids import CNOT, H, LATTICE_SURGERY, MEAS_Z +from ...property_keys import ( + SURFACE_CODE_ONE_QUBIT_TIME_FACTOR, + SURFACE_CODE_TWO_QUBIT_TIME_FACTOR, +) @dataclass @@ -30,6 +34,12 @@ class SurfaceCode(ISATransform): error_correction_threshold: float The error correction threshold for the surface code. (Default is 0.01 (1%), see [arXiv:1009.3686](https://arxiv.org/abs/1009.3686)) + one_qubit_gate_depth: int + The depth of one-qubit gates in each syndrome extraction cycle. + (Default is 1, see Fig. 2 in [arXiv:1009.3686](https://arxiv.org/abs/1009.3686)) + two_qubit_gate_depth: int + The depth of two-qubit gates in each syndrome extraction cycle. + (Default is 4, see Fig. 2 in [arXiv:1009.3686](https://arxiv.org/abs/1009.3686)) Hyper parameters: distance: int @@ -50,6 +60,8 @@ class SurfaceCode(ISATransform): crossing_prefactor: float = 0.03 error_correction_threshold: float = 0.01 + one_qubit_gate_depth: int = 1 + two_qubit_gate_depth: int = 4 _: KW_ONLY distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) @@ -81,8 +93,22 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non space_formula = linear_function(2 * self.distance**2 - 1) # Each syndrome extraction cycle consists of ancilla preparation, 4 - # rounds of CNOTs, and measurement. (See Fig. 2 in arXiv:1009.3686) - time_value = (h_time + 4 * cnot_time + meas_time) * self.distance + # rounds of CNOTs, and measurement. (See Fig. 2 in arXiv:1009.3686); + # these may be modified by the one_qubit_gate_depth and + # two_qubit_gate_depth parameters, or scaled by the time factors + # provided in the instruction properties. The syndrome extraction cycle + # is repeated d times for a distance-d code. + one_qubit_gate_depth = self.one_qubit_gate_depth * h.get_property_or( + SURFACE_CODE_ONE_QUBIT_TIME_FACTOR, 1 + ) + two_qubit_gate_depth = self.two_qubit_gate_depth * cnot.get_property_or( + SURFACE_CODE_TWO_QUBIT_TIME_FACTOR, 1 + ) + + code_cycle_time = ( + one_qubit_gate_depth * h_time + two_qubit_gate_depth * cnot_time + meas_time + ) + time_value = code_cycle_time * self.distance # See Eqs. (10) and (11) in arXiv:1208.0928 error_formula = linear_function( diff --git a/source/pip/qsharp/qre/models/qec/_yoked.py b/source/pip/qsharp/qre/models/qec/_yoked.py index 9b24069aa2..fd6652ddaf 100644 --- a/source/pip/qsharp/qre/models/qec/_yoked.py +++ b/source/pip/qsharp/qre/models/qec/_yoked.py @@ -6,10 +6,11 @@ from math import ceil from typing import Generator -from ..._instruction import ISATransform, constraint, LOGICAL, PropertyKey +from ..._instruction import ISATransform, constraint, LOGICAL from ..._qre import ISA, ISARequirements, generic_function from ..._architecture import _Context from ...instruction_ids import LATTICE_SURGERY, MEMORY +from ...property_keys import DISTANCE class ShapeHeuristic(IntEnum): @@ -70,7 +71,7 @@ def required_isa() -> ISARequirements: def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: lattice_surgery = impl_isa[LATTICE_SURGERY] - distance = lattice_surgery.get_property(PropertyKey.DISTANCE) + distance = lattice_surgery.get_property(DISTANCE) assert distance is not None shape_fn = [self._min_area_shape, self._square_shape][self.shape_heuristic] diff --git a/source/pip/qsharp/qre/property_keys.py b/source/pip/qsharp/qre/property_keys.py new file mode 100644 index 0000000000..917e25ca0d --- /dev/null +++ b/source/pip/qsharp/qre/property_keys.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pyright: reportAttributeAccessIssue=false + + +from .._native import property_keys + +for name in property_keys.__all__: + globals()[name] = getattr(property_keys, name) diff --git a/source/pip/qsharp/qre/property_keys.pyi b/source/pip/qsharp/qre/property_keys.pyi new file mode 100644 index 0000000000..17ebdb49af --- /dev/null +++ b/source/pip/qsharp/qre/property_keys.pyi @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +DISTANCE: int +SURFACE_CODE_ONE_QUBIT_TIME_FACTOR: int +SURFACE_CODE_TWO_QUBIT_TIME_FACTOR: int +ACCELERATION: int +NUM_TS_PER_ROTATION: int +EXPECTED_SHOTS: int +RUNTIME_SINGLE_SHOT: int +EVALUATION_TIME: int +PHYSICAL_COMPUTE_QUBITS: int +PHYSICAL_FACTORY_QUBITS: int +PHYSICAL_MEMORY_QUBITS: int diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 8c016d646a..ff16b4cde2 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -8,7 +8,7 @@ use std::{ use pyo3::{ IntoPyObjectExt, - exceptions::{PyException, PyKeyError, PyRuntimeError, PyTypeError}, + exceptions::{PyException, PyKeyError, PyRuntimeError, PyTypeError, PyValueError}, prelude::*, types::{PyBool, PyDict, PyFloat, PyInt, PyString, PyTuple, PyType}, }; @@ -38,11 +38,15 @@ pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(generic_function, m)?)?; m.add_function(wrap_pyfunction!(estimate_parallel, m)?)?; m.add_function(wrap_pyfunction!(binom_ppf, m)?)?; + m.add_function(wrap_pyfunction!(float_to_bits, m)?)?; + m.add_function(wrap_pyfunction!(float_from_bits, m)?)?; m.add_function(wrap_pyfunction!(instruction_name, m)?)?; + m.add_function(wrap_pyfunction!(property_name_to_key, m)?)?; m.add("EstimationError", m.py().get_type::())?; add_instruction_ids(m)?; + add_property_keys(m)?; Ok(()) } @@ -314,11 +318,15 @@ impl qre::ParetoItem2D for Instruction { type Objective2 = u64; fn objective1(&self) -> Self::Objective1 { - self.0.expect_space(None) + self.0 + .space(None) + .unwrap_or_else(|| self.0.expect_space(Some(1))) } fn objective2(&self) -> Self::Objective2 { - self.0.expect_time(None) + self.0 + .time(None) + .unwrap_or_else(|| self.0.expect_time(Some(1))) } } @@ -328,15 +336,21 @@ impl qre::ParetoItem3D for Instruction { type Objective3 = f64; fn objective1(&self) -> Self::Objective1 { - self.0.expect_space(None) + self.0 + .space(None) + .unwrap_or_else(|| self.0.expect_space(Some(1))) } fn objective2(&self) -> Self::Objective2 { - self.0.expect_time(None) + self.0 + .time(None) + .unwrap_or_else(|| self.0.expect_time(Some(1))) } fn objective3(&self) -> Self::Objective3 { - self.0.expect_error_rate(None) + self.0 + .error_rate(None) + .unwrap_or_else(|| self.0.expect_error_rate(Some(1))) } } @@ -377,16 +391,6 @@ fn convert_encoding(encoding: u64) -> PyResult { } } -/// Property name → integer key mapping (must match Python `_PROPERTY_KEYS`). -fn property_name_to_key(name: &str) -> PyResult { - match name { - "distance" => Ok(0), - other => Err(PyTypeError::new_err(format!( - "Unknown property '{other}'. Valid properties: [\"distance\"]" - ))), - } -} - /// Build a `qre::Instruction` from either an existing `Instruction` Python /// object or from keyword arguments (id + encoding + arity + …). #[allow(clippy::too_many_arguments)] @@ -446,7 +450,10 @@ fn build_instruction( if let Some(kw) = kwargs { for (key, value) in kw { let key_str: String = key.extract()?; - let prop_key = property_name_to_key(&key_str)?; + let prop_key = + qre::property_name_to_key(&key_str.to_ascii_uppercase()).ok_or_else(|| { + PyValueError::new_err(format!("Unknown property name: {key_str}")) + })?; let prop_value: u64 = value.extract()?; instr.set_property(prop_key, prop_value); } @@ -743,6 +750,16 @@ impl EstimationCollection { self.0.len() } + #[getter] + pub fn total_jobs(&self) -> usize { + self.0.total_jobs() + } + + #[getter] + pub fn successful_estimates(&self) -> usize { + self.0.successful_estimates() + } + #[allow(clippy::needless_pass_by_value)] pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { let iter = EstimationCollectionIterator { @@ -848,7 +865,7 @@ impl EstimationResult { Ok(dict) } - pub fn set_property(&mut self, key: String, value: &Bound<'_, PyAny>) -> PyResult<()> { + pub fn set_property(&mut self, key: u64, value: &Bound<'_, PyAny>) -> PyResult<()> { let property = if value.is_instance_of::() { qre::Property::new_bool(value.extract()?) } else if let Ok(i) = value.extract::() { @@ -939,7 +956,7 @@ impl Trace { self.0.increment_base_error(amount); } - pub fn set_property(&mut self, key: String, value: &Bound<'_, PyAny>) -> PyResult<()> { + pub fn set_property(&mut self, key: u64, value: &Bound<'_, PyAny>) -> PyResult<()> { let property = if value.is_instance_of::() { qre::Property::new_bool(value.extract()?) } else if let Ok(i) = value.extract::() { @@ -956,7 +973,7 @@ impl Trace { } #[allow(clippy::needless_pass_by_value)] - pub fn get_property<'py>(self_: PyRef<'py, Self>, key: &str) -> Option> { + pub fn get_property(self_: PyRef<'_, Self>, key: u64) -> Option> { if let Some(value) = self_.0.get_property(key) { match value { qre::Property::Bool(b) => PyBool::new(self_.py(), *b) @@ -977,7 +994,7 @@ impl Trace { } } - pub fn has_property(&self, key: &str) -> bool { + pub fn has_property(&self, key: u64) -> bool { self.0.has_property(key) } @@ -1004,7 +1021,10 @@ impl Trace { pub fn estimate(&self, isa: &ISA, max_error: Option) -> Option { self.0 .estimate(&isa.0, max_error) - .map(EstimationResult) + .map(|mut r| { + r.set_isa(isa.0.clone()); + EstimationResult(r) + }) .ok() } @@ -1292,6 +1312,16 @@ pub fn binom_ppf(q: f64, n: usize, p: f64) -> usize { qre::binom_ppf(q, n, p) } +#[pyfunction(name = "_float_to_bits")] +pub fn float_to_bits(f: f64) -> u64 { + qre::float_to_bits(f) +} + +#[pyfunction(name = "_float_from_bits")] +pub fn float_from_bits(bits: u64) -> f64 { + qre::float_from_bits(bits) +} + #[pyfunction] pub fn instruction_name(id: u64) -> Option { qre::instruction_name(id).map(String::from) @@ -1387,3 +1417,39 @@ fn add_instruction_ids(m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } + +#[pyfunction] +pub fn property_name_to_key(name: &str) -> Option { + qre::property_name_to_key(&name.to_ascii_uppercase()) +} + +fn add_property_keys(m: &Bound<'_, PyModule>) -> PyResult<()> { + #[allow(clippy::wildcard_imports)] + use qre::property_keys::*; + + let property_keys = PyModule::new(m.py(), "property_keys")?; + + macro_rules! add_ids { + ($($name:ident),* $(,)?) => { + $(property_keys.add(stringify!($name), $name)?;)* + }; + } + + add_ids!( + DISTANCE, + SURFACE_CODE_ONE_QUBIT_TIME_FACTOR, + SURFACE_CODE_TWO_QUBIT_TIME_FACTOR, + ACCELERATION, + NUM_TS_PER_ROTATION, + EXPECTED_SHOTS, + RUNTIME_SINGLE_SHOT, + EVALUATION_TIME, + PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_FACTORY_QUBITS, + PHYSICAL_MEMORY_QUBITS, + ); + + m.add_submodule(&property_keys)?; + + Ok(()) +} diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index ecd5b7e673..7454a639b9 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -17,7 +17,6 @@ ISARequirements, ISATransform, LatticeSurgery, - PropertyKey, Trace, constraint, estimate, @@ -40,6 +39,7 @@ ISARefNode, ) from qsharp.qre.instruction_ids import CCX, CCZ, LATTICE_SURGERY, T, RZ +from qsharp.qre.property_keys import DISTANCE, NUM_TS_PER_ROTATION # NOTE These classes will be generalized as part of the QRE API in the following # pull requests and then moved out of the tests. @@ -123,17 +123,17 @@ def test_isa(): def test_instruction_properties(): # Test instruction with no properties instr_no_props = _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {}) - assert instr_no_props.get_property(PropertyKey.DISTANCE) is None - assert instr_no_props.has_property(PropertyKey.DISTANCE) is False - assert instr_no_props.get_property_or(PropertyKey.DISTANCE, 5) == 5 + assert instr_no_props.get_property(DISTANCE) is None + assert instr_no_props.has_property(DISTANCE) is False + assert instr_no_props.get_property_or(DISTANCE, 5) == 5 # Test instruction with valid property (distance) instr_with_distance = _make_instruction( T, 1, 1, 1000, None, None, 1e-8, {"distance": 9} ) - assert instr_with_distance.get_property(PropertyKey.DISTANCE) == 9 - assert instr_with_distance.has_property(PropertyKey.DISTANCE) is True - assert instr_with_distance.get_property_or(PropertyKey.DISTANCE, 5) == 9 + assert instr_with_distance.get_property(DISTANCE) == 9 + assert instr_with_distance.has_property(DISTANCE) is True + assert instr_with_distance.get_property_or(DISTANCE, 5) == 9 # Test instruction with invalid property name with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): @@ -143,15 +143,15 @@ def test_instruction_properties(): def test_instruction_constraints(): # Test constraint without properties c_no_props = constraint(T, encoding=LOGICAL) - assert c_no_props.has_property(PropertyKey.DISTANCE) is False + assert c_no_props.has_property(DISTANCE) is False # Test constraint with valid property (distance=True) c_with_distance = constraint(T, encoding=LOGICAL, distance=True) - assert c_with_distance.has_property(PropertyKey.DISTANCE) is True + assert c_with_distance.has_property(DISTANCE) is True # Test constraint with distance=False (should not add the property) c_distance_false = constraint(T, encoding=LOGICAL, distance=False) - assert c_distance_false.has_property(PropertyKey.DISTANCE) is False + assert c_distance_false.has_property(DISTANCE) is False # Test constraint with invalid property name with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): @@ -742,21 +742,26 @@ def test_sum_isa_enumeration_nodes(): def test_trace_properties(): trace = Trace(42) - trace.set_property("int", 42) - assert trace.get_property("int") == 42 - assert isinstance(trace.get_property("int"), int) + INT = 0 + FLOAT = 1 + BOOL = 2 + STR = 3 - trace.set_property("float", 3.14) - assert trace.get_property("float") == 3.14 - assert isinstance(trace.get_property("float"), float) + trace.set_property(INT, 42) + assert trace.get_property(INT) == 42 + assert isinstance(trace.get_property(INT), int) - trace.set_property("bool", True) - assert trace.get_property("bool") is True - assert isinstance(trace.get_property("bool"), bool) + trace.set_property(FLOAT, 3.14) + assert trace.get_property(FLOAT) == 3.14 + assert isinstance(trace.get_property(FLOAT), float) - trace.set_property("str", "hello") - assert trace.get_property("str") == "hello" - assert isinstance(trace.get_property("str"), str) + trace.set_property(BOOL, True) + assert trace.get_property(BOOL) is True + assert isinstance(trace.get_property(BOOL), bool) + + trace.set_property(STR, "hello") + assert trace.get_property(STR) == "hello" + assert isinstance(trace.get_property(STR), str) def test_qsharp_application(): @@ -984,11 +989,13 @@ def test_estimation_table_empty(): def test_estimation_table_add_column(): """Test adding a column to the table.""" + VAL = 0 + table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={"val": 42})) - table.append(_make_entry(200, 10000, 0.02, properties={"val": 84})) + table.append(_make_entry(100, 5000, 0.01, properties={VAL: 42})) + table.append(_make_entry(200, 10000, 0.02, properties={VAL: 84})) - table.add_column("val", lambda e: e.properties["val"]) + table.add_column("val", lambda e: e.properties[VAL]) frame = table.as_frame() assert list(frame.columns) == ["qubits", "runtime", "error", "val"] @@ -997,12 +1004,14 @@ def test_estimation_table_add_column(): def test_estimation_table_add_column_with_formatter(): """Test adding a column with a formatter.""" + NS = 0 + table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={"ns": 1000})) + table.append(_make_entry(100, 5000, 0.01, properties={NS: 1000})) table.add_column( "duration", - lambda e: e.properties["ns"], + lambda e: e.properties[NS], formatter=lambda x: pd.Timedelta(x, unit="ns"), ) @@ -1012,12 +1021,16 @@ def test_estimation_table_add_column_with_formatter(): def test_estimation_table_add_multiple_columns(): """Test adding multiple columns preserves order.""" + A = 0 + B = 1 + C = 2 + table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={"a": 1, "b": 2, "c": 3})) + table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2, C: 3})) - table.add_column("a", lambda e: e.properties["a"]) - table.add_column("b", lambda e: e.properties["b"]) - table.add_column("c", lambda e: e.properties["c"]) + table.add_column("a", lambda e: e.properties[A]) + table.add_column("b", lambda e: e.properties[B]) + table.add_column("c", lambda e: e.properties[C]) frame = table.as_frame() assert list(frame.columns) == ["qubits", "runtime", "error", "a", "b", "c"] @@ -1028,10 +1041,12 @@ def test_estimation_table_add_multiple_columns(): def test_estimation_table_insert_column_at_beginning(): """Test inserting a column at index 0.""" + NAME = 0 + table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={"name": "test"})) + table.append(_make_entry(100, 5000, 0.01, properties={NAME: "test"})) - table.insert_column(0, "name", lambda e: e.properties["name"]) + table.insert_column(0, "name", lambda e: e.properties[NAME]) frame = table.as_frame() assert list(frame.columns) == ["name", "qubits", "runtime", "error"] @@ -1040,11 +1055,13 @@ def test_estimation_table_insert_column_at_beginning(): def test_estimation_table_insert_column_in_middle(): """Test inserting a column between existing default columns.""" + EXTRA = 0 + table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={"extra": 99})) + table.append(_make_entry(100, 5000, 0.01, properties={EXTRA: 99})) # Insert between qubits and runtime (index 1) - table.insert_column(1, "extra", lambda e: e.properties["extra"]) + table.insert_column(1, "extra", lambda e: e.properties[EXTRA]) frame = table.as_frame() assert list(frame.columns) == ["qubits", "extra", "runtime", "error"] @@ -1053,11 +1070,13 @@ def test_estimation_table_insert_column_in_middle(): def test_estimation_table_insert_column_at_end(): """Test inserting a column at the end (same effect as add_column).""" + LAST = 0 + table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={"last": True})) + table.append(_make_entry(100, 5000, 0.01, properties={LAST: True})) # 3 default columns, inserting at index 3 = end - table.insert_column(3, "last", lambda e: e.properties["last"]) + table.insert_column(3, "last", lambda e: e.properties[LAST]) frame = table.as_frame() assert list(frame.columns) == ["qubits", "runtime", "error", "last"] @@ -1066,13 +1085,15 @@ def test_estimation_table_insert_column_at_end(): def test_estimation_table_insert_column_with_formatter(): """Test inserting a column with a formatter.""" + NS = 0 + table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={"ns": 2000})) + table.append(_make_entry(100, 5000, 0.01, properties={NS: 2000})) table.insert_column( 0, "custom_time", - lambda e: e.properties["ns"], + lambda e: e.properties[NS], formatter=lambda x: pd.Timedelta(x, unit="ns"), ) @@ -1083,11 +1104,14 @@ def test_estimation_table_insert_column_with_formatter(): def test_estimation_table_insert_and_add_columns(): """Test combining insert_column and add_column.""" + A = 0 + B = 0 + table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={"a": 1, "b": 2})) + table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2})) - table.add_column("b", lambda e: e.properties["b"]) - table.insert_column(0, "a", lambda e: e.properties["a"]) + table.add_column("b", lambda e: e.properties[B]) + table.insert_column(0, "a", lambda e: e.properties[A]) frame = table.as_frame() assert list(frame.columns) == ["a", "qubits", "runtime", "error", "b"] @@ -1160,7 +1184,7 @@ def test_estimation_table_add_column_from_source(): results.add_column( "compute_distance", - lambda entry: entry.source[LATTICE_SURGERY].instruction[PropertyKey.DISTANCE], + lambda entry: entry.source[LATTICE_SURGERY].instruction[DISTANCE], ) frame = results.as_frame() @@ -1194,7 +1218,7 @@ def test_estimation_table_add_column_from_properties(): results.add_column( "num_ts_per_rotation", - lambda entry: entry.properties["num_ts_per_rotation"], + lambda entry: entry.properties[NUM_TS_PER_ROTATION], ) frame = results.as_frame() @@ -1227,8 +1251,7 @@ def test_estimation_table_insert_column_before_defaults(): assert len(results) >= 1 - # Insert a name column at the beginning and add factory summary at the end - results.insert_column(0, "name", lambda entry: entry.properties.get("name", "")) + # Add a factory summary at the end results.add_factory_summary_column() frame = results.as_frame() @@ -1264,3 +1287,5 @@ def test_estimation_table_computed_column(): frame = table.as_frame() assert frame["qubit_error_product"][0] == pytest.approx(1.0) assert frame["qubit_error_product"][1] == pytest.approx(4.0) + + diff --git a/source/pip/tests/test_qre_models.py b/source/pip/tests/test_qre_models.py index a8c8f8462a..d0e8c33d37 100644 --- a/source/pip/tests/test_qre_models.py +++ b/source/pip/tests/test_qre_models.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qsharp.qre import LOGICAL, PHYSICAL, PropertyKey +from qsharp.qre import LOGICAL, PHYSICAL from qsharp.qre.instruction_ids import ( T, CCZ, @@ -36,6 +36,7 @@ ThreeAux, YokedSurfaceCode, ) +from qsharp.qre.property_keys import DISTANCE # --------------------------------------------------------------------------- @@ -465,7 +466,7 @@ def test_distance_property_propagated(self): isas = list(ysc.provided_isa(ls_isa, ctx)) mem = isas[0][MEMORY] - assert mem.get_property(PropertyKey.DISTANCE) == d + assert mem.get_property(DISTANCE) == d # --------------------------------------------------------------------------- diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index 01dac1762c..4ed2e4d5b5 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -4,7 +4,7 @@ use std::{ fmt::Display, ops::Add, - sync::{Arc, RwLock}, + sync::{Arc, RwLock, RwLockReadGuard}, }; use num_traits::FromPrimitive; @@ -13,6 +13,8 @@ use serde::{Deserialize, Serialize}; use crate::trace::instruction_ids::instruction_name; +pub mod property_keys; + #[cfg(test)] mod tests; @@ -96,7 +98,17 @@ impl ISA { self.nodes.is_empty() } - fn read_graph(&self) -> std::sync::RwLockReadGuard<'_, ProvenanceGraph> { + /// Returns a read-locked view of this ISA, enabling zero-clone + /// instruction access for the lifetime of the returned guard. + #[must_use] + pub fn lock(&self) -> LockedISA<'_> { + LockedISA { + graph: self.read_graph(), + nodes: &self.nodes, + } + } + + fn read_graph(&self) -> RwLockReadGuard<'_, ProvenanceGraph> { self.graph.read().expect("provenance graph lock poisoned") } @@ -222,6 +234,22 @@ impl Add for ISA { } } +/// A read-locked view of an ISA. Holds the graph read lock for the +/// lifetime of this struct, enabling zero-clone instruction access. +pub struct LockedISA<'a> { + graph: RwLockReadGuard<'a, ProvenanceGraph>, + nodes: &'a FxHashMap, +} + +impl LockedISA<'_> { + /// Returns a reference to the instruction with the given ID, if present. + #[must_use] + pub fn get(&self, id: &u64) -> Option<&Instruction> { + let &node_idx = self.nodes.get(id)?; + Some(self.graph.instruction(node_idx)) + } +} + #[derive(Default)] pub struct ISARequirements { constraints: FxHashMap, diff --git a/source/qre/src/isa/property_keys.rs b/source/qre/src/isa/property_keys.rs new file mode 100644 index 0000000000..b031cba828 --- /dev/null +++ b/source/qre/src/isa/property_keys.rs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// NOTE: To add a new property key: +// 1. Add the name to the `define_properties!` macro below (values are auto-assigned) +// 2. Add it to `add_property_keys` in qre.rs +// 3. Add it to property_keys.pyi +// +// The `property_name_to_key` function is auto-generated from the entries. + +/// Property keys for instruction properties. These are used to query properties of instructions in the ISA. +macro_rules! define_properties { + // Internal rule: accumulator-based counting to auto-assign incrementing u64 values. + (@step $counter:expr, $name:ident, $($rest:ident),* $(,)?) => { + pub const $name: u64 = $counter; + define_properties!(@step $counter + 1, $($rest),*); + }; + (@step $counter:expr, $name:ident $(,)?) => { + pub const $name: u64 = $counter; + }; + // Entry point + ( $($name:ident),* $(,)? ) => { + define_properties!(@step 0, $($name),*); + + /// Property name → integer key mapping + pub fn property_name_to_key(name: &str) -> Option { + match name { + $( + stringify!($name) => Some($name), + )* + _ => None + } + } + }; +} + +define_properties! { + DISTANCE, + SURFACE_CODE_ONE_QUBIT_TIME_FACTOR, + SURFACE_CODE_TWO_QUBIT_TIME_FACTOR, + ACCELERATION, + NUM_TS_PER_ROTATION, + RUNTIME_SINGLE_SHOT, + EXPECTED_SHOTS, + EVALUATION_TIME, + PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_FACTORY_QUBITS, + PHYSICAL_MEMORY_QUBITS, +} diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs index bf73b47ffd..e7c13e01f5 100644 --- a/source/qre/src/lib.rs +++ b/source/qre/src/lib.rs @@ -9,17 +9,19 @@ pub use pareto::{ ParetoFrontier as ParetoFrontier2D, ParetoFrontier3D, ParetoItem2D, ParetoItem3D, }; mod result; -pub use result::{EstimationCollection, EstimationResult, FactoryResult}; -mod trace; +pub use isa::property_keys; +pub use isa::property_keys::property_name_to_key; pub use isa::{ - ConstraintBound, Encoding, ISA, ISARequirements, Instruction, InstructionConstraint, + ConstraintBound, Encoding, ISA, ISARequirements, Instruction, InstructionConstraint, LockedISA, ProvenanceGraph, VariableArityFunction, }; +pub use result::{EstimationCollection, EstimationResult, FactoryResult}; +mod trace; pub use trace::instruction_ids; pub use trace::instruction_ids::instruction_name; pub use trace::{Block, LatticeSurgery, PSSPC, Property, Trace, TraceTransform, estimate_parallel}; mod utils; -pub use utils::binom_ppf; +pub use utils::{binom_ppf, float_from_bits, float_to_bits}; /// A resourc estimation error. #[derive(Clone, Debug, Error, PartialEq)] diff --git a/source/qre/src/pareto.rs b/source/qre/src/pareto.rs index ace12ca172..414faaad39 100644 --- a/source/qre/src/pareto.rs +++ b/source/qre/src/pareto.rs @@ -87,6 +87,10 @@ impl ParetoFrontier { self.0.iter() } + pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, I> { + self.0.iter_mut() + } + #[must_use] pub fn len(&self) -> usize { self.0.len() @@ -132,6 +136,15 @@ impl<'a, I: ParetoItem2D> IntoIterator for &'a ParetoFrontier { } } +impl<'a, I: ParetoItem2D> IntoIterator for &'a mut ParetoFrontier { + type Item = &'a mut I; + type IntoIter = std::slice::IterMut<'a, I>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut() + } +} + /// A Pareto frontier for 3-dimensional objectives. /// /// The implementation maintains the frontier sorted lexicographically. diff --git a/source/qre/src/result.rs b/source/qre/src/result.rs index 485b6c1135..f8b48c016d 100644 --- a/source/qre/src/result.rs +++ b/source/qre/src/result.rs @@ -17,7 +17,8 @@ pub struct EstimationResult { error: f64, factories: FxHashMap, isa: ISA, - properties: FxHashMap, + isa_index: Option, + properties: FxHashMap, } impl EstimationResult { @@ -89,22 +90,31 @@ impl EstimationResult { &self.isa } - pub fn set_property(&mut self, key: String, value: Property) { + pub fn set_isa_index(&mut self, index: usize) { + self.isa_index = Some(index); + } + + #[must_use] + pub fn isa_index(&self) -> Option { + self.isa_index + } + + pub fn set_property(&mut self, key: u64, value: Property) { self.properties.insert(key, value); } #[must_use] - pub fn get_property(&self, key: &str) -> Option<&Property> { - self.properties.get(key) + pub fn get_property(&self, key: u64) -> Option<&Property> { + self.properties.get(&key) } #[must_use] - pub fn has_property(&self, key: &str) -> bool { - self.properties.contains_key(key) + pub fn has_property(&self, key: u64) -> bool { + self.properties.contains_key(&key) } #[must_use] - pub fn properties(&self) -> &FxHashMap { + pub fn properties(&self) -> &FxHashMap { &self.properties } } @@ -146,26 +156,48 @@ impl ParetoItem2D for EstimationResult { } #[derive(Default)] -pub struct EstimationCollection(ParetoFrontier2D); +pub struct EstimationCollection { + frontier: ParetoFrontier2D, + total_jobs: usize, + successful_estimates: usize, +} impl EstimationCollection { #[must_use] pub fn new() -> Self { Self::default() } + + #[must_use] + pub fn total_jobs(&self) -> usize { + self.total_jobs + } + + pub fn set_total_jobs(&mut self, total_jobs: usize) { + self.total_jobs = total_jobs; + } + + #[must_use] + pub fn successful_estimates(&self) -> usize { + self.successful_estimates + } + + pub fn set_successful_estimates(&mut self, successful_estimates: usize) { + self.successful_estimates = successful_estimates; + } } impl Deref for EstimationCollection { type Target = ParetoFrontier2D; fn deref(&self) -> &Self::Target { - &self.0 + &self.frontier } } impl DerefMut for EstimationCollection { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 + &mut self.frontier } } diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 60f1764a4e..2ac610be15 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -6,7 +6,10 @@ use std::fmt::{Display, Formatter}; use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; -use crate::{Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction}; +use crate::{ + Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction, LockedISA, + property_keys::{PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS}, +}; pub mod instruction_ids; use instruction_ids::instruction_name; @@ -23,7 +26,7 @@ pub struct Trace { compute_qubits: u64, memory_qubits: Option, resource_states: Option>, - properties: FxHashMap, + properties: FxHashMap, } impl Trace { @@ -119,18 +122,18 @@ impl Trace { 0 } - pub fn set_property(&mut self, key: String, value: Property) { + pub fn set_property(&mut self, key: u64, value: Property) { self.properties.insert(key, value); } #[must_use] - pub fn get_property(&self, key: &str) -> Option<&Property> { - self.properties.get(key) + pub fn get_property(&self, key: u64) -> Option<&Property> { + self.properties.get(&key) } #[must_use] - pub fn has_property(&self, key: &str) -> bool { - self.properties.contains_key(key) + pub fn has_property(&self, key: u64) -> bool { + self.properties.contains_key(&key) } #[must_use] @@ -146,9 +149,11 @@ impl Trace { #[allow( clippy::cast_precision_loss, clippy::cast_possible_truncation, - clippy::cast_sign_loss + clippy::cast_sign_loss, + clippy::too_many_lines )] pub fn estimate(&self, isa: &ISA, max_error: Option) -> Result { + let locked = isa.lock(); let max_error = max_error.unwrap_or(1.0); if self.base_error > max_error { @@ -176,7 +181,7 @@ impl Trace { // ------------------------------------------------------------------ if let Some(resource_states) = &self.resource_states { for (state_id, count) in resource_states { - let rate = get_error_rate_by_id(isa, *state_id)?; + let rate = get_error_rate_by_id(&locked, *state_id)?; let actual_error = result.add_error(rate * (*count as f64)); if actual_error > max_error { return Err(Error::MaximumErrorExceeded { @@ -194,7 +199,7 @@ impl Trace { // Missing instructions raise an error. Callable rates use arity. // ------------------------------------------------------------------ for (gate, mult) in self.deep_iter() { - let instr = get_instruction(isa, gate.id)?; + let instr = get_instruction(&locked, gate.id)?; let arity = gate.qubits.len() as u64; @@ -219,11 +224,15 @@ impl Trace { * qubit_counts.last().copied().unwrap_or(1.0)) .ceil() as u64; result.add_qubits(total_compute_qubits); + result.set_property( + PHYSICAL_COMPUTE_QUBITS, + Property::Int(total_compute_qubits.cast_signed()), + ); result.add_runtime( self.block .depth_and_used(Some(&|op: &Gate| { - let instr = get_instruction(isa, op.id)?; + let instr = get_instruction(&locked, op.id)?; Ok(instr.expect_time(Some(op.qubits.len() as u64))) }))? .0, @@ -233,11 +242,12 @@ impl Trace { // Factory overhead estimation. Each factory produces states at // a certain rate, so we need enough copies to meet the demand. // ------------------------------------------------------------------ + let mut total_factory_qubits = 0; for (factory, count) in &factories { - let instr = get_instruction(isa, *factory)?; - let factory_time = get_time(&instr)?; - let factory_space = get_space(&instr)?; - let factory_error_rate = get_error_rate(&instr)?; + let instr = get_instruction(&locked, *factory)?; + let factory_time = get_time(instr)?; + let factory_space = get_space(instr)?; + let factory_error_rate = get_error_rate(instr)?; let runs = result.runtime() / factory_time; if runs == 0 { @@ -250,21 +260,31 @@ impl Trace { let copies = count.div_ceil(runs); - result.add_qubits(copies * factory_space); + total_factory_qubits += copies * factory_space; result.add_factory_result( *factory, FactoryResult::new(copies, runs, *count, factory_error_rate), ); } + result.add_qubits(total_factory_qubits); + result.set_property( + PHYSICAL_FACTORY_QUBITS, + Property::Int(total_factory_qubits.cast_signed()), + ); // Memory qubits if let Some(memory_qubits) = self.memory_qubits { // We need a MEMORY instruction in our ISA - let memory = isa + let memory = locked .get(&instruction_ids::MEMORY) .ok_or(Error::InstructionNotFound(instruction_ids::MEMORY))?; - result.add_qubits(memory.expect_space(Some(memory_qubits))); + let memory_space = memory.expect_space(Some(memory_qubits)); + result.add_qubits(memory_space); + result.set_property( + PHYSICAL_MEMORY_QUBITS, + Property::Int(memory_space.cast_signed()), + ); // The number of rounds for the memory qubits to stay alive with // respect to the total runtime of the algorithm. @@ -282,11 +302,9 @@ impl Trace { } } - result.set_isa(isa.clone()); - // Copy properties from the trace to the result for (key, value) in &self.properties { - result.set_property(key.clone(), value.clone()); + result.set_property(*key, value.clone()); } Ok(result) @@ -590,7 +608,7 @@ impl Display for Property { // Some helper functions to extract instructions and their metrics together with // error handling -fn get_instruction(isa: &ISA, id: u64) -> Result { +fn get_instruction<'a>(isa: &'a LockedISA<'_>, id: u64) -> Result<&'a Instruction, Error> { isa.get(&id).ok_or(Error::InstructionNotFound(id)) } @@ -612,7 +630,7 @@ fn get_error_rate(instruction: &Instruction) -> Result { .ok_or(Error::CannotExtractErrorRate(instruction.id())) } -fn get_error_rate_by_id(isa: &ISA, id: u64) -> Result { +fn get_error_rate_by_id(isa: &LockedISA<'_>, id: u64) -> Result { let instr = get_instruction(isa, id)?; instr .error_rate(None) @@ -650,6 +668,8 @@ pub fn estimate_parallel<'a>( let next_job = std::sync::atomic::AtomicUsize::new(0); let mut collection = EstimationCollection::new(); + collection.set_total_jobs(total_jobs); + std::thread::scope(|scope| { let num_threads = std::thread::available_parallelism() .map(std::num::NonZero::get) @@ -677,7 +697,9 @@ pub fn estimate_parallel<'a>( let trace_idx = job / num_isas; let isa_idx = job % num_isas; - if let Ok(estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) { + if let Ok(mut estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) + { + estimation.set_isa_index(isa_idx); local_results.push(estimation); } } @@ -690,10 +712,21 @@ pub fn estimate_parallel<'a>( drop(tx); // Collect results from all workers into the shared collection. + let mut successful = 0; for local_results in rx { + successful += local_results.len(); collection.extend(local_results.into_iter()); } + collection.set_successful_estimates(successful); }); + // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap + // clones for discarded results. + for result in collection.iter_mut() { + if let Some(idx) = result.isa_index() { + result.set_isa(isas[idx].clone()); + } + } + collection } diff --git a/source/qre/src/trace/transforms/psspc.rs b/source/qre/src/trace/transforms/psspc.rs index 0f671b14f3..80ec36bd99 100644 --- a/source/qre/src/trace/transforms/psspc.rs +++ b/source/qre/src/trace/transforms/psspc.rs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +use crate::property_keys::NUM_TS_PER_ROTATION; use crate::trace::{Gate, TraceTransform}; use crate::{Error, Property, Trace, instruction_ids}; @@ -147,7 +148,7 @@ impl PSSPC { // Track some properties transformed.set_property( - String::from("num_ts_per_rotation"), + NUM_TS_PER_ROTATION, Property::Int(self.num_ts_per_rotation as i64), ); diff --git a/source/qre/src/utils.rs b/source/qre/src/utils.rs index ea6dde8623..ffa82b04d4 100644 --- a/source/qre/src/utils.rs +++ b/source/qre/src/utils.rs @@ -10,3 +10,13 @@ pub fn binom_ppf(q: f64, n: usize, p: f64) -> usize { let dist = Binomial::with_failure(n, 1.0 - p); dist.inverse(q) } + +#[must_use] +pub fn float_to_bits(f: f64) -> u64 { + f.to_bits() +} + +#[must_use] +pub fn float_from_bits(b: u64) -> f64 { + f64::from_bits(b) +} From 8b1b520199c623d0ea6e29dbb3e36f1738da0dcf Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 16 Mar 2026 18:37:34 +0100 Subject: [PATCH 25/45] QIR interop in QRE (#3018) This adds QIR support to QRE. It uses existing code to walk the QIR tree. The transformation will fail in programs with branches until we have better heuristics for branch prediction in place. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- source/pip/qsharp/qre/interop/__init__.py | 3 +- source/pip/qsharp/qre/interop/_qir.py | 135 ++++++++++++++++ source/pip/tests/test_qre.py | 181 ++++++++++++++++++++++ 3 files changed, 318 insertions(+), 1 deletion(-) create mode 100644 source/pip/qsharp/qre/interop/_qir.py diff --git a/source/pip/qsharp/qre/interop/__init__.py b/source/pip/qsharp/qre/interop/__init__.py index 3b6c04c922..5f49608679 100644 --- a/source/pip/qsharp/qre/interop/__init__.py +++ b/source/pip/qsharp/qre/interop/__init__.py @@ -2,5 +2,6 @@ # Licensed under the MIT License. from ._qsharp import trace_from_entry_expr, trace_from_entry_expr_cached +from ._qir import trace_from_qir -__all__ = ["trace_from_entry_expr", "trace_from_entry_expr_cached"] +__all__ = ["trace_from_entry_expr", "trace_from_entry_expr_cached", "trace_from_qir"] diff --git a/source/pip/qsharp/qre/interop/_qir.py b/source/pip/qsharp/qre/interop/_qir.py new file mode 100644 index 0000000000..e3d9499c36 --- /dev/null +++ b/source/pip/qsharp/qre/interop/_qir.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import pyqir + +from ..._native import QirInstructionId +from ..._simulation import AggregateGatesPass +from .. import instruction_ids as ids +from .._qre import Trace + +# Maps QirInstructionId to (instruction_id, arity) where arity is: +# 1 = single-qubit gate: tuple is (op, qubit) +# 2 = two-qubit gate: tuple is (op, qubit1, qubit2) +# 3 = three-qubit gate: tuple is (op, qubit1, qubit2, qubit3) +# -1 = single-qubit rotation: tuple is (op, angle, qubit) +# -2 = two-qubit rotation: tuple is (op, angle, qubit1, qubit2) +_GATE_MAP: list[tuple[QirInstructionId, int, int]] = [ + # Single-qubit gates + (QirInstructionId.I, ids.PAULI_I, 1), + (QirInstructionId.H, ids.H, 1), + (QirInstructionId.X, ids.PAULI_X, 1), + (QirInstructionId.Y, ids.PAULI_Y, 1), + (QirInstructionId.Z, ids.PAULI_Z, 1), + (QirInstructionId.S, ids.S, 1), + (QirInstructionId.SAdj, ids.S_DAG, 1), + (QirInstructionId.SX, ids.SQRT_X, 1), + (QirInstructionId.SXAdj, ids.SQRT_X_DAG, 1), + (QirInstructionId.T, ids.T, 1), + (QirInstructionId.TAdj, ids.T_DAG, 1), + # Two-qubit gates + (QirInstructionId.CNOT, ids.CNOT, 2), + (QirInstructionId.CX, ids.CX, 2), + (QirInstructionId.CY, ids.CY, 2), + (QirInstructionId.CZ, ids.CZ, 2), + (QirInstructionId.SWAP, ids.SWAP, 2), + # Three-qubit gates + (QirInstructionId.CCX, ids.CCX, 3), + # Single-qubit rotations (op, angle, qubit) + (QirInstructionId.RX, ids.RX, -1), + (QirInstructionId.RY, ids.RY, -1), + (QirInstructionId.RZ, ids.RZ, -1), + # Two-qubit rotations (op, angle, qubit1, qubit2) + (QirInstructionId.RXX, ids.RXX, -2), + (QirInstructionId.RYY, ids.RYY, -2), + (QirInstructionId.RZZ, ids.RZZ, -2), +] + +_MEAS_MAP: list[tuple[QirInstructionId, int]] = [ + (QirInstructionId.M, ids.MEAS_Z), + (QirInstructionId.MZ, ids.MEAS_Z), + (QirInstructionId.MResetZ, ids.MEAS_RESET_Z), +] + +_SKIP = ( + # Resets qubit to |0⟩ without measuring; we do not currently account for + # that in resource estimation + QirInstructionId.RESET, + # Runtime qubit state transfer; an implementation detail, not a logical operation + QirInstructionId.Move, + # Reads a measurement result from classical memory; purely classical I/O + QirInstructionId.ReadResult, + # The following are classical output recording operations that do not represent + # quantum operations and have no impact on resource estimation. + QirInstructionId.ResultRecordOutput, + QirInstructionId.BoolRecordOutput, + QirInstructionId.IntRecordOutput, + QirInstructionId.DoubleRecordOutput, + QirInstructionId.TupleRecordOutput, + QirInstructionId.ArrayRecordOutput, +) + + +def trace_from_qir(input: str | bytes) -> Trace: + """Convert a QIR program into a resource-estimation Trace. + + Parses the QIR module, extracts quantum gates, and builds a Trace that + can be used for resource estimation. Conditional branches are resolved + by always following the false path (assuming measurement results are Zero). + + Args: + input: QIR input as LLVM IR text (str) or bitcode (bytes). + + Returns: + A Trace containing the quantum operations from the QIR program. + """ + context = pyqir.Context() + + if isinstance(input, str): + mod = pyqir.Module.from_ir(context, input) + else: + mod = pyqir.Module.from_bitcode(context, input) + + gates, num_qubits, _ = AggregateGatesPass().run(mod) + + trace = Trace(compute_qubits=num_qubits) + + for gate in gates: + # NOTE: AggregateGatesPass does not return QirInstruction objects + assert isinstance(gate, tuple) + _add_gate(trace, gate) + + return trace + + +def _add_gate(trace: Trace, gate: tuple) -> None: + op = gate[0] + + for qir_id, instr_id, arity in _GATE_MAP: + if op == qir_id: + if arity == 1: + trace.add_operation(instr_id, [gate[1]]) + elif arity == 2: + trace.add_operation(instr_id, [gate[1], gate[2]]) + elif arity == 3: + trace.add_operation(instr_id, [gate[1], gate[2], gate[3]]) + elif arity == -1: + trace.add_operation(instr_id, [gate[2]], [gate[1]]) + elif arity == -2: + trace.add_operation(instr_id, [gate[2], gate[3]], [gate[1]]) + return + + for qir_id, instr_id in _MEAS_MAP: + if op == qir_id: + trace.add_operation(instr_id, [gate[1]]) + return + + for skip_id in _SKIP: + if op == skip_id: + return + + # The only unhandled QirInstructionId is CorrelatedNoise + assert op == QirInstructionId.CorrelatedNoise, f"Unexpected QIR instruction: {op}" + raise NotImplementedError(f"Unsupported QIR instruction: {op}") diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 7454a639b9..3b1817f180 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -3,6 +3,7 @@ from dataclasses import KW_ONLY, dataclass, field from enum import Enum +from pathlib import Path from typing import cast, Generator import pytest @@ -29,6 +30,7 @@ SurfaceCode, AQREGateBased, ) +from qsharp.qre.interop import trace_from_qir from qsharp.qre._architecture import _Context, _make_instruction from qsharp.qre._estimation import ( EstimationTable, @@ -1289,3 +1291,182 @@ def test_estimation_table_computed_column(): assert frame["qubit_error_product"][1] == pytest.approx(4.0) +def _ll_files(): + ll_dir = ( + Path(__file__).parent.parent + / "tests-integration" + / "resources" + / "adaptive_ri" + / "output" + ) + return sorted(ll_dir.glob("*.ll")) + + +@pytest.mark.parametrize("ll_file", _ll_files(), ids=lambda p: p.stem) +def test_trace_from_qir(ll_file): + # NOTE: This test is primarily to ensure that the function can parse real + # QIR output without errors, rather than checking specific properties of the + # trace. + try: + trace_from_qir(ll_file.read_text()) + except ValueError as e: + # The only reason of failure is presence of control flow + assert ( + str(e) + == "simulation of programs with branching control flow is not supported" + ) + + +def test_trace_from_qir_handles_all_instruction_ids(): + """Verify that trace_from_qir handles every QirInstructionId except CorrelatedNoise. + + Generates a synthetic QIR program containing one instance of each gate + intrinsic recognised by AggregateGatesPass and asserts that trace_from_qir + processes all of them without error. + """ + import pyqir + import pyqir.qis as qis + from qsharp._native import QirInstructionId + from qsharp.qre.interop._qir import _GATE_MAP, _MEAS_MAP, _SKIP + + # -- Completeness check: every QirInstructionId must be covered -------- + handled_ids = ( + [qir_id for qir_id, _, _ in _GATE_MAP] + + [qir_id for qir_id, _ in _MEAS_MAP] + + list(_SKIP) + ) + # Exhaustive list of all QirInstructionId variants (pyo3 enums are not iterable) + all_ids = [ + QirInstructionId.I, + QirInstructionId.H, + QirInstructionId.X, + QirInstructionId.Y, + QirInstructionId.Z, + QirInstructionId.S, + QirInstructionId.SAdj, + QirInstructionId.SX, + QirInstructionId.SXAdj, + QirInstructionId.T, + QirInstructionId.TAdj, + QirInstructionId.CNOT, + QirInstructionId.CX, + QirInstructionId.CY, + QirInstructionId.CZ, + QirInstructionId.CCX, + QirInstructionId.SWAP, + QirInstructionId.RX, + QirInstructionId.RY, + QirInstructionId.RZ, + QirInstructionId.RXX, + QirInstructionId.RYY, + QirInstructionId.RZZ, + QirInstructionId.RESET, + QirInstructionId.M, + QirInstructionId.MResetZ, + QirInstructionId.MZ, + QirInstructionId.Move, + QirInstructionId.ReadResult, + QirInstructionId.ResultRecordOutput, + QirInstructionId.BoolRecordOutput, + QirInstructionId.IntRecordOutput, + QirInstructionId.DoubleRecordOutput, + QirInstructionId.TupleRecordOutput, + QirInstructionId.ArrayRecordOutput, + QirInstructionId.CorrelatedNoise, + ] + unhandled = [ + i + for i in all_ids + if i not in handled_ids and i != QirInstructionId.CorrelatedNoise + ] + assert unhandled == [], ( + f"QirInstructionId values not covered by _GATE_MAP, _MEAS_MAP, or _SKIP: " + f"{', '.join(str(i) for i in unhandled)}" + ) + + # -- Generate a QIR program with every producible gate ----------------- + simple = pyqir.SimpleModule("test_all_gates", num_qubits=4, num_results=3) + builder = simple.builder + ctx = simple.context + q = simple.qubits + r = simple.results + + void_ty = pyqir.Type.void(ctx) + qubit_ty = pyqir.qubit_type(ctx) + result_ty = pyqir.result_type(ctx) + double_ty = pyqir.Type.double(ctx) + i64_ty = pyqir.IntType(ctx, 64) + + def declare(name, param_types): + return simple.add_external_function( + name, pyqir.FunctionType(void_ty, param_types) + ) + + # Single-qubit gates (pyqir.qis builtins) + qis.h(builder, q[0]) + qis.x(builder, q[0]) + qis.y(builder, q[0]) + qis.z(builder, q[0]) + qis.s(builder, q[0]) + qis.s_adj(builder, q[0]) + qis.t(builder, q[0]) + qis.t_adj(builder, q[0]) + + # SX — not in pyqir.qis + sx_fn = declare("__quantum__qis__sx__body", [qubit_ty]) + builder.call(sx_fn, [q[0]]) + + # Two-qubit gates (qis.cx emits __quantum__qis__cnot__body which the + # pass does not handle, so use builder.call with the correct name) + cx_fn = declare("__quantum__qis__cx__body", [qubit_ty, qubit_ty]) + builder.call(cx_fn, [q[0], q[1]]) + qis.cz(builder, q[0], q[1]) + qis.swap(builder, q[0], q[1]) + + cy_fn = declare("__quantum__qis__cy__body", [qubit_ty, qubit_ty]) + builder.call(cy_fn, [q[0], q[1]]) + + # Three-qubit gate + qis.ccx(builder, q[0], q[1], q[2]) + + # Single-qubit rotations + qis.rx(builder, 1.0, q[0]) + qis.ry(builder, 1.0, q[0]) + qis.rz(builder, 1.0, q[0]) + + # Two-qubit rotations — not in pyqir.qis + rot2_ty = [double_ty, qubit_ty, qubit_ty] + angle = pyqir.const(double_ty, 1.0) + for name in ("rxx", "ryy", "rzz"): + fn = declare(f"__quantum__qis__{name}__body", rot2_ty) + builder.call(fn, [angle, q[0], q[1]]) + + # Measurements + qis.mz(builder, q[0], r[0]) + + m_fn = declare("__quantum__qis__m__body", [qubit_ty, result_ty]) + builder.call(m_fn, [q[1], r[1]]) + + mresetz_fn = declare("__quantum__qis__mresetz__body", [qubit_ty, result_ty]) + builder.call(mresetz_fn, [q[2], r[2]]) + + # Reset / Move + qis.reset(builder, q[0]) + + move_fn = declare("__quantum__qis__move__body", [qubit_ty]) + builder.call(move_fn, [q[0]]) + + # Output recording + tag = simple.add_byte_string(b"tag") + arr_fn = declare("__quantum__rt__array_record_output", [i64_ty, tag.type]) + builder.call(arr_fn, [pyqir.const(i64_ty, 1), tag]) + + rec_fn = declare("__quantum__rt__result_record_output", [result_ty, tag.type]) + builder.call(rec_fn, [r[0], tag]) + + tup_fn = declare("__quantum__rt__tuple_record_output", [i64_ty, tag.type]) + builder.call(tup_fn, [pyqir.const(i64_ty, 1), tag]) + + # -- Run trace_from_qir and verify it succeeds ------------------------- + trace = trace_from_qir(simple.ir()) + assert trace is not None From 5ce010a28ff95a6c44ca18e752686123da59d806 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Tue, 17 Mar 2026 10:06:47 +0100 Subject: [PATCH 26/45] Split Yoked SC model into two codes (#3019) Splits Yoked SC into two models, also adds more property IDs. --- source/pip/qsharp/qre/models/__init__.py | 10 +- source/pip/qsharp/qre/models/qec/__init__.py | 9 +- source/pip/qsharp/qre/models/qec/_yoked.py | 168 ++++++++++++++----- source/pip/qsharp/qre/property_keys.pyi | 1 + source/pip/src/qre.rs | 1 + source/pip/tests/test_qre_models.py | 26 +-- source/qre/src/isa/property_keys.rs | 1 + 7 files changed, 157 insertions(+), 59 deletions(-) diff --git a/source/pip/qsharp/qre/models/__init__.py b/source/pip/qsharp/qre/models/__init__.py index 5f22ddb1a6..5b8400002c 100644 --- a/source/pip/qsharp/qre/models/__init__.py +++ b/source/pip/qsharp/qre/models/__init__.py @@ -2,7 +2,12 @@ # Licensed under the MIT License. from .factories import Litinski19Factory, MagicUpToClifford, RoundBasedFactory -from .qec import SurfaceCode, ThreeAux, YokedSurfaceCode +from .qec import ( + SurfaceCode, + ThreeAux, + OneDimensionalYokedSurfaceCode, + TwoDimensionalYokedSurfaceCode, +) from .qubits import AQREGateBased, Majorana __all__ = [ @@ -13,5 +18,6 @@ "RoundBasedFactory", "SurfaceCode", "ThreeAux", - "YokedSurfaceCode", + "OneDimensionalYokedSurfaceCode", + "TwoDimensionalYokedSurfaceCode", ] diff --git a/source/pip/qsharp/qre/models/qec/__init__.py b/source/pip/qsharp/qre/models/qec/__init__.py index 588544fb3a..4e4cf816f7 100644 --- a/source/pip/qsharp/qre/models/qec/__init__.py +++ b/source/pip/qsharp/qre/models/qec/__init__.py @@ -3,6 +3,11 @@ from ._surface_code import SurfaceCode from ._three_aux import ThreeAux -from ._yoked import YokedSurfaceCode +from ._yoked import OneDimensionalYokedSurfaceCode, TwoDimensionalYokedSurfaceCode -__all__ = ["SurfaceCode", "ThreeAux", "YokedSurfaceCode"] +__all__ = [ + "SurfaceCode", + "ThreeAux", + "OneDimensionalYokedSurfaceCode", + "TwoDimensionalYokedSurfaceCode", +] diff --git a/source/pip/qsharp/qre/models/qec/_yoked.py b/source/pip/qsharp/qre/models/qec/_yoked.py index fd6652ddaf..8bb9bf9597 100644 --- a/source/pip/qsharp/qre/models/qec/_yoked.py +++ b/source/pip/qsharp/qre/models/qec/_yoked.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from dataclasses import dataclass, KW_ONLY, field -from enum import IntEnum +from dataclasses import dataclass from math import ceil from typing import Generator @@ -13,23 +12,8 @@ from ...property_keys import DISTANCE -class ShapeHeuristic(IntEnum): - """ - The heuristic to determine the shape of the memory qubits with respect to - the number of required rows and columns. - - Attributes: - MIN_AREA: The shape that minimizes the total number of qubits. - SQUARE: The shape that minimizes the difference between the number of rows - and columns. - """ - - MIN_AREA = 0 - SQUARE = 1 - - @dataclass -class YokedSurfaceCode(ISATransform): +class OneDimensionalYokedSurfaceCode(ISATransform): """ This class models the Yoked surface code to provide a generic memory instruction based on lattice surgery instructions from a surface code like @@ -53,12 +37,17 @@ class YokedSurfaceCode(ISATransform): codes, [arXiv:2312.04522](https://arxiv.org/abs/2312.04522) """ - crossing_prefactor: float = 0.016 - error_correction_threshold: float = 0.064 - _: KW_ONLY - shape_heuristic: ShapeHeuristic = field( - default=ShapeHeuristic.MIN_AREA, metadata={"domain": list(ShapeHeuristic)} - ) + # NOTE: The crossing_prefactor is relative to that of the underlying surface + # code. That is if the surface code model is p(SC) = + # A*(p(phy)/th(SC))^((d+1)/2), then multiplier for its yoked extension is + # crossing_prefactor*A + crossing_prefactor: float = 8 / 15 + + # NOTE: The threshold is relative to that of the underlying surface code. + # Namely, as the yoking doubles the distance, one would expect the yoked + # surface code to have a threshold of sqrt(th(SC)). However modeling shows + # it falls short of this. + error_correction_threshold: float = 64 / 10 @staticmethod def required_isa() -> ISARequirements: @@ -74,23 +63,21 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non distance = lattice_surgery.get_property(DISTANCE) assert distance is not None - shape_fn = [self._min_area_shape, self._square_shape][self.shape_heuristic] - def space(arity: int) -> int: - a, b = shape_fn(arity) + a, b = self._min_area_shape(arity) return lattice_surgery.expect_space(a * b) space_fn = generic_function(space) def time(arity: int) -> int: - a, b = shape_fn(arity) + a, b = self._min_area_shape(arity) s = lattice_surgery.expect_time(a * b) return s * (8 * distance * (a - 1) + 2 * distance) time_fn = generic_function(time) def error_rate(arity: int) -> float: - a, b = shape_fn(arity) + a, b = self._min_area_shape(arity) rounds = 2 * (a - 2) # logical error rate on a single surface code patch p = lattice_surgery.expect_error_rate(1) @@ -98,7 +85,8 @@ def error_rate(arity: int) -> float: rounds**2 * (a * b) ** 2 * self.crossing_prefactor - * (p / self.error_correction_threshold) ** ((distance + 1) // 2) + * p + * (1 / self.error_correction_threshold) ** ((distance + 1) // 2) ) error_rate_fn = generic_function(error_rate) @@ -117,19 +105,6 @@ def error_rate(arity: int) -> float: ) ) - @staticmethod - def _square_shape(num_qubits: int) -> tuple[int, int]: - """ - Given a number of qubits num_qubits, returns numbers (a + 1) and (b + 2) - such that a * b >= num_qubits and a and b are as close as possible. - """ - - a = int(num_qubits**0.5) - while num_qubits % a != 0: - a -= 1 - b = num_qubits // a - return a + 1, b + 2 - @staticmethod def _min_area_shape(num_qubits: int) -> tuple[int, int]: """ @@ -155,3 +130,110 @@ def _min_area_shape(num_qubits: int) -> tuple[int, int]: assert best_a is not None assert best_b is not None return best_a + 1, best_b + 2 + + +@dataclass +class TwoDimensionalYokedSurfaceCode(ISATransform): + """ + This class models the Yoked surface code to provide a generic memory + instruction based on lattice surgery instructions from a surface code like + error correction code. + + Attributes: + crossing_prefactor: float + The prefactor for logical error rate (Default is 0.016) + error_correction_threshold: float + The error correction threshold for the surface code (Default is + 0.064) + + Hyper parameters: + shape_heuristic: ShapeHeuristic + The heuristic to determine the shape of the surface code patch for a + given number of logical qubits. (Default is ShapeHeuristic.MIN_AREA) + + References: + + - Craig Gidney, Michael Newman, Peter Brooks, Cody Jones: Yoked surface + codes, [arXiv:2312.04522](https://arxiv.org/abs/2312.04522) + """ + + # NOTE: The crossing_prefactor is relative to that of the underlying surface + # code. That is if the surface code model is p(SC) = + # A*(p(phy)/th(SC))^((d+1)/2), then multiplier for its yoked extension is + # crossing_prefactor*A + crossing_prefactor: float = 5 / 600 + + # NOTE: The threshold is relative to that of the underlying surface code. + # Namely, as the yoking doubles the distance, one would expect the yoked + # surface code to have a threshold of sqrt(th(SC)). However modeling shows + # it falls short of this. + error_correction_threshold: float = 2500 / 10 + + @staticmethod + def required_isa() -> ISARequirements: + # We require a lattice surgery instruction that also provides the code + # distance as a property. This is necessary to compute the time + # and error rate formulas for the provided memory instruction. + return ISARequirements( + constraint(LATTICE_SURGERY, LOGICAL, arity=None, distance=True), + ) + + def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + lattice_surgery = impl_isa[LATTICE_SURGERY] + distance = lattice_surgery.get_property(DISTANCE) + assert distance is not None + + def space(arity: int) -> int: + a, b = self._square_shape(arity) + return lattice_surgery.expect_space(a * b) + + space_fn = generic_function(space) + + def time(arity: int) -> int: + a, b = self._square_shape(arity) + s = lattice_surgery.expect_time(a * b) + return s * (8 * distance * max(a - 2, b - 2) + 2 * distance) + + time_fn = generic_function(time) + + def error_rate(arity: int) -> float: + a, b = self._square_shape(arity) + rounds = 2 * max(a - 3, b - 3) + # logical error rate on a single surface code patch + p = lattice_surgery.expect_error_rate(1) + return ( + rounds**4 + * (a * b) ** 2 + * self.crossing_prefactor + * p + * (1 / self.error_correction_threshold) ** ((distance + 1) // 2) + ) + + error_rate_fn = generic_function(error_rate) + + yield ctx.make_isa( + ctx.add_instruction( + MEMORY, + arity=None, + encoding=LOGICAL, + space=space_fn, + time=time_fn, + error_rate=error_rate_fn, + transform=self, + source=[lattice_surgery], + distance=distance, + ) + ) + + @staticmethod + def _square_shape(num_qubits: int) -> tuple[int, int]: + """ + Given a number of qubits num_qubits, returns numbers (a + 2) and (b + 2) + such that a * b >= num_qubits and a and b are as close as possible. + """ + + a = int(num_qubits**0.5) + while num_qubits % a != 0: + a -= 1 + b = num_qubits // a + return a + 2, b + 2 diff --git a/source/pip/qsharp/qre/property_keys.pyi b/source/pip/qsharp/qre/property_keys.pyi index 17ebdb49af..62f5fd5213 100644 --- a/source/pip/qsharp/qre/property_keys.pyi +++ b/source/pip/qsharp/qre/property_keys.pyi @@ -12,3 +12,4 @@ EVALUATION_TIME: int PHYSICAL_COMPUTE_QUBITS: int PHYSICAL_FACTORY_QUBITS: int PHYSICAL_MEMORY_QUBITS: int +MOLECULE: int diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index ff16b4cde2..73c87da189 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -1447,6 +1447,7 @@ fn add_property_keys(m: &Bound<'_, PyModule>) -> PyResult<()> { PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, + MOLECULE ); m.add_submodule(&property_keys)?; diff --git a/source/pip/tests/test_qre_models.py b/source/pip/tests/test_qre_models.py index d0e8c33d37..ef03a1eb42 100644 --- a/source/pip/tests/test_qre_models.py +++ b/source/pip/tests/test_qre_models.py @@ -34,7 +34,7 @@ Litinski19Factory, SurfaceCode, ThreeAux, - YokedSurfaceCode, + TwoDimensionalYokedSurfaceCode, ) from qsharp.qre.property_keys import DISTANCE @@ -403,7 +403,7 @@ def _get_lattice_surgery_isa(self, distance=5): def test_provides_memory_instruction(self): ls_isa, ctx = self._get_lattice_surgery_isa() - ysc = YokedSurfaceCode() + ysc = TwoDimensionalYokedSurfaceCode() isas = list(ysc.provided_isa(ls_isa, ctx)) assert len(isas) == 1 @@ -411,7 +411,7 @@ def test_provides_memory_instruction(self): def test_memory_is_logical(self): ls_isa, ctx = self._get_lattice_surgery_isa() - ysc = YokedSurfaceCode() + ysc = TwoDimensionalYokedSurfaceCode() isas = list(ysc.provided_isa(ls_isa, ctx)) mem = isas[0][MEMORY] @@ -419,7 +419,7 @@ def test_memory_is_logical(self): def test_memory_arity_is_variable(self): ls_isa, ctx = self._get_lattice_surgery_isa() - ysc = YokedSurfaceCode() + ysc = TwoDimensionalYokedSurfaceCode() isas = list(ysc.provided_isa(ls_isa, ctx)) mem = isas[0][MEMORY] @@ -428,7 +428,7 @@ def test_memory_arity_is_variable(self): def test_space_increases_with_arity(self): ls_isa, ctx = self._get_lattice_surgery_isa() - ysc = YokedSurfaceCode() + ysc = TwoDimensionalYokedSurfaceCode() isas = list(ysc.provided_isa(ls_isa, ctx)) mem = isas[0][MEMORY] @@ -439,7 +439,7 @@ def test_space_increases_with_arity(self): def test_time_increases_with_arity(self): ls_isa, ctx = self._get_lattice_surgery_isa() - ysc = YokedSurfaceCode() + ysc = TwoDimensionalYokedSurfaceCode() isas = list(ysc.provided_isa(ls_isa, ctx)) mem = isas[0][MEMORY] @@ -450,7 +450,7 @@ def test_time_increases_with_arity(self): def test_error_rate_increases_with_arity(self): ls_isa, ctx = self._get_lattice_surgery_isa() - ysc = YokedSurfaceCode() + ysc = TwoDimensionalYokedSurfaceCode() isas = list(ysc.provided_isa(ls_isa, ctx)) mem = isas[0][MEMORY] @@ -462,7 +462,7 @@ def test_error_rate_increases_with_arity(self): def test_distance_property_propagated(self): d = 7 ls_isa, ctx = self._get_lattice_surgery_isa(distance=d) - ysc = YokedSurfaceCode() + ysc = TwoDimensionalYokedSurfaceCode() isas = list(ysc.provided_isa(ls_isa, ctx)) mem = isas[0][MEMORY] @@ -1008,12 +1008,14 @@ def test_surface_code_with_yoked_surface_code(self): ctx = arch.context() count = 0 - for isa in YokedSurfaceCode.q(source=SurfaceCode.q()).enumerate(ctx): + for isa in TwoDimensionalYokedSurfaceCode.q(source=SurfaceCode.q()).enumerate( + ctx + ): assert MEMORY in isa count += 1 - # 12 distances × 2 shape heuristics = 24 - assert count == 24 + # 12 distances × 1 shape heuristic = 12 + assert count == 12 def test_majorana_three_aux_yoked(self): """Majorana -> ThreeAux -> YokedSurfaceCode pipeline.""" @@ -1021,7 +1023,7 @@ def test_majorana_three_aux_yoked(self): ctx = arch.context() count = 0 - for isa in YokedSurfaceCode.q(source=ThreeAux.q()).enumerate(ctx): + for isa in TwoDimensionalYokedSurfaceCode.q(source=ThreeAux.q()).enumerate(ctx): assert MEMORY in isa count += 1 diff --git a/source/qre/src/isa/property_keys.rs b/source/qre/src/isa/property_keys.rs index b031cba828..376d9979f8 100644 --- a/source/qre/src/isa/property_keys.rs +++ b/source/qre/src/isa/property_keys.rs @@ -46,4 +46,5 @@ define_properties! { PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, + MOLECULE, } From 0c86dbb0a92ee858930068fdeb1bc6e595e79b21 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Tue, 17 Mar 2026 18:24:19 +0100 Subject: [PATCH 27/45] Use Rust parallelism for estimation with post processing (#3025) When performing post processing we need to know the updated results after the post_process function on the application is called (in Python). As a result, so far a Python-based parallelism was used for estimation when post processing is in play. This is changed in this PR, by first computing all estimates in parallel (without filtering to Pareto optimal results) and then performing the filtering in Python. This reduces runtime by about 30% on some QRE tests. Further, it allows to use the same Rust-based estimation function which improves code sharing. --- source/pip/qsharp/qre/_estimation.py | 77 ++++++++++++++++++++-------- source/pip/qsharp/qre/_qre.pyi | 12 +++++ source/pip/src/qre.rs | 11 ++++ source/qre/src/lib.rs | 2 +- source/qre/src/result.rs | 31 +++++++++++ source/qre/src/trace.rs | 18 ++++++- 6 files changed, 126 insertions(+), 25 deletions(-) diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 1a78225a94..225845654c 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -3,7 +3,6 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import cast, Optional, Callable, Any @@ -78,33 +77,67 @@ def estimate( if post_process: # Enumerate traces with their parameters so we can post-process later - params_and_traces = list(trace_query.enumerate(app_ctx, track_parameters=True)) + params_and_traces = cast( + list[tuple[Any, Trace]], + list(trace_query.enumerate(app_ctx, track_parameters=True)), + ) isas = list(isa_query.enumerate(arch_ctx)) num_traces = len(params_and_traces) num_isas = len(isas) - # Estimate all trace × ISA combinations using Python threads - collection = _EstimationCollection() - - def _estimate_one(params, trace, isa): - result = trace.estimate(isa, max_error) + # Phase 1: Run all estimates in Rust (parallel, fast). + traces_only = [trace for _, trace in params_and_traces] + collection = _estimate_parallel(cast(list[Trace], traces_only), isas, max_error) + successful = collection.successful_estimates + summaries = collection.all_summaries # (trace_idx, isa_idx, qubits, runtime) + + # Phase 2: Learn per-trace runtime multiplier and qubit multiplier from + # one sample each: if post_process changes runtime or qubit count it + # will affect the Pareto optimality, but the changes depend only on the + # trace, not on the ISA. + trace_multipliers: dict[int, tuple[float, float]] = {} + trace_sample_isa: dict[int, int] = {} + for t_idx, i_idx, _q, r in summaries: + if t_idx not in trace_sample_isa: + trace_sample_isa[t_idx] = i_idx + for t_idx, i_idx in trace_sample_isa.items(): + params, trace = params_and_traces[t_idx] + sample = trace.estimate(isas[i_idx], max_error) + if sample is not None: + pre_q = sample.qubits + pre_r = sample.runtime + pp = app_ctx.application.post_process(params, sample) + if pp is not None and pre_r > 0 and pre_q > 0: + trace_multipliers[t_idx] = (pp.qubits / pre_q, pp.runtime / pre_r) + + # Phase 3: Estimate post-pp values and filter to Pareto candidates. + estimated_pp: list[tuple[int, int, int, int]] = [] # (t, i, q, est_r) + for t_idx, i_idx, q, r in summaries: + mult_q, mult_r = trace_multipliers.get(t_idx, (0.0, 0.0)) + est_q = int(q * mult_q) if mult_q > 0 else q + est_r = int(r * mult_r) if mult_r > 0 else r + estimated_pp.append((t_idx, i_idx, est_q, est_r)) + + # Build approximate post-pp Pareto frontier to identify candidates. + estimated_pp.sort(key=lambda x: (x[2], x[3])) # sort by qubits, then runtime + approx_pareto: list[tuple[int, int, int, int]] = [] + min_r = float("inf") + for item in estimated_pp: + if item[3] < min_r: + approx_pareto.append(item) + min_r = item[3] + + # Phase 4: Re-estimate and post-process only the Pareto candidates. + pp_collection = _EstimationCollection() + for t_idx, i_idx, _q, _r in approx_pareto: + params, trace = params_and_traces[t_idx] + result = trace.estimate(isas[i_idx], max_error) if result is not None: - result = app_ctx.application.post_process(params, result) - return result - - successful = 0 - with ThreadPoolExecutor() as executor: - futures = [ - executor.submit(_estimate_one, params, trace, isa) - for params, trace in cast(list[tuple[Any, Trace]], params_and_traces) - for isa in isas - ] - for future in futures: - result = future.result() - if result is not None: - successful += 1 - collection.insert(result) + pp_result = app_ctx.application.post_process(params, result) + if pp_result is not None: + pp_collection.insert(pp_result) + collection = pp_collection else: traces = list(trace_query.enumerate(app_ctx)) isas = list(isa_query.enumerate(arch_ctx)) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 3ad89d9ff4..651adf1280 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -945,6 +945,18 @@ class _EstimationCollection: """ ... + @property + def all_summaries(self) -> list[tuple[int, int, int, int]]: + """ + Returns lightweight summaries of ALL successful estimates as a list + of (trace_index, isa_index, qubits, runtime) tuples. + + Returns: + list[tuple[int, int, int, int]]: List of (trace_index, isa_index, + qubits, runtime) for every successful estimation. + """ + ... + class FactoryResult: """ Represents the result of a factory used in resource estimation. diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 73c87da189..ef8a502ecb 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -760,6 +760,17 @@ impl EstimationCollection { self.0.successful_estimates() } + /// Returns lightweight summaries of ALL successful estimates as a list + /// of (trace index, isa index, qubits, runtime) tuples. + #[getter] + pub fn all_summaries(&self) -> Vec<(usize, usize, u64, u64)> { + self.0 + .all_summaries() + .iter() + .map(|s| (s.trace_index, s.isa_index, s.qubits, s.runtime)) + .collect() + } + #[allow(clippy::needless_pass_by_value)] pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { let iter = EstimationCollectionIterator { diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs index e7c13e01f5..079afd79fb 100644 --- a/source/qre/src/lib.rs +++ b/source/qre/src/lib.rs @@ -15,7 +15,7 @@ pub use isa::{ ConstraintBound, Encoding, ISA, ISARequirements, Instruction, InstructionConstraint, LockedISA, ProvenanceGraph, VariableArityFunction, }; -pub use result::{EstimationCollection, EstimationResult, FactoryResult}; +pub use result::{EstimationCollection, EstimationResult, FactoryResult, ResultSummary}; mod trace; pub use trace::instruction_ids; pub use trace::instruction_ids::instruction_name; diff --git a/source/qre/src/result.rs b/source/qre/src/result.rs index f8b48c016d..7039ebad62 100644 --- a/source/qre/src/result.rs +++ b/source/qre/src/result.rs @@ -18,6 +18,7 @@ pub struct EstimationResult { factories: FxHashMap, isa: ISA, isa_index: Option, + trace_index: Option, properties: FxHashMap, } @@ -99,6 +100,15 @@ impl EstimationResult { self.isa_index } + pub fn set_trace_index(&mut self, index: usize) { + self.trace_index = Some(index); + } + + #[must_use] + pub fn trace_index(&self) -> Option { + self.trace_index + } + pub fn set_property(&mut self, key: u64, value: Property) { self.properties.insert(key, value); } @@ -155,9 +165,21 @@ impl ParetoItem2D for EstimationResult { } } +/// Lightweight summary of a successful estimation, used to identify +/// post-processing candidates without storing full results. +#[derive(Clone, Copy)] +pub struct ResultSummary { + pub trace_index: usize, + pub isa_index: usize, + pub qubits: u64, + pub runtime: u64, +} + #[derive(Default)] pub struct EstimationCollection { frontier: ParetoFrontier2D, + /// Lightweight summaries of ALL successful estimates (not just Pareto). + all_summaries: Vec, total_jobs: usize, successful_estimates: usize, } @@ -185,6 +207,15 @@ impl EstimationCollection { pub fn set_successful_estimates(&mut self, successful_estimates: usize) { self.successful_estimates = successful_estimates; } + + pub fn push_summary(&mut self, summary: ResultSummary) { + self.all_summaries.push(summary); + } + + #[must_use] + pub fn all_summaries(&self) -> &[ResultSummary] { + &self.all_summaries + } } impl Deref for EstimationCollection { diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 2ac610be15..d5fa95dd6c 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -1,13 +1,17 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use std::fmt::{Display, Formatter}; +use std::{ + fmt::{Display, Formatter}, + sync::atomic::AtomicUsize, +}; use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; use crate::{ Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction, LockedISA, + ResultSummary, property_keys::{PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS}, }; @@ -665,7 +669,7 @@ pub fn estimate_parallel<'a>( // Shared atomic counter acts as a lock-free work queue. Workers call // fetch_add to claim the next job index. - let next_job = std::sync::atomic::AtomicUsize::new(0); + let next_job = AtomicUsize::new(0); let mut collection = EstimationCollection::new(); collection.set_total_jobs(total_jobs); @@ -700,6 +704,8 @@ pub fn estimate_parallel<'a>( if let Ok(mut estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) { estimation.set_isa_index(isa_idx); + estimation.set_trace_index(trace_idx); + local_results.push(estimation); } } @@ -714,6 +720,14 @@ pub fn estimate_parallel<'a>( // Collect results from all workers into the shared collection. let mut successful = 0; for local_results in rx { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } successful += local_results.len(); collection.extend(local_results.into_iter()); } From dcadb5c7dc2324ec6894f699d347ce73c23fdbc5 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 23 Mar 2026 08:02:06 +0100 Subject: [PATCH 28/45] Fixes bucketing logic to distribute rotations to match the right depth (#3028) The previous implementation hit a corner case (the one in the test) which is now fixed with a new bucketing logic. --- source/pip/qsharp/qre/interop/_qsharp.py | 54 ++++++++++++++++++------ source/pip/tests/test_qre.py | 24 +++++++++++ 2 files changed, 64 insertions(+), 14 deletions(-) diff --git a/source/pip/qsharp/qre/interop/_qsharp.py b/source/pip/qsharp/qre/interop/_qsharp.py index e14b372a9d..cded428266 100644 --- a/source/pip/qsharp/qre/interop/_qsharp.py +++ b/source/pip/qsharp/qre/interop/_qsharp.py @@ -14,6 +14,42 @@ from ..property_keys import EVALUATION_TIME +def _bucketize_rotation_counts( + rotation_count: int, rotation_depth: int +) -> list[tuple[int, int]]: + """ + Returns a list of (count, depth) pairs representing the rotation layers in + the trace. + + The following properties hold for the returned list `result`: + - sum(depth for _, depth in result) == rotation_depth + - sum(count * depth for count, depth in result) == rotation_count + - count > 0 for each (count, _) in result + - count <= qubit_count for each (count, _) in result holds by definition + when rotation_count <= rotation_depth * qubit_count + + Args: + rotation_count: Total number of rotations. + rotation_depth: Total depth of the rotation layers. + + Returns: + A list of (count, depth) pairs, where 'count' is the number of + rotations in a layer and 'depth' is the depth of that layer. + """ + if rotation_depth == 0: + return [] + + base = rotation_count // rotation_depth + extra = rotation_count % rotation_depth + + result: list[tuple[int, int]] = [] + if extra > 0: + result.append((base + 1, extra)) + if rotation_depth - extra > 0: + result.append((base, rotation_depth - extra)) + return result + + def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: start = time.time_ns() @@ -37,21 +73,11 @@ def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: rotation_count = counts.get("rotationCount", 0) rotation_depth = counts.get("rotationDepth", rotation_count) - if rotation_count != 0: - if rotation_depth > 1: - rotations_per_layer = rotation_count // (rotation_depth - 1) - else: - rotations_per_layer = 0 - - last_layer = rotation_count - (rotations_per_layer * (rotation_depth - 1)) - - if rotations_per_layer != 0: - block = trace.add_block(repetitions=rotation_depth - 1) - for i in range(rotations_per_layer): + if rotation_count != 0 and rotation_depth != 0: + for count, depth in _bucketize_rotation_counts(rotation_count, rotation_depth): + block = trace.add_block(repetitions=depth) + for i in range(count): block.add_operation(RZ, [i]) - block = trace.add_block() - for i in range(last_layer): - block.add_operation(RZ, [i]) if t_count := counts.get("tCount", 0): block = trace.add_block(repetitions=t_count) diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 3b1817f180..68dafbd7eb 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -1470,3 +1470,27 @@ def declare(name, param_types): # -- Run trace_from_qir and verify it succeeds ------------------------- trace = trace_from_qir(simple.ir()) assert trace is not None + + +def test_rotation_buckets(): + from qsharp.qre.interop._qsharp import _bucketize_rotation_counts + + print() + + r_count = 15066 + r_depth = 14756 + q_count = 291 + + result = _bucketize_rotation_counts(r_count, r_depth) + + a_count = 0 + a_depth = 0 + for c, d in result: + print(c, d) + assert c <= q_count + assert c > 0 + a_count += c * d + a_depth += d + + assert a_count == r_count + assert a_depth == r_depth From 317ff6aae17b5bae8aa5bb8469c66cb244cd57c9 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 23 Mar 2026 09:03:06 +0100 Subject: [PATCH 29/45] Graph-based ISA pruning for resource estimation (#3031) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an alternative estimation path that builds a **provenance graph** of ISA instructions and prunes suboptimal candidates before forming the Cartesian product, significantly reducing the combinatorial search space. ### Changes **Rust (`source/qre/src/`)** - `isa.rs`: Add `build_pareto_index()` to compute per-instruction-ID Pareto-optimal node sets over (space, time, error). Add `query_satisfying()` to enumerate ISAs from pruned graph nodes. Extract `InstructionConstraint::is_satisfied_by()` from inline logic. - `trace.rs`: Add `estimate_with_graph()` — a new parallel estimator that uses the provenance graph with per-slot dominance pruning to skip combinations dominated by previously successful estimates. Add `Trace::required_instruction_ids()` helper. Add `post_process` flag to `estimate_parallel()` to control summary collection. - `result.rs` / `lib.rs`: Expose new types and re-exports. **Python (`source/pip/`)** - `_estimation.py`: Add `use_graph` parameter to `estimate()` (default `True`). When enabled, populates the provenance graph and calls the graph-based estimator instead of the flat enumerator. - `_isa_enumeration.py`: Add `populate()` method to `ISAQuery` and its subclasses to fill the provenance graph without yielding ISA objects. - `_instruction.py`: Add `InstructionSource` utility. - `qre.rs`: Expose `_estimate_with_graph`, `ProvenanceGraph` bindings, and related Python-facing APIs. - `_qre.pyi`: Update type stubs. **Tests** - Add tests verifying graph-based estimation produces Pareto-optimal results consistent with the exhaustive path. ### Trade-offs The graph-based pruning filters ISA instructions by comparing per-instruction space, time, and error independently. Because total qubit counts depend on the interaction between factory space and runtime (copies × factory_space), an instruction dominated on per-instruction metrics can still contribute to a globally Pareto-optimal result. `use_graph=False` can be used when completeness of the Pareto frontier is required. --- source/pip/qsharp/qre/_estimation.py | 72 +++- source/pip/qsharp/qre/_instruction.py | 17 + source/pip/qsharp/qre/_isa_enumeration.py | 97 ++++++ source/pip/qsharp/qre/_qre.py | 1 + source/pip/qsharp/qre/_qre.pyi | 99 +++++- source/pip/src/qre.rs | 114 +++++-- source/pip/tests/test_qre.py | 63 ++++ source/qre/src/isa.rs | 326 +++++++++++++++--- source/qre/src/lib.rs | 5 +- source/qre/src/result.rs | 15 + source/qre/src/trace.rs | 383 +++++++++++++++++++++- 11 files changed, 1099 insertions(+), 93 deletions(-) diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 225845654c..4a4b75366d 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -12,6 +12,7 @@ from ._architecture import Architecture from ._qre import ( _estimate_parallel, + _estimate_with_graph, _EstimationCollection, Trace, FactoryResult, @@ -35,6 +36,7 @@ def estimate( *, max_error: float = 1.0, post_process: bool = False, + use_graph: bool = True, name: Optional[str] = None, ) -> EstimationTable: """ @@ -52,6 +54,20 @@ def estimate( The collection only contains the results that are optimal with respect to the total number of qubits and the total runtime. + Note: + The pruning strategy used when `use_graph` is set to True (default) + filters ISA instructions by comparing their per-instruction space, time, + and error independently. However, the total qubit count of a result + depends on the interaction between factory space and runtime: + ``factory_qubits = copies × factory_space`` where copies are determined + by ``count.div_ceil(runtime / factory_time)``. Because of this, an ISA + instruction that is dominated on per-instruction metrics can still + contribute to a globally Pareto-optimal result (e.g., a factory with + higher time may need fewer copies, leading to fewer total qubits). As a + consequence, `use_graph=True` may miss some results that + `use_graph=False` would find. Use `use_graph=False` when completeness of + the Pareto frontier is required. + Args: application (Application): The quantum application to be estimated. architecture (Architecture): The target quantum architecture. @@ -62,6 +78,10 @@ def estimate( post_process (bool): If True, use the Python-threaded estimation path (intended for future post-processing logic). If False (default), use the Rust parallel estimation path. + use_graph (bool): If True (default), use the Rust estimation path that + builds a graph of ISAs and prunes suboptimal ISAs during estimation. + If False, use the Rust estimation path that does not perform any + pruning and simply enumerates all ISAs for each trace. name (Optional[str]): An optional name for the estimation. If give, this will be added as a first column to the results table for all entries. @@ -81,14 +101,31 @@ def estimate( list[tuple[Any, Trace]], list(trace_query.enumerate(app_ctx, track_parameters=True)), ) - isas = list(isa_query.enumerate(arch_ctx)) - num_traces = len(params_and_traces) - num_isas = len(isas) # Phase 1: Run all estimates in Rust (parallel, fast). traces_only = [trace for _, trace in params_and_traces] - collection = _estimate_parallel(cast(list[Trace], traces_only), isas, max_error) + + if use_graph: + isa_query.populate(arch_ctx) + arch_ctx._provenance.build_pareto_index() + + num_isas = arch_ctx._provenance.total_isa_count() + + collection = _estimate_with_graph( + cast(list[Trace], traces_only), arch_ctx._provenance, max_error, True + ) + isas = collection.isas + else: + isas = list(isa_query.enumerate(arch_ctx)) + + num_isas = len(isas) + + collection = _estimate_parallel( + cast(list[Trace], traces_only), isas, max_error, True + ) + + total_jobs = collection.total_jobs successful = collection.successful_estimates summaries = collection.all_summaries # (trace_idx, isa_idx, qubits, runtime) @@ -140,13 +177,28 @@ def estimate( collection = pp_collection else: traces = list(trace_query.enumerate(app_ctx)) - isas = list(isa_query.enumerate(arch_ctx)) - num_traces = len(traces) - num_isas = len(isas) - # Use the Rust parallel estimation path - collection = _estimate_parallel(cast(list[Trace], traces), isas, max_error) + if use_graph: + isa_query.populate(arch_ctx) + arch_ctx._provenance.build_pareto_index() + + num_isas = arch_ctx._provenance.total_isa_count() + + collection = _estimate_with_graph( + cast(list[Trace], traces), arch_ctx._provenance, max_error, False + ) + else: + isas = list(isa_query.enumerate(arch_ctx)) + + num_isas = len(isas) + + # Use the Rust parallel estimation path + collection = _estimate_parallel( + cast(list[Trace], traces), isas, max_error, False + ) + + total_jobs = collection.total_jobs successful = collection.successful_estimates # Post-process the results and add them to a results table @@ -170,7 +222,7 @@ def estimate( # Fill in the stats for this estimation run table.stats.num_traces = num_traces table.stats.num_isas = num_isas - table.stats.total_jobs = num_traces * num_isas + table.stats.total_jobs = total_jobs table.stats.successful_estimates = successful table.stats.pareto_results = len(collection) diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index d6c25b704f..56645fd2aa 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -267,6 +267,23 @@ def __getitem__(self, id: int) -> _InstructionSourceNodeReference: raise KeyError(f"Instruction ID {id} not found in instruction source graph.") + def __contains__(self, id: int) -> bool: + """ + Checks if there is an instruction source root node with the given + instruction ID. + + Args: + id (int): The instruction ID to search for. + + Returns: + bool: True if a node with the given instruction ID exists, False otherwise. + """ + for root in self.roots: + if self.nodes[root].instruction.id == id: + return True + + return False + def get( self, id: int, default: Optional[_InstructionSourceNodeReference] = None ) -> Optional[_InstructionSourceNodeReference]: diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py index 6298ffd847..5cbb9fa187 100644 --- a/source/pip/qsharp/qre/_isa_enumeration.py +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -10,6 +10,7 @@ from typing import Generator from ._architecture import _Context +from ._enumeration import _enumerate_instances from ._qre import ISA @@ -37,6 +38,29 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ pass + def populate(self, ctx: _Context) -> int: + """ + Populates the provenance graph with instructions from this node. + + Unlike `enumerate`, this does not yield ISA objects. Each transform + queries the graph for Pareto-optimal instructions matching its + requirements, and adds produced instructions directly to the graph. + + Args: + ctx (_Context): The enumeration context whose provenance graph + will be populated. + + Returns: + int: The starting node index of the instructions contributed by + this subtree. Used by consumers to scope graph queries to only + see their source's nodes. + """ + # Default implementation: consume enumerate for its side effects + start = ctx._provenance.raw_node_count() + for _ in self.enumerate(ctx): + pass + return start + def __add__(self, other: ISAQuery) -> _SumNode: """ Performs a union of two enumeration nodes. @@ -146,6 +170,14 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ yield ctx._isa + def populate(self, ctx: _Context) -> int: + """Architecture ISA is already in the graph from ``_Context.__init__``. + + Returns: + int: 1, since architecture nodes start at index 1. + """ + return 1 + # Singleton instance for convenience ISA_ROOT = RootNode() @@ -184,6 +216,31 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for isa in self.source.enumerate(ctx): yield from self.component.enumerate_isas(isa, ctx, **self.kwargs) + def populate(self, ctx: _Context) -> int: + """ + Populates the graph by querying matching instructions. + + Runs the source first to ensure dependency instructions are in + the graph, then queries the graph for all instructions matching + this component's requirements within the source's node range. + For each matching ISA × each hyperparameter instance, calls + ``provided_isa`` to add new instructions to the graph. + + Returns: + int: The starting node index of this component's own additions. + """ + source_start = self.source.populate(ctx) + impl_isas = ctx._provenance.query_satisfying( + self.component.required_isa(), min_node_idx=source_start + ) + own_start = ctx._provenance.raw_node_count() + for instance in _enumerate_instances(self.component, **self.kwargs): + ctx._transforms[id(instance)] = instance + for impl_isa in impl_isas: + for _ in instance.provided_isa(impl_isa, ctx): + pass + return own_start + @dataclass class _ProductNode(ISAQuery): @@ -212,6 +269,17 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for isa_tuple in itertools.product(*source_generators) ) + def populate(self, ctx: _Context) -> int: + """Populates the graph from each source sequentially (no cross product). + + Returns: + int: The starting node index before any source populated. + """ + first = ctx._provenance.raw_node_count() + for source in self.sources: + source.populate(ctx) + return first + @dataclass class _SumNode(ISAQuery): @@ -237,6 +305,17 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for source in self.sources: yield from source.enumerate(ctx) + def populate(self, ctx: _Context) -> int: + """Populates the graph from each source sequentially. + + Returns: + int: The starting node index before any source populated. + """ + first = ctx._provenance.raw_node_count() + for source in self.sources: + source.populate(ctx) + return first + @dataclass class ISARefNode(ISAQuery): @@ -268,6 +347,14 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: raise ValueError(f"Undefined component reference: '{self.name}'") yield ctx._bindings[self.name] + def populate(self, ctx: _Context) -> int: + """Instructions already in graph from the bound component. + + Returns: + int: 1, since bound component nodes start at index 1. + """ + return 1 + @dataclass class _BindingNode(ISAQuery): @@ -329,3 +416,13 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: # Add binding to context and enumerate child node new_ctx = ctx._with_binding(self.name, isa) yield from self.node.enumerate(new_ctx) + + def populate(self, ctx: _Context) -> int: + """Populates the graph from both the component and the child node. + + Returns: + int: The starting node index of the component's additions. + """ + comp_start = self.component.populate(ctx) + self.node.populate(ctx) + return comp_start diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index 46870b4aae..a67e320218 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -12,6 +12,7 @@ Constraint, ConstraintBound, _estimate_parallel, + _estimate_with_graph, _EstimationCollection, EstimationResult, FactoryResult, diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 651adf1280..610a901431 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -751,6 +751,63 @@ class _ProvenanceGraph: """ ... + def build_pareto_index(self) -> None: + """ + Builds the per-instruction-ID Pareto index. + + For each instruction ID, retains only the Pareto-optimal nodes w.r.t. + (space, time, error_rate) evaluated at arity 1. Must be called after + all nodes have been added. + """ + ... + + def query_satisfying( + self, + requirements: ISARequirements, + min_node_idx: Optional[int] = None, + ) -> list[ISA]: + """ + Returns ISAs formed from Pareto-optimal graph nodes satisfying the + given requirements. + + For each constraint in requirements, selects matching Pareto-optimal + nodes. Returns the Cartesian product of per-constraint matches, + augmented with one representative node per unconstrained instruction + ID. + + Must be called after ``build_pareto_index``. + + Args: + requirements: The ISA requirements to satisfy. + min_node_idx: If provided, only consider Pareto nodes at or above + this index for constrained groups. + + Returns: + list[ISA]: ISAs formed from matching Pareto-optimal nodes. + """ + ... + + def raw_node_count(self) -> int: + """ + Returns the raw node count (including the sentinel at index 0). + + Returns: + int: The number of nodes in the graph. + """ + ... + + def total_isa_count(self) -> int: + """ + Returns the total number of ISAs that can be formed from Pareto-optimal + nodes. + + Requires ``build_pareto_index`` to have been called. + + Returns: + int: The total number of ISAs that can be formed. + """ + ... + class EstimationResult: """ Represents the result of a resource estimation. @@ -957,6 +1014,16 @@ class _EstimationCollection: """ ... + @property + def isas(self) -> list[ISA]: + """ + Returns the list of ISAs for which estimates were performed. + + Returns: + list[ISA]: The list of ISAs. + """ + ... + class FactoryResult: """ Represents the result of a factory used in resource estimation. @@ -1370,7 +1437,10 @@ class InstructionFrontier: ... def _estimate_parallel( - traces: list[Trace], isas: list[ISA], max_error: float = 1.0 + traces: list[Trace], + isas: list[ISA], + max_error: float = 1.0, + post_process: bool = False, ) -> _EstimationCollection: """ Estimates resources for multiple traces and ISAs in parallel. @@ -1379,6 +1449,33 @@ def _estimate_parallel( traces (list[Trace]): The list of traces. isas (list[ISA]): The list of ISAs. max_error (float): The maximum allowed error. The default is 1.0. + post_process (bool): If True, computes auxiliary data such as result + summaries needed for post-processing after estimation. + + Returns: + _EstimationCollection: The estimation collection. + """ + ... + +def _estimate_with_graph( + traces: list[Trace], + graph: _ProvenanceGraph, + max_error: float = 1.0, + post_process: bool = False, +) -> _EstimationCollection: + """ + Estimates resources using a Pareto-filtered provenance graph. + + Instead of forming the full Cartesian product of ISAs × traces, this + function enumerates per-trace instruction combinations from the + Pareto-optimal subsets in the frozen graph. + + Args: + traces (list[Trace]): The list of traces to estimate. + graph (_ProvenanceGraph): The provenance graph to use for estimation. + max_error (float): The maximum allowed error. The default is 1.0. + post_process (bool): If True, computes auxiliary data such as result + summaries and ISAs needed for post-processing after estimation. Returns: _EstimationCollection: The estimation collection. diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index ef8a502ecb..6e69218f9a 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -37,6 +37,7 @@ pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(block_linear_function, m)?)?; m.add_function(wrap_pyfunction!(generic_function, m)?)?; m.add_function(wrap_pyfunction!(estimate_parallel, m)?)?; + m.add_function(wrap_pyfunction!(estimate_with_graph, m)?)?; m.add_function(wrap_pyfunction!(binom_ppf, m)?)?; m.add_function(wrap_pyfunction!(float_to_bits, m)?)?; m.add_function(wrap_pyfunction!(float_from_bits, m)?)?; @@ -620,6 +621,59 @@ impl ProvenanceGraph { } Ok(ISA(isa)) } + + /// Builds the per-instruction-ID Pareto index. + /// + /// Must be called after all nodes have been added. For each instruction + /// ID, retains only the Pareto-optimal nodes w.r.t. (space, time, + /// error rate) evaluated at arity 1. + pub fn build_pareto_index(&self) -> PyResult<()> { + self.0 + .write() + .map_err(poisoned_lock_err)? + .build_pareto_index(); + Ok(()) + } + + /// Returns the raw node count (including the sentinel at index 0). + pub fn raw_node_count(&self) -> PyResult { + Ok(self.0.read().map_err(poisoned_lock_err)?.raw_node_count()) + } + + /// Computes an upper bound on the possible ISAs that can be formed from + /// this graph. + /// + /// Must be called after `build_pareto_index`. + pub fn total_isa_count(&self) -> PyResult { + Ok(self.0.read().map_err(poisoned_lock_err)?.total_isa_count()) + } + + /// Returns ISAs formed from Pareto-optimal graph nodes satisfying the + /// given requirements. + /// + /// For each constraint in `requirements`, selects matching Pareto-optimal + /// nodes. Returns the Cartesian product of per-constraint matches, + /// augmented with one representative node per unconstrained instruction + /// ID. + /// + /// When ``min_node_idx`` is provided, only Pareto nodes at or above + /// that index are considered for constrained groups (useful for scoping + /// queries to a subset of the graph). + /// + /// Must be called after `build_pareto_index`. + #[pyo3(signature = (requirements, min_node_idx=None))] + pub fn query_satisfying( + &self, + requirements: &ISARequirements, + min_node_idx: Option, + ) -> PyResult> { + let graph = self.0.read().map_err(poisoned_lock_err)?; + Ok(graph + .query_satisfying(&self.0, &requirements.0, min_node_idx) + .into_iter() + .map(ISA) + .collect()) + } } #[pyclass(name = "_IntFunction")] @@ -771,6 +825,13 @@ impl EstimationCollection { .collect() } + /// Returns the set of ISAs for which this collection contains successful + /// estimates. + #[getter] + pub fn isas(&self) -> Vec { + self.0.isas().iter().cloned().map(ISA).collect() + } + #[allow(clippy::needless_pass_by_value)] pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { let iter = EstimationCollectionIterator { @@ -1265,12 +1326,13 @@ impl InstructionFrontierIterator { } #[allow(clippy::needless_pass_by_value)] -#[pyfunction(name = "_estimate_parallel", signature = (traces, isas, max_error = 1.0))] +#[pyfunction(name = "_estimate_parallel", signature = (traces, isas, max_error = 1.0, post_process = false))] pub fn estimate_parallel( py: Python<'_>, traces: Vec>, isas: Vec>, max_error: f64, + post_process: bool, ) -> EstimationCollection { let traces: Vec<_> = traces.iter().map(|t| &t.0).collect(); let isas: Vec<_> = isas.iter().map(|i| &i.0).collect(); @@ -1281,41 +1343,43 @@ pub fn estimate_parallel( // If the calling thread holds the GIL while blocked in // std::thread::scope, the worker threads deadlock. let collection = release_gil(py, || { - qre::estimate_parallel(&traces, &isas, Some(max_error)) + qre::estimate_parallel(&traces, &isas, Some(max_error), post_process) }); EstimationCollection(collection) } +#[allow(clippy::needless_pass_by_value)] +#[pyfunction(name = "_estimate_with_graph", signature = (traces, graph, max_error = 1.0, post_process = false))] +pub fn estimate_with_graph( + py: Python<'_>, + traces: Vec>, + graph: &ProvenanceGraph, + max_error: f64, + post_process: bool, +) -> PyResult { + let traces: Vec<_> = traces.iter().map(|t| &t.0).collect(); + + let collection = release_gil(py, || { + qre::estimate_with_graph(&traces, &graph.0, Some(max_error), post_process) + }); + Ok(EstimationCollection(collection)) +} + /// Releases the GIL for the duration of the closure `f`, allowing other -/// threads to acquire it. A RAII guard ensures the thread state is restored -/// even if `f` panics. -/// -/// # Safety +/// threads to acquire it. Delegates to `py.detach()` so that pyo3's internal +/// attach-count is properly reset; this ensures that any `Python::attach` +/// calls inside `f` (e.g. from `generic_function` callbacks) will correctly +/// call `PyGILState_Ensure` to re-acquire the GIL. /// /// The caller must ensure that no `Bound<'_, _>` or `Python<'_>` references /// are used inside `f`. GIL-independent `Py` handles are fine because /// they re-acquire the GIL via `Python::attach` when needed. -/// -/// We cannot use `py.allow_threads` here because the captured data -/// (`&qre::ISA`) transitively contains `Arc` whose -/// trait object does not carry the `Ungil` auto-trait bound. -fn release_gil(_py: Python<'_>, f: F) -> R +fn release_gil(py: Python<'_>, f: F) -> R where - F: FnOnce() -> R, + F: FnOnce() -> R + Send, + R: Send, { - struct RestoreGuard(*mut pyo3::ffi::PyThreadState); - - impl Drop for RestoreGuard { - fn drop(&mut self) { - // SAFETY: called on the same thread that saved the state. - unsafe { pyo3::ffi::PyEval_RestoreThread(self.0) }; - } - } - - // SAFETY: we hold the GIL (proven by the `_py` token) and release it - // here so that worker threads can acquire it for Python callbacks. - let _guard = RestoreGuard(unsafe { pyo3::ffi::PyEval_SaveThread() }); - f() + py.detach(f) } #[pyfunction(name = "_binom_ppf")] diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 68dafbd7eb..34a70dfec4 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -5,10 +5,12 @@ from enum import Enum from pathlib import Path from typing import cast, Generator +import os import pytest import pandas as pd import qsharp +from qsharp.estimator import LogicalCounts from qsharp.qre import ( Application, ISA, @@ -29,6 +31,8 @@ from qsharp.qre.models import ( SurfaceCode, AQREGateBased, + RoundBasedFactory, + TwoDimensionalYokedSurfaceCode, ) from qsharp.qre.interop import trace_from_qir from qsharp.qre._architecture import _Context, _make_instruction @@ -1472,6 +1476,65 @@ def declare(name, param_types): assert trace is not None +@pytest.mark.skipif( + "SLOW_TESTS" not in os.environ, + reason="turn on slow tests by setting SLOW_TESTS=1 in the environment", +) +@pytest.mark.parametrize( + "post_process, use_graph", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_estimation_methods(post_process, use_graph): + counts = LogicalCounts( + { + "numQubits": 1000, + "tCount": 1_500_000, + "rotationCount": 0, + "rotationDepth": 0, + "cczCount": 1_000_000_000, + "ccixCount": 0, + "measurementCount": 25_000_000, + "numComputeQubits": 200, + "readFromMemoryCount": 30_000_000, + "writeToMemoryCount": 30_000_000, + } + ) + + trace_query = PSSPC.q() * LatticeSurgery.q(slow_down_factor=[1.0, 2.0]) + isa_query = ( + SurfaceCode.q() + * RoundBasedFactory.q() + * TwoDimensionalYokedSurfaceCode.q(source=SurfaceCode.q()) + ) + + app = QSharpApplication(counts) + arch = AQREGateBased(gate_time=50, measurement_time=100) + + results = estimate( + app, + arch, + isa_query, + trace_query, + max_error=1 / 3, + post_process=post_process, + use_graph=use_graph, + ) + results.add_factory_summary_column() + + assert [(result.qubits, result.runtime) for result in results] == [ + (238707, 23997050000000), + (240407, 11998525000000), + ] + + print() + print(results.stats) + + def test_rotation_buckets(): from qsharp.qre.interop._qsharp import _bucketize_rotation_counts diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index 4ed2e4d5b5..b35d364c45 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -11,7 +11,7 @@ use num_traits::FromPrimitive; use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; -use crate::trace::instruction_ids::instruction_name; +use crate::{ParetoFrontier3D, trace::instruction_ids::instruction_name}; pub mod property_keys; @@ -137,50 +137,9 @@ impl ISA { let instruction = graph.instruction(node_idx); - if instruction.encoding != constraint.encoding { + if !constraint.is_satisfied_by(instruction) { return false; } - - match &instruction.metrics { - Metrics::FixedArity { - arity, error_rate, .. - } => { - // Constraint requires variable arity for this instruction - let Some(constraint_arity) = constraint.arity else { - return false; - }; - - // Arity must match - if *arity != constraint_arity { - return false; - } - - // Error rate constraint must be satisfied - if let Some(ref bound) = constraint.error_rate_fn - && !bound.evaluate(error_rate) - { - return false; - } - } - - Metrics::VariableArity { error_rate_fn, .. } => { - // If an arity and error rate constraint is specified, it - // must be satisfied - if let (Some(constraint_arity), Some(ref bound)) = - (constraint.arity, constraint.error_rate_fn) - && !bound.evaluate(&error_rate_fn.evaluate(constraint_arity)) - { - return false; - } - } - } - - // Check that all required properties are present in the instruction - for prop in &constraint.properties { - if !instruction.has_property(prop) { - return false; - } - } } true } @@ -517,6 +476,54 @@ impl InstructionConstraint { pub fn properties(&self) -> &FxHashSet { &self.properties } + + /// Returns the instruction ID this constraint applies to. + #[must_use] + pub fn id(&self) -> u64 { + self.id + } + + /// Checks whether a given instruction satisfies this constraint. + #[must_use] + pub fn is_satisfied_by(&self, instruction: &Instruction) -> bool { + if instruction.encoding != self.encoding { + return false; + } + + match &instruction.metrics { + Metrics::FixedArity { + arity, error_rate, .. + } => { + // Constraint requires variable arity but instruction is fixed + let Some(constraint_arity) = self.arity else { + return false; + }; + if *arity != constraint_arity { + return false; + } + if let Some(ref bound) = self.error_rate_fn + && !bound.evaluate(error_rate) + { + return false; + } + } + Metrics::VariableArity { error_rate_fn, .. } => { + if let (Some(constraint_arity), Some(bound)) = (self.arity, &self.error_rate_fn) + && !bound.evaluate(&error_rate_fn.evaluate(constraint_arity)) + { + return false; + } + } + } + + for prop in &self.properties { + if !instruction.has_property(prop) { + return false; + } + } + + true + } } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -658,6 +665,9 @@ pub struct ProvenanceGraph { // children of node i are located at children[offset..offset+num_children] // in the children vector. children: Vec, + // Per-instruction-ID index of Pareto-optimal node indices. + // Built by `build_pareto_index()` after all nodes have been added. + pareto_index: FxHashMap>, } impl Default for ProvenanceGraph { @@ -668,10 +678,35 @@ impl Default for ProvenanceGraph { ProvenanceGraph { nodes: vec![empty], children: Vec::new(), + pareto_index: FxHashMap::default(), } } } +/// Thin wrapper for 3D Pareto comparison of instructions at arity 1. +struct InstructionParetoItem { + node_index: usize, + space: u64, + time: u64, + error: f64, +} + +impl crate::ParetoItem3D for InstructionParetoItem { + type Objective1 = u64; + type Objective2 = u64; + type Objective3 = f64; + + fn objective1(&self) -> u64 { + self.space + } + fn objective2(&self) -> u64 { + self.time + } + fn objective3(&self) -> f64 { + self.error + } +} + impl ProvenanceGraph { #[must_use] pub fn new() -> Self { @@ -723,6 +758,211 @@ impl ProvenanceGraph { pub fn num_edges(&self) -> usize { self.children.len() } + + /// Builds the per-instruction-ID Pareto index. + /// + /// For each instruction ID in the graph, collects all nodes and retains + /// only the Pareto-optimal subset with respect to (space, time, `error_rate`) + /// evaluated at arity 1. Instructions with different encodings or + /// properties are never in competition. + /// + /// Must be called after all nodes have been added. + pub fn build_pareto_index(&mut self) { + // Group node indices by (instruction_id, encoding, properties) + let mut groups: FxHashMap> = FxHashMap::default(); + for idx in 1..self.nodes.len() { + let instr = &self.nodes[idx].instruction; + groups.entry(instr.id).or_default().push(idx); + } + + let mut pareto_index = FxHashMap::default(); + for (id, node_indices) in groups { + // Sub-partition by encoding and property keys to avoid comparing + // incompatible instructions (Risk R2 mitigation) + #[allow(clippy::type_complexity)] + let mut sub_groups: FxHashMap<(Encoding, Vec<(u64, u64)>), Vec> = + FxHashMap::default(); + for &idx in &node_indices { + let instr = &self.nodes[idx].instruction; + let mut prop_vec: Vec<(u64, u64)> = instr + .properties + .as_ref() + .map(|p| { + let mut v: Vec<_> = p.iter().map(|(&k, &v)| (k, v)).collect(); + v.sort_unstable(); + v + }) + .unwrap_or_default(); + prop_vec.sort_unstable(); + sub_groups + .entry((instr.encoding, prop_vec)) + .or_default() + .push(idx); + } + + let mut pareto_nodes = Vec::new(); + for (_key, indices) in sub_groups { + let items: Vec = indices + .iter() + .filter_map(|&idx| { + let instr = &self.nodes[idx].instruction; + let space = instr.space(Some(1))?; + let time = instr.time(Some(1))?; + let error = instr.error_rate(Some(1))?; + Some(InstructionParetoItem { + node_index: idx, + space, + time, + error, + }) + }) + .collect(); + + let frontier: ParetoFrontier3D = items.into_iter().collect(); + pareto_nodes.extend(frontier.into_iter().map(|item| item.node_index)); + } + + pareto_index.insert(id, pareto_nodes); + } + + self.pareto_index = pareto_index; + } + + /// Returns the Pareto-optimal node indices for a given instruction ID. + #[must_use] + pub fn pareto_nodes(&self, instruction_id: u64) -> Option<&[usize]> { + self.pareto_index.get(&instruction_id).map(Vec::as_slice) + } + + /// Returns all instruction IDs that have Pareto-optimal entries. + #[must_use] + pub fn pareto_instruction_ids(&self) -> Vec { + self.pareto_index.keys().copied().collect() + } + + /// Returns the raw node count (including the sentinel at index 0). + #[must_use] + pub fn raw_node_count(&self) -> usize { + self.nodes.len() + } + + /// Returns the total number of ISAs that can be formed from Pareto-optimal + /// nodes. + /// + /// Requires [`build_pareto_index`](Self::build_pareto_index) to have + /// been called. + #[must_use] + pub fn total_isa_count(&self) -> usize { + self.pareto_index.values().map(Vec::len).product() + } + + /// Returns ISAs formed from Pareto-optimal nodes that satisfy the given + /// requirements. + /// + /// For each constraint, selects matching Pareto-optimal nodes. Produces + /// the Cartesian product of per-constraint match sets, each augmented + /// with one representative node per unconstrained instruction ID (so + /// that returned ISAs contain entries for all instruction types in the + /// graph). + /// + /// When `min_node_idx` is `Some(n)`, only Pareto nodes with index ≥ n + /// are considered for constrained groups. Unconstrained "extra" nodes + /// are not filtered since they serve only as default placeholders. + /// + /// Requires [`build_pareto_index`](Self::build_pareto_index) to have + /// been called. + #[must_use] + pub fn query_satisfying( + &self, + graph_arc: &Arc>, + requirements: &ISARequirements, + min_node_idx: Option, + ) -> Vec { + let min_idx = min_node_idx.unwrap_or(0); + + let mut constrained_groups: Vec> = Vec::new(); + let mut constrained_ids: FxHashSet = FxHashSet::default(); + + for constraint in requirements.constraints.values() { + constrained_ids.insert(constraint.id()); + + // When a node range is specified, scan ALL nodes in the range + // instead of using the global Pareto index. The global index + // may have pruned nodes from this range as duplicates of + // earlier, equivalent nodes outside the range. + let matching: Vec<(u64, usize)> = if min_idx > 0 { + (min_idx..self.nodes.len()) + .filter(|&node_idx| { + let instr = &self.nodes[node_idx].instruction; + instr.id == constraint.id() && constraint.is_satisfied_by(instr) + }) + .map(|node_idx| (constraint.id(), node_idx)) + .collect() + } else { + let Some(pareto) = self.pareto_index.get(&constraint.id()) else { + return Vec::new(); + }; + pareto + .iter() + .filter(|&&node_idx| constraint.is_satisfied_by(self.instruction(node_idx))) + .map(|&node_idx| (constraint.id(), node_idx)) + .collect() + }; + + if matching.is_empty() { + return Vec::new(); + } + constrained_groups.push(matching); + } + + // One representative node per unconstrained instruction ID. + // When a Pareto index is available, use it; otherwise scan all + // nodes (this path is used during populate() before the index + // is built). + let extra_nodes: Vec<(u64, usize)> = if self.pareto_index.is_empty() { + let mut seen: FxHashMap = FxHashMap::default(); + for idx in 1..self.nodes.len() { + let id = self.nodes[idx].instruction.id; + if !constrained_ids.contains(&id) { + seen.entry(id).or_insert(idx); + } + } + seen.into_iter().collect() + } else { + self.pareto_index + .iter() + .filter(|(id, _)| !constrained_ids.contains(id)) + .filter_map(|(&id, nodes)| nodes.first().map(|&n| (id, n))) + .collect() + }; + + // Cartesian product of constrained groups + let mut combinations: Vec> = vec![Vec::new()]; + for group in &constrained_groups { + let mut next = Vec::with_capacity(combinations.len() * group.len()); + for combo in &combinations { + for &item in group { + let mut extended = combo.clone(); + extended.push(item); + next.push(extended); + } + } + combinations = next; + } + + // Build ISAs from selections + combinations + .into_iter() + .map(|mut combo| { + combo.extend(extra_nodes.iter().copied()); + let mut isa = ISA::with_graph(Arc::clone(graph_arc)); + for (id, node_idx) in combo { + isa.add_node(id, node_idx); + } + isa + }) + .collect() + } } struct ProvenanceNode { diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs index 079afd79fb..b22334f6da 100644 --- a/source/qre/src/lib.rs +++ b/source/qre/src/lib.rs @@ -19,7 +19,10 @@ pub use result::{EstimationCollection, EstimationResult, FactoryResult, ResultSu mod trace; pub use trace::instruction_ids; pub use trace::instruction_ids::instruction_name; -pub use trace::{Block, LatticeSurgery, PSSPC, Property, Trace, TraceTransform, estimate_parallel}; +pub use trace::{ + Block, LatticeSurgery, PSSPC, Property, Trace, TraceTransform, estimate_parallel, + estimate_with_graph, +}; mod utils; pub use utils::{binom_ppf, float_from_bits, float_to_bits}; diff --git a/source/qre/src/result.rs b/source/qre/src/result.rs index 7039ebad62..208195531b 100644 --- a/source/qre/src/result.rs +++ b/source/qre/src/result.rs @@ -182,6 +182,7 @@ pub struct EstimationCollection { all_summaries: Vec, total_jobs: usize, successful_estimates: usize, + isas: Vec, } impl EstimationCollection { @@ -216,6 +217,20 @@ impl EstimationCollection { pub fn all_summaries(&self) -> &[ResultSummary] { &self.all_summaries } + + pub fn push_isa(&mut self, isa: ISA) -> usize { + self.isas.push(isa); + self.isas.len() - 1 + } + + pub fn set_isas(&mut self, isas: Vec) { + self.isas = isas; + } + + #[must_use] + pub fn isas(&self) -> &[ISA] { + &self.isas + } } impl Deref for EstimationCollection { diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index d5fa95dd6c..21b44cd6ad 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -2,8 +2,12 @@ // Licensed under the MIT License. use std::{ + collections::hash_map::DefaultHasher, fmt::{Display, Formatter}, - sync::atomic::AtomicUsize, + hash::{Hash, Hasher}, + iter::repeat_with, + sync::{Arc, RwLock, atomic::AtomicUsize}, + vec, }; use rustc_hash::{FxHashMap, FxHashSet}; @@ -11,7 +15,7 @@ use serde::{Deserialize, Serialize}; use crate::{ Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction, LockedISA, - ResultSummary, + ProvenanceGraph, ResultSummary, property_keys::{PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS}, }; @@ -145,6 +149,31 @@ impl Trace { TraceIterator::new(&self.block) } + /// Returns the set of used instruction IDs in the trace including their volume + #[must_use] + pub fn required_instruction_ids(&self) -> FxHashMap { + let mut ids = FxHashMap::default(); + for (gate, mult) in self.deep_iter() { + let arity = gate.qubits.len() as u64; + ids.entry(gate.id) + .and_modify(|c| *c += mult * arity) + .or_insert(mult * (gate.qubits.len() as u64)); + } + if let Some(ref rs) = self.resource_states { + for (res_id, count) in rs { + ids.entry(*res_id) + .and_modify(|c| *c += *count) + .or_insert(*count); + } + } + if let Some(memory_qubits) = self.memory_qubits { + ids.entry(instruction_ids::MEMORY) + .and_modify(|c| *c += memory_qubits) + .or_insert(memory_qubits); + } + ids + } + #[must_use] pub fn depth(&self) -> u64 { self.block.depth() @@ -327,7 +356,11 @@ impl Display for Trace { } if let Some(resource_states) = &self.resource_states { for (res_id, amount) in resource_states { - writeln!(f, "@resource_state({res_id}, {amount})")?; + writeln!( + f, + "@resource_state({}, {amount})", + instruction_name(*res_id).unwrap_or("??") + )?; } } write!(f, "{}", self.block) @@ -390,7 +423,27 @@ impl Block { match op { Operation::GateOperation(Gate { id, qubits, params }) => { let name = instruction_name(*id).unwrap_or("??"); - writeln!(f, "{indent_str} {name}({params:?})({qubits:?})")?; + write!(f, "{indent_str} {name}")?; + if !params.is_empty() { + write!( + f, + "({})", + params + .iter() + .map(f64::to_string) + .collect::>() + .join(", ") + )?; + } + writeln!( + f, + "({})", + qubits + .iter() + .map(u64::to_string) + .collect::>() + .join(", ") + )?; } Operation::BlockOperation(b) => { b.write(f, indent + 2)?; @@ -420,7 +473,7 @@ impl Block { let duration = match duration_fn { Some(f) => f(gate)?, - None => 1, + _ => 1, }; let end_time = start_time + duration; @@ -508,7 +561,7 @@ impl<'a> Iterator for TraceIterator<'a> { self.stack.push((block.operations.iter(), new_multiplier)); } }, - None => { + _ => { self.stack.pop(); } } @@ -663,6 +716,7 @@ pub fn estimate_parallel<'a>( traces: &[&'a Trace], isas: &[&'a ISA], max_error: Option, + post_process: bool, ) -> EstimationCollection { let total_jobs = traces.len() * isas.len(); let num_isas = isas.len(); @@ -720,13 +774,15 @@ pub fn estimate_parallel<'a>( // Collect results from all workers into the shared collection. let mut successful = 0; for local_results in rx { - for result in &local_results { - collection.push_summary(ResultSummary { - trace_index: result.trace_index().unwrap_or(0), - isa_index: result.isa_index().unwrap_or(0), - qubits: result.qubits(), - runtime: result.runtime(), - }); + if post_process { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } } successful += local_results.len(); collection.extend(local_results.into_iter()); @@ -744,3 +800,304 @@ pub fn estimate_parallel<'a>( collection } + +/// A single entry in a combination of instruction choices for estimation. +#[derive(Clone, Copy, Hash, Eq, PartialEq)] +struct CombinationEntry { + instruction_id: u64, + node_index: usize, + space: u64, + time: u64, +} + +/// Per-slot pruning witnesses: maps a context hash to the `(space, time)` +/// pairs observed in successful estimations. +type SlotWitnesses = RwLock>>; + +/// Computes a hash of the combination context (all slots except the excluded +/// one). Two combinations that agree on every slot except `exclude_idx` +/// produce the same context hash. +fn combination_context_hash(combination: &[CombinationEntry], exclude_idx: usize) -> u64 { + let mut hasher = DefaultHasher::new(); + for (i, entry) in combination.iter().enumerate() { + if i != exclude_idx { + entry.instruction_id.hash(&mut hasher); + entry.node_index.hash(&mut hasher); + } + } + hasher.finish() +} + +/// Checks whether a combination is dominated by a previously successful one. +/// +/// A combination is prunable if, for any instruction slot, there exists a +/// successful combination with the same instructions in all other slots and +/// an instruction at that slot with `space <=` and `time <=`. +fn is_dominated(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) -> bool { + for (slot_idx, entry) in combination.iter().enumerate() { + let ctx_hash = combination_context_hash(combination, slot_idx); + let map = trace_pruning[slot_idx] + .read() + .expect("Pruning lock poisoned"); + if map.get(&ctx_hash).is_some_and(|w| { + w.iter() + .any(|&(ws, wt)| ws <= entry.space && wt <= entry.time) + }) { + return true; + } + } + false +} + +/// Records a successful estimation as a pruning witness for future +/// combinations. +fn record_success(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) { + for (slot_idx, entry) in combination.iter().enumerate() { + let ctx_hash = combination_context_hash(combination, slot_idx); + let mut map = trace_pruning[slot_idx] + .write() + .expect("Pruning lock poisoned"); + map.entry(ctx_hash) + .or_default() + .push((entry.space, entry.time)); + } +} + +#[derive(Default)] +struct ISAIndex { + index: FxHashMap, usize>, + isas: Vec, +} + +impl From for Vec { + fn from(value: ISAIndex) -> Self { + value.isas + } +} + +impl ISAIndex { + pub fn push(&mut self, combination: &Vec, isa: &ISA) -> usize { + if let Some(&idx) = self.index.get(combination) { + idx + } else { + let idx = self.isas.len(); + self.isas.push(isa.clone()); + self.index.insert(combination.clone(), idx); + idx + } + } +} + +#[must_use] +#[allow(clippy::cast_precision_loss, clippy::too_many_lines)] +pub fn estimate_with_graph( + traces: &[&Trace], + graph: &Arc>, + max_error: Option, + post_process: bool, +) -> EstimationCollection { + let max_error = max_error.unwrap_or(1.0); + + // Phase 1: Pre-compute all (trace_index, combination) jobs sequentially. + // This reads the provenance graph once per trace and generates the + // cartesian product of Pareto-filtered nodes. Each node carries + // pre-computed (space, time) values for dominance pruning in Phase 2. + let mut jobs: Vec<(usize, Vec)> = Vec::new(); + + // Use the maximum number of instruction slots across all combinations to + // size the pruning witness structure. This will updated while we generate + // jobs. + let mut max_slots = 0; + + for (trace_idx, trace) in traces.iter().enumerate() { + if trace.base_error() > max_error { + continue; + } + + let required = trace.required_instruction_ids(); + + let graph_lock = graph.read().expect("Graph lock poisoned"); + let id_and_nodes: Vec<_> = required + .iter() + .filter_map(|(&id, &volume)| { + let max_error_rate = max_error / (volume as f64); + graph_lock.pareto_nodes(id).map(|nodes| { + ( + id, + nodes + .iter() + .filter(|&&node| { + let instruction = graph_lock.instruction(node); + instruction.error_rate(Some(1)).unwrap_or(0.0) <= max_error_rate + }) + .map(|&node| { + let instruction = graph_lock.instruction(node); + let space = instruction.space(Some(1)).unwrap_or(0); + let time = instruction.time(Some(1)).unwrap_or(0); + (node, space, time) + }) + .collect::>(), + ) + }) + }) + .collect(); + drop(graph_lock); + + if id_and_nodes.len() != required.len() { + // If any required instruction is missing from the graph, we can't + // run any estimation for this trace. + continue; + } + + let mut combinations: Vec> = vec![Vec::new()]; + for (id, nodes) in id_and_nodes { + let mut new_combinations = Vec::new(); + for (node, space, time) in nodes { + for combo in &combinations { + let mut new_combo = combo.clone(); + new_combo.push(CombinationEntry { + instruction_id: id, + node_index: node, + space, + time, + }); + new_combinations.push(new_combo); + } + } + combinations = new_combinations; + } + + for combination in combinations { + max_slots = max_slots.max(combination.len()); + jobs.push((trace_idx, combination)); + } + } + + // Sort jobs so that combinations with smaller total (space + time) are + // processed first. This maximises the effectiveness of dominance pruning + // because successful "cheap" combinations establish witnesses that let us + // skip more expensive ones. + jobs.sort_by_key(|(_, combo)| { + combo + .iter() + .map(|entry| entry.space + entry.time) + .sum::() + }); + + let total_jobs = jobs.len(); + + // Phase 2: Run estimations in parallel with dominance-based pruning. + // + // For each instruction slot in a combination, we track (space, time) + // witnesses from successful estimations keyed by the "context", which is a + // hash of the node indices in all *other* slots. Before running an + // estimation, we check every slot: if a witness with space ≤ and time ≤ + // exists for that context, the combination is dominated and skipped. + let next_job = AtomicUsize::new(0); + + let pruning_witnesses: Vec> = repeat_with(|| { + repeat_with(|| RwLock::new(FxHashMap::default())) + .take(max_slots) + .collect() + }) + .take(traces.len()) + .collect(); + + // There are no explicit ISAs in this estimation function, as we create them + // on the fly from the graph nodes. For successful jobs, we will attach the + // ISAs to the results collection in a vector with the ISA index addressing + // that vector. In order to avoid storing duplicate ISAs we hash the ISA + // index. + let isa_index = Arc::new(RwLock::new(ISAIndex::default())); + + let mut collection = EstimationCollection::new(); + collection.set_total_jobs(total_jobs); + + std::thread::scope(|scope| { + let num_threads = std::thread::available_parallelism() + .map(std::num::NonZero::get) + .unwrap_or(1); + + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); + + for _ in 0..num_threads { + let tx = tx.clone(); + let next_job = &next_job; + let jobs = &jobs; + let pruning_witnesses = &pruning_witnesses; + let isa_index = Arc::clone(&isa_index); + scope.spawn(move || { + let mut local_results = Vec::new(); + loop { + let job_idx = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if job_idx >= total_jobs { + break; + } + + let (trace_idx, combination) = &jobs[job_idx]; + + // Dominance pruning: skip if a cheaper instruction at any + // slot already succeeded with the same surrounding context. + if is_dominated(combination, &pruning_witnesses[*trace_idx]) { + continue; + } + + let mut isa = ISA::with_graph(graph.clone()); + for entry in combination { + isa.add_node(entry.instruction_id, entry.node_index); + } + + if let Ok(mut result) = traces[*trace_idx].estimate(&isa, Some(max_error)) { + let isa_idx = isa_index + .write() + .expect("RwLock should not be poisoned") + .push(combination, &isa); + result.set_isa_index(isa_idx); + + result.set_trace_index(*trace_idx); + + local_results.push(result); + record_success(combination, &pruning_witnesses[*trace_idx]); + } + } + let _ = tx.send(local_results); + }); + } + drop(tx); + + let mut successful = 0; + for local_results in rx { + if post_process { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } + } + successful += local_results.len(); + collection.extend(local_results.into_iter()); + } + collection.set_successful_estimates(successful); + }); + + let isa_index = Arc::try_unwrap(isa_index) + .ok() + .expect("all threads joined; Arc refcount should be 1") + .into_inner() + .expect("RwLock should not be poisoned"); + + // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap + // clones for discarded results. + for result in collection.iter_mut() { + if let Some(idx) = result.isa_index() { + result.set_isa(isa_index.isas[idx].clone()); + } + } + + collection.set_isas(isa_index.into()); + + collection +} From 5de8ec1ea154ddaccab6be1159768991013a8a0f Mon Sep 17 00:00:00 2001 From: Brad Lackey Date: Mon, 23 Mar 2026 10:40:53 -0400 Subject: [PATCH 30/45] Magnets: modified Trotter expansion classes and added Cirq output. (#3040) New features for Trotter expansions. * Main class is now TrotterExpansion. * TrotterStep represents a single Trotter step. * TrotterExpansion accepts a function the creates the target type of Trotter step. * Minor changes to the Model class and PauliString class to support changes. * Added cirq.Circuit output to TrotterStep instances. * Added cirq.CircuitOperation output to TrotterExpansion instances. * Updated unit tests. --- source/pip/qsharp/magnets/models/model.py | 39 ++- source/pip/qsharp/magnets/trotter/trotter.py | 225 +++++++-------- source/pip/qsharp/magnets/utilities/pauli.py | 4 + source/pip/tests/magnets/test_model.py | 39 +++ source/pip/tests/magnets/test_pauli.py | 12 + source/pip/tests/magnets/test_trotter.py | 272 +++++++++---------- 6 files changed, 328 insertions(+), 263 deletions(-) diff --git a/source/pip/qsharp/magnets/models/model.py b/source/pip/qsharp/magnets/models/model.py index af6373fcd9..d0cf8b1887 100755 --- a/source/pip/qsharp/magnets/models/model.py +++ b/source/pip/qsharp/magnets/models/model.py @@ -97,12 +97,6 @@ def add_interaction( self._terms[term][color] = [] self._terms[term][color].append(len(self._ops) - 1) - def terms(self, t: int) -> Iterator[PauliString]: - """Get the list of PauliStrings corresponding to a term group.""" - if t not in self._terms: - raise ValueError("Term group does not exist.") - return iter([self._ops[i] for i in self._terms[t]]) - @property def nqubits(self) -> int: """Return the number of qubits in the model.""" @@ -113,6 +107,39 @@ def nterms(self) -> int: """Return the number of term groups in the model.""" return len(self._terms) + @property + def terms(self) -> list[int]: + """Get the list of term indices in the model.""" + return list(self._terms.keys()) + + def ncolors(self, term: int) -> int: + """Return the number of colors in a given term.""" + if term not in self._terms: + raise ValueError(f"Term {term} does not exist in the model.") + return len(self._terms[term]) + + def colors(self, term: int) -> list[int]: + """Return the list of colors in a given term.""" + if term not in self._terms: + raise ValueError(f"Term {term} does not exist in the model.") + return list(self._terms[term].keys()) + + def nops(self, term: int, color: int) -> int: + """Return the number of operators in a given term and color.""" + if term not in self._terms: + raise ValueError(f"Term {term} does not exist in the model.") + if color not in self._terms[term]: + raise ValueError(f"Color {color} does not exist in term {term}.") + return len(self._terms[term][color]) + + def ops(self, term: int, color: int) -> list[PauliString]: + """Return the list of operators in a given term and color.""" + if term not in self._terms: + raise ValueError(f"Term {term} does not exist in the model.") + if color not in self._terms[term]: + raise ValueError(f"Color {color} does not exist in term {term}.") + return [self._ops[i] for i in self._terms[term][color]] + def __str__(self) -> str: """String representation of the model.""" return "Generic model with {} terms on {} qubits.".format( diff --git a/source/pip/qsharp/magnets/trotter/trotter.py b/source/pip/qsharp/magnets/trotter/trotter.py index f7ac8b18f0..383aca5ae2 100644 --- a/source/pip/qsharp/magnets/trotter/trotter.py +++ b/source/pip/qsharp/magnets/trotter/trotter.py @@ -1,60 +1,63 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Base Trotter class for first- and second-order Trotter-Suzuki decomposition.""" +"""Trotter schedule utilities for magnet models. +This module provides: +- ``TrotterStep``: a schedule of ``(time, term_index)`` entries, +- recursion helpers (Suzuki and Yoshida) that raise the order by 2, +- factory helpers such as Strang splitting, and +- ``TrotterExpansion`` to apply a step repeatedly to a concrete model. +""" + +from collections.abc import Callable from typing import Iterator, Optional +from qsharp.magnets.models import Model +from qsharp.magnets.utilities import PauliString +import math -class TrotterStep: - """ - Base class for Trotter decompositions. +try: + import cirq +except Exception as ex: + raise ImportError( + "qsharp.magnets.models requires the cirq extras. Install with 'pip install \"qsharp[cirq]\"'." + ) from ex - Essentially, this is a wrapper around a list of ``(time, term_index)`` tuples, - which specify which term to apply for how long, independent of the specific - Trotter decomposition or model being used. - The TrotterStep class provides a common interface for different Trotter decompositions, - such as first-order Trotter and Strang splitting. It also serves as the base class for - higher-order Trotter steps that can be constructed via Suzuki or Yoshida recursion. Each - Trotter step is defined by the sequence of terms to apply and their corresponding time - durations, as well as the overall order of the decomposition and the time step for each term. +class TrotterStep: + """Schedule of Hamiltonian-term applications for one Trotter step. + + A ``TrotterStep`` stores an ordered list of ``(time, term_index)`` tuples. + Each tuple indicates that term group ``term_index`` should be applied for + evolution time ``time``. - The constructor creates an empty Trotter step (when ``num_terms = 0``), or a - first-order Trotter step: + The constructor builds a first-order step over the provided term indices: .. math:: e^{-i H t} \\approx \\prod_k e^{-i H_k t}, \\quad H = \\sum_k H_k. - In the first-order case, each term index from ``0`` to ``num_terms - 1`` appears - once, each with duration ``time_step``. - - Example: - - .. code-block:: python - - >>> trotter = TrotterStep(num_terms=3, time_step=0.5) - >>> list(trotter.step()) - [(0.5, 0), (0.5, 1), (0.5, 2)] - - References: - H. F. Trotter, Proc. Amer. Math. Soc. 10, 545 (1959). - - TODO: Initializer offers randomized order of terms. + where each supplied term index appears once with duration ``time_step``. """ - def __init__(self, num_terms: int = 0, time_step: float = 0.0): - """ - Creates an empty Trotter decomposition. + def __init__(self, terms: list[int] = [], time_step: float = 0.0): + """Initialize a Trotter step from explicit term indices. + + Args: + terms: Ordered term indices to include in this step. + time_step: Duration associated with each listed term. + Notes: + If ``terms`` is empty, the step is initialized as order 0. + Otherwise, it is initialized as order 1. """ - self._nterms = num_terms + self._nterms = len(terms) self._time_step = time_step - self._order = 1 if num_terms > 0 else 0 + self._order = 1 if self._nterms > 0 else 0 self._repr_string: Optional[str] = None - self.terms: list[tuple[float, int]] = [(time_step, j) for j in range(num_terms)] + self.terms: list[tuple[float, int]] = [(time_step, j) for j in terms] @property def order(self) -> int: @@ -63,12 +66,12 @@ def order(self) -> int: @property def nterms(self) -> int: - """Get the number of terms in the Hamiltonian.""" + """Get the number of term entries used to build this schedule.""" return self._nterms @property def time_step(self) -> float: - """Get the time step for each term in the Trotter decomposition.""" + """Get the base time step metadata stored on this step.""" return self._time_step def reduce(self) -> None: @@ -100,14 +103,33 @@ def reduce(self) -> None: self.terms = reduced_terms def step(self) -> Iterator[tuple[float, int]]: - """ - Iterate over the Trotter decomposition as a list of (time, term_index) tuples. + """Iterate over ``(time, term_index)`` entries for this step.""" + return iter(self.terms) + + def cirq(self, model: Model) -> cirq.Circuit: + """Build a Cirq circuit for one application of this Trotter step. + + Args: + model: Model that maps each term index to grouped Pauli operators. Returns: - Iterator of tuples where each tuple contains the time duration and the - index of the term to be applied. + A ``cirq.Circuit`` containing ``cirq.PauliStringPhasor`` operations + in the same order as ``self.step()``. """ - return iter(self.terms) + _INT_TO_CIRQ = (cirq.I, cirq.X, cirq.Z, cirq.Y) + circuit = cirq.Circuit() + for time, term_index in self.step(): + for color in model.colors(term_index): + for op in model.ops(term_index, color): + pauli = cirq.PauliString( + { + cirq.LineQubit(p.qubit): _INT_TO_CIRQ[p.op] + for p in op._paulis + }, + ) + oper = cirq.PauliStringPhasor(pauli, exponent_neg=time / math.pi) + circuit.append(oper) + return circuit def __str__(self) -> str: """String representation of the Trotter decomposition.""" @@ -209,14 +231,13 @@ def yoshida_recursion(trotter: TrotterStep) -> TrotterStep: return yoshida -def strang_splitting(num_terms: int, time: float) -> TrotterStep: +def strang_splitting(terms: list[int], time: float) -> TrotterStep: """ - Factory function for creating a Strang splitting (second-order - Trotter-Suzuki decomposition). + Create a second-order Strang splitting schedule for explicit term indices. The second-order Trotter formula uses symmetric splitting: - e^{-i H t} ≈ ∏_{k=1}^{n} e^{-i H_k t/2} ∏_{k=n}^{1} e^{-i H_k t/2} + e^{-i H t} \\approx \\prod_{k=1}^{n-1} e^{-i H_k t/2} \\, e^{-i H_n t} \\, \\prod_{k=n-1}^{1} e^{-i H_k t/2} This provides second-order accuracy in the time step, compared to first-order for the basic Trotter decomposition. @@ -224,28 +245,35 @@ def strang_splitting(num_terms: int, time: float) -> TrotterStep: Example: .. code-block:: python - >>> strang = strang_splitting(num_terms=3, time=0.5) + >>> strang = strang_splitting(terms=[0, 1, 2], time=0.5) >>> list(strang.step()) [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] + Args: + terms: Ordered term indices for a single symmetric step. Must be non-empty. + time: Total evolution time assigned to this second-order step. + + Returns: + A second-order ``TrotterStep``. + References: G. Strang, SIAM J. Numer. Anal. 5, 506 (1968). """ strang = TrotterStep() - strang._nterms = num_terms + strang._nterms = len(terms) strang._time_step = time strang._order = 2 - strang._repr_string = f"StrangSplitting(time_step={time}, num_terms={num_terms})" + strang._repr_string = f"StrangSplitting(time_step={time}, num_terms={len(terms)})" strang.terms = [] - for term_index in range(num_terms - 1): - strang.terms.append((time / 2, term_index)) - strang.terms.append((time, num_terms - 1)) - for term_index in reversed(range(num_terms - 1)): - strang.terms.append((time / 2, term_index)) + for i in range(len(terms) - 1): + strang.terms.append((time / 2, terms[i])) + strang.terms.append((time, terms[-1])) + for i in reversed(range(len(terms) - 1)): + strang.terms.append((time / 2, terms[i])) return strang -def fourth_order_trotter_suzuki(num_terms: int, time: float) -> TrotterStep: +def fourth_order_trotter_suzuki(terms: list[int], time: float) -> TrotterStep: """ Factory function for creating a fourth-order Trotter-Suzuki decomposition using Suzuki recursion. @@ -258,51 +286,41 @@ def fourth_order_trotter_suzuki(num_terms: int, time: float) -> TrotterStep: Example: .. code-block:: python - >>> fourth_order = fourth_order_trotter_suzuki(num_terms=3, time=0.5) + >>> fourth_order = fourth_order_trotter_suzuki(terms=[0, 1, 2], time=0.5) >>> list(fourth_order.step()) [(0.1767766952966369, 0), (0.1767766952966369, 1), (0.1767766952966369, 2), (0.3535533905932738, 1), (0.3535533905932738, 0), (0.1767766952966369, 1), (0.1767766952966369, 2), (0.1767766952966369, 1), (0.1767766952966369, 0)] """ - return suzuki_recursion(strang_splitting(num_terms, time)) + return suzuki_recursion(strang_splitting(terms, time)) class TrotterExpansion: - """ - Trotter expansion for repeated application of a Trotter step. + """Repeated application of a Trotter method on a concrete model. - This class wraps a TrotterStep instance and specifies how many times to repeat - the step. The expansion represents full time evolution as a sequence of - Trotter steps: + ``TrotterExpansion`` builds one step with ``trotter_method(model.terms, dt)`` + where ``dt = time / num_steps`` and then repeats it ``num_steps`` times. - e^{-i H T} ≈ (S(T/n))^n - - where S is the Trotter step formula, T is the total time, and n is the number - of steps. - - Example: - - .. code-block:: python - >>> n = 4 # Number of Trotter steps - >>> total_time = 1.0 # Total time - >>> step = TrotterStep(num_terms=2, time_step=total_time/n) - >>> expansion = TrotterExpansion(step, n) - >>> expansion.order - 1 - >>> expansion.total_time - 1.0 - >>> list(expansion.step())[:4] - [(0.25, 0), (0.25, 1), (0.25, 0), (0.25, 1)] + Iteration via :meth:`step` yields ``PauliString`` operators already scaled by + the per-entry schedule time. """ - def __init__(self, trotter_step: TrotterStep, num_steps: int): - """ - Initialize the Trotter expansion. + def __init__( + self, + trotter_method: Callable[[list[int], float], TrotterStep], + model: Model, + time: float, + num_steps: int, + ): + """Initialize a repeated-step Trotter expansion. Args: - trotter_step: An instance of TrotterStep representing a single Trotter step. - num_steps: Number of times to repeat the Trotter step. + trotter_method: Callable mapping ``(terms, dt)`` to a ``TrotterStep``. + model: Model that defines term groups and per-term operators. + time: Total evolution time. + num_steps: Number of repeated Trotter steps. """ - self._trotter_step = trotter_step + self._model = model self._num_steps = num_steps + self._trotter_step = trotter_method(model.terms, time / num_steps) @property def order(self) -> int: @@ -312,7 +330,7 @@ def order(self) -> int: @property def nterms(self) -> int: """Get the number of Hamiltonian terms.""" - return self._trotter_step.nterms + return self._model.nterms @property def nsteps(self) -> int: @@ -324,28 +342,23 @@ def total_time(self) -> float: """Get the total evolution time (time_step * num_steps).""" return self._trotter_step.time_step * self._num_steps - def step(self) -> Iterator[tuple[float, int]]: - """ - Iterate over the full Trotter expansion. + def step(self) -> Iterator[PauliString]: + """Iterate over scaled operators for the full expansion. - Yields all (time, term_index) tuples for the complete expansion, - repeating the Trotter step sequence num_steps times. - - Returns: - Iterator of (time, term_index) tuples for the full evolution. + Yields: + ``PauliString`` operators with coefficients scaled by schedule time, + in execution order across all repeated steps. """ for _ in range(self._num_steps): - yield from self._trotter_step.step() - - def get(self) -> list[tuple[list[tuple[float, int]], int]]: - """ - Get the Trotter expansion as a compact representation. - - Returns: - List containing a single tuple of (terms, num_steps) where terms - is the list of (time, term_index) for one step. - """ - return [(list(self._trotter_step.step()), self._num_steps)] + for s, i in self._trotter_step.step(): + for c in self._model.colors(i): + for op in self._model.ops(i, c): + yield (op * s) + + def cirq(self) -> cirq.CircuitOperation: + """Get a repeated Cirq circuit operation for this expansion.""" + circuit = self._trotter_step.cirq(self._model).freeze() + return cirq.CircuitOperation(circuit, repetitions=self._num_steps) def __str__(self) -> str: """String representation of the Trotter expansion.""" diff --git a/source/pip/qsharp/magnets/utilities/pauli.py b/source/pip/qsharp/magnets/utilities/pauli.py index c681aa2987..4eb7b92e5b 100644 --- a/source/pip/qsharp/magnets/utilities/pauli.py +++ b/source/pip/qsharp/magnets/utilities/pauli.py @@ -230,6 +230,10 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> Pauli: return self._paulis[index] + def __mul__(self, scalar: complex) -> "PauliString": + """Scale the coefficient of this PauliString by a complex scalar.""" + return PauliString(self._paulis, coefficient=self._coefficient * scalar) + def __str__(self) -> str: labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} s = "".join(map(str, self._paulis)) diff --git a/source/pip/tests/magnets/test_model.py b/source/pip/tests/magnets/test_model.py index 00222ec4fe..31913b62ad 100755 --- a/source/pip/tests/magnets/test_model.py +++ b/source/pip/tests/magnets/test_model.py @@ -75,6 +75,45 @@ def test_model_add_interaction_with_term(): assert model._terms[3] == {0: [0]} +def test_model_term_color_query_methods(): + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -1.0, term=1, color=2) + model.add_interaction(edge, "XX", -0.5, term=1, color=2) + model.add_interaction(edge, "YY", -0.25, term=1, color=3) + + assert model.terms == [1] + assert model.ncolors(1) == 2 + assert set(model.colors(1)) == {2, 3} + assert model.nops(1, 2) == 2 + assert model.nops(1, 3) == 1 + assert model.ops(1, 2) == [ + PauliString.from_qubits((0, 1), "ZZ", -1.0), + PauliString.from_qubits((0, 1), "XX", -0.5), + ] + assert model.ops(1, 3) == [PauliString.from_qubits((0, 1), "YY", -0.25)] + + +def test_model_query_methods_raise_for_missing_term_and_color(): + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -1.0, term=0, color=0) + + with pytest.raises(ValueError, match="Term 99 does not exist in the model"): + model.ncolors(99) + with pytest.raises(ValueError, match="Term 99 does not exist in the model"): + model.colors(99) + with pytest.raises(ValueError, match="Term 99 does not exist in the model"): + model.nops(99, 0) + with pytest.raises(ValueError, match="Term 99 does not exist in the model"): + model.ops(99, 0) + + with pytest.raises(ValueError, match="Color 7 does not exist in term 0"): + model.nops(0, 7) + with pytest.raises(ValueError, match="Color 7 does not exist in term 0"): + model.ops(0, 7) + + def test_model_add_interaction_rejects_edge_not_in_geometry(): model = Model(Hypergraph([Hyperedge([0, 1])])) with pytest.raises(ValueError, match="Edge is not part of the model geometry"): diff --git a/source/pip/tests/magnets/test_pauli.py b/source/pip/tests/magnets/test_pauli.py index 7ca82e7c6f..03d29e8a96 100644 --- a/source/pip/tests/magnets/test_pauli.py +++ b/source/pip/tests/magnets/test_pauli.py @@ -101,6 +101,18 @@ def test_pauli_string_equality_and_hash_include_coefficient(): assert p1 != p3 +def test_pauli_string_mul_scales_coefficient_and_preserves_terms(): + """Test PauliString.__mul__ returns scaled coefficient with same operators.""" + ps = PauliString.from_qubits((0, 2), "XZ", coefficient=2.0) + + scaled = ps * (-0.25j) + + assert scaled.qubits == ps.qubits + assert list(scaled) == list(ps) + assert scaled.coefficient == -0.5j + assert ps.coefficient == 2.0 + + def test_pauli_string_cirq_property_preserves_terms_and_coefficient(): """Test PauliString.cirq conversion with coefficient.""" ps = PauliString.from_qubits((0, 2), "XZ", coefficient=-0.5j) diff --git a/source/pip/tests/magnets/test_trotter.py b/source/pip/tests/magnets/test_trotter.py index 2b2fe22572..2bab4fa3c8 100644 --- a/source/pip/tests/magnets/test_trotter.py +++ b/source/pip/tests/magnets/test_trotter.py @@ -3,6 +3,11 @@ """Unit tests for Trotter-Suzuki decomposition classes and factory functions.""" +import pytest + +from qsharp.magnets.models import Model +from qsharp.magnets.utilities import Hyperedge, Hypergraph, PauliString + from qsharp.magnets.trotter import ( TrotterStep, TrotterExpansion, @@ -13,7 +18,15 @@ ) -# TrotterStep base class tests +def make_two_term_model() -> Model: + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -2.0, term=0, color=0) + model.add_interaction(edge, "XX", -0.5, term=1, color=0) + return model + + +# TrotterStep tests def test_trotter_step_empty_init(): @@ -51,64 +64,64 @@ def test_trotter_step_reduce_empty(): # first-order TrotterStep constructor tests -def test_trotter_step_first_order_basic(): - """Test basic first-order TrotterStep creation.""" - trotter = TrotterStep(num_terms=3, time_step=0.5) +def test_trotter_step_from_explicit_terms_basic(): + """Test basic TrotterStep creation from explicit term indices.""" + trotter = TrotterStep(terms=[0, 1, 2], time_step=0.5) assert trotter.nterms == 3 assert trotter.time_step == 0.5 assert trotter.order == 1 def test_trotter_step_first_order_single_term(): - """Test first-order TrotterStep with a single term.""" - trotter = TrotterStep(num_terms=1, time_step=1.0) + """Test TrotterStep with a single explicit term.""" + trotter = TrotterStep(terms=[7], time_step=1.0) result = list(trotter.step()) - assert result == [(1.0, 0)] + assert result == [(1.0, 7)] def test_trotter_step_first_order_multiple_terms(): - """Test first-order TrotterStep with multiple terms.""" - trotter = TrotterStep(num_terms=3, time_step=0.5) + """Test TrotterStep with multiple explicit terms.""" + trotter = TrotterStep(terms=[0, 1, 2], time_step=0.5) result = list(trotter.step()) assert result == [(0.5, 0), (0.5, 1), (0.5, 2)] def test_trotter_step_first_order_zero_time(): - """Test first-order TrotterStep with zero time.""" - trotter = TrotterStep(num_terms=2, time_step=0.0) + """Test TrotterStep with zero time.""" + trotter = TrotterStep(terms=[0, 1], time_step=0.0) result = list(trotter.step()) assert result == [(0.0, 0), (0.0, 1)] def test_trotter_step_first_order_returns_all_terms(): - """Test that first-order TrotterStep returns all term indices.""" - num_terms = 5 - trotter = TrotterStep(num_terms=num_terms, time_step=1.0) + """Test that TrotterStep returns all provided term indices in order.""" + terms = [2, 4, 7, 11, 15] + trotter = TrotterStep(terms=terms, time_step=1.0) result = list(trotter.step()) - assert len(result) == num_terms + assert len(result) == len(terms) term_indices = [idx for _, idx in result] - assert term_indices == list(range(num_terms)) + assert term_indices == terms def test_trotter_step_first_order_uniform_time(): - """Test that all terms have the same time in first-order TrotterStep.""" + """Test that all entries have the same configured time.""" time = 0.25 - trotter = TrotterStep(num_terms=4, time_step=time) + trotter = TrotterStep(terms=[0, 1, 2, 3], time_step=time) result = list(trotter.step()) for t, _ in result: assert t == time def test_trotter_step_first_order_str(): - """Test string representation of first-order TrotterStep.""" - trotter = TrotterStep(num_terms=3, time_step=0.5) + """Test string representation of TrotterStep.""" + trotter = TrotterStep(terms=[0, 1, 2], time_step=0.5) result = str(trotter) assert "order" in result.lower() or "1" in result def test_trotter_step_first_order_repr(): - """Test repr representation of first-order TrotterStep.""" - trotter = TrotterStep(num_terms=3, time_step=0.5) + """Test repr representation of TrotterStep.""" + trotter = TrotterStep(terms=[0, 1, 2], time_step=0.5) assert "TrotterStep" in repr(trotter) @@ -117,7 +130,7 @@ def test_trotter_step_first_order_repr(): def test_strang_splitting_basic(): """Test basic strang_splitting creation.""" - strang = strang_splitting(num_terms=3, time=0.5) + strang = strang_splitting(terms=[0, 1, 2], time=0.5) assert strang.nterms == 3 assert strang.time_step == 0.5 assert strang.order == 2 @@ -125,7 +138,7 @@ def test_strang_splitting_basic(): def test_strang_splitting_single_term(): """Test strang_splitting with a single term.""" - strang = strang_splitting(num_terms=1, time=1.0) + strang = strang_splitting(terms=[0], time=1.0) result = list(strang.step()) # Single term: just full time on term 0 assert result == [(1.0, 0)] @@ -133,7 +146,7 @@ def test_strang_splitting_single_term(): def test_strang_splitting_two_terms(): """Test strang_splitting with two terms.""" - strang = strang_splitting(num_terms=2, time=1.0) + strang = strang_splitting(terms=[0, 1], time=1.0) result = list(strang.step()) # Forward: half on term 0, full on term 1, backward: half on term 0 assert result == [(0.5, 0), (1.0, 1), (0.5, 0)] @@ -141,7 +154,7 @@ def test_strang_splitting_two_terms(): def test_strang_splitting_three_terms(): """Test strang_splitting with three terms (example from docstring).""" - strang = strang_splitting(num_terms=3, time=0.5) + strang = strang_splitting(terms=[0, 1, 2], time=0.5) result = list(strang.step()) expected = [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] assert result == expected @@ -149,7 +162,7 @@ def test_strang_splitting_three_terms(): def test_strang_splitting_symmetric(): """Test that strang_splitting produces symmetric sequence.""" - strang = strang_splitting(num_terms=4, time=1.0) + strang = strang_splitting(terms=[0, 1, 2, 3], time=1.0) result = list(strang.step()) # Check symmetry: term indices should be palindromic term_indices = [idx for _, idx in result] @@ -159,18 +172,18 @@ def test_strang_splitting_symmetric(): def test_strang_splitting_time_sum(): """Test that total time in strang_splitting equals expected value.""" time = 1.0 - num_terms = 3 - strang = strang_splitting(num_terms=num_terms, time=time) + terms = [0, 1, 2] + strang = strang_splitting(terms=terms, time=time) result = list(strang.step()) total_time = sum(t for t, _ in result) # Each term appears once with full time equivalent # (half + half for outer terms, full for middle) - assert abs(total_time - time * num_terms) < 1e-10 + assert abs(total_time - time * len(terms)) < 1e-10 def test_strang_splitting_middle_term_full_time(): """Test that the middle term gets full time step.""" - strang = strang_splitting(num_terms=5, time=2.0) + strang = strang_splitting(terms=[0, 1, 2, 3, 4], time=2.0) result = list(strang.step()) # Middle term (index 4, the last term) should have full time middle_entries = [(t, idx) for t, idx in result if idx == 4] @@ -180,7 +193,7 @@ def test_strang_splitting_middle_term_full_time(): def test_strang_splitting_outer_terms_half_time(): """Test that outer terms get half time steps.""" - strang = strang_splitting(num_terms=4, time=2.0) + strang = strang_splitting(terms=[0, 1, 2, 3], time=2.0) result = list(strang.step()) # Term 0 should appear twice with half time each term_0_entries = [(t, idx) for t, idx in result if idx == 0] @@ -191,7 +204,7 @@ def test_strang_splitting_outer_terms_half_time(): def test_strang_splitting_repr(): """Test repr representation of strang_splitting result.""" - strang = strang_splitting(num_terms=3, time=0.5) + strang = strang_splitting(terms=[0, 1, 2], time=0.5) assert "StrangSplitting" in repr(strang) @@ -200,7 +213,7 @@ def test_strang_splitting_repr(): def test_suzuki_recursion_from_strang(): """Test Suzuki recursion applied to Strang splitting produces 4th order.""" - strang = strang_splitting(num_terms=2, time=1.0) + strang = strang_splitting(terms=[0, 1], time=1.0) suzuki = suzuki_recursion(strang) assert suzuki.order == 4 assert suzuki.nterms == 2 @@ -209,7 +222,7 @@ def test_suzuki_recursion_from_strang(): def test_suzuki_recursion_from_first_order(): """Test Suzuki recursion applied to first-order Trotter produces 3rd order.""" - trotter = TrotterStep(num_terms=2, time_step=1.0) + trotter = TrotterStep(terms=[0, 1], time_step=1.0) suzuki = suzuki_recursion(trotter) assert suzuki.order == 3 assert suzuki.nterms == 2 @@ -217,28 +230,28 @@ def test_suzuki_recursion_from_first_order(): def test_suzuki_recursion_preserves_nterms(): """Test that Suzuki recursion preserves number of terms.""" - base = strang_splitting(num_terms=5, time=0.5) + base = strang_splitting(terms=[0, 1, 2, 3, 4], time=0.5) suzuki = suzuki_recursion(base) assert suzuki.nterms == base.nterms def test_suzuki_recursion_preserves_time_step(): """Test that Suzuki recursion preserves time step.""" - base = strang_splitting(num_terms=3, time=0.75) + base = strang_splitting(terms=[0, 1, 2], time=0.75) suzuki = suzuki_recursion(base) assert suzuki.time_step == base.time_step def test_suzuki_recursion_repr(): """Test repr of Suzuki recursion result.""" - base = strang_splitting(num_terms=2, time=1.0) + base = strang_splitting(terms=[0, 1], time=1.0) suzuki = suzuki_recursion(base) assert "SuzukiRecursion" in repr(suzuki) def test_suzuki_recursion_time_weights_sum(): """Test that time weights in Suzuki recursion sum correctly.""" - base = TrotterStep(num_terms=2, time_step=1.0) + base = TrotterStep(terms=[0, 1], time_step=1.0) suzuki = suzuki_recursion(base) # The total scaled time should equal the original total time * nterms # because we're scaling times, not adding them @@ -255,7 +268,7 @@ def test_suzuki_recursion_time_weights_sum(): def test_yoshida_recursion_from_strang(): """Test Yoshida recursion applied to Strang splitting produces 4th order.""" - strang = strang_splitting(num_terms=2, time=1.0) + strang = strang_splitting(terms=[0, 1], time=1.0) yoshida = yoshida_recursion(strang) assert yoshida.order == 4 assert yoshida.nterms == 2 @@ -264,7 +277,7 @@ def test_yoshida_recursion_from_strang(): def test_yoshida_recursion_from_first_order(): """Test Yoshida recursion applied to first-order Trotter produces 3rd order.""" - trotter = TrotterStep(num_terms=2, time_step=1.0) + trotter = TrotterStep(terms=[0, 1], time_step=1.0) yoshida = yoshida_recursion(trotter) assert yoshida.order == 3 assert yoshida.nterms == 2 @@ -272,28 +285,28 @@ def test_yoshida_recursion_from_first_order(): def test_yoshida_recursion_preserves_nterms(): """Test that Yoshida recursion preserves number of terms.""" - base = strang_splitting(num_terms=5, time=0.5) + base = strang_splitting(terms=[0, 1, 2, 3, 4], time=0.5) yoshida = yoshida_recursion(base) assert yoshida.nterms == base.nterms def test_yoshida_recursion_preserves_time_step(): """Test that Yoshida recursion preserves time step.""" - base = strang_splitting(num_terms=3, time=0.75) + base = strang_splitting(terms=[0, 1, 2], time=0.75) yoshida = yoshida_recursion(base) assert yoshida.time_step == base.time_step def test_yoshida_recursion_repr(): """Test repr of Yoshida recursion result.""" - base = strang_splitting(num_terms=2, time=1.0) + base = strang_splitting(terms=[0, 1], time=1.0) yoshida = yoshida_recursion(base) assert "YoshidaRecursion" in repr(yoshida) def test_yoshida_recursion_time_weights_sum(): """Test that time weights in Yoshida recursion sum correctly.""" - base = TrotterStep(num_terms=2, time_step=1.0) + base = TrotterStep(terms=[0, 1], time_step=1.0) yoshida = yoshida_recursion(base) # The total scaled time should equal the original total time * nterms # because weights w1 + w0 + w1 = 2*w1 + w0 = 2*w1 + (1 - 2*w1) = 1 @@ -305,7 +318,7 @@ def test_yoshida_recursion_time_weights_sum(): def test_yoshida_fewer_terms_than_suzuki(): """Test that Yoshida produces fewer terms than Suzuki (3x vs 5x).""" - base = strang_splitting(num_terms=3, time=1.0) + base = strang_splitting(terms=[0, 1, 2], time=1.0) suzuki = suzuki_recursion(base) yoshida = yoshida_recursion(base) # Yoshida uses 3 copies, Suzuki uses 5 copies @@ -318,7 +331,7 @@ def test_yoshida_fewer_terms_than_suzuki(): def test_fourth_order_trotter_suzuki_basic(): """Test fourth_order_trotter_suzuki factory function.""" - fourth = fourth_order_trotter_suzuki(num_terms=2, time=1.0) + fourth = fourth_order_trotter_suzuki(terms=[0, 1], time=1.0) assert fourth.order == 4 assert fourth.nterms == 2 assert fourth.time_step == 1.0 @@ -326,8 +339,8 @@ def test_fourth_order_trotter_suzuki_basic(): def test_fourth_order_trotter_suzuki_equals_suzuki_of_strang(): """Test that fourth_order_trotter_suzuki equals suzuki_recursion(strang_splitting).""" - fourth = fourth_order_trotter_suzuki(num_terms=3, time=0.5) - manual = suzuki_recursion(strang_splitting(num_terms=3, time=0.5)) + fourth = fourth_order_trotter_suzuki(terms=[0, 1, 2], time=0.5) + manual = suzuki_recursion(strang_splitting(terms=[0, 1, 2], time=0.5)) assert list(fourth.step()) == list(manual.step()) assert fourth.order == manual.order @@ -335,146 +348,103 @@ def test_fourth_order_trotter_suzuki_equals_suzuki_of_strang(): # TrotterExpansion tests -def test_trotter_expansion_init_basic(): - """Test basic TrotterExpansion initialization.""" - step = TrotterStep(num_terms=2, time_step=0.25) - expansion = TrotterExpansion(step, num_steps=4) - assert expansion._trotter_step is step - assert expansion._num_steps == 4 - - -def test_trotter_expansion_get_single_step(): - """Test TrotterExpansion with a single step.""" - step = TrotterStep(num_terms=2, time_step=1.0) - expansion = TrotterExpansion(step, num_steps=1) - result = expansion.get() - assert len(result) == 1 - terms, count = result[0] - assert count == 1 - assert terms == [(1.0, 0), (1.0, 1)] - - -def test_trotter_expansion_get_multiple_steps(): - """Test TrotterExpansion with multiple steps.""" - step = TrotterStep(num_terms=2, time_step=0.25) - expansion = TrotterExpansion(step, num_steps=4) - result = expansion.get() - assert len(result) == 1 - terms, count = result[0] - assert count == 4 - assert terms == [(0.25, 0), (0.25, 1)] - - -def test_trotter_expansion_with_strang(): - """Test TrotterExpansion using strang_splitting.""" - step = strang_splitting(num_terms=2, time=0.5) - expansion = TrotterExpansion(step, num_steps=2) - result = expansion.get() - assert len(result) == 1 - terms, count = result[0] - assert count == 2 - # strang_splitting with 2 terms: [(0.25, 0), (0.5, 1), (0.25, 0)] - assert terms == [(0.25, 0), (0.5, 1), (0.25, 0)] - - -def test_trotter_expansion_total_time(): - """Test that total evolution time is correct.""" - total_time = 1.0 - num_steps = 4 - step = TrotterStep(num_terms=3, time_step=total_time / num_steps) - expansion = TrotterExpansion(step, num_steps=num_steps) - result = expansion.get() - terms, count = result[0] - # Total time = sum of times in one step * count - step_time = sum(t for t, _ in terms) - total = step_time * count - # For first-order Trotter, step_time = time * num_terms - assert abs(total - total_time * 3) < 1e-10 - - -def test_trotter_expansion_preserves_step(): - """Test that expansion preserves the original step.""" - step = TrotterStep(num_terms=3, time_step=0.5) - expansion = TrotterExpansion(step, num_steps=10) - result = expansion.get() - terms, _ = result[0] - assert terms == list(step.step()) - - -def test_trotter_expansion_with_fourth_order(): - """Test TrotterExpansion with fourth-order Trotter-Suzuki.""" - step = fourth_order_trotter_suzuki(num_terms=2, time=0.25) - expansion = TrotterExpansion(step, num_steps=4) - result = expansion.get() - terms, count = result[0] - assert count == 4 - assert step.order == 4 - - def test_trotter_expansion_order_property(): """Test TrotterExpansion order property.""" - step = strang_splitting(num_terms=3, time=0.5) - expansion = TrotterExpansion(step, num_steps=4) + model = make_two_term_model() + expansion = TrotterExpansion(strang_splitting, model, time=1.0, num_steps=4) assert expansion.order == 2 def test_trotter_expansion_nterms_property(): """Test TrotterExpansion nterms property.""" - step = TrotterStep(num_terms=5, time_step=0.5) - expansion = TrotterExpansion(step, num_steps=4) - assert expansion.nterms == 5 + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=4) + assert expansion.nterms == 2 def test_trotter_expansion_num_steps_property(): """Test TrotterExpansion num_steps property.""" - step = TrotterStep(num_terms=2, time_step=0.25) - expansion = TrotterExpansion(step, num_steps=8) + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=8) assert expansion.nsteps == 8 def test_trotter_expansion_total_time_property(): """Test TrotterExpansion total_time property.""" - step = TrotterStep(num_terms=2, time_step=0.25) - expansion = TrotterExpansion(step, num_steps=4) + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=4) assert expansion.total_time == 1.0 def test_trotter_expansion_step_iterator(): - """Test TrotterExpansion step() iterator yields full expansion.""" - step = TrotterStep(num_terms=2, time_step=0.5) - expansion = TrotterExpansion(step, num_steps=3) + """Test TrotterExpansion.step() yields scaled PauliStrings.""" + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.2, num_steps=3) result = list(expansion.step()) - # Should yield 3 repetitions of [(0.5, 0), (0.5, 1)] - expected = [(0.5, 0), (0.5, 1), (0.5, 0), (0.5, 1), (0.5, 0), (0.5, 1)] - assert result == expected + + # dt = 1.2 / 3 = 0.4 and model terms are 0->ZZ(-2.0), 1->XX(-0.5) + expected = [ + ((0, 1), "ZZ", -0.8), + ((0, 1), "XX", -0.2), + ((0, 1), "ZZ", -0.8), + ((0, 1), "XX", -0.2), + ((0, 1), "ZZ", -0.8), + ((0, 1), "XX", -0.2), + ] + assert len(result) == len(expected) + for op, (qubits, paulis, coefficient) in zip(result, expected): + assert op.qubits == qubits + assert op.paulis == paulis + assert op.coefficient == pytest.approx(coefficient) def test_trotter_expansion_step_iterator_with_strang(): - """Test TrotterExpansion step() with Strang splitting.""" - step = strang_splitting(num_terms=2, time=1.0) - expansion = TrotterExpansion(step, num_steps=2) + """Test TrotterExpansion.step() with Strang splitting schedule.""" + model = make_two_term_model() + expansion = TrotterExpansion(strang_splitting, model, time=2.0, num_steps=2) result = list(expansion.step()) - # Strang with 2 terms: [(0.5, 0), (1.0, 1), (0.5, 0)] - # Repeated twice - expected = [(0.5, 0), (1.0, 1), (0.5, 0), (0.5, 0), (1.0, 1), (0.5, 0)] + + # dt = 1.0; one Strang step over terms [0,1] is: + # (0.5,0), (1.0,1), (0.5,0) + expected_single = [ + PauliString.from_qubits((0, 1), "ZZ", -1.0), + PauliString.from_qubits((0, 1), "XX", -0.5), + PauliString.from_qubits((0, 1), "ZZ", -1.0), + ] + expected = expected_single * 2 assert result == expected def test_trotter_expansion_str(): """Test TrotterExpansion string representation.""" - step = strang_splitting(num_terms=3, time=0.25) - expansion = TrotterExpansion(step, num_steps=4) + model = make_two_term_model() + expansion = TrotterExpansion(strang_splitting, model, time=1.0, num_steps=4) result = str(expansion) assert "order=2" in result assert "num_steps=4" in result assert "total_time=1.0" in result - assert "num_terms=3" in result + assert "num_terms=2" in result def test_trotter_expansion_repr(): """Test TrotterExpansion repr representation.""" - step = TrotterStep(num_terms=2, time_step=0.5) - expansion = TrotterExpansion(step, num_steps=4) + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=4) result = repr(expansion) assert "TrotterExpansion" in result assert "num_steps=4" in result + + +def test_trotter_expansion_cirq_repetitions(): + """Test that TrotterExpansion.cirq repeats one-step circuit num_steps times.""" + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=5) + + op = expansion.cirq() + assert op.repetitions == 5 + + +def test_strang_splitting_rejects_empty_terms(): + """Test strang_splitting raises for empty term list.""" + with pytest.raises(IndexError): + strang_splitting([], time=1.0) From 1d5b5689ee8b3317a4896355fdbfba121ee4a9f2 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 23 Mar 2026 17:52:57 +0100 Subject: [PATCH 31/45] Two small fixes to QRE 3 (#3041) - Remove unused import - Propagate source and transform information in cached T states --- source/pip/qsharp/qre/_architecture.py | 1 - source/pip/qsharp/qre/models/factories/_round_based.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index 2260ac44e8..9045caee73 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -14,7 +14,6 @@ _IntFunction, _FloatFunction, constant_function, - instruction_name, property_name_to_key, ) diff --git a/source/pip/qsharp/qre/models/factories/_round_based.py b/source/pip/qsharp/qre/models/factories/_round_based.py index 2371d0b9d2..aed95e1243 100644 --- a/source/pip/qsharp/qre/models/factories/_round_based.py +++ b/source/pip/qsharp/qre/models/factories/_round_based.py @@ -110,7 +110,9 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non if self.use_cache and cache_path.exists(): cached_states = InstructionFrontier.load(str(cache_path)) for state in cached_states: - yield ctx.make_isa(ctx.add_instruction(state)) + yield ctx.make_isa( + ctx.add_instruction(state, transform=self, source=[impl_isa[T]]) + ) return # 2) Compute as before From cd308a913d774a6f244fe2c6121dc6d57a8a8c66 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 23 Mar 2026 17:53:18 +0100 Subject: [PATCH 32/45] Return ISA as a pandas data frame (#3042) --- source/pip/qsharp/qre/__init__.py | 5 +++++ source/pip/qsharp/qre/_instruction.py | 24 ++++++++++++++++++++++++ source/pip/qsharp/qre/_qre.pyi | 15 +++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index 2cdaf8dfc1..fbfd891be7 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -36,6 +36,11 @@ ) from ._trace import LatticeSurgery, PSSPC, TraceQuery, TraceTransform +# Extend Rust Python types with additional Python-side functionality +from ._instruction import _isa_as_frame + +ISA.as_frame = _isa_as_frame + __all__ = [ "block_linear_function", "constant_function", diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index 56645fd2aa..3f950669b0 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -8,6 +8,8 @@ from typing import Generator, Iterable, Optional from enum import IntEnum +import pandas as pd + from ._architecture import _Context, Architecture from ._enumeration import _enumerate_instances from ._isa_enumeration import ( @@ -372,3 +374,25 @@ def get( return _InstructionSourceNodeReference(self.graph, child_id) return default + + +def _isa_as_frame(self: ISA) -> pd.DataFrame: + data = { + "id": [instruction_name(inst.id) for inst in self], + "encoding": [Encoding(inst.encoding).name for inst in self], + "arity": [inst.arity for inst in self], + "space": [ + inst.expect_space() if inst.arity is not None else None for inst in self + ], + "time": [ + inst.expect_time() if inst.arity is not None else None for inst in self + ], + "error": [ + inst.expect_error_rate() if inst.arity is not None else None + for inst in self + ], + } + + df = pd.DataFrame(data) + df.set_index("id", inplace=True) + return df diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 610a901431..e2e9999a65 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -4,6 +4,8 @@ from __future__ import annotations from typing import Any, Callable, Iterator, Optional, overload +import pandas as pd + class ISA: def __add__(self, other: ISA) -> ISA: """ @@ -90,6 +92,19 @@ class ISA: """ ... + def as_frame(self) -> pd.DataFrame: + """ + Returns a pandas DataFrame representation of the ISA. + + The DataFrame will have one row per instruction, with columns for + instruction properties such as time, space, and error rate. The exact + columns may vary based on the properties of the instructions in the ISA. + + Returns: + pd.DataFrame: A DataFrame representation of the ISA. + """ + ... + def __iter__(self) -> Iterator[_Instruction]: """ Returns an iterator over the instructions. From 8597d0a37893d63c899df5cef5a3cb4b9da6eadd Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 23 Mar 2026 17:53:34 +0100 Subject: [PATCH 33/45] Temporary plotting for estimation results (#3043) Adding a plotting function for estimation results. We may change this later to use Q# widgets like other parts in the repo. But adding this functionality now to support samples of higher priority. --- source/pip/qsharp/qre/_estimation.py | 122 +++++++++++++++++++++++++++ source/pip/tests/test_qre.py | 72 +++++++++++++++- 2 files changed, 193 insertions(+), 1 deletion(-) diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 4a4b75366d..18d219eb2c 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -348,6 +348,128 @@ def as_frame(self): ] ) + # Mapping from runtime unit name to its value in nanoseconds. + _TIME_UNITS: dict[str, float] = { + "ns": 1, + "µs": 1e3, + "us": 1e3, + "ms": 1e6, + "s": 1e9, + "min": 60e9, + "hours": 3600e9, + "days": 86_400e9, + "weeks": 604_800e9, + "months": 31 * 86_400e9, + "years": 365 * 86_400e9, + "decades": 10 * 365 * 86_400e9, + "centuries": 100 * 365 * 86_400e9, + } + + # Ordered subset of _TIME_UNITS used for default x-axis tick labels. + _TICK_UNITS: list[tuple[str, float]] = [ + ("1 ns", _TIME_UNITS["ns"]), + ("1 µs", _TIME_UNITS["µs"]), + ("1 ms", _TIME_UNITS["ms"]), + ("1 s", _TIME_UNITS["s"]), + ("1 min", _TIME_UNITS["min"]), + ("1 hour", _TIME_UNITS["hours"]), + ("1 day", _TIME_UNITS["days"]), + ("1 week", _TIME_UNITS["weeks"]), + ("1 month", _TIME_UNITS["months"]), + ("1 year", _TIME_UNITS["years"]), + ("1 decade", _TIME_UNITS["decades"]), + ("1 century", _TIME_UNITS["centuries"]), + ] + + def plot( + self, + *, + runtime_unit: Optional[str] = None, + figsize: tuple[float, float] = (15, 8), + scatter_args: dict[str, Any] = {"marker": "x"}, + ): + """Returns a plot of the estimates displaying qubits vs runtime. + + Creates a log-log scatter plot where the x-axis shows the total + runtime and the y-axis shows the total number of physical qubits. + + When *runtime_unit* is ``None`` (the default), the x-axis uses + human-readable time-unit tick labels spanning nanoseconds to + centuries. When a unit string is given (e.g. ``"hours"``), all + runtimes are scaled to that unit and the x-axis label includes the + unit while the ticks are plain numbers. + + Supported *runtime_unit* values: ``"ns"``, ``"µs"`` (or ``"us"``), + ``"ms"``, ``"s"``, ``"min"``, ``"hours"``, ``"days"``, ``"weeks"``, + ``"months"``, ``"years"``. + + Args: + runtime_unit: Optional time unit to scale the x-axis to. + scatter_args: Additional keyword arguments to pass to + ``matplotlib.axes.Axes.scatter`` when plotting the points. + + Returns: + matplotlib.figure.Figure: The figure containing the plot. + + Raises: + ImportError: If matplotlib is not installed. + ValueError: If the table is empty or *runtime_unit* is not + recognised. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + "Missing optional 'matplotlib' dependency. To install run: " + "pip install matplotlib" + ) + + if len(self) == 0: + raise ValueError("Cannot plot an empty EstimationTable.") + + if runtime_unit is not None and runtime_unit not in self._TIME_UNITS: + raise ValueError( + f"Unknown runtime_unit {runtime_unit!r}. " + f"Supported units: {', '.join(self._TIME_UNITS)}" + ) + + ys = [entry.qubits for entry in self] + + fig, ax = plt.subplots(figsize=figsize) + + ax.set_ylabel("Physical qubits") + + if runtime_unit is not None: + scale = self._TIME_UNITS[runtime_unit] + xs = [entry.runtime / scale for entry in self] + ax.set_xlabel(f"Runtime ({runtime_unit})") + ax.set_xscale("log") + ax.set_yscale("log") + ax.scatter(x=xs, y=ys, **scatter_args) + else: + xs = [entry.runtime for entry in self] + ax.set_xlabel("Runtime") + ax.set_xscale("log") + ax.set_yscale("log") + ax.scatter(x=xs, y=ys, **scatter_args) + + time_labels, time_units = zip(*self._TICK_UNITS) + + cutoff = ( + next( + (i for i, x in enumerate(time_units) if x > max(xs)), + len(time_units) - 1, + ) + + 1 + ) + + ax.set_xticks(time_units[:cutoff]) + ax.set_xticklabels(time_labels[:cutoff], rotation=90) + + plt.close(fig) + + return fig + @dataclass(frozen=True, slots=True) class EstimationTableColumn: diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 34a70dfec4..dcedca0409 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -4,7 +4,7 @@ from dataclasses import KW_ONLY, dataclass, field from enum import Enum from pathlib import Path -from typing import cast, Generator +from typing import cast, Generator, Sized import os import pytest @@ -1295,6 +1295,76 @@ def test_estimation_table_computed_column(): assert frame["qubit_error_product"][1] == pytest.approx(4.0) +def test_estimation_table_plot_returns_figure(): + """Test that plot() returns a matplotlib Figure with correct axes.""" + from matplotlib.figure import Figure + + table = EstimationTable() + table.append(_make_entry(100, 5_000_000_000, 0.01)) + table.append(_make_entry(200, 10_000_000_000, 0.02)) + table.append(_make_entry(50, 50_000_000_000, 0.005)) + + fig = table.plot() + + assert isinstance(fig, Figure) + ax = fig.axes[0] + assert ax.get_ylabel() == "Physical qubits" + assert ax.get_xlabel() == "Runtime" + assert ax.get_xscale() == "log" + assert ax.get_yscale() == "log" + + # Verify data points + offsets = ax.collections[0].get_offsets() + assert len(cast(Sized, offsets)) == 3 + + +def test_estimation_table_plot_empty_raises(): + """Test that plot() raises ValueError on an empty table.""" + table = EstimationTable() + with pytest.raises(ValueError, match="Cannot plot an empty EstimationTable"): + table.plot() + + +def test_estimation_table_plot_single_entry(): + """Test that plot() works with a single entry.""" + from matplotlib.figure import Figure + + table = EstimationTable() + table.append(_make_entry(100, 1_000_000, 0.01)) + + fig = table.plot() + assert isinstance(fig, Figure) + + offsets = fig.axes[0].collections[0].get_offsets() + assert len(cast(Sized, offsets)) == 1 + + +def test_estimation_table_plot_with_runtime_unit(): + """Test that plot(runtime_unit=...) scales x values and labels the axis.""" + table = EstimationTable() + # 1 hour = 3600e9 ns, 2 hours = 7200e9 ns + table.append(_make_entry(100, int(3600e9), 0.01)) + table.append(_make_entry(200, int(7200e9), 0.02)) + + fig = table.plot(runtime_unit="hours") + + ax = fig.axes[0] + assert ax.get_xlabel() == "Runtime (hours)" + + # Verify the x data is scaled: should be 1.0 and 2.0 hours + offsets = cast(list, ax.collections[0].get_offsets()) + assert offsets[0][0] == pytest.approx(1.0) + assert offsets[1][0] == pytest.approx(2.0) + + +def test_estimation_table_plot_invalid_runtime_unit(): + """Test that plot() raises ValueError for an unknown runtime_unit.""" + table = EstimationTable() + table.append(_make_entry(100, 1000, 0.01)) + with pytest.raises(ValueError, match="Unknown runtime_unit"): + table.plot(runtime_unit="fortnights") + + def _ll_files(): ll_dir = ( Path(__file__).parent.parent From 232df87a4756c7a17819c5fed787700b51395093 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 23 Mar 2026 17:53:53 +0100 Subject: [PATCH 34/45] Get the name of a known property ID as a string (#3044) --- source/pip/qsharp/qre/__init__.py | 4 ++++ source/pip/qsharp/qre/_qre.py | 1 + source/pip/qsharp/qre/_qre.pyi | 12 ++++++++++++ source/pip/src/qre.rs | 6 ++++++ source/pip/tests/test_qre.py | 20 ++++++++++++++++++++ source/qre/src/isa/property_keys.rs | 11 +++++++++++ source/qre/src/lib.rs | 2 +- 7 files changed, 55 insertions(+), 1 deletion(-) diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index fbfd891be7..e90eb3b0b0 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -33,6 +33,8 @@ generic_function, linear_function, instruction_name, + property_name, + property_name_to_key, ) from ._trace import LatticeSurgery, PSSPC, TraceQuery, TraceTransform @@ -70,6 +72,8 @@ "ISATransform", "LatticeSurgery", "PSSPC", + "property_name", + "property_name_to_key", "Trace", "TraceQuery", "TraceTransform", diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index a67e320218..f724349388 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -30,6 +30,7 @@ PSSPC, Trace, property_name_to_key, + property_name, _float_to_bits, _float_from_bits, ) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index e2e9999a65..03a8cd9bfe 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -1539,3 +1539,15 @@ def property_name_to_key(name: str) -> Optional[int]: Optional[int]: The property key, or None if the name is not recognized. """ ... + +def property_name(id: int) -> Optional[str]: + """ + Converts a property key to its corresponding name, if known. + + Args: + id (int): The property key. + + Returns: + Optional[str]: The property name, or None if the key is not recognized. + """ + ... diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 6e69218f9a..87bdc2f0db 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -43,6 +43,7 @@ pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(float_from_bits, m)?)?; m.add_function(wrap_pyfunction!(instruction_name, m)?)?; m.add_function(wrap_pyfunction!(property_name_to_key, m)?)?; + m.add_function(wrap_pyfunction!(property_name, m)?)?; m.add("EstimationError", m.py().get_type::())?; @@ -1498,6 +1499,11 @@ pub fn property_name_to_key(name: &str) -> Option { qre::property_name_to_key(&name.to_ascii_uppercase()) } +#[pyfunction] +pub fn property_name(id: u64) -> Option { + qre::property_name(id).map(String::from) +} + fn add_property_keys(m: &Bound<'_, PyModule>) -> PyResult<()> { #[allow(clippy::wildcard_imports)] use qre::property_keys::*; diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index dcedca0409..aa66f421c5 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -25,6 +25,8 @@ estimate, linear_function, generic_function, + property_name, + property_name_to_key, ) from qsharp.qre._qre import _ProvenanceGraph from qsharp.qre.application import QSharpApplication @@ -190,6 +192,24 @@ def test_instruction_constraints(): assert isa_with_dist.satisfies(reqs_with_prop) is True +def test_property_names(): + assert property_name(DISTANCE) == "DISTANCE" + + # An unregistered property + UNKNOWN = 10_000 + assert property_name(UNKNOWN) is None + + # But using an existing property key with a different variable name will + # still return something + UNKNOWN = 0 + assert property_name(UNKNOWN) == "DISTANCE" + + assert property_name_to_key("DISTANCE") == DISTANCE + + # But we also allow case-insensitive lookup + assert property_name_to_key("distance") == DISTANCE + + def test_generic_function(): from qsharp.qre._qre import _IntFunction, _FloatFunction diff --git a/source/qre/src/isa/property_keys.rs b/source/qre/src/isa/property_keys.rs index 376d9979f8..4f6eb50f0b 100644 --- a/source/qre/src/isa/property_keys.rs +++ b/source/qre/src/isa/property_keys.rs @@ -31,6 +31,17 @@ macro_rules! define_properties { _ => None } } + + /// Integer key → property name mapping + #[must_use] + pub fn property_name(id: u64) -> Option<&'static str> { + match id { + $( + $name => Some(stringify!($name)), + )* + _ => None, + } + } }; } diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs index b22334f6da..42db079461 100644 --- a/source/qre/src/lib.rs +++ b/source/qre/src/lib.rs @@ -10,7 +10,7 @@ pub use pareto::{ }; mod result; pub use isa::property_keys; -pub use isa::property_keys::property_name_to_key; +pub use isa::property_keys::{property_name, property_name_to_key}; pub use isa::{ ConstraintBound, Encoding, ISA, ISARequirements, Instruction, InstructionConstraint, LockedISA, ProvenanceGraph, VariableArityFunction, From f1ceab8e2d19ede266dec4c2c867cb3af445a534 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 23 Mar 2026 17:54:13 +0100 Subject: [PATCH 35/45] Trace properties for logical qubits (#3045) We always get logical compute and memory qubits (memory is 0, if there are no memory qubits). These are the qubits that are in the trace before estimation, and therefore might already contain auxiliary qubits. Application generators may add algorithm logical qubits (for compute and memory) to the trace properties as well. --- source/pip/qsharp/qre/interop/_qsharp.py | 8 +++++++- source/pip/qsharp/qre/property_keys.pyi | 4 ++++ source/pip/src/qre.rs | 6 +++++- source/pip/tests/test_qre.py | 15 ++++++++++++++- source/qre/src/isa/property_keys.rs | 4 ++++ source/qre/src/trace.rs | 15 ++++++++++++++- 6 files changed, 48 insertions(+), 4 deletions(-) diff --git a/source/pip/qsharp/qre/interop/_qsharp.py b/source/pip/qsharp/qre/interop/_qsharp.py index cded428266..d595ad9e9c 100644 --- a/source/pip/qsharp/qre/interop/_qsharp.py +++ b/source/pip/qsharp/qre/interop/_qsharp.py @@ -11,7 +11,11 @@ from ...estimator import LogicalCounts from .._qre import Trace from ..instruction_ids import CCX, MEAS_Z, RZ, T, READ_FROM_MEMORY, WRITE_TO_MEMORY -from ..property_keys import EVALUATION_TIME +from ..property_keys import ( + EVALUATION_TIME, + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, +) def _bucketize_rotation_counts( @@ -103,6 +107,8 @@ def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: block.add_operation(WRITE_TO_MEMORY, [0, compute_qubits]) trace.set_property(EVALUATION_TIME, evaluation_time) + trace.set_property(ALGORITHM_COMPUTE_QUBITS, compute_qubits) + trace.set_property(ALGORITHM_MEMORY_QUBITS, memory_qubits) return trace diff --git a/source/pip/qsharp/qre/property_keys.pyi b/source/pip/qsharp/qre/property_keys.pyi index 62f5fd5213..ed0b311821 100644 --- a/source/pip/qsharp/qre/property_keys.pyi +++ b/source/pip/qsharp/qre/property_keys.pyi @@ -13,3 +13,7 @@ PHYSICAL_COMPUTE_QUBITS: int PHYSICAL_FACTORY_QUBITS: int PHYSICAL_MEMORY_QUBITS: int MOLECULE: int +LOGICAL_COMPUTE_QUBITS: int +LOGICAL_MEMORY_QUBITS: int +ALGORITHM_COMPUTE_QUBITS: int +ALGORITHM_MEMORY_QUBITS: int diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 87bdc2f0db..d31b5bd6f8 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -1528,7 +1528,11 @@ fn add_property_keys(m: &Bound<'_, PyModule>) -> PyResult<()> { PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, - MOLECULE + MOLECULE, + LOGICAL_COMPUTE_QUBITS, + LOGICAL_MEMORY_QUBITS, + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, ); m.add_submodule(&property_keys)?; diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index aa66f421c5..4aad35c92e 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -47,7 +47,14 @@ ISARefNode, ) from qsharp.qre.instruction_ids import CCX, CCZ, LATTICE_SURGERY, T, RZ -from qsharp.qre.property_keys import DISTANCE, NUM_TS_PER_ROTATION +from qsharp.qre.property_keys import ( + DISTANCE, + NUM_TS_PER_ROTATION, + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, + LOGICAL_COMPUTE_QUBITS, + LOGICAL_MEMORY_QUBITS, +) # NOTE These classes will be generalized as part of the QRE API in the following # pull requests and then moved out of the tests. @@ -862,8 +869,14 @@ def test_qsharp_application(): assert trace2.resource_states == { T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx } + assert trace2.get_property(ALGORITHM_COMPUTE_QUBITS) == 3 + assert trace2.get_property(ALGORITHM_MEMORY_QUBITS) == 0 result = trace2.estimate(isa, max_error=float("inf")) assert result is not None + assert result.properties[ALGORITHM_COMPUTE_QUBITS] == 3 + assert result.properties[ALGORITHM_MEMORY_QUBITS] == 0 + assert result.properties[LOGICAL_COMPUTE_QUBITS] == 12 + assert result.properties[LOGICAL_MEMORY_QUBITS] == 0 _assert_estimation_result(trace2, result, isa) assert counter == 32 diff --git a/source/qre/src/isa/property_keys.rs b/source/qre/src/isa/property_keys.rs index 4f6eb50f0b..6f4e7ca877 100644 --- a/source/qre/src/isa/property_keys.rs +++ b/source/qre/src/isa/property_keys.rs @@ -58,4 +58,8 @@ define_properties! { PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, MOLECULE, + LOGICAL_COMPUTE_QUBITS, + LOGICAL_MEMORY_QUBITS, + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, } diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 21b44cd6ad..ccfdd5c0cd 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -16,7 +16,10 @@ use serde::{Deserialize, Serialize}; use crate::{ Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction, LockedISA, ProvenanceGraph, ResultSummary, - property_keys::{PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS}, + property_keys::{ + LOGICAL_COMPUTE_QUBITS, LOGICAL_MEMORY_QUBITS, PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, + }, }; pub mod instruction_ids; @@ -335,6 +338,16 @@ impl Trace { } } + // Make main trace metrics properties to access them from the result + result.set_property( + LOGICAL_COMPUTE_QUBITS, + Property::Int(self.compute_qubits.cast_signed()), + ); + result.set_property( + LOGICAL_MEMORY_QUBITS, + Property::Int(self.memory_qubits.unwrap_or(0).cast_signed()), + ); + // Copy properties from the trace to the result for (key, value) in &self.properties { result.set_property(*key, value.clone()); From 9d87a9c6e180c295af44cebd770a3232a127ae1e Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 23 Mar 2026 18:37:18 +0100 Subject: [PATCH 36/45] Expose ISA requirements of a trace (#3046) Some refactoring of the trace code for graph estimation, and a way to expose the ISA requirements of a trace through the Python API. --- source/pip/qsharp/qre/_qre.pyi | 53 ++++++++++++++++- source/pip/src/qre.rs | 56 ++++++++++++++++++ source/pip/tests/test_qre.py | 4 ++ source/qre/src/isa.rs | 22 ++++++++ source/qre/src/trace.rs | 100 +++++++++++++++++++++++---------- 5 files changed, 203 insertions(+), 32 deletions(-) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 03a8cd9bfe..4562cc8f3a 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -140,11 +140,32 @@ class ISARequirements: constraints. Args: - constraints (list[InstructionConstraint] | *InstructionConstraint): The list of instruction + constraints (list[Constraint] | *Constraint): The list of instruction constraints. """ ... + def __len__(self) -> int: + """ + Returns the number of constraints in the requirements specification. + + Returns: + int: The number of constraints. + """ + ... + + def __iter__(self) -> Iterator[Constraint]: + """ + Returns an iterator over the constraints. + + Note: + The order of constraints is not guaranteed. + + Returns: + Iterator[Constraint]: The constraint iterator. + """ + ... + class _Instruction: @staticmethod def fixed_arity( @@ -511,6 +532,26 @@ class Constraint: """ ... + @property + def id(self) -> int: + """ + The instruction ID. + + Returns: + int: The instruction ID. + """ + ... + + @property + def encoding(self) -> int: + """ + The instruction encoding. 0 = Physical, 1 = Logical. + + Returns: + int: The instruction encoding. + """ + ... + def add_property(self, property: int) -> None: """ Adds a property requirement to the constraint. @@ -1311,6 +1352,16 @@ class Trace: """ ... + @property + def required_isa(self) -> ISARequirements: + """ + The required ISA for the trace. + + Returns: + ISARequirements: The required ISA for the trace. + """ + ... + def __str__(self) -> str: """ Returns a string representation of the trace. diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index d31b5bd6f8..e459788515 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -173,6 +173,35 @@ impl ISARequirements { .collect::>() .map(ISARequirements) } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let constraints: Vec = slf.0.constraints(); + let iter = ISARequirementsIterator { + iter: constraints.into_iter(), + }; + Py::new(slf.py(), iter) + } +} + +#[pyclass] +pub struct ISARequirementsIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl ISARequirementsIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next().map(Constraint) + } } #[allow(clippy::unsafe_derive_deserialize)] @@ -376,6 +405,19 @@ impl Constraint { ))) } + #[getter] + pub fn id(&self) -> u64 { + self.0.id() + } + + #[getter] + pub fn encoding(&self) -> u64 { + match self.0.encoding() { + qre::Encoding::Physical => 0, + qre::Encoding::Logical => 1, + } + } + pub fn add_property(&mut self, property: u64) { self.0.add_property(property); } @@ -1140,6 +1182,20 @@ impl Trace { fn __str__(&self) -> String { format!("{}", self.0) } + + #[getter] + pub fn required_isa(&self) -> ISARequirements { + let constraints = self + .0 + .required_instruction_ids() + .keys() + .map(|id| + // NOTE: Retrieve more precise arity information from the trace + qre::InstructionConstraint::new(*id, qre::Encoding::Logical, None, None)) + .collect(); + + ISARequirements(constraints) + } } #[pyclass(unsendable)] diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 4aad35c92e..2dc318fa5e 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -816,6 +816,8 @@ def test_qsharp_application(): assert trace.depth == 3 assert trace.resource_states == {} + assert {c.id for c in trace.required_isa} == {CCX, T, RZ} + graph = _ProvenanceGraph() isa = graph.make_isa( [ @@ -865,10 +867,12 @@ def test_qsharp_application(): T: num_ts + psspc.num_ts_per_rotation * num_rotations, CCX: num_ccx, } + assert {c.id for c in trace2.required_isa} == {CCX, T, LATTICE_SURGERY} else: assert trace2.resource_states == { T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx } + assert {c.id for c in trace2.required_isa} == {T, LATTICE_SURGERY} assert trace2.get_property(ALGORITHM_COMPUTE_QUBITS) == 3 assert trace2.get_property(ALGORITHM_MEMORY_QUBITS) == 0 result = trace2.estimate(isa, max_error=float("inf")) diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index b35d364c45..f4761542fa 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -225,6 +225,22 @@ impl ISARequirements { pub fn add_constraint(&mut self, constraint: InstructionConstraint) { self.constraints.insert(constraint.id, constraint); } + + #[must_use] + pub fn len(&self) -> usize { + self.constraints.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.constraints.is_empty() + } + + /// Returns all instructions as owned clones. + #[must_use] + pub fn constraints(&self) -> Vec { + self.constraints.values().cloned().collect() + } } impl FromIterator for ISARequirements { @@ -483,6 +499,12 @@ impl InstructionConstraint { self.id } + /// Returns the required encoding for this constraint. + #[must_use] + pub fn encoding(&self) -> Encoding { + self.encoding + } + /// Checks whether a given instruction satisfies this constraint. #[must_use] pub fn is_satisfied_by(&self, instruction: &Instruction) -> bool { diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index ccfdd5c0cd..6b75689c4b 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -814,13 +814,20 @@ pub fn estimate_parallel<'a>( collection } +/// A node in the provenance graph along with pre-computed (space, time) values +/// for pruning. +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +struct NodeProfile { + node_index: usize, + space: u64, + time: u64, +} + /// A single entry in a combination of instruction choices for estimation. #[derive(Clone, Copy, Hash, Eq, PartialEq)] struct CombinationEntry { instruction_id: u64, - node_index: usize, - space: u64, - time: u64, + node: NodeProfile, } /// Per-slot pruning witnesses: maps a context hash to the `(space, time)` @@ -835,7 +842,7 @@ fn combination_context_hash(combination: &[CombinationEntry], exclude_idx: usize for (i, entry) in combination.iter().enumerate() { if i != exclude_idx { entry.instruction_id.hash(&mut hasher); - entry.node_index.hash(&mut hasher); + entry.node.node_index.hash(&mut hasher); } } hasher.finish() @@ -854,7 +861,7 @@ fn is_dominated(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses .expect("Pruning lock poisoned"); if map.get(&ctx_hash).is_some_and(|w| { w.iter() - .any(|&(ws, wt)| ws <= entry.space && wt <= entry.time) + .any(|&(ws, wt)| ws <= entry.node.space && wt <= entry.node.time) }) { return true; } @@ -872,7 +879,7 @@ fn record_success(combination: &[CombinationEntry], trace_pruning: &[SlotWitness .expect("Pruning lock poisoned"); map.entry(ctx_hash) .or_default() - .push((entry.space, entry.time)); + .push((entry.node.space, entry.node.time)); } } @@ -901,6 +908,54 @@ impl ISAIndex { } } +/// Generates the cartesian product of `id_and_nodes` and pushes each +/// combination directly into `jobs`, avoiding intermediate allocations. +/// +/// The cartesian product is enumerated using mixed-radix indexing. Given +/// dimensions with sizes `[n0, n1, n2, …]`, the total number of combinations +/// is `n0 * n1 * n2 * …`. Each combination index `i` in `0..total` uniquely +/// identifies one element from every dimension: the index into dimension `d` is +/// `(i / (n0 * n1 * … * n(d-1))) % nd`, which we compute incrementally by +/// repeatedly taking `i % nd` and then dividing `i` by `nd`. This is +/// analogous to extracting digits from a number in a mixed-radix system. +fn push_cartesian_product( + id_and_nodes: &[(u64, Vec)], + trace_idx: usize, + jobs: &mut Vec<(usize, Vec)>, + max_slots: &mut usize, +) { + // The product of all dimension sizes gives the total number of + // combinations. If any dimension is empty the product is zero and there + // are no valid combinations to generate. + let total: usize = id_and_nodes.iter().map(|(_, nodes)| nodes.len()).product(); + if total == 0 { + return; + } + + *max_slots = (*max_slots).max(id_and_nodes.len()); + jobs.reserve(total); + + // Enumerate every combination by treating the combination index `i` as a + // mixed-radix number. The inner loop "peels off" one digit per dimension: + // node_idx = i % nodes.len() — selects this dimension's element + // i /= nodes.len() — shifts to the next dimension's digit + // After processing all dimensions, `i` is exhausted (becomes 0), and + // `combo` contains exactly one entry per instruction id. + for mut i in 0..total { + let mut combo = Vec::with_capacity(id_and_nodes.len()); + for (id, nodes) in id_and_nodes { + let node_idx = i % nodes.len(); + i /= nodes.len(); + let profile = nodes[node_idx]; + combo.push(CombinationEntry { + instruction_id: *id, + node: profile, + }); + } + jobs.push((trace_idx, combo)); + } +} + #[must_use] #[allow(clippy::cast_precision_loss, clippy::too_many_lines)] pub fn estimate_with_graph( @@ -947,7 +1002,11 @@ pub fn estimate_with_graph( let instruction = graph_lock.instruction(node); let space = instruction.space(Some(1)).unwrap_or(0); let time = instruction.time(Some(1)).unwrap_or(0); - (node, space, time) + NodeProfile { + node_index: node, + space, + time, + } }) .collect::>(), ) @@ -962,28 +1021,7 @@ pub fn estimate_with_graph( continue; } - let mut combinations: Vec> = vec![Vec::new()]; - for (id, nodes) in id_and_nodes { - let mut new_combinations = Vec::new(); - for (node, space, time) in nodes { - for combo in &combinations { - let mut new_combo = combo.clone(); - new_combo.push(CombinationEntry { - instruction_id: id, - node_index: node, - space, - time, - }); - new_combinations.push(new_combo); - } - } - combinations = new_combinations; - } - - for combination in combinations { - max_slots = max_slots.max(combination.len()); - jobs.push((trace_idx, combination)); - } + push_cartesian_product(&id_and_nodes, trace_idx, &mut jobs, &mut max_slots); } // Sort jobs so that combinations with smaller total (space + time) are @@ -993,7 +1031,7 @@ pub fn estimate_with_graph( jobs.sort_by_key(|(_, combo)| { combo .iter() - .map(|entry| entry.space + entry.time) + .map(|entry| entry.node.space + entry.node.time) .sum::() }); @@ -1057,7 +1095,7 @@ pub fn estimate_with_graph( let mut isa = ISA::with_graph(graph.clone()); for entry in combination { - isa.add_node(entry.instruction_id, entry.node_index); + isa.add_node(entry.instruction_id, entry.node.node_index); } if let Ok(mut result) = traces[*trace_idx].estimate(&isa, Some(max_error)) { From dabfa32c93c725b20c25f37274d0f0b1a74ec7f4 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Wed, 25 Mar 2026 18:54:52 +0100 Subject: [PATCH 37/45] Trace ISA requirements (#3050) This improves the function to retrieve ISA requirements from a trace. This can then be used both as a trace profile to prune estimation jobs based on the minimum required error rate for an instruction, but also to report the provided ISA for a trace through the Python API. A function to display it as a pandas data frame facilitates using it in a Jupyter notebook. --- source/pip/qsharp/qre/__init__.py | 3 +- source/pip/qsharp/qre/_instruction.py | 12 +++++ source/pip/qsharp/qre/_qre.pyi | 32 +++++++++++ source/pip/src/qre.rs | 21 ++++---- source/qre/src/isa.rs | 23 ++++++++ source/qre/src/trace.rs | 78 +++++++++++++++++++-------- 6 files changed, 137 insertions(+), 32 deletions(-) diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index e90eb3b0b0..7defa10627 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -39,9 +39,10 @@ from ._trace import LatticeSurgery, PSSPC, TraceQuery, TraceTransform # Extend Rust Python types with additional Python-side functionality -from ._instruction import _isa_as_frame +from ._instruction import _isa_as_frame, _requirements_as_frame ISA.as_frame = _isa_as_frame +ISARequirements.as_frame = _requirements_as_frame __all__ = [ "block_linear_function", diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index 3f950669b0..de54bfd657 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -396,3 +396,15 @@ def _isa_as_frame(self: ISA) -> pd.DataFrame: df = pd.DataFrame(data) df.set_index("id", inplace=True) return df + + +def _requirements_as_frame(self: ISARequirements) -> pd.DataFrame: + data = { + "id": [instruction_name(inst.id) for inst in self], + "encoding": [Encoding(inst.encoding).name for inst in self], + "arity": [inst.arity for inst in self], + } + + df = pd.DataFrame(data) + df.set_index("id", inplace=True) + return df diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 4562cc8f3a..3f71b19c55 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -166,6 +166,18 @@ class ISARequirements: """ ... + def as_frame(self) -> pd.DataFrame: + """ + Returns a pandas DataFrame representation of the ISA requirements. + + The DataFrame will have one row per instruction, with columns for + constraint properties such as encoding. + + Returns: + pd.DataFrame: A DataFrame representation of the ISA requirements. + """ + ... + class _Instruction: @staticmethod def fixed_arity( @@ -552,6 +564,26 @@ class Constraint: """ ... + @property + def arity(self) -> Optional[int]: + """ + The instruction arity. + + Returns: + Optional[int]: The instruction arity. + """ + ... + + @property + def error_rate(self) -> Optional[ConstraintBound]: + """ + The constraint on the instruction error rate. + + Returns: + Optional[ConstraintBound]: The constraint on the instruction error rate. + """ + ... + def add_property(self, property: int) -> None: """ Adds a property requirement to the constraint. diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index e459788515..6a47e50245 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -418,6 +418,16 @@ impl Constraint { } } + #[getter] + pub fn arity(&self) -> Option { + self.0.arity() + } + + #[getter] + pub fn error_rate(&self) -> Option { + self.0.error_rate().copied().map(ConstraintBound) + } + pub fn add_property(&mut self, property: u64) { self.0.add_property(property); } @@ -1185,16 +1195,7 @@ impl Trace { #[getter] pub fn required_isa(&self) -> ISARequirements { - let constraints = self - .0 - .required_instruction_ids() - .keys() - .map(|id| - // NOTE: Retrieve more precise arity information from the trace - qre::InstructionConstraint::new(*id, qre::Encoding::Logical, None, None)) - .collect(); - - ISARequirements(constraints) + ISARequirements(self.0.required_instruction_ids(None)) } } diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index f4761542fa..3fcde87d89 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -2,6 +2,7 @@ // Licensed under the MIT License. use std::{ + collections::hash_map::Entry, fmt::Display, ops::Add, sync::{Arc, RwLock, RwLockReadGuard}, @@ -236,6 +237,10 @@ impl ISARequirements { self.constraints.is_empty() } + pub fn entry(&mut self, id: u64) -> Entry<'_, u64, InstructionConstraint> { + self.constraints.entry(id) + } + /// Returns all instructions as owned clones. #[must_use] pub fn constraints(&self) -> Vec { @@ -505,6 +510,24 @@ impl InstructionConstraint { self.encoding } + #[must_use] + pub fn arity(&self) -> Option { + self.arity + } + + pub fn set_arity(&mut self, arity: Option) { + self.arity = arity; + } + + #[must_use] + pub fn error_rate(&self) -> Option<&ConstraintBound> { + self.error_rate_fn.as_ref() + } + + pub fn set_error_rate(&mut self, error_rate_fn: Option>) { + self.error_rate_fn = error_rate_fn; + } + /// Checks whether a given instruction satisfies this constraint. #[must_use] pub fn is_satisfied_by(&self, instruction: &Instruction) -> bool { diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 6b75689c4b..23d8a67898 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -14,8 +14,8 @@ use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; use crate::{ - Error, EstimationCollection, EstimationResult, FactoryResult, ISA, Instruction, LockedISA, - ProvenanceGraph, ResultSummary, + ConstraintBound, Encoding, Error, EstimationCollection, EstimationResult, FactoryResult, ISA, + ISARequirements, Instruction, InstructionConstraint, LockedISA, ProvenanceGraph, ResultSummary, property_keys::{ LOGICAL_COMPUTE_QUBITS, LOGICAL_MEMORY_QUBITS, PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, @@ -152,29 +152,62 @@ impl Trace { TraceIterator::new(&self.block) } - /// Returns the set of used instruction IDs in the trace including their volume + /// Returns the set of instruction IDs required by this trace, along with + /// their arity constraints if available. We take the actual arity from the + /// instruction, and if we see instructions with the same ID but different + /// arities, we mark them as variable arity in the returned requirements. + /// If `max_error` is provided, also adds error rate constraints based on + /// the instruction usage volume and the maximum allowed error. These error + /// rate constraints can be used for instruction pruning during estimation. + #[allow(clippy::cast_precision_loss)] #[must_use] - pub fn required_instruction_ids(&self) -> FxHashMap { - let mut ids = FxHashMap::default(); + pub fn required_instruction_ids(&self, max_error: Option) -> ISARequirements { + let mut constraints = FxHashMap::::default(); + + let mut update_constraints = |id: u64, arity: u64, added_volume: u64| { + constraints + .entry(id) + .and_modify(|(constraint, volume)| { + if let Some(prev_arity) = constraint.arity() + && prev_arity != arity + { + constraint.set_arity(None); + } + *volume += added_volume; + }) + .or_insert({ + let constraint = + InstructionConstraint::new(id, Encoding::Logical, Some(arity), None); + (constraint, added_volume) + }); + }; + for (gate, mult) in self.deep_iter() { let arity = gate.qubits.len() as u64; - ids.entry(gate.id) - .and_modify(|c| *c += mult * arity) - .or_insert(mult * (gate.qubits.len() as u64)); + update_constraints(gate.id, arity, mult * arity); } if let Some(ref rs) = self.resource_states { for (res_id, count) in rs { - ids.entry(*res_id) - .and_modify(|c| *c += *count) - .or_insert(*count); + update_constraints(*res_id, 1, *count); } } if let Some(memory_qubits) = self.memory_qubits { - ids.entry(instruction_ids::MEMORY) - .and_modify(|c| *c += memory_qubits) - .or_insert(memory_qubits); + update_constraints(instruction_ids::MEMORY, memory_qubits, memory_qubits); + } + + if let Some(max_error) = max_error { + constraints + .into_values() + .map(|(mut c, volume)| { + c.set_error_rate(Some(ConstraintBound::less_equal( + max_error / (volume as f64), + ))); + c + }) + .collect() + } else { + constraints.into_values().map(|(c, _)| c).collect() } - ids } #[must_use] @@ -982,21 +1015,24 @@ pub fn estimate_with_graph( continue; } - let required = trace.required_instruction_ids(); + let required = trace.required_instruction_ids(Some(max_error)); let graph_lock = graph.read().expect("Graph lock poisoned"); let id_and_nodes: Vec<_> = required + .constraints() .iter() - .filter_map(|(&id, &volume)| { - let max_error_rate = max_error / (volume as f64); - graph_lock.pareto_nodes(id).map(|nodes| { + .filter_map(|constraint| { + graph_lock.pareto_nodes(constraint.id()).map(|nodes| { ( - id, + constraint.id(), nodes .iter() .filter(|&&node| { + // Filter out nodes that don't meet the constraint bounds. let instruction = graph_lock.instruction(node); - instruction.error_rate(Some(1)).unwrap_or(0.0) <= max_error_rate + constraint.error_rate().is_none_or(|c| { + c.evaluate(&instruction.error_rate(Some(1)).unwrap_or(0.0)) + }) }) .map(|&node| { let instruction = graph_lock.instruction(node); From c22cf5cb8119f478c8fa31cd31c24fc1754871a8 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Wed, 25 Mar 2026 18:55:14 +0100 Subject: [PATCH 38/45] Plot multiple estimation tables in single plot (#3054) This is still an intermediate way to get plots into Jupyter notebooks and should be replaced by qdk widgets in a future PR. --- source/pip/qsharp/qre/__init__.py | 2 + source/pip/qsharp/qre/_estimation.py | 280 ++++++++++++++++----------- 2 files changed, 165 insertions(+), 117 deletions(-) diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index 7defa10627..a17bc2122c 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -8,6 +8,7 @@ EstimationTable, EstimationTableColumn, EstimationTableEntry, + plot_estimates, ) from ._instruction import ( LOGICAL, @@ -50,6 +51,7 @@ "constraint", "estimate", "linear_function", + "plot_estimates", "Application", "Architecture", "Block", diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 18d219eb2c..14b628b763 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import cast, Optional, Callable, Any +from typing import cast, Optional, Callable, Any, Iterable import pandas as pd @@ -204,6 +204,8 @@ def estimate( # Post-process the results and add them to a results table table = EstimationTable() + table.name = name + if name is not None: table.insert_column(0, "name", lambda entry: name) @@ -243,6 +245,7 @@ def __init__(self): """Initialize an empty estimation table with default columns.""" super().__init__() + self.name: Optional[str] = None self.stats = EstimationTableStats() self._columns: list[tuple[str, EstimationTableColumn]] = [ @@ -348,127 +351,16 @@ def as_frame(self): ] ) - # Mapping from runtime unit name to its value in nanoseconds. - _TIME_UNITS: dict[str, float] = { - "ns": 1, - "µs": 1e3, - "us": 1e3, - "ms": 1e6, - "s": 1e9, - "min": 60e9, - "hours": 3600e9, - "days": 86_400e9, - "weeks": 604_800e9, - "months": 31 * 86_400e9, - "years": 365 * 86_400e9, - "decades": 10 * 365 * 86_400e9, - "centuries": 100 * 365 * 86_400e9, - } - - # Ordered subset of _TIME_UNITS used for default x-axis tick labels. - _TICK_UNITS: list[tuple[str, float]] = [ - ("1 ns", _TIME_UNITS["ns"]), - ("1 µs", _TIME_UNITS["µs"]), - ("1 ms", _TIME_UNITS["ms"]), - ("1 s", _TIME_UNITS["s"]), - ("1 min", _TIME_UNITS["min"]), - ("1 hour", _TIME_UNITS["hours"]), - ("1 day", _TIME_UNITS["days"]), - ("1 week", _TIME_UNITS["weeks"]), - ("1 month", _TIME_UNITS["months"]), - ("1 year", _TIME_UNITS["years"]), - ("1 decade", _TIME_UNITS["decades"]), - ("1 century", _TIME_UNITS["centuries"]), - ] - - def plot( - self, - *, - runtime_unit: Optional[str] = None, - figsize: tuple[float, float] = (15, 8), - scatter_args: dict[str, Any] = {"marker": "x"}, - ): - """Returns a plot of the estimates displaying qubits vs runtime. - - Creates a log-log scatter plot where the x-axis shows the total - runtime and the y-axis shows the total number of physical qubits. - - When *runtime_unit* is ``None`` (the default), the x-axis uses - human-readable time-unit tick labels spanning nanoseconds to - centuries. When a unit string is given (e.g. ``"hours"``), all - runtimes are scaled to that unit and the x-axis label includes the - unit while the ticks are plain numbers. - - Supported *runtime_unit* values: ``"ns"``, ``"µs"`` (or ``"us"``), - ``"ms"``, ``"s"``, ``"min"``, ``"hours"``, ``"days"``, ``"weeks"``, - ``"months"``, ``"years"``. + def plot(self, **kwargs): + """Plot this table's results. - Args: - runtime_unit: Optional time unit to scale the x-axis to. - scatter_args: Additional keyword arguments to pass to - ``matplotlib.axes.Axes.scatter`` when plotting the points. + Convenience wrapper around :func:`plot_estimates`. All keyword + arguments are forwarded. Returns: matplotlib.figure.Figure: The figure containing the plot. - - Raises: - ImportError: If matplotlib is not installed. - ValueError: If the table is empty or *runtime_unit* is not - recognised. """ - try: - import matplotlib.pyplot as plt - except ImportError: - raise ImportError( - "Missing optional 'matplotlib' dependency. To install run: " - "pip install matplotlib" - ) - - if len(self) == 0: - raise ValueError("Cannot plot an empty EstimationTable.") - - if runtime_unit is not None and runtime_unit not in self._TIME_UNITS: - raise ValueError( - f"Unknown runtime_unit {runtime_unit!r}. " - f"Supported units: {', '.join(self._TIME_UNITS)}" - ) - - ys = [entry.qubits for entry in self] - - fig, ax = plt.subplots(figsize=figsize) - - ax.set_ylabel("Physical qubits") - - if runtime_unit is not None: - scale = self._TIME_UNITS[runtime_unit] - xs = [entry.runtime / scale for entry in self] - ax.set_xlabel(f"Runtime ({runtime_unit})") - ax.set_xscale("log") - ax.set_yscale("log") - ax.scatter(x=xs, y=ys, **scatter_args) - else: - xs = [entry.runtime for entry in self] - ax.set_xlabel("Runtime") - ax.set_xscale("log") - ax.set_yscale("log") - ax.scatter(x=xs, y=ys, **scatter_args) - - time_labels, time_units = zip(*self._TICK_UNITS) - - cutoff = ( - next( - (i for i, x in enumerate(time_units) if x > max(xs)), - len(time_units) - 1, - ) - + 1 - ) - - ax.set_xticks(time_units[:cutoff]) - ax.set_xticklabels(time_labels[:cutoff], rotation=90) - - plt.close(fig) - - return fig + return plot_estimates(self, **kwargs) @dataclass(frozen=True, slots=True) @@ -522,3 +414,157 @@ class EstimationTableStats: total_jobs: int = 0 successful_estimates: int = 0 pareto_results: int = 0 + + +# Mapping from runtime unit name to its value in nanoseconds. +_TIME_UNITS: dict[str, float] = { + "ns": 1, + "µs": 1e3, + "us": 1e3, + "ms": 1e6, + "s": 1e9, + "min": 60e9, + "hours": 3600e9, + "days": 86_400e9, + "weeks": 604_800e9, + "months": 31 * 86_400e9, + "years": 365 * 86_400e9, + "decades": 10 * 365 * 86_400e9, + "centuries": 100 * 365 * 86_400e9, +} + +# Ordered subset of _TIME_UNITS used for default x-axis tick labels. +_TICK_UNITS: list[tuple[str, float]] = [ + ("1 ns", _TIME_UNITS["ns"]), + ("1 µs", _TIME_UNITS["µs"]), + ("1 ms", _TIME_UNITS["ms"]), + ("1 s", _TIME_UNITS["s"]), + ("1 min", _TIME_UNITS["min"]), + ("1 hour", _TIME_UNITS["hours"]), + ("1 day", _TIME_UNITS["days"]), + ("1 week", _TIME_UNITS["weeks"]), + ("1 month", _TIME_UNITS["months"]), + ("1 year", _TIME_UNITS["years"]), + ("1 decade", _TIME_UNITS["decades"]), + ("1 century", _TIME_UNITS["centuries"]), +] + + +def plot_estimates( + data: EstimationTable | Iterable[EstimationTable], + *, + runtime_unit: Optional[str] = None, + figsize: tuple[float, float] = (15, 8), + scatter_args: dict[str, Any] = {"marker": "x"}, +): + """Returns a plot of the estimates displaying qubits vs runtime. + + Creates a log-log scatter plot where the x-axis shows the total runtime and + the y-axis shows the total number of physical qubits. + + *data* may be a single `EstimationTable` or an iterable of tables. When + multiple tables are provided, each is plotted as a separate series. If a + table has a `EstimationTable.name` (set via the *name* parameter of + `estimate`), it is used as the legend label for that series. + + When *runtime_unit* is ``None`` (the default), the x-axis uses + human-readable time-unit tick labels spanning nanoseconds to centuries. + When a unit string is given (e.g. ``"hours"``), all runtimes are scaled to + that unit and the x-axis label includes the unit while the ticks are plain + numbers. + + Supported *runtime_unit* values: ``"ns"``, ``"µs"`` (or ``"us"``), ``"ms"``, + ``"s"``, ``"min"``, ``"hours"``, ``"days"``, ``"weeks"``, ``"months"``, + ``"years"``. + + Args: + data: A single EstimationTable or an iterable of + EstimationTable objects to plot. + runtime_unit: Optional time unit to scale the x-axis to. + figsize: Figure dimensions in inches as ``(width, height)``. + scatter_args: Additional keyword arguments to pass to + ``matplotlib.axes.Axes.scatter`` when plotting the points. + + Returns: + matplotlib.figure.Figure: The figure containing the plot. + + Raises: + ImportError: If matplotlib is not installed. + ValueError: If all tables are empty or *runtime_unit* is not + recognised. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + "Missing optional 'matplotlib' dependency. To install run: " + "pip install matplotlib" + ) + + # Normalize to a list of tables + if isinstance(data, EstimationTable): + tables = [data] + else: + tables = list(data) + + if not tables or all(len(t) == 0 for t in tables): + raise ValueError("Cannot plot an empty EstimationTable.") + + if runtime_unit is not None and runtime_unit not in _TIME_UNITS: + raise ValueError( + f"Unknown runtime_unit {runtime_unit!r}. " + f"Supported units: {', '.join(_TIME_UNITS)}" + ) + + fig, ax = plt.subplots(figsize=figsize) + ax.set_ylabel("Physical qubits") + ax.set_xscale("log") + ax.set_yscale("log") + + all_xs: list[float] = [] + has_labels = False + + for table in tables: + if len(table) == 0: + continue + + ys = [entry.qubits for entry in table] + + if runtime_unit is not None: + scale = _TIME_UNITS[runtime_unit] + xs = [entry.runtime / scale for entry in table] + else: + xs = [float(entry.runtime) for entry in table] + + all_xs.extend(xs) + + label = table.name + if label is not None: + has_labels = True + + ax.scatter(x=xs, y=ys, label=label, **scatter_args) + + if runtime_unit is not None: + ax.set_xlabel(f"Runtime ({runtime_unit})") + else: + ax.set_xlabel("Runtime") + + time_labels, time_units = zip(*_TICK_UNITS) + + cutoff = ( + next( + (i for i, x in enumerate(time_units) if x > max(all_xs)), + len(time_units) - 1, + ) + + 1 + ) + + ax.set_xticks(time_units[:cutoff]) + ax.set_xticklabels(time_labels[:cutoff], rotation=90) + + if has_labels: + ax.legend() + + plt.close(fig) + + return fig From 0b921238e32a44135f056dae4422d01a8294f279 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Fri, 27 Mar 2026 09:40:41 +0100 Subject: [PATCH 39/45] Create QRE traces from cirq circuits (#3060) This adds support to generate traces from cirq circuits. These traces represent the original circuit and do not simply represent the gate counts. They monkey patch a `_to_trace` method to cirq gates and users of QRE can add a `_to_trace` method to their custom operations and gates. If no such method is found, the operation's or gate's decomposition method will be called to walk the circuit tree. The PR has some other small fixes: - extending the gate-based architecture gate set - minor fixes to `Trace` Python API - minor improvements on how to create _estimation table results_ (with extended info such as source ISA tree) from unannotated _estimation results_ --- source/pip/qsharp/qre/_estimation.py | 30 +- source/pip/qsharp/qre/_qre.pyi | 42 +- source/pip/qsharp/qre/application/__init__.py | 3 +- source/pip/qsharp/qre/application/_cirq.py | 48 ++ source/pip/qsharp/qre/interop/__init__.py | 10 +- source/pip/qsharp/qre/interop/_cirq.py | 417 ++++++++++++++++++ source/pip/qsharp/qre/interop/_qsharp.py | 2 +- source/pip/qsharp/qre/models/qubits/_aqre.py | 133 ++++-- source/pip/qsharp/qre/property_keys.pyi | 1 + source/pip/src/qre.rs | 32 +- source/pip/tests/qre/test_cirq_interop.py | 78 ++++ source/qre/src/isa/property_keys.rs | 1 + source/qre/src/trace.rs | 13 + 13 files changed, 747 insertions(+), 63 deletions(-) create mode 100644 source/pip/qsharp/qre/application/_cirq.py create mode 100644 source/pip/qsharp/qre/interop/_cirq.py create mode 100644 source/pip/tests/qre/test_cirq_interop.py diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 14b628b763..b49f92d60b 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -9,7 +9,7 @@ import pandas as pd from ._application import Application -from ._architecture import Architecture +from ._architecture import Architecture, _Context from ._qre import ( _estimate_parallel, _estimate_with_graph, @@ -17,6 +17,7 @@ Trace, FactoryResult, instruction_name, + EstimationResult, ) from ._trace import TraceQuery, PSSPC, LatticeSurgery from ._instruction import InstructionSource @@ -209,17 +210,9 @@ def estimate( if name is not None: table.insert_column(0, "name", lambda entry: name) - for result in collection: - entry = EstimationTableEntry( - qubits=result.qubits, - runtime=result.runtime, - error=result.error, - source=InstructionSource.from_isa(arch_ctx, result.isa), - factories=result.factories.copy(), - properties=result.properties.copy(), - ) - - table.append(entry) + table.extend( + EstimationTableEntry.from_result(result, arch_ctx) for result in collection + ) # Fill in the stats for this estimation run table.stats.num_traces = num_traces @@ -406,6 +399,19 @@ class EstimationTableEntry: factories: dict[int, FactoryResult] = field(default_factory=dict) properties: dict[int, int | float | bool | str] = field(default_factory=dict) + @classmethod + def from_result( + cls, result: EstimationResult, ctx: _Context + ) -> EstimationTableEntry: + return cls( + qubits=result.qubits, + runtime=result.runtime, + error=result.error, + source=InstructionSource.from_isa(ctx, result.isa), + factories=result.factories.copy(), + properties=result.properties.copy(), + ) + @dataclass(slots=True) class EstimationTableStats: diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 3f71b19c55..e143333df4 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -1219,6 +1219,16 @@ class Trace: """ ... + @compute_qubits.setter + def compute_qubits(self, qubits: int) -> None: + """ + Sets the number of compute qubits. + + Args: + qubits (int): The number of compute qubits to set. + """ + ... + @property def base_error(self) -> float: """ @@ -1257,7 +1267,8 @@ class Trace: """ ... - def set_memory_qubits(self, qubits: int) -> None: + @memory_qubits.setter + def memory_qubits(self, qubits: int) -> None: """ Sets the number of memory qubits. @@ -1322,6 +1333,16 @@ class Trace: """ ... + @property + def total_qubits(self) -> int: + """ + The total number of qubits (compute + memory). + + Returns: + int: The total number of qubits. + """ + ... + @property def depth(self) -> int: """ @@ -1332,6 +1353,16 @@ class Trace: """ ... + @property + def num_gates(self) -> int: + """ + The total number of gates in the trace. + + Returns: + int: The total number of gates. + """ + ... + def estimate( self, isa: ISA, max_error: Optional[float] = None ) -> Optional[EstimationResult]: @@ -1372,6 +1403,15 @@ class Trace: """ ... + def root_block(self) -> Block: + """ + Returns the root block of the trace. + + Returns: + Block: The root block of the trace. + """ + ... + def add_block(self, repetitions: int = 1) -> Block: """ Adds a block to the trace. diff --git a/source/pip/qsharp/qre/application/__init__.py b/source/pip/qsharp/qre/application/__init__.py index 9f36049425..7e39460c7e 100644 --- a/source/pip/qsharp/qre/application/__init__.py +++ b/source/pip/qsharp/qre/application/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from ._cirq import CirqApplication from ._qsharp import QSharpApplication -__all__ = ["QSharpApplication"] +__all__ = ["CirqApplication", "QSharpApplication"] diff --git a/source/pip/qsharp/qre/application/_cirq.py b/source/pip/qsharp/qre/application/_cirq.py new file mode 100644 index 0000000000..6f054213b7 --- /dev/null +++ b/source/pip/qsharp/qre/application/_cirq.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import annotations + +from dataclasses import dataclass + +import cirq + +from .._application import Application +from .._qre import Trace +from ..interop import trace_from_cirq + + +@dataclass +class CirqApplication(Application[None]): + """Application that produces a resource estimation trace from a Cirq circuit. + + Accepts either a Cirq ``Circuit`` object or an OpenQASM string. When a + QASM string is provided, it is parsed into a circuit using + ``cirq.contrib.qasm_import`` (requires the optional ``ply`` dependency). + + Args: + circuit_or_qasm: A Cirq Circuit or an OpenQASM string. + classical_control_probability: Probability that a classically + controlled operation is included in the trace. Defaults to 0.5. + """ + + circuit_or_qasm: str | cirq.CIRCUIT_LIKE + classical_control_probability: float = 0.5 + + def __post_init__(self): + if isinstance(self.circuit_or_qasm, str): + try: + from cirq.contrib.qasm_import import circuit_from_qasm + + self._circuit = circuit_from_qasm(self.circuit_or_qasm) + except ImportError: + raise ImportError( + "Missing optional 'ply' dependency. To install run: " + "pip install ply" + ) + else: + self._circuit = self.circuit_or_qasm + + def get_trace(self, parameters: None = None) -> Trace: + return trace_from_cirq(self._circuit) diff --git a/source/pip/qsharp/qre/interop/__init__.py b/source/pip/qsharp/qre/interop/__init__.py index 5f49608679..bbf927d3e8 100644 --- a/source/pip/qsharp/qre/interop/__init__.py +++ b/source/pip/qsharp/qre/interop/__init__.py @@ -1,7 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from ._cirq import trace_from_cirq, PushBlock, PopBlock from ._qsharp import trace_from_entry_expr, trace_from_entry_expr_cached from ._qir import trace_from_qir -__all__ = ["trace_from_entry_expr", "trace_from_entry_expr_cached", "trace_from_qir"] +__all__ = [ + "trace_from_cirq", + "trace_from_entry_expr", + "trace_from_entry_expr_cached", + "trace_from_qir", + "PushBlock", + "PopBlock", +] diff --git a/source/pip/qsharp/qre/interop/_cirq.py b/source/pip/qsharp/qre/interop/_cirq.py new file mode 100644 index 0000000000..0153c00320 --- /dev/null +++ b/source/pip/qsharp/qre/interop/_cirq.py @@ -0,0 +1,417 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import random +from dataclasses import dataclass +from math import pi +from typing import Iterable + +import cirq +from cirq import ( + HPowGate, + XPowGate, + YPowGate, + ZPowGate, + CXPowGate, + CZPowGate, + CCXPowGate, + CCZPowGate, + MeasurementGate, + ResetChannel, + GateOperation, + ClassicallyControlledOperation, + PhaseGradientGate, + SwapPowGate, +) +from qsharp.qre import Trace, Block +from qsharp.qre.instruction_ids import ( + H, + PAULI_X, + PAULI_Y, + PAULI_Z, + SQRT_X, + SQRT_X_DAG, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + S, + S_DAG, + T, + T_DAG, + CX, + CZ, + RX, + RY, + RZ, + MEAS_Z, + CCX, + CCZ, + SWAP, +) + +_TOLERANCE = 1e-8 + + +def _approx_eq(a: float, b: float) -> bool: + """Check whether two floats are approximately equal.""" + return abs(a - b) <= _TOLERANCE + + +def trace_from_cirq( + circuit: cirq.CIRCUIT_LIKE, *, classical_control_probability: float = 0.5 +) -> Trace: + """Convert a Cirq circuit into a resource estimation Trace. + + Iterates through all moments and operations in the circuit, converting + each gate into trace operations. Gates with a ``_to_trace`` method are + converted directly; others are recursively decomposed via Cirq's + ``_decompose_with_context_`` or ``_decompose_`` protocols. + + Args: + circuit: The Cirq circuit to convert. + classical_control_probability: Probability that a classically + controlled operation is included in the trace. Defaults to 0.5. + + Returns: + A Trace representing the resource profile of the circuit. + """ + + if isinstance(circuit, cirq.Circuit): + # circuit is already in the expected format, so we can process it directly. + pass + elif isinstance(circuit, cirq.Gate): + circuit = cirq.Circuit(circuit.on(*cirq.LineQid.for_gate(circuit))) + else: + # circuit is OP_TREE + circuit = cirq.Circuit(circuit) + + context = _Context(circuit, classical_control_probability) + + for moment in circuit: + for op in moment.operations: + context.handle_op(op) + + return context.trace + + +class _Context: + """Tracks the current trace and block nesting during trace generation. + + Maintains a stack of blocks so that ``PushBlock`` and ``PopBlock`` + operations can create nested repeated sections in the trace. + """ + + def __init__(self, circuit: cirq.Circuit, classical_control_probability: float): + self._trace = Trace(len(circuit.all_qubits())) + self._classical_control_probability = classical_control_probability + self._blocks = [self._trace.root_block()] + self._q_to_id = _QidToTraceId(circuit.all_qubits()) + self._decomp_context = cirq.DecompositionContext( + qubit_manager=cirq.GreedyQubitManager("trace_from_cirq") + ) + + def push_block(self, repetitions: int): + block = self.block.add_block(repetitions) + self._blocks.append(block) + + def pop_block(self): + self._blocks.pop() + + @property + def trace(self) -> Trace: + self._trace.compute_qubits = len(self._q_to_id) + return self._trace + + @property + def block(self) -> Block: + return self._blocks[-1] + + @property + def q_to_id(self) -> _QidToTraceId: + return self._q_to_id + + @property + def classical_control_probability(self) -> float: + return self._classical_control_probability + + @property + def decomp_context(self) -> cirq.DecompositionContext: + return self._decomp_context + + def handle_op( + self, + op: cirq.OP_TREE | TraceGate | PushBlock | PopBlock, + ) -> None: + """Recursively convert a single operation into trace instructions. + + Supported operation forms: + + - ``TraceGate``: A raw trace instruction, added directly to the current block. + - ``PushBlock`` / ``PopBlock``: Control block nesting with repetitions. + - ``GateOperation``: Dispatched via ``_to_trace`` if available on the + gate, otherwise decomposed via ``_decompose_with_context_`` or + ``_decompose_``. + - ``ClassicallyControlledOperation``: Included with the probability + specified in the generation context. + - ``list``: Each element is handled recursively. + - Any other operation: Decomposed via ``_decompose_with_context_``. + + Args: + op: The operation to convert. + """ + if isinstance(op, TraceGate): + qs = [ + self.q_to_id[q] + for q in ([op.qubits] if isinstance(op.qubits, cirq.Qid) else op.qubits) + ] + + if op.params is None: + self.block.add_operation(op.id, qs) + else: + self.block.add_operation( + op.id, qs, op.params if isinstance(op.params, list) else [op.params] + ) + elif isinstance(op, PushBlock): + self.push_block(op.repetitions) + elif isinstance(op, PopBlock): + self.pop_block() + elif isinstance(op, cirq.Operation): + if isinstance(op, GateOperation): + gate = op.gate + + if hasattr(gate, "_to_trace"): + for sub_op in gate._to_trace(self.decomp_context, op): # type: ignore + self.handle_op(sub_op) + elif hasattr(gate, "_decompose_with_context_"): + for sub_op in gate._decompose_with_context_(op.qubits, self.decomp_context): # type: ignore + self.handle_op(sub_op) + elif hasattr(gate, "_decompose_"): + # decompose the gate and handle the resulting operations recursively + for sub_op in gate._decompose_(op.qubits): # type: ignore + self.handle_op(sub_op) + else: + for sub_op in op._decompose_with_context_(self.decomp_context): # type: ignore + self.handle_op(sub_op) + elif isinstance(op, ClassicallyControlledOperation): + if random.random() < self.classical_control_probability: + self.handle_op(op.without_classical_controls()) + else: + for sub_op in op._decompose_with_context_(self.decomp_context): # type: ignore + self.handle_op(sub_op) + else: + # op is Iterable[OP_TREE] + for sub_op in op: + self.handle_op(sub_op) + + +@dataclass(frozen=True, slots=True) +class PushBlock: + """Signals the start of a repeated block in the trace. + + Args: + repetitions: Number of times the block is repeated. + """ + + repetitions: int + + +@dataclass(frozen=True, slots=True) +class PopBlock: + """Signals the end of the current repeated block in the trace.""" + + ... + + +@dataclass(frozen=True, slots=True) +class TraceGate: + id: int + qubits: list[cirq.Qid] | cirq.Qid + params: list[float] | float | None = None + + +class _QidToTraceId(dict): + """Mapping from Cirq qubits to integer trace qubit indices. + + Initialized with a set of known qubits. If an unknown qubit is looked + up, it is automatically assigned the next available index. + """ + + def __init__(self, init: Iterable[cirq.Qid]): + super().__init__({q: i for i, q in enumerate(init)}) + + def __getitem__(self, key: cirq.Qid) -> int: + """ + If the key is not present, add it to the mapping with the next available id. + """ + + if key not in self: + self[key] = len(self) + return super().__getitem__(key) + + +def h_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(H, [op.qubits[0]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def x_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + q = [op.qubits[0]] + exp = self.exponent + if _approx_eq(exp, 1) or _approx_eq(exp, -1): + yield TraceGate(PAULI_X, q) + elif _approx_eq(exp, 0.5): + yield TraceGate(SQRT_X, q) + elif _approx_eq(exp, -0.5): + yield TraceGate(SQRT_X_DAG, q) + elif _approx_eq(exp, 0.25): + yield TraceGate(SQRT_SQRT_X, q) + elif _approx_eq(exp, -0.25): + yield TraceGate(SQRT_SQRT_X_DAG, q) + else: + yield TraceGate(RX, q, exp * pi) + + +def y_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + q = [op.qubits[0]] + exp = self.exponent + if _approx_eq(exp, 1) or _approx_eq(exp, -1): + yield TraceGate(PAULI_Y, q) + elif _approx_eq(exp, 0.5): + yield TraceGate(SQRT_Y, q) + elif _approx_eq(exp, -0.5): + yield TraceGate(SQRT_Y_DAG, q) + elif _approx_eq(exp, 0.25): + yield TraceGate(SQRT_SQRT_Y, q) + elif _approx_eq(exp, -0.25): + yield TraceGate(SQRT_SQRT_Y_DAG, q) + else: + yield TraceGate(RY, q, exp * pi) + + +def z_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + q = [op.qubits[0]] + exp = self.exponent + if _approx_eq(exp, 1) or _approx_eq(exp, -1): + yield TraceGate(PAULI_Z, q) + elif _approx_eq(exp, 0.5): + yield TraceGate(S, q) + elif _approx_eq(exp, -0.5): + yield TraceGate(S_DAG, q) + elif _approx_eq(exp, 0.25): + yield TraceGate(T, q) + elif _approx_eq(exp, -0.25): + yield TraceGate(T_DAG, q) + else: + yield TraceGate(RZ, q, exp * pi) + + +def cx_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(CX, [op.qubits[0], op.qubits[1]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def cz_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + exp = self.exponent + c, t = op.qubits[0], op.qubits[1] + if _approx_eq(abs(exp), 1): + yield TraceGate(CZ, [c, t]) + elif _approx_eq(exp, 0.5): + # controlled S gate + yield TraceGate(T, [c]) + yield TraceGate(T, [t]) + yield TraceGate(CZ, [c, t]) + yield TraceGate(T_DAG, [t]) + yield TraceGate(CZ, [c, t]) + elif _approx_eq(exp, -0.5): + # controlled S† gate + yield TraceGate(T_DAG, [c]) + yield TraceGate(T_DAG, [t]) + yield TraceGate(CZ, [c, t]) + yield TraceGate(T, [t]) + yield TraceGate(CZ, [c, t]) + else: + rads = exp / 2 * pi + yield TraceGate(RZ, [c], [rads]) + yield TraceGate(RZ, [t], [rads]) + yield TraceGate(CZ, [c, t]) + yield TraceGate(RZ, [t], [-rads]) + yield TraceGate(CZ, [c, t]) + + +def swap_pow_gate_to_trace( + self, context: cirq.DecompositionContext, op: cirq.Operation +): + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(SWAP, [op.qubits[0], op.qubits[1]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def ccx_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(CCX, [op.qubits[0], op.qubits[1], op.qubits[2]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def ccz_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(CCZ, [op.qubits[0], op.qubits[1], op.qubits[2]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def measurement_gate_to_trace( + self, context: cirq.DecompositionContext, op: cirq.Operation +): + for q in op.qubits: + yield TraceGate(MEAS_Z, [q]) + + +def reset_channel_to_trace( + self, context: cirq.DecompositionContext, op: cirq.Operation +): + yield from () + + +# Attach _to_trace methods to Cirq gate classes so that handle_op can +# convert them directly into trace instructions without decomposition. +HPowGate._to_trace = h_pow_gate_to_trace +XPowGate._to_trace = x_pow_gate_to_trace +YPowGate._to_trace = y_pow_gate_to_trace +ZPowGate._to_trace = z_pow_gate_to_trace +CXPowGate._to_trace = cx_pow_gate_to_trace +CZPowGate._to_trace = cz_pow_gate_to_trace +SwapPowGate._to_trace = swap_pow_gate_to_trace +CCXPowGate._to_trace = ccx_pow_gate_to_trace +CCZPowGate._to_trace = ccz_pow_gate_to_trace +MeasurementGate._to_trace = measurement_gate_to_trace +ResetChannel._to_trace = reset_channel_to_trace + +# Decomposition overrides + + +def phase_gradient_decompose(self, qubits): + """ + Overrides implementation of PhaseGradientGate._decompose_ to skip rotations + with very small angles. In particular the original implementation may lead + to FP overflows for large values of i. + """ + + for i, q in enumerate(qubits): + exp = self.exponent / 2**i + if exp < 1e-16: + break + yield cirq.Z(q) ** exp + + +PhaseGradientGate._decompose_ = phase_gradient_decompose diff --git a/source/pip/qsharp/qre/interop/_qsharp.py b/source/pip/qsharp/qre/interop/_qsharp.py index d595ad9e9c..0a38644694 100644 --- a/source/pip/qsharp/qre/interop/_qsharp.py +++ b/source/pip/qsharp/qre/interop/_qsharp.py @@ -96,7 +96,7 @@ def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: block.add_operation(MEAS_Z, [0]) if memory_qubits != 0: - trace.set_memory_qubits(memory_qubits) + trace.memory_qubits = memory_qubits if rfm_count := counts.get("readFromMemoryCount", 0): block = trace.add_block(repetitions=rfm_count) diff --git a/source/pip/qsharp/qre/models/qubits/_aqre.py b/source/pip/qsharp/qre/models/qubits/_aqre.py index 2726daf219..6e6f09b8be 100644 --- a/source/pip/qsharp/qre/models/qubits/_aqre.py +++ b/source/pip/qsharp/qre/models/qubits/_aqre.py @@ -5,8 +5,34 @@ from typing import Optional from ..._architecture import Architecture, _Context -from ...instruction_ids import CNOT, CZ, MEAS_Z, PAULI_I, H, T from ..._instruction import ISA, Encoding +from ...instruction_ids import ( + CNOT, + CZ, + MEAS_X, + MEAS_Y, + MEAS_Z, + PAULI_I, + PAULI_X, + PAULI_Y, + PAULI_Z, + RX, + RY, + RZ, + S_DAG, + SQRT_X, + SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + T_DAG, + H, + S, + T, +) @dataclass @@ -54,46 +80,65 @@ def provided_isa(self, ctx: _Context) -> ISA: # Value is initialized in __post_init__ assert self.two_qubit_gate_time is not None - return ctx.make_isa( - ctx.add_instruction( - PAULI_I, - encoding=Encoding.PHYSICAL, - arity=1, - time=self.gate_time, - error_rate=self.error_rate, - ), - ctx.add_instruction( - CNOT, - encoding=Encoding.PHYSICAL, - arity=2, - time=self.two_qubit_gate_time, - error_rate=self.error_rate, - ), - ctx.add_instruction( - CZ, - encoding=Encoding.PHYSICAL, - arity=2, - time=self.two_qubit_gate_time, - error_rate=self.error_rate, - ), - ctx.add_instruction( - H, - encoding=Encoding.PHYSICAL, - arity=1, - time=self.gate_time, - error_rate=self.error_rate, - ), - ctx.add_instruction( - MEAS_Z, - encoding=Encoding.PHYSICAL, - arity=1, - time=self.measurement_time, - error_rate=self.error_rate, - ), - ctx.add_instruction( - T, - encoding=Encoding.PHYSICAL, - time=self.gate_time, - error_rate=self.error_rate, - ), - ) + # NOTE: This can be improved with instruction coercion once implemented. + instructions = [] + + # Single-qubit gates + single = [ + PAULI_I, + PAULI_X, + PAULI_Y, + PAULI_Z, + H, + SQRT_X, + SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + S, + S_DAG, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + T, + T_DAG, + RX, + RY, + RZ, + ] + + for instr in single: + instructions.append( + ctx.add_instruction( + instr, + encoding=Encoding.PHYSICAL, + arity=1, + time=self.gate_time, + error_rate=self.error_rate, + ) + ) + + for instr in [MEAS_X, MEAS_Y, MEAS_Z]: + instructions.append( + ctx.add_instruction( + instr, + encoding=Encoding.PHYSICAL, + arity=1, + time=self.measurement_time, + error_rate=self.error_rate, + ) + ) + + # Two-qubit gates + for instr in [CNOT, CZ]: + instructions.append( + ctx.add_instruction( + instr, + encoding=Encoding.PHYSICAL, + arity=2, + time=self.two_qubit_gate_time, + error_rate=self.error_rate, + ) + ) + + return ctx.make_isa(*instructions) diff --git a/source/pip/qsharp/qre/property_keys.pyi b/source/pip/qsharp/qre/property_keys.pyi index ed0b311821..0e0e2358b4 100644 --- a/source/pip/qsharp/qre/property_keys.pyi +++ b/source/pip/qsharp/qre/property_keys.pyi @@ -17,3 +17,4 @@ LOGICAL_COMPUTE_QUBITS: int LOGICAL_MEMORY_QUBITS: int ALGORITHM_COMPUTE_QUBITS: int ALGORITHM_MEMORY_QUBITS: int +NAME: int diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 6a47e50245..23b5f6baf7 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -1072,6 +1072,11 @@ impl Trace { self.0.compute_qubits() } + #[setter] + pub fn set_compute_qubits(&mut self, qubits: u64) { + self.0.set_compute_qubits(qubits); + } + #[getter] pub fn base_error(&self) -> f64 { self.0.base_error() @@ -1137,11 +1142,21 @@ impl Trace { Ok(dict) } + #[getter] + pub fn total_qubits(&self) -> u64 { + self.0.total_qubits() + } + #[getter] pub fn depth(&self) -> u64 { self.0.depth() } + #[getter] + pub fn num_gates(&self) -> u64 { + self.0.num_gates() + } + #[pyo3(signature = (isa, max_error = None))] pub fn estimate(&self, isa: &ISA, max_error: Option) -> Option { self.0 @@ -1158,14 +1173,23 @@ impl Trace { self.0.add_operation(id, qubits, params); } + pub fn root_block(mut slf: PyRefMut<'_, Self>) -> Block { + let block = slf.0.root_block_mut(); + let ptr = NonNull::from(block); + Block { + ptr, + parent: slf.into(), + } + } + #[pyo3(signature = (repetitions = 1))] - pub fn add_block(mut slf: PyRefMut<'_, Self>, repetitions: u64) -> PyResult { + pub fn add_block(mut slf: PyRefMut<'_, Self>, repetitions: u64) -> Block { let block = slf.0.add_block(repetitions); let ptr = NonNull::from(block); - Ok(Block { + Block { ptr, parent: slf.into(), - }) + } } #[getter] @@ -1177,6 +1201,7 @@ impl Trace { self.0.has_memory_qubits() } + #[setter] pub fn set_memory_qubits(&mut self, qubits: u64) { self.0.set_memory_qubits(qubits); } @@ -1590,6 +1615,7 @@ fn add_property_keys(m: &Bound<'_, PyModule>) -> PyResult<()> { LOGICAL_MEMORY_QUBITS, ALGORITHM_COMPUTE_QUBITS, ALGORITHM_MEMORY_QUBITS, + NAME ); m.add_submodule(&property_keys)?; diff --git a/source/pip/tests/qre/test_cirq_interop.py b/source/pip/tests/qre/test_cirq_interop.py new file mode 100644 index 0000000000..b65fe39bdb --- /dev/null +++ b/source/pip/tests/qre/test_cirq_interop.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import cirq +from qsharp.qre.application import CirqApplication + + +def test_with_qft(): + _test_one_circuit(cirq.qft(*cirq.LineQubit.range(1025)), 1025, 212602, 266007) + + +def test_h(): + _test_one_circuit(cirq.H, 1, 1, 1) + _test_one_circuit(cirq.H**0.5, 1, 3, 3) + + +def test_cx(): + _test_one_circuit(cirq.CX, 2, 1, 1) + _test_one_circuit(cirq.CX**0.5, 2, 6, 7) + _test_one_circuit(cirq.CX**0.25, 2, 6, 7) + + +def test_cz(): + _test_one_circuit(cirq.CZ, 2, 1, 1) + _test_one_circuit(cirq.CZ**0.5, 2, 4, 5) + _test_one_circuit(cirq.CZ**0.25, 2, 4, 5) + + +def test_swap(): + _test_one_circuit(cirq.SWAP, 2, 1, 1) + _test_one_circuit(cirq.SWAP**0.5, 2, 8, 9) + + +def test_ccx(): + _test_one_circuit(cirq.CCX, 3, 1, 1) + _test_one_circuit(cirq.CCX**0.5, 3, 11, 17) + + +def test_ccz(): + _test_one_circuit(cirq.CCZ, 3, 1, 1) + _test_one_circuit(cirq.CCZ**0.5, 3, 10, 15) + + +def test_circuit_with_block(): + class CustomGate(cirq.Gate): + def num_qubits(self) -> int: + return 2 + + def _decompose_(self, qubits): + a, b = qubits + yield cirq.CX(a, b) + yield cirq.CX(b, a) + yield cirq.CX(a, b) + + q0, q1 = cirq.LineQubit.range(2) + _test_one_circuit( + [ + cirq.H.on_each(q0, q1), + CustomGate().on(q0, q1), + ], + 2, + 4, + 5, + ) + + +def _test_one_circuit( + circuit: cirq.CIRCUIT_LIKE, + expected_qubits: int, + expected_depth: int, + expected_gates: int, +): + app = CirqApplication(circuit) + trace = app.get_trace() + + assert trace.total_qubits == expected_qubits, "unexpected number of qubits in trace" + assert trace.depth == expected_depth, "unexpected depth of trace" + assert trace.num_gates == expected_gates, "unexpected number of gates in trace" diff --git a/source/qre/src/isa/property_keys.rs b/source/qre/src/isa/property_keys.rs index 6f4e7ca877..16158e31db 100644 --- a/source/qre/src/isa/property_keys.rs +++ b/source/qre/src/isa/property_keys.rs @@ -62,4 +62,5 @@ define_properties! { LOGICAL_MEMORY_QUBITS, ALGORITHM_COMPUTE_QUBITS, ALGORITHM_MEMORY_QUBITS, + NAME, } diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 23d8a67898..08858a4551 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -66,6 +66,10 @@ impl Trace { self.compute_qubits } + pub fn set_compute_qubits(&mut self, compute_qubits: u64) { + self.compute_qubits = compute_qubits; + } + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { self.block.add_operation(id, qubits, params); } @@ -74,6 +78,10 @@ impl Trace { self.block.add_block(repetitions) } + pub fn root_block_mut(&mut self) -> &mut Block { + &mut self.block + } + #[must_use] pub fn base_error(&self) -> f64 { self.base_error @@ -215,6 +223,11 @@ impl Trace { self.block.depth() } + #[must_use] + pub fn num_gates(&self) -> u64 { + self.deep_iter().map(|(_, m)| m).sum() + } + #[allow( clippy::cast_precision_loss, clippy::cast_possible_truncation, From 358171d9ff306682981b9efbf3759c5a76dad71a Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Fri, 27 Mar 2026 12:09:59 +0000 Subject: [PATCH 40/45] File missing from previous merge. --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 50d82c092a..d8660e70ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2082,7 +2082,7 @@ dependencies = [ "probability", "rustc-hash", "serde", - "thiserror 1.0.63", + "thiserror 2.0.18", ] [[package]] From 8f08a2b2abc965d3570b544087f63dab9b920f0e Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Mon, 30 Mar 2026 19:04:00 +0200 Subject: [PATCH 41/45] Reorganize files and spit up large files in QRE (#3069) Some files became too large, so they have been split up. Also some Python names are used in the public API and the `_` prefix was removed. The list of commits gives a good overview of what happened in this PR. There is no changes to the functionality or no added functionality in this PR. --- source/pip/benchmarks/bench_qre.py | 8 +- source/pip/qsharp/qre/__init__.py | 14 +- source/pip/qsharp/qre/_architecture.py | 46 +- source/pip/qsharp/qre/_estimation.py | 388 +--- source/pip/qsharp/qre/_instruction.py | 20 +- source/pip/qsharp/qre/_isa_enumeration.py | 34 +- source/pip/qsharp/qre/_qre.py | 2 +- source/pip/qsharp/qre/_qre.pyi | 66 +- source/pip/qsharp/qre/_results.py | 374 ++++ source/pip/qsharp/qre/interop/_cirq.py | 53 +- source/pip/qsharp/qre/models/__init__.py | 4 +- .../qsharp/qre/models/factories/_litinski.py | 6 +- .../qre/models/factories/_round_based.py | 14 +- .../pip/qsharp/qre/models/factories/_utils.py | 4 +- .../qsharp/qre/models/qec/_surface_code.py | 6 +- .../pip/qsharp/qre/models/qec/_three_aux.py | 6 +- source/pip/qsharp/qre/models/qec/_yoked.py | 10 +- .../pip/qsharp/qre/models/qubits/__init__.py | 4 +- .../qubits/{_aqre.py => _gate_based.py} | 15 +- source/pip/qsharp/qre/models/qubits/_msft.py | 4 +- source/pip/src/qre.rs | 4 +- source/pip/tests/qre/__init__.py | 2 + source/pip/tests/qre/conftest.py | 58 + source/pip/tests/qre/test_application.py | 216 +++ source/pip/tests/qre/test_enumeration.py | 527 ++++++ source/pip/tests/qre/test_estimation.py | 98 + source/pip/tests/qre/test_estimation_table.py | 439 +++++ source/pip/tests/qre/test_interop.py | 213 +++ source/pip/tests/qre/test_isa.py | 181 ++ .../test_models.py} | 100 +- source/pip/tests/test_qre.py | 1666 ----------------- source/qre/src/isa.rs | 329 +--- source/qre/src/isa/provenance.rs | 332 ++++ source/qre/src/trace.rs | 460 +---- source/qre/src/trace/estimation.rs | 462 +++++ 35 files changed, 3153 insertions(+), 3012 deletions(-) create mode 100644 source/pip/qsharp/qre/_results.py rename source/pip/qsharp/qre/models/qubits/{_aqre.py => _gate_based.py} (86%) create mode 100644 source/pip/tests/qre/__init__.py create mode 100644 source/pip/tests/qre/conftest.py create mode 100644 source/pip/tests/qre/test_application.py create mode 100644 source/pip/tests/qre/test_enumeration.py create mode 100644 source/pip/tests/qre/test_estimation.py create mode 100644 source/pip/tests/qre/test_estimation_table.py create mode 100644 source/pip/tests/qre/test_interop.py create mode 100644 source/pip/tests/qre/test_isa.py rename source/pip/tests/{test_qre_models.py => qre/test_models.py} (90%) delete mode 100644 source/pip/tests/test_qre.py create mode 100644 source/qre/src/isa/provenance.rs create mode 100644 source/qre/src/trace/estimation.rs diff --git a/source/pip/benchmarks/bench_qre.py b/source/pip/benchmarks/bench_qre.py index e236594921..536669a8aa 100644 --- a/source/pip/benchmarks/bench_qre.py +++ b/source/pip/benchmarks/bench_qre.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, KW_ONLY, field from qsharp.qre import linear_function, generic_function from qsharp.qre._architecture import _make_instruction -from qsharp.qre.models import AQREGateBased, SurfaceCode +from qsharp.qre.models import GateBased, SurfaceCode from qsharp.qre._enumeration import _enumerate_instances @@ -37,10 +37,10 @@ def bench_enumerate_isas(): # Add the tests directory to sys.path to import test_qre # TODO: Remove this once the models in test_qre are moved to a proper module - sys.path.append(os.path.join(os.path.dirname(__file__), "../tests")) - from test_qre import ExampleLogicalFactory, ExampleFactory # type: ignore + sys.path.append(os.path.join(os.path.dirname(__file__), "../tests/qre/")) + from conftest import ExampleLogicalFactory, ExampleFactory # type: ignore - ctx = AQREGateBased(gate_time=50, measurement_time=100).context() + ctx = GateBased(gate_time=50, measurement_time=100).context() # Hierarchical factory using from_components query = SurfaceCode.q() * ExampleLogicalFactory.q( diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index a17bc2122c..6ba945acf1 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -3,13 +3,7 @@ from ._application import Application from ._architecture import Architecture -from ._estimation import ( - estimate, - EstimationTable, - EstimationTableColumn, - EstimationTableEntry, - plot_estimates, -) +from ._estimation import estimate from ._instruction import ( LOGICAL, PHYSICAL, @@ -37,6 +31,12 @@ property_name, property_name_to_key, ) +from ._results import ( + EstimationTable, + EstimationTableColumn, + EstimationTableEntry, + plot_estimates, +) from ._trace import LatticeSurgery, PSSPC, TraceQuery, TraceTransform # Extend Rust Python types with additional Python-side functionality diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index 9045caee73..1bfb3f29ff 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -10,7 +10,7 @@ from ._qre import ( ISA, _ProvenanceGraph, - _Instruction, + Instruction, _IntFunction, _FloatFunction, constant_function, @@ -25,7 +25,7 @@ class Architecture(ABC): @abstractmethod - def provided_isa(self, ctx: _Context) -> ISA: + def provided_isa(self, ctx: ISAContext) -> ISA: """ Creates the ISA provided by this architecture, adding instructions directly to the context's provenance graph. @@ -39,12 +39,12 @@ def provided_isa(self, ctx: _Context) -> ISA: """ ... - def context(self) -> _Context: + def context(self) -> ISAContext: """Create a new enumeration context for this architecture.""" - return _Context(self) + return ISAContext(self) -class _Context: +class ISAContext: """ Context passed through enumeration, holding shared state. """ @@ -58,7 +58,7 @@ def __init__(self, arch: Architecture): self._bindings: dict[str, ISA] = {} self._transforms: dict[int, Architecture | ISATransform] = {0: arch} - def _with_binding(self, name: str, isa: ISA) -> _Context: + def _with_binding(self, name: str, isa: ISA) -> ISAContext: """Return a new context with an additional binding (internal use).""" ctx = copy.copy(self) ctx._bindings = {**self._bindings, name: isa} @@ -71,7 +71,7 @@ def isa(self) -> ISA: def add_instruction( self, - id_or_instruction: int | _Instruction, + id_or_instruction: int | Instruction, encoding: Encoding = 0, # type: ignore *, arity: Optional[int] = 1, @@ -80,7 +80,7 @@ def add_instruction( length: Optional[int | _IntFunction] = None, error_rate: float | _FloatFunction = 0.0, transform: ISATransform | None = None, - source: list[_Instruction] | None = None, + source: list[Instruction] | None = None, **kwargs: int, ) -> int: """ @@ -93,7 +93,7 @@ def add_instruction( ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8) - 2. With a pre-existing ``_Instruction`` object (e.g. from + 2. With a pre-existing ``Instruction`` object (e.g. from ``with_id()``):: ctx.add_instruction(existing_instruction) @@ -107,26 +107,26 @@ def add_instruction( Args: id_or_instruction: Either an instruction ID (int) for creating - a new instruction, or an existing ``_Instruction`` object. + a new instruction, or an existing ``Instruction`` object. encoding: The instruction encoding (0 = Physical, 1 = Logical). - Ignored when passing an existing ``_Instruction``. + Ignored when passing an existing ``Instruction``. arity: The instruction arity. ``None`` for variable arity. - Ignored when passing an existing ``_Instruction``. + Ignored when passing an existing ``Instruction``. time: Instruction time in ns (or ``_IntFunction`` for variable - arity). Ignored when passing an existing ``_Instruction``. + arity). Ignored when passing an existing ``Instruction``. space: Instruction space in physical qubits (or ``_IntFunction`` for variable arity). Ignored when passing an existing - ``_Instruction``. + ``Instruction``. length: Arity including ancilla qubits. Ignored when passing an - existing ``_Instruction``. + existing ``Instruction``. error_rate: Instruction error rate (or ``_FloatFunction`` for variable arity). Ignored when passing an existing - ``_Instruction``. + ``Instruction``. transform: The ``ISATransform`` that produced the instruction. - source: List of source ``_Instruction`` objects consumed by the + source: List of source ``Instruction`` objects consumed by the transform. **kwargs: Additional properties (e.g. ``distance=9``). Ignored - when passing an existing ``_Instruction``. + when passing an existing ``Instruction``. Returns: The node index in the provenance graph. @@ -146,7 +146,7 @@ def add_instruction( **kwargs, ) - if isinstance(id_or_instruction, _Instruction): + if isinstance(id_or_instruction, Instruction): instr = id_or_instruction else: instr = _make_instruction( @@ -193,10 +193,10 @@ def _make_instruction( length: int | _IntFunction | None, error_rate: float | _FloatFunction, properties: dict[str, int], -) -> _Instruction: - """Build an ``_Instruction`` from keyword arguments.""" +) -> Instruction: + """Build an ``Instruction`` from keyword arguments.""" if arity is not None: - instr = _Instruction.fixed_arity( + instr = Instruction.fixed_arity( id, encoding, arity, @@ -215,7 +215,7 @@ def _make_instruction( if isinstance(error_rate, (int, float)): error_rate = constant_function(float(error_rate)) - instr = _Instruction.variable_arity( + instr = Instruction.variable_arity( id, encoding, time, diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index b49f92d60b..7f39fd1683 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -3,30 +3,20 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import cast, Optional, Callable, Any, Iterable +from typing import cast, Optional, Any -import pandas as pd from ._application import Application -from ._architecture import Architecture, _Context +from ._architecture import Architecture from ._qre import ( _estimate_parallel, _estimate_with_graph, _EstimationCollection, Trace, - FactoryResult, - instruction_name, - EstimationResult, ) from ._trace import TraceQuery, PSSPC, LatticeSurgery -from ._instruction import InstructionSource from ._isa_enumeration import ISAQuery -from .property_keys import ( - PHYSICAL_COMPUTE_QUBITS, - PHYSICAL_MEMORY_QUBITS, - PHYSICAL_FACTORY_QUBITS, -) +from ._results import EstimationTable, EstimationTableEntry def estimate( @@ -136,12 +126,12 @@ def estimate( # trace, not on the ISA. trace_multipliers: dict[int, tuple[float, float]] = {} trace_sample_isa: dict[int, int] = {} - for t_idx, i_idx, _q, r in summaries: + for t_idx, isa_idx, _q, r in summaries: if t_idx not in trace_sample_isa: - trace_sample_isa[t_idx] = i_idx - for t_idx, i_idx in trace_sample_isa.items(): + trace_sample_isa[t_idx] = isa_idx + for t_idx, isa_idx in trace_sample_isa.items(): params, trace = params_and_traces[t_idx] - sample = trace.estimate(isas[i_idx], max_error) + sample = trace.estimate(isas[isa_idx], max_error) if sample is not None: pre_q = sample.qubits pre_r = sample.runtime @@ -150,12 +140,14 @@ def estimate( trace_multipliers[t_idx] = (pp.qubits / pre_q, pp.runtime / pre_r) # Phase 3: Estimate post-pp values and filter to Pareto candidates. - estimated_pp: list[tuple[int, int, int, int]] = [] # (t, i, q, est_r) - for t_idx, i_idx, q, r in summaries: + estimated_pp: list[tuple[int, int, int, int]] = ( + [] + ) # (t_idx, isa_idx, est_q, est_r) + for t_idx, isa_idx, q, r in summaries: mult_q, mult_r = trace_multipliers.get(t_idx, (0.0, 0.0)) est_q = int(q * mult_q) if mult_q > 0 else q est_r = int(r * mult_r) if mult_r > 0 else r - estimated_pp.append((t_idx, i_idx, est_q, est_r)) + estimated_pp.append((t_idx, isa_idx, est_q, est_r)) # Build approximate post-pp Pareto frontier to identify candidates. estimated_pp.sort(key=lambda x: (x[2], x[3])) # sort by qubits, then runtime @@ -168,9 +160,9 @@ def estimate( # Phase 4: Re-estimate and post-process only the Pareto candidates. pp_collection = _EstimationCollection() - for t_idx, i_idx, _q, _r in approx_pareto: + for t_idx, isa_idx, _q, _r in approx_pareto: params, trace = params_and_traces[t_idx] - result = trace.estimate(isas[i_idx], max_error) + result = trace.estimate(isas[isa_idx], max_error) if result is not None: pp_result = app_ctx.application.post_process(params, result) if pp_result is not None: @@ -222,355 +214,3 @@ def estimate( table.stats.pareto_results = len(collection) return table - - -class EstimationTable(list["EstimationTableEntry"]): - """A table of quantum resource estimation results. - - Extends ``list[EstimationTableEntry]`` and provides configurable columns for - displaying estimation data. By default the table includes *qubits*, - *runtime* (displayed as a ``pandas.Timedelta``), and *error* columns. - Additional columns can be added or inserted with :meth:`add_column` and - :meth:`insert_column`. - """ - - def __init__(self): - """Initialize an empty estimation table with default columns.""" - super().__init__() - - self.name: Optional[str] = None - self.stats = EstimationTableStats() - - self._columns: list[tuple[str, EstimationTableColumn]] = [ - ("qubits", EstimationTableColumn(lambda entry: entry.qubits)), - ( - "runtime", - EstimationTableColumn( - lambda entry: entry.runtime, - formatter=lambda x: pd.Timedelta(x, unit="ns"), - ), - ), - ("error", EstimationTableColumn(lambda entry: entry.error)), - ] - - def add_column( - self, - name: str, - function: Callable[[EstimationTableEntry], Any], - formatter: Optional[Callable[[Any], Any]] = None, - ) -> None: - """Adds a column to the estimation table. - - Args: - name (str): The name of the column. - function (Callable[[EstimationTableEntry], Any]): A function that - takes an EstimationTableEntry and returns the value for this - column. - formatter (Optional[Callable[[Any], Any]]): An optional function - that formats the output of `function` for display purposes. - """ - self._columns.append((name, EstimationTableColumn(function, formatter))) - - def insert_column( - self, - index: int, - name: str, - function: Callable[[EstimationTableEntry], Any], - formatter: Optional[Callable[[Any], Any]] = None, - ) -> None: - """Inserts a column at the specified index in the estimation table. - - Args: - index (int): The index at which to insert the column. - name (str): The name of the column. - function (Callable[[EstimationTableEntry], Any]): A function that - takes an EstimationTableEntry and returns the value for this - column. - formatter (Optional[Callable[[Any], Any]]): An optional function - that formats the output of `function` for display purposes. - """ - self._columns.insert(index, (name, EstimationTableColumn(function, formatter))) - - def add_qubit_partition_column(self) -> None: - self.add_column( - "physical_compute_qubits", - lambda entry: entry.properties.get(PHYSICAL_COMPUTE_QUBITS, 0), - ) - self.add_column( - "physical_factory_qubits", - lambda entry: entry.properties.get(PHYSICAL_FACTORY_QUBITS, 0), - ) - self.add_column( - "physical_memory_qubits", - lambda entry: entry.properties.get(PHYSICAL_MEMORY_QUBITS, 0), - ) - - def add_factory_summary_column(self) -> None: - """Adds a column to the estimation table that summarizes the factories used in the estimation.""" - - def summarize_factories(entry: EstimationTableEntry) -> str: - if not entry.factories: - return "None" - return ", ".join( - f"{factory_result.copies}×{instruction_name(id)}" - for id, factory_result in entry.factories.items() - ) - - self.add_column("factories", summarize_factories) - - def as_frame(self): - """Convert the estimation table to a :class:`pandas.DataFrame`. - - Each row corresponds to an :class:`EstimationTableEntry` and each - column is determined by the columns registered on this table. Column - formatters, when present, are applied to the values before they are - placed in the frame. - - Returns: - pandas.DataFrame: A DataFrame representation of the estimation - results. - """ - return pd.DataFrame( - [ - { - column_name: ( - column.formatter(column.function(entry)) - if column.formatter is not None - else column.function(entry) - ) - for column_name, column in self._columns - } - for entry in self - ] - ) - - def plot(self, **kwargs): - """Plot this table's results. - - Convenience wrapper around :func:`plot_estimates`. All keyword - arguments are forwarded. - - Returns: - matplotlib.figure.Figure: The figure containing the plot. - """ - return plot_estimates(self, **kwargs) - - -@dataclass(frozen=True, slots=True) -class EstimationTableColumn: - """Definition of a single column in an :class:`EstimationTable`. - - Attributes: - function: A callable that extracts the raw column value from an - :class:`EstimationTableEntry`. - formatter: An optional callable that transforms the raw value for - display purposes (e.g. converting nanoseconds to a - ``pandas.Timedelta``). - """ - - function: Callable[[EstimationTableEntry], Any] - formatter: Optional[Callable[[Any], Any]] = None - - -@dataclass(frozen=True, slots=True) -class EstimationTableEntry: - """A single row in an :class:`EstimationTable`. - - Each entry represents one Pareto-optimal estimation result for a - particular combination of application trace and architecture ISA. - - Attributes: - qubits: Total number of physical qubits required. - runtime: Total runtime of the algorithm in nanoseconds. - error: Total estimated error probability. - source: The instruction source derived from the architecture ISA used - for this estimation. - factories: A mapping from instruction id to the - :class:`FactoryResult` describing the magic-state factory used - and the number of copies required. - properties: Additional key-value properties attached to the - estimation result. - """ - - qubits: int - runtime: int - error: float - source: InstructionSource - factories: dict[int, FactoryResult] = field(default_factory=dict) - properties: dict[int, int | float | bool | str] = field(default_factory=dict) - - @classmethod - def from_result( - cls, result: EstimationResult, ctx: _Context - ) -> EstimationTableEntry: - return cls( - qubits=result.qubits, - runtime=result.runtime, - error=result.error, - source=InstructionSource.from_isa(ctx, result.isa), - factories=result.factories.copy(), - properties=result.properties.copy(), - ) - - -@dataclass(slots=True) -class EstimationTableStats: - num_traces: int = 0 - num_isas: int = 0 - total_jobs: int = 0 - successful_estimates: int = 0 - pareto_results: int = 0 - - -# Mapping from runtime unit name to its value in nanoseconds. -_TIME_UNITS: dict[str, float] = { - "ns": 1, - "µs": 1e3, - "us": 1e3, - "ms": 1e6, - "s": 1e9, - "min": 60e9, - "hours": 3600e9, - "days": 86_400e9, - "weeks": 604_800e9, - "months": 31 * 86_400e9, - "years": 365 * 86_400e9, - "decades": 10 * 365 * 86_400e9, - "centuries": 100 * 365 * 86_400e9, -} - -# Ordered subset of _TIME_UNITS used for default x-axis tick labels. -_TICK_UNITS: list[tuple[str, float]] = [ - ("1 ns", _TIME_UNITS["ns"]), - ("1 µs", _TIME_UNITS["µs"]), - ("1 ms", _TIME_UNITS["ms"]), - ("1 s", _TIME_UNITS["s"]), - ("1 min", _TIME_UNITS["min"]), - ("1 hour", _TIME_UNITS["hours"]), - ("1 day", _TIME_UNITS["days"]), - ("1 week", _TIME_UNITS["weeks"]), - ("1 month", _TIME_UNITS["months"]), - ("1 year", _TIME_UNITS["years"]), - ("1 decade", _TIME_UNITS["decades"]), - ("1 century", _TIME_UNITS["centuries"]), -] - - -def plot_estimates( - data: EstimationTable | Iterable[EstimationTable], - *, - runtime_unit: Optional[str] = None, - figsize: tuple[float, float] = (15, 8), - scatter_args: dict[str, Any] = {"marker": "x"}, -): - """Returns a plot of the estimates displaying qubits vs runtime. - - Creates a log-log scatter plot where the x-axis shows the total runtime and - the y-axis shows the total number of physical qubits. - - *data* may be a single `EstimationTable` or an iterable of tables. When - multiple tables are provided, each is plotted as a separate series. If a - table has a `EstimationTable.name` (set via the *name* parameter of - `estimate`), it is used as the legend label for that series. - - When *runtime_unit* is ``None`` (the default), the x-axis uses - human-readable time-unit tick labels spanning nanoseconds to centuries. - When a unit string is given (e.g. ``"hours"``), all runtimes are scaled to - that unit and the x-axis label includes the unit while the ticks are plain - numbers. - - Supported *runtime_unit* values: ``"ns"``, ``"µs"`` (or ``"us"``), ``"ms"``, - ``"s"``, ``"min"``, ``"hours"``, ``"days"``, ``"weeks"``, ``"months"``, - ``"years"``. - - Args: - data: A single EstimationTable or an iterable of - EstimationTable objects to plot. - runtime_unit: Optional time unit to scale the x-axis to. - figsize: Figure dimensions in inches as ``(width, height)``. - scatter_args: Additional keyword arguments to pass to - ``matplotlib.axes.Axes.scatter`` when plotting the points. - - Returns: - matplotlib.figure.Figure: The figure containing the plot. - - Raises: - ImportError: If matplotlib is not installed. - ValueError: If all tables are empty or *runtime_unit* is not - recognised. - """ - try: - import matplotlib.pyplot as plt - except ImportError: - raise ImportError( - "Missing optional 'matplotlib' dependency. To install run: " - "pip install matplotlib" - ) - - # Normalize to a list of tables - if isinstance(data, EstimationTable): - tables = [data] - else: - tables = list(data) - - if not tables or all(len(t) == 0 for t in tables): - raise ValueError("Cannot plot an empty EstimationTable.") - - if runtime_unit is not None and runtime_unit not in _TIME_UNITS: - raise ValueError( - f"Unknown runtime_unit {runtime_unit!r}. " - f"Supported units: {', '.join(_TIME_UNITS)}" - ) - - fig, ax = plt.subplots(figsize=figsize) - ax.set_ylabel("Physical qubits") - ax.set_xscale("log") - ax.set_yscale("log") - - all_xs: list[float] = [] - has_labels = False - - for table in tables: - if len(table) == 0: - continue - - ys = [entry.qubits for entry in table] - - if runtime_unit is not None: - scale = _TIME_UNITS[runtime_unit] - xs = [entry.runtime / scale for entry in table] - else: - xs = [float(entry.runtime) for entry in table] - - all_xs.extend(xs) - - label = table.name - if label is not None: - has_labels = True - - ax.scatter(x=xs, y=ys, label=label, **scatter_args) - - if runtime_unit is not None: - ax.set_xlabel(f"Runtime ({runtime_unit})") - else: - ax.set_xlabel("Runtime") - - time_labels, time_units = zip(*_TICK_UNITS) - - cutoff = ( - next( - (i for i, x in enumerate(time_units) if x > max(all_xs)), - len(time_units) - 1, - ) - + 1 - ) - - ax.set_xticks(time_units[:cutoff]) - ax.set_xticklabels(time_labels[:cutoff], rotation=90) - - if has_labels: - ax.legend() - - plt.close(fig) - - return fig diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index de54bfd657..ab3c176e69 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -10,7 +10,7 @@ import pandas as pd -from ._architecture import _Context, Architecture +from ._architecture import ISAContext, Architecture from ._enumeration import _enumerate_instances from ._isa_enumeration import ( ISA_ROOT, @@ -22,7 +22,7 @@ ISA, Constraint, ConstraintBound, - _Instruction, + Instruction, ISARequirements, instruction_name, property_name_to_key, @@ -97,7 +97,9 @@ def required_isa() -> ISARequirements: ... @abstractmethod - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: """ Yields ISAs provided by this transform given an implementation ISA. @@ -113,7 +115,7 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non def enumerate_isas( cls, impl_isa: ISA | Iterable[ISA], - ctx: _Context, + ctx: ISAContext, **kwargs, ) -> Generator[ISA, None, None]: """ @@ -178,7 +180,7 @@ class InstructionSource: roots: list[int] = field(default_factory=list, init=False) @classmethod - def from_isa(cls, ctx: _Context, isa: ISA) -> InstructionSource: + def from_isa(cls, ctx: ISAContext, isa: ISA) -> InstructionSource: """ Constructs an InstructionSource graph from an ISA. @@ -187,7 +189,7 @@ def from_isa(cls, ctx: _Context, isa: ISA) -> InstructionSource: transforms and architectures that generated them. Args: - ctx (_Context): The enumeration context containing the provenance graph. + ctx (ISAContext): The enumeration context containing the provenance graph. isa (ISA): Instructions in the ISA will serve as root nodes in the source graph. Returns: @@ -231,7 +233,7 @@ def add_root(self, node_id: int) -> None: def add_node( self, - instruction: _Instruction, + instruction: Instruction, transform: Optional[ISATransform | Architecture], children: list[int], ) -> int: @@ -311,7 +313,7 @@ def get( @dataclass(frozen=True, slots=True) class _InstructionSourceNode: - instruction: _Instruction + instruction: Instruction transform: Optional[ISATransform | Architecture] children: list[int] @@ -322,7 +324,7 @@ def __init__(self, graph: InstructionSource, node_id: int): self.node_id = node_id @property - def instruction(self) -> _Instruction: + def instruction(self) -> Instruction: return self.graph.nodes[self.node_id].instruction @property diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py index 5cbb9fa187..c33fdac435 100644 --- a/source/pip/qsharp/qre/_isa_enumeration.py +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Generator -from ._architecture import _Context +from ._architecture import ISAContext from ._enumeration import _enumerate_instances from ._qre import ISA @@ -25,7 +25,7 @@ class ISAQuery(ABC): """ @abstractmethod - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields all ISA instances represented by this enumeration node. @@ -38,7 +38,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ pass - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """ Populates the provenance graph with instructions from this node. @@ -47,7 +47,7 @@ def populate(self, ctx: _Context) -> int: requirements, and adds produced instructions directly to the graph. Args: - ctx (_Context): The enumeration context whose provenance graph + ctx (ISAContext): The enumeration context whose provenance graph will be populated. Returns: @@ -158,7 +158,7 @@ class RootNode(ISAQuery): Reads from the context instead of holding a reference. """ - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields the architecture ISA from the context. @@ -170,8 +170,8 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ yield ctx._isa - def populate(self, ctx: _Context) -> int: - """Architecture ISA is already in the graph from ``_Context.__init__``. + def populate(self, ctx: ISAContext) -> int: + """Architecture ISA is already in the graph from ``ISAContext.__init__``. Returns: int: 1, since architecture nodes start at index 1. @@ -203,7 +203,7 @@ class _ComponentQuery(ISAQuery): source: ISAQuery = field(default_factory=lambda: ISA_ROOT) kwargs: dict = field(default_factory=dict) - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields ISAs generated by the component from source ISAs. @@ -216,7 +216,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for isa in self.source.enumerate(ctx): yield from self.component.enumerate_isas(isa, ctx, **self.kwargs) - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """ Populates the graph by querying matching instructions. @@ -253,7 +253,7 @@ class _ProductNode(ISAQuery): sources: list[ISAQuery] - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields ISAs formed by combining ISAs from all source nodes. @@ -269,7 +269,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for isa_tuple in itertools.product(*source_generators) ) - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """Populates the graph from each source sequentially (no cross product). Returns: @@ -292,7 +292,7 @@ class _SumNode(ISAQuery): sources: list[ISAQuery] - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields ISAs from each source node in sequence. @@ -305,7 +305,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for source in self.sources: yield from source.enumerate(ctx) - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """Populates the graph from each source sequentially. Returns: @@ -330,7 +330,7 @@ class ISARefNode(ISAQuery): name: str - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields the bound ISA from the context. @@ -347,7 +347,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: raise ValueError(f"Undefined component reference: '{self.name}'") yield ctx._bindings[self.name] - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """Instructions already in graph from the bound component. Returns: @@ -401,7 +401,7 @@ class _BindingNode(ISAQuery): component: ISAQuery node: ISAQuery - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Enumerates child nodes with the bound component in context. @@ -417,7 +417,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: new_ctx = ctx._with_binding(self.name, isa) yield from self.node.enumerate(new_ctx) - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """Populates the graph from both the component and the child node. Returns: diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index f724349388..2d1aaa7aa5 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -19,7 +19,7 @@ _FloatFunction, generic_function, instruction_name, - _Instruction, + Instruction, InstructionFrontier, _IntFunction, ISA, diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index e143333df4..7e9f92ddc8 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -33,7 +33,7 @@ class ISA: """ ... - def __getitem__(self, id: int) -> _Instruction: + def __getitem__(self, id: int) -> Instruction: """ Gets an instruction by its ID. @@ -41,23 +41,23 @@ class ISA: id (int): The instruction ID. Returns: - _Instruction: The instruction. + Instruction: The instruction. """ ... def get( - self, id: int, default: Optional[_Instruction] = None - ) -> Optional[_Instruction]: + self, id: int, default: Optional[Instruction] = None + ) -> Optional[Instruction]: """ Gets an instruction by its ID, or returns a default value if not found. Args: id (int): The instruction ID. - default (Optional[_Instruction]): The default value to return if the + default (Optional[Instruction]): The default value to return if the instruction is not found. Returns: - Optional[_Instruction]: The instruction, or the default value if not found. + Optional[Instruction]: The instruction, or the default value if not found. """ ... @@ -105,7 +105,7 @@ class ISA: """ ... - def __iter__(self) -> Iterator[_Instruction]: + def __iter__(self) -> Iterator[Instruction]: """ Returns an iterator over the instructions. @@ -113,7 +113,7 @@ class ISA: The order of instructions is not guaranteed. Returns: - Iterator[_Instruction]: The instruction iterator. + Iterator[Instruction]: The instruction iterator. """ ... @@ -178,7 +178,7 @@ class ISARequirements: """ ... -class _Instruction: +class Instruction: @staticmethod def fixed_arity( id: int, @@ -188,7 +188,7 @@ class _Instruction: space: Optional[int], length: Optional[int], error_rate: float, - ) -> _Instruction: + ) -> Instruction: """ Creates an instruction with a fixed arity. @@ -207,7 +207,7 @@ class _Instruction: error_rate (float): The instruction error rate. Returns: - _Instruction: The instruction. + Instruction: The instruction. """ ... @@ -219,7 +219,7 @@ class _Instruction: space_fn: _IntFunction, error_rate_fn: _FloatFunction, length_fn: Optional[_IntFunction], - ) -> _Instruction: + ) -> Instruction: """ Creates an instruction with variable arity. @@ -236,11 +236,11 @@ class _Instruction: If None, space_fn is used. Returns: - _Instruction: The instruction. + Instruction: The instruction. """ ... - def with_id(self, id: int) -> _Instruction: + def with_id(self, id: int) -> Instruction: """ Returns a copy of the instruction with the given ID. @@ -252,7 +252,7 @@ class _Instruction: id (int): The instruction ID. Returns: - _Instruction: A copy of the instruction with the given ID. + Instruction: A copy of the instruction with the given ID. """ ... @@ -702,7 +702,7 @@ class _ProvenanceGraph: """ def add_node( - self, instruction: _Instruction, transform_id: int, children: list[int] + self, instruction: Instruction, transform_id: int, children: list[int] ) -> int: """ Adds a node to the provenance graph. @@ -717,7 +717,7 @@ class _ProvenanceGraph: """ ... - def instruction(self, node_index: int) -> _Instruction: + def instruction(self, node_index: int) -> Instruction: """ Returns the instruction for a given node index. @@ -774,7 +774,7 @@ class _ProvenanceGraph: @overload def add_instruction( self, - instruction: _Instruction, + instruction: Instruction, ) -> int: ... @overload def add_instruction( @@ -791,7 +791,7 @@ class _ProvenanceGraph: ) -> int: ... def add_instruction( self, - id_or_instruction: int | _Instruction, + id_or_instruction: int | Instruction, encoding: int = 0, *, arity: Optional[int] = 1, @@ -805,20 +805,20 @@ class _ProvenanceGraph: Adds an instruction to the provenance graph with no transform or children. - Can be called with a pre-existing ``_Instruction`` or with keyword + Can be called with a pre-existing ``Instruction`` or with keyword args to create one inline. Args: - id_or_instruction: An instruction ID (int) or ``_Instruction``. - encoding: 0 = Physical, 1 = Logical. Ignored for ``_Instruction``. + id_or_instruction: An instruction ID (int) or ``Instruction``. + encoding: 0 = Physical, 1 = Logical. Ignored for ``Instruction``. arity: Instruction arity, ``None`` for variable. Ignored for - ``_Instruction``. - time: Time in ns (or ``_IntFunction``). Ignored for ``_Instruction``. + ``Instruction``. + time: Time in ns (or ``_IntFunction``). Ignored for ``Instruction``. space: Space in physical qubits (or ``_IntFunction``). Ignored for - ``_Instruction``. - length: Arity including ancillas. Ignored for ``_Instruction``. + ``Instruction``. + length: Arity including ancillas. Ignored for ``Instruction``. error_rate: Error rate (or ``_FloatFunction``). Ignored for - ``_Instruction``. + ``Instruction``. **kwargs: Additional properties (e.g. ``distance=9``). Returns: @@ -1511,21 +1511,21 @@ class InstructionFrontier: """ ... - def insert(self, point: _Instruction): + def insert(self, point: Instruction): """ Inserts an instruction to the frontier. Args: - point (_Instruction): The instruction to insert. + point (Instruction): The instruction to insert. """ ... - def extend(self, points: list[_Instruction]) -> None: + def extend(self, points: list[Instruction]) -> None: """ Extends the frontier with a list of instructions. Args: - points (list[_Instruction]): The instructions to insert. + points (list[Instruction]): The instructions to insert. """ ... @@ -1538,12 +1538,12 @@ class InstructionFrontier: """ ... - def __iter__(self) -> Iterator[_Instruction]: + def __iter__(self) -> Iterator[Instruction]: """ Returns an iterator over the instructions in the frontier. Returns: - Iterator[_Instruction]: The iterator. + Iterator[Instruction]: The iterator. """ ... diff --git a/source/pip/qsharp/qre/_results.py b/source/pip/qsharp/qre/_results.py new file mode 100644 index 0000000000..efaa3be144 --- /dev/null +++ b/source/pip/qsharp/qre/_results.py @@ -0,0 +1,374 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Callable, Any, Iterable + +import pandas as pd + +from ._architecture import ISAContext +from ._qre import ( + FactoryResult, + instruction_name, + EstimationResult, +) +from ._instruction import InstructionSource +from .property_keys import ( + PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_MEMORY_QUBITS, + PHYSICAL_FACTORY_QUBITS, +) + + +class EstimationTable(list["EstimationTableEntry"]): + """A table of quantum resource estimation results. + + Extends ``list[EstimationTableEntry]`` and provides configurable columns for + displaying estimation data. By default the table includes *qubits*, + *runtime* (displayed as a ``pandas.Timedelta``), and *error* columns. + Additional columns can be added or inserted with :meth:`add_column` and + :meth:`insert_column`. + """ + + def __init__(self): + """Initialize an empty estimation table with default columns.""" + super().__init__() + + self.name: Optional[str] = None + self.stats = EstimationTableStats() + + self._columns: list[tuple[str, EstimationTableColumn]] = [ + ("qubits", EstimationTableColumn(lambda entry: entry.qubits)), + ( + "runtime", + EstimationTableColumn( + lambda entry: entry.runtime, + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ), + ), + ("error", EstimationTableColumn(lambda entry: entry.error)), + ] + + def add_column( + self, + name: str, + function: Callable[[EstimationTableEntry], Any], + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: + """Adds a column to the estimation table. + + Args: + name (str): The name of the column. + function (Callable[[EstimationTableEntry], Any]): A function that + takes an EstimationTableEntry and returns the value for this + column. + formatter (Optional[Callable[[Any], Any]]): An optional function + that formats the output of `function` for display purposes. + """ + self._columns.append((name, EstimationTableColumn(function, formatter))) + + def insert_column( + self, + index: int, + name: str, + function: Callable[[EstimationTableEntry], Any], + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: + """Inserts a column at the specified index in the estimation table. + + Args: + index (int): The index at which to insert the column. + name (str): The name of the column. + function (Callable[[EstimationTableEntry], Any]): A function that + takes an EstimationTableEntry and returns the value for this + column. + formatter (Optional[Callable[[Any], Any]]): An optional function + that formats the output of `function` for display purposes. + """ + self._columns.insert(index, (name, EstimationTableColumn(function, formatter))) + + def add_qubit_partition_column(self) -> None: + self.add_column( + "physical_compute_qubits", + lambda entry: entry.properties.get(PHYSICAL_COMPUTE_QUBITS, 0), + ) + self.add_column( + "physical_factory_qubits", + lambda entry: entry.properties.get(PHYSICAL_FACTORY_QUBITS, 0), + ) + self.add_column( + "physical_memory_qubits", + lambda entry: entry.properties.get(PHYSICAL_MEMORY_QUBITS, 0), + ) + + def add_factory_summary_column(self) -> None: + """Adds a column to the estimation table that summarizes the factories used in the estimation.""" + + def summarize_factories(entry: EstimationTableEntry) -> str: + if not entry.factories: + return "None" + return ", ".join( + f"{factory_result.copies}×{instruction_name(id)}" + for id, factory_result in entry.factories.items() + ) + + self.add_column("factories", summarize_factories) + + def as_frame(self): + """Convert the estimation table to a :class:`pandas.DataFrame`. + + Each row corresponds to an :class:`EstimationTableEntry` and each + column is determined by the columns registered on this table. Column + formatters, when present, are applied to the values before they are + placed in the frame. + + Returns: + pandas.DataFrame: A DataFrame representation of the estimation + results. + """ + return pd.DataFrame( + [ + { + column_name: ( + column.formatter(column.function(entry)) + if column.formatter is not None + else column.function(entry) + ) + for column_name, column in self._columns + } + for entry in self + ] + ) + + def plot(self, **kwargs): + """Plot this table's results. + + Convenience wrapper around :func:`plot_estimates`. All keyword + arguments are forwarded. + + Returns: + matplotlib.figure.Figure: The figure containing the plot. + """ + return plot_estimates(self, **kwargs) + + +@dataclass(frozen=True, slots=True) +class EstimationTableColumn: + """Definition of a single column in an :class:`EstimationTable`. + + Attributes: + function: A callable that extracts the raw column value from an + :class:`EstimationTableEntry`. + formatter: An optional callable that transforms the raw value for + display purposes (e.g. converting nanoseconds to a + ``pandas.Timedelta``). + """ + + function: Callable[[EstimationTableEntry], Any] + formatter: Optional[Callable[[Any], Any]] = None + + +@dataclass(frozen=True, slots=True) +class EstimationTableEntry: + """A single row in an :class:`EstimationTable`. + + Each entry represents one Pareto-optimal estimation result for a + particular combination of application trace and architecture ISA. + + Attributes: + qubits: Total number of physical qubits required. + runtime: Total runtime of the algorithm in nanoseconds. + error: Total estimated error probability. + source: The instruction source derived from the architecture ISA used + for this estimation. + factories: A mapping from instruction id to the + :class:`FactoryResult` describing the magic-state factory used + and the number of copies required. + properties: Additional key-value properties attached to the + estimation result. + """ + + qubits: int + runtime: int + error: float + source: InstructionSource + factories: dict[int, FactoryResult] = field(default_factory=dict) + properties: dict[int, int | float | bool | str] = field(default_factory=dict) + + @classmethod + def from_result( + cls, result: EstimationResult, ctx: ISAContext + ) -> EstimationTableEntry: + return cls( + qubits=result.qubits, + runtime=result.runtime, + error=result.error, + source=InstructionSource.from_isa(ctx, result.isa), + factories=result.factories.copy(), + properties=result.properties.copy(), + ) + + +@dataclass(slots=True) +class EstimationTableStats: + num_traces: int = 0 + num_isas: int = 0 + total_jobs: int = 0 + successful_estimates: int = 0 + pareto_results: int = 0 + + +# Mapping from runtime unit name to its value in nanoseconds. +_TIME_UNITS: dict[str, float] = { + "ns": 1, + "µs": 1e3, + "us": 1e3, + "ms": 1e6, + "s": 1e9, + "min": 60e9, + "hours": 3600e9, + "days": 86_400e9, + "weeks": 604_800e9, + "months": 31 * 86_400e9, + "years": 365 * 86_400e9, + "decades": 10 * 365 * 86_400e9, + "centuries": 100 * 365 * 86_400e9, +} + +# Ordered subset of _TIME_UNITS used for default x-axis tick labels. +_TICK_UNITS: list[tuple[str, float]] = [ + ("1 ns", _TIME_UNITS["ns"]), + ("1 µs", _TIME_UNITS["µs"]), + ("1 ms", _TIME_UNITS["ms"]), + ("1 s", _TIME_UNITS["s"]), + ("1 min", _TIME_UNITS["min"]), + ("1 hour", _TIME_UNITS["hours"]), + ("1 day", _TIME_UNITS["days"]), + ("1 week", _TIME_UNITS["weeks"]), + ("1 month", _TIME_UNITS["months"]), + ("1 year", _TIME_UNITS["years"]), + ("1 decade", _TIME_UNITS["decades"]), + ("1 century", _TIME_UNITS["centuries"]), +] + + +def plot_estimates( + data: EstimationTable | Iterable[EstimationTable], + *, + runtime_unit: Optional[str] = None, + figsize: tuple[float, float] = (15, 8), + scatter_args: dict[str, Any] = {"marker": "x"}, +): + """Returns a plot of the estimates displaying qubits vs runtime. + + Creates a log-log scatter plot where the x-axis shows the total runtime and + the y-axis shows the total number of physical qubits. + + *data* may be a single `EstimationTable` or an iterable of tables. When + multiple tables are provided, each is plotted as a separate series. If a + table has a `EstimationTable.name` (set via the *name* parameter of + `estimate`), it is used as the legend label for that series. + + When *runtime_unit* is ``None`` (the default), the x-axis uses + human-readable time-unit tick labels spanning nanoseconds to centuries. + When a unit string is given (e.g. ``"hours"``), all runtimes are scaled to + that unit and the x-axis label includes the unit while the ticks are plain + numbers. + + Supported *runtime_unit* values: ``"ns"``, ``"µs"`` (or ``"us"``), ``"ms"``, + ``"s"``, ``"min"``, ``"hours"``, ``"days"``, ``"weeks"``, ``"months"``, + ``"years"``. + + Args: + data: A single EstimationTable or an iterable of + EstimationTable objects to plot. + runtime_unit: Optional time unit to scale the x-axis to. + figsize: Figure dimensions in inches as ``(width, height)``. + scatter_args: Additional keyword arguments to pass to + ``matplotlib.axes.Axes.scatter`` when plotting the points. + + Returns: + matplotlib.figure.Figure: The figure containing the plot. + + Raises: + ImportError: If matplotlib is not installed. + ValueError: If all tables are empty or *runtime_unit* is not + recognised. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + "Missing optional 'matplotlib' dependency. To install run: " + "pip install matplotlib" + ) + + # Normalize to a list of tables + if isinstance(data, EstimationTable): + tables = [data] + else: + tables = list(data) + + if not tables or all(len(t) == 0 for t in tables): + raise ValueError("Cannot plot an empty EstimationTable.") + + if runtime_unit is not None and runtime_unit not in _TIME_UNITS: + raise ValueError( + f"Unknown runtime_unit {runtime_unit!r}. " + f"Supported units: {', '.join(_TIME_UNITS)}" + ) + + fig, ax = plt.subplots(figsize=figsize) + ax.set_ylabel("Physical qubits") + ax.set_xscale("log") + ax.set_yscale("log") + + all_xs: list[float] = [] + has_labels = False + + for table in tables: + if len(table) == 0: + continue + + ys = [entry.qubits for entry in table] + + if runtime_unit is not None: + scale = _TIME_UNITS[runtime_unit] + xs = [entry.runtime / scale for entry in table] + else: + xs = [float(entry.runtime) for entry in table] + + all_xs.extend(xs) + + label = table.name + if label is not None: + has_labels = True + + ax.scatter(x=xs, y=ys, label=label, **scatter_args) + + if runtime_unit is not None: + ax.set_xlabel(f"Runtime ({runtime_unit})") + else: + ax.set_xlabel("Runtime") + + time_labels, time_units = zip(*_TICK_UNITS) + + cutoff = ( + next( + (i for i, x in enumerate(time_units) if x > max(all_xs)), + len(time_units) - 1, + ) + + 1 + ) + + ax.set_xticks(time_units[:cutoff]) + ax.set_xticklabels(time_labels[:cutoff], rotation=90) + + if has_labels: + ax.legend() + + plt.close(fig) + + return fig diff --git a/source/pip/qsharp/qre/interop/_cirq.py b/source/pip/qsharp/qre/interop/_cirq.py index 0153c00320..b685456d84 100644 --- a/source/pip/qsharp/qre/interop/_cirq.py +++ b/source/pip/qsharp/qre/interop/_cirq.py @@ -90,7 +90,7 @@ def trace_from_cirq( # circuit is OP_TREE circuit = cirq.Circuit(circuit) - context = _Context(circuit, classical_control_probability) + context = _CirqTraceBuilder(circuit, classical_control_probability) for moment in circuit: for op in moment.operations: @@ -99,11 +99,25 @@ def trace_from_cirq( return context.trace -class _Context: - """Tracks the current trace and block nesting during trace generation. +class _CirqTraceBuilder: + """Builds a resource estimation ``Trace`` from a Cirq circuit. - Maintains a stack of blocks so that ``PushBlock`` and ``PopBlock`` - operations can create nested repeated sections in the trace. + This class walks the operations produced by ``trace_from_cirq`` and + translates each one into trace instructions. It maintains the state + needed during the conversion: + + * A ``Trace`` instance that accumulates the result. + * A stack of ``Block`` objects so that ``PushBlock`` / ``PopBlock`` + markers can create nested repeated sections. + * A qubit-id mapping (``_QidToTraceId``) that assigns each Cirq qubit + a sequential integer index. + * A Cirq ``DecompositionContext`` for gates that need recursive + decomposition. + + Args: + circuit: The Cirq circuit being converted. + classical_control_probability: Probability that a classically + controlled operation is included in the trace. """ def __init__(self, circuit: cirq.Circuit, classical_control_probability: float): @@ -116,31 +130,41 @@ def __init__(self, circuit: cirq.Circuit, classical_control_probability: float): ) def push_block(self, repetitions: int): + """Open a new repeated block with the given number of repetitions.""" block = self.block.add_block(repetitions) self._blocks.append(block) def pop_block(self): + """Close the current repeated block, returning to the parent.""" self._blocks.pop() @property def trace(self) -> Trace: + """The accumulated trace, with ``compute_qubits`` updated to reflect + all qubits seen so far (including any allocated during decomposition).""" self._trace.compute_qubits = len(self._q_to_id) return self._trace @property def block(self) -> Block: + """The innermost open block in the trace.""" return self._blocks[-1] @property def q_to_id(self) -> _QidToTraceId: + """Mapping from Cirq ``Qid`` to integer trace qubit index.""" return self._q_to_id @property def classical_control_probability(self) -> float: + """Probability used to stochastically include classically controlled + operations.""" return self._classical_control_probability @property def decomp_context(self) -> cirq.DecompositionContext: + """Cirq decomposition context shared across all recursive + decompositions.""" return self._decomp_context def handle_op( @@ -151,15 +175,18 @@ def handle_op( Supported operation forms: - - ``TraceGate``: A raw trace instruction, added directly to the current block. - - ``PushBlock`` / ``PopBlock``: Control block nesting with repetitions. - - ``GateOperation``: Dispatched via ``_to_trace`` if available on the - gate, otherwise decomposed via ``_decompose_with_context_`` or - ``_decompose_``. + - ``TraceGate``: A raw trace instruction, added directly to the + current block. + - ``PushBlock`` / ``PopBlock``: Control block nesting with + repetitions. + - ``GateOperation``: Dispatched via ``_to_trace`` if available on + the gate, otherwise decomposed via + ``_decompose_with_context_`` or ``_decompose_``. - ``ClassicallyControlledOperation``: Included with the probability - specified in the generation context. - - ``list``: Each element is handled recursively. - - Any other operation: Decomposed via ``_decompose_with_context_``. + given by ``classical_control_probability``. + - ``list`` / iterable: Each element is handled recursively. + - Any other ``cirq.Operation``: Decomposed via + ``_decompose_with_context_``. Args: op: The operation to convert. diff --git a/source/pip/qsharp/qre/models/__init__.py b/source/pip/qsharp/qre/models/__init__.py index 5b8400002c..3da76797ac 100644 --- a/source/pip/qsharp/qre/models/__init__.py +++ b/source/pip/qsharp/qre/models/__init__.py @@ -8,10 +8,10 @@ OneDimensionalYokedSurfaceCode, TwoDimensionalYokedSurfaceCode, ) -from .qubits import AQREGateBased, Majorana +from .qubits import GateBased, Majorana __all__ = [ - "AQREGateBased", + "GateBased", "Litinski19Factory", "Majorana", "MagicUpToClifford", diff --git a/source/pip/qsharp/qre/models/factories/_litinski.py b/source/pip/qsharp/qre/models/factories/_litinski.py index d4f35117e4..30d3b444c6 100644 --- a/source/pip/qsharp/qre/models/factories/_litinski.py +++ b/source/pip/qsharp/qre/models/factories/_litinski.py @@ -7,7 +7,7 @@ from math import ceil from typing import Generator -from ..._architecture import _Context +from ..._architecture import ISAContext from ..._qre import ISARequirements, ConstraintBound, ISA from ..._instruction import ISATransform, constraint, LOGICAL from ...instruction_ids import T, CNOT, H, MEAS_Z, CCZ @@ -48,7 +48,9 @@ def required_isa() -> ISARequirements: constraint(MEAS_Z, error_rate=ConstraintBound.le(1e-3)), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: h = impl_isa[H] cnot = impl_isa[CNOT] meas_z = impl_isa[MEAS_Z] diff --git a/source/pip/qsharp/qre/models/factories/_round_based.py b/source/pip/qsharp/qre/models/factories/_round_based.py index aed95e1243..5f746595bd 100644 --- a/source/pip/qsharp/qre/models/factories/_round_based.py +++ b/source/pip/qsharp/qre/models/factories/_round_based.py @@ -11,7 +11,7 @@ from pathlib import Path from typing import Callable, Generator, Iterable, Optional, Sequence -from ..._qre import ISA, InstructionFrontier, ISARequirements, _Instruction, _binom_ppf +from ..._qre import ISA, InstructionFrontier, ISARequirements, Instruction, _binom_ppf from ..._instruction import ( LOGICAL, PHYSICAL, @@ -19,7 +19,7 @@ ISATransform, constraint, ) -from ..._architecture import _Context +from ..._architecture import ISAContext from ...instruction_ids import CNOT, LATTICE_SURGERY, T, MEAS_ZZ from ..qec import SurfaceCode @@ -103,7 +103,9 @@ def required_isa() -> ISARequirements: constraint(T), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: cache_path = self._cache_path(impl_isa) # 1) Try to load from cache @@ -190,7 +192,7 @@ def _physical_units(self, gate_time, clifford_error) -> list[_DistillationUnit]: ] def _logical_units( - self, lattice_surgery_instruction: _Instruction + self, lattice_surgery_instruction: Instruction ) -> list[_DistillationUnit]: logical_cycle_time = lattice_surgery_instruction.expect_time(1) logical_error = lattice_surgery_instruction.expect_error_rate(1) @@ -214,8 +216,8 @@ def _logical_units( ), ] - def _state_from_pipeline(self, pipeline: _Pipeline) -> _Instruction: - return _Instruction.fixed_arity( + def _state_from_pipeline(self, pipeline: _Pipeline) -> Instruction: + return Instruction.fixed_arity( T, int(LOGICAL), 1, diff --git a/source/pip/qsharp/qre/models/factories/_utils.py b/source/pip/qsharp/qre/models/factories/_utils.py index dcd72c6afe..a0efbc4ec5 100644 --- a/source/pip/qsharp/qre/models/factories/_utils.py +++ b/source/pip/qsharp/qre/models/factories/_utils.py @@ -3,7 +3,7 @@ from typing import Generator -from ..._architecture import _Context +from ..._architecture import ISAContext from ..._qre import ISARequirements, ISA from ..._instruction import ISATransform from ...instruction_ids import ( @@ -58,7 +58,7 @@ class MagicUpToClifford(ISATransform): def required_isa() -> ISARequirements: return ISARequirements() - def provided_isa(self, impl_isa, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa(self, impl_isa, ctx: ISAContext) -> Generator[ISA, None, None]: # Families of equivalent gates under Clifford conjugation. families = [ [ diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py index e402ea9c41..ee5cc8bace 100644 --- a/source/pip/qsharp/qre/models/qec/_surface_code.py +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -12,7 +12,7 @@ ConstraintBound, LOGICAL, ) -from ..._isa_enumeration import _Context +from ..._isa_enumeration import ISAContext from ..._qre import linear_function from ...instruction_ids import CNOT, H, LATTICE_SURGERY, MEAS_Z from ...property_keys import ( @@ -73,7 +73,9 @@ def required_isa() -> ISARequirements: constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: cnot = impl_isa[CNOT] h = impl_isa[H] meas_z = impl_isa[MEAS_Z] diff --git a/source/pip/qsharp/qre/models/qec/_three_aux.py b/source/pip/qsharp/qre/models/qec/_three_aux.py index 2af1879205..5f7cff6da3 100644 --- a/source/pip/qsharp/qre/models/qec/_three_aux.py +++ b/source/pip/qsharp/qre/models/qec/_three_aux.py @@ -6,7 +6,7 @@ from dataclasses import KW_ONLY, dataclass, field from typing import Generator -from ..._architecture import _Context +from ..._architecture import ISAContext from ..._instruction import ( LOGICAL, ISATransform, @@ -59,7 +59,9 @@ def required_isa() -> ISARequirements: constraint(MEAS_ZZ, arity=2), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: meas_x = impl_isa[MEAS_X] meas_z = impl_isa[MEAS_Z] meas_xx = impl_isa[MEAS_XX] diff --git a/source/pip/qsharp/qre/models/qec/_yoked.py b/source/pip/qsharp/qre/models/qec/_yoked.py index 8bb9bf9597..9cb1b26527 100644 --- a/source/pip/qsharp/qre/models/qec/_yoked.py +++ b/source/pip/qsharp/qre/models/qec/_yoked.py @@ -7,7 +7,7 @@ from ..._instruction import ISATransform, constraint, LOGICAL from ..._qre import ISA, ISARequirements, generic_function -from ..._architecture import _Context +from ..._architecture import ISAContext from ...instruction_ids import LATTICE_SURGERY, MEMORY from ...property_keys import DISTANCE @@ -58,7 +58,9 @@ def required_isa() -> ISARequirements: constraint(LATTICE_SURGERY, LOGICAL, arity=None, distance=True), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: lattice_surgery = impl_isa[LATTICE_SURGERY] distance = lattice_surgery.get_property(DISTANCE) assert distance is not None @@ -178,7 +180,9 @@ def required_isa() -> ISARequirements: constraint(LATTICE_SURGERY, LOGICAL, arity=None, distance=True), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: lattice_surgery = impl_isa[LATTICE_SURGERY] distance = lattice_surgery.get_property(DISTANCE) assert distance is not None diff --git a/source/pip/qsharp/qre/models/qubits/__init__.py b/source/pip/qsharp/qre/models/qubits/__init__.py index 99c9e1c156..ab7887faf3 100644 --- a/source/pip/qsharp/qre/models/qubits/__init__.py +++ b/source/pip/qsharp/qre/models/qubits/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from ._aqre import AQREGateBased +from ._gate_based import GateBased from ._msft import Majorana -__all__ = ["AQREGateBased", "Majorana"] +__all__ = ["GateBased", "Majorana"] diff --git a/source/pip/qsharp/qre/models/qubits/_aqre.py b/source/pip/qsharp/qre/models/qubits/_gate_based.py similarity index 86% rename from source/pip/qsharp/qre/models/qubits/_aqre.py rename to source/pip/qsharp/qre/models/qubits/_gate_based.py index 6e6f09b8be..d9ee589485 100644 --- a/source/pip/qsharp/qre/models/qubits/_aqre.py +++ b/source/pip/qsharp/qre/models/qubits/_gate_based.py @@ -4,7 +4,7 @@ from dataclasses import KW_ONLY, dataclass, field from typing import Optional -from ..._architecture import Architecture, _Context +from ..._architecture import Architecture, ISAContext from ..._instruction import ISA, Encoding from ...instruction_ids import ( CNOT, @@ -36,15 +36,10 @@ @dataclass -class AQREGateBased(Architecture): +class GateBased(Architecture): """ - A generic gate-based architecture based on the qubit parameters in Azure - Quantum Resource Estimator (AQRE, - [arXiv:2211.07629](https://arxiv.org/abs/2211.07629)). The error rate can - be set arbitrarily and is either 1e-3 or 1e-4 in the reference. Typical - gate times are 50ns and measurement times are 100ns for superconducting - transmon qubits - [arXiv:cond-mat/0703002](https://arxiv.org/abs/cond-mat/0703002). + A generic gate-based architecture. The error rate can be set arbitrarily + and is either 1e-3 or 1e-4 in the reference. Args: error_rate: The error rate for all gates. Defaults to 1e-4. @@ -76,7 +71,7 @@ def __post_init__(self): if self.two_qubit_gate_time is None: self.two_qubit_gate_time = self.gate_time - def provided_isa(self, ctx: _Context) -> ISA: + def provided_isa(self, ctx: ISAContext) -> ISA: # Value is initialized in __post_init__ assert self.two_qubit_gate_time is not None diff --git a/source/pip/qsharp/qre/models/qubits/_msft.py b/source/pip/qsharp/qre/models/qubits/_msft.py index 022157c1d4..1d74300e3e 100644 --- a/source/pip/qsharp/qre/models/qubits/_msft.py +++ b/source/pip/qsharp/qre/models/qubits/_msft.py @@ -3,7 +3,7 @@ from dataclasses import KW_ONLY, dataclass, field -from ..._architecture import Architecture, _Context +from ..._architecture import Architecture, ISAContext from ...instruction_ids import ( T, PREP_X, @@ -47,7 +47,7 @@ class Majorana(Architecture): _: KW_ONLY error_rate: float = field(default=1e-5, metadata={"domain": [1e-4, 1e-5, 1e-6]}) - def provided_isa(self, ctx: _Context) -> ISA: + def provided_isa(self, ctx: ISAContext) -> ISA: if abs(self.error_rate - 1e-4) <= 1e-8: t_error_rate = 0.05 elif abs(self.error_rate - 1e-5) <= 1e-8: diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 23b5f6baf7..0e9daa1686 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -205,7 +205,7 @@ impl ISARequirementsIterator { } #[allow(clippy::unsafe_derive_deserialize)] -#[pyclass(name = "_Instruction")] +#[pyclass(from_py_object)] #[derive(Clone, Serialize, Deserialize)] #[serde(transparent)] pub struct Instruction(qre::Instruction); @@ -566,7 +566,7 @@ impl ConstraintBound { } #[derive(Clone)] -#[pyclass(name = "_ProvenanceGraph")] +#[pyclass(name = "_ProvenanceGraph", from_py_object)] pub struct ProvenanceGraph(Arc>); impl Default for ProvenanceGraph { diff --git a/source/pip/tests/qre/__init__.py b/source/pip/tests/qre/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/source/pip/tests/qre/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/source/pip/tests/qre/conftest.py b/source/pip/tests/qre/conftest.py new file mode 100644 index 0000000000..c779e6ff31 --- /dev/null +++ b/source/pip/tests/qre/conftest.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field +from typing import Generator + +from qsharp.qre import ( + ISA, + LOGICAL, + ISARequirements, + ISATransform, + constraint, +) +from qsharp.qre._architecture import ISAContext +from qsharp.qre.instruction_ids import LATTICE_SURGERY, T + + +# NOTE These classes will be generalized as part of the QRE API in the following +# pull requests and then moved out of the tests. + + +@dataclass +class ExampleFactory(ISATransform): + _: KW_ONLY + level: int = field(default=1, metadata={"domain": range(1, 4)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(T), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + yield ctx.make_isa( + ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + ) + + +@dataclass +class ExampleLogicalFactory(ISATransform): + _: KW_ONLY + level: int = field(default=1, metadata={"domain": range(1, 4)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(LATTICE_SURGERY, encoding=LOGICAL), + constraint(T, encoding=LOGICAL), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + yield ctx.make_isa( + ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), + ) diff --git a/source/pip/tests/qre/test_application.py b/source/pip/tests/qre/test_application.py new file mode 100644 index 0000000000..6b73222e12 --- /dev/null +++ b/source/pip/tests/qre/test_application.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass, field + +import qsharp + +from qsharp.qre import ( + Application, + ISA, + LOGICAL, + PSSPC, + EstimationResult, + LatticeSurgery, + Trace, + linear_function, +) +from qsharp.qre._qre import _ProvenanceGraph +from qsharp.qre._enumeration import _enumerate_instances +from qsharp.qre.application import QSharpApplication +from qsharp.qre.instruction_ids import CCX, LATTICE_SURGERY, T, RZ +from qsharp.qre.property_keys import ( + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, + LOGICAL_COMPUTE_QUBITS, + LOGICAL_MEMORY_QUBITS, +) + + +def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): + actual_qubits = ( + isa[LATTICE_SURGERY].expect_space(trace.compute_qubits) + + isa[T].expect_space() * result.factories[T].copies + ) + if CCX in trace.resource_states: + actual_qubits += isa[CCX].expect_space() * result.factories[CCX].copies + assert result.qubits == actual_qubits + + assert ( + result.runtime + == isa[LATTICE_SURGERY].expect_time(trace.compute_qubits) * trace.depth + ) + + actual_error = ( + trace.base_error + + isa[LATTICE_SURGERY].expect_error_rate(trace.compute_qubits) * trace.depth + + isa[T].expect_error_rate() * result.factories[T].states + ) + if CCX in trace.resource_states: + actual_error += isa[CCX].expect_error_rate() * result.factories[CCX].states + assert abs(result.error - actual_error) <= 1e-8 + + +def test_trace_properties(): + trace = Trace(42) + + INT = 0 + FLOAT = 1 + BOOL = 2 + STR = 3 + + trace.set_property(INT, 42) + assert trace.get_property(INT) == 42 + assert isinstance(trace.get_property(INT), int) + + trace.set_property(FLOAT, 3.14) + assert trace.get_property(FLOAT) == 3.14 + assert isinstance(trace.get_property(FLOAT), float) + + trace.set_property(BOOL, True) + assert trace.get_property(BOOL) is True + assert isinstance(trace.get_property(BOOL), bool) + + trace.set_property(STR, "hello") + assert trace.get_property(STR) == "hello" + assert isinstance(trace.get_property(STR), str) + + +def test_qsharp_application(): + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + trace = app.get_trace() + + assert trace.compute_qubits == 3 + assert trace.depth == 3 + assert trace.resource_states == {} + + assert {c.id for c in trace.required_isa} == {CCX, T, RZ} + + graph = _ProvenanceGraph() + isa = graph.make_isa( + [ + graph.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + time=1000, + space=linear_function(50), + error_rate=linear_function(1e-6), + ), + graph.add_instruction( + T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 + ), + graph.add_instruction( + CCX, encoding=LOGICAL, time=2000, space=800, error_rate=1e-10 + ), + ] + ) + + # Properties from the program + counts = qsharp.logical_counts(code) + num_ts = counts["tCount"] + num_ccx = counts["cczCount"] + num_rotations = counts["rotationCount"] + rotation_depth = counts["rotationDepth"] + + lattice_surgery = LatticeSurgery() + + counter = 0 + for psspc in _enumerate_instances(PSSPC): + counter += 1 + trace2 = psspc.transform(trace) + assert trace2 is not None + trace2 = lattice_surgery.transform(trace2) + assert trace2 is not None + assert trace2.compute_qubits == 12 + assert ( + trace2.depth + == num_ts + + num_ccx * 3 + + num_rotations + + rotation_depth * psspc.num_ts_per_rotation + ) + if psspc.ccx_magic_states: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations, + CCX: num_ccx, + } + assert {c.id for c in trace2.required_isa} == {CCX, T, LATTICE_SURGERY} + else: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx + } + assert {c.id for c in trace2.required_isa} == {T, LATTICE_SURGERY} + assert trace2.get_property(ALGORITHM_COMPUTE_QUBITS) == 3 + assert trace2.get_property(ALGORITHM_MEMORY_QUBITS) == 0 + result = trace2.estimate(isa, max_error=float("inf")) + assert result is not None + assert result.properties[ALGORITHM_COMPUTE_QUBITS] == 3 + assert result.properties[ALGORITHM_MEMORY_QUBITS] == 0 + assert result.properties[LOGICAL_COMPUTE_QUBITS] == 12 + assert result.properties[LOGICAL_MEMORY_QUBITS] == 0 + _assert_estimation_result(trace2, result, isa) + assert counter == 32 + + +def test_application_enumeration(): + @dataclass(kw_only=True) + class _Params: + size: int = field(default=1, metadata={"domain": range(1, 4)}) + + class TestApp(Application[_Params]): + def get_trace(self, parameters: _Params) -> Trace: + return Trace(parameters.size) + + app = TestApp() + assert sum(1 for _ in TestApp.q().enumerate(app.context())) == 3 + assert sum(1 for _ in TestApp.q(size=1).enumerate(app.context())) == 1 + assert sum(1 for _ in TestApp.q(size=[4, 5]).enumerate(app.context())) == 2 + + +def test_trace_enumeration(): + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + + ctx = app.context() + assert sum(1 for _ in QSharpApplication.q().enumerate(ctx)) == 1 + + assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 32 + + assert sum(1 for _ in LatticeSurgery.q().enumerate(ctx)) == 1 + + q = PSSPC.q() * LatticeSurgery.q() + assert sum(1 for _ in q.enumerate(ctx)) == 32 + + +def test_rotation_error_psspc(): + # This test helps to bound the variables for the number of rotations in PSSPC + + # Create a trace with a single rotation gate and ensure that the base error + # after PSSPC transformation is less than 1. + trace = Trace(1) + trace.add_operation(RZ, [0]) + + for psspc in _enumerate_instances(PSSPC, ccx_magic_states=False): + transformed = psspc.transform(trace) + assert transformed is not None + assert ( + transformed.base_error < 1.0 + ), f"Base error too high: {transformed.base_error} for {psspc.num_ts_per_rotation} T states per rotation" diff --git a/source/pip/tests/qre/test_enumeration.py b/source/pip/tests/qre/test_enumeration.py new file mode 100644 index 0000000000..476e65f22b --- /dev/null +++ b/source/pip/tests/qre/test_enumeration.py @@ -0,0 +1,527 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field +from enum import Enum +from typing import cast + +import pytest + +from qsharp.qre import LOGICAL +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._isa_enumeration import ( + ISARefNode, + _ComponentQuery, + _ProductNode, + _SumNode, +) + +from .conftest import ExampleFactory, ExampleLogicalFactory + + +def test_enumerate_instances(): + from qsharp.qre._enumeration import _enumerate_instances + + instances = list(_enumerate_instances(SurfaceCode)) + + # There are 12 instances with distances from 3 to 25 + assert len(instances) == 12 + expected_distances = list(range(3, 26, 2)) + for instance, expected_distance in zip(instances, expected_distances): + assert instance.distance == expected_distance + + # Test with specific distances + instances = list(_enumerate_instances(SurfaceCode, distance=[3, 5, 7])) + assert len(instances) == 3 + expected_distances = [3, 5, 7] + for instance, expected_distance in zip(instances, expected_distances): + assert instance.distance == expected_distance + + # Test with fixed distance + instances = list(_enumerate_instances(SurfaceCode, distance=9)) + assert len(instances) == 1 + assert instances[0].distance == 9 + + +def test_enumerate_instances_bool(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class BoolConfig: + _: KW_ONLY + flag: bool + + instances = list(_enumerate_instances(BoolConfig)) + assert len(instances) == 2 + assert instances[0].flag is True + assert instances[1].flag is False + + +def test_enumerate_instances_enum(): + from qsharp.qre._enumeration import _enumerate_instances + + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + @dataclass + class EnumConfig: + _: KW_ONLY + color: Color + + instances = list(_enumerate_instances(EnumConfig)) + assert len(instances) == 3 + assert instances[0].color == Color.RED + assert instances[1].color == Color.GREEN + assert instances[2].color == Color.BLUE + + +def test_enumerate_instances_failure(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InvalidConfig: + _: KW_ONLY + # This field has no domain, is not bool/enum, and has no default + value: int + + with pytest.raises(ValueError, match="Cannot enumerate field value"): + list(_enumerate_instances(InvalidConfig)) + + +def test_enumerate_instances_single(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class SingleConfig: + value: int = 42 + + instances = list(_enumerate_instances(SingleConfig)) + assert len(instances) == 1 + assert instances[0].value == 42 + + +def test_enumerate_instances_literal(): + from qsharp.qre._enumeration import _enumerate_instances + + from typing import Literal + + @dataclass + class LiteralConfig: + _: KW_ONLY + mode: Literal["fast", "slow"] + + instances = list(_enumerate_instances(LiteralConfig)) + assert len(instances) == 2 + assert instances[0].mode == "fast" + assert instances[1].mode == "slow" + + +def test_enumerate_instances_nested(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InnerConfig: + _: KW_ONLY + option: bool + + @dataclass + class OuterConfig: + _: KW_ONLY + inner: InnerConfig + + instances = list(_enumerate_instances(OuterConfig)) + assert len(instances) == 2 + assert instances[0].inner.option is True + assert instances[1].inner.option is False + + +def test_enumerate_instances_union(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + instances = list(_enumerate_instances(UnionConfig)) + assert len(instances) == 5 + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + assert isinstance(instances[2].option, OptionB) + assert instances[2].option.number == 1 + + +def test_enumerate_instances_nested_with_constraints(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InnerConfig: + _: KW_ONLY + option: bool + + @dataclass + class OuterConfig: + _: KW_ONLY + inner: InnerConfig + + # Constrain nested field via dict + instances = list(_enumerate_instances(OuterConfig, inner={"option": True})) + assert len(instances) == 1 + assert instances[0].inner.option is True + + +def test_enumerate_instances_union_single_type(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + # Restrict to OptionB only - uses its default domain + instances = list(_enumerate_instances(UnionConfig, option=OptionB)) + assert len(instances) == 3 + assert all(isinstance(i.option, OptionB) for i in instances) + assert [cast(OptionB, i.option).number for i in instances] == [1, 2, 3] + + # Restrict to OptionA only + instances = list(_enumerate_instances(UnionConfig, option=OptionA)) + assert len(instances) == 2 + assert all(isinstance(i.option, OptionA) for i in instances) + assert cast(OptionA, instances[0].option).value is True + assert cast(OptionA, instances[1].option).value is False + + +def test_enumerate_instances_union_list_of_types(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class OptionC: + _: KW_ONLY + flag: bool + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB | OptionC + + # Select a subset: only OptionA and OptionB + instances = list(_enumerate_instances(UnionConfig, option=[OptionA, OptionB])) + assert len(instances) == 5 # 2 from OptionA + 3 from OptionB + assert all(isinstance(i.option, (OptionA, OptionB)) for i in instances) + + +def test_enumerate_instances_union_constraint_dict(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + # Constrain OptionA, enumerate only that member + instances = list( + _enumerate_instances(UnionConfig, option={OptionA: {"value": True}}) + ) + assert len(instances) == 1 + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + + # Constrain OptionB with a domain, enumerate only that member + instances = list( + _enumerate_instances(UnionConfig, option={OptionB: {"number": [2, 3]}}) + ) + assert len(instances) == 2 + assert all(isinstance(i.option, OptionB) for i in instances) + assert cast(OptionB, instances[0].option).number == 2 + assert cast(OptionB, instances[1].option).number == 3 + + # Constrain one member and keep another with defaults + instances = list( + _enumerate_instances( + UnionConfig, + option={OptionA: {"value": True}, OptionB: {}}, + ) + ) + assert len(instances) == 4 # 1 from OptionA + 3 from OptionB + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + assert all(isinstance(i.option, OptionB) for i in instances[1:]) + assert [cast(OptionB, i.option).number for i in instances[1:]] == [1, 2, 3] + + +def test_enumerate_isas(): + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # This will enumerate the 4 ISAs for the error correction code + count = sum(1 for _ in SurfaceCode.q().enumerate(ctx)) + assert count == 12 + + # This will enumerate the 2 ISAs for the error correction code when + # restricting the domain + count = sum(1 for _ in SurfaceCode.q(distance=[3, 4]).enumerate(ctx)) + assert count == 2 + + # This will enumerate the 3 ISAs for the factory + count = sum(1 for _ in ExampleFactory.q().enumerate(ctx)) + assert count == 3 + + # This will enumerate 36 ISAs for all products between the 12 error + # correction code ISAs and the 3 factory ISAs + count = sum(1 for _ in (SurfaceCode.q() * ExampleFactory.q()).enumerate(ctx)) + assert count == 36 + + # When providing a list, components are chained (OR operation). This + # enumerates ISAs from first factory instance OR second factory instance + count = sum( + 1 + for _ in ( + SurfaceCode.q() * (ExampleFactory.q() + ExampleFactory.q()) + ).enumerate(ctx) + ) + assert count == 72 + + # When providing separate arguments, components are combined via product + # (AND). This enumerates ISAs from first factory instance AND second + # factory instance + count = sum( + 1 + for _ in (SurfaceCode.q() * ExampleFactory.q() * ExampleFactory.q()).enumerate( + ctx + ) + ) + assert count == 108 + + # Hierarchical factory using from_components: the component receives ISAs + # from the product of other components as its source + count = sum( + 1 + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q(source=(SurfaceCode.q() * ExampleFactory.q())) + ).enumerate(ctx) + ) + assert count == 1296 + + +def test_binding_node(): + """Test binding nodes with ISARefNode for component bindings""" + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # Test basic binding: same code used twice + # Without binding: 12 codes × 12 codes = 144 combinations + count_without = sum(1 for _ in (SurfaceCode.q() * SurfaceCode.q()).enumerate(ctx)) + assert count_without == 144 + + # With binding: 12 codes (same instance used twice) + count_with = sum( + 1 + for _ in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx) + ) + assert count_with == 12 + + # Verify the binding works: with binding, both should use same params + for isa in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx): + logical_gates = [g for g in isa if g.encoding == LOGICAL] + # Should have 1 logical gate (LATTICE_SURGERY) + assert len(logical_gates) == 1 + + # Test binding with factories (nested bindings) + count_without = sum( + 1 + for _ in ( + SurfaceCode.q() * ExampleFactory.q() * SurfaceCode.q() * ExampleFactory.q() + ).enumerate(ctx) + ) + assert count_without == 1296 # 12 * 3 * 12 * 3 + + count_with = sum( + 1 + for _ in SurfaceCode.bind( + "c", + ExampleFactory.bind( + "f", + ISARefNode("c") * ISARefNode("f") * ISARefNode("c") * ISARefNode("f"), + ), + ).enumerate(ctx) + ) + assert count_with == 36 # 12 * 3 + + # Test binding with from_components equivalent (hierarchical) + # Without binding: 4 outer codes × (4 inner codes × 3 factories × 3 levels) + count_without = sum( + 1 + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q( + source=(SurfaceCode.q() * ExampleFactory.q()), + ) + ).enumerate(ctx) + ) + assert count_without == 1296 # 12 * 12 * 3 * 3 + + # With binding: 4 codes (same used twice) × 3 factories × 3 levels + count_with = sum( + 1 + for _ in SurfaceCode.bind( + "c", + ISARefNode("c") + * ExampleLogicalFactory.q( + source=(ISARefNode("c") * ExampleFactory.q()), + ), + ).enumerate(ctx) + ) + assert count_with == 108 # 12 * 3 * 3 + + # Test binding with kwargs + count_with_kwargs = sum( + 1 + for _ in SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) + ) + assert count_with_kwargs == 1 # Only distance=5 + + # Verify kwargs are applied + for isa in ( + SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) + ): + logical_gates = [g for g in isa if g.encoding == LOGICAL] + assert all(g.space(1) == 49 for g in logical_gates) + + # Test multiple independent bindings (nested) + count = sum( + 1 + for _ in SurfaceCode.bind( + "c1", + ExampleFactory.bind( + "c2", + ISARefNode("c1") + * ISARefNode("c1") + * ISARefNode("c2") + * ISARefNode("c2"), + ), + ).enumerate(ctx) + ) + # 12 codes for c1 × 3 factories for c2 + assert count == 36 + + +def test_binding_node_errors(): + """Test error handling for binding nodes""" + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # Test ISARefNode enumerate with undefined binding raises ValueError + try: + list(ISARefNode("test").enumerate(ctx)) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Undefined component reference: 'test'" in str(e) + + +def test_product_isa_enumeration_nodes(): + terminal = SurfaceCode.q() + query = terminal * terminal + + # Multiplication should create ProductNode + assert isinstance(query, _ProductNode) + assert len(query.sources) == 2 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Multiplying again should extend the sources + query = query * terminal + assert isinstance(query, _ProductNode) + assert len(query.sources) == 3 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also from the other side + query = terminal * query + assert isinstance(query, _ProductNode) + assert len(query.sources) == 4 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also for two ProductNodes + query = query * query + assert isinstance(query, _ProductNode) + assert len(query.sources) == 8 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + +def test_sum_isa_enumeration_nodes(): + terminal = SurfaceCode.q() + query = terminal + terminal + + # Multiplication should create SumNode + assert isinstance(query, _SumNode) + assert len(query.sources) == 2 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Multiplying again should extend the sources + query = query + terminal + assert isinstance(query, _SumNode) + assert len(query.sources) == 3 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also from the other side + query = terminal + query + assert isinstance(query, _SumNode) + assert len(query.sources) == 4 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also for two SumNodes + query = query + query + assert isinstance(query, _SumNode) + assert len(query.sources) == 8 + for source in query.sources: + assert isinstance(source, _ComponentQuery) diff --git a/source/pip/tests/qre/test_estimation.py b/source/pip/tests/qre/test_estimation.py new file mode 100644 index 0000000000..bb857115ed --- /dev/null +++ b/source/pip/tests/qre/test_estimation.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +import pytest + +from qsharp.estimator import LogicalCounts +from qsharp.qre import ( + PSSPC, + LatticeSurgery, + estimate, +) +from qsharp.qre.application import QSharpApplication +from qsharp.qre.models import ( + SurfaceCode, + GateBased, + RoundBasedFactory, + TwoDimensionalYokedSurfaceCode, +) + +from .conftest import ExampleFactory + + +def test_estimation_max_error(): + app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) + arch = GateBased(gate_time=50, measurement_time=100) + + for max_error in [1e-1, 1e-2, 1e-3, 1e-4]: + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=max_error, + ) + + assert len(results) == 1 + assert next(iter(results)).error <= max_error + + +@pytest.mark.skipif( + "SLOW_TESTS" not in os.environ, + reason="turn on slow tests by setting SLOW_TESTS=1 in the environment", +) +@pytest.mark.parametrize( + "post_process, use_graph", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_estimation_methods(post_process, use_graph): + counts = LogicalCounts( + { + "numQubits": 1000, + "tCount": 1_500_000, + "rotationCount": 0, + "rotationDepth": 0, + "cczCount": 1_000_000_000, + "ccixCount": 0, + "measurementCount": 25_000_000, + "numComputeQubits": 200, + "readFromMemoryCount": 30_000_000, + "writeToMemoryCount": 30_000_000, + } + ) + + trace_query = PSSPC.q() * LatticeSurgery.q(slow_down_factor=[1.0, 2.0]) + isa_query = ( + SurfaceCode.q() + * RoundBasedFactory.q() + * TwoDimensionalYokedSurfaceCode.q(source=SurfaceCode.q()) + ) + + app = QSharpApplication(counts) + arch = GateBased(gate_time=50, measurement_time=100) + + results = estimate( + app, + arch, + isa_query, + trace_query, + max_error=1 / 3, + post_process=post_process, + use_graph=use_graph, + ) + results.add_factory_summary_column() + + assert [(result.qubits, result.runtime) for result in results] == [ + (238707, 23997050000000), + (240407, 11998525000000), + ] + + print() + print(results.stats) diff --git a/source/pip/tests/qre/test_estimation_table.py b/source/pip/tests/qre/test_estimation_table.py new file mode 100644 index 0000000000..d2a25ae31b --- /dev/null +++ b/source/pip/tests/qre/test_estimation_table.py @@ -0,0 +1,439 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import cast, Sized + +import pytest +import pandas as pd + +from qsharp.qre import ( + PSSPC, + LatticeSurgery, + estimate, +) +from qsharp.qre.application import QSharpApplication +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._estimation import ( + EstimationTable, + EstimationTableEntry, +) +from qsharp.qre._instruction import InstructionSource +from qsharp.qre.instruction_ids import LATTICE_SURGERY +from qsharp.qre.property_keys import DISTANCE, NUM_TS_PER_ROTATION + +from .conftest import ExampleFactory + + +def _make_entry(qubits, runtime, error, properties=None): + """Helper to create an EstimationTableEntry with a dummy InstructionSource.""" + return EstimationTableEntry( + qubits=qubits, + runtime=runtime, + error=error, + source=InstructionSource(), + properties=properties or {}, + ) + + +def test_estimation_table_default_columns(): + """Test that a new EstimationTable has the three default columns.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error"] + assert frame["qubits"][0] == 100 + assert frame["runtime"][0] == pd.Timedelta(5000, unit="ns") + assert frame["error"][0] == 0.01 + + +def test_estimation_table_multiple_rows(): + """Test as_frame with multiple entries.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + table.append(_make_entry(200, 10000, 0.02)) + table.append(_make_entry(300, 15000, 0.03)) + + frame = table.as_frame() + assert len(frame) == 3 + assert list(frame["qubits"]) == [100, 200, 300] + assert list(frame["error"]) == [0.01, 0.02, 0.03] + + +def test_estimation_table_empty(): + """Test as_frame with no entries produces an empty DataFrame.""" + table = EstimationTable() + frame = table.as_frame() + assert len(frame) == 0 + + +def test_estimation_table_add_column(): + """Test adding a column to the table.""" + VAL = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={VAL: 42})) + table.append(_make_entry(200, 10000, 0.02, properties={VAL: 84})) + + table.add_column("val", lambda e: e.properties[VAL]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "val"] + assert list(frame["val"]) == [42, 84] + + +def test_estimation_table_add_column_with_formatter(): + """Test adding a column with a formatter.""" + NS = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NS: 1000})) + + table.add_column( + "duration", + lambda e: e.properties[NS], + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ) + + frame = table.as_frame() + assert frame["duration"][0] == pd.Timedelta(1000, unit="ns") + + +def test_estimation_table_add_multiple_columns(): + """Test adding multiple columns preserves order.""" + A = 0 + B = 1 + C = 2 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2, C: 3})) + + table.add_column("a", lambda e: e.properties[A]) + table.add_column("b", lambda e: e.properties[B]) + table.add_column("c", lambda e: e.properties[C]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "a", "b", "c"] + assert frame["a"][0] == 1 + assert frame["b"][0] == 2 + assert frame["c"][0] == 3 + + +def test_estimation_table_insert_column_at_beginning(): + """Test inserting a column at index 0.""" + NAME = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NAME: "test"})) + + table.insert_column(0, "name", lambda e: e.properties[NAME]) + + frame = table.as_frame() + assert list(frame.columns) == ["name", "qubits", "runtime", "error"] + assert frame["name"][0] == "test" + + +def test_estimation_table_insert_column_in_middle(): + """Test inserting a column between existing default columns.""" + EXTRA = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={EXTRA: 99})) + + # Insert between qubits and runtime (index 1) + table.insert_column(1, "extra", lambda e: e.properties[EXTRA]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "extra", "runtime", "error"] + assert frame["extra"][0] == 99 + + +def test_estimation_table_insert_column_at_end(): + """Test inserting a column at the end (same effect as add_column).""" + LAST = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={LAST: True})) + + # 3 default columns, inserting at index 3 = end + table.insert_column(3, "last", lambda e: e.properties[LAST]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "last"] + assert frame["last"][0] + + +def test_estimation_table_insert_column_with_formatter(): + """Test inserting a column with a formatter.""" + NS = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NS: 2000})) + + table.insert_column( + 0, + "custom_time", + lambda e: e.properties[NS], + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ) + + frame = table.as_frame() + assert frame["custom_time"][0] == pd.Timedelta(2000, unit="ns") + assert list(frame.columns)[0] == "custom_time" + + +def test_estimation_table_insert_and_add_columns(): + """Test combining insert_column and add_column.""" + A = 0 + B = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2})) + + table.add_column("b", lambda e: e.properties[B]) + table.insert_column(0, "a", lambda e: e.properties[A]) + + frame = table.as_frame() + assert list(frame.columns) == ["a", "qubits", "runtime", "error", "b"] + + +def test_estimation_table_factory_summary_no_factories(): + """Test factory summary column when entries have no factories.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + + table.add_factory_summary_column() + + frame = table.as_frame() + assert "factories" in frame.columns + assert frame["factories"][0] == "None" + + +def test_estimation_table_factory_summary_with_estimation(): + """Test factory summary column with real estimation results.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_factory_summary_column() + frame = results.as_frame() + + assert "factories" in frame.columns + # Each result should mention T in the factory summary + for val in frame["factories"]: + assert "T" in val + + +def test_estimation_table_add_column_from_source(): + """Test adding a column that accesses the InstructionSource (like distance).""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_column( + "compute_distance", + lambda entry: entry.source[LATTICE_SURGERY].instruction[DISTANCE], + ) + + frame = results.as_frame() + assert "compute_distance" in frame.columns + for d in frame["compute_distance"]: + assert isinstance(d, int) + assert d >= 3 + + +def test_estimation_table_add_column_from_properties(): + """Test adding columns that access trace properties from estimation.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_column( + "num_ts_per_rotation", + lambda entry: entry.properties[NUM_TS_PER_ROTATION], + ) + + frame = results.as_frame() + assert "num_ts_per_rotation" in frame.columns + for val in frame["num_ts_per_rotation"]: + assert isinstance(val, int) + assert val >= 1 + + +def test_estimation_table_insert_column_before_defaults(): + """Test inserting a name column before all default columns, similar to the factoring notebook.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + name="test_experiment", + ) + + assert len(results) >= 1 + + # Add a factory summary at the end + results.add_factory_summary_column() + + frame = results.as_frame() + assert frame.columns[0] == "name" + assert frame.columns[-1] == "factories" + # Default columns should still be in order + assert list(frame.columns[1:4]) == ["qubits", "runtime", "error"] + + +def test_estimation_table_as_frame_sortable(): + """Test that the DataFrame from as_frame can be sorted, as done in the factoring tests.""" + table = EstimationTable() + table.append(_make_entry(300, 15000, 0.03)) + table.append(_make_entry(100, 5000, 0.01)) + table.append(_make_entry(200, 10000, 0.02)) + + frame = table.as_frame() + sorted_frame = frame.sort_values(by=["qubits", "runtime"]).reset_index(drop=True) + + assert list(sorted_frame["qubits"]) == [100, 200, 300] + assert list(sorted_frame["error"]) == [0.01, 0.02, 0.03] + + +def test_estimation_table_computed_column(): + """Test adding a column that computes a derived value from the entry.""" + table = EstimationTable() + table.append(_make_entry(100, 5_000_000, 0.01)) + table.append(_make_entry(200, 10_000_000, 0.02)) + + # Compute qubits * error as a derived metric + table.add_column("qubit_error_product", lambda e: e.qubits * e.error) + + frame = table.as_frame() + assert frame["qubit_error_product"][0] == pytest.approx(1.0) + assert frame["qubit_error_product"][1] == pytest.approx(4.0) + + +def test_estimation_table_plot_returns_figure(): + """Test that plot() returns a matplotlib Figure with correct axes.""" + from matplotlib.figure import Figure + + table = EstimationTable() + table.append(_make_entry(100, 5_000_000_000, 0.01)) + table.append(_make_entry(200, 10_000_000_000, 0.02)) + table.append(_make_entry(50, 50_000_000_000, 0.005)) + + fig = table.plot() + + assert isinstance(fig, Figure) + ax = fig.axes[0] + assert ax.get_ylabel() == "Physical qubits" + assert ax.get_xlabel() == "Runtime" + assert ax.get_xscale() == "log" + assert ax.get_yscale() == "log" + + # Verify data points + offsets = ax.collections[0].get_offsets() + assert len(cast(Sized, offsets)) == 3 + + +def test_estimation_table_plot_empty_raises(): + """Test that plot() raises ValueError on an empty table.""" + table = EstimationTable() + with pytest.raises(ValueError, match="Cannot plot an empty EstimationTable"): + table.plot() + + +def test_estimation_table_plot_single_entry(): + """Test that plot() works with a single entry.""" + from matplotlib.figure import Figure + + table = EstimationTable() + table.append(_make_entry(100, 1_000_000, 0.01)) + + fig = table.plot() + assert isinstance(fig, Figure) + + offsets = fig.axes[0].collections[0].get_offsets() + assert len(cast(Sized, offsets)) == 1 + + +def test_estimation_table_plot_with_runtime_unit(): + """Test that plot(runtime_unit=...) scales x values and labels the axis.""" + table = EstimationTable() + # 1 hour = 3600e9 ns, 2 hours = 7200e9 ns + table.append(_make_entry(100, int(3600e9), 0.01)) + table.append(_make_entry(200, int(7200e9), 0.02)) + + fig = table.plot(runtime_unit="hours") + + ax = fig.axes[0] + assert ax.get_xlabel() == "Runtime (hours)" + + # Verify the x data is scaled: should be 1.0 and 2.0 hours + offsets = cast(list, ax.collections[0].get_offsets()) + assert offsets[0][0] == pytest.approx(1.0) + assert offsets[1][0] == pytest.approx(2.0) + + +def test_estimation_table_plot_invalid_runtime_unit(): + """Test that plot() raises ValueError for an unknown runtime_unit.""" + table = EstimationTable() + table.append(_make_entry(100, 1000, 0.01)) + with pytest.raises(ValueError, match="Unknown runtime_unit"): + table.plot(runtime_unit="fortnights") diff --git a/source/pip/tests/qre/test_interop.py b/source/pip/tests/qre/test_interop.py new file mode 100644 index 0000000000..a8f7900abb --- /dev/null +++ b/source/pip/tests/qre/test_interop.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path + +import pytest + +from qsharp.qre.interop import trace_from_qir + + +def _ll_files(): + ll_dir = ( + Path(__file__).parent.parent.parent + / "tests-integration" + / "resources" + / "adaptive_ri" + / "output" + ) + return sorted(ll_dir.glob("*.ll")) + + +@pytest.mark.parametrize("ll_file", _ll_files(), ids=lambda p: p.stem) +def test_trace_from_qir(ll_file): + # NOTE: This test is primarily to ensure that the function can parse real + # QIR output without errors, rather than checking specific properties of the + # trace. + try: + trace_from_qir(ll_file.read_text()) + except ValueError as e: + # The only reason of failure is presence of control flow + assert ( + str(e) + == "simulation of programs with branching control flow is not supported" + ) + + +def test_trace_from_qir_handles_all_instruction_ids(): + """Verify that trace_from_qir handles every QirInstructionId except CorrelatedNoise. + + Generates a synthetic QIR program containing one instance of each gate + intrinsic recognised by AggregateGatesPass and asserts that trace_from_qir + processes all of them without error. + """ + import pyqir + import pyqir.qis as qis + from qsharp._native import QirInstructionId + from qsharp.qre.interop._qir import _GATE_MAP, _MEAS_MAP, _SKIP + + # -- Completeness check: every QirInstructionId must be covered -------- + handled_ids = ( + [qir_id for qir_id, _, _ in _GATE_MAP] + + [qir_id for qir_id, _ in _MEAS_MAP] + + list(_SKIP) + ) + # Exhaustive list of all QirInstructionId variants (pyo3 enums are not iterable) + all_ids = [ + QirInstructionId.I, + QirInstructionId.H, + QirInstructionId.X, + QirInstructionId.Y, + QirInstructionId.Z, + QirInstructionId.S, + QirInstructionId.SAdj, + QirInstructionId.SX, + QirInstructionId.SXAdj, + QirInstructionId.T, + QirInstructionId.TAdj, + QirInstructionId.CNOT, + QirInstructionId.CX, + QirInstructionId.CY, + QirInstructionId.CZ, + QirInstructionId.CCX, + QirInstructionId.SWAP, + QirInstructionId.RX, + QirInstructionId.RY, + QirInstructionId.RZ, + QirInstructionId.RXX, + QirInstructionId.RYY, + QirInstructionId.RZZ, + QirInstructionId.RESET, + QirInstructionId.M, + QirInstructionId.MResetZ, + QirInstructionId.MZ, + QirInstructionId.Move, + QirInstructionId.ReadResult, + QirInstructionId.ResultRecordOutput, + QirInstructionId.BoolRecordOutput, + QirInstructionId.IntRecordOutput, + QirInstructionId.DoubleRecordOutput, + QirInstructionId.TupleRecordOutput, + QirInstructionId.ArrayRecordOutput, + QirInstructionId.CorrelatedNoise, + ] + unhandled = [ + i + for i in all_ids + if i not in handled_ids and i != QirInstructionId.CorrelatedNoise + ] + assert unhandled == [], ( + f"QirInstructionId values not covered by _GATE_MAP, _MEAS_MAP, or _SKIP: " + f"{', '.join(str(i) for i in unhandled)}" + ) + + # -- Generate a QIR program with every producible gate ----------------- + simple = pyqir.SimpleModule("test_all_gates", num_qubits=4, num_results=3) + builder = simple.builder + ctx = simple.context + q = simple.qubits + r = simple.results + + void_ty = pyqir.Type.void(ctx) + qubit_ty = pyqir.qubit_type(ctx) + result_ty = pyqir.result_type(ctx) + double_ty = pyqir.Type.double(ctx) + i64_ty = pyqir.IntType(ctx, 64) + + def declare(name, param_types): + return simple.add_external_function( + name, pyqir.FunctionType(void_ty, param_types) + ) + + # Single-qubit gates (pyqir.qis builtins) + qis.h(builder, q[0]) + qis.x(builder, q[0]) + qis.y(builder, q[0]) + qis.z(builder, q[0]) + qis.s(builder, q[0]) + qis.s_adj(builder, q[0]) + qis.t(builder, q[0]) + qis.t_adj(builder, q[0]) + + # SX — not in pyqir.qis + sx_fn = declare("__quantum__qis__sx__body", [qubit_ty]) + builder.call(sx_fn, [q[0]]) + + # Two-qubit gates (qis.cx emits __quantum__qis__cnot__body which the + # pass does not handle, so use builder.call with the correct name) + cx_fn = declare("__quantum__qis__cx__body", [qubit_ty, qubit_ty]) + builder.call(cx_fn, [q[0], q[1]]) + qis.cz(builder, q[0], q[1]) + qis.swap(builder, q[0], q[1]) + + cy_fn = declare("__quantum__qis__cy__body", [qubit_ty, qubit_ty]) + builder.call(cy_fn, [q[0], q[1]]) + + # Three-qubit gate + qis.ccx(builder, q[0], q[1], q[2]) + + # Single-qubit rotations + qis.rx(builder, 1.0, q[0]) + qis.ry(builder, 1.0, q[0]) + qis.rz(builder, 1.0, q[0]) + + # Two-qubit rotations — not in pyqir.qis + rot2_ty = [double_ty, qubit_ty, qubit_ty] + angle = pyqir.const(double_ty, 1.0) + for name in ("rxx", "ryy", "rzz"): + fn = declare(f"__quantum__qis__{name}__body", rot2_ty) + builder.call(fn, [angle, q[0], q[1]]) + + # Measurements + qis.mz(builder, q[0], r[0]) + + m_fn = declare("__quantum__qis__m__body", [qubit_ty, result_ty]) + builder.call(m_fn, [q[1], r[1]]) + + mresetz_fn = declare("__quantum__qis__mresetz__body", [qubit_ty, result_ty]) + builder.call(mresetz_fn, [q[2], r[2]]) + + # Reset / Move + qis.reset(builder, q[0]) + + move_fn = declare("__quantum__qis__move__body", [qubit_ty]) + builder.call(move_fn, [q[0]]) + + # Output recording + tag = simple.add_byte_string(b"tag") + arr_fn = declare("__quantum__rt__array_record_output", [i64_ty, tag.type]) + builder.call(arr_fn, [pyqir.const(i64_ty, 1), tag]) + + rec_fn = declare("__quantum__rt__result_record_output", [result_ty, tag.type]) + builder.call(rec_fn, [r[0], tag]) + + tup_fn = declare("__quantum__rt__tuple_record_output", [i64_ty, tag.type]) + builder.call(tup_fn, [pyqir.const(i64_ty, 1), tag]) + + # -- Run trace_from_qir and verify it succeeds ------------------------- + trace = trace_from_qir(simple.ir()) + assert trace is not None + + +def test_rotation_buckets(): + from qsharp.qre.interop._qsharp import _bucketize_rotation_counts + + print() + + r_count = 15066 + r_depth = 14756 + q_count = 291 + + result = _bucketize_rotation_counts(r_count, r_depth) + + a_count = 0 + a_depth = 0 + for c, d in result: + print(c, d) + assert c <= q_count + assert c > 0 + a_count += c * d + a_depth += d + + assert a_count == r_count + assert a_depth == r_depth diff --git a/source/pip/tests/qre/test_isa.py b/source/pip/tests/qre/test_isa.py new file mode 100644 index 0000000000..6c1e8e318a --- /dev/null +++ b/source/pip/tests/qre/test_isa.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +from qsharp.qre import ( + LOGICAL, + ISARequirements, + constraint, + generic_function, + property_name, + property_name_to_key, +) +from qsharp.qre._qre import _ProvenanceGraph +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._architecture import _make_instruction +from qsharp.qre.instruction_ids import CCX, CCZ, LATTICE_SURGERY, T +from qsharp.qre.property_keys import DISTANCE + + +def test_isa(): + graph = _ProvenanceGraph() + isa = graph.make_isa( + [ + graph.add_instruction( + T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 + ), + graph.add_instruction( + CCX, encoding=LOGICAL, arity=3, time=2000, space=800, error_rate=1e-10 + ), + ] + ) + + assert T in isa + assert CCX in isa + assert LATTICE_SURGERY not in isa + + t_instr = isa[T] + assert t_instr.time() == 1000 + assert t_instr.error_rate() == 1e-8 + assert t_instr.space() == 400 + + assert len(isa) == 2 + ccz_instr = isa[CCX].with_id(CCZ) + assert ccz_instr.arity == 3 + assert ccz_instr.time() == 2000 + assert ccz_instr.error_rate() == 1e-10 + assert ccz_instr.space() == 800 + + # Add another instruction to the graph and register it in the ISA + ccz_node = graph.add_instruction(ccz_instr) + isa.add_node(CCZ, ccz_node) + assert CCZ in isa + assert len(isa) == 3 + + # Adding the same instruction ID should not increase the count + isa.add_node(CCZ, ccz_node) + assert len(isa) == 3 + + +def test_instruction_properties(): + # Test instruction with no properties + instr_no_props = _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {}) + assert instr_no_props.get_property(DISTANCE) is None + assert instr_no_props.has_property(DISTANCE) is False + assert instr_no_props.get_property_or(DISTANCE, 5) == 5 + + # Test instruction with valid property (distance) + instr_with_distance = _make_instruction( + T, 1, 1, 1000, None, None, 1e-8, {"distance": 9} + ) + assert instr_with_distance.get_property(DISTANCE) == 9 + assert instr_with_distance.has_property(DISTANCE) is True + assert instr_with_distance.get_property_or(DISTANCE, 5) == 9 + + # Test instruction with invalid property name + with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): + _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {"invalid_prop": 42}) + + +def test_instruction_constraints(): + # Test constraint without properties + c_no_props = constraint(T, encoding=LOGICAL) + assert c_no_props.has_property(DISTANCE) is False + + # Test constraint with valid property (distance=True) + c_with_distance = constraint(T, encoding=LOGICAL, distance=True) + assert c_with_distance.has_property(DISTANCE) is True + + # Test constraint with distance=False (should not add the property) + c_distance_false = constraint(T, encoding=LOGICAL, distance=False) + assert c_distance_false.has_property(DISTANCE) is False + + # Test constraint with invalid property name + with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): + constraint(T, encoding=LOGICAL, invalid_prop=True) + + # Test ISA.satisfies with property constraints + graph = _ProvenanceGraph() + isa_no_dist = graph.make_isa( + [ + graph.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + ] + ) + isa_with_dist = graph.make_isa( + [ + graph.add_instruction( + T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 + ), + ] + ) + + reqs_no_prop = ISARequirements(constraint(T, encoding=LOGICAL)) + reqs_with_prop = ISARequirements(constraint(T, encoding=LOGICAL, distance=True)) + + # ISA without distance property + assert isa_no_dist.satisfies(reqs_no_prop) is True + assert isa_no_dist.satisfies(reqs_with_prop) is False + + # ISA with distance property + assert isa_with_dist.satisfies(reqs_no_prop) is True + assert isa_with_dist.satisfies(reqs_with_prop) is True + + +def test_property_names(): + assert property_name(DISTANCE) == "DISTANCE" + + # An unregistered property + UNKNOWN = 10_000 + assert property_name(UNKNOWN) is None + + # But using an existing property key with a different variable name will + # still return something + UNKNOWN = 0 + assert property_name(UNKNOWN) == "DISTANCE" + + assert property_name_to_key("DISTANCE") == DISTANCE + + # But we also allow case-insensitive lookup + assert property_name_to_key("distance") == DISTANCE + + +def test_generic_function(): + from qsharp.qre._qre import _IntFunction, _FloatFunction + + def time(x: int) -> int: + return x * x + + time_fn = generic_function(time) + assert isinstance(time_fn, _IntFunction) + + def error_rate(x: int) -> float: + return x / 2.0 + + error_rate_fn = generic_function(error_rate) + assert isinstance(error_rate_fn, _FloatFunction) + + # Without annotations, defaults to FloatFunction + space_fn = generic_function(lambda x: 12) + assert isinstance(space_fn, _FloatFunction) + + i = _make_instruction(42, 0, None, time_fn, 12, None, error_rate_fn, {}) + assert i.space(5) == 12 + assert i.time(5) == 25 + assert i.error_rate(5) == 2.5 + + +def test_isa_from_architecture(): + arch = GateBased(gate_time=50, measurement_time=100) + code = SurfaceCode() + ctx = arch.context() + + # Verify that the architecture satisfies the code requirements + assert ctx.isa.satisfies(SurfaceCode.required_isa()) + + # Generate logical ISAs + isas = list(code.provided_isa(ctx.isa, ctx)) + + # There is one ISA with one instructions + assert len(isas) == 1 + assert len(isas[0]) == 1 diff --git a/source/pip/tests/test_qre_models.py b/source/pip/tests/qre/test_models.py similarity index 90% rename from source/pip/tests/test_qre_models.py rename to source/pip/tests/qre/test_models.py index ef03a1eb42..728d557169 100644 --- a/source/pip/tests/test_qre_models.py +++ b/source/pip/tests/qre/test_models.py @@ -27,7 +27,7 @@ SQRT_SQRT_Z_DAG, ) from qsharp.qre.models import ( - AQREGateBased, + GateBased, Majorana, RoundBasedFactory, MagicUpToClifford, @@ -40,21 +40,21 @@ # --------------------------------------------------------------------------- -# AQREGateBased architecture tests +# GateBased architecture tests # --------------------------------------------------------------------------- -class TestAQREGateBased: +class TestGateBased: def test_default_error_rate(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) assert arch.error_rate == 1e-4 def test_custom_error_rate(self): - arch = AQREGateBased(error_rate=1e-3, gate_time=50, measurement_time=100) + arch = GateBased(error_rate=1e-3, gate_time=50, measurement_time=100) assert arch.error_rate == 1e-3 def test_provided_isa_contains_expected_instructions(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -62,7 +62,7 @@ def test_provided_isa_contains_expected_instructions(self): assert instr_id in isa def test_instruction_encodings_are_physical(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -71,7 +71,7 @@ def test_instruction_encodings_are_physical(self): def test_instruction_error_rates_match(self): rate = 1e-3 - arch = AQREGateBased(error_rate=rate, gate_time=50, measurement_time=100) + arch = GateBased(error_rate=rate, gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -79,7 +79,7 @@ def test_instruction_error_rates_match(self): assert isa[instr_id].expect_error_rate() == rate def test_gate_times(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -95,7 +95,7 @@ def test_gate_times(self): assert isa[MEAS_Z].expect_time() == 100 def test_arities(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -106,7 +106,7 @@ def test_arities(self): assert isa[MEAS_Z].arity == 1 def test_context_creation(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() assert ctx is not None @@ -180,7 +180,7 @@ def test_default_distance(self): assert sc.distance == 3 def test_provides_lattice_surgery(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc = SurfaceCode(distance=3) @@ -195,7 +195,7 @@ def test_provides_lattice_surgery(self): def test_space_scales_with_distance(self): """Space = 2*d^2 - 1 physical qubits per logical qubit.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) for d in [3, 5, 7, 9]: ctx = arch.context() @@ -207,8 +207,8 @@ def test_space_scales_with_distance(self): def test_time_scales_with_distance(self): """Time = (h_time + 4*cnot_time + meas_time) * d.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) - # h=50, cnot=50, meas=100 for AQREGateBased + arch = GateBased(gate_time=50, measurement_time=100) + # h=50, cnot=50, meas=100 for GateBased syndrome_time = 50 + 4 * 50 + 100 # = 350 for d in [3, 5, 7]: @@ -219,7 +219,7 @@ def test_time_scales_with_distance(self): assert ls.expect_time(1) == syndrome_time * d def test_error_rate_decreases_with_distance(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) errors = [] for d in [3, 5, 7, 9, 11]: @@ -234,7 +234,7 @@ def test_error_rate_decreases_with_distance(self): def test_enumeration_via_query(self): """Enumerating SurfaceCode.q() should yield multiple distances.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 @@ -246,7 +246,7 @@ def test_enumeration_via_query(self): assert count == 12 def test_custom_crossing_prefactor(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc_default = SurfaceCode(distance=5) @@ -265,7 +265,7 @@ def test_custom_crossing_prefactor(self): assert abs(custom_error - 2 * default_error) < 1e-20 def test_custom_error_correction_threshold(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx1 = arch.context() sc_low_threshold = SurfaceCode(error_correction_threshold=0.005, distance=5) @@ -395,7 +395,7 @@ def test_enumeration_via_query(self): class TestYokedSurfaceCode: def _get_lattice_surgery_isa(self, distance=5): """Helper to get a lattice surgery ISA from SurfaceCode.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc = SurfaceCode(distance=distance) isas = list(sc.provided_isa(ctx.isa, ctx)) @@ -479,9 +479,9 @@ def test_required_isa(self): reqs = Litinski19Factory.required_isa() assert reqs is not None - def test_table1_aqre_yields_t_and_ccz(self): - """AQREGateBased (error 1e-4) matches Table 1 scenario: T & CCZ.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + def test_table1_yields_t_and_ccz(self): + """GateBased (error 1e-4) matches Table 1 scenario: T & CCZ.""" + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -496,7 +496,7 @@ def test_table1_aqre_yields_t_and_ccz(self): assert len(isa) == 2 def test_table1_instruction_properties(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -516,7 +516,7 @@ def test_table1_instruction_properties(self): def test_table1_t_error_rates_are_diverse(self): """T entries in Table 1 should span a range of error rates.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -532,8 +532,8 @@ def test_table1_t_error_rates_are_diverse(self): assert 0 < err < 1e-5 def test_table1_1e3_clifford_yields_6_isas(self): - """AQREGateBased with 1e-3 error matches Table 1 at 1e-3 Clifford.""" - arch = AQREGateBased(error_rate=1e-3, gate_time=50, measurement_time=100) + """GateBased with 1e-3 error matches Table 1 at 1e-3 Clifford.""" + arch = GateBased(error_rate=1e-3, gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -550,7 +550,7 @@ def test_table2_scenario_no_ccz(self): """Table 2 scenario: T error ~10x higher than Clifford, no CCZ.""" from qsharp.qre._qre import _ProvenanceGraph - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() # Manually create ISA with T error rate 10x Clifford @@ -578,7 +578,7 @@ def test_no_yield_when_error_too_high(self): """If T error > 10x Clifford, no entries match.""" from qsharp.qre._qre import _ProvenanceGraph - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() graph = _ProvenanceGraph() @@ -597,11 +597,11 @@ def test_no_yield_when_error_too_high(self): def test_time_based_on_syndrome_extraction(self): """Time should be based on syndrome extraction time × cycles.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() - # For AQREGateBased: syndrome_extraction_time = 4*50 + 50 + 100 = 350 + # For GateBased: syndrome_extraction_time = 4*50 + 50 + 100 = 350 syndrome_time = 4 * 50 + 50 + 100 # 350 ns isas = list(factory.provided_isa(ctx.isa, ctx)) @@ -625,7 +625,7 @@ def test_required_isa_is_empty(self): def test_adds_clifford_equivalent_t_gates(self): """Given T gate, should add SQRT_SQRT_X/Y/Z and dagger variants.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -650,7 +650,7 @@ def test_adds_clifford_equivalent_t_gates(self): def test_adds_clifford_equivalent_ccz(self): """Given CCZ, should add CCX and CCY.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -666,7 +666,7 @@ def test_adds_clifford_equivalent_ccz(self): def test_full_count_of_instructions(self): """T gate (1) + 5 equivalents (SQRT_SQRT_*) + CCZ (1) + 2 equivalents (CCX, CCY) = 9.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -678,7 +678,7 @@ def test_full_count_of_instructions(self): def test_equivalent_instructions_share_properties(self): """Clifford equivalents should have same time, space, error rate.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -709,7 +709,7 @@ def test_equivalent_instructions_share_properties(self): def test_modification_count_matches_factory_output(self): """MagicUpToClifford should produce one modified ISA per input ISA.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -725,7 +725,7 @@ def test_no_family_present_passes_through(self): """If no family member is present, ISA passes through unchanged.""" from qsharp.qre._qre import _ProvenanceGraph - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() modifier = MagicUpToClifford() @@ -758,7 +758,7 @@ def test_no_family_present_passes_through(self): def test_isa_manipulation(): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -813,7 +813,7 @@ def test_required_isa(self): assert reqs is not None def test_produces_logical_t_gates(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): t = isa[T] @@ -826,7 +826,7 @@ def test_produces_logical_t_gates(self): def test_error_rates_are_bounded(self): """Distilled T error rates should be bounded and mostly small.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) # T error rate is 1e-4 + arch = GateBased(gate_time=50, measurement_time=100) # T error rate is 1e-4 errors = [] for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): @@ -843,7 +843,7 @@ def test_error_rates_are_bounded(self): def test_max_produces_fewer_or_equal_results_than_sum(self): """Using max for physical_qubit_calculation may filter differently.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) sum_count = sum( 1 for _ in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()) @@ -859,7 +859,7 @@ def test_max_produces_fewer_or_equal_results_than_sum(self): def test_max_space_less_than_or_equal_sum_space(self): """max-aggregated space should be <= sum-aggregated space for each.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) sum_spaces = sorted( isa[T].expect_space() @@ -890,8 +890,8 @@ def test_with_three_aux_code_query(self): assert count > 0 - def test_round_based_aqre_sum(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + def test_round_based_gate_based_sum(self): + arch = GateBased(gate_time=50, measurement_time=100) total_space = 0 total_time = 0 @@ -909,8 +909,8 @@ def test_round_based_aqre_sum(self): assert abs(total_error - 0.001_463_030_863_973_197_8) < 1e-8 assert count == 107 - def test_round_based_aqre_max(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + def test_round_based_gate_based_max(self): + arch = GateBased(gate_time=50, measurement_time=100) total_space = 0 total_time = 0 @@ -960,10 +960,10 @@ def test_round_based_msft_sum(self): class TestCrossModelIntegration: def test_surface_code_feeds_into_litinski(self): """SurfaceCode -> Litinski19Factory pipeline works end to end.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() - # SurfaceCode takes AQRE physical ISA -> LATTICE_SURGERY + # SurfaceCode takes gate-based physical ISA -> LATTICE_SURGERY sc = SurfaceCode(distance=5) sc_isas = list(sc.provided_isa(ctx.isa, ctx)) assert len(sc_isas) == 1 @@ -989,7 +989,7 @@ def test_three_aux_feeds_into_round_based(self): def test_litinski_with_magic_up_to_clifford_query(self): """Full query chain: Litinski19Factory -> MagicUpToClifford.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 @@ -1004,7 +1004,7 @@ def test_litinski_with_magic_up_to_clifford_query(self): def test_surface_code_with_yoked_surface_code(self): """SurfaceCode -> YokedSurfaceCode pipeline provides MEMORY.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py deleted file mode 100644 index 2dc318fa5e..0000000000 --- a/source/pip/tests/test_qre.py +++ /dev/null @@ -1,1666 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from dataclasses import KW_ONLY, dataclass, field -from enum import Enum -from pathlib import Path -from typing import cast, Generator, Sized -import os -import pytest - -import pandas as pd -import qsharp -from qsharp.estimator import LogicalCounts -from qsharp.qre import ( - Application, - ISA, - LOGICAL, - PSSPC, - EstimationResult, - ISARequirements, - ISATransform, - LatticeSurgery, - Trace, - constraint, - estimate, - linear_function, - generic_function, - property_name, - property_name_to_key, -) -from qsharp.qre._qre import _ProvenanceGraph -from qsharp.qre.application import QSharpApplication -from qsharp.qre.models import ( - SurfaceCode, - AQREGateBased, - RoundBasedFactory, - TwoDimensionalYokedSurfaceCode, -) -from qsharp.qre.interop import trace_from_qir -from qsharp.qre._architecture import _Context, _make_instruction -from qsharp.qre._estimation import ( - EstimationTable, - EstimationTableEntry, -) -from qsharp.qre._instruction import InstructionSource -from qsharp.qre._isa_enumeration import ( - ISARefNode, -) -from qsharp.qre.instruction_ids import CCX, CCZ, LATTICE_SURGERY, T, RZ -from qsharp.qre.property_keys import ( - DISTANCE, - NUM_TS_PER_ROTATION, - ALGORITHM_COMPUTE_QUBITS, - ALGORITHM_MEMORY_QUBITS, - LOGICAL_COMPUTE_QUBITS, - LOGICAL_MEMORY_QUBITS, -) - -# NOTE These classes will be generalized as part of the QRE API in the following -# pull requests and then moved out of the tests. - - -@dataclass -class ExampleFactory(ISATransform): - _: KW_ONLY - level: int = field(default=1, metadata={"domain": range(1, 4)}) - - @staticmethod - def required_isa() -> ISARequirements: - return ISARequirements( - constraint(T), - ) - - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: - yield ctx.make_isa( - ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), - ) - - -@dataclass -class ExampleLogicalFactory(ISATransform): - _: KW_ONLY - level: int = field(default=1, metadata={"domain": range(1, 4)}) - - @staticmethod - def required_isa() -> ISARequirements: - return ISARequirements( - constraint(LATTICE_SURGERY, encoding=LOGICAL), - constraint(T, encoding=LOGICAL), - ) - - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: - yield ctx.make_isa( - ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), - ) - - -def test_isa(): - graph = _ProvenanceGraph() - isa = graph.make_isa( - [ - graph.add_instruction( - T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 - ), - graph.add_instruction( - CCX, encoding=LOGICAL, arity=3, time=2000, space=800, error_rate=1e-10 - ), - ] - ) - - assert T in isa - assert CCX in isa - assert LATTICE_SURGERY not in isa - - t_instr = isa[T] - assert t_instr.time() == 1000 - assert t_instr.error_rate() == 1e-8 - assert t_instr.space() == 400 - - assert len(isa) == 2 - ccz_instr = isa[CCX].with_id(CCZ) - assert ccz_instr.arity == 3 - assert ccz_instr.time() == 2000 - assert ccz_instr.error_rate() == 1e-10 - assert ccz_instr.space() == 800 - - # Add another instruction to the graph and register it in the ISA - ccz_node = graph.add_instruction(ccz_instr) - isa.add_node(CCZ, ccz_node) - assert CCZ in isa - assert len(isa) == 3 - - # Adding the same instruction ID should not increase the count - isa.add_node(CCZ, ccz_node) - assert len(isa) == 3 - - -def test_instruction_properties(): - # Test instruction with no properties - instr_no_props = _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {}) - assert instr_no_props.get_property(DISTANCE) is None - assert instr_no_props.has_property(DISTANCE) is False - assert instr_no_props.get_property_or(DISTANCE, 5) == 5 - - # Test instruction with valid property (distance) - instr_with_distance = _make_instruction( - T, 1, 1, 1000, None, None, 1e-8, {"distance": 9} - ) - assert instr_with_distance.get_property(DISTANCE) == 9 - assert instr_with_distance.has_property(DISTANCE) is True - assert instr_with_distance.get_property_or(DISTANCE, 5) == 9 - - # Test instruction with invalid property name - with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): - _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {"invalid_prop": 42}) - - -def test_instruction_constraints(): - # Test constraint without properties - c_no_props = constraint(T, encoding=LOGICAL) - assert c_no_props.has_property(DISTANCE) is False - - # Test constraint with valid property (distance=True) - c_with_distance = constraint(T, encoding=LOGICAL, distance=True) - assert c_with_distance.has_property(DISTANCE) is True - - # Test constraint with distance=False (should not add the property) - c_distance_false = constraint(T, encoding=LOGICAL, distance=False) - assert c_distance_false.has_property(DISTANCE) is False - - # Test constraint with invalid property name - with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): - constraint(T, encoding=LOGICAL, invalid_prop=True) - - # Test ISA.satisfies with property constraints - graph = _ProvenanceGraph() - isa_no_dist = graph.make_isa( - [ - graph.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), - ] - ) - isa_with_dist = graph.make_isa( - [ - graph.add_instruction( - T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 - ), - ] - ) - - reqs_no_prop = ISARequirements(constraint(T, encoding=LOGICAL)) - reqs_with_prop = ISARequirements(constraint(T, encoding=LOGICAL, distance=True)) - - # ISA without distance property - assert isa_no_dist.satisfies(reqs_no_prop) is True - assert isa_no_dist.satisfies(reqs_with_prop) is False - - # ISA with distance property - assert isa_with_dist.satisfies(reqs_no_prop) is True - assert isa_with_dist.satisfies(reqs_with_prop) is True - - -def test_property_names(): - assert property_name(DISTANCE) == "DISTANCE" - - # An unregistered property - UNKNOWN = 10_000 - assert property_name(UNKNOWN) is None - - # But using an existing property key with a different variable name will - # still return something - UNKNOWN = 0 - assert property_name(UNKNOWN) == "DISTANCE" - - assert property_name_to_key("DISTANCE") == DISTANCE - - # But we also allow case-insensitive lookup - assert property_name_to_key("distance") == DISTANCE - - -def test_generic_function(): - from qsharp.qre._qre import _IntFunction, _FloatFunction - - def time(x: int) -> int: - return x * x - - time_fn = generic_function(time) - assert isinstance(time_fn, _IntFunction) - - def error_rate(x: int) -> float: - return x / 2.0 - - error_rate_fn = generic_function(error_rate) - assert isinstance(error_rate_fn, _FloatFunction) - - # Without annotations, defaults to FloatFunction - space_fn = generic_function(lambda x: 12) - assert isinstance(space_fn, _FloatFunction) - - i = _make_instruction(42, 0, None, time_fn, 12, None, error_rate_fn, {}) - assert i.space(5) == 12 - assert i.time(5) == 25 - assert i.error_rate(5) == 2.5 - - -def test_isa_from_architecture(): - arch = AQREGateBased(gate_time=50, measurement_time=100) - code = SurfaceCode() - ctx = arch.context() - - # Verify that the architecture satisfies the code requirements - assert ctx.isa.satisfies(SurfaceCode.required_isa()) - - # Generate logical ISAs - isas = list(code.provided_isa(ctx.isa, ctx)) - - # There is one ISA with one instructions - assert len(isas) == 1 - assert len(isas[0]) == 1 - - -def test_enumerate_instances(): - from qsharp.qre._enumeration import _enumerate_instances - - instances = list(_enumerate_instances(SurfaceCode)) - - # There are 12 instances with distances from 3 to 25 - assert len(instances) == 12 - expected_distances = list(range(3, 26, 2)) - for instance, expected_distance in zip(instances, expected_distances): - assert instance.distance == expected_distance - - # Test with specific distances - instances = list(_enumerate_instances(SurfaceCode, distance=[3, 5, 7])) - assert len(instances) == 3 - expected_distances = [3, 5, 7] - for instance, expected_distance in zip(instances, expected_distances): - assert instance.distance == expected_distance - - # Test with fixed distance - instances = list(_enumerate_instances(SurfaceCode, distance=9)) - assert len(instances) == 1 - assert instances[0].distance == 9 - - -def test_enumerate_instances_bool(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class BoolConfig: - _: KW_ONLY - flag: bool - - instances = list(_enumerate_instances(BoolConfig)) - assert len(instances) == 2 - assert instances[0].flag is True - assert instances[1].flag is False - - -def test_enumerate_instances_enum(): - from qsharp.qre._enumeration import _enumerate_instances - - class Color(Enum): - RED = 1 - GREEN = 2 - BLUE = 3 - - @dataclass - class EnumConfig: - _: KW_ONLY - color: Color - - instances = list(_enumerate_instances(EnumConfig)) - assert len(instances) == 3 - assert instances[0].color == Color.RED - assert instances[1].color == Color.GREEN - assert instances[2].color == Color.BLUE - - -def test_enumerate_instances_failure(): - from qsharp.qre._enumeration import _enumerate_instances - - import pytest - - @dataclass - class InvalidConfig: - _: KW_ONLY - # This field has no domain, is not bool/enum, and has no default - value: int - - with pytest.raises(ValueError, match="Cannot enumerate field value"): - list(_enumerate_instances(InvalidConfig)) - - -def test_enumerate_instances_single(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class SingleConfig: - value: int = 42 - - instances = list(_enumerate_instances(SingleConfig)) - assert len(instances) == 1 - assert instances[0].value == 42 - - -def test_enumerate_instances_literal(): - from qsharp.qre._enumeration import _enumerate_instances - - from typing import Literal - - @dataclass - class LiteralConfig: - _: KW_ONLY - mode: Literal["fast", "slow"] - - instances = list(_enumerate_instances(LiteralConfig)) - assert len(instances) == 2 - assert instances[0].mode == "fast" - assert instances[1].mode == "slow" - - -def test_enumerate_instances_nested(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class InnerConfig: - _: KW_ONLY - option: bool - - @dataclass - class OuterConfig: - _: KW_ONLY - inner: InnerConfig - - instances = list(_enumerate_instances(OuterConfig)) - assert len(instances) == 2 - assert instances[0].inner.option is True - assert instances[1].inner.option is False - - -def test_enumerate_instances_union(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class OptionA: - _: KW_ONLY - value: bool - - @dataclass - class OptionB: - _: KW_ONLY - number: int = field(default=1, metadata={"domain": [1, 2, 3]}) - - @dataclass - class UnionConfig: - _: KW_ONLY - option: OptionA | OptionB - - instances = list(_enumerate_instances(UnionConfig)) - assert len(instances) == 5 - assert isinstance(instances[0].option, OptionA) - assert instances[0].option.value is True - assert isinstance(instances[2].option, OptionB) - assert instances[2].option.number == 1 - - -def test_enumerate_instances_nested_with_constraints(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class InnerConfig: - _: KW_ONLY - option: bool - - @dataclass - class OuterConfig: - _: KW_ONLY - inner: InnerConfig - - # Constrain nested field via dict - instances = list(_enumerate_instances(OuterConfig, inner={"option": True})) - assert len(instances) == 1 - assert instances[0].inner.option is True - - -def test_enumerate_instances_union_single_type(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class OptionA: - _: KW_ONLY - value: bool - - @dataclass - class OptionB: - _: KW_ONLY - number: int = field(default=1, metadata={"domain": [1, 2, 3]}) - - @dataclass - class UnionConfig: - _: KW_ONLY - option: OptionA | OptionB - - # Restrict to OptionB only - uses its default domain - instances = list(_enumerate_instances(UnionConfig, option=OptionB)) - assert len(instances) == 3 - assert all(isinstance(i.option, OptionB) for i in instances) - assert [cast(OptionB, i.option).number for i in instances] == [1, 2, 3] - - # Restrict to OptionA only - instances = list(_enumerate_instances(UnionConfig, option=OptionA)) - assert len(instances) == 2 - assert all(isinstance(i.option, OptionA) for i in instances) - assert cast(OptionA, instances[0].option).value is True - assert cast(OptionA, instances[1].option).value is False - - -def test_enumerate_instances_union_list_of_types(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class OptionA: - _: KW_ONLY - value: bool - - @dataclass - class OptionB: - _: KW_ONLY - number: int = field(default=1, metadata={"domain": [1, 2, 3]}) - - @dataclass - class OptionC: - _: KW_ONLY - flag: bool - - @dataclass - class UnionConfig: - _: KW_ONLY - option: OptionA | OptionB | OptionC - - # Select a subset: only OptionA and OptionB - instances = list(_enumerate_instances(UnionConfig, option=[OptionA, OptionB])) - assert len(instances) == 5 # 2 from OptionA + 3 from OptionB - assert all(isinstance(i.option, (OptionA, OptionB)) for i in instances) - - -def test_enumerate_instances_union_constraint_dict(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class OptionA: - _: KW_ONLY - value: bool - - @dataclass - class OptionB: - _: KW_ONLY - number: int = field(default=1, metadata={"domain": [1, 2, 3]}) - - @dataclass - class UnionConfig: - _: KW_ONLY - option: OptionA | OptionB - - # Constrain OptionA, enumerate only that member - instances = list( - _enumerate_instances(UnionConfig, option={OptionA: {"value": True}}) - ) - assert len(instances) == 1 - assert isinstance(instances[0].option, OptionA) - assert instances[0].option.value is True - - # Constrain OptionB with a domain, enumerate only that member - instances = list( - _enumerate_instances(UnionConfig, option={OptionB: {"number": [2, 3]}}) - ) - assert len(instances) == 2 - assert all(isinstance(i.option, OptionB) for i in instances) - assert cast(OptionB, instances[0].option).number == 2 - assert cast(OptionB, instances[1].option).number == 3 - - # Constrain one member and keep another with defaults - instances = list( - _enumerate_instances( - UnionConfig, - option={OptionA: {"value": True}, OptionB: {}}, - ) - ) - assert len(instances) == 4 # 1 from OptionA + 3 from OptionB - assert isinstance(instances[0].option, OptionA) - assert instances[0].option.value is True - assert all(isinstance(i.option, OptionB) for i in instances[1:]) - assert [cast(OptionB, i.option).number for i in instances[1:]] == [1, 2, 3] - - -def test_enumerate_isas(): - ctx = AQREGateBased(gate_time=50, measurement_time=100).context() - - # This will enumerate the 4 ISAs for the error correction code - count = sum(1 for _ in SurfaceCode.q().enumerate(ctx)) - assert count == 12 - - # This will enumerate the 2 ISAs for the error correction code when - # restricting the domain - count = sum(1 for _ in SurfaceCode.q(distance=[3, 4]).enumerate(ctx)) - assert count == 2 - - # This will enumerate the 3 ISAs for the factory - count = sum(1 for _ in ExampleFactory.q().enumerate(ctx)) - assert count == 3 - - # This will enumerate 36 ISAs for all products between the 12 error - # correction code ISAs and the 3 factory ISAs - count = sum(1 for _ in (SurfaceCode.q() * ExampleFactory.q()).enumerate(ctx)) - assert count == 36 - - # When providing a list, components are chained (OR operation). This - # enumerates ISAs from first factory instance OR second factory instance - count = sum( - 1 - for _ in ( - SurfaceCode.q() * (ExampleFactory.q() + ExampleFactory.q()) - ).enumerate(ctx) - ) - assert count == 72 - - # When providing separate arguments, components are combined via product - # (AND). This enumerates ISAs from first factory instance AND second - # factory instance - count = sum( - 1 - for _ in (SurfaceCode.q() * ExampleFactory.q() * ExampleFactory.q()).enumerate( - ctx - ) - ) - assert count == 108 - - # Hierarchical factory using from_components: the component receives ISAs - # from the product of other components as its source - count = sum( - 1 - for _ in ( - SurfaceCode.q() - * ExampleLogicalFactory.q(source=(SurfaceCode.q() * ExampleFactory.q())) - ).enumerate(ctx) - ) - assert count == 1296 - - -def test_binding_node(): - """Test binding nodes with ISARefNode for component bindings""" - ctx = AQREGateBased(gate_time=50, measurement_time=100).context() - - # Test basic binding: same code used twice - # Without binding: 12 codes × 12 codes = 144 combinations - count_without = sum(1 for _ in (SurfaceCode.q() * SurfaceCode.q()).enumerate(ctx)) - assert count_without == 144 - - # With binding: 12 codes (same instance used twice) - count_with = sum( - 1 - for _ in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx) - ) - assert count_with == 12 - - # Verify the binding works: with binding, both should use same params - for isa in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx): - logical_gates = [g for g in isa if g.encoding == LOGICAL] - # Should have 1 logical gate (LATTICE_SURGERY) - assert len(logical_gates) == 1 - - # Test binding with factories (nested bindings) - count_without = sum( - 1 - for _ in ( - SurfaceCode.q() * ExampleFactory.q() * SurfaceCode.q() * ExampleFactory.q() - ).enumerate(ctx) - ) - assert count_without == 1296 # 12 * 3 * 12 * 3 - - count_with = sum( - 1 - for _ in SurfaceCode.bind( - "c", - ExampleFactory.bind( - "f", - ISARefNode("c") * ISARefNode("f") * ISARefNode("c") * ISARefNode("f"), - ), - ).enumerate(ctx) - ) - assert count_with == 36 # 12 * 3 - - # Test binding with from_components equivalent (hierarchical) - # Without binding: 4 outer codes × (4 inner codes × 3 factories × 3 levels) - count_without = sum( - 1 - for _ in ( - SurfaceCode.q() - * ExampleLogicalFactory.q( - source=(SurfaceCode.q() * ExampleFactory.q()), - ) - ).enumerate(ctx) - ) - assert count_without == 1296 # 12 * 12 * 3 * 3 - - # With binding: 4 codes (same used twice) × 3 factories × 3 levels - count_with = sum( - 1 - for _ in SurfaceCode.bind( - "c", - ISARefNode("c") - * ExampleLogicalFactory.q( - source=(ISARefNode("c") * ExampleFactory.q()), - ), - ).enumerate(ctx) - ) - assert count_with == 108 # 12 * 3 * 3 - - # Test binding with kwargs - count_with_kwargs = sum( - 1 - for _ in SurfaceCode.q(distance=5) - .bind("c", ISARefNode("c") * ISARefNode("c")) - .enumerate(ctx) - ) - assert count_with_kwargs == 1 # Only distance=5 - - # Verify kwargs are applied - for isa in ( - SurfaceCode.q(distance=5) - .bind("c", ISARefNode("c") * ISARefNode("c")) - .enumerate(ctx) - ): - logical_gates = [g for g in isa if g.encoding == LOGICAL] - assert all(g.space(1) == 49 for g in logical_gates) - - # Test multiple independent bindings (nested) - count = sum( - 1 - for _ in SurfaceCode.bind( - "c1", - ExampleFactory.bind( - "c2", - ISARefNode("c1") - * ISARefNode("c1") - * ISARefNode("c2") - * ISARefNode("c2"), - ), - ).enumerate(ctx) - ) - # 12 codes for c1 × 3 factories for c2 - assert count == 36 - - -def test_binding_node_errors(): - """Test error handling for binding nodes""" - ctx = AQREGateBased(gate_time=50, measurement_time=100).context() - - # Test ISARefNode enumerate with undefined binding raises ValueError - try: - list(ISARefNode("test").enumerate(ctx)) - assert False, "Should have raised ValueError" - except ValueError as e: - assert "Undefined component reference: 'test'" in str(e) - - -def test_product_isa_enumeration_nodes(): - from qsharp.qre._isa_enumeration import _ComponentQuery, _ProductNode - - terminal = SurfaceCode.q() - query = terminal * terminal - - # Multiplication should create ProductNode - assert isinstance(query, _ProductNode) - assert len(query.sources) == 2 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Multiplying again should extend the sources - query = query * terminal - assert isinstance(query, _ProductNode) - assert len(query.sources) == 3 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Also from the other side - query = terminal * query - assert isinstance(query, _ProductNode) - assert len(query.sources) == 4 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Also for two ProductNodes - query = query * query - assert isinstance(query, _ProductNode) - assert len(query.sources) == 8 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - -def test_sum_isa_enumeration_nodes(): - from qsharp.qre._isa_enumeration import _ComponentQuery, _SumNode - - terminal = SurfaceCode.q() - query = terminal + terminal - - # Multiplication should create SumNode - assert isinstance(query, _SumNode) - assert len(query.sources) == 2 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Multiplying again should extend the sources - query = query + terminal - assert isinstance(query, _SumNode) - assert len(query.sources) == 3 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Also from the other side - query = terminal + query - assert isinstance(query, _SumNode) - assert len(query.sources) == 4 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Also for two SumNodes - query = query + query - assert isinstance(query, _SumNode) - assert len(query.sources) == 8 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - -def test_trace_properties(): - trace = Trace(42) - - INT = 0 - FLOAT = 1 - BOOL = 2 - STR = 3 - - trace.set_property(INT, 42) - assert trace.get_property(INT) == 42 - assert isinstance(trace.get_property(INT), int) - - trace.set_property(FLOAT, 3.14) - assert trace.get_property(FLOAT) == 3.14 - assert isinstance(trace.get_property(FLOAT), float) - - trace.set_property(BOOL, True) - assert trace.get_property(BOOL) is True - assert isinstance(trace.get_property(BOOL), bool) - - trace.set_property(STR, "hello") - assert trace.get_property(STR) == "hello" - assert isinstance(trace.get_property(STR), str) - - -def test_qsharp_application(): - from qsharp.qre._enumeration import _enumerate_instances - - code = """ - {{ - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - }} - """ - - app = QSharpApplication(code) - trace = app.get_trace() - - assert trace.compute_qubits == 3 - assert trace.depth == 3 - assert trace.resource_states == {} - - assert {c.id for c in trace.required_isa} == {CCX, T, RZ} - - graph = _ProvenanceGraph() - isa = graph.make_isa( - [ - graph.add_instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - time=1000, - space=linear_function(50), - error_rate=linear_function(1e-6), - ), - graph.add_instruction( - T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 - ), - graph.add_instruction( - CCX, encoding=LOGICAL, time=2000, space=800, error_rate=1e-10 - ), - ] - ) - - # Properties from the program - counts = qsharp.logical_counts(code) - num_ts = counts["tCount"] - num_ccx = counts["cczCount"] - num_rotations = counts["rotationCount"] - rotation_depth = counts["rotationDepth"] - - lattice_surgery = LatticeSurgery() - - counter = 0 - for psspc in _enumerate_instances(PSSPC): - counter += 1 - trace2 = psspc.transform(trace) - assert trace2 is not None - trace2 = lattice_surgery.transform(trace2) - assert trace2 is not None - assert trace2.compute_qubits == 12 - assert ( - trace2.depth - == num_ts - + num_ccx * 3 - + num_rotations - + rotation_depth * psspc.num_ts_per_rotation - ) - if psspc.ccx_magic_states: - assert trace2.resource_states == { - T: num_ts + psspc.num_ts_per_rotation * num_rotations, - CCX: num_ccx, - } - assert {c.id for c in trace2.required_isa} == {CCX, T, LATTICE_SURGERY} - else: - assert trace2.resource_states == { - T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx - } - assert {c.id for c in trace2.required_isa} == {T, LATTICE_SURGERY} - assert trace2.get_property(ALGORITHM_COMPUTE_QUBITS) == 3 - assert trace2.get_property(ALGORITHM_MEMORY_QUBITS) == 0 - result = trace2.estimate(isa, max_error=float("inf")) - assert result is not None - assert result.properties[ALGORITHM_COMPUTE_QUBITS] == 3 - assert result.properties[ALGORITHM_MEMORY_QUBITS] == 0 - assert result.properties[LOGICAL_COMPUTE_QUBITS] == 12 - assert result.properties[LOGICAL_MEMORY_QUBITS] == 0 - _assert_estimation_result(trace2, result, isa) - assert counter == 32 - - -def test_application_enumeration(): - @dataclass(kw_only=True) - class _Params: - size: int = field(default=1, metadata={"domain": range(1, 4)}) - - class TestApp(Application[_Params]): - def get_trace(self, parameters: _Params) -> Trace: - return Trace(parameters.size) - - app = TestApp() - assert sum(1 for _ in TestApp.q().enumerate(app.context())) == 3 - assert sum(1 for _ in TestApp.q(size=1).enumerate(app.context())) == 1 - assert sum(1 for _ in TestApp.q(size=[4, 5]).enumerate(app.context())) == 2 - - -def test_trace_enumeration(): - code = """ - {{ - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - }} - """ - - app = QSharpApplication(code) - - ctx = app.context() - assert sum(1 for _ in QSharpApplication.q().enumerate(ctx)) == 1 - - assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 32 - - assert sum(1 for _ in LatticeSurgery.q().enumerate(ctx)) == 1 - - q = PSSPC.q() * LatticeSurgery.q() - assert sum(1 for _ in q.enumerate(ctx)) == 32 - - -def test_rotation_error_psspc(): - from qsharp.qre._enumeration import _enumerate_instances - - # This test helps to bound the variables for the number of rotations in PSSPC - - # Create a trace with a single rotation gate and ensure that the base error - # after PSSPC transformation is less than 1. - trace = Trace(1) - trace.add_operation(RZ, [0]) - - for psspc in _enumerate_instances(PSSPC, ccx_magic_states=False): - transformed = psspc.transform(trace) - assert transformed is not None - assert ( - transformed.base_error < 1.0 - ), f"Base error too high: {transformed.base_error} for {psspc.num_ts_per_rotation} T states per rotation" - - -def test_estimation_max_error(): - from qsharp.estimator import LogicalCounts - - app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) - arch = AQREGateBased(gate_time=50, measurement_time=100) - - for max_error in [1e-1, 1e-2, 1e-3, 1e-4]: - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=max_error, - ) - - assert len(results) == 1 - assert next(iter(results)).error <= max_error - - -def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): - actual_qubits = ( - isa[LATTICE_SURGERY].expect_space(trace.compute_qubits) - + isa[T].expect_space() * result.factories[T].copies - ) - if CCX in trace.resource_states: - actual_qubits += isa[CCX].expect_space() * result.factories[CCX].copies - assert result.qubits == actual_qubits - - assert ( - result.runtime - == isa[LATTICE_SURGERY].expect_time(trace.compute_qubits) * trace.depth - ) - - actual_error = ( - trace.base_error - + isa[LATTICE_SURGERY].expect_error_rate(trace.compute_qubits) * trace.depth - + isa[T].expect_error_rate() * result.factories[T].states - ) - if CCX in trace.resource_states: - actual_error += isa[CCX].expect_error_rate() * result.factories[CCX].states - assert abs(result.error - actual_error) <= 1e-8 - - -# --- EstimationTable tests --- - - -def _make_entry(qubits, runtime, error, properties=None): - """Helper to create an EstimationTableEntry with a dummy InstructionSource.""" - return EstimationTableEntry( - qubits=qubits, - runtime=runtime, - error=error, - source=InstructionSource(), - properties=properties or {}, - ) - - -def test_estimation_table_default_columns(): - """Test that a new EstimationTable has the three default columns.""" - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01)) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "runtime", "error"] - assert frame["qubits"][0] == 100 - assert frame["runtime"][0] == pd.Timedelta(5000, unit="ns") - assert frame["error"][0] == 0.01 - - -def test_estimation_table_multiple_rows(): - """Test as_frame with multiple entries.""" - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01)) - table.append(_make_entry(200, 10000, 0.02)) - table.append(_make_entry(300, 15000, 0.03)) - - frame = table.as_frame() - assert len(frame) == 3 - assert list(frame["qubits"]) == [100, 200, 300] - assert list(frame["error"]) == [0.01, 0.02, 0.03] - - -def test_estimation_table_empty(): - """Test as_frame with no entries produces an empty DataFrame.""" - table = EstimationTable() - frame = table.as_frame() - assert len(frame) == 0 - - -def test_estimation_table_add_column(): - """Test adding a column to the table.""" - VAL = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={VAL: 42})) - table.append(_make_entry(200, 10000, 0.02, properties={VAL: 84})) - - table.add_column("val", lambda e: e.properties[VAL]) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "runtime", "error", "val"] - assert list(frame["val"]) == [42, 84] - - -def test_estimation_table_add_column_with_formatter(): - """Test adding a column with a formatter.""" - NS = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={NS: 1000})) - - table.add_column( - "duration", - lambda e: e.properties[NS], - formatter=lambda x: pd.Timedelta(x, unit="ns"), - ) - - frame = table.as_frame() - assert frame["duration"][0] == pd.Timedelta(1000, unit="ns") - - -def test_estimation_table_add_multiple_columns(): - """Test adding multiple columns preserves order.""" - A = 0 - B = 1 - C = 2 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2, C: 3})) - - table.add_column("a", lambda e: e.properties[A]) - table.add_column("b", lambda e: e.properties[B]) - table.add_column("c", lambda e: e.properties[C]) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "runtime", "error", "a", "b", "c"] - assert frame["a"][0] == 1 - assert frame["b"][0] == 2 - assert frame["c"][0] == 3 - - -def test_estimation_table_insert_column_at_beginning(): - """Test inserting a column at index 0.""" - NAME = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={NAME: "test"})) - - table.insert_column(0, "name", lambda e: e.properties[NAME]) - - frame = table.as_frame() - assert list(frame.columns) == ["name", "qubits", "runtime", "error"] - assert frame["name"][0] == "test" - - -def test_estimation_table_insert_column_in_middle(): - """Test inserting a column between existing default columns.""" - EXTRA = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={EXTRA: 99})) - - # Insert between qubits and runtime (index 1) - table.insert_column(1, "extra", lambda e: e.properties[EXTRA]) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "extra", "runtime", "error"] - assert frame["extra"][0] == 99 - - -def test_estimation_table_insert_column_at_end(): - """Test inserting a column at the end (same effect as add_column).""" - LAST = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={LAST: True})) - - # 3 default columns, inserting at index 3 = end - table.insert_column(3, "last", lambda e: e.properties[LAST]) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "runtime", "error", "last"] - assert frame["last"][0] - - -def test_estimation_table_insert_column_with_formatter(): - """Test inserting a column with a formatter.""" - NS = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={NS: 2000})) - - table.insert_column( - 0, - "custom_time", - lambda e: e.properties[NS], - formatter=lambda x: pd.Timedelta(x, unit="ns"), - ) - - frame = table.as_frame() - assert frame["custom_time"][0] == pd.Timedelta(2000, unit="ns") - assert list(frame.columns)[0] == "custom_time" - - -def test_estimation_table_insert_and_add_columns(): - """Test combining insert_column and add_column.""" - A = 0 - B = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2})) - - table.add_column("b", lambda e: e.properties[B]) - table.insert_column(0, "a", lambda e: e.properties[A]) - - frame = table.as_frame() - assert list(frame.columns) == ["a", "qubits", "runtime", "error", "b"] - - -def test_estimation_table_factory_summary_no_factories(): - """Test factory summary column when entries have no factories.""" - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01)) - - table.add_factory_summary_column() - - frame = table.as_frame() - assert "factories" in frame.columns - assert frame["factories"][0] == "None" - - -def test_estimation_table_factory_summary_with_estimation(): - """Test factory summary column with real estimation results.""" - code = """ - { - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - } - """ - app = QSharpApplication(code) - arch = AQREGateBased(gate_time=50, measurement_time=100) - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=0.5, - ) - - assert len(results) >= 1 - - results.add_factory_summary_column() - frame = results.as_frame() - - assert "factories" in frame.columns - # Each result should mention T in the factory summary - for val in frame["factories"]: - assert "T" in val - - -def test_estimation_table_add_column_from_source(): - """Test adding a column that accesses the InstructionSource (like distance).""" - code = """ - { - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - } - """ - app = QSharpApplication(code) - arch = AQREGateBased(gate_time=50, measurement_time=100) - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=0.5, - ) - - assert len(results) >= 1 - - results.add_column( - "compute_distance", - lambda entry: entry.source[LATTICE_SURGERY].instruction[DISTANCE], - ) - - frame = results.as_frame() - assert "compute_distance" in frame.columns - for d in frame["compute_distance"]: - assert isinstance(d, int) - assert d >= 3 - - -def test_estimation_table_add_column_from_properties(): - """Test adding columns that access trace properties from estimation.""" - code = """ - { - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - } - """ - app = QSharpApplication(code) - arch = AQREGateBased(gate_time=50, measurement_time=100) - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=0.5, - ) - - assert len(results) >= 1 - - results.add_column( - "num_ts_per_rotation", - lambda entry: entry.properties[NUM_TS_PER_ROTATION], - ) - - frame = results.as_frame() - assert "num_ts_per_rotation" in frame.columns - for val in frame["num_ts_per_rotation"]: - assert isinstance(val, int) - assert val >= 1 - - -def test_estimation_table_insert_column_before_defaults(): - """Test inserting a name column before all default columns, similar to the factoring notebook.""" - code = """ - { - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - } - """ - app = QSharpApplication(code) - arch = AQREGateBased(gate_time=50, measurement_time=100) - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=0.5, - name="test_experiment", - ) - - assert len(results) >= 1 - - # Add a factory summary at the end - results.add_factory_summary_column() - - frame = results.as_frame() - assert frame.columns[0] == "name" - assert frame.columns[-1] == "factories" - # Default columns should still be in order - assert list(frame.columns[1:4]) == ["qubits", "runtime", "error"] - - -def test_estimation_table_as_frame_sortable(): - """Test that the DataFrame from as_frame can be sorted, as done in the factoring tests.""" - table = EstimationTable() - table.append(_make_entry(300, 15000, 0.03)) - table.append(_make_entry(100, 5000, 0.01)) - table.append(_make_entry(200, 10000, 0.02)) - - frame = table.as_frame() - sorted_frame = frame.sort_values(by=["qubits", "runtime"]).reset_index(drop=True) - - assert list(sorted_frame["qubits"]) == [100, 200, 300] - assert list(sorted_frame["error"]) == [0.01, 0.02, 0.03] - - -def test_estimation_table_computed_column(): - """Test adding a column that computes a derived value from the entry.""" - table = EstimationTable() - table.append(_make_entry(100, 5_000_000, 0.01)) - table.append(_make_entry(200, 10_000_000, 0.02)) - - # Compute qubits * error as a derived metric - table.add_column("qubit_error_product", lambda e: e.qubits * e.error) - - frame = table.as_frame() - assert frame["qubit_error_product"][0] == pytest.approx(1.0) - assert frame["qubit_error_product"][1] == pytest.approx(4.0) - - -def test_estimation_table_plot_returns_figure(): - """Test that plot() returns a matplotlib Figure with correct axes.""" - from matplotlib.figure import Figure - - table = EstimationTable() - table.append(_make_entry(100, 5_000_000_000, 0.01)) - table.append(_make_entry(200, 10_000_000_000, 0.02)) - table.append(_make_entry(50, 50_000_000_000, 0.005)) - - fig = table.plot() - - assert isinstance(fig, Figure) - ax = fig.axes[0] - assert ax.get_ylabel() == "Physical qubits" - assert ax.get_xlabel() == "Runtime" - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "log" - - # Verify data points - offsets = ax.collections[0].get_offsets() - assert len(cast(Sized, offsets)) == 3 - - -def test_estimation_table_plot_empty_raises(): - """Test that plot() raises ValueError on an empty table.""" - table = EstimationTable() - with pytest.raises(ValueError, match="Cannot plot an empty EstimationTable"): - table.plot() - - -def test_estimation_table_plot_single_entry(): - """Test that plot() works with a single entry.""" - from matplotlib.figure import Figure - - table = EstimationTable() - table.append(_make_entry(100, 1_000_000, 0.01)) - - fig = table.plot() - assert isinstance(fig, Figure) - - offsets = fig.axes[0].collections[0].get_offsets() - assert len(cast(Sized, offsets)) == 1 - - -def test_estimation_table_plot_with_runtime_unit(): - """Test that plot(runtime_unit=...) scales x values and labels the axis.""" - table = EstimationTable() - # 1 hour = 3600e9 ns, 2 hours = 7200e9 ns - table.append(_make_entry(100, int(3600e9), 0.01)) - table.append(_make_entry(200, int(7200e9), 0.02)) - - fig = table.plot(runtime_unit="hours") - - ax = fig.axes[0] - assert ax.get_xlabel() == "Runtime (hours)" - - # Verify the x data is scaled: should be 1.0 and 2.0 hours - offsets = cast(list, ax.collections[0].get_offsets()) - assert offsets[0][0] == pytest.approx(1.0) - assert offsets[1][0] == pytest.approx(2.0) - - -def test_estimation_table_plot_invalid_runtime_unit(): - """Test that plot() raises ValueError for an unknown runtime_unit.""" - table = EstimationTable() - table.append(_make_entry(100, 1000, 0.01)) - with pytest.raises(ValueError, match="Unknown runtime_unit"): - table.plot(runtime_unit="fortnights") - - -def _ll_files(): - ll_dir = ( - Path(__file__).parent.parent - / "tests-integration" - / "resources" - / "adaptive_ri" - / "output" - ) - return sorted(ll_dir.glob("*.ll")) - - -@pytest.mark.parametrize("ll_file", _ll_files(), ids=lambda p: p.stem) -def test_trace_from_qir(ll_file): - # NOTE: This test is primarily to ensure that the function can parse real - # QIR output without errors, rather than checking specific properties of the - # trace. - try: - trace_from_qir(ll_file.read_text()) - except ValueError as e: - # The only reason of failure is presence of control flow - assert ( - str(e) - == "simulation of programs with branching control flow is not supported" - ) - - -def test_trace_from_qir_handles_all_instruction_ids(): - """Verify that trace_from_qir handles every QirInstructionId except CorrelatedNoise. - - Generates a synthetic QIR program containing one instance of each gate - intrinsic recognised by AggregateGatesPass and asserts that trace_from_qir - processes all of them without error. - """ - import pyqir - import pyqir.qis as qis - from qsharp._native import QirInstructionId - from qsharp.qre.interop._qir import _GATE_MAP, _MEAS_MAP, _SKIP - - # -- Completeness check: every QirInstructionId must be covered -------- - handled_ids = ( - [qir_id for qir_id, _, _ in _GATE_MAP] - + [qir_id for qir_id, _ in _MEAS_MAP] - + list(_SKIP) - ) - # Exhaustive list of all QirInstructionId variants (pyo3 enums are not iterable) - all_ids = [ - QirInstructionId.I, - QirInstructionId.H, - QirInstructionId.X, - QirInstructionId.Y, - QirInstructionId.Z, - QirInstructionId.S, - QirInstructionId.SAdj, - QirInstructionId.SX, - QirInstructionId.SXAdj, - QirInstructionId.T, - QirInstructionId.TAdj, - QirInstructionId.CNOT, - QirInstructionId.CX, - QirInstructionId.CY, - QirInstructionId.CZ, - QirInstructionId.CCX, - QirInstructionId.SWAP, - QirInstructionId.RX, - QirInstructionId.RY, - QirInstructionId.RZ, - QirInstructionId.RXX, - QirInstructionId.RYY, - QirInstructionId.RZZ, - QirInstructionId.RESET, - QirInstructionId.M, - QirInstructionId.MResetZ, - QirInstructionId.MZ, - QirInstructionId.Move, - QirInstructionId.ReadResult, - QirInstructionId.ResultRecordOutput, - QirInstructionId.BoolRecordOutput, - QirInstructionId.IntRecordOutput, - QirInstructionId.DoubleRecordOutput, - QirInstructionId.TupleRecordOutput, - QirInstructionId.ArrayRecordOutput, - QirInstructionId.CorrelatedNoise, - ] - unhandled = [ - i - for i in all_ids - if i not in handled_ids and i != QirInstructionId.CorrelatedNoise - ] - assert unhandled == [], ( - f"QirInstructionId values not covered by _GATE_MAP, _MEAS_MAP, or _SKIP: " - f"{', '.join(str(i) for i in unhandled)}" - ) - - # -- Generate a QIR program with every producible gate ----------------- - simple = pyqir.SimpleModule("test_all_gates", num_qubits=4, num_results=3) - builder = simple.builder - ctx = simple.context - q = simple.qubits - r = simple.results - - void_ty = pyqir.Type.void(ctx) - qubit_ty = pyqir.qubit_type(ctx) - result_ty = pyqir.result_type(ctx) - double_ty = pyqir.Type.double(ctx) - i64_ty = pyqir.IntType(ctx, 64) - - def declare(name, param_types): - return simple.add_external_function( - name, pyqir.FunctionType(void_ty, param_types) - ) - - # Single-qubit gates (pyqir.qis builtins) - qis.h(builder, q[0]) - qis.x(builder, q[0]) - qis.y(builder, q[0]) - qis.z(builder, q[0]) - qis.s(builder, q[0]) - qis.s_adj(builder, q[0]) - qis.t(builder, q[0]) - qis.t_adj(builder, q[0]) - - # SX — not in pyqir.qis - sx_fn = declare("__quantum__qis__sx__body", [qubit_ty]) - builder.call(sx_fn, [q[0]]) - - # Two-qubit gates (qis.cx emits __quantum__qis__cnot__body which the - # pass does not handle, so use builder.call with the correct name) - cx_fn = declare("__quantum__qis__cx__body", [qubit_ty, qubit_ty]) - builder.call(cx_fn, [q[0], q[1]]) - qis.cz(builder, q[0], q[1]) - qis.swap(builder, q[0], q[1]) - - cy_fn = declare("__quantum__qis__cy__body", [qubit_ty, qubit_ty]) - builder.call(cy_fn, [q[0], q[1]]) - - # Three-qubit gate - qis.ccx(builder, q[0], q[1], q[2]) - - # Single-qubit rotations - qis.rx(builder, 1.0, q[0]) - qis.ry(builder, 1.0, q[0]) - qis.rz(builder, 1.0, q[0]) - - # Two-qubit rotations — not in pyqir.qis - rot2_ty = [double_ty, qubit_ty, qubit_ty] - angle = pyqir.const(double_ty, 1.0) - for name in ("rxx", "ryy", "rzz"): - fn = declare(f"__quantum__qis__{name}__body", rot2_ty) - builder.call(fn, [angle, q[0], q[1]]) - - # Measurements - qis.mz(builder, q[0], r[0]) - - m_fn = declare("__quantum__qis__m__body", [qubit_ty, result_ty]) - builder.call(m_fn, [q[1], r[1]]) - - mresetz_fn = declare("__quantum__qis__mresetz__body", [qubit_ty, result_ty]) - builder.call(mresetz_fn, [q[2], r[2]]) - - # Reset / Move - qis.reset(builder, q[0]) - - move_fn = declare("__quantum__qis__move__body", [qubit_ty]) - builder.call(move_fn, [q[0]]) - - # Output recording - tag = simple.add_byte_string(b"tag") - arr_fn = declare("__quantum__rt__array_record_output", [i64_ty, tag.type]) - builder.call(arr_fn, [pyqir.const(i64_ty, 1), tag]) - - rec_fn = declare("__quantum__rt__result_record_output", [result_ty, tag.type]) - builder.call(rec_fn, [r[0], tag]) - - tup_fn = declare("__quantum__rt__tuple_record_output", [i64_ty, tag.type]) - builder.call(tup_fn, [pyqir.const(i64_ty, 1), tag]) - - # -- Run trace_from_qir and verify it succeeds ------------------------- - trace = trace_from_qir(simple.ir()) - assert trace is not None - - -@pytest.mark.skipif( - "SLOW_TESTS" not in os.environ, - reason="turn on slow tests by setting SLOW_TESTS=1 in the environment", -) -@pytest.mark.parametrize( - "post_process, use_graph", - [ - (False, False), - (True, False), - (False, True), - (True, True), - ], -) -def test_estimation_methods(post_process, use_graph): - counts = LogicalCounts( - { - "numQubits": 1000, - "tCount": 1_500_000, - "rotationCount": 0, - "rotationDepth": 0, - "cczCount": 1_000_000_000, - "ccixCount": 0, - "measurementCount": 25_000_000, - "numComputeQubits": 200, - "readFromMemoryCount": 30_000_000, - "writeToMemoryCount": 30_000_000, - } - ) - - trace_query = PSSPC.q() * LatticeSurgery.q(slow_down_factor=[1.0, 2.0]) - isa_query = ( - SurfaceCode.q() - * RoundBasedFactory.q() - * TwoDimensionalYokedSurfaceCode.q(source=SurfaceCode.q()) - ) - - app = QSharpApplication(counts) - arch = AQREGateBased(gate_time=50, measurement_time=100) - - results = estimate( - app, - arch, - isa_query, - trace_query, - max_error=1 / 3, - post_process=post_process, - use_graph=use_graph, - ) - results.add_factory_summary_column() - - assert [(result.qubits, result.runtime) for result in results] == [ - (238707, 23997050000000), - (240407, 11998525000000), - ] - - print() - print(results.stats) - - -def test_rotation_buckets(): - from qsharp.qre.interop._qsharp import _bucketize_rotation_counts - - print() - - r_count = 15066 - r_depth = 14756 - q_count = 291 - - result = _bucketize_rotation_counts(r_count, r_depth) - - a_count = 0 - a_depth = 0 - for c, d in result: - print(c, d) - assert c <= q_count - assert c > 0 - a_count += c * d - a_depth += d - - assert a_count == r_count - assert a_depth == r_depth diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index 3fcde87d89..cdeac80524 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -12,10 +12,13 @@ use num_traits::FromPrimitive; use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; -use crate::{ParetoFrontier3D, trace::instruction_ids::instruction_name}; +use crate::trace::instruction_ids::instruction_name; pub mod property_keys; +mod provenance; +pub use provenance::ProvenanceGraph; + #[cfg(test)] mod tests; @@ -703,327 +706,3 @@ impl ConstraintBound { } } } - -pub struct ProvenanceGraph { - nodes: Vec, - // A consecutive list of child node indices for each node, where the - // children of node i are located at children[offset..offset+num_children] - // in the children vector. - children: Vec, - // Per-instruction-ID index of Pareto-optimal node indices. - // Built by `build_pareto_index()` after all nodes have been added. - pareto_index: FxHashMap>, -} - -impl Default for ProvenanceGraph { - fn default() -> Self { - // Initialize with a dummy node at index 0 to simplify indexing logic - // (so that 0 can be used as a "null" provenance) - let empty = ProvenanceNode::default(); - ProvenanceGraph { - nodes: vec![empty], - children: Vec::new(), - pareto_index: FxHashMap::default(), - } - } -} - -/// Thin wrapper for 3D Pareto comparison of instructions at arity 1. -struct InstructionParetoItem { - node_index: usize, - space: u64, - time: u64, - error: f64, -} - -impl crate::ParetoItem3D for InstructionParetoItem { - type Objective1 = u64; - type Objective2 = u64; - type Objective3 = f64; - - fn objective1(&self) -> u64 { - self.space - } - fn objective2(&self) -> u64 { - self.time - } - fn objective3(&self) -> f64 { - self.error - } -} - -impl ProvenanceGraph { - #[must_use] - pub fn new() -> Self { - Self::default() - } - - pub fn add_node( - &mut self, - mut instruction: Instruction, - transform_id: u64, - children: &[usize], - ) -> usize { - let node_index = self.nodes.len(); - instruction.source = node_index; - let offset = self.children.len(); - let num_children = children.len(); - self.children.extend_from_slice(children); - self.nodes.push(ProvenanceNode { - instruction, - transform_id, - offset, - num_children, - }); - node_index - } - - #[must_use] - pub fn instruction(&self, node_index: usize) -> &Instruction { - &self.nodes[node_index].instruction - } - - #[must_use] - pub fn transform_id(&self, node_index: usize) -> u64 { - self.nodes[node_index].transform_id - } - - #[must_use] - pub fn children(&self, node_index: usize) -> &[usize] { - let node = &self.nodes[node_index]; - &self.children[node.offset..node.offset + node.num_children] - } - - #[must_use] - pub fn num_nodes(&self) -> usize { - self.nodes.len() - 1 - } - - #[must_use] - pub fn num_edges(&self) -> usize { - self.children.len() - } - - /// Builds the per-instruction-ID Pareto index. - /// - /// For each instruction ID in the graph, collects all nodes and retains - /// only the Pareto-optimal subset with respect to (space, time, `error_rate`) - /// evaluated at arity 1. Instructions with different encodings or - /// properties are never in competition. - /// - /// Must be called after all nodes have been added. - pub fn build_pareto_index(&mut self) { - // Group node indices by (instruction_id, encoding, properties) - let mut groups: FxHashMap> = FxHashMap::default(); - for idx in 1..self.nodes.len() { - let instr = &self.nodes[idx].instruction; - groups.entry(instr.id).or_default().push(idx); - } - - let mut pareto_index = FxHashMap::default(); - for (id, node_indices) in groups { - // Sub-partition by encoding and property keys to avoid comparing - // incompatible instructions (Risk R2 mitigation) - #[allow(clippy::type_complexity)] - let mut sub_groups: FxHashMap<(Encoding, Vec<(u64, u64)>), Vec> = - FxHashMap::default(); - for &idx in &node_indices { - let instr = &self.nodes[idx].instruction; - let mut prop_vec: Vec<(u64, u64)> = instr - .properties - .as_ref() - .map(|p| { - let mut v: Vec<_> = p.iter().map(|(&k, &v)| (k, v)).collect(); - v.sort_unstable(); - v - }) - .unwrap_or_default(); - prop_vec.sort_unstable(); - sub_groups - .entry((instr.encoding, prop_vec)) - .or_default() - .push(idx); - } - - let mut pareto_nodes = Vec::new(); - for (_key, indices) in sub_groups { - let items: Vec = indices - .iter() - .filter_map(|&idx| { - let instr = &self.nodes[idx].instruction; - let space = instr.space(Some(1))?; - let time = instr.time(Some(1))?; - let error = instr.error_rate(Some(1))?; - Some(InstructionParetoItem { - node_index: idx, - space, - time, - error, - }) - }) - .collect(); - - let frontier: ParetoFrontier3D = items.into_iter().collect(); - pareto_nodes.extend(frontier.into_iter().map(|item| item.node_index)); - } - - pareto_index.insert(id, pareto_nodes); - } - - self.pareto_index = pareto_index; - } - - /// Returns the Pareto-optimal node indices for a given instruction ID. - #[must_use] - pub fn pareto_nodes(&self, instruction_id: u64) -> Option<&[usize]> { - self.pareto_index.get(&instruction_id).map(Vec::as_slice) - } - - /// Returns all instruction IDs that have Pareto-optimal entries. - #[must_use] - pub fn pareto_instruction_ids(&self) -> Vec { - self.pareto_index.keys().copied().collect() - } - - /// Returns the raw node count (including the sentinel at index 0). - #[must_use] - pub fn raw_node_count(&self) -> usize { - self.nodes.len() - } - - /// Returns the total number of ISAs that can be formed from Pareto-optimal - /// nodes. - /// - /// Requires [`build_pareto_index`](Self::build_pareto_index) to have - /// been called. - #[must_use] - pub fn total_isa_count(&self) -> usize { - self.pareto_index.values().map(Vec::len).product() - } - - /// Returns ISAs formed from Pareto-optimal nodes that satisfy the given - /// requirements. - /// - /// For each constraint, selects matching Pareto-optimal nodes. Produces - /// the Cartesian product of per-constraint match sets, each augmented - /// with one representative node per unconstrained instruction ID (so - /// that returned ISAs contain entries for all instruction types in the - /// graph). - /// - /// When `min_node_idx` is `Some(n)`, only Pareto nodes with index ≥ n - /// are considered for constrained groups. Unconstrained "extra" nodes - /// are not filtered since they serve only as default placeholders. - /// - /// Requires [`build_pareto_index`](Self::build_pareto_index) to have - /// been called. - #[must_use] - pub fn query_satisfying( - &self, - graph_arc: &Arc>, - requirements: &ISARequirements, - min_node_idx: Option, - ) -> Vec { - let min_idx = min_node_idx.unwrap_or(0); - - let mut constrained_groups: Vec> = Vec::new(); - let mut constrained_ids: FxHashSet = FxHashSet::default(); - - for constraint in requirements.constraints.values() { - constrained_ids.insert(constraint.id()); - - // When a node range is specified, scan ALL nodes in the range - // instead of using the global Pareto index. The global index - // may have pruned nodes from this range as duplicates of - // earlier, equivalent nodes outside the range. - let matching: Vec<(u64, usize)> = if min_idx > 0 { - (min_idx..self.nodes.len()) - .filter(|&node_idx| { - let instr = &self.nodes[node_idx].instruction; - instr.id == constraint.id() && constraint.is_satisfied_by(instr) - }) - .map(|node_idx| (constraint.id(), node_idx)) - .collect() - } else { - let Some(pareto) = self.pareto_index.get(&constraint.id()) else { - return Vec::new(); - }; - pareto - .iter() - .filter(|&&node_idx| constraint.is_satisfied_by(self.instruction(node_idx))) - .map(|&node_idx| (constraint.id(), node_idx)) - .collect() - }; - - if matching.is_empty() { - return Vec::new(); - } - constrained_groups.push(matching); - } - - // One representative node per unconstrained instruction ID. - // When a Pareto index is available, use it; otherwise scan all - // nodes (this path is used during populate() before the index - // is built). - let extra_nodes: Vec<(u64, usize)> = if self.pareto_index.is_empty() { - let mut seen: FxHashMap = FxHashMap::default(); - for idx in 1..self.nodes.len() { - let id = self.nodes[idx].instruction.id; - if !constrained_ids.contains(&id) { - seen.entry(id).or_insert(idx); - } - } - seen.into_iter().collect() - } else { - self.pareto_index - .iter() - .filter(|(id, _)| !constrained_ids.contains(id)) - .filter_map(|(&id, nodes)| nodes.first().map(|&n| (id, n))) - .collect() - }; - - // Cartesian product of constrained groups - let mut combinations: Vec> = vec![Vec::new()]; - for group in &constrained_groups { - let mut next = Vec::with_capacity(combinations.len() * group.len()); - for combo in &combinations { - for &item in group { - let mut extended = combo.clone(); - extended.push(item); - next.push(extended); - } - } - combinations = next; - } - - // Build ISAs from selections - combinations - .into_iter() - .map(|mut combo| { - combo.extend(extra_nodes.iter().copied()); - let mut isa = ISA::with_graph(Arc::clone(graph_arc)); - for (id, node_idx) in combo { - isa.add_node(id, node_idx); - } - isa - }) - .collect() - } -} - -struct ProvenanceNode { - instruction: Instruction, - transform_id: u64, - offset: usize, - num_children: usize, -} - -impl Default for ProvenanceNode { - fn default() -> Self { - ProvenanceNode { - instruction: Instruction::fixed_arity(0, Encoding::Physical, 0, 0, None, None, 0.0), - transform_id: 0, - offset: 0, - num_children: 0, - } - } -} diff --git a/source/qre/src/isa/provenance.rs b/source/qre/src/isa/provenance.rs new file mode 100644 index 0000000000..8b59660639 --- /dev/null +++ b/source/qre/src/isa/provenance.rs @@ -0,0 +1,332 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::sync::{Arc, RwLock}; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{Encoding, ISA, ISARequirements, Instruction, ParetoFrontier3D}; + +pub struct ProvenanceGraph { + nodes: Vec, + // A consecutive list of child node indices for each node, where the + // children of node i are located at children[offset..offset+num_children] + // in the children vector. + children: Vec, + // Per-instruction-ID index of Pareto-optimal node indices. + // Built by `build_pareto_index()` after all nodes have been added. + pareto_index: FxHashMap>, +} + +impl Default for ProvenanceGraph { + fn default() -> Self { + // Initialize with a dummy node at index 0 to simplify indexing logic + // (so that 0 can be used as a "null" provenance) + let empty = ProvenanceNode::default(); + ProvenanceGraph { + nodes: vec![empty], + children: Vec::new(), + pareto_index: FxHashMap::default(), + } + } +} + +/// Thin wrapper for 3D Pareto comparison of instructions at arity 1. +struct InstructionParetoItem { + node_index: usize, + space: u64, + time: u64, + error: f64, +} + +impl crate::ParetoItem3D for InstructionParetoItem { + type Objective1 = u64; + type Objective2 = u64; + type Objective3 = f64; + + fn objective1(&self) -> u64 { + self.space + } + fn objective2(&self) -> u64 { + self.time + } + fn objective3(&self) -> f64 { + self.error + } +} + +impl ProvenanceGraph { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn add_node( + &mut self, + mut instruction: Instruction, + transform_id: u64, + children: &[usize], + ) -> usize { + let node_index = self.nodes.len(); + instruction.source = node_index; + let offset = self.children.len(); + let num_children = children.len(); + self.children.extend_from_slice(children); + self.nodes.push(ProvenanceNode { + instruction, + transform_id, + offset, + num_children, + }); + node_index + } + + #[must_use] + pub fn instruction(&self, node_index: usize) -> &Instruction { + &self.nodes[node_index].instruction + } + + #[must_use] + pub fn transform_id(&self, node_index: usize) -> u64 { + self.nodes[node_index].transform_id + } + + #[must_use] + pub fn children(&self, node_index: usize) -> &[usize] { + let node = &self.nodes[node_index]; + &self.children[node.offset..node.offset + node.num_children] + } + + #[must_use] + pub fn num_nodes(&self) -> usize { + self.nodes.len() - 1 + } + + #[must_use] + pub fn num_edges(&self) -> usize { + self.children.len() + } + + /// Builds the per-instruction-ID Pareto index. + /// + /// For each instruction ID in the graph, collects all nodes and retains + /// only the Pareto-optimal subset with respect to (space, time, `error_rate`) + /// evaluated at arity 1. Instructions with different encodings or + /// properties are never in competition. + /// + /// Must be called after all nodes have been added. + pub fn build_pareto_index(&mut self) { + // Group node indices by (instruction_id, encoding, properties) + let mut groups: FxHashMap> = FxHashMap::default(); + for idx in 1..self.nodes.len() { + let instr = &self.nodes[idx].instruction; + groups.entry(instr.id).or_default().push(idx); + } + + let mut pareto_index = FxHashMap::default(); + for (id, node_indices) in groups { + // Sub-partition by encoding and property keys to avoid comparing + // incompatible instructions (Risk R2 mitigation) + #[allow(clippy::type_complexity)] + let mut sub_groups: FxHashMap<(Encoding, Vec<(u64, u64)>), Vec> = + FxHashMap::default(); + for &idx in &node_indices { + let instr = &self.nodes[idx].instruction; + let mut prop_vec: Vec<(u64, u64)> = instr + .properties + .as_ref() + .map(|p| { + let mut v: Vec<_> = p.iter().map(|(&k, &v)| (k, v)).collect(); + v.sort_unstable(); + v + }) + .unwrap_or_default(); + prop_vec.sort_unstable(); + sub_groups + .entry((instr.encoding, prop_vec)) + .or_default() + .push(idx); + } + + let mut pareto_nodes = Vec::new(); + for (_key, indices) in sub_groups { + let items: Vec = indices + .iter() + .filter_map(|&idx| { + let instr = &self.nodes[idx].instruction; + let space = instr.space(Some(1))?; + let time = instr.time(Some(1))?; + let error = instr.error_rate(Some(1))?; + Some(InstructionParetoItem { + node_index: idx, + space, + time, + error, + }) + }) + .collect(); + + let frontier: ParetoFrontier3D = items.into_iter().collect(); + pareto_nodes.extend(frontier.into_iter().map(|item| item.node_index)); + } + + pareto_index.insert(id, pareto_nodes); + } + + self.pareto_index = pareto_index; + } + + /// Returns the Pareto-optimal node indices for a given instruction ID. + #[must_use] + pub fn pareto_nodes(&self, instruction_id: u64) -> Option<&[usize]> { + self.pareto_index.get(&instruction_id).map(Vec::as_slice) + } + + /// Returns all instruction IDs that have Pareto-optimal entries. + #[must_use] + pub fn pareto_instruction_ids(&self) -> Vec { + self.pareto_index.keys().copied().collect() + } + + /// Returns the raw node count (including the sentinel at index 0). + #[must_use] + pub fn raw_node_count(&self) -> usize { + self.nodes.len() + } + + /// Returns the total number of ISAs that can be formed from Pareto-optimal + /// nodes. + /// + /// Requires [`build_pareto_index`](Self::build_pareto_index) to have + /// been called. + #[must_use] + pub fn total_isa_count(&self) -> usize { + self.pareto_index.values().map(Vec::len).product() + } + + /// Returns ISAs formed from Pareto-optimal nodes that satisfy the given + /// requirements. + /// + /// For each constraint, selects matching Pareto-optimal nodes. Produces + /// the Cartesian product of per-constraint match sets, each augmented + /// with one representative node per unconstrained instruction ID (so + /// that returned ISAs contain entries for all instruction types in the + /// graph). + /// + /// When `min_node_idx` is `Some(n)`, only Pareto nodes with index ≥ n + /// are considered for constrained groups. Unconstrained "extra" nodes + /// are not filtered since they serve only as default placeholders. + /// + /// Requires [`build_pareto_index`](Self::build_pareto_index) to have + /// been called. + #[must_use] + pub fn query_satisfying( + &self, + graph_arc: &Arc>, + requirements: &ISARequirements, + min_node_idx: Option, + ) -> Vec { + let min_idx = min_node_idx.unwrap_or(0); + + let mut constrained_groups: Vec> = Vec::new(); + let mut constrained_ids: FxHashSet = FxHashSet::default(); + + for constraint in requirements.constraints.values() { + constrained_ids.insert(constraint.id()); + + // When a node range is specified, scan ALL nodes in the range + // instead of using the global Pareto index. The global index + // may have pruned nodes from this range as duplicates of + // earlier, equivalent nodes outside the range. + let matching: Vec<(u64, usize)> = if min_idx > 0 { + (min_idx..self.nodes.len()) + .filter(|&node_idx| { + let instr = &self.nodes[node_idx].instruction; + instr.id == constraint.id() && constraint.is_satisfied_by(instr) + }) + .map(|node_idx| (constraint.id(), node_idx)) + .collect() + } else { + let Some(pareto) = self.pareto_index.get(&constraint.id()) else { + return Vec::new(); + }; + pareto + .iter() + .filter(|&&node_idx| constraint.is_satisfied_by(self.instruction(node_idx))) + .map(|&node_idx| (constraint.id(), node_idx)) + .collect() + }; + + if matching.is_empty() { + return Vec::new(); + } + constrained_groups.push(matching); + } + + // One representative node per unconstrained instruction ID. + // When a Pareto index is available, use it; otherwise scan all + // nodes (this path is used during populate() before the index + // is built). + let extra_nodes: Vec<(u64, usize)> = if self.pareto_index.is_empty() { + let mut seen: FxHashMap = FxHashMap::default(); + for idx in 1..self.nodes.len() { + let id = self.nodes[idx].instruction.id; + if !constrained_ids.contains(&id) { + seen.entry(id).or_insert(idx); + } + } + seen.into_iter().collect() + } else { + self.pareto_index + .iter() + .filter(|(id, _)| !constrained_ids.contains(id)) + .filter_map(|(&id, nodes)| nodes.first().map(|&n| (id, n))) + .collect() + }; + + // Cartesian product of constrained groups + let mut combinations: Vec> = vec![Vec::new()]; + for group in &constrained_groups { + let mut next = Vec::with_capacity(combinations.len() * group.len()); + for combo in &combinations { + for &item in group { + let mut extended = combo.clone(); + extended.push(item); + next.push(extended); + } + } + combinations = next; + } + + // Build ISAs from selections + combinations + .into_iter() + .map(|mut combo| { + combo.extend(extra_nodes.iter().copied()); + let mut isa = ISA::with_graph(Arc::clone(graph_arc)); + for (id, node_idx) in combo { + isa.add_node(id, node_idx); + } + isa + }) + .collect() + } +} + +struct ProvenanceNode { + instruction: Instruction, + transform_id: u64, + offset: usize, + num_children: usize, +} + +impl Default for ProvenanceNode { + fn default() -> Self { + ProvenanceNode { + instruction: Instruction::fixed_arity(0, Encoding::Physical, 0, 0, None, None, 0.0), + transform_id: 0, + offset: 0, + num_children: 0, + } + } +} diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 08858a4551..0e2d1ef106 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -2,11 +2,7 @@ // Licensed under the MIT License. use std::{ - collections::hash_map::DefaultHasher, fmt::{Display, Formatter}, - hash::{Hash, Hasher}, - iter::repeat_with, - sync::{Arc, RwLock, atomic::AtomicUsize}, vec, }; @@ -14,14 +10,17 @@ use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; use crate::{ - ConstraintBound, Encoding, Error, EstimationCollection, EstimationResult, FactoryResult, ISA, - ISARequirements, Instruction, InstructionConstraint, LockedISA, ProvenanceGraph, ResultSummary, + ConstraintBound, Encoding, Error, EstimationResult, FactoryResult, ISA, ISARequirements, + Instruction, InstructionConstraint, LockedISA, property_keys::{ LOGICAL_COMPUTE_QUBITS, LOGICAL_MEMORY_QUBITS, PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, }, }; +mod estimation; +pub use estimation::{estimate_parallel, estimate_with_graph}; + pub mod instruction_ids; use instruction_ids::instruction_name; #[cfg(test)] @@ -752,452 +751,3 @@ fn get_error_rate_by_id(isa: &LockedISA<'_>, id: u64) -> Result { .error_rate(None) .ok_or(Error::CannotExtractErrorRate(id)) } - -/// Estimates all (trace, ISA) combinations in parallel, returning only the -/// successful results collected into an [`EstimationCollection`]. -/// -/// This uses a shared atomic counter as a lock-free work queue. Each worker -/// thread atomically claims the next job index, maps it to a `(trace, isa)` -/// pair, and runs the estimation. This keeps all available cores busy until -/// the last job completes. -/// -/// # Work distribution -/// -/// Jobs are numbered `0 .. traces.len() * isas.len()`. For job index `j`: -/// - `trace_idx = j / isas.len()` -/// - `isa_idx = j % isas.len()` -/// -/// Each worker accumulates results locally and sends them back over a bounded -/// channel once it runs out of work, avoiding contention on the shared -/// collection. -#[must_use] -pub fn estimate_parallel<'a>( - traces: &[&'a Trace], - isas: &[&'a ISA], - max_error: Option, - post_process: bool, -) -> EstimationCollection { - let total_jobs = traces.len() * isas.len(); - let num_isas = isas.len(); - - // Shared atomic counter acts as a lock-free work queue. Workers call - // fetch_add to claim the next job index. - let next_job = AtomicUsize::new(0); - - let mut collection = EstimationCollection::new(); - collection.set_total_jobs(total_jobs); - - std::thread::scope(|scope| { - let num_threads = std::thread::available_parallelism() - .map(std::num::NonZero::get) - .unwrap_or(1); - - // Bounded channel so each worker can send its batch of results back - // to the main thread without unbounded buffering. - let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); - - for _ in 0..num_threads { - let tx = tx.clone(); - let next_job = &next_job; - scope.spawn(move || { - let mut local_results = Vec::new(); - loop { - // Atomically claim the next job. Relaxed ordering is - // sufficient because there is no dependent data between - // jobs — each (trace, isa) pair is independent. - let job = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if job >= total_jobs { - break; - } - - // Map the flat job index to a (trace, ISA) pair. - let trace_idx = job / num_isas; - let isa_idx = job % num_isas; - - if let Ok(mut estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) - { - estimation.set_isa_index(isa_idx); - estimation.set_trace_index(trace_idx); - - local_results.push(estimation); - } - } - // Send all results from this worker in one batch. - let _ = tx.send(local_results); - }); - } - // Drop the cloned sender so the receiver iterator terminates once all - // workers have finished. - drop(tx); - - // Collect results from all workers into the shared collection. - let mut successful = 0; - for local_results in rx { - if post_process { - for result in &local_results { - collection.push_summary(ResultSummary { - trace_index: result.trace_index().unwrap_or(0), - isa_index: result.isa_index().unwrap_or(0), - qubits: result.qubits(), - runtime: result.runtime(), - }); - } - } - successful += local_results.len(); - collection.extend(local_results.into_iter()); - } - collection.set_successful_estimates(successful); - }); - - // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap - // clones for discarded results. - for result in collection.iter_mut() { - if let Some(idx) = result.isa_index() { - result.set_isa(isas[idx].clone()); - } - } - - collection -} - -/// A node in the provenance graph along with pre-computed (space, time) values -/// for pruning. -#[derive(Clone, Copy, Hash, PartialEq, Eq)] -struct NodeProfile { - node_index: usize, - space: u64, - time: u64, -} - -/// A single entry in a combination of instruction choices for estimation. -#[derive(Clone, Copy, Hash, Eq, PartialEq)] -struct CombinationEntry { - instruction_id: u64, - node: NodeProfile, -} - -/// Per-slot pruning witnesses: maps a context hash to the `(space, time)` -/// pairs observed in successful estimations. -type SlotWitnesses = RwLock>>; - -/// Computes a hash of the combination context (all slots except the excluded -/// one). Two combinations that agree on every slot except `exclude_idx` -/// produce the same context hash. -fn combination_context_hash(combination: &[CombinationEntry], exclude_idx: usize) -> u64 { - let mut hasher = DefaultHasher::new(); - for (i, entry) in combination.iter().enumerate() { - if i != exclude_idx { - entry.instruction_id.hash(&mut hasher); - entry.node.node_index.hash(&mut hasher); - } - } - hasher.finish() -} - -/// Checks whether a combination is dominated by a previously successful one. -/// -/// A combination is prunable if, for any instruction slot, there exists a -/// successful combination with the same instructions in all other slots and -/// an instruction at that slot with `space <=` and `time <=`. -fn is_dominated(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) -> bool { - for (slot_idx, entry) in combination.iter().enumerate() { - let ctx_hash = combination_context_hash(combination, slot_idx); - let map = trace_pruning[slot_idx] - .read() - .expect("Pruning lock poisoned"); - if map.get(&ctx_hash).is_some_and(|w| { - w.iter() - .any(|&(ws, wt)| ws <= entry.node.space && wt <= entry.node.time) - }) { - return true; - } - } - false -} - -/// Records a successful estimation as a pruning witness for future -/// combinations. -fn record_success(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) { - for (slot_idx, entry) in combination.iter().enumerate() { - let ctx_hash = combination_context_hash(combination, slot_idx); - let mut map = trace_pruning[slot_idx] - .write() - .expect("Pruning lock poisoned"); - map.entry(ctx_hash) - .or_default() - .push((entry.node.space, entry.node.time)); - } -} - -#[derive(Default)] -struct ISAIndex { - index: FxHashMap, usize>, - isas: Vec, -} - -impl From for Vec { - fn from(value: ISAIndex) -> Self { - value.isas - } -} - -impl ISAIndex { - pub fn push(&mut self, combination: &Vec, isa: &ISA) -> usize { - if let Some(&idx) = self.index.get(combination) { - idx - } else { - let idx = self.isas.len(); - self.isas.push(isa.clone()); - self.index.insert(combination.clone(), idx); - idx - } - } -} - -/// Generates the cartesian product of `id_and_nodes` and pushes each -/// combination directly into `jobs`, avoiding intermediate allocations. -/// -/// The cartesian product is enumerated using mixed-radix indexing. Given -/// dimensions with sizes `[n0, n1, n2, …]`, the total number of combinations -/// is `n0 * n1 * n2 * …`. Each combination index `i` in `0..total` uniquely -/// identifies one element from every dimension: the index into dimension `d` is -/// `(i / (n0 * n1 * … * n(d-1))) % nd`, which we compute incrementally by -/// repeatedly taking `i % nd` and then dividing `i` by `nd`. This is -/// analogous to extracting digits from a number in a mixed-radix system. -fn push_cartesian_product( - id_and_nodes: &[(u64, Vec)], - trace_idx: usize, - jobs: &mut Vec<(usize, Vec)>, - max_slots: &mut usize, -) { - // The product of all dimension sizes gives the total number of - // combinations. If any dimension is empty the product is zero and there - // are no valid combinations to generate. - let total: usize = id_and_nodes.iter().map(|(_, nodes)| nodes.len()).product(); - if total == 0 { - return; - } - - *max_slots = (*max_slots).max(id_and_nodes.len()); - jobs.reserve(total); - - // Enumerate every combination by treating the combination index `i` as a - // mixed-radix number. The inner loop "peels off" one digit per dimension: - // node_idx = i % nodes.len() — selects this dimension's element - // i /= nodes.len() — shifts to the next dimension's digit - // After processing all dimensions, `i` is exhausted (becomes 0), and - // `combo` contains exactly one entry per instruction id. - for mut i in 0..total { - let mut combo = Vec::with_capacity(id_and_nodes.len()); - for (id, nodes) in id_and_nodes { - let node_idx = i % nodes.len(); - i /= nodes.len(); - let profile = nodes[node_idx]; - combo.push(CombinationEntry { - instruction_id: *id, - node: profile, - }); - } - jobs.push((trace_idx, combo)); - } -} - -#[must_use] -#[allow(clippy::cast_precision_loss, clippy::too_many_lines)] -pub fn estimate_with_graph( - traces: &[&Trace], - graph: &Arc>, - max_error: Option, - post_process: bool, -) -> EstimationCollection { - let max_error = max_error.unwrap_or(1.0); - - // Phase 1: Pre-compute all (trace_index, combination) jobs sequentially. - // This reads the provenance graph once per trace and generates the - // cartesian product of Pareto-filtered nodes. Each node carries - // pre-computed (space, time) values for dominance pruning in Phase 2. - let mut jobs: Vec<(usize, Vec)> = Vec::new(); - - // Use the maximum number of instruction slots across all combinations to - // size the pruning witness structure. This will updated while we generate - // jobs. - let mut max_slots = 0; - - for (trace_idx, trace) in traces.iter().enumerate() { - if trace.base_error() > max_error { - continue; - } - - let required = trace.required_instruction_ids(Some(max_error)); - - let graph_lock = graph.read().expect("Graph lock poisoned"); - let id_and_nodes: Vec<_> = required - .constraints() - .iter() - .filter_map(|constraint| { - graph_lock.pareto_nodes(constraint.id()).map(|nodes| { - ( - constraint.id(), - nodes - .iter() - .filter(|&&node| { - // Filter out nodes that don't meet the constraint bounds. - let instruction = graph_lock.instruction(node); - constraint.error_rate().is_none_or(|c| { - c.evaluate(&instruction.error_rate(Some(1)).unwrap_or(0.0)) - }) - }) - .map(|&node| { - let instruction = graph_lock.instruction(node); - let space = instruction.space(Some(1)).unwrap_or(0); - let time = instruction.time(Some(1)).unwrap_or(0); - NodeProfile { - node_index: node, - space, - time, - } - }) - .collect::>(), - ) - }) - }) - .collect(); - drop(graph_lock); - - if id_and_nodes.len() != required.len() { - // If any required instruction is missing from the graph, we can't - // run any estimation for this trace. - continue; - } - - push_cartesian_product(&id_and_nodes, trace_idx, &mut jobs, &mut max_slots); - } - - // Sort jobs so that combinations with smaller total (space + time) are - // processed first. This maximises the effectiveness of dominance pruning - // because successful "cheap" combinations establish witnesses that let us - // skip more expensive ones. - jobs.sort_by_key(|(_, combo)| { - combo - .iter() - .map(|entry| entry.node.space + entry.node.time) - .sum::() - }); - - let total_jobs = jobs.len(); - - // Phase 2: Run estimations in parallel with dominance-based pruning. - // - // For each instruction slot in a combination, we track (space, time) - // witnesses from successful estimations keyed by the "context", which is a - // hash of the node indices in all *other* slots. Before running an - // estimation, we check every slot: if a witness with space ≤ and time ≤ - // exists for that context, the combination is dominated and skipped. - let next_job = AtomicUsize::new(0); - - let pruning_witnesses: Vec> = repeat_with(|| { - repeat_with(|| RwLock::new(FxHashMap::default())) - .take(max_slots) - .collect() - }) - .take(traces.len()) - .collect(); - - // There are no explicit ISAs in this estimation function, as we create them - // on the fly from the graph nodes. For successful jobs, we will attach the - // ISAs to the results collection in a vector with the ISA index addressing - // that vector. In order to avoid storing duplicate ISAs we hash the ISA - // index. - let isa_index = Arc::new(RwLock::new(ISAIndex::default())); - - let mut collection = EstimationCollection::new(); - collection.set_total_jobs(total_jobs); - - std::thread::scope(|scope| { - let num_threads = std::thread::available_parallelism() - .map(std::num::NonZero::get) - .unwrap_or(1); - - let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); - - for _ in 0..num_threads { - let tx = tx.clone(); - let next_job = &next_job; - let jobs = &jobs; - let pruning_witnesses = &pruning_witnesses; - let isa_index = Arc::clone(&isa_index); - scope.spawn(move || { - let mut local_results = Vec::new(); - loop { - let job_idx = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if job_idx >= total_jobs { - break; - } - - let (trace_idx, combination) = &jobs[job_idx]; - - // Dominance pruning: skip if a cheaper instruction at any - // slot already succeeded with the same surrounding context. - if is_dominated(combination, &pruning_witnesses[*trace_idx]) { - continue; - } - - let mut isa = ISA::with_graph(graph.clone()); - for entry in combination { - isa.add_node(entry.instruction_id, entry.node.node_index); - } - - if let Ok(mut result) = traces[*trace_idx].estimate(&isa, Some(max_error)) { - let isa_idx = isa_index - .write() - .expect("RwLock should not be poisoned") - .push(combination, &isa); - result.set_isa_index(isa_idx); - - result.set_trace_index(*trace_idx); - - local_results.push(result); - record_success(combination, &pruning_witnesses[*trace_idx]); - } - } - let _ = tx.send(local_results); - }); - } - drop(tx); - - let mut successful = 0; - for local_results in rx { - if post_process { - for result in &local_results { - collection.push_summary(ResultSummary { - trace_index: result.trace_index().unwrap_or(0), - isa_index: result.isa_index().unwrap_or(0), - qubits: result.qubits(), - runtime: result.runtime(), - }); - } - } - successful += local_results.len(); - collection.extend(local_results.into_iter()); - } - collection.set_successful_estimates(successful); - }); - - let isa_index = Arc::try_unwrap(isa_index) - .ok() - .expect("all threads joined; Arc refcount should be 1") - .into_inner() - .expect("RwLock should not be poisoned"); - - // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap - // clones for discarded results. - for result in collection.iter_mut() { - if let Some(idx) = result.isa_index() { - result.set_isa(isa_index.isas[idx].clone()); - } - } - - collection.set_isas(isa_index.into()); - - collection -} diff --git a/source/qre/src/trace/estimation.rs b/source/qre/src/trace/estimation.rs new file mode 100644 index 0000000000..b75ab35fde --- /dev/null +++ b/source/qre/src/trace/estimation.rs @@ -0,0 +1,462 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + iter::repeat_with, + sync::{Arc, RwLock, atomic::AtomicUsize}, +}; + +use rustc_hash::FxHashMap; + +use crate::{EstimationCollection, ISA, ProvenanceGraph, ResultSummary, Trace}; + +/// Estimates all (trace, ISA) combinations in parallel, returning only the +/// successful results collected into an [`EstimationCollection`]. +/// +/// This uses a shared atomic counter as a lock-free work queue. Each worker +/// thread atomically claims the next job index, maps it to a `(trace, isa)` +/// pair, and runs the estimation. This keeps all available cores busy until +/// the last job completes. +/// +/// # Work distribution +/// +/// Jobs are numbered `0 .. traces.len() * isas.len()`. For job index `j`: +/// - `trace_idx = j / isas.len()` +/// - `isa_idx = j % isas.len()` +/// +/// Each worker accumulates results locally and sends them back over a bounded +/// channel once it runs out of work, avoiding contention on the shared +/// collection. +#[must_use] +pub fn estimate_parallel<'a>( + traces: &[&'a Trace], + isas: &[&'a ISA], + max_error: Option, + post_process: bool, +) -> EstimationCollection { + let total_jobs = traces.len() * isas.len(); + let num_isas = isas.len(); + + // Shared atomic counter acts as a lock-free work queue. Workers call + // fetch_add to claim the next job index. + let next_job = AtomicUsize::new(0); + + let mut collection = EstimationCollection::new(); + collection.set_total_jobs(total_jobs); + + std::thread::scope(|scope| { + let num_threads = std::thread::available_parallelism() + .map(std::num::NonZero::get) + .unwrap_or(1); + + // Bounded channel so each worker can send its batch of results back + // to the main thread without unbounded buffering. + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); + + for _ in 0..num_threads { + let tx = tx.clone(); + let next_job = &next_job; + scope.spawn(move || { + let mut local_results = Vec::new(); + loop { + // Atomically claim the next job. Relaxed ordering is + // sufficient because there is no dependent data between + // jobs — each (trace, isa) pair is independent. + let job = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if job >= total_jobs { + break; + } + + // Map the flat job index to a (trace, ISA) pair. + let trace_idx = job / num_isas; + let isa_idx = job % num_isas; + + if let Ok(mut estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) + { + estimation.set_isa_index(isa_idx); + estimation.set_trace_index(trace_idx); + + local_results.push(estimation); + } + } + // Send all results from this worker in one batch. + let _ = tx.send(local_results); + }); + } + // Drop the cloned sender so the receiver iterator terminates once all + // workers have finished. + drop(tx); + + // Collect results from all workers into the shared collection. + let mut successful = 0; + for local_results in rx { + if post_process { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } + } + successful += local_results.len(); + collection.extend(local_results.into_iter()); + } + collection.set_successful_estimates(successful); + }); + + // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap + // clones for discarded results. + for result in collection.iter_mut() { + if let Some(idx) = result.isa_index() { + result.set_isa(isas[idx].clone()); + } + } + + collection +} + +/// A node in the provenance graph along with pre-computed (space, time) values +/// for pruning. +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +struct NodeProfile { + node_index: usize, + space: u64, + time: u64, +} + +/// A single entry in a combination of instruction choices for estimation. +#[derive(Clone, Copy, Hash, Eq, PartialEq)] +struct CombinationEntry { + instruction_id: u64, + node: NodeProfile, +} + +/// Per-slot pruning witnesses: maps a context hash to the `(space, time)` +/// pairs observed in successful estimations. +type SlotWitnesses = RwLock>>; + +/// Computes a hash of the combination context (all slots except the excluded +/// one). Two combinations that agree on every slot except `exclude_idx` +/// produce the same context hash. +fn combination_context_hash(combination: &[CombinationEntry], exclude_idx: usize) -> u64 { + let mut hasher = DefaultHasher::new(); + for (i, entry) in combination.iter().enumerate() { + if i != exclude_idx { + entry.instruction_id.hash(&mut hasher); + entry.node.node_index.hash(&mut hasher); + } + } + hasher.finish() +} + +/// Checks whether a combination is dominated by a previously successful one. +/// +/// A combination is prunable if, for any instruction slot, there exists a +/// successful combination with the same instructions in all other slots and +/// an instruction at that slot with `space <=` and `time <=`. +fn is_dominated(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) -> bool { + for (slot_idx, entry) in combination.iter().enumerate() { + let ctx_hash = combination_context_hash(combination, slot_idx); + let map = trace_pruning[slot_idx] + .read() + .expect("Pruning lock poisoned"); + if map.get(&ctx_hash).is_some_and(|w| { + w.iter() + .any(|&(ws, wt)| ws <= entry.node.space && wt <= entry.node.time) + }) { + return true; + } + } + false +} + +/// Records a successful estimation as a pruning witness for future +/// combinations. +fn record_success(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) { + for (slot_idx, entry) in combination.iter().enumerate() { + let ctx_hash = combination_context_hash(combination, slot_idx); + let mut map = trace_pruning[slot_idx] + .write() + .expect("Pruning lock poisoned"); + map.entry(ctx_hash) + .or_default() + .push((entry.node.space, entry.node.time)); + } +} + +#[derive(Default)] +struct ISAIndex { + index: FxHashMap, usize>, + isas: Vec, +} + +impl From for Vec { + fn from(value: ISAIndex) -> Self { + value.isas + } +} + +impl ISAIndex { + pub fn push(&mut self, combination: &Vec, isa: &ISA) -> usize { + if let Some(&idx) = self.index.get(combination) { + idx + } else { + let idx = self.isas.len(); + self.isas.push(isa.clone()); + self.index.insert(combination.clone(), idx); + idx + } + } +} + +/// Generates the cartesian product of `id_and_nodes` and pushes each +/// combination directly into `jobs`, avoiding intermediate allocations. +/// +/// The cartesian product is enumerated using mixed-radix indexing. Given +/// dimensions with sizes `[n0, n1, n2, …]`, the total number of combinations +/// is `n0 * n1 * n2 * …`. Each combination index `i` in `0..total` uniquely +/// identifies one element from every dimension: the index into dimension `d` is +/// `(i / (n0 * n1 * … * n(d-1))) % nd`, which we compute incrementally by +/// repeatedly taking `i % nd` and then dividing `i` by `nd`. This is +/// analogous to extracting digits from a number in a mixed-radix system. +fn push_cartesian_product( + id_and_nodes: &[(u64, Vec)], + trace_idx: usize, + jobs: &mut Vec<(usize, Vec)>, + max_slots: &mut usize, +) { + // The product of all dimension sizes gives the total number of + // combinations. If any dimension is empty the product is zero and there + // are no valid combinations to generate. + let total: usize = id_and_nodes.iter().map(|(_, nodes)| nodes.len()).product(); + if total == 0 { + return; + } + + *max_slots = (*max_slots).max(id_and_nodes.len()); + jobs.reserve(total); + + // Enumerate every combination by treating the combination index `i` as a + // mixed-radix number. The inner loop "peels off" one digit per dimension: + // node_idx = i % nodes.len() — selects this dimension's element + // i /= nodes.len() — shifts to the next dimension's digit + // After processing all dimensions, `i` is exhausted (becomes 0), and + // `combo` contains exactly one entry per instruction id. + for mut i in 0..total { + let mut combo = Vec::with_capacity(id_and_nodes.len()); + for (id, nodes) in id_and_nodes { + let node_idx = i % nodes.len(); + i /= nodes.len(); + let profile = nodes[node_idx]; + combo.push(CombinationEntry { + instruction_id: *id, + node: profile, + }); + } + jobs.push((trace_idx, combo)); + } +} + +#[must_use] +#[allow(clippy::cast_precision_loss, clippy::too_many_lines)] +pub fn estimate_with_graph( + traces: &[&Trace], + graph: &Arc>, + max_error: Option, + post_process: bool, +) -> EstimationCollection { + let max_error = max_error.unwrap_or(1.0); + + // Phase 1: Pre-compute all (trace_index, combination) jobs sequentially. + // This reads the provenance graph once per trace and generates the + // cartesian product of Pareto-filtered nodes. Each node carries + // pre-computed (space, time) values for dominance pruning in Phase 2. + let mut jobs: Vec<(usize, Vec)> = Vec::new(); + + // Use the maximum number of instruction slots across all combinations to + // size the pruning witness structure. This will updated while we generate + // jobs. + let mut max_slots = 0; + + for (trace_idx, trace) in traces.iter().enumerate() { + if trace.base_error() > max_error { + continue; + } + + let required = trace.required_instruction_ids(Some(max_error)); + + let graph_lock = graph.read().expect("Graph lock poisoned"); + let id_and_nodes: Vec<_> = required + .constraints() + .iter() + .filter_map(|constraint| { + graph_lock.pareto_nodes(constraint.id()).map(|nodes| { + ( + constraint.id(), + nodes + .iter() + .filter(|&&node| { + // Filter out nodes that don't meet the constraint bounds. + let instruction = graph_lock.instruction(node); + constraint.error_rate().is_none_or(|c| { + c.evaluate(&instruction.error_rate(Some(1)).unwrap_or(0.0)) + }) + }) + .map(|&node| { + let instruction = graph_lock.instruction(node); + let space = instruction.space(Some(1)).unwrap_or(0); + let time = instruction.time(Some(1)).unwrap_or(0); + NodeProfile { + node_index: node, + space, + time, + } + }) + .collect::>(), + ) + }) + }) + .collect(); + drop(graph_lock); + + if id_and_nodes.len() != required.len() { + // If any required instruction is missing from the graph, we can't + // run any estimation for this trace. + continue; + } + + push_cartesian_product(&id_and_nodes, trace_idx, &mut jobs, &mut max_slots); + } + + // Sort jobs so that combinations with smaller total (space + time) are + // processed first. This maximises the effectiveness of dominance pruning + // because successful "cheap" combinations establish witnesses that let us + // skip more expensive ones. + jobs.sort_by_key(|(_, combo)| { + combo + .iter() + .map(|entry| entry.node.space + entry.node.time) + .sum::() + }); + + let total_jobs = jobs.len(); + + // Phase 2: Run estimations in parallel with dominance-based pruning. + // + // For each instruction slot in a combination, we track (space, time) + // witnesses from successful estimations keyed by the "context", which is a + // hash of the node indices in all *other* slots. Before running an + // estimation, we check every slot: if a witness with space ≤ and time ≤ + // exists for that context, the combination is dominated and skipped. + let next_job = AtomicUsize::new(0); + + let pruning_witnesses: Vec> = repeat_with(|| { + repeat_with(|| RwLock::new(FxHashMap::default())) + .take(max_slots) + .collect() + }) + .take(traces.len()) + .collect(); + + // There are no explicit ISAs in this estimation function, as we create them + // on the fly from the graph nodes. For successful jobs, we will attach the + // ISAs to the results collection in a vector with the ISA index addressing + // that vector. In order to avoid storing duplicate ISAs we hash the ISA + // index. + let isa_index = Arc::new(RwLock::new(ISAIndex::default())); + + let mut collection = EstimationCollection::new(); + collection.set_total_jobs(total_jobs); + + std::thread::scope(|scope| { + let num_threads = std::thread::available_parallelism() + .map(std::num::NonZero::get) + .unwrap_or(1); + + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); + + for _ in 0..num_threads { + let tx = tx.clone(); + let next_job = &next_job; + let jobs = &jobs; + let pruning_witnesses = &pruning_witnesses; + let isa_index = Arc::clone(&isa_index); + scope.spawn(move || { + let mut local_results = Vec::new(); + loop { + let job_idx = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if job_idx >= total_jobs { + break; + } + + let (trace_idx, combination) = &jobs[job_idx]; + + // Dominance pruning: skip if a cheaper instruction at any + // slot already succeeded with the same surrounding context. + if is_dominated(combination, &pruning_witnesses[*trace_idx]) { + continue; + } + + let mut isa = ISA::with_graph(graph.clone()); + for entry in combination { + isa.add_node(entry.instruction_id, entry.node.node_index); + } + + if let Ok(mut result) = traces[*trace_idx].estimate(&isa, Some(max_error)) { + let isa_idx = isa_index + .write() + .expect("RwLock should not be poisoned") + .push(combination, &isa); + result.set_isa_index(isa_idx); + + result.set_trace_index(*trace_idx); + + local_results.push(result); + record_success(combination, &pruning_witnesses[*trace_idx]); + } + } + let _ = tx.send(local_results); + }); + } + drop(tx); + + let mut successful = 0; + for local_results in rx { + if post_process { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } + } + successful += local_results.len(); + collection.extend(local_results.into_iter()); + } + collection.set_successful_estimates(successful); + }); + + let isa_index = Arc::try_unwrap(isa_index) + .ok() + .expect("all threads joined; Arc refcount should be 1") + .into_inner() + .expect("RwLock should not be poisoned"); + + // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap + // clones for discarded results. + for result in collection.iter_mut() { + if let Some(idx) = result.isa_index() { + result.set_isa(isa_index.isas[idx].clone()); + } + } + + collection.set_isas(isa_index.into()); + + collection +} From c1bc4fc3fbfc65d9bbbb1b7d42ec647f1936cd4f Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Thu, 2 Apr 2026 17:42:02 +0200 Subject: [PATCH 42/45] Complete and consistent docs for QRE (#3077) Makes sure that Python docs for QRE are complete and are consistent with the style used in the other Python files. --- source/pip/qsharp/qre/_application.py | 42 +++- source/pip/qsharp/qre/_architecture.py | 21 +- source/pip/qsharp/qre/_enumeration.py | 2 +- source/pip/qsharp/qre/_estimation.py | 10 +- source/pip/qsharp/qre/_instruction.py | 77 +++++-- source/pip/qsharp/qre/_isa_enumeration.py | 32 +-- source/pip/qsharp/qre/_qre.pyi | 195 +++++++++--------- source/pip/qsharp/qre/_results.py | 56 +++-- source/pip/qsharp/qre/_trace.py | 91 +++++++- source/pip/qsharp/qre/application/_cirq.py | 8 + source/pip/qsharp/qre/application/_qsharp.py | 25 +++ source/pip/qsharp/qre/interop/_cirq.py | 27 ++- source/pip/qsharp/qre/interop/_qir.py | 1 + source/pip/qsharp/qre/interop/_qsharp.py | 31 ++- .../qsharp/qre/models/factories/_litinski.py | 33 +++ .../qre/models/factories/_round_based.py | 35 +++- .../pip/qsharp/qre/models/factories/_utils.py | 4 +- source/pip/tests/qre/test_application.py | 6 + source/pip/tests/qre/test_cirq_interop.py | 9 + source/pip/tests/qre/test_enumeration.py | 15 ++ source/pip/tests/qre/test_estimation.py | 2 + source/pip/tests/qre/test_interop.py | 3 + source/pip/tests/qre/test_isa.py | 6 + source/pip/tests/qre/test_models.py | 39 ++++ 24 files changed, 595 insertions(+), 175 deletions(-) diff --git a/source/pip/qsharp/qre/_application.py b/source/pip/qsharp/qre/_application.py index 8f2e1d33ed..6c20621b2b 100644 --- a/source/pip/qsharp/qre/_application.py +++ b/source/pip/qsharp/qre/_application.py @@ -49,10 +49,25 @@ class Application(ABC, Generic[TraceParameters]): @abstractmethod def get_trace(self, parameters: TraceParameters) -> Trace: - """Return the trace corresponding to this application.""" + """Return the trace corresponding to this application and parameters. + + Args: + parameters (TraceParameters): The trace parameters. + + Returns: + Trace: The trace for this application instance and parameters. + """ @staticmethod def q(**kwargs) -> TraceQuery: + """Create a trace query for this application. + + Args: + **kwargs: Domain overrides forwarded to trace parameter enumeration. + + Returns: + TraceQuery: A trace query for this application type. + """ return TraceQuery(NoneType, **kwargs) def context(self) -> _Context: @@ -69,7 +84,14 @@ def enumerate_traces( self, **kwargs, ) -> Generator[Trace, None, None]: - """Yields all traces of an application given its dataclass parameters.""" + """Yield all traces of an application given its dataclass parameters. + + Args: + **kwargs: Domain overrides forwarded to ``_enumerate_instances``. + + Yields: + Trace: A trace for each enumerated set of trace parameters. + """ param_type = get_type_hints(self.__class__.get_trace).get("parameters") if param_type is types.NoneType: @@ -95,7 +117,7 @@ def enumerate_traces_with_parameters( self, **kwargs, ) -> Generator[tuple[TraceParameters, Trace], None, None]: - """Yields (parameters, trace) pairs for an application. + """Yield (parameters, trace) pairs for an application. Like ``enumerate_traces``, but each yielded trace is accompanied by the trace parameters that were used to generate it. @@ -103,9 +125,9 @@ def enumerate_traces_with_parameters( Args: **kwargs: Domain overrides forwarded to ``_enumerate_instances``. - Returns: - Generator[tuple[TraceParameters, Trace], None, None]: A generator - of (parameters, trace) pairs. + Yields: + tuple[TraceParameters, Trace]: A pair of trace parameters and + the corresponding trace. """ param_type = get_type_hints(self.__class__.get_trace).get("parameters") @@ -136,7 +158,15 @@ def disable_parallel_traces(self): class _Context: + """Enumeration context wrapping an application instance.""" + application: Application def __init__(self, application: Application, **kwargs): + """Initialize the context for the given application. + + Args: + application (Application): The application instance. + **kwargs: Additional keyword arguments (reserved for future use). + """ self.application = application diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index 1bfb3f29ff..cd8bb52e64 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -24,14 +24,16 @@ class Architecture(ABC): + """Abstract base class for quantum hardware architectures.""" + @abstractmethod def provided_isa(self, ctx: ISAContext) -> ISA: """ - Creates the ISA provided by this architecture, adding instructions + Create the ISA provided by this architecture, adding instructions directly to the context's provenance graph. Args: - ctx: The enumeration context whose provenance graph stores + ctx (ISAContext): The enumeration context whose provenance graph stores the instructions. Returns: @@ -40,7 +42,11 @@ def provided_isa(self, ctx: ISAContext) -> ISA: ... def context(self) -> ISAContext: - """Create a new enumeration context for this architecture.""" + """Create a new enumeration context for this architecture. + + Returns: + ISAContext: A new enumeration context. + """ return ISAContext(self) @@ -50,6 +56,11 @@ class ISAContext: """ def __init__(self, arch: Architecture): + """Initialize the ISA context for the given architecture. + + Args: + arch (Architecture): The architecture providing the base ISA. + """ self._provenance: _ProvenanceGraph = _ProvenanceGraph() # Let the architecture create instructions directly in the graph. @@ -172,11 +183,11 @@ def add_instruction( def make_isa(self, *node_indices: int) -> ISA: """ - Creates an ISA backed by this context's provenance graph from the + Create an ISA backed by this context's provenance graph from the given node indices. Args: - *node_indices: Node indices in the provenance graph. + *node_indices (int): Node indices in the provenance graph. Returns: ISA: An ISA referencing the provenance graph. diff --git a/source/pip/qsharp/qre/_enumeration.py b/source/pip/qsharp/qre/_enumeration.py index fef8b85314..b01d706944 100644 --- a/source/pip/qsharp/qre/_enumeration.py +++ b/source/pip/qsharp/qre/_enumeration.py @@ -102,7 +102,7 @@ def _enumerate_union_members( def _enumerate_instances(cls: Type[T], **kwargs) -> Generator[T, None, None]: """ - Yields all instances of a dataclass given its class. + Yield all instances of a dataclass given its class. The enumeration logic supports defining domains for fields using the ``domain`` metadata key. Additionally, boolean fields are automatically diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index 7f39fd1683..174f0cee90 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -46,7 +46,7 @@ def estimate( the total number of qubits and the total runtime. Note: - The pruning strategy used when `use_graph` is set to True (default) + The pruning strategy used when ``use_graph`` is set to True (default) filters ISA instructions by comparing their per-instruction space, time, and error independently. However, the total qubit count of a result depends on the interaction between factory space and runtime: @@ -55,8 +55,8 @@ def estimate( instruction that is dominated on per-instruction metrics can still contribute to a globally Pareto-optimal result (e.g., a factory with higher time may need fewer copies, leading to fewer total qubits). As a - consequence, `use_graph=True` may miss some results that - `use_graph=False` would find. Use `use_graph=False` when completeness of + consequence, ``use_graph=True`` may miss some results that + ``use_graph=False`` would find. Use ``use_graph=False`` when completeness of the Pareto frontier is required. Args: @@ -73,11 +73,11 @@ def estimate( builds a graph of ISAs and prunes suboptimal ISAs during estimation. If False, use the Rust estimation path that does not perform any pruning and simply enumerates all ISAs for each trace. - name (Optional[str]): An optional name for the estimation. If give, this + name (Optional[str]): An optional name for the estimation. If given, this will be added as a first column to the results table for all entries. Returns: - EstimationTable: A table containing the optimal estimation results + EstimationTable: A table containing the optimal estimation results. """ app_ctx = application.context() diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index ab3c176e69..e48bcecd43 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -47,7 +47,7 @@ def constraint( **kwargs: bool, ) -> Constraint: """ - Creates an instruction constraint. + Create an instruction constraint. Args: id (int): The instruction ID. @@ -89,7 +89,7 @@ class ISATransform(ABC): @abstractmethod def required_isa() -> ISARequirements: """ - Returns the requirements that an implementation ISA must satisfy. + Return the requirements that an implementation ISA must satisfy. Returns: ISARequirements: The requirements for the underlying ISA. @@ -105,6 +105,8 @@ def provided_isa( Args: impl_isa (ISA): The implementation ISA that satisfies requirements. + ctx (ISAContext): The enumeration context whose provenance graph + stores the instructions. Yields: ISA: A provided logical ISA. @@ -119,13 +121,14 @@ def enumerate_isas( **kwargs, ) -> Generator[ISA, None, None]: """ - Enumerates all valid ISAs for this transform given implementation ISAs. + Enumerate all valid ISAs for this transform given implementation ISAs. This method iterates over all instances of the transform class (enumerating - hypterparameters) and filters implementation ISAs against requirements. + hyperparameters) and filters implementation ISAs against requirements. Args: impl_isa (ISA | Iterable[ISA]): One or more implementation ISAs. + ctx (ISAContext): The enumeration context. **kwargs: Arguments passed to parameter enumeration. Yields: @@ -143,7 +146,7 @@ def enumerate_isas( @classmethod def q(cls, *, source: ISAQuery | None = None, **kwargs) -> ISAQuery: """ - Creates an ISAQuery node for this transform. + Create an ISAQuery node for this transform. Args: source (Node | None): The source node providing implementation ISAs. @@ -160,9 +163,9 @@ def q(cls, *, source: ISAQuery | None = None, **kwargs) -> ISAQuery: @classmethod def bind(cls, name: str, node: ISAQuery) -> _BindingNode: """ - Creates a BindingNode for this transform. + Create a BindingNode for this transform. - This is a convenience method equivalent to `cls.q().bind(name, node)`. + This is a convenience method equivalent to ``cls.q().bind(name, node)``. Args: name (str): The name to bind the transform's output to. @@ -182,7 +185,7 @@ class InstructionSource: @classmethod def from_isa(cls, ctx: ISAContext, isa: ISA) -> InstructionSource: """ - Constructs an InstructionSource graph from an ISA. + Construct an InstructionSource graph from an ISA. The instruction source graph contains more information than the provenance graph in the context, as it connects the instructions to the @@ -229,6 +232,11 @@ def _make_node( return graph def add_root(self, node_id: int) -> None: + """Add a root node to the instruction source graph. + + Args: + node_id (int): The index of the node to add as a root. + """ self.roots.append(node_id) def add_node( @@ -237,11 +245,24 @@ def add_node( transform: Optional[ISATransform | Architecture], children: list[int], ) -> int: + """Add a node to the instruction source graph. + + Args: + instruction (Instruction): The instruction for this node. + transform (Optional[ISATransform | Architecture]): The transform + that produced the instruction. + children (list[int]): Indices of child nodes. + + Returns: + int: The index of the newly added node. + """ node_id = len(self.nodes) self.nodes.append(_InstructionSourceNode(instruction, transform, children)) return node_id def __str__(self) -> str: + """Return a formatted string representation of the instruction source graph.""" + def _format_node(node: _InstructionSourceNode, indent: int = 0) -> str: result = " " * indent + f"{instruction_name(node.instruction.id) or '??'}" if node.transform is not None: @@ -256,7 +277,7 @@ def _format_node(node: _InstructionSourceNode, indent: int = 0) -> str: def __getitem__(self, id: int) -> _InstructionSourceNodeReference: """ - Retrieves the first instruction source root node with the given + Retrieve the first instruction source root node with the given instruction ID. Raises KeyError if no such node exists. Args: @@ -273,7 +294,7 @@ def __getitem__(self, id: int) -> _InstructionSourceNodeReference: def __contains__(self, id: int) -> bool: """ - Checks if there is an instruction source root node with the given + Check if there is an instruction source root node with the given instruction ID. Args: @@ -292,7 +313,7 @@ def get( self, id: int, default: Optional[_InstructionSourceNodeReference] = None ) -> Optional[_InstructionSourceNodeReference]: """ - Retrieves the first instruction source root node with the given + Retrieve the first instruction source root node with the given instruction ID. Returns default if no such node exists. Args: @@ -313,30 +334,43 @@ def get( @dataclass(frozen=True, slots=True) class _InstructionSourceNode: + """A node in the instruction source graph.""" + instruction: Instruction transform: Optional[ISATransform | Architecture] children: list[int] class _InstructionSourceNodeReference: + """Reference to a node in an InstructionSource graph.""" + def __init__(self, graph: InstructionSource, node_id: int): + """Initialize a reference to a node in the instruction source graph. + + Args: + graph (InstructionSource): The owning instruction source graph. + node_id (int): The index of the referenced node. + """ self.graph = graph self.node_id = node_id @property def instruction(self) -> Instruction: + """The instruction at this node.""" return self.graph.nodes[self.node_id].instruction @property def transform(self) -> Optional[ISATransform | Architecture]: + """The transform that produced this node's instruction, if any.""" return self.graph.nodes[self.node_id].transform def __str__(self) -> str: + """Return a string representation of the referenced node.""" return str(self.graph.nodes[self.node_id]) def __getitem__(self, id: int) -> _InstructionSourceNodeReference: """ - Retrieves the first child instruction source node with the given + Retrieve the first child instruction source node with the given instruction ID. Raises KeyError if no such node exists. Args: @@ -357,7 +391,7 @@ def get( self, id: int, default: Optional[_InstructionSourceNodeReference] = None ) -> Optional[_InstructionSourceNodeReference]: """ - Retrieves the first child instruction source node with the given + Retrieve the first child instruction source node with the given instruction ID. Returns default if no such node exists. Args: @@ -379,6 +413,15 @@ def get( def _isa_as_frame(self: ISA) -> pd.DataFrame: + """Convert an ISA to a pandas DataFrame. + + Args: + self (ISA): The ISA to convert. + + Returns: + pd.DataFrame: A DataFrame with columns for id, encoding, arity, + space, time, and error. + """ data = { "id": [instruction_name(inst.id) for inst in self], "encoding": [Encoding(inst.encoding).name for inst in self], @@ -401,6 +444,14 @@ def _isa_as_frame(self: ISA) -> pd.DataFrame: def _requirements_as_frame(self: ISARequirements) -> pd.DataFrame: + """Convert ISA requirements to a pandas DataFrame. + + Args: + self (ISARequirements): The requirements to convert. + + Returns: + pd.DataFrame: A DataFrame with columns for id, encoding, and arity. + """ data = { "id": [instruction_name(inst.id) for inst in self], "encoding": [Encoding(inst.encoding).name for inst in self], diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py index c33fdac435..7543c071ed 100644 --- a/source/pip/qsharp/qre/_isa_enumeration.py +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -19,8 +19,8 @@ class ISAQuery(ABC): Abstract base class for all nodes in the ISA enumeration tree. Enumeration nodes define the structure of the search space for ISAs starting - from architectures and mofied by ISA transforms such as error correction - schemes. They can be composed using operators like `+` (sum) and `*` + from architectures and modified by ISA transforms such as error correction + schemes. They can be composed using operators like ``+`` (sum) and ``*`` (product) to build complex enumeration strategies. """ @@ -30,8 +30,8 @@ def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: Yields all ISA instances represented by this enumeration node. Args: - ctx (Context): The enumeration context containing shared state, - e.g., access to the underlying architecture. + ctx (ISAContext): The enumeration context containing shared state, + e.g., access to the underlying architecture. Yields: ISA: A possible ISA that can be generated from this node. @@ -40,9 +40,9 @@ def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: def populate(self, ctx: ISAContext) -> int: """ - Populates the provenance graph with instructions from this node. + Populate the provenance graph with instructions from this node. - Unlike `enumerate`, this does not yield ISA objects. Each transform + Unlike ``enumerate``, this does not yield ISA objects. Each transform queries the graph for Pareto-optimal instructions matching its requirements, and adds produced instructions directly to the graph. @@ -63,7 +63,7 @@ def populate(self, ctx: ISAContext) -> int: def __add__(self, other: ISAQuery) -> _SumNode: """ - Performs a union of two enumeration nodes. + Perform a union of two enumeration nodes. Enumerating the sum node yields all ISAs from this node, followed by all ISAs from the other node. Duplicate ISAs may be produced if both nodes @@ -97,7 +97,7 @@ def __add__(self, other: ISAQuery) -> _SumNode: def __mul__(self, other: ISAQuery) -> _ProductNode: """ - Performs the cross product of two enumeration nodes. + Perform the cross product of two enumeration nodes. Enumerating the product node yields ISAs resulting from the Cartesian product of ISAs from both nodes. The ISAs are combined using @@ -188,15 +188,15 @@ class _ComponentQuery(ISAQuery): """ Query node that enumerates ISAs based on a component type and source. - This node takes a component type (which must have an `enumerate_isas` class + This node takes a component type (which must have an ``enumerate_isas`` class method) and a source node. It enumerates the source node to get base ISAs, - and then calls `enumerate_isas` on the component type for each base ISA + and then calls ``enumerate_isas`` on the component type for each base ISA to generate derived ISAs. Attributes: component: The component type to query (e.g., a QEC code class). source: The source node providing input ISAs (default: ISA_ROOT). - kwargs: Additional keyword arguments passed to `enumerate_isas`. + kwargs: Additional keyword arguments passed to ``enumerate_isas``. """ component: type @@ -218,7 +218,7 @@ def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: def populate(self, ctx: ISAContext) -> int: """ - Populates the graph by querying matching instructions. + Populate the graph by querying matching instructions. Runs the source first to ensure dependency instructions are in the graph, then queries the graph for all instructions matching @@ -270,7 +270,7 @@ def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: ) def populate(self, ctx: ISAContext) -> int: - """Populates the graph from each source sequentially (no cross product). + """Populate the graph from each source sequentially (no cross product). Returns: int: The starting node index before any source populated. @@ -306,7 +306,7 @@ def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: yield from source.enumerate(ctx) def populate(self, ctx: ISAContext) -> int: - """Populates the graph from each source sequentially. + """Populate the graph from each source sequentially. Returns: int: The starting node index before any source populated. @@ -403,7 +403,7 @@ class _BindingNode(ISAQuery): def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ - Enumerates child nodes with the bound component in context. + Enumerate child nodes with the bound component in context. Args: ctx (Context): The enumeration context. @@ -418,7 +418,7 @@ def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: yield from self.node.enumerate(new_ctx) def populate(self, ctx: ISAContext) -> int: - """Populates the graph from both the component and the child node. + """Populate the graph from both the component and the child node. Returns: int: The starting node index of the component's additions. diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 7e9f92ddc8..da2a822063 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -9,7 +9,7 @@ import pandas as pd class ISA: def __add__(self, other: ISA) -> ISA: """ - Concatenates two ISAs (logical union). Instructions in the second + Concatenate two ISAs (logical union). Instructions in the second operand overwrite instructions in the first operand if they have the same ID. """ @@ -17,7 +17,7 @@ class ISA: def __contains__(self, id: int) -> bool: """ - Checks if the ISA contains an instruction with the given ID. + Check if the ISA contains an instruction with the given ID. Args: id (int): The instruction ID. @@ -29,13 +29,13 @@ class ISA: def satisfies(self, requirements: ISARequirements) -> bool: """ - Checks if the ISA satisfies the given ISA requirements. + Check if the ISA satisfies the given ISA requirements. """ ... def __getitem__(self, id: int) -> Instruction: """ - Gets an instruction by its ID. + Get an instruction by its ID. Args: id (int): The instruction ID. @@ -49,7 +49,7 @@ class ISA: self, id: int, default: Optional[Instruction] = None ) -> Optional[Instruction]: """ - Gets an instruction by its ID, or returns a default value if not found. + Get an instruction by its ID, or return a default value if not found. Args: id (int): The instruction ID. @@ -63,7 +63,7 @@ class ISA: def __len__(self) -> int: """ - Returns the number of instructions in the ISA. + Return the number of instructions in the ISA. Returns: int: The number of instructions. @@ -72,7 +72,7 @@ class ISA: def node_index(self, id: int) -> Optional[int]: """ - Returns the provenance graph node index for the given instruction ID. + Return the provenance graph node index for the given instruction ID. Args: id (int): The instruction ID. @@ -84,7 +84,7 @@ class ISA: def add_node(self, instruction_id: int, node_index: int) -> None: """ - Adds a pre-existing provenance graph node to the ISA. + Add a pre-existing provenance graph node to the ISA. Args: instruction_id (int): The instruction ID. @@ -94,7 +94,7 @@ class ISA: def as_frame(self) -> pd.DataFrame: """ - Returns a pandas DataFrame representation of the ISA. + Return a pandas DataFrame representation of the ISA. The DataFrame will have one row per instruction, with columns for instruction properties such as time, space, and error rate. The exact @@ -107,7 +107,7 @@ class ISA: def __iter__(self) -> Iterator[Instruction]: """ - Returns an iterator over the instructions. + Return an iterator over the instructions. Note: The order of instructions is not guaranteed. @@ -119,7 +119,7 @@ class ISA: def __str__(self) -> str: """ - Returns a string representation of the ISA. + Return a string representation of the ISA. Note: The order of instructions in the output is not guaranteed. @@ -136,7 +136,7 @@ class ISARequirements: def __new__(cls, constraints: list[Constraint], /) -> ISARequirements: ... def __new__(cls, *constraints: Constraint | list[Constraint]) -> ISARequirements: """ - Creates an ISA requirements specification from a list of instructions + Create an ISA requirements specification from a list of instructions constraints. Args: @@ -147,7 +147,7 @@ class ISARequirements: def __len__(self) -> int: """ - Returns the number of constraints in the requirements specification. + Return the number of constraints in the requirements specification. Returns: int: The number of constraints. @@ -156,7 +156,7 @@ class ISARequirements: def __iter__(self) -> Iterator[Constraint]: """ - Returns an iterator over the constraints. + Return an iterator over the constraints. Note: The order of constraints is not guaranteed. @@ -168,7 +168,7 @@ class ISARequirements: def as_frame(self) -> pd.DataFrame: """ - Returns a pandas DataFrame representation of the ISA requirements. + Return a pandas DataFrame representation of the ISA requirements. The DataFrame will have one row per instruction, with columns for constraint properties such as encoding. @@ -190,7 +190,7 @@ class Instruction: error_rate: float, ) -> Instruction: """ - Creates an instruction with a fixed arity. + Create an instruction with a fixed arity. Note: This function is not intended to be called directly by the user, use qre.instruction instead. @@ -221,7 +221,7 @@ class Instruction: length_fn: Optional[_IntFunction], ) -> Instruction: """ - Creates an instruction with variable arity. + Create an instruction with variable arity. Note: This function is not intended to be called directly by the user, use qre.instruction instead. @@ -242,7 +242,7 @@ class Instruction: def with_id(self, id: int) -> Instruction: """ - Returns a copy of the instruction with the given ID. + Return a copy of the instruction with the given ID. Note: The created instruction will not inherit the source property of the @@ -360,7 +360,7 @@ class Instruction: def set_source(self, index: int) -> None: """ - Sets the source index for the instruction. + Set the source index for the instruction. Args: index (int): The source index to set. @@ -370,7 +370,7 @@ class Instruction: @property def source(self) -> int: """ - Gets the source index for the instruction. + Get the source index for the instruction. Returns: int: The source index for the instruction. @@ -379,7 +379,7 @@ class Instruction: def set_property(self, key: int, value: int) -> None: """ - Sets a property on the instruction. + Set a property on the instruction. Args: key (int): The property key. @@ -389,7 +389,7 @@ class Instruction: def get_property(self, key: int) -> Optional[int]: """ - Gets a property by its key. + Get a property by its key. Args: key (int): The property key. @@ -401,7 +401,7 @@ class Instruction: def has_property(self, key: int) -> bool: """ - Checks if the instruction has a property with the given key. + Check if the instruction has a property with the given key. Args: key (int): The property key. @@ -413,7 +413,7 @@ class Instruction: def get_property_or(self, key: int, default: int) -> int: """ - Gets a property by its key, or returns a default value if not found. + Get a property by its key, or return a default value if not found. Args: key (int): The property key. @@ -426,7 +426,7 @@ class Instruction: def __getitem__(self, key: int) -> int: """ - Gets a property by its key, or raises an error if not found. + Get a property by its key, or raise an error if not found. Args: key (int): The property key. @@ -438,7 +438,7 @@ class Instruction: def __str__(self) -> str: """ - Returns a string representation of the instruction. + Return a string representation of the instruction. Returns: str: A string representation of the instruction. @@ -453,7 +453,7 @@ class ConstraintBound: @staticmethod def lt(value: float) -> ConstraintBound: """ - Creates a less than constraint bound. + Create a less than constraint bound. Args: value (float): The value. @@ -466,7 +466,7 @@ class ConstraintBound: @staticmethod def le(value: float) -> ConstraintBound: """ - Creates a less equal constraint bound. + Create a less equal constraint bound. Args: value (float): The value. @@ -479,7 +479,7 @@ class ConstraintBound: @staticmethod def eq(value: float) -> ConstraintBound: """ - Creates an equal constraint bound. + Create an equal constraint bound. Args: value (float): The value. @@ -492,7 +492,7 @@ class ConstraintBound: @staticmethod def gt(value: float) -> ConstraintBound: """ - Creates a greater than constraint bound. + Create a greater than constraint bound. Args: value (float): The value. @@ -505,7 +505,7 @@ class ConstraintBound: @staticmethod def ge(value: float) -> ConstraintBound: """ - Creates a greater equal constraint bound. + Create a greater equal constraint bound. Args: value (float): The value. @@ -586,7 +586,7 @@ class Constraint: def add_property(self, property: int) -> None: """ - Adds a property requirement to the constraint. + Add a property requirement to the constraint. Args: property (int): The property key that must be present in matching instructions. @@ -595,7 +595,7 @@ class Constraint: def has_property(self, property: int) -> bool: """ - Checks if the constraint requires a specific property. + Check if the constraint requires a specific property. Args: property (int): The property key to check. @@ -616,7 +616,7 @@ def constant_function( value: int | float, ) -> _IntFunction | _FloatFunction: """ - Creates a constant function. + Create a constant function. Args: value (int | float): The constant value. @@ -634,7 +634,7 @@ def linear_function( slope: int | float, ) -> _IntFunction | _FloatFunction: """ - Creates a linear function. + Create a linear function. Args: slope (int | float): The slope. @@ -654,14 +654,15 @@ def block_linear_function( block_size: int, slope: int | float, offset: int | float ) -> _IntFunction | _FloatFunction: """ - Creates a block linear function that takes an arity (number of qubits) as + Create a block linear function that takes an arity (number of qubits) as input. Given an arity, it will compute the number of blocks `num_blocks` by computing `ceil(arity / block_size)` and then return `slope * num_blocks + offset`. Args: - block_size (int): The block size. slope (int | float): The slope. offset - (int | float): The offset + block_size (int): The block size. + slope (int | float): The slope. + offset (int | float): The offset. Returns: _IntFunction | _FloatFunction: The block linear function. @@ -676,7 +677,7 @@ def generic_function( func: Callable[[int], int | float], ) -> _IntFunction | _FloatFunction: """ - Creates a generic function from a Python callable. + Create a generic function from a Python callable. Note: Only use this function if the other function constructors @@ -705,7 +706,7 @@ class _ProvenanceGraph: self, instruction: Instruction, transform_id: int, children: list[int] ) -> int: """ - Adds a node to the provenance graph. + Add a node to the provenance graph. Args: instruction (int): The instruction corresponding to the node. @@ -719,7 +720,7 @@ class _ProvenanceGraph: def instruction(self, node_index: int) -> Instruction: """ - Returns the instruction for a given node index. + Return the instruction for a given node index. Args: node_index (int): The index of the node in the provenance graph. @@ -731,7 +732,7 @@ class _ProvenanceGraph: def transform_id(self, node_index: int) -> int: """ - Returns the transform ID for a given node index. + Return the transform ID for a given node index. Args: node_index (int): The index of the node in the provenance graph. @@ -743,7 +744,7 @@ class _ProvenanceGraph: def children(self, node_index: int) -> list[int]: """ - Returns the list of child node indices for a given node index. + Return the list of child node indices for a given node index. Args: node_index (int): The index of the node in the provenance graph. @@ -755,7 +756,7 @@ class _ProvenanceGraph: def num_nodes(self) -> int: """ - Returns the number of nodes in the provenance graph. + Return the number of nodes in the provenance graph. Returns: int: The number of nodes in the provenance graph. @@ -764,7 +765,7 @@ class _ProvenanceGraph: def num_edges(self) -> int: """ - Returns the number of edges in the provenance graph. + Return the number of edges in the provenance graph. Returns: int: The number of edges in the provenance graph. @@ -802,7 +803,7 @@ class _ProvenanceGraph: **kwargs: int, ) -> int: """ - Adds an instruction to the provenance graph with no transform or + Add an instruction to the provenance graph with no transform or children. Can be called with a pre-existing ``Instruction`` or with keyword @@ -828,7 +829,7 @@ class _ProvenanceGraph: def make_isa(self, node_indices: list[int]) -> ISA: """ - Creates an ISA backed by this provenance graph from the given node + Create an ISA backed by this provenance graph from the given node indices. Args: @@ -855,7 +856,7 @@ class _ProvenanceGraph: min_node_idx: Optional[int] = None, ) -> list[ISA]: """ - Returns ISAs formed from Pareto-optimal graph nodes satisfying the + Return ISAs formed from Pareto-optimal graph nodes satisfying the given requirements. For each constraint in requirements, selects matching Pareto-optimal @@ -877,7 +878,7 @@ class _ProvenanceGraph: def raw_node_count(self) -> int: """ - Returns the raw node count (including the sentinel at index 0). + Return the raw node count (including the sentinel at index 0). Returns: int: The number of nodes in the graph. @@ -886,7 +887,7 @@ class _ProvenanceGraph: def total_isa_count(self) -> int: """ - Returns the total number of ISAs that can be formed from Pareto-optimal + Return the total number of ISAs that can be formed from Pareto-optimal nodes. Requires ``build_pareto_index`` to have been called. @@ -905,7 +906,7 @@ class EstimationResult: cls, *, qubits: int = 0, runtime: int = 0, error: float = 0.0 ) -> EstimationResult: """ - Creates a new estimation result. + Create a new estimation result. Args: qubits (int): The number of logical qubits. @@ -930,7 +931,7 @@ class EstimationResult: @qubits.setter def qubits(self, qubits: int) -> None: """ - Sets the number of logical qubits. + Set the number of logical qubits. Args: qubits (int): The number of logical qubits to set. @@ -950,7 +951,7 @@ class EstimationResult: @runtime.setter def runtime(self, runtime: int) -> None: """ - Sets the runtime. + Set the runtime. Args: runtime (int): The runtime in nanoseconds to set. @@ -970,7 +971,7 @@ class EstimationResult: @error.setter def error(self, error: float) -> None: """ - Sets the error probability. + Set the error probability. Args: error (float): The error probability to set. @@ -1009,7 +1010,7 @@ class EstimationResult: def set_property(self, key: int, value: bool | int | float | str) -> None: """ - Sets a custom property. + Set a custom property. Args: key (int) The property key. @@ -1020,7 +1021,7 @@ class EstimationResult: def __str__(self) -> str: """ - Returns a string representation of the estimation result. + Return a string representation of the estimation result. Returns: str: A string representation of the estimation result. @@ -1035,7 +1036,7 @@ class _EstimationCollection: def __new__(cls) -> _EstimationCollection: """ - Creates a new estimation collection. + Create a new estimation collection. Returns: _EstimationCollection: The estimation collection. @@ -1044,7 +1045,7 @@ class _EstimationCollection: def insert(self, result: EstimationResult) -> None: """ - Inserts an estimation result into the collection. + Insert an estimation result into the collection. Args: result (EstimationResult): The estimation result to insert. @@ -1053,7 +1054,7 @@ class _EstimationCollection: def __len__(self) -> int: """ - Returns the number of estimation results in the collection. + Return the number of estimation results in the collection. Returns: int: The number of estimation results. @@ -1062,7 +1063,7 @@ class _EstimationCollection: def __iter__(self) -> Iterator[EstimationResult]: """ - Returns an iterator over the estimation results. + Return an iterator over the estimation results. Returns: Iterator[EstimationResult]: The estimation result iterator. @@ -1072,7 +1073,7 @@ class _EstimationCollection: @property def total_jobs(self) -> int: """ - Returns the total number of (trace, ISA) estimation jobs. + Return the total number of (trace, ISA) estimation jobs. Returns: int: The total number of jobs. @@ -1082,7 +1083,7 @@ class _EstimationCollection: @property def successful_estimates(self) -> int: """ - Returns the number of estimation jobs that completed successfully + Return the number of estimation jobs that completed successfully (before Pareto filtering). Returns: @@ -1093,7 +1094,7 @@ class _EstimationCollection: @property def all_summaries(self) -> list[tuple[int, int, int, int]]: """ - Returns lightweight summaries of ALL successful estimates as a list + Return lightweight summaries of ALL successful estimates as a list of (trace_index, isa_index, qubits, runtime) tuples. Returns: @@ -1105,7 +1106,7 @@ class _EstimationCollection: @property def isas(self) -> list[ISA]: """ - Returns the list of ISAs for which estimates were performed. + Return the list of ISAs for which estimates were performed. Returns: list[ISA]: The list of ISAs. @@ -1167,7 +1168,7 @@ class Trace: def __new__(cls, compute_qubits: int) -> Trace: """ - Creates a new trace. + Create a new trace. Returns: Trace: The trace. @@ -1176,7 +1177,7 @@ class Trace: def clone_empty(self, compute_qubits: Optional[int] = None) -> Trace: """ - Creates a new trace with the same metadata but empty block. + Create a new trace with the same metadata but empty block. Args: compute_qubits (Optional[int]): The number of compute qubits. If None, @@ -1190,7 +1191,7 @@ class Trace: @classmethod def from_json(cls, json: str) -> Trace: """ - Creates a trace from a JSON string. + Create a trace from a JSON string. Args: json (str): The JSON string. @@ -1222,7 +1223,7 @@ class Trace: @compute_qubits.setter def compute_qubits(self, qubits: int) -> None: """ - Sets the number of compute qubits. + Set the number of compute qubits. Args: qubits (int): The number of compute qubits to set. @@ -1260,7 +1261,7 @@ class Trace: def has_memory_qubits(self) -> bool: """ - Checks if the trace has memory qubits set. + Check if the trace has memory qubits set. Returns: bool: True if memory qubits are set, False otherwise. @@ -1270,7 +1271,7 @@ class Trace: @memory_qubits.setter def memory_qubits(self, qubits: int) -> None: """ - Sets the number of memory qubits. + Set the number of memory qubits. Args: qubits (int): The number of memory qubits. @@ -1299,7 +1300,7 @@ class Trace: def set_property(self, key: int, value: Any) -> None: """ - Sets a property. All values of type `int`, `float`, `bool`, and `str` + Set a property. All values of type `int`, `float`, `bool`, and `str` are supported. Any other value is converted to a string using its `__str__` method. @@ -1311,7 +1312,7 @@ class Trace: def get_property(self, key: int) -> Optional[int | float | bool | str]: """ - Gets a property. + Get a property. Args: key (int): The property key. @@ -1323,7 +1324,7 @@ class Trace: def has_property(self, key: int) -> bool: """ - Checks if a property with the given key exists. + Check if a property with the given key exists. Args: key (int): The property key. @@ -1367,7 +1368,7 @@ class Trace: self, isa: ISA, max_error: Optional[float] = None ) -> Optional[EstimationResult]: """ - Estimates resources for the trace given a logical ISA. + Estimate resources for the trace given a logical ISA. Args: isa (ISA): The logical ISA. @@ -1394,7 +1395,7 @@ class Trace: self, id: int, qubits: list[int], params: list[float] = [] ) -> None: """ - Adds an operation to the trace. + Add an operation to the trace. Args: id (int): The operation ID. @@ -1405,7 +1406,7 @@ class Trace: def root_block(self) -> Block: """ - Returns the root block of the trace. + Return the root block of the trace. Returns: Block: The root block of the trace. @@ -1414,7 +1415,7 @@ class Trace: def add_block(self, repetitions: int = 1) -> Block: """ - Adds a block to the trace. + Add a block to the trace. Args: repetitions (int): The number of times the block is repeated. @@ -1436,7 +1437,7 @@ class Trace: def __str__(self) -> str: """ - Returns a string representation of the trace. + Return a string representation of the trace. Returns: str: A string representation of the trace. @@ -1456,7 +1457,7 @@ class Block: self, id: int, qubits: list[int], params: list[float] = [] ) -> None: """ - Adds an operation to the block. + Add an operation to the block. Args: id (int): The operation ID. @@ -1467,7 +1468,7 @@ class Block: def add_block(self, repetitions: int = 1) -> Block: """ - Adds a nested block to the block. + Add a nested block to the block. Args: repetitions (int): The number of times the block is repeated. @@ -1479,7 +1480,7 @@ class Block: def __str__(self) -> str: """ - Returns a string representation of the block. + Return a string representation of the block. Returns: str: A string representation of the block. @@ -1502,7 +1503,7 @@ class InstructionFrontier: def __new__(cls, *, with_error_objective: bool = True) -> InstructionFrontier: """ - Creates a new instruction frontier. + Create a new instruction frontier. Args: with_error_objective (bool): If True (default), the frontier uses @@ -1513,7 +1514,7 @@ class InstructionFrontier: def insert(self, point: Instruction): """ - Inserts an instruction to the frontier. + Insert an instruction into the frontier. Args: point (Instruction): The instruction to insert. @@ -1522,7 +1523,7 @@ class InstructionFrontier: def extend(self, points: list[Instruction]) -> None: """ - Extends the frontier with a list of instructions. + Extend the frontier with a list of instructions. Args: points (list[Instruction]): The instructions to insert. @@ -1531,7 +1532,7 @@ class InstructionFrontier: def __len__(self) -> int: """ - Returns the number of instructions in the frontier. + Return the number of instructions in the frontier. Returns: int: The number of instructions. @@ -1540,7 +1541,7 @@ class InstructionFrontier: def __iter__(self) -> Iterator[Instruction]: """ - Returns an iterator over the instructions in the frontier. + Return an iterator over the instructions in the frontier. Returns: Iterator[Instruction]: The iterator. @@ -1552,7 +1553,7 @@ class InstructionFrontier: filename: str, *, with_error_objective: bool = True ) -> InstructionFrontier: """ - Loads an instruction frontier from a file. + Load an instruction frontier from a file. Args: filename (str): The file name. @@ -1567,7 +1568,7 @@ class InstructionFrontier: def dump(self, filename: str) -> None: """ - Dumps the instruction frontier to a file. + Dump the instruction frontier to a file. Args: filename (str): The file name. @@ -1581,7 +1582,7 @@ def _estimate_parallel( post_process: bool = False, ) -> _EstimationCollection: """ - Estimates resources for multiple traces and ISAs in parallel. + Estimate resources for multiple traces and ISAs in parallel. Args: traces (list[Trace]): The list of traces. @@ -1602,7 +1603,7 @@ def _estimate_with_graph( post_process: bool = False, ) -> _EstimationCollection: """ - Estimates resources using a Pareto-filtered provenance graph. + Estimate resources using a Pareto-filtered provenance graph. Instead of forming the full Cartesian product of ISAs × traces, this function enumerates per-trace instruction combinations from the @@ -1628,20 +1629,16 @@ def _binom_ppf(q: float, n: int, p: float) -> int: ... def _float_to_bits(f: float) -> int: - """ - Converts a float to its bit representation as an integer. - """ + """Convert a float to its bit representation as an integer.""" ... def _float_from_bits(b: int) -> float: - """ - Converts a float from its bit representation as an integer. - """ + """Convert a float from its bit representation as an integer.""" ... def instruction_name(id: int) -> Optional[str]: """ - Returns the name of an instruction given its ID, if known. + Return the name of an instruction given its ID, if known. Args: id (int): The instruction ID. @@ -1653,7 +1650,7 @@ def instruction_name(id: int) -> Optional[str]: def property_name_to_key(name: str) -> Optional[int]: """ - Converts a property name to its corresponding key, if known. + Convert a property name to its corresponding key, if known. Args: name (str): The property name. @@ -1665,7 +1662,7 @@ def property_name_to_key(name: str) -> Optional[int]: def property_name(id: int) -> Optional[str]: """ - Converts a property key to its corresponding name, if known. + Convert a property key to its corresponding name, if known. Args: id (int): The property key. diff --git a/source/pip/qsharp/qre/_results.py b/source/pip/qsharp/qre/_results.py index efaa3be144..1a47a0d975 100644 --- a/source/pip/qsharp/qre/_results.py +++ b/source/pip/qsharp/qre/_results.py @@ -28,8 +28,8 @@ class EstimationTable(list["EstimationTableEntry"]): Extends ``list[EstimationTableEntry]`` and provides configurable columns for displaying estimation data. By default the table includes *qubits*, *runtime* (displayed as a ``pandas.Timedelta``), and *error* columns. - Additional columns can be added or inserted with :meth:`add_column` and - :meth:`insert_column`. + Additional columns can be added or inserted with ``add_column`` and + ``insert_column``. """ def __init__(self): @@ -57,7 +57,7 @@ def add_column( function: Callable[[EstimationTableEntry], Any], formatter: Optional[Callable[[Any], Any]] = None, ) -> None: - """Adds a column to the estimation table. + """Add a column to the estimation table. Args: name (str): The name of the column. @@ -65,7 +65,7 @@ def add_column( takes an EstimationTableEntry and returns the value for this column. formatter (Optional[Callable[[Any], Any]]): An optional function - that formats the output of `function` for display purposes. + that formats the output of ``function`` for display purposes. """ self._columns.append((name, EstimationTableColumn(function, formatter))) @@ -76,7 +76,7 @@ def insert_column( function: Callable[[EstimationTableEntry], Any], formatter: Optional[Callable[[Any], Any]] = None, ) -> None: - """Inserts a column at the specified index in the estimation table. + """Insert a column at the specified index in the estimation table. Args: index (int): The index at which to insert the column. @@ -85,11 +85,12 @@ def insert_column( takes an EstimationTableEntry and returns the value for this column. formatter (Optional[Callable[[Any], Any]]): An optional function - that formats the output of `function` for display purposes. + that formats the output of ``function`` for display purposes. """ self._columns.insert(index, (name, EstimationTableColumn(function, formatter))) def add_qubit_partition_column(self) -> None: + """Add columns for the physical compute, factory, and memory qubit counts.""" self.add_column( "physical_compute_qubits", lambda entry: entry.properties.get(PHYSICAL_COMPUTE_QUBITS, 0), @@ -104,7 +105,7 @@ def add_qubit_partition_column(self) -> None: ) def add_factory_summary_column(self) -> None: - """Adds a column to the estimation table that summarizes the factories used in the estimation.""" + """Add a column to the estimation table that summarizes the factories used in the estimation.""" def summarize_factories(entry: EstimationTableEntry) -> str: if not entry.factories: @@ -117,9 +118,9 @@ def summarize_factories(entry: EstimationTableEntry) -> str: self.add_column("factories", summarize_factories) def as_frame(self): - """Convert the estimation table to a :class:`pandas.DataFrame`. + """Convert the estimation table to a ``pandas.DataFrame``. - Each row corresponds to an :class:`EstimationTableEntry` and each + Each row corresponds to an ``EstimationTableEntry`` and each column is determined by the columns registered on this table. Column formatters, when present, are applied to the values before they are placed in the frame. @@ -145,7 +146,7 @@ def as_frame(self): def plot(self, **kwargs): """Plot this table's results. - Convenience wrapper around :func:`plot_estimates`. All keyword + Convenience wrapper around ``plot_estimates``. All keyword arguments are forwarded. Returns: @@ -156,11 +157,11 @@ def plot(self, **kwargs): @dataclass(frozen=True, slots=True) class EstimationTableColumn: - """Definition of a single column in an :class:`EstimationTable`. + """Definition of a single column in an ``EstimationTable``. Attributes: function: A callable that extracts the raw column value from an - :class:`EstimationTableEntry`. + ``EstimationTableEntry``. formatter: An optional callable that transforms the raw value for display purposes (e.g. converting nanoseconds to a ``pandas.Timedelta``). @@ -172,7 +173,7 @@ class EstimationTableColumn: @dataclass(frozen=True, slots=True) class EstimationTableEntry: - """A single row in an :class:`EstimationTable`. + """A single row in an ``EstimationTable``. Each entry represents one Pareto-optimal estimation result for a particular combination of application trace and architecture ISA. @@ -184,7 +185,7 @@ class EstimationTableEntry: source: The instruction source derived from the architecture ISA used for this estimation. factories: A mapping from instruction id to the - :class:`FactoryResult` describing the magic-state factory used + ``FactoryResult`` describing the magic-state factory used and the number of copies required. properties: Additional key-value properties attached to the estimation result. @@ -201,6 +202,15 @@ class EstimationTableEntry: def from_result( cls, result: EstimationResult, ctx: ISAContext ) -> EstimationTableEntry: + """Create an entry from an estimation result and architecture context. + + Args: + result (EstimationResult): The raw estimation result. + ctx (ISAContext): The architecture context used for the estimation. + + Returns: + EstimationTableEntry: A new table entry populated from the result. + """ return cls( qubits=result.qubits, runtime=result.runtime, @@ -213,6 +223,16 @@ def from_result( @dataclass(slots=True) class EstimationTableStats: + """Statistics for a single estimation run. + + Attributes: + num_traces (int): Number of traces evaluated. + num_isas (int): Number of ISAs evaluated. + total_jobs (int): Total estimation jobs executed. + successful_estimates (int): Number of jobs that produced a result. + pareto_results (int): Number of Pareto-optimal results retained. + """ + num_traces: int = 0 num_isas: int = 0 total_jobs: int = 0 @@ -261,15 +281,15 @@ def plot_estimates( figsize: tuple[float, float] = (15, 8), scatter_args: dict[str, Any] = {"marker": "x"}, ): - """Returns a plot of the estimates displaying qubits vs runtime. + """Plot estimation results displaying qubits vs runtime. Creates a log-log scatter plot where the x-axis shows the total runtime and the y-axis shows the total number of physical qubits. - *data* may be a single `EstimationTable` or an iterable of tables. When + *data* may be a single ``EstimationTable`` or an iterable of tables. When multiple tables are provided, each is plotted as a separate series. If a - table has a `EstimationTable.name` (set via the *name* parameter of - `estimate`), it is used as the legend label for that series. + table has a ``EstimationTable.name`` (set via the *name* parameter of + ``estimate``), it is used as the legend label for that series. When *runtime_unit* is ``None`` (the default), the x-axis uses human-readable time-unit tick labels spanning nanoseconds to centuries. diff --git a/source/pip/qsharp/qre/_trace.py b/source/pip/qsharp/qre/_trace.py index 15873ebf15..965454ac95 100644 --- a/source/pip/qsharp/qre/_trace.py +++ b/source/pip/qsharp/qre/_trace.py @@ -15,16 +15,49 @@ class TraceTransform(ABC): + """Abstract base class for trace transformations.""" + @abstractmethod - def transform(self, trace: Trace) -> Optional[Trace]: ... + def transform(self, trace: Trace) -> Optional[Trace]: + """Apply this transformation to a trace. + + Args: + trace (Trace): The input trace. + + Returns: + Optional[Trace]: The transformed trace, or None if the + transformation is not applicable. + """ + ... @classmethod def q(cls, **kwargs) -> TraceQuery: + """Create a trace query for this transform type. + + Args: + **kwargs: Domain overrides for parameter enumeration. + + Returns: + TraceQuery: A trace query wrapping this transform type. + """ return TraceQuery(cls, **kwargs) @dataclass class PSSPC(TraceTransform): + """Pauli-based computation trace transform (PSSPC). + + Converts rotation gates and optionally CCX gates into T-state-based + operations suitable for lattice surgery resource estimation. + + Attributes: + num_ts_per_rotation (int): Number of T states used per rotation + gate. Default is 20. + ccx_magic_states (bool): If True, CCX gates are treated as magic + states rather than being decomposed into T gates. Default is + False. + """ + _: KW_ONLY num_ts_per_rotation: int = field( default=20, metadata={"domain": list(range(5, 21))} @@ -35,11 +68,29 @@ def __post_init__(self): self._psspc = _PSSPC(self.num_ts_per_rotation, self.ccx_magic_states) def transform(self, trace: Trace) -> Optional[Trace]: + """Apply the PSSPC transformation to a trace. + + Args: + trace (Trace): The input trace. + + Returns: + Optional[Trace]: The transformed trace. + """ return self._psspc.transform(trace) @dataclass class LatticeSurgery(TraceTransform): + """Lattice surgery trace transform. + + Converts a trace into a form suitable for lattice-surgery-based + resource estimation. + + Attributes: + slow_down_factor (float): Multiplicative factor applied to the + trace depth. Default is 1.0. + """ + _: KW_ONLY slow_down_factor: float = field(default=1.0, metadata={"domain": [1.0]}) @@ -47,15 +98,31 @@ def __post_init__(self): self._lattice_surgery = _LatticeSurgery(self.slow_down_factor) def transform(self, trace: Trace) -> Optional[Trace]: + """Apply the lattice surgery transformation to a trace. + + Args: + trace (Trace): The input trace. + + Returns: + Optional[Trace]: The transformed trace. + """ return self._lattice_surgery.transform(trace) class _Node(ABC): + """Abstract base class for trace enumeration nodes.""" + @abstractmethod def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: ... class TraceQuery(_Node): + """A query that enumerates transformed traces from an application. + + A trace query chains a sequence of trace transforms, each with optional + keyword arguments to override their default parameter domains. + """ + # This is a sequence of trace transforms together with possible kwargs to # override their default domains. The first element might be sequence: list[tuple[Type, dict[str, Any]]] @@ -66,6 +133,17 @@ def __init__(self, t: Type, **kwargs): def enumerate( self, ctx: _Context, track_parameters: bool = False ) -> Generator[Trace | tuple[Any, Trace], None, None]: + """Enumerate transformed traces from the application context. + + Args: + ctx (_Context): The application enumeration context. + track_parameters (bool): If True, yield ``(parameters, trace)`` + tuples instead of plain traces. Default is False. + + Yields: + Trace | tuple[Any, Trace]: A transformed trace, or a + ``(parameters, trace)`` tuple when *track_parameters* is True. + """ sequence = self.sequence kwargs = {} if len(sequence) > 0 and sequence[0][0] is NoneType: @@ -96,6 +174,17 @@ def enumerate( yield (params, transformed) if track_parameters else transformed def __mul__(self, other: TraceQuery) -> TraceQuery: + """Chain another trace query onto this one. + + Args: + other (TraceQuery): The trace query to append. + + Returns: + TraceQuery: A new query with the combined transform sequence. + + Raises: + ValueError: If *other* begins with a None transform. + """ new_query = TraceQuery.__new__(TraceQuery) if len(other.sequence) > 0 and other.sequence[0][0] is NoneType: diff --git a/source/pip/qsharp/qre/application/_cirq.py b/source/pip/qsharp/qre/application/_cirq.py index 6f054213b7..f2bc7272f3 100644 --- a/source/pip/qsharp/qre/application/_cirq.py +++ b/source/pip/qsharp/qre/application/_cirq.py @@ -45,4 +45,12 @@ def __post_init__(self): self._circuit = self.circuit_or_qasm def get_trace(self, parameters: None = None) -> Trace: + """Return the resource estimation trace for the Cirq circuit. + + Args: + parameters (None): Unused. Defaults to None. + + Returns: + Trace: The resource estimation trace. + """ return trace_from_cirq(self._circuit) diff --git a/source/pip/qsharp/qre/application/_qsharp.py b/source/pip/qsharp/qre/application/_qsharp.py index 0ed642d826..abdad2bce4 100644 --- a/source/pip/qsharp/qre/application/_qsharp.py +++ b/source/pip/qsharp/qre/application/_qsharp.py @@ -16,15 +16,40 @@ @dataclass class QSharpApplication(Application[None]): + """Application that produces a resource estimation trace from Q# code. + + Accepts a Q# entry expression string, a callable, or pre-computed + ``LogicalCounts``. + + Attributes: + cache_dir (Path): Directory for caching compiled traces. + use_cache (bool): Whether to use the trace cache. Default is False. + """ + cache_dir: Path = field( default=Path.home() / ".cache" / "re3" / "qsharp", repr=False ) use_cache: bool = field(default=False, repr=False) def __init__(self, entry_expr: str | Callable | LogicalCounts): + """Initialize the Q# application. + + Args: + entry_expr (str | Callable | LogicalCounts): The Q# entry + expression, a callable returning logical counts, or + pre-computed logical counts. + """ self._entry_expr = entry_expr def get_trace(self, parameters: None = None) -> Trace: + """Return the resource estimation trace for the Q# program. + + Args: + parameters (None): Unused. Defaults to None. + + Returns: + Trace: The resource estimation trace. + """ # TODO: make caching work for `Callable` as well if self.use_cache and isinstance(self._entry_expr, str): cache_path = self.cache_dir / f"{self._entry_expr}.json" diff --git a/source/pip/qsharp/qre/interop/_cirq.py b/source/pip/qsharp/qre/interop/_cirq.py index b685456d84..fe84dfe4c5 100644 --- a/source/pip/qsharp/qre/interop/_cirq.py +++ b/source/pip/qsharp/qre/interop/_cirq.py @@ -256,6 +256,14 @@ class PopBlock: @dataclass(frozen=True, slots=True) class TraceGate: + """A raw trace instruction emitted during Cirq circuit conversion. + + Attributes: + id (int): The instruction ID. + qubits (list[cirq.Qid] | cirq.Qid): The target qubits. + params (list[float] | float | None): Optional gate parameters. + """ + id: int qubits: list[cirq.Qid] | cirq.Qid params: list[float] | float | None = None @@ -282,6 +290,7 @@ def __getitem__(self, key: cirq.Qid) -> int: def h_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert an HPowGate into trace instructions.""" if _approx_eq(abs(self.exponent), 1): yield TraceGate(H, [op.qubits[0]]) else: @@ -289,6 +298,7 @@ def h_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Opera def x_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert an XPowGate into trace instructions.""" q = [op.qubits[0]] exp = self.exponent if _approx_eq(exp, 1) or _approx_eq(exp, -1): @@ -306,6 +316,7 @@ def x_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Opera def y_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a YPowGate into trace instructions.""" q = [op.qubits[0]] exp = self.exponent if _approx_eq(exp, 1) or _approx_eq(exp, -1): @@ -323,6 +334,7 @@ def y_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Opera def z_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a ZPowGate into trace instructions.""" q = [op.qubits[0]] exp = self.exponent if _approx_eq(exp, 1) or _approx_eq(exp, -1): @@ -340,6 +352,7 @@ def z_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Opera def cx_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a CXPowGate into trace instructions.""" if _approx_eq(abs(self.exponent), 1): yield TraceGate(CX, [op.qubits[0], op.qubits[1]]) else: @@ -347,6 +360,7 @@ def cx_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Oper def cz_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a CZPowGate into trace instructions.""" exp = self.exponent c, t = op.qubits[0], op.qubits[1] if _approx_eq(abs(exp), 1): @@ -377,6 +391,7 @@ def cz_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Oper def swap_pow_gate_to_trace( self, context: cirq.DecompositionContext, op: cirq.Operation ): + """Convert a SwapPowGate into trace instructions.""" if _approx_eq(abs(self.exponent), 1): yield TraceGate(SWAP, [op.qubits[0], op.qubits[1]]) else: @@ -384,6 +399,7 @@ def swap_pow_gate_to_trace( def ccx_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a CCXPowGate into trace instructions.""" if _approx_eq(abs(self.exponent), 1): yield TraceGate(CCX, [op.qubits[0], op.qubits[1], op.qubits[2]]) else: @@ -391,6 +407,7 @@ def ccx_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Ope def ccz_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a CCZPowGate into trace instructions.""" if _approx_eq(abs(self.exponent), 1): yield TraceGate(CCZ, [op.qubits[0], op.qubits[1], op.qubits[2]]) else: @@ -400,6 +417,7 @@ def ccz_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Ope def measurement_gate_to_trace( self, context: cirq.DecompositionContext, op: cirq.Operation ): + """Convert a MeasurementGate into trace instructions.""" for q in op.qubits: yield TraceGate(MEAS_Z, [q]) @@ -407,6 +425,7 @@ def measurement_gate_to_trace( def reset_channel_to_trace( self, context: cirq.DecompositionContext, op: cirq.Operation ): + """Convert a ResetChannel into trace instructions (no-op).""" yield from () @@ -428,10 +447,10 @@ def reset_channel_to_trace( def phase_gradient_decompose(self, qubits): - """ - Overrides implementation of PhaseGradientGate._decompose_ to skip rotations - with very small angles. In particular the original implementation may lead - to FP overflows for large values of i. + """Override PhaseGradientGate._decompose_ to skip rotations with very small angles. + + The original implementation may lead to floating-point overflows for + large values of i. """ for i, q in enumerate(qubits): diff --git a/source/pip/qsharp/qre/interop/_qir.py b/source/pip/qsharp/qre/interop/_qir.py index e3d9499c36..ebfb9559d1 100644 --- a/source/pip/qsharp/qre/interop/_qir.py +++ b/source/pip/qsharp/qre/interop/_qir.py @@ -105,6 +105,7 @@ def trace_from_qir(input: str | bytes) -> Trace: def _add_gate(trace: Trace, gate: tuple) -> None: + """Add a single QIR gate tuple to the trace.""" op = gate[0] for qir_id, instr_id, arity in _GATE_MAP: diff --git a/source/pip/qsharp/qre/interop/_qsharp.py b/source/pip/qsharp/qre/interop/_qsharp.py index 0a38644694..b7c31767f0 100644 --- a/source/pip/qsharp/qre/interop/_qsharp.py +++ b/source/pip/qsharp/qre/interop/_qsharp.py @@ -22,10 +22,10 @@ def _bucketize_rotation_counts( rotation_count: int, rotation_depth: int ) -> list[tuple[int, int]]: """ - Returns a list of (count, depth) pairs representing the rotation layers in + Return a list of (count, depth) pairs representing the rotation layers in the trace. - The following properties hold for the returned list `result`: + The following properties hold for the returned list ``result``: - sum(depth for _, depth in result) == rotation_depth - sum(count * depth for count, depth in result) == rotation_count - count > 0 for each (count, _) in result @@ -55,6 +55,18 @@ def _bucketize_rotation_counts( def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: + """Convert a Q# entry expression into a resource-estimation Trace. + + Evaluates the entry expression to obtain logical counts, then builds + a trace containing the corresponding quantum operations. + + Args: + entry_expr (str | Callable | LogicalCounts): A Q# entry expression + string, a callable, or pre-computed logical counts. + + Returns: + Trace: A trace representing the resource profile of the program. + """ start = time.time_ns() counts = ( @@ -115,6 +127,21 @@ def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: def trace_from_entry_expr_cached( entry_expr: str | Callable | LogicalCounts, cache_path: Optional[Path] ) -> Trace: + """Convert a Q# entry expression into a Trace, with optional caching. + + If *cache_path* is provided and exists, the trace is loaded from disk. + Otherwise, the trace is computed via ``trace_from_entry_expr`` and + optionally written to *cache_path*. + + Args: + entry_expr (str | Callable | LogicalCounts): A Q# entry expression + string, a callable, or pre-computed logical counts. + cache_path (Optional[Path]): Path for reading/writing the cached + trace. If None, caching is disabled. + + Returns: + Trace: A trace representing the resource profile of the program. + """ if cache_path and cache_path.exists(): return Trace.from_json(cache_path.read_text()) diff --git a/source/pip/qsharp/qre/models/factories/_litinski.py b/source/pip/qsharp/qre/models/factories/_litinski.py index 30d3b444c6..ffe4b2558d 100644 --- a/source/pip/qsharp/qre/models/factories/_litinski.py +++ b/source/pip/qsharp/qre/models/factories/_litinski.py @@ -51,6 +51,15 @@ def required_isa() -> ISARequirements: def provided_isa( self, impl_isa: ISA, ctx: ISAContext ) -> Generator[ISA, None, None]: + """Yield ISAs with T and CCZ factory instructions. + + Args: + impl_isa (ISA): The implementation ISA providing physical gates. + ctx (ISAContext): The enumeration context. + + Yields: + ISA: An ISA containing distilled T and/or CCZ instructions. + """ h = impl_isa[H] cnot = impl_isa[CNOT] meas_z = impl_isa[MEAS_Z] @@ -126,6 +135,7 @@ def make_node(entry: _Entry) -> int: yield ctx.make_isa(make_node(t_entry)) def _initialize_entries(self): + """Initialize the distillation protocol lookup tables.""" self._entries = { # Assuming a Clifford error rate of at most 1e-4: 1e-4: ( @@ -323,6 +333,16 @@ def _initialize_entries(self): @dataclass(frozen=True, slots=True) class _Entry: + """A single distillation protocol entry from the Litinski tables. + + Attributes: + protocol (list[tuple[_Protocol, int]] | _Protocol): The distillation + protocol or pipeline of protocols. + error_rate (float): Output error rate of the protocol. + space (int): Space cost in physical qubits. + cycles (float): Number of syndrome extraction cycles. + """ + protocol: list[tuple[_Protocol, int]] | _Protocol error_rate: float # Space estimation in number of physical qubits @@ -333,6 +353,7 @@ class _Entry: @property def output_states(self) -> int: + """Return the number of output magic states.""" if isinstance(self.protocol, list): return self.protocol[-1][0].output_states else: @@ -340,6 +361,7 @@ def output_states(self) -> int: @property def state(self) -> int: + """Return the magic state instruction ID (T or CCZ).""" if isinstance(self.protocol, list): return self.protocol[-1][0].state else: @@ -348,6 +370,17 @@ def state(self) -> int: @dataclass(frozen=True, slots=True) class _Protocol: + """Parameters for a single distillation protocol. + + Attributes: + input_states (int): Number of input T states. + output_states (int): Number of output T states. + d_x (int): Spatial X distance. + d_z (int): Spatial Z distance. + d_m (int): Temporal distance. + state (int): Magic state instruction ID. Default is T. + """ + # Number of input T states in protocol input_states: int # Number of output T states in protocol diff --git a/source/pip/qsharp/qre/models/factories/_round_based.py b/source/pip/qsharp/qre/models/factories/_round_based.py index 5f746595bd..c4b9379448 100644 --- a/source/pip/qsharp/qre/models/factories/_round_based.py +++ b/source/pip/qsharp/qre/models/factories/_round_based.py @@ -47,12 +47,12 @@ class RoundBasedFactory(ISATransform): the overall space requirements. Space requirements are calculated using a user-provided function that - aggregates per-round space (e.g., sum or max). The `sum` function models - the case in which qubits are not reused across rounds, while the `max` + aggregates per-round space (e.g., sum or max). The ``sum`` function models + the case in which qubits are not reused across rounds, while the ``max`` function models the case in which qubits are reused across rounds. For the enumeration of logical-level distillation units, the factory relies - on a user-provided `ISAQuery` (defaulting to `SurfaceCode.q()`) to explore + on a user-provided ``ISAQuery`` (defaulting to ``SurfaceCode.q()``) to explore different surface code configurations and their corresponding lattice surgery instructions. These need to be provided by the user and cannot automatically be derived from the provided implementation ISA, as they can @@ -172,6 +172,7 @@ def provided_isa( ) def _physical_units(self, gate_time, clifford_error) -> list[_DistillationUnit]: + """Return physical distillation units for the given gate parameters.""" return [ _DistillationUnit( num_input_states=15, @@ -194,6 +195,7 @@ def _physical_units(self, gate_time, clifford_error) -> list[_DistillationUnit]: def _logical_units( self, lattice_surgery_instruction: Instruction ) -> list[_DistillationUnit]: + """Return logical distillation units derived from a lattice surgery instruction.""" logical_cycle_time = lattice_surgery_instruction.expect_time(1) logical_error = lattice_surgery_instruction.expect_error_rate(1) @@ -217,6 +219,7 @@ def _logical_units( ] def _state_from_pipeline(self, pipeline: _Pipeline) -> Instruction: + """Create a T-gate instruction from a distillation pipeline.""" return Instruction.fixed_arity( T, int(LOGICAL), @@ -247,11 +250,14 @@ def _cache_key(self, impl_isa: ISA) -> str: return hashlib.sha256(data).hexdigest() def _cache_path(self, impl_isa: ISA) -> Path: + """Return the cache file path for the given implementation ISA.""" self.cache_dir.mkdir(parents=True, exist_ok=True) return self.cache_dir / f"{self._cache_key(impl_isa)}.json" class _Pipeline: + """A multi-round distillation pipeline.""" + def __init__( self, units: Sequence[_DistillationUnit], @@ -276,6 +282,12 @@ def try_create( failure_probability_requirement: float = 0.01, physical_qubit_calculation: Callable[[Iterable], int] = sum, ) -> Optional[_Pipeline]: + """Create a pipeline if the configuration is feasible. + + Returns: + Optional[_Pipeline]: The pipeline, or None if the required + number of units per round is infeasible. + """ pipeline = cls( units, initial_input_error_rate, @@ -287,6 +299,7 @@ def try_create( return pipeline def _compute_units_per_round(self) -> bool: + """Adjust the number of units per round to meet output requirements.""" if len(self.rounds) > 0: states_needed_next = self.rounds[-1].unit.num_output_states @@ -298,6 +311,7 @@ def _compute_units_per_round(self) -> bool: return True def _add_rounds(self, units: Sequence[_DistillationUnit]): + """Append distillation rounds from the given units.""" per_round_failure_prob_req = self.failure_probability_requirement / len(units) for unit in units: @@ -313,23 +327,29 @@ def _add_rounds(self, units: Sequence[_DistillationUnit]): @property def space(self) -> int: + """Total physical-qubit space of the pipeline.""" return self.physical_qubit_calculation(round.space for round in self.rounds) @property def time(self) -> int: + """Total time of the pipeline in nanoseconds.""" return sum(round.unit.time for round in self.rounds) @property def error_rate(self) -> float: + """Output error rate of the pipeline.""" return self.output_error_rate @property def num_output_states(self) -> int: + """Number of output magic states produced by the pipeline.""" return self.rounds[-1].compute_num_output_states() @dataclass(slots=True) class _DistillationUnit: + """A single distillation unit with fixed input/output characteristics.""" + num_input_states: int time: int space: int @@ -339,12 +359,14 @@ class _DistillationUnit: num_output_states: int = 1 def error_rate(self, input_error_rate: float) -> float: + """Compute the output error rate for a given input error rate.""" result = 0.0 for c in self.error_rate_coeffs: result = result * input_error_rate + c return result def failure_probability(self, input_error_rate: float) -> float: + """Compute the failure probability for a given input error rate.""" result = 0.0 for c in self.failure_probability_coeffs: result = result * input_error_rate + c @@ -353,6 +375,8 @@ def failure_probability(self, input_error_rate: float) -> float: @dataclass(slots=True) class _DistillationRound: + """A single round in a distillation pipeline.""" + unit: _DistillationUnit failure_probability_requirement: float input_error_rate: float @@ -363,6 +387,7 @@ def __post_init__(self): self.failure_probability = self.unit.failure_probability(self.input_error_rate) def adjust_num_units_to(self, output_states_needed_next: int) -> bool: + """Adjust the number of units to produce at least the required output states.""" if self.failure_probability == 0.0: self.num_units = output_states_needed_next return True @@ -396,17 +421,21 @@ def adjust_num_units_to(self, output_states_needed_next: int) -> bool: @property def space(self) -> int: + """Total physical-qubit space for this round.""" return self.num_units * self.unit.space @property def num_input_states(self) -> int: + """Total number of input states consumed by this round.""" return self.num_units * self.unit.num_input_states @property def max_num_output_states(self) -> int: + """Maximum number of output states this round can produce.""" return self.num_units * self.unit.num_output_states def compute_num_output_states(self) -> int: + """Compute the expected number of output states accounting for failure probability.""" failure_prob = self.failure_probability if failure_prob <= 1e-8: diff --git a/source/pip/qsharp/qre/models/factories/_utils.py b/source/pip/qsharp/qre/models/factories/_utils.py index a0efbc4ec5..0fbec26ed7 100644 --- a/source/pip/qsharp/qre/models/factories/_utils.py +++ b/source/pip/qsharp/qre/models/factories/_utils.py @@ -23,8 +23,8 @@ class MagicUpToClifford(ISATransform): """ An ISA transform that adds Clifford equivalent representations of magic states. For example, if the input ISA contains a T gate, the provided ISA - will also contain `SQRT_SQRT_X`, `SQRT_SQRT_X_DAG`, `SQRT_SQRT_Y`, - `SQRT_SQRT_Y_DAG`, and `T_DAG`. The same is applied for `CCZ` gates and + will also contain ``SQRT_SQRT_X``, ``SQRT_SQRT_X_DAG``, ``SQRT_SQRT_Y``, + ``SQRT_SQRT_Y_DAG``, and ``T_DAG``. The same is applied for ``CCZ`` gates and their Clifford equivalents. Example: diff --git a/source/pip/tests/qre/test_application.py b/source/pip/tests/qre/test_application.py index 6b73222e12..c62537ee88 100644 --- a/source/pip/tests/qre/test_application.py +++ b/source/pip/tests/qre/test_application.py @@ -28,6 +28,7 @@ def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): + """Assert that an estimation result matches expected qubit, runtime, and error values.""" actual_qubits = ( isa[LATTICE_SURGERY].expect_space(trace.compute_qubits) + isa[T].expect_space() * result.factories[T].copies @@ -52,6 +53,7 @@ def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): def test_trace_properties(): + """Test setting and getting typed properties on a Trace.""" trace = Trace(42) INT = 0 @@ -77,6 +79,7 @@ def test_trace_properties(): def test_qsharp_application(): + """Test QSharpApplication trace generation and estimation from a Q# program.""" code = """ {{ use (a, b, c) = (Qubit(), Qubit(), Qubit()); @@ -163,6 +166,7 @@ def test_qsharp_application(): def test_application_enumeration(): + """Test that Application.q() enumerates the correct number of traces.""" @dataclass(kw_only=True) class _Params: size: int = field(default=1, metadata={"domain": range(1, 4)}) @@ -178,6 +182,7 @@ def get_trace(self, parameters: _Params) -> Trace: def test_trace_enumeration(): + """Test trace query enumeration with PSSPC and LatticeSurgery transforms.""" code = """ {{ use (a, b, c) = (Qubit(), Qubit(), Qubit()); @@ -201,6 +206,7 @@ def test_trace_enumeration(): def test_rotation_error_psspc(): + """Test that PSSPC base error stays below 1.0 for a single rotation gate.""" # This test helps to bound the variables for the number of rotations in PSSPC # Create a trace with a single rotation gate and ensure that the base error diff --git a/source/pip/tests/qre/test_cirq_interop.py b/source/pip/tests/qre/test_cirq_interop.py index b65fe39bdb..826ac54c7e 100644 --- a/source/pip/tests/qre/test_cirq_interop.py +++ b/source/pip/tests/qre/test_cirq_interop.py @@ -6,42 +6,50 @@ def test_with_qft(): + """Test trace generation from a 1025-qubit QFT circuit.""" _test_one_circuit(cirq.qft(*cirq.LineQubit.range(1025)), 1025, 212602, 266007) def test_h(): + """Test trace generation from Hadamard and fractional Hadamard gates.""" _test_one_circuit(cirq.H, 1, 1, 1) _test_one_circuit(cirq.H**0.5, 1, 3, 3) def test_cx(): + """Test trace generation from CX and fractional CX gates.""" _test_one_circuit(cirq.CX, 2, 1, 1) _test_one_circuit(cirq.CX**0.5, 2, 6, 7) _test_one_circuit(cirq.CX**0.25, 2, 6, 7) def test_cz(): + """Test trace generation from CZ and fractional CZ gates.""" _test_one_circuit(cirq.CZ, 2, 1, 1) _test_one_circuit(cirq.CZ**0.5, 2, 4, 5) _test_one_circuit(cirq.CZ**0.25, 2, 4, 5) def test_swap(): + """Test trace generation from SWAP and fractional SWAP gates.""" _test_one_circuit(cirq.SWAP, 2, 1, 1) _test_one_circuit(cirq.SWAP**0.5, 2, 8, 9) def test_ccx(): + """Test trace generation from CCX and fractional CCX gates.""" _test_one_circuit(cirq.CCX, 3, 1, 1) _test_one_circuit(cirq.CCX**0.5, 3, 11, 17) def test_ccz(): + """Test trace generation from CCZ and fractional CCZ gates.""" _test_one_circuit(cirq.CCZ, 3, 1, 1) _test_one_circuit(cirq.CCZ**0.5, 3, 10, 15) def test_circuit_with_block(): + """Test trace generation from a circuit with a custom decomposable gate.""" class CustomGate(cirq.Gate): def num_qubits(self) -> int: return 2 @@ -70,6 +78,7 @@ def _test_one_circuit( expected_depth: int, expected_gates: int, ): + """Assert that a Cirq circuit produces a trace with the expected qubits, depth, and gates.""" app = CirqApplication(circuit) trace = app.get_trace() diff --git a/source/pip/tests/qre/test_enumeration.py b/source/pip/tests/qre/test_enumeration.py index 476e65f22b..636f982699 100644 --- a/source/pip/tests/qre/test_enumeration.py +++ b/source/pip/tests/qre/test_enumeration.py @@ -20,6 +20,7 @@ def test_enumerate_instances(): + """Test enumeration of SurfaceCode instances with default and custom domains.""" from qsharp.qre._enumeration import _enumerate_instances instances = list(_enumerate_instances(SurfaceCode)) @@ -44,6 +45,7 @@ def test_enumerate_instances(): def test_enumerate_instances_bool(): + """Test that boolean dataclass fields enumerate both True and False.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -58,6 +60,7 @@ class BoolConfig: def test_enumerate_instances_enum(): + """Test that Enum dataclass fields enumerate all members.""" from qsharp.qre._enumeration import _enumerate_instances class Color(Enum): @@ -78,6 +81,7 @@ class EnumConfig: def test_enumerate_instances_failure(): + """Test that a field with no domain and no default raises ValueError.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -91,6 +95,7 @@ class InvalidConfig: def test_enumerate_instances_single(): + """Test enumeration of a dataclass with a single non-kw-only field.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -103,6 +108,7 @@ class SingleConfig: def test_enumerate_instances_literal(): + """Test that Literal-typed fields enumerate their allowed values.""" from qsharp.qre._enumeration import _enumerate_instances from typing import Literal @@ -119,6 +125,7 @@ class LiteralConfig: def test_enumerate_instances_nested(): + """Test enumeration of nested dataclass fields.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -138,6 +145,7 @@ class OuterConfig: def test_enumerate_instances_union(): + """Test enumeration of union-typed dataclass fields.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -164,6 +172,7 @@ class UnionConfig: def test_enumerate_instances_nested_with_constraints(): + """Test constraining nested dataclass fields via a dict.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -183,6 +192,7 @@ class OuterConfig: def test_enumerate_instances_union_single_type(): + """Test restricting a union field to a single member type.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -215,6 +225,7 @@ class UnionConfig: def test_enumerate_instances_union_list_of_types(): + """Test restricting a union field to a subset of member types.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -244,6 +255,7 @@ class UnionConfig: def test_enumerate_instances_union_constraint_dict(): + """Test constraining union field members via a type-to-kwargs dict.""" from qsharp.qre._enumeration import _enumerate_instances @dataclass @@ -293,6 +305,7 @@ class UnionConfig: def test_enumerate_isas(): + """Test ISA enumeration with products, sums, and hierarchical factories.""" ctx = GateBased(gate_time=50, measurement_time=100).context() # This will enumerate the 4 ISAs for the error correction code @@ -464,6 +477,7 @@ def test_binding_node_errors(): def test_product_isa_enumeration_nodes(): + """Test that multiplying ISAQuery nodes produces flattened ProductNodes.""" terminal = SurfaceCode.q() query = terminal * terminal @@ -496,6 +510,7 @@ def test_product_isa_enumeration_nodes(): def test_sum_isa_enumeration_nodes(): + """Test that adding ISAQuery nodes produces flattened SumNodes.""" terminal = SurfaceCode.q() query = terminal + terminal diff --git a/source/pip/tests/qre/test_estimation.py b/source/pip/tests/qre/test_estimation.py index bb857115ed..2947db2075 100644 --- a/source/pip/tests/qre/test_estimation.py +++ b/source/pip/tests/qre/test_estimation.py @@ -23,6 +23,7 @@ def test_estimation_max_error(): + """Test that estimation results respect the max_error constraint.""" app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) arch = GateBased(gate_time=50, measurement_time=100) @@ -53,6 +54,7 @@ def test_estimation_max_error(): ], ) def test_estimation_methods(post_process, use_graph): + """Test all combinations of post_process and use_graph estimation paths.""" counts = LogicalCounts( { "numQubits": 1000, diff --git a/source/pip/tests/qre/test_interop.py b/source/pip/tests/qre/test_interop.py index a8f7900abb..4e94d6f549 100644 --- a/source/pip/tests/qre/test_interop.py +++ b/source/pip/tests/qre/test_interop.py @@ -9,6 +9,7 @@ def _ll_files(): + """Return the list of QIR .ll test files.""" ll_dir = ( Path(__file__).parent.parent.parent / "tests-integration" @@ -21,6 +22,7 @@ def _ll_files(): @pytest.mark.parametrize("ll_file", _ll_files(), ids=lambda p: p.stem) def test_trace_from_qir(ll_file): + """Test that trace_from_qir can parse real QIR output files.""" # NOTE: This test is primarily to ensure that the function can parse real # QIR output without errors, rather than checking specific properties of the # trace. @@ -190,6 +192,7 @@ def declare(name, param_types): def test_rotation_buckets(): + """Test that rotation bucketization preserves total count and depth.""" from qsharp.qre.interop._qsharp import _bucketize_rotation_counts print() diff --git a/source/pip/tests/qre/test_isa.py b/source/pip/tests/qre/test_isa.py index 6c1e8e318a..95809968b6 100644 --- a/source/pip/tests/qre/test_isa.py +++ b/source/pip/tests/qre/test_isa.py @@ -19,6 +19,7 @@ def test_isa(): + """Test ISA creation, instruction lookup, and dynamic node addition.""" graph = _ProvenanceGraph() isa = graph.make_isa( [ @@ -59,6 +60,7 @@ def test_isa(): def test_instruction_properties(): + """Test getting and setting instruction properties.""" # Test instruction with no properties instr_no_props = _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {}) assert instr_no_props.get_property(DISTANCE) is None @@ -79,6 +81,7 @@ def test_instruction_properties(): def test_instruction_constraints(): + """Test constraint property filtering and ISA.satisfies behavior.""" # Test constraint without properties c_no_props = constraint(T, encoding=LOGICAL) assert c_no_props.has_property(DISTANCE) is False @@ -123,6 +126,7 @@ def test_instruction_constraints(): def test_property_names(): + """Test property name lookup and case-insensitive key resolution.""" assert property_name(DISTANCE) == "DISTANCE" # An unregistered property @@ -141,6 +145,7 @@ def test_property_names(): def test_generic_function(): + """Test generic_function wrapping for int and float return types.""" from qsharp.qre._qre import _IntFunction, _FloatFunction def time(x: int) -> int: @@ -166,6 +171,7 @@ def error_rate(x: int) -> float: def test_isa_from_architecture(): + """Test generating logical ISAs from an architecture and QEC code.""" arch = GateBased(gate_time=50, measurement_time=100) code = SurfaceCode() ctx = arch.context() diff --git a/source/pip/tests/qre/test_models.py b/source/pip/tests/qre/test_models.py index 728d557169..46b236afb0 100644 --- a/source/pip/tests/qre/test_models.py +++ b/source/pip/tests/qre/test_models.py @@ -46,14 +46,17 @@ class TestGateBased: def test_default_error_rate(self): + """Test that default error rate is 1e-4.""" arch = GateBased(gate_time=50, measurement_time=100) assert arch.error_rate == 1e-4 def test_custom_error_rate(self): + """Test that a custom error rate is accepted.""" arch = GateBased(error_rate=1e-3, gate_time=50, measurement_time=100) assert arch.error_rate == 1e-3 def test_provided_isa_contains_expected_instructions(self): + """Test that GateBased ISA contains all expected physical instructions.""" arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -62,6 +65,7 @@ def test_provided_isa_contains_expected_instructions(self): assert instr_id in isa def test_instruction_encodings_are_physical(self): + """Test that all GateBased ISA instructions have PHYSICAL encoding.""" arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -70,6 +74,7 @@ def test_instruction_encodings_are_physical(self): assert isa[instr_id].encoding == PHYSICAL def test_instruction_error_rates_match(self): + """Test that all instruction error rates match the architecture error rate.""" rate = 1e-3 arch = GateBased(error_rate=rate, gate_time=50, measurement_time=100) ctx = arch.context() @@ -79,6 +84,7 @@ def test_instruction_error_rates_match(self): assert isa[instr_id].expect_error_rate() == rate def test_gate_times(self): + """Test that gate times match the configured gate and measurement times.""" arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -95,6 +101,7 @@ def test_gate_times(self): assert isa[MEAS_Z].expect_time() == 100 def test_arities(self): + """Test that instruction arities match expected values.""" arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -106,6 +113,7 @@ def test_arities(self): assert isa[MEAS_Z].arity == 1 def test_context_creation(self): + """Test that context creation succeeds.""" arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() assert ctx is not None @@ -118,10 +126,12 @@ def test_context_creation(self): class TestMajorana: def test_default_error_rate(self): + """Test that default Majorana error rate is 1e-5.""" arch = Majorana() assert arch.error_rate == 1e-5 def test_provided_isa_contains_expected_instructions(self): + """Test that Majorana ISA contains all expected instructions.""" arch = Majorana() ctx = arch.context() isa = ctx.isa @@ -130,6 +140,7 @@ def test_provided_isa_contains_expected_instructions(self): assert instr_id in isa def test_all_times_are_1us(self): + """Test that all Majorana instruction times are 1000 ns.""" arch = Majorana() ctx = arch.context() isa = ctx.isa @@ -138,6 +149,7 @@ def test_all_times_are_1us(self): assert isa[instr_id].expect_time() == 1000 def test_clifford_error_rates_match_qubit_error(self): + """Test that Clifford error rates match the qubit error rate.""" for rate in [1e-4, 1e-5, 1e-6]: arch = Majorana(error_rate=rate) ctx = arch.context() @@ -157,6 +169,7 @@ def test_t_error_rate_mapping(self): assert isa[T].expect_error_rate() == t_rate def test_two_qubit_measurement_arities(self): + """Test that two-qubit measurement instructions have arity 2.""" arch = Majorana() ctx = arch.context() isa = ctx.isa @@ -172,14 +185,17 @@ def test_two_qubit_measurement_arities(self): class TestSurfaceCode: def test_required_isa(self): + """Test that SurfaceCode has non-None required ISA.""" reqs = SurfaceCode.required_isa() assert reqs is not None def test_default_distance(self): + """Test SurfaceCode with explicit distance parameter.""" sc = SurfaceCode(distance=3) assert sc.distance == 3 def test_provides_lattice_surgery(self): + """Test that SurfaceCode provides a logical LATTICE_SURGERY instruction.""" arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc = SurfaceCode(distance=3) @@ -219,6 +235,7 @@ def test_time_scales_with_distance(self): assert ls.expect_time(1) == syndrome_time * d def test_error_rate_decreases_with_distance(self): + """Test that logical error rate decreases as code distance increases.""" arch = GateBased(gate_time=50, measurement_time=100) errors = [] @@ -246,6 +263,7 @@ def test_enumeration_via_query(self): assert count == 12 def test_custom_crossing_prefactor(self): + """Test that doubling the crossing prefactor doubles the error rate.""" arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() @@ -265,6 +283,7 @@ def test_custom_crossing_prefactor(self): assert abs(custom_error - 2 * default_error) < 1e-20 def test_custom_error_correction_threshold(self): + """Test that a lower error correction threshold yields a higher logical error.""" arch = GateBased(gate_time=50, measurement_time=100) ctx1 = arch.context() @@ -290,10 +309,12 @@ def test_custom_error_correction_threshold(self): class TestThreeAux: def test_required_isa(self): + """Test that ThreeAux has non-None required ISA.""" reqs = ThreeAux.required_isa() assert reqs is not None def test_provides_lattice_surgery(self): + """Test that ThreeAux provides a LATTICE_SURGERY instruction.""" arch = Majorana() ctx = arch.context() ta = ThreeAux(distance=3) @@ -340,6 +361,7 @@ def test_time_formula_single_rail(self): assert ls.expect_time(1) == expected_time def test_error_rate_decreases_with_distance(self): + """Test that ThreeAux error rate decreases with increasing distance.""" arch = Majorana() errors = [] @@ -374,6 +396,7 @@ def test_single_rail_has_different_error_threshold(self): assert error_double != error_single def test_enumeration_via_query(self): + """Test that ThreeAux.q() enumerates all distance and rail combinations.""" arch = Majorana() ctx = arch.context() @@ -402,6 +425,7 @@ def _get_lattice_surgery_isa(self, distance=5): return isas[0], ctx def test_provides_memory_instruction(self): + """Test that YokedSurfaceCode provides a MEMORY instruction.""" ls_isa, ctx = self._get_lattice_surgery_isa() ysc = TwoDimensionalYokedSurfaceCode() @@ -410,6 +434,7 @@ def test_provides_memory_instruction(self): assert MEMORY in isas[0] def test_memory_is_logical(self): + """Test that the MEMORY instruction has LOGICAL encoding.""" ls_isa, ctx = self._get_lattice_surgery_isa() ysc = TwoDimensionalYokedSurfaceCode() @@ -418,6 +443,7 @@ def test_memory_is_logical(self): assert mem.encoding == LOGICAL def test_memory_arity_is_variable(self): + """Test that MEMORY instruction has variable arity (None).""" ls_isa, ctx = self._get_lattice_surgery_isa() ysc = TwoDimensionalYokedSurfaceCode() @@ -427,6 +453,7 @@ def test_memory_arity_is_variable(self): assert mem.arity is None def test_space_increases_with_arity(self): + """Test that MEMORY space increases with the number of qubits.""" ls_isa, ctx = self._get_lattice_surgery_isa() ysc = TwoDimensionalYokedSurfaceCode() @@ -438,6 +465,7 @@ def test_space_increases_with_arity(self): assert spaces[i] < spaces[i + 1] def test_time_increases_with_arity(self): + """Test that MEMORY time increases with the number of qubits.""" ls_isa, ctx = self._get_lattice_surgery_isa() ysc = TwoDimensionalYokedSurfaceCode() @@ -449,6 +477,7 @@ def test_time_increases_with_arity(self): assert times[i] < times[i + 1] def test_error_rate_increases_with_arity(self): + """Test that MEMORY error rate increases with the number of qubits.""" ls_isa, ctx = self._get_lattice_surgery_isa() ysc = TwoDimensionalYokedSurfaceCode() @@ -460,6 +489,7 @@ def test_error_rate_increases_with_arity(self): assert errors[i] < errors[i + 1] def test_distance_property_propagated(self): + """Test that the distance property is propagated to the MEMORY instruction.""" d = 7 ls_isa, ctx = self._get_lattice_surgery_isa(distance=d) ysc = TwoDimensionalYokedSurfaceCode() @@ -476,6 +506,7 @@ def test_distance_property_propagated(self): class TestLitinski19Factory: def test_required_isa(self): + """Test that Litinski19Factory has non-None required ISA.""" reqs = Litinski19Factory.required_isa() assert reqs is not None @@ -496,6 +527,7 @@ def test_table1_yields_t_and_ccz(self): assert len(isa) == 2 def test_table1_instruction_properties(self): + """Test that Table 1 T and CCZ instructions have valid properties.""" arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -620,6 +652,7 @@ def test_time_based_on_syndrome_extraction(self): class TestMagicUpToClifford: def test_required_isa_is_empty(self): + """Test that MagicUpToClifford has non-None required ISA.""" reqs = MagicUpToClifford.required_isa() assert reqs is not None @@ -758,6 +791,7 @@ def test_no_family_present_passes_through(self): def test_isa_manipulation(): + """Test Litinski19Factory and MagicUpToClifford ISA integration.""" arch = GateBased(gate_time=50, measurement_time=100) factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -809,10 +843,12 @@ def test_isa_manipulation(): class TestRoundBasedFactory: def test_required_isa(self): + """Test that RoundBasedFactory has non-None required ISA.""" reqs = RoundBasedFactory.required_isa() assert reqs is not None def test_produces_logical_t_gates(self): + """Test that RoundBasedFactory produces logical T gates with valid properties.""" arch = GateBased(gate_time=50, measurement_time=100) for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): @@ -891,6 +927,7 @@ def test_with_three_aux_code_query(self): assert count > 0 def test_round_based_gate_based_sum(self): + """Test RoundBasedFactory aggregated totals with GateBased sum mode.""" arch = GateBased(gate_time=50, measurement_time=100) total_space = 0 @@ -910,6 +947,7 @@ def test_round_based_gate_based_sum(self): assert count == 107 def test_round_based_gate_based_max(self): + """Test RoundBasedFactory aggregated totals with GateBased max mode.""" arch = GateBased(gate_time=50, measurement_time=100) total_space = 0 @@ -931,6 +969,7 @@ def test_round_based_gate_based_max(self): assert count == 77 def test_round_based_msft_sum(self): + """Test RoundBasedFactory aggregated totals with Majorana sum mode.""" arch = Majorana() total_space = 0 From 556e8a49f43166bf5374176135fceefa5afb7e5b Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Thu, 2 Apr 2026 17:42:22 +0200 Subject: [PATCH 43/45] Remove matplotlib from tests (#3078) Using matplotlib in tests is causing CI problems. This fixes #3076 --- source/pip/tests/qre/test_estimation_table.py | 72 ------------------- 1 file changed, 72 deletions(-) diff --git a/source/pip/tests/qre/test_estimation_table.py b/source/pip/tests/qre/test_estimation_table.py index d2a25ae31b..744a3dc607 100644 --- a/source/pip/tests/qre/test_estimation_table.py +++ b/source/pip/tests/qre/test_estimation_table.py @@ -1,8 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import cast, Sized - import pytest import pandas as pd @@ -367,73 +365,3 @@ def test_estimation_table_computed_column(): frame = table.as_frame() assert frame["qubit_error_product"][0] == pytest.approx(1.0) assert frame["qubit_error_product"][1] == pytest.approx(4.0) - - -def test_estimation_table_plot_returns_figure(): - """Test that plot() returns a matplotlib Figure with correct axes.""" - from matplotlib.figure import Figure - - table = EstimationTable() - table.append(_make_entry(100, 5_000_000_000, 0.01)) - table.append(_make_entry(200, 10_000_000_000, 0.02)) - table.append(_make_entry(50, 50_000_000_000, 0.005)) - - fig = table.plot() - - assert isinstance(fig, Figure) - ax = fig.axes[0] - assert ax.get_ylabel() == "Physical qubits" - assert ax.get_xlabel() == "Runtime" - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "log" - - # Verify data points - offsets = ax.collections[0].get_offsets() - assert len(cast(Sized, offsets)) == 3 - - -def test_estimation_table_plot_empty_raises(): - """Test that plot() raises ValueError on an empty table.""" - table = EstimationTable() - with pytest.raises(ValueError, match="Cannot plot an empty EstimationTable"): - table.plot() - - -def test_estimation_table_plot_single_entry(): - """Test that plot() works with a single entry.""" - from matplotlib.figure import Figure - - table = EstimationTable() - table.append(_make_entry(100, 1_000_000, 0.01)) - - fig = table.plot() - assert isinstance(fig, Figure) - - offsets = fig.axes[0].collections[0].get_offsets() - assert len(cast(Sized, offsets)) == 1 - - -def test_estimation_table_plot_with_runtime_unit(): - """Test that plot(runtime_unit=...) scales x values and labels the axis.""" - table = EstimationTable() - # 1 hour = 3600e9 ns, 2 hours = 7200e9 ns - table.append(_make_entry(100, int(3600e9), 0.01)) - table.append(_make_entry(200, int(7200e9), 0.02)) - - fig = table.plot(runtime_unit="hours") - - ax = fig.axes[0] - assert ax.get_xlabel() == "Runtime (hours)" - - # Verify the x data is scaled: should be 1.0 and 2.0 hours - offsets = cast(list, ax.collections[0].get_offsets()) - assert offsets[0][0] == pytest.approx(1.0) - assert offsets[1][0] == pytest.approx(2.0) - - -def test_estimation_table_plot_invalid_runtime_unit(): - """Test that plot() raises ValueError for an unknown runtime_unit.""" - table = EstimationTable() - table.append(_make_entry(100, 1000, 0.01)) - with pytest.raises(ValueError, match="Unknown runtime_unit"): - table.plot(runtime_unit="fortnights") From cf70736e4ea45afd229b787a8c9cb91c067cbdd2 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Thu, 2 Apr 2026 17:43:07 +0200 Subject: [PATCH 44/45] Adjust pruning logic for graph based search (#3089) Graph-based estimation missed some estimates when using ISA transforms with a pass-through (e.g., those extending instruction sets and keeping the previous ones). --- source/qre/src/isa/provenance.rs | 18 +++++++++-- source/qre/src/trace/estimation.rs | 48 ++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/source/qre/src/isa/provenance.rs b/source/qre/src/isa/provenance.rs index 8b59660639..5f68ca180e 100644 --- a/source/qre/src/isa/provenance.rs +++ b/source/qre/src/isa/provenance.rs @@ -239,13 +239,27 @@ impl ProvenanceGraph { // may have pruned nodes from this range as duplicates of // earlier, equivalent nodes outside the range. let matching: Vec<(u64, usize)> = if min_idx > 0 { - (min_idx..self.nodes.len()) + let mut m: Vec<(u64, usize)> = (min_idx..self.nodes.len()) .filter(|&node_idx| { let instr = &self.nodes[node_idx].instruction; instr.id == constraint.id() && constraint.is_satisfied_by(instr) }) .map(|node_idx| (constraint.id(), node_idx)) - .collect() + .collect(); + + // Fall back to the full graph for passthrough instructions + // that the source does not modify (e.g. architecture base + // gates that a wrapper leaves unchanged). + if m.is_empty() { + m = (1..min_idx) + .filter(|&node_idx| { + let instr = &self.nodes[node_idx].instruction; + instr.id == constraint.id() && constraint.is_satisfied_by(instr) + }) + .map(|node_idx| (constraint.id(), node_idx)) + .collect(); + } + m } else { let Some(pareto) = self.pareto_index.get(&constraint.id()) else { return Vec::new(); diff --git a/source/qre/src/trace/estimation.rs b/source/qre/src/trace/estimation.rs index b75ab35fde..3bbc7dd25d 100644 --- a/source/qre/src/trace/estimation.rs +++ b/source/qre/src/trace/estimation.rs @@ -8,7 +8,7 @@ use std::{ sync::{Arc, RwLock, atomic::AtomicUsize}, }; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use crate::{EstimationCollection, ISA, ProvenanceGraph, ResultSummary, Trace}; @@ -158,8 +158,28 @@ fn combination_context_hash(combination: &[CombinationEntry], exclude_idx: usize /// A combination is prunable if, for any instruction slot, there exists a /// successful combination with the same instructions in all other slots and /// an instruction at that slot with `space <=` and `time <=`. -fn is_dominated(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) -> bool { +/// +/// However, for instruction slots that affect the algorithm runtime (gate +/// instructions) when the trace has factory instructions (resource states), +/// time-based dominance is **unsound**. A faster gate instruction reduces +/// the algorithm runtime, which reduces the number of factory runs per copy, +/// requiring more factory copies and potentially more total qubits. The +/// "dominated" combination (with a slower gate) can therefore produce a +/// Pareto-optimal result with fewer qubits but more runtime. +/// +/// When `runtime_affecting_ids` is non-empty, slots whose instruction ID +/// appears in that set are skipped entirely during the dominance check. +fn is_dominated( + combination: &[CombinationEntry], + trace_pruning: &[SlotWitnesses], + runtime_affecting_ids: &FxHashSet, +) -> bool { for (slot_idx, entry) in combination.iter().enumerate() { + // Skip dominance check for runtime-affecting slots when factories + // exist, because shorter gate time can increase factory overhead. + if runtime_affecting_ids.contains(&entry.instruction_id) { + continue; + } let ctx_hash = combination_context_hash(combination, slot_idx); let map = trace_pruning[slot_idx] .read() @@ -282,6 +302,23 @@ pub fn estimate_with_graph( // jobs. let mut max_slots = 0; + // For each trace, collect the set of instruction IDs that affect the + // algorithm runtime (gate instructions from the trace block structure). + // When a trace also has resource states (factories), dominance pruning + // on these slots is unsound because shorter gate time can increase + // factory overhead (see `is_dominated` documentation). + let runtime_affecting_ids: Vec> = traces + .iter() + .map(|trace| { + let has_factories = trace.get_resource_states().is_some_and(|rs| !rs.is_empty()); + if has_factories { + trace.deep_iter().map(|(gate, _)| gate.id).collect() + } else { + FxHashSet::default() + } + }) + .collect(); + for (trace_idx, trace) in traces.iter().enumerate() { if trace.base_error() > max_error { continue; @@ -384,6 +421,7 @@ pub fn estimate_with_graph( let next_job = &next_job; let jobs = &jobs; let pruning_witnesses = &pruning_witnesses; + let runtime_affecting_ids = &runtime_affecting_ids; let isa_index = Arc::clone(&isa_index); scope.spawn(move || { let mut local_results = Vec::new(); @@ -397,7 +435,11 @@ pub fn estimate_with_graph( // Dominance pruning: skip if a cheaper instruction at any // slot already succeeded with the same surrounding context. - if is_dominated(combination, &pruning_witnesses[*trace_idx]) { + if is_dominated( + combination, + &pruning_witnesses[*trace_idx], + &runtime_affecting_ids[*trace_idx], + ) { continue; } From e5708d1fc74a9638aa2b3058bfa4f533efd90d32 Mon Sep 17 00:00:00 2001 From: Mathias Soeken Date: Thu, 2 Apr 2026 17:39:35 +0000 Subject: [PATCH 45/45] Fix CODEOWNERS file. --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index ed9636ca57..5edd1862eb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ -- @billti @idavis @minestarks @swernli +* @billti @idavis @minestarks @swernli /.github/ISSUE_TEMPLATE/fuzz_bug_report.md @billti @idavis @swernli /.github/workflows/fuzz.yml @billti @idavis @swernli