Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ build/
*.egg-info/
tmp*
notebooks/re/tmp*
notebooks/tmp*
profiling/profile.json
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 25.11.0
rev: 26.1.0
hooks:
- id: black
language_version: python3.12
Expand Down
17,799 changes: 17,799 additions & 0 deletions notebooks/compiler_demo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
dev = [
"pytest",
"pre-commit>=3.0.0",
"black>=23.7.0",
"black==26.1.0",
"py-spy",
]

Expand Down
1 change: 1 addition & 0 deletions qmath/compile/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .evaluate import EvaluateExpression
158 changes: 158 additions & 0 deletions qmath/compile/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import ast

from psiqworkbench import QPU, QUInt, QFixed, Qubrick
from psiqworkbench.filter_presets import BIT_DEFAULT

from qmath.utils.symbolic import alloc_temp_qreg_like
from qmath.func.common import MultiplyAdd, MultiplyConstAdd, Add, AddConst, Negate
from qmath.func.square import Square

from qmath.utils.gates import ParallelCnot

# Type alias to represent quantum register or a literal number.
QValue = QFixed | float


# Ensures that x is of type QValue.
def _make_qvalue(x) -> QValue:
if isinstance(x, QFixed):
return x
if isinstance(x, int) or isinstance(x, float):
return float(x)
raise ValueError("Unsupported type", type(x))


ops = []


class EvaluateExpression(Qubrick):
"""Evaluates arithmetic expression."""

def __init__(self, expr: str, mutable_vars: set[str] = None, **kwargs):
super().__init__(**kwargs)
self.expr = expr
self.vars = dict()
self.immutable_regs = set()
self.mutable_vars = mutable_vars or set()

def _make_copy(self, x: QFixed) -> QFixed:
_, ans = alloc_temp_qreg_like(self, x)
ParallelCnot().compute(x, ans)
return ans

def _implement_unary_op(self, op: ast.BinOp, arg: QValue) -> QValue:
if isinstance(op, ast.USub):
return self._negate(arg)
raise ValueError(f"Unsupported unary op: {op}.")

def _implement_binary_op(self, op: ast.BinOp, arg1: QValue, arg2: QValue) -> QValue:
if isinstance(op, ast.Add):
return self._add(arg1, arg2)
if isinstance(op, ast.Sub):
return self._sub(arg1, arg2)
if isinstance(op, ast.Mult):
return self._mul(arg1, arg2)
raise ValueError(f"Unsupported binary op: {op}.")

def _negate(self, arg: QValue) -> QValue:
if isinstance(arg, float):
return -arg
assert isinstance(arg, QFixed)
if arg.mask() in self.immutable_regs:
return self._negate(self._make_copy(arg))
Negate().compute(arg)
return arg

def _add(self, arg1: QValue, arg2: QValue) -> QValue:
if isinstance(arg1, float) and isinstance(arg2, float):
return arg1 + arg2
if isinstance(arg1, float):
return self._add(arg2, arg1)

assert isinstance(arg1, QFixed)

if isinstance(arg2, QFixed):
# Quantum-quantum addition.
if arg1.mask() in self.immutable_regs and arg2.mask() in self.immutable_regs:
return self._add(self._make_copy(arg1), arg2)
if arg1.mask() in self.immutable_regs:
return self._add(arg2, arg1)
Add().compute(arg1, arg2)
return arg1
else:
assert isinstance(arg2, float)
if arg1.mask() in self.immutable_regs:
return self._add(self._make_copy(arg1), arg2)
AddConst(arg2).compute(arg1)
return arg1

def _sub(self, arg1: QValue, arg2: QValue) -> QValue:
if isinstance(arg1, float):
return self._add(-arg1, arg2)
if isinstance(arg2, float):
return self._add(arg1, -arg2)

assert isinstance(arg1, QFixed)
assert isinstance(arg2, QFixed)

if arg1.mask() not in self.immutable_regs:
# arg1 -= arg2
with Negate().computed(arg1):
Add().compute(arg1, arg2)
return arg1
elif arg2.mask() not in self.immutable_regs:
# arg2 := -arg2
# arg2 += arg1
Negate().compute(arg2)
Add().compute(arg2, arg1)
return arg2
else:
# Both immutable. Allocate answer.
return self._negate(arg1, self._make_copy(arg2))

def _mul(self, arg1: QValue, arg2: QValue) -> QValue:
if isinstance(arg1, float) and isinstance(arg2, float):
return arg1 * arg2
if isinstance(arg1, float):
return self._mul(arg2, arg1)

assert isinstance(arg1, QFixed)
_, ans = alloc_temp_qreg_like(self, arg1)

if isinstance(arg2, QFixed):
if arg1.mask() == arg2.mask():
Square().compute(arg1, ans)
return ans
MultiplyAdd().compute(ans, arg1, arg2)
else:
assert isinstance(arg2, float)
MultiplyConstAdd(arg2).compute(ans, arg1)
return ans

def _convert_ast_node(self, node) -> QFixed | float:
if isinstance(node, ast.BinOp):
arg1 = self._convert_ast_node(node.left)
arg2 = self._convert_ast_node(node.right)
return self._implement_binary_op(node.op, arg1, arg2)
elif isinstance(node, ast.UnaryOp):
arg = self._convert_ast_node(node.operand)
return self._implement_unary_op(node.op, arg)
elif isinstance(node, ast.Name):
assert node.id in self.vars
return self.vars[node.id]
elif isinstance(node, ast.Constant):
return _make_qvalue(node.value)
else:
raise ValueError(f"Cannot handle: {node}")

def _compute(self, args: dict):
self.vars = dict()
for key, value in args.items():
value = _make_qvalue(value)
self.vars[key] = value
if key not in self.mutable_vars and isinstance(value, QFixed):
self.immutable_regs.add(value.mask())

root = ast.parse(self.expr, mode="eval")
ans = self._convert_ast_node(root.body)
self.set_result_qreg(ans)
111 changes: 111 additions & 0 deletions qmath/compile/evaluate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
from dataclasses import dataclass
from typing import Callable
import pytest

from psiqworkbench import QPU, QFixed
from psiqworkbench.filter_presets import BIT_DEFAULT

from qmath.compile import EvaluateExpression
from qmath.utils.test_utils import QPUTestHelper

RUN_SLOW_TESTS = os.getenv("RUN_SLOW_TESTS") == "1"


@dataclass
class EvaluateTestCase:
expr: str
args: list[str]
func: Callable[..., float]
inputs: list[list[float]]
num_qubits: int
qubits_per_reg: int = 8
radix: int = 1


def _test_evaluate(tc: EvaluateTestCase):
qpu_helper = QPUTestHelper(
num_inputs=len(tc.args),
num_qubits=tc.num_qubits,
qubits_per_reg=tc.qubits_per_reg,
radix=tc.radix,
)
v = {tc.args[i]: qpu_helper.inputs[i] for i in range(len(tc.args))}
op = EvaluateExpression(tc.expr, qc=qpu_helper.qpu)
op.compute(v)
qpu_helper.record_op(op.get_result_qreg())

for args in tc.inputs:
assert qpu_helper.apply_op(args) == tc.func(*args)


# Use this test case for debugging. It does not use any helpers.
@pytest.mark.skipif(not RUN_SLOW_TESTS, reason="slow test")
def test_debug():
qpu = QPU(filters=BIT_DEFAULT)
qpu.reset(1000)
qs_x = QFixed(20, name="x", radix=5, qpu=qpu)
qs_y = QFixed(20, name="y", radix=5, qpu=qpu)
qs_z = QFixed(20, name="z", radix=5, qpu=qpu)
x, y, z = -10, 0, 5
qs_x.write(x)
qs_y.write(y)
qs_z.write(z)

expected = -x + 2 * (y + 3 * z - x * x) + x * y + x * y * z - z * x
compiler = EvaluateExpression("-x + 2*(y + 3*z - x*x) + x*y + x*y*z - z*x", qc=qpu)
compiler.compute({"x": qs_x, "y": qs_y, "z": qs_z})
ans = compiler.get_result_qreg()

assert ans.read() == expected


def test_add():
_test_evaluate(
EvaluateTestCase(
expr="x+y+z",
args=["x", "y", "z"],
func=lambda x, y, z: x + y + z,
inputs=[[1, 2, -1], [-3.5, 4, 0]],
num_qubits=50,
)
)


def test_multiply():
_test_evaluate(
EvaluateTestCase(
expr="x*y",
args=["x", "y"],
func=lambda x, y: x * y,
inputs=[[3, 2], [-3.5, 4]],
num_qubits=100,
)
)


def test_multiply_const():
_test_evaluate(
EvaluateTestCase(
expr="x*2.5",
args=["x"],
func=lambda x: x * 2.5,
inputs=[[-1], [0], [2], [4]],
num_qubits=100,
)
)


@pytest.mark.skipif(not RUN_SLOW_TESTS, reason="slow test")
def test_complex_expression():
_test_evaluate(
EvaluateTestCase(
expr="-x + 2*(y + 3*z - x*x) + x*y + x*y*z - z*x",
args=["x", "y", "z"],
func=lambda x, y, z: -x + 2 * (y + 3 * z - x * x) + x * y + x * y * z - z * x,
inputs=[[1, 2.0, -3], [4.125, 5, 6.5], [-10, 0, 5]],
num_qubits=300,
qubits_per_reg=20,
radix=10,
)
)
Empty file added qmath/compile/optimizers.py
Empty file.
41 changes: 39 additions & 2 deletions qmath/func/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,53 @@ def _estimate(self, dst: SymbolicQFixed, lhs: SymbolicQFixed, rhs: SymbolicQFixe
self.get_qc().add_cost_event(cost)


# TODO: implement.
# Preparation for MultiplyConstAdd to handle negative inputs.
class _MulConstPrep(Qubrick):

def _compute(self, x: QFixed, x_sign: Qubits, y: float, dst: QFixed):
x_sign[0].lelbow(x[-1])
Negate().compute(x, ctrl=x_sign)
if y < 0:
x_sign.x()
dst.x(x_sign)


class MultiplyConstAdd(Qubrick):
"""Computes dst += lhs * rhs (rhs is a classical number)."""

def __init__(self, rhs: float, **kwargs):
super().__init__(**kwargs)
self.rhs = rhs

# z += y*x, assuming x>=0, y>0.
def _compute_positive(self, x: QFixed, y: float, z: QFixed):
assert y > 0
x = QInt(x)
z = QInt(z)
min_i = x.radix - z.radix - (x.num_qubits - 1)
max_i = x.radix - z.radix + (z.num_qubits - 1)

for i in range(min_i, max_i + 1):
shift = z.radix - x.radix + i
bit = int(y * (2 ** (-shift))) % 2
if bit == 1:
if shift < 0:
qbk.GidneyAdd().compute(z, x[(-shift):])
else:
qbk.GidneyAdd().compute(z[shift:], x)

def _compute(self, dst: QFixed, lhs: QFixed):
pass
if self.rhs == 0:
return

x_sign = self.alloc_temp_qreg(1, "x_sign")

# Preparation to handle negative inputs.
with _MulConstPrep().computed(lhs, x_sign, self.rhs, dst):
self._compute_positive(lhs, abs(self.rhs), dst)

x_sign.release()

def _estimate(self, dst: SymbolicQFixed, lhs: SymbolicQFixed):
# TODO: implement.
pass
15 changes: 14 additions & 1 deletion qmath/func/common_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random

from qmath.func.common import AbsInPlace, Subtract
from qmath.func.common import AbsInPlace, Subtract, MultiplyConstAdd
from qmath.utils.test_utils import QPUTestHelper


Expand All @@ -26,3 +26,16 @@ def test_subtract():
result = qpu_helper.apply_op([x, y])
expected = x - y
assert abs(result - expected) < 1e-9


def test_multiply_const_add():
for y in [-11.25, 0, 1.5, 10.3]:
qpu_helper = QPUTestHelper(num_inputs=2, num_qubits=200, qubits_per_reg=25, radix=15)
qs_x, qs_z = qpu_helper.inputs
MultiplyConstAdd(y).compute(qs_z, qs_x)
qpu_helper.record_op(qs_z)

for x in [-10, 5.5, 0, 10.125]:
result = qpu_helper.apply_op([x, 0])
expected = x * y
assert abs(result - expected) < 1e-4
Loading