diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..2250cb6 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,48 @@ +name: Publish to crates.io + +on: + release: + types: [created] + +jobs: + publish: + runs-on: ubuntu-latest + + environment: + name: crates-io + url: https://crates.io/crates/libpgfmt + + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Install stable Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry and build + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-release-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-release- + + - name: Run tests + run: cargo test + + - name: Verify version matches tag + run: | + TAG="${GITHUB_REF#refs/tags/v}" + CARGO_VERSION=$(grep '^version' Cargo.toml | head -1 | sed 's/.*"\(.*\)"/\1/') + if [ "$TAG" != "$CARGO_VERSION" ]; then + echo "Tag v$TAG does not match Cargo.toml version $CARGO_VERSION" + exit 1 + fi + + - name: Publish to crates.io + run: cargo publish + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml new file mode 100644 index 0000000..42e747c --- /dev/null +++ b/.github/workflows/testing.yaml @@ -0,0 +1,46 @@ +name: Testing +on: + pull_request: + push: + branches: ["*"] + paths-ignore: + - "*.md" + tags-ignore: ["*"] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + rust: [stable, nightly] + + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Install Rust ${{ matrix.rust }} + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + components: clippy, rustfmt + + - name: Cache cargo registry and build + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ matrix.rust }}-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-${{ matrix.rust }}- + + - name: Check formatting + run: cargo fmt --check + + - name: Run clippy + run: cargo clippy -- -D warnings + + - name: Run tests + run: cargo test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8ac5135 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: + - repo: local + hooks: + - id: cargo-fmt + name: cargo fmt + entry: cargo fmt --check + language: system + types: [rust] + pass_filenames: false + + - id: cargo-clippy + name: cargo clippy + entry: cargo clippy -- -D warnings + language: system + types: [rust] + pass_filenames: false + + - id: cargo-test + name: cargo test + entry: cargo test + language: system + types: [rust] + pass_filenames: false + stages: [pre-push] diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..f1b11c8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,18 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +libpgfmt is a Rust library for formatting PostgreSQL-specific SQL and PL/pgSQL. + +## Build Commands + +```sh +cargo build +cargo test +cargo test # run a single test +cargo clippy # lint +cargo fmt --check # check formatting +cargo fmt # auto-format +``` diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..e8dec8c --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,248 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "cc" +version = "1.2.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "libpgfmt" +version = "1.0.0" +dependencies = [ + "pretty_assertions", + "tree-sitter", + "tree-sitter-postgres", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "pretty_assertions" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" +dependencies = [ + "diff", + "yansi", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "indexmap", + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tree-sitter" +version = "0.26.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7a6592b1aec0109df37b6bafea77eb4e61466e37b0a5a98bef4f89bfb81b7a2" +dependencies = [ + "cc", + "regex", + "regex-syntax", + "serde_json", + "streaming-iterator", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-language" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782" + +[[package]] +name = "tree-sitter-postgres" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b50731216f0c2594ce8e3070a85374412801f1d221d3c5796f2233bee4543388" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..7fd3905 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "libpgfmt" +version = "1.0.0" +edition = "2024" +rust-version = "1.88" +description = "A Rust library for formatting PostgreSQL SQL and PL/pgSQL" +license = "BSD-3-Clause" +repository = "https://github.com/gmr/libpgfmt" +keywords = ["postgresql", "sql", "formatter", "plpgsql"] +categories = ["text-processing", "development-tools"] + +[dependencies] +tree-sitter = "0.26" +tree-sitter-postgres = "0.1" + +[dev-dependencies] +pretty_assertions = "1" diff --git a/Justfile b/Justfile new file mode 100644 index 0000000..9f1db3d --- /dev/null +++ b/Justfile @@ -0,0 +1,64 @@ +# Default recipe: run checks +default: check + +# Run all checks (format, lint, test) +check: fmt-check lint test + +# Build the library +build: + cargo build + +# Run tests +test: + cargo test + +# Run clippy lints +lint: + cargo clippy -- -D warnings + +# Check formatting +fmt-check: + cargo fmt --check + +# Auto-format code +fmt: + cargo fmt + +# Run all checks then build in release mode +release-build: check + cargo build --release + +# Set the release version in Cargo.toml +set-version version: + #!/usr/bin/env bash + set -euo pipefail + current=$(grep '^version' Cargo.toml | head -1 | sed 's/.*"\(.*\)"/\1/') + if [ "{{version}}" = "$current" ]; then + echo "Version is already {{version}}" + exit 1 + fi + # Use a temp file for portability (BSD sed -i requires arg, GNU doesn't) + tmp=$(mktemp) + sed 's/^version = ".*"/version = "{{version}}"/' Cargo.toml > "$tmp" + mv "$tmp" Cargo.toml + cargo check + echo "Updated version: $current -> {{version}}" + +# Tag a release (sets version, commits, tags, pushes) +release version: (set-version version) + git add Cargo.toml Cargo.lock + git commit -m "Release v{{version}}" + git tag -a "v{{version}}" -m "v{{version}}" + git push origin main --tags + +# Publish to crates.io (dry run) +publish-dry: + cargo publish --dry-run + +# Publish to crates.io +publish: + cargo publish + +# Clean build artifacts +clean: + cargo clean diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..b97eee9 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,21 @@ +use std::fmt; + +/// Errors that can occur during formatting. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FormatError { + /// The tree-sitter parser could not be initialized. + Parser(String), + /// The input SQL contains a syntax error. + Syntax(String), +} + +impl fmt::Display for FormatError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FormatError::Parser(msg) => write!(f, "Parser error: {msg}"), + FormatError::Syntax(msg) => write!(f, "Syntax error: {msg}"), + } + } +} + +impl std::error::Error for FormatError {} diff --git a/src/formatter/expr.rs b/src/formatter/expr.rs new file mode 100644 index 0000000..7997716 --- /dev/null +++ b/src/formatter/expr.rs @@ -0,0 +1,1137 @@ +/// Expression formatting — converts expression AST nodes to inline SQL text. +use crate::node_helpers::{NodeExt, flatten_list}; +use tree_sitter::Node; + +use super::Formatter; + +/// SQL built-in aggregate and function names that should follow keyword casing. +/// User-defined function names are preserved as-is. +const SQL_BUILTIN_FUNCTIONS: &[&str] = &[ + "abs", + "avg", + "array_agg", + "bit_and", + "bit_or", + "bool_and", + "bool_or", + "cardinality", + "cast", + "ceil", + "ceiling", + "char_length", + "character_length", + "coalesce", + "concat", + "concat_ws", + "convert", + "corr", + "count", + "covar_pop", + "covar_samp", + "cume_dist", + "current_date", + "current_time", + "current_timestamp", + "date_part", + "date_trunc", + "dense_rank", + "every", + "exists", + "exp", + "extract", + "first_value", + "floor", + "format", + "generate_series", + "greatest", + "json_agg", + "json_object_agg", + "jsonb_agg", + "jsonb_object_agg", + "lag", + "last_value", + "lead", + "least", + "left", + "length", + "ln", + "localtime", + "localtimestamp", + "log", + "lower", + "lpad", + "ltrim", + "max", + "min", + "mod", + "now", + "nth_value", + "ntile", + "nullif", + "octet_length", + "overlay", + "percent_rank", + "position", + "power", + "rank", + "regexp_matches", + "regexp_replace", + "regexp_split_to_array", + "regexp_split_to_table", + "repeat", + "replace", + "reverse", + "right", + "round", + "row_number", + "rpad", + "rtrim", + "sign", + "split_part", + "sqrt", + "stddev", + "stddev_pop", + "stddev_samp", + "string_agg", + "strpos", + "substr", + "substring", + "sum", + "to_char", + "to_date", + "to_number", + "to_timestamp", + "translate", + "trim", + "trunc", + "unnest", + "upper", + "var_pop", + "var_samp", + "variance", + "width_bucket", + "xmlagg", +]; + +fn is_sql_builtin_function(name: &str) -> bool { + let lower = name.to_lowercase(); + SQL_BUILTIN_FUNCTIONS.contains(&lower.as_str()) +} + +/// PostgreSQL internal type name → standard SQL type name mapping. +const PG_TYPE_MAP: &[(&str, &str)] = &[ + ("bigint", "BIGINT"), + ("bigserial", "BIGSERIAL"), + ("bool", "BOOLEAN"), + ("boolean", "BOOLEAN"), + ("bytea", "BYTEA"), + ("char", "CHAR"), + ("character", "CHAR"), + ("character varying", "VARCHAR"), + ("date", "DATE"), + ("double precision", "DOUBLE PRECISION"), + ("float4", "REAL"), + ("float8", "DOUBLE PRECISION"), + ("int", "INTEGER"), + ("int2", "SMALLINT"), + ("int4", "INTEGER"), + ("int8", "BIGINT"), + ("integer", "INTEGER"), + ("interval", "INTERVAL"), + ("json", "JSON"), + ("jsonb", "JSONB"), + ("name", "NAME"), + ("numeric", "NUMERIC"), + ("oid", "OID"), + ("real", "REAL"), + ("serial", "SERIAL"), + ("serial4", "SERIAL"), + ("serial8", "BIGSERIAL"), + ("smallint", "SMALLINT"), + ("smallserial", "SMALLSERIAL"), + ("text", "TEXT"), + ("time", "TIME"), + ("timestamp", "TIMESTAMP"), + ("timestamptz", "TIMESTAMP WITH TIME ZONE"), + ("timetz", "TIME WITH TIME ZONE"), + ("trigger", "TRIGGER"), + ("uuid", "UUID"), + ("varchar", "VARCHAR"), + ("xml", "XML"), +]; + +impl<'a> Formatter<'a> { + /// Format any expression node into inline SQL text. + pub(crate) fn format_expr(&self, node: Node<'a>) -> String { + match node.kind() { + "a_expr" => self.format_a_expr(node), + "a_expr_prec" => self.format_a_expr_prec(node), + "c_expr" => self.format_c_expr(node), + "columnref" => self.format_columnref(node), + "AexprConst" => self.format_const(node), + "func_expr" | "func_application" => self.format_func(node), + "case_expr" => self.format_case_expr(node), + "target_el" => self.format_target_el(node), + "Typename" => self.format_typename(node), + "SimpleTypename" => self.format_simple_typename(node), + "select_with_parens" => self.format_select_with_parens(node), + "Sconst" | "string_literal" => self.format_string_const(node), + "Iconst" | "integer_literal" => self.text(node).to_string(), + "Fconst" | "float_literal" => self.text(node).to_string(), + "sortby" => self.format_sortby(node), + "identifier" => self.text(node).to_string(), + "type_function_name" => self.format_first_named_child(node), + "ColId" => self.format_col_id(node), + "ColLabel" => self.format_first_named_child(node), + "qualified_name" => self.format_qualified_name(node), + "indirection" => self.format_indirection(node), + "indirection_el" => self.format_indirection_el(node), + "attr_name" => self.format_first_named_child(node), + "relation_expr" => self.format_relation_expr(node), + "func_name" => self.format_func_name(node), + "Numeric" | "GenericType" => self.format_typename_inner(node), + // Unreserved keywords used as identifiers should preserve casing. + "unreserved_keyword" => self.text(node).to_string(), + "expr_list" => { + let items = flatten_list(node, "expr_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + formatted.join(", ") + } + "func_arg_expr" => self.format_first_named_child(node), + "opt_alias_clause" | "alias_clause" => self.format_alias(node), + "group_by_item" => self.format_first_named_child(node), + "ERROR" => self.text(node).to_string(), + _ if node.kind().starts_with("kw_") => self.format_keyword_node(node), + _ => { + // Fallback: reconstruct from children or use source text. + if node.named_child_count() == 0 { + self.text(node).to_string() + } else { + self.format_first_named_child(node) + } + } + } + } + + /// Format an a_expr node (the main expression type with operators). + fn format_a_expr(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + // Check if this a_expr contains an inline expr_list (e.g., IN (...)). + // If so, skip unnamed parens since we format them with the expr_list. + let has_expr_list = node.find_child("expr_list").is_some(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "a_expr_prec" | "a_expr" | "c_expr" => { + parts.push(self.format_expr(child)); + } + "kw_and" => parts.push(self.kw("AND")), + "kw_or" => parts.push(self.kw("OR")), + "kw_not" => parts.push(self.kw("NOT")), + "kw_is" => parts.push(self.kw("IS")), + "kw_null" => parts.push(self.kw("NULL")), + "kw_true" => parts.push(self.kw("TRUE")), + "kw_false" => parts.push(self.kw("FALSE")), + "kw_in" => parts.push(self.kw("IN")), + "kw_any" => parts.push(self.kw("ANY")), + "kw_all" => parts.push(self.kw("ALL")), + "kw_some" => parts.push(self.kw("SOME")), + "kw_like" => parts.push(self.kw("LIKE")), + "kw_ilike" => parts.push(self.kw("ILIKE")), + "kw_between" => parts.push(self.kw("BETWEEN")), + "kw_exists" => parts.push(self.kw("EXISTS")), + "kw_as" => parts.push(self.kw("AS")), + "select_with_parens" => { + parts.push(self.format_select_with_parens(child)); + } + "in_expr" => { + parts.push(self.format_in_expr(child)); + } + "expr_list" => { + // Inline expr_list (e.g., IN (a, b, c)) — format with parens. + let items = flatten_list(child, "expr_list"); + let formatted: Vec<_> = + items.iter().map(|i| self.format_expr(*i)).collect(); + parts.push(format!("({})", formatted.join(", "))); + } + "qual_all_Op" | "all_Op" | "MathOp" | "sub_type" | "qual_Op" => { + let op_text = self.text(child).trim(); + let normalized = if op_text == "!=" { "<>" } else { op_text }; + parts.push(normalized.to_string()); + } + _ if child.kind().starts_with("kw_") => { + parts.push(self.format_keyword_node(child)); + } + _ => parts.push(self.format_expr(child)), + } + } else { + // Unnamed children are operators like =, <, >, !=, etc. + let text = self.text(child).trim(); + if !text.is_empty() { + // Skip parens that surround an expr_list (handled inline). + if has_expr_list && (text == "(" || text == ")") { + continue; + } + // Normalize != to <> + let op = if text == "!=" { "<>" } else { text }; + parts.push(op.to_string()); + } + } + } + Self::join_with_multiline_indent(&parts) + } + + /// Join parts with spaces, properly indenting multi-line parts. + /// When a part contains newlines, continuation lines are indented + /// to align with where that part starts in the joined output. + fn join_with_multiline_indent(parts: &[String]) -> String { + if parts.is_empty() { + return String::new(); + } + // Fast path: no multi-line parts. + if !parts.iter().any(|p| p.contains('\n')) { + return parts.join(" "); + } + + let mut result = String::new(); + let mut col = 0usize; + for (i, part) in parts.iter().enumerate() { + if i > 0 { + result.push(' '); + col += 1; + } + if part.contains('\n') { + // This part starts at column `col` in the output. + // Indent continuation lines by `col` spaces. + let indent_str = " ".repeat(col); + let mut lines = part.lines(); + if let Some(first) = lines.next() { + result.push_str(first); + // Update col to end of first line (for any subsequent parts). + col += first.len(); + } + for line in lines { + result.push('\n'); + result.push_str(&indent_str); + result.push_str(line); + } + } else { + result.push_str(part); + col += part.len(); + } + } + result + } + + fn format_a_expr_prec(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + parts.push(self.format_expr(child)); + } else { + let text = self.text(child).trim(); + if !text.is_empty() { + let op = if text == "!=" { "<>" } else { text }; + parts.push(op.to_string()); + } + } + } + Self::join_with_multiline_indent(&parts) + } + + fn format_c_expr(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut has_block_subquery = false; + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "columnref" => parts.push(self.format_columnref(child)), + "AexprConst" => parts.push(self.format_const(child)), + "func_expr" | "func_application" => parts.push(self.format_func(child)), + "case_expr" => parts.push(self.format_case_expr(child)), + "select_with_parens" => { + let formatted = self.format_select_with_parens(child); + if formatted.starts_with("(\n") { + has_block_subquery = true; + } + parts.push(formatted); + } + "kw_exists" => parts.push(self.kw("EXISTS")), + "kw_row" => parts.push(self.kw("ROW")), + _ if child.kind().starts_with("kw_") => { + parts.push(self.format_keyword_node(child)); + } + _ => parts.push(self.format_expr(child)), + } + } else { + let text = self.text(child).trim(); + if !text.is_empty() { + parts.push(text.to_string()); + } + } + } + + // For block-format subqueries (left-aligned styles), join with simple + // spaces without column-based multiline indentation; the subquery + // already has proper internal indentation. + let result = if has_block_subquery { + parts.join(" ") + } else { + Self::join_with_multiline_indent(&parts) + }; + + // Clean up double spaces on each line, preserving leading whitespace + // and spaces inside quoted strings. + result + .lines() + .map(|line| { + let leading = line.len() - line.trim_start().len(); + let prefix = &line[..leading]; + let cleaned = collapse_whitespace_outside_quotes(&line[leading..]); + format!("{prefix}{cleaned}") + }) + .collect::>() + .join("\n") + } + + fn format_columnref(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "ColId" => parts.push(self.format_col_id(child)), + "indirection" => parts.push(self.format_indirection(child)), + _ => parts.push(self.format_expr(child)), + } + } + parts.join("") + } + + pub(crate) fn format_col_id(&self, node: Node<'a>) -> String { + let mut cursor = node.walk(); + if let Some(child) = node.named_children(&mut cursor).next() { + return match child.kind() { + "identifier" | "unreserved_keyword" => self.text(child).to_string(), + _ => self.format_expr(child), + }; + } + self.text(node).to_string() + } + + fn format_indirection(&self, node: Node<'a>) -> String { + let mut result = String::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + result.push_str(&self.format_indirection_el(child)); + } + result + } + + fn format_indirection_el(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + parts.push(self.format_expr(child)); + } else { + parts.push(self.text(child).to_string()); + } + } + parts.join("") + } + + fn format_const(&self, node: Node<'a>) -> String { + let mut cursor = node.walk(); + if let Some(child) = node.named_children(&mut cursor).next() { + return match child.kind() { + "Sconst" | "string_literal" => self.format_string_const(child), + "Iconst" | "integer_literal" | "Fconst" | "float_literal" => { + self.text(child).to_string() + } + "kw_true" => self.kw("TRUE"), + "kw_false" => self.kw("FALSE"), + "kw_null" => self.kw("NULL"), + _ if child.kind().starts_with("kw_") => self.format_keyword_node(child), + _ => self.format_expr(child), + }; + } + self.text(node).to_string() + } + + fn format_string_const(&self, node: Node<'a>) -> String { + let mut cursor = node.walk(); + if let Some(child) = node + .named_children(&mut cursor) + .find(|c| c.kind() == "string_literal") + { + return self.text(child).to_string(); + } + self.text(node).to_string() + } + + pub(crate) fn format_func(&self, node: Node<'a>) -> String { + match node.kind() { + "func_expr" => { + if let Some(app) = node.find_child("func_application") { + return self.format_func(app); + } + // func_expr_common_subexpr or other variants. + self.format_func_expr_common(node) + } + "func_application" => self.format_func_application(node), + _ => self.text(node).to_string(), + } + } + + fn format_func_application(&self, node: Node<'a>) -> String { + let name = node + .find_child("func_name") + .map(|n| self.format_func_name(n)) + .unwrap_or_default(); + + // Check for special forms: COUNT(*), etc. + let mut cursor = node.walk(); + let children: Vec<_> = node.children(&mut cursor).collect(); + + let mut args = String::new(); + let mut has_star = false; + let mut has_distinct = false; + let mut over_clause = None; + + for child in &children { + if !child.is_named() { + let text = self.text(*child); + if text == "*" { + has_star = true; + } + } else { + match child.kind() { + "func_arg_list" => { + let items = flatten_list(*child, "func_arg_list"); + let formatted: Vec<_> = + items.iter().map(|i| self.format_expr(*i)).collect(); + args = formatted.join(", "); + } + "distinct_clause" | "kw_distinct" => has_distinct = true, + "over_clause" => over_clause = Some(*child), + "func_name" => {} // already handled + _ => {} + } + } + } + + // Apply keyword casing only to SQL built-in functions; preserve + // user-defined function names as-is. + let cased_name = if is_sql_builtin_function(&name) { + self.kw(&name) + } else { + name + }; + let inner = if has_star { + "*".to_string() + } else if has_distinct { + format!("{} {args}", self.kw("DISTINCT")) + } else { + args + }; + + let mut result = format!("{cased_name}({inner})"); + + if let Some(over) = over_clause { + result.push(' '); + result.push_str(&self.format_over_clause(over)); + } + + result + } + + fn format_func_expr_common(&self, node: Node<'a>) -> String { + // Handle COALESCE, GREATEST, LEAST, NULLIF, CURRENT_TIMESTAMP, etc. + let mut cursor = node.walk(); + let mut parts = Vec::new(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "func_application" => return self.format_func(child), + _ if child.kind().starts_with("kw_") => { + parts.push(self.format_keyword_node(child)); + } + _ => parts.push(self.format_expr(child)), + } + } else { + let text = self.text(child).trim(); + if !text.is_empty() { + parts.push(text.to_string()); + } + } + } + parts.join(" ") + } + + pub(crate) fn format_func_name(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "type_function_name" => parts.push(self.format_first_named_child(child)), + "ColId" => parts.push(self.format_col_id(child)), + "indirection" => parts.push(self.format_indirection(child)), + _ => parts.push(self.format_expr(child)), + } + } + parts.join("") + } + + fn format_over_clause(&self, node: Node<'a>) -> String { + let mut parts = vec![self.kw("OVER")]; + parts.push("(".to_string()); + + let mut inner = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "opt_partition_clause" => { + inner.push(self.format_partition_clause(child)); + } + "opt_sort_clause" | "sort_clause" => { + inner.push(self.format_sort_clause_inline(child)); + } + "kw_over" => {} // skip + _ => inner.push(self.format_expr(child)), + } + } + parts.push(inner.join(" ")); + parts.push(")".to_string()); + parts.join("") + } + + fn format_partition_clause(&self, node: Node<'a>) -> String { + let mut parts = vec![self.kw("PARTITION"), self.kw("BY")]; + if let Some(list) = node.find_child("expr_list") { + let items = flatten_list(list, "expr_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + parts.push(formatted.join(", ")); + } + parts.join(" ") + } + + fn format_sort_clause_inline(&self, node: Node<'a>) -> String { + let actual = if node.kind() == "opt_sort_clause" { + node.find_child("sort_clause").unwrap_or(node) + } else { + node + }; + let mut parts = vec![self.kw("ORDER"), self.kw("BY")]; + if let Some(list) = actual.find_child("sortby_list") { + let items = flatten_list(list, "sortby_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_sortby(*i)).collect(); + parts.push(formatted.join(", ")); + } + parts.join(" ") + } + + pub(crate) fn format_sortby(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "a_expr" | "c_expr" => parts.push(self.format_expr(child)), + "opt_asc_desc" => { + if let Some(kw) = child.find_child_any(&["kw_asc", "kw_desc"]) { + parts.push(self.format_keyword_node(kw)); + } + } + "opt_nulls_order" => { + parts.push(self.kw("NULLS")); + if child.has_child("kw_first") { + parts.push(self.kw("FIRST")); + } else { + parts.push(self.kw("LAST")); + } + } + _ => {} + } + } + parts.join(" ") + } + + fn format_case_expr(&self, node: Node<'a>) -> String { + let mut parts = vec![self.kw("CASE")]; + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "kw_case" | "kw_end" => {} + "case_arg" => { + if let Some(expr) = child.find_child_any(&["a_expr", "c_expr"]) { + parts.push(self.format_expr(expr)); + } + } + "when_clause_list" => { + let clauses = flatten_list(child, "when_clause_list"); + for clause in clauses { + parts.push(self.format_when_clause(clause)); + } + } + "case_default" => { + if let Some(expr) = child.find_child_any(&["a_expr", "c_expr", "a_expr_prec"]) { + parts.push(self.kw("ELSE")); + parts.push(self.format_expr(expr)); + } + } + _ => {} + } + } + parts.push(self.kw("END")); + parts.join(" ") + } + + fn format_when_clause(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let named = node.named_children_vec(); + for child in &named { + match child.kind() { + "kw_when" => parts.push(self.kw("WHEN")), + "kw_then" => parts.push(self.kw("THEN")), + "a_expr" | "c_expr" | "a_expr_prec" => { + parts.push(self.format_expr(*child)); + } + _ => {} + } + } + parts.join(" ") + } + + fn format_in_expr(&self, node: Node<'a>) -> String { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "select_with_parens" => { + return self.format_select_with_parens(child); + } + "expr_list" => { + let items = flatten_list(child, "expr_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + return format!("({})", formatted.join(", ")); + } + _ => {} + } + } + self.text(node).to_string() + } + + pub(crate) fn format_select_with_parens(&self, node: Node<'a>) -> String { + // Contains a sub-SELECT in parentheses. + if let Some(snp) = node.find_child("select_no_parens") { + let inner = self.format_select_no_parens(snp); + let lines: Vec<&str> = inner.lines().collect(); + if lines.len() <= 1 { + return format!("({inner})"); + } + + if !self.config.river { + // Left-aligned styles: block format with opening paren on its own, + // indented body, and closing paren on its own line. + let indent = self.config.indent; + let mut result = String::from("(\n"); + for line in &lines { + if line.is_empty() { + result.push('\n'); + } else { + result.push_str(indent); + result.push_str(line); + result.push('\n'); + } + } + result.push(')'); + return result; + } + + // River styles: inline with continuation lines indented after '('. + let mut result = format!("({}", lines[0]); + let paren_indent = " "; + for line in &lines[1..] { + result.push('\n'); + result.push_str(paren_indent); + result.push_str(line); + } + result.push(')'); + result + } else { + format!("({})", self.text(node).trim()) + } + } + + pub(crate) fn format_target_el(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "a_expr" | "c_expr" => parts.push(self.format_expr(child)), + "kw_as" => parts.push(self.kw("AS")), + "ColLabel" => parts.push(self.format_expr(child)), + _ => parts.push(self.format_expr(child)), + } + } else { + let text = self.text(child).trim(); + if text == "*" { + parts.push("*".to_string()); + } + } + } + Self::join_with_multiline_indent(&parts) + } + + pub(crate) fn format_typename(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + let mut has_setof = false; + let mut has_array = false; + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "SimpleTypename" => parts.push(self.format_simple_typename(child)), + "kw_setof" => { + has_setof = true; + } + "opt_array_bounds" => has_array = true, + _ => parts.push(self.format_expr(child)), + } + } + } + let mut result = String::new(); + if has_setof { + result.push_str(&self.kw("SETOF")); + result.push(' '); + } + result.push_str(&parts.join(" ")); + if has_array { + result.push_str("[]"); + } + result + } + + pub(crate) fn format_simple_typename(&self, node: Node<'a>) -> String { + let mut cursor = node.walk(); + if let Some(child) = node.named_children(&mut cursor).next() { + return match child.kind() { + "Numeric" | "GenericType" | "Bit" | "Character" | "ConstDatetime" + | "ConstInterval" => self.format_typename_inner(child), + _ => self.format_expr(child), + }; + } + self.text(node).to_string() + } + + fn format_typename_inner(&self, node: Node<'a>) -> String { + // Get the base type name. + let mut base = String::new(); + let mut modifiers = String::new(); + let mut extra_keywords = Vec::new(); + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "kw_integer" | "kw_int" | "kw_smallint" | "kw_bigint" | "kw_real" + | "kw_boolean" | "kw_float" | "kw_decimal" => { + base = self.map_type_name(&self.text(child).to_lowercase()); + } + "kw_double" => base = "DOUBLE".to_string(), + "kw_precision" => { + if base == "DOUBLE" { + base = "DOUBLE PRECISION".to_string(); + } else { + extra_keywords.push("PRECISION".to_string()); + } + } + "kw_varying" => extra_keywords.push("VARYING".to_string()), + "kw_with" => extra_keywords.push(self.kw("WITH")), + "kw_without" => extra_keywords.push(self.kw("WITHOUT")), + "kw_time" => { + if base.is_empty() { + base = self.kw("TIME"); + } else { + extra_keywords.push(self.kw("TIME")); + } + } + "kw_zone" => extra_keywords.push(self.kw("ZONE")), + "kw_timestamp" => base = self.kw("TIMESTAMP"), + "type_function_name" | "unreserved_keyword" => { + let name = self.format_first_named_child(child); + base = self.map_type_name(&name.to_lowercase()); + } + "opt_type_modifiers" => { + modifiers = self.format_type_modifiers(child); + } + "attrs" => { + base.push_str(&self.format_attrs(child)); + } + "opt_float" => { + // FLOAT(n) precision. + let text = self.text(child); + if !text.trim().is_empty() { + modifiers = text.to_string(); + } + } + _ if child.kind().starts_with("kw_") => { + let kw_text = self.text(child); + extra_keywords.push(self.kw(kw_text)); + } + _ => { + if base.is_empty() { + base = self.format_expr(child); + } + } + } + } else { + let text = self.text(child).trim(); + if text == "(" || text == ")" || text == "," { + // Part of modifiers — handled by opt_type_modifiers. + } + } + } + + let mut result = if self.config.upper_keywords { + base.to_uppercase() + } else { + base.to_lowercase() + }; + + if !extra_keywords.is_empty() { + result.push(' '); + result.push_str(&extra_keywords.join(" ")); + } + if !modifiers.is_empty() { + result.push_str(&modifiers); + } + result + } + + fn format_type_modifiers(&self, node: Node<'a>) -> String { + let mut items = Vec::new(); + if let Some(list) = node.find_child("expr_list") { + let exprs = flatten_list(list, "expr_list"); + for expr in exprs { + items.push(self.format_expr(expr)); + } + } + if items.is_empty() { + return String::new(); + } + format!("({})", items.join(", ")) + } + + fn format_attrs(&self, node: Node<'a>) -> String { + let mut result = String::new(); + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "attr_name" => result.push_str(&self.format_expr(child)), + _ => result.push_str(&self.format_expr(child)), + } + } else { + result.push_str(self.text(child)); + } + } + result + } + + fn map_type_name(&self, name: &str) -> String { + for (pg_name, std_name) in PG_TYPE_MAP { + if *pg_name == name { + return if self.config.upper_keywords { + std_name.to_string() + } else { + std_name.to_lowercase() + }; + } + } + // If not in the map, return the name with proper casing. + if self.config.upper_keywords { + name.to_uppercase() + } else { + name.to_lowercase() + } + } + + pub(crate) fn format_qualified_name(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "ColId" => parts.push(self.format_col_id(child)), + "indirection" => parts.push(self.format_indirection(child)), + "attr_name" => { + // Schema-qualified: schema.name + parts.push(format!(".{}", self.format_expr(child))); + } + _ => parts.push(self.format_expr(child)), + } + } else { + let text = self.text(child).trim(); + if text == "." { + parts.push(".".to_string()); + } + } + } + // Join without extra spaces (dots already included). + let result = parts.join(""); + // Clean up any double dots. + result.replace("..", ".") + } + + pub(crate) fn format_relation_expr(&self, node: Node<'a>) -> String { + if let Some(qn) = node.find_child("qualified_name") { + return self.format_qualified_name(qn); + } + self.text(node).to_string() + } + + fn format_alias(&self, node: Node<'a>) -> String { + if node.kind() == "opt_alias_clause" + && let Some(ac) = node.find_child("alias_clause") + { + return self.format_alias(ac); + } + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "kw_as" => parts.push(self.kw("AS")), + "ColId" => parts.push(self.format_col_id(child)), + _ => parts.push(self.format_expr(child)), + } + } + parts.join(" ") + } + + pub(crate) fn format_keyword_node(&self, node: Node<'a>) -> String { + self.kw(self.text(node)) + } + + fn format_first_named_child(&self, node: Node<'a>) -> String { + let mut cursor = node.walk(); + if let Some(child) = node.named_children(&mut cursor).next() { + return self.format_expr(child); + } + self.text(node).to_string() + } + + /// Format a table reference (for FROM clause), returning the table name with alias. + pub(crate) fn format_table_ref(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "relation_expr" => parts.push(self.format_relation_expr(child)), + "opt_alias_clause" | "alias_clause" => { + parts.push(self.format_alias(child)); + } + "joined_table" => return self.text(child).to_string(), // handled elsewhere + _ => parts.push(self.format_expr(child)), + } + } + parts.join(" ") + } +} + +/// Collapse runs of whitespace to single spaces, but preserve whitespace +/// inside single-quoted or double-quoted strings. +fn collapse_whitespace_outside_quotes(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + let mut in_single_quote = false; + let mut in_double_quote = false; + let mut prev_was_space = false; + + let chars: Vec = s.chars().collect(); + let len = chars.len(); + let mut i = 0; + + while i < len { + let ch = chars[i]; + + if in_single_quote { + result.push(ch); + if ch == '\'' { + // Check for escaped quote (''). + if i + 1 < len && chars[i + 1] == '\'' { + result.push('\''); + i += 2; + continue; + } + in_single_quote = false; + } + i += 1; + continue; + } + + if in_double_quote { + result.push(ch); + if ch == '"' { + if i + 1 < len && chars[i + 1] == '"' { + result.push('"'); + i += 2; + continue; + } + in_double_quote = false; + } + i += 1; + continue; + } + + if ch == '\'' { + in_single_quote = true; + prev_was_space = false; + result.push(ch); + } else if ch == '"' { + in_double_quote = true; + prev_was_space = false; + result.push(ch); + } else if ch == '$' { + // Dollar-quoted string: $$...$$ or $tag$...$tag$. + let tag_start = i; + let mut tag_end = i + 1; + while tag_end < len && (chars[tag_end].is_ascii_alphanumeric() || chars[tag_end] == '_') + { + tag_end += 1; + } + if tag_end < len && chars[tag_end] == '$' { + let tag: String = chars[tag_start..=tag_end].iter().collect(); + result.push_str(&tag); + i = tag_end + 1; + while i < len { + let remaining: String = chars[i..].iter().collect(); + if remaining.starts_with(&tag) { + result.push_str(&tag); + i += tag.len(); + break; + } + result.push(chars[i]); + i += 1; + } + prev_was_space = false; + continue; + } + // Not a dollar-quote, just a dollar sign. + prev_was_space = false; + result.push(ch); + } else if ch.is_whitespace() { + if !prev_was_space { + result.push(' '); + prev_was_space = true; + } + } else { + prev_was_space = false; + result.push(ch); + } + i += 1; + } + + result +} diff --git a/src/formatter/mod.rs b/src/formatter/mod.rs new file mode 100644 index 0000000..1b81e0c --- /dev/null +++ b/src/formatter/mod.rs @@ -0,0 +1,202 @@ +mod expr; +mod plpgsql; +mod select; +mod stmt; + +use crate::error::FormatError; +use crate::node_helpers::NodeExt; +use crate::style::Style; +use tree_sitter::Node; + +/// Configuration derived from the style. +#[derive(Debug, Clone)] +pub(crate) struct StyleConfig { + /// Keyword casing: true = UPPER, false = lower. + pub upper_keywords: bool, + /// Indentation string for left-aligned styles. + pub indent: &'static str, + /// Use leading commas instead of trailing. + pub leading_commas: bool, + /// JOINs participate in river alignment (AWeber, mattmc3). + pub joins_in_river: bool, + /// Always use explicit INNER JOIN (never plain JOIN). + pub explicit_inner_join: bool, + /// Insert blank lines between major clauses. + pub blank_lines_between_clauses: bool, + /// Use river (right-aligned keyword) layout. + pub river: bool, + /// Compact CTE chaining (Kickstarter: "), name AS ("). + pub compact_ctes: bool, + /// JOIN ON on same line as JOIN (Kickstarter). + pub join_on_same_line: bool, + /// Blank lines inside CTE bodies (GitLab). + pub blank_lines_in_ctes: bool, + /// Strip INNER keyword from INNER JOIN (mattmc3: use plain JOIN). + pub strip_inner_join: bool, +} + +impl StyleConfig { + pub fn from_style(style: Style) -> Self { + match style { + Style::River => Self { + upper_keywords: true, + indent: " ", + leading_commas: false, + joins_in_river: false, + explicit_inner_join: false, + blank_lines_between_clauses: false, + river: true, + compact_ctes: false, + join_on_same_line: false, + blank_lines_in_ctes: false, + strip_inner_join: false, + }, + Style::Mozilla => Self { + upper_keywords: true, + indent: " ", + leading_commas: false, + joins_in_river: false, + explicit_inner_join: false, + blank_lines_between_clauses: false, + river: false, + compact_ctes: false, + join_on_same_line: false, + blank_lines_in_ctes: false, + strip_inner_join: false, + }, + Style::Aweber => Self { + upper_keywords: true, + indent: " ", + leading_commas: false, + joins_in_river: true, + explicit_inner_join: false, + blank_lines_between_clauses: false, + river: true, + compact_ctes: false, + join_on_same_line: false, + blank_lines_in_ctes: false, + strip_inner_join: false, + }, + Style::Dbt => Self { + upper_keywords: false, + indent: " ", + leading_commas: false, + joins_in_river: false, + explicit_inner_join: true, + blank_lines_between_clauses: true, + river: false, + compact_ctes: false, + join_on_same_line: false, + blank_lines_in_ctes: false, + strip_inner_join: false, + }, + Style::Gitlab => Self { + upper_keywords: true, + indent: " ", + leading_commas: false, + joins_in_river: false, + explicit_inner_join: true, + blank_lines_between_clauses: false, + river: false, + compact_ctes: false, + join_on_same_line: false, + blank_lines_in_ctes: true, + strip_inner_join: false, + }, + Style::Kickstarter => Self { + upper_keywords: true, + indent: " ", + leading_commas: false, + joins_in_river: false, + explicit_inner_join: true, + blank_lines_between_clauses: false, + river: false, + compact_ctes: true, + join_on_same_line: true, + blank_lines_in_ctes: false, + strip_inner_join: false, + }, + Style::Mattmc3 => Self { + upper_keywords: false, + indent: " ", + leading_commas: true, + joins_in_river: true, + explicit_inner_join: false, + blank_lines_between_clauses: false, + river: true, + compact_ctes: false, + join_on_same_line: false, + blank_lines_in_ctes: false, + strip_inner_join: true, + }, + } + } +} + +/// The core SQL formatter. +pub(crate) struct Formatter<'a> { + pub source: &'a str, + pub style: Style, + pub config: StyleConfig, +} + +impl<'a> Formatter<'a> { + pub fn new(source: &'a str, style: Style) -> Self { + Self { + source, + style, + config: StyleConfig::from_style(style), + } + } + + /// Format the root `source_file` node containing one or more statements. + pub fn format_root(&self, root: Node<'a>) -> Result { + let mut results = Vec::new(); + let mut cursor = root.walk(); + for child in root.named_children(&mut cursor) { + if child.kind() == "toplevel_stmt" + && let Some(stmt) = child.find_child("stmt") + { + results.push(self.format_stmt(stmt)?); + } + } + if results.is_empty() { + return Ok(String::new()); + } + Ok(results.join("\n\n")) + } + + /// Format a PL/pgSQL root node. + pub fn format_plpgsql_root(&self, root: Node<'a>) -> Result { + if let Some(block) = root.find_child("pl_block") { + let mut body = self.format_plpgsql_block(block, 0); + // The outermost PL/pgSQL block should end with "END;" (semicolon + // before the closing $$ delimiter). + if !body.trim_end().ends_with(';') { + body.push(';'); + } + return Ok(body); + } + // Fallback: return normalized source. + Ok(root.text(self.source).to_string()) + } + + /// Apply keyword casing. + pub fn kw(&self, keyword: &str) -> String { + if self.config.upper_keywords { + keyword.to_uppercase() + } else { + keyword.to_lowercase() + } + } + + /// Format a two-word keyword pair like "GROUP BY" or "ORDER BY". + pub fn kw_pair(&self, first: &str, second: &str) -> String { + format!("{} {}", self.kw(first), self.kw(second)) + } + + /// Get the text of a node. + pub fn text(&self, node: Node<'a>) -> &'a str { + node.text(self.source) + } +} diff --git a/src/formatter/plpgsql.rs b/src/formatter/plpgsql.rs new file mode 100644 index 0000000..31fd8d2 --- /dev/null +++ b/src/formatter/plpgsql.rs @@ -0,0 +1,397 @@ +/// PL/pgSQL formatting. +use crate::node_helpers::NodeExt; +use tree_sitter::Node; + +use super::Formatter; + +impl<'a> Formatter<'a> { + /// Format a PL/pgSQL block. + pub(crate) fn format_plpgsql_block(&self, node: Node<'a>, indent_level: usize) -> String { + let indent = " ".repeat(indent_level); + let mut lines = Vec::new(); + + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "decl_sect" => { + lines.push(format!("{indent}{}", self.kw("DECLARE"))); + self.format_decl_sect(child, indent_level + 1, &mut lines); + } + "kw_begin" => { + lines.push(format!("{indent}{}", self.kw("BEGIN"))); + } + "proc_sect" => { + self.format_proc_sect(child, indent_level + 1, &mut lines); + } + "exception_sect" => { + lines.push(format!("{indent}{}", self.kw("EXCEPTION"))); + self.format_exception_sect(child, indent_level + 1, &mut lines); + } + "kw_end" => { + lines.push(format!("{indent}{}", self.kw("END"))); + } + _ => {} + } + } + + lines.join("\n") + } + + fn format_decl_sect(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + if child.kind() == "decl_stmt" { + self.format_decl_stmt(child, indent_level, lines); + } + } + } + + fn format_decl_stmt(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + if let Some(decl) = node.find_child("decl_statement") { + let var_name = decl + .find_child("decl_varname") + .map(|n| self.text(n).trim().to_string()) + .unwrap_or_default(); + + let mut parts = vec![var_name]; + + // Constant? + if decl.has_child("kw_constant") { + parts.push(self.kw("CONSTANT")); + } + + // Data type. + if let Some(dt) = decl.find_child("decl_datatype") { + let type_text = self.text(dt).trim().to_string(); + parts.push(type_text); + } + + // Collation. + if let Some(coll) = decl.find_child("decl_collate") { + let coll_text = self.text(coll).trim().to_string(); + parts.push(coll_text); + } + + // NOT NULL. + if decl.has_child("kw_not") { + parts.push(format!("{} {}", self.kw("NOT"), self.kw("NULL"))); + } + + // Default value. + if let Some(defval) = decl.find_child("decl_defval") { + let def_text = self.text(defval).trim().to_string(); + parts.push(def_text); + } + + lines.push(format!("{indent}{};", parts.join(" "))); + } + } + + fn format_proc_sect(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + if child.kind() == "proc_stmt" { + self.format_proc_stmt(child, indent_level, lines); + } + } + } + + fn format_proc_stmt(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "stmt_if" => self.format_stmt_if(child, indent_level, lines), + "stmt_loop" | "stmt_while" => self.format_stmt_loop(child, indent_level, lines), + "stmt_for" => self.format_stmt_for(child, indent_level, lines), + "stmt_foreach_a" => self.format_stmt_foreach(child, indent_level, lines), + "stmt_case" => self.format_stmt_case(child, indent_level, lines), + "stmt_return" => { + let expr = child + .find_child("sql_expression") + .map(|n| self.text(n).trim()) + .unwrap_or(""); + if expr.is_empty() { + lines.push(format!("{indent}{};", self.kw("RETURN"))); + } else { + lines.push(format!("{indent}{} {expr};", self.kw("RETURN"))); + } + } + "stmt_raise" => self.format_stmt_raise(child, indent_level, lines), + "stmt_null" => { + lines.push(format!("{indent}{};", self.kw("NULL"))); + } + // Statements whose source text is passed through with indentation. + "stmt_assign" | "stmt_execsql" | "stmt_perform" | "stmt_call" + | "stmt_dynexecute" | "stmt_exit" | "stmt_continue" | "stmt_open" + | "stmt_close" | "stmt_fetch" | "stmt_move" | "stmt_commit" | "stmt_rollback" + | "stmt_assert" => { + let text = self.text(child).trim(); + lines.push(format!("{indent}{text}")); + } + "pl_block" => { + let block_text = self.format_plpgsql_block(child, indent_level); + lines.push(block_text); + } + _ => { + let text = self.text(child).trim(); + if !text.is_empty() { + lines.push(format!("{indent}{text}")); + } + } + } + } + } + + fn format_stmt_if(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + let inner_indent_level = indent_level + 1; + + let mut cursor = node.walk(); + let children: Vec<_> = node.named_children(&mut cursor).collect(); + + let mut i = 0; + while i < children.len() { + let child = children[i]; + match child.kind() { + "kw_if" => { + if i == 0 { + // Opening IF. + let cond = children + .get(i + 1) + .filter(|c| c.kind() == "sql_expression") + .map(|c| self.text(*c).trim()) + .unwrap_or(""); + lines.push(format!( + "{indent}{} {cond} {}", + self.kw("IF"), + self.kw("THEN") + )); + i += 3; // skip cond and THEN + } + // Closing IF (END IF). + } + "sql_expression" => { + i += 1; // handled with IF/ELSIF + } + "kw_then" => { + i += 1; // handled with IF/ELSIF + } + "proc_sect" => { + self.format_proc_sect(child, inner_indent_level, lines); + i += 1; + } + "elsif_clause" => { + self.format_elsif_clause(child, indent_level, lines); + i += 1; + } + "else_clause" => { + lines.push(format!("{indent}{}", self.kw("ELSE"))); + if let Some(proc) = child.find_child("proc_sect") { + self.format_proc_sect(proc, inner_indent_level, lines); + } + i += 1; + } + "kw_end" => { + i += 1; // END IF handled below + } + _ => { + i += 1; + } + } + } + lines.push(format!("{indent}{} {};", self.kw("END"), self.kw("IF"))); + } + + fn format_elsif_clause(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + let inner_indent_level = indent_level + 1; + + let cond = node + .find_child("sql_expression") + .map(|n| self.text(n).trim()) + .unwrap_or(""); + lines.push(format!( + "{indent}{} {cond} {}", + self.kw("ELSIF"), + self.kw("THEN") + )); + + if let Some(proc) = node.find_child("proc_sect") { + self.format_proc_sect(proc, inner_indent_level, lines); + } + } + + fn format_stmt_loop(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + let inner_indent_level = indent_level + 1; + + // WHILE condition or just LOOP. + if node.kind() == "stmt_while" { + let cond = node + .find_child("sql_expression") + .map(|n| self.text(n).trim()) + .unwrap_or(""); + lines.push(format!( + "{indent}{} {cond} {}", + self.kw("WHILE"), + self.kw("LOOP") + )); + } else { + lines.push(format!("{indent}{}", self.kw("LOOP"))); + } + + if let Some(body) = node.find_child("loop_body") + && let Some(proc) = body.find_child("proc_sect") + { + self.format_proc_sect(proc, inner_indent_level, lines); + } + + lines.push(format!("{indent}{} {};", self.kw("END"), self.kw("LOOP"))); + } + + fn format_stmt_for(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + let inner_indent_level = indent_level + 1; + + let var = node + .find_child("for_variable") + .map(|n| self.text(n).trim()) + .unwrap_or(""); + + // Determine if it's a FOR ... IN range or FOR ... IN query. + let in_clause = if let Some(range) = node.find_child("for_integer_range") { + self.text(range).trim().to_string() + } else if let Some(query) = node.find_child("for_control") { + self.text(query).trim().to_string() + } else { + // Fallback: reconstruct from source. + let text = self.text(node); + if let Some(start) = text.find("IN") { + if let Some(end) = text.find("LOOP") { + text[start + 2..end].trim().to_string() + } else { + String::new() + } + } else { + String::new() + } + }; + + let for_kw = self.kw("FOR"); + let in_kw = self.kw("IN"); + let loop_kw = self.kw("LOOP"); + lines.push(format!( + "{indent}{for_kw} {var} {in_kw} {in_clause} {loop_kw}" + )); + + if let Some(body) = node.find_child("loop_body") + && let Some(proc) = body.find_child("proc_sect") + { + self.format_proc_sect(proc, inner_indent_level, lines); + } + + lines.push(format!("{indent}{} {};", self.kw("END"), self.kw("LOOP"))); + } + + fn format_stmt_foreach(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + // FOREACH ... SLICE ... IN ARRAY ... LOOP ... END LOOP + let text = self.text(node); + lines.push(format!("{indent}{text}")); + } + + fn format_stmt_case(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + let inner_indent_level = indent_level + 1; + let inner_indent = " ".repeat(inner_indent_level); + + let expr = node + .find_child("sql_expression") + .map(|n| self.text(n).trim()) + .unwrap_or(""); + lines.push(format!("{indent}{} {expr}", self.kw("CASE"))); + + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + if child.kind() == "case_when" { + let when_expr = child + .find_child("sql_expression") + .map(|n| self.text(n).trim()) + .unwrap_or(""); + lines.push(format!( + "{inner_indent}{} {when_expr} {}", + self.kw("WHEN"), + self.kw("THEN") + )); + if let Some(proc) = child.find_child("proc_sect") { + self.format_proc_sect(proc, inner_indent_level + 1, lines); + } + } else if child.kind() == "else_clause" { + lines.push(format!("{inner_indent}{}", self.kw("ELSE"))); + if let Some(proc) = child.find_child("proc_sect") { + self.format_proc_sect(proc, inner_indent_level + 1, lines); + } + } + } + + lines.push(format!("{indent}{} {};", self.kw("END"), self.kw("CASE"))); + } + + fn format_stmt_raise(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + let mut parts = vec![self.kw("RAISE")]; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "kw_raise" => {} // already handled + "raise_level" => { + let level = self.text(child).trim(); + parts.push(self.kw(level)); + } + "string_literal" => { + parts.push(self.text(child).to_string()); + } + "sql_expression" => { + parts.push(self.text(child).trim().to_string()); + } + _ => {} + } + } else { + let text = self.text(child).trim(); + if text == "," { + // Append comma to the previous part instead of adding a separate token. + if let Some(last) = parts.last_mut() { + last.push(','); + } + } + } + } + + lines.push(format!("{indent}{};", parts.join(" "))); + } + + fn format_exception_sect(&self, node: Node<'a>, indent_level: usize, lines: &mut Vec) { + let indent = " ".repeat(indent_level); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "proc_conditions" => { + let cond_text = self.text(child).trim(); + lines.push(format!( + "{indent}{} {cond_text} {}", + self.kw("WHEN"), + self.kw("THEN") + )); + } + "proc_sect" => { + self.format_proc_sect(child, indent_level + 1, lines); + } + _ => {} + } + } + } +} diff --git a/src/formatter/select.rs b/src/formatter/select.rs new file mode 100644 index 0000000..77a0fa4 --- /dev/null +++ b/src/formatter/select.rs @@ -0,0 +1,1626 @@ +/// SELECT statement formatting — the most complex formatter. +use crate::node_helpers::{NodeExt, flatten_list}; +use crate::style::Style; +use tree_sitter::Node; + +use super::Formatter; + +/// Collected clauses from a SELECT statement. +pub(crate) struct SelectClauses<'a> { + pub distinct: Option>, + pub targets: Vec>, + pub from: Option>, + pub where_clause: Option>, + pub group_clause: Option>, + pub having_clause: Option>, + pub sort_clause: Option>, + pub limit_clause: Option>, + pub offset_clause: Option>, + pub with_clause: Option>, + /// For UNION / INTERSECT / EXCEPT. + pub set_op: Option>, + /// VALUES clause (for INSERT ... VALUES). + pub values_clause: Option>, +} + +pub(crate) struct SetOp<'a> { + pub keyword: String, + pub quantifier: Option, + pub right: Node<'a>, + /// Pre-collected clauses for the right side (ERROR recovery). + pub right_clauses: Option>>, +} + +impl<'a> Formatter<'a> { + /// Format a SelectStmt node. + pub(crate) fn format_select_stmt(&self, node: Node<'a>) -> String { + let snp = node.find_child("select_no_parens").unwrap_or(node); + self.format_select_no_parens(snp) + } + + /// Format a select_no_parens node. + pub(crate) fn format_select_no_parens(&self, node: Node<'a>) -> String { + let clauses = self.collect_select_clauses(node); + if clauses.values_clause.is_some() { + return self.format_values_only(&clauses); + } + if self.config.river { + self.format_select_river(&clauses) + } else { + self.format_select_left_aligned(&clauses) + } + } + + /// Collect all clauses from a select_no_parens (or simple_select) node tree. + fn collect_select_clauses(&self, node: Node<'a>) -> SelectClauses<'a> { + let mut clauses = SelectClauses { + distinct: None, + targets: Vec::new(), + from: None, + where_clause: None, + group_clause: None, + having_clause: None, + sort_clause: None, + limit_clause: None, + offset_clause: None, + with_clause: None, + set_op: None, + values_clause: None, + }; + self.collect_clauses_recursive(node, &mut clauses); + clauses + } + + fn collect_clauses_recursive(&self, node: Node<'a>, clauses: &mut SelectClauses<'a>) { + let mut cursor = node.walk(); + let children: Vec<_> = node.named_children(&mut cursor).collect(); + let mut seen_set_op = false; + for child in &children { + // Once we've seen a set operation keyword, stop collecting into + // the left-side clauses. The right side is handled separately. + if seen_set_op { + continue; + } + match child.kind() { + "with_clause" => clauses.with_clause = Some(*child), + "simple_select" => self.collect_clauses_recursive(*child, clauses), + "select_clause" => self.collect_clauses_recursive(*child, clauses), + "distinct_clause" => clauses.distinct = Some(*child), + "target_list" => { + clauses.targets = flatten_list(*child, "target_list"); + } + "opt_target_list" => { + if let Some(tl) = child.find_child("target_list") { + clauses.targets = flatten_list(tl, "target_list"); + } + } + "from_clause" => clauses.from = Some(*child), + "where_clause" => clauses.where_clause = Some(*child), + "group_clause" => clauses.group_clause = Some(*child), + "having_clause" => clauses.having_clause = Some(*child), + "opt_sort_clause" => { + if let Some(sc) = child.find_child("sort_clause") { + clauses.sort_clause = Some(sc); + } + } + "sort_clause" => clauses.sort_clause = Some(*child), + "select_limit" => { + self.collect_limit_clauses(*child, clauses); + } + "limit_clause" => { + self.collect_limit_clauses(*child, clauses); + } + "offset_clause" => clauses.offset_clause = Some(*child), + "kw_union" | "kw_intersect" | "kw_except" => { + seen_set_op = true; + // Set operation — find the right side. + let keyword = self.kw(self.text(*child)); + // Look for quantifier (ALL/DISTINCT) and right select_clause. + let mut quantifier = None; + let mut found_keyword = false; + let mut right_clause = None; + for sib in &children { + if sib.id() == child.id() { + found_keyword = true; + continue; + } + if found_keyword { + match sib.kind() { + "set_quantifier" => { + quantifier = Some(self.format_set_quantifier(*sib)); + } + "select_clause" | "simple_select" => { + right_clause = Some(*sib); + break; + } + _ => {} + } + } + } + if let Some(right) = right_clause { + clauses.set_op = Some(SetOp { + keyword, + quantifier, + right, + right_clauses: None, + }); + } else { + // ERROR recovery: the right side tokens are loose children + // of the same node (not wrapped in select_clause). Collect + // the right-side clauses from children after the set op keyword. + let mut right_clauses = SelectClauses { + distinct: None, + targets: Vec::new(), + from: None, + where_clause: None, + group_clause: None, + having_clause: None, + sort_clause: None, + limit_clause: None, + offset_clause: None, + with_clause: None, + set_op: None, + values_clause: None, + }; + for sib in &children { + if sib.id() == child.id() { + found_keyword = true; + continue; + } + if !found_keyword { + continue; + } + match sib.kind() { + "set_quantifier" => {} // already handled + "opt_target_list" => { + if let Some(tl) = sib.find_child("target_list") { + right_clauses.targets = flatten_list(tl, "target_list"); + } + } + "target_list" => { + right_clauses.targets = flatten_list(*sib, "target_list"); + } + "from_clause" => right_clauses.from = Some(*sib), + "where_clause" => right_clauses.where_clause = Some(*sib), + "group_clause" => right_clauses.group_clause = Some(*sib), + "having_clause" => right_clauses.having_clause = Some(*sib), + "sort_clause" => right_clauses.sort_clause = Some(*sib), + _ => {} + } + } + clauses.set_op = Some(SetOp { + keyword, + quantifier, + right_clauses: Some(Box::new(right_clauses)), + right: node, // unused in this case + }); + } + } + "values_clause" => clauses.values_clause = Some(*child), + _ => {} + } + } + } + + fn collect_limit_clauses(&self, node: Node<'a>, clauses: &mut SelectClauses<'a>) { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "limit_clause" => { + clauses.limit_clause = Some(child); + } + "offset_clause" => { + clauses.offset_clause = Some(child); + } + "kw_limit" => { + // This limit_clause node itself is what we want. + clauses.limit_clause = Some(node); + } + "kw_offset" => { + clauses.offset_clause = Some(node); + } + _ => {} + } + } + } + + fn format_set_quantifier(&self, node: Node<'a>) -> String { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + if child.kind() == "kw_all" { + return self.kw("ALL"); + } + if child.kind() == "kw_distinct" { + return self.kw("DISTINCT"); + } + } + String::new() + } + + fn format_values_only(&self, clauses: &SelectClauses<'a>) -> String { + if let Some(vc) = clauses.values_clause { + return self.format_values_clause(vc); + } + String::new() + } + + // ── River-style SELECT ────────────────────────────────────────────── + + fn format_select_river(&self, clauses: &SelectClauses<'a>) -> String { + let mut lines = Vec::new(); + + // Calculate river width from all keywords that will appear. + let keywords = self.collect_river_keywords(clauses); + let river_width = keywords.iter().map(|k| k.len()).max().unwrap_or(6); + + // WITH clause. + if let Some(with) = clauses.with_clause { + lines.push(self.format_with_clause_river(with, river_width)); + } + + // SELECT [DISTINCT] targets. + let select_kw = if clauses.distinct.is_some() { + let distinct_text = clauses + .distinct + .map(|d| self.format_distinct(d)) + .unwrap_or_else(|| self.kw("DISTINCT")); + format!("{} {}", self.kw("SELECT"), distinct_text) + } else { + self.kw("SELECT") + }; + self.append_river_targets(&select_kw, &clauses.targets, river_width, &mut lines); + + // FROM clause with JOINs. + if let Some(from) = clauses.from { + self.format_from_river(from, river_width, &mut lines); + } + + // WHERE clause. + if let Some(where_c) = clauses.where_clause { + self.format_where_river(where_c, river_width, &mut lines); + } + + // GROUP BY clause. + if let Some(group) = clauses.group_clause { + self.format_group_by_river(group, river_width, &mut lines); + } + + // HAVING clause. + if let Some(having) = clauses.having_clause { + self.format_having_river(having, river_width, &mut lines); + } + + // ORDER BY clause. + if let Some(sort) = clauses.sort_clause { + self.format_order_by_river(sort, river_width, &mut lines); + } + + // LIMIT / OFFSET. + if let Some(limit) = clauses.limit_clause { + self.format_limit_river(limit, river_width, &mut lines); + } + if let Some(offset) = clauses.offset_clause { + self.format_offset_river(offset, river_width, &mut lines); + } + + let mut result = lines.join("\n"); + + // Set operations (UNION, INTERSECT, EXCEPT). + if let Some(ref set_op) = clauses.set_op { + let op_text = if let Some(ref q) = set_op.quantifier { + format!("{} {q}", set_op.keyword) + } else { + set_op.keyword.clone() + }; + result.push_str("\n\n"); + result.push_str(&op_text); + result.push_str("\n\n"); + if let Some(ref rc) = set_op.right_clauses { + result.push_str(&self.format_select_river(rc)); + } else { + let right_clauses = self.collect_select_clauses(set_op.right); + result.push_str(&self.format_select_river(&right_clauses)); + } + } + + result + } + + /// Collect all keywords that will appear in this SELECT for river width calculation. + fn collect_river_keywords(&self, clauses: &SelectClauses<'a>) -> Vec { + let mut keywords = Vec::new(); + + // For river width, always use just SELECT (not SELECT DISTINCT). + // DISTINCT is part of the content, not the river keyword. + keywords.push(self.kw("SELECT")); + + if clauses.from.is_some() { + keywords.push(self.kw("FROM")); + // Collect JOIN keywords if they participate in the river. + if let Some(from) = clauses.from { + self.collect_join_keywords_for_river(from, &mut keywords); + } + } + if clauses.where_clause.is_some() { + keywords.push(self.kw("WHERE")); + // AND/OR keywords participate in river for conditions. + if let Some(where_c) = clauses.where_clause { + self.collect_condition_keywords(where_c, &mut keywords); + } + } + if clauses.group_clause.is_some() { + keywords.push(self.kw_pair("GROUP", "BY")); + } + if clauses.having_clause.is_some() { + keywords.push(self.kw("HAVING")); + } + if clauses.sort_clause.is_some() { + keywords.push(self.kw_pair("ORDER", "BY")); + } + if clauses.limit_clause.is_some() { + keywords.push(self.kw("LIMIT")); + } + if clauses.offset_clause.is_some() { + keywords.push(self.kw("OFFSET")); + } + keywords + } + + fn collect_join_keywords_for_river(&self, from_node: Node<'a>, keywords: &mut Vec) { + if !self.config.joins_in_river { + return; + } + if let Some(from_list) = from_node.find_child("from_list") { + let tables = flatten_list(from_list, "from_list"); + for table in tables { + if table.kind() == "table_ref" { + if let Some(jt) = table.find_child("joined_table") { + self.collect_join_keywords_inner(jt, keywords); + } + } else { + self.collect_join_keywords_inner(table, keywords); + } + } + } + } + + fn collect_join_keywords_inner(&self, node: Node<'a>, keywords: &mut Vec) { + if node.kind() == "joined_table" { + // Get the join keyword. + let join_kw = self.get_join_keyword(node); + keywords.push(join_kw); + + // Check for ON/USING. + if let Some(qual) = node.find_child("join_qual") { + if qual.has_child("kw_on") { + keywords.push(self.kw("ON")); + // Collect AND keywords from the ON condition. + if let Some(expr) = qual.find_child_any(&["a_expr", "c_expr"]) { + self.collect_condition_keywords_from_expr(expr, keywords); + } + } else if qual.has_child("kw_using") { + keywords.push(self.kw("USING")); + } + } + + // Recurse into left side. + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + if child.kind() == "table_ref" + && let Some(jt) = child.find_child("joined_table") + { + self.collect_join_keywords_inner(jt, keywords); + } + } + } + } + + fn collect_condition_keywords(&self, clause_node: Node<'a>, keywords: &mut Vec) { + if let Some(expr) = clause_node.find_child_any(&["a_expr", "c_expr"]) { + self.collect_condition_keywords_from_expr(expr, keywords); + } + } + + fn collect_condition_keywords_from_expr(&self, node: Node<'a>, keywords: &mut Vec) { + if node.kind() == "a_expr" { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "kw_and" => keywords.push(self.kw("AND")), + "kw_or" => keywords.push(self.kw("OR")), + _ => {} + } + } + } + } + + /// Create a river-aligned line: right-align keyword in the given width. + pub(crate) fn river_line(&self, keyword: &str, content: &str, width: usize) -> String { + let padding = if keyword.len() < width { + " ".repeat(width - keyword.len()) + } else { + String::new() + }; + let first_line = format!("{padding}{keyword} "); + if content.contains('\n') { + // Multi-line content (e.g. subqueries): indent continuation lines + // to align with the content start column. + let indent = " ".repeat(first_line.len()); + let mut lines = content.lines(); + let mut result = format!("{first_line}{}", lines.next().unwrap_or("")); + for line in lines { + result.push('\n'); + result.push_str(&indent); + result.push_str(line); + } + result + } else { + format!("{first_line}{content}") + } + } + + /// Append target list items in river style. + fn append_river_targets( + &self, + select_kw: &str, + targets: &[Node<'a>], + width: usize, + lines: &mut Vec, + ) { + if targets.is_empty() { + lines.push(self.river_line(select_kw, "*", width)); + return; + } + + let first = self.format_target_el(targets[0]); + if targets.len() == 1 { + lines.push(self.river_line(select_kw, &first, width)); + return; + } + + // First item on the SELECT line. + if self.config.leading_commas { + lines.push(self.river_line(select_kw, &first, width)); + let content_col = width + 1; // where content starts + for target in &targets[1..] { + let formatted = self.format_target_el(*target); + let padding = " ".repeat(content_col - 2); + if formatted.contains('\n') { + let mut target_lines = formatted.lines(); + let first_target_line = target_lines.next().unwrap_or(""); + lines.push(format!("{padding}, {first_target_line}")); + let cont_padding = " ".repeat(content_col); + for line in target_lines { + lines.push(format!("{cont_padding}{line}")); + } + } else { + lines.push(format!("{padding}, {formatted}")); + } + } + } else { + lines.push(self.river_line(select_kw, &format!("{first},"), width)); + let content_col = width + 1; + for (i, target) in targets[1..].iter().enumerate() { + let formatted = self.format_target_el(*target); + let padding = " ".repeat(content_col); + let suffix = if i < targets.len() - 2 { "," } else { "" }; + if formatted.contains('\n') { + // Multi-line target: indent continuation lines. + let mut target_lines = formatted.lines(); + let first_target_line = target_lines.next().unwrap_or(""); + lines.push(format!("{padding}{first_target_line}")); + for line in target_lines { + lines.push(format!("{padding}{line}")); + } + // Add suffix to last line. + if !suffix.is_empty() + && let Some(last) = lines.last_mut() + { + last.push_str(suffix); + } + } else { + lines.push(format!("{padding}{formatted}{suffix}")); + } + } + } + } + + fn format_distinct(&self, node: Node<'a>) -> String { + let kw = self.kw("DISTINCT"); + // Check for DISTINCT ON (expr_list). + if let Some(list) = node.find_child("expr_list") { + let items = flatten_list(list, "expr_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + format!("{kw} ON ({})", formatted.join(", ")) + } else { + kw + } + } + + fn format_from_river(&self, from_node: Node<'a>, width: usize, lines: &mut Vec) { + if let Some(from_list) = from_node.find_child("from_list") { + let tables = flatten_list(from_list, "from_list"); + if tables.is_empty() { + return; + } + let has_multiple_non_join = tables + .iter() + .filter(|t| !(t.kind() == "table_ref" && t.has_child("joined_table"))) + .count() + > 1; + for (i, table) in tables.iter().enumerate() { + if table.kind() == "table_ref" && table.has_child("joined_table") { + // Table with JOINs. + let jt = table.find_child("joined_table").unwrap(); + self.format_joined_table_river(jt, width, i == 0, lines); + } else { + let mut text = self.format_table_ref(*table); + // Append comma between non-join tables in a comma-separated FROM list. + if has_multiple_non_join && i < tables.len() - 1 { + text = format!("{text},"); + } + if i == 0 { + lines.push(self.river_line(&self.kw("FROM"), &text, width)); + } else { + let content_col = width + 1; + let padding = " ".repeat(content_col); + lines.push(format!("{padding}{text}")); + } + } + } + } + } + + fn format_joined_table_river( + &self, + node: Node<'a>, + width: usize, + is_first_from: bool, + lines: &mut Vec, + ) { + // A joined_table has: left table_ref, join_type, JOIN, right table_ref, join_qual. + // Recursion: left table_ref may itself contain a joined_table. + let named = node.named_children_vec(); + + // Find components. + let mut left_table: Option = None; + let mut right_table: Option = None; + let mut join_type_node: Option = None; + let mut join_qual_node: Option = None; + let mut table_count = 0; + + for child in &named { + match child.kind() { + "table_ref" => { + if table_count == 0 { + left_table = Some(*child); + } else { + right_table = Some(*child); + } + table_count += 1; + } + "join_type" => join_type_node = Some(*child), + "join_qual" => join_qual_node = Some(*child), + _ => {} + } + } + + // Format left side first (may be recursive). + if let Some(left) = left_table { + if let Some(inner_jt) = left.find_child("joined_table") { + self.format_joined_table_river(inner_jt, width, is_first_from, lines); + } else { + let text = self.format_table_ref(left); + if is_first_from { + lines.push(self.river_line(&self.kw("FROM"), &text, width)); + } else { + let content_col = width + 1; + lines.push(format!("{}{text}", " ".repeat(content_col))); + } + } + } + + // Format the JOIN keyword and right table. + let join_kw = self.get_join_keyword(node); + + if let Some(right) = right_table { + let right_text = self.format_table_ref(right); + + if self.config.joins_in_river { + // JOINs participate in the river (AWeber/mattmc3 style). + lines.push(self.river_line(&join_kw, &right_text, width)); + } else if join_type_node.is_some() { + // Typed JOINs (INNER/LEFT/etc.) are indented under FROM content. + let content_col = width + 1; + let padding = " ".repeat(content_col); + // Blank line between typed JOINs (not before the first one). + // Detect prior JOIN by checking if any previous line contains a JOIN keyword. + let has_prior_join = lines.iter().any(|l| { + let trimmed = l.trim(); + trimmed.contains("JOIN ") + }); + if has_prior_join { + lines.push(String::new()); + } + lines.push(format!("{padding}{join_kw} {right_text}")); + } else { + // Plain JOIN is river-aligned (like FROM). + lines.push(self.river_line(&join_kw, &right_text, width)); + } + } + + // Format ON/USING. + let is_typed_join = join_type_node.is_some(); + if let Some(qual) = join_qual_node { + self.format_join_qual_river(qual, width, is_typed_join, lines); + } + } + + fn format_join_qual_river( + &self, + node: Node<'a>, + width: usize, + is_typed_join: bool, + lines: &mut Vec, + ) { + if let Some(on_expr) = node.find_child_any(&["a_expr", "c_expr"]) { + if self.config.joins_in_river { + // AWeber/mattmc3: ON participates in the river. + let conditions = self.split_top_level_conditions(on_expr); + if conditions.len() <= 1 { + lines.push(self.river_line(&self.kw("ON"), &conditions[0].1, width)); + } else { + lines.push(self.river_line(&self.kw("ON"), &conditions[0].1, width)); + for (op, cond_text) in &conditions[1..] { + lines.push(self.river_line(op, cond_text, width)); + } + } + } else if is_typed_join { + // Typed JOINs (INNER/LEFT/etc.): ON indented under FROM content. + let content_col = width + 1; + let conditions = self.split_top_level_conditions(on_expr); + let on_kw = self.kw("ON"); + let padding = " ".repeat(content_col); + lines.push(format!("{padding}{on_kw} {}", conditions[0].1)); + if conditions.len() > 1 { + // AND aligns with the start of ON's condition text. + let and_indent = " ".repeat(content_col + on_kw.len() + 1); + for (op, cond_text) in &conditions[1..] { + lines.push(format!("{and_indent}{op} {cond_text}")); + } + } + } else { + // Plain JOIN: ON is river-aligned. + let conditions = self.split_top_level_conditions(on_expr); + lines.push(self.river_line(&self.kw("ON"), &conditions[0].1, width)); + if conditions.len() > 1 { + for (op, cond_text) in &conditions[1..] { + lines.push(self.river_line(op, cond_text, width)); + } + } + } + } else if node.has_child("kw_using") { + // USING clause. + let using_text = self.format_using_clause(node); + if self.config.joins_in_river { + lines.push(self.river_line(&self.kw("USING"), &using_text, width)); + } else if is_typed_join { + let content_col = width + 1; + let padding = " ".repeat(content_col); + lines.push(format!("{padding}{} {using_text}", self.kw("USING"))); + } else { + lines.push(self.river_line(&self.kw("USING"), &using_text, width)); + } + } + } + + fn format_using_clause(&self, node: Node<'a>) -> String { + if let Some(list) = node.find_child("name_list") { + let items = flatten_list(list, "name_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + format!("({})", formatted.join(", ")) + } else if let Some(list) = node.find_child("columnList") { + let items = flatten_list(list, "columnList"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + format!("({})", formatted.join(", ")) + } else { + // Fallback. + let text = self.text(node); + if let Some(start) = text.find('(') { + text[start..].to_string() + } else { + text.to_string() + } + } + } + + pub(crate) fn format_where_river(&self, node: Node<'a>, width: usize, lines: &mut Vec) { + self.format_condition_clause_river(node, "WHERE", width, lines); + } + + fn format_having_river(&self, node: Node<'a>, width: usize, lines: &mut Vec) { + self.format_condition_clause_river(node, "HAVING", width, lines); + } + + /// Shared river-style formatting for WHERE and HAVING clauses. + fn format_condition_clause_river( + &self, + node: Node<'a>, + keyword: &str, + width: usize, + lines: &mut Vec, + ) { + if let Some(expr) = node.find_child_any(&["a_expr", "c_expr"]) { + let conditions = self.split_top_level_conditions(expr); + lines.push(self.river_line(&self.kw(keyword), &conditions[0].1, width)); + for (op, cond_text) in &conditions[1..] { + lines.push(self.river_line(op, cond_text, width)); + } + } + } + + fn format_group_by_river(&self, node: Node<'a>, width: usize, lines: &mut Vec) { + let kw = self.kw_pair("GROUP", "BY"); + if let Some(list) = node.find_child("group_by_list") { + let items = flatten_list(list, "group_by_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + lines.push(self.river_line(&kw, &formatted.join(", "), width)); + } + } + + fn format_order_by_river(&self, node: Node<'a>, width: usize, lines: &mut Vec) { + let kw = self.kw_pair("ORDER", "BY"); + if let Some(list) = node.find_child("sortby_list") { + let items = flatten_list(list, "sortby_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_sortby(*i)).collect(); + lines.push(self.river_line(&kw, &formatted.join(", "), width)); + } + } + + fn format_limit_river(&self, node: Node<'a>, width: usize, lines: &mut Vec) { + let value = self.extract_limit_value(node); + if !value.is_empty() { + lines.push(self.river_line(&self.kw("LIMIT"), &value, width)); + } + } + + fn format_offset_river(&self, node: Node<'a>, width: usize, lines: &mut Vec) { + let value = self.extract_offset_value(node); + if !value.is_empty() { + lines.push(self.river_line(&self.kw("OFFSET"), &value, width)); + } + } + + fn extract_limit_value(&self, node: Node<'a>) -> String { + if let Some(val) = node.find_child("select_limit_value") { + return self.format_expr(val); + } + // Try finding the expression directly. + if let Some(expr) = node.find_child_any(&["a_expr", "c_expr"]) { + return self.format_expr(expr); + } + String::new() + } + + fn extract_offset_value(&self, node: Node<'a>) -> String { + if let Some(val) = node.find_child("select_offset_value") { + return self.format_expr(val); + } + if let Some(val) = node.find_child("select_fetch_first_value") { + return self.format_expr(val); + } + if let Some(expr) = node.find_child_any(&["a_expr", "c_expr"]) { + return self.format_expr(expr); + } + String::new() + } + + /// Split a WHERE/HAVING expression into individual conditions separated by AND/OR. + /// + /// Returns a vector of (operator, formatted_text) pairs. + /// The first entry has operator "". + /// + /// The tree-sitter grammar can nest AND/OR deeper than expected due to + /// operator precedence. This method formats the full expression first, + /// then splits the resulting text on top-level AND/OR boundaries. + pub(crate) fn split_top_level_conditions(&self, node: Node<'a>) -> Vec<(String, String)> { + let full_text = self.format_expr(node); + let and_kw = self.kw("AND"); + let or_kw = self.kw("OR"); + + // Split on AND/OR that appear as whole words outside strings and parens. + // Handles: single-quoted strings (''), double-quoted identifiers (""), + // dollar-quoted strings ($$..$$, $tag$..$tag$), E'...' escape strings, + // and BETWEEN...AND (the AND in BETWEEN is not a boolean operator). + let between_kw = self.kw("BETWEEN"); + let mut conditions = Vec::new(); + let mut current = String::new(); + let mut current_op = String::new(); + let mut paren_depth: u32 = 0; + let mut in_between = false; + let chars: Vec = full_text.chars().collect(); + let len = chars.len(); + let mut i = 0; + let mut buf = String::new(); + + while i < len { + let ch = chars[i]; + + // Single-quoted string: 'text' with '' as escape. + if ch == '\'' { + buf.push(ch); + i += 1; + while i < len { + let c = chars[i]; + buf.push(c); + i += 1; + if c == '\'' { + if i < len && chars[i] == '\'' { + buf.push('\''); + i += 1; + } else { + break; + } + } + } + continue; + } + + // Double-quoted identifier: "name". + if ch == '"' { + buf.push(ch); + i += 1; + while i < len { + let c = chars[i]; + buf.push(c); + i += 1; + if c == '"' { + if i < len && chars[i] == '"' { + buf.push('"'); + i += 1; + } else { + break; + } + } + } + continue; + } + + // Dollar-quoted string: $$...$$ or $tag$...$tag$. + if ch == '$' { + let tag_start = i; + let mut tag_end = i + 1; + while tag_end < len + && (chars[tag_end].is_ascii_alphanumeric() || chars[tag_end] == '_') + { + tag_end += 1; + } + if tag_end < len && chars[tag_end] == '$' { + let tag: String = chars[tag_start..=tag_end].iter().collect(); + buf.push_str(&tag); + i = tag_end + 1; + // Scan for closing tag. + while i < len { + if chars[i] == '$' { + let remaining: String = chars[i..].iter().collect(); + if remaining.starts_with(&tag) { + buf.push_str(&tag); + i += tag.len(); + break; + } + } + buf.push(chars[i]); + i += 1; + } + continue; + } + // Not a dollar-quote; fall through. + } + + buf.push(ch); + i += 1; + + if ch == '(' { + paren_depth = paren_depth.saturating_add(1); + } else if ch == ')' { + paren_depth = paren_depth.saturating_sub(1); + } + if paren_depth > 0 { + continue; + } + + // Check for keyword at word boundary (next char is space or end). + let next_is_boundary = i >= len || chars[i] == ' '; + if !next_is_boundary { + continue; + } + + // Detect BETWEEN keyword (to skip the AND in BETWEEN...AND). + if buf.ends_with(&format!(" {between_kw}")) { + in_between = true; + continue; + } + + let is_and = buf.ends_with(&format!(" {and_kw}")); + let is_or = !is_and && buf.ends_with(&format!(" {or_kw}")); + + if is_and || is_or { + // Skip the AND that belongs to BETWEEN...AND. + if is_and && in_between { + in_between = false; + continue; + } + + let kw = if is_and { &and_kw } else { &or_kw }; + let kw_with_space = kw.len() + 1; // " " + kw + let cond_text = &buf[..buf.len() - kw_with_space]; + current.push_str(cond_text); + conditions.push((current_op.clone(), current.trim().to_string())); + current = String::new(); + current_op = kw.to_string(); + buf.clear(); + // Skip space after the keyword. + if i < len && chars[i] == ' ' { + i += 1; + } + } + } + // Push remaining text. + current.push_str(&buf); + let remaining = current.trim().to_string(); + if !remaining.is_empty() { + conditions.push((current_op, remaining)); + } + + if conditions.len() <= 1 { + vec![(String::new(), full_text)] + } else { + conditions + } + } + + /// Get the JOIN keyword for a joined_table node. + pub(crate) fn get_join_keyword(&self, node: Node<'a>) -> String { + let join_type = node.find_child("join_type"); + + if let Some(jt) = join_type { + let mut parts = Vec::new(); + let mut cursor = jt.walk(); + for child in jt.named_children(&mut cursor) { + match child.kind() { + "kw_inner" => { + if !self.config.strip_inner_join { + parts.push(self.kw("INNER")); + } + } + "kw_left" => parts.push(self.kw("LEFT")), + "kw_right" => parts.push(self.kw("RIGHT")), + "kw_full" => parts.push(self.kw("FULL")), + "kw_cross" => parts.push(self.kw("CROSS")), + "kw_natural" => parts.push(self.kw("NATURAL")), + "kw_outer" => {} // implicit, skip + _ => parts.push(self.format_keyword_node(child)), + } + } + parts.push(self.kw("JOIN")); + parts.join(" ") + } else { + // Plain JOIN — decide whether to make it INNER JOIN or leave as JOIN. + if self.config.explicit_inner_join { + format!("{} {}", self.kw("INNER"), self.kw("JOIN")) + } else { + self.kw("JOIN") + } + } + } + + // ── Left-aligned SELECT (Mozilla, dbt, GitLab, Kickstarter) ───────── + + fn format_select_left_aligned(&self, clauses: &SelectClauses<'a>) -> String { + let mut lines = Vec::new(); + let indent = self.config.indent; + let blank = self.config.blank_lines_between_clauses; + + // WITH clause. + if let Some(with) = clauses.with_clause { + lines.push(self.format_with_clause_left(with)); + // Blank line after CTE block before SELECT for all styles + // except compact CTEs (Kickstarter). + if !self.config.compact_ctes { + lines.push(String::new()); + } + } + + // SELECT [DISTINCT] targets. + let select_kw = if clauses.distinct.is_some() { + let distinct_text = clauses + .distinct + .map(|d| self.format_distinct(d)) + .unwrap_or_else(|| self.kw("DISTINCT")); + format!("{} {}", self.kw("SELECT"), distinct_text) + } else { + self.kw("SELECT") + }; + + if clauses.targets.len() <= 1 { + let target_text = clauses + .targets + .first() + .map(|t| self.format_target_el(*t)) + .unwrap_or_else(|| "*".to_string()); + lines.push(format!("{select_kw} {target_text}")); + } else { + lines.push(select_kw); + for (i, target) in clauses.targets.iter().enumerate() { + let formatted = self.format_target_el(*target); + if i < clauses.targets.len() - 1 { + lines.push(format!("{indent}{formatted},")); + } else { + lines.push(format!("{indent}{formatted}")); + } + } + } + + // FROM clause. + if let Some(from) = clauses.from { + if blank { + lines.push(String::new()); + } + self.format_from_left_aligned(from, &mut lines); + } + + // WHERE clause. + if let Some(where_c) = clauses.where_clause { + if blank { + lines.push(String::new()); + } + self.format_where_left_aligned(where_c, &mut lines); + } + + // GROUP BY. + if let Some(group) = clauses.group_clause { + if blank { + lines.push(String::new()); + } + self.format_group_by_left_aligned(group, &mut lines); + } + + // HAVING. + if let Some(having) = clauses.having_clause { + if blank { + lines.push(String::new()); + } + self.format_having_left_aligned(having, &mut lines); + } + + // ORDER BY. + if let Some(sort) = clauses.sort_clause { + if blank { + lines.push(String::new()); + } + self.format_order_by_left_aligned(sort, &mut lines); + } + + // LIMIT. + if let Some(limit) = clauses.limit_clause { + if blank { + lines.push(String::new()); + } + let value = self.extract_limit_value(limit); + if !value.is_empty() { + lines.push(format!("{} {value}", self.kw("LIMIT"))); + } + } + + // OFFSET. + if let Some(offset) = clauses.offset_clause { + let value = self.extract_offset_value(offset); + if !value.is_empty() { + lines.push(format!("{} {value}", self.kw("OFFSET"))); + } + } + + let mut result = lines.join("\n"); + + // Set operations. + if let Some(ref set_op) = clauses.set_op { + let op_text = if let Some(ref q) = set_op.quantifier { + format!("{} {q}", set_op.keyword) + } else { + set_op.keyword.clone() + }; + if blank { + result.push_str("\n\n"); + } else { + result.push('\n'); + } + result.push_str(&op_text); + result.push('\n'); + if let Some(ref rc) = set_op.right_clauses { + result.push_str(&self.format_select_left_aligned(rc)); + } else { + let right_clauses = self.collect_select_clauses(set_op.right); + result.push_str(&self.format_select_left_aligned(&right_clauses)); + } + } + + result + } + + fn format_from_left_aligned(&self, node: Node<'a>, lines: &mut Vec) { + let indent = self.config.indent; + if let Some(from_list) = node.find_child("from_list") { + let tables = flatten_list(from_list, "from_list"); + if tables.is_empty() { + return; + } + + let has_multiple_non_join = tables + .iter() + .filter(|t| !(t.kind() == "table_ref" && t.has_child("joined_table"))) + .count() + > 1; + + // Check if any table has joins. + let first = tables[0]; + if first.kind() == "table_ref" && first.has_child("joined_table") { + let jt = first.find_child("joined_table").unwrap(); + self.format_joined_table_left_aligned(jt, lines, true); + } else { + let mut text = self.format_table_ref(first); + if has_multiple_non_join && tables.len() > 1 { + text = format!("{text},"); + } + if tables.len() == 1 && !first.has_child("joined_table") { + lines.push(format!("{} {text}", self.kw("FROM"))); + } else { + lines.push(self.kw("FROM")); + lines.push(format!("{indent}{text}")); + } + } + + for (i, table) in tables[1..].iter().enumerate() { + if table.kind() == "table_ref" && table.has_child("joined_table") { + let jt = table.find_child("joined_table").unwrap(); + self.format_joined_table_left_aligned(jt, lines, false); + } else { + let mut text = self.format_table_ref(*table); + // Append comma if not the last non-join table. + if has_multiple_non_join && i < tables.len() - 2 { + text = format!("{text},"); + } + lines.push(format!("{indent}{text}")); + } + } + } + } + + fn format_joined_table_left_aligned( + &self, + node: Node<'a>, + lines: &mut Vec, + is_first: bool, + ) { + let indent = self.config.indent; + let named = node.named_children_vec(); + + let mut left_table: Option = None; + let mut right_table: Option = None; + let mut join_qual_node: Option = None; + let mut table_count = 0; + + for child in &named { + match child.kind() { + "table_ref" => { + if table_count == 0 { + left_table = Some(*child); + } else { + right_table = Some(*child); + } + table_count += 1; + } + "join_qual" => join_qual_node = Some(*child), + _ => {} + } + } + + // Format left side. + if let Some(left) = left_table { + if let Some(inner_jt) = left.find_child("joined_table") { + self.format_joined_table_left_aligned(inner_jt, lines, is_first); + } else { + let text = self.format_table_ref(left); + if is_first { + // GitLab/Kickstarter: FROM and table on same line. + // Mozilla/dbt: FROM on its own line, table indented. + if self.style == Style::Gitlab || self.style == Style::Kickstarter { + lines.push(format!("{} {text}", self.kw("FROM"))); + } else { + lines.push(self.kw("FROM")); + lines.push(format!("{indent}{text}")); + } + } else { + lines.push(format!("{indent}{text}")); + } + } + } + + // Format JOIN. + let join_kw = self.get_join_keyword(node); + + if let Some(right) = right_table { + let right_text = self.format_table_ref(right); + if self.config.join_on_same_line { + // Kickstarter: JOIN ... ON on same line. + let mut join_line = format!("{join_kw} {right_text}"); + if let Some(qual) = join_qual_node { + let qual_text = self.format_join_qual_inline(qual); + join_line.push_str(&format!(" {qual_text}")); + // If there are multiple AND conditions, wrap them. + lines.push(join_line); + // Additional conditions on indented lines. + self.format_extra_join_conditions(qual, lines); + return; + } + lines.push(join_line); + } else { + lines.push(join_kw.to_string()); + lines.push(format!("{indent}{right_text}")); + } + } + + // Format ON/USING (non-Kickstarter). + if !self.config.join_on_same_line + && let Some(qual) = join_qual_node + { + self.format_join_qual_left_aligned(qual, lines); + } + } + + fn format_join_qual_inline(&self, node: Node<'a>) -> String { + if node.has_child("kw_on") { + if let Some(expr) = node.find_child_any(&["a_expr", "c_expr"]) { + let conditions = self.split_top_level_conditions(expr); + if !conditions.is_empty() { + return format!("{} {}", self.kw("ON"), conditions[0].1); + } + } + } else if node.has_child("kw_using") { + let using_text = self.format_using_clause(node); + return format!("{} {using_text}", self.kw("USING")); + } + String::new() + } + + fn format_extra_join_conditions(&self, node: Node<'a>, lines: &mut Vec) { + if let Some(expr) = node.find_child_any(&["a_expr", "c_expr"]) { + let conditions = self.split_top_level_conditions(expr); + if conditions.len() > 1 { + let indent = self.config.indent; + for (op, cond_text) in &conditions[1..] { + lines.push(format!("{indent}{op} {cond_text}")); + } + } + } + } + + fn format_join_qual_left_aligned(&self, node: Node<'a>, lines: &mut Vec) { + let indent = self.config.indent; + if node.has_child("kw_on") { + if let Some(expr) = node.find_child_any(&["a_expr", "c_expr"]) { + let conditions = self.split_top_level_conditions(expr); + if conditions.len() <= 1 { + lines.push(format!("{indent}{} {}", self.kw("ON"), conditions[0].1)); + } else { + lines.push(format!("{indent}{} {}", self.kw("ON"), conditions[0].1)); + for (op, cond_text) in &conditions[1..] { + lines.push(format!("{indent}{op} {cond_text}")); + } + } + } + } else if node.has_child("kw_using") { + let using_text = self.format_using_clause(node); + lines.push(format!("{indent}{} {using_text}", self.kw("USING"))); + } + } + + pub(crate) fn format_where_left_aligned(&self, node: Node<'a>, lines: &mut Vec) { + let indent = self.config.indent; + if let Some(expr) = node.find_child_any(&["a_expr", "c_expr"]) { + let conditions = self.split_top_level_conditions(expr); + if conditions.len() <= 1 { + let text = &conditions[0].1; + if self.style == Style::Mozilla + || self.style == Style::Dbt + || self.style == Style::Gitlab + || self.style == Style::Kickstarter + { + lines.push(self.kw("WHERE")); + Self::push_indented_multiline(lines, indent, text); + } else { + lines.push(format!("{} {text}", self.kw("WHERE"))); + } + } else { + lines.push(self.kw("WHERE")); + Self::push_indented_multiline(lines, indent, &conditions[0].1); + for (op, cond_text) in &conditions[1..] { + Self::push_indented_multiline(lines, indent, &format!("{op} {cond_text}")); + } + } + } + } + + /// Push a potentially multi-line text with each line prefixed by indent. + fn push_indented_multiline(lines: &mut Vec, indent: &str, text: &str) { + for line in text.lines() { + if line.is_empty() { + lines.push(String::new()); + } else { + lines.push(format!("{indent}{line}")); + } + } + } + + fn format_having_left_aligned(&self, node: Node<'a>, lines: &mut Vec) { + self.format_condition_clause_left_aligned(node, "HAVING", lines); + } + + /// Shared left-aligned formatting for HAVING (and potentially other) clauses. + fn format_condition_clause_left_aligned( + &self, + node: Node<'a>, + keyword: &str, + lines: &mut Vec, + ) { + let indent = self.config.indent; + if let Some(expr) = node.find_child_any(&["a_expr", "c_expr"]) { + let conditions = self.split_top_level_conditions(expr); + lines.push(self.kw(keyword)); + lines.push(format!("{indent}{}", conditions[0].1)); + for (op, cond_text) in &conditions[1..] { + lines.push(format!("{indent}{op} {cond_text}")); + } + } + } + + fn format_group_by_left_aligned(&self, node: Node<'a>, lines: &mut Vec) { + let kw = self.kw_pair("GROUP", "BY"); + if let Some(list) = node.find_child("group_by_list") { + let items = flatten_list(list, "group_by_list"); + if items.len() <= 1 || self.style == Style::Kickstarter { + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + lines.push(format!("{kw} {}", formatted.join(", "))); + } else { + // dbt/GitLab: each on its own line. + lines.push(kw); + let indent = self.config.indent; + for (i, item) in items.iter().enumerate() { + let formatted = self.format_expr(*item); + if i < items.len() - 1 { + lines.push(format!("{indent}{formatted},")); + } else { + lines.push(format!("{indent}{formatted}")); + } + } + } + } + } + + fn format_order_by_left_aligned(&self, node: Node<'a>, lines: &mut Vec) { + let kw = self.kw_pair("ORDER", "BY"); + if let Some(list) = node.find_child("sortby_list") { + let items = flatten_list(list, "sortby_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_sortby(*i)).collect(); + lines.push(format!("{kw} {}", formatted.join(", "))); + } + } + + // ── WITH / CTE formatting ─────────────────────────────────────────── + + fn format_with_clause_river(&self, node: Node<'a>, river_width: usize) -> String { + let mut lines = Vec::new(); + if let Some(cte_list) = node.find_child("cte_list") { + let ctes = flatten_list(cte_list, "cte_list"); + for (i, cte) in ctes.iter().enumerate() { + let cte_text = self.format_cte_river(*cte, river_width); + if i == 0 { + lines.push(format!("{} {cte_text}", self.kw("WITH"))); + } else { + lines.push(cte_text); + } + } + } + lines.join(",\n") + } + + fn format_cte_river(&self, node: Node<'a>, _river_width: usize) -> String { + let name = node + .find_child("name") + .map(|n| self.format_expr(n)) + .unwrap_or_default(); + + let body = self.format_cte_body(node); + + format!("{name} {} (\n{body}\n)", self.kw("AS")) + } + + /// Extract and format the body of a CTE, handling SELECT, INSERT, UPDATE, + /// DELETE, and any other PreparableStmt type. + fn format_cte_body(&self, node: Node<'a>) -> String { + if let Some(prep) = node.find_child("PreparableStmt") { + if let Some(select) = prep.find_child("SelectStmt") { + return self.format_select_stmt(select); + } + if let Some(insert) = prep.find_child("InsertStmt") { + return self.format_insert_stmt(insert); + } + if let Some(update) = prep.find_child("UpdateStmt") { + return self.format_update_stmt(update); + } + if let Some(delete) = prep.find_child("DeleteStmt") { + return self.format_delete_stmt(delete); + } + // Fallback: return the raw text of the PreparableStmt. + return self.text(prep).trim().to_string(); + } + String::new() + } + + fn format_with_clause_left(&self, node: Node<'a>) -> String { + let mut lines = Vec::new(); + let indent = self.config.indent; + let blank_in_ctes = self.config.blank_lines_in_ctes; + + if let Some(cte_list) = node.find_child("cte_list") { + let ctes = flatten_list(cte_list, "cte_list"); + + if self.config.blank_lines_between_clauses { + // dbt style: with\n\nname as (\n...) + lines.push(format!("{}\n", self.kw("with"))); + } + + for (i, cte) in ctes.iter().enumerate() { + let name = cte + .find_child("name") + .map(|n| self.format_expr(n)) + .unwrap_or_default(); + + let body = self.format_cte_body(*cte); + + let indented_body = body + .lines() + .map(|l| { + if l.is_empty() { + String::new() + } else { + format!("{indent}{l}") + } + }) + .collect::>() + .join("\n"); + + let cte_prefix = if self.config.compact_ctes && i > 0 { + format!("), {name} {} (", self.kw("AS")) + } else { + let as_line = format!("{name} {} (", self.kw("AS")); + if i == 0 && !self.config.blank_lines_between_clauses { + format!("{} {as_line}", self.kw("WITH")) + } else { + as_line + } + }; + + if self.config.compact_ctes && i > 0 { + lines.push(cte_prefix); + } else { + if i > 0 && !self.config.compact_ctes { + // Close previous CTE. + // Already handled below. + } + lines.push(cte_prefix); + } + + if blank_in_ctes || self.config.blank_lines_between_clauses { + lines.push(String::new()); + } + lines.push(indented_body); + if blank_in_ctes || self.config.blank_lines_between_clauses { + lines.push(String::new()); + } + + if !self.config.compact_ctes { + lines.push(")".to_string()); + } + } + + if self.config.compact_ctes { + lines.push(")".to_string()); + } + } + lines.join("\n") + } + + // ── VALUES clause ─────────────────────────────────────────────────── + + pub(crate) fn format_values_clause(&self, node: Node<'a>) -> String { + // values_clause contains multiple (expr_list) groups. + let mut value_groups = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "values_clause" => { + // Recursive left-linked. + let inner_groups = self.collect_value_groups(child); + value_groups.extend(inner_groups); + } + "expr_list" => { + let items = flatten_list(child, "expr_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + value_groups.push(format!("({})", formatted.join(", "))); + } + _ => {} + } + } + let kw = self.kw("VALUES"); + if self.config.leading_commas { + let mut lines = vec![format!("{kw} {}", value_groups[0])]; + for group in &value_groups[1..] { + let padding = " ".repeat(kw.len() + 1 - 2); + lines.push(format!("{padding}, {group}")); + } + lines.join("\n") + } else { + let mut lines = Vec::new(); + if self.config.river { + // River style: VALUES aligned. + for (i, group) in value_groups.iter().enumerate() { + if i == 0 { + if value_groups.len() > 1 { + lines.push(format!("{kw} {group},")); + } else { + lines.push(format!("{kw} {group}")); + } + } else { + let padding = " ".repeat(kw.len() + 1); + if i < value_groups.len() - 1 { + lines.push(format!("{padding}{group},")); + } else { + lines.push(format!("{padding}{group}")); + } + } + } + } else { + // Left-aligned: VALUES on its own line. + lines.push(kw); + let indent = self.config.indent; + for (i, group) in value_groups.iter().enumerate() { + if i < value_groups.len() - 1 { + lines.push(format!("{indent}{group},")); + } else { + lines.push(format!("{indent}{group}")); + } + } + } + lines.join("\n") + } + } + + fn collect_value_groups(&self, node: Node<'a>) -> Vec { + let mut groups = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "values_clause" => { + groups.extend(self.collect_value_groups(child)); + } + "expr_list" => { + let items = flatten_list(child, "expr_list"); + let formatted: Vec<_> = items.iter().map(|i| self.format_expr(*i)).collect(); + groups.push(format!("({})", formatted.join(", "))); + } + _ => {} + } + } + groups + } +} diff --git a/src/formatter/stmt.rs b/src/formatter/stmt.rs new file mode 100644 index 0000000..3197a0f --- /dev/null +++ b/src/formatter/stmt.rs @@ -0,0 +1,976 @@ +/// Statement-level formatting: dispatches to specific statement formatters. +use crate::error::FormatError; +use crate::node_helpers::{NodeExt, flatten_list}; +use tree_sitter::Node; + +use super::Formatter; + +/// Classification of table elements for river-style CREATE TABLE. +enum TableElementKind { + /// PRIMARY KEY constraint (should be first). + PrimaryKey(String), + /// Column definition: (name, typename, constraints_text). + Column(String, String, String), + /// Table constraint: (optional_name, body). + Constraint(Option, String), +} + +impl<'a> Formatter<'a> { + /// Format a `stmt` node, dispatching based on the statement type. + pub(crate) fn format_stmt(&self, node: Node<'a>) -> Result { + let mut cursor = node.walk(); + if let Some(child) = node.named_children(&mut cursor).next() { + let result = match child.kind() { + "SelectStmt" => self.format_select_stmt(child), + "InsertStmt" => self.format_insert_stmt(child), + "UpdateStmt" => self.format_update_stmt(child), + "DeleteStmt" => self.format_delete_stmt(child), + "CreateStmt" => self.format_create_table_stmt(child), + "ViewStmt" => self.format_view_stmt(child), + "CreateFunctionStmt" => self.format_create_function_stmt(child), + "CreateDomainStmt" => self.format_create_domain_stmt(child), + "CreateForeignTableStmt" => self.format_create_foreign_table_stmt(child), + "CreateTableAsStmt" | "CreateMatViewStmt" => { + self.format_create_table_as_stmt(child) + } + _ => { + let text = self.text(child); + normalize_whitespace(text) + } + }; + let trimmed = result.trim_end_matches(';'); + // If the last line contains a line comment (--), appending ; + // directly would put the semicolon inside the comment. + let needs_newline = trimmed + .lines() + .last() + .map(|line| line.contains("--")) + .unwrap_or(false); + return if needs_newline { + Ok(format!("{trimmed}\n;")) + } else { + Ok(format!("{trimmed};")) + }; + } + Ok(String::new()) + } + + // ── INSERT ────────────────────────────────────────────────────────── + + pub(crate) fn format_insert_stmt(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + + // INSERT INTO target. + let target = node + .find_child("insert_target") + .map(|n| self.format_qualified_name_from(n)) + .unwrap_or_default(); + parts.push(format!( + "{} {} {target}", + self.kw("INSERT"), + self.kw("INTO") + )); + + // Column list. + let insert_rest = node.find_child("insert_rest"); + if let Some(rest) = insert_rest { + if let Some(col_list) = rest.find_child("insert_column_list") { + let cols = flatten_list(col_list, "insert_column_list"); + let formatted: Vec<_> = cols.iter().map(|c| self.format_expr(*c)).collect(); + parts[0] = format!("{} ({})", parts[0], formatted.join(", ")); + } + + // VALUES or SELECT. + if let Some(select) = rest.find_child("SelectStmt") { + let formatted = self.format_select_stmt(select); + // Check if it's VALUES or a sub-SELECT. + let select_text = formatted.trim_end_matches(';'); + let values_kw = self.kw("VALUES"); + let is_values = select_text.trim_start().starts_with(&values_kw); + + if is_values && self.config.river { + // River: VALUES aligned with INSERT INTO. + // Compute padding to right-align VALUES with INSERT INTO. + let insert_kw_len = parts[0] + .split(' ') + .take(2) + .collect::>() + .join(" ") + .len(); + let river_width = std::cmp::max(insert_kw_len, values_kw.len()); + + // Strip VALUES keyword and any pre-existing indentation from + // the formatter's multi-line output. + let raw_content = select_text.trim_start_matches(&values_kw); + let trimmed_lines: Vec<_> = raw_content.lines().map(|l| l.trim()).collect(); + let content = trimmed_lines.join("\n"); + + if self.config.leading_commas && trimmed_lines.len() > 1 { + // For leading commas, handle continuation lines manually: + // the `, ` replaces 2 chars of the indent padding. + let kw_padding = if values_kw.len() < river_width { + " ".repeat(river_width - values_kw.len()) + } else { + String::new() + }; + let first_line_content = trimmed_lines[0].trim(); + parts.push(format!("{kw_padding}{values_kw} {first_line_content}")); + let content_col = river_width + 1; // where content starts + for line in &trimmed_lines[1..] { + let trimmed = line.trim(); + if trimmed.starts_with(',') { + // Leading comma: put it 2 chars before content col. + let padding = " ".repeat(content_col - 2); + parts.push(format!("{padding}{trimmed}")); + } else if !trimmed.is_empty() { + let padding = " ".repeat(content_col); + parts.push(format!("{padding}{trimmed}")); + } + } + } else { + parts.push(self.river_line(&values_kw, content.trim(), river_width)); + } + } else { + // Real SELECT or non-river VALUES: emit as-is. + parts.push(select_text.to_string()); + } + } + } + + parts.join("\n") + } + + // ── UPDATE ────────────────────────────────────────────────────────── + + pub(crate) fn format_update_stmt(&self, node: Node<'a>) -> String { + let table = node + .find_child("relation_expr_opt_alias") + .map(|n| self.format_relation_expr_opt_alias(n)) + .unwrap_or_default(); + + let mut lines = Vec::new(); + + if self.config.river { + // Collect keywords for river width. + let mut keywords = vec![self.kw("UPDATE"), self.kw("SET")]; + if node.has_child("where_or_current_clause") { + keywords.push(self.kw("WHERE")); + } + let width = keywords.iter().map(|k| k.len()).max().unwrap_or(6); + + lines.push(self.river_line(&self.kw("UPDATE"), &table, width)); + + // SET clause. + if let Some(set_list) = node.find_child("set_clause_list") { + let clauses = flatten_list(set_list, "set_clause_list"); + let formatted: Vec<_> = + clauses.iter().map(|c| self.format_set_clause(*c)).collect(); + if formatted.len() == 1 { + lines.push(self.river_line(&self.kw("SET"), &formatted[0], width)); + } else if self.config.leading_commas { + // Leading commas: first item without comma, subsequent with leading ", ". + lines.push(self.river_line(&self.kw("SET"), &formatted[0], width)); + let content_col = width + 1; + for clause in &formatted[1..] { + let padding = " ".repeat(content_col - 2); + lines.push(format!("{padding}, {clause}")); + } + } else { + lines.push(self.river_line( + &self.kw("SET"), + &format!("{},", formatted[0]), + width, + )); + let content_col = width + 1; + for (i, clause) in formatted[1..].iter().enumerate() { + let padding = " ".repeat(content_col); + if i < formatted.len() - 2 { + lines.push(format!("{padding}{clause},")); + } else { + lines.push(format!("{padding}{clause}")); + } + } + } + } + + // WHERE clause. + if let Some(where_c) = node.find_child("where_or_current_clause") { + self.format_where_river(where_c, width, &mut lines); + } + } else { + lines.push(format!("{} {table}", self.kw("UPDATE"))); + + // SET clause. + let indent = self.config.indent; + if let Some(set_list) = node.find_child("set_clause_list") { + let clauses = flatten_list(set_list, "set_clause_list"); + let formatted: Vec<_> = + clauses.iter().map(|c| self.format_set_clause(*c)).collect(); + lines.push(self.kw("SET")); + for (i, clause) in formatted.iter().enumerate() { + if i < formatted.len() - 1 { + lines.push(format!("{indent}{clause},")); + } else { + lines.push(format!("{indent}{clause}")); + } + } + } + + // WHERE clause. + if let Some(where_c) = node.find_child("where_or_current_clause") { + self.format_where_left_aligned(where_c, &mut lines); + } + } + + lines.join("\n") + } + + fn format_set_clause(&self, node: Node<'a>) -> String { + let target = node + .find_child("set_target") + .map(|n| self.format_expr(n)) + .unwrap_or_default(); + let value = node + .find_child_any(&["a_expr", "c_expr"]) + .map(|n| self.format_expr(n)) + .unwrap_or_default(); + format!("{target} = {value}") + } + + // ── DELETE ────────────────────────────────────────────────────────── + + pub(crate) fn format_delete_stmt(&self, node: Node<'a>) -> String { + let table = node + .find_child("relation_expr_opt_alias") + .map(|n| self.format_relation_expr_opt_alias(n)) + .unwrap_or_default(); + + let mut lines = Vec::new(); + + if self.config.river { + let delete_kw = self.kw("DELETE"); + let mut keywords = vec![delete_kw.clone(), self.kw("FROM")]; + if node.has_child("where_or_current_clause") { + keywords.push(self.kw("WHERE")); + } + let width = keywords.iter().map(|k| k.len()).max().unwrap_or(6); + + lines.push(delete_kw); + lines.push(self.river_line(&self.kw("FROM"), &table, width)); + + if let Some(where_c) = node.find_child("where_or_current_clause") { + self.format_where_river(where_c, width, &mut lines); + } + } else { + lines.push(format!("{} {} {table}", self.kw("DELETE"), self.kw("FROM"))); + if let Some(where_c) = node.find_child("where_or_current_clause") { + self.format_where_left_aligned(where_c, &mut lines); + } + } + + lines.join("\n") + } + + // ── CREATE TABLE ──────────────────────────────────────────────────── + + fn format_create_table_stmt(&self, node: Node<'a>) -> String { + let table_name = node + .find_child("qualified_name") + .map(|n| self.format_qualified_name(n)) + .unwrap_or_default(); + + let mut lines = Vec::new(); + lines.push(format!( + "{} {} {table_name} (", + self.kw("CREATE"), + self.kw("TABLE") + )); + + // Column definitions and constraints. + if let Some(elem_list) = node + .find_child("OptTableElementList") + .and_then(|n| n.find_child("TableElementList")) + { + let elements = flatten_list(elem_list, "TableElementList"); + let indent = self.config.indent; + + if self.config.river { + // River style: PRIMARY KEY first, padded columns, constraint + // on separate indented line. + let mut pk_elements = Vec::new(); + let mut col_elements = Vec::new(); + let mut constraint_elements = Vec::new(); + + for e in &elements { + let elem = self.classify_table_element(*e); + match elem { + TableElementKind::PrimaryKey(text) => pk_elements.push(text), + TableElementKind::Column(name, typename, constraints) => { + col_elements.push((name, typename, constraints)); + } + TableElementKind::Constraint(name, body) => { + constraint_elements.push((name, body)); + } + } + } + + // Calculate max column name and type widths for alignment. + let max_name_len = col_elements + .iter() + .map(|(n, _, _)| n.len()) + .max() + .unwrap_or(0); + let max_type_len = col_elements + .iter() + .map(|(_, t, _)| t.len()) + .max() + .unwrap_or(0); + + // Build ordered list: PKs first, then columns, then constraints. + let mut all_items: Vec = Vec::new(); + for pk in &pk_elements { + all_items.push(pk.clone()); + } + for (name, typename, constraints) in &col_elements { + let padded_name = format!("{:width$}", name, width = max_name_len); + let padded_type = format!("{:width$}", typename, width = max_type_len); + let mut item = format!("{padded_name} {padded_type}"); + if !constraints.is_empty() { + item = format!("{item} {constraints}"); + } + all_items.push(item); + } + // Table constraints: CONSTRAINT name on one line, + // CHECK(...) on the next, both aligned with the type column. + for (name, body) in &constraint_elements { + let constraint_padding = " ".repeat(max_name_len + 1); + if let Some(cname) = name { + all_items.push(format!( + "{constraint_padding}{} {cname}\n{constraint_padding}{body}", + self.kw("CONSTRAINT") + )); + } else { + all_items.push(format!("{constraint_padding}{body}")); + } + } + + for (i, item) in all_items.iter().enumerate() { + let comma = if i < all_items.len() - 1 { "," } else { "" }; + if item.contains('\n') { + // Multi-line item (constraint): only add comma to last line. + let item_lines: Vec<&str> = item.lines().collect(); + for (j, line) in item_lines.iter().enumerate() { + if j == item_lines.len() - 1 { + lines.push(format!("{indent}{line}{comma}")); + } else { + lines.push(format!("{indent}{line}")); + } + } + } else { + lines.push(format!("{indent}{item}{comma}")); + } + } + } else { + let formatted: Vec<_> = elements + .iter() + .map(|e| self.format_table_element(*e)) + .collect(); + + for (i, elem) in formatted.iter().enumerate() { + let comma = if i < formatted.len() - 1 { "," } else { "" }; + lines.push(format!("{indent}{elem}{comma}")); + } + } + } + + lines.push(")".to_string()); + + // WITH clause for storage parameters. + if let Some(with) = node.find_child("OptWith") { + let text = self.text(with); + if !text.trim().is_empty() { + lines.push(format!("{} {}", self.kw("WITH"), text.trim())); + } + } + + lines.join("\n") + } + + /// Classify a table element for river-style CREATE TABLE formatting. + fn classify_table_element(&self, node: Node<'a>) -> TableElementKind { + match node.kind() { + "TableElement" => { + if let Some(col) = node.find_child("columnDef") { + let name = col + .find_child("ColId") + .map(|n| self.format_col_id(n)) + .unwrap_or_default(); + let typename = col + .find_child("Typename") + .map(|n| self.format_typename(n)) + .unwrap_or_default(); + let mut constraint_parts = Vec::new(); + if let Some(qual_list) = col.find_child("ColQualList") { + let mut cursor = qual_list.walk(); + for child in qual_list.named_children(&mut cursor) { + if child.kind() == "ColConstraint" { + constraint_parts.push(self.format_col_constraint(child)); + } + } + } + return TableElementKind::Column(name, typename, constraint_parts.join(" ")); + } + if let Some(constraint) = node.find_child("TableConstraint") { + return self.classify_table_constraint(constraint); + } + TableElementKind::Column(self.text(node).to_string(), String::new(), String::new()) + } + _ => { + TableElementKind::Column(self.text(node).to_string(), String::new(), String::new()) + } + } + } + + fn classify_table_constraint(&self, node: Node<'a>) -> TableElementKind { + let constraint_name = node.find_child("name").map(|n| self.format_expr(n)); + + if let Some(elem) = node.find_child("ConstraintElem") { + // Check if it's PRIMARY KEY. + let mut is_pk = false; + let mut cursor = elem.walk(); + for child in elem.named_children(&mut cursor) { + if child.kind() == "kw_primary" { + is_pk = true; + break; + } + } + if is_pk { + let formatted = self.format_constraint_elem(elem); + return if let Some(cname) = constraint_name { + TableElementKind::PrimaryKey(format!( + "{} {cname} {formatted}", + self.kw("CONSTRAINT") + )) + } else { + TableElementKind::PrimaryKey(formatted) + }; + } + let body = self.format_constraint_elem(elem); + return TableElementKind::Constraint(constraint_name, body); + } + TableElementKind::Constraint(constraint_name, self.text(node).to_string()) + } + + fn format_table_element(&self, node: Node<'a>) -> String { + match node.kind() { + "TableElement" => { + if let Some(col) = node.find_child("columnDef") { + return self.format_column_def(col); + } + if let Some(constraint) = node.find_child("TableConstraint") { + return self.format_table_constraint(constraint); + } + self.text(node).to_string() + } + _ => self.text(node).to_string(), + } + } + + fn format_column_def(&self, node: Node<'a>) -> String { + let name = node + .find_child("ColId") + .map(|n| self.format_col_id(n)) + .unwrap_or_default(); + let typename = node + .find_child("Typename") + .map(|n| self.format_typename(n)) + .unwrap_or_default(); + + let mut parts = vec![name, typename]; + + // Column constraints. + if let Some(qual_list) = node.find_child("ColQualList") { + let mut cursor = qual_list.walk(); + for child in qual_list.named_children(&mut cursor) { + if child.kind() == "ColConstraint" { + parts.push(self.format_col_constraint(child)); + } + } + } + + parts.join(" ") + } + + fn format_col_constraint(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + // Optional constraint name. + if let Some(name) = node.find_child("name") { + parts.push(self.kw("CONSTRAINT")); + parts.push(self.format_expr(name)); + } + if let Some(elem) = node.find_child("ColConstraintElem") { + parts.push(self.format_col_constraint_elem(elem)); + } + parts.join(" ") + } + + fn format_col_constraint_elem(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "kw_not" => parts.push(self.kw("NOT")), + "kw_null" => parts.push(self.kw("NULL")), + "kw_primary" => parts.push(self.kw("PRIMARY")), + "kw_key" => parts.push(self.kw("KEY")), + "kw_unique" => parts.push(self.kw("UNIQUE")), + "kw_default" => parts.push(self.kw("DEFAULT")), + "kw_check" => parts.push(self.kw("CHECK")), + "kw_references" => parts.push(self.kw("REFERENCES")), + "a_expr" | "c_expr" | "b_expr" => { + parts.push(self.format_expr(child)); + } + _ if child.kind().starts_with("kw_") => { + parts.push(self.kw(self.text(child))); + } + _ => parts.push(self.format_expr(child)), + } + } + parts.join(" ") + } + + fn format_table_constraint(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + if let Some(name) = node.find_child("name") { + parts.push(self.kw("CONSTRAINT")); + parts.push(self.format_expr(name)); + } + if let Some(elem) = node.find_child("ConstraintElem") { + parts.push(self.format_constraint_elem(elem)); + } + parts.join(" ") + } + + fn format_constraint_elem(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut has_check = false; + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() { + match child.kind() { + "kw_primary" => parts.push(self.kw("PRIMARY")), + "kw_key" => parts.push(self.kw("KEY")), + "kw_unique" => parts.push(self.kw("UNIQUE")), + "kw_check" => { + has_check = true; + parts.push(self.kw("CHECK")); + } + "kw_foreign" => parts.push(self.kw("FOREIGN")), + "kw_references" => parts.push(self.kw("REFERENCES")), + "columnList" => { + let items = flatten_list(child, "columnList"); + let formatted: Vec<_> = + items.iter().map(|i| self.format_expr(*i)).collect(); + parts.push(format!("({})", formatted.join(", "))); + } + "a_expr" | "c_expr" => { + let expr_text = format!("({})", self.format_expr(child)); + if has_check && self.config.river { + // River style: CHECK(expr) without space. + if let Some(last) = parts.last_mut() { + *last = format!("{last}{expr_text}"); + } + } else { + parts.push(expr_text); + } + } + _ if child.kind().starts_with("kw_") => { + parts.push(self.kw(self.text(child))); + } + _ => parts.push(self.format_expr(child)), + } + } else { + let text = self.text(child).trim(); + if text == "(" || text == ")" { + // Handled by columnList formatting. + } + } + } + parts.join(" ") + } + + // ── CREATE VIEW ───────────────────────────────────────────────────── + + fn format_view_stmt(&self, node: Node<'a>) -> String { + let mut prefix = format!("{} {}", self.kw("CREATE"), self.kw("VIEW")); + + // View name. + let name = node + .find_child("qualified_name") + .or_else(|| node.find_child("view_name")) + .map(|n| self.format_qualified_name(n)) + .unwrap_or_default(); + prefix = format!("{prefix} {name} {} ", self.kw("AS")); + + // The SELECT body. + if let Some(select) = node.find_child("SelectStmt") { + let body = self.format_select_stmt(select); + format!("{prefix}\n{}", body.trim_end_matches(';')) + } else { + prefix + } + } + + // ── CREATE TABLE AS / CREATE MATERIALIZED VIEW ────────────────────── + + fn format_create_table_as_stmt(&self, node: Node<'a>) -> String { + let kind = node.kind(); + let mut prefix_parts = vec![self.kw("CREATE")]; + + if kind == "CreateMatViewStmt" { + prefix_parts.push(self.kw("MATERIALIZED")); + prefix_parts.push(self.kw("VIEW")); + } else { + // Could be CREATE TABLE AS or CREATE MATERIALIZED VIEW AS. + if node.has_child("kw_materialized") { + prefix_parts.push(self.kw("MATERIALIZED")); + prefix_parts.push(self.kw("VIEW")); + } else { + prefix_parts.push(self.kw("TABLE")); + } + } + + let name = self.find_name_in_create(node); + prefix_parts.push(name); + prefix_parts.push(self.kw("AS")); + + let prefix = prefix_parts.join(" "); + + // The SELECT body. + let mut body = String::new(); + if let Some(select) = node.find_child("SelectStmt") { + body = self.format_select_stmt(select); + } else if let Some(query) = node.find_child("create_as_target") + && let Some(select) = query.find_child("SelectStmt") + { + body = self.format_select_stmt(select); + } + + let body = body.trim_end_matches(';'); + + // Check for WITH NO DATA. + let mut suffix = String::new(); + if node.has_child("kw_no") || self.text(node).contains("WITH NO DATA") { + suffix = format!( + "\n{} {} {}", + self.kw("WITH"), + self.kw("NO"), + self.kw("DATA") + ); + } + + format!("{prefix}\n{body}{suffix}") + } + + // ── CREATE FUNCTION ───────────────────────────────────────────────── + + fn format_create_function_stmt(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + + // CREATE FUNCTION/PROCEDURE name(args) + let mut header = vec![self.kw("CREATE")]; + if node.has_child("kw_procedure") { + header.push(self.kw("PROCEDURE")); + } else { + header.push(self.kw("FUNCTION")); + } + + let func_name = node + .find_child("func_name") + .map(|n| self.format_func_name(n)) + .unwrap_or_default(); + header.push(func_name); + + // Arguments. + if let Some(args) = node.find_child("func_args_with_defaults") { + let args_text = self.format_func_args(args); + let last = header.last_mut().unwrap(); + *last = format!("{last}{args_text}"); + } + + // RETURNS type. + if let Some(ret) = node.find_child("func_return") { + let ret_type = self.format_func_return(ret); + header.push(format!("{} {ret_type}", self.kw("RETURNS"))); + } + + parts.push(header.join(" ")); + + // Function options (LANGUAGE, AS, etc.). + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "opt_createfunc_opt_list" | "createfunc_opt_list" => { + self.format_createfunc_opts(child, &mut parts); + } + _ => {} + } + } + + parts.join("\n ") + } + + fn format_func_args(&self, node: Node<'a>) -> String { + // Reconstruct the function arguments. + let text = self.text(node); + // For now, normalize whitespace in the args. + + normalize_whitespace(text) + } + + fn format_func_return(&self, node: Node<'a>) -> String { + if let Some(ft) = node.find_child("func_type") + && let Some(tn) = ft.find_child("Typename") + { + return self.format_typename(tn); + } + self.text(node).trim().to_string() + } + + fn format_createfunc_opts(&self, node: Node<'a>, parts: &mut Vec) { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "createfunc_opt_item" => { + self.format_createfunc_opt_item(child, parts); + } + "createfunc_opt_list" => { + self.format_createfunc_opts(child, parts); + } + _ => {} + } + } + } + + fn format_createfunc_opt_item(&self, node: Node<'a>, parts: &mut Vec) { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "kw_language" => { + if let Some(lang) = node.find_child("NonReservedWord_or_Sconst") { + parts.push(format!("{} {}", self.kw("LANGUAGE"), self.text(lang))); + } + } + "func_as" => { + // AS $$ ... $$ + parts.push(format!("{}\n{}", self.kw("AS"), self.text(child))); + } + _ => {} + } + } + } + + // ── CREATE DOMAIN ─────────────────────────────────────────────────── + + fn format_create_domain_stmt(&self, node: Node<'a>) -> String { + let name = self.find_name_in_create(node); + let mut parts = vec![format!( + "{} {} {name}", + self.kw("CREATE"), + self.kw("DOMAIN") + )]; + + // AS typename. + if let Some(tn) = node.find_child("Typename") { + parts[0] = format!( + "{} {} {}", + parts[0], + self.kw("AS"), + self.format_typename(tn) + ); + } + + // Constraints. + if let Some(constraints) = node.find_child("ColQualList") { + let mut cursor = constraints.walk(); + for child in constraints.named_children(&mut cursor) { + if child.kind() == "ColConstraint" { + let indent = self.config.indent; + parts.push(format!("{indent}{}", self.format_col_constraint(child))); + } + } + } + + parts.join("\n") + } + + // ── CREATE FOREIGN TABLE ──────────────────────────────────────────── + + fn format_create_foreign_table_stmt(&self, node: Node<'a>) -> String { + // Similar to CREATE TABLE but with SERVER and OPTIONS. + let text = self.text(node); + normalize_whitespace(text) + } + + // ── Helpers ───────────────────────────────────────────────────────── + + fn format_relation_expr_opt_alias(&self, node: Node<'a>) -> String { + let mut parts = Vec::new(); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + match child.kind() { + "relation_expr" => parts.push(self.format_relation_expr(child)), + "opt_alias_clause" | "alias_clause" => { + // alias_clause already includes the AS keyword. + let alias = self.format_expr(child); + if !alias.is_empty() { + parts.push(alias); + } + } + "ColId" => { + // Bare identifier alias without AS keyword. + let alias = self.format_expr(child); + if !alias.is_empty() { + parts.push(format!("{} {alias}", self.kw("AS"))); + } + } + _ => parts.push(self.format_expr(child)), + } + } + parts.join(" ") + } + + fn format_qualified_name_from(&self, node: Node<'a>) -> String { + // insert_target wraps a qualified_name. + if let Some(qn) = node.find_child("qualified_name") { + return self.format_expr(qn); + } + self.format_expr(node) + } + + fn find_name_in_create(&self, node: Node<'a>) -> String { + // Look for qualified_name, any_name, or create_as_target. + if let Some(qn) = node.find_child("qualified_name") { + return self.format_expr(qn); + } + if let Some(an) = node.find_child("any_name") { + return self.format_expr(an); + } + if let Some(cat) = node.find_child("create_as_target") { + if let Some(qn) = cat.find_child("qualified_name") { + return self.format_expr(qn); + } + return self.format_expr(cat); + } + if let Some(mv) = node.find_child("create_mv_target") { + if let Some(qn) = mv.find_child("qualified_name") { + return self.format_expr(qn); + } + return self.format_expr(mv); + } + String::new() + } + + // format_where_river and format_where_left_aligned are defined in select.rs +} + +/// Collapse runs of whitespace to single spaces, but preserve whitespace +/// inside single-quoted strings, double-quoted identifiers, and dollar-quoted +/// strings so that literal content is not altered. +fn normalize_whitespace(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + let chars: Vec = s.chars().collect(); + let len = chars.len(); + let mut i = 0; + let mut in_space_run = false; + + while i < len { + let ch = chars[i]; + + // Single-quoted string. + if ch == '\'' { + in_space_run = false; + result.push(ch); + i += 1; + while i < len { + result.push(chars[i]); + if chars[i] == '\'' { + i += 1; + if i < len && chars[i] == '\'' { + result.push(chars[i]); + i += 1; + } else { + break; + } + } else { + i += 1; + } + } + continue; + } + + // Double-quoted identifier. + if ch == '"' { + in_space_run = false; + result.push(ch); + i += 1; + while i < len { + result.push(chars[i]); + if chars[i] == '"' { + i += 1; + if i < len && chars[i] == '"' { + result.push(chars[i]); + i += 1; + } else { + break; + } + } else { + i += 1; + } + } + continue; + } + + // Dollar-quoted string. + if ch == '$' { + let tag_start = i; + let mut tag_end = i + 1; + while tag_end < len && (chars[tag_end].is_ascii_alphanumeric() || chars[tag_end] == '_') + { + tag_end += 1; + } + if tag_end < len && chars[tag_end] == '$' { + in_space_run = false; + let tag: String = chars[tag_start..=tag_end].iter().collect(); + result.push_str(&tag); + i = tag_end + 1; + while i < len { + let remaining: String = chars[i..].iter().collect(); + if remaining.starts_with(&tag) { + result.push_str(&tag); + i += tag.len(); + break; + } + result.push(chars[i]); + i += 1; + } + continue; + } + } + + // Normal whitespace collapsing. + if ch.is_whitespace() { + if !in_space_run && !result.is_empty() { + result.push(' '); + } + in_space_run = true; + i += 1; + } else { + in_space_run = false; + result.push(ch); + i += 1; + } + } + + result.trim().to_string() +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..7c4bc05 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,116 @@ +pub mod error; +mod formatter; +mod node_helpers; +pub mod style; + +use error::FormatError; +use formatter::Formatter; +use style::Style; +use tree_sitter::Parser; +use tree_sitter_postgres::{LANGUAGE, LANGUAGE_PLPGSQL}; + +/// Format one or more PostgreSQL SQL statements according to the specified style. +/// +/// Input may contain multiple semicolon-separated statements. +/// Statements without trailing semicolons are handled gracefully. +pub fn format(sql: &str, style: Style) -> Result { + let trimmed = sql.trim(); + if trimmed.is_empty() { + return Ok(String::new()); + } + // The grammar requires trailing semicolons; ensure they are present. + // If the input ends with a line comment (--), append the semicolon on a + // new line so it doesn't become part of the comment. + let input = if trimmed.ends_with(';') { + trimmed.to_string() + } else if trimmed.lines().last().is_some_and(|l| l.contains("--")) { + format!("{trimmed}\n;") + } else { + format!("{trimmed};") + }; + let mut parser = Parser::new(); + parser + .set_language(&LANGUAGE.into()) + .map_err(|e| FormatError::Parser(e.to_string()))?; + let tree = parser + .parse(&input, None) + .ok_or_else(|| FormatError::Parser("Failed to parse SQL".into()))?; + let root = tree.root_node(); + // The tree-sitter-postgres grammar doesn't handle some valid SQL + // constructs (e.g., decimal literals like 800.00). When errors are + // limited to leaf nodes, attempt to format anyway so the rest of the + // statement is still properly styled. Only bail out when the tree + // structure is fundamentally broken (ERROR at the top level wrapping + // major statement parts). + if root.has_error() && has_structural_error(&root) { + return Err(FormatError::Syntax(find_error_message(&root, &input))); + } + let fmt = Formatter::new(&input, style); + fmt.format_root(root) +} + +/// Format PL/pgSQL code according to the specified style. +/// +/// The input should be the body of a PL/pgSQL function (the content between +/// the dollar-quote delimiters, typically starting with DECLARE or BEGIN). +pub fn format_plpgsql(code: &str, style: Style) -> Result { + let trimmed = code.trim(); + if trimmed.is_empty() { + return Ok(String::new()); + } + let mut parser = Parser::new(); + parser + .set_language(&LANGUAGE_PLPGSQL.into()) + .map_err(|e| FormatError::Parser(e.to_string()))?; + let tree = parser + .parse(trimmed, None) + .ok_or_else(|| FormatError::Parser("Failed to parse PL/pgSQL".into()))?; + let root = tree.root_node(); + if root.has_error() { + return Err(FormatError::Syntax(find_error_message(&root, trimmed))); + } + let fmt = Formatter::new(trimmed, style); + fmt.format_plpgsql_root(root) +} + +/// Check whether the parse tree has a structural error (ERROR node wrapping +/// significant content). Small ERROR nodes (e.g., unparsed decimal fraction +/// like ".00") are tolerable and can be passed through as-is. +fn has_structural_error(node: &tree_sitter::Node) -> bool { + if node.is_error() { + // Small leaf ERROR nodes (e.g., ".00" decimal part = 3 bytes) are + // tolerable grammar gaps. Anything larger likely indicates a genuine + // parse failure that would produce garbled output. + let size = node.end_byte() - node.start_byte(); + return size > 4; + } + if node.is_missing() { + return true; + } + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.has_error() && has_structural_error(&child) { + return true; + } + } + false +} + +fn find_error_message(node: &tree_sitter::Node, source: &str) -> String { + if node.is_error() || node.is_missing() { + let start = node.start_position(); + return format!( + "Syntax error at line {}, column {}: {:?}", + start.row + 1, + start.column + 1, + &source[node.byte_range()] + ); + } + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.has_error() { + return find_error_message(&child, source); + } + } + "Unknown syntax error".into() +} diff --git a/src/node_helpers.rs b/src/node_helpers.rs new file mode 100644 index 0000000..618841d --- /dev/null +++ b/src/node_helpers.rs @@ -0,0 +1,70 @@ +use tree_sitter::Node; + +/// Extension trait for tree-sitter nodes. +pub(crate) trait NodeExt<'a> { + /// Get the source text for this node. + fn text(&self, source: &'a str) -> &'a str; + + /// Find the first named child with the given kind. + fn find_child(&self, kind: &str) -> Option>; + + /// Find first child matching any of the given kinds. + fn find_child_any(&self, kinds: &[&str]) -> Option>; + + /// Get all named children. + fn named_children_vec(&self) -> Vec>; + + /// Check if this node has a named child with the given kind. + fn has_child(&self, kind: &str) -> bool; +} + +impl<'a> NodeExt<'a> for Node<'a> { + fn text(&self, source: &'a str) -> &'a str { + &source[self.byte_range()] + } + + fn find_child(&self, kind: &str) -> Option> { + let mut cursor = self.walk(); + self.named_children(&mut cursor) + .find(|&child| child.kind() == kind) + } + + fn find_child_any(&self, kinds: &[&str]) -> Option> { + let mut cursor = self.walk(); + self.named_children(&mut cursor) + .find(|&child| kinds.contains(&child.kind())) + } + + fn named_children_vec(&self) -> Vec> { + let mut cursor = self.walk(); + self.named_children(&mut cursor).collect() + } + + fn has_child(&self, kind: &str) -> bool { + self.find_child(kind).is_some() + } +} + +/// Flatten a left-recursive list node (like `target_list`, `expr_list`, etc.) +/// into a vector of the leaf items. +/// +/// In the tree-sitter-postgres grammar, lists are encoded as left-recursive rules: +/// target_list -> target_list ',' target_el | target_el +/// +/// This function collects all the non-list leaf items. +pub(crate) fn flatten_list<'a>(node: Node<'a>, list_kind: &str) -> Vec> { + let mut items = Vec::new(); + flatten_list_inner(node, list_kind, &mut items); + items +} + +fn flatten_list_inner<'a>(node: Node<'a>, list_kind: &str, items: &mut Vec>) { + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + if child.kind() == list_kind { + flatten_list_inner(child, list_kind, items); + } else { + items.push(child); + } + } +} diff --git a/src/style.rs b/src/style.rs new file mode 100644 index 0000000..12e2bd4 --- /dev/null +++ b/src/style.rs @@ -0,0 +1,66 @@ +use std::fmt; +use std::str::FromStr; + +/// SQL formatting style. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub enum Style { + /// Simon Holywell's river style — keywords right-aligned to form a visual river. + #[default] + River, + /// Mozilla style — keywords left-aligned, content indented 4 spaces. + Mozilla, + /// AWeber style — river style with JOINs participating in keyword alignment. + Aweber, + /// dbt style — Mozilla-like with lowercase keywords and blank lines between clauses. + Dbt, + /// GitLab style — Mozilla-like with 2-space indent and uppercase keywords. + Gitlab, + /// Kickstarter style — Mozilla-like with 2-space indent and compact JOINs. + Kickstarter, + /// mattmc3 style — lowercase river with leading commas. + Mattmc3, +} + +impl Style { + /// All available styles. + pub const ALL: &[Style] = &[ + Style::River, + Style::Mozilla, + Style::Aweber, + Style::Dbt, + Style::Gitlab, + Style::Kickstarter, + Style::Mattmc3, + ]; +} + +impl fmt::Display for Style { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Style::River => write!(f, "river"), + Style::Mozilla => write!(f, "mozilla"), + Style::Aweber => write!(f, "aweber"), + Style::Dbt => write!(f, "dbt"), + Style::Gitlab => write!(f, "gitlab"), + Style::Kickstarter => write!(f, "kickstarter"), + Style::Mattmc3 => write!(f, "mattmc3"), + } + } +} + +impl FromStr for Style { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "river" => Ok(Style::River), + "mozilla" => Ok(Style::Mozilla), + "aweber" => Ok(Style::Aweber), + "dbt" => Ok(Style::Dbt), + "gitlab" => Ok(Style::Gitlab), + "kickstarter" => Ok(Style::Kickstarter), + "mattmc3" => Ok(Style::Mattmc3), + _ => Err(format!("Unsupported style: '{s}'")), + } + } +} diff --git a/tests/fixtures/aweber/select_case_join.expected b/tests/fixtures/aweber/select_case_join.expected new file mode 100644 index 0000000..2a22dae --- /dev/null +++ b/tests/fixtures/aweber/select_case_join.expected @@ -0,0 +1,19 @@ + SELECT DISTINCT bal.account_id, + bal.invoice_id, + CASE WHEN pkg.name LIKE '%Free%' THEN 'Free' + WHEN pkg.name LIKE '%Pro%' THEN 'Pro' + WHEN pkg.name LIKE '%Plus%' THEN 'Plus' + ELSE pkg.name + END AS package_type, + ap.pkg_id, + bal.amount + FROM report.balances_vw AS bal +LEFT JOIN public.invoices AS inv + ON inv.invoice_id = bal.invoice_id +LEFT JOIN public.account_packages AS ap + ON bal.pkg_id = ap.pkg_id +LEFT JOIN public.packages AS pkg + ON pkg.pkg_id = ap.pkg_id + WHERE inv.status = 'Paid' + AND bal.pkg_id IS NOT NULL + AND bal.amount > 0; diff --git a/tests/fixtures/aweber/select_case_join.sql b/tests/fixtures/aweber/select_case_join.sql new file mode 100644 index 0000000..c781eed --- /dev/null +++ b/tests/fixtures/aweber/select_case_join.sql @@ -0,0 +1 @@ +SELECT DISTINCT bal.account_id, bal.invoice_id, CASE WHEN pkg.name LIKE '%Free%' THEN 'Free' WHEN pkg.name LIKE '%Pro%' THEN 'Pro' WHEN pkg.name LIKE '%Plus%' THEN 'Plus' ELSE pkg.name END AS package_type, ap.pkg_id, bal.amount FROM report.balances_vw AS bal LEFT JOIN public.invoices AS inv ON inv.invoice_id = bal.invoice_id LEFT JOIN public.account_packages AS ap ON bal.pkg_id = ap.pkg_id LEFT JOIN public.packages AS pkg ON pkg.pkg_id = ap.pkg_id WHERE inv.status = 'Paid' AND bal.pkg_id IS NOT NULL AND bal.amount > 0 diff --git a/tests/fixtures/aweber/select_cte_nested.expected b/tests/fixtures/aweber/select_cte_nested.expected new file mode 100644 index 0000000..67221fb --- /dev/null +++ b/tests/fixtures/aweber/select_cte_nested.expected @@ -0,0 +1,44 @@ +WITH totals AS ( + SELECT account_id, + SUM(amount) AS lifetime_value + FROM public.invoices + WHERE status = 'Paid' + GROUP BY account_id +), +latest_emails AS ( + WITH ranked AS ( + SELECT contact_id, + email::TEXT AS email, + RANK() OVER (PARTITION BY contact_id ORDER BY id DESC) AS rank + FROM public.contact_emails + ) + SELECT contact_id, + email + FROM ranked + WHERE rank = 1 +), +current_plan AS ( + SELECT a.account_id, + p.plan_name, + CASE WHEN p.plan_name ~ 'Free' THEN TRUE + ELSE FALSE + END AS is_free + FROM public.accounts AS a + JOIN public.packages AS p + USING (pkg_id) + WHERE a.status_id <> 7 +) + SELECT a.account_id, + a.date_opened, + CASE WHEN cp.is_free THEN 'Free' + ELSE 'Paid' + END AS account_type, + COALESCE(t.lifetime_value, 0.00) AS lifetime_value, + le.email AS contact_email + FROM public.accounts AS a + JOIN current_plan AS cp + USING (account_id) +LEFT JOIN totals AS t + USING (account_id) +LEFT JOIN latest_emails AS le + USING (contact_id); diff --git a/tests/fixtures/aweber/select_cte_nested.sql b/tests/fixtures/aweber/select_cte_nested.sql new file mode 100644 index 0000000..46b8ae0 --- /dev/null +++ b/tests/fixtures/aweber/select_cte_nested.sql @@ -0,0 +1 @@ +WITH totals AS (SELECT account_id, SUM(amount) AS lifetime_value FROM public.invoices WHERE status = 'Paid' GROUP BY account_id), latest_emails AS (WITH ranked AS (SELECT contact_id, CAST(email AS TEXT) AS email, rank() OVER (PARTITION BY contact_id ORDER BY id DESC) AS rank FROM public.contact_emails) SELECT contact_id, email FROM ranked WHERE rank = 1), current_plan AS (SELECT a.account_id, p.plan_name, CASE WHEN p.plan_name ~ 'Free' THEN TRUE ELSE FALSE END AS is_free FROM public.accounts AS a JOIN public.packages AS p USING (pkg_id) WHERE a.status_id != 7) SELECT a.account_id, a.date_opened, CASE WHEN cp.is_free THEN 'Free' ELSE 'Paid' END AS account_type, COALESCE(t.lifetime_value, 0.00) AS lifetime_value, le.email AS contact_email FROM public.accounts AS a JOIN current_plan AS cp USING (account_id) LEFT JOIN totals AS t USING (account_id) LEFT JOIN latest_emails AS le USING (contact_id) diff --git a/tests/fixtures/aweber/select_cte_union.expected b/tests/fixtures/aweber/select_cte_union.expected new file mode 100644 index 0000000..7ebbeb0 --- /dev/null +++ b/tests/fixtures/aweber/select_cte_union.expected @@ -0,0 +1,43 @@ +WITH active_items AS ( + SELECT i.created_at AS logged_at, + i.created_by AS author, + 'created' AS action, + i.id AS item_id, + i.name AS item_name + FROM items AS i + WHERE i.modified_at IS NULL +ORDER BY i.created_at DESC + LIMIT 100 +), +modified_items AS ( + SELECT i.modified_at AS logged_at, + i.modified_by AS author, + 'updated' AS action, + i.id AS item_id, + i.name AS item_name + FROM items AS i + WHERE i.modified_at IS NOT NULL +ORDER BY i.modified_at DESC + LIMIT 100 +), +combined AS ( +SELECT * + FROM active_items + +UNION + +SELECT * + FROM modified_items +) + SELECT c.author, + u.display_name, + c.action, + c.item_id, + c.item_name, + MAX(logged_at) AS logged_at + FROM combined AS c + JOIN users AS u + ON u.username = c.author +GROUP BY c.author, u.display_name, c.action, c.item_id, c.item_name +ORDER BY MAX(c.logged_at) DESC + LIMIT 100; diff --git a/tests/fixtures/aweber/select_cte_union.sql b/tests/fixtures/aweber/select_cte_union.sql new file mode 100644 index 0000000..022e4cd --- /dev/null +++ b/tests/fixtures/aweber/select_cte_union.sql @@ -0,0 +1 @@ +WITH active_items AS (SELECT i.created_at AS logged_at, i.created_by AS author, 'created' AS action, i.id AS item_id, i.name AS item_name FROM items AS i WHERE i.modified_at IS NULL ORDER BY i.created_at DESC LIMIT 100), modified_items AS (SELECT i.modified_at AS logged_at, i.modified_by AS author, 'updated' AS action, i.id AS item_id, i.name AS item_name FROM items AS i WHERE i.modified_at IS NOT NULL ORDER BY i.modified_at DESC LIMIT 100), combined AS (SELECT * FROM active_items UNION SELECT * FROM modified_items) SELECT c.author, u.display_name, c.action, c.item_id, c.item_name, MAX(logged_at) AS logged_at FROM combined AS c JOIN users AS u ON u.username = c.author GROUP BY c.author, u.display_name, c.action, c.item_id, c.item_name ORDER BY MAX(c.logged_at) DESC LIMIT 100 diff --git a/tests/fixtures/aweber/select_join.expected b/tests/fixtures/aweber/select_join.expected new file mode 100644 index 0000000..b1fdf82 --- /dev/null +++ b/tests/fixtures/aweber/select_join.expected @@ -0,0 +1,8 @@ + SELECT r.last_name + FROM riders AS r +INNER JOIN bikes AS b + ON r.bike_vin_num = b.vin_num + AND b.engines > 2 +INNER JOIN crew AS c + ON r.crew_chief_last_name = c.last_name + AND c.chief = 'Y'; diff --git a/tests/fixtures/aweber/select_join.sql b/tests/fixtures/aweber/select_join.sql new file mode 100644 index 0000000..9f3cd6a --- /dev/null +++ b/tests/fixtures/aweber/select_join.sql @@ -0,0 +1 @@ +SELECT r.last_name FROM riders AS r INNER JOIN bikes AS b ON r.bike_vin_num = b.vin_num AND b.engines > 2 INNER JOIN crew AS c ON r.crew_chief_last_name = c.last_name AND c.chief = 'Y' diff --git a/tests/fixtures/aweber/select_left_join.expected b/tests/fixtures/aweber/select_left_join.expected new file mode 100644 index 0000000..11bfabf --- /dev/null +++ b/tests/fixtures/aweber/select_left_join.expected @@ -0,0 +1,7 @@ + SELECT r.id, + r.name, + COUNT(o.id) AS order_count + FROM recent AS r +LEFT JOIN orders AS o + ON r.id = o.user_id + GROUP BY r.id, r.name; diff --git a/tests/fixtures/aweber/select_left_join.sql b/tests/fixtures/aweber/select_left_join.sql new file mode 100644 index 0000000..5996df2 --- /dev/null +++ b/tests/fixtures/aweber/select_left_join.sql @@ -0,0 +1 @@ +SELECT r.id, r.name, COUNT(o.id) AS order_count FROM recent AS r LEFT JOIN orders AS o ON r.id = o.user_id GROUP BY r.id, r.name diff --git a/tests/fixtures/aweber/select_or.expected b/tests/fixtures/aweber/select_or.expected new file mode 100644 index 0000000..f0aa381 --- /dev/null +++ b/tests/fixtures/aweber/select_or.expected @@ -0,0 +1,6 @@ +SELECT a.title, + a.released_on, + a.recorded_on + FROM albums AS a + WHERE a.title = 'Charcoal Lane' + OR a.title = 'The New Danger'; diff --git a/tests/fixtures/aweber/select_or.sql b/tests/fixtures/aweber/select_or.sql new file mode 100644 index 0000000..e89a06b --- /dev/null +++ b/tests/fixtures/aweber/select_or.sql @@ -0,0 +1 @@ +SELECT a.title, a.released_on, a.recorded_on FROM albums AS a WHERE a.title = 'Charcoal Lane' OR a.title = 'The New Danger' diff --git a/tests/fixtures/aweber/select_simple.expected b/tests/fixtures/aweber/select_simple.expected new file mode 100644 index 0000000..c72ec93 --- /dev/null +++ b/tests/fixtures/aweber/select_simple.expected @@ -0,0 +1,3 @@ +SELECT file_hash + FROM file_system + WHERE file_name = '.vimrc'; diff --git a/tests/fixtures/aweber/select_simple.sql b/tests/fixtures/aweber/select_simple.sql new file mode 100644 index 0000000..41a369a --- /dev/null +++ b/tests/fixtures/aweber/select_simple.sql @@ -0,0 +1 @@ +SELECT file_hash FROM file_system WHERE file_name = '.vimrc' diff --git a/tests/fixtures/aweber/select_subquery.expected b/tests/fixtures/aweber/select_subquery.expected new file mode 100644 index 0000000..9bb3883 --- /dev/null +++ b/tests/fixtures/aweber/select_subquery.expected @@ -0,0 +1,10 @@ +SELECT r.last_name, + (SELECT MAX(YEAR(championship_date)) + FROM champions AS c + WHERE c.last_name = r.last_name + AND c.confirmed = 'Y') AS last_championship_year + FROM riders AS r + WHERE r.last_name IN (SELECT c.last_name + FROM champions AS c + WHERE YEAR(championship_date) > '2008' + AND c.confirmed = 'Y'); diff --git a/tests/fixtures/aweber/select_subquery.sql b/tests/fixtures/aweber/select_subquery.sql new file mode 100644 index 0000000..0c73c3b --- /dev/null +++ b/tests/fixtures/aweber/select_subquery.sql @@ -0,0 +1 @@ +SELECT r.last_name, (SELECT MAX(YEAR(championship_date)) FROM champions AS c WHERE c.last_name = r.last_name AND c.confirmed = 'Y') AS last_championship_year FROM riders AS r WHERE r.last_name IN (SELECT c.last_name FROM champions AS c WHERE YEAR(championship_date) > '2008' AND c.confirmed = 'Y') diff --git a/tests/fixtures/dbt/select_cte.expected b/tests/fixtures/dbt/select_cte.expected new file mode 100644 index 0000000..43de374 --- /dev/null +++ b/tests/fixtures/dbt/select_cte.expected @@ -0,0 +1,33 @@ +with + +recent as ( + + select + id, + name + + from users + + where + active = true + + order by created_at desc + + limit 10 + +) + +select + r.id, + r.name, + count(o.id) as order_count + +from + recent as r +left join + orders as o + on r.id = o.user_id + +group by + r.id, + r.name; diff --git a/tests/fixtures/dbt/select_cte.sql b/tests/fixtures/dbt/select_cte.sql new file mode 100644 index 0000000..2728200 --- /dev/null +++ b/tests/fixtures/dbt/select_cte.sql @@ -0,0 +1 @@ +WITH recent AS (SELECT id, name FROM users WHERE active = TRUE ORDER BY created_at DESC LIMIT 10) SELECT r.id, r.name, COUNT(o.id) AS order_count FROM recent AS r LEFT JOIN orders AS o ON r.id = o.user_id GROUP BY r.id, r.name diff --git a/tests/fixtures/dbt/select_group_order.expected b/tests/fixtures/dbt/select_group_order.expected new file mode 100644 index 0000000..8b3da9a --- /dev/null +++ b/tests/fixtures/dbt/select_group_order.expected @@ -0,0 +1,12 @@ +select + x, + count(*) as cnt + +from t + +group by x + +having + count(*) > 1 + +order by cnt desc; diff --git a/tests/fixtures/dbt/select_group_order.sql b/tests/fixtures/dbt/select_group_order.sql new file mode 100644 index 0000000..0c8a5bf --- /dev/null +++ b/tests/fixtures/dbt/select_group_order.sql @@ -0,0 +1 @@ +SELECT x, COUNT(*) AS cnt FROM t GROUP BY x HAVING COUNT(*) > 1 ORDER BY cnt DESC diff --git a/tests/fixtures/dbt/select_join.expected b/tests/fixtures/dbt/select_join.expected new file mode 100644 index 0000000..ba26211 --- /dev/null +++ b/tests/fixtures/dbt/select_join.expected @@ -0,0 +1,8 @@ +select r.last_name + +from + riders as r +inner join + bikes as b + on r.bike_vin_num = b.vin_num + and b.engines > 2; diff --git a/tests/fixtures/dbt/select_join.sql b/tests/fixtures/dbt/select_join.sql new file mode 100644 index 0000000..74e5e80 --- /dev/null +++ b/tests/fixtures/dbt/select_join.sql @@ -0,0 +1 @@ +SELECT r.last_name FROM riders AS r INNER JOIN bikes AS b ON r.bike_vin_num = b.vin_num AND b.engines > 2 diff --git a/tests/fixtures/dbt/select_simple.expected b/tests/fixtures/dbt/select_simple.expected new file mode 100644 index 0000000..3f607f4 --- /dev/null +++ b/tests/fixtures/dbt/select_simple.expected @@ -0,0 +1,11 @@ +select + client_id, + submission_date + +from main_summary + +where + submission_date > '20180101' + and sample_id = '42' + +limit 10; diff --git a/tests/fixtures/dbt/select_simple.sql b/tests/fixtures/dbt/select_simple.sql new file mode 100644 index 0000000..1d6621b --- /dev/null +++ b/tests/fixtures/dbt/select_simple.sql @@ -0,0 +1 @@ +SELECT client_id, submission_date FROM main_summary WHERE submission_date > '20180101' AND sample_id = '42' LIMIT 10 diff --git a/tests/fixtures/gitlab/select_cte.expected b/tests/fixtures/gitlab/select_cte.expected new file mode 100644 index 0000000..bc680bb --- /dev/null +++ b/tests/fixtures/gitlab/select_cte.expected @@ -0,0 +1,14 @@ +WITH important_list AS ( + + SELECT DISTINCT specific_column + FROM other_table + WHERE + specific_column <> 'foo' + +) + +SELECT primary_table.column_1 +FROM primary_table +INNER JOIN + important_list + ON primary_table.column_3 = important_list.specific_column; diff --git a/tests/fixtures/gitlab/select_cte.sql b/tests/fixtures/gitlab/select_cte.sql new file mode 100644 index 0000000..6a6c280 --- /dev/null +++ b/tests/fixtures/gitlab/select_cte.sql @@ -0,0 +1 @@ +WITH important_list AS (SELECT DISTINCT specific_column FROM other_table WHERE specific_column != 'foo') SELECT primary_table.column_1 FROM primary_table INNER JOIN important_list ON primary_table.column_3 = important_list.specific_column diff --git a/tests/fixtures/gitlab/select_group_order.expected b/tests/fixtures/gitlab/select_group_order.expected new file mode 100644 index 0000000..27f7d74 --- /dev/null +++ b/tests/fixtures/gitlab/select_group_order.expected @@ -0,0 +1,8 @@ +SELECT + x, + COUNT(*) AS cnt +FROM t +GROUP BY x +HAVING + COUNT(*) > 1 +ORDER BY cnt DESC; diff --git a/tests/fixtures/gitlab/select_group_order.sql b/tests/fixtures/gitlab/select_group_order.sql new file mode 100644 index 0000000..0c8a5bf --- /dev/null +++ b/tests/fixtures/gitlab/select_group_order.sql @@ -0,0 +1 @@ +SELECT x, COUNT(*) AS cnt FROM t GROUP BY x HAVING COUNT(*) > 1 ORDER BY cnt DESC diff --git a/tests/fixtures/gitlab/select_join.expected b/tests/fixtures/gitlab/select_join.expected new file mode 100644 index 0000000..5b131b6 --- /dev/null +++ b/tests/fixtures/gitlab/select_join.expected @@ -0,0 +1,15 @@ +SELECT + a.title, + COUNT(*) AS cnt +FROM albums AS a +LEFT JOIN + orders AS o + ON a.id = o.album_id +WHERE + a.year > 2000 + AND a.genre = 'rock' +GROUP BY a.title +HAVING + COUNT(*) > 1 +ORDER BY cnt DESC +LIMIT 10; diff --git a/tests/fixtures/gitlab/select_join.sql b/tests/fixtures/gitlab/select_join.sql new file mode 100644 index 0000000..47a82a5 --- /dev/null +++ b/tests/fixtures/gitlab/select_join.sql @@ -0,0 +1 @@ +SELECT a.title, COUNT(*) AS cnt FROM albums AS a LEFT JOIN orders AS o ON a.id = o.album_id WHERE a.year > 2000 AND a.genre = 'rock' GROUP BY a.title HAVING COUNT(*) > 1 ORDER BY cnt DESC LIMIT 10 diff --git a/tests/fixtures/gitlab/select_simple.expected b/tests/fixtures/gitlab/select_simple.expected new file mode 100644 index 0000000..b5413db --- /dev/null +++ b/tests/fixtures/gitlab/select_simple.expected @@ -0,0 +1,8 @@ +SELECT + client_id, + submission_date +FROM main_summary +WHERE + submission_date > '20180101' + AND sample_id = '42' +LIMIT 10; diff --git a/tests/fixtures/gitlab/select_simple.sql b/tests/fixtures/gitlab/select_simple.sql new file mode 100644 index 0000000..1d6621b --- /dev/null +++ b/tests/fixtures/gitlab/select_simple.sql @@ -0,0 +1 @@ +SELECT client_id, submission_date FROM main_summary WHERE submission_date > '20180101' AND sample_id = '42' LIMIT 10 diff --git a/tests/fixtures/kickstarter/select_cte.expected b/tests/fixtures/kickstarter/select_cte.expected new file mode 100644 index 0000000..162e12f --- /dev/null +++ b/tests/fixtures/kickstarter/select_cte.expected @@ -0,0 +1,14 @@ +WITH backings_per_category AS ( + SELECT + category_id, + deadline + FROM app.backings +), backers AS ( + SELECT + backer_id, + COUNT(id) AS projects_backed + FROM app.backings + GROUP BY backer_id +) +SELECT * +FROM backers; diff --git a/tests/fixtures/kickstarter/select_cte.sql b/tests/fixtures/kickstarter/select_cte.sql new file mode 100644 index 0000000..0b8cb87 --- /dev/null +++ b/tests/fixtures/kickstarter/select_cte.sql @@ -0,0 +1 @@ +WITH backings_per_category AS (SELECT category_id, deadline FROM app.backings), backers AS (SELECT backer_id, COUNT(id) AS projects_backed FROM app.backings GROUP BY backer_id) SELECT * FROM backers diff --git a/tests/fixtures/kickstarter/select_join.expected b/tests/fixtures/kickstarter/select_join.expected new file mode 100644 index 0000000..3981ffe --- /dev/null +++ b/tests/fixtures/kickstarter/select_join.expected @@ -0,0 +1,10 @@ +SELECT + p.name AS project_name, + COUNT(b.id) AS backing_count +FROM app.projects AS p +INNER JOIN app.backings AS b ON p.id = b.project_id + AND b.country <> 'US' +LEFT JOIN app.rewards AS rewards ON b.id = rewards.backing_id +WHERE + p.country = 'US' + AND p.deadline >= '2015-01-01'; diff --git a/tests/fixtures/kickstarter/select_join.sql b/tests/fixtures/kickstarter/select_join.sql new file mode 100644 index 0000000..fe0e222 --- /dev/null +++ b/tests/fixtures/kickstarter/select_join.sql @@ -0,0 +1 @@ +SELECT p.name AS project_name, COUNT(b.id) AS backing_count FROM app.projects AS p INNER JOIN app.backings AS b ON p.id = b.project_id AND b.country != 'US' LEFT JOIN app.rewards AS rewards ON b.id = rewards.backing_id WHERE p.country = 'US' AND p.deadline >= '2015-01-01' diff --git a/tests/fixtures/kickstarter/select_simple.expected b/tests/fixtures/kickstarter/select_simple.expected new file mode 100644 index 0000000..0026e38 --- /dev/null +++ b/tests/fixtures/kickstarter/select_simple.expected @@ -0,0 +1,10 @@ +SELECT + id, + name, + email, + created_at +FROM app.users +WHERE + active = TRUE +ORDER BY created_at DESC +LIMIT 10; diff --git a/tests/fixtures/kickstarter/select_simple.sql b/tests/fixtures/kickstarter/select_simple.sql new file mode 100644 index 0000000..42cbf1c --- /dev/null +++ b/tests/fixtures/kickstarter/select_simple.sql @@ -0,0 +1 @@ +SELECT id, name, email, created_at FROM app.users WHERE active = TRUE ORDER BY created_at DESC LIMIT 10 diff --git a/tests/fixtures/kickstarter/select_where.expected b/tests/fixtures/kickstarter/select_where.expected new file mode 100644 index 0000000..6af8299 --- /dev/null +++ b/tests/fixtures/kickstarter/select_where.expected @@ -0,0 +1,10 @@ +SELECT + id, + name, + status +FROM app.projects +WHERE + country = 'US' + AND deadline >= '2015-01-01' + AND state = 'live' +ORDER BY deadline; diff --git a/tests/fixtures/kickstarter/select_where.sql b/tests/fixtures/kickstarter/select_where.sql new file mode 100644 index 0000000..d09bd42 --- /dev/null +++ b/tests/fixtures/kickstarter/select_where.sql @@ -0,0 +1 @@ +SELECT id, name, status FROM app.projects WHERE country = 'US' AND deadline >= '2015-01-01' AND state = 'live' ORDER BY deadline diff --git a/tests/fixtures/mattmc3/insert_values.expected b/tests/fixtures/mattmc3/insert_values.expected new file mode 100644 index 0000000..e1bf139 --- /dev/null +++ b/tests/fixtures/mattmc3/insert_values.expected @@ -0,0 +1,3 @@ +insert into currencies (code, name, modified_date) + values ('XBT', 'Bitcoin', now()) + , ('ETH', 'Ethereum', now()); diff --git a/tests/fixtures/mattmc3/insert_values.sql b/tests/fixtures/mattmc3/insert_values.sql new file mode 100644 index 0000000..9ab2b0d --- /dev/null +++ b/tests/fixtures/mattmc3/insert_values.sql @@ -0,0 +1 @@ +INSERT INTO currencies (code, name, modified_date) VALUES ('XBT', 'Bitcoin', now()), ('ETH', 'Ethereum', now()) diff --git a/tests/fixtures/mattmc3/select_join.expected b/tests/fixtures/mattmc3/select_join.expected new file mode 100644 index 0000000..52ea4e0 --- /dev/null +++ b/tests/fixtures/mattmc3/select_join.expected @@ -0,0 +1,8 @@ +select r.last_name + from riders as r + join bikes as b + on r.bike_vin_num = b.vin_num + and b.engines > 2 + join crew as c + on r.crew_chief_last_name = c.last_name + and c.chief = 'Y'; diff --git a/tests/fixtures/mattmc3/select_join.sql b/tests/fixtures/mattmc3/select_join.sql new file mode 100644 index 0000000..9f3cd6a --- /dev/null +++ b/tests/fixtures/mattmc3/select_join.sql @@ -0,0 +1 @@ +SELECT r.last_name FROM riders AS r INNER JOIN bikes AS b ON r.bike_vin_num = b.vin_num AND b.engines > 2 INNER JOIN crew AS c ON r.crew_chief_last_name = c.last_name AND c.chief = 'Y' diff --git a/tests/fixtures/mattmc3/select_or.expected b/tests/fixtures/mattmc3/select_or.expected new file mode 100644 index 0000000..6a67862 --- /dev/null +++ b/tests/fixtures/mattmc3/select_or.expected @@ -0,0 +1,6 @@ +select a.title + , a.released_on + , a.recorded_on + from albums as a + where a.title = 'Charcoal Lane' + or a.title = 'The New Danger'; diff --git a/tests/fixtures/mattmc3/select_or.sql b/tests/fixtures/mattmc3/select_or.sql new file mode 100644 index 0000000..e89a06b --- /dev/null +++ b/tests/fixtures/mattmc3/select_or.sql @@ -0,0 +1 @@ +SELECT a.title, a.released_on, a.recorded_on FROM albums AS a WHERE a.title = 'Charcoal Lane' OR a.title = 'The New Danger' diff --git a/tests/fixtures/mattmc3/select_simple.expected b/tests/fixtures/mattmc3/select_simple.expected new file mode 100644 index 0000000..ad6831e --- /dev/null +++ b/tests/fixtures/mattmc3/select_simple.expected @@ -0,0 +1,6 @@ +select p.name + , p.product_number + , p.color + , p.list_price + from products as p + where p.list_price < 800; diff --git a/tests/fixtures/mattmc3/select_simple.sql b/tests/fixtures/mattmc3/select_simple.sql new file mode 100644 index 0000000..6072bfd --- /dev/null +++ b/tests/fixtures/mattmc3/select_simple.sql @@ -0,0 +1 @@ +SELECT p.name, p.product_number, p.color, p.list_price FROM products AS p WHERE p.list_price < 800 diff --git a/tests/fixtures/mattmc3/update_multi.expected b/tests/fixtures/mattmc3/update_multi.expected new file mode 100644 index 0000000..dee58c9 --- /dev/null +++ b/tests/fixtures/mattmc3/update_multi.expected @@ -0,0 +1,5 @@ +update products + set list_price = list_price + 100 + , modified_date = now() + where category = 'sale' + and active = true; diff --git a/tests/fixtures/mattmc3/update_multi.sql b/tests/fixtures/mattmc3/update_multi.sql new file mode 100644 index 0000000..ac72a43 --- /dev/null +++ b/tests/fixtures/mattmc3/update_multi.sql @@ -0,0 +1 @@ +UPDATE products SET list_price = list_price + 100, modified_date = now() WHERE category = 'sale' AND active = TRUE diff --git a/tests/fixtures/mozilla/create_table.expected b/tests/fixtures/mozilla/create_table.expected new file mode 100644 index 0000000..d769c8d --- /dev/null +++ b/tests/fixtures/mozilla/create_table.expected @@ -0,0 +1,7 @@ +CREATE TABLE staff ( + staff_num INTEGER NOT NULL, + first_name TEXT NOT NULL, + pens_in_drawer INTEGER NOT NULL, + CONSTRAINT pens_in_drawer_range CHECK (pens_in_drawer >= 1 AND pens_in_drawer < 100), + PRIMARY KEY (staff_num) +); diff --git a/tests/fixtures/mozilla/create_table.sql b/tests/fixtures/mozilla/create_table.sql new file mode 100644 index 0000000..b4e09d3 --- /dev/null +++ b/tests/fixtures/mozilla/create_table.sql @@ -0,0 +1 @@ +CREATE TABLE staff (staff_num INTEGER NOT NULL, first_name TEXT NOT NULL, pens_in_drawer INTEGER NOT NULL, CONSTRAINT pens_in_drawer_range CHECK(pens_in_drawer >= 1 AND pens_in_drawer < 100), PRIMARY KEY (staff_num)) diff --git a/tests/fixtures/mozilla/delete_and.expected b/tests/fixtures/mozilla/delete_and.expected new file mode 100644 index 0000000..5d5fcbd --- /dev/null +++ b/tests/fixtures/mozilla/delete_and.expected @@ -0,0 +1,4 @@ +DELETE FROM albums +WHERE + id = 1 + AND active = FALSE; diff --git a/tests/fixtures/mozilla/delete_and.sql b/tests/fixtures/mozilla/delete_and.sql new file mode 100644 index 0000000..a029af8 --- /dev/null +++ b/tests/fixtures/mozilla/delete_and.sql @@ -0,0 +1 @@ +DELETE FROM albums WHERE id = 1 AND active = FALSE diff --git a/tests/fixtures/mozilla/insert_multi.expected b/tests/fixtures/mozilla/insert_multi.expected new file mode 100644 index 0000000..d1484ec --- /dev/null +++ b/tests/fixtures/mozilla/insert_multi.expected @@ -0,0 +1,4 @@ +INSERT INTO albums (title, release_date) +VALUES + ('Charcoal Lane', '1990-01-01'), + ('The New Danger', '2008-01-01'); diff --git a/tests/fixtures/mozilla/insert_multi.sql b/tests/fixtures/mozilla/insert_multi.sql new file mode 100644 index 0000000..16f0a85 --- /dev/null +++ b/tests/fixtures/mozilla/insert_multi.sql @@ -0,0 +1 @@ +INSERT INTO albums (title, release_date) VALUES ('Charcoal Lane', '1990-01-01'), ('The New Danger', '2008-01-01') diff --git a/tests/fixtures/mozilla/select_cte.expected b/tests/fixtures/mozilla/select_cte.expected new file mode 100644 index 0000000..894e2ad --- /dev/null +++ b/tests/fixtures/mozilla/select_cte.expected @@ -0,0 +1,23 @@ +WITH recent AS ( + SELECT + id, + name + FROM users + WHERE + active = TRUE + ORDER BY created_at DESC + LIMIT 10 +) + +SELECT + r.id, + r.name, + COUNT(o.id) AS order_count +FROM + recent AS r +LEFT JOIN + orders AS o + ON r.id = o.user_id +GROUP BY + r.id, + r.name; diff --git a/tests/fixtures/mozilla/select_cte.sql b/tests/fixtures/mozilla/select_cte.sql new file mode 100644 index 0000000..2728200 --- /dev/null +++ b/tests/fixtures/mozilla/select_cte.sql @@ -0,0 +1 @@ +WITH recent AS (SELECT id, name FROM users WHERE active = TRUE ORDER BY created_at DESC LIMIT 10) SELECT r.id, r.name, COUNT(o.id) AS order_count FROM recent AS r LEFT JOIN orders AS o ON r.id = o.user_id GROUP BY r.id, r.name diff --git a/tests/fixtures/mozilla/select_group_order.expected b/tests/fixtures/mozilla/select_group_order.expected new file mode 100644 index 0000000..245d71d --- /dev/null +++ b/tests/fixtures/mozilla/select_group_order.expected @@ -0,0 +1,8 @@ +SELECT + x, + COUNT(*) AS cnt +FROM t +GROUP BY x +HAVING + COUNT(*) > 1 +ORDER BY cnt DESC; diff --git a/tests/fixtures/mozilla/select_group_order.sql b/tests/fixtures/mozilla/select_group_order.sql new file mode 100644 index 0000000..0c8a5bf --- /dev/null +++ b/tests/fixtures/mozilla/select_group_order.sql @@ -0,0 +1 @@ +SELECT x, COUNT(*) AS cnt FROM t GROUP BY x HAVING COUNT(*) > 1 ORDER BY cnt DESC diff --git a/tests/fixtures/mozilla/select_join.expected b/tests/fixtures/mozilla/select_join.expected new file mode 100644 index 0000000..ab2855c --- /dev/null +++ b/tests/fixtures/mozilla/select_join.expected @@ -0,0 +1,7 @@ +SELECT r.last_name +FROM + riders AS r +INNER JOIN + bikes AS b + ON r.bike_vin_num = b.vin_num + AND b.engines > 2; diff --git a/tests/fixtures/mozilla/select_join.sql b/tests/fixtures/mozilla/select_join.sql new file mode 100644 index 0000000..74e5e80 --- /dev/null +++ b/tests/fixtures/mozilla/select_join.sql @@ -0,0 +1 @@ +SELECT r.last_name FROM riders AS r INNER JOIN bikes AS b ON r.bike_vin_num = b.vin_num AND b.engines > 2 diff --git a/tests/fixtures/mozilla/select_simple.expected b/tests/fixtures/mozilla/select_simple.expected new file mode 100644 index 0000000..c1f466f --- /dev/null +++ b/tests/fixtures/mozilla/select_simple.expected @@ -0,0 +1,8 @@ +SELECT + client_id, + submission_date +FROM main_summary +WHERE + submission_date > '20180101' + AND sample_id = '42' +LIMIT 10; diff --git a/tests/fixtures/mozilla/select_simple.sql b/tests/fixtures/mozilla/select_simple.sql new file mode 100644 index 0000000..1d6621b --- /dev/null +++ b/tests/fixtures/mozilla/select_simple.sql @@ -0,0 +1 @@ +SELECT client_id, submission_date FROM main_summary WHERE submission_date > '20180101' AND sample_id = '42' LIMIT 10 diff --git a/tests/fixtures/mozilla/select_single_col.expected b/tests/fixtures/mozilla/select_single_col.expected new file mode 100644 index 0000000..f983c21 --- /dev/null +++ b/tests/fixtures/mozilla/select_single_col.expected @@ -0,0 +1,3 @@ +SELECT name +FROM users +LIMIT 10; diff --git a/tests/fixtures/mozilla/select_single_col.sql b/tests/fixtures/mozilla/select_single_col.sql new file mode 100644 index 0000000..c3f6569 --- /dev/null +++ b/tests/fixtures/mozilla/select_single_col.sql @@ -0,0 +1 @@ +SELECT name FROM users LIMIT 10 diff --git a/tests/fixtures/mozilla/select_subquery.expected b/tests/fixtures/mozilla/select_subquery.expected new file mode 100644 index 0000000..e92e38c --- /dev/null +++ b/tests/fixtures/mozilla/select_subquery.expected @@ -0,0 +1,9 @@ +SELECT x +FROM t +WHERE + EXISTS ( + SELECT 1 + FROM s + WHERE + s.id = t.id + ); diff --git a/tests/fixtures/mozilla/select_subquery.sql b/tests/fixtures/mozilla/select_subquery.sql new file mode 100644 index 0000000..e5b55f1 --- /dev/null +++ b/tests/fixtures/mozilla/select_subquery.sql @@ -0,0 +1 @@ +SELECT x FROM t WHERE EXISTS (SELECT 1 FROM s WHERE s.id = t.id) diff --git a/tests/fixtures/mozilla/select_union.expected b/tests/fixtures/mozilla/select_union.expected new file mode 100644 index 0000000..7003c05 --- /dev/null +++ b/tests/fixtures/mozilla/select_union.expected @@ -0,0 +1,5 @@ +SELECT x +FROM t +UNION ALL +SELECT y +FROM s; diff --git a/tests/fixtures/mozilla/select_union.sql b/tests/fixtures/mozilla/select_union.sql new file mode 100644 index 0000000..2eeea9f --- /dev/null +++ b/tests/fixtures/mozilla/select_union.sql @@ -0,0 +1 @@ +SELECT x FROM t UNION ALL SELECT y FROM s diff --git a/tests/fixtures/mozilla/select_using_join.expected b/tests/fixtures/mozilla/select_using_join.expected new file mode 100644 index 0000000..9dbdd52 --- /dev/null +++ b/tests/fixtures/mozilla/select_using_join.expected @@ -0,0 +1,13 @@ +SELECT + sp.account_id, + p.plan_name +FROM + subscriptions AS sp +JOIN + plans AS p + USING (plan_id) +JOIN + features AS f + USING (feature_id) +WHERE + f.feature_name = 'notifications'; diff --git a/tests/fixtures/mozilla/select_using_join.sql b/tests/fixtures/mozilla/select_using_join.sql new file mode 100644 index 0000000..9b42940 --- /dev/null +++ b/tests/fixtures/mozilla/select_using_join.sql @@ -0,0 +1 @@ +SELECT sp.account_id, p.plan_name FROM subscriptions AS sp JOIN plans AS p USING (plan_id) JOIN features AS f USING (feature_id) WHERE f.feature_name = 'notifications' diff --git a/tests/fixtures/mozilla/update_multi.expected b/tests/fixtures/mozilla/update_multi.expected new file mode 100644 index 0000000..4a54401 --- /dev/null +++ b/tests/fixtures/mozilla/update_multi.expected @@ -0,0 +1,6 @@ +UPDATE albums +SET + release_date = '1990-01-01', + title = 'test' +WHERE + id = 1; diff --git a/tests/fixtures/mozilla/update_multi.sql b/tests/fixtures/mozilla/update_multi.sql new file mode 100644 index 0000000..727d575 --- /dev/null +++ b/tests/fixtures/mozilla/update_multi.sql @@ -0,0 +1 @@ +UPDATE albums SET release_date = '1990-01-01', title = 'test' WHERE id = 1 diff --git a/tests/fixtures/river/create_domain.expected b/tests/fixtures/river/create_domain.expected new file mode 100644 index 0000000..f1849d9 --- /dev/null +++ b/tests/fixtures/river/create_domain.expected @@ -0,0 +1,2 @@ +CREATE DOMAIN public.status_type AS TEXT + CONSTRAINT valid_values CHECK (value = ANY (ARRAY['active'::TEXT, 'inactive'::TEXT, 'pending'::TEXT])); diff --git a/tests/fixtures/river/create_domain.sql b/tests/fixtures/river/create_domain.sql new file mode 100644 index 0000000..0326241 --- /dev/null +++ b/tests/fixtures/river/create_domain.sql @@ -0,0 +1 @@ +CREATE DOMAIN public.status_type AS text CONSTRAINT valid_values CHECK ((VALUE = ANY (ARRAY['active'::text, 'inactive'::text, 'pending'::text]))) diff --git a/tests/fixtures/river/create_foreign_table.expected b/tests/fixtures/river/create_foreign_table.expected new file mode 100644 index 0000000..219fdf0 --- /dev/null +++ b/tests/fixtures/river/create_foreign_table.expected @@ -0,0 +1,12 @@ +CREATE FOREIGN TABLE fdw_reporting.metrics ( + account_id INTEGER NOT NULL, + total_amount NUMERIC(10, 2), + avg_amount NUMERIC(10, 2), + generated_on DATE NOT NULL +) +SERVER reporting_server +OPTIONS ( + batch_size '100000', + schema_name 'report', + table_name 'metrics' +); diff --git a/tests/fixtures/river/create_foreign_table.sql b/tests/fixtures/river/create_foreign_table.sql new file mode 100644 index 0000000..8972327 --- /dev/null +++ b/tests/fixtures/river/create_foreign_table.sql @@ -0,0 +1 @@ +CREATE FOREIGN TABLE fdw_reporting.metrics (account_id integer NOT NULL, total_amount numeric(10,2), avg_amount numeric(10,2), generated_on date NOT NULL) SERVER reporting_server OPTIONS (batch_size '100000', schema_name 'report', table_name 'metrics') diff --git a/tests/fixtures/river/create_function.expected b/tests/fixtures/river/create_function.expected new file mode 100644 index 0000000..5787c41 --- /dev/null +++ b/tests/fixtures/river/create_function.expected @@ -0,0 +1,5 @@ +CREATE FUNCTION app.update_modified_at() RETURNS TRIGGER + LANGUAGE plpgsql + AS $$ + BEGIN IF (TG_OP = 'UPDATE') THEN NEW.modified_at = CURRENT_TIMESTAMP; END IF; RETURN NEW; END; +$$; diff --git a/tests/fixtures/river/create_function.sql b/tests/fixtures/river/create_function.sql new file mode 100644 index 0000000..836ac7f --- /dev/null +++ b/tests/fixtures/river/create_function.sql @@ -0,0 +1 @@ +CREATE FUNCTION app.update_modified_at() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN IF (TG_OP = 'UPDATE') THEN NEW.modified_at = CURRENT_TIMESTAMP; END IF; RETURN NEW; END; $$ diff --git a/tests/fixtures/river/create_matview.expected b/tests/fixtures/river/create_matview.expected new file mode 100644 index 0000000..a419701 --- /dev/null +++ b/tests/fixtures/river/create_matview.expected @@ -0,0 +1,29 @@ +CREATE MATERIALIZED VIEW report.service_subscription_info AS + SELECT sp.account_id, + sp.subscription_id, + sp.plan_id, + bt.term, + p.plan_name, + pf.units_included, + pf.units_increment, + pf.price_increment, + sp.cancelled_at + FROM public.subscriptions AS sp + JOIN public.plans AS p + USING (plan_id) + + INNER JOIN public.plan_details AS pd + USING (plan_id) + + INNER JOIN public.pricing_fees AS pf + USING (plan_detail_id) + + INNER JOIN public.features AS f + USING (feature_id) + + INNER JOIN public.billing_terms AS bt + USING (billing_term_id) + WHERE sp.cancelled_at IS NULL OR sp.cancelled_at > CURRENT_TIMESTAMP + AND f.feature_name = 'notifications'::TEXT +ORDER BY sp.cancelled_at DESC, sp.started_at +WITH NO DATA; diff --git a/tests/fixtures/river/create_matview.sql b/tests/fixtures/river/create_matview.sql new file mode 100644 index 0000000..4beaa82 --- /dev/null +++ b/tests/fixtures/river/create_matview.sql @@ -0,0 +1 @@ +CREATE MATERIALIZED VIEW report.service_subscription_info AS SELECT sp.account_id, sp.subscription_id, sp.plan_id, bt.term, p.plan_name, pf.units_included, pf.units_increment, pf.price_increment, sp.cancelled_at FROM public.subscriptions sp JOIN public.plans p USING (plan_id) JOIN public.plan_details pd USING (plan_id) JOIN public.pricing_fees pf USING (plan_detail_id) JOIN public.features f USING (feature_id) JOIN public.billing_terms bt USING (billing_term_id) WHERE (sp.cancelled_at IS NULL OR sp.cancelled_at > CURRENT_TIMESTAMP) AND f.feature_name = 'notifications'::text ORDER BY sp.cancelled_at DESC, sp.started_at WITH NO DATA diff --git a/tests/fixtures/river/create_table.expected b/tests/fixtures/river/create_table.expected new file mode 100644 index 0000000..d0db889 --- /dev/null +++ b/tests/fixtures/river/create_table.expected @@ -0,0 +1,8 @@ +CREATE TABLE staff ( + PRIMARY KEY (staff_num), + staff_num INTEGER NOT NULL, + first_name TEXT NOT NULL, + pens_in_drawer INTEGER NOT NULL, + CONSTRAINT pens_in_drawer_range + CHECK(pens_in_drawer >= 1 AND pens_in_drawer < 100) +); diff --git a/tests/fixtures/river/create_table.sql b/tests/fixtures/river/create_table.sql new file mode 100644 index 0000000..b4e09d3 --- /dev/null +++ b/tests/fixtures/river/create_table.sql @@ -0,0 +1 @@ +CREATE TABLE staff (staff_num INTEGER NOT NULL, first_name TEXT NOT NULL, pens_in_drawer INTEGER NOT NULL, CONSTRAINT pens_in_drawer_range CHECK(pens_in_drawer >= 1 AND pens_in_drawer < 100), PRIMARY KEY (staff_num)) diff --git a/tests/fixtures/river/create_table_with.expected b/tests/fixtures/river/create_table_with.expected new file mode 100644 index 0000000..8bbdbb0 --- /dev/null +++ b/tests/fixtures/river/create_table_with.expected @@ -0,0 +1,7 @@ +CREATE TABLE mikkoo.audit ( + message_id UUID NOT NULL, + event_id BIGINT NOT NULL, + queue TEXT NOT NULL, + published_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() NOT NULL +) +WITH (autovacuum_vacuum_threshold='100', autovacuum_vacuum_scale_factor='0'); diff --git a/tests/fixtures/river/create_table_with.sql b/tests/fixtures/river/create_table_with.sql new file mode 100644 index 0000000..5b4fec5 --- /dev/null +++ b/tests/fixtures/river/create_table_with.sql @@ -0,0 +1 @@ +CREATE TABLE mikkoo.audit (message_id uuid NOT NULL, event_id bigint NOT NULL, queue text NOT NULL, published_at timestamp with time zone DEFAULT now() NOT NULL) WITH (autovacuum_vacuum_threshold='100', autovacuum_vacuum_scale_factor='0') diff --git a/tests/fixtures/river/create_view_cte.expected b/tests/fixtures/river/create_view_cte.expected new file mode 100644 index 0000000..1887d46 --- /dev/null +++ b/tests/fixtures/river/create_view_cte.expected @@ -0,0 +1,16 @@ +CREATE VIEW report.order_summary_vw AS +WITH recent_orders AS ( + SELECT DISTINCT ON (orders.account_id) orders.account_id, + orders.order_id + FROM public.orders +ORDER BY orders.account_id, orders.placed_at DESC +) + SELECT ro.account_id, + MAX(CASE WHEN sr.label::TEXT ~~ '%Cancelled%'::TEXT THEN 'Cancelled'::TEXT ELSE 'Active'::TEXT END) AS account_status + FROM recent_orders AS ro + LEFT JOIN public.order_status AS os + ON ro.order_id = os.order_id + + LEFT JOIN public.status_reason AS sr + ON os.reason_id = sr.reason_id +GROUP BY ro.account_id; diff --git a/tests/fixtures/river/create_view_cte.sql b/tests/fixtures/river/create_view_cte.sql new file mode 100644 index 0000000..5f0f80b --- /dev/null +++ b/tests/fixtures/river/create_view_cte.sql @@ -0,0 +1 @@ +CREATE VIEW report.order_summary_vw AS WITH recent_orders AS (SELECT DISTINCT ON (orders.account_id) orders.account_id, orders.order_id FROM public.orders ORDER BY orders.account_id, orders.placed_at DESC) SELECT ro.account_id, max(CASE WHEN (sr.label)::text ~~ '%Cancelled%'::text THEN 'Cancelled'::text ELSE 'Active'::text END) AS account_status FROM recent_orders AS ro LEFT JOIN public.order_status AS os ON ro.order_id = os.order_id LEFT JOIN public.status_reason AS sr ON os.reason_id = sr.reason_id GROUP BY ro.account_id diff --git a/tests/fixtures/river/delete_simple.expected b/tests/fixtures/river/delete_simple.expected new file mode 100644 index 0000000..335b52a --- /dev/null +++ b/tests/fixtures/river/delete_simple.expected @@ -0,0 +1,3 @@ +DELETE + FROM albums + WHERE id = 1; diff --git a/tests/fixtures/river/delete_simple.sql b/tests/fixtures/river/delete_simple.sql new file mode 100644 index 0000000..7dee3ba --- /dev/null +++ b/tests/fixtures/river/delete_simple.sql @@ -0,0 +1 @@ +DELETE FROM albums WHERE id = 1 diff --git a/tests/fixtures/river/insert_values.expected b/tests/fixtures/river/insert_values.expected new file mode 100644 index 0000000..7ee8b3e --- /dev/null +++ b/tests/fixtures/river/insert_values.expected @@ -0,0 +1,3 @@ +INSERT INTO albums (title, release_date, recording_date) + VALUES ('Charcoal Lane', '1990-01-01 01:01:01.00000', '1990-01-01 01:01:01.00000'), + ('The New Danger', '2008-01-01 01:01:01.00000', '1990-01-01 01:01:01.00000'); diff --git a/tests/fixtures/river/insert_values.sql b/tests/fixtures/river/insert_values.sql new file mode 100644 index 0000000..a12e263 --- /dev/null +++ b/tests/fixtures/river/insert_values.sql @@ -0,0 +1 @@ +INSERT INTO albums (title, release_date, recording_date) VALUES ('Charcoal Lane', '1990-01-01 01:01:01.00000', '1990-01-01 01:01:01.00000'), ('The New Danger', '2008-01-01 01:01:01.00000', '1990-01-01 01:01:01.00000') diff --git a/tests/fixtures/river/select_agg_functions.expected b/tests/fixtures/river/select_agg_functions.expected new file mode 100644 index 0000000..77df8ef --- /dev/null +++ b/tests/fixtures/river/select_agg_functions.expected @@ -0,0 +1,2 @@ +SELECT SUM(s.monitor_tally) AS monitor_total + FROM staff AS s; diff --git a/tests/fixtures/river/select_agg_functions.sql b/tests/fixtures/river/select_agg_functions.sql new file mode 100644 index 0000000..1187422 --- /dev/null +++ b/tests/fixtures/river/select_agg_functions.sql @@ -0,0 +1 @@ +SELECT SUM(s.monitor_tally) AS monitor_total FROM staff AS s diff --git a/tests/fixtures/river/select_alias.expected b/tests/fixtures/river/select_alias.expected new file mode 100644 index 0000000..bea1208 --- /dev/null +++ b/tests/fixtures/river/select_alias.expected @@ -0,0 +1,4 @@ +SELECT first_name AS fn + FROM staff AS s1 + JOIN students AS s2 + ON s2.mentor_id = s1.staff_num; diff --git a/tests/fixtures/river/select_alias.sql b/tests/fixtures/river/select_alias.sql new file mode 100644 index 0000000..86ac68a --- /dev/null +++ b/tests/fixtures/river/select_alias.sql @@ -0,0 +1 @@ +SELECT first_name AS fn FROM staff AS s1 JOIN students AS s2 ON s2.mentor_id = s1.staff_num diff --git a/tests/fixtures/river/select_and.expected b/tests/fixtures/river/select_and.expected new file mode 100644 index 0000000..2f6ecf5 --- /dev/null +++ b/tests/fixtures/river/select_and.expected @@ -0,0 +1,4 @@ +SELECT model_num + FROM phones AS p + WHERE p.released_on >= '2014-09-30' + AND p.manufacturer = 'Apple'; diff --git a/tests/fixtures/river/select_and.sql b/tests/fixtures/river/select_and.sql new file mode 100644 index 0000000..cd01eba --- /dev/null +++ b/tests/fixtures/river/select_and.sql @@ -0,0 +1 @@ +SELECT model_num FROM phones AS p WHERE p.released_on >= '2014-09-30' AND p.manufacturer = 'Apple' diff --git a/tests/fixtures/river/select_case.expected b/tests/fixtures/river/select_case.expected new file mode 100644 index 0000000..6dea3d9 --- /dev/null +++ b/tests/fixtures/river/select_case.expected @@ -0,0 +1,5 @@ +SELECT CASE postcode WHEN 'BN1' THEN 'Brighton' WHEN 'EH1' THEN 'Edinburgh' END AS city + FROM office_locations + WHERE country = 'United Kingdom' + AND opening_time BETWEEN 8 AND 9 + AND postcode IN ('EH1', 'BN1', 'NN1', 'KW1'); diff --git a/tests/fixtures/river/select_case.sql b/tests/fixtures/river/select_case.sql new file mode 100644 index 0000000..09e7ba9 --- /dev/null +++ b/tests/fixtures/river/select_case.sql @@ -0,0 +1 @@ +SELECT CASE postcode WHEN 'BN1' THEN 'Brighton' WHEN 'EH1' THEN 'Edinburgh' END AS city FROM office_locations WHERE country = 'United Kingdom' AND opening_time BETWEEN 8 AND 9 AND postcode IN ('EH1', 'BN1', 'NN1', 'KW1') diff --git a/tests/fixtures/river/select_cte.expected b/tests/fixtures/river/select_cte.expected new file mode 100644 index 0000000..357ae90 --- /dev/null +++ b/tests/fixtures/river/select_cte.expected @@ -0,0 +1,15 @@ +WITH recent AS ( + SELECT id, + name + FROM users + WHERE active = TRUE +ORDER BY created_at DESC + LIMIT 10 +) + SELECT r.id, + r.name, + COUNT(o.id) AS order_count + FROM recent AS r + LEFT JOIN orders AS o + ON r.id = o.user_id +GROUP BY r.id, r.name; diff --git a/tests/fixtures/river/select_cte.sql b/tests/fixtures/river/select_cte.sql new file mode 100644 index 0000000..2728200 --- /dev/null +++ b/tests/fixtures/river/select_cte.sql @@ -0,0 +1 @@ +WITH recent AS (SELECT id, name FROM users WHERE active = TRUE ORDER BY created_at DESC LIMIT 10) SELECT r.id, r.name, COUNT(o.id) AS order_count FROM recent AS r LEFT JOIN orders AS o ON r.id = o.user_id GROUP BY r.id, r.name diff --git a/tests/fixtures/river/select_distinct.expected b/tests/fixtures/river/select_distinct.expected new file mode 100644 index 0000000..6756609 --- /dev/null +++ b/tests/fixtures/river/select_distinct.expected @@ -0,0 +1,4 @@ +SELECT DISTINCT x, + y + FROM t + WHERE z = 1; diff --git a/tests/fixtures/river/select_distinct.sql b/tests/fixtures/river/select_distinct.sql new file mode 100644 index 0000000..0373fc6 --- /dev/null +++ b/tests/fixtures/river/select_distinct.sql @@ -0,0 +1 @@ +SELECT DISTINCT x, y FROM t WHERE z = 1 diff --git a/tests/fixtures/river/select_group_by.expected b/tests/fixtures/river/select_group_by.expected new file mode 100644 index 0000000..b6aedd5 --- /dev/null +++ b/tests/fixtures/river/select_group_by.expected @@ -0,0 +1,6 @@ + SELECT f.species_name, + AVG(f.height) AS average_height + FROM flora AS f + WHERE f.species_name = 'Banksia' + OR f.species_name = 'Sheoak' +GROUP BY f.species_name, f.observation_date; diff --git a/tests/fixtures/river/select_group_by.sql b/tests/fixtures/river/select_group_by.sql new file mode 100644 index 0000000..cef8525 --- /dev/null +++ b/tests/fixtures/river/select_group_by.sql @@ -0,0 +1 @@ +SELECT f.species_name, AVG(f.height) AS average_height FROM flora AS f WHERE f.species_name = 'Banksia' OR f.species_name = 'Sheoak' GROUP BY f.species_name, f.observation_date diff --git a/tests/fixtures/river/select_having.expected b/tests/fixtures/river/select_having.expected new file mode 100644 index 0000000..6024ac9 --- /dev/null +++ b/tests/fixtures/river/select_having.expected @@ -0,0 +1,5 @@ + SELECT x, + COUNT(*) AS cnt + FROM t +GROUP BY x + HAVING COUNT(*) > 1; diff --git a/tests/fixtures/river/select_having.sql b/tests/fixtures/river/select_having.sql new file mode 100644 index 0000000..45ee404 --- /dev/null +++ b/tests/fixtures/river/select_having.sql @@ -0,0 +1 @@ +SELECT x, COUNT(*) AS cnt FROM t GROUP BY x HAVING COUNT(*) > 1 diff --git a/tests/fixtures/river/select_join.expected b/tests/fixtures/river/select_join.expected new file mode 100644 index 0000000..082da09 --- /dev/null +++ b/tests/fixtures/river/select_join.expected @@ -0,0 +1,9 @@ +SELECT r.last_name + FROM riders AS r + INNER JOIN bikes AS b + ON r.bike_vin_num = b.vin_num + AND b.engines > 2 + + INNER JOIN crew AS c + ON r.crew_chief_last_name = c.last_name + AND c.chief = 'Y'; diff --git a/tests/fixtures/river/select_join.sql b/tests/fixtures/river/select_join.sql new file mode 100644 index 0000000..9f3cd6a --- /dev/null +++ b/tests/fixtures/river/select_join.sql @@ -0,0 +1 @@ +SELECT r.last_name FROM riders AS r INNER JOIN bikes AS b ON r.bike_vin_num = b.vin_num AND b.engines > 2 INNER JOIN crew AS c ON r.crew_chief_last_name = c.last_name AND c.chief = 'Y' diff --git a/tests/fixtures/river/select_or.expected b/tests/fixtures/river/select_or.expected new file mode 100644 index 0000000..f0aa381 --- /dev/null +++ b/tests/fixtures/river/select_or.expected @@ -0,0 +1,6 @@ +SELECT a.title, + a.released_on, + a.recorded_on + FROM albums AS a + WHERE a.title = 'Charcoal Lane' + OR a.title = 'The New Danger'; diff --git a/tests/fixtures/river/select_or.sql b/tests/fixtures/river/select_or.sql new file mode 100644 index 0000000..e89a06b --- /dev/null +++ b/tests/fixtures/river/select_or.sql @@ -0,0 +1 @@ +SELECT a.title, a.released_on, a.recorded_on FROM albums AS a WHERE a.title = 'Charcoal Lane' OR a.title = 'The New Danger' diff --git a/tests/fixtures/river/select_order_limit.expected b/tests/fixtures/river/select_order_limit.expected new file mode 100644 index 0000000..4f1ad8e --- /dev/null +++ b/tests/fixtures/river/select_order_limit.expected @@ -0,0 +1,5 @@ + SELECT x + FROM t +ORDER BY y DESC NULLS LAST + LIMIT 10 + OFFSET 5; diff --git a/tests/fixtures/river/select_order_limit.sql b/tests/fixtures/river/select_order_limit.sql new file mode 100644 index 0000000..7cb7636 --- /dev/null +++ b/tests/fixtures/river/select_order_limit.sql @@ -0,0 +1 @@ +SELECT x FROM t ORDER BY y DESC NULLS LAST LIMIT 10 OFFSET 5 diff --git a/tests/fixtures/river/select_simple.expected b/tests/fixtures/river/select_simple.expected new file mode 100644 index 0000000..c72ec93 --- /dev/null +++ b/tests/fixtures/river/select_simple.expected @@ -0,0 +1,3 @@ +SELECT file_hash + FROM file_system + WHERE file_name = '.vimrc'; diff --git a/tests/fixtures/river/select_simple.sql b/tests/fixtures/river/select_simple.sql new file mode 100644 index 0000000..41a369a --- /dev/null +++ b/tests/fixtures/river/select_simple.sql @@ -0,0 +1 @@ +SELECT file_hash FROM file_system WHERE file_name = '.vimrc' diff --git a/tests/fixtures/river/select_subquery_exists.expected b/tests/fixtures/river/select_subquery_exists.expected new file mode 100644 index 0000000..77575f3 --- /dev/null +++ b/tests/fixtures/river/select_subquery_exists.expected @@ -0,0 +1,5 @@ +SELECT x + FROM t + WHERE EXISTS (SELECT 1 + FROM s + WHERE s.id = t.id); diff --git a/tests/fixtures/river/select_subquery_exists.sql b/tests/fixtures/river/select_subquery_exists.sql new file mode 100644 index 0000000..e5b55f1 --- /dev/null +++ b/tests/fixtures/river/select_subquery_exists.sql @@ -0,0 +1 @@ +SELECT x FROM t WHERE EXISTS (SELECT 1 FROM s WHERE s.id = t.id) diff --git a/tests/fixtures/river/select_subquery_in.expected b/tests/fixtures/river/select_subquery_in.expected new file mode 100644 index 0000000..a8ef8fc --- /dev/null +++ b/tests/fixtures/river/select_subquery_in.expected @@ -0,0 +1,6 @@ +SELECT r.last_name + FROM riders AS r + WHERE r.last_name IN (SELECT c.last_name + FROM champions AS c + WHERE YEAR(championship_date) > '2008' + AND c.confirmed = 'Y'); diff --git a/tests/fixtures/river/select_subquery_in.sql b/tests/fixtures/river/select_subquery_in.sql new file mode 100644 index 0000000..b7e9b74 --- /dev/null +++ b/tests/fixtures/river/select_subquery_in.sql @@ -0,0 +1 @@ +SELECT r.last_name FROM riders AS r WHERE r.last_name IN (SELECT c.last_name FROM champions AS c WHERE YEAR(championship_date) > '2008' AND c.confirmed = 'Y') diff --git a/tests/fixtures/river/select_subquery_nested.expected b/tests/fixtures/river/select_subquery_nested.expected new file mode 100644 index 0000000..9bb3883 --- /dev/null +++ b/tests/fixtures/river/select_subquery_nested.expected @@ -0,0 +1,10 @@ +SELECT r.last_name, + (SELECT MAX(YEAR(championship_date)) + FROM champions AS c + WHERE c.last_name = r.last_name + AND c.confirmed = 'Y') AS last_championship_year + FROM riders AS r + WHERE r.last_name IN (SELECT c.last_name + FROM champions AS c + WHERE YEAR(championship_date) > '2008' + AND c.confirmed = 'Y'); diff --git a/tests/fixtures/river/select_subquery_nested.sql b/tests/fixtures/river/select_subquery_nested.sql new file mode 100644 index 0000000..0c73c3b --- /dev/null +++ b/tests/fixtures/river/select_subquery_nested.sql @@ -0,0 +1 @@ +SELECT r.last_name, (SELECT MAX(YEAR(championship_date)) FROM champions AS c WHERE c.last_name = r.last_name AND c.confirmed = 'Y') AS last_championship_year FROM riders AS r WHERE r.last_name IN (SELECT c.last_name FROM champions AS c WHERE YEAR(championship_date) > '2008' AND c.confirmed = 'Y') diff --git a/tests/fixtures/river/select_subquery_scalar.expected b/tests/fixtures/river/select_subquery_scalar.expected new file mode 100644 index 0000000..fa41ec2 --- /dev/null +++ b/tests/fixtures/river/select_subquery_scalar.expected @@ -0,0 +1,6 @@ +SELECT x, + (SELECT MAX(y) + FROM t2 + WHERE t2.id = t1.id) AS max_y + FROM t1 + WHERE z = 1; diff --git a/tests/fixtures/river/select_subquery_scalar.sql b/tests/fixtures/river/select_subquery_scalar.sql new file mode 100644 index 0000000..bf61c8d --- /dev/null +++ b/tests/fixtures/river/select_subquery_scalar.sql @@ -0,0 +1 @@ +SELECT x, (SELECT MAX(y) FROM t2 WHERE t2.id = t1.id) AS max_y FROM t1 WHERE z = 1 diff --git a/tests/fixtures/river/select_union.expected b/tests/fixtures/river/select_union.expected new file mode 100644 index 0000000..8654504 --- /dev/null +++ b/tests/fixtures/river/select_union.expected @@ -0,0 +1,7 @@ +SELECT x + FROM t + +UNION ALL + +SELECT y + FROM s; diff --git a/tests/fixtures/river/select_union.sql b/tests/fixtures/river/select_union.sql new file mode 100644 index 0000000..2eeea9f --- /dev/null +++ b/tests/fixtures/river/select_union.sql @@ -0,0 +1 @@ +SELECT x FROM t UNION ALL SELECT y FROM s diff --git a/tests/fixtures/river/update_multi_set.expected b/tests/fixtures/river/update_multi_set.expected new file mode 100644 index 0000000..2aebcbb --- /dev/null +++ b/tests/fixtures/river/update_multi_set.expected @@ -0,0 +1,4 @@ +UPDATE file_system + SET file_modified_at = '1980-02-22 13:19:01.00000', + file_size = 209732 + WHERE file_name = '.vimrc'; diff --git a/tests/fixtures/river/update_multi_set.sql b/tests/fixtures/river/update_multi_set.sql new file mode 100644 index 0000000..073ea9f --- /dev/null +++ b/tests/fixtures/river/update_multi_set.sql @@ -0,0 +1 @@ +UPDATE file_system SET file_modified_at = '1980-02-22 13:19:01.00000', file_size = 209732 WHERE file_name = '.vimrc' diff --git a/tests/fixtures/river/update_simple.expected b/tests/fixtures/river/update_simple.expected new file mode 100644 index 0000000..583b7b0 --- /dev/null +++ b/tests/fixtures/river/update_simple.expected @@ -0,0 +1,3 @@ +UPDATE albums + SET release_date = '1990-01-01 01:01:01.00000' + WHERE title = 'The New Danger'; diff --git a/tests/fixtures/river/update_simple.sql b/tests/fixtures/river/update_simple.sql new file mode 100644 index 0000000..694e0bd --- /dev/null +++ b/tests/fixtures/river/update_simple.sql @@ -0,0 +1 @@ +UPDATE albums SET release_date = '1990-01-01 01:01:01.00000' WHERE title = 'The New Danger' diff --git a/tests/fixtures_test.rs b/tests/fixtures_test.rs new file mode 100644 index 0000000..820fc6d --- /dev/null +++ b/tests/fixtures_test.rs @@ -0,0 +1,328 @@ +use libpgfmt::{format, style::Style}; +use std::path::Path; + +fn run_fixture(style: Style, name: &str) { + let style_dir = match style { + Style::River => "river", + Style::Mozilla => "mozilla", + Style::Aweber => "aweber", + Style::Dbt => "dbt", + Style::Gitlab => "gitlab", + Style::Kickstarter => "kickstarter", + Style::Mattmc3 => "mattmc3", + }; + let base = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("fixtures") + .join(style_dir); + let sql_path = base.join(format!("{name}.sql")); + let expected_path = base.join(format!("{name}.expected")); + + let sql = std::fs::read_to_string(&sql_path) + .unwrap_or_else(|e| panic!("Failed to read {}: {e}", sql_path.display())); + let expected = std::fs::read_to_string(&expected_path) + .unwrap_or_else(|e| panic!("Failed to read {}: {e}", expected_path.display())); + + let result = format(sql.trim(), style) + .unwrap_or_else(|e| panic!("Failed to format {style_dir}/{name}: {e}")); + + pretty_assertions::assert_eq!( + result.trim(), + expected.trim(), + "\n\nStyle: {style_dir}, Fixture: {name}" + ); +} + +// ── River fixtures ────────────────────────────────────────────────────── + +#[test] +fn river_select_simple() { + run_fixture(Style::River, "select_simple"); +} + +#[test] +fn river_select_and() { + run_fixture(Style::River, "select_and"); +} + +#[test] +fn river_select_or() { + run_fixture(Style::River, "select_or"); +} + +#[test] +fn river_select_alias() { + run_fixture(Style::River, "select_alias"); +} + +#[test] +fn river_select_agg_functions() { + run_fixture(Style::River, "select_agg_functions"); +} + +#[test] +fn river_select_case() { + run_fixture(Style::River, "select_case"); +} + +#[test] +fn river_select_distinct() { + run_fixture(Style::River, "select_distinct"); +} + +#[test] +fn river_select_group_by() { + run_fixture(Style::River, "select_group_by"); +} + +#[test] +fn river_select_having() { + run_fixture(Style::River, "select_having"); +} + +#[test] +fn river_select_join() { + run_fixture(Style::River, "select_join"); +} + +#[test] +fn river_select_order_limit() { + run_fixture(Style::River, "select_order_limit"); +} + +#[test] +fn river_select_cte() { + run_fixture(Style::River, "select_cte"); +} + +#[test] +fn river_select_subquery_exists() { + run_fixture(Style::River, "select_subquery_exists"); +} + +#[test] +fn river_select_subquery_in() { + run_fixture(Style::River, "select_subquery_in"); +} + +#[test] +fn river_select_subquery_scalar() { + run_fixture(Style::River, "select_subquery_scalar"); +} + +#[test] +fn river_select_subquery_nested() { + run_fixture(Style::River, "select_subquery_nested"); +} + +#[test] +fn river_select_union() { + run_fixture(Style::River, "select_union"); +} + +#[test] +fn river_insert_values() { + run_fixture(Style::River, "insert_values"); +} + +#[test] +fn river_update_simple() { + run_fixture(Style::River, "update_simple"); +} + +#[test] +fn river_update_multi_set() { + run_fixture(Style::River, "update_multi_set"); +} + +#[test] +fn river_delete_simple() { + run_fixture(Style::River, "delete_simple"); +} + +#[test] +fn river_create_table() { + run_fixture(Style::River, "create_table"); +} + +// ── Mozilla fixtures ──────────────────────────────────────────────────── + +#[test] +fn mozilla_select_simple() { + run_fixture(Style::Mozilla, "select_simple"); +} + +#[test] +fn mozilla_select_single_col() { + run_fixture(Style::Mozilla, "select_single_col"); +} + +#[test] +fn mozilla_select_join() { + run_fixture(Style::Mozilla, "select_join"); +} + +#[test] +fn mozilla_select_group_order() { + run_fixture(Style::Mozilla, "select_group_order"); +} + +#[test] +fn mozilla_select_cte() { + run_fixture(Style::Mozilla, "select_cte"); +} + +#[test] +fn mozilla_select_subquery() { + run_fixture(Style::Mozilla, "select_subquery"); +} + +#[test] +fn mozilla_select_union() { + run_fixture(Style::Mozilla, "select_union"); +} + +#[test] +fn mozilla_insert_multi() { + run_fixture(Style::Mozilla, "insert_multi"); +} + +#[test] +fn mozilla_update_multi() { + run_fixture(Style::Mozilla, "update_multi"); +} + +#[test] +fn mozilla_delete_and() { + run_fixture(Style::Mozilla, "delete_and"); +} + +#[test] +fn mozilla_create_table() { + run_fixture(Style::Mozilla, "create_table"); +} + +#[test] +fn mozilla_select_using_join() { + run_fixture(Style::Mozilla, "select_using_join"); +} + +// ── AWeber fixtures ───────────────────────────────────────────────────── + +#[test] +fn aweber_select_simple() { + run_fixture(Style::Aweber, "select_simple"); +} + +#[test] +fn aweber_select_or() { + run_fixture(Style::Aweber, "select_or"); +} + +#[test] +fn aweber_select_join() { + run_fixture(Style::Aweber, "select_join"); +} + +#[test] +fn aweber_select_left_join() { + run_fixture(Style::Aweber, "select_left_join"); +} + +#[test] +fn aweber_select_subquery() { + run_fixture(Style::Aweber, "select_subquery"); +} + +// ── dbt fixtures ──────────────────────────────────────────────────────── + +#[test] +fn dbt_select_simple() { + run_fixture(Style::Dbt, "select_simple"); +} + +#[test] +fn dbt_select_join() { + run_fixture(Style::Dbt, "select_join"); +} + +#[test] +fn dbt_select_group_order() { + run_fixture(Style::Dbt, "select_group_order"); +} + +#[test] +fn dbt_select_cte() { + run_fixture(Style::Dbt, "select_cte"); +} + +// ── GitLab fixtures ───────────────────────────────────────────────────── + +#[test] +fn gitlab_select_simple() { + run_fixture(Style::Gitlab, "select_simple"); +} + +#[test] +fn gitlab_select_join() { + run_fixture(Style::Gitlab, "select_join"); +} + +#[test] +fn gitlab_select_group_order() { + run_fixture(Style::Gitlab, "select_group_order"); +} + +#[test] +fn gitlab_select_cte() { + run_fixture(Style::Gitlab, "select_cte"); +} + +// ── Kickstarter fixtures ──────────────────────────────────────────────── + +#[test] +fn kickstarter_select_simple() { + run_fixture(Style::Kickstarter, "select_simple"); +} + +#[test] +fn kickstarter_select_join() { + run_fixture(Style::Kickstarter, "select_join"); +} + +#[test] +fn kickstarter_select_where() { + run_fixture(Style::Kickstarter, "select_where"); +} + +#[test] +fn kickstarter_select_cte() { + run_fixture(Style::Kickstarter, "select_cte"); +} + +// ── mattmc3 fixtures ──────────────────────────────────────────────────── + +#[test] +fn mattmc3_select_simple() { + run_fixture(Style::Mattmc3, "select_simple"); +} + +#[test] +fn mattmc3_select_or() { + run_fixture(Style::Mattmc3, "select_or"); +} + +#[test] +fn mattmc3_select_join() { + run_fixture(Style::Mattmc3, "select_join"); +} + +#[test] +fn mattmc3_insert_values() { + run_fixture(Style::Mattmc3, "insert_values"); +} + +#[test] +fn mattmc3_update_multi() { + run_fixture(Style::Mattmc3, "update_multi"); +} diff --git a/tests/smoke_test.rs b/tests/smoke_test.rs new file mode 100644 index 0000000..4c90af9 --- /dev/null +++ b/tests/smoke_test.rs @@ -0,0 +1,83 @@ +use libpgfmt::{format, style::Style}; + +#[test] +fn select_simple_river() { + let sql = "SELECT file_hash FROM file_system WHERE file_name = '.vimrc'"; + let result = format(sql, Style::River).unwrap(); + let expected = "\ +SELECT file_hash + FROM file_system + WHERE file_name = '.vimrc';"; + assert_eq!(result, expected, "\nGot:\n{result}"); +} + +#[test] +fn select_multi_col_or_river() { + let sql = "SELECT a.title, a.released_on, a.recorded_on FROM albums AS a WHERE a.title = 'Charcoal Lane' OR a.title = 'The New Danger'"; + let result = format(sql, Style::River).unwrap(); + let expected = "\ +SELECT a.title, + a.released_on, + a.recorded_on + FROM albums AS a + WHERE a.title = 'Charcoal Lane' + OR a.title = 'The New Danger';"; + assert_eq!(result, expected, "\nGot:\n{result}"); +} + +#[test] +fn delete_simple_river() { + let sql = "DELETE FROM albums WHERE id = 1"; + let result = format(sql, Style::River).unwrap(); + let expected = "\ +DELETE + FROM albums + WHERE id = 1;"; + assert_eq!(result, expected, "\nGot:\n{result}"); +} + +#[test] +fn update_simple_river() { + let sql = "UPDATE albums SET release_date = '1990-01-01 01:01:01.00000' WHERE title = 'The New Danger'"; + let result = format(sql, Style::River).unwrap(); + let expected = "\ +UPDATE albums + SET release_date = '1990-01-01 01:01:01.00000' + WHERE title = 'The New Danger';"; + assert_eq!(result, expected, "\nGot:\n{result}"); +} + +#[test] +fn select_simple_mozilla() { + let sql = "SELECT client_id, submission_date FROM main_summary WHERE submission_date > '20180101' AND sample_id = '42' LIMIT 10"; + let result = format(sql, Style::Mozilla).unwrap(); + let expected = "\ +SELECT + client_id, + submission_date +FROM main_summary +WHERE + submission_date > '20180101' + AND sample_id = '42' +LIMIT 10;"; + assert_eq!(result, expected, "\nGot:\n{result}"); +} + +#[test] +fn select_simple_dbt() { + let sql = "SELECT client_id, submission_date FROM main_summary WHERE submission_date > '20180101' AND sample_id = '42' LIMIT 10"; + let result = format(sql, Style::Dbt).unwrap(); + let expected = "\ +select + client_id, + submission_date + +from main_summary + +where + submission_date > '20180101' + and sample_id = '42' + +limit 10;"; + assert_eq!(result, expected, "\nGot:\n{result}"); +}