diff --git a/.github/AGENTS.md b/.github/AGENTS.md deleted file mode 100644 index 1e0f18d8..00000000 --- a/.github/AGENTS.md +++ /dev/null @@ -1,205 +0,0 @@ -# Agent Instructions for egglog-python - -This file provides instructions for AI coding agents (including GitHub Copilot) working on this repository. - -## Project Overview - -This repository provides Python bindings for the Rust library `egglog`, enabling the use of e-graphs in Python for optimization, symbolic computation, and analysis. It is a hybrid project combining: -- **Python code** in `python/egglog/` - The main Python API and library -- **Rust code** in `src/` - PyO3-based bindings to the egglog Rust library -- **Documentation** in `docs/` - Sphinx-based documentation - -## Repository Structure - -- `python/egglog/` - Main Python package source code -- `python/tests/` - Python test suite (pytest-based) -- `src/` - Rust source code for Python bindings (PyO3) -- `docs/` - Documentation source files (Sphinx) -- `test-data/` - Test data files -- `pyproject.toml` - Python project configuration and dependencies -- `Cargo.toml` - Rust project configuration -- `uv.lock` - Locked dependencies (managed by uv) - -## Build and Development Commands - -### Prerequisites -- **uv** - Package manager (https://github.com/astral-sh/uv) -- **Rust toolchain** - Version pinned in `rust-toolchain.toml` -- **Python** - Version pinned in `.python-version` - -### Common Commands - -```bash -# Install dependencies -uv sync --all-extras - -# Reinstall the Rust extension after changing code in `src/` -uv sync --reinstall-package egglog --all-extras - -# Run tests -uv run pytest --benchmark-disable -vvv --durations=10 - -# Type checking with mypy -make mypy - -# Stub testing -make stubtest - -# Build documentation -make docs - -# Refresh the bundled visualizer assets -make clean -make - -# Format code (auto-run by pre-commit) -uv run ruff format . - -# Lint code (auto-run by pre-commit) -uv run ruff check --fix . -``` - -## Python Code Standards - -### General Guidelines -- **Line length**: 120 characters maximum -- **Type hints**: Use type annotations for public APIs and functions -- **Formatting**: Use Ruff for code formatting and linting -- **Testing**: Write tests using pytest in `python/tests/` -- **Docstrings**: Use clear, concise docstrings for public functions and classes - -### Ruff Configuration -The project uses Ruff for linting and formatting with specific rules: -- Allows uppercase variable names (N806, N802) -- Allows star imports (F405, F403) -- Allows `exec` and subprocess usage (S102, S307, S603) -- Allows `Any` type annotations (ANN401) -- Test files don't require full type annotations - -See `pyproject.toml` for complete Ruff configuration. - -### Type Checking -- **mypy** is used for static type checking -- Run `make mypy` to type check Python code -- Run `make stubtest` to validate type stubs against runtime behavior -- Exclusions: `__snapshots__`, `_build`, `conftest.py` - -### Testing -- Tests are located in `python/tests/` -- Use pytest with snapshot testing (syrupy) -- Benchmarks use pytest-benchmark and CodSpeed -- Run tests with: `uv run pytest --benchmark-disable -vvv` - -## Rust Code Standards - -### General Guidelines -- **Edition**: Rust 2024 (experimental) -- **FFI**: Uses PyO3 for Python bindings -- **Main library**: Uses egglog from git (saulshanabrook/egg-smol, clone-cost branch) - -### Rust File Organization -- `src/lib.rs` - Main library entry point -- `src/egraph.rs` - E-graph implementation -- `src/conversions.rs` - Type conversions between Python and Rust -- `src/py_object_sort.rs` - Python object handling -- `src/extract.rs` - Extraction functionality -- `src/error.rs` - Error handling -- `src/serialize.rs` - Serialization support -- `src/termdag.rs` - Term DAG operations -- `src/utils.rs` - Utility functions - -### Python File Organization - -#### Public Interface -All public Python APIs are exported from the top-level `egglog` module. Anything that is public should be exported in `python/egglog/__init__.py` at the top level. - -#### Lower-Level Bindings -The `egglog.bindings` module provides lower-level access to the Rust implementation for advanced use cases. - -#### Core Python Files -- `python/egglog/__init__.py` - Top-level module exports, defines the public API -- `python/egglog/egraph.py` - Main EGraph class and e-graph management -- `python/egglog/egraph_state.py` - E-graph state and execution management -- `python/egglog/runtime.py` - Runtime system for expression evaluation and method definitions -- `python/egglog/builtins.py` - Built-in types (i64, f64, String, Vec, etc.) and operations -- `python/egglog/declarations.py` - Class, function, and method declaration decorators -- `python/egglog/conversion.py` - Type conversion between Python and egglog types -- `python/egglog/pretty.py` - Pretty printing for expressions and e-graph visualization -- `python/egglog/deconstruct.py` - Deconstruction of Python values into egglog expressions -- `python/egglog/thunk.py` - Lazy evaluation support -- `python/egglog/type_constraint_solver.py` - Type inference and constraint solving -- `python/egglog/config.py` - Configuration settings -- `python/egglog/ipython_magic.py` - IPython/Jupyter integration -- `python/egglog/visualizer_widget.py` - Interactive visualization widget -- `python/egglog/version_compat.py` - Python version compatibility utilities -- `python/egglog/examples/` - End-to-end samples and tutorials demonstrating the API -- `python/egglog/exp/` - Experimental Array API integrations and code generation helpers - -The compiled extension artifact `python/egglog/bindings.cpython-*.so` is generated by `uv sync` and should not be edited manually. - -## Code Style Preferences - -1. **Imports**: Follow Ruff's import sorting -2. **Naming**: - - Python: snake_case for functions and variables, PascalCase for classes - - Rust: Follow standard Rust conventions -3. **Comments**: Use clear, explanatory comments for complex logic -4. **Documentation**: Keep docs synchronized with code changes - -## Contributing Guidelines - -When making changes: -1. Update or add tests in `python/tests/` for Python changes -2. Run the full test suite before committing -3. Ensure type checking passes with `make mypy` -4. Build documentation if changing public APIs -5. Follow existing code patterns and style -6. Keep changes minimal and focused -7. Ensure the automatic changelog entry in `docs/changelog.md` (added when opening the PR) accurately reflects your change and add manual notes if additional clarification is needed - -## Common Patterns - -### Python API Design -- Define e-graph classes by inheriting from `egglog.Expr` -- Use `@egraph.function` decorator for functions -- Use `@egraph.method` decorator for methods -- Leverage type annotations for better IDE support - -### Working with Values -- Use `get_literal_value(expr)` or the `.value` property to get Python values from primitives -- Use pattern matching with `match`/`case` for destructuring egglog primitives -- Use `get_callable_fn(expr)` to get the underlying Python function from a callable expression -- Use `get_callable_args(expr)` to get arguments to a callable - -### Parallelism -- The underlying Rust library uses Rayon for parallelism -- Control worker thread count via `RAYON_NUM_THREADS` environment variable -- Defaults to single thread if not set - -### Rust-Python Integration -- Use PyO3's `#[pyclass]` and `#[pymethods]` macros -- Handle errors with appropriate Python exceptions -- Convert between Rust and Python types in `conversions.rs` - -## Documentation - -Documentation is built with Sphinx: -- Source files in `docs/` -- Build with `make docs` -- Output in `docs/_build/html/` -- Hosted on ReadTheDocs - -## Testing Strategy - -1. **Unit tests**: Test individual functions and classes -2. **Integration tests**: Test complete workflows -3. **Snapshot tests**: Use syrupy for snapshot testing of complex outputs -4. **Benchmarks**: Performance testing with pytest-benchmark and pytest-codspeed -5. **Parallel testing**: Use pytest-xdist for faster test runs -6. **Type checking**: Validate type stubs and annotations - -## Performance Considerations - -- The library uses Rust for performance-critical operations -- Benchmarking is done via CodSpeed for continuous performance monitoring -- Profile with release builds (`cargo build --release`) when needed diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2c297eb9..b1435c55 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -58,20 +58,22 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version-file: ".python-version" + # Work around import-time SciPy/sklearn segfaults on the codspeed ARM64 Python 3.13 runner. + python-version: "3.12" - uses: astral-sh/setup-uv@v7 with: enable-cache: true + python-version: "3.12" - uses: dtolnay/rust-toolchain@1.79.0 - uses: Swatinem/rust-cache@v2 - run: | export UV_PROJECT_ENVIRONMENT="${pythonLocation}" - uv sync --extra test --locked - - uses: CodSpeedHQ/action@v4.3.1 + uv sync --extra test --locked --python 3.12 + - uses: CodSpeedHQ/action@v4.11.1 with: token: ${{ secrets.CODSPEED_TOKEN }} # allow updating snapshots due to indeterministic benchmarks - run: pytest -vvv --snapshot-update --durations=10 + run: pytest -vvv --snapshot-update --durations=10 --codspeed-max-rounds=10 python/tests/test_array_api.py -k "test_jit or test_run_lda" mode: ${{ matrix.runner == 'ubuntu-latest' && 'instrumentation' || 'walltime' }} docs: diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..f8828ade --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,46 @@ +# Repo Guidance + +## Overview + +- This repo combines the high-level Python bindings in `python/egglog/`, the Rust extension in `src/`, and the Sphinx docs in `docs/`. +- Public Python APIs are exported from `python/egglog/__init__.py`. +- The compiled `python/egglog/bindings.cpython-*.so` artifact is generated and should not be edited directly. + +## Common Commands + +- `uv sync --all-extras` installs the full dev environment. +- `uv sync --reinstall-package egglog --all-extras` rebuilds the Rust extension after changes in `src/`. +- `uv run pytest --benchmark-disable -q` runs the Python tests without benchmark calibration. +- `make mypy` runs the type checker. +- `make stubtest` checks the runtime against the type stubs. +- `make docs` builds the docs. + +## Docs + +- Use the Context7 MCP server for egglog documentation instead of copying external doc summaries into this file. +- Keep general workflows in the how-to guides, and keep Python-specific runtime/reference examples in `docs/reference/python-integration.md`. +- If a PR adds or updates a changelog entry in `docs/changelog.md`, keep it aligned with the final code changes. +- For a clean docs rebuild, clear `docs/_build/`; the MyST-NB execution cache lives in `docs/_build/.jupyter_cache`. + +## Python bindings + +- Prefer relative imports inside `python/egglog`. +- When changing public high-level APIs, update the public docs, stubs, and pretty/freeze round-trip expectations together. +- Higher-order callable type probing should stay isolated from the live ruleset: copy declarations and run with no current ruleset so inference does not register temporary unnamed functions or rewrites. + +## Array API + +- Start with `python/egglog/exp/array_api.py` and `python/tests/test_array_api.py`. +- `Vec[...]` is a primitive sort; avoid rewrites or unions that merge distinct vec values. +- Guard vector indexing rewrites with explicit bounds checks. + +## CI + +- When debugging GitHub Actions logs, prefer the private `$github-actions-rest-logs` skill or the equivalent REST API flow with `GITHUB_PAT_TOKEN`. + +## Verification + +- Prefer the minimal code change and the minimal diff that solves the task; only broaden the change if the smaller fix is not sufficient. +- Run `make mypy` for typing changes. +- Run targeted pytest for touched modules. +- Run `make docs` for docs or public API changes. diff --git a/Cargo.lock b/Cargo.lock index 820a96b9..0616be4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -79,6 +79,23 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.5.0" @@ -121,12 +138,27 @@ version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cfg-if" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "num-traits", +] + [[package]] name = "clap" version = "4.5.51" @@ -293,6 +325,17 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", +] + [[package]] name = "dot-generator" version = "0.2.0" @@ -316,8 +359,8 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "egglog" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "csv", "dyn-clone", @@ -343,8 +386,8 @@ dependencies = [ [[package]] name = "egglog-add-primitive" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "quote", "syn 2.0.107", @@ -352,16 +395,16 @@ dependencies = [ [[package]] name = "egglog-ast" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "ordered-float", ] [[package]] name = "egglog-bridge" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "anyhow", "dyn-clone", @@ -384,17 +427,19 @@ dependencies = [ [[package]] name = "egglog-concurrency" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "arc-swap", + "egglog-numeric-id", "rayon", + "smallvec", ] [[package]] name = "egglog-core-relations" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "anyhow", "bumpalo", @@ -424,7 +469,7 @@ dependencies = [ [[package]] name = "egglog-experimental" version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog-experimental?branch=main#908c47d7046c840c8ff07caa0f7ff29a2e7adc82" +source = "git+https://github.com/egraphs-good/egglog-experimental?branch=main#eae9570d78105c53497fccdf0ff7fb1937592036" dependencies = [ "egglog", "egglog-ast", @@ -436,16 +481,16 @@ dependencies = [ [[package]] name = "egglog-numeric-id" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "rayon", ] [[package]] name = "egglog-reports" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "clap", "hashbrown 0.16.0", @@ -458,8 +503,8 @@ dependencies = [ [[package]] name = "egglog-union-find" -version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-container-fn-bug#00a6ce015ec9b898d1eefda86d46e52103607baf" dependencies = [ "crossbeam", "egglog-concurrency", @@ -468,7 +513,7 @@ dependencies = [ [[package]] name = "egglog_python" -version = "13.0.0" +version = "13.0.1" dependencies = [ "base64", "egglog", @@ -478,15 +523,23 @@ dependencies = [ "egglog-experimental", "egglog-reports", "egraph-serialize", + "indexmap", "lalrpop-util", "log", "num-bigint", "num-rational", + "opentelemetry", + "opentelemetry-otlp", + "opentelemetry-stdout", + "opentelemetry_sdk", "ordered-float", "pyo3", "pyo3-log", "rayon", "serde_json", + "tracing", + "tracing-opentelemetry", + "tracing-subscriber", "uuid", ] @@ -550,6 +603,87 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "slab", +] + [[package]] name = "generic-array" version = "0.14.9" @@ -560,6 +694,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -572,6 +717,12 @@ dependencies = [ "wasip2", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "graphviz-rust" version = "0.9.6" @@ -621,6 +772,191 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + [[package]] name = "im-rc" version = "15.1.0" @@ -675,12 +1011,37 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "iri-string" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -725,6 +1086,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + [[package]] name = "lock_api" version = "0.4.14" @@ -755,6 +1122,26 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mio" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "num" version = "0.4.3" @@ -840,6 +1227,99 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "opentelemetry" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "236e667b670a5cdf90c258f5a55794ec5ac5027e960c224bff8367a59e1e6426" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "pin-project-lite", + "thiserror", + "tracing", +] + +[[package]] +name = "opentelemetry-http" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8863faf2910030d139fb48715ad5ff2f35029fc5f244f6d5f689ddcf4d26253" +dependencies = [ + "async-trait", + "bytes", + "http", + "opentelemetry", + "reqwest", + "tracing", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bef114c6d41bea83d6dc60eb41720eedd0261a67af57b66dd2b84ac46c01d91" +dependencies = [ + "async-trait", + "futures-core", + "http", + "opentelemetry", + "opentelemetry-http", + "opentelemetry-proto", + "opentelemetry_sdk", + "prost", + "reqwest", + "thiserror", + "tracing", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f8870d3024727e99212eb3bb1762ec16e255e3e6f58eeb3dc8db1aa226746d" +dependencies = [ + "opentelemetry", + "opentelemetry_sdk", + "prost", + "tonic", +] + +[[package]] +name = "opentelemetry-stdout" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eb0e5a5132e4b80bf037a78e3e12c8402535199f5de490d0c38f7eac71bc831" +dependencies = [ + "async-trait", + "chrono", + "futures-util", + "opentelemetry", + "opentelemetry_sdk", + "serde", + "thiserror", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84dfad6042089c7fc1f6118b7040dc2eb4ab520abbf410b79dc481032af39570" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "opentelemetry", + "percent-encoding", + "rand 0.8.5", + "serde_json", + "thiserror", + "tracing", +] + [[package]] name = "ordered-float" version = "5.1.0" @@ -864,6 +1344,12 @@ dependencies = [ "windows-link", ] +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + [[package]] name = "pest" version = "2.8.3" @@ -920,18 +1406,59 @@ dependencies = [ ] [[package]] -name = "portable-atomic" -version = "1.11.1" +name = "pin-project" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" +dependencies = [ + "pin-project-internal", +] [[package]] -name = "ppv-lite86" -version = "0.2.21" +name = "pin-project-internal" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ - "zerocopy", + "proc-macro2", + "quote", + "syn 2.0.107", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", ] [[package]] @@ -943,12 +1470,36 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 2.0.107", +] + [[package]] name = "pyo3" version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37a6df7eab65fc7bee654a421404947e10a0f7085b6951bf2ea395f4659fb0cf" dependencies = [ + "indexmap", "indoc", "libc", "memoffset", @@ -1039,6 +1590,8 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ + "libc", + "rand_chacha 0.3.1", "rand_core 0.6.4", "serde", ] @@ -1049,10 +1602,20 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha", + "rand_chacha 0.9.0", "rand_core 0.9.3", ] +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + [[package]] name = "rand_chacha" version = "0.9.0" @@ -1069,6 +1632,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ + "getrandom 0.2.17", "serde", ] @@ -1078,7 +1642,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom", + "getrandom 0.3.4", ] [[package]] @@ -1136,6 +1700,40 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "rustc-hash" version = "2.1.1" @@ -1144,9 +1742,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustix" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ "bitflags", "errno", @@ -1217,6 +1815,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sha2" version = "0.10.9" @@ -1228,6 +1838,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "sized-chunks" version = "0.6.5" @@ -1238,12 +1857,34 @@ dependencies = [ "typenum", ] +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + [[package]] name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + [[package]] name = "strsim" version = "0.11.1" @@ -1272,6 +1913,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", +] + [[package]] name = "target-lexicon" version = "0.13.3" @@ -1280,12 +1941,12 @@ checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" [[package]] name = "tempfile" -version = "3.23.0" +version = "3.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", - "getrandom", + "getrandom 0.3.4", "once_cell", "rustix", "windows-sys 0.61.2", @@ -1311,6 +1972,196 @@ dependencies = [ "syn 2.0.107", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "libc", + "mio", + "pin-project-lite", + "socket2", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tonic" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +dependencies = [ + "async-trait", + "base64", + "bytes", + "http", + "http-body", + "http-body-util", + "percent-encoding", + "pin-project", + "prost", + "tokio-stream", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "721f2d2569dce9f3dfbbddee5906941e953bfcdf736a62da3377f5751650cc36" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry", + "opentelemetry_sdk", + "smallvec", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "web-time", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "typenum" version = "1.19.0" @@ -1335,6 +2186,24 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -1347,17 +2216,38 @@ version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ - "getrandom", + "getrandom 0.3.4", "js-sys", "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "version_check" version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + [[package]] name = "wasip2" version = "1.0.1+wasi-0.2.4" @@ -1394,6 +2284,19 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e038d41e478cc73bae0ff9b36c60cff1c98b8f38f8d7e8061e79ee63608ac5c" +dependencies = [ + "cfg-if", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.104" @@ -1426,6 +2329,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "web-sys" +version = "0.3.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9367c417a924a74cae129e6a2ae3b47fabb1f8995595ab474029da749a8be120" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "web-time" version = "1.1.0" @@ -1531,6 +2444,35 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.8.27" @@ -1550,3 +2492,57 @@ dependencies = [ "quote", "syn 2.0.107", ] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", + "synstructure", +] + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", +] diff --git a/Cargo.toml b/Cargo.toml index 4aab67c9..e3741144 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,32 +10,53 @@ name = "egglog" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.27", features = ["extension-module", "num-bigint", "num-rational"] } +pyo3 = { version = "0.27", features = ["extension-module", "num-bigint", "num-rational", "indexmap"] } num-bigint = "*" num-rational = "*" -egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug", default-features = false } -egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +indexmap = "2.12" +opentelemetry = "0.28" +opentelemetry-otlp = { version = "0.28", features = ["http-proto", "reqwest-blocking-client", "trace"] } +opentelemetry-stdout = { version = "0.28", features = ["trace"] } +opentelemetry_sdk = "0.28" +# egglog = { path = "../egg-smol", default-features = false } +# egglog-bridge = { path = "../egg-smol/egglog-bridge" } +# egglog-core-relations = { path = "../egg-smol/core-relations" } +# egglog-ast = { path = "../egg-smol/egglog-ast" } +# egglog-reports = { path = "../egg-smol/egglog-reports" } +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug", default-features = false } +egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } +egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } +egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } +egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } + + egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "main", default-features = false } -egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] } serde_json = "1" pyo3-log = "*" log = "0.4" lalrpop-util = { version = "0.22", features = ["lexer"] } ordered-float = "5" +tracing = "0.1" +tracing-opentelemetry = "0.29" +tracing-subscriber = "0.3" uuid = { version = "1.18", features = ["v4"] } rayon = "1.11" base64 = "0.22.1" # Use patched version of egglog in experimental [patch.'https://github.com/egraphs-good/egglog'] -egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +# egglog = { path = "../egg-smol" } +# egglog-core-relations = { path = "../egg-smol/core-relations" } +# egglog-ast = { path = "../egg-smol/egglog-ast" } +# egglog-reports = { path = "../egg-smol/egglog-reports" } +# egglog-bridge = { path = "../egg-smol/egglog-bridge" } + +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } +egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } +egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } +egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } +egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" } # enable debug symbols for easier profiling [profile.release] diff --git a/conftest.py b/conftest.py index dd7fdece..97da4af6 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,12 @@ import os import pathlib +import sys +from dataclasses import dataclass +from importlib import import_module +from typing import Any ROOT_DIR = pathlib.Path(__file__).parent + # So that it finds the local typings os.environ["MYPYPATH"] = str(ROOT_DIR / "python") @@ -10,3 +15,85 @@ pytest_plugins = ["mypy.test.data"] # Set this to the root directory so it finds the `test-data` directory os.environ["MYPY_TEST_PREFIX"] = str(ROOT_DIR) + +DEFAULT_OTLP_ENDPOINT = "http://127.0.0.1:4318/v1/traces" +SERVICE_NAME = "egglog" + +sys.modules.setdefault("egglog_pytest_otel", sys.modules[__name__]) + + +@dataclass(frozen=True) +class PytestOtelConfig: + endpoint: str | None + traces: str + + +def add_pytest_otel_options(parser: Any) -> None: + group = parser.getgroup("egglog-otel") + group.addoption( + "--otel-traces", + action="store", + default="off", + choices=("off", "console", "jaeger"), + help="Export egglog traces during tests.", + ) + group.addoption( + "--otel-otlp-endpoint", + action="store", + default=None, + help="OTLP/HTTP traces endpoint used when --otel-traces=jaeger.", + ) + + +def get_pytest_otel_config(config: Any) -> PytestOtelConfig: + traces = config.getoption("--otel-traces") + endpoint = config.getoption("--otel-otlp-endpoint") + if traces == "jaeger" and not endpoint: + endpoint = DEFAULT_OTLP_ENDPOINT + return PytestOtelConfig(traces=traces, endpoint=endpoint) + + +def configure_pytest_otel(config: Any): + otel_config = get_pytest_otel_config(config) + if otel_config.traces == "off": + return None + + trace = import_module("opentelemetry.trace") + OTLPSpanExporter = import_module("opentelemetry.exporter.otlp.proto.http.trace_exporter").OTLPSpanExporter + Resource = import_module("opentelemetry.sdk.resources").Resource + TracerProvider = import_module("opentelemetry.sdk.trace").TracerProvider + trace_export = import_module("opentelemetry.sdk.trace.export") + BatchSpanProcessor = trace_export.BatchSpanProcessor + ConsoleSpanExporter = trace_export.ConsoleSpanExporter + SimpleSpanProcessor = trace_export.SimpleSpanProcessor + + bindings = import_module("egglog.bindings") + + provider = TracerProvider(resource=Resource.create({"service.name": SERVICE_NAME})) + if otel_config.traces == "console": + provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) + else: + provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint=otel_config.endpoint))) + trace.set_tracer_provider(provider) + if otel_config.traces == "console": + bindings.setup_tracing(exporter="console") + else: + bindings.setup_tracing(exporter="http", endpoint=otel_config.endpoint) + return provider + + +def pytest_addoption(parser): + add_pytest_otel_options(parser) + + +def pytest_configure(config): + provider = configure_pytest_otel(config) + if provider is not None: + config._egglog_otel_provider = provider + + +def pytest_unconfigure(config): + provider = getattr(config, "_egglog_otel_provider", None) + if provider is not None: + import_module("egglog.bindings").shutdown_tracing() + provider.shutdown() diff --git a/docs/changelog.md b/docs/changelog.md index 162af2fb..f0f665e6 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,19 @@ _This project uses semantic versioning_ ## UNRELEASED +- Improve high-level Python ergonomics and docs [#397](https://github.com/egraphs-good/egglog-python/pull/397) + - Add `EGraph.freeze()`, returning a `FrozenEGraph` snapshot that can be pretty-printed back into replayable high-level Python actions for debugging and inspection. + - Add a variadic `EGraph(*actions, seminaive=True, save_egglog_string=False)` constructor so actions can be registered at construction time, and export `ActionLike` from `egglog` for typing code that works with `EGraph.register(...)` and the constructor. + - Add richer general-purpose builtin helpers used by the new containers/polynomials work, including additional `Vec`, `MultiSet`, numeric, and string operations. + - Expand the [Python integration reference](reference/python-integration.md) with examples for `run`, `stats`, `function_values`, `freeze`, `display`, and `saturate`, and clarify that proof-mode commands exposed in low-level bindings are not yet supported as a full high-level Python workflow. + - Add [OpenTelemetry tracing docs](how-to-guides/tracing.md) and the linked [`pytest` tracing workflow](reference/contributing.md#tracing) for debugging Python and Rust spans together. + - Add the new [containers/polynomials write-up](explanation/2026_02_containers.md). + - Fix several high-level bugs: + - auto-prefix generated `let` bindings with `$` so recorded programs and round-tripped output are valid egglog; + - keep schedules and default-rewrite rules live after materialization so later declarations are not missed; + - fix empty Python container conversions and several higher-order callable / nested-lambda inference edge cases. + - Update the experimental `egglog.exp.array_api` module. + ## 13.0.1 (2026-03-04) - Fix install by adding cloudpickle as required dependency [#405](https://github.com/egraphs-good/egglog-python/pull/405) diff --git a/docs/conf.py b/docs/conf.py index 49142405..4d8f159b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -import pathlib # noqa: INP001 +import pathlib import subprocess ## diff --git a/docs/explanation/2026_02_containers.md b/docs/explanation/2026_02_containers.md new file mode 100644 index 00000000..e504eea1 --- /dev/null +++ b/docs/explanation/2026_02_containers.md @@ -0,0 +1,724 @@ +--- +file_format: mystnb +--- + +```{post} 2026-03-13 +``` + +# Custom Data Structures in E-Graphs + +*Cross-posted on the [UW PLSE blog](https://uwplse.org/2026/02/24/egglog-containers.html).* + +[E-graphs](https://en.wikipedia.org/wiki/E-graph) are a data structure used to reason about program equivalence. +Combined with specialized algorithms they can be used to build optimizers or compilers. However, +their performance can struggle as the number of equivalent expressions explodes if we include on algebraic identities, such as +associativity, commutativity, and distributivity (A/C/D). + +Alternatively we can attempt to build these identities into our underlying data structure, such as Philip Zucker's explorations of +[Gröbner basis](https://www.philipzucker.com/multiset_rw/) and [bottom up e-matching](https://www.semanticscholar.org/paper/Omelets-Need-Onions%3A-E-graphs-Modulo-Theories-via-Zucker/b07bdef17fdbb7cf927a5a844fc587335864e89a). +For example, instead of representing a sequence of additions as say a tree of binary operations, we can instead represent it as a sorted list of terms being added +or a multiset mapping terms to their counts. +However, building entirely new e-graph systems to take advantage of this is a large engineering lift and splits the ecosystem of users, +reducing the possibility for code reuse between project that use e-graphs. + +Here, I explore how supporting custom data structures and higher order functions can +be used to build efficient algebraic representations without changing the internals of an e-graph system. + +## EGraphs in Egglog + +In this post we will be using, an e-graph framework built on top of a custom database. It's written in +[Rust](https://github.com/egraphs-good/egglog) with [bindings in Python](https://github.com/egraphs-good/egglog-python), +which I will be using here. + +First we start with an example where we define a language through a set of types, uninterpreted functions, and rewrite rules +to define equivalences between expressions. + +This lets us check things like if `2 * (x + 3) == 6 + 2 * x`, given distributivity and commutativity, along with constant folding rules: + +```{code-cell} python +from __future__ import annotations + +from egglog import * + +# 1. Create a custom type +class Num(Expr): + # 2. Define constructors for this type from an integer, + # string, binary addition or multiplication + def __init__(self, value: i64Like) -> None: ... + + @classmethod + def var(cls, name: StringLike) -> Num: ... + + def __add__(self, other: Num) -> Num: ... + + def __mul__(self, other: Num) -> Num: ... + + +# 2. Define a set of rewrite rules that add equivalences +# They work by finding an expression that matches the LHS +# modulo the variables, then adding the RHS with the variables +# substituted, and setting them as equivalent to it +@ruleset +def comm_dist_fold(a: Num, b: Num, c: Num, i: i64, j: i64): + # commutativity + yield rewrite(a + b).to(b + a) + # distributivity + yield rewrite(a * (b + c)).to((a * b) + (a * c)) + # constant folding + yield rewrite(Num(i) + Num(j)).to(Num(i + j)) + yield rewrite(Num(i) * Num(j)).to(Num(i * j)) + + +# 3. Create an empty e-graph +egraph = EGraph() + +# 4. Add our two initial expressions +expr1 = egraph.let("expr1", Num(2) * (Num.var("x") + Num(3))) +expr2 = egraph.let("expr2", Num(6) + Num(2) * Num.var("x")) + + +# 5. Run this ruleset until it is "saturated" +# meaning that further application will be no-ops +# as well as output a visualization showing the progress +egraph.saturate(comm_dist_fold) +# 6. Verify that our two expressions are now equivalent +egraph.check(expr1 == expr2) +``` + +The visualization shows the final state of the e-graph and allows us to step through it using the slider at the top: + +The arrows points a function to its arguments. When two expressions are equivalent they are placed in the same cluster, +called an e-class. The top e-class has labels `expr1` and `expr2` in it, meaning they are equivalent now. + +By dragging the top slide slider to the left it will show the initial state, before any of the rules were run, +when it just contains our two initial expressions. They are start in different e-classes, since we don't they are equal +until we run our rules. As you drag the slider to the right, you will see the state of the e-graph after each rule application. + +EGraphs can also be used for program optimization. By choosing a cost model, for example based on the total number of terms, +we can try to find an expression equivalent to our initial expression and extract it out. + +This is comparable to how term rewriting system can also be used for optimization or transformation. One way to look at +egraphs is as if we have use a term rewriting system but we remember all the previous terms we have encountered, and defer +picking the "best' one till the end. +This lets us focus less on rule application order, but it does mean that our memory will increase over time, which +is what we will get to soon. + +For a more thorough introduction, check out the [egglog tutorial](https://egglog-python.readthedocs.io/latest/tutorials/tut_1_basics.html), +and for an example of how it can be used inside of a larger system see the [Numba v2 mini book](https://numba.pydata.org/numba-prototypes/sealir_tutorials/index.html). +More examples are collected [in the awesome e-graphs repo](https://github.com/philzook58/awesome-egraphs). +EGraphs also have [an active community](https://egraphs.org/) around them that [chat online](https://egraphs.zulipchat.com/) and +[meet in person](https://egraphs.org/workshop/). + +## Size Blow Up + +While e-graphs are powerful, they can "blow up", increasing in size drastically even when starting with small expression. +For example if we start with `2 + a + b + b + 3` and add A/C/D rules we can see it increase in size: + +```{code-cell} python +@ruleset +def assoc(a: Num, b: Num, c: Num): + yield birewrite(a + (b + c)).to((a + b) + c) + + +egraph = EGraph() +a, b = Num.var("a"), Num.var("b") +new_expr = egraph.let("new_expr", Num(2) + a + b + b + Num(3)) + +# run both the associativity and commutativity/distributivity +# rules together +egraph.saturate(assoc | comm_dist_fold) +egraph.extract(new_expr) +``` + +As the number of e-nodes increases so does the memory usage and also the runtime, limiting the ability +to use these kinds of rules on large expressions. One way to work around this +is to limit the number of times we apply certain rules or limit the size of e-graphs, with the tradeoff that this +limits the size of our search space. What if instead we could maintain +an optimization like constant folding without an increasing blow-up size due to the other rules? + +We will look at how containers can be used to achieve this, but first some background on how Egglog handles primitives. + +## Primitives and Containers in Egglog + +Along with Egglog letting you define your own types, like `Num`, it also comes with a number of builtins/primitives like `i64` and `String`. +The core comes with a number of them, but they can also be written in Rust extensions as plugins. To define a new type in Rust +you must define how to compare them with equality and how to hash them. +Primitives are treated like opaque values, Egglog doesn't reason about their inner structure. Functions can also be defined +over primitives, again either in the core or as a Rust extension. Egglog doesn't know anything about their semantics, just that +they take in primitives and return other primitives. + +If we think of primitives as opaque values, what if we want to contain a primitive inside of another one? For example, +a `Vec` type that contains a number of items. To define such a type in Rust, we need the above properties around hashing +and equality, but we also need to make sure it respects congruence. Congruence is the property where if you have +two expression `f(a)` and `f(b)` in the e-graph, and then you make `a == b`, then `f(a)` should also equal `f(b)`. We want +this same property to hold for something like a vector, so if you have `Vec(a, c)` and `Vec(b, c)` in the e-graph, and you make `a == b`, +then these two vecs should also be equal `Vec(a, c) == Vec(b, c)`. + +We do this by implementing one additional operation on containers, rebuilding. This is called whenever we want to renormalize +the e-graph to preserve congruence. We defer it so we don't do it after every union operation, to reduce the amount of work. +Since containers "contain" references to other e-classes, we need to update those references. That what this rebuilding +operation does, so that when its time to rebuild, the `Vec` type calls rebuilding on each of its inner values, updating them with +new names for each e-class. +So then when we check for equality after that, it will preserve congruence. + +Egglog doesn't know anything more about the structures of containers besides how to rebuild them and any primitive functions you define on them. +This both makes them relatively easy to implement and add, but also limits the ability to "match" over them, which will see how to work around +in the next section. + +One use case for containers is to represent operations with more structure. I the above, we have defined +addition as a binary operation with two ordered arguments. However, we may decide that for our use case, +we not only wanna make `a + b` equal to `b + a`, but in fact indistinguishable. This effectively +replaces the commutative and associativity rules with instead a representation that maintains their invariants. +One way we could do that is with a container that represents all terms being added. +This would be a [multiset aka bag](https://en.wikipedia.org/wiki/Multiset), since +we want to know how many times a term is being added, but we don't care about the order. + +We can write this in the Python bindings like so: + +```{code-cell} python +@function +def sum_(xs: MultiSet[Num]) -> Num: ... +``` + +Now if we construct `sum_(MultiSet(a, b))` this will be equal to `sum_(MultiSet(b, a))` due to the implementation of multiset b +being order insensitive: + +```{code-cell} python +egraph = EGraph() +x = Num.var("x") +y = Num.var("y") +z = egraph.let("z", sum_(MultiSet(x, y))) +egraph.check(z == sum_(MultiSet(y, x))) +``` + +We have the rebuilding property we talked about above as well, to maintain congruence. If we now union `x` with `y`, +the sum will reflect this to become `sum_(MultiSet(x, x))`: + +```{code-cell} python +egraph.register(union(x).with_(y)) +egraph.check(z == sum_(MultiSet(x, x))) +``` + +So we can see here we can represent a whole set of equal summations only with one multiset, instead of having to add +many terms to the e-graph. + +## Matching on Containers by Index + +Given this new implementation, how would we replicate the above constant folding example on it? + +Well first we can start by creating a new e-graph and adding the expression, this time using our `sum` function with multisets, +instead of binary addition: + +```{code-cell} python +egraph = EGraph() +new_expr = egraph.let("new_expr", sum_(MultiSet(Num(2), a, b, b, Num(3)))) +``` + +One way to think about a constant folding rule would be to "Look for a sum that contains two constant numbers, +take them both out and add their sum back in". However, we don't have the ability to match on the contents of a multiset directly, +since as we said above Egglog doesn't know anything about its inner structure. + +One way we can work around this is to build up an "index" function for the contents of the multiset. This maps +a multiset and an item inside of it, to the count of times it shows up in that multiset: + +```{code-cell} python +@function +def ms_num_index(xs: MultiSet[Num], x: Num) -> i64: ... +``` + +It is similar to in a database if you need to do a join efficiently you have to build an index. + +Then we can add two rules, one that fills in the index whenever we have a sum and then another one that matches on that +to do the constant folding: + +```{code-cell} python +@ruleset +def constant_fold_index(xs: MultiSet[Num], i: i64, k: i64): + # For all sums, fill in the index function + yield rule(sum_(xs)).then(xs.fill_index(ms_num_index)) + + # Try replacing any sum with the folded version + yield rewrite(sum_(xs)).to( + # Replace the two numbers with their sum, by removing + # them and then inserting their sum back in + sum_(xs.remove(Num(i)).remove(Num(k)).insert(Num(i + k))), + # These are conditions for the rewrite to match: + # Look for a multiset that contains two numbers that + # are not the same one + ms_num_index(xs, Num(i)), + ms_num_index(xs, Num(k)), + i != k, + ) + + +egraph = EGraph() +new_expr = egraph.let("new_expr", sum_(MultiSet(Num(2), a, b, b, Num(3)))) +egraph.saturate(constant_fold_index) +egraph.extract(new_expr) +``` + +If we run this now we can see that we get back out the folded expression, without the blow-up from before: + +However, we still add a number of nodes to the e-graph to maintain the index. While this works in this small example, +if many intermediate multisets are generated, this can lead again to a blow-up in the e-graph size. + +So what if instead there was a way to express this rule without needing to maintain this index? + +## Matching on Containers with Higher Order Functions + +We can look at the rule above as trying to pull out two numbers from a sum and fold them in together. So if there were `n` +constants, it would trigger `n * (n - 1)` times, since we can choose any two of them to fold together. What if instead +we wanted to express a rule that selects *all* constants from a multiset and folds them together? + +The index approach won't work here, because we don't have a fixed number to match on. Instead, we can use higher order functions +to express this as a block wise operation. Effectively we want to say "Pull out all constants in the multiset, add them together, +and then add that back into the multiset with all the non-constants". + +But first we need to add a helper function that returns an `i64` for a `Num` if its a constant: + +```{code-cell} python +@function +def get_i64(x: Num) -> i64: ... + + +@ruleset +def set_get_i64(i: i64): + yield rule(Num(i)).then(set_(get_i64(Num(i))).to(i)) +``` + +Then we can define the constant folding, using higher order `fold` and `map` operations + +```{code-cell} python +@ruleset +def constant_fold_sum(xs: MultiSet[Num]): + # Extract out all the constants from the sum + constants = xs.map(get_i64) + # Filter for the remaining values that are not constants + remaining = xs - constants.map(UnstableFn(Num)) + # Sum all the constants to fold them together + folded = multiset_fold(i64.__add__, i64(0), constants) + yield rewrite(sum_(xs)).to( + # replace it with the non constants plus the folded + sum_(remaining.insert(Num(folded))), + # Only run this rule if there are more than one + # constant to fold together + constants.length() > 1, + ) + + +egraph = EGraph() +new_expr = egraph.let("new_expr", sum_(MultiSet(Num(2), a, b, b, Num(3)))) +egraph.saturate(set_get_i64.saturate() + constant_fold_sum.saturate()) +egraph.extract(new_expr) +``` + +Running the setting ruleset first, then the folding ruleset, +we can see that we get the same result as before,but without needing to maintain the index: + +Using higher order functions on containers, we can express efficient rewrite rules +that reduce the size of the e-graph compared to using binary operations. The container themselves preserves some of the core +identities that would lead to blow up, and the higher order functions support block wise operations to process an +arbitrary number of items as once. + +For a larger example that motivated this work, see the case study in Appendix 1, where we have a large polynomial expression with many terms that +we want to factor. That case study also demonstrates how we can convert from binary operations into containers as well. In the Appendix 2, +there are a few more examples that we could apply this approach to as well. + +## Takeaways + +Experimenting with using containers in this way explores how we can add more efficient representations in e-graphs +to an existing system like Egglog, by using custom data structures. + +It's also interesting that these representations can be not only more efficient but also more directly correspond +to the semantics of the your use case, compared to say a tree of binary operations. + +This work also highlights some of the current limitations of egglog. + +One issue is that composing functions of primitives is currently very limited. The only tool we have is currying, but +it is not possible to reorder arguments or compose them in more complicated manners. This inevitably leads to +creating more bespoke functions. For example, I had to add a `multiset_contains_swapped` function that swaps the order +of the `contains` method, since I needed to partially apply it with the second argument. Further exploring this line of +work might lead to trying out different ways of enriching primitive functions, possibly by allowing a way at runtime +to create new ones by composing others, either through a DSL/JIT or a higher order composition approach like the +[compiling to categories](http://conal.net/papers/compiling-to-categories/) work. + +Implementing these higher order functional primitives on containers is also challenging, due to the lack of built-in +generic type support in Egglog. Adding them currently is fiddly and requires careful thought over how to implement +their generic types. Adding built in support for generic types, both in primitives and user code, could make this more +scalable. + +Overall, I hope that this work shows that there is a design space here in Egglog to try out different +ways of representing new normalized forms of different domains and then designing algorithms over them. As opposed to +creating a whole new e-graph implementation, adding them as custom containers to Egglog supports reuse of the existing +engineering work and compositionality within the ecosystem. I am left wondering how further improvements to Egglog +can help extend this type of experimentation of how to efficiently represent different domains inside of e-graphs. + +## Appendix 1: Case Study from a Cloth Simulation Workload + +Here we start with an expression from the paper ["Interactive design of periodic yarn-level cloth patterns"](https://www.semanticscholar.org/paper/Interactive-design-of-periodic-yarn-level-cloth-Leaf-Wu/6350d7feb2dfc37d434da2839eacd5e8b025edda), +which is part of a larger program that does cloth simulation. It was recommended by my advisor, [Gilbert Bernstein](http://www.gilbertbernstein.com/), +since we can use their reference implementation in Mathematica to verify that our implementation matches theirs. + +![meme from tim and eric TV show with someone miming their mind being blown, with the text "yarn = polynomials" imposed](./2026_02_yarn-polynomials.gif) + +*Note that all code for this case study is reproducible in [this notebook](https://github.com/egraphs-good/egglog-python/blob/270a1876b6dbea37e441c132adbfdc8c11cbb319/docs/explanation/2026_02_containers_code.ipynb).* +*It is currently based on a branch of the Python bindings and Rust source, that adds additional multiset operations.* +For this docs version, the notebook content is reproduced later in this page in a folded appendix block. + +We define define a function to produce the amount of bending for a certain point over the [Python Array API Specification](https://data-apis.org/array-api/latest/API_specification/), +so that it works on both concrete NumPy arrays and symbolic arrays. It takes in a number of 1D arrays and returns a 0D array. +The details of what each argument represents and the underlying semantics are not important to this example, but our main objective +would be for a graphics researcher to be able to prototype a function like this. Then we would want our system +to "optimize" it in some way, before compiling it to something like CUDA to run on a GPU: + +```{code-cell} python +def bending_function(Q, Bp, Bpp): + xp = Q.__array_namespace__() + QM = xp.reshape(Q, (4, 3)).T + + yip = xp.vecdot(QM, Bp) + yipp = xp.vecdot(QM, Bpp) + num = xp.linalg.vector_norm(xp.cross(yip, yipp)) + den = xp.linalg.vector_norm(yip) ** 3 + return (num / den) ** 2 +``` + +We symbolically evaluate the result of this function by using [an implementation of the Array API written in Egglog](https://github.com/egraphs-good/egglog-python/blob/cb263b163150181d164db25fbbac6e8a1e2da719/python/egglog/exp/array_api.py): + +```{code-cell} python +import egglog +import egglog.exp.array_api as enp + +Bp = enp.NDArray([enp.Value.var(f"bp{i}") for i in range(1, 5)]) +Bpp = enp.NDArray([enp.Value.var(f"bpp{i}") for i in range(1, 5)]) +Q = enp.NDArray([enp.Value.var(f"q{i}") for i in range(1, 13)]) +FunctionBending = enp.NDArray(bending_function(Q, Bp, Bpp).eval()) +FunctionBending +``` + +We can also compute its gradient with respect to `Q`, to give us an even larger expression. + +```{code-cell} python +GradientBending = enp.NDArray(FunctionBending.diff(Q).eval()) +``` + +Calling `eval` here will create the necessary e-graph with rewrites, add the expression, and reduce it to a simplified +form that only contains a rational expression with polynomial subexpressions. + +For the sake of this example, let's first fully "distribute" the polynomial we have. This means expanding it into a normal +form, by applying the distribute rule, so that `a(x + y)` becomes `ax + ay`. This is meant to simulate a worst-case scenario, +since the cost increases as we distribute, duplicating terms. This suffices to give us a large enough example to stress test our system: + +```{code-cell} python +@egglog.ruleset +def remove_subtraction(a: enp.Value, b: enp.Value): + yield egglog.rewrite(a - b, subsume=True).to(a + (-1) * b) + + +@egglog.ruleset +def distribute(a: enp.Value, b: enp.Value, c: enp.Value): + yield egglog.rewrite((a + b) * c, subsume=True).to(a * c + b * c) + yield egglog.rewrite(c * (a + b), subsume=True).to(c * a + c * b) +``` + +```{code-cell} python +:tags: [hide-output] +egraph = egglog.EGraph() +egraph.register(FunctionBending) +egraph.run(remove_subtraction.saturate() + distribute.saturate()) +FunctionBending_distributed = egraph.extract(FunctionBending) + +gradient_egraph = egglog.EGraph() +gradient_egraph.register(GradientBending) +gradient_egraph.run(remove_subtraction.saturate() + distribute.saturate()) +GradientBending_distributed = gradient_egraph.extract(GradientBending) +FunctionBending_distributed +``` + +We now have an expression that is mainly a sum of products, a multivariate polynomial. + +For some sense of their size, the `FunctionBending` has initial cost of 401 and the `GradientBending` has 20,570. +This cost is produced by the Egglog extractor, corresponding roughly to one node per op like `*` and one per variable as a tree. +This is meant to reflect roughly the cost to compute the expression, so a lower an expression with a lower cost would run faster. + +One way to lower the cost of a polynomial is to factor it, so that `ax + ay` becomes `a(x + y)`. This is the same as applying the distributivity rule in reverse. +There are however many equivalent factorizations, and some may be better than others. +One way to use egglog to optimize this space would be to add in the associativity, commutativity, and distributivity rules and run it until saturation, and extract +out the lowest cost: + +```{code-cell} python +:tags: [hide-output] + +@egglog.ruleset +def factoring(a: enp.Value, b: enp.Value, c: enp.Value): + yield egglog.birewrite((a + b) * c).to(a * c + b * c) + yield egglog.rewrite(a * b).to(b * a) + yield egglog.rewrite(a + b).to(b + a) + yield egglog.birewrite(a * (b * c)).to((a * b) * c) + +egraph.run(factoring.saturate()) +egraph.extract(FunctionBending_distributed) +``` + +For the `FunctionBending` example, this works fine, taking about a tenth of a second to saturate and then extract out the smallest one. + +However, if we use the `GradientBending`, each iteration will take longer and longer. Cutting it off after 10 second per iteration, +we get through only three of them and the e-graph is not saturated. It will have decreased the cost to 2,126,268 from the original of +4,250,786. However, it will also have increased the number of nodes in the e-graph from 588,125 originally to +2,583,064. This blow is due to the evaluation of associativity and commutativity rules. + +The gradient is also only the first derivative of the function bending. In the real workload from the paper, +we also need to compute the second derivative and ideally consider it as part of a larger expression. +So at least if we fully distribute first, trying to naively explore the entire search space of factorization through +associativity, commutativity, and distributivity rules is not really feasible for this type of expression. + +### Representing Polynomials with Multisets + +Taking a step back, the main space we want to explore here is the different options for factoring the expression. We +don't really care about which expression to pick due to associativity or commutativity since the cost will be the same (at this +point we are not considering common sub expression elimination and constant folding doesn't apply in this example). +We add those rules so that we can explore the space of factorizations through the distributivity rule. + +So what if instead we choose to represent a polynomial such that the form is agnostic to ordering or association? +To represent just a product of values, we need a single multiset, storing the exponent of each term as the count of the number times +that expression shows up in the product. For example, the expression `a * b * b` would be represented as the multiset `{a: 1, b: 2}`. +To represent a sum of products (aka a polynomial), we need a multiset of multisets, where each inner multiset is a monomial, and the outer multiset is the sum of these monomials, with the counts holding the multiple of each. +For example, the expression `2 * a * b + 3 * a**2` would be represented as the multiset of multisets `{ {a: 1, b: 1}: 2, {a: 2}: 3}`. + +We can add a new function to construct values from this represent: + +```python +@function +def polynomial(x: MultiSet[MultiSet[Value]]) -> Value: ... +``` + +Our first task then is to translate between our binary operations and this multiset form. +The first couple of rules are relatively straightforward, just converting addition, multiplication, and exponentiation to the +corresponding forms, along with saving some analysis on terms that we will use later. This is a one way translation, so +we can also delete the source terms once we match them, so extraction doesn't match them. We create a ruleset to do this translation: + +```python +@function(merge=lambda old, new: new) +def get_monomial(x: Value) -> MultiSet[Value]: + """ + Will be defined on all polynomials with exactly one monomial created in `to_polynomial_ruleset`: + + get_monomial(polynomial(MultiSet(xs))) => xs + """ + + +@function(merge=lambda old, new: new) +def get_sole_polynomial(xs: MultiSet[Value]) -> MultiSet[MultiSet[Value]]: + """ + Will be defined on all monomials that contain a single polynomial created in `to_polynomial_ruleset`: + + get_sole_polynomial(MultiSet(polynomial(xss))) => xss + """ + +@ruleset +def to_polynomial_ruleset( + n1: Value, + n2: Value, + n3: Value, + i: i64, + ms: MultiSet[Value], + mss: MultiSet[MultiSet[Value]], + mss1: MultiSet[MultiSet[Value]], +): + yield rule( + eq(n3).to(n1 + n2), + eq(mss).to(MultiSet(MultiSet(n1), MultiSet(n2))), + name="add", + ).then( + union(n3).with_(polynomial(mss)), + set_(get_sole_polynomial(MultiSet(polynomial(mss)))).to(mss), + delete(n1 + n2), + ) + yield rule( + eq(n3).to(n1 * n2), + eq(ms).to(MultiSet(n1, n2)), + name="mul", + ).then( + union(n3).with_(polynomial(MultiSet(ms))), + set_(get_monomial(polynomial(MultiSet(ms)))).to(ms), + delete(n1 * n2), + ) + yield rule( + eq(n3).to(n1**i), + i >= 0, + eq(ms).to(MultiSet.single(n1, i)), + name="pow", + ).then( + union(n3).with_(polynomial(MultiSet(ms))), + set_(get_monomial(polynomial(MultiSet(ms)))).to(ms), + delete(n1**i), + ) +``` + +When applying this ruleset we will replace binary operations with multiset values, but they will be unnecessarily +nested. For example, we might end up with a term like `polynomial(MultiSet(MultiSet(polynomial(xs))))`, which should be replaced +with just `polynomial(xs)`. We define two additional rules to cover cases like this: + +```python + yield rule( + eq(n1).to(polynomial(mss)), + # For each monomial, if any of its terms is a polynomial with a single monomial, flatten + # that into the monomial, otherwise keep it as is + mss1 == mss.map(partial(multiset_flat_map, get_monomial)), + mss != mss1, # skip if this is a no-op + name="unwrap monomial", + ).then( + union(n1).with_(polynomial(mss1)), + delete(polynomial(mss)), + set_(get_sole_polynomial(MultiSet(polynomial(mss1)))).to(mss1), + ) + yield rule( + eq(n1).to(polynomial(mss)), + # If any of the monomials just has a single item which is a polynomial, then flatten that into the outer polynomial + mss1 == multiset_flat_map(UnstableFn(get_sole_polynomial), mss), + mss != mss1, + name="unwrap polynomial", + ).then( + union(n1).with_(polynomial(mss1)), + delete(polynomial(mss)), + set_(get_sole_polynomial(MultiSet(polynomial(mss1)))).to(mss1), + ) +``` + +We have avoided the need to match inside of containers by instead using higher order functions to apply blockwise +operations that are executed in Rust during rule matching. We had to create the above analysis for the same reason, +we cannot create functions whose implementation is deferred until a later rewrite, they must be available at match time. + +After running these rulesets any subexpressions that contain only additions and multiplications will be turned into +flattened multisets, which is what we wanted to do in this section. +What's nice here is also if there are any other operations defined like `/`, this will work transparently +with them, making this type of analysis extensible as the system grows, since we only normalize polynomial subtrees. The +contents of their terms don't have to be limited to integers and variables: + + + +```{code-cell} python +:tags: [remove-input] + +polynomial_egraph = egglog.EGraph() +polynomial_egraph.register(FunctionBending_distributed) +polynomial_egraph.run(enp.to_polynomial_ruleset.saturate()) +FunctionBending_polynomial_multisets, FunctionBending_polynomial_multisets_cost = polynomial_egraph.extract( + FunctionBending_distributed, + include_cost=True, +) +print(FunctionBending_polynomial_multisets_cost) +FunctionBending_polynomial_multisets +``` + +### Greedy Multivariate Horner Factorization + +Now that we have our polynomial subterms represented as nested multisets, the next step here is to see if we can find a form +with lower cost. One of the ways to do this with polynomials is to try to find a factorization of them that minimizes the number +of multiplications. With univariate polynomials, we can use an optimal algorithm called [Horner's method](https://en.wikipedia.org/wiki/Horner%27s_method). +Extending this to multivariate polynomials, there isn't an efficient algorithm that is guaranteed to produce the optimal factoring, but +there is [a greedy algorithm that will often produce a good one](https://www.semanticscholar.org/paper/Greedy-algorithms-for-optimizing-multivariate-Ceberio-Kreinovich/96103f6f48bd15d40de43a716922d1177b2b5ea2). + +So instead of considering all possible factorizations and then waiting till extraction to pick out the best one, +we can try to implement this greedy algorithm. This is made easier by the fact that we have flattened the polynomial +into a multiset of multisets, so we can analysis it holistically. + +To implement this, we find the factor that shows up in the most monomials, then find the subset of monomials which contain it, take +the intersection of all of those (to find the largest factor we can pull out of all of them), factor that out, and add it to the remainder +that didn't include that factor: + +```python +@ruleset +def factor_ruleset( + n: Value, + mss: MultiSet[MultiSet[Value]], + counts: MultiSet[Value], + picked_term: Value, + picked: MultiSet[MultiSet[Value]], + divided: MultiSet[MultiSet[Value]], + factor: MultiSet[Value], + remainder: MultiSet[MultiSet[Value]], +): + yield rule( + eq(n).to(polynomial(mss)), + # Find factor that shows up in most monomials, at least two of them + counts == MultiSet.sum_multisets(mss.map(MultiSet.reset_counts)), + eq(picked_term).to(counts.pick_max()), # on ties pick an arbitrary one + # Only factor out if it appears in more than one monomial + counts.count(picked_term) > 1, + # The factor we choose is the largest intersection between all the monomials that have that that factored term + picked == mss.filter(partial(multiset_contains_swapped, picked_term)), + factor == multiset_fold(MultiSet.__and__, picked.pick(), picked), # intersection + divided == picked.map(partial(multiset_subtract_swapped, factor)), + # remainder is those monomials that do not contain the factor + remainder == mss.filter(partial(multiset_not_contains_swapped, picked_term)), + name="factor", + ).then( + # factor * polynomial(divided) + remainder + union(n).with_(polynomial(MultiSet(factor.insert(polynomial(divided))) + remainder)), + delete(polynomial(mss)), + ) +``` + +If we apply this, we now have a factored form! We can see that this uses a similar technique to above, where we use higher order functions +to create a new polynomial based on the old one, and replace it. + +```{code-cell} python +:tags: [remove-input] + +polynomial_egraph.run(enp.factor_ruleset.saturate()) +FunctionBending_polynomial_multisets_factored, FunctionBending_polynomial_multisets_factored_cost = polynomial_egraph.extract( + FunctionBending_polynomial_multisets, + include_cost=True, +) +print(FunctionBending_polynomial_multisets_factored_cost) +FunctionBending_polynomial_multisets_factored +``` + +We can then turn this multiset form back into one with binary operations, giving us an end-end-end way to factor polynomials +in Egglog without exploring the full A/C/D space, reducing the size blowup. + +For the smaller expression of the bending function, this produces a result of the same cost as the full factorization +and takes half the time. It also produces many fewer nodes. While the full factored version has 13,040 in the e-graph, +this one only has 927, which is only slightly more than the original size after distributing (904 nodes). + +For the larger expression, of the gradient, the difference is even starker. This approach is able to factor it to a cost of 79,974, +whereas we stopped the full factorization after it reached 2,125,338. In terms of e-graph size we have 112,144 nodes at the end +compared to the 2,582,934 of the full factorization. + +So overall, this example shows a way to do a directed factorization to reduce the total cost of an expression, +without having to explore the full space of equivalent expressions. While this isolated use case might not be a good +fit for e-graphs, inside of a larger optimization pipeline this shows how we can capture this type of optimization in +a way that is composable. + +Moreover, it is an experiment in how we can build rules on top of containers that use higher order functions +to do more complicated analysis, without leading to a node blow up. + +## Appendix 2: Further Examples + +Above I presented a large example that comes from my current line of research. However, there are many +smaller examples that could also be used to explore the usefulness of this kind of technique. Due to limitations in my time +I haven't explored these deeply, but did want to mention them. + +[Yihong](https://effect.systems/) shared with me an example of how just having a simple associativity rule and a rule for multiplying by zero will lead +to never saturating: + +```clojure +(datatype Int (mul Int Int) (a) (zero)) + +(birewrite (mul x (mul y z)) (mul (mul x y) z)) +(rewrite (mul (zero) x) (zero)) + +(mul (zero) (a)) +(run-schedule (repeat 8 (run))) +``` + +Instead, if we represented this as a product of a multiset, we could simply have a rule that looked for a zero element +in the multiset and replaced that with zero. Then there would be no associativity needed, and so no chance for this to blow up. A `product(MultiSet(...))` +operation can handle associativity and commutativity and the rebuilding handles merges. + +When I asked on the EGraph's Zulip for more examples, Sophia B also [shared another example with me](https://egraphs.zulipchat.com/#narrow/channel/328972-general/topic/A.2FC.20Blowup.20Example/near/573091425). If you have the rule `f(a + b) + 1 = f(a) + f(b)` plus A/C, you can derive equalities like `f(x) + (f(y) + f(z)) = f(x + (y + z)) + 2`, but it can take a large number of nodes. +Instead in this system, we would have to encode that rule over multisets and add constant propagation to the sum function, to see how it could be found more directly, through normalization. + +*I [used an LLM](https://chatgpt.com/share/69969a20-26d4-8011-879a-62a04adfed31) to get feedback on the draft and revised based on its suggestions to improve readability, organization, and consistency.* +*Thank you Oliver, Yihong, Gilbert and Alexandra also who gave feedback to me throughout this process and while drafting this post.* diff --git a/docs/explanation/2026_02_yarn-polynomials.gif b/docs/explanation/2026_02_yarn-polynomials.gif new file mode 100644 index 00000000..838060c7 Binary files /dev/null and b/docs/explanation/2026_02_yarn-polynomials.gif differ diff --git a/docs/how-to-guides.md b/docs/how-to-guides.md index 3e2abe89..91862f77 100644 --- a/docs/how-to-guides.md +++ b/docs/how-to-guides.md @@ -4,18 +4,8 @@ file_format: mystnb # How-to guides -## Parsing and running program strings - -You can provide your program in a special DSL language. You can parse this with {meth}`egglog.bindings.EGraph.parse_program` and then run the result with You can parse this with {meth}`egglog.bindings.EGraph.run_program`:: - -```{code-cell} -from egglog.bindings import EGraph - -egraph = EGraph() -commands = egraph.parse_program("(check (= (+ 1 2) 3))") -commands -``` - -```{code-cell} -egraph.run_program(*commands) +```{toctree} +:maxdepth: 1 +how-to-guides/parsing-and-running-program-strings +how-to-guides/tracing ``` diff --git a/docs/how-to-guides/parsing-and-running-program-strings.md b/docs/how-to-guides/parsing-and-running-program-strings.md new file mode 100644 index 00000000..a9d2c5d0 --- /dev/null +++ b/docs/how-to-guides/parsing-and-running-program-strings.md @@ -0,0 +1,21 @@ +--- +file_format: mystnb +--- + +# Parsing and running program strings + +You can provide your program in a special DSL language. Parse it with +{meth}`egglog.bindings.EGraph.parse_program` and run the resulting commands with +{meth}`egglog.bindings.EGraph.run_program`: + +```{code-cell} +from egglog.bindings import EGraph + +egraph = EGraph() +commands = egraph.parse_program("(check (= (+ 1 2) 3))") +commands +``` + +```{code-cell} +egraph.run_program(*commands) +``` diff --git a/docs/how-to-guides/tracing.md b/docs/how-to-guides/tracing.md new file mode 100644 index 00000000..d6468349 --- /dev/null +++ b/docs/how-to-guides/tracing.md @@ -0,0 +1,93 @@ +# Tracing + +`egglog` can emit OpenTelemetry spans from both the high-level Python wrapper and the Rust bindings. +The Python package stays library-style: it only depends on `opentelemetry-api`, and it starts emitting Python spans +once your application configures an OpenTelemetry tracer provider. + +The Rust side uses the current Python trace context when one exists, so Rust spans can appear under the same parent +trace. To export Rust spans, call `egglog.bindings.setup_tracing(...)` before the traced Rust calls: + +- `exporter="console"` writes Rust spans to stdout. +- `exporter="http"` sends Rust spans to an OTLP/HTTP endpoint. + +For the contributor-oriented pytest workflow, see {doc}`../reference/contributing`. + +## Trace A Host Application + +This example configures Python tracing in an application that happens to call into `egglog`. +The Python spans come from the configured tracer provider, and the Rust spans join the same trace because the +bindings propagate the current `traceparent` and `tracestate` into the Rust tracing layer. + +```python +from opentelemetry import trace +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor + +from egglog import EGraph, bindings, i64 + +provider = TracerProvider(resource=Resource.create({"service.name": "demo-app"})) +provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) +trace.set_tracer_provider(provider) +bindings.setup_tracing(exporter="console") + +tracer = trace.get_tracer(__name__) + +with tracer.start_as_current_span("optimize"): + EGraph().extract(i64(0)) + +bindings.shutdown_tracing() +provider.shutdown() +``` + +In that setup: + +- Python spans use the module tracer names such as `egglog.egraph` and `egglog.egraph_state`. +- Python span names are the public method names such as `create`, `push`, `pop`, `register`, `run`, and `extract`, plus `run_schedule_to_egg` while schedules are lowered. +- Rust spans use names such as `bindings.run_program`, `bindings.serialize`, and `bindings.extractor.extract_best`. + +If you call the low-level bindings directly, pass `traceparent=` and `tracestate=` yourself on the traced methods. +The high-level Python API does that automatically. + +## Send Traces To Jaeger + +The official Jaeger getting-started docs use this container: + +```bash +docker run --rm \ + --name jaeger \ + -p 16686:16686 \ + -p 4317:4317 \ + -p 4318:4318 \ + -p 5778:5778 \ + -p 9411:9411 \ + cr.jaegertracing.io/jaegertracing/jaeger:2.16.0 +``` + +Point both Python and Rust tracing at Jaeger over OTLP/HTTP: + +```python +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +from egglog import bindings + +provider = TracerProvider(resource=Resource.create({"service.name": "demo-app"})) +provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint="http://127.0.0.1:4318/v1/traces"))) +trace.set_tracer_provider(provider) +bindings.setup_tracing(exporter="http", endpoint="http://127.0.0.1:4318/v1/traces") +``` + +After that, open [http://localhost:16686](http://localhost:16686) and search for traces from `demo-app` and `egglog`. + +## Local Test Runs + +If you want the same tracing setup during `pytest`, use the built-in test flags documented in +{doc}`../reference/contributing`. + +When using `--otel-traces=console` under `pytest`, pass `-s` so the console exporter output is shown as the test runs. +Console mode is best for short, targeted runs because it is intentionally verbose. For longer or hotter tests, prefer +OTLP/Jaeger tracing. diff --git a/docs/reference/contributing.md b/docs/reference/contributing.md index 5ba0125c..971029bb 100644 --- a/docs/reference/contributing.md +++ b/docs/reference/contributing.md @@ -86,6 +86,39 @@ If there is a performance sensitive piece of code, you could isolate it in a fil uv run py-spy record --format speedscope -- python -O tmp.py ``` +### Tracing + +`pytest` can also configure OpenTelemetry tracing for local debugging. For the full host-application setup and the +Jaeger startup command, see {doc}`../how-to-guides/tracing`. The pytest plugin configures both the Python tracer +provider and `egglog.bindings.setup_tracing(...)` for you. + +To print both Python and Rust spans to the console during a test run: + +```bash +uv run pytest python/tests/test_tracing.py --benchmark-disable -q -s --otel-traces=console +``` + +Console mode is intentionally verbose. It works best for short runs or a single targeted test. + +For a targeted test, pass the same flag to the test you are debugging: + +```bash +uv run pytest python/tests/test_array_api.py::test_jit[lda] -vv --benchmark-disable -s --otel-traces=console +``` + +To send spans to Jaeger over OTLP/HTTP, start Jaeger as shown in {doc}`../how-to-guides/tracing`, then run pytest +with the Jaeger tracing mode: + +```bash +uv run pytest python/tests/test_array_api.py::test_jit[lda] -vv --benchmark-disable --otel-traces=jaeger +``` + +Then open [http://localhost:16686](http://localhost:16686). + +For a longer-running or performance-sensitive test, prefer `--otel-traces=jaeger` over console mode. + +If you need a non-default OTLP endpoint, add `--otel-otlp-endpoint=http://host:4318/v1/traces`. + ### Making changes All changes that impact users should be documented in the `docs/changelog.md` file. Please also add tests for any new features diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index a472d916..390af8b0 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -211,6 +211,15 @@ The correct function type (in this case it would be `Callable[[i64, i64], Unit]` To run actions in Python, they are passed as arguments to the `egraph.register` function. We have constructors to create each kind of action. They are created and registered in this way, so that we can use the same syntax for executing them on the top level egraph as we do for defining them as results for rules. +You can also pass initial actions directly to the high-level constructor: + +```{code-cell} python +egraph = EGraph( + let("x", i64(1)), + set_(fib(0)).to(i64(0)), +) +``` + Here are examples of all the actions: ### Let @@ -526,6 +535,10 @@ The `(check ...)` command to verify that some facts are true, can be translated egraph.check(eq(fib(1)).to(i64(1))) ``` +Low-level proof-mode commands such as `Prove`, `ProveExists`, and +`ProveExistsOutput` are exposed through the bindings layer, but the high-level +Python API does not yet support a complete proof workflow. + ## Extract The `(extract ...)` command in egglog translates to the `egraph.extract` method, returning the lowest cost expression: diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index 2789a51d..23f52d8f 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -652,33 +652,107 @@ egraph.check(eq(x).to(WrappedMath(math_float(3.14)) + WrappedMath(math_float(3.1 egraph ``` -## Visualization +## Debugging and Inspection -The default renderer for the e-graph in a Jupyter Notebook [an interactive Javascript visualizer](https://github.com/egraphs-good/egraph-visualizer): +When a rule does not fire or an equality appears unexpectedly, the most useful +high-level inspection methods are `run`, `stats`, `function_values`, `freeze`, +`display`, and `saturate`. ```{code-cell} python -egraph +from __future__ import annotations + +from egglog import * + + +class DebugMath(Expr): + def __init__(self, value: i64Like) -> None: ... + + def __add__(self, other: DebugMath) -> DebugMath: ... + + +@function +def score(x: DebugMath) -> i64: ... + + +debug_rules = ruleset() + + +@debug_rules.register +def _(i: i64, j: i64): + yield rewrite(DebugMath(i) + DebugMath(j)).to(DebugMath(i + j)) +``` + +### `run` + +Use {meth}`egglog.egraph.EGraph.run` to execute a schedule and inspect the +`RunReport` for per-run counters and timings: + +```{code-cell} python +egraph = EGraph() +expr = egraph.let("expr", DebugMath(2) + DebugMath(3)) +egraph.register(set_(score(expr)).to(5)) + +report = egraph.run(debug_rules) +report.num_matches_per_rule +``` + +### `stats` + +Use {meth}`egglog.egraph.EGraph.stats` when you want cumulative counters for the +current e-graph instead of only the most recent run: + +```{code-cell} python +stats = egraph.stats() +stats.num_matches_per_rule +``` + +### `function_values` + +Use {meth}`egglog.egraph.EGraph.function_values` to inspect the current rows in a +function table: + +```{code-cell} python +egraph.function_values(score) +``` + +### `freeze` + +Use {meth}`egglog.egraph.EGraph.freeze` to snapshot the current state into a +replayable high-level program: + +```{code-cell} python +frozen = egraph.freeze() +str(frozen) ``` -You can also customize the visualization through using the method: +### `display` + +In Jupyter, the default rich display for an e-graph is the interactive +[egraph visualizer](https://github.com/egraphs-good/egraph-visualizer). You can +also call {meth}`egglog.egraph.EGraph.display` directly: ```{code-cell} python egraph.display() ``` -If you would like to visualize the progression of the e-graph over time, you can use the method to -run a number of iterations and then visualize the e-graph at each step: +### `saturate` + +Use {meth}`egglog.egraph.EGraph.saturate` to keep running until the schedule +stops changing the graph while printing the extracted form after each step: ```{code-cell} python egraph = EGraph() -egraph.register(Math(2) + Math(100)) -i, j = vars_("i j", i64) -r = ruleset( - rewrite(Math(i) + Math(j)).to(Math(i + j)), -) -egraph.saturate(r) +expr = egraph.let("expr", DebugMath(2) + DebugMath(100)) +egraph.saturate(debug_rules, expr=expr, max=2, visualize=False) ``` +Common pitfalls when authoring rules: + +- Primitive container sorts like `Vec[...]` should not be merged or unioned. +- Guard vector indexing rules with bounds checks (`0 <= k < vs.length()`). +- Ensure rules that subtract from lengths only fire when the length is proven + positive. + ## Custom Cost Models By default, when extracting from the e-graph, we use a simple cost model, that looks at the costs assigned to each diff --git a/pyproject.toml b/pyproject.toml index 42a608a4..6a8b3017 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Topic :: Software Development :: Interpreters", "Typing :: Typed", ] -dependencies = ["typing-extensions", "black", "graphviz", "anywidget", "cloudpickle>=3"] +dependencies = ["typing-extensions", "black", "graphviz", "anywidget", "cloudpickle>=3", "opentelemetry-api"] [project.optional-dependencies] @@ -52,6 +52,8 @@ test = [ "mypy", "syrupy>=5", "egglog[array]", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp-proto-http", "pytest-codspeed", "pytest-benchmark", "pytest-xdist" @@ -145,8 +147,6 @@ ignore = [ "ANN201", # Allow uppercase args "N803", - # allow generic df name - "PD901", # Allow future anywhere in file "F404", # allow imports anywhere in cell @@ -205,6 +205,10 @@ ignore = [ "TC003", # allow eq without hash "PLW1641", + # allow non module file for docs + "INP001", + # don't replace lambdas with functions because we need them to defer resolution + "PLW0108", ] select = ["ALL"] @@ -281,4 +285,5 @@ members = ["egglog"] [dependency-groups] dev = [ "py-spy>=0.4.1", + "sympy>=1.14.0", ] diff --git a/python/egglog/__init__.py b/python/egglog/__init__.py index 7d20dfdb..27b2bc61 100644 --- a/python/egglog/__init__.py +++ b/python/egglog/__init__.py @@ -8,6 +8,7 @@ from .conversion import * from .deconstruct import * from .egraph import * +from .egraph import ActionLike as ActionLike from .runtime import define_expr_method as define_expr_method del ipython_magic diff --git a/python/egglog/_tracing.py b/python/egglog/_tracing.py new file mode 100644 index 00000000..4af21f31 --- /dev/null +++ b/python/egglog/_tracing.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, TypeVar + +from opentelemetry import trace + +_R = TypeVar("_R") + + +def call_with_current_trace(fn: Callable[..., _R], /, *args: Any, **kwargs: Any) -> _R: + span_context = trace.get_current_span().get_span_context() + if not span_context.is_valid: + return fn(*args, **kwargs) + + trace_kwargs = { + "traceparent": ( + f"00-{span_context.trace_id:032x}-{span_context.span_id:016x}-{int(span_context.trace_flags):02x}" + ) + } + tracestate = span_context.trace_state.to_header() + if tracestate: + trace_kwargs["tracestate"] = tracestate + return fn(*args, **kwargs, **trace_kwargs) diff --git a/python/egglog/bindings.pyi b/python/egglog/bindings.pyi index 7e585dab..6c05af47 100644 --- a/python/egglog/bindings.pyi +++ b/python/egglog/bindings.pyi @@ -2,7 +2,7 @@ from collections.abc import Callable from datetime import timedelta from fractions import Fraction from pathlib import Path -from typing import Any, Generic, Protocol, TypeAlias, TypeVar, final +from typing import Any, Generic, Literal, Protocol, TypeAlias, TypeVar, final __all__ = [ "ActionCommand", @@ -31,6 +31,9 @@ __all__ = [ "Fact", "Fail", "Float", + "FrozenEGraph", + "FrozenFunction", + "FrozenRow", "Function", "FunctionCommand", "FusedIntersect", @@ -55,6 +58,9 @@ __all__ = [ "PrintFunctionSize", "PrintOverallStatistics", "PrintSize", + "Prove", + "ProveExists", + "ProveExistsOutput", "Push", "Relation", "Repeat", @@ -99,8 +105,13 @@ __all__ = [ "Var", "Variant", "WithPlan", + "setup_tracing", + "shutdown_tracing", ] +def setup_tracing(*, exporter: Literal["console", "http"], endpoint: str | None = None) -> None: ... +def shutdown_tracing() -> None: ... + @final class SerializedEGraph: @property @@ -121,7 +132,9 @@ class EGraph: ) -> EGraph: ... def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ... def commands(self) -> str | None: ... - def run_program(self, *commands: _Command) -> list[_CommandOutput]: ... + def run_program( + self, *commands: _Command, traceparent: str | None = None, tracestate: str | None = None + ) -> list[_CommandOutput]: ... def serialize( self, root_eclasses: list[_Expr], @@ -129,10 +142,14 @@ class EGraph: max_functions: int | None = None, max_calls_per_function: int | None = None, include_temporary_functions: bool = False, + traceparent: str | None = None, + tracestate: str | None = None, ) -> SerializedEGraph: ... def set_report_level(self, level: _ReportLevel) -> None: ... def lookup_function(self, name: str, key: list[Value]) -> Value | None: ... - def eval_expr(self, expr: _Expr) -> tuple[str, Value]: ... + def eval_expr( + self, expr: _Expr, *, traceparent: str | None = None, tracestate: str | None = None + ) -> tuple[str, Value]: ... def value_to_i64(self, v: Value) -> int: ... def value_to_f64(self, v: Value) -> float: ... def value_to_string(self, v: Value) -> str: ... @@ -147,6 +164,7 @@ class EGraph: def value_to_function(self, v: Value) -> tuple[str, list[Value]]: ... def value_to_set(self, v: Value) -> set[Value]: ... # def dynamic_cost_model_enode_cost(self, func: str, args: list[Value]) -> int: ... + def freeze(self) -> FrozenEGraph: ... @final class Value: @@ -268,6 +286,7 @@ class TermApp: def __new__(cls, name: str, args: list[int]) -> TermApp: ... _Term: TypeAlias = TermLit | TermVar | TermApp +_TermId: TypeAlias = int ## # Facts @@ -530,15 +549,20 @@ class PrintAllFunctionsSize: @final class ExtractVariants: termdag: TermDag - terms: list[_Term] - def __new__(cls, termdag: TermDag, terms: list[_Term]) -> ExtractVariants: ... + terms: list[_TermId] + def __new__(cls, termdag: TermDag, terms: list[_TermId]) -> ExtractVariants: ... @final class ExtractBest: termdag: TermDag cost: int - term: _Term - def __new__(cls, termdag: TermDag, cost: int, term: _Term) -> ExtractBest: ... + term: _TermId + def __new__(cls, termdag: TermDag, cost: int, term: _TermId) -> ExtractBest: ... + +@final +class ProveExistsOutput: + proof: str + def __new__(cls, proof: str) -> ProveExistsOutput: ... @final class OverallStatistics: @@ -554,10 +578,10 @@ class RunScheduleOutput: class PrintFunctionOutput: function: Function termdag: TermDag - terms: list[tuple[_Term, _Term]] + terms: list[tuple[_TermId, _TermId]] mode: _PrintFunctionMode def __new__( - cls, function: Function, termdag: TermDag, terms: list[tuple[_Term, _Term]], mode: _PrintFunctionMode + cls, function: Function, termdag: TermDag, terms: list[tuple[_TermId, _TermId]], mode: _PrintFunctionMode ) -> PrintFunctionOutput: ... @final @@ -570,6 +594,7 @@ _CommandOutput: TypeAlias = ( | PrintAllFunctionsSize | ExtractVariants | ExtractBest + | ProveExistsOutput | OverallStatistics | RunScheduleOutput | PrintFunctionOutput @@ -717,6 +742,18 @@ class Check: facts: list[_Fact] def __new__(cls, span: _Span, facts: list[_Fact]) -> Check: ... +@final +class Prove: + span: _Span + facts: list[_Fact] + def __new__(cls, span: _Span, facts: list[_Fact]) -> Prove: ... + +@final +class ProveExists: + span: _Span + expr: str + def __new__(cls, span: _Span, expr: str) -> ProveExists: ... + @final class PrintFunction: span: _Span @@ -822,6 +859,8 @@ _Command: TypeAlias = ( | RunSchedule | Extract | Check + | Prove + | ProveExists | PrintFunction | PrintSize | Output @@ -844,14 +883,14 @@ _Command: TypeAlias = ( @final class TermDag: def size(self) -> int: ... - def lookup(self, node: _Term) -> int: ... - def get(self, id: int) -> _Term: ... - def app(self, sym: str, children: list[int]) -> _Term: ... - def lit(self, lit: _Literal) -> _Term: ... - def var(self, sym: str) -> _Term: ... - def expr_to_term(self, expr: _Expr) -> _Term: ... - def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ... - def to_string(self, term: _Term) -> str: ... + def lookup(self, node: _Term) -> _TermId: ... + def get(self, id: _TermId) -> _Term: ... + def app(self, sym: str, children: list[_TermId]) -> _TermId: ... + def lit(self, lit: _Literal) -> _TermId: ... + def var(self, sym: str) -> _TermId: ... + def expr_to_term(self, expr: _Expr) -> _TermId: ... + def term_to_expr(self, term: _TermId, span: _Span) -> _Expr: ... + def to_string(self, term: _TermId) -> str: ... ## # Extraction @@ -879,9 +918,53 @@ class CostModel(Generic[_COST, _ENODE_COST]): @final class Extractor(Generic[_COST]): def __new__( - cls, rootsorts: list[str] | None, egraph: EGraph, cost_model: CostModel[_COST, Any] + cls, + rootsorts: list[str] | None, + egraph: EGraph, + cost_model: CostModel[_COST, Any], + *, + traceparent: str | None = None, + tracestate: str | None = None, ) -> Extractor[_COST]: ... - def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _Term]: ... + def extract_best( + self, + egraph: EGraph, + termdag: TermDag, + value: Value, + sort: str, + *, + traceparent: str | None = None, + tracestate: str | None = None, + ) -> tuple[_COST, _TermId]: ... def extract_variants( - self, egraph: EGraph, termdag: TermDag, value: Value, nvariants: int, sort: str - ) -> list[tuple[_COST, _Term]]: ... + self, + egraph: EGraph, + termdag: TermDag, + value: Value, + nvariants: int, + sort: str, + *, + traceparent: str | None = None, + tracestate: str | None = None, + ) -> list[tuple[_COST, _TermId]]: ... + +## +# Frozen +## + +@final +class FrozenEGraph: + functions: dict[str, FrozenFunction] + +@final +class FrozenFunction: + input_sorts: list[str] + output_sort: str + is_let_binding: bool + rows: list[FrozenRow] + +@final +class FrozenRow: + subsumed: bool + inputs: list[Value] + output: Value diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index a9c75f69..5f11934b 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -56,6 +56,12 @@ "i64", "i64Like", "join", + "multiset_contains_swapped", + "multiset_flat_map", + "multiset_fold", + "multiset_not_contains_swapped", + "multiset_remove_swapped", + "multiset_subtract_swapped", "py_eval", "py_eval_fn", "py_exec", @@ -77,7 +83,7 @@ def __str__(self) -> str: class Unit(BuiltinExpr, egg_sort="Unit"): """ - The unit type. This is used to reprsent if a value exists in the e-graph or not. + The unit type. This is used to represent if a value exists in the e-graph or not. """ def __init__(self) -> None: ... @@ -111,6 +117,9 @@ def replace(self, old: StringLike, new: StringLike) -> String: ... def __add__(self, other: StringLike) -> String: return join(self, other) + @method(egg_fn="log") + def log(self) -> Unit: ... + StringLike: TypeAlias = String | str @@ -281,8 +290,14 @@ def bool_le(self, other: i64Like) -> Bool: ... @method(egg_fn="bool->=") def bool_ge(self, other: i64Like) -> Bool: ... + @method(egg_fn="abs") + def __abs__(self) -> i64: ... + + @method(egg_fn="vec-range") + def range(self) -> Vec[i64]: ... + -# The types which can be convertered into an i64 +# The types which can be converted into an i64 i64Like: TypeAlias = i64 | int # noqa: N816, PYI042 converter(int, i64, i64) @@ -348,6 +363,9 @@ def __rtruediv__(self, other: f64Like) -> f64: ... def __rmod__(self, other: f64Like) -> f64: ... + @method(egg_fn="abs") + def __abs__(self) -> f64: ... + @method(egg_fn="<") def __lt__(self, other: f64Like) -> Unit: # type: ignore[has-type] ... @@ -399,7 +417,7 @@ def eval(self) -> dict[T, V]: @property def value(self) -> dict[T, V]: d = {} - while args := get_callable_args(self, Map[T, V].insert): + while args := get_callable_args(self, Map.insert): # type: ignore[var-annotated] self, k, v = args # noqa: PLW0642 d[k] = v if get_callable_args(self, Map.empty) is None: @@ -521,9 +539,7 @@ def rebuild(self) -> Set[T]: ... converter( set, Set, - lambda t: Set[get_type_args()[0]]( # type: ignore[misc,operator] - *(convert(x, get_type_args()[0]) for x in t) - ), + lambda t: Set(*(convert(x, get_type_args()[0]) for x in t)) if t else Set[get_type_args()[0]].empty(), # type: ignore[misc] ) SetLike: TypeAlias = Set[T] | set[TO] @@ -559,6 +575,17 @@ def __contains__(self, key: T) -> bool: @method(egg_fn="multiset-of") def __init__(self, *args: T) -> None: ... + @method(egg_fn="multiset-intersection") + def __and__(self, other: MultiSet[T]) -> MultiSet[T]: ... + + @method(egg_fn="multiset-single") + @classmethod + def single(cls, x: T, i: i64Like) -> MultiSet[T]: ... + + @method(egg_fn="multiset-sum-multisets") + @classmethod + def sum_multisets(cls, xs: MultiSet[MultiSet[T]]) -> MultiSet[T]: ... + @method(egg_fn="multiset-insert") def insert(self, value: T) -> MultiSet[T]: ... @@ -580,16 +607,63 @@ def pick(self) -> T: ... @method(egg_fn="multiset-sum") def __add__(self, other: MultiSet[T]) -> MultiSet[T]: ... + @method(egg_fn="multiset-subtract") + def __sub__(self, other: MultiSet[T]) -> MultiSet[T]: ... + @method(egg_fn="unstable-multiset-map", reverse_args=True) - def map(self, f: Callable[[T], T]) -> MultiSet[T]: ... + def map(self, f: Callable[[T], V]) -> MultiSet[V]: ... + + @method(egg_fn="unstable-multiset-fill-index") + def fill_index(self, f: Callable[[MultiSet[T], T], i64]) -> Unit: ... + + @method(egg_fn="unstable-multiset-clear-index") + def clear_index(self, f: Callable[[MultiSet[T], T], i64]) -> Unit: ... + + @method(egg_fn="multiset-pick-max") + def pick_max(self) -> T: ... + + @method(egg_fn="multiset-count") + def count(self, value: T) -> i64: ... + + @method(egg_fn="unstable-multiset-filter", reverse_args=True) + def filter(self, f: Callable[[T], Unit]) -> MultiSet[T]: ... + + @method(egg_fn="unstable-multiset-filter-not", reverse_args=True) + def filter_not(self, f: Callable[[T], Unit]) -> MultiSet[T]: ... + + @method(egg_fn="multiset-reset-counts") + def reset_counts(self) -> MultiSet[T]: ... + + +# TODO: Move to method when partial supports reverse_args +@function(egg_fn="unstable-multiset-flat-map", builtin=True) +def multiset_flat_map(f: Callable[[T], MultiSet[T]], xs: MultiSet[T]) -> MultiSet[T]: ... + + +@function(egg_fn="multiset-remove-swapped", builtin=True) +def multiset_remove_swapped(x: T, xs: MultiSet[T]) -> MultiSet[T]: ... + + +@function(egg_fn="multiset-subtract-swapped", builtin=True) +def multiset_subtract_swapped(x: MultiSet[T], xs: MultiSet[T]) -> MultiSet[T]: ... + + +@function(egg_fn="multiset-not-contains-swapped", builtin=True) +def multiset_not_contains_swapped(x: T, xs: MultiSet[T]) -> Unit: ... + + +@function(egg_fn="multiset-contains-swapped", builtin=True) +def multiset_contains_swapped(x: T, xs: MultiSet[T]) -> Unit: ... + + +@function(egg_fn="unstable-multiset-reduce", builtin=True) +def multiset_fold(f: Callable[[T, T], T], initial: T, xs: MultiSet[T]) -> T: ... converter( tuple, MultiSet, - lambda t: MultiSet[get_type_args()[0]]( # type: ignore[misc,operator] - *(convert(x, get_type_args()[0]) for x in t) - ), + lambda t: MultiSet(*(convert(x, get_type_args()[0]) for x in t)) if t else MultiSet[get_type_args()[0]](), # type: ignore[operator,misc] ) MultiSetLike: TypeAlias = MultiSet[T] | tuple[TO, ...] @@ -801,7 +875,7 @@ def bool_le(self, other: BigIntLike) -> Bool: ... def bool_ge(self, other: BigIntLike) -> Bool: ... -converter(i64, BigInt, lambda i: BigInt(i)) +converter(i64, BigInt, BigInt) BigIntLike: TypeAlias = BigInt | i64Like @@ -943,7 +1017,7 @@ def __init__(self, *args: T) -> None: ... def empty(cls) -> Vec[T]: ... @method(egg_fn="vec-append") - def append(self, *others: Vec[T]) -> Vec[T]: ... + def append(self, *others: VecLike[T, T]) -> Vec[T]: ... @method(egg_fn="vec-push") def push(self, value: T) -> Vec[T]: ... @@ -972,97 +1046,23 @@ def remove(self, index: i64Like) -> Vec[T]: ... @method(egg_fn="vec-set") def set(self, index: i64Like, value: T) -> Vec[T]: ... + @method(egg_fn="vec-union") + def __or__(self, other: Vec[T]) -> Vec[T]: ... + + @method(egg_fn="unstable-vec-map", reverse_args=True) + def map(self, fn: Callable[[T], V]) -> Vec[V]: ... + for sequence_type in (list, tuple): converter( sequence_type, Vec, - lambda t: Vec[get_type_args()[0]]( # type: ignore[misc,operator] - *(convert(x, get_type_args()[0]) for x in t) - ), + lambda t: Vec(*(convert(x, get_type_args()[0]) for x in t)) if t else Vec[get_type_args()[0]].empty(), # type: ignore[misc] ) VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO] -class PyObject(BuiltinExpr, egg_sort="PyObject"): - @method(preserve=True) - @deprecated("use .value") - def eval(self) -> object: - return self.value - - @method(preserve=True) # type: ignore[prop-decorator] - @property - def value(self) -> object: - expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr - if not isinstance(expr, PyObjectDecl): - raise ExprValueError(self, "PyObject(x)") - return cloudpickle.loads(expr.pickled) - - __match_args__ = ("value",) - - def __init__(self, value: object) -> None: ... - - @method(egg_fn="py-call") - def __call__(self, *args: object) -> PyObject: ... - - @method(egg_fn="py-call-extended") - def call_extended(self, args: PyObject, kwargs: PyObject) -> PyObject: - """ - Call the PyObject with the given args and kwargs PyObjects. - """ - - @method(egg_fn="py-from-string") - @classmethod - def from_string(cls, s: StringLike) -> PyObject: ... - - @method(egg_fn="py-to-string") - def to_string(self) -> String: ... - - @method(egg_fn="py-to-bool") - def to_bool(self) -> Bool: ... - - @method(egg_fn="py-dict-update") - def dict_update(self, *keys_and_values: object) -> PyObject: ... - - @method(egg_fn="py-from-int") - @classmethod - def from_int(cls, i: i64Like) -> PyObject: ... - - @method(egg_fn="py-dict") - @classmethod - def dict(cls, *keys_and_values: object) -> PyObject: ... - - -converter(object, PyObject, PyObject) - - -@function(builtin=True, egg_fn="py-eval") -def py_eval(code: StringLike, globals: object = PyObject.dict(), locals: object = PyObject.dict()) -> PyObject: ... - - -class PyObjectFunction(Protocol): - def __call__(self, *__args: PyObject) -> PyObject: ... - - -@deprecated("use PyObject(fn) directly") -def py_eval_fn(fn: Callable) -> PyObjectFunction: - """ - Takes a python callable and maps it to a callable which takes and returns PyObjects. - - It translates it to a call which uses `py_eval` to call the function, passing in the - args as locals, and using the globals from function. - """ - return PyObject(fn) - - -@function(builtin=True, egg_fn="py-exec") -def py_exec(code: StringLike, globals: object = PyObject.dict(), locals: object = PyObject.dict()) -> PyObject: - """ - Copies the locals, execs the Python code, and returns the locals with any updates. - """ - - TS = TypeVarTuple("TS") T1 = TypeVar("T1") @@ -1112,9 +1112,11 @@ def __call__(self, *args: *TS) -> T: ... # Method Type is for builtins like __getitem__ -converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__)) -converter(RuntimeFunction, UnstableFn, UnstableFn) -converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args)) +converter(MethodType, UnstableFn, lambda m: UnstableFn[*get_type_args()](m.__func__, m.__self__)) # type: ignore[operator, misc] +# Ignore PLW0108. +converter(RuntimeFunction, UnstableFn, lambda rf: UnstableFn[*get_type_args()](rf)) # type: ignore[operator, misc] +# converter(RuntimeClass, UnstableFn, lambda rc: UnstableFn[*get_type_args()](rc)) # type: ignore[operator, misc] +converter(partial, UnstableFn, lambda p: UnstableFn[*get_type_args()](p.func, *p.args)) # type: ignore[operator, misc] def _convert_function(fn: FunctionType) -> UnstableFn: @@ -1152,5 +1154,83 @@ def _convert_function(fn: FunctionType) -> UnstableFn: converter(FunctionType, UnstableFn, _convert_function) +class PyObject(BuiltinExpr, egg_sort="PyObject"): + @method(preserve=True) + @deprecated("use .value") + def eval(self) -> object: + return self.value + + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> object: + expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr + if not isinstance(expr, PyObjectDecl): + raise ExprValueError(self, "PyObject(x)") + return cloudpickle.loads(expr.pickled) + + __match_args__ = ("value",) + + def __init__(self, value: object) -> None: ... + + @method(egg_fn="py-call") + def __call__(self, *args: object) -> PyObject: ... + + @method(egg_fn="py-call-extended") + def call_extended(self, args: PyObject, kwargs: PyObject) -> PyObject: + """ + Call the PyObject with the given args and kwargs PyObjects. + """ + + @method(egg_fn="py-from-string") + @classmethod + def from_string(cls, s: StringLike) -> PyObject: ... + + @method(egg_fn="py-to-string") + def to_string(self) -> String: ... + + @method(egg_fn="py-to-bool") + def to_bool(self) -> Bool: ... + + @method(egg_fn="py-dict-update") + def dict_update(self, *keys_and_values: object) -> PyObject: ... + + @method(egg_fn="py-from-int") + @classmethod + def from_int(cls, i: i64Like) -> PyObject: ... + + @method(egg_fn="py-dict") + @classmethod + def dict(cls, *keys_and_values: object) -> PyObject: ... + + +converter(object, PyObject, PyObject) + + +@function(builtin=True, egg_fn="py-eval") +def py_eval(code: StringLike, globals_: object = PyObject.dict(), locals_: object = PyObject.dict()) -> PyObject: ... + + +class PyObjectFunction(Protocol): + def __call__(self, *__args: PyObject) -> PyObject: ... + + +@deprecated("use PyObject(fn) directly") +def py_eval_fn(fn: Callable) -> PyObjectFunction: + """ + Takes a python callable and maps it to a callable which takes and returns PyObjects. + + It translates it to a call which uses `py_eval` to call the function, passing in the + args as locals, and using the globals from function. + """ + return PyObject(fn) + + +@function(builtin=True, egg_fn="py-exec") +def py_exec(code: StringLike, globals_: object = PyObject.dict(), locals_: object = PyObject.dict()) -> PyObject: + """ + Copies the locals, execs the Python code, and returns the locals with any updates. + """ + + Container: TypeAlias = Map | Set | MultiSet | Vec | UnstableFn Primitive: TypeAlias = String | Bool | i64 | f64 | Rational | BigInt | BigRat | PyObject | Unit diff --git a/python/egglog/conversion.py b/python/egglog/conversion.py index 3690df1b..13baac55 100644 --- a/python/egglog/conversion.py +++ b/python/egglog/conversion.py @@ -11,22 +11,20 @@ from .pretty import * from .runtime import * from .thunk import * -from .type_constraint_solver import TypeConstraintError if TYPE_CHECKING: from collections.abc import Generator from .egraph import BaseExpr - from .type_constraint_solver import TypeConstraintSolver __all__ = ["ConvertError", "convert", "converter", "get_type_args"] # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {} -# Global declerations to store all convertable types so we can query if they have certain methods or not +# Global declarations to store all convertible types so we can query if they have certain methods or not _CONVERSION_DECLS = Declarations.create() -# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing +# Defer a list of declarations to be added to the global declarations, so that we can not trigger them processing # until we need them -_TO_PROCESS_DECLS: list[DeclerationsLike] = [] +_TO_PROCESS_DECLS: list[DeclarationsLike] = [] def retrieve_conversion_decls() -> Declarations: @@ -134,7 +132,7 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr: def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type: """ - Process a type before converting it, to add it to the global declerations and resolve to a ref. + Process a type before converting it, to add it to the global declarations and resolve to a ref. """ if isinstance(tp, RuntimeClass): _TO_PROCESS_DECLS.append(tp) @@ -220,41 +218,19 @@ def resolve_literal( tp: TypeOrVarRef, arg: object, decls: Callable[[], Declarations] = retrieve_conversion_decls, - tcs: TypeConstraintSolver | None = None, - cls_ident: Ident | None = None, ) -> RuntimeExpr: """ Try to convert an object to a type, raising a ConvertError if it is not possible. - If the type has vars in it, they will be tried to be resolved into concrete vars based on the type constraint solver. - If it cannot be resolved, we assume that the value passed in will resolve it. """ - arg_type = resolve_type(arg) - - # If we have any type variables, dont bother trying to resolve the literal, just return the arg - try: - tp_just = tp.to_just() - except TypeVarError: - # If this is a generic arg but passed in a non runtime expression, try to resolve the generic - # args first based on the existing type constraint solver - if tcs: - try: - tp_just = tcs.substitute_typevars(tp, cls_ident) - # If we can't resolve the type var yet, then just assume it is the right value - except TypeConstraintError: - assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}" - tp_just = arg.__egg_typed_expr__.tp - else: - # If this is a var, it has to be a runtime expession - assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}" - return arg - if tcs: - tcs.infer_typevars(tp, tp_just, cls_ident) - if arg_type == tp_just: - # If the type is an egg type, it has to be a runtime expr - assert isinstance(arg, RuntimeExpr) + # If this is a runtime expression that could match the type already, just return it + if isinstance(arg, RuntimeExpr) and tp.matches_just({}, arg.__egg_typed_expr__.tp): return arg + tp_just = tp.to_just() + if arg is DUMMY_VALUE: + return RuntimeExpr.__from_values__(decls(), TypedExprDecl(tp_just, DummyDecl())) + arg_type = resolve_type(arg) if (conversion := _lookup_conversion(arg_type, tp_just)) is not None: with with_type_args(tp_just.args, decls): return conversion[1](arg) @@ -265,7 +241,7 @@ def _lookup_conversion(lhs: type | JustTypeRef, rhs: JustTypeRef) -> tuple[int, """ Looks up a conversion function for the given types. - Also looks up all parent types of the lhs if it is a Python type and looks up more general not paramtrized types for rhs. + Also looks up all parent types of the lhs if it is a Python type and looks up more general not parametrized types for rhs. """ for lhs_type in lhs.__mro__ if isinstance(lhs, type) else [lhs]: if (key := (lhs_type, rhs)) in CONVERSIONS: @@ -275,7 +251,7 @@ def _lookup_conversion(lhs: type | JustTypeRef, rhs: JustTypeRef) -> tuple[int, return None -def _debug_print_converers(): +def _debug_print_converters(): """ Prints a mapping of all source types to target types that have a conversion function. """ diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index e2ffa08f..9a5d08e2 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -1,13 +1,14 @@ """ Data only descriptions of the components of an egraph and the expressions. -We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types. +We separate it it into two pieces, the references the declarations, so that we can report mutually recursive types. """ from __future__ import annotations from dataclasses import dataclass, field -from functools import cached_property +from functools import cache, cached_property +from itertools import chain, repeat from typing import ( TYPE_CHECKING, ClassVar, @@ -41,7 +42,6 @@ "ChangeDecl", "ClassDecl", "ClassMethodRef", - "ClassTypeVarRef", "ClassVariableRef", "CombinedRulesetDecl", "CommandDecl", @@ -50,9 +50,11 @@ "ConstructorDecl", "Declarations", "Declarations", - "DeclerationsLike", + "DeclarationsLike", "DefaultRewriteDecl", - "DelayedDeclerations", + "DelayedDeclarations", + "DummyDecl", + "EGraphDecl", "EqDecl", "ExprActionDecl", "ExprDecl", @@ -62,7 +64,7 @@ "FunctionRef", "FunctionSignature", "GetCostDecl", - "HasDeclerations", + "HasDeclarations", "Ident", "InitRef", "JustTypeRef", @@ -92,6 +94,7 @@ "TypeOrVarRef", "TypeRefWithVars", "TypeVarError", + "TypeVarRef", "TypedExprDecl", "UnboundVarDecl", "UnionDecl", @@ -99,12 +102,12 @@ "ValueDecl", "collect_unbound_vars", "replace_typed_expr", - "upcast_declerations", + "upcast_declarations", ] @dataclass(match_args=False) -class DelayedDeclerations: +class DelayedDeclarations: __egg_decls_thunk__: Callable[[], Declarations] = field(repr=False) @property @@ -120,20 +123,20 @@ def __egg_decls__(self) -> Declarations: @runtime_checkable -class HasDeclerations(Protocol): +class HasDeclarations(Protocol): @property def __egg_decls__(self) -> Declarations: ... -DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"] +DeclarationsLike: TypeAlias = Union[HasDeclarations, None, "Declarations"] -def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]: +def upcast_declarations(declarations_like: Iterable[DeclarationsLike]) -> list[Declarations]: d = [] - for l in declerations_like: + for l in declarations_like: if l is None: continue - if isinstance(l, HasDeclerations): + if isinstance(l, HasDeclarations): d.append(l.__egg_decls__) elif isinstance(l, Declarations): d.append(l) @@ -177,8 +180,8 @@ def default_ruleset(self) -> RulesetDecl: return ruleset @classmethod - def create(cls, *others: DeclerationsLike) -> Declarations: - others = upcast_declerations(others) + def create(cls, *others: DeclarationsLike) -> Declarations: + others = upcast_declarations(others) if not others: return Declarations() first, *rest = others @@ -193,26 +196,26 @@ def copy(self) -> Declarations: self.update_other(new) return new - def update(self, *others: DeclerationsLike) -> None: + def update(self, *others: DeclarationsLike) -> None: for other in others: self |= other - def __or__(self, other: DeclerationsLike) -> Declarations: + def __or__(self, other: DeclarationsLike) -> Declarations: result = self.copy() result |= other return result - def __ior__(self, other: DeclerationsLike) -> Self: + def __ior__(self, other: DeclarationsLike) -> Self: if other is None: return self - if isinstance(other, HasDeclerations): + if isinstance(other, HasDeclarations): other = other.__egg_decls__ other.update_other(self) return self def update_other(self, other: Declarations) -> None: """ - Updates the other decl with these values in palce. + Updates the other decl with these values in place. """ other._functions |= self._functions other._classes |= self._classes @@ -269,7 +272,7 @@ def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRe """ Checks if the class has a binary method compatible with the given types. """ - vars: dict[ClassTypeVarRef, JustTypeRef] = {} + vars: dict[TypeVarRef, JustTypeRef] = {} if callable_decl := self._classes[self_type.ident].methods.get(method_name): match callable_decl.signature: case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just( @@ -282,7 +285,7 @@ def check_binary_method_with_self_type(self, method_name: str, self_type: JustTy """ Checks if the class has a binary method with the given name and self type. Returns the other type if it exists. """ - vars: dict[ClassTypeVarRef, JustTypeRef] = {} + vars: dict[TypeVarRef, JustTypeRef] = {} class_decl = self._classes.get(self_type.ident) if class_decl is None: return None @@ -297,7 +300,7 @@ def check_binary_method_with_other_type(self, method_name: str, other_type: Just Returns the types which are compatible with the given binary method name and other type. """ for class_decl in self._classes.values(): - vars: dict[ClassTypeVarRef, JustTypeRef] = {} + vars: dict[TypeVarRef, JustTypeRef] = {} if callable_decl := class_decl.methods.get(method_name): match callable_decl.signature: case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just( @@ -308,9 +311,9 @@ def check_binary_method_with_other_type(self, method_name: str, other_type: Just def get_class_decl(self, ident: Ident) -> ClassDecl: return self._classes[ident] - def get_paramaterized_class(self, ident: Ident) -> TypeRefWithVars: + def get_parameterized_class(self, ident: Ident) -> TypeRefWithVars: """ - Returns a class reference with type parameters, if the class is paramaterized. + Returns a class reference with type parameters, if the class is parameterized. """ type_vars = self._classes[ident].type_vars return TypeRefWithVars(ident, type_vars) @@ -319,11 +322,11 @@ def get_paramaterized_class(self, ident: Ident) -> TypeRefWithVars: @dataclass class ClassDecl: egg_name: str | None = None - type_vars: tuple[ClassTypeVarRef, ...] = () + type_vars: tuple[TypeVarRef, ...] = () builtin: bool = False init: ConstructorDecl | FunctionDecl | None = None class_methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict) - # These have to be seperate from class_methods so that printing them can be done easily + # These have to be separate from class_methods so that printing them can be done easily class_variables: dict[str, ConstantDecl] = field(default_factory=dict) methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict) properties: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict) @@ -347,6 +350,175 @@ class CombinedRulesetDecl: rulesets: tuple[Ident, ...] +T_expr_decl = TypeVar("T_expr_decl", bound="ExprDecl") + + +@dataclass(frozen=True) +class EGraphDecl: + """ + State of an e-graph, which when re-added to a new e-graph will reconstruct the same e-graph, given the same Declarations. + + All the expressions in here may reference values which appear in the `e_classes` mapping. + """ + + # Mapping from top level let binding names to their types and expressions + let_bindings: dict[str, TypedExprDecl] = field(default_factory=dict) + # Mapping from egglog values representing e-classes to all the expressions in that e-class + e_classes: dict[Value, tuple[JustTypeRef, tuple[CallDecl, ...]]] = field(default_factory=dict) + # Mapping from function calls to the values they are set to + sets: dict[CallDecl, TypedExprDecl] = field(default_factory=dict) + # Top-level expr actions such as relation facts. + expr_actions: tuple[TypedExprDecl, ...] = field(default=()) + # Mapping from function calls to the set costs. + costs: dict[CallDecl, tuple[JustTypeRef, int]] = field(default_factory=dict) + # Set of values which are subsumed + subsumed: tuple[tuple[JustTypeRef, CallDecl], ...] = field(default=()) + + def __hash__(self) -> int: + return hash(( + type(self), + frozenset(self.let_bindings.items()), + frozenset((value, tp, exprs) for value, (tp, exprs) in self.e_classes.items()), + frozenset(self.sets.items()), + self.expr_actions, + frozenset(self.costs.items()), + self.subsumed, + )) + + @cached_property + def to_actions(self) -> list[ActionDecl]: # noqa: C901 + """ + Converts this egraph decl to a list of actions that can be executed to reconstruct the egraph. + + Converts all e-classes to grounded terms + unions. + + Currently does not support cycles or empty e-classes. + """ + # First fill up the e_class_grounded_term for all e_classes + # by iteratively adding grounded terms for e-classes which have a grounded term until no more progress can be made. + + # mapping from e-class to a grounded term in that e-class + e_class_grounded_term: dict[Value, CallDecl] = {} + + def is_grounded(expr: ExprDecl) -> bool: + """ + Checks if the given expression is grounded, meaning any values recursively in it have grounded terms in their e-classes. + """ + match expr: + case LetRefDecl(name): + raise ValueError(f"Cannot have unexpanded let bindings in egraph decl: {name}") + case UnboundVarDecl(_): + msg = "Cannot have unbound variables in egraph decl" + raise ValueError(msg) + case CallDecl(_, args, _): + return all(is_grounded(a.expr) for a in args) + case LitDecl(_) | PyObjectDecl(_): + return True + case PartialCallDecl(call): + return is_grounded(call) + case DummyDecl(): + msg = "Cannot have dummy decls in egraph decl" + raise ValueError(msg) + case ValueDecl(value): + return value in e_class_grounded_term + case GetCostDecl(): + msg = "Cannot have GetCostDecl in egraph decl" + raise ValueError(msg) + case _: + assert_never(expr) + + made_progress = True + while made_progress: + made_progress = False + for e_class, (_, exprs) in self.e_classes.items(): + if e_class in e_class_grounded_term: + continue + for expr in exprs: + if is_grounded(expr): + e_class_grounded_term[e_class] = expr + made_progress = True + break + + # call declarations already emitted as part of other actions. + emitted_call_decls = set[CallDecl]() + + @cache + def to_grounded(expr: ExprDecl) -> ExprDecl: + """ + Converts the given expression to a grounded term, by replacing any values in it with their grounded terms. + """ + match expr: + case LetRefDecl(name): + raise ValueError(f"Cannot have unexpanded let bindings in egraph decl: {name}") + case UnboundVarDecl(_): + msg = "Cannot have unbound variables in egraph decl" + raise ValueError(msg) + case CallDecl(callable, args, bound_tp_params): + emitted_call_decls.add(expr) + new_args = tuple(TypedExprDecl(a.tp, to_grounded(a.expr)) for a in args) + return CallDecl(callable, new_args, bound_tp_params) + case LitDecl(_) | PyObjectDecl(_): + return expr + case PartialCallDecl(call): + return PartialCallDecl(cast("CallDecl", to_grounded(call))) + case DummyDecl(): + msg = "Cannot have dummy decls in egraph decl" + raise ValueError(msg) + case ValueDecl(value): + if value not in e_class_grounded_term: + raise ValueError(f"Value {value} does not have a grounded term in egraph decl") + return to_grounded(e_class_grounded_term[value]) + case GetCostDecl(): + msg = "Cannot have GetCostDecl in egraph decl" + raise ValueError(msg) + case _: + assert_never(expr) + + # calls that are in e-classes with only one value, so wouldn't be added as a union and might need + # to be added as a single expr action if they don't show up anywhere else + single_e_class_calls: list[tuple[JustTypeRef, CallDecl]] = [] + + # Now add all e-classes as actions. + actions: list[ActionDecl] = [] + for e_class, (tp, exprs) in self.e_classes.items(): + chosen_term = e_class_grounded_term[e_class] + if len(exprs) == 1: + single_e_class_calls.append((tp, chosen_term)) + continue + + grounded_chosen_term = to_grounded(chosen_term) + for expr in exprs: + if expr == chosen_term: + continue + actions.append(UnionDecl(tp, grounded_chosen_term, to_grounded(expr))) + actions.extend( + LetDecl(name, TypedExprDecl(typed_expr.tp, to_grounded(typed_expr.expr))) + for name, typed_expr in self.let_bindings.items() + ) + actions.extend( + SetDecl(set_expr.tp, cast("CallDecl", to_grounded(call)), to_grounded(set_expr.expr)) + for call, set_expr in self.sets.items() + ) + actions.extend( + ExprActionDecl(TypedExprDecl(typed_expr.tp, to_grounded(typed_expr.expr))) + for typed_expr in self.expr_actions + ) + actions.extend( + SetCostDecl(tp, cast("CallDecl", to_grounded(call)), LitDecl(cost)) + for call, (tp, cost) in self.costs.items() + ) + actions.extend(ChangeDecl(tp, cast("CallDecl", to_grounded(call)), "subsume") for tp, call in self.subsumed) + + # Now add any remaining calls that weren't part of any other actions + actions.extend( + ExprActionDecl(TypedExprDecl(tp, to_grounded(expr))) + for (tp, expr) in single_e_class_calls + if expr not in emitted_call_decls + ) + + return actions + + # Have two different types of type refs, one that can include vars recursively and one that cannot. # We only use the one with vars for classmethods and methods, and the other one for egg references as # well as runtime values. @@ -371,7 +543,7 @@ def __str__(self) -> str: # mapping of name and module of resolved typevars to runtime values # so that when spitting them back out again can use same instance # since equality is based on identity not value -_RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {} +_RESOLVED_TYPEVARS: dict[TypeVarRef, TypeVar] = {} class TypeVarError(RuntimeError): @@ -379,14 +551,14 @@ class TypeVarError(RuntimeError): @dataclass(frozen=True) -class ClassTypeVarRef: +class TypeVarRef: """ - A class type variable represents one of the types of the class, if it is a generic class. + A generic type variable reference. """ ident: Ident - def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: + def to_just(self, vars: dict[TypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: if vars is None or self not in vars: raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings") return vars[self] @@ -395,7 +567,7 @@ def __str__(self) -> str: return str(self.to_type_var()) @classmethod - def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef: + def from_type_var(cls, typevar: TypeVar) -> TypeVarRef: res = cls(Ident(typevar.__name__, typevar.__module__)) _RESOLVED_TYPEVARS[res] = typevar return res @@ -403,7 +575,7 @@ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef: def to_type_var(self) -> TypeVar: return _RESOLVED_TYPEVARS[self] - def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool: + def matches_just(self, vars: dict[TypeVarRef, JustTypeRef], other: JustTypeRef) -> bool: """ Checks if this type variable matches the given JustTypeRef, including type variables. """ @@ -412,13 +584,20 @@ def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustType vars[self] = other return True + @property + def vars(self) -> set[TypeVarRef]: + """ + Returns all type variables in this type reference. + """ + return {self} + @dataclass(frozen=True) class TypeRefWithVars: ident: Ident args: tuple[TypeOrVarRef, ...] = () - def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: + def to_just(self, vars: dict[TypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: return JustTypeRef(self.ident, tuple(a.to_just(vars) for a in self.args)) def __str__(self) -> str: @@ -426,7 +605,7 @@ def __str__(self) -> str: return f"{self.ident.name}[{', '.join(str(a) for a in self.args)}]" return str(self.ident.name) - def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool: + def matches_just(self, vars: dict[TypeVarRef, JustTypeRef], other: JustTypeRef) -> bool: """ Checks if this type reference matches the given JustTypeRef, including type variables. """ @@ -436,8 +615,18 @@ def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustType and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True)) ) + @property + def vars(self) -> set[TypeVarRef]: + """ + Returns all type variables in this type reference. + """ + vars = set[TypeVarRef]() + for arg in self.args: + vars.update(arg.vars) + return vars + -TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars +TypeOrVarRef: TypeAlias = TypeVarRef | TypeRefWithVars ## # Callables References @@ -588,6 +777,25 @@ def semantic_return_type(self) -> TypeOrVarRef: def mutates(self) -> bool: return self.return_type is None + @property + def arg_vars(self) -> set[TypeVarRef]: + """ + Returns all type variables in the argument types. + """ + vars = set[TypeVarRef]() + for arg in self.arg_types: + vars.update(arg.vars) + if self.var_arg_type: + vars.update(self.var_arg_type.vars) + return vars + + @property + def all_args(self) -> Iterable[TypeOrVarRef]: + """ + Returns all argument types, including var args. + """ + return chain(self.arg_types, (repeat(self.var_arg_type) if self.var_arg_type else [])) + @dataclass(frozen=True) class FunctionDecl: @@ -620,6 +828,11 @@ class UnboundVarDecl: egg_name: str | None = None +@dataclass(frozen=True) +class DummyDecl: + pass + + @dataclass(frozen=True) class LetRefDecl: name: str @@ -669,7 +882,7 @@ def __new__(cls, *args: object, **kwargs: object) -> Self: """ Pool CallDecls so that they can be compared by identity more quickly. - Neccessary bc we search for common parents when serializing CallDecl trees to egglog to + Necessary bc we search for common parents when serializing CallDecl trees to egglog to only serialize each sub-tree once. """ # normalize the args/kwargs to a tuple so that they can be compared @@ -711,7 +924,7 @@ class PartialCallDecl: Note it does not need to have any args, in which case it's just a function pointer. - Seperated from the call decl so it's clear it is translated to a `unstable-fn` call. + Separated from the call decl so it's clear it is translated to a `unstable-fn` call. """ call: CallDecl @@ -729,7 +942,15 @@ class ValueDecl: ExprDecl: TypeAlias = ( - UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl | ValueDecl | GetCostDecl + DummyDecl + | UnboundVarDecl + | LetRefDecl + | LitDecl + | CallDecl + | PyObjectDecl + | PartialCallDecl + | ValueDecl + | GetCostDecl ) diff --git a/python/egglog/deconstruct.py b/python/egglog/deconstruct.py index 26a35622..1a953b06 100644 --- a/python/egglog/deconstruct.py +++ b/python/egglog/deconstruct.py @@ -23,7 +23,14 @@ T = TypeVar("T", bound=BaseExpr) TS = TypeVarTuple("TS", default=Unpack[tuple[BaseExpr, ...]]) -__all__ = ["get_callable_args", "get_callable_fn", "get_let_name", "get_literal_value", "get_var_name"] +__all__ = [ + "get_callable_args", + "get_callable_fn", + "get_constant_name", + "get_let_name", + "get_literal_value", + "get_var_name", +] @overload @@ -74,6 +81,19 @@ def get_literal_value(x: object) -> object: return None +def get_constant_name(x: BaseExpr) -> Ident | None: + """ + Check if the expression is a constant and return its name. + If it is not a constant, return None. + """ + if not isinstance(x, RuntimeExpr): + raise TypeError(f"Expected Expression, got {type(x).__name__}") + match x.__egg_typed_expr__.expr: + case CallDecl(ConstantRef(ident)): + return ident + return None + + def get_let_name(x: BaseExpr) -> str | None: """ Check if the expression is a `let` expression and return the name of the variable. @@ -168,8 +188,8 @@ def _deconstruct_call_decl( TypeRefWithVars(call.callable.ident, tuple(tp.to_var() for tp in (call.bound_tp_params or []))), ), arg_exprs egg_bound = ( - JustTypeRef(call.callable.ident, call.bound_tp_params or ()) - if isinstance(call.callable, ClassMethodRef) + JustTypeRef(call.callable.ident, call.bound_tp_params) + if isinstance(call.callable, (ClassMethodRef, MethodRef)) and call.bound_tp_params else None ) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 2dc75985..8fc6643b 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -31,9 +31,11 @@ from warnings import warn import graphviz +from opentelemetry import trace from typing_extensions import ParamSpec, Unpack from . import bindings +from ._tracing import call_with_current_trace from .conversion import * from .conversion import convert_to_same_type, resolve_literal from .declarations import * @@ -47,8 +49,12 @@ from .builtins import String, Unit, i64, i64Like +_TRACER = trace.get_tracer(__name__) + + __all__ = [ "Action", + "ActionLike", "BackOff", "BaseExpr", "BuiltinExpr", @@ -322,7 +328,6 @@ def __new__( # type: ignore[misc] if not bases or bases == (BaseExpr,): return super().__new__(cls, name, bases, namespace) builtin = BuiltinExpr in bases - # TODO: Raise error on subclassing or multiple inheritence frame = currentframe() assert frame @@ -395,20 +400,24 @@ def _generate_class_decls( # noqa: C901,PLR0912 runtime_cls: RuntimeClass, ) -> Declarations: """ - Lazy constructor for class declerations to support classes with methods whose types are not yet defined. + Lazy constructor for class declarations to support classes with methods whose types are not yet defined. """ parameters: list[TypeVar] = ( # Get the generic params from the orig bases generic class namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in namespace else [] ) - type_vars = tuple(ClassTypeVarRef.from_type_var(p) for p in parameters) + type_vars = tuple(TypeVarRef.from_type_var(p) for p in parameters) del parameters cls_decl = ClassDecl( egg_sort, type_vars, builtin, match_args=namespace.pop("__match_args__", ()), doc=namespace.pop("__doc__", None) ) decls = Declarations(_classes={cls_ident: cls_decl}) - # Update class think eagerly when resolving so that lookups work in methods + # Update class thunk eagerly when resolving so that lookups work in methods. runtime_cls.__egg_decls_thunk__ = Thunk.value(decls) + # Cached RuntimeFunction/RuntimeExpr wrappers capture the current decl thunk, so + # swapping in the concrete declarations must invalidate any wrappers created while + # the class was still pointing at the lazy declaration builder. + runtime_cls.__egg_attr_cache__.clear() ## # Register class variables @@ -591,7 +600,7 @@ def _fn_decl( else: var_arg_type = None arg_types = tuple( - decls.get_paramaterized_class(ref.ident) + decls.get_parameterized_class(ref.ident) if i == 0 and isinstance(ref, MethodRef | PropertyRef) else resolve_type_annotation_mutate(decls, hints[t.name]) for i, t in enumerate(params) @@ -606,7 +615,7 @@ def _fn_decl( decls.update(*arg_defaults) return_type = ( - decls.get_paramaterized_class(ref.ident) + decls.get_parameterized_class(ref.ident) if isinstance(ref, InitRef) else arg_types[0] if mutates_first_arg @@ -663,6 +672,8 @@ def _fn_decl( doc=doc, ) decls.set_function_decl(ref, decl) + if is_builtin: + return lambda: None return Thunk.fn( _add_default_rewrite_function, decls, @@ -778,7 +789,7 @@ def _add_default_rewrite_function( arg_exprs: list[RuntimeExpr | RuntimeClass] = [RuntimeExpr.__from_values__(decls, a) for a in args] # If this is a classmethod, add the class as the first arg if isinstance(ref, ClassMethodRef): - tp = decls.get_paramaterized_class(ref.ident) + tp = decls.get_parameterized_class(ref.ident) arg_exprs.insert(0, RuntimeClass(Thunk.value(decls), tp)) with set_current_ruleset(ruleset): res = fn(*arg_exprs) @@ -806,6 +817,7 @@ def _add_default_rewrite( resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls)) rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume) ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset) + ruleset_decls |= decls ruleset_decls |= resolved_value @@ -858,20 +870,27 @@ class EGraph: Can run actions, check facts, run schedules, or extract minimal cost expressions. """ - seminaive: InitVar[bool] = True - save_egglog_string: InitVar[bool] = False - _state: EGraphState = field(init=False, repr=False) # For pushing/popping with egglog _state_stack: list[EGraphState] = field(default_factory=list, repr=False) # For storing the global "current" egraph _token_stack: list[EGraph] = field(default_factory=list, repr=False) - def __post_init__(self, seminaive: bool, save_egglog_string: bool) -> None: - egraph = bindings.EGraph(seminaive=seminaive, record=save_egglog_string) - self._state = EGraphState(egraph) - - def _add_decls(self, *decls: DeclerationsLike) -> None: + def __init__( + self, + *actions: ActionLike, + seminaive: bool = True, + save_egglog_string: bool = False, + ) -> None: + with _TRACER.start_as_current_span("create"): + with _TRACER.start_as_current_span("create_bindings"): + self._state = EGraphState(bindings.EGraph(seminaive=seminaive, record=save_egglog_string)) + self._state_stack = [] + self._token_stack = [] + if actions: + self.register(*actions) + + def _add_decls(self, *decls: DeclarationsLike) -> None: for d in decls: self._state.__egg_decls__ |= d @@ -899,7 +918,7 @@ def input(self, fn: Callable[..., String], path: str) -> None: """ Loads a CSV file and sets it as *input, output of the function. """ - self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn)[1], path)) + self._run_program(bindings.Input(span(1), self._callable_to_egg(fn)[1], path)) def _callable_to_egg(self, fn: ExprCallable) -> tuple[CallableRef, str]: ref, decls = resolve_callable(fn) @@ -939,6 +958,7 @@ def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bi @overload def run(self, schedule: Schedule, /) -> bindings.RunReport: ... + @_TRACER.start_as_current_span("run") def run( self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None ) -> bindings.RunReport: @@ -952,7 +972,7 @@ def run( def _run_schedule(self, schedule: Schedule) -> bindings.RunReport: self._add_decls(schedule) cmd = self._state.run_schedule_to_egg(schedule.schedule) - (command_output,) = self._egraph.run_program(cmd) + (command_output,) = self._run_program(cmd) assert isinstance(command_output, bindings.RunScheduleOutput) return command_output.report @@ -960,7 +980,7 @@ def stats(self) -> bindings.RunReport: """ Returns the overall run report for the egraph. """ - (output,) = self._egraph.run_program(bindings.PrintOverallStatistics(span(1), None)) + (output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None)) assert isinstance(output, bindings.OverallStatistics) return output.report @@ -977,17 +997,18 @@ def check_bool(self, *facts: FactLike) -> bool: raise return True + @_TRACER.start_as_current_span("check") def check(self, *facts: FactLike) -> None: """ Check if a fact is true in the egraph. """ - self._egraph.run_program(self._facts_to_check(facts)) + self._run_program(self._facts_to_check(facts)) def check_fail(self, *facts: FactLike) -> None: """ Checks that one of the facts is not true """ - self._egraph.run_program(bindings.Fail(span(1), self._facts_to_check(facts))) + self._run_program(bindings.Fail(span(1), self._facts_to_check(facts))) def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check: facts = _fact_likes(fact_likes) @@ -1010,6 +1031,7 @@ def extract( self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: CostModel[COST] ) -> tuple[BASE_EXPR, COST]: ... + @_TRACER.start_as_current_span("extract") def extract( self, expr: BASE_EXPR, /, include_cost: bool = False, cost_model: CostModel[COST] | None = None ) -> BASE_EXPR | tuple[BASE_EXPR, COST]: @@ -1029,14 +1051,14 @@ def extract( self.register(expr) egg_cost_model = _CostModel(cost_model, self).to_bindings_cost_model() egg_sort = self._state.type_ref_to_egg(tp) - extractor = bindings.Extractor([egg_sort], self._state.egraph, egg_cost_model) + extractor = call_with_current_trace(bindings.Extractor, [egg_sort], self._state.egraph, egg_cost_model) termdag = bindings.TermDag() value = self._state.typed_expr_to_value(runtime_expr.__egg_typed_expr__) - cost, term = extractor.extract_best(self._state.egraph, termdag, value, egg_sort) + cost, term = call_with_current_trace(extractor.extract_best, self._state.egraph, termdag, value, egg_sort) res = self._from_termdag(termdag, term, tp) return (res, cost) if include_cost else res - def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any: + def _from_termdag(self, termdag: bindings.TermDag, term: int, tp: JustTypeRef) -> Any: (new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp) return RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr) @@ -1062,24 +1084,26 @@ def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput: else: cmd = bindings.Extract(span(2), *args) try: - return self._egraph.run_program(cmd)[0] + return self._run_program(cmd)[0] except BaseException as e: - e.add_note("while extracting expr:\n" + str(expr)) + e.add_note("while extracting: " + str(expr)) raise + @_TRACER.start_as_current_span("push") def push(self) -> None: """ Push the current state of the egraph, so that it can be popped later and reverted back. """ - self._egraph.run_program(bindings.Push(1)) + self._run_program(bindings.Push(1)) self._state_stack.append(self._state) self._state = self._state.copy() + @_TRACER.start_as_current_span("pop") def pop(self) -> None: """ Pop the current state of the egraph, reverting back to the previous state. """ - self._egraph.run_program(bindings.Pop(span(1), 1)) + self._run_program(bindings.Pop(span(1), 1)) self._state = self._state_stack.pop() def __enter__(self) -> Self: @@ -1104,7 +1128,8 @@ def _serialize( split_functions = kwargs.pop("split_functions", []) include_temporary_functions = kwargs.pop("include_temporary_functions", False) n_inline_leaves = kwargs.pop("n_inline_leaves", 0) - serialized = self._egraph.serialize( + serialized = call_with_current_trace( + self._egraph.serialize, [], max_functions=max_functions, max_calls_per_function=max_calls_per_function, @@ -1178,13 +1203,14 @@ def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> N serialized = self._serialize(**kwargs) VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open() - def saturate( + def saturate( # noqa: C901 self, schedule: Schedule | None = None, *, expr: Expr | None = None, max: int = 1000, visualize: bool = True, + print_frozen: bool = False, **kwargs: Unpack[GraphvizKwargs], ) -> None: """ @@ -1209,9 +1235,15 @@ def to_json() -> str: i += 1 if visualize: egraphs.append(to_json()) + if print_frozen: + print(f"After iteration {i}:") + print(self.freeze()) + print("\n") except: if visualize: egraphs.append(to_json()) + if print_frozen: + print(self.freeze()) raise finally: if visualize: @@ -1221,10 +1253,14 @@ def to_json() -> str: def _egraph(self) -> bindings.EGraph: return self._state.egraph + def _run_program(self, *commands: bindings._Command) -> list[bindings._CommandOutput]: + return call_with_current_trace(self._egraph.run_program, *commands) + @property def __egg_decls__(self) -> Declarations: return self._state.__egg_decls__ + @_TRACER.start_as_current_span("register") def register( self, /, @@ -1249,7 +1285,7 @@ def register( def _register_commands(self, cmds: list[Command]) -> None: self._add_decls(*cmds) egg_cmds = [egg_cmd for cmd in cmds if (egg_cmd := self._command_to_egg(cmd)) is not None] - self._egraph.run_program(*egg_cmds) + self._run_program(*egg_cmds) def _command_to_egg(self, cmd: Command) -> bindings._Command | None: ruleset_ident = Ident("") @@ -1269,7 +1305,7 @@ def function_size(self, fn: ExprCallable) -> int: Returns the number of rows in a certain function """ egg_name = self._callable_to_egg(fn)[1] - (output,) = self._egraph.run_program(bindings.PrintSize(span(1), egg_name)) + (output,) = self._run_program(bindings.PrintSize(span(1), egg_name)) assert isinstance(output, bindings.PrintFunctionSize) return output.size @@ -1277,7 +1313,7 @@ def all_function_sizes(self) -> list[tuple[ExprCallable, int]]: """ Returns a list of all functions and their sizes. """ - (output,) = self._egraph.run_program(bindings.PrintSize(span(1), None)) + (output,) = self._run_program(bindings.PrintSize(span(1), None)) assert isinstance(output, bindings.PrintAllFunctionsSize) return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))] @@ -1298,7 +1334,7 @@ def function_values( """ ref, egg_name = self._callable_to_egg(fn) cmd = bindings.PrintFunction(span(1), egg_name, length, None, bindings.DefaultPrintFunctionMode()) - (output,) = self._egraph.run_program(cmd) + (output,) = self._run_program(cmd) assert isinstance(output, bindings.PrintFunctionOutput) signature = self.__egg_decls__.get_callable_decl(ref).signature assert isinstance(signature, FunctionSignature) @@ -1335,6 +1371,144 @@ def has_custom_cost(self, fn: ExprCallable) -> bool: resolved, _ = resolve_callable(fn) return resolved in self._state.cost_callables + def freeze(self) -> FrozenEGraph: # noqa: C901,PLR0912 + """ + Freeze the current e-graph state for debugging. + + The returned :class:`FrozenEGraph` contains a snapshot of the current + declarations and can be pretty-printed back into replayable high-level + actions with ``str(...)``. + + This is useful when debugging unexpected unions, sets, subsumptions, or + costs after a run: + + >>> from egglog import * + >>> class Math(Expr): + ... def __init__(self, value: i64Like) -> None: ... + ... + >>> egraph = EGraph(Math(1)) + >>> str(egraph.freeze()) + 'EGraph(Math(1)).freeze()' + """ + frozen = self._egraph.freeze() + let_bindings: dict[str, TypedExprDecl] = {} + e_classes: dict[bindings.Value, tuple[JustTypeRef, list[CallDecl]]] = {} + sets: dict[CallDecl, TypedExprDecl] = {} + costs: dict[CallDecl, tuple[JustTypeRef, int]] = {} + expr_actions: list[TypedExprDecl] = [] + subsumed: list[tuple[JustTypeRef, CallDecl]] = [] + + def append_e_class_row(output: bindings.Value, tp: JustTypeRef, call: CallDecl, is_subsumed: bool) -> None: + if output not in e_classes: + e_classes[output] = (tp, []) + e_class_tp, exprs = e_classes[output] + assert e_class_tp == tp + exprs.append(call) + if is_subsumed: + subsumed.append((tp, call)) + + for name, fn in frozen.functions.items(): + if fn.is_let_binding: + if name.startswith("$__expr_"): + continue + for row in fn.rows: + output_tp = self._state.egg_sort_to_type_ref[fn.output_sort] + let_bindings[name] = TypedExprDecl(output_tp, self._state.value_to_expr(output_tp, row.output)) + continue + is_cost = False + if name in self._state.egg_fn_to_callable_refs: + (callable_ref,) = self._state.egg_fn_to_callable_refs[name] + else: + (callable_ref,) = ( + ref for ref in self._state.cost_callables if name == self._state.cost_table_name(ref) + ) + is_cost = True + callable_decl = self.__egg_decls__.get_callable_decl(callable_ref) + signature = callable_decl.signature + assert isinstance(signature, FunctionSignature), ( + f"Cannot freeze special callable {callable_ref} with signature {signature}" + ) + assert signature.var_arg_type is None, f"Frozen calls do not support var args: {callable_ref}" + assert not signature.reverse_args, f"Frozen calls do not support reverse_args: {callable_ref}" + + for row in fn.rows: + arg_exprs = tuple( + TypedExprDecl(tp, self._state.value_to_expr(tp, value)) + for arg_type, value in zip(signature.arg_types, row.inputs, strict=True) + for tp in (arg_type.to_just(),) + ) + call = CallDecl(callable_ref, arg_exprs) + if is_cost: + cost_tp = self._state.egg_sort_to_type_ref[fn.output_sort] + cost_expr = TypedExprDecl(cost_tp, self._state.value_to_expr(cost_tp, row.output)) + match cost_expr.expr: + case LitDecl(int(value)): + costs[call] = (signature.semantic_return_type.to_just(), value) + case _: + raise TypeError(f"Expected integer cost for {callable_ref}, got {cost_expr.expr}") + continue + + output_tp = signature.semantic_return_type.to_just() + match callable_decl: + case ConstructorDecl(): + append_e_class_row(row.output, output_tp, call, row.subsumed) + case FunctionDecl(): + set_tp = self._state.egg_sort_to_type_ref[fn.output_sort] + sets[call] = TypedExprDecl(set_tp, self._state.value_to_expr(set_tp, row.output)) + case ConstantDecl(type_ref): + if type_ref.ident.module == Ident.builtin("").module: + set_tp = self._state.egg_sort_to_type_ref[fn.output_sort] + sets[call] = TypedExprDecl(set_tp, self._state.value_to_expr(set_tp, row.output)) + continue + append_e_class_row(row.output, output_tp, call, row.subsumed) + case RelationDecl(): + if row.subsumed: + raise TypeError(f"Cannot freeze subsumed relation row for {callable_ref}") + expr_actions.append(TypedExprDecl(output_tp, call)) + case _: + assert_never(callable_decl) + + return FrozenEGraph( + self.__egg_decls__.copy(), + EGraphDecl( + let_bindings=let_bindings, + e_classes={value: (tp, tuple(exprs)) for value, (tp, exprs) in e_classes.items()}, + sets=sets, + expr_actions=tuple(expr_actions), + costs=costs, + subsumed=tuple(subsumed), + ), + ) + + def _values_to_expr(self, args: list[bindings.Value], name: str) -> RuntimeExpr | None: + if name not in self._state.egg_fn_to_callable_refs: + return None + (callable_ref,) = self._state.egg_fn_to_callable_refs[name] + signature = self.__egg_decls__.get_callable_decl(callable_ref).signature + assert isinstance(signature, FunctionSignature) + arg_exprs = tuple( + TypedExprDecl(tp, self._state.value_to_expr(tp, arg)) + for arg_type, arg in zip(signature.arg_types, args, strict=True) + for tp in (arg_type.to_just(),) + ) + res_type = signature.semantic_return_type.to_just() + return RuntimeExpr.__from_values__( + self.__egg_decls__, + TypedExprDecl(res_type, CallDecl(callable_ref, arg_exprs)), + ) + + +@dataclass(frozen=True) +class FrozenEGraph: + __egg_decls__: Declarations + decl: EGraphDecl + + def __str__(self) -> str: + return pretty_decl(self.__egg_decls__, self.decl) + + def __repr__(self) -> str: + return str(self) + # Either a constant or a function. ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr @@ -1384,7 +1558,7 @@ def ruleset( @dataclass -class Schedule(DelayedDeclerations): +class Schedule(DelayedDeclarations): """ A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met """ @@ -1414,7 +1588,7 @@ def __add__(self, other: Schedule) -> Schedule: """ Run two schedules in sequence. """ - return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule))) + return Schedule(partial(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule))) @dataclass @@ -1625,11 +1799,14 @@ def set_cost(expr: BaseExpr, cost: i64Like) -> Action: from .builtins import i64 # noqa: PLC0415 expr_runtime = to_runtime_expr(expr) + cost_runtime = to_runtime_expr(convert(cost, i64)) typed_expr_decl = expr_runtime.__egg_typed_expr__ expr_decl = typed_expr_decl.expr assert isinstance(expr_decl, CallDecl), "Can only set cost of calls, not literals or vars" - cost_decl = to_runtime_expr(convert(cost, i64)).__egg_typed_expr__.expr - return Action(expr_runtime.__egg_decls__, SetCostDecl(typed_expr_decl.tp, expr_decl, cost_decl)) + return Action( + Declarations.create(expr_runtime, cost_runtime), + SetCostDecl(typed_expr_decl.tp, expr_decl, cost_runtime.__egg_typed_expr__.expr), + ) def let(name: str, expr: BaseExpr) -> Action: @@ -1892,7 +2069,7 @@ def run(ruleset: Ruleset | None = None, *until: FactLike, scheduler: BackOff | N """ facts = _fact_likes(until) return Schedule( - Thunk.fn(Declarations.create, ruleset, *facts), + partial(Declarations.create, ruleset, *facts), RunDecl( ruleset.__egg_ident__ if ruleset else Ident(""), tuple(f.fact for f in facts), @@ -1935,7 +2112,7 @@ def seq(*schedules: Schedule) -> Schedule: """ Run a sequence of schedules. """ - return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules))) + return Schedule(partial(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules))) def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]: @@ -2214,18 +2391,10 @@ def enode_cost(self, name: str, args: list[bindings.Value]) -> int: return self.enode_cost_results[(name, tuple(args))] except KeyError: pass - (callable_ref,) = self.egraph._state.egg_fn_to_callable_refs[name] - signature = self.egraph.__egg_decls__.get_callable_decl(callable_ref).signature - assert isinstance(signature, FunctionSignature) - arg_exprs = [ - TypedExprDecl(tp.to_just(), self.egraph._state.value_to_expr(tp.to_just(), arg)) - for (arg, tp) in zip(args, signature.arg_types, strict=True) - ] - res_type = signature.semantic_return_type.to_just() - res = RuntimeExpr.__from_values__( - self.egraph.__egg_decls__, - TypedExprDecl(res_type, CallDecl(callable_ref, tuple(arg_exprs))), - ) + res = self.egraph._values_to_expr(args, name) + if res is None: + msg = f"Cannot compute custom cost for unknown egg function {name!r}" + raise ValueError(msg) index = len(self.enode_cost_expressions) self.enode_cost_expressions.append(res) self.enode_cost_results[(name, tuple(args))] = index diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index 2a72418c..1d65aeff 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -12,8 +12,10 @@ from uuid import UUID import cloudpickle +from opentelemetry import trace from . import bindings +from ._tracing import call_with_current_trace from .declarations import * from .declarations import ConstructorDecl from .pretty import * @@ -25,6 +27,9 @@ __all__ = ["EGraphState", "span"] +_TRACER = trace.get_tracer(__name__) + + def span(frame_index: int = 0) -> bindings.RustSpan: """ Returns a span for the current file and line. @@ -39,22 +44,26 @@ def span(frame_index: int = 0) -> bindings.RustSpan: return bindings.RustSpan("", 0, 0) +def _normalize_global_let_name(name: str) -> str: + return name if name.startswith("$") else f"${name}" + + @dataclass class EGraphState: """ - State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined. + State of the EGraph declarations and rulesets, so when we pop/push the stack we know whats defined. Used for converting to/from egg and for pretty printing. """ egraph: bindings.EGraph - # The decleratons we have added. + # The declarations we have added. __egg_decls__: Declarations = field(default_factory=Declarations) # Mapping of added rulesets to the added rules rulesets: dict[Ident, set[RewriteOrRuleDecl]] = field(default_factory=dict) # Bidirectional mapping between egg function names and python callable references. - # Note that there are possibly mutliple callable references for a single egg function name, like `+` + # Note that there are possibly multiple callable references for a single egg function name, like `+` # for both int and rational classes. egg_fn_to_callable_refs: dict[str, set[CallableRef]] = field( default_factory=lambda: defaultdict(set, {"!=": {FunctionRef(Ident.builtin("!="))}}) @@ -72,10 +81,14 @@ class EGraphState: # Callables which have cost tables associated with them cost_callables: set[CallableRef] = field(default_factory=set) + # Counter for deterministic synthetic let bindings created while lowering expressions to egg. + expr_to_let_counter: int = 0 + # Counter for deterministic synthetic names assigned to unnamed functions. + unnamed_function_counter: int = 0 def copy(self) -> EGraphState: """ - Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping. + Returns a copy of the state. The egraph reference is kept the same. Used for pushing/popping. """ return EGraphState( egraph=self.egraph, @@ -87,8 +100,14 @@ def copy(self) -> EGraphState: egg_sort_to_type_ref=self.egg_sort_to_type_ref.copy(), expr_to_egg_cache=self.expr_to_egg_cache.copy(), cost_callables=self.cost_callables.copy(), + expr_to_let_counter=self.expr_to_let_counter, + unnamed_function_counter=self.unnamed_function_counter, ) + def _run_program(self, *commands: bindings._Command) -> list[bindings._CommandOutput]: + return call_with_current_trace(self.egraph.run_program, *commands) + + @_TRACER.start_as_current_span("run_schedule_to_egg") def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command: """ Turn a run schedule into an egg command. @@ -236,7 +255,7 @@ def ruleset_to_egg(self, ident: Ident) -> None: case RulesetDecl(rules): if ident not in self.rulesets: if str(ident): - self.egraph.run_program(bindings.AddRuleset(span(), str(ident))) + self._run_program(bindings.AddRuleset(span(), str(ident))) added_rules = self.rulesets[ident] = set() else: added_rules = self.rulesets[ident] @@ -245,7 +264,7 @@ def ruleset_to_egg(self, ident: Ident) -> None: continue cmd = self.command_to_egg(rule, ident) if cmd is not None: - self.egraph.run_program(cmd) + self._run_program(cmd) added_rules.add(rule) case CombinedRulesetDecl(rulesets): if ident in self.rulesets: @@ -253,7 +272,7 @@ def ruleset_to_egg(self, ident: Ident) -> None: self.rulesets[ident] = set() for ruleset in rulesets: self.ruleset_to_egg(ruleset) - self.egraph.run_program(bindings.UnstableCombinedRuleset(span(), str(ident), list(map(str, rulesets)))) + self._run_program(bindings.UnstableCombinedRuleset(span(), str(ident), list(map(str, rulesets)))) def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command | None: match cmd: @@ -316,8 +335,10 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin return bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr)) case SetDecl(tp, call, rhs): self.type_ref_to_egg(tp) - call_ = self._expr_to_egg(call) - return bindings.Set(span(), call_.name, call_.args, self._expr_to_egg(rhs)) + egg_fn, typed_args = self.translate_call(call) + return bindings.Set( + span(), egg_fn, [self.typed_expr_to_egg(arg, False) for arg in typed_args], self._expr_to_egg(rhs) + ) case ExprActionDecl(typed_expr): if expr_to_let: maybe_typed_expr = self._transform_let(typed_expr) @@ -328,7 +349,7 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin return bindings.Expr_(span(), self.typed_expr_to_egg(typed_expr)) case ChangeDecl(tp, call, change): self.type_ref_to_egg(tp) - call_ = self._expr_to_egg(call) + egg_fn, typed_args = self.translate_call(call) egg_change: bindings._Change match change: case "delete": @@ -337,7 +358,9 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin egg_change = bindings.Subsume() case _: assert_never(change) - return bindings.Change(span(), egg_change, call_.name, call_.args) + return bindings.Change( + span(), egg_change, egg_fn, [self.typed_expr_to_egg(arg, False) for arg in typed_args] + ) case UnionDecl(tp, lhs, rhs): self.type_ref_to_egg(tp) return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs)) @@ -361,9 +384,7 @@ def create_cost_table(self, ref: CallableRef) -> str: signature = self.__egg_decls__.get_callable_decl(ref).signature assert isinstance(signature, FunctionSignature), "Can only add cost tables for functions" signature = replace(signature, return_type=TypeRefWithVars(Ident.builtin("i64"))) - self.egraph.run_program( - bindings.FunctionCommand(span(), name, self._signature_to_egg_schema(signature), None) - ) + self._run_program(bindings.FunctionCommand(span(), name, self._signature_to_egg_schema(signature), None)) return name def cost_table_name(self, ref: CallableRef) -> str: @@ -393,18 +414,16 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C9 reverse_args = False match decl: case RelationDecl(arg_types, _, _): - self.egraph.run_program( - bindings.Relation(span(), egg_name, [self.type_ref_to_egg(a) for a in arg_types]) - ) + self._run_program(bindings.Relation(span(), egg_name, [self.type_ref_to_egg(a) for a in arg_types])) case ConstantDecl(tp, _): - # Use constructor decleration instead of constant b/c constants cannot be extracted + # Use constructor declaration instead of constant b/c constants cannot be extracted # https://github.com/egraphs-good/egglog/issues/334 is_function = self.__egg_decls__._classes[tp.ident].builtin schema = bindings.Schema([], self.type_ref_to_egg(tp)) if is_function: - self.egraph.run_program(bindings.FunctionCommand(span(), egg_name, schema, None)) + self._run_program(bindings.FunctionCommand(span(), egg_name, schema, None)) else: - self.egraph.run_program(bindings.Constructor(span(), egg_name, schema, None, False)) + self._run_program(bindings.Constructor(span(), egg_name, schema, None, False)) case FunctionDecl(signature, builtin, _, merge): if isinstance(signature, FunctionSignature): reverse_args = signature.reverse_args @@ -417,27 +436,26 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C9 if merge: msg = "Cannot specify a merge function for a function that returns unit" raise ValueError(msg) - self.egraph.run_program(bindings.Relation(span(), egg_name, schema.input)) + self._run_program(bindings.Relation(span(), egg_name, schema.input)) else: - self.egraph.run_program( + self._run_program( bindings.FunctionCommand( span(), egg_name, self._signature_to_egg_schema(signature), self._expr_to_egg(merge) if merge else None, - ) + ), ) case ConstructorDecl(signature, _, cost, unextractable): - self.egraph.run_program( + self._run_program( bindings.Constructor( span(), egg_name, self._signature_to_egg_schema(signature), cost, unextractable, - ) + ), ) - case _: assert_never(decl) self.callable_ref_to_egg_fn[ref] = egg_name, reverse_args @@ -449,9 +467,10 @@ def _signature_to_egg_schema(self, signature: FunctionSignature) -> bindings.Sch self.type_ref_to_egg(signature.semantic_return_type.to_just()), ) - def type_ref_to_egg(self, ref: JustTypeRef) -> str: # noqa: C901, PLR0912 + def type_ref_to_egg(self, ref: JustTypeRef) -> str: """ - Returns the egg sort name for a type reference, registering it if it is not already registered. + Returns the egg sort name for a type reference, registering it not already registered, and also recursively + any type args are registered. """ try: return self.type_ref_to_egg_sort[ref] @@ -460,47 +479,35 @@ def type_ref_to_egg(self, ref: JustTypeRef) -> str: # noqa: C901, PLR0912 decl = self.__egg_decls__._classes[ref.ident] self.type_ref_to_egg_sort[ref] = egg_name = (not ref.args and decl.egg_name) or _generate_type_egg_name(ref) self.egg_sort_to_type_ref[egg_name] = ref - if not decl.builtin or ref.args: + + if decl.builtin: + # If this has args, create a new parameterized version of the builtin class if ref.args: if ref.ident == Ident.builtin("UnstableFn"): - # UnstableFn is a special case, where the rest of args are collected into a call - if len(ref.args) < 2: - msg = "Zero argument higher order functions not supported" - raise NotImplementedError(msg) type_args: list[bindings._Expr] = [ bindings.Call( span(), self.type_ref_to_egg(ref.args[1]), [bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args[2:]], - ), + ) + if len(ref.args) > 1 + else bindings.Lit(span(), bindings.Unit()), bindings.Var(span(), self.type_ref_to_egg(ref.args[0])), ] else: - # If any of methods have another type ref in them process all those first with substituted vars - # so that things like multiset - mapp will be added. Function type must be added first. - # Find all args of all methods and find any with type args themselves that are not this type and add them - tcs = TypeConstraintSolver(self.__egg_decls__) - tcs.bind_class(ref) - for method in decl.methods.values(): - if not isinstance((signature := method.signature), FunctionSignature): - continue - for arg_tp in signature.arg_types: - if isinstance(arg_tp, TypeRefWithVars) and arg_tp.args and arg_tp.ident != ref.ident: - self.type_ref_to_egg(tcs.substitute_typevars(arg_tp, ref.ident)) - type_args = [bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args] - args = (self.type_ref_to_egg(JustTypeRef(ref.ident)), type_args) - else: - args = None - self.egraph.run_program(bindings.Sort(span(), egg_name, args)) - # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because - # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted - # even if you never use that function. - if decl.builtin: + assert decl.egg_name + self._run_program(bindings.Sort(span(), egg_name, (decl.egg_name, type_args))) + + # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods. + # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted + # even if you never use that function. for method_name in decl.class_methods: self.callable_ref_to_egg(ClassMethodRef(ref.ident, method_name)) if decl.init: self.callable_ref_to_egg(InitRef(ref.ident)) + else: + self._run_program(bindings.Sort(span(), egg_name, None)) return egg_name @@ -543,14 +550,14 @@ def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl | None: """ Rewrites this expression as a let binding if it's not already a let binding. """ - # TODO: Replace with counter so that it works with hash collisions and is more stable - var_decl = LetRefDecl(f"__expr_{hash(typed_expr)}") - if var_decl in self.expr_to_egg_cache: + if isinstance(self.expr_to_egg_cache.get(typed_expr.expr), bindings.Var): return None + var_decl = LetRefDecl(f"$__expr_{self.expr_to_let_counter}") + self.expr_to_let_counter += 1 var_egg = self._expr_to_egg(var_decl) cmd = bindings.ActionCommand(bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr))) try: - self.egraph.run_program(cmd) + self._run_program(cmd) # errors when creating let bindings for things like `(vec-empty)` except bindings.EggSmolError: return typed_expr @@ -578,7 +585,7 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912, res: bindings._Expr match expr_decl: case LetRefDecl(name): - res = bindings.Var(span(), f"{name}") + res = bindings.Var(span(), _normalize_global_let_name(name)) case UnboundVarDecl(name, egg_name): res = bindings.Var(span(), egg_name or f"_{name}") case LitDecl(value): @@ -617,6 +624,9 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912, case ValueDecl(): msg = "Cannot turn a Value into an expression" raise ValueError(msg) + case DummyDecl(): + msg = "Cannot turn a DummyDecl into an expression" + raise ValueError(msg) case _: assert_never(expr_decl.expr) self.expr_to_egg_cache[expr_decl] = res @@ -639,14 +649,12 @@ def translate_call(self, expr: CallDecl | GetCostDecl) -> tuple[str, list[TypedE case _: assert_never(expr) - def exprs_from_egg( - self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef - ) -> Iterable[TypedExprDecl]: + def exprs_from_egg(self, termdag: bindings.TermDag, terms: list[int], tp: JustTypeRef) -> Iterable[TypedExprDecl]: """ Create a function that can convert from an egg term to a typed expr. """ state = FromEggState(self, termdag) - return [state.from_expr(tp, term) for term in terms] + return [state.resolve_term(term_id, tp) for term_id in terms] def _get_possible_types(self, cls_ident: Ident) -> frozenset[JustTypeRef]: """ @@ -674,11 +682,10 @@ def _generate_callable_egg_name(self, ref: CallableRef) -> str: return f"{cls_ident}.{name}" case InitRef(cls_ident): return f"{cls_ident}.__init__" - case UnnamedFunctionRef(args, val): - parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [ - str(self.typed_expr_to_egg(val, False)) - ] - return "_".join(parts) + case UnnamedFunctionRef(): + name = f"_lambda_{self.unnamed_function_counter}" + self.unnamed_function_counter += 1 + return name case _: assert_never(ref) @@ -686,7 +693,7 @@ def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value: if isinstance(typed_expr.expr, ValueDecl): return typed_expr.expr.value egg_expr = self.typed_expr_to_egg(typed_expr, False) - return self.egraph.eval_expr(egg_expr)[1] + return call_with_current_trace(self.egraph.eval_expr, egg_expr)[1] def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # noqa: C901, PLR0911, PLR0912 if tp.ident.module != Ident.builtin("").module: @@ -770,7 +777,7 @@ def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # return CallDecl( InitRef(Ident.builtin("Set")), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs_), - (v_tp,), + (v_tp,) if not xs_ else (), ) case "Vec": xs = self.egraph.value_to_vec(value) @@ -778,7 +785,7 @@ def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # return CallDecl( InitRef(Ident.builtin("Vec")), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), - (v_tp,), + (v_tp,) if not xs else (), ) case "MultiSet": xs = self.egraph.value_to_multiset(value) @@ -786,19 +793,20 @@ def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # return CallDecl( InitRef(Ident.builtin("MultiSet")), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), - (v_tp,), + (v_tp,) if not xs else (), ) case "UnstableFn": _names, _args = self.egraph.value_to_function(value) return_tp, *arg_types = tp.args return self._unstable_fn_value_to_expr(_names, _args, return_tp, arg_types) case _: - raise NotImplementedError(f"Value to expr not implemented for type {tp.ident}") + # If this is not a builtin type, or we don't know how to convert it, just return as value + return ValueDecl(value) def _unstable_fn_value_to_expr( self, name: str, partial_args: list[bindings.Value], return_tp: JustTypeRef, _arg_types: list[JustTypeRef] ) -> PartialCallDecl: - # Similar to FromEggState::from_call but accepts partial list of args and returns in values + # Similar to FromEggState::from_call but reconstructs a partial application from serialized values. # Find first callable ref whose return type matches and fill in arg types. for callable_ref in self.egg_fn_to_callable_refs[name]: signature = self.__egg_decls__.get_callable_decl(callable_ref).signature @@ -806,23 +814,13 @@ def _unstable_fn_value_to_expr( continue if signature.semantic_return_type.ident != return_tp.ident: continue - tcs = TypeConstraintSolver(self.__egg_decls__) - - arg_types, bound_tp_params = tcs.infer_arg_types( - signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp, None + arg_types = TypeConstraintSolver().infer_arg_types( + signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp ) - args = tuple( TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False) ) - - call_decl = CallDecl( - callable_ref, - args, - # Don't include bound type params if this is just a method, we only needed them for type resolution - # but dont need to store them - bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (), - ) + call_decl = CallDecl(callable_ref, args) return PartialCallDecl(call_decl) raise ValueError(f"Function '{name}' not found") @@ -919,11 +917,7 @@ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl: assert_never(term) return TypedExprDecl(tp, expr_decl) - def from_call( - self, - tp: JustTypeRef, - term: bindings.TermApp, # additional_arg_tps: tuple[JustTypeRef, ...] - ) -> CallDecl: + def from_call(self, tp: JustTypeRef, term: bindings.TermApp) -> CallDecl: """ Convert a call to a CallDecl. @@ -941,33 +935,32 @@ def from_call( signature = self.decls.get_callable_decl(callable_ref).signature assert isinstance(signature, FunctionSignature) if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef): - # Need OR in case we have class method whose class whas never added as a sort, which would happen + # Need OR in case we have class method whose class was never added as a sort, which would happen # if the class method didn't return that type and no other function did. In this case, we don't need - # to care about the type vars and we we don't need to bind any possible type. + # to care about the type vars and we don't need to bind any possible type. possible_types = self.state._get_possible_types(callable_ref.ident) or [None] - cls_name = callable_ref.ident else: possible_types = [None] - cls_name = None for possible_type in possible_types: - tcs = TypeConstraintSolver(self.decls) + tcs = TypeConstraintSolver() if possible_type and possible_type.args: - tcs.bind_class(possible_type) + tcs.bind_class(possible_type, self.decls) + bound_args = possible_type.args + else: + bound_args = () try: - arg_types, bound_tp_params = tcs.infer_arg_types( - signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name + arg_types = tcs.infer_arg_types( + signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp ) + # Include this in try because of iterable + a_tp = list(zip(term.args, arg_types, strict=False)) except TypeConstraintError: continue - args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False)) - - return CallDecl( - callable_ref, - args, - # Don't include bound type params if this is just a method, we only needed them for type resolution - # but dont need to store them - bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (), - ) + args = tuple(self.resolve_term(a, tp) for a, tp in a_tp) + # Only save bound tp params if needed for inferring return type + # this is true if the set of set of type vars in the return are not a subset of those in the args + bound_tp_params = () if signature.semantic_return_type.vars.issubset(signature.arg_vars) else bound_args + return CallDecl(callable_ref, args, bound_tp_params) raise ValueError( f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}" ) diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index b69d8212..42ea19a0 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -1,53 +1,5 @@ """ - - -## Lists - -Lists have two main constructors: - -- `List(length, idx_fn)` -- `List.EMPTY` / `initial.append(last)` - -This is so that they can be defined either with a known fixed integer length (the cons list type) or a symbolic -length that could not be resolved to an integer. - -There are rewrites to convert between these constructors in both directions. The only limitation however is that -`length` has to a real i64 in order to be converted to a cons list. - -When you are writing a function that uses ints, feel free to the `__getitem__` or `length()` methods or match -directly on `List()` constructor. If you can write your function using that interface please do. But for some other -methods whether the resulting length/index function is dependent on the rest of it, you can only define it with a known -length, so you can then use the const list constructors. - -We also support creating lists from vectors. These can be converted one to one to the snoc list representation. - -It is troublesome to have to redefine lists for every type. It would be nice to have generic types, but they are not implemented yet. - -We are gauranteed that all lists with known lengths will be represented as cons/empty. To safely use lists, use -the `.length` and `.__getitem__` methods, unles you want to to depend on it having known length, in which -case you can match directly on the cons list. - -To be a list, you must implement two methods: - -* `l.length() -> Int` -* `l.__getitem__(i: Int) -> T` - -There are three main types of constructors for lists which all implement these methods: - -* Functional `List(length, idx_fn)` -* cons (well reversed cons) lists `List.EMPTY` and `l.append(x)` -* Vectors `List.from_vec(vec)` - -Also all lists constructors must be converted to the functional representation, so that we can match on it -and convert lists with known lengths into cons lists and into vectors. - -This is neccessary so that known length lists are properly materialized during extraction. - -Q: Why are they implemented as SNOC lists instead of CONS lists? -A: So that when converting from functional to lists we can use the same index function by starting at the end and folding - that way recursively. - - +Experimental Array API support. """ # mypy: disable-error-code="empty-body" @@ -62,10 +14,13 @@ import sys from collections.abc import Callable from copy import copy +from fractions import Fraction +from functools import partial from types import EllipsisType -from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast +from typing import TYPE_CHECKING, ClassVar, Protocol, TypeAlias, TypeVar, cast import numpy as np +from opentelemetry import trace from egglog import * from egglog.runtime import RuntimeExpr @@ -76,6 +31,8 @@ from collections.abc import Iterator from types import ModuleType +_TRACER = trace.get_tracer(__name__) + # Pretend that exprs are numbers b/c sklearn does isinstance checks numbers.Integral.register(RuntimeExpr) @@ -88,18 +45,46 @@ class Boolean(Expr, ruleset=array_api_ruleset): + """ + A boolean expression + """ + + NEVER: ClassVar[Boolean] + def __init__(self, value: BoolLike) -> None: ... @method(preserve=True) def __bool__(self) -> bool: + """ + >>> bool(Boolean(True)) + True + >>> bool(Boolean(False)) + False + """ + # Special case bool so it works when comparing to arrays outside of tracing, like when indexing + if ( + not _CURRENT_EGRAPH + and ( + args := get_callable_args(self, Int.__eq__) + or get_callable_args(self, Boolean.__eq__) # type: ignore[arg-type] + or get_callable_args(self, Value.__eq__) # type: ignore[arg-type] + ) + is not None + ): + return bool(eq(args[0]).to(cast("Int", args[1]))) return self.eval() @method(preserve=True) def eval(self) -> bool: - return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool) + return try_evaling(self) + @method(preserve=True) # type: ignore[prop-decorator] @property - def to_bool(self) -> Bool: ... + def value(self) -> bool: + match get_callable_args(self, Boolean): + case (b,): + return cast("Bool", b).value + raise ExprValueError(self, "Boolean(b)") def __or__(self, other: BooleanLike) -> Boolean: ... @@ -109,8 +94,17 @@ def __invert__(self) -> Boolean: ... def __eq__(self, other: BooleanLike) -> Boolean: ... # type: ignore[override] + @classmethod + def if_(cls, b: BooleanLike, i: Callable[[], Boolean], j: Callable[[], Boolean]) -> Boolean: + """ + Returns i() if b is True, else j(). Wrapped in callables to avoid eager evaluation. + + >>> bool(Boolean.if_(TRUE, lambda: Boolean(True), lambda: Boolean(False))) + True + """ + -BooleanLike = Boolean | BoolLike +BooleanLike: TypeAlias = Boolean | BoolLike TRUE = Boolean(True) FALSE = Boolean(False) @@ -118,9 +112,10 @@ def __eq__(self, other: BooleanLike) -> Boolean: ... # type: ignore[override] @array_api_ruleset.register -def _bool(x: Boolean, i: Int, j: Int, b: Bool): +def _bool( + x: Boolean, y: Boolean, i: Int, j: Int, b: Bool, b1: Bool, bt: Callable[[], Boolean], bf: Callable[[], Boolean] +): return [ - rule(eq(x).to(Boolean(b))).then(set_(x.to_bool).to(b)), rewrite(TRUE | x).to(TRUE), rewrite(FALSE | x).to(x), rewrite(TRUE & x).to(x), @@ -131,6 +126,9 @@ def _bool(x: Boolean, i: Int, j: Int, b: Bool): rewrite(x == x).to(TRUE), # noqa: PLR0124 rewrite(FALSE == TRUE).to(FALSE), rewrite(TRUE == FALSE).to(FALSE), + rewrite(Boolean.if_(TRUE, bt, bf), subsume=True).to(bt()), + rewrite(Boolean.if_(FALSE, bt, bf), subsume=True).to(bf()), + rule(eq(Boolean(b)).to(Boolean(b1)), ne(b).to(b1)).then(panic("Different booleans cannot be equal")), ] @@ -149,6 +147,7 @@ def __invert__(self) -> Int: ... def __lt__(self, other: IntLike) -> Boolean: ... def __le__(self, other: IntLike) -> Boolean: ... + def __abs__(self) -> Int: ... def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override] ... @@ -157,7 +156,10 @@ def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override] # https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487 @method(preserve=True) def __hash__(self) -> int: - egraph = _get_current_egraph() + # Only hash if we have a current e-graph saved, like in the middle of tracing + egraph = _CURRENT_EGRAPH + if egraph is None: + return hash(self.__egg_typed_expr__) # type: ignore[attr-defined] egraph.register(self) egraph.run(array_api_schedule) simplified = egraph.extract(self) @@ -228,12 +230,17 @@ def __rxor__(self, other: IntLike) -> Int: ... def __ror__(self, other: IntLike) -> Int: ... - @property - def to_i64(self) -> i64: ... - @method(preserve=True) def eval(self) -> int: - return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64) + return try_evaling(self) + + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> int: + match get_callable_args(self, Int): + case (i,): + return cast("i64", i).value + raise ExprValueError(self, "Int(i)") @method(preserve=True) def __index__(self) -> int: @@ -252,11 +259,17 @@ def __bool__(self) -> bool: return bool(self.eval()) @classmethod - def if_(cls, b: BooleanLike, i: IntLike, j: IntLike) -> Int: ... + def if_(cls, b: BooleanLike, i: Callable[[], Int], j: Callable[[], Int]) -> Int: + """ + Returns i() if b is True, else j(). Wrapped in callables to avoid eager evaluation. + + >>> int(Int.if_(TRUE, lambda: Int(1), lambda: Int(2))) + 1 + """ @array_api_ruleset.register -def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int): +def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int, ot: Callable[[], Int], bt: Callable[[], Int]): yield rewrite(Int(i) == Int(i)).to(TRUE) yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE)) @@ -272,8 +285,6 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int): yield rule(eq(r).to(Int(i) > Int(j)), i > j).then(union(r).with_(TRUE)) yield rule(eq(r).to(Int(i) > Int(j)), i < j).then(union(r).with_(FALSE)) - yield rule(eq(o).to(Int(j))).then(set_(o.to_i64).to(j)) - yield rule(eq(Int(i)).to(Int(j)), ne(i).to(j)).then(panic("Real ints cannot be equal to different ints")) yield rewrite(Int(i) + Int(j)).to(Int(i + j)) @@ -287,17 +298,20 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int): yield rewrite(Int(i) << Int(j)).to(Int(i << j)) yield rewrite(Int(i) >> Int(j)).to(Int(i >> j)) yield rewrite(~Int(i)).to(Int(~i)) + yield rewrite(Int(i).__abs__()).to(Int(i.__abs__())) - yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o) - yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b) + yield rewrite(Int.if_(TRUE, ot, bt), subsume=True).to(ot()) + yield rewrite(Int.if_(FALSE, ot, bt), subsume=True).to(bt()) yield rewrite(o.__round__(OptionalInt.none)).to(o) # Never cannot be equal to anything real yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int")) + # If two integers are equal, panic + yield rule(eq(Int(i)).to(Int(j)), ne(i).to(j)).then(panic("Different ints cannot be equal")) -converter(i64, Int, lambda x: Int(x)) +converter(i64, Int, Int) IntLike: TypeAlias = Int | i64Like @@ -309,7 +323,20 @@ def check_index(length: IntLike, idx: IntLike) -> Int: """ length = cast("Int", length) idx = cast("Int", idx) - return Int.if_(((idx >= 0) & (idx < length)), idx, Int.NEVER) + return Int.if_(((idx >= 0) & (idx < length)), lambda: idx, lambda: Int.NEVER) + + +class OptionalInt(Expr, ruleset=array_api_ruleset): + none: ClassVar[OptionalInt] + + @classmethod + def some(cls, value: Int) -> OptionalInt: ... + + +OptionalIntLike: TypeAlias = OptionalInt | IntLike | None + +converter(type(None), OptionalInt, lambda _: OptionalInt.none) +converter(Int, OptionalInt, lambda x: OptionalInt.some(x)) # @array_api_ruleset.register @@ -342,22 +369,33 @@ class Float(Expr, ruleset=array_api_ruleset): @method(cost=3) def __init__(self, value: f64Like) -> None: ... - @property - def to_f64(self) -> f64: ... + @method(cost=2) + @classmethod + def rational(cls, r: BigRat) -> Float: ... @method(preserve=True) - def eval(self) -> float: - return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64) + def eval(self) -> float | Fraction: + return try_evaling(self) - def abs(self) -> Float: ... + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> float | Fraction: + match get_callable_args(self, Float.rational): + case (r,): + return r.value + match get_callable_args(self, Float): + case (f,): + return cast("f64", f).value + raise ExprValueError(self, "Float(f) or Float.rational(r)") - @method(cost=2) - @classmethod - def rational(cls, r: BigRat) -> Float: ... + def __float__(self) -> float: + return float(self.eval()) @classmethod def from_int(cls, i: IntLike) -> Float: ... + def abs(self) -> Float: ... + def __truediv__(self, other: FloatLike) -> Float: ... def __mul__(self, other: FloatLike) -> Float: ... @@ -365,6 +403,7 @@ def __mul__(self, other: FloatLike) -> Float: ... def __add__(self, other: FloatLike) -> Float: ... def __sub__(self, other: FloatLike) -> Float: ... + def __abs__(self) -> Float: ... def __pow__(self, other: FloatLike) -> Float: ... def __round__(self, ndigits: OptionalIntLike = None) -> Float: ... @@ -377,24 +416,25 @@ def __gt__(self, other: FloatLike) -> Boolean: ... def __ge__(self, other: FloatLike) -> Boolean: ... -converter(float, Float, lambda x: Float(x)) -converter(Int, Float, lambda x: Float.from_int(x)) +FloatLike: TypeAlias = Float | float | IntLike -FloatLike: TypeAlias = Float | float | IntLike +converter(float, Float, Float) +converter(Int, Float, lambda x: Float.from_int(x)) +converter(BigRat, Float, lambda x: Float.rational(x)) @array_api_ruleset.register def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int): return [ - rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)), rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))), rewrite(Float(f).abs()).to(Float(f), f >= 0.0), rewrite(Float(f).abs()).to(Float(-f), f < 0.0), - # Convert from float to rationl, if its a whole number i.e. can be converted to int + # Convert from float to rational, if its a whole number i.e. can be converted to int rewrite(Float(f)).to(Float.rational(BigRat(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)), # always convert from int to rational rewrite(Float.from_int(Int(i))).to(Float.rational(BigRat(i, 1))), + rewrite(Float.rational(r)).to(Float(r.to_f64())), rewrite(Float(f) + Float(f2)).to(Float(f + f2)), rewrite(Float(f) - Float(f2)).to(Float(f - f2)), rewrite(Float(f) * Float(f2)).to(Float(f * f2)), @@ -417,112 +457,403 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int): rewrite(Float(f) < Float(f2)).to(TRUE, f < f2), rewrite(Float.rational(r) == Float.rational(r)).to(TRUE), rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)), - # round rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())), + rewrite(Float(f).__abs__()).to(Float(f.__abs__())), + # Two different floats cannot be equal + rule(eq(Float(f)).to(Float(f2)), ne(f).to(f2)).then(panic("Different floats cannot be equal")), ] class TupleInt(Expr, ruleset=array_api_ruleset): """ - Should act like a tuple[int, ...] + A tuple of integers. - All constructors should be rewritten to the functional semantics in the __init__ method. - """ + The following is true for all types of tuple: - @classmethod - def var(cls, name: StringLike) -> TupleInt: ... + Tuples have two main constructors: - def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ... + - `Tuple[T](vs: Vec[T]=[])` + - `Tuple.fn(length: Int, idx_fn: Callable[[Int], T])` - EMPTY: ClassVar[TupleInt] - NEVER: ClassVar[TupleInt] + This is so that they can be defined either with a known fixed integer length or a symbolic + length that could not be resolved to an integer. - def append(self, i: IntLike) -> TupleInt: ... + Both constructors must implement two methods: - @classmethod - def single(cls, i: Int) -> TupleInt: - return TupleInt(Int(1), lambda _: i) + * `l.length() -> Int` + * `l.__getitem__(i: Int) -> T` - @method(subsume=True) - @classmethod - def range(cls, stop: IntLike) -> TupleInt: - return TupleInt(stop, lambda i: i) + Lists with a known length will be subsumed into the vector representation. + + Lists that have vecs that are equal will have the elements unified. + + Methods that transform lists should also subsume or be unextractable, so that the vector version will be preferred. + """ + + def __init__(self, vec: VecLike[Int, IntLike] = Vec[Int].empty()) -> None: + """ + Create a TupleInt from a Vec of Ints. + + >>> list(TupleInt(Vec(i64(1), i64(2), i64(3)))) + [i64(1), i64(2), i64(3)] + >>> list(TupleInt()) + [] + """ @classmethod - def from_vec(cls, vec: VecLike[Int, IntLike]) -> TupleInt: ... + def fn(cls, length: IntLike, idx_fn: Callable[[Int], Int]) -> TupleInt: + """ + Create a TupleInt from a length and an index function. - def __add__(self, other: TupleIntLike) -> TupleInt: + >>> list(TupleInt.fn(3, lambda i: i * 10)) + [Int(0), Int(10), Int(20)] + """ + + def length(self) -> Int: + """ + Return the length of the tuple. + + >>> int(TupleInt([1, 2, 3]).length()) + 3 + >>> int(TupleInt.fn(5, lambda i: i).length()) + 5 + """ + + def __getitem__(self, i: IntLike) -> Int: + """ + Return the integer at index i. + + >>> int(TupleInt([10, 20, 30])[1]) + 20 + >>> int(TupleInt.fn(3, lambda i: i * 10)[2]) + 20 + """ + + def __eq__(self, other: TupleIntLike) -> Boolean: # type: ignore[override] other = cast("TupleInt", other) - return TupleInt( - self.length() + other.length(), lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()]) + return Boolean.if_( + self.length() == other.length(), + lambda: TupleInt.range(self.length()).foldl_boolean(lambda acc, i: acc & (self[i] == other[i]), TRUE), + lambda: FALSE, ) - def length(self) -> Int: ... - def __getitem__(self, i: IntLike) -> Int: ... - @method(preserve=True) def __len__(self) -> int: + """ + >>> len(TupleInt([1, 2, 3])) + 3 + """ return self.length().eval() @method(preserve=True) def __iter__(self) -> Iterator[Int]: return iter(self.eval()) - @property - def to_vec(self) -> Vec[Int]: ... - @method(preserve=True) def eval(self) -> tuple[Int, ...]: - return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec) + """ + Returns the evaluated tuple of Ints. + """ + return try_evaling(self) + + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> tuple[Int, ...]: + match get_callable_args(self, TupleInt): + case (vec,): + return tuple(cast("Vec[Int]", vec)) + raise ExprValueError(self, "TupleInt(vec)") + + @method(unextractable=True) + def append(self, i: IntLike) -> TupleInt: + """ + Append an integer to the end of the tuple. + + >>> ti = TupleInt.range(3) + >>> ti2 = ti.append(3) + >>> list(ti2) + [Int(0), Int(1), Int(2), Int(3)] + """ + return TupleInt.fn( + self.length() + 1, lambda j: Int.if_(j == self.length(), lambda: cast("Int", i), lambda: self[j]) + ) + + @method(unextractable=True) + def append_start(self, i: IntLike) -> TupleInt: + """ + Prepend an integer to the start of the tuple. + >>> ti = TupleInt.range(3) + >>> ti2 = ti.append_start( -1) + >>> list(ti2) + [Int(-1), Int(0), Int(1), Int(2)] + """ + return TupleInt.fn(self.length() + 1, lambda j: Int.if_(j == 0, lambda: cast("Int", i), lambda: self[j - 1])) + + @method(unextractable=True) + def __add__(self, other: TupleIntLike) -> TupleInt: + """ + Concatenate two TupleInts. + >>> ti1 = TupleInt.range(3) + >>> ti2 = TupleInt.range(2) + >>> ti3 = ti1 + ti2 + >>> list(ti3) + [Int(0), Int(1), Int(2), Int(0), Int(1)] + """ + other = cast("TupleInt", other) + return TupleInt.fn( + self.length() + other.length(), + lambda i: Int.if_(i < self.length(), lambda: self[i], lambda: other[i - self.length()]), + ) + + @method(unextractable=True) + def drop(self, n: IntLike) -> TupleInt: + """ + Return a new tuple with the first n elements dropped. + + >>> ti = TupleInt([1, 2, 3, 4]) + >>> list(ti.drop(2)) + [Int(3), Int(4)] + """ + return TupleInt.fn(self.length() - n, lambda i: self[i + n]) + + @method(unextractable=True) + def take(self, n: IntLike) -> TupleInt: + """ + Return a new tuple with only the first n elements, + + >>> ti = TupleInt([1, 2, 3, 4]) + >>> list(ti.take(2)) + [Int(1), Int(2)] + """ + return TupleInt.fn(n, self.__getitem__) + + @method(unextractable=True) + def rest(self) -> TupleInt: + """ + Return a new tuple with the first element dropped. + + >>> ti = TupleInt([1, 2, 3]) + >>> list(ti.rest()) + [Int(2), Int(3)] + """ + return self.drop(i64(1)) + + @method(unextractable=True) + def last(self) -> Int: + """ + Return the last element in the tuple. + + >>> ti = TupleInt([1, 2, 3]) + >>> int(ti.last()) + 3 + """ + return self[self.length() - 1] + + @method(unextractable=True) + def drop_last(self) -> TupleInt: + """ + Return a new tuple with the last element dropped. + + >>> ti = TupleInt([1, 2, 3]) + >>> list(ti.drop_last()) + [Int(1), Int(2)] + """ + return TupleInt.fn(self.length() - 1, self.__getitem__) + + @method(unextractable=True) + @classmethod + def range(cls, stop: IntLike) -> TupleInt: + """ + Create a TupleInt with the integers from 0 to stop - 1. + >>> list(TupleInt.range(5)) + [Int(0), Int(1), Int(2), Int(3), Int(4)] + """ + return TupleInt.fn(stop, lambda i: i) - def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ... - def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ... - def foldl_tuple_int(self, f: Callable[[TupleInt, Int], TupleInt], init: TupleIntLike) -> TupleInt: ... + @method(unextractable=True) + def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: + """ + Fold the tuple from the left with the given function and initial value. + + >>> ti = TupleInt([1, 2, 3]) + >>> int(ti.foldl(lambda acc, x: acc + x, i64(0))) + 6 + """ + return Int.if_(self.length() == 0, lambda: init, lambda: f(self.drop_last().foldl(f, init), self.last())) + + @method(unextractable=True) + def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: + """ + Fold the tuple from the left with the given boolean function and initial value. + + >>> ti = TupleInt([1, 2, 3]) + >>> bool(ti.foldl_boolean(lambda acc, x: acc | (x == i64(2)), FALSE)) + True + >>> bool(ti.foldl_boolean(lambda acc, x: acc & (x < i64(3)), TRUE)) + False + """ + return Boolean.if_( + self.length() == 0, lambda: init, lambda: f(self.drop_last().foldl_boolean(f, init), self.last()) + ) + + @method(unextractable=True) + def foldl_tuple_int(self, f: Callable[[TupleInt, Int], TupleInt], init: TupleIntLike) -> TupleInt: + """ + Fold the tuple from the left with the given tuple function and initial value. + + >>> ti = TupleInt([1, 2, 3]) + >>> ti2 = ti.foldl_tuple_int(lambda acc, x: acc.append(x * 2), TupleInt()) + >>> list(ti2) + [Int(2), Int(4), Int(6)] + """ + init = cast("TupleInt", init) + return TupleInt.if_( + self.length() == 0, lambda: init, lambda: f(self.drop_last().foldl_tuple_int(f, init), self.last()) + ) + + @method(unextractable=True) + def foldl_value(self, f: Callable[[Value, Int], Value], init: ValueLike) -> Value: + """ + Fold the tuple from the left with the given value function and initial value. + >>> ti = TupleInt([1, 2, 3]) + >>> v = ti.foldl_value(lambda acc, x: Value.from_int(x) + acc, Value.from_int(0)) + >>> int(v.to_int) + 6 + """ + init = cast("Value", init) + return Value.if_( + self.length() == 0, lambda: init, lambda: f(self.drop_last().foldl_value(f, init), self.last()) + ) - @method(subsume=True) + @method(unextractable=True) def contains(self, i: Int) -> Boolean: + """ + Returns True if the tuple contains the given integer. + + >>> ti = TupleInt([1, 2, 3]) + >>> bool(ti.contains(i64(2))) + True + >>> bool(ti.contains(i64(4))) + False + """ return self.foldl_boolean(lambda acc, j: acc | (i == j), FALSE) - @method(subsume=True) + @method(unextractable=True) def filter(self, f: Callable[[Int], Boolean]) -> TupleInt: + """ + Returns a new tuple with only the elements that satisfy the given predicate. + + >>> ti = TupleInt([1, 2, 3, 4]) + >>> list(ti.filter(lambda x: x % Int(2) == Int(0))) + [Int(2), Int(4)] + >>> list(ti.filter(lambda x: x > Int(2))) + [Int(3), Int(4)] + """ return self.foldl_tuple_int( - lambda acc, v: TupleInt.if_(f(v), acc.append(v), acc), - TupleInt.EMPTY, + lambda acc, v: TupleInt.if_(f(v), lambda: acc.append(v), lambda: acc), + TupleInt(), ) - @method(subsume=True) - def map(self, f: Callable[[Int], Int]) -> TupleInt: - return TupleInt(self.length(), lambda i: f(self[i])) - @classmethod - def if_(cls, b: BooleanLike, i: TupleIntLike, j: TupleIntLike) -> TupleInt: ... + def if_(cls, b: BooleanLike, i: Callable[[], TupleInt], j: Callable[[], TupleInt]) -> TupleInt: + """ + Returns i() if b is True, else j(). Wrapped in callables to avoid eager evaluation. - def drop(self, n: Int) -> TupleInt: - return TupleInt(self.length() - n, lambda i: self[i + n]) + >>> ti1 = TupleInt([1, 2]) + >>> ti2 = TupleInt([3, 4]) + >>> ti = TupleInt.if_(TRUE, lambda: ti1, lambda: ti2) + >>> list(map(int, ti)) + [1, 2] + """ + @method(unextractable=True) def product(self) -> Int: - return self.foldl(lambda acc, i: acc * i, Int(1)) + """ + Return the product of all elements in the tuple. - def map_tuple_int(self, f: Callable[[Int], TupleInt]) -> TupleTupleInt: - return TupleTupleInt(self.length(), lambda i: f(self[i])) + >>> ti = TupleInt([1, 2, 3, 4]) + >>> int(ti.product()) + 24 + """ + return self.foldl(lambda acc, i: acc * i, Int(1)) + @method(unextractable=True) def select(self, indices: TupleIntLike) -> TupleInt: """ Return a new tuple with the elements at the given indices + + >>> ti = TupleInt([10, 20, 30, 40]) + >>> indices = TupleInt([1, 3]) + >>> list(ti.select(indices)) + [Int(20), Int(40)] """ indices = cast("TupleInt", indices) return indices.map(lambda i: self[i]) + @method(unextractable=True) def deselect(self, indices: TupleIntLike) -> TupleInt: """ Return a new tuple with the elements not at the given indices + + >>> ti = TupleInt([10, 20, 30, 40]) + >>> indices = TupleInt([1, 3]) + >>> list(ti.deselect(indices)) + [Int(10), Int(30)] """ indices = cast("TupleInt", indices) return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i]) + @method(unextractable=True) + def reverse(self) -> TupleInt: + """ + Return a new tuple with the elements in reverse order. + + >>> ti = TupleInt([1, 2, 3]) + >>> list(ti.reverse()) + [Int(3), Int(2), Int(1)] + """ + return TupleInt.fn(self.length(), lambda i: self[self.length() - i - 1]) + + @method(unextractable=True) + def map(self, f: Callable[[Int], Int]) -> TupleInt: + """ + Returns a new tuple with each element transformed by the given function. + + >>> ti = TupleInt([1, 2]) + >>> list(ti.map(lambda x: x * Int(2))) + [Int(2), Int(4)] + """ + return TupleInt.fn(self.length(), lambda i: f(self[i])) + + # Put at bottom so can use previous methods when resolving + @method(unextractable=True) + def map_tuple_int(self, f: Callable[[Int], TupleInt]) -> TupleTupleInt: + """ + Returns a new tuple of TupleInts with each element transformed by the given function. + + >>> ti = TupleInt([1, 2]) + >>> tti = ti.map_tuple_int(lambda x: TupleInt([x, x + 10])) + >>> list(tti[0]) + [Int(1), Int(11)] + >>> list(tti[1]) + [Int(2), Int(12)] + """ + return TupleTupleInt.fn(self.length(), lambda i: f(self[i])) + + @method(unextractable=True) + def map_value(self, f: Callable[[Int], Value]) -> TupleValue: + """ + Returns a new tuple of Values with each element transformed by the given function. + + >>> ti = TupleInt([1, 2]) + >>> tv = ti.map_value(lambda x: Value.from_int(x * 3)) + >>> list(tv) + [Value.from_int(Int(3)), Value.from_int(Int(6))] + """ + return TupleValue.fn(self.length(), lambda i: f(self[i])) -converter(Vec[Int], TupleInt, lambda x: TupleInt.from_vec(x)) +converter(Vec[Int], TupleInt, TupleInt) TupleIntLike: TypeAlias = TupleInt | VecLike[Int, IntLike] @@ -530,84 +861,35 @@ def deselect(self, indices: TupleIntLike) -> TupleInt: def _tuple_int( i: Int, i2: Int, - f: Callable[[Int, Int], Int], - bool_f: Callable[[Boolean, Int], Boolean], idx_fn: Callable[[Int], Int], - tuple_int_f: Callable[[TupleInt, Int], TupleInt], vs: Vec[Int], - b: Boolean, + vs2: Vec[Int], ti: TupleInt, - ti2: TupleInt, k: i64, + lt: Callable[[], TupleInt], + lf: Callable[[], TupleInt], ): - return [ - rule(eq(ti).to(TupleInt.from_vec(vs))).then(set_(ti.to_vec).to(vs)), - # Functional access - rewrite(TupleInt(i, idx_fn).length()).to(i), - rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(check_index(i, i2))), - # cons access - rewrite(TupleInt.EMPTY.length()).to(Int(0)), - rewrite(TupleInt.EMPTY[i]).to(Int.NEVER), - rewrite(ti.append(i).length()).to(ti.length() + 1), - rewrite(ti.append(i)[i2]).to(Int.if_(i2 == ti.length(), i, ti[i2])), - # cons to functional (removed this so that there is not infinite replacements between the,) - # rewrite(TupleInt.EMPTY).to(TupleInt(0, lambda _: Int.NEVER)), - # rewrite(TupleInt(i, idx_fn).append(i2)).to(TupleInt(i + 1, lambda j: Int.if_(j == i, i2, idx_fn(j)))), - # functional to cons - rewrite(TupleInt(0, idx_fn), subsume=True).to(TupleInt.EMPTY), - rewrite(TupleInt(Int(k), idx_fn), subsume=True).to(TupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0), - # cons to vec - rewrite(TupleInt.EMPTY).to(TupleInt.from_vec(Vec[Int]())), - rewrite(TupleInt.from_vec(vs).append(i)).to(TupleInt.from_vec(vs.append(Vec(i)))), - # fold - rewrite(TupleInt.EMPTY.foldl(f, i), subsume=True).to(i), - rewrite(ti.append(i2).foldl(f, i), subsume=True).to(f(ti.foldl(f, i), i2)), - # fold boolean - rewrite(TupleInt.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b), - rewrite(ti.append(i2).foldl_boolean(bool_f, b), subsume=True).to(bool_f(ti.foldl_boolean(bool_f, b), i2)), - # fold tuple_int - rewrite(TupleInt.EMPTY.foldl_tuple_int(tuple_int_f, ti), subsume=True).to(ti), - rewrite(ti.append(i2).foldl_tuple_int(tuple_int_f, ti2), subsume=True).to( - tuple_int_f(ti.foldl_tuple_int(tuple_int_f, ti2), i2) - ), - # if_ - rewrite(TupleInt.if_(TRUE, ti, ti2), subsume=True).to(ti), - rewrite(TupleInt.if_(FALSE, ti, ti2), subsume=True).to(ti2), - # unify append - rule(eq(ti.append(i)).to(ti2.append(i2))).then(union(ti).with_(ti2), union(i).with_(i2)), - ] + # Unify the elements of equal tuples + yield rule(eq(ti).to(TupleInt(vs)), eq(ti).to(TupleInt(vs2)), vs != vs2).then(vs | vs2) + yield rewrite(TupleInt.fn(i2, idx_fn).length(), subsume=False).to(i2) + yield rewrite(TupleInt.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i))) -class TupleTupleInt(Expr, ruleset=array_api_ruleset): - @classmethod - def var(cls, name: StringLike) -> TupleTupleInt: ... + yield rewrite(TupleInt(vs).length()).to(Int(vs.length())) + yield rewrite(TupleInt(vs)[Int(k)]).to(vs[k]) - EMPTY: ClassVar[TupleTupleInt] + yield rewrite(TupleInt.if_(TRUE, lt, lf), subsume=True).to(lt()) + yield rewrite(TupleInt.if_(FALSE, lt, lf), subsume=True).to(lf()) - def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ... + yield rewrite(TupleInt.fn(Int(k), idx_fn), subsume=True).to(TupleInt(k.range().map(lambda i: idx_fn(Int(i))))) - @method(subsume=True) - @classmethod - def single(cls, i: TupleIntLike) -> TupleTupleInt: - i = cast("TupleInt", i) - return TupleTupleInt(1, lambda _: i) - @method(subsume=True) +class TupleTupleInt(Expr, ruleset=array_api_ruleset): + def __init__(self, vec: VecLike[TupleInt, TupleIntLike] = ()) -> None: ... @classmethod - def from_vec(cls, vec: Vec[TupleInt]) -> TupleTupleInt: ... - - def append(self, i: TupleIntLike) -> TupleTupleInt: ... - - def __add__(self, other: TupleTupleIntLike) -> TupleTupleInt: - other = cast("TupleTupleInt", other) - return TupleTupleInt( - self.length() + other.length(), - lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]), - ) - + def fn(cls, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> TupleTupleInt: ... def length(self) -> Int: ... def __getitem__(self, i: IntLike) -> TupleInt: ... - @method(preserve=True) def __len__(self) -> int: return self.length().eval() @@ -616,22 +898,60 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[TupleInt]: return iter(self.eval()) - @property - def to_vec(self) -> Vec[TupleInt]: ... - @method(preserve=True) def eval(self) -> tuple[TupleInt, ...]: - return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec) + return try_evaling(self) + + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> tuple[TupleInt, ...]: + match get_callable_args(self, TupleTupleInt): + case (vec,): + return tuple(cast("Vec[TupleInt]", vec)) + raise ExprValueError(self, "TupleTupleInt(vec)") + + @method(unextractable=True) + def append(self, i: TupleIntLike) -> TupleTupleInt: + return TupleTupleInt.fn( + self.length() + 1, lambda j: TupleInt.if_(j == self.length(), lambda: cast("TupleInt", i), lambda: self[j]) + ) + + @method(unextractable=True) + def __add__(self, other: TupleTupleIntLike) -> TupleTupleInt: + other = cast("TupleTupleInt", other) + return TupleTupleInt.fn( + self.length() + other.length(), + lambda i: TupleInt.if_(i < self.length(), lambda: self[i], lambda: other[i - self.length()]), + ) + @method(unextractable=True) def drop(self, n: Int) -> TupleTupleInt: - return TupleTupleInt(self.length() - n, lambda i: self[i + n]) + return TupleTupleInt.fn(self.length() - n, lambda i: self[i + n]) + @method(unextractable=True) def map_int(self, f: Callable[[TupleInt], Int]) -> TupleInt: - return TupleInt(self.length(), lambda i: f(self[i])) + return TupleInt.fn(self.length(), lambda i: f(self[i])) + + @method(unextractable=True) + def foldl_value(self, f: Callable[[Value, TupleInt], Value], init: ValueLike) -> Value: + return Value.if_( + self.length() == 0, + lambda: cast("Value", init), + lambda: f(self.drop_last().foldl_value(f, init), self.last()), + ) + + @method(unextractable=True) + def last(self) -> TupleInt: + return self[self.length() - 1] + + @method(unextractable=True) + def drop_last(self) -> TupleTupleInt: + return TupleTupleInt.fn(self.length() - 1, self.__getitem__) - def foldl_value(self, f: Callable[[Value, TupleInt], Value], init: ValueLike) -> Value: ... + @classmethod + def if_(cls, b: BooleanLike, i: Callable[[], TupleTupleInt], j: Callable[[], TupleTupleInt]) -> TupleTupleInt: ... - @method(subsume=True) + @method(unextractable=True) def product(self) -> TupleTupleInt: """ Cartesian product of inputs @@ -639,73 +959,50 @@ def product(self) -> TupleTupleInt: https://docs.python.org/3/library/itertools.html#itertools.product https://github.com/saulshanabrook/saulshanabrook/discussions/39 + + >>> [[int(x) for x in row] for row in TupleTupleInt([TupleInt([1, 2]), TupleInt([3, 4])]).product()] + [[1, 3], [1, 4], [2, 3], [2, 4]] """ - return TupleTupleInt( + return TupleTupleInt.fn( self.map_int(lambda x: x.length()).product(), - lambda i: TupleInt( + lambda i: TupleInt.fn( self.length(), lambda j: self[j][(i // self.drop(j + 1).map_int(lambda x: x.length()).product()) % self[j].length()], ), ) -converter(Vec[TupleInt], TupleTupleInt, lambda x: TupleTupleInt.from_vec(x)) +converter(Vec[TupleInt], TupleTupleInt, TupleTupleInt) TupleTupleIntLike: TypeAlias = TupleTupleInt | VecLike[TupleInt, TupleIntLike] @array_api_ruleset.register def _tuple_tuple_int( - length: Int, - fn: Callable[[TupleInt], Int], + i: Int, + i2: Int, idx_fn: Callable[[Int], TupleInt], - f: Callable[[Value, TupleInt], Value], - i: Value, - k: i64, - idx: Int, vs: Vec[TupleInt], - ti: TupleInt, - ti1: TupleInt, - tti: TupleTupleInt, - tti1: TupleTupleInt, + vs2: Vec[TupleInt], + ti: TupleTupleInt, + k: i64, + lt: Callable[[], TupleTupleInt], + lf: Callable[[], TupleTupleInt], ): - yield rule(eq(tti).to(TupleTupleInt.from_vec(vs))).then(set_(tti.to_vec).to(vs)) - yield rewrite(TupleTupleInt(length, idx_fn).length()).to(length) - yield rewrite(TupleTupleInt(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length))) - - # cons access - yield rewrite(TupleTupleInt.EMPTY.length()).to(Int(0)) - yield rewrite(TupleTupleInt.EMPTY[idx]).to(TupleInt.NEVER) - yield rewrite(tti.append(ti).length()).to(tti.length() + 1) - yield rewrite(tti.append(ti)[idx]).to(TupleInt.if_(idx == tti.length(), ti, tti[idx])) - - # functional to cons - yield rewrite(TupleTupleInt(0, idx_fn), subsume=True).to(TupleTupleInt.EMPTY) - yield rewrite(TupleTupleInt(Int(k), idx_fn), subsume=True).to( - TupleTupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0 - ) - # cons to vec - yield rewrite(TupleTupleInt.EMPTY).to(TupleTupleInt.from_vec(Vec[TupleInt]())) - yield rewrite(TupleTupleInt.from_vec(vs).append(ti)).to(TupleTupleInt.from_vec(vs.append(Vec(ti)))) - # fold value - yield rewrite(TupleTupleInt.EMPTY.foldl_value(f, i), subsume=True).to(i) - yield rewrite(tti.append(ti).foldl_value(f, i), subsume=True).to(f(tti.foldl_value(f, i), ti)) - - # unify append - yield rule(eq(tti.append(ti)).to(tti1.append(ti1))).then(union(tti).with_(tti1), union(ti).with_(ti1)) + yield rule(eq(ti).to(TupleTupleInt(vs)), eq(ti).to(TupleTupleInt(vs2)), vs != vs2).then(vs | vs2) + yield rewrite(TupleTupleInt.fn(i2, idx_fn).length(), subsume=False).to(i2) + yield rewrite(TupleTupleInt.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i))) -class OptionalInt(Expr, ruleset=array_api_ruleset): - none: ClassVar[OptionalInt] - - @classmethod - def some(cls, value: Int) -> OptionalInt: ... - + yield rewrite(TupleTupleInt(vs).length(), subsume=False).to(Int(vs.length())) + yield rewrite(TupleTupleInt(vs)[Int(k)], subsume=False).to(vs[k]) -OptionalIntLike: TypeAlias = OptionalInt | IntLike | None + yield rewrite(TupleTupleInt.fn(Int(k), idx_fn), subsume=True).to( + TupleTupleInt(k.range().map(lambda i: idx_fn(Int(i)))) + ) -converter(type(None), OptionalInt, lambda _: OptionalInt.none) -converter(Int, OptionalInt, OptionalInt.some) + yield rewrite(TupleTupleInt.if_(TRUE, lt, lf), subsume=True).to(lt()) + yield rewrite(TupleTupleInt.if_(FALSE, lt, lf), subsume=True).to(lf()) class DType(Expr, ruleset=array_api_ruleset): @@ -734,7 +1031,7 @@ def __eq__(self, other: DType) -> Boolean: # type: ignore[override] @array_api_ruleset.register def _(): for l, r in itertools.product(_DTYPES, repeat=2): - yield rewrite(l == r).to(TRUE if l is r else FALSE) + yield rewrite(l == r, subsume=False).to(TRUE if l is r else FALSE) class IsDtypeKind(Expr, ruleset=array_api_ruleset): @@ -787,31 +1084,36 @@ def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind): ] -# TODO: Add pushdown for math on scalars to values -# and add replacements - - class Value(Expr, ruleset=array_api_ruleset): NEVER: ClassVar[Value] @classmethod - def int(cls, i: IntLike) -> Value: ... + def var(cls, name: StringLike) -> Value: ... + + @classmethod + def from_int(cls, i: IntLike) -> Value: ... @classmethod - def float(cls, f: FloatLike) -> Value: ... + def from_float(cls, f: FloatLike) -> Value: ... @classmethod - def bool(cls, b: BooleanLike) -> Value: ... + def from_bool(cls, b: BooleanLike) -> Value: ... def isfinite(self) -> Boolean: ... def __lt__(self, other: ValueLike) -> Value: ... + def __le__(self, other: ValueLike) -> Boolean: ... + def __gt__(self, other: ValueLike) -> Boolean: ... + def __ge__(self, other: ValueLike) -> Boolean: ... + def __eq__(self, other: ValueLike) -> Boolean: ... # type: ignore[override] def __truediv__(self, other: ValueLike) -> Value: ... - def __mul__(self, other: ValueLike) -> Value: ... - def __add__(self, other: ValueLike) -> Value: ... + def __sub__(self, other: ValueLike) -> Value: ... + def __pow__(self, other: ValueLike) -> Value: ... + + def __abs__(self) -> Value: ... def astype(self, dtype: DType) -> Value: ... @@ -824,10 +1126,10 @@ def dtype(self) -> DType: """ @property - def to_bool(self) -> Boolean: ... + def to_int(self) -> Int: ... @property - def to_int(self) -> Int: ... + def to_bool(self) -> Boolean: ... @property def to_truthy_value(self) -> Value: @@ -842,87 +1144,249 @@ def real(self) -> Value: ... def sqrt(self) -> Value: ... @classmethod - def if_(cls, b: BooleanLike, i: ValueLike, j: ValueLike) -> Value: ... + def if_(cls, b: BooleanLike, i: Callable[[], Value], j: Callable[[], Value]) -> Value: ... - def __eq__(self, other: ValueLike) -> Boolean: ... # type: ignore[override] + def __int__(self) -> int: + return int(self.value) + + def __float__(self) -> float: + return float(self.value) + + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> bool | int | float | Fraction: + match get_callable_args(self, Value.from_int): + case (i,): + return cast("Int", i).value + match get_callable_args(self, Value.from_float): + case (f,): + return cast("Float", f).value + match get_callable_args(self, Value.from_bool): + case (b,): + return cast("Boolean", b).value + raise ExprValueError(self, "Value.int|float|bool(...)") + + @method(cost=100000000) + def diff(self, v: Value) -> Value: + """ + Differentiate self with respect to v. + + >>> x = Value.var("x") + >>> int(x.diff(x).to_int) + 1 + >>> int(x.diff(Value.var("y")).to_int) + 0 + >>> int((x + Value.from_int(2)).diff(x).to_int) + 1 + """ ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike -converter(Int, Value, Value.int) -converter(Float, Value, Value.float) -converter(Boolean, Value, Value.bool) +converter(Int, Value, lambda x: Value.from_int(x)) +converter(Float, Value, lambda x: Value.from_float(x)) +converter(Boolean, Value, lambda x: Value.from_bool(x)) converter(Value, Int, lambda x: x.to_int, 10) @array_api_ruleset.register -def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, i1: Int, f1: Float, b1: Boolean): +def _value( + i: Int, + f: Float, + b: Boolean, + v: Value, + v1: Value, + v2: Value, + v3: Value, + i1: Int, + f1: Float, + b1: Boolean, + vt: Callable[[], Value], + v1t: Callable[[], Value], + s: String, + s1: String, + i_: i64, +): # Default dtypes # https://data-apis.org/array-api/latest/API_specification/data_types.html?highlight=dtype#default-data-types - yield rewrite(Value.int(i).dtype).to(DType.int64) - yield rewrite(Value.float(f).dtype).to(DType.float64) - yield rewrite(Value.bool(b).dtype).to(DType.bool) + yield rewrite(Value.from_int(i).dtype).to(DType.int64) + yield rewrite(Value.from_float(f).dtype).to(DType.float64) + yield rewrite(Value.from_bool(b).dtype).to(DType.bool) - yield rewrite(Value.bool(b).to_bool).to(b) - yield rewrite(Value.int(i).to_int).to(i) + yield rewrite(Value.from_int(i).to_int).to(i) + yield rewrite(Value.from_bool(b).to_bool).to(b) - yield rewrite(Value.bool(b).to_truthy_value).to(Value.bool(b)) + yield rewrite(Value.from_bool(b).to_truthy_value).to(Value.from_bool(b)) # TODO: Add more rules for to_bool_value - yield rewrite(Value.float(f).conj()).to(Value.float(f)) - yield rewrite(Value.float(f).real()).to(Value.float(f)) - yield rewrite(Value.int(i).real()).to(Value.int(i)) - yield rewrite(Value.int(i).conj()).to(Value.int(i)) + yield rewrite(Value.from_float(f).conj()).to(Value.from_float(f)) + yield rewrite(Value.from_float(f).real()).to(Value.from_float(f)) + yield rewrite(Value.from_int(i).real()).to(Value.from_int(i)) + yield rewrite(Value.from_int(i).conj()).to(Value.from_int(i)) - yield rewrite(Value.float(f).sqrt()).to(Value.float(f ** (0.5))) + yield rewrite(Value.from_float(f).sqrt()).to(Value.from_float(f ** (0.5))) - yield rewrite(Value.float(Float.rational(BigRat(0, 1))) + v).to(v) + yield rewrite(Value.from_float(Float.rational(BigRat(0, 1))) + v).to(v) - yield rewrite(Value.if_(TRUE, v, v1)).to(v) - yield rewrite(Value.if_(FALSE, v, v1)).to(v1) + yield rewrite(Value.if_(TRUE, vt, v1t), subsume=True).to(vt()) + yield rewrite(Value.if_(FALSE, vt, v1t), subsume=True).to(v1t()) # == - yield rewrite(Value.int(i) == Value.int(i1)).to(i == i1) - yield rewrite(Value.float(f) == Value.float(f1)).to(f == f1) - yield rewrite(Value.bool(b) == Value.bool(b1)).to(b == b1) + yield rewrite(Value.from_int(i) == Value.from_int(i1)).to(i == i1) + yield rewrite(Value.from_float(f) == Value.from_float(f1)).to(f == f1) + yield rewrite(Value.from_bool(b) == Value.from_bool(b1)).to(b == b1) + # >= + yield rewrite(Value.from_int(i) >= Value.from_int(i1)).to(i >= i1) + yield rewrite(Value.from_float(f) >= Value.from_float(f1)).to(f >= f1) + # <= + yield rewrite(Value.from_int(i) <= Value.from_int(i1)).to(i <= i1) + yield rewrite(Value.from_float(f) <= Value.from_float(f1)).to(f <= f1) + # > + yield rewrite(Value.from_int(i) > Value.from_int(i1)).to(i > i1) + yield rewrite(Value.from_float(f) > Value.from_float(f1)).to(f > f1) + # < + yield rewrite(Value.from_int(i) < Value.from_int(i1)).to(Value.from_bool(i < i1)) + yield rewrite(Value.from_float(f) < Value.from_float(f1)).to(Value.from_bool(f < f1)) + + # / + yield rewrite(Value.from_float(f) / Value.from_float(f1)).to(Value.from_float(f / f1)) + # * + yield rewrite(Value.from_float(f) * Value.from_float(f1)).to(Value.from_float(f * f1)) + yield rewrite(Value.from_int(i) * Value.from_int(i1)).to(Value.from_int(i * i1)) + # + + yield rewrite(Value.from_float(f) + Value.from_float(f1)).to(Value.from_float(f + f1)) + yield rewrite(Value.from_int(i) + Value.from_int(i1)).to(Value.from_int(i + i1)) + # - + yield rewrite(Value.from_float(f) - Value.from_float(f1)).to(Value.from_float(f - f1)) + yield rewrite(Value.from_int(i) - Value.from_int(i1)).to(Value.from_int(i - i1)) + # ** + yield rewrite(Value.from_float(f) ** Value.from_float(f1)).to(Value.from_float(f**f1)) + yield rewrite(Value.from_int(i) ** Value.from_int(i1)).to(Value.from_int(i**i1)) + yield rewrite(Value.from_int(i) ** Value.from_float(f1)).to(Value.from_float(Float.from_int(i) ** f1)) + + # abs + yield rewrite(Value.from_int(i).__abs__()).to(Value.from_int(i.__abs__())) + yield rewrite(Value.from_float(f).__abs__()).to(Value.from_float(f.__abs__())) + # abs(x) **2 = x**2 + yield rewrite(v.__abs__() ** Value.from_float(Float.rational(BigRat(2, 1)))).to(v ** Value.from_float(2)) + + # ** distributes over division + yield rewrite((v1 / v) ** v2, subsume=False).to(v1**v2 / (v**v2)) + # x ** y ** z = x ** (y * z) + yield rewrite((v**v1) ** v2, subsume=False).to(v ** (v1 * v2)) + yield rewrite(Value.from_float(f) * Value.from_int(i)).to(Value.from_float(f * Float.from_int(i))) + yield rewrite(v ** Value.from_float(Float.rational(BigRat(1, 1)))).to(v) + yield rewrite(Value.from_float(Float.from_int(i))).to(Value.from_int(i)) + + # Upcast binary op + yield rewrite(Value.from_int(i) * Value.from_float(f)).to(Value.from_float(Float.from_int(i)) * Value.from_float(f)) + + # Integer identities / annihilators + yield rewrite(v + Value.from_int(0)).to(v) + yield rewrite(Value.from_int(0) + v).to(v) + yield rewrite(v * Value.from_int(1)).to(v) + yield rewrite(Value.from_int(1) * v).to(v) + yield rewrite(v * Value.from_int(0)).to(Value.from_int(0)) + yield rewrite(Value.from_int(0) * v).to(Value.from_int(0)) + yield rewrite(v - Value.from_int(0)).to(v) + yield rewrite(v**1).to(v) + + # Differentiation rules + yield rewrite(v.diff(v)).to(Value.from_int(1)) + yield rewrite((v1 + v2).diff(v3)).to(v1.diff(v3) + v2.diff(v3)) + yield rewrite((v1 - v2).diff(v3)).to(v1.diff(v3) - v2.diff(v3)) + yield rewrite((v1 * v2).diff(v3)).to(v1.diff(v3) * v2 + v1 * v2.diff(v3)) + yield rewrite((v1 / v2).diff(v3)).to((v1.diff(v3) * v2 - v1 * v2.diff(v3)) / (v2 * v2)) + yield rewrite((v1**i_).diff(v3)).to((v1 * v1 ** (i_ - 1)).diff(v3), i_ > 1) + yield rewrite(Value.var(s).diff(Value.var(s1))).to(Value.from_int(0), s != s1) + yield rewrite(Value.from_int(i_).diff(Value.var(s))).to(Value.from_int(0)) class TupleValue(Expr, ruleset=array_api_ruleset): - EMPTY: ClassVar[TupleValue] + def __init__(self, vec: VecLike[Value, ValueLike] = ()) -> None: ... + @classmethod + def fn(cls, length: IntLike, idx_fn: Callable[[Int], Value]) -> TupleValue: ... + def length(self) -> Int: ... + def __getitem__(self, i: IntLike) -> Value: ... + @method(preserve=True) + def __len__(self) -> int: + return self.length().eval() - def __init__(self, length: IntLike, idx_fn: Callable[[Int], Value]) -> None: ... + @method(preserve=True) + def __iter__(self) -> Iterator[Value]: + return iter(self.eval()) - def append(self, i: ValueLike) -> TupleValue: ... + @method(preserve=True) + def eval(self) -> tuple[Value, ...]: + return try_evaling(self) - @classmethod - def from_vec(cls, vec: Vec[Value]) -> TupleValue: ... + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> tuple[Value, ...]: + match get_callable_args(self, TupleValue): + case (vec,): + return tuple(cast("Vec[Value]", vec)) + raise ExprValueError(self, "TupleValue(vec)") + + @method(unextractable=True) + def append(self, i: ValueLike) -> TupleValue: + return TupleValue.fn( + self.length() + 1, lambda j: Value.if_(j == self.length(), lambda: cast("Value", i), lambda: self[j]) + ) + @method(unextractable=True) def __add__(self, other: TupleValueLike) -> TupleValue: other = cast("TupleValue", other) - return TupleValue( + return TupleValue.fn( self.length() + other.length(), - lambda i: Value.if_(i < self.length(), self[i], other[i - self.length()]), + lambda i: Value.if_(i < self.length(), lambda: self[i], lambda: other[i - self.length()]), ) - def length(self) -> Int: ... + @method(unextractable=True) + def last(self) -> Value: + return self[self.length() - 1] - def __getitem__(self, i: Int) -> Value: ... + @method(unextractable=True) + def drop_last(self) -> TupleValue: + return TupleValue.fn(self.length() - 1, self.__getitem__) - def foldl_boolean(self, f: Callable[[Boolean, Value], Boolean], init: BooleanLike) -> Boolean: ... + @method(unextractable=True) + def foldl_boolean(self, f: Callable[[Boolean, Value], Boolean], init: BooleanLike) -> Boolean: + return Boolean.if_( + self.length() == 0, + lambda: cast("Boolean", init), + lambda: f(self.drop_last().foldl_boolean(f, init), self.last()), + ) + + @method(subsume=False) + def foldl_value(self, f: Callable[[Value, Value], Value], init: ValueLike) -> Value: + return Value.if_( + self.length() == 0, + lambda: cast("Value", init), + lambda: f(self.drop_last().foldl_value(f, init), self.last()), + ) + @method(unextractable=True) + def map_value(self, f: Callable[[Value], Value]) -> TupleValue: + return TupleValue.fn(self.length(), lambda i: f(self[i])) + + @method(unextractable=True) def contains(self, value: ValueLike) -> Boolean: value = cast("Value", value) return self.foldl_boolean(lambda acc, j: acc | (value == j), FALSE) - @method(subsume=True) + @method(unextractable=True) @classmethod def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue: ti = cast("TupleInt", ti) - return TupleValue(ti.length(), lambda i: Value.int(ti[i])) + return TupleValue.fn(ti.length(), lambda i: Value.from_int(ti[i])) + @classmethod + def if_(cls, b: BooleanLike, i: Callable[[], TupleValue], j: Callable[[], TupleValue]) -> TupleValue: ... -converter(Vec[Value], TupleValue, lambda x: TupleValue.from_vec(x)) + +converter(Vec[Value], TupleValue, TupleValue) converter(TupleInt, TupleValue, lambda x: TupleValue.from_tuple_int(x)) TupleValueLike: TypeAlias = TupleValue | VecLike[Value, ValueLike] | TupleIntLike @@ -930,43 +1394,28 @@ def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue: @array_api_ruleset.register def _tuple_value( - length: Int, + i: Int, + i2: Int, idx_fn: Callable[[Int], Value], - k: i64, - idx: Int, vs: Vec[Value], - v: Value, - v1: Value, - tv: TupleValue, - tv1: TupleValue, - bool_f: Callable[[Boolean, Value], Boolean], - b: Boolean, + vs2: Vec[Value], + ti: TupleValue, + k: i64, + lt: Callable[[], TupleValue], + lf: Callable[[], TupleValue], ): - yield rewrite(TupleValue(length, idx_fn).length()).to(length) - yield rewrite(TupleValue(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length))) - - # cons access - yield rewrite(TupleValue.EMPTY.length()).to(Int(0)) - yield rewrite(TupleValue.EMPTY[idx]).to(Value.NEVER) - yield rewrite(tv.append(v).length()).to(tv.length() + 1) - yield rewrite(tv.append(v)[idx]).to(Value.if_(idx == tv.length(), v, tv[idx])) - - # functional to cons - yield rewrite(TupleValue(0, idx_fn), subsume=True).to(TupleValue.EMPTY) - yield rewrite(TupleValue(Int(k), idx_fn), subsume=True).to( - TupleValue(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0 - ) + yield rule(eq(ti).to(TupleValue(vs)), eq(ti).to(TupleValue(vs2)), vs != vs2).then(vs | vs2) - # cons to vec - yield rewrite(TupleValue.EMPTY).to(TupleValue.from_vec(Vec[Value]())) - yield rewrite(TupleValue.from_vec(vs).append(v)).to(TupleValue.from_vec(vs.append(Vec(v)))) + yield rewrite(TupleValue.fn(i2, idx_fn).length(), subsume=False).to(i2) + yield rewrite(TupleValue.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i))) - # fold boolean - yield rewrite(TupleValue.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b) - yield rewrite(tv.append(v).foldl_boolean(bool_f, b), subsume=True).to(bool_f(tv.foldl_boolean(bool_f, b), v)) + yield rewrite(TupleValue(vs).length(), subsume=False).to(Int(vs.length())) + yield rewrite(TupleValue(vs)[Int(k)], subsume=False).to(vs[k], k >= 0, k < vs.length()) - # unify append - yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1)) + yield rewrite(TupleValue.fn(Int(k), idx_fn), subsume=True).to(TupleValue(k.range().map(lambda i: idx_fn(Int(i))))) + + yield rewrite(TupleValue.if_(TRUE, lt, lf), subsume=True).to(lt()) + yield rewrite(TupleValue.if_(FALSE, lt, lf), subsume=True).to(lf()) @function @@ -1007,8 +1456,8 @@ def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: ... converter(type(...), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.ELLIPSIS) converter(type(None), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.NONE) -converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int) -converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice) +converter(Int, MultiAxisIndexKeyItem, lambda i: MultiAxisIndexKeyItem.int(i)) +converter(Slice, MultiAxisIndexKeyItem, lambda s: MultiAxisIndexKeyItem.slice(s)) MultiAxisIndexKeyItemLike: TypeAlias = MultiAxisIndexKeyItem | EllipsisType | None | IntLike | SliceLike @@ -1041,53 +1490,266 @@ class IndexKey(Expr, ruleset=array_api_ruleset): https://data-apis.org/array-api/2022.12/API_specification/indexing.html - It is equivalent to the following type signature: + It is equivalent to the following type signature: + + Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array] + """ + + ELLIPSIS: ClassVar[IndexKey] + + @classmethod + def int(cls, i: Int) -> IndexKey: ... + + @classmethod + def slice(cls, slice: Slice) -> IndexKey: ... + + # Disabled until we support late binding + # @classmethod + # def boolean_array(cls, b: NDArray) -> IndexKey: + # ... + + @classmethod + def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ... + + @classmethod + def ndarray(cls, key: NDArray) -> IndexKey: + """ + Indexes by a masked array + """ + + +IndexKeyLike: TypeAlias = "IndexKey | IntLike | SliceLike | MultiAxisIndexKeyLike | NDArrayLike" + + +converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS) +converter(Int, IndexKey, lambda i: IndexKey.int(i)) +converter(Slice, IndexKey, lambda s: IndexKey.slice(s)) +converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m)) + + +class Device(Expr, ruleset=array_api_ruleset): ... + + +ALL_INDICES: TupleInt = constant("ALL_INDICES", TupleInt) + + +class RecursiveValue(Expr): + """ + Either a value or vec of RecursiveValues + + >>> convert(Value.from_int(42), RecursiveValue) + RecursiveValue(Value.from_int(Int(42))) + >>> convert((1, 2, 3), RecursiveValue) + RecursiveValue.vec(Vec(RecursiveValue(Value.from_int(Int(1))), RecursiveValue(Value.from_int(Int(2))), RecursiveValue(Value.from_int(Int(3))))) + >>> convert(((1,), (2,)), RecursiveValue) + RecursiveValue.vec(Vec(RecursiveValue.vec(Vec(RecursiveValue(Value.from_int(Int(1))))), RecursiveValue.vec(Vec(RecursiveValue(Value.from_int(Int(2))))))) + """ + + def __init__(self, value: ValueLike) -> None: ... + + @classmethod + def vec(cls, vec: VecLike[RecursiveValue, RecursiveValueLike]) -> RecursiveValue: ... + + def __getitem__(self, index: VecLike[Int, IntLike]) -> Value: + """ + Index into the RecursiveValue with the given indices. It should match the shape. + + >>> rv = convert(((1, 2), (3, 4)), RecursiveValue) + >>> int(rv[[0, 1]].to_int) + 2 + """ + + @property + def shape(self) -> TupleInt: + """ + Shape of the RecursiveValue. + + >>> rv = convert(((1,), (3,)), RecursiveValue) + >>> list(rv.shape) + [Int(2), Int(1)] + """ + + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> PyTupleValuesRecursive: + """ + Unwraps the RecursiveValue into either a Value or a nested tuple of Values. + + >>> convert(((1, 2), (3, 4)), RecursiveValue).value + ((Value.from_int(Int(1)), Value.from_int(Int(2))), (Value.from_int(Int(3)), Value.from_int(Int(4)))) + """ + match get_callable_args(self, RecursiveValue): + case (value,): + return cast("Value", value) + match get_callable_args(self, RecursiveValue.vec): + case (vec,): + return tuple(v.value for v in cast("Vec[RecursiveValue]", vec)) + raise ExprValueError(self, "RecursiveValue or RecursiveValue.vec") + + __match_args__ = ("value",) + + @method(preserve=True) + def eval(self) -> PyTupleValuesRecursive: + """ + Evals to a nested tuple of values representing the RecursiveValue. + """ + return try_evaling(self) + + @classmethod + def from_index_and_shape(cls, shape: Vec[Int], idx_fn: Callable[[TupleInt], Value]) -> RecursiveValue: ... + + +PyTupleValuesRecursive: TypeAlias = Value | tuple["PyTupleValuesRecursive", ...] + +RecursiveValueLike: TypeAlias = RecursiveValue | VecLike[RecursiveValue, "RecursiveValueLike"] | ValueLike + +converter(Vec[RecursiveValue], RecursiveValue, lambda x: RecursiveValue.vec(x)) +converter(Value, RecursiveValue, RecursiveValue) + + +@array_api_ruleset.register +def _recursive_value( + v: Value, + vs: Vec[RecursiveValue], + k: i64, + vi: Vec[Int], + vi1: Vec[Int], + rv: RecursiveValue, + rv1: RecursiveValue, + idx_fn: Callable[[TupleInt], Value], +): + yield rewrite(RecursiveValue(v).shape).to(TupleInt(())) + yield rewrite(RecursiveValue.vec(vs).shape).to(TupleInt((vs.length(),)) + vs[0].shape, vs.length() > 0) + yield rewrite(RecursiveValue.vec(vs).shape).to(TupleInt((0,)), vs.length() == i64(0)) + + yield rewrite(RecursiveValue(v)[vi], subsume=False).to(v) # Assume ti is empty + + # indexing + yield rule( + eq(rv).to(RecursiveValue.vec(vs)), + eq(v).to(rv[vi]), + vi.length() > 0, + eq(vi[0]).to(Int(k)), + eq(rv1).to(vs[k]), + eq(vi1).to(vi.remove(0)), + ).then( + union(v).with_(rv1[vi1]), + subsume(rv[vi]), + ) + # from idx fn + yield rule( + eq(rv).to(RecursiveValue.from_index_and_shape(vi, idx_fn)), + vi.length() > 0, + eq(vi[0]).to(Int(k)), + eq(vi1).to(vi.remove(0)), + ).then( + union(rv).with_( + RecursiveValue.vec( + k.range().map( + lambda i: RecursiveValue.from_index_and_shape( + vi1, + lambda rest_indices: idx_fn(rest_indices.append_start(Int(i))), + ) + ) + ) + ), + subsume(RecursiveValue.from_index_and_shape(vi, idx_fn)), + ) + yield rule( + eq(rv).to(RecursiveValue.from_index_and_shape(vi, idx_fn)), + vi.length() == i64(0), + ).then( + union(rv).with_(RecursiveValue(idx_fn(TupleInt(())))), + subsume(RecursiveValue.from_index_and_shape(vi, idx_fn)), + ) + - Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array] +class NDArray(Expr, ruleset=array_api_ruleset): + """ + NDArray implementation following the Array API Standard. + + >>> NDArray((1, 2, 3)).eval() + (Value.from_int(Int(1)), Value.from_int(Int(2)), Value.from_int(Int(3))) + >>> NDArray((1, 2, 3)).eval_numpy("int64") + array([1, 2, 3]) + >>> NDArray(((1, 2), (3, 4))).eval_numpy("int64") + array([[1, 2], + [3, 4]]) """ - ELLIPSIS: ClassVar[IndexKey] - - @classmethod - def int(cls, i: Int) -> IndexKey: ... + def __init__(self, values: RecursiveValueLike) -> None: ... @classmethod - def slice(cls, slice: Slice) -> IndexKey: ... - - # Disabled until we support late binding - # @classmethod - # def boolean_array(cls, b: NDArray) -> IndexKey: - # ... + def fn(cls, shape: TupleIntLike, dtype: DType, idx_fn: Callable[[TupleInt], Value]) -> NDArray: ... - @classmethod - def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ... + NEVER: ClassVar[NDArray] + @method(unextractable=True) @classmethod - def ndarray(cls, key: NDArray) -> IndexKey: - """ - Indexes by a masked array + def from_tuple_value(cls, tv: TupleValueLike) -> NDArray: """ + Creates an vector NDArray from a tuple of values. + >>> NDArray.from_tuple_value((1, 2)).eval_numpy("int64") + array([1, 2]) + """ + tv = cast("TupleValue", tv) + return NDArray.fn( + TupleInt((tv.length(),)), + tv[0].dtype, + lambda idx: tv[idx[0]], + ) -IndexKeyLike: TypeAlias = "IndexKey | IntLike | SliceLike | MultiAxisIndexKeyLike | NDArrayLike" - - -converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS) -converter(Int, IndexKey, lambda i: IndexKey.int(i)) -converter(Slice, IndexKey, lambda s: IndexKey.slice(s)) -converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m)) + @method(unextractable=True) + def to_tuple_values(self) -> TupleValue: + """ + Turns a vector array into a tuple value. + """ + return TupleValue.fn(self.shape[0], lambda i: self.index((i,))) + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> PyTupleValuesRecursive: + """ + Unwraps the RecursiveValue into either a Value or a nested tuple of Values. -class Device(Expr, ruleset=array_api_ruleset): ... + >>> convert(((1, 2), (3, 4)), RecursiveValue).value + ((Value.from_int(Int(1)), Value.from_int(Int(2))), (Value.from_int(Int(3)), Value.from_int(Int(4)))) + """ + match get_callable_args(self, NDArray): + case (RecursiveValue(value),): + return value + raise ExprValueError(self, "NDArray(recursive_value)") + __match_args__ = ("value",) -ALL_INDICES: TupleInt = constant("ALL_INDICES", TupleInt) + @method(preserve=True) + def eval(self) -> PyTupleValuesRecursive: + """ + Evals to a nested tuple of values representing the RecursiveValue. + """ + return try_evaling(self) + @method(preserve=True) + def eval_numpy(self, dtype: np.dtype | None = None) -> np.ndarray: + """ + Evals to a numpy ndarray. + """ + return np.array(self.eval(), dtype=dtype) -class NDArray(Expr, ruleset=array_api_ruleset): - def __init__(self, shape: TupleIntLike, dtype: DType, idx_fn: Callable[[TupleInt], Value]) -> None: ... + @method(preserve=True) + def __array__(self, dtype=None, copy=None) -> np.ndarray: + if copy is False: + msg = "NDArray.__array__ with copy=False is not supported" + raise NotImplementedError(msg) + return self.eval_numpy(dtype=dtype) - NEVER: ClassVar[NDArray] + def __int__(self) -> int: + res = self.eval() + if isinstance(res, tuple): + msg = "Cannot convert a non-scalar array to int" + raise TypeError(msg) + return int(res) @method(cost=200) @classmethod @@ -1098,7 +1760,8 @@ def __array_namespace__(self, api_version: object = None) -> ModuleType: return sys.modules[__name__] @property - def ndim(self) -> Int: ... + def ndim(self) -> Int: + return self.shape.length() @property def dtype(self) -> DType: ... @@ -1111,7 +1774,10 @@ def shape(self) -> TupleInt: ... @method(preserve=True) def __bool__(self) -> bool: - return self.to_value().to_bool.eval() + # Special case bool so it works when comparing to arrays outside of tracing, like when indexing + if not _CURRENT_EGRAPH and (args := get_callable_args(self, NDArray.__eq__)) is not None: + return bool(eq(args[0]).to(cast("NDArray", args[1]))) + return self.index(()).to_bool.eval() @property def size(self) -> Int: ... @@ -1120,10 +1786,22 @@ def size(self) -> Int: ... def __len__(self) -> int: return self.size.eval() + @method(egg_fn="sum") + def sum(self, axis: OptionalIntOrTupleLike = None) -> NDArray: ... + @method(preserve=True) def __iter__(self) -> Iterator[NDArray]: - for i in range(len(self)): - yield self[IndexKey.int(Int(i))] + """ + Only for 1D arrays: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__getitem__.html + + >>> list(NDArray((1, 2, 3))) + [NDArray(RecursiveValue(Value.from_int(Int(1)))), NDArray(RecursiveValue(Value.from_int(Int(2)))), NDArray(RecursiveValue(Value.from_int(Int(3))))] + """ + inner = self.eval() + if isinstance(inner, Value): + msg = "Cannot iterate over a 0D array" + raise TypeError(msg) + return map(NDArray, inner) def __getitem__(self, key: IndexKeyLike) -> NDArray: ... @@ -1198,28 +1876,19 @@ def __rxor__(self, other: NDArray) -> NDArray: ... def __ror__(self, other: NDArray) -> NDArray: ... - @classmethod - def scalar(cls, value: Value) -> NDArray: - return NDArray(TupleInt.EMPTY, value.dtype, lambda _: value) - - def to_value(self) -> Value: - """ - Returns the value if this is a scalar. - """ - - def to_values(self) -> TupleValue: - """ - Returns the value if this is a vector. - """ + def __abs__(self) -> NDArray: ... @property def T(self) -> NDArray: """ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.T.html#array_api.array.T """ - - @classmethod - def vector(cls, values: TupleValueLike) -> NDArray: ... + # Only works on 2D arrays + return NDArray.fn( + (self.shape[1], self.shape[0]), + self.dtype, + lambda idx: self.index((idx[1], idx[0])), + ) def index(self, indices: TupleIntLike) -> Value: """ @@ -1227,85 +1896,107 @@ def index(self, indices: TupleIntLike) -> Value: """ @classmethod - def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ... + def if_(cls, b: BooleanLike, i: Callable[[], NDArray], j: Callable[[], NDArray]) -> NDArray: ... + + @method(unextractable=True) + def diff(self, v: NDArrayLike) -> NDArray: + """ + Differentiate self with respect to v. + + It will have the shape of the concat of both input shapes. On the outside are the indices of the variable array + and on the inside the indices of the value array. + + >>> v = Value.var("v") + >>> int(NDArray(v).diff(v)) + 1 + >>> int(NDArray(v + v).diff(v)) + 2 + >>> int(NDArray(v * 3).diff(v)) + 3 + >>> tuple(map(int, NDArray((v, v * 2, v * 3)).diff(v))) + (1, 2, 3) + >>> tuple(map(int, NDArray(v * 2).diff(NDArray([v, Value.var("w")])))) + (2, 0) + """ + v = cast("NDArray", v) + return NDArray.fn( + v.shape + self.shape, + self.dtype, + lambda idx: self.index(idx.drop(v.shape.length())).diff(v.index(idx.take(v.shape.length()))), + ) + +VecValuesRecursive: TypeAlias = "Value | Vec[VecValuesRecursive]" -NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike +NDArrayLike: TypeAlias = NDArray | RecursiveValueLike converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v)) -converter(Value, NDArray, lambda v: NDArray.scalar(v)) +converter(RecursiveValue, NDArray, NDArray) # Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive # to prefer upcasting in the other direction when we can, which is safer at runtime -converter(NDArray, Value, lambda n: n.to_value(), 100) -converter(TupleValue, NDArray, lambda v: NDArray.vector(v)) -converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v)) +converter(NDArray, Value, lambda n: n.index(()), 100) @array_api_ruleset.register def _ndarray( x: NDArray, - x1: NDArray, - b: Boolean, - f: Float, - fi1: f64, - fi2: f64, shape: TupleInt, dtype: DType, idx_fn: Callable[[TupleInt], Value], idx: TupleInt, - tv: TupleValue, + v: Value, + v1: Value, + xt: Callable[[], NDArray], + x1t: Callable[[], NDArray], + rv: RecursiveValue, + vi: Vec[Int], + i: i64, ): return [ - rewrite(NDArray(shape, dtype, idx_fn).shape).to(shape), - rewrite(NDArray(shape, dtype, idx_fn).dtype).to(dtype), - rewrite(NDArray(shape, dtype, idx_fn).index(idx), subsume=True).to(idx_fn(idx)), - rewrite(x.ndim).to(x.shape.length()), - # rewrite(NDArray.scalar(Value.bool(b)).to_bool()).to(b), - # Converting to a value requires a scalar bool value - rewrite(x.to_value()).to(x.index(TupleInt.EMPTY)), - rewrite(NDArray.vector(tv).to_values()).to(tv), - # TODO: Push these down to float - rewrite(NDArray.scalar(Value.float(f)) / NDArray.scalar(Value.float(f))).to( - NDArray.scalar(Value.float(Float(1.0))) - ), - rewrite(NDArray.scalar(Value.float(f)) - NDArray.scalar(Value.float(f))).to( - NDArray.scalar(Value.float(Float(0.0))) - ), - rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to( - NDArray.scalar(Value.bool(TRUE)), fi1 > fi2 - ), - rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to( - NDArray.scalar(Value.bool(FALSE)), fi1 <= fi2 - ), - # Transpose of tranpose is the original array - rewrite(x.T.T).to(x), + rewrite(NDArray.fn(shape, dtype, idx_fn).shape, subsume=False).to(shape), + rewrite(NDArray.fn(shape, dtype, idx_fn).dtype, subsume=False).to(dtype), + rewrite(NDArray.fn(shape, dtype, idx_fn).index(idx), subsume=True).to(idx_fn(idx)), + rewrite(NDArray(rv).shape, subsume=False).to(rv.shape), + rewrite(NDArray(rv).index(TupleInt(vi)), subsume=False).to(rv[vi]), + # TODO: Special case scalar ops for now + rewrite(NDArray(v) / NDArray(v1), subsume=False).to(NDArray(v / v1)), + rewrite(NDArray(v) + NDArray(v1), subsume=False).to(NDArray(v + v1)), + rewrite(NDArray(v) * NDArray(v1), subsume=False).to(NDArray(v * v1)), + rewrite(NDArray(v) ** NDArray(v1), subsume=False).to(NDArray(v**v1)), + rewrite(NDArray(v) - NDArray(v1), subsume=False).to(NDArray(v - v1)), + # Comparisons + rewrite(NDArray(v) < NDArray(v1), subsume=False).to(NDArray(v < v1)), + rewrite(NDArray(v) <= NDArray(v1), subsume=False).to(NDArray(v <= v1)), + rewrite(NDArray(v) == NDArray(v1), subsume=False).to(NDArray(v == v1)), + rewrite(NDArray(v) > NDArray(v1), subsume=False).to(NDArray(v > v1)), + rewrite(NDArray(v) >= NDArray(v1), subsume=False).to(NDArray(v >= v1)), + # Transpose of transpose is the original array + rewrite(x.T.T, subsume=False).to(x), # if_ - rewrite(NDArray.if_(TRUE, x, x1)).to(x), - rewrite(NDArray.if_(FALSE, x, x1)).to(x1), + rewrite(NDArray.if_(TRUE, xt, x1t), subsume=True).to(xt()), + rewrite(NDArray.if_(FALSE, xt, x1t), subsume=True).to(x1t()), + # to RecursiveValue, + # only trigger if size smaller than 20 to avoid blowing up + rule( + eq(x).to(NDArray.fn(TupleInt(vi), dtype, idx_fn)), + ).then(TupleInt(vi).product()), + rule( + eq(x).to(NDArray.fn(TupleInt(vi), dtype, idx_fn)), + eq(TupleInt(vi).product()).to(Int(i)), + i <= 20, + ).then( + union(x).with_(NDArray(RecursiveValue.from_index_and_shape(vi, idx_fn))), + subsume(NDArray.fn(TupleInt(vi), dtype, idx_fn)), + ), ] class TupleNDArray(Expr, ruleset=array_api_ruleset): - EMPTY: ClassVar[TupleNDArray] - - def __init__(self, length: IntLike, idx_fn: Callable[[Int], NDArray]) -> None: ... - - def append(self, i: NDArrayLike) -> TupleNDArray: ... - + def __init__(self, vec: VecLike[NDArray, NDArrayLike] = ()) -> None: ... @classmethod - def from_vec(cls, vec: Vec[NDArray]) -> TupleNDArray: ... - - def __add__(self, other: TupleNDArrayLike) -> TupleNDArray: - other = cast("TupleNDArray", other) - return TupleNDArray( - self.length() + other.length(), - lambda i: NDArray.if_(i < self.length(), self[i], other[i - self.length()]), - ) - + def fn(cls, length: IntLike, idx_fn: Callable[[Int], NDArray]) -> TupleNDArray: ... def length(self) -> Int: ... - def __getitem__(self, i: IntLike) -> NDArray: ... - @method(preserve=True) def __len__(self) -> int: return self.length().eval() @@ -1314,53 +2005,68 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[NDArray]: return iter(self.eval()) - @property - def to_vec(self) -> Vec[NDArray]: ... - @method(preserve=True) def eval(self) -> tuple[NDArray, ...]: - return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec) + return try_evaling(self) + + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> tuple[NDArray, ...]: + match get_callable_args(self, TupleNDArray): + case (vec,): + return tuple(cast("Vec[NDArray]", vec)) + raise ExprValueError(self, "TupleNDArray(vec)") + + @method(unextractable=True) + def append(self, i: NDArrayLike) -> TupleNDArray: + return TupleNDArray.fn( + self.length() + 1, lambda j: NDArray.if_(j == self.length(), lambda: cast("NDArray", i), lambda: self[j]) + ) + + @method(unextractable=True) + def __add__(self, other: TupleValueLike) -> TupleNDArray: + other = cast("TupleNDArray", other) + return TupleNDArray.fn( + self.length() + other.length(), + lambda i: NDArray.if_(i < self.length(), lambda: self[i], lambda: other[i - self.length()]), + ) + + @method(unextractable=True) + def drop_last(self) -> TupleNDArray: + return TupleNDArray.fn(self.length() - 1, self.__getitem__) + + @method(unextractable=True) + def last(self) -> NDArray: + return self[self.length() - 1] -converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x)) +converter(Vec[NDArray], TupleNDArray, TupleNDArray) TupleNDArrayLike: TypeAlias = TupleNDArray | VecLike[NDArray, NDArrayLike] @array_api_ruleset.register def _tuple_ndarray( - length: Int, + i: Int, + i2: Int, idx_fn: Callable[[Int], NDArray], - k: i64, - idx: Int, vs: Vec[NDArray], - v: NDArray, - v1: NDArray, - tv: TupleNDArray, - tv1: TupleNDArray, - b: Boolean, + vs2: Vec[NDArray], + ti: TupleNDArray, + k: i64, + lt: Callable[[], TupleNDArray], + lf: Callable[[], TupleNDArray], ): - yield rule(eq(tv).to(TupleNDArray.from_vec(vs))).then(set_(tv.to_vec).to(vs)) - yield rewrite(TupleNDArray(length, idx_fn).length()).to(length) - yield rewrite(TupleNDArray(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length))) - - # cons access - yield rewrite(TupleNDArray.EMPTY.length()).to(Int(0)) - yield rewrite(TupleNDArray.EMPTY[idx]).to(NDArray.NEVER) - yield rewrite(tv.append(v).length()).to(tv.length() + 1) - yield rewrite(tv.append(v)[idx]).to(NDArray.if_(idx == tv.length(), v, tv[idx])) - # functional to cons - yield rewrite(TupleNDArray(0, idx_fn), subsume=True).to(TupleNDArray.EMPTY) - yield rewrite(TupleNDArray(Int(k), idx_fn), subsume=True).to( - TupleNDArray(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0 - ) + yield rule(eq(ti).to(TupleNDArray(vs)), eq(ti).to(TupleNDArray(vs2)), vs != vs2).then(vs | vs2) + yield rewrite(TupleNDArray.fn(i2, idx_fn).length(), subsume=False).to(i2) + yield rewrite(TupleNDArray.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i))) - # cons to vec - yield rewrite(TupleNDArray.EMPTY).to(TupleNDArray.from_vec(Vec[NDArray]())) - yield rewrite(TupleNDArray.from_vec(vs).append(v)).to(TupleNDArray.from_vec(vs.append(Vec(v)))) + yield rewrite(TupleNDArray(vs).length(), subsume=False).to(Int(vs.length())) + yield rewrite(TupleNDArray(vs)[Int(k)], subsume=False).to(vs[k], k >= 0, k < vs.length()) - # unify append - yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1)) + yield rewrite(TupleNDArray.fn(Int(k), idx_fn), subsume=True).to( + TupleNDArray(k.range().map(lambda i: idx_fn(Int(i)))), k >= 0 + ) class OptionalBool(Expr, ruleset=array_api_ruleset): @@ -1407,29 +2113,21 @@ def some(cls, value: TupleIntLike) -> OptionalTupleInt: ... converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x)) -class IntOrTuple(Expr, ruleset=array_api_ruleset): - none: ClassVar[IntOrTuple] +class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset): + none: ClassVar[OptionalIntOrTuple] @classmethod - def int(cls, value: Int) -> IntOrTuple: ... + def int(cls, value: Int) -> OptionalIntOrTuple: ... @classmethod - def tuple(cls, value: TupleIntLike) -> IntOrTuple: ... - - -converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v)) -converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v)) - + def tuple(cls, value: TupleIntLike) -> OptionalIntOrTuple: ... -class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset): - none: ClassVar[OptionalIntOrTuple] - - @classmethod - def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ... +OptionalIntOrTupleLike: TypeAlias = OptionalIntOrTuple | None | IntLike | TupleIntLike converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none) -converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v)) +converter(Int, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.int(v)) +converter(TupleInt, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.tuple(v)) @function @@ -1442,7 +2140,7 @@ def asarray( @array_api_ruleset.register -def _assarray(a: NDArray, d: OptionalDType, ob: OptionalBool): +def _asarray(a: NDArray, d: OptionalDType, ob: OptionalBool): yield rewrite(asarray(a, d, ob).ndim).to(a.ndim) # asarray doesn't change ndim yield rewrite(asarray(a)).to(a) @@ -1452,7 +2150,7 @@ def isfinite(x: NDArray) -> NDArray: ... @function -def sum(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: +def sum(x: NDArray, axis: OptionalIntOrTupleLike = OptionalIntOrTuple.none) -> NDArray: """ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.sum.html?highlight=sum """ @@ -1461,29 +2159,36 @@ def sum(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArr @array_api_ruleset.register def _sum(x: NDArray, y: NDArray, v: Value, dtype: DType): return [ - rewrite(sum(x / NDArray.scalar(v))).to(sum(x) / NDArray.scalar(v)), + rewrite(sum(x / NDArray(v))).to(sum(x) / NDArray(v)), # Sum of 0D array is ] -@function -def reshape(x: NDArray, shape: TupleIntLike, copy: OptionalBool = OptionalBool.none) -> NDArray: ... - - -# @function -# def reshape_transform_index(original_shape: TupleInt, shape: TupleInt, index: TupleInt) -> TupleInt: -# """ -# Transforms an indexing operation on a reshaped array to an indexing operation on the original array. -# """ -# ... +@function(ruleset=array_api_ruleset) +def reshape(x: NDArray, shape: TupleIntLike, copy: OptionalBool = OptionalBool.none) -> NDArray: + shape = cast("TupleInt", shape) + resolved_shape = normalize_reshape_shape(x.shape, shape) + return NDArray.if_( + # If we are reshaping to the same shape, just return the original array to avoid unnecessary indexing + resolved_shape == x.shape, + lambda: x, + lambda: NDArray.fn( + resolved_shape, + x.dtype, + lambda idx: x.index(unravel_index(ravel_index(idx, resolved_shape), x.shape)), + ), + ) -# @function -# def reshape_transform_shape(original_shape: TupleInt, shape: TupleInt) -> TupleInt: -# """ -# Transforms the shape of an array to one that is reshaped, by replacing -1 with the correct value. -# """ -# ... +@function(ruleset=array_api_ruleset, unextractable=True) +def normalize_reshape_shape(original_shape: TupleIntLike, shape: TupleIntLike) -> TupleInt: + """ + Replace a single inferred `-1` dimension with the corresponding concrete dimension. + """ + original_shape = cast("TupleInt", original_shape) + shape = cast("TupleInt", shape) + inferred_dim = original_shape.product() // shape.filter(lambda d: ~(d == Int(-1))).product() + return shape.map(lambda d: Int.if_(d == Int(-1), lambda: inferred_dim, lambda: d)) # @array_api_ruleset.register @@ -1529,10 +2234,10 @@ def concat(arrays: TupleNDArrayLike, axis: OptionalInt = OptionalInt.none) -> ND @array_api_ruleset.register -def _concat(x: NDArray): +def _concat(vs: Vec[NDArray]): return [ # only support no-op concat for now - rewrite(concat(TupleNDArray.EMPTY.append(x))).to(x), + rewrite(concat(TupleNDArray(vs))).to(vs[0], vs.length() == i64(1)), ] @@ -1544,13 +2249,11 @@ def astype(x: NDArray, dtype: DType) -> NDArray: ... def _astype(x: NDArray, dtype: DType, i: i64): return [ rewrite(astype(x, dtype).dtype).to(dtype), - rewrite(astype(NDArray.scalar(Value.int(Int(i))), float64)).to( - NDArray.scalar(Value.float(Float(f64.from_i64(i)))) - ), + rewrite(astype(NDArray(Value.from_int(Int(i))), float64)).to(NDArray(Value.from_float(Float(f64.from_i64(i))))), ] -@function +@function(unextractable=True, ruleset=array_api_ruleset) def unique_counts(x: NDArray) -> TupleNDArray: """ Returns the unique elements of an input array x and the corresponding counts for each unique element in x. @@ -1558,18 +2261,25 @@ def unique_counts(x: NDArray) -> TupleNDArray: https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_counts.html """ + return TupleNDArray((unique_counts_elements(x), unique_counts_counts(x))) + + +@function +def unique_counts_elements(x: NDArray) -> NDArray: ... + + +@function +def unique_counts_counts(x: NDArray) -> NDArray: ... @array_api_ruleset.register def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value, dtype: DType): return [ - # rewrite(unique_counts(x).length()).to(Int(2)), - rewrite(unique_counts(x)).to(TupleNDArray(2, unique_counts(x).__getitem__)), # Sum of all unique counts is the size of the array - rewrite(sum(unique_counts(x)[Int(1)])).to(NDArray.scalar(Value.int(x.size))), + rewrite(sum(unique_counts_counts(x))).to(NDArray(Value.from_int(x.size))), # Same but with astype in the middle # TODO: Replace - rewrite(sum(astype(unique_counts(x)[Int(1)], dtype))).to(astype(NDArray.scalar(Value.int(x.size)), dtype)), + rewrite(sum(astype(unique_counts_counts(x), dtype))).to(astype(NDArray(Value.from_int(x.size)), dtype)), ] @@ -1592,26 +2302,29 @@ def log(x: NDArray) -> NDArray: ... @array_api_ruleset.register def _abs(f: Float): return [ - rewrite(abs(NDArray.scalar(Value.float(f)))).to(NDArray.scalar(Value.float(f.abs()))), + rewrite(abs(NDArray(Value.from_float(f)))).to(NDArray(Value.from_float(f.abs()))), ] -@function +@function(ruleset=array_api_ruleset, unextractable=True) def unique_inverse(x: NDArray) -> TupleNDArray: """ Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x. https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_inverse.html """ + return TupleNDArray((unique_values(x), unique_inverse_inverse_indices(x))) + + +@function +def unique_inverse_inverse_indices(x: NDArray) -> NDArray: ... @array_api_ruleset.register def _unique_inverse(x: NDArray, i: Int): return [ - # rewrite(unique_inverse(x).length()).to(Int(2)), - rewrite(unique_inverse(x)).to(TupleNDArray(2, unique_inverse(x).__getitem__)), # Shape of unique_inverse first element is same as shape of unique_values - rewrite(unique_inverse(x)[Int(0)]).to(unique_values(x)), + rewrite(unique_values(x)[Int(0)]).to(unique_values(x)), ] @@ -1626,7 +2339,7 @@ def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ... @function -def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ... +def mean(x: NDArray, axis: OptionalIntOrTupleLike = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ... # TODO: Possibly change names to include modules. @@ -1635,7 +2348,7 @@ def sqrt(x: NDArray) -> NDArray: ... @function -def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ... +def std(x: NDArray, axis: OptionalIntOrTupleLike = OptionalIntOrTuple.none) -> NDArray: ... @function @@ -1646,25 +2359,97 @@ def real(x: NDArray) -> NDArray: ... def conj(x: NDArray) -> NDArray: ... +@function(ruleset=array_api_ruleset, unextractable=True) +def vecdot(x1: NDArrayLike, x2: NDArrayLike) -> NDArray: + """ + https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.vecdot.html + https://numpy.org/doc/stable/reference/generated/numpy.vecdot.html + + TODO: Support axis, complex numbers, broadcasting, and more than matrix-vector + + >>> v = NDArray([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]]) + >>> n = NDArray([0., 0.6, 0.8]) + >>> vecdot(v, n).eval_numpy("float64") + array([ 3., 8., 10.]) + """ + x1 = cast("NDArray", x1) + x2 = cast("NDArray", x2) + + return NDArray.fn( + x1.shape.drop_last(), + x1.dtype, + lambda idx: ( + TupleInt.range(x1.shape.last()) + .map_value(lambda i: x1.index(idx.append(i)) * x2.index((i,))) + .foldl_value(Value.__add__, Value.from_float(0)) + ), + ) + + +@function(ruleset=array_api_ruleset, unextractable=True) +def vector_norm(x: NDArrayLike) -> NDArray: + """ + https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.vector_norm.html + TODO: support axis + # >>> x = NDArray([1, 2, 3, 4, 5, 6, 7, 8, 9]) + # >>> vector_norm(x).eval_numpy("float64") + # array(16.88194302) + """ + # https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html#numpy.linalg.norm + # sum(abs(x)**ord)**(1./ord) where ord=2 + x = cast("NDArray", x) + # Only works on vectors + return NDArray( + TupleInt.range(x.shape[0]).foldl_value( + lambda acc, i: acc + (x.index((i,)).__abs__() ** Value.from_float(Float(2.0))), + Value.from_float(Float(0.0)), + ) + ** Value.from_float(Float(0.5)) + ) + + +@function(ruleset=array_api_ruleset, unextractable=True) +def cross(a: NDArrayLike, b: NDArrayLike) -> NDArray: + """ + https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.cross.html + TODO: support axis, and more than two vecs + + >>> x = NDArray([1, 2, 3]) + >>> y = NDArray([4, 5, 6]) + >>> cross(x, y).eval_numpy("int64") + array([-3, 6, -3]) + """ + a = cast("NDArray", a) + b = cast("NDArray", b) + return NDArray.fn( + (3,), + a.dtype, + lambda idx: ( + (a.index(((idx[0] + 1) % 3,)) * b.index(((idx[0] + 2) % 3,))) + - (a.index(((idx[0] + 2) % 3,)) * b.index(((idx[0] + 1) % 3,))) + ), + ) + + linalg = sys.modules[__name__] -@function -def svd(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray: +def svd(x: NDArray, full_matrices: Boolean = TRUE) -> tuple[NDArray, NDArray, NDArray]: """ https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html """ + res = svd_(x, full_matrices) + return (res[0], res[1], res[2]) -@array_api_ruleset.register -def _linalg(x: NDArray, full_matrices: Boolean): - return [ - # rewrite(svd(x, full_matrices).length()).to(Int(3)), - rewrite(svd(x, full_matrices)).to(TupleNDArray(3, svd(x, full_matrices).__getitem__)), - ] +@function +def svd_(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray: + """ + https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html + """ -@function(ruleset=array_api_ruleset) +@function(ruleset=array_api_ruleset, unextractable=True) def ndindex(shape: TupleIntLike) -> TupleTupleInt: """ https://numpy.org/doc/stable/reference/generated/numpy.ndindex.html @@ -1676,7 +2461,7 @@ def ndindex(shape: TupleIntLike) -> TupleTupleInt: ## # Interval analysis # -# to analyze `any(((astype(unique_counts(NDArray.var("y"))[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0))) < NDArray.scalar(Value.int(Int(0)))).bool()`` +# to analyze `any(((astype(unique_counts(NDArray.var("y"))[Int(1)], DType.float64) / NDArray(Value.float(Float(150.0))) < NDArray(Value.from_int(Int(0)))).bool()`` ## greater_zero = relation("greater_zero", Value) @@ -1697,9 +2482,9 @@ def ndindex(shape: TupleIntLike) -> TupleTupleInt: # ... -# any((astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0)))) < NDArray.scalar(Value.int(Int(0)))).to_bool() +# any((astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray(Value.float(Float(150.0)))) < NDArray(Value.from_int(Int(0)))).to_bool() -# sum(astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.int(Int(150)))) +# sum(astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray(Value.from_int(Int(150)))) # And also # def @@ -1739,23 +2524,25 @@ def _interval_analaysis( x_value = x.index(broadcast_index(x.shape, res_shape, idx)) y_value = y.index(broadcast_index(y.shape, res_shape, idx)) return [ - # Calling any on an array gives back a sclar, which is true if any of the values are truthy - rewrite(any(x)).to( - NDArray.scalar(Value.bool(possible_values(x.index(ALL_INDICES).to_truthy_value).contains(Value.bool(TRUE)))) + # Calling any on an array gives back a scalar, which is true if any of the values are truthy + rewrite(any(x), subsume=False).to( + NDArray( + Value.from_bool(possible_values(x.index(ALL_INDICES).to_truthy_value).contains(Value.from_bool(TRUE))) + ), ), # Indexing x < y is the same as broadcasting the index and then indexing both and then comparing rewrite((x < y).index(idx)).to(x_value < y_value), # Same for x / y rewrite((x / y).index(idx)).to(x_value / y_value), # Indexing a scalar is the same as the scalar - rewrite(NDArray.scalar(v).index(idx)).to(v), + rewrite(NDArray(v).index(idx)).to(v), # Indexing of astype is same as astype of indexing rewrite(astype(x, dtype).index(idx)).to(x.index(idx).astype(dtype)), - # rule(eq(y).to(x < NDArray.scalar(Value.int(Int(0)))), ndarray_all_greater_0(x)).then(ndarray_all_false(y)), - # rule(eq(y).to(any(x)), ndarray_all_false(x)).then(union(y).with_(NDArray.scalar(Value.bool(FALSE)))), + # rule(eq(y).to(x < NDArray(Value.from_int(Int(0)))), ndarray_all_greater_0(x)).then(ndarray_all_false(y)), + # rule(eq(y).to(any(x)), ndarray_all_false(x)).then(union(y).with_(NDArray(Value.bool(FALSE)))), # Indexing into unique counts counts are all positive rule( - eq(v).to(unique_counts(x)[Int(1)].index(idx)), + eq(v).to(unique_counts_counts(x).index(idx)), ).then(greater_zero(v)), # Min value preserved over astype rule( @@ -1765,9 +2552,9 @@ def _interval_analaysis( greater_zero(v1), ), # Min value of scalar is scalar itself - rule(eq(v).to(Value.float(Float(f))), f > 0.0).then(greater_zero(v)), - rule(eq(v).to(Value.int(Int(i))), i > 0).then(greater_zero(v)), - # If we have divison of v and v1, and both greater than zero, then the result is greater than zero + rule(eq(v).to(Value.from_float(Float(f))), f > 0.0).then(greater_zero(v)), + rule(eq(v).to(Value.from_int(Int(i))), i > 0).then(greater_zero(v)), + # If we have division of v and v1, and both greater than zero, then the result is greater than zero rule( greater_zero(v), greater_zero(v1), @@ -1778,12 +2565,12 @@ def _interval_analaysis( # Define v < 0 to be false, if greater_zero(v) rule( greater_zero(v), - eq(v1).to(v < Value.int(Int(0))), + eq(v1).to(v < Value.from_int(Int(0))), ).then( - union(v1).with_(Value.bool(FALSE)), + union(v1).with_(Value.from_bool(FALSE)), ), # possible values of bool is bool - rewrite(possible_values(Value.bool(b))).to(TupleValue.EMPTY.append(Value.bool(b))), + rewrite(possible_values(Value.from_bool(b))).to(TupleValue([Value.from_bool(b)])), # casting to a type preserves if > 0 rule( eq(v1).to(v.astype(dtype)), @@ -1804,46 +2591,10 @@ def _interval_analaysis( ## -def _demand_shape(compound: NDArray, inner: NDArray) -> Command: - __a = var("__a", NDArray) - return rule(eq(__a).to(compound)).then(inner.shape, inner.shape.length()) - - -@array_api_ruleset.register -def _scalar_math(v: Value, vs: TupleValue, i: Int): - yield rewrite(NDArray.scalar(v).shape).to(TupleInt.EMPTY) - yield rewrite(NDArray.scalar(v).dtype).to(v.dtype) - yield rewrite(NDArray.scalar(v).index(TupleInt.EMPTY)).to(v) - - -@array_api_ruleset.register -def _vector_math(v: Value, vs: TupleValue, ti: TupleInt): - yield rewrite(NDArray.vector(vs).shape).to(TupleInt.single(vs.length())) - yield rewrite(NDArray.vector(vs).dtype).to(vs[Int(0)].dtype) - yield rewrite(NDArray.vector(vs).index(ti)).to(vs[ti[0]]) - - -@array_api_ruleset.register -def _reshape_math(x: NDArray, shape: TupleInt, copy: OptionalBool): - res = reshape(x, shape, copy) - - yield _demand_shape(res, x) - # Demand shape length and index - yield rule(res).then(shape.length(), shape[0]) - - # Reshaping a vec to a vec is the same as the vec - yield rewrite(res).to( - x, - eq(x.shape.length()).to(Int(1)), - eq(shape.length()).to(Int(1)), - eq(shape[0]).to(Int(-1)), - ) - - @array_api_ruleset.register def _indexing_pushdown(x: NDArray, shape: TupleInt, copy: OptionalBool, i: Int): # rewrite full getitem to indexec - yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt.single(i)))) + yield rewrite(x[IndexKey.int(i)]).to(NDArray(x.index(TupleInt([i])))) # TODO: Multi index rewrite as well if all are ints @@ -1900,7 +2651,7 @@ def _isfinite(x: NDArray, ti: TupleInt): yield rewrite(x.shape).to(orig_x.shape) yield rewrite(x.dtype).to(orig_x.dtype) yield rewrite(x.index(ti)).to(orig_x.index(ti)) - # But say that any indixed value is finite + # But say that any indexed value is finite yield rewrite(x.index(ti).isfinite()).to(TRUE) @@ -1927,17 +2678,17 @@ def _assume_value_one_of(x: NDArray, v: Value, vs: TupleValue, idx: TupleInt): @array_api_ruleset.register def _ndarray_value_isfinite(arr: NDArray, x: Value, xs: TupleValue, i: Int, f: f64, b: Boolean): - yield rewrite(Value.int(i).isfinite()).to(TRUE) - yield rewrite(Value.bool(b).isfinite()).to(TRUE) - yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, ne(f).to(f64(math.nan))) + yield rewrite(Value.from_int(i).isfinite()).to(TRUE) + yield rewrite(Value.from_bool(b).isfinite()).to(TRUE) + yield rewrite(Value.from_float(Float(f)).isfinite()).to(TRUE, ne(f).to(f64(math.nan))) # a sum of an array is finite if all the values are finite - yield rewrite(isfinite(sum(arr))).to(NDArray.scalar(Value.bool(arr.index(ALL_INDICES).isfinite()))) + yield rewrite(isfinite(sum(arr))).to(NDArray(Value.from_bool(arr.index(ALL_INDICES).isfinite()))) @array_api_ruleset.register def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool): - yield rewrite(unique_values(x=a)).to(NDArray.vector(possible_values(a.index(ALL_INDICES)))) + yield rewrite(unique_values(x=a)).to(NDArray.from_tuple_value(possible_values(a.index(ALL_INDICES)))) # yield rewrite( # possible_values(reshape(a.index(shape, copy), ALL_INDICES)), # ).to(possible_values(a.index(ALL_INDICES))) @@ -1948,36 +2699,60 @@ def _size(x: NDArray): yield rewrite(x.size).to(x.shape.foldl(Int.__mul__, Int(1))) -# Seperate rulseset so we can use it in program gen -@ruleset -def array_api_vec_to_cons_ruleset( - vs: Vec[Int], - vv: Vec[Value], - vn: Vec[NDArray], - vt: Vec[TupleInt], -): - yield rewrite(TupleInt.from_vec(vs)).to(TupleInt.EMPTY, eq(vs.length()).to(i64(0))) - yield rewrite(TupleInt.from_vec(vs)).to( - TupleInt.from_vec(vs.remove(vs.length() - 1)).append(vs[vs.length() - 1]), ne(vs.length()).to(i64(0)) - ) +@function(ruleset=array_api_ruleset) +def ravel_index(index: TupleIntLike, shape: TupleIntLike) -> Int: + """ + Convert a multi-dimensional index to a flat index. + + https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html#numpy.ravel_multi_index + + >>> int(ravel_index((3, 4), (7, 6))) + 22 + >>> int(ravel_index((6, 5), (7, 6))) + 41 + >>> int(ravel_index((6, 1), (7, 6))) + 37 + >>> int(ravel_index((3, 1, 4, 1), (6, 7, 8, 9))) + 1621 + """ + index = cast("TupleInt", index) + shape = cast("TupleInt", shape) - yield rewrite(TupleValue.from_vec(vv)).to(TupleValue.EMPTY, eq(vv.length()).to(i64(0))) - yield rewrite(TupleValue.from_vec(vv)).to( - TupleValue.from_vec(vv.remove(vv.length() - 1)).append(vv[vv.length() - 1]), ne(vv.length()).to(i64(0)) - ) + return TupleInt.range(shape.length()).foldl(lambda res, i: res * shape[i] + index[i], Int(0)) - yield rewrite(TupleTupleInt.from_vec(vt)).to(TupleTupleInt.EMPTY, eq(vt.length()).to(i64(0))) - yield rewrite(TupleTupleInt.from_vec(vt)).to( - TupleTupleInt.from_vec(vt.remove(vt.length() - 1)).append(vt[vt.length() - 1]), ne(vt.length()).to(i64(0)) - ) - yield rewrite(TupleNDArray.from_vec(vn)).to(TupleNDArray.EMPTY, eq(vn.length()).to(i64(0))) - yield rewrite(TupleNDArray.from_vec(vn)).to( - TupleNDArray.from_vec(vn.remove(vn.length() - 1)).append(vn[vn.length() - 1]), ne(vn.length()).to(i64(0)) + +@function(ruleset=array_api_ruleset) +def unravel_index(flat_index: IntLike, shape: TupleIntLike) -> TupleInt: + """ + Convert a flat index to a multi-dimensional index. + + https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html + + >>> tuple(map(int, unravel_index(22, (7, 6)))) + (3, 4) + >>> tuple(map(int, unravel_index(41, (7, 6)))) + (6, 5) + >>> tuple(map(int, unravel_index(37, (7, 6)))) + (6, 1) + >>> tuple(map(int, unravel_index(1621, (6, 7, 8, 9)))) + (3, 1, 4, 1) + """ + shape = cast("TupleInt", shape) + + return ( + shape.reverse() + .foldl_tuple_int( + # Store remainder as last item in accumulator + lambda acc, dim: acc.drop_last().append((r := acc.last()) % dim).append(r // dim), + TupleInt([flat_index]), + ) + .drop_last() + .reverse() ) -array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset -array_api_schedule = array_api_combined_ruleset.saturate() +array_api_combined_ruleset = array_api_ruleset +array_api_schedule = (array_api_combined_ruleset + run()).saturate() _CURRENT_EGRAPH: None | EGraph = None @@ -1990,30 +2765,216 @@ def set_array_api_egraph(egraph: EGraph) -> Iterator[None]: global _CURRENT_EGRAPH assert _CURRENT_EGRAPH is None _CURRENT_EGRAPH = egraph - yield - _CURRENT_EGRAPH = None + try: + yield + finally: + _CURRENT_EGRAPH = None def _get_current_egraph() -> EGraph: - return _CURRENT_EGRAPH or EGraph() + return _CURRENT_EGRAPH or EGraph(save_egglog_string=True) + + +T_co = TypeVar("T_co", covariant=True) + + +class ExprWithValue(Protocol[T_co]): + @property + def value(self) -> T_co: ... -def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any: +@_TRACER.start_as_current_span("try_evaling") +def try_evaling(expr: ExprWithValue[T_co]) -> T_co: """ - Try evaling the expression that will result in a primitive expression being fill. - if it fails, display the egraph and raise an error. + Evaluate an expression in the current e-graph, then re-extract it in a fresh + e-graph to avoid cross-expression contradictions. """ - try: - extracted = egraph.extract(prim_expr) - except EggSmolError: - # If this primitive doesn't exist in the egraph, we need to try to create it by - # registering the expression and running the schedule - egraph.register(expr) - egraph.run(schedule) - try: - extracted = egraph.extract(prim_expr) - except BaseException as e: - # egraph.display(n_inline_leaves=1, split_primitive_outputs=True) - e.add_note(f"Cannot evaluate {egraph.extract(expr)}") - raise - return extracted.value # type: ignore[attr-defined] + egraph = _get_current_egraph() + egraph.register(expr) # type: ignore[arg-type] + egraph.run(array_api_schedule) + return egraph.extract(expr).value # type: ignore[call-overload] + + +# Polynomials +## + + +@function +def polynomial(x: MultiSetLike[MultiSet[Value], MultiSetLike[Value, ValueLike]]) -> Value: ... + + +@function(merge=lambda old, new: new) +def get_monomial(x: Value) -> MultiSet[Value]: + """ + Should be defined on all polynomials with one monomial created in `to_polynomial_ruleset`: + + get_monomial(polynomial(MultiSet(xs))) => xs + """ + + +@function(merge=lambda old, new: new) +def get_sole_polynomial(xs: MultiSet[Value]) -> MultiSet[MultiSet[Value]]: + """ + Should be defined on all monomials that contain a single polynomial created in `to_polynomial_ruleset`: + + get_sole_polynomial(MultiSet(polynomial(xss))) => xss + """ + + +@ruleset +def to_polynomial_ruleset( + n1: Value, + n2: Value, + n3: Value, + i: i64, + ms: MultiSet[Value], + mss: MultiSet[MultiSet[Value]], + mss1: MultiSet[MultiSet[Value]], +): + yield rule( + eq(n3).to(n1 + n2), + eq(mss).to(MultiSet(MultiSet(n1), MultiSet(n2))), + name="add", + ).then( + union(n3).with_(polynomial(mss)), + set_(get_sole_polynomial(MultiSet(polynomial(mss)))).to(mss), + delete(n1 + n2), + ) + yield rule( + eq(n3).to(n1 * n2), + eq(ms).to(MultiSet(n1, n2)), + name="mul", + ).then( + union(n3).with_(polynomial(MultiSet(ms))), + set_(get_monomial(polynomial(MultiSet(ms)))).to(ms), + delete(n1 * n2), + ) + yield rule( + eq(n3).to(n1**i), + i >= 0, + eq(ms).to(MultiSet.single(n1, i)), + name="pow", + ).then( + union(n3).with_(polynomial(MultiSet(ms))), + set_(get_monomial(polynomial(MultiSet(ms)))).to(ms), + delete(n1**i), + ) + + yield rule( + eq(n1).to(polynomial(mss)), + # For each monomial, if any of its terms is a polynomial with a single monomial, just flatten + # that into the monomial + mss1 == mss.map(partial(multiset_flat_map, get_monomial)), + mss != mss1, # skip if this is a no-op + name="unwrap monomial", + ).then( + union(n1).with_(polynomial(mss1)), + delete(polynomial(mss)), + set_(get_sole_polynomial(MultiSet(n1))).to(mss1), + ) + yield rule( + eq(n1).to(polynomial(mss)), + # If any of the monomials just has a single item which is a polynomial, then flatten that into the outer polynomial + mss1 == multiset_flat_map(UnstableFn(get_sole_polynomial), mss), + mss != mss1, + name="unwrap polynomial", + ).then( + union(n1).with_(polynomial(mss1)), + delete(polynomial(mss)), + set_(get_sole_polynomial(MultiSet(n1))).to(mss1), + ) + + +@ruleset +def factor_ruleset( + n: Value, + mss: MultiSet[MultiSet[Value]], + counts: MultiSet[Value], + picked_term: Value, + picked: MultiSet[MultiSet[Value]], + divided: MultiSet[MultiSet[Value]], + factor: MultiSet[Value], + remainder: MultiSet[MultiSet[Value]], +): + yield rule( + eq(n).to(polynomial(mss)), + # Find factor that shows up in most monomials, at least two of them + counts == MultiSet.sum_multisets(mss.map(MultiSet.reset_counts)), + eq(picked_term).to(counts.pick_max()), + # Only factor out if it term appears in more than one monomial + counts.count(picked_term) > 1, + # The factor we choose is the largest intersection between all the monomials that have the picked term + picked == mss.filter(partial(multiset_contains_swapped, picked_term)), + factor == multiset_fold(MultiSet.__and__, picked.pick(), picked), + divided == picked.map(partial(multiset_subtract_swapped, factor)), + # remainder is those monomials that do not contain the factor + remainder == mss.filter(partial(multiset_not_contains_swapped, picked_term)), + name="factor", + ).then( + union(n).with_(polynomial(MultiSet(factor.insert(polynomial(divided))) + remainder)), + delete(polynomial(mss)), + ) + + +@ruleset +def from_polynomial_ruleset(mss: MultiSet[MultiSet[Value]], n1: Value, n: Value, i: i64): + mul: Callable[[Value, Value], Value] = Value.__mul__ + + yield rule( + eq(n).to(polynomial(mss)), + ).then( + union(n).with_( + multiset_fold( + Value.__add__, + Value.from_int(0), + mss.map( + partial(multiset_fold, mul, Value.from_int(1)), + ), + ) + ), + delete(polynomial(mss)), + ) + + # TODO: change this to emit more efficient form in the future + + # Clean up exponents + yield rule( + eq(n1).to(n * n), + ).then( + union(n1).with_(n**2), + delete(n * n), + ) + yield rule( + eq(n1).to(n**i * n), + ).then( + union(n1).with_(n ** (i + 1)), + delete(n**i * n), + ) + yield rule( + eq(n1).to(n * n**i), + ).then( + union(n1).with_(n ** (i + 1)), + delete(n * n**i), + ) + # clean up muls + yield rule( + eq(n1).to(n + n), + ).then( + union(n1).with_(Value.from_int(2) * n), + delete(n + n), + ) + yield rule( + eq(n1).to(Value.from_int(i) * n + n), + ).then( + union(n1).with_(Value.from_int(i + 1) * n), + delete(Value.from_int(i) * n + n), + ) + yield rule( + eq(n1).to(n + Value.from_int(i) * n), + ).then( + union(n1).with_(Value.from_int(i + 1) * n), + delete(n + Value.from_int(i) * n), + ) + + +polynomial_schedule = to_polynomial_ruleset.saturate() + factor_ruleset.saturate() + from_polynomial_ruleset.saturate() diff --git a/python/egglog/exp/array_api_jit.py b/python/egglog/exp/array_api_jit.py index 529b4a41..1fa9624e 100644 --- a/python/egglog/exp/array_api_jit.py +++ b/python/egglog/exp/array_api_jit.py @@ -3,17 +3,20 @@ from typing import TypeVar, cast import numpy as np +from opentelemetry import trace -from egglog import EGraph, greedy_dag_cost_model -from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling +from egglog import EGraph, bindings, greedy_dag_cost_model +from egglog.exp.array_api import NDArray, set_array_api_egraph from egglog.exp.array_api_numba import array_api_numba_schedule from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program from .program_gen import Program X = TypeVar("X", bound=Callable) +_TRACER = trace.get_tracer(__name__) +@_TRACER.start_as_current_span("jit") def jit( fn: X, *, @@ -24,20 +27,34 @@ def jit( Jit compiles a function """ egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False) + egraph = EGraph() if handle_expr: handle_expr(res) if handle_optimized_expr: handle_optimized_expr(res_optimized) fn_program = EvalProgram(program, {"np": np}) - return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object)) + egraph.register(fn_program) + egraph.run(array_api_program_gen_schedule) + try: + return cast("X", egraph.extract(fn_program.as_py_object).value) + except bindings.EggSmolError as e: + try: + debug_program = egraph.extract(fn_program) + except bindings.EggSmolError: + debug_program = fn_program + e.add_note(f"Failed to get py object from {debug_program}") + raise + + +@_TRACER.start_as_current_span("function_to_program") def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]: sig = inspect.signature(fn) arg1, arg2 = sig.parameters.keys() egraph = EGraph(save_egglog_string=save_egglog_string) with egraph: - with set_array_api_egraph(egraph): + with _TRACER.start_as_current_span("call_function"), set_array_api_egraph(egraph): res = fn(NDArray.var(arg1), NDArray.var(arg2)) egraph.register(res) egraph.run(array_api_numba_schedule) diff --git a/python/egglog/exp/array_api_loopnest.py b/python/egglog/exp/array_api_loopnest.py index 046c3c1a..0273c9bb 100644 --- a/python/egglog/exp/array_api_loopnest.py +++ b/python/egglog/exp/array_api_loopnest.py @@ -61,14 +61,15 @@ def get_dims(self) -> TupleInt: ... @array_api_ruleset.register -def _loopnest_api_ruleset(lna: LoopNestAPI, dim: Int, ti: TupleInt, idx_fn: Callable[[Int], Int], i: i64): +def _loopnest_api_ruleset(lna: LoopNestAPI, dim: Int, ti: TupleInt, idx_fn: Callable[[Int], Int], i: i64, vs: Vec[Int]): # from_tuple - yield rewrite(LoopNestAPI.from_tuple(TupleInt.EMPTY), subsume=True).to(OptionalLoopNestAPI.NONE) - yield rewrite(LoopNestAPI.from_tuple(ti.append(dim)), subsume=True).to( - OptionalLoopNestAPI(LoopNestAPI(dim, LoopNestAPI.from_tuple(ti))) + yield rewrite(LoopNestAPI.from_tuple(TupleInt(())), subsume=True).to(OptionalLoopNestAPI.NONE) + yield rewrite(LoopNestAPI.from_tuple(TupleInt(vs)), subsume=True).to( + OptionalLoopNestAPI(LoopNestAPI(vs[vs.length() - 1], LoopNestAPI.from_tuple(TupleInt(vs.pop())))), + vs.length() > 0, ) # get_dims - yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims(), subsume=True).to(TupleInt.single(dim)) + yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims(), subsume=True).to(TupleInt((dim,))) yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims(), subsume=True).to(lna.get_dims().append(dim)) # unwrap yield rewrite(OptionalLoopNestAPI(lna).unwrap()).to(lna) diff --git a/python/egglog/exp/array_api_numba.py b/python/egglog/exp/array_api_numba.py index ddb320cc..82b1dd1a 100644 --- a/python/egglog/exp/array_api_numba.py +++ b/python/egglog/exp/array_api_numba.py @@ -17,24 +17,19 @@ # Rewrite mean(x, , ) to use sum b/c numba cant do mean with axis # https://github.com/numba/numba/issues/1269 @array_api_numba_ruleset.register -def _mean(y: NDArray, x: NDArray, i: Int): - axis = OptionalIntOrTuple.some(IntOrTuple.int(i)) - res = sum(x, axis) / NDArray.scalar(Value.int(x.shape[i])) +def _mean(y: NDArray, x: NDArray, axis: Int): + res = sum(x, axis) / x.shape[axis] yield rewrite(mean(x, axis, FALSE), subsume=True).to(res) - yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, i)) + yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, axis)) # Rewrite std(x, ) to use mean and sum b/c numba cant do std with axis @array_api_numba_ruleset.register -def _std(y: NDArray, x: NDArray, i: Int): - axis = OptionalIntOrTuple.some(IntOrTuple.int(i)) +def _std(y: NDArray, x: NDArray, axis: Int): # https://numpy.org/doc/stable/reference/generated/numpy.std.html # "std = sqrt(mean(x)), where x = abs(a - a.mean())**2." - yield rewrite( - std(x, axis), - subsume=True, - ).to( + yield rewrite(std(x, axis), subsume=True).to( sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)), ) @@ -47,14 +42,16 @@ def count_values(x: NDArrayLike, values: TupleValueLike) -> TupleValue: """ x = cast(NDArray, x) values = cast(TupleValue, values) - return TupleValue(values.length(), lambda i: sum(x == values[i]).to_value()) + return TupleValue.fn(values.length(), lambda i: sum(x == values[i]).index(())) @array_api_numba_ruleset.register def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value): return [ # The unique counts are the count of all the unique values - rewrite(unique_counts(x)[1], subsume=True).to(NDArray.vector(count_values(x, unique_values(x).to_values()))), + rewrite(unique_counts_counts(x), subsume=True).to( + NDArray.from_tuple_value(count_values(x, unique_values(x).to_tuple_values())) + ), ] @@ -63,7 +60,5 @@ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value): def _unique_inverse(x: NDArray, i: Int): return [ # Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values - rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to( - x == NDArray.scalar(unique_values(x).index((i,))) - ), + rewrite(unique_inverse_inverse_indices(x) == i, subsume=True).to(x == unique_values(x).index((i,))), ] diff --git a/python/egglog/exp/array_api_program_gen.py b/python/egglog/exp/array_api_program_gen.py index 57027f1c..bdafffd3 100644 --- a/python/egglog/exp/array_api_program_gen.py +++ b/python/egglog/exp/array_api_program_gen.py @@ -15,12 +15,9 @@ array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset") array_api_program_gen_combined_ruleset = ( - array_api_program_gen_ruleset - | program_gen_ruleset - | array_api_program_gen_eval_ruleset - | array_api_vec_to_cons_ruleset + array_api_program_gen_ruleset | program_gen_ruleset | array_api_program_gen_eval_ruleset ) -array_api_program_gen_schedule = (array_api_program_gen_combined_ruleset | eval_program_rulseset).saturate() +array_api_program_gen_schedule = (array_api_program_gen_combined_ruleset | eval_program_ruleset).saturate() @function @@ -38,7 +35,7 @@ def int_program(x: Int) -> Program: ... @array_api_program_gen_ruleset.register -def _int_program(i64_: i64, i: Int, j: Int, s: String): +def _int_program(i64_: i64, i: Int, j: Int, s: String, b: Boolean, ti: Callable[[], Int], ti1: Callable[[], Int]): yield rewrite(int_program(Int.var(s))).to(Program(s, True)) yield rewrite(int_program(Int(i64_))).to(Program(i64_.to_string())) yield rewrite(int_program(~i)).to(Program("~") + int_program(i)) @@ -60,27 +57,57 @@ def _int_program(i64_: i64, i: Int, j: Int, s: String): yield rewrite(int_program(i >> j)).to(Program("(") + int_program(i) + " >> " + int_program(j) + ")") yield rewrite(int_program(i // j)).to(Program("(") + int_program(i) + " // " + int_program(j) + ")") + assigned = int_program(j).assign() + yield rewrite(int_program(check_index(i, j)), subsume=True).to( + assigned.statement(Program("assert ") + assigned + " < " + int_program(i)) + ) + + yield rewrite(int_program(Int.if_(b, ti, ti1))).to( + int_program(ti()) + " if " + bool_program(b) + " else " + int_program(ti1()) + ) + @function +def program_if(b: BooleanLike, t: Callable[[], Program], f: Callable[[], Program]) -> Program: ... + + +@function(ruleset=array_api_program_gen_ruleset) def tuple_int_foldl_program(xs: TupleIntLike, f: Callable[[Program, Int], Program], init: ProgramLike) -> Program: ... @function(ruleset=array_api_program_gen_ruleset) -def tuple_int_program(x: TupleIntLike) -> Program: - return tuple_int_foldl_program(x, lambda acc, i: acc + int_program(i) + ", ", "(") + ")" +def tuple_int_program(x: TupleIntLike) -> Program: ... @array_api_program_gen_ruleset.register -def _tuple_int_program(i: Int, ti: TupleInt, ti2: TupleInt, f: Callable[[Program, Int], Program], init: Program): +def _tuple_int_program( + i: Int, + ti: TupleInt, + ti2: TupleInt, + f: Callable[[Program, Int], Program], + init: Program, + b: Boolean, + tt: Callable[[], Program], + ft: Callable[[], Program], + vi: Vec[Int], +): yield rewrite(int_program(ti[i])).to(tuple_int_program(ti) + "[" + int_program(i) + "]") yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_int_program(ti) + ")") - yield rewrite(tuple_int_foldl_program(TupleInt.EMPTY, f, init)).to(init) - yield rewrite(tuple_int_foldl_program(ti.append(i), f, init)).to(f(tuple_int_foldl_program(ti, f, init), i)) + yield rewrite(program_if(True, tt, ft)).to(tt()) + yield rewrite(program_if(False, tt, ft)).to(ft()) yield rewrite(tuple_int_program(ti + ti2)).to( Program("(") + tuple_int_program(ti) + " + " + tuple_int_program(ti2) + ")" ) + yield rewrite(tuple_int_program(ti)).to( + tuple_int_foldl_program(ti, lambda acc, i: acc + int_program(i) + ", ", "(") + ")" + ) + + yield rewrite(tuple_int_foldl_program(TupleInt(()), f, init)).to(init) + yield rewrite(tuple_int_foldl_program(TupleInt(vi), f, init)).to( + f(tuple_int_foldl_program(vi.pop(), f, init), vi[vi.length() - 1]), vi.length() > 0 + ) @function @@ -139,41 +166,50 @@ def value_program(x: Value) -> Program: ... @array_api_program_gen_ruleset.register def _value_program(i: Int, b: Boolean, f: Float, x: NDArray, v1: Value, v2: Value, xs: NDArray, ti: TupleInt): - yield rewrite(value_program(Value.int(i))).to(int_program(i)) - yield rewrite(value_program(Value.bool(b))).to(bool_program(b)) - yield rewrite(value_program(Value.float(f))).to(float_program(f)) + yield rewrite(value_program(Value.from_int(i))).to(int_program(i)) + yield rewrite(value_program(Value.from_bool(b))).to(bool_program(b)) + yield rewrite(value_program(Value.from_float(f))).to(float_program(f)) # Could add .item() but we usually dont need it. - yield rewrite(value_program(x.to_value())).to(ndarray_program(x)) + # yield rewrite(value_program(x.to_value())).to(ndarray_program(x)) yield rewrite(value_program(v1 < v2)).to(Program("(") + value_program(v1) + " < " + value_program(v2) + ")") yield rewrite(value_program(v1 / v2)).to(Program("(") + value_program(v1) + " / " + value_program(v2) + ")") yield rewrite(value_program(v1 + v2)).to(Program("(") + value_program(v1) + " + " + value_program(v2) + ")") yield rewrite(value_program(v1 * v2)).to(Program("(") + value_program(v1) + " * " + value_program(v2) + ")") yield rewrite(bool_program(v1.to_bool)).to(value_program(v1)) yield rewrite(int_program(v1.to_int)).to(value_program(v1)) - yield rewrite(value_program(xs.index(ti))).to((ndarray_program(xs) + "[" + tuple_int_program(ti) + "]").assign()) + yield rewrite(value_program(xs.index(ti))).to( + (ndarray_program(xs) + "[" + tuple_int_program(ti) + "]").assign(), ne(ti).to(TupleInt(())) + ) + yield rewrite(value_program(xs.index(TupleInt(())))).to(ndarray_program(xs)) yield rewrite(value_program(v1.sqrt())).to(Program("np.sqrt(") + value_program(v1) + ")") yield rewrite(value_program(v1.real())).to(Program("np.real(") + value_program(v1) + ")") yield rewrite(value_program(v1.conj())).to(Program("np.conj(") + value_program(v1) + ")") -@function +@function(ruleset=array_api_program_gen_ruleset) def tuple_value_foldl_program( xs: TupleValueLike, f: Callable[[Program, Value], Program], init: ProgramLike ) -> Program: ... -@function(ruleset=array_api_program_gen_ruleset) -def tuple_value_program(x: TupleValueLike) -> Program: - return tuple_value_foldl_program(x, lambda acc, i: acc + value_program(i) + ", ", "(") + ")" +@function +def tuple_value_program(x: TupleValueLike) -> Program: ... @array_api_program_gen_ruleset.register -def _tuple_value_program(i: Int, ti: TupleValue, f: Callable[[Program, Value], Program], v: Value, init: Program): +def _tuple_value_program( + i: Int, ti: TupleValue, f: Callable[[Program, Value], Program], v: Value, init: Program, vv: Vec[Value] +): yield rewrite(value_program(ti[i])).to(tuple_value_program(ti) + "[" + int_program(i) + "]") yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_value_program(ti) + ")") + yield rewrite(tuple_value_program(ti)).to( + tuple_value_foldl_program(ti, lambda acc, i: acc + value_program(i) + ", ", "(") + ")" + ) - yield rewrite(tuple_value_foldl_program(TupleValue.EMPTY, f, init)).to(init) - yield rewrite(tuple_value_foldl_program(ti.append(v), f, init)).to(f(tuple_value_foldl_program(ti, f, init), v)) + yield rewrite(tuple_value_foldl_program(TupleValue(()), f, init)).to(init) + yield rewrite(tuple_value_foldl_program(TupleValue(vv), f, init)).to( + f(tuple_value_foldl_program(vv.pop(), f, init), vv[vv.length() - 1]), vv.length() > 0 + ) @function @@ -189,13 +225,15 @@ def tuple_ndarray_program(x: TupleNDArrayLike) -> Program: @array_api_program_gen_ruleset.register def _tuple_ndarray_program( - i: Int, ti: TupleNDArray, f: Callable[[Program, NDArray], Program], v: NDArray, init: Program + i: Int, ti: TupleNDArray, f: Callable[[Program, NDArray], Program], v: NDArray, init: Program, vn: Vec[NDArray] ): yield rewrite(ndarray_program(ti[i])).to(tuple_ndarray_program(ti) + "[" + int_program(i) + "]") yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_ndarray_program(ti) + ")") - yield rewrite(tuple_ndarray_foldl_program(TupleNDArray.EMPTY, f, init)).to(init) - yield rewrite(tuple_ndarray_foldl_program(ti.append(v), f, init)).to(f(tuple_ndarray_foldl_program(ti, f, init), v)) + yield rewrite(tuple_ndarray_foldl_program(TupleNDArray(()), f, init)).to(init) + yield rewrite(tuple_ndarray_foldl_program(TupleNDArray(vn), f, init)).to( + f(tuple_ndarray_foldl_program(vn.pop(), f, init), vn[vn.length() - 1]), vn.length() > 0 + ) @function @@ -302,23 +340,14 @@ def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray): yield rewrite(index_key_program(IndexKey.ndarray(a))).to(ndarray_program(a)) -@function -def int_or_tuple_program(x: IntOrTuple) -> Program: ... - - -@array_api_program_gen_ruleset.register -def _int_or_tuple_program(x: Int, t: TupleInt): - yield rewrite(int_or_tuple_program(IntOrTuple.int(x))).to(int_program(x)) - yield rewrite(int_or_tuple_program(IntOrTuple.tuple(t))).to(tuple_int_program(t)) - - @function def optional_int_or_tuple_program(x: OptionalIntOrTuple) -> Program: ... @array_api_program_gen_ruleset.register -def _optional_int_or_tuple_program(it: IntOrTuple): - yield rewrite(optional_int_or_tuple_program(OptionalIntOrTuple.some(it))).to(int_or_tuple_program(it)) +def _optional_int_or_tuple_program(i: Int, ti: TupleInt): + yield rewrite(optional_int_or_tuple_program(OptionalIntOrTuple.int(i))).to(int_program(i)) + yield rewrite(optional_int_or_tuple_program(OptionalIntOrTuple.tuple(ti))).to(tuple_int_program(ti)) yield rewrite(optional_int_or_tuple_program(OptionalIntOrTuple.none)).to(Program("None")) @@ -332,18 +361,18 @@ def _ndarray_program( ti: TupleInt, i: Int, tv: TupleValue, - v: Value, + rv: RecursiveValue, ob: OptionalBool, tnd: TupleNDArray, optional_device_: OptionalDevice, - int_or_tuple_: IntOrTuple, + optional_int_or_tuple_: OptionalIntOrTuple, idx: IndexKey, odtype: OptionalDType, ): # Var yield rewrite(ndarray_program(NDArray.var(s))).to(Program(s, True)) - # Asssume dtype + # Assume dtype z_assumed_dtype = copy(z) assume_dtype(z_assumed_dtype, dtype) z_program = ndarray_program(z) @@ -395,10 +424,8 @@ def _ndarray_program( # Tuple ndarray indexing yield rewrite(ndarray_program(tnd[i])).to(tuple_ndarray_program(tnd) + "[" + int_program(i) + "]") - # ndarray scalar - # TODO: Use dtype and shape and indexing instead? - # TODO: SPecify dtype? - yield rewrite(ndarray_program(NDArray.scalar(v))).to(Program("np.array(") + value_program(v) + ")") + # literal array + yield rewrite(ndarray_program(NDArray(rv))).to(Program("np.array(") + recursive_value_program(rv) + ")") # zeros yield rewrite(ndarray_program(zeros(ti, OptionalDType.none, optional_device_))).to( @@ -457,20 +484,28 @@ def bin_op(res: NDArray, op: str) -> Command: # mean(x, axis) yield rewrite(ndarray_program(mean(x))).to((Program("np.mean(") + ndarray_program(x) + ")").assign()) yield rewrite( - ndarray_program(mean(x, OptionalIntOrTuple.some(int_or_tuple_), FALSE)), + ndarray_program(mean(x, optional_int_or_tuple_, FALSE)), ).to( - (Program("np.mean(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign(), + ( + Program("np.mean(") + + ndarray_program(x) + + ", axis=" + + optional_int_or_tuple_program(optional_int_or_tuple_) + + ")" + ).assign(), + optional_int_or_tuple_ != OptionalIntOrTuple.none, ) yield rewrite( - ndarray_program(mean(x, OptionalIntOrTuple.some(int_or_tuple_), TRUE)), + ndarray_program(mean(x, optional_int_or_tuple_, TRUE)), ).to( ( Program("np.mean(") + ndarray_program(x) + ", axis=" - + int_or_tuple_program(int_or_tuple_) + + optional_int_or_tuple_program(optional_int_or_tuple_) + ", keepdims=True)" ).assign(), + optional_int_or_tuple_ != OptionalIntOrTuple.none, ) # Concat @@ -480,16 +515,21 @@ def bin_op(res: NDArray, op: str) -> Command: yield rewrite(ndarray_program(concat(tnd, OptionalInt.some(i)))).to( (Program("np.concatenate(") + tuple_ndarray_program(tnd) + ", axis=" + int_program(i) + ")").assign() ) - # Vector - yield rewrite(ndarray_program(NDArray.vector(tv))).to(Program("np.array(") + tuple_value_program(tv) + ")") # std yield rewrite(ndarray_program(std(x))).to((Program("np.std(") + ndarray_program(x) + ")").assign()) - yield rewrite(ndarray_program(std(x, OptionalIntOrTuple.some(int_or_tuple_)))).to( - (Program("np.std(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign(), + yield rewrite(ndarray_program(std(x, optional_int_or_tuple_))).to( + ( + Program("np.std(") + + ndarray_program(x) + + ", axis=" + + optional_int_or_tuple_program(optional_int_or_tuple_) + + ")" + ).assign(), + optional_int_or_tuple_ != OptionalIntOrTuple.none, ) # svd - yield rewrite(tuple_ndarray_program(svd(x))).to((Program("np.linalg.svd(") + ndarray_program(x) + ")").assign()) - yield rewrite(tuple_ndarray_program(svd(x, FALSE))).to( + yield rewrite(tuple_ndarray_program(svd_(x))).to((Program("np.linalg.svd(") + ndarray_program(x) + ")").assign()) + yield rewrite(tuple_ndarray_program(svd_(x, FALSE))).to( (Program("np.linalg.svd(") + ndarray_program(x) + ", full_matrices=False)").assign() ) # sqrt @@ -498,8 +538,15 @@ def bin_op(res: NDArray, op: str) -> Command: yield rewrite(ndarray_program(x.T)).to(ndarray_program(x) + ".T") # sum yield rewrite(ndarray_program(sum(x))).to((Program("np.sum(") + ndarray_program(x) + ")").assign()) - yield rewrite(ndarray_program(sum(x, OptionalIntOrTuple.some(int_or_tuple_)))).to( - (Program("np.sum(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign() + yield rewrite(ndarray_program(sum(x, optional_int_or_tuple_))).to( + ( + Program("np.sum(") + + ndarray_program(x) + + ", axis=" + + optional_int_or_tuple_program(optional_int_or_tuple_) + + ")" + ).assign(), + optional_int_or_tuple_ != OptionalIntOrTuple.none, ) yield rewrite(tuple_int_program(x.shape)).to(ndarray_program(x) + ".shape") yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign()) @@ -508,3 +555,26 @@ def bin_op(res: NDArray, op: str) -> Command: yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to( Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")" ) + + +@function +def recursive_value_program(x: RecursiveValue) -> Program: ... + + +@array_api_program_gen_ruleset.register +def _recursive_value_program(v: Value, vv: Vec[RecursiveValue]): + yield rewrite(recursive_value_program(RecursiveValue(v))).to(value_program(v)) + yield rewrite(recursive_value_program(RecursiveValue.vec(vv))).to("(" + vec_recursive_value_program(vv) + ")") + + +@function +def vec_recursive_value_program(x: Vec[RecursiveValue]) -> Program: ... + + +@array_api_program_gen_ruleset.register +def _vec_recursive_value_program(vv: Vec[RecursiveValue]): + yield rewrite(vec_recursive_value_program(Vec[RecursiveValue].empty())).to(Program("")) + yield rewrite(vec_recursive_value_program(vv)).to( + recursive_value_program(vv[0]) + ", " + vec_recursive_value_program(vv.remove(0)), + vv.length() > 0, + ) diff --git a/python/egglog/exp/polynomials.py b/python/egglog/exp/polynomials.py new file mode 100644 index 00000000..2f2a6269 --- /dev/null +++ b/python/egglog/exp/polynomials.py @@ -0,0 +1,274 @@ +""" +Helpers for the polynomial container examples in the containers docs. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass + +import numpy as np + +import egglog +import egglog.exp.array_api as enp + +__all__ = [ + "Report", + "TotalReport", + "bending_function", + "distribute", + "factoring", + "remove_subtraction", + "run_example", + "symbolic_bending_examples", + "symbolic_bending_inputs", + "try_example", +] + + +def bending_function(Q, Bp, Bpp): + xp = Q.__array_namespace__() + QM = xp.reshape(Q, (4, 3)).T + + yip = xp.vecdot(QM, Bp) + yipp = xp.vecdot(QM, Bpp) + num = xp.linalg.vector_norm(xp.cross(yip, yipp)) + den = xp.linalg.vector_norm(yip) ** 3 + return (num / den) ** 2 + + +def symbolic_bending_inputs() -> tuple[enp.NDArray, enp.NDArray, enp.NDArray]: + bp = enp.NDArray([enp.Value.var(f"bp{i}") for i in range(1, 5)]) + bpp = enp.NDArray([enp.Value.var(f"bpp{i}") for i in range(1, 5)]) + q = enp.NDArray([enp.Value.var(f"q{i}") for i in range(1, 13)]) + return bp, bpp, q + + +def symbolic_bending_examples() -> tuple[enp.NDArray, enp.NDArray]: + bp, bpp, q = symbolic_bending_inputs() + function_bending = enp.NDArray(bending_function(q, bp, bpp).eval()) + gradient_bending = enp.NDArray(function_bending.diff(q).eval()) + return function_bending, gradient_bending + + +@egglog.ruleset +def remove_subtraction(a: enp.Value, b: enp.Value): + yield egglog.rewrite(a - b, subsume=True).to(a + enp.Value.from_int(-1) * b) + + +@egglog.ruleset +def distribute(a: enp.Value, b: enp.Value, c: enp.Value): + yield egglog.rewrite((a + b) * c, subsume=True).to(a * c + b * c) + yield egglog.rewrite(c * (a + b), subsume=True).to(c * a + c * b) + + +@egglog.ruleset +def factoring(a: enp.Value, b: enp.Value, c: enp.Value): + yield egglog.birewrite((a + b) * c).to(a * c + b * c) + yield egglog.rewrite(a * b).to(b * a) + yield egglog.rewrite(a + b).to(b + a) + yield egglog.birewrite(a * (b * c)).to((a * b) * c) + + +@dataclass(frozen=True) +class Report: + register_sec: float + run_sec: float + extract_sec: float + extracted: enp.NDArray + cost: int + function_sizes: list[tuple[egglog.ExprCallable, int]] + updated: bool + + @property + def total_sec(self) -> float: + return self.register_sec + self.run_sec + self.extract_sec + + @property + def total_size(self) -> int: + return sum(size for _, size in self.function_sizes) + + +def run_example( + ruleset: egglog.Schedule | egglog.Ruleset, input: enp.NDArray, egraph: egglog.EGraph | None = None +) -> Report: + if egraph is None: + egraph = egglog.EGraph() + + start = time.perf_counter() + egraph.register(input) + register_sec = time.perf_counter() - start + + start = time.perf_counter() + run_report = egraph.run(ruleset) + run_sec = time.perf_counter() - start + + start = time.perf_counter() + extracted, cost = egraph.extract(input, include_cost=True) + extract_sec = time.perf_counter() - start + + return Report(register_sec, run_sec, extract_sec, extracted, cost, egraph.all_function_sizes(), run_report.updated) + + +@dataclass(frozen=True) +class TotalReport: + original: Report + distributed: Report + factored: list[Report] + polynomial_multisets: Report + polynomial_multisets_factored: Report + polynomial: Report + + @property + def combined_factored(self) -> Report: + if not self.factored: + return self.distributed + return Report( + register_sec=self.factored[0].register_sec, + run_sec=sum(r.run_sec for r in self.factored), + extract_sec=self.factored[-1].extract_sec, + extracted=self.factored[-1].extracted, + cost=self.factored[-1].cost, + function_sizes=self.factored[-1].function_sizes, + updated=self.factored[-1].updated, + ) + + @property + def combined_polynomial(self) -> Report: + return Report( + register_sec=self.polynomial_multisets.register_sec, + run_sec=self.polynomial_multisets.run_sec + + self.polynomial_multisets_factored.run_sec + + self.polynomial.run_sec, + extract_sec=self.polynomial.extract_sec, + extracted=self.polynomial.extracted, + cost=self.polynomial.cost, + function_sizes=self.polynomial.function_sizes, + updated=self.polynomial.updated, + ) + + def __str__(self) -> str: + return f"""Costs: +* original: {self.original.cost:,} +* distributed: {self.distributed.cost:,} +* factored: {self.combined_factored.cost:,} +* horner multisets: {self.combined_polynomial.cost:,} + + +Number of nodes: +* original: {self.original.total_size:,} +* distributed: {self.distributed.total_size:,} +* factored: {self.combined_factored.total_size:,} +* horner multisets: {self.combined_polynomial.total_size:,} + +Time: +* original: {self.original.total_sec:.2f}s +* distributed: {self.distributed.total_sec:.2f}s +* factored: {self.combined_factored.total_sec:.2f}s +* horner multisets: {self.combined_polynomial.total_sec:.2f}s +""" + + +def try_example( + expr: enp.NDArray, + *, + max_factoring_iters: int = 20, + max_factoring_sec: float = 10.0, +) -> TotalReport: + original_report = run_example(remove_subtraction, expr) + print(f"original cost: {original_report.cost:,}") + distributed_report = run_example(distribute.saturate(), original_report.extracted) + print(f"distributed cost: {distributed_report.cost:,}") + + egraph = egglog.EGraph() + polynomial_multisets_report = run_example( + enp.to_polynomial_ruleset.saturate(), distributed_report.extracted, egraph + ) + polynomial_multisets_factored_report = run_example( + enp.factor_ruleset.saturate(), polynomial_multisets_report.extracted, egraph + ) + polynomial_report = run_example( + enp.from_polynomial_ruleset.saturate(), polynomial_multisets_factored_report.extracted, egraph + ) + print(f"polynomial cost: {polynomial_report.cost:,}") + + egraph = egglog.EGraph() + factored_reports: list[Report] = [] + for i in range(max_factoring_iters): + res = run_example(factoring, distributed_report.extracted, egraph) + if not res.updated or res.run_sec > max_factoring_sec: + break + print(f"factoring iteration {i}, cost: {res.cost:,}") + factored_reports.append(res) + print("Finished\n") + + return TotalReport( + original_report, + distributed_report, + factored_reports, + polynomial_multisets_report, + polynomial_multisets_factored_report, + polynomial_report, + ) + + +def main() -> None: + """ + Run the end-to-end polynomial container example. + + >>> main() # doctest: +ELLIPSIS + original cost: ... + distributed cost: ... + polynomial cost: ... + Finished + + Costs: + * original: ... + * distributed: ... + * factored: ... + * horner multisets: ... + + + Number of nodes: + * original: ... + * distributed: ... + * factored: ... + * horner multisets: ... + + Time: + * original: ... + * distributed: ... + * factored: ... + * horner multisets: ... + + gradient remove_subtraction cost: ... + """ + rng = np.random.default_rng(0) + q = rng.random(12) + bp = rng.random(4) + bpp = rng.random(4) + + qm = np.reshape(q, (4, 3)).T + yip = qm @ bp + yipp = qm @ bpp + expected = (np.linalg.norm(np.cross(yip, yipp)) / np.linalg.norm(yip) ** 3) ** 2 + result = bending_function(q, bp, bpp) + assert np.isclose(result, expected) + + function_bending, gradient_bending = symbolic_bending_examples() + + function_report = try_example(function_bending, max_factoring_iters=0) + assert function_report.original.cost > 0 + assert function_report.polynomial.cost > 0 + + gradient_report = run_example(remove_subtraction, gradient_bending) + assert gradient_report.cost > 0 + assert gradient_report.total_sec >= 0.0 + assert gradient_report.total_size > 0 + + print(function_report) + print("gradient remove_subtraction cost:", gradient_report.cost) + + +if __name__ == "__main__": + main() diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index eb5a2480..984869ea 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -12,7 +12,7 @@ class Program(Expr): """ - Semanticallly represents an expression with a number of ordered statements that it depends on to run. + Semantically represents an expression with a number of ordered statements that it depends on to run. The expression and statements are all represented as strings. """ @@ -91,7 +91,7 @@ def parent(self) -> Program: """ @property - def is_identifer(self) -> Bool: + def is_identifier(self) -> Bool: """ Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers. """ @@ -119,7 +119,7 @@ def as_py_object(self) -> PyObject: @ruleset -def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject): +def eval_program_ruleset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject): # When we evaluate a program, we first want to compile to a string yield rule(EvalProgram(p, g)).then(p.compile()) # Then we want to evaluate the statements/expr @@ -164,7 +164,7 @@ def program_gen_ruleset( set_(p.expr).to(s), set_(p.statements).to(String("")), set_(p.next_sym).to(i), - set_(p.is_identifer).to(b), + set_(p.is_identifier).to(b), ) ## @@ -178,7 +178,7 @@ def program_gen_ruleset( ## stmt = eq(p).to(p1.expr_to_statement()) # 1. Set parent and is_identifier to false, since its empty - yield rule(stmt, p.compile(i)).then(set_(p1.parent).to(p), set_(p.is_identifer).to(Bool(False))) + yield rule(stmt, p.compile(i)).then(set_(p1.parent).to(p), set_(p.is_identifier).to(Bool(False))) # 2. Compile p1 if parent set yield rule(stmt, p.compile(i), eq(p1.parent).to(p)).then(p1.compile(i)) # 3.a. If parent not set, set statements to expr @@ -215,9 +215,9 @@ def program_gen_ruleset( # If the resulting expression is either of the inputs, then its an identifer if those are # Otherwise, if its not equal to either input, its not an identifier - yield rule(program_add, eq(p.expr).to(p1.expr), eq(b).to(p1.is_identifer)).then(set_(p.is_identifer).to(b)) - yield rule(program_add, eq(p.expr).to(p2.expr), eq(b).to(p2.is_identifer)).then(set_(p.is_identifer).to(b)) - yield rule(program_add, ne(p.expr).to(p1.expr), ne(p.expr).to(p2.expr)).then(set_(p.is_identifer).to(Bool(False))) + yield rule(program_add, eq(p.expr).to(p1.expr), eq(b).to(p1.is_identifier)).then(set_(p.is_identifier).to(b)) + yield rule(program_add, eq(p.expr).to(p2.expr), eq(b).to(p2.is_identifier)).then(set_(p.is_identifier).to(b)) + yield rule(program_add, ne(p.expr).to(p1.expr), ne(p.expr).to(p2.expr)).then(set_(p.is_identifier).to(Bool(False))) # Set parent of p1 yield rule(program_add, p.compile(i)).then( @@ -299,7 +299,7 @@ def program_gen_ruleset( # expression as the gensym, and setting is_identifier to true program_assign = eq(p).to(p1.assign()) # Set parent - yield rule(program_assign, p.compile(i)).then(set_(p1.parent).to(p), set_(p.is_identifer).to(Bool(True))) + yield rule(program_assign, p.compile(i)).then(set_(p1.parent).to(p), set_(p.is_identifier).to(Bool(True))) # If parent set, compile the expression yield rule(program_assign, p.compile(i), eq(p1.parent).to(p)).then(p1.compile(i)) @@ -313,7 +313,7 @@ def program_gen_ruleset( eq(s1).to(p1.statements), eq(i).to(p1.next_sym), eq(s2).to(p1.expr), - eq(p1.is_identifer).to(Bool(False)), + eq(p1.is_identifier).to(Bool(False)), ).then( set_(p.statements).to(join(s1, symbol, " = ", s2, "\n")), set_(p.expr).to(symbol), @@ -325,7 +325,7 @@ def program_gen_ruleset( ne(p1.parent).to(p), p.compile(i), eq(s2).to(p1.expr), - eq(p1.is_identifer).to(Bool(False)), + eq(p1.is_identifier).to(Bool(False)), ).then( set_(p.statements).to(join(symbol, " = ", s2, "\n")), set_(p.expr).to(symbol), @@ -341,7 +341,7 @@ def program_gen_ruleset( eq(s1).to(p1.statements), eq(i).to(p1.next_sym), eq(s2).to(p1.expr), - eq(p1.is_identifer).to(Bool(True)), + eq(p1.is_identifier).to(Bool(True)), ).then( set_(p.statements).to(s1), set_(p.expr).to(s2), @@ -353,7 +353,7 @@ def program_gen_ruleset( ne(p1.parent).to(p), p.compile(i), eq(s2).to(p1.expr), - eq(p1.is_identifer).to(Bool(True)), + eq(p1.is_identifier).to(Bool(True)), ).then( set_(p.statements).to(String("")), set_(p.expr).to(s2), @@ -376,7 +376,7 @@ def program_gen_ruleset( p2.compile(i), p3.compile(i), p1.compile(i), - set_(p.is_identifer).to(Bool(True)), + set_(p.is_identifier).to(Bool(True)), ) # 2. Set statements to function body and the next sym to i yield rule( @@ -408,7 +408,7 @@ def program_gen_ruleset( p3.compile(i), p1.compile(i), p4.compile(i), - set_(p.is_identifer).to(Bool(True)), + set_(p.is_identifier).to(Bool(True)), ) yield rule( fn_three, diff --git a/python/egglog/ipython_magic.py b/python/egglog/ipython_magic.py index 49111806..2f2101e7 100644 --- a/python/egglog/ipython_magic.py +++ b/python/egglog/ipython_magic.py @@ -1,3 +1,4 @@ +from ._tracing import call_with_current_trace from .bindings import EGraph EGRAPH_VAR = "_MAGIC_EGRAPH" @@ -33,7 +34,7 @@ def egglog(line, cell, local_ns): e = EGraph() local_ns[EGRAPH_VAR] = e cmds = e.parse_program(cell) - res = e.run_program(*cmds) + res = call_with_current_trace(e.run_program, *cmds) if "output" in line: print("\n".join(res)) if "graph" in line: diff --git a/python/egglog/pretty.py b/python/egglog/pretty.py index 556d629e..008fb1bd 100644 --- a/python/egglog/pretty.py +++ b/python/egglog/pretty.py @@ -1,5 +1,5 @@ """ -Pretty printing for declerations. +Pretty printing for declarations. """ from __future__ import annotations @@ -77,7 +77,15 @@ } AllDecls: TypeAlias = ( - RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl | BackOffDecl + RulesetDecl + | CombinedRulesetDecl + | CommandDecl + | ActionDecl + | FactDecl + | ExprDecl + | ScheduleDecl + | BackOffDecl + | EGraphDecl ) @@ -96,6 +104,12 @@ def pretty_decl( if wrapping_fn: expr = f"{wrapping_fn}({expr})" program = "\n".join([*pretty.statements, expr]) + # First unparse AST to get consistent formatting, then use black to format it nicely + try: + ast_tree = ast.parse(program, mode="exec") + except SyntaxError: + return program + program = ast.unparse(ast_tree) try: # TODO: Try replacing with ruff for speed # https://github.com/amyreese/ruff-api @@ -221,6 +235,11 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90 self(schedule) case GetCostDecl(ref, args): self(CallDecl(ref, args)) + case DummyDecl(): + pass + case EGraphDecl() as eg: + for a in eg.to_actions: + self(a) case _: assert_never(decl) @@ -252,7 +271,7 @@ def __call__( if decl in self.names: return self.names[decl] expr, tp_name = self.uncached(decl, unwrap_lit=unwrap_lit, parens=parens, ruleset_ident=ruleset_ident) - # We use a heuristic to decide whether to name this sub-expression as a variable + # We use a heuristic to decide whether to name this sub-expression as a variable. # The rough goal is to reduce the number of newlines, given our line length of ~180 # We determine it's worth making a new line for this expression if the total characters # it would take up is > than some constant (~ line length). @@ -271,8 +290,8 @@ def uncached( # noqa: C901, PLR0911, PLR0912 self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_ident: Ident | None ) -> tuple[str, str]: """ - Returns a tuple of a string value of the decleration and the "type" to use when create a memoized cached version - for de-duplication. + Returns a tuple of a string value of the declaration and the "type" to use when create a memoized cached version + for deduplication. """ match decl: case LitDecl(value): @@ -373,8 +392,12 @@ def uncached( # noqa: C901, PLR0911, PLR0912 return f"back_off({', '.join(list_args)})", "scheduler" case ValueDecl(value): return str(value), "value" + case DummyDecl(): + return "__InternalDummyValueShouldNotBeSeenOpenAnIssue()", "dummy" case GetCostDecl(ref, args): return f"get_cost({self(CallDecl(ref, args))})", "get_cost" + case EGraphDecl() as eg: + return f"EGraph({', '.join(map(self, eg.to_actions))}).freeze()", "egraph" assert_never(decl) def _call( diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 56f68ea7..2ec5a453 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -15,9 +15,8 @@ import operator import types from collections.abc import Callable -from dataclasses import InitVar, dataclass, replace +from dataclasses import InitVar, dataclass, field, replace from inspect import Parameter, Signature -from itertools import zip_longest from typing import TYPE_CHECKING, Any, TypeVar, Union, assert_never, cast, get_args, get_origin import cloudpickle @@ -34,6 +33,7 @@ __all__ = [ "ALWAYS_MUTATES_SELF", "ALWAYS_PRESERVED", + "DUMMY_VALUE", "LIT_IDENTS", "NUMERIC_BINARY_METHODS", "RuntimeClass", @@ -142,14 +142,14 @@ def resolve_type_annotation_mutate(decls: Declarations, tp: object) -> TypeOrVarRef: """ - Wrap resolve_type_annotation to mutate decls, as a helper for internal use in sitations where that is more ergonomic. + Wrap resolve_type_annotation to mutate decls, as a helper for internal use in situations where that is more ergonomic. """ new_decls, tp = resolve_type_annotation(tp) decls |= new_decls return tp -def resolve_type_annotation(tp: object) -> tuple[DeclerationsLike, TypeOrVarRef]: +def resolve_type_annotation(tp: object) -> tuple[DeclarationsLike, TypeOrVarRef]: """ Resolves a type object into a type reference. @@ -157,7 +157,7 @@ def resolve_type_annotation(tp: object) -> tuple[DeclerationsLike, TypeOrVarRef] resolve the decls if need be. """ if isinstance(tp, TypeVar): - return None, ClassTypeVarRef.from_type_var(tp) + return None, TypeVarRef.from_type_var(tp) # If there is a union, then we assume the first item is the type we want, and the others are types that can be converted to that type. if get_origin(tp) == Union: first, *_rest = get_args(tp) @@ -181,7 +181,7 @@ def inverse_resolve_type_annotation(decls_thunk: Callable[[], Declarations], tp: """ Inverse of resolve_type_annotation """ - if isinstance(tp, ClassTypeVarRef): + if isinstance(tp, TypeVarRef): return tp.to_type_var() return RuntimeClass(decls_thunk, tp) @@ -238,14 +238,7 @@ class RuntimeClassDescriptor: def __get__(self, obj: object, owner: RuntimeClass | None = None) -> Callable: if owner is None: raise AttributeError(f"Can only access {self.name} on the class, not an instance") - cls_decl = owner.__egg_decls__._classes[owner.__egg_tp__.ident] - if self.name in cls_decl.class_methods: - return RuntimeFunction( - owner.__egg_decls_thunk__, Thunk.value(ClassMethodRef(owner.__egg_tp__.ident, self.name)), None - ) - if self.name in cls_decl.preserved_methods: - return cls_decl.preserved_methods[self.name] - raise AttributeError(f"Class {owner.__egg_tp__.ident} has no method {self.name}") from None + return RuntimeClass.__getattr__(owner, self.name) RUNTIME_CLASS_DESCRIPTORS: dict[str, RuntimeClassDescriptor] = { @@ -254,10 +247,13 @@ def __get__(self, obj: object, owner: RuntimeClass | None = None) -> Callable: @dataclass(match_args=False) -class RuntimeClass(DelayedDeclerations, metaclass=ClassFactory): +class RuntimeClass(DelayedDeclarations, metaclass=ClassFactory): __egg_tp__: TypeRefWithVars # True if we want `__parameters__` to be recognized by `Union`, which means we can't inherit from `type` directly. _egg_has_params: InitVar[bool] = False + __egg_attr_cache__: dict[str, RuntimeFunction | RuntimeExpr | Callable] = field( + init=False, repr=False, default_factory=dict + ) def __post_init__(self, _egg_has_params: bool) -> None: global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS @@ -302,16 +298,19 @@ def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None: # Assumes we don't have types set for UnstableFn w/ generics, that they have to be inferred # 1. Call it with the partial args, and use untyped vars for the rest of the args - res = cast("Callable", fn_arg)(*partial_args, _egg_partial_function=True) + res = cast("Callable", fn_arg)(*partial_args, _egg_function_types=self.__egg_tp__.args) assert res is not None, "Mutable partial functions not supported" # 2. Use the inferred return type and inferred rest arg types as the types of the function, and # the partially applied args as the args. call = (res_typed_expr := res.__egg_typed_expr__).expr return_tp = res_typed_expr.tp assert isinstance(call, CallDecl), "partial function must be a call" - n_args = len(partial_args) - value = PartialCallDecl(replace(call, args=call.args[:n_args])) - remaining_arg_types = [a.tp for a in call.args[n_args:]] + # Clip off the remaining arguments + captured_args = len(partial_args) + int( + isinstance(fn_arg, RuntimeFunction) and isinstance(fn_arg.__egg_bound__, RuntimeExpr) + ) + value = PartialCallDecl(replace(call, args=call.args[:captured_args])) + remaining_arg_types = [a.tp for a in call.args[captured_args:]] type_ref = JustTypeRef(Ident.builtin("UnstableFn"), (return_tp, *remaining_arg_types)) return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value)) @@ -345,13 +344,13 @@ def __getitem__(self, args: object) -> RuntimeClass: # defer resolving decls so that we can do generic instantiation for converters before all # method types are defined. decls_like, new_args = cast( - "tuple[tuple[DeclerationsLike, ...], tuple[TypeOrVarRef, ...]]", + "tuple[tuple[DeclarationsLike, ...], tuple[TypeOrVarRef, ...]]", zip(*(resolve_type_annotation(arg) for arg in args), strict=False), ) - # if we already have some args bound and some not, then we shold replace all existing args of typevars with new + # if we already have some args bound and some not, then we should replace all existing args of typevars with new # args if old_args := self.__egg_tp__.args: - is_typevar = [isinstance(arg, ClassTypeVarRef) for arg in old_args] + is_typevar = [isinstance(arg, TypeVarRef) for arg in old_args] if sum(is_typevar) != len(new_args): raise TypeError(f"Expected {sum(is_typevar)} typevars, got {len(new_args)}") new_args_list = list(new_args) @@ -361,7 +360,7 @@ def __getitem__(self, args: object) -> RuntimeClass: tp = TypeRefWithVars(self.__egg_tp__.ident, final_args) return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp, _egg_has_params=True) - def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable: + def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable: # noqa: C901 if not isinstance(name, str): raise TypeError(f"Attribute name must be a string, got {name!r}") if name == "__origin__" and self.__egg_tp__.args: @@ -379,6 +378,11 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable: }: raise AttributeError + try: + return self.__egg_attr_cache__[name] + except KeyError: + pass + try: cls_decl = self.__egg_decls__._classes[self.__egg_tp__.ident] except Exception as e: @@ -387,29 +391,33 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable: preserved_methods = cls_decl.preserved_methods if name in preserved_methods: - return preserved_methods[name] - + res = preserved_methods[name] # if this is a class variable, return an expr for it, otherwise, assume it's a method - if name in cls_decl.class_variables: + elif name in cls_decl.class_variables: return_tp = cls_decl.class_variables[name] - return RuntimeExpr( + res = RuntimeExpr( self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.ident, name)))), ) - if name in cls_decl.class_methods: - return RuntimeFunction( + else: + if name in cls_decl.class_methods: + callable_ref: CallableRef = ClassMethodRef(self.__egg_tp__.ident, name) + # allow referencing properties and methods as class variables as well + elif name in cls_decl.properties: + callable_ref = PropertyRef(self.__egg_tp__.ident, name) + elif name in cls_decl.methods: + callable_ref = MethodRef(self.__egg_tp__.ident, name) + else: + msg = f"Class {self.__egg_tp__.ident} has no method {name}" + raise AttributeError(msg) from None + res = RuntimeFunction( self.__egg_decls_thunk__, - Thunk.value(ClassMethodRef(self.__egg_tp__.ident, name)), - self.__egg_tp__.to_just(), + Thunk.value(callable_ref), + self.__egg_tp__.to_just() if self.__egg_tp__.args else None, ) - # allow referencing properties and methods as class variables as well - if name in cls_decl.properties: - return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.ident, name))) - if name in cls_decl.methods: - return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.ident, name))) - msg = f"Class {self.__egg_tp__.ident} has no method {name}" - raise AttributeError(msg) from None + self.__egg_attr_cache__[name] = res + return res def __str__(self) -> str: return str(self.__egg_tp__) @@ -466,8 +474,9 @@ def __getattribute__(cls, name: str) -> Any: @dataclass -class RuntimeFunction(DelayedDeclerations, metaclass=RuntimeFunctionMeta): +class RuntimeFunction(DelayedDeclarations, metaclass=RuntimeFunctionMeta): __egg_ref_thunk__: Callable[[], CallableRef] + # Either they bound class for something like `Vec[Int].create` or a RuntimeExpr for bound methods # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self __egg_bound__: JustTypeRef | RuntimeExpr | None = None @@ -495,7 +504,9 @@ def __hash__(self) -> int: def __egg_ref__(self) -> CallableRef: return self.__egg_ref_thunk__() - def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None: + def __call__( # noqa: C901,PLR0912 + self, *args: object, _egg_function_types: tuple[TypeOrVarRef, ...] | None = None, **kwargs: object + ) -> RuntimeExpr | None: from .conversion import resolve_literal # noqa: PLC0415 if isinstance(self.__egg_bound__, RuntimeExpr): @@ -528,47 +539,59 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: assert isinstance(signature, FunctionSignature) # Turn all keyword args into positional args - py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function) + py_signature = to_py_signature(signature, self.__egg_decls__, optional_args=_egg_function_types is not None) try: bound = py_signature.bind(*args, **kwargs) except TypeError as err: - raise TypeError(f"Failed to bind arguments for {self} with args {args} and kwargs {kwargs}: {err}") from err + err.add_note(f"when calling {self} with args {args} and kwargs {kwargs}") + raise del kwargs bound.apply_defaults() assert not bound.kwargs args = bound.args - tcs = TypeConstraintSolver(decls) - bound_tp = ( - None - if self.__egg_bound__ is None - else self.__egg_bound__.__egg_typed_expr__.tp - if isinstance(self.__egg_bound__, RuntimeExpr) - else self.__egg_bound__ - ) - if ( - bound_tp - and bound_tp.args - # Don't bind class if we have a first class function arg, b/c we don't support that yet - and not function_value - ): - tcs.bind_class(bound_tp) + tcs = TypeConstraintSolver() + if isinstance(self.__egg_bound__, JustTypeRef) and self.__egg_bound__.args: + if function_value: + msg = "Cannot have both bound type params and function value" + raise ValueError(msg) + tcs.bind_class(self.__egg_bound__, decls) + bound_tp_params = self.__egg_bound__.args + else: + bound_tp_params = () assert (operator.ge if signature.var_arg_type else operator.eq)(len(args), len(signature.arg_types)) - cls_ident = bound_tp.ident if bound_tp else None + # Hack to allow being explicit on function types when casting. # noqa: FIX004 + for _fn_tp in _egg_function_types or (): + try: + _fn_tp_just = _fn_tp.to_just() + except TypeVarError: + continue + tcs.bind_class(_fn_tp_just, decls) + if _fn_tp_just.args: + pass + # Try using any runtime expressions passed in to help infer typevars + for arg, tp in zip(args, signature.all_args, strict=False): + if not isinstance(arg, RuntimeExpr): + continue + try: + tcs.infer_typevars(tp, arg.__egg_typed_expr__.tp) + # If this leads to an incompatibility, just skip it, since it could need to be upcasted + except TypeConstraintError: + continue + # Now at this point we should be able to resolve all the typevars upcasted_args = [ - resolve_literal(cast("TypeOrVarRef", tp), arg, Thunk.value(decls), tcs=tcs, cls_ident=cls_ident) - for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type) + resolve_literal( + tcs.substitute_typevars_try_function(tp, arg, Thunk.value(decls)).to_var(), arg, Thunk.value(decls) + ) + for arg, tp in zip(args, signature.all_args, strict=False) ] decls.update(*upcasted_args) arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args) - return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_ident) - bound_params = ( - cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else () - ) - # If we were using unstable-app to call a funciton, add that function back as the first arg. + return_tp = tcs.substitute_typevars(signature.semantic_return_type) + # If we were using unstable-app to call a function, add that function back as the first arg. if function_value: arg_exprs = (function_value, *arg_exprs) - expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params) + expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_tp_params) typed_expr_decl = TypedExprDecl(return_tp, expr_decl) # If there is not return type, we are mutating the first arg if not signature.return_type: @@ -598,12 +621,16 @@ def __doc__(self) -> str | None: # type: ignore[override] return None +# sentinel that will be upcasted to any type with a DummyDecl value +DUMMY_VALUE = object() + + def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature: """ Convert to a Python signature. If optional_args is true, then all args will be treated as optional, as if a default was provided that makes them - a var with that arg name as the value. + `DUMMY_VALUE` Used for partial application to try binding a function with only some of its args. """ @@ -611,9 +638,9 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: Parameter( n, Parameter.POSITIONAL_OR_KEYWORD, - default=RuntimeExpr.__from_values__(decls, TypedExprDecl(t.to_just(), d or LetRefDecl(n))) - if d is not None or optional_args - else Parameter.empty, + default=RuntimeExpr.__from_values__(decls, TypedExprDecl(t.to_just(), d)) + if d is not None + else (DUMMY_VALUE if optional_args else Parameter.empty), ) for n, d, t in zip(sig.arg_names, sig.arg_defaults, sig.arg_types, strict=True) ] @@ -626,7 +653,7 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: @dataclass -class RuntimeExpr(DelayedDeclerations): +class RuntimeExpr(DelayedDeclarations): __egg_typed_expr_thunk__: Callable[[], TypedExprDecl] def __post_init__(self) -> None: @@ -762,7 +789,7 @@ def define_expr_method(name: str) -> None: """ Given the name of a method, explicitly defines it on the runtime type that holds `Expr` objects as a method. - Call this if you need a method to be defined on the type itself where overrding with `__getattr__` does not suffice, + Call this if you need a method to be defined on the type itself where overriding with `__getattr__` does not suffice, like for NumPy's `__array_ufunc__`. """ @@ -877,8 +904,7 @@ def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeClass | Run case InitRef(name): return RuntimeClass(Thunk.value(decls), TypeRefWithVars(name)) case FunctionRef() | MethodRef() | ClassMethodRef() | PropertyRef() | UnnamedFunctionRef(): - bound = JustTypeRef(ref.ident) if isinstance(ref, ClassMethodRef) else None - return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), bound) + return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), None) case ConstantRef(name): tp = decls._constants[name].type_ref case ClassVariableRef(cls_name, var_name): diff --git a/python/egglog/thunk.py b/python/egglog/thunk.py index 4363449e..89e900d1 100644 --- a/python/egglog/thunk.py +++ b/python/egglog/thunk.py @@ -63,7 +63,7 @@ def __call__(self) -> T: res = fn(*args) except Exception as e: self.state = Error(e, context) - raise e from None + raise else: self.state = Resolved(res) return res diff --git a/python/egglog/type_constraint_solver.py b/python/egglog/type_constraint_solver.py index 4a333179..1f2efc8c 100644 --- a/python/egglog/type_constraint_solver.py +++ b/python/egglog/type_constraint_solver.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass, field from itertools import chain, repeat from typing import TYPE_CHECKING, assert_never @@ -26,24 +26,22 @@ class TypeConstraintSolver: Given some typevars and types, solves the constraints to resolve the typevars. """ - _decls: Declarations = field(repr=False) - # Mapping of class ident to mapping of bound class typevar to type - _cls_typevar_index_to_type: defaultdict[Ident, dict[ClassTypeVarRef, JustTypeRef]] = field( - default_factory=lambda: defaultdict(dict) - ) + # Mapping of typevar index to inferred type for each class + _typevar_to_type: dict[Ident, JustTypeRef] = field(default_factory=dict, init=False) - def bind_class(self, ref: JustTypeRef) -> None: + def bind_class(self, ref: JustTypeRef, decls: Declarations) -> None: """ Bind the typevars of a class to the given types. Used for a situation like Map[int, str].create(). + + This is the same as binding the typevars of the class to the given types. """ - name = ref.ident - cls_typevars = self._decls.get_class_decl(name).type_vars - if len(cls_typevars) != len(ref.args): - raise TypeConstraintError(f"Mismatch of typevars {cls_typevars} and {ref}") - bound_typevars = self._cls_typevar_index_to_type[name] - for i, arg in enumerate(ref.args): - bound_typevars[cls_typevars[i]] = arg + try: + cls_typevars = decls.get_class_decl(ref.ident).type_vars + except KeyError: + cls_typevars = () + for typevar, arg in zip(cls_typevars, ref.args, strict=True): + self.infer_typevars(typevar, arg) def infer_arg_types( self, @@ -51,61 +49,86 @@ def infer_arg_types( fn_return: TypeOrVarRef, fn_var_args: TypeOrVarRef | None, return_: JustTypeRef, - cls_ident: Ident | None, - ) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...]]: + ) -> Iterable[JustTypeRef]: """ Given a return type, infer the argument types. If there is a variable arg, it returns an infinite iterable. - - Also returns the bound type params if the class name is passed in. """ - self.infer_typevars(fn_return, return_, cls_ident) - arg_types: Iterable[JustTypeRef] = [self.substitute_typevars(a, cls_ident) for a in fn_args] - if fn_var_args: - # Need to be generator so it can be infinite for variable args - arg_types = chain(arg_types, repeat(self.substitute_typevars(fn_var_args, cls_ident))) - bound_typevars = ( - tuple( - v - # Sort by the index of the typevar in the class - for _, v in sorted( - self._cls_typevar_index_to_type[cls_ident].items(), - key=lambda kv: self._decls.get_class_decl(cls_ident).type_vars.index(kv[0]), - ) - ) - if cls_ident - else () - ) - return arg_types, bound_typevars - - def infer_typevars(self, fn_arg: TypeOrVarRef, arg: JustTypeRef, cls_ident: Ident | None = None) -> None: + self.infer_typevars(fn_return, return_) + arg_types = [self.substitute_typevars(fn_arg) for fn_arg in fn_args] + if fn_var_args is None: + return arg_types + var_arg_type = self.substitute_typevars(fn_var_args) + return chain(arg_types, repeat(var_arg_type)) + + def infer_typevars(self, fn_arg: TypeOrVarRef, arg: JustTypeRef) -> None: + """ + Infer typevars from a function argument and a given type, raises TypeConstraintError if they are incompatible. + """ match fn_arg: case TypeRefWithVars(cls_ident, fn_args): if cls_ident != arg.ident: raise TypeConstraintError(f"Expected {cls_ident}, got {arg.ident}") for inner_fn_arg, inner_arg in zip(fn_args, arg.args, strict=True): - self.infer_typevars(inner_fn_arg, inner_arg, cls_ident) - case ClassTypeVarRef(): - if cls_ident is None: - msg = "Cannot infer typevar without class name" - raise RuntimeError(msg) - - class_typevars = self._cls_typevar_index_to_type[cls_ident] - if fn_arg in class_typevars: - if class_typevars[fn_arg] != arg: - raise TypeConstraintError(f"Expected {class_typevars[fn_arg]}, got {arg}") + self.infer_typevars(inner_fn_arg, inner_arg) + case TypeVarRef(typevar_ident): + if typevar_ident in self._typevar_to_type: + if self._typevar_to_type[typevar_ident] != arg: + raise TypeConstraintError(f"Expected {self._typevar_to_type[typevar_ident]}, got {arg}") else: - class_typevars[fn_arg] = arg + self._typevar_to_type[typevar_ident] = arg case _: assert_never(fn_arg) - def substitute_typevars(self, tp: TypeOrVarRef, cls_ident: Ident | None = None) -> JustTypeRef: + def substitute_typevars(self, tp: TypeOrVarRef) -> JustTypeRef: + """ + Substitute typevars in a type with their inferred types, raises TypeConstraintError if a typevar is unresolved. + """ match tp: - case ClassTypeVarRef(): - assert cls_ident is not None + case TypeVarRef(typevar_ident): try: - return self._cls_typevar_index_to_type[cls_ident][tp] + return self._typevar_to_type[typevar_ident] except KeyError as e: - raise TypeConstraintError(f"Not enough bound typevars for {tp!r} in class {cls_ident}") from e + raise TypeConstraintError(f"Unresolved type variable: {typevar_ident}") from e case TypeRefWithVars(name, args): - return JustTypeRef(name, tuple(self.substitute_typevars(arg, cls_ident) for arg in args)) + return JustTypeRef(name, tuple(self.substitute_typevars(arg) for arg in args)) assert_never(tp) + + def substitute_typevars_try_function( + self, tp: TypeOrVarRef, value: Callable, decls: Callable[[], Declarations] + ) -> JustTypeRef: + """ + Try to substitute typevars in a type with their inferred types. + + If this fails and we have an UnstableFn type and a function value, we can try to infer the typevars by calling + it with the input types, if we can resolve those + """ + from .egraph import set_current_ruleset # noqa: PLC0415 + from .runtime import RuntimeExpr # noqa: PLC0415 + + try: + return self.substitute_typevars(tp) + except TypeConstraintError: + if isinstance(tp, TypeVarRef) or tp.ident != Ident.builtin("UnstableFn") or not callable(value): + raise + # Probe against an isolated copy of the declarations with no ambient ruleset so any temporary + # unnamed-function rewrites created while inferring types are discarded after the probe. + probe_decls = decls().copy() + dummy_args = [ + RuntimeExpr.__from_values__( + probe_decls, + TypedExprDecl(self.substitute_typevars(arg_tp), DummyDecl()), + ) + for arg_tp in tp.args[1:] + ] + try: + with set_current_ruleset(None): + result = value(*dummy_args) + except Exception as e: + e.add_note(f"While trying to infer return type of {value} by calling it") + raise + if not isinstance(result, RuntimeExpr): + raise TypeConstraintError( + f"Function {value} did not return a RuntimeExpr, got {type(result)}, so cannot infer return type" + ) + self.infer_typevars(tp.args[0], result.__egg_typed_expr__.tp) + return self.substitute_typevars(tp) diff --git a/python/egglog/visualizer_widget.py b/python/egglog/visualizer_widget.py index c096da99..5ece0cd9 100644 --- a/python/egglog/visualizer_widget.py +++ b/python/egglog/visualizer_widget.py @@ -1,11 +1,13 @@ +import base64 +import json import pathlib import webbrowser import anywidget import traitlets from IPython.display import display -from ipywidgets.embed import embed_minimal_html +# from ipywidgets.embed import embed_minimal_html from .ipython_magic import IN_IPYTHON CURRENT_DIR = pathlib.Path(__file__).parent @@ -24,16 +26,24 @@ class VisualizerWidget(anywidget.AnyWidget): def display_or_open(self) -> None: """ - Displays the widget if we are in a Jupyter environment, otherwise saves it to a file and opens it. + Displays the widget if we are in a Jupyter environment, otherwise opens a standalone HTML page. """ if IN_IPYTHON: display(self) return - # 1. Create a temporary html file that will stay open after close - # 2. Write the widget to it with embed_minimal_html - # 3. Open the file using the open function from graphviz - file = pathlib.Path.cwd() / "tmp.html" - # https://github.com/manzt/anywidget/issues/339#issuecomment-1755654547 - embed_minimal_html(file, views=[self], drop_defaults=False) - print("Visualizer widget saved to", file) - webbrowser.open(file.as_uri()) + payload = json.dumps(self.egraphs).replace(" + + +""" diff --git a/python/tests/__snapshots__/test_array_api/TestLoopNest.test_index_codegen[expr].py b/python/tests/__snapshots__/test_array_api/TestLoopNest.test_index_codegen[expr].py index a80e7566..859323b2 100644 --- a/python/tests/__snapshots__/test_array_api/TestLoopNest.test_index_codegen[expr].py +++ b/python/tests/__snapshots__/test_array_api/TestLoopNest.test_index_codegen[expr].py @@ -1,13 +1,14 @@ -_Value_1 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(0), Int(0), Int.var("i"), Int.var("j")))) -_Value_2 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(0), Int(1), Int.var("i"), Int.var("j")))) -_Value_3 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(1), Int(0), Int.var("i"), Int.var("j")))) -_Value_4 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(1), Int(1), Int.var("i"), Int.var("j")))) -_Value_5 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(2), Int(0), Int.var("i"), Int.var("j")))) -_Value_6 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(2), Int(1), Int.var("i"), Int.var("j")))) +_Value_1 = NDArray.var("X").index(TupleInt(Vec(Int(0), Int(0), Int.var("i"), Int.var("j")))) +_Value_2 = NDArray.var("X").index(TupleInt(Vec(Int(0), Int(1), Int.var("i"), Int.var("j")))) +_Value_3 = NDArray.var("X").index(TupleInt(Vec(Int(1), Int(0), Int.var("i"), Int.var("j")))) +_Value_4 = NDArray.var("X").index(TupleInt(Vec(Int(1), Int(1), Int.var("i"), Int.var("j")))) +_Value_5 = NDArray.var("X").index(TupleInt(Vec(Int(2), Int(0), Int.var("i"), Int.var("j")))) +_Value_6 = NDArray.var("X").index(TupleInt(Vec(Int(2), Int(1), Int.var("i"), Int.var("j")))) ( - ( - ((((_Value_1.conj() * _Value_1).real() + (_Value_2.conj() * _Value_2).real()) + (_Value_3.conj() * _Value_3).real()) + (_Value_4.conj() * _Value_4).real()) - + (_Value_5.conj() * _Value_5).real() - ) + (_Value_1.conj() * _Value_1).real() + + (_Value_2.conj() * _Value_2).real() + + (_Value_3.conj() * _Value_3).real() + + (_Value_4.conj() * _Value_4).real() + + (_Value_5.conj() * _Value_5).real() + (_Value_6.conj() * _Value_6).real() ).sqrt() \ No newline at end of file diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py index 7bb9f1ea..bf59a430 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py @@ -12,7 +12,7 @@ def __fn(X, y): _4 = y == np.array(2) _5 = np.sum(_4) _6 = np.array((_1, _3, _5, )).astype(np.dtype(np.float64)) - _7 = _6 / np.array(float(150)) + _7 = _6 / np.array(150) _8 = np.zeros((3, 4, ), dtype=np.dtype(np.float64)) _9 = np.sum(X[_0], axis=0) _10 = _9 / np.array(X[_0].shape[0]) @@ -39,7 +39,7 @@ def __fn(X, y): _28 = _27 / np.array(_26.shape[0]) _29 = np.sqrt(_28) _30 = _29 == np.array(0) - _29[_30] = np.array(float(1)) + _29[_30] = np.array((150 / 150)) _31 = _21 / _29 _32 = _17 * _31 _33 = np.linalg.svd(_32, full_matrices=False) diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py index 0de1df3d..cda4a169 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py @@ -1,48 +1,43 @@ _NDArray_1 = NDArray.var("X") assume_dtype(_NDArray_1, DType.float64) -assume_shape(_NDArray_1, TupleInt.from_vec(Vec[Int](Int(150), Int(4)))) +assume_shape(_NDArray_1, TupleInt(Vec(Int(150), Int(4)))) assume_isfinite(_NDArray_1) _NDArray_2 = NDArray.var("y") assume_dtype(_NDArray_2, DType.int64) -assume_shape(_NDArray_2, TupleInt.from_vec(Vec[Int](Int(150)))) -assume_value_one_of(_NDArray_2, TupleValue.from_vec(Vec[Value](Value.int(Int(0)), Value.int(Int(1)), Value.int(Int(2))))) +assume_shape(_NDArray_2, TupleInt(Vec(Int(150)))) +assume_value_one_of(_NDArray_2, TupleValue(Vec(Value.from_int(Int(0)), Value.from_int(Int(1)), Value.from_int(Int(2))))) _NDArray_3 = astype( - NDArray.vector( - TupleValue.from_vec( - Vec[Value]( - sum(_NDArray_2 == NDArray.scalar(Value.int(Int(0)))).to_value(), - sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value(), - sum(_NDArray_2 == NDArray.scalar(Value.int(Int(2)))).to_value(), + NDArray( + RecursiveValue.vec( + Vec( + RecursiveValue(sum(_NDArray_2 == NDArray(RecursiveValue(Value.from_int(Int(0))))).index(TupleInt())), + RecursiveValue(sum(_NDArray_2 == NDArray(RecursiveValue(Value.from_int(Int(1))))).index(TupleInt())), + RecursiveValue(sum(_NDArray_2 == NDArray(RecursiveValue(Value.from_int(Int(2))))).index(TupleInt())), ) ) ), DType.float64, -) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1"))))) -_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device)) +) / NDArray(RecursiveValue(Value.from_int(Int(150)))) +_NDArray_4 = zeros(TupleInt(Vec(Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device)) _MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice()) -_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1))) -_NDArray_5 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] -_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0))) -_NDArray_4[_IndexKey_1] = sum(_NDArray_5, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_5.shape[Int(0)])) -_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1))) -_NDArray_6 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] -_NDArray_4[_IndexKey_2] = sum(_NDArray_6, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_6.shape[Int(0)])) -_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) -_NDArray_7 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))] -_NDArray_4[_IndexKey_3] = sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)])) -_NDArray_8 = concat( - TupleNDArray.from_vec(Vec[NDArray](_NDArray_5 - _NDArray_4[_IndexKey_1], _NDArray_6 - _NDArray_4[_IndexKey_2], _NDArray_7 - _NDArray_4[_IndexKey_3])), OptionalInt.some(Int(0)) -) -_NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))) -_NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)]))) +_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1))) +_NDArray_5 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray(RecursiveValue(Value.from_int(Int(0)))))] +_NDArray_4[_IndexKey_1] = sum(_NDArray_5, OptionalIntOrTuple.int(Int(0))) / NDArray(RecursiveValue(Value.from_int(_NDArray_5.shape[Int(0)]))) +_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1))) +_NDArray_6 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray(RecursiveValue(Value.from_int(Int(1)))))] +_NDArray_4[_IndexKey_2] = sum(_NDArray_6, OptionalIntOrTuple.int(Int(0))) / NDArray(RecursiveValue(Value.from_int(_NDArray_6.shape[Int(0)]))) +_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) +_NDArray_7 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray(RecursiveValue(Value.from_int(Int(2)))))] +_NDArray_4[_IndexKey_3] = sum(_NDArray_7, OptionalIntOrTuple.int(Int(0))) / NDArray(RecursiveValue(Value.from_int(_NDArray_7.shape[Int(0)]))) +_NDArray_8 = concat(TupleNDArray(Vec(_NDArray_5 - _NDArray_4[_IndexKey_1], _NDArray_6 - _NDArray_4[_IndexKey_2], _NDArray_7 - _NDArray_4[_IndexKey_3])), OptionalInt.some(Int(0))) +_NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, OptionalIntOrTuple.int(Int(0))) / NDArray(RecursiveValue(Value.from_int(_NDArray_8.shape[Int(0)]))))) +_NDArray_10 = sqrt(sum(_NDArray_9, OptionalIntOrTuple.int(Int(0))) / NDArray(RecursiveValue(Value.from_int(_NDArray_9.shape[Int(0)])))) _NDArray_11 = copy(_NDArray_10) -_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar( - Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) -) -_TupleNDArray_1 = svd( +_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray(RecursiveValue(Value.from_int(Int(0)))))] = NDArray(RecursiveValue(Value.from_int(Int(150)) / Value.from_int(Int(150)))) +_TupleNDArray_1 = svd_( sqrt( asarray( - NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), + NDArray(RecursiveValue(Value.from_float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147")))))), OptionalDType.some(DType.float64), OptionalBool.none, OptionalDevice.some(_NDArray_1.device), @@ -51,34 +46,44 @@ * (_NDArray_8 / _NDArray_11), Boolean(False), ) -_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) +_Slice_1 = Slice( + OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray(RecursiveValue(Value.from_float(Float(0.0001)))), DType.int32)).index(TupleInt()).to_int) +) _NDArray_12 = ( - _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] - / _NDArray_11 + _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_11 ).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] -_TupleNDArray_2 = svd( +_TupleNDArray_2 = svd_( ( - sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2")))))) - * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T + sqrt( + NDArray(RecursiveValue(Value.from_int(Int(150)))) + * _NDArray_3 + * NDArray(RecursiveValue(Value.from_float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2")))))) + ) + * (_NDArray_4 - _NDArray_3 @ _NDArray_4).T ).T @ _NDArray_12, Boolean(False), ) ( - (_NDArray_1 - (_NDArray_3 @ _NDArray_4)) + (_NDArray_1 - _NDArray_3 @ _NDArray_4) @ ( _NDArray_12 @ _TupleNDArray_2[Int(2)].T[ IndexKey.multi_axis( MultiAxisIndexKey.from_vec( - Vec[MultiAxisIndexKeyItem]( + Vec( _MultiAxisIndexKeyItem_1, MultiAxisIndexKeyItem.slice( Slice( OptionalInt.none, OptionalInt.some( - sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32)) - .to_value() + sum( + astype( + _TupleNDArray_2[Int(1)] > NDArray(RecursiveValue(Value.from_float(Float(0.0001)))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))], + DType.int32, + ) + ) + .index(TupleInt()) .to_int ), ) @@ -88,8 +93,4 @@ ) ] ) -)[ - IndexKey.multi_axis( - MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](_MultiAxisIndexKeyItem_1, MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(2)))))) - ) -] \ No newline at end of file +)[IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(_MultiAxisIndexKeyItem_1, MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(2)))))))] \ No newline at end of file diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py index 1601d464..58d07e28 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py @@ -1,101 +1,78 @@ _NDArray_1 = NDArray.var("X") assume_dtype(_NDArray_1, DType.float64) -assume_shape(_NDArray_1, TupleInt.from_vec(Vec[Int](Int(150), Int(4)))) +assume_shape(_NDArray_1, TupleInt(Vec(Int(150), Int(4)))) assume_isfinite(_NDArray_1) _NDArray_2 = NDArray.var("y") assume_dtype(_NDArray_2, DType.int64) -assume_shape(_NDArray_2, TupleInt.from_vec(Vec[Int](Int(150)))) -_TupleValue_1 = TupleValue.from_vec(Vec[Value](Value.int(Int(0)), Value.int(Int(1)), Value.int(Int(2)))) -assume_value_one_of(_NDArray_2, _TupleValue_1) -_NDArray_3 = zeros( - TupleInt.from_vec(Vec[Int](NDArray.vector(_TupleValue_1).shape[Int(0)], asarray(_NDArray_1).shape[Int(1)])), +assume_shape(_NDArray_2, TupleInt(Vec(Int(150)))) +assume_value_one_of(_NDArray_2, TupleValue(Vec(Value.from_int(Int(0)), Value.from_int(Int(1)), Value.from_int(Int(2))))) +_NDArray_3 = astype(unique_counts_counts(_NDArray_2), asarray(_NDArray_1).dtype) / NDArray(RecursiveValue(Value.from_float(Float(150.0)))) +_NDArray_4 = zeros( + TupleInt( + Vec( + NDArray(RecursiveValue.vec(Vec(RecursiveValue(Value.from_int(Int(0))), RecursiveValue(Value.from_int(Int(1))), RecursiveValue(Value.from_int(Int(2)))))).shape[Int(0)], + asarray(_NDArray_1).shape[Int(1)], + ) + ), OptionalDType.some(asarray(_NDArray_1).dtype), OptionalDevice.some(asarray(_NDArray_1).device), ) _MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice()) _IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1))) -_IndexKey_2 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(0)))) -_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0))) -_NDArray_3[_IndexKey_1] = mean(asarray(_NDArray_1)[_IndexKey_2], _OptionalIntOrTuple_1) -_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1))) -_IndexKey_4 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(1)))) -_NDArray_3[_IndexKey_3] = mean(asarray(_NDArray_1)[_IndexKey_4], _OptionalIntOrTuple_1) -_IndexKey_5 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) -_IndexKey_6 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(2)))) -_NDArray_3[_IndexKey_5] = mean(asarray(_NDArray_1)[_IndexKey_6], _OptionalIntOrTuple_1) -_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device)) -_IndexKey_7 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1))) -_NDArray_4[_IndexKey_7] = mean(_NDArray_1[_IndexKey_2], _OptionalIntOrTuple_1) -_IndexKey_8 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1))) -_NDArray_4[_IndexKey_8] = mean(_NDArray_1[_IndexKey_4], _OptionalIntOrTuple_1) -_IndexKey_9 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) -_NDArray_4[_IndexKey_9] = mean(_NDArray_1[_IndexKey_6], _OptionalIntOrTuple_1) -_NDArray_5 = concat( - TupleNDArray.from_vec( - Vec[NDArray]( - _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] - _NDArray_4[_IndexKey_7], - _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] - _NDArray_4[_IndexKey_8], - _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))] - _NDArray_4[_IndexKey_9], +_NDArray_5 = NDArray(RecursiveValue(Value.from_int(Int(0)))) +_NDArray_4[_IndexKey_1] = mean(asarray(_NDArray_1)[IndexKey.ndarray(unique_inverse_inverse_indices(_NDArray_2) == _NDArray_5)], OptionalIntOrTuple.int(Int(0))) +_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1))) +_NDArray_4[_IndexKey_2] = mean( + asarray(_NDArray_1)[IndexKey.ndarray(unique_inverse_inverse_indices(_NDArray_2) == NDArray(RecursiveValue(Value.from_int(Int(1)))))], OptionalIntOrTuple.int(Int(0)) +) +_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) +_NDArray_4[_IndexKey_3] = mean( + asarray(_NDArray_1)[IndexKey.ndarray(unique_inverse_inverse_indices(_NDArray_2) == NDArray(RecursiveValue(Value.from_int(Int(2)))))], OptionalIntOrTuple.int(Int(0)) +) +_NDArray_6 = asarray(reshape(asarray(_NDArray_2), TupleInt(Vec(Int(-1))))) +_Int_1 = unique_values(concat(TupleNDArray(Vec(unique_values(asarray(_NDArray_6)))))).shape[Int(0)] +_NDArray_7 = concat( + TupleNDArray( + Vec( + asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_6 == _NDArray_5)] - _NDArray_4[_IndexKey_1], + asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_6 == NDArray(RecursiveValue(Value.from_int(Int(1)))))] - _NDArray_4[_IndexKey_2], + asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_6 == NDArray(RecursiveValue(Value.from_int(Int(2)))))] - _NDArray_4[_IndexKey_3], ) ), OptionalInt.some(Int(0)), ) -_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1) -_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar( - Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) -) -_TupleNDArray_1 = svd( +_NDArray_8 = std(_NDArray_7, OptionalIntOrTuple.int(Int(0))) +_NDArray_8[IndexKey.ndarray(std(_NDArray_7, OptionalIntOrTuple.int(Int(0))) == _NDArray_5)] = NDArray(RecursiveValue(Value.from_float(Float(1.0)))) +_TupleNDArray_1 = svd_( sqrt( asarray( - NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), - OptionalDType.some(DType.float64), + NDArray(RecursiveValue(Value.from_float(Float(1.0) / Float.from_int(Int(150) - _Int_1)))), + OptionalDType.some(asarray(_NDArray_1).dtype), OptionalBool.none, - OptionalDevice.some(_NDArray_1.device), + OptionalDevice.some(asarray(_NDArray_1).device), ) ) - * (_NDArray_5 / _NDArray_6), + * (_NDArray_7 / _NDArray_8), Boolean(False), ) -_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) -_NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1))))) -_NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7)))))) -_NDArray_9 = std( - concat( - TupleNDArray.from_vec( - Vec[NDArray]( - asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(0))])] - _NDArray_3[_IndexKey_1], - asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(1))])] - _NDArray_3[_IndexKey_3], - asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(2))])] - _NDArray_3[_IndexKey_5], - ) - ), - OptionalInt.some(Int(0)), - ), - _OptionalIntOrTuple_1, +_Slice_1 = Slice( + OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray(RecursiveValue(Value.from_float(Float(0.0001)))), DType.int32)).index(TupleInt()).to_int) ) -_NDArray_10 = copy(_NDArray_9) -_NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0))) -_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1"))))) -_TupleNDArray_2 = svd( +_NDArray_9 = ( + _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_8 +).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] +_TupleNDArray_2 = svd_( ( - sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2")))))) - * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T + sqrt(NDArray(RecursiveValue(Value.from_int(Int(150)))) * _NDArray_3 * NDArray(RecursiveValue(Value.from_float(Float(1.0) / Float.from_int(_Int_1 - Int(1)))))) + * (_NDArray_4 - _NDArray_3 @ _NDArray_4).T ).T - @ ( - ( - _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] - / _NDArray_6 - ).T - / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] - ), + @ _NDArray_9, Boolean(False), ) ( - (asarray(_NDArray_1) - ((astype(unique_counts(_NDArray_2)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(150.0)))) @ _NDArray_3)) + (asarray(_NDArray_1) - _NDArray_3 @ _NDArray_4) @ ( - ( - (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_10).T - / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] - ) + _NDArray_9 @ _TupleNDArray_2[Int(2)].T[ IndexKey.multi_axis( MultiAxisIndexKey.from_vec( @@ -105,8 +82,13 @@ Slice( OptionalInt.none, OptionalInt.some( - sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32)) - .to_value() + sum( + astype( + _TupleNDArray_2[Int(1)] > NDArray(RecursiveValue(Value.from_float(Float(0.0001)))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))], + DType.int32, + ) + ) + .index(TupleInt()) .to_int ), ) diff --git a/python/tests/__snapshots__/test_array_api/test_jit[tuple][code].py b/python/tests/__snapshots__/test_array_api/test_jit[tuple][code].py deleted file mode 100644 index 71501846..00000000 --- a/python/tests/__snapshots__/test_array_api/test_jit[tuple][code].py +++ /dev/null @@ -1,2 +0,0 @@ -def __fn(x, y): - return x[(x.shape + (1, 2, ))[100]] diff --git a/python/tests/__snapshots__/test_array_api/test_jit[tuple][expr].py b/python/tests/__snapshots__/test_array_api/test_jit[tuple][expr].py index fd442e81..e23fad80 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[tuple][expr].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[tuple][expr].py @@ -1 +1,6 @@ -NDArray.var("x")[IndexKey.int((NDArray.var("x").shape + TupleInt.from_vec(Vec[Int](Int(1), Int(2))))[Int(100)])] \ No newline at end of file +_Int_1 = check_index(NDArray.var("x").shape.length() + Int(2), Int(100)) +NDArray.var("x")[ + IndexKey.int( + Int.if_(_Int_1 < NDArray.var("x").shape.length(), lambda: NDArray.var("x").shape[_Int_1], lambda: TupleInt(Vec(Int(1), Int(2)))[_Int_1 - NDArray.var("x").shape.length()]) + ) +] \ No newline at end of file diff --git a/python/tests/__snapshots__/test_array_api/test_jit[tuple][initial_expr].py b/python/tests/__snapshots__/test_array_api/test_jit[tuple][initial_expr].py index fd442e81..1549c166 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[tuple][initial_expr].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[tuple][initial_expr].py @@ -1 +1 @@ -NDArray.var("x")[IndexKey.int((NDArray.var("x").shape + TupleInt.from_vec(Vec[Int](Int(1), Int(2))))[Int(100)])] \ No newline at end of file +NDArray.var("x")[IndexKey.int((NDArray.var("x").shape + TupleInt(Vec(Int(1), Int(2))))[Int(100)])] \ No newline at end of file diff --git a/python/tests/__snapshots__/test_array_api/test_program_compile[tuple][expr].py b/python/tests/__snapshots__/test_array_api/test_program_compile[tuple][expr].py index 327edb60..86a09db5 100644 --- a/python/tests/__snapshots__/test_array_api/test_program_compile[tuple][expr].py +++ b/python/tests/__snapshots__/test_array_api/test_program_compile[tuple][expr].py @@ -1 +1 @@ -tuple_value_program(TupleValue.from_vec(Vec[Value](Value.int(Int(1)), Value.int(Int(2))))) \ No newline at end of file +tuple_value_program(TupleValue(Vec(Value.from_int(Int(1)), Value.from_int(Int(2))))) \ No newline at end of file diff --git a/python/tests/conftest.py b/python/tests/conftest.py index ee5df955..6a0b230d 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -3,23 +3,24 @@ import pytest from syrupy.extensions.single_file import SingleFileSnapshotExtension -import egglog.conversion -import egglog.exp.array_api - @pytest.fixture(autouse=True) def _reset_conversions(): - old_conversions = copy.copy(egglog.conversion.CONVERSIONS) - old_conversion_decls = copy.copy(egglog.conversion._TO_PROCESS_DECLS) + from egglog import conversion # noqa: PLC0415 + + old_conversions = copy.copy(conversion.CONVERSIONS) + old_conversion_decls = copy.copy(conversion._TO_PROCESS_DECLS) yield - egglog.conversion.CONVERSIONS = old_conversions - egglog.conversion._TO_PROCESS_DECLS = old_conversion_decls + conversion.CONVERSIONS = old_conversions + conversion._TO_PROCESS_DECLS = old_conversion_decls @pytest.fixture(autouse=True) def _reset_current_egraph(): + from egglog.exp import array_api # noqa: PLC0415 + yield - egglog.exp.array_api._CURRENT_EGRAPH = None + array_api._CURRENT_EGRAPH = None class PythonSnapshotExtension(SingleFileSnapshotExtension): diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 34564ddc..4f0369cf 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -2,11 +2,12 @@ import inspect from collections.abc import Callable from functools import partial -from itertools import product +from itertools import product, repeat from pathlib import Path from types import FunctionType import numba +import numpy as np import pytest from sklearn import config_context, datasets from sklearn.discriminant_analysis import LinearDiscriminantAnalysis @@ -32,6 +33,20 @@ def test_upcast_order(): assert Int(2) > round(0.5 * Int(2)) # type: ignore[operator] +def test_cross_eval_numpy_uses_fresh_egraph(): + x = NDArray([1, 2, 3]) + y = NDArray([4, 5, 6]) + + assert cross(x, y).eval_numpy(np.dtype("int64")).tolist() == [-3, 6, -3] + + +def test_vecdot_eval_numpy_uses_fresh_egraph(): + v = NDArray([[0.0, 5.0, 0.0], [0.0, 0.0, 10.0], [0.0, 6.0, 8.0]]) + n = NDArray([0.0, 0.6, 0.8]) + + assert vecdot(v, n).eval_numpy(np.dtype("float64")).tolist() == pytest.approx([3.0, 8.0, 10.0]) + + @function(ruleset=array_api_ruleset) def is_even(x: Int) -> Boolean: return x % 2 == 0 @@ -39,9 +54,9 @@ def is_even(x: Int) -> Boolean: class TestTupleValue: def test_includes(self): - x = TupleValue.EMPTY.append(Value.bool(FALSE)) - check_eq(x.contains(Value.bool(FALSE)), TRUE, array_api_schedule) - check_eq(x.contains(Value.bool(TRUE)), FALSE, array_api_schedule) + x = TupleValue((FALSE,)) + check_eq(x.contains(Value.from_bool(FALSE)), TRUE, array_api_schedule) + check_eq(x.contains(Value.from_bool(TRUE)), FALSE, array_api_schedule) class TestTupleInt: @@ -49,33 +64,25 @@ def test_conversion(self): @function def f(x: TupleIntLike) -> TupleInt: ... - assert expr_parts(f((1, 2))) == expr_parts(f(TupleInt.from_vec(Vec[Int](Int(1), Int(2))))) + assert expr_parts(f((1, 2))) == expr_parts(f(TupleInt(Vec(Int(1), Int(2))))) def test_cons_to_vec(self): check_eq( - TupleInt.EMPTY.append(2), - TupleInt.from_vec(Vec(Int(2))), - array_api_schedule, - add_second=False, - ) - - def test_vec_to_cons(self): - check_eq( - TupleInt.from_vec(Vec(Int(1), Int(2))), - TupleInt.EMPTY.append(1).append(2), + TupleInt(()).append(2), + TupleInt(Vec(Int(2))), array_api_schedule, add_second=False, ) def test_indexing_cons(self): - check_eq(TupleInt.EMPTY.append(1).append(2)[Int(0)], Int(1), array_api_schedule) - check_eq(TupleInt.EMPTY.append(1).append(2)[Int(1)], Int(2), array_api_schedule) + check_eq(TupleInt((1, 2))[Int(0)], Int(1), array_api_schedule) + check_eq(TupleInt((1, 2))[Int(1)], Int(2), array_api_schedule) def test_length_cons(self): - check_eq(TupleInt.EMPTY.append(1).append(2).length(), Int(2), array_api_schedule) + check_eq(TupleInt((1, 2)).length(), Int(2), array_api_schedule) def test_fn_to_cons(self): - check_eq(TupleInt(2, lambda i: i), TupleInt.EMPTY.append(0).append(1), array_api_schedule, add_second=False) + check_eq(TupleInt.fn(2, lambda i: i), TupleInt((0, 1)), array_api_schedule, add_second=False) def test_range_length(self): check_eq(TupleInt.range(some_length).length(), some_length, array_api_schedule) @@ -86,27 +93,36 @@ def test_range_index(self): ) def test_not_contains_example(self): - check_eq(TupleInt.from_vec(Vec(Int(0), Int(1))).contains(Int(3)), FALSE, array_api_schedule) + check_eq(TupleInt(Vec(Int(0), Int(1))).contains(Int(3)), FALSE, array_api_schedule) def test_contains_example(self): - check_eq(TupleInt.from_vec(Vec(Int(0), Int(3))).contains(Int(3)), TRUE, array_api_schedule) + check_eq(TupleInt(Vec(Int(0), Int(3))).contains(Int(3)), TRUE, array_api_schedule) def test_filter_append(self): check_eq( - TupleInt.EMPTY.append(1).append(2).filter(is_even), - TupleInt.EMPTY.append(2), + TupleInt((1, 2)).filter(is_even), + TupleInt((2,)), array_api_schedule, add_second=False, ) def test_filter_range(self): - check_eq(TupleInt.range(4).filter(is_even), TupleInt.from_vec(Vec(Int(0), Int(2))), array_api_schedule) + check_eq(TupleInt.range(4).filter(is_even), TupleInt(Vec(Int(0), Int(2))), array_api_schedule) def test_filter_lambda_length(self): with set_current_ruleset(array_api_ruleset): x = TupleInt.range(5).filter(lambda i: i < 2).length() check_eq(x, Int(2), array_api_schedule) + def test_eq_true(self): + check_eq(TupleInt((1, 2)) == TupleInt((1, 2)), TRUE, array_api_schedule) + + def test_eq_false_length(self): + check_eq(TupleInt((1, 2)) == TupleInt((1, 2, 3)), FALSE, array_api_schedule) + + def test_eq_false_element(self): + check_eq(TupleInt((1, 2)) == TupleInt((1, 3)), FALSE, array_api_schedule) + @function def some_array_idx_fn(x: TupleInt) -> Value: ... @@ -114,11 +130,11 @@ def some_array_idx_fn(x: TupleInt) -> Value: ... class TestNDArray: def test_index(self): - x = NDArray(some_shape, some_dtype, some_array_idx_fn) + x = NDArray.fn(some_shape, some_dtype, some_array_idx_fn) check_eq(x.index(some_index), some_array_idx_fn(some_index), array_api_schedule) def test_shape(self): - x = NDArray(some_shape, some_dtype, some_array_idx_fn) + x = NDArray.fn(some_shape, some_dtype, some_array_idx_fn) check_eq(x.shape, some_shape, array_api_schedule) def test_simplify_any_unique(self): @@ -126,36 +142,73 @@ def test_simplify_any_unique(self): any( ( astype(unique_counts(NDArray.var("X"))[Int(1)], DType.float64) - / NDArray.scalar(Value.float(Float(150.0))) + / NDArray(Value.from_float(Float(150.0))) ) - < NDArray.scalar(Value.int(Int(0))) + < NDArray(Value.from_int(Int(0))) ) - .to_value() + .index(TupleInt()) .to_bool ) - check_eq(res, FALSE, array_api_schedule) + check_eq(res, Boolean(False), array_api_schedule) + + def test_other(self): + _NDArray_1 = NDArray.var("y") + assume_dtype(_NDArray_1, DType.int64) + assume_shape(_NDArray_1, TupleInt(Vec(Int(150)))) + assume_value_one_of( + _NDArray_1, + TupleValue(Vec(Value.from_int(Int(0)), Value.from_int(Int(1)), Value.from_int(Int(2)))), + ) + res = ( + any( + astype(unique_counts(_NDArray_1)[1], DType.float64) / NDArray(RecursiveValue(Value.from_int(Int(150)))) + < NDArray(RecursiveValue(Value.from_int(Int(0)))) + ) + .index(TupleInt()) + .to_bool + ) + check_eq(res, Boolean(False), array_api_schedule) def test_reshape_index(self): # Verify that it doesn't expand forever x = NDArray.var("x") - new_shape = TupleInt.single(Int(-1)) - res = reshape(x, new_shape).index(TupleInt.single(Int(1)) + TupleInt.single(Int(2))) + new_shape = TupleInt((-1,)) + res = reshape(x, new_shape).index(TupleInt((Int(1),)) + TupleInt((Int(2),))) egraph = EGraph() egraph.register(res) egraph.run(array_api_schedule) equiv_expr = egraph.extract_multiple(res, 10) assert len(equiv_expr) < 10 + def test_normalize_reshape_shape(self): + check_eq( + normalize_reshape_shape(TupleInt((Int(5),)), TupleInt((-1,))), + TupleInt((Int(5),)), + array_api_schedule, + ) + + def test_reshape_after_schedule_decls_access(self): + _ = array_api_schedule.__egg_decls__ + + x = NDArray.var("x") + assume_shape(x, TupleInt((Int(5),))) + res = reshape(x, TupleInt((-1,))) + egraph = EGraph() + egraph.register(res) + egraph.run(array_api_schedule) + + egraph.check(eq(res).to(x)) + def test_reshape_vec_noop(self): x = NDArray.var("x") - assume_shape(x, TupleInt.single(Int(5))) - res = reshape(x, TupleInt.single(Int(-1))) + assume_shape(x, TupleInt((Int(5),))) + res = reshape(x, TupleInt((-1,))) egraph = EGraph() egraph.register(res) egraph.run(array_api_schedule) equiv_expr = egraph.extract_multiple(res, 10) - assert len(equiv_expr) == 2 + assert len(equiv_expr) <= 3 egraph.check(eq(res).to(x)) @@ -169,11 +222,11 @@ def some_tuple_tuple_int_reduce_value_fn(carry: Value, x: TupleInt) -> Value: .. class TestTupleTupleInt: def test_reduce_value_zero(self): - x = TupleTupleInt(0, some_tuple_tuple_int_idx_fn) + x = TupleTupleInt.fn(0, some_tuple_tuple_int_idx_fn) check_eq(x.foldl_value(some_tuple_tuple_int_reduce_value_fn, some_value), some_value, array_api_schedule) def test_reduce_value_one(self): - x = TupleTupleInt(1, some_tuple_tuple_int_idx_fn) + x = TupleTupleInt.fn(1, some_tuple_tuple_int_idx_fn) check_eq( x.foldl_value(some_tuple_tuple_int_reduce_value_fn, some_value), some_tuple_tuple_int_reduce_value_fn(some_value, some_tuple_tuple_int_idx_fn(Int(0))), @@ -188,7 +241,7 @@ def test_product_example(self): aka product((0, 1, 2, 3), (4, 5)) == """ - # TODO: Increase size, but for now check doesnt terminate at larger sizes for some reason + # TODO: Increase size, but for now check doesn't terminate at larger sizes for some reason # input = ((0, 1, 2, 3), (4, 5)) input = ((0, 1), (4, 5)) expected_output = tuple(product(*input)) @@ -207,7 +260,7 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: # get only the inner shape for reduction reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple() - return NDArray( + return NDArray.fn( outshape, X.dtype, lambda k: LoopNestAPI.from_tuple(reduce_axis) @@ -221,7 +274,7 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: @function(ruleset=array_api_ruleset, subsume=True) def linalg_norm_v2(X: NDArrayLike, axis: TupleIntLike) -> NDArray: X = cast(NDArray, X) - return NDArray( + return NDArray.fn( X.shape.deselect(axis), X.dtype, lambda k: ndindex(X.shape.select(axis)) @@ -238,11 +291,11 @@ def linalg_val(X: NDArray, linalg_fn: Callable[[NDArray, TupleIntLike], NDArray] class TestLoopNest: @pytest.mark.parametrize("linalg_fn", [linalg_norm, linalg_norm_v2]) def test_shape(self, linalg_fn): - X = np.random.random((3, 2, 3, 4)) + X = np.random.default_rng(0).random((3, 2, 3, 4)) expect = np.linalg.norm(X, axis=(0, 1)) assert expect.shape == (3, 4) - check_eq(linalg_val(constant("X", NDArray), linalg_fn).shape, TupleInt.from_vec((3, 4)), array_api_schedule) + check_eq(linalg_val(constant("X", NDArray), linalg_fn).shape, TupleInt((3, 4)), array_api_schedule) @pytest.mark.parametrize("linalg_fn", [linalg_norm, linalg_norm_v2]) def test_abstract_index(self, linalg_fn): @@ -251,12 +304,12 @@ def test_abstract_index(self, linalg_fn): X = constant("X", NDArray) idxed = linalg_val(X, linalg_fn).index((i, j)) - _Value_1 = X.index(TupleInt.from_vec(Vec[Int](Int(0), Int(0), i, j))) - _Value_2 = X.index(TupleInt.from_vec(Vec[Int](Int(0), Int(1), i, j))) - _Value_3 = X.index(TupleInt.from_vec(Vec[Int](Int(1), Int(0), i, j))) - _Value_4 = X.index(TupleInt.from_vec(Vec[Int](Int(1), Int(1), i, j))) - _Value_5 = X.index(TupleInt.from_vec(Vec[Int](Int(2), Int(0), i, j))) - _Value_6 = X.index(TupleInt.from_vec(Vec[Int](Int(2), Int(1), i, j))) + _Value_1 = X.index(TupleInt((Int(0), Int(0), i, j))) + _Value_2 = X.index(TupleInt((Int(0), Int(1), i, j))) + _Value_3 = X.index(TupleInt((Int(1), Int(0), i, j))) + _Value_4 = X.index(TupleInt((Int(1), Int(1), i, j))) + _Value_5 = X.index(TupleInt((Int(2), Int(0), i, j))) + _Value_6 = X.index(TupleInt((Int(2), Int(1), i, j))) res = ( ( ( @@ -287,11 +340,14 @@ def test_index_codegen(self, snapshot_py): value_program(simplified_index).function_three(ndarray_program(X), int_program(i), int_program(j)), {"np": np}, ) - fn = cast(FunctionType, try_evaling(EGraph(), array_api_program_gen_schedule, res, res.as_py_object)) + egraph = EGraph() + egraph.register(res) + egraph.run(array_api_program_gen_schedule) + fn = cast(FunctionType, egraph.extract(res.as_py_object).value) assert inspect.getsource(fn) == snapshot_py(name="code") - X = np.random.random((3, 2, 3, 4)) + X = np.random.default_rng(0).random((3, 2, 3, 4)) expect = np.linalg.norm(X, axis=(0, 1)) for idxs in np.ndindex(*expect.shape): @@ -344,11 +400,24 @@ def lda(X: NDArray, y: NDArray): return run_lda(X, y) +def test_lda_symbolic_build_cold_schedule(): + X_arr = NDArray.var("X") + y_arr = NDArray.var("y") + egraph = EGraph() + with set_array_api_egraph(egraph): + res = lda(X_arr, y_arr) + assert isinstance(res, NDArray) + + @pytest.mark.parametrize( "program", [ pytest.param(lambda x, y: x + y, id="add"), - pytest.param(lambda x, y: x[(x.shape + TupleInt.from_vec((1, 2)))[100]], id="tuple"), + pytest.param( + lambda x, y: x[(x.shape + TupleInt((1, 2)))[100]], + id="tuple", + marks=pytest.mark.xfail(reason="functions aren't applied yet in to string"), + ), pytest.param(lda, id="lda"), ], ) @@ -383,13 +452,41 @@ def test_run_lda(fn_thunk, benchmark): benchmark(fn, X_np, y_np) +x, y, z, q, r = map(constant, ("x", "y", "z", "q", "r"), repeat(Value)) + + +@pytest.mark.parametrize( + ("input_expr", "expected"), + [ + pytest.param(x * x, x**2, id="exp"), + pytest.param(x * y + x * z, x * (y + z), id="factor"), + pytest.param(x * y + x * z * q + x * z * r, x * (y + z * (q + r)), id="factor most first"), + pytest.param( + x**2 * z * y + x**2 * q * z * y, x**2 * z * y * (q + Value.from_int(1)), id="factor out all terms equal" + ), + ], +) +def test_polynomial_factoring(input_expr: Value, expected: Value): + egraph = EGraph() + x = egraph.let("x", input_expr) + egraph.run(polynomial_schedule) + equiv_expr = egraph.extract(x) + # Normalized them both so that we don't have to worry about term order. + normalized = EGraph() + extracted_ref = normalized.let("extracted", equiv_expr) + expected_ref = normalized.let("expected", expected) + normalized.run(to_polynomial_ruleset.saturate()) + normalized.check(eq(extracted_ref).to(expected_ref)) + + # if calling as script, print out egglog source for test # similar to jit, but don't include pyobject parts so it works in vanilla egglog if __name__ == "__main__": print("Generating egglog source for test") egraph, _, _, program = function_to_program(lda, True) egraph.register(program.compile()) - try_evaling(egraph, array_api_program_gen_combined_ruleset.saturate(), program, program.statements) + egraph.run(array_api_program_gen_combined_ruleset.saturate()) + egraph.extract(program.statements) name = "python.egg" print("Saving to", name) Path(name).write_text(egraph.as_egglog_string) diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index 0e4f240b..f7763842 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -63,6 +63,17 @@ def test_example(example_file: pathlib.Path): BLACK_MODE = black.Mode(line_length=88) +REPO_ROOT = pathlib.Path(__file__).resolve().parents[2] +GENERIC_FRESH_EGRAPH_REPRO = REPO_ROOT / "test-data" / "generic-fresh-egraph-repro.egg" +GENERIC_FRESH_EGRAPH_ROOT = '(let $__expr_0 (unstable-vec-map (unstable-fn "f1") (vec-range 1)))' +GENERIC_FRESH_EGRAPH_RESOLVED = "(vec-of (int 1))" + + +def extract_best_term(program: str) -> str: + egraph = EGraph(record=True) + outputs = egraph.run_program(*egraph.parse_program(program)) + extract = next(output for output in outputs if isinstance(output, ExtractBest)) + return extract.termdag.to_string(extract.term) class TestEGraph: diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index d658844e..0efd66f8 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -60,6 +60,17 @@ def __mul__(self, other: Math) -> Math: ... egraph.check(eq(expr1).to(expr2)) +def test_let_auto_prefixes_global_names(capfd: pytest.CaptureFixture[str]): + egraph = EGraph(save_egglog_string=True) + + x = egraph.let("x", i64(1)) + egraph.check(eq(x).to(i64(1))) + + captured = capfd.readouterr() + assert "should start with `$`" not in captured.err + assert "(let $x " in egraph.as_egglog_string + + def test_fib(): egraph = EGraph() @@ -160,6 +171,33 @@ def test_extract_include_cost(): assert cost == 1 +def test_egraph_constructor_registers_actions(): + class ConstructorExpr(Expr): + def __init__(self) -> None: ... + + @function + def constructor_cost() -> i64: ... + + @function + def constructor_lhs() -> ConstructorExpr: ... + + @function + def constructor_rhs() -> ConstructorExpr: ... + + constructed = EGraph( + ConstructorExpr(), set_(constructor_cost()).to(i64(1)), eq(constructor_lhs()).to(constructor_rhs()) + ) + + expected = EGraph() + expected.register( + ConstructorExpr(), + set_(constructor_cost()).to(i64(1)), + eq(constructor_lhs()).to(constructor_rhs()), + ) + + assert str(constructed.freeze()) == str(expected.freeze()) + + def test_relation(): egraph = EGraph() @@ -658,6 +696,19 @@ def f() -> A: check_eq(f(), A(), r) + def test_function_ruleset_can_run_after_materialization_without_registration(self): + r = ruleset() + + @function(ruleset=r) + def f() -> A: + return A() + + # Materialize the function once so its default rewrite is added to the ruleset, + # but do not register any expression that would separately add `f` to the egraph. + f() + egraph = EGraph() + assert not egraph.run(r).updated + def test_constant(self): a = constant("a", A, A()) check_eq(a, A(), run()) @@ -781,16 +832,16 @@ def test_vec_like_conversion(): @function def my_fn(xs: VecLike[i64, i64Like]) -> Unit: ... - assert expr_parts(my_fn((1, 2))) == expr_parts(my_fn(Vec[i64](i64(1), i64(2)))) - assert expr_parts(my_fn([])) == expr_parts(my_fn(Vec[i64]())) + assert expr_parts(my_fn((1, 2))) == expr_parts(my_fn(Vec(i64(1), i64(2)))) + assert expr_parts(my_fn([])) == expr_parts(my_fn(Vec[i64].empty())) def test_set_like_conversion(): @function def my_fn(xs: SetLike[i64, i64Like]) -> Unit: ... - assert expr_parts(my_fn({1, 2})) == expr_parts(my_fn(Set[i64](i64(1), i64(2)))) - assert expr_parts(my_fn(set())) == expr_parts(my_fn(Set[i64]())) + assert expr_parts(my_fn({1, 2})) == expr_parts(my_fn(Set(i64(1), i64(2)))) + assert expr_parts(my_fn(set())) == expr_parts(my_fn(Set[i64].empty())) def test_map_like_conversion(): @@ -1093,6 +1144,20 @@ def __sub__(self, other: E) -> E: ... class TestScheduler: + def test_seq_schedule_decls_track_ruleset_updates(self): + egraph = EGraph() + + rel = relation("rel_live", i64) + live_rules = ruleset(name="live-rules") + schedule = seq(live_rules, run()).saturate() + _ = schedule.__egg_decls__ + + live_rules.register(rule(rel(i64(0))).then(rel(i64(1)))) + + egraph.register(rel(i64(0))) + egraph.run(schedule) + egraph.check(rel(i64(1))) + def test_sequence_repeat_saturate(self): """ Mirrors the scheduling example: alternate step-right and step-left, @@ -1245,7 +1310,7 @@ def my_f(xs: Vec[i64]) -> E: ... # cost = 2 x = i64(10) # cost = 3 + 2 = 5 - xs = Vec[i64](x) + xs = Vec(x) # cost = 100 res = E() # cost = 1 + 5 = 6 diff --git a/python/tests/test_pretty.py b/python/tests/test_pretty.py index 3c19c85f..d2a4dfc7 100644 --- a/python/tests/test_pretty.py +++ b/python/tests/test_pretty.py @@ -10,6 +10,7 @@ import pytest from egglog import * +from egglog.declarations import EGraphDecl if TYPE_CHECKING: from egglog.runtime import RuntimeExpr @@ -59,6 +60,14 @@ def __floor__(self) -> A: ... def __ceil__(self) -> A: ... +class Box(Expr): + def __init__(self, vec: Vec[A] = Vec[A].empty()) -> None: ... + + +class Wrapper(Expr): + def __init__(self, box: Box) -> None: ... + + @function def f(x: A) -> A: ... @@ -100,13 +109,20 @@ def binary(x: A, y: A) -> A: ... setitem_a[g()] = h() b = constant("b", A) +c_i64 = constant("c_i64", i64) +rel = relation("rel", A) @function def my_very_long_function_name() -> A: ... -long_line = my_very_long_function_name() + my_very_long_function_name() + my_very_long_function_name() +long_line = ( + my_very_long_function_name() + + my_very_long_function_name() + + my_very_long_function_name() + + my_very_long_function_name() +) r = ruleset(name="r") @@ -150,7 +166,7 @@ def __repr__(self) -> str: pytest.param(has_default(A()), "has_default()", id="has default"), pytest.param( rewrite(long_line).to(long_line), - "_A_1 = (my_very_long_function_name() + my_very_long_function_name()) + my_very_long_function_name()\nrewrite(_A_1).to(_A_1)", + "_A_1 = my_very_long_function_name() + my_very_long_function_name() + my_very_long_function_name() + my_very_long_function_name()\nrewrite(_A_1).to(_A_1)", id="wrap long line", ), pytest.param(A() - A(), "A() - A()", id="subtraction"), @@ -234,3 +250,48 @@ def __repr__(self) -> str: @pytest.mark.parametrize(("x", "s"), PARAMS) def test_str(x: RuntimeExpr, s: str) -> None: assert str(x) == s + + +FREEZE_PARAMS = [ + pytest.param((A(),), "EGraph(A()).freeze()", id="freeze add"), + pytest.param((b,), "EGraph(b).freeze()", id="freeze constant"), + pytest.param((set_(p()).to(i64(1)),), "EGraph(set_(p()).to(i64(1))).freeze()", id="freeze set"), + pytest.param((set_(c_i64).to(i64(1)),), "EGraph(set_(c_i64).to(i64(1))).freeze()", id="freeze constant set"), + pytest.param((union(g()).with_(h()),), "EGraph(union(g()).with_(h())).freeze()", id="freeze union"), + pytest.param((rel(g()),), "EGraph(rel(g())).freeze()", id="freeze relation"), + pytest.param((g(), subsume(g())), "EGraph(subsume(g())).freeze()", id="freeze subsume"), + pytest.param((g(), set_cost(g(), 10)), "EGraph(set_cost(g(), 10)).freeze()", id="freeze set cost"), +] + + +@pytest.mark.parametrize(("actions", "s"), FREEZE_PARAMS) +def test_frozen_egraph_str(actions: tuple[ActionLike, ...], s: str) -> None: + egraph = EGraph(*actions) + frozen = egraph.freeze() + assert isinstance(frozen.decl, EGraphDecl) + assert str(frozen) == s + + +def test_frozen_egraph_str_grounding() -> None: + egraph = EGraph() + egraph.register(f(g()), union(g()).with_(h())) + frozen = egraph.freeze() + assert isinstance(frozen.decl, EGraphDecl) + assert str(frozen) == "EGraph(union(g()).with_(h()), f(g())).freeze()" + + +def test_frozen_egraph_str_let_vec_constructor() -> None: + egraph = EGraph() + egraph.register(let("$x", Vec(A()))) + frozen = egraph.freeze() + assert isinstance(frozen.decl, EGraphDecl) + assert str(frozen) == 'EGraph(let("$x", Vec(A()))).freeze()' + + +def test_frozen_egraph_str_nested_vec_constructor() -> None: + egraph = EGraph() + egraph.register(Wrapper(Box(Vec(A())))) + frozen = egraph.freeze() + assert isinstance(frozen.decl, EGraphDecl) + assert "Value(" not in str(frozen) + assert str(frozen) == "EGraph(Wrapper(Box(Vec(A())))).freeze()" diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index a93fc52a..4f6e61a6 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -83,7 +83,7 @@ def test_py_object(): evalled = EvalProgram(fn, {"z": 10}) egraph = EGraph() egraph.register(evalled) - egraph.run((to_program_ruleset | eval_program_rulseset | program_gen_ruleset).saturate()) + egraph.run((to_program_ruleset | eval_program_ruleset | program_gen_ruleset).saturate()) res = cast("FunctionType", egraph.extract(evalled.as_py_object).value) assert res(1, 2) == 13 assert inspect.getsource(res) diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index e747ac79..8f2bd983 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -1,8 +1,11 @@ from __future__ import annotations +import doctest + import pytest from egglog.declarations import * +from egglog.exp import array_api from egglog.runtime import * from egglog.thunk import * from egglog.type_constraint_solver import * @@ -12,9 +15,7 @@ def test_type_str(): decls = Declarations( _classes={ Ident.builtin("i64"): ClassDecl(), - Ident.builtin("Map"): ClassDecl( - type_vars=(ClassTypeVarRef(Ident.builtin("K")), ClassTypeVarRef(Ident.builtin("V"))) - ), + Ident.builtin("Map"): ClassDecl(type_vars=(TypeVarRef(Ident.builtin("K")), TypeVarRef(Ident.builtin("V")))), } ) i64 = RuntimeClass(Thunk.value(decls), TypeRefWithVars(Ident.builtin("i64"))) @@ -40,7 +41,7 @@ def test_function_call(): def test_classmethod_call(): - K, V = ClassTypeVarRef(Ident.builtin("K")), ClassTypeVarRef(Ident.builtin("V")) + K, V = TypeVarRef(Ident.builtin("K")), TypeVarRef(Ident.builtin("V")) decls = Declarations( _classes={ Ident.builtin("i64"): ClassDecl(), @@ -127,3 +128,19 @@ def test_class_variable(): assert one.__egg_typed_expr__ == TypedExprDecl( JustTypeRef(Ident.builtin("i64")), CallDecl(ClassVariableRef(Ident.builtin("i64"), "one")) ) + + +def test_runtime_class_attr_lookup_is_stable(): + assert array_api.TupleInt.__getitem__ is array_api.TupleInt.__getitem__ + assert array_api.TupleInt.__dict__["__getitem__"] is array_api.TupleInt.__dict__["__getitem__"] + + +def test_doctest_finder_collects_runtime_function_docstrings(): + names = {test.name for test in doctest.DocTestFinder().find(array_api)} + assert { + "egglog.exp.array_api.TupleInt.__getitem__", + "egglog.exp.array_api.TupleInt.if_", + "egglog.exp.array_api.TupleTupleInt.product", + "egglog.exp.array_api.Value.diff", + "egglog.exp.array_api.RecursiveValue.__getitem__", + } <= names diff --git a/python/tests/test_tracing.py b/python/tests/test_tracing.py new file mode 100644 index 00000000..5de27b6b --- /dev/null +++ b/python/tests/test_tracing.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +import re +import subprocess +import sys +import textwrap +from unittest.mock import patch + +from egglog_pytest_otel import DEFAULT_OTLP_ENDPOINT, configure_pytest_otel, get_pytest_otel_config + + +def _console_tracing_script(body: str, *, extra_imports: str = "") -> str: + return "\n\n".join( + part + for part in ( + textwrap.dedent( + """ + from opentelemetry import trace + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor + + from egglog import bindings + + provider = TracerProvider(resource=Resource.create({"service.name": "egglog"})) + provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) + trace.set_tracer_provider(provider) + bindings.setup_tracing(exporter="console") + """ + ).strip(), + textwrap.dedent(extra_imports).strip(), + textwrap.dedent(body).strip(), + textwrap.dedent( + """ + bindings.shutdown_tracing() + provider.shutdown() + """ + ).strip(), + ) + if part + ) + + +HIGH_LEVEL_TRACE_SCRIPT = _console_tracing_script( + """ + from egglog import EGraph, i64 + + tracer = trace.get_tracer("test") + with tracer.start_as_current_span("parent"): + EGraph().extract(i64(0)) + """ +) + +LOW_LEVEL_TRACE_SCRIPT = _console_tracing_script( + """ + tracer = trace.get_tracer("test") + egraph = bindings.EGraph() + + def current_headers(): + carrier = {} + propagate.inject(carrier) + return carrier.get("traceparent"), carrier.get("tracestate") + + with tracer.start_as_current_span("run_program_parent"): + traceparent, tracestate = current_headers() + egraph.run_program(bindings.Push(1), traceparent=traceparent, tracestate=tracestate) + + lit = bindings.Lit(bindings.PanicSpan(), bindings.Int(0)) + with tracer.start_as_current_span("eval_expr_parent"): + traceparent, tracestate = current_headers() + sort, value = egraph.eval_expr(lit, traceparent=traceparent, tracestate=tracestate) + + with tracer.start_as_current_span("serialize_parent"): + traceparent, tracestate = current_headers() + egraph.serialize([], traceparent=traceparent, tracestate=tracestate) + + cost_model = bindings.CostModel( + lambda head, head_cost, children_costs: head_cost + sum(children_costs), + lambda func, args: 1, + lambda sort, value, element_costs: 1, + lambda sort, value: 1, + ) + with tracer.start_as_current_span("extractor_new_parent"): + traceparent, tracestate = current_headers() + extractor = bindings.Extractor([sort], egraph, cost_model, traceparent=traceparent, tracestate=tracestate) + + with tracer.start_as_current_span("extract_best_parent"): + traceparent, tracestate = current_headers() + extractor.extract_best(egraph, bindings.TermDag(), value, sort, traceparent=traceparent, tracestate=tracestate) + """, + extra_imports="from opentelemetry import propagate", +) + +SETUP_STATE_MACHINE_SCRIPT = textwrap.dedent( + """ + from egglog import bindings + + bindings.setup_tracing(exporter="console") + bindings.setup_tracing(exporter="console") + + try: + bindings.setup_tracing(exporter="http", endpoint="http://127.0.0.1:4318/v1/traces") + except RuntimeError: + print("reconfigure_error") + else: + raise SystemExit("expected reconfigure error") + + egraph = bindings.EGraph() + egraph.run_program(bindings.Push(1)) + bindings.shutdown_tracing() + + try: + bindings.setup_tracing(exporter="console") + except RuntimeError: + print("shutdown_error") + else: + raise SystemExit("expected shutdown error") + """ +) + +HTTP_SETUP_SCRIPT = textwrap.dedent( + """ + import threading + from http.server import BaseHTTPRequestHandler, HTTPServer + + from egglog import bindings + + class Handler(BaseHTTPRequestHandler): + def do_POST(self): + length = int(self.headers.get("content-length", "0")) + if length: + self.rfile.read(length) + self.send_response(200) + self.send_header("Content-Length", "0") + self.send_header("Content-Type", "application/x-protobuf") + self.send_header("Connection", "close") + self.end_headers() + + def log_message(self, format, *args): + pass + + server = HTTPServer(("127.0.0.1", 0), Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + try: + endpoint = f"http://127.0.0.1:{server.server_port}/v1/traces" + bindings.setup_tracing(exporter="http", endpoint=endpoint) + bindings.shutdown_tracing() + finally: + server.shutdown() + thread.join() + server.server_close() + + print("http_ok") + """ +) + +HTTP_MISSING_ENDPOINT_SCRIPT = textwrap.dedent( + """ + from egglog import bindings + + try: + bindings.setup_tracing(exporter="http") + except RuntimeError: + print("missing_endpoint") + else: + raise SystemExit("expected missing endpoint error") + """ +) + + +class DummyConfig: + def __init__(self, *, traces: str, endpoint: str | None = None) -> None: + self._options = { + "--otel-otlp-endpoint": endpoint, + "--otel-traces": traces, + } + + def getoption(self, name: str): + return self._options[name] + + +def _run_script(script: str) -> subprocess.CompletedProcess[str]: + return subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + check=True, + text=True, + ) + + +def _parse_python_span(stdout: str, name: str) -> tuple[str, str, str | None]: + match = re.search( + rf'"name": "{re.escape(name)}".*?"trace_id": "0x([0-9a-f]+)".*?"span_id": "0x([0-9a-f]+)".*?"parent_id": (null|"0x([0-9a-f]+)")', + stdout, + re.DOTALL, + ) + assert match is not None + return match.group(1), match.group(2), match.group(4) + + +def _parse_rust_span(stdout: str, name: str) -> tuple[str, str | None]: + match = re.search( + rf"Name\s*: {re.escape(name)}\s+TraceId\s*: ([0-9a-f]+)\s+SpanId\s*: [0-9a-f]+\s+TraceFlags\s*: .*?\s+ParentSpanId: ([0-9a-f]+|None)", + stdout, + re.DOTALL, + ) + assert match is not None + parent_span_id = None if match.group(2) == "None" else match.group(2) + return match.group(1), parent_span_id + + +def test_get_pytest_otel_config_defaults_to_off() -> None: + config = get_pytest_otel_config(DummyConfig(traces="off")) + assert config.traces == "off" + assert config.endpoint is None + + +def test_get_pytest_otel_config_uses_default_jaeger_endpoint() -> None: + config = get_pytest_otel_config(DummyConfig(traces="jaeger")) + assert config.traces == "jaeger" + assert config.endpoint == DEFAULT_OTLP_ENDPOINT + + +def test_get_pytest_otel_config_preserves_explicit_endpoint() -> None: + config = get_pytest_otel_config(DummyConfig(traces="jaeger", endpoint="http://127.0.0.1:9999/v1/traces")) + assert config.traces == "jaeger" + assert config.endpoint == "http://127.0.0.1:9999/v1/traces" + + +def test_configure_pytest_otel_console_uses_bindings_setup() -> None: + with ( + patch("egglog.bindings.setup_tracing") as setup_tracing, + patch("opentelemetry.trace.set_tracer_provider"), + ): + provider = configure_pytest_otel(DummyConfig(traces="console")) + + assert provider is not None + setup_tracing.assert_called_once_with(exporter="console") + + +def test_configure_pytest_otel_jaeger_uses_http_setup() -> None: + with ( + patch("egglog.bindings.setup_tracing") as setup_tracing, + patch("opentelemetry.trace.set_tracer_provider"), + ): + provider = configure_pytest_otel(DummyConfig(traces="jaeger")) + + assert provider is not None + setup_tracing.assert_called_once_with(exporter="http", endpoint=DEFAULT_OTLP_ENDPOINT) + + +def test_bindings_setup_state_machine() -> None: + result = _run_script(SETUP_STATE_MACHINE_SCRIPT) + assert "bindings.run_program" in result.stdout + assert "reconfigure_error" in result.stdout + assert "shutdown_error" in result.stdout + + +def test_bindings_http_setup_requires_endpoint() -> None: + result = _run_script(HTTP_MISSING_ENDPOINT_SCRIPT) + assert "missing_endpoint" in result.stdout + + +def test_bindings_http_setup_succeeds_with_explicit_endpoint() -> None: + result = _run_script(HTTP_SETUP_SCRIPT) + assert "http_ok" in result.stdout + + +def test_console_export_smoke_uses_high_level_wrapper() -> None: + result = _run_script(HIGH_LEVEL_TRACE_SCRIPT) + assert '"name": "parent"' in result.stdout + assert '"name": "extract"' in result.stdout + assert "bindings.run_program" in result.stdout + + parent_trace_id, parent_span_id, parent_parent_id = _parse_python_span(result.stdout, "parent") + extract_trace_id, extract_span_id, extract_parent_id = _parse_python_span(result.stdout, "extract") + rust_trace_id, rust_parent_id = _parse_rust_span(result.stdout, "bindings.run_program") + + assert parent_parent_id is None + assert extract_trace_id == parent_trace_id + assert extract_parent_id == parent_span_id + assert rust_trace_id == parent_trace_id + assert rust_parent_id == extract_span_id + + +def test_low_level_explicit_context_propagates_to_rust_spans() -> None: + result = _run_script(LOW_LEVEL_TRACE_SCRIPT) + + run_program_trace_id, run_program_span_id, _ = _parse_python_span(result.stdout, "run_program_parent") + rust_run_trace_id, rust_run_parent_id = _parse_rust_span(result.stdout, "bindings.run_program") + assert rust_run_trace_id == run_program_trace_id + assert rust_run_parent_id == run_program_span_id + + eval_trace_id, eval_span_id, _ = _parse_python_span(result.stdout, "eval_expr_parent") + rust_eval_trace_id, rust_eval_parent_id = _parse_rust_span(result.stdout, "bindings.eval_expr") + assert rust_eval_trace_id == eval_trace_id + assert rust_eval_parent_id == eval_span_id + + serialize_trace_id, serialize_span_id, _ = _parse_python_span(result.stdout, "serialize_parent") + rust_serialize_trace_id, rust_serialize_parent_id = _parse_rust_span(result.stdout, "bindings.serialize") + assert rust_serialize_trace_id == serialize_trace_id + assert rust_serialize_parent_id == serialize_span_id + + extractor_new_trace_id, extractor_new_span_id, _ = _parse_python_span(result.stdout, "extractor_new_parent") + rust_extractor_new_trace_id, rust_extractor_new_parent_id = _parse_rust_span( + result.stdout, "bindings.extractor.new" + ) + assert rust_extractor_new_trace_id == extractor_new_trace_id + assert rust_extractor_new_parent_id == extractor_new_span_id + + extract_best_trace_id, extract_best_span_id, _ = _parse_python_span(result.stdout, "extract_best_parent") + rust_extract_best_trace_id, rust_extract_best_parent_id = _parse_rust_span( + result.stdout, "bindings.extractor.extract_best" + ) + assert rust_extract_best_trace_id == extract_best_trace_id + assert rust_extract_best_parent_id == extract_best_span_id diff --git a/python/tests/test_type_constraint_solver.py b/python/tests/test_type_constraint_solver.py index f809d511..799f6551 100644 --- a/python/tests/test_type_constraint_solver.py +++ b/python/tests/test_type_constraint_solver.py @@ -7,47 +7,47 @@ i64 = JustTypeRef(Ident("i64")) unit = JustTypeRef(Ident("Unit")) -K, V = ClassTypeVarRef(Ident("K")), ClassTypeVarRef(Ident("V")) +K, V = TypeVarRef(Ident("K")), TypeVarRef(Ident("V")) map = TypeRefWithVars(Ident("Map"), (K, V)) map_i64_unit = JustTypeRef(Ident("Map"), (i64, unit)) decls = Declarations(_classes={Ident("Map"): ClassDecl(type_vars=(K, V))}) def test_simple() -> None: - tcs = TypeConstraintSolver(Declarations()) + tcs = TypeConstraintSolver() tcs.infer_typevars(i64.to_var(), i64) assert tcs.substitute_typevars(i64.to_var()) == i64 def test_wrong_arg() -> None: - tcs = TypeConstraintSolver(Declarations()) + tcs = TypeConstraintSolver() with pytest.raises(TypeConstraintError): tcs.infer_typevars(i64.to_var(), unit) def test_generic() -> None: - tcs = TypeConstraintSolver(Declarations()) - tcs.infer_typevars(map, map_i64_unit, Ident("Map")) - tcs.infer_typevars(K, i64, Ident("Map")) - assert tcs.substitute_typevars(V, Ident("Map")) == unit + tcs = TypeConstraintSolver() + tcs.infer_typevars(map, map_i64_unit) + tcs.infer_typevars(K, i64) + assert tcs.substitute_typevars(V) == unit def test_generic_wrong() -> None: - tcs = TypeConstraintSolver(Declarations()) - tcs.infer_typevars(map, map_i64_unit, Ident("Map")) + tcs = TypeConstraintSolver() + tcs.infer_typevars(map, map_i64_unit) with pytest.raises(TypeConstraintError): - tcs.infer_typevars(K, unit, Ident("Map")) + tcs.infer_typevars(K, unit) def test_bound() -> None: - bound_cs = TypeConstraintSolver(decls) - bound_cs.bind_class(map_i64_unit) - bound_cs.infer_typevars(K, i64, Ident("Map")) - assert bound_cs.substitute_typevars(V, Ident("Map")) == unit + bound_cs = TypeConstraintSolver() + bound_cs.bind_class(map_i64_unit, decls) + bound_cs.infer_typevars(K, i64) + assert bound_cs.substitute_typevars(V) == unit def test_bound_wrong(): - bound_cs = TypeConstraintSolver(decls) - bound_cs.bind_class(map_i64_unit) + bound_cs = TypeConstraintSolver() + bound_cs.bind_class(map_i64_unit, decls) with pytest.raises(TypeConstraintError): - bound_cs.infer_typevars(K, unit, Ident("Map")) + bound_cs.infer_typevars(K, unit) diff --git a/python/tests/test_unstable_fn.py b/python/tests/test_unstable_fn.py index eaaab64b..6d3b9b3b 100644 --- a/python/tests/test_unstable_fn.py +++ b/python/tests/test_unstable_fn.py @@ -53,6 +53,15 @@ def __mul__(self, x: MathLike) -> MathList: ... converter(type(None), MathList, lambda _: MathList.NIL) +class Pair(Expr): + def __init__(self, x: i64Like) -> None: ... + + def add(self, y: i64Like) -> Pair: ... + + +converter(i64, Pair, Pair) + + @function def square(x: MathLike) -> Math: ... @@ -74,6 +83,12 @@ def test_string_fn_partial(): assert str(UnstableFn(Math.__mul__, Math(2))) == "partial(Math.__mul__, Math(2))" +def test_bound_runtime_function_partial(): + pair = Pair(2) + assert expr_parts(UnstableFn(pair.add)) == expr_parts(UnstableFn(Pair.add, pair)) + assert expr_parts(UnstableFn(pair.add, 3)) == expr_parts(UnstableFn(Pair.add, pair, 3)) + + @ruleset def map_ruleset(f: MathFn, x: Math, xs: MathList): yield rewrite(MathList.NIL.map(f)).to(MathList.NIL) @@ -363,3 +378,16 @@ def apply_C(f: Callable[[C], C], x: C) -> C: egraph.run(10) egraph.check(eq(x).to(A())) egraph.check(eq(y).to(C())) + + def test_different_parameter_names_get_different_names(self): + @function + def apply_f(f: Callable[[A], A], x: A) -> A: + return f(x) + + egraph = EGraph(save_egglog_string=True) + egraph.register(apply_f(lambda left: A(), A())) + egraph.register(apply_f(lambda right: A(), A())) + + egglog = egraph.as_egglog_string + assert "_lambda_0" in egglog + assert "_lambda_1" in egglog diff --git a/src/conversions.rs b/src/conversions.rs index bf799506..00d46f87 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -138,13 +138,16 @@ convert_enums!( variants: variants.iter().map(|v| v.into()).collect() }; Sort(span: Span, name: String, presort_and_args: Option<(String, Vec)>) - s -> egglog::ast::Command::Sort( - s.span.clone().into(), - (&s.name).into(), - s.presort_and_args.as_ref().map(|(p, a)| (p.into(), a.iter().map(|e| e.into()).collect())) - ), - egglog::ast::Command::Sort(span, n, presort_and_args) => Sort { - name: n.to_string(), + s -> egglog::ast::Command::Sort { + span: s.span.clone().into(), + name: (&s.name).into(), + presort_and_args: s.presort_and_args.as_ref().map(|(p, a)| (p.into(), a.iter().map(|e| e.into()).collect())), + uf: None, + proof_func: None, + unionable: true + }, + egglog::ast::Command::Sort { span, name, presort_and_args, .. } => Sort { + name: name.to_string(), presort_and_args: presort_and_args.as_ref().map(|(p, a)| (p.to_string(), a.iter().map(|e| e.into()).collect())), span: span.into() }; @@ -153,9 +156,11 @@ convert_enums!( span: f.span.clone().into(), name: (&f.name).into(), schema: (&f.schema).into(), - merge: f.merge.as_ref().map(|e| e.into()) + merge: f.merge.as_ref().map(|e| e.into()), + hidden: false, + let_binding: false }, - egglog::ast::Command::Function {span, name, schema, merge} => FunctionCommand { + egglog::ast::Command::Function {span, name, schema, merge, .. } => FunctionCommand { span: span.into(), name: name.to_string(), schema: schema.into(), @@ -210,6 +215,18 @@ convert_enums!( Check(span: Span, facts: Vec) c -> egglog::ast::Command::Check(c.span.clone().into(), c.facts.iter().map(|f| f.into()).collect()), egglog::ast::Command::Check(span, facts) => Check { span: span.into(), facts: facts.iter().map(|f| f.into()).collect() }; + ProveCommand[name="Prove"](span: Span, facts: Vec) + p -> egglog::ast::Command::Prove(p.span.clone().into(), p.facts.iter().map(|f| f.into()).collect()), + egglog::ast::Command::Prove(span, facts) => ProveCommand { + span: span.into(), + facts: facts.iter().map(|f| f.into()).collect() + }; + ProveExistsCommand[name="ProveExists"](span: Span, expr: String) + p -> egglog::ast::Command::ProveExists(p.span.clone().into(), (&p.expr).into()), + egglog::ast::Command::ProveExists(span, expr) => ProveExistsCommand { + span: span.into(), + expr: expr.to_string() + }; PrintFunction(span: Span, name: String, length: Option, filename: Option, mode: PrintFunctionMode) p -> egglog::ast::Command::PrintFunction(p.span.clone().into(), (&p.name).into(), p.length, p.filename.clone(), p.mode.clone().into()), egglog::ast::Command::PrintFunction(span, n, l, f, m) => PrintFunction { @@ -262,9 +279,12 @@ convert_enums!( name: (&c.name).into(), schema: (&c.schema).into(), cost: c.cost, - unextractable: c.unextractable + unextractable: c.unextractable, + hidden: false, + let_binding: false, + term_constructor: None }, - egglog::ast::Command::Constructor {span, name, schema, cost, unextractable} => Constructor { + egglog::ast::Command::Constructor {span, name, schema, cost, unextractable, .. } => Constructor { span: span.into(), name: name.to_string(), schema: schema.into(), @@ -364,25 +384,30 @@ convert_enums!( PrintAllFunctionsSize(sizes: Vec<(String, usize)>) b -> egglog::CommandOutput::PrintAllFunctionsSize(b.sizes.clone()), egglog::CommandOutput::PrintAllFunctionsSize(sizes) => PrintAllFunctionsSize {sizes: sizes.clone()}; - ExtractBest(termdag: TermDag, cost: DefaultCost, term: Term) + ExtractBest(termdag: TermDag, cost: DefaultCost, term: usize) b -> egglog::CommandOutput::ExtractBest( b.termdag.0.clone(), b.cost, - (&b.term).into() + b.term ), egglog::CommandOutput::ExtractBest(termdag, cost, term) => ExtractBest { termdag: TermDag(termdag.clone()), cost: *cost, - term: term.into() + term: *term }; - ExtractVariants(termdag: TermDag, terms: Vec) + ExtractVariants(termdag: TermDag, terms: Vec) v -> egglog::CommandOutput::ExtractVariants( v.termdag.0.clone(), - v.terms.iter().map(|v| v.into()).collect() + v.terms.clone() ), egglog::CommandOutput::ExtractVariants(termdag, terms) => ExtractVariants { termdag: TermDag(termdag.clone()), - terms: terms.iter().map(|v| v.into()).collect() + terms: terms.clone() + }; + ProveExistsOutput(proof: String) + _p -> panic!("Converting Python proof output back into egglog is unsupported"), + egglog::CommandOutput::ProveExists { proof_store, proof_id } => ProveExistsOutput { + proof: proof_store.proof_to_string(*proof_id) }; OverallStatistics(report: RunReport) b -> egglog::CommandOutput::OverallStatistics(b.report.clone().into()), @@ -390,17 +415,17 @@ convert_enums!( RunScheduleOutput(report: RunReport) b -> egglog::CommandOutput::RunSchedule(b.report.clone().into()), egglog::CommandOutput::RunSchedule(report) => RunScheduleOutput {report: report.into()}; - PrintFunctionOutput(function: Function, termdag: TermDag, terms: Vec<(Term, Term)>, mode: PrintFunctionMode) + PrintFunctionOutput(function: Function, termdag: TermDag, terms: Vec<(usize, usize)>, mode: PrintFunctionMode) v -> egglog::CommandOutput::PrintFunction( v.function.0.clone(), v.termdag.0.clone(), - v.terms.iter().map(|(l, r)| (l.into(), r.into())).collect(), + v.terms.clone(), v.mode.clone().into() ), egglog::CommandOutput::PrintFunction(function, termdag, terms, mode) => PrintFunctionOutput { function: Function(function.clone()), termdag: TermDag(termdag.clone()), - terms: terms.iter().map(|(l, r)| (l.into(), r.into())).collect(), + terms: terms.clone(), mode: mode.into() }; UserDefinedOutput(output: UserDefinedCommandOutput) diff --git a/src/egraph.rs b/src/egraph.rs index 19e9f172..4cad3d30 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -2,8 +2,10 @@ use crate::conversions::*; use crate::error::{EggResult, WrappedError}; +use crate::freeze::FrozenEGraph; use crate::py_object_sort::{PyObjectSort, PyPickledValue, load}; use crate::serialize::SerializedEGraph; +use crate::tracing_otel; use egglog::prelude::{RustSpan, Span, add_base_sort}; use egglog::{SerializeConfig, span}; @@ -52,17 +54,29 @@ impl EGraph { /// Run a series of commands on the EGraph. /// Returns a list of strings representing the output. /// An EggSmolError is raised if there is problem parsing or executing. - #[pyo3(signature=(*commands))] + #[pyo3(signature=(*commands, traceparent=None, tracestate=None))] fn run_program( &mut self, py: Python<'_>, commands: Vec, + traceparent: Option, + tracestate: Option, ) -> EggResult> { + let _context_guard = + tracing_otel::attach_parent_context(traceparent.as_deref(), tracestate.as_deref()); let commands: Vec = commands.into_iter().map(|x| x.into()).collect(); let mut cmds_str = String::new(); + for cmd in &commands { - cmds_str = cmds_str + &cmd.to_string() + "\n"; + let cmd_string = cmd.to_string(); + cmds_str = cmds_str + &cmd_string + "\n"; } + let span = tracing::info_span!( + "bindings.run_program", + command_count = commands.len(), + commands = tracing::field::display(cmds_str.trim_end()) + ); + let _entered = span.enter(); info!("Running commands:\n{}", cmds_str); let res = py.detach(|| self.egraph.run_program(commands)); if let Some(err) = PyErr::take(py) { @@ -74,7 +88,8 @@ impl EGraph { if let Some(cmds) = &mut self.cmds { cmds.push_str(&cmds_str); } - Ok(outputs.into_iter().map(|o| o.into()).collect()) + let outputs = outputs.into_iter().map(|o| o.into()).collect(); + Ok(outputs) } } } @@ -87,33 +102,43 @@ impl EGraph { /// Serialize the EGraph to a SerializedEGraph object. #[pyo3( - signature = (root_eclasses, *, max_functions=None, max_calls_per_function=None, include_temporary_functions=false), - text_signature = "(self, root_eclasses, *, max_functions=None, max_calls_per_function=None, include_temporary_functions=False)" + signature = (root_eclasses, *, max_functions=None, max_calls_per_function=None, include_temporary_functions=false, traceparent=None, tracestate=None), + text_signature = "(self, root_eclasses, *, max_functions=None, max_calls_per_function=None, include_temporary_functions=False, traceparent=None, tracestate=None)" )] fn serialize( &mut self, - py: Python<'_>, root_eclasses: Vec, max_functions: Option, max_calls_per_function: Option, include_temporary_functions: bool, + traceparent: Option, + tracestate: Option, ) -> SerializedEGraph { - py.detach(|| { - let root_eclasses: Vec<_> = root_eclasses - .into_iter() - .map(|x| self.egraph.eval_expr(&egglog::ast::Expr::from(x)).unwrap()) - .collect(); - let res = self.egraph.serialize(SerializeConfig { - max_functions, - max_calls_per_function, - include_temporary_functions, - root_eclasses, - }); - SerializedEGraph { - egraph: res.egraph, - truncated_functions: res.truncated_functions, - discarded_functions: res.discarded_functions, - } + let _context_guard = + tracing_otel::attach_parent_context(traceparent.as_deref(), tracestate.as_deref()); + let span = tracing::info_span!( + "bindings.serialize", + root_eclass_count = root_eclasses.len() + ); + let _entered = span.enter(); + Python::attach(|py| { + py.detach(|| { + let root_eclasses: Vec<_> = root_eclasses + .into_iter() + .map(|x| self.egraph.eval_expr(&egglog::ast::Expr::from(x)).unwrap()) + .collect(); + let res = self.egraph.serialize(SerializeConfig { + max_functions, + max_calls_per_function, + include_temporary_functions, + root_eclasses, + }); + SerializedEGraph { + egraph: res.egraph, + truncated_functions: res.truncated_functions, + discarded_functions: res.discarded_functions, + } + }) }) } @@ -130,7 +155,18 @@ impl EGraph { .map(Value) } - fn eval_expr(&mut self, py: Python<'_>, expr: Expr) -> EggResult<(String, Value)> { + #[pyo3(signature = (expr, *, traceparent=None, tracestate=None))] + fn eval_expr( + &mut self, + py: Python<'_>, + expr: Expr, + traceparent: Option, + tracestate: Option, + ) -> EggResult<(String, Value)> { + let _context_guard = + tracing_otel::attach_parent_context(traceparent.as_deref(), tracestate.as_deref()); + let span = tracing::info_span!("bindings.eval_expr"); + let _entered = span.enter(); let expr: egglog::ast::Expr = expr.into(); let res = py.detach(|| { self.egraph @@ -227,6 +263,10 @@ impl EGraph { ) } + fn freeze(&self) -> FrozenEGraph { + FrozenEGraph::from_egraph(&self.egraph) + } + // fn dynamic_cost_model_enode_cost( // &self, // func: String, @@ -246,5 +286,5 @@ impl EGraph { /// Wrapper around Egglog Value. Represents either a primitive base value or a reference to an e-class. #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Clone)] -#[pyclass(eq, frozen, hash, str = "{0:?}")] +#[pyclass(eq, frozen, ord, hash, str = "{0:?}")] pub struct Value(pub egglog::Value); diff --git a/src/extract.rs b/src/extract.rs index 1c6cde41..5385610a 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -2,7 +2,8 @@ use std::cmp::Ordering; use pyo3::{exceptions::PyValueError, prelude::*}; -use crate::{conversions::Term, egraph::EGraph, egraph::Value, termdag::TermDag}; +use crate::{egraph::EGraph, egraph::Value, termdag::TermDag, tracing_otel}; +use egglog::TermId; #[derive(Debug)] // We have to store the result, since the cost model does not return errors @@ -181,12 +182,22 @@ impl Extractor { /// /// For convenience, if the rootsorts is `None`, it defaults to extract all extractable rootsorts. #[new] + #[pyo3(signature = (rootsorts, egraph, cost_model, *, traceparent=None, tracestate=None))] fn new( py: Python<'_>, rootsorts: Option>, egraph: &EGraph, cost_model: CostModel, + traceparent: Option, + tracestate: Option, ) -> PyResult { + let _context_guard = + tracing_otel::attach_parent_context(traceparent.as_deref(), tracestate.as_deref()); + let span = tracing::info_span!( + "bindings.extractor.new", + has_rootsorts = rootsorts.is_some() + ); + let _entered = span.enter(); let egraph = &egraph.egraph; // Transforms sorts to arcsorts, returning an error if any are unknown let rootsorts = rootsorts @@ -209,6 +220,7 @@ impl Extractor { /// /// This function expects the sort to be already computed, /// which can be one of the rootsorts, or reachable from rootsorts, or primitives, or containers of computed sorts. + #[pyo3(signature = (egraph, termdag, value, sort, *, traceparent=None, tracestate=None))] fn extract_best( &self, py: Python<'_>, @@ -216,7 +228,13 @@ impl Extractor { termdag: &mut TermDag, value: Value, sort: String, - ) -> PyResult<(Py, Term)> { + traceparent: Option, + tracestate: Option, + ) -> PyResult<(Py, TermId)> { + let _context_guard = + tracing_otel::attach_parent_context(traceparent.as_deref(), tracestate.as_deref()); + let span = tracing::info_span!("bindings.extractor.extract_best", sort = %sort); + let _entered = span.enter(); let sort = egraph .egraph .get_sort_by_name(&sort) @@ -225,13 +243,14 @@ impl Extractor { .0 .extract_best_with_sort(&egraph.egraph, &mut termdag.0, value.0, sort.clone()) .ok_or(PyValueError::new_err("Unextractable root".to_string()))?; - Ok((cost.0.clone_ref(py), term.into())) + Ok((cost.0.clone_ref(py), term)) } /// Extract variants of an e-class. /// /// The variants are selected by first picking `nvariants` e-nodes with the lowest cost from the e-class /// and then extracting a term from each e-node. + #[pyo3(signature = (egraph, termdag, value, nvariants, sort, *, traceparent=None, tracestate=None))] fn extract_variants( &self, py: Python<'_>, @@ -240,7 +259,13 @@ impl Extractor { value: Value, nvariants: usize, sort: String, - ) -> PyResult, Term)>> { + traceparent: Option, + tracestate: Option, + ) -> PyResult, TermId)>> { + let _context_guard = + tracing_otel::attach_parent_context(traceparent.as_deref(), tracestate.as_deref()); + let span = tracing::info_span!("bindings.extractor.extract_variants", sort = %sort, variant_count = nvariants); + let _entered = span.enter(); let sort = egraph .egraph .get_sort_by_name(&sort) @@ -254,7 +279,7 @@ impl Extractor { ); Ok(variants .into_iter() - .map(|(cost, term)| (cost.0.clone_ref(py), term.into())) + .map(|(cost, term)| (cost.0.clone_ref(py), term)) .collect()) } } diff --git a/src/freeze.rs b/src/freeze.rs new file mode 100644 index 00000000..28d1a1e8 --- /dev/null +++ b/src/freeze.rs @@ -0,0 +1,67 @@ +// Freeze an egglog, turning it into an immutable structure that can be printed, serialized, or added back to an e-graph. + +use egglog::EGraph; +use indexmap::IndexMap; +use pyo3::prelude::*; + +use crate::egraph::Value; + +#[pyclass(eq, frozen, get_all)] +#[derive(PartialEq, Eq, Clone, Hash)] +pub struct FrozenRow { + subsumed: bool, + inputs: Vec, + output: Value, +} + +#[pyclass(eq, frozen, hash, get_all)] +#[derive(PartialEq, Eq, Clone, Hash)] +pub struct FrozenFunction { + input_sorts: Vec, + output_sort: String, + is_let_binding: bool, + rows: Vec, +} + +#[pyclass(eq, frozen, get_all)] +#[derive(PartialEq, Eq, Clone)] +pub struct FrozenEGraph { + functions: IndexMap, +} + +impl FrozenEGraph { + /// Convert a live `EGraph` into an immutable `FrozenEGraph` snapshot. + pub fn from_egraph(egraph: &EGraph) -> FrozenEGraph { + let mut functions = IndexMap::new(); + for fname in egraph.get_function_names() { + let mut rows = Vec::new(); + egraph.function_for_each(&fname, |row| { + let frozen_row = FrozenRow { + subsumed: row.subsumed, + inputs: row.vals[..row.vals.len() - 1] + .iter() + .cloned() + .map(Value) + .collect(), + output: Value(*row.vals.last().unwrap()), + }; + rows.push(frozen_row); + }).unwrap(); + let func = egraph.get_function(&fname).unwrap(); + let frozen_function = FrozenFunction { + input_sorts: func + .schema() + .input + .iter() + .map(|s| s.name().to_string()) + .collect(), + output_sort: func.schema().output.name().to_string(), + rows, + is_let_binding: func.is_let_binding(), + }; + functions.insert(fname.clone(), frozen_function); + } + + FrozenEGraph { functions } + } +} diff --git a/src/lib.rs b/src/lib.rs index 723f5059..7590ddc7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,12 +2,30 @@ mod conversions; mod egraph; mod error; mod extract; +mod freeze; mod py_object_sort; mod serialize; mod termdag; +mod tracing_otel; mod utils; use pyo3::prelude::*; +use pyo3::wrap_pyfunction; + +#[pyfunction] +#[pyo3(signature = (*, exporter, endpoint=None))] +fn setup_tracing(py: Python<'_>, exporter: &str, endpoint: Option<&str>) -> PyResult<()> { + let exporter = exporter.to_string(); + let endpoint = endpoint.map(str::to_string); + py.detach(move || crate::tracing_otel::setup_tracing(&exporter, endpoint.as_deref())) + .map_err(pyo3::exceptions::PyRuntimeError::new_err) +} + +#[pyfunction] +fn shutdown_tracing(py: Python<'_>) -> PyResult<()> { + py.detach(crate::tracing_otel::shutdown_tracing) + .map_err(pyo3::exceptions::PyRuntimeError::new_err) +} /// Bindings for egglog rust library #[pymodule] @@ -33,6 +51,11 @@ fn bindings(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(setup_tracing, m)?)?; + m.add_function(wrap_pyfunction!(shutdown_tracing, m)?)?; crate::conversions::add_structs_to_module(m)?; crate::conversions::add_enums_to_module(m)?; diff --git a/src/py_object_sort.rs b/src/py_object_sort.rs index 51de73f9..1c5e78df 100644 --- a/src/py_object_sort.rs +++ b/src/py_object_sort.rs @@ -9,7 +9,7 @@ use core::fmt; /// /// use egglog::{ - BaseValue, Term, TermDag, Value, add_primitive, + BaseValue, TermDag, TermId, Value, add_primitive, ast::Literal, prelude::{BaseSort, EGraph}, sort::{BaseValues, S}, @@ -195,7 +195,7 @@ impl BaseSort for PyObjectSort { base_values: &BaseValues, value: Value, termdag: &mut TermDag, - ) -> Term { + ) -> TermId { let ident = base_values.unwrap::(value); let arg = termdag.lit(Literal::String(STANDARD.encode(&ident.0).into())); termdag.app("py-object".into(), vec![arg]) diff --git a/src/serialize.rs b/src/serialize.rs index a18281ce..29b91554 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -35,7 +35,7 @@ impl SerializedEGraph { serde_json::to_string(&self.egraph).unwrap() } - /// Split all primitive nodes, as well as other ops that match, into seperate e-classes + /// Split all primitive nodes, as well as other ops that match, into separate e-classes fn split_classes(&mut self, egraph: &EGraph, ops: HashSet) { self.egraph.split_classes(|id, node| { egraph.egraph.from_node_id(id).is_primitive() || ops.contains(&node.op) diff --git a/src/termdag.rs b/src/termdag.rs index c4027a1f..238868e0 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -36,22 +36,20 @@ impl TermDag { /// and insert into the DAG if it is not already present. /// /// Panics if any of the children are not already in the DAG. - pub fn app(&mut self, sym: String, children: Vec) -> Term { - self.0 - .app(sym, children.into_iter().map(|c| c.into()).collect()) - .into() + pub fn app(&mut self, sym: String, children: Vec) -> TermId { + self.0.app(sym, children) } - /// Make and return a [`Term::Lit`] with the given literal, and insert into - /// the DAG if it is not already present. - pub fn lit(&mut self, lit: Literal) -> Term { - self.0.lit(lit.into()).into() + /// Make a [`Term::Lit`] with the given literal and return its id, + /// inserting it into the DAG if it is not already present. + pub fn lit(&mut self, lit: Literal) -> TermId { + self.0.lit(lit.into()) } - /// Make and return a [`Term::Var`] with the given symbol, and insert into + /// Make and return a [`Term::Var`] id with the given symbol, and insert into /// the DAG if it is not already present. - pub fn var(&mut self, sym: String) -> Term { - self.0.var(sym).into() + pub fn var(&mut self, sym: String) -> TermId { + self.0.var(sym) } /// Recursively converts the given expression to a term. @@ -59,21 +57,21 @@ impl TermDag { /// This involves inserting every subexpression into this DAG. Because /// TermDags are hashconsed, the resulting term is guaranteed to maximally /// share subterms. - pub fn expr_to_term(&mut self, expr: Expr) -> Term { - self.0.expr_to_term(&expr.into()).into() + pub fn expr_to_term(&mut self, expr: Expr) -> TermId { + self.0.expr_to_term(&expr.into()) } /// Recursively converts the given term to an expression. /// /// Panics if the term contains subterms that are not in the DAG. - pub fn term_to_expr(&self, term: Term, span: Span) -> Expr { - self.0.term_to_expr(&term.into(), span.into()).into() + pub fn term_to_expr(&self, term: TermId, span: Span) -> Expr { + self.0.term_to_expr(&term, span.into()).into() } /// Converts the given term to a string. /// /// Panics if the term or any of its subterms are not in the DAG. - pub fn to_string(&self, term: Term) -> String { - self.0.to_string(&term.into()) + pub fn to_string(&self, term: TermId) -> String { + self.0.to_string(term) } } diff --git a/src/tracing_otel.rs b/src/tracing_otel.rs new file mode 100644 index 00000000..17098780 --- /dev/null +++ b/src/tracing_otel.rs @@ -0,0 +1,198 @@ +use std::collections::HashMap; +use std::sync::Mutex; +use std::time::Instant; + +use opentelemetry::{ + Context, ContextGuard, + propagation::TextMapPropagator, + trace::{TraceContextExt, Tracer as _, TracerProvider as _}, +}; +use opentelemetry_otlp::WithExportConfig as _; +use opentelemetry_sdk::{Resource, propagation::TraceContextPropagator, trace::SdkTracerProvider}; +use tracing_opentelemetry::OpenTelemetryLayer; +use tracing_subscriber::layer::SubscriberExt; + +const SERVICE_NAME: &str = "egglog"; +const TRACER_NAME: &str = "egglog.rust"; + +#[derive(Clone, Debug, PartialEq, Eq)] +enum TracingConfig { + Console, + Http { endpoint: String }, +} + +#[derive(Default)] +struct TracingBackend { + config: Option, + provider: Option, + shutdown: bool, +} + +static TRACING_BACKEND: Mutex = Mutex::new(TracingBackend { + config: None, + provider: None, + shutdown: false, +}); + +pub(crate) fn attach_parent_context( + traceparent: Option<&str>, + tracestate: Option<&str>, +) -> Option { + extract_context_from_headers(traceparent, tracestate).map(|context| context.attach()) +} + +pub(crate) fn setup_tracing(exporter: &str, endpoint: Option<&str>) -> Result<(), String> { + let config = parse_config(exporter, endpoint)?; + { + let backend = TRACING_BACKEND.lock().unwrap(); + if backend.shutdown { + return Err("egglog tracing has already been shut down for this process".to_string()); + } + if let Some(existing) = &backend.config { + return if existing == &config { + Ok(()) + } else { + Err(format!( + "egglog tracing is already configured as {existing:?}; cannot reconfigure to {config:?}" + )) + }; + } + } + + warmup_tracer_provider(&config)?; + let provider = create_tracer_provider(&config)?; + let otel_layer = OpenTelemetryLayer::new(provider.tracer(TRACER_NAME)); + let subscriber = tracing_subscriber::registry().with(otel_layer); + tracing::subscriber::set_global_default(subscriber) + .map_err(|err| format!("could not install egglog tracing subscriber: {err}"))?; + + let mut backend = TRACING_BACKEND.lock().unwrap(); + if backend.shutdown { + let _ = provider.shutdown(); + return Err("egglog tracing has already been shut down for this process".to_string()); + } + match &backend.config { + None => { + backend.config = Some(config); + backend.provider = Some(provider); + Ok(()) + } + Some(existing) if existing == &config => { + let _ = provider.shutdown(); + Ok(()) + } + Some(existing) => { + let _ = provider.shutdown(); + Err(format!( + "egglog tracing is already configured as {existing:?}; cannot reconfigure" + )) + } + } +} + +pub(crate) fn shutdown_tracing() -> Result<(), String> { + let provider = { + let mut backend = TRACING_BACKEND.lock().unwrap(); + if backend.shutdown { + return Ok(()); + } + backend.shutdown = true; + backend.provider.take() + }; + + if let Some(provider) = provider { + provider.shutdown().map_err(|err| err.to_string())?; + } + Ok(()) +} + +fn parse_config(exporter: &str, endpoint: Option<&str>) -> Result { + match exporter { + "console" => Ok(TracingConfig::Console), + "http" => endpoint + .map(|endpoint| TracingConfig::Http { + endpoint: endpoint.to_string(), + }) + .ok_or_else(|| "setup_tracing(exporter='http') requires an endpoint".to_string()), + _ => Err(format!("unsupported tracing exporter {exporter:?}")), + } +} + +fn create_tracer_provider(config: &TracingConfig) -> Result { + let resource = Resource::builder().with_service_name(SERVICE_NAME).build(); + + match config { + TracingConfig::Console => Ok(SdkTracerProvider::builder() + .with_resource(resource) + .with_simple_exporter(opentelemetry_stdout::SpanExporter::default()) + .build()), + TracingConfig::Http { endpoint } => { + let exporter = build_http_exporter(endpoint)?; + Ok(SdkTracerProvider::builder() + .with_resource(resource) + .with_batch_exporter(exporter) + .build()) + } + } +} + +fn warmup_tracer_provider(config: &TracingConfig) -> Result<(), String> { + let TracingConfig::Http { endpoint } = config else { + return Ok(()); + }; + + log::info!("warming up egglog rust tracing exporter at {endpoint}"); + let start = Instant::now(); + let resource = Resource::builder().with_service_name(SERVICE_NAME).build(); + let exporter = build_http_exporter(endpoint)?; + let provider = SdkTracerProvider::builder() + .with_resource(resource) + .with_simple_exporter(exporter) + .build(); + let tracer = provider.tracer(TRACER_NAME); + tracer.in_span("bindings.setup_tracing", |_| {}); + provider.shutdown().map_err(|err| { + log::warn!( + "egglog rust tracing exporter warmup failed after {:?}: {err}", + start.elapsed() + ); + err.to_string() + })?; + log::info!( + "warmed up egglog rust tracing exporter at {endpoint} in {:?}", + start.elapsed() + ); + Ok(()) +} + +fn build_http_exporter(endpoint: &str) -> Result { + opentelemetry_otlp::SpanExporter::builder() + .with_http() + .with_endpoint(endpoint) + .build() + .map_err(|err| err.to_string()) +} + +fn extract_context_from_headers( + traceparent: Option<&str>, + tracestate: Option<&str>, +) -> Option { + let mut headers = HashMap::new(); + if let Some(traceparent) = traceparent { + headers.insert("traceparent".to_string(), traceparent.to_string()); + } + if let Some(tracestate) = tracestate { + headers.insert("tracestate".to_string(), tracestate.to_string()); + } + if headers.is_empty() { + return None; + } + + let propagator = TraceContextPropagator::new(); + let context = propagator.extract(&headers); + if context.span().span_context().is_valid() { + Some(context) + } else { + None + } +} diff --git a/uv.lock b/uv.lock index 11b06639..d97e6708 100644 --- a/uv.lock +++ b/uv.lock @@ -608,6 +608,7 @@ dependencies = [ { name = "black" }, { name = "cloudpickle" }, { name = "graphviz" }, + { name = "opentelemetry-api" }, { name = "typing-extensions" }, ] @@ -634,6 +635,8 @@ dev = [ { name = "nbconvert" }, { name = "numba" }, { name = "numpy" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-sdk" }, { name = "pre-commit" }, { name = "pydata-sphinx-theme" }, { name = "pytest" }, @@ -673,6 +676,8 @@ test = [ { name = "mypy" }, { name = "numba" }, { name = "numpy" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-sdk" }, { name = "pytest" }, { name = "pytest-benchmark" }, { name = "pytest-codspeed" }, @@ -684,6 +689,7 @@ test = [ [package.dev-dependencies] dev = [ { name = "py-spy" }, + { name = "sympy" }, ] [package.metadata] @@ -711,6 +717,9 @@ requires-dist = [ { name = "nbconvert", marker = "extra == 'docs'" }, { name = "numba", marker = "extra == 'array'", specifier = ">=0.59.1" }, { name = "numpy", marker = "extra == 'array'", specifier = ">2" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" }, + { name = "opentelemetry-sdk", marker = "extra == 'test'" }, { name = "pre-commit", marker = "extra == 'dev'" }, { name = "pydata-sphinx-theme", marker = "extra == 'docs'" }, { name = "pytest", marker = "extra == 'test'" }, @@ -729,7 +738,10 @@ requires-dist = [ provides-extras = ["array", "dev", "docs", "test"] [package.metadata.requires-dev] -dev = [{ name = "py-spy", specifier = ">=0.4.1" }] +dev = [ + { name = "py-spy", specifier = ">=0.4.1" }, + { name = "sympy", specifier = ">=1.14.0" }, +] [[package]] name = "execnet" @@ -835,6 +847,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cf/58/8acf1b3e91c58313ce5cb67df61001fc9dcd21be4fadb76c1a2d540e09ed/fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014", size = 9121, upload-time = "2021-03-11T07:16:28.351Z" }, ] +[[package]] +name = "googleapis-common-protos" +version = "1.73.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/96/a0205167fa0154f4a542fd6925bdc63d039d88dab3588b875078107e6f06/googleapis_common_protos-1.73.0.tar.gz", hash = "sha256:778d07cd4fbeff84c6f7c72102f0daf98fa2bfd3fa8bea426edc545588da0b5a", size = 147323, upload-time = "2026-03-06T21:53:09.727Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/28/23eea8acd65972bbfe295ce3666b28ac510dfcb115fac089d3edb0feb00a/googleapis_common_protos-1.73.0-py3-none-any.whl", hash = "sha256:dfdaaa2e860f242046be561e6d6cb5c5f1541ae02cfbcb034371aadb2942b4e8", size = 297578, upload-time = "2026-03-06T21:52:33.933Z" }, +] + [[package]] name = "graphviz" version = "0.21" @@ -1806,6 +1830,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7a/f0/8282d9641415e9e33df173516226b404d367a0fc55e1a60424a152913abc/mistune-3.1.4-py3-none-any.whl", hash = "sha256:93691da911e5d9d2e23bc54472892aff676df27a75274962ff9edc210364266d", size = 53481, upload-time = "2025-08-29T07:20:42.218Z" }, ] +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + [[package]] name = "mypy" version = "1.18.2" @@ -2084,6 +2117,88 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/8e/2844c3959ce9a63acc7c8e50881133d86666f0420bcde695e115ced0920f/numpy-2.3.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:81b3a59793523e552c4a96109dde028aa4448ae06ccac5a76ff6532a85558a7f", size = 12973130, upload-time = "2025-10-15T16:18:09.397Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-proto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/51/bc/1559d46557fe6eca0b46c88d4c2676285f1f3be2e8d06bb5d15fbffc814a/opentelemetry_exporter_otlp_proto_common-1.40.0.tar.gz", hash = "sha256:1cbee86a4064790b362a86601ee7934f368b81cd4cc2f2e163902a6e7818a0fa", size = 20416, upload-time = "2026-03-04T14:17:23.801Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/ca/8f122055c97a932311a3f640273f084e738008933503d0c2563cd5d591fc/opentelemetry_exporter_otlp_proto_common-1.40.0-py3-none-any.whl", hash = "sha256:7081ff453835a82417bf38dccf122c827c3cbc94f2079b03bba02a3165f25149", size = 18369, upload-time = "2026-03-04T14:17:04.796Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-common" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/fa/73d50e2c15c56be4d000c98e24221d494674b0cc95524e2a8cb3856d95a4/opentelemetry_exporter_otlp_proto_http-1.40.0.tar.gz", hash = "sha256:db48f5e0f33217588bbc00274a31517ba830da576e59503507c839b38fa0869c", size = 17772, upload-time = "2026-03-04T14:17:25.324Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/3a/8865d6754e61c9fb170cdd530a124a53769ee5f740236064816eb0ca7301/opentelemetry_exporter_otlp_proto_http-1.40.0-py3-none-any.whl", hash = "sha256:a8d1dab28f504c5d96577d6509f80a8150e44e8f45f82cdbe0e34c99ab040069", size = 19960, upload-time = "2026-03-04T14:17:07.153Z" }, +] + +[[package]] +name = "opentelemetry-proto" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/77/dd38991db037fdfce45849491cb61de5ab000f49824a00230afb112a4392/opentelemetry_proto-1.40.0.tar.gz", hash = "sha256:03f639ca129ba513f5819810f5b1f42bcb371391405d99c168fe6937c62febcd", size = 45667, upload-time = "2026-03-04T14:17:31.194Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/b2/189b2577dde745b15625b3214302605b1353436219d42b7912e77fa8dc24/opentelemetry_proto-1.40.0-py3-none-any.whl", hash = "sha256:266c4385d88923a23d63e353e9761af0f47a6ed0d486979777fe4de59dc9b25f", size = 72073, upload-time = "2026-03-04T14:17:16.673Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/fd/3c3125b20ba18ce2155ba9ea74acb0ae5d25f8cd39cfd37455601b7955cc/opentelemetry_sdk-1.40.0.tar.gz", hash = "sha256:18e9f5ec20d859d268c7cb3c5198c8d105d073714db3de50b593b8c1345a48f2", size = 184252, upload-time = "2026-03-04T14:17:31.87Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/c5/6a852903d8bfac758c6dc6e9a68b015d3c33f2f1be5e9591e0f4b69c7e0a/opentelemetry_sdk-1.40.0-py3-none-any.whl", hash = "sha256:787d2154a71f4b3d81f20524a8ce061b7db667d24e46753f32a7bc48f1c1f3f1", size = 141951, upload-time = "2026-03-04T14:17:17.961Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/c0/4ae7973f3c2cfd2b6e321f1675626f0dab0a97027cc7a297474c9c8f3d04/opentelemetry_semantic_conventions-0.61b0.tar.gz", hash = "sha256:072f65473c5d7c6dc0355b27d6c9d1a679d63b6d4b4b16a9773062cb7e31192a", size = 145755, upload-time = "2026-03-04T14:17:32.664Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/37/cc6a55e448deaa9b27377d087da8615a3416d8ad523d5960b78dbeadd02a/opentelemetry_semantic_conventions-0.61b0-py3-none-any.whl", hash = "sha256:fa530a96be229795f8cef353739b618148b0fe2b4b3f005e60e262926c4d38e2", size = 231621, upload-time = "2026-03-04T14:17:19.33Z" }, +] + [[package]] name = "overrides" version = "7.7.0" @@ -2337,6 +2452,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, ] +[[package]] +name = "protobuf" +version = "6.33.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/25/7c72c307aafc96fa87062aa6291d9f7c94836e43214d43722e86037aac02/protobuf-6.33.5.tar.gz", hash = "sha256:6ddcac2a081f8b7b9642c09406bc6a4290128fce5f471cddd165960bb9119e5c", size = 444465, upload-time = "2026-01-29T21:51:33.494Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/79/af92d0a8369732b027e6d6084251dd8e782c685c72da161bd4a2e00fbabb/protobuf-6.33.5-cp310-abi3-win32.whl", hash = "sha256:d71b040839446bac0f4d162e758bea99c8251161dae9d0983a3b88dee345153b", size = 425769, upload-time = "2026-01-29T21:51:21.751Z" }, + { url = "https://files.pythonhosted.org/packages/55/75/bb9bc917d10e9ee13dee8607eb9ab963b7cf8be607c46e7862c748aa2af7/protobuf-6.33.5-cp310-abi3-win_amd64.whl", hash = "sha256:3093804752167bcab3998bec9f1048baae6e29505adaf1afd14a37bddede533c", size = 437118, upload-time = "2026-01-29T21:51:24.022Z" }, + { url = "https://files.pythonhosted.org/packages/a2/6b/e48dfc1191bc5b52950246275bf4089773e91cb5ba3592621723cdddca62/protobuf-6.33.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:a5cb85982d95d906df1e2210e58f8e4f1e3cdc088e52c921a041f9c9a0386de5", size = 427766, upload-time = "2026-01-29T21:51:25.413Z" }, + { url = "https://files.pythonhosted.org/packages/4e/b1/c79468184310de09d75095ed1314b839eb2f72df71097db9d1404a1b2717/protobuf-6.33.5-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:9b71e0281f36f179d00cbcb119cb19dec4d14a81393e5ea220f64b286173e190", size = 324638, upload-time = "2026-01-29T21:51:26.423Z" }, + { url = "https://files.pythonhosted.org/packages/c5/f5/65d838092fd01c44d16037953fd4c2cc851e783de9b8f02b27ec4ffd906f/protobuf-6.33.5-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:8afa18e1d6d20af15b417e728e9f60f3aa108ee76f23c3b2c07a2c3b546d3afd", size = 339411, upload-time = "2026-01-29T21:51:27.446Z" }, + { url = "https://files.pythonhosted.org/packages/9b/53/a9443aa3ca9ba8724fdfa02dd1887c1bcd8e89556b715cfbacca6b63dbec/protobuf-6.33.5-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:cbf16ba3350fb7b889fca858fb215967792dc125b35c7976ca4818bee3521cf0", size = 323465, upload-time = "2026-01-29T21:51:28.925Z" }, + { url = "https://files.pythonhosted.org/packages/57/bf/2086963c69bdac3d7cff1cc7ff79b8ce5ea0bec6797a017e1be338a46248/protobuf-6.33.5-py3-none-any.whl", hash = "sha256:69915a973dd0f60f31a08b8318b73eab2bd6a392c79184b3612226b0a3f8ec02", size = 170687, upload-time = "2026-01-29T21:51:32.557Z" }, +] + [[package]] name = "psutil" version = "7.1.3" @@ -2510,26 +2640,26 @@ wheels = [ [[package]] name = "pytest-codspeed" -version = "4.2.0" +version = "4.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi" }, { name = "pytest" }, { name = "rich" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e2/e8/27fcbe6516a1c956614a4b61a7fccbf3791ea0b992e07416e8948184327d/pytest_codspeed-4.2.0.tar.gz", hash = "sha256:04b5d0bc5a1851ba1504d46bf9d7dbb355222a69f2cd440d54295db721b331f7", size = 113263, upload-time = "2025-10-24T09:02:55.704Z" } +sdist = { url = "https://files.pythonhosted.org/packages/98/ab/eca41967d11c95392829a8b4bfa9220a51cffc4a33ec4653358000356918/pytest_codspeed-4.3.0.tar.gz", hash = "sha256:5230d9d65f39063a313ed1820df775166227ec5c20a1122968f85653d5efee48", size = 124745, upload-time = "2026-02-09T15:23:34.745Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/2d/f0083a2f14ecf008d961d40439a71da0ae0d568e5f8dc2fccd3e8a2ab3e4/pytest_codspeed-4.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2de87bde9fbc6fd53f0fd21dcf2599c89e0b8948d49f9bad224edce51c47e26b", size = 261960, upload-time = "2025-10-24T09:02:40.665Z" }, - { url = "https://files.pythonhosted.org/packages/5f/0c/1f514c553db4ea5a69dfbe2706734129acd0eca8d5101ec16f1dd00dbc0f/pytest_codspeed-4.2.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95aeb2479ca383f6b18e2cc9ebcd3b03ab184980a59a232aea6f370bbf59a1e3", size = 250808, upload-time = "2025-10-24T09:02:42.07Z" }, - { url = "https://files.pythonhosted.org/packages/81/04/479905bd6653bc981c0554fcce6df52d7ae1594e1eefd53e6cf31810ec7f/pytest_codspeed-4.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7d4fefbd4ae401e2c60f6be920a0be50eef0c3e4a1f0a1c83962efd45be38b39", size = 262084, upload-time = "2025-10-24T09:02:43.155Z" }, - { url = "https://files.pythonhosted.org/packages/d2/46/d6f345d7907bac6cbb6224bd697ecbc11cf7427acc9e843c3618f19e3476/pytest_codspeed-4.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:309b4227f57fcbb9df21e889ea1ae191d0d1cd8b903b698fdb9ea0461dbf1dfe", size = 251100, upload-time = "2025-10-24T09:02:44.168Z" }, - { url = "https://files.pythonhosted.org/packages/de/dc/e864f45e994a50390ff49792256f1bdcbf42f170e3bc0470ee1a7d2403f3/pytest_codspeed-4.2.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:72aab8278452a6d020798b9e4f82780966adb00f80d27a25d1274272c54630d5", size = 262057, upload-time = "2025-10-24T09:02:45.791Z" }, - { url = "https://files.pythonhosted.org/packages/1d/1c/f1d2599784486879cf6579d8d94a3e22108f0e1f130033dab8feefd29249/pytest_codspeed-4.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:684fcd9491d810ded653a8d38de4835daa2d001645f4a23942862950664273f8", size = 251013, upload-time = "2025-10-24T09:02:46.937Z" }, - { url = "https://files.pythonhosted.org/packages/0c/fd/eafd24db5652a94b4d00fe9b309b607de81add0f55f073afb68a378a24b6/pytest_codspeed-4.2.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:50794dabea6ec90d4288904452051e2febace93e7edf4ca9f2bce8019dd8cd37", size = 262065, upload-time = "2025-10-24T09:02:48.018Z" }, - { url = "https://files.pythonhosted.org/packages/f9/14/8d9340d7dc0ae647991b28a396e16b3403e10def883cde90d6b663d3f7ec/pytest_codspeed-4.2.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0ebd87f2a99467a1cfd8e83492c4712976e43d353ee0b5f71cbb057f1393aca", size = 251057, upload-time = "2025-10-24T09:02:49.102Z" }, - { url = "https://files.pythonhosted.org/packages/4b/39/48cf6afbca55bc7c8c93c3d4ae926a1068bcce3f0241709db19b078d5418/pytest_codspeed-4.2.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dbbb2d61b85bef8fc7e2193f723f9ac2db388a48259d981bbce96319043e9830", size = 267983, upload-time = "2025-10-24T09:02:50.558Z" }, - { url = "https://files.pythonhosted.org/packages/33/86/4407341efb5dceb3e389635749ce1d670542d6ca148bd34f9d5334295faf/pytest_codspeed-4.2.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:748411c832147bfc85f805af78a1ab1684f52d08e14aabe22932bbe46c079a5f", size = 256732, upload-time = "2025-10-24T09:02:51.603Z" }, - { url = "https://files.pythonhosted.org/packages/25/0e/8cb71fd3ed4ed08c07aec1245aea7bc1b661ba55fd9c392db76f1978d453/pytest_codspeed-4.2.0-py3-none-any.whl", hash = "sha256:e81bbb45c130874ef99aca97929d72682733527a49f84239ba575b5cb843bab0", size = 113726, upload-time = "2025-10-24T09:02:54.785Z" }, + { url = "https://files.pythonhosted.org/packages/d9/15/ec0ac1f022173b3134c9638f2a35f21fbb3142c75da066d9e49e5a8bb4bd/pytest_codspeed-4.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dbeff1eb2f2e36df088658b556fa993e6937bf64ffb07406de4db16fd2b26874", size = 347076, upload-time = "2026-02-09T15:23:19.989Z" }, + { url = "https://files.pythonhosted.org/packages/a5/e8/1fe375794ad02b7835f378a7bcfa8fbac9acadefe600a782a7c4a7064db7/pytest_codspeed-4.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:878aad5e4bb7b401ad8d82f3af5186030cd2bd0d0446782e10dabb9db8827466", size = 342215, upload-time = "2026-02-09T15:23:20.954Z" }, + { url = "https://files.pythonhosted.org/packages/09/58/50df94e9a78e1c77818a492c90557eeb1309af025120c9a21e6375950c52/pytest_codspeed-4.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:527a3a02eaa3e4d4583adc4ba2327eef79628f3e1c682a4b959439551a72588e", size = 347395, upload-time = "2026-02-09T15:23:21.986Z" }, + { url = "https://files.pythonhosted.org/packages/e4/56/7dfbd3eefd112a14e6fb65f9ff31dacf2e9c381cb94b27332b81d2b13f8d/pytest_codspeed-4.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9858c2a6e1f391d5696757e7b6e9484749a7376c46f8b4dd9aebf093479a9667", size = 342625, upload-time = "2026-02-09T15:23:23.035Z" }, + { url = "https://files.pythonhosted.org/packages/7f/53/7255f6a25bc56ff1745b254b21545dfe0be2268f5b91ce78f7e8a908f0ad/pytest_codspeed-4.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:34f2fd8497456eefbd325673f677ea80d93bb1bc08a578c1fa43a09cec3d1879", size = 347325, upload-time = "2026-02-09T15:23:23.998Z" }, + { url = "https://files.pythonhosted.org/packages/2e/f8/82ae570d8b9ad30f33c9d4002a7a1b2740de0e090540c69a28e4f711ebe2/pytest_codspeed-4.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:df6a36a2a9da1406bc50428437f657f0bd8c842ae54bee5fb3ad30e01d50c0f5", size = 342558, upload-time = "2026-02-09T15:23:25.656Z" }, + { url = "https://files.pythonhosted.org/packages/b3/e1/55cfe9474f91d174c7a4b04d257b5fc6d4d06f3d3680f2da672ee59ccc10/pytest_codspeed-4.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bec30f4fc9c4973143cd80f0d33fa780e9fa3e01e4dbe8cedf229e72f1212c62", size = 347383, upload-time = "2026-02-09T15:23:26.68Z" }, + { url = "https://files.pythonhosted.org/packages/7f/3b/8fd781d959bbe789b3de8ce4c50d5706a684a0df377147dfb27b200c20c1/pytest_codspeed-4.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e6584e641cadf27d894ae90b87c50377232a97cbfd76ee0c7ecd0c056fa3f7f4", size = 342481, upload-time = "2026-02-09T15:23:27.686Z" }, + { url = "https://files.pythonhosted.org/packages/bb/0c/368045133c6effa2c665b1634b7b8a9c88b307f877fa31f1f8df47885b51/pytest_codspeed-4.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df0d1f6ea594f29b745c634d66d5f5f1caa1c3abd2af82fea49d656038e8fc77", size = 353680, upload-time = "2026-02-09T15:23:28.726Z" }, + { url = "https://files.pythonhosted.org/packages/59/21/e543abcd72244294e25ae88ec3a9311ade24d6913f8c8f42569d671700bc/pytest_codspeed-4.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a2f5bb6d8898bea7db45e3c8b916ee48e36905b929477bb511b79c5a3ccacda4", size = 347888, upload-time = "2026-02-09T15:23:30.443Z" }, + { url = "https://files.pythonhosted.org/packages/55/d9/b8a53c20cf5b41042c205bb9d36d37da00418d30fd1a94bf9eb147820720/pytest_codspeed-4.3.0-py3-none-any.whl", hash = "sha256:05baff2a61dc9f3e92b92b9c2ab5fb45d9b802438f5373073f5766a91319ed7a", size = 125224, upload-time = "2026-02-09T15:23:33.774Z" }, ] [[package]] @@ -2970,73 +3100,73 @@ wheels = [ [[package]] name = "scipy" -version = "1.16.3" +version = "1.17.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/ca/d8ace4f98322d01abcd52d381134344bf7b431eba7ed8b42bdea5a3c2ac9/scipy-1.16.3.tar.gz", hash = "sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb", size = 30597883, upload-time = "2025-10-28T17:38:54.068Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/5f/6f37d7439de1455ce9c5a556b8d1db0979f03a796c030bafdf08d35b7bf9/scipy-1.16.3-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97", size = 36630881, upload-time = "2025-10-28T17:31:47.104Z" }, - { url = "https://files.pythonhosted.org/packages/7c/89/d70e9f628749b7e4db2aa4cd89735502ff3f08f7b9b27d2e799485987cd9/scipy-1.16.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511", size = 28941012, upload-time = "2025-10-28T17:31:53.411Z" }, - { url = "https://files.pythonhosted.org/packages/a8/a8/0e7a9a6872a923505dbdf6bb93451edcac120363131c19013044a1e7cb0c/scipy-1.16.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005", size = 20931935, upload-time = "2025-10-28T17:31:57.361Z" }, - { url = "https://files.pythonhosted.org/packages/bd/c7/020fb72bd79ad798e4dbe53938543ecb96b3a9ac3fe274b7189e23e27353/scipy-1.16.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb", size = 23534466, upload-time = "2025-10-28T17:32:01.875Z" }, - { url = "https://files.pythonhosted.org/packages/be/a0/668c4609ce6dbf2f948e167836ccaf897f95fb63fa231c87da7558a374cd/scipy-1.16.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876", size = 33593618, upload-time = "2025-10-28T17:32:06.902Z" }, - { url = "https://files.pythonhosted.org/packages/ca/6e/8942461cf2636cdae083e3eb72622a7fbbfa5cf559c7d13ab250a5dbdc01/scipy-1.16.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2", size = 35899798, upload-time = "2025-10-28T17:32:12.665Z" }, - { url = "https://files.pythonhosted.org/packages/79/e8/d0f33590364cdbd67f28ce79368b373889faa4ee959588beddf6daef9abe/scipy-1.16.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e", size = 36226154, upload-time = "2025-10-28T17:32:17.961Z" }, - { url = "https://files.pythonhosted.org/packages/39/c1/1903de608c0c924a1749c590064e65810f8046e437aba6be365abc4f7557/scipy-1.16.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733", size = 38878540, upload-time = "2025-10-28T17:32:23.907Z" }, - { url = "https://files.pythonhosted.org/packages/f1/d0/22ec7036ba0b0a35bccb7f25ab407382ed34af0b111475eb301c16f8a2e5/scipy-1.16.3-cp311-cp311-win_amd64.whl", hash = "sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78", size = 38722107, upload-time = "2025-10-28T17:32:29.921Z" }, - { url = "https://files.pythonhosted.org/packages/7b/60/8a00e5a524bb3bf8898db1650d350f50e6cffb9d7a491c561dc9826c7515/scipy-1.16.3-cp311-cp311-win_arm64.whl", hash = "sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184", size = 25506272, upload-time = "2025-10-28T17:32:34.577Z" }, - { url = "https://files.pythonhosted.org/packages/40/41/5bf55c3f386b1643812f3a5674edf74b26184378ef0f3e7c7a09a7e2ca7f/scipy-1.16.3-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6", size = 36659043, upload-time = "2025-10-28T17:32:40.285Z" }, - { url = "https://files.pythonhosted.org/packages/1e/0f/65582071948cfc45d43e9870bf7ca5f0e0684e165d7c9ef4e50d783073eb/scipy-1.16.3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07", size = 28898986, upload-time = "2025-10-28T17:32:45.325Z" }, - { url = "https://files.pythonhosted.org/packages/96/5e/36bf3f0ac298187d1ceadde9051177d6a4fe4d507e8f59067dc9dd39e650/scipy-1.16.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9", size = 20889814, upload-time = "2025-10-28T17:32:49.277Z" }, - { url = "https://files.pythonhosted.org/packages/80/35/178d9d0c35394d5d5211bbff7ac4f2986c5488b59506fef9e1de13ea28d3/scipy-1.16.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686", size = 23565795, upload-time = "2025-10-28T17:32:53.337Z" }, - { url = "https://files.pythonhosted.org/packages/fa/46/d1146ff536d034d02f83c8afc3c4bab2eddb634624d6529a8512f3afc9da/scipy-1.16.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203", size = 33349476, upload-time = "2025-10-28T17:32:58.353Z" }, - { url = "https://files.pythonhosted.org/packages/79/2e/415119c9ab3e62249e18c2b082c07aff907a273741b3f8160414b0e9193c/scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1", size = 35676692, upload-time = "2025-10-28T17:33:03.88Z" }, - { url = "https://files.pythonhosted.org/packages/27/82/df26e44da78bf8d2aeaf7566082260cfa15955a5a6e96e6a29935b64132f/scipy-1.16.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe", size = 36019345, upload-time = "2025-10-28T17:33:09.773Z" }, - { url = "https://files.pythonhosted.org/packages/82/31/006cbb4b648ba379a95c87262c2855cd0d09453e500937f78b30f02fa1cd/scipy-1.16.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70", size = 38678975, upload-time = "2025-10-28T17:33:15.809Z" }, - { url = "https://files.pythonhosted.org/packages/c2/7f/acbd28c97e990b421af7d6d6cd416358c9c293fc958b8529e0bd5d2a2a19/scipy-1.16.3-cp312-cp312-win_amd64.whl", hash = "sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc", size = 38555926, upload-time = "2025-10-28T17:33:21.388Z" }, - { url = "https://files.pythonhosted.org/packages/ce/69/c5c7807fd007dad4f48e0a5f2153038dc96e8725d3345b9ee31b2b7bed46/scipy-1.16.3-cp312-cp312-win_arm64.whl", hash = "sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2", size = 25463014, upload-time = "2025-10-28T17:33:25.975Z" }, - { url = "https://files.pythonhosted.org/packages/72/f1/57e8327ab1508272029e27eeef34f2302ffc156b69e7e233e906c2a5c379/scipy-1.16.3-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c", size = 36617856, upload-time = "2025-10-28T17:33:31.375Z" }, - { url = "https://files.pythonhosted.org/packages/44/13/7e63cfba8a7452eb756306aa2fd9b37a29a323b672b964b4fdeded9a3f21/scipy-1.16.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d", size = 28874306, upload-time = "2025-10-28T17:33:36.516Z" }, - { url = "https://files.pythonhosted.org/packages/15/65/3a9400efd0228a176e6ec3454b1fa998fbbb5a8defa1672c3f65706987db/scipy-1.16.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9", size = 20865371, upload-time = "2025-10-28T17:33:42.094Z" }, - { url = "https://files.pythonhosted.org/packages/33/d7/eda09adf009a9fb81827194d4dd02d2e4bc752cef16737cc4ef065234031/scipy-1.16.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4", size = 23524877, upload-time = "2025-10-28T17:33:48.483Z" }, - { url = "https://files.pythonhosted.org/packages/7d/6b/3f911e1ebc364cb81320223a3422aab7d26c9c7973109a9cd0f27c64c6c0/scipy-1.16.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959", size = 33342103, upload-time = "2025-10-28T17:33:56.495Z" }, - { url = "https://files.pythonhosted.org/packages/21/f6/4bfb5695d8941e5c570a04d9fcd0d36bce7511b7d78e6e75c8f9791f82d0/scipy-1.16.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88", size = 35697297, upload-time = "2025-10-28T17:34:04.722Z" }, - { url = "https://files.pythonhosted.org/packages/04/e1/6496dadbc80d8d896ff72511ecfe2316b50313bfc3ebf07a3f580f08bd8c/scipy-1.16.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234", size = 36021756, upload-time = "2025-10-28T17:34:13.482Z" }, - { url = "https://files.pythonhosted.org/packages/fe/bd/a8c7799e0136b987bda3e1b23d155bcb31aec68a4a472554df5f0937eef7/scipy-1.16.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d", size = 38696566, upload-time = "2025-10-28T17:34:22.384Z" }, - { url = "https://files.pythonhosted.org/packages/cd/01/1204382461fcbfeb05b6161b594f4007e78b6eba9b375382f79153172b4d/scipy-1.16.3-cp313-cp313-win_amd64.whl", hash = "sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304", size = 38529877, upload-time = "2025-10-28T17:35:51.076Z" }, - { url = "https://files.pythonhosted.org/packages/7f/14/9d9fbcaa1260a94f4bb5b64ba9213ceb5d03cd88841fe9fd1ffd47a45b73/scipy-1.16.3-cp313-cp313-win_arm64.whl", hash = "sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2", size = 25455366, upload-time = "2025-10-28T17:35:59.014Z" }, - { url = "https://files.pythonhosted.org/packages/e2/a3/9ec205bd49f42d45d77f1730dbad9ccf146244c1647605cf834b3a8c4f36/scipy-1.16.3-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b", size = 37027931, upload-time = "2025-10-28T17:34:31.451Z" }, - { url = "https://files.pythonhosted.org/packages/25/06/ca9fd1f3a4589cbd825b1447e5db3a8ebb969c1eaf22c8579bd286f51b6d/scipy-1.16.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079", size = 29400081, upload-time = "2025-10-28T17:34:39.087Z" }, - { url = "https://files.pythonhosted.org/packages/6a/56/933e68210d92657d93fb0e381683bc0e53a965048d7358ff5fbf9e6a1b17/scipy-1.16.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a", size = 21391244, upload-time = "2025-10-28T17:34:45.234Z" }, - { url = "https://files.pythonhosted.org/packages/a8/7e/779845db03dc1418e215726329674b40576879b91814568757ff0014ad65/scipy-1.16.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119", size = 23929753, upload-time = "2025-10-28T17:34:51.793Z" }, - { url = "https://files.pythonhosted.org/packages/4c/4b/f756cf8161d5365dcdef9e5f460ab226c068211030a175d2fc7f3f41ca64/scipy-1.16.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c", size = 33496912, upload-time = "2025-10-28T17:34:59.8Z" }, - { url = "https://files.pythonhosted.org/packages/09/b5/222b1e49a58668f23839ca1542a6322bb095ab8d6590d4f71723869a6c2c/scipy-1.16.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e", size = 35802371, upload-time = "2025-10-28T17:35:08.173Z" }, - { url = "https://files.pythonhosted.org/packages/c1/8d/5964ef68bb31829bde27611f8c9deeac13764589fe74a75390242b64ca44/scipy-1.16.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135", size = 36190477, upload-time = "2025-10-28T17:35:16.7Z" }, - { url = "https://files.pythonhosted.org/packages/ab/f2/b31d75cb9b5fa4dd39a0a931ee9b33e7f6f36f23be5ef560bf72e0f92f32/scipy-1.16.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6", size = 38796678, upload-time = "2025-10-28T17:35:26.354Z" }, - { url = "https://files.pythonhosted.org/packages/b4/1e/b3723d8ff64ab548c38d87055483714fefe6ee20e0189b62352b5e015bb1/scipy-1.16.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc", size = 38640178, upload-time = "2025-10-28T17:35:35.304Z" }, - { url = "https://files.pythonhosted.org/packages/8e/f3/d854ff38789aca9b0cc23008d607ced9de4f7ab14fa1ca4329f86b3758ca/scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a", size = 25803246, upload-time = "2025-10-28T17:35:42.155Z" }, - { url = "https://files.pythonhosted.org/packages/99/f6/99b10fd70f2d864c1e29a28bbcaa0c6340f9d8518396542d9ea3b4aaae15/scipy-1.16.3-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:875555ce62743e1d54f06cdf22c1e0bc47b91130ac40fe5d783b6dfa114beeb6", size = 36606469, upload-time = "2025-10-28T17:36:08.741Z" }, - { url = "https://files.pythonhosted.org/packages/4d/74/043b54f2319f48ea940dd025779fa28ee360e6b95acb7cd188fad4391c6b/scipy-1.16.3-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:bb61878c18a470021fb515a843dc7a76961a8daceaaaa8bad1332f1bf4b54657", size = 28872043, upload-time = "2025-10-28T17:36:16.599Z" }, - { url = "https://files.pythonhosted.org/packages/4d/e1/24b7e50cc1c4ee6ffbcb1f27fe9f4c8b40e7911675f6d2d20955f41c6348/scipy-1.16.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:f2622206f5559784fa5c4b53a950c3c7c1cf3e84ca1b9c4b6c03f062f289ca26", size = 20862952, upload-time = "2025-10-28T17:36:22.966Z" }, - { url = "https://files.pythonhosted.org/packages/dd/3a/3e8c01a4d742b730df368e063787c6808597ccb38636ed821d10b39ca51b/scipy-1.16.3-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:7f68154688c515cdb541a31ef8eb66d8cd1050605be9dcd74199cbd22ac739bc", size = 23508512, upload-time = "2025-10-28T17:36:29.731Z" }, - { url = "https://files.pythonhosted.org/packages/1f/60/c45a12b98ad591536bfe5330cb3cfe1850d7570259303563b1721564d458/scipy-1.16.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3c820ddb80029fe9f43d61b81d8b488d3ef8ca010d15122b152db77dc94c22", size = 33413639, upload-time = "2025-10-28T17:36:37.982Z" }, - { url = "https://files.pythonhosted.org/packages/71/bc/35957d88645476307e4839712642896689df442f3e53b0fa016ecf8a3357/scipy-1.16.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d3837938ae715fc0fe3c39c0202de3a8853aff22ca66781ddc2ade7554b7e2cc", size = 35704729, upload-time = "2025-10-28T17:36:46.547Z" }, - { url = "https://files.pythonhosted.org/packages/3b/15/89105e659041b1ca11c386e9995aefacd513a78493656e57789f9d9eab61/scipy-1.16.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:aadd23f98f9cb069b3bd64ddc900c4d277778242e961751f77a8cb5c4b946fb0", size = 36086251, upload-time = "2025-10-28T17:36:55.161Z" }, - { url = "https://files.pythonhosted.org/packages/1a/87/c0ea673ac9c6cc50b3da2196d860273bc7389aa69b64efa8493bdd25b093/scipy-1.16.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b7c5f1bda1354d6a19bc6af73a649f8285ca63ac6b52e64e658a5a11d4d69800", size = 38716681, upload-time = "2025-10-28T17:37:04.1Z" }, - { url = "https://files.pythonhosted.org/packages/91/06/837893227b043fb9b0d13e4bd7586982d8136cb249ffb3492930dab905b8/scipy-1.16.3-cp314-cp314-win_amd64.whl", hash = "sha256:e5d42a9472e7579e473879a1990327830493a7047506d58d73fc429b84c1d49d", size = 39358423, upload-time = "2025-10-28T17:38:20.005Z" }, - { url = "https://files.pythonhosted.org/packages/95/03/28bce0355e4d34a7c034727505a02d19548549e190bedd13a721e35380b7/scipy-1.16.3-cp314-cp314-win_arm64.whl", hash = "sha256:6020470b9d00245926f2d5bb93b119ca0340f0d564eb6fbaad843eaebf9d690f", size = 26135027, upload-time = "2025-10-28T17:38:24.966Z" }, - { url = "https://files.pythonhosted.org/packages/b2/6f/69f1e2b682efe9de8fe9f91040f0cd32f13cfccba690512ba4c582b0bc29/scipy-1.16.3-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:e1d27cbcb4602680a49d787d90664fa4974063ac9d4134813332a8c53dbe667c", size = 37028379, upload-time = "2025-10-28T17:37:14.061Z" }, - { url = "https://files.pythonhosted.org/packages/7c/2d/e826f31624a5ebbab1cd93d30fd74349914753076ed0593e1d56a98c4fb4/scipy-1.16.3-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:9b9c9c07b6d56a35777a1b4cc8966118fb16cfd8daf6743867d17d36cfad2d40", size = 29400052, upload-time = "2025-10-28T17:37:21.709Z" }, - { url = "https://files.pythonhosted.org/packages/69/27/d24feb80155f41fd1f156bf144e7e049b4e2b9dd06261a242905e3bc7a03/scipy-1.16.3-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:3a4c460301fb2cffb7f88528f30b3127742cff583603aa7dc964a52c463b385d", size = 21391183, upload-time = "2025-10-28T17:37:29.559Z" }, - { url = "https://files.pythonhosted.org/packages/f8/d3/1b229e433074c5738a24277eca520a2319aac7465eea7310ea6ae0e98ae2/scipy-1.16.3-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:f667a4542cc8917af1db06366d3f78a5c8e83badd56409f94d1eac8d8d9133fa", size = 23930174, upload-time = "2025-10-28T17:37:36.306Z" }, - { url = "https://files.pythonhosted.org/packages/16/9d/d9e148b0ec680c0f042581a2be79a28a7ab66c0c4946697f9e7553ead337/scipy-1.16.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f379b54b77a597aa7ee5e697df0d66903e41b9c85a6dd7946159e356319158e8", size = 33497852, upload-time = "2025-10-28T17:37:42.228Z" }, - { url = "https://files.pythonhosted.org/packages/2f/22/4e5f7561e4f98b7bea63cf3fd7934bff1e3182e9f1626b089a679914d5c8/scipy-1.16.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4aff59800a3b7f786b70bfd6ab551001cb553244988d7d6b8299cb1ea653b353", size = 35798595, upload-time = "2025-10-28T17:37:48.102Z" }, - { url = "https://files.pythonhosted.org/packages/83/42/6644d714c179429fc7196857866f219fef25238319b650bb32dde7bf7a48/scipy-1.16.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:da7763f55885045036fabcebd80144b757d3db06ab0861415d1c3b7c69042146", size = 36186269, upload-time = "2025-10-28T17:37:53.72Z" }, - { url = "https://files.pythonhosted.org/packages/ac/70/64b4d7ca92f9cf2e6fc6aaa2eecf80bb9b6b985043a9583f32f8177ea122/scipy-1.16.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ffa6eea95283b2b8079b821dc11f50a17d0571c92b43e2b5b12764dc5f9b285d", size = 38802779, upload-time = "2025-10-28T17:37:59.393Z" }, - { url = "https://files.pythonhosted.org/packages/61/82/8d0e39f62764cce5ffd5284131e109f07cf8955aef9ab8ed4e3aa5e30539/scipy-1.16.3-cp314-cp314t-win_amd64.whl", hash = "sha256:d9f48cafc7ce94cf9b15c6bffdc443a81a27bf7075cf2dcd5c8b40f85d10c4e7", size = 39471128, upload-time = "2025-10-28T17:38:05.259Z" }, - { url = "https://files.pythonhosted.org/packages/64/47/a494741db7280eae6dc033510c319e34d42dd41b7ac0c7ead39354d1a2b5/scipy-1.16.3-cp314-cp314t-win_arm64.whl", hash = "sha256:21d9d6b197227a12dcbf9633320a4e34c6b0e51c57268df255a0942983bac562", size = 26464127, upload-time = "2025-10-28T17:38:11.34Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/7a/97/5a3609c4f8d58b039179648e62dd220f89864f56f7357f5d4f45c29eb2cc/scipy-1.17.1.tar.gz", hash = "sha256:95d8e012d8cb8816c226aef832200b1d45109ed4464303e997c5b13122b297c0", size = 30573822, upload-time = "2026-02-23T00:26:24.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/75/b4ce781849931fef6fd529afa6b63711d5a733065722d0c3e2724af9e40a/scipy-1.17.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1f95b894f13729334fb990162e911c9e5dc1ab390c58aa6cbecb389c5b5e28ec", size = 31613675, upload-time = "2026-02-23T00:16:00.13Z" }, + { url = "https://files.pythonhosted.org/packages/f7/58/bccc2861b305abdd1b8663d6130c0b3d7cc22e8d86663edbc8401bfd40d4/scipy-1.17.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:e18f12c6b0bc5a592ed23d3f7b891f68fd7f8241d69b7883769eb5d5dfb52696", size = 28162057, upload-time = "2026-02-23T00:16:09.456Z" }, + { url = "https://files.pythonhosted.org/packages/6d/ee/18146b7757ed4976276b9c9819108adbc73c5aad636e5353e20746b73069/scipy-1.17.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a3472cfbca0a54177d0faa68f697d8ba4c80bbdc19908c3465556d9f7efce9ee", size = 20334032, upload-time = "2026-02-23T00:16:17.358Z" }, + { url = "https://files.pythonhosted.org/packages/ec/e6/cef1cf3557f0c54954198554a10016b6a03b2ec9e22a4e1df734936bd99c/scipy-1.17.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:766e0dc5a616d026a3a1cffa379af959671729083882f50307e18175797b3dfd", size = 22709533, upload-time = "2026-02-23T00:16:25.791Z" }, + { url = "https://files.pythonhosted.org/packages/4d/60/8804678875fc59362b0fb759ab3ecce1f09c10a735680318ac30da8cd76b/scipy-1.17.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:744b2bf3640d907b79f3fd7874efe432d1cf171ee721243e350f55234b4cec4c", size = 33062057, upload-time = "2026-02-23T00:16:36.931Z" }, + { url = "https://files.pythonhosted.org/packages/09/7d/af933f0f6e0767995b4e2d705a0665e454d1c19402aa7e895de3951ebb04/scipy-1.17.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43af8d1f3bea642559019edfe64e9b11192a8978efbd1539d7bc2aaa23d92de4", size = 35349300, upload-time = "2026-02-23T00:16:49.108Z" }, + { url = "https://files.pythonhosted.org/packages/b4/3d/7ccbbdcbb54c8fdc20d3b6930137c782a163fa626f0aef920349873421ba/scipy-1.17.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cd96a1898c0a47be4520327e01f874acfd61fb48a9420f8aa9f6483412ffa444", size = 35127333, upload-time = "2026-02-23T00:17:01.293Z" }, + { url = "https://files.pythonhosted.org/packages/e8/19/f926cb11c42b15ba08e3a71e376d816ac08614f769b4f47e06c3580c836a/scipy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4eb6c25dd62ee8d5edf68a8e1c171dd71c292fdae95d8aeb3dd7d7de4c364082", size = 37741314, upload-time = "2026-02-23T00:17:12.576Z" }, + { url = "https://files.pythonhosted.org/packages/95/da/0d1df507cf574b3f224ccc3d45244c9a1d732c81dcb26b1e8a766ae271a8/scipy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:d30e57c72013c2a4fe441c2fcb8e77b14e152ad48b5464858e07e2ad9fbfceff", size = 36607512, upload-time = "2026-02-23T00:17:23.424Z" }, + { url = "https://files.pythonhosted.org/packages/68/7f/bdd79ceaad24b671543ffe0ef61ed8e659440eb683b66f033454dcee90eb/scipy-1.17.1-cp311-cp311-win_arm64.whl", hash = "sha256:9ecb4efb1cd6e8c4afea0daa91a87fbddbce1b99d2895d151596716c0b2e859d", size = 24599248, upload-time = "2026-02-23T00:17:34.561Z" }, + { url = "https://files.pythonhosted.org/packages/35/48/b992b488d6f299dbe3f11a20b24d3dda3d46f1a635ede1c46b5b17a7b163/scipy-1.17.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:35c3a56d2ef83efc372eaec584314bd0ef2e2f0d2adb21c55e6ad5b344c0dcb8", size = 31610954, upload-time = "2026-02-23T00:17:49.855Z" }, + { url = "https://files.pythonhosted.org/packages/b2/02/cf107b01494c19dc100f1d0b7ac3cc08666e96ba2d64db7626066cee895e/scipy-1.17.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fcb310ddb270a06114bb64bbe53c94926b943f5b7f0842194d585c65eb4edd76", size = 28172662, upload-time = "2026-02-23T00:18:01.64Z" }, + { url = "https://files.pythonhosted.org/packages/cf/a9/599c28631bad314d219cf9ffd40e985b24d603fc8a2f4ccc5ae8419a535b/scipy-1.17.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cc90d2e9c7e5c7f1a482c9875007c095c3194b1cfedca3c2f3291cdc2bc7c086", size = 20344366, upload-time = "2026-02-23T00:18:12.015Z" }, + { url = "https://files.pythonhosted.org/packages/35/f5/906eda513271c8deb5af284e5ef0206d17a96239af79f9fa0aebfe0e36b4/scipy-1.17.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c80be5ede8f3f8eded4eff73cc99a25c388ce98e555b17d31da05287015ffa5b", size = 22704017, upload-time = "2026-02-23T00:18:21.502Z" }, + { url = "https://files.pythonhosted.org/packages/da/34/16f10e3042d2f1d6b66e0428308ab52224b6a23049cb2f5c1756f713815f/scipy-1.17.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e19ebea31758fac5893a2ac360fedd00116cbb7628e650842a6691ba7ca28a21", size = 32927842, upload-time = "2026-02-23T00:18:35.367Z" }, + { url = "https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02ae3b274fde71c5e92ac4d54bc06c42d80e399fec704383dcd99b301df37458", size = 35235890, upload-time = "2026-02-23T00:18:49.188Z" }, + { url = "https://files.pythonhosted.org/packages/c5/5c/9d7f4c88bea6e0d5a4f1bc0506a53a00e9fcb198de372bfe4d3652cef482/scipy-1.17.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a604bae87c6195d8b1045eddece0514d041604b14f2727bbc2b3020172045eb", size = 35003557, upload-time = "2026-02-23T00:18:54.74Z" }, + { url = "https://files.pythonhosted.org/packages/65/94/7698add8f276dbab7a9de9fb6b0e02fc13ee61d51c7c3f85ac28b65e1239/scipy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f590cd684941912d10becc07325a3eeb77886fe981415660d9265c4c418d0bea", size = 37625856, upload-time = "2026-02-23T00:19:00.307Z" }, + { url = "https://files.pythonhosted.org/packages/a2/84/dc08d77fbf3d87d3ee27f6a0c6dcce1de5829a64f2eae85a0ecc1f0daa73/scipy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:41b71f4a3a4cab9d366cd9065b288efc4d4f3c0b37a91a8e0947fb5bd7f31d87", size = 36549682, upload-time = "2026-02-23T00:19:07.67Z" }, + { url = "https://files.pythonhosted.org/packages/bc/98/fe9ae9ffb3b54b62559f52dedaebe204b408db8109a8c66fdd04869e6424/scipy-1.17.1-cp312-cp312-win_arm64.whl", hash = "sha256:f4115102802df98b2b0db3cce5cb9b92572633a1197c77b7553e5203f284a5b3", size = 24547340, upload-time = "2026-02-23T00:19:12.024Z" }, + { url = "https://files.pythonhosted.org/packages/76/27/07ee1b57b65e92645f219b37148a7e7928b82e2b5dbeccecb4dff7c64f0b/scipy-1.17.1-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:5e3c5c011904115f88a39308379c17f91546f77c1667cea98739fe0fccea804c", size = 31590199, upload-time = "2026-02-23T00:19:17.192Z" }, + { url = "https://files.pythonhosted.org/packages/ec/ae/db19f8ab842e9b724bf5dbb7db29302a91f1e55bc4d04b1025d6d605a2c5/scipy-1.17.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6fac755ca3d2c3edcb22f479fceaa241704111414831ddd3bc6056e18516892f", size = 28154001, upload-time = "2026-02-23T00:19:22.241Z" }, + { url = "https://files.pythonhosted.org/packages/5b/58/3ce96251560107b381cbd6e8413c483bbb1228a6b919fa8652b0d4090e7f/scipy-1.17.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:7ff200bf9d24f2e4d5dc6ee8c3ac64d739d3a89e2326ba68aaf6c4a2b838fd7d", size = 20325719, upload-time = "2026-02-23T00:19:26.329Z" }, + { url = "https://files.pythonhosted.org/packages/b2/83/15087d945e0e4d48ce2377498abf5ad171ae013232ae31d06f336e64c999/scipy-1.17.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:4b400bdc6f79fa02a4d86640310dde87a21fba0c979efff5248908c6f15fad1b", size = 22683595, upload-time = "2026-02-23T00:19:30.304Z" }, + { url = "https://files.pythonhosted.org/packages/b4/e0/e58fbde4a1a594c8be8114eb4aac1a55bcd6587047efc18a61eb1f5c0d30/scipy-1.17.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b64ca7d4aee0102a97f3ba22124052b4bd2152522355073580bf4845e2550b6", size = 32896429, upload-time = "2026-02-23T00:19:35.536Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:581b2264fc0aa555f3f435a5944da7504ea3a065d7029ad60e7c3d1ae09c5464", size = 35203952, upload-time = "2026-02-23T00:19:42.259Z" }, + { url = "https://files.pythonhosted.org/packages/8d/a5/9afd17de24f657fdfe4df9a3f1ea049b39aef7c06000c13db1530d81ccca/scipy-1.17.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:beeda3d4ae615106d7094f7e7cef6218392e4465cc95d25f900bebabfded0950", size = 34979063, upload-time = "2026-02-23T00:19:47.547Z" }, + { url = "https://files.pythonhosted.org/packages/8b/13/88b1d2384b424bf7c924f2038c1c409f8d88bb2a8d49d097861dd64a57b2/scipy-1.17.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6609bc224e9568f65064cfa72edc0f24ee6655b47575954ec6339534b2798369", size = 37598449, upload-time = "2026-02-23T00:19:53.238Z" }, + { url = "https://files.pythonhosted.org/packages/35/e5/d6d0e51fc888f692a35134336866341c08655d92614f492c6860dc45bb2c/scipy-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:37425bc9175607b0268f493d79a292c39f9d001a357bebb6b88fdfaff13f6448", size = 36510943, upload-time = "2026-02-23T00:20:50.89Z" }, + { url = "https://files.pythonhosted.org/packages/2a/fd/3be73c564e2a01e690e19cc618811540ba5354c67c8680dce3281123fb79/scipy-1.17.1-cp313-cp313-win_arm64.whl", hash = "sha256:5cf36e801231b6a2059bf354720274b7558746f3b1a4efb43fcf557ccd484a87", size = 24545621, upload-time = "2026-02-23T00:20:55.871Z" }, + { url = "https://files.pythonhosted.org/packages/6f/6b/17787db8b8114933a66f9dcc479a8272e4b4da75fe03b0c282f7b0ade8cd/scipy-1.17.1-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:d59c30000a16d8edc7e64152e30220bfbd724c9bbb08368c054e24c651314f0a", size = 31936708, upload-time = "2026-02-23T00:19:58.694Z" }, + { url = "https://files.pythonhosted.org/packages/38/2e/524405c2b6392765ab1e2b722a41d5da33dc5c7b7278184a8ad29b6cb206/scipy-1.17.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:010f4333c96c9bb1a4516269e33cb5917b08ef2166d5556ca2fd9f082a9e6ea0", size = 28570135, upload-time = "2026-02-23T00:20:03.934Z" }, + { url = "https://files.pythonhosted.org/packages/fd/c3/5bd7199f4ea8556c0c8e39f04ccb014ac37d1468e6cfa6a95c6b3562b76e/scipy-1.17.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:2ceb2d3e01c5f1d83c4189737a42d9cb2fc38a6eeed225e7515eef71ad301dce", size = 20741977, upload-time = "2026-02-23T00:20:07.935Z" }, + { url = "https://files.pythonhosted.org/packages/d9/b8/8ccd9b766ad14c78386599708eb745f6b44f08400a5fd0ade7cf89b6fc93/scipy-1.17.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:844e165636711ef41f80b4103ed234181646b98a53c8f05da12ca5ca289134f6", size = 23029601, upload-time = "2026-02-23T00:20:12.161Z" }, + { url = "https://files.pythonhosted.org/packages/6d/a0/3cb6f4d2fb3e17428ad2880333cac878909ad1a89f678527b5328b93c1d4/scipy-1.17.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:158dd96d2207e21c966063e1635b1063cd7787b627b6f07305315dd73d9c679e", size = 33019667, upload-time = "2026-02-23T00:20:17.208Z" }, + { url = "https://files.pythonhosted.org/packages/f3/c3/2d834a5ac7bf3a0c806ad1508efc02dda3c8c61472a56132d7894c312dea/scipy-1.17.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74cbb80d93260fe2ffa334efa24cb8f2f0f622a9b9febf8b483c0b865bfb3475", size = 35264159, upload-time = "2026-02-23T00:20:23.087Z" }, + { url = "https://files.pythonhosted.org/packages/4d/77/d3ed4becfdbd217c52062fafe35a72388d1bd82c2d0ba5ca19d6fcc93e11/scipy-1.17.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:dbc12c9f3d185f5c737d801da555fb74b3dcfa1a50b66a1a93e09190f41fab50", size = 35102771, upload-time = "2026-02-23T00:20:28.636Z" }, + { url = "https://files.pythonhosted.org/packages/bd/12/d19da97efde68ca1ee5538bb261d5d2c062f0c055575128f11a2730e3ac1/scipy-1.17.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:94055a11dfebe37c656e70317e1996dc197e1a15bbcc351bcdd4610e128fe1ca", size = 37665910, upload-time = "2026-02-23T00:20:34.743Z" }, + { url = "https://files.pythonhosted.org/packages/06/1c/1172a88d507a4baaf72c5a09bb6c018fe2ae0ab622e5830b703a46cc9e44/scipy-1.17.1-cp313-cp313t-win_amd64.whl", hash = "sha256:e30bdeaa5deed6bc27b4cc490823cd0347d7dae09119b8803ae576ea0ce52e4c", size = 36562980, upload-time = "2026-02-23T00:20:40.575Z" }, + { url = "https://files.pythonhosted.org/packages/70/b0/eb757336e5a76dfa7911f63252e3b7d1de00935d7705cf772db5b45ec238/scipy-1.17.1-cp313-cp313t-win_arm64.whl", hash = "sha256:a720477885a9d2411f94a93d16f9d89bad0f28ca23c3f8daa521e2dcc3f44d49", size = 24856543, upload-time = "2026-02-23T00:20:45.313Z" }, + { url = "https://files.pythonhosted.org/packages/cf/83/333afb452af6f0fd70414dc04f898647ee1423979ce02efa75c3b0f2c28e/scipy-1.17.1-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:a48a72c77a310327f6a3a920092fa2b8fd03d7deaa60f093038f22d98e096717", size = 31584510, upload-time = "2026-02-23T00:21:01.015Z" }, + { url = "https://files.pythonhosted.org/packages/ed/a6/d05a85fd51daeb2e4ea71d102f15b34fedca8e931af02594193ae4fd25f7/scipy-1.17.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:45abad819184f07240d8a696117a7aacd39787af9e0b719d00285549ed19a1e9", size = 28170131, upload-time = "2026-02-23T00:21:05.888Z" }, + { url = "https://files.pythonhosted.org/packages/db/7b/8624a203326675d7746a254083a187398090a179335b2e4a20e2ddc46e83/scipy-1.17.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:3fd1fcdab3ea951b610dc4cef356d416d5802991e7e32b5254828d342f7b7e0b", size = 20342032, upload-time = "2026-02-23T00:21:09.904Z" }, + { url = "https://files.pythonhosted.org/packages/c9/35/2c342897c00775d688d8ff3987aced3426858fd89d5a0e26e020b660b301/scipy-1.17.1-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:7bdf2da170b67fdf10bca777614b1c7d96ae3ca5794fd9587dce41eb2966e866", size = 22678766, upload-time = "2026-02-23T00:21:14.313Z" }, + { url = "https://files.pythonhosted.org/packages/ef/f2/7cdb8eb308a1a6ae1e19f945913c82c23c0c442a462a46480ce487fdc0ac/scipy-1.17.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:adb2642e060a6549c343603a3851ba76ef0b74cc8c079a9a58121c7ec9fe2350", size = 32957007, upload-time = "2026-02-23T00:21:19.663Z" }, + { url = "https://files.pythonhosted.org/packages/0b/2e/7eea398450457ecb54e18e9d10110993fa65561c4f3add5e8eccd2b9cd41/scipy-1.17.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eee2cfda04c00a857206a4330f0c5e3e56535494e30ca445eb19ec624ae75118", size = 35221333, upload-time = "2026-02-23T00:21:25.278Z" }, + { url = "https://files.pythonhosted.org/packages/d9/77/5b8509d03b77f093a0d52e606d3c4f79e8b06d1d38c441dacb1e26cacf46/scipy-1.17.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d2650c1fb97e184d12d8ba010493ee7b322864f7d3d00d3f9bb97d9c21de4068", size = 35042066, upload-time = "2026-02-23T00:21:31.358Z" }, + { url = "https://files.pythonhosted.org/packages/f9/df/18f80fb99df40b4070328d5ae5c596f2f00fffb50167e31439e932f29e7d/scipy-1.17.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:08b900519463543aa604a06bec02461558a6e1cef8fdbb8098f77a48a83c8118", size = 37612763, upload-time = "2026-02-23T00:21:37.247Z" }, + { url = "https://files.pythonhosted.org/packages/4b/39/f0e8ea762a764a9dc52aa7dabcfad51a354819de1f0d4652b6a1122424d6/scipy-1.17.1-cp314-cp314-win_amd64.whl", hash = "sha256:3877ac408e14da24a6196de0ddcace62092bfc12a83823e92e49e40747e52c19", size = 37290984, upload-time = "2026-02-23T00:22:35.023Z" }, + { url = "https://files.pythonhosted.org/packages/7c/56/fe201e3b0f93d1a8bcf75d3379affd228a63d7e2d80ab45467a74b494947/scipy-1.17.1-cp314-cp314-win_arm64.whl", hash = "sha256:f8885db0bc2bffa59d5c1b72fad7a6a92d3e80e7257f967dd81abb553a90d293", size = 25192877, upload-time = "2026-02-23T00:22:39.798Z" }, + { url = "https://files.pythonhosted.org/packages/96/ad/f8c414e121f82e02d76f310f16db9899c4fcde36710329502a6b2a3c0392/scipy-1.17.1-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:1cc682cea2ae55524432f3cdff9e9a3be743d52a7443d0cba9017c23c87ae2f6", size = 31949750, upload-time = "2026-02-23T00:21:42.289Z" }, + { url = "https://files.pythonhosted.org/packages/7c/b0/c741e8865d61b67c81e255f4f0a832846c064e426636cd7de84e74d209be/scipy-1.17.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:2040ad4d1795a0ae89bfc7e8429677f365d45aa9fd5e4587cf1ea737f927b4a1", size = 28585858, upload-time = "2026-02-23T00:21:47.706Z" }, + { url = "https://files.pythonhosted.org/packages/ed/1b/3985219c6177866628fa7c2595bfd23f193ceebbe472c98a08824b9466ff/scipy-1.17.1-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:131f5aaea57602008f9822e2115029b55d4b5f7c070287699fe45c661d051e39", size = 20757723, upload-time = "2026-02-23T00:21:52.039Z" }, + { url = "https://files.pythonhosted.org/packages/c0/19/2a04aa25050d656d6f7b9e7b685cc83d6957fb101665bfd9369ca6534563/scipy-1.17.1-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:9cdc1a2fcfd5c52cfb3045feb399f7b3ce822abdde3a193a6b9a60b3cb5854ca", size = 23043098, upload-time = "2026-02-23T00:21:56.185Z" }, + { url = "https://files.pythonhosted.org/packages/86/f1/3383beb9b5d0dbddd030335bf8a8b32d4317185efe495374f134d8be6cce/scipy-1.17.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e3dcd57ab780c741fde8dc68619de988b966db759a3c3152e8e9142c26295ad", size = 33030397, upload-time = "2026-02-23T00:22:01.404Z" }, + { url = "https://files.pythonhosted.org/packages/41/68/8f21e8a65a5a03f25a79165ec9d2b28c00e66dc80546cf5eb803aeeff35b/scipy-1.17.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a9956e4d4f4a301ebf6cde39850333a6b6110799d470dbbb1e25326ac447f52a", size = 35281163, upload-time = "2026-02-23T00:22:07.024Z" }, + { url = "https://files.pythonhosted.org/packages/84/8d/c8a5e19479554007a5632ed7529e665c315ae7492b4f946b0deb39870e39/scipy-1.17.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:a4328d245944d09fd639771de275701ccadf5f781ba0ff092ad141e017eccda4", size = 35116291, upload-time = "2026-02-23T00:22:12.585Z" }, + { url = "https://files.pythonhosted.org/packages/52/52/e57eceff0e342a1f50e274264ed47497b59e6a4e3118808ee58ddda7b74a/scipy-1.17.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a77cbd07b940d326d39a1d1b37817e2ee4d79cb30e7338f3d0cddffae70fcaa2", size = 37682317, upload-time = "2026-02-23T00:22:18.513Z" }, + { url = "https://files.pythonhosted.org/packages/11/2f/b29eafe4a3fbc3d6de9662b36e028d5f039e72d345e05c250e121a230dd4/scipy-1.17.1-cp314-cp314t-win_amd64.whl", hash = "sha256:eb092099205ef62cd1782b006658db09e2fed75bffcae7cc0d44052d8aa0f484", size = 37345327, upload-time = "2026-02-23T00:22:24.442Z" }, + { url = "https://files.pythonhosted.org/packages/07/39/338d9219c4e87f3e708f18857ecd24d22a0c3094752393319553096b98af/scipy-1.17.1-cp314-cp314t-win_arm64.whl", hash = "sha256:200e1050faffacc162be6a486a984a0497866ec54149a01270adc8a59b7c7d21", size = 25489165, upload-time = "2026-02-23T00:22:29.563Z" }, ] [[package]] @@ -3278,6 +3408,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, ] +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + [[package]] name = "syrupy" version = "5.0.0"