Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions src/formatter/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,35 +344,65 @@ impl<'a> Formatter<'a> {
fn format_c_expr(&self, node: Node<'a>) -> String {
let mut parts = Vec::new();
let mut has_block_subquery = false;
let mut paren_depth: u32 = 0;
let mut paren_parts: Vec<String> = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.is_named() {
match child.kind() {
"columnref" => parts.push(self.format_columnref(child)),
"AexprConst" => parts.push(self.format_const(child)),
"func_expr" | "func_application" => parts.push(self.format_func(child)),
"case_expr" => parts.push(self.format_case_expr(child)),
let formatted = match child.kind() {
"columnref" => self.format_columnref(child),
"AexprConst" => self.format_const(child),
"func_expr" | "func_application" => self.format_func(child),
"case_expr" => self.format_case_expr(child),
"select_with_parens" => {
let formatted = self.format_select_with_parens(child);
if formatted.starts_with("(\n") {
let f = self.format_select_with_parens(child);
if f.starts_with("(\n") {
has_block_subquery = true;
}
parts.push(formatted);
}
"kw_exists" => parts.push(self.kw("EXISTS")),
"kw_row" => parts.push(self.kw("ROW")),
_ if child.kind().starts_with("kw_") => {
parts.push(self.format_keyword_node(child));
f
}
_ => parts.push(self.format_expr(child)),
"kw_exists" => self.kw("EXISTS"),
"kw_row" => self.kw("ROW"),
_ if child.kind().starts_with("kw_") => self.format_keyword_node(child),
_ => self.format_expr(child),
};
if paren_depth > 0 {
paren_parts.push(formatted);
} else {
parts.push(formatted);
}
} else {
let text = self.text(child).trim();
if !text.is_empty() {
parts.push(text.to_string());
if text == "(" {
if paren_depth > 0 {
// Nested paren — include it as content.
paren_parts.push("(".to_string());
}
paren_depth += 1;
} else if text == ")" && paren_depth > 0 {
paren_depth -= 1;
if paren_depth == 0 {
// Close outermost paren group.
let inner = paren_parts.join(" ");
parts.push(format!("({inner})"));
paren_parts.clear();
} else {
// Closing a nested paren.
paren_parts.push(")".to_string());
}
} else if !text.is_empty() {
if paren_depth > 0 {
paren_parts.push(text.to_string());
} else {
parts.push(text.to_string());
}
}
}
}
// Unclosed parens — flush as-is.
if paren_depth > 0 {
parts.push(format!("({}", paren_parts.join(" ")));
}

// For block-format subqueries (left-aligned styles), join with simple
// spaces without column-based multiline indentation; the subquery
Expand Down
39 changes: 19 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,27 +157,26 @@ pub fn format_plpgsql(code: &str, style: Style) -> Result<String, FormatError> {
fmt.format_plpgsql_root(root)
}

/// Check whether the parse tree has a structural error (ERROR node wrapping
/// significant content). Small ERROR nodes (e.g., unparsed decimal fraction
/// like ".00") are tolerable and can be passed through as-is.
fn has_structural_error(node: &tree_sitter::Node) -> bool {
if node.is_error() {
// Small leaf ERROR nodes (e.g., ".00" decimal part = 3 bytes) are
// tolerable grammar gaps. Anything larger likely indicates a genuine
// parse failure that would produce garbled output.
let size = node.end_byte() - node.start_byte();
return size > 4;
}
if node.is_missing() {
return true;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.has_error() && has_structural_error(&child) {
return true;
}
/// Check whether the parse tree has a structural error that would prevent
/// meaningful formatting.
///
/// The tree-sitter-postgres grammar has known limitations that produce ERROR
/// nodes for valid SQL (e.g., `IS NOT NULL AND`, parenthesized boolean
/// expressions, dollar-quoted function bodies). We only reject input when
/// the parser couldn't produce any valid statement structure at all.
fn has_structural_error(root: &tree_sitter::Node) -> bool {
// If the parser produced at least one valid toplevel_stmt, the errors
// are grammar limitations (expression-level conflicts, dollar-quoted
// bodies, etc.) — not fundamentally broken SQL. Format what we can.
let mut cursor = root.walk();
let has_valid_stmt = root
.named_children(&mut cursor)
.any(|c| c.kind() == "toplevel_stmt");
if has_valid_stmt {
return false;
}
false
// No valid statements at all — this is genuinely broken input.
true
}

fn find_error_message(node: &tree_sitter::Node, source: &str) -> String {
Expand Down
137 changes: 137 additions & 0 deletions tests/all_fixtures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use libpgfmt::{format, style::Style};
use std::path::Path;

fn run_fixture(style: Style, style_name: &str, name: &str) {
let base = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures")
.join(style_name);
let sql_path = base.join(format!("{name}.sql"));
let expected_path = base.join(format!("{name}.expected"));

let sql = std::fs::read_to_string(&sql_path)
.unwrap_or_else(|e| panic!("Failed to read {}: {e}", sql_path.display()));
let expected = std::fs::read_to_string(&expected_path)
.unwrap_or_else(|e| panic!("Failed to read {}: {e}", expected_path.display()));

let result = format(sql.trim(), style);
match result {
Ok(formatted) => {
pretty_assertions::assert_eq!(
formatted.trim(),
expected.trim(),
"\n\nStyle: {style_name}, Fixture: {name}"
);
}
Err(e) => {
panic!("Failed to format {style_name}/{name}: {e}");
}
}
}

/// Known fixtures that don't match expected output yet due to grammar
/// limitations or incomplete formatting support. These parse successfully
/// but produce different output than the pgfmt reference.
const KNOWN_FAILING: &[&str] = &[
"river/create_domain",
"river/create_foreign_table",
"river/create_function",
"river/create_matview",
"river/create_table_with",
"river/create_view_cte",
"aweber/select_case_join",
"aweber/select_cte_nested",
];

/// Discover all .sql files in each style directory and run them.
#[test]
fn all_fixture_pairs() {
let fixtures_dir = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures");

let styles: &[(&str, Style)] = &[
("river", Style::River),
("mozilla", Style::Mozilla),
("aweber", Style::Aweber),
("dbt", Style::Dbt),
("gitlab", Style::Gitlab),
("kickstarter", Style::Kickstarter),
("mattmc3", Style::Mattmc3),
];

let mut total = 0;
let mut passed = 0;
let mut failures = Vec::new();

for (style_name, style) in styles {
let style_dir = fixtures_dir.join(style_name);
if !style_dir.exists() {
continue;
}
let mut entries: Vec<_> = std::fs::read_dir(&style_dir)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|ext| ext == "sql"))
.collect();
entries.sort_by_key(|e| e.file_name());

for entry in entries {
let stem = entry
.path()
.file_stem()
.unwrap()
.to_string_lossy()
.to_string();
let expected_path = style_dir.join(format!("{stem}.expected"));
if !expected_path.exists() {
eprintln!("SKIP {style_name}/{stem}: no .expected file");
continue;
}
let fixture_key = format!("{style_name}/{stem}");
let is_known_failing = KNOWN_FAILING.contains(&fixture_key.as_str());
total += 1;
let result = std::panic::catch_unwind(|| {
run_fixture(*style, style_name, &stem);
});
match result {
Ok(()) => {
passed += 1;
if is_known_failing {
eprintln!("UNEXPECTED PASS {fixture_key}: remove from KNOWN_FAILING");
}
}
Err(e) => {
if is_known_failing {
eprintln!("EXPECTED FAIL {fixture_key}");
passed += 1; // Don't count as failure.
} else {
let msg = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else {
"unknown panic".to_string()
};
let short = if msg.chars().count() > 200 {
let truncated: String = msg.chars().take(200).collect();
format!("{truncated}...")
} else {
msg
};
failures.push(format!("{fixture_key}: {short}"));
}
}
}
}
}

eprintln!("\n=== Fixture Results: {passed}/{total} passed ===");
if !failures.is_empty() {
eprintln!("\nFailures:");
for f in &failures {
eprintln!(" FAIL: {f}");
}
panic!("{} of {} fixtures failed", failures.len(), total);
}
}
28 changes: 28 additions & 0 deletions tests/test_parens.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use libpgfmt::{format, style::Style};

#[test]
fn preserve_parens_around_or_in_and() {
let sql = "SELECT 1 FROM t WHERE (a IS NULL OR b > 1) AND c = 'x'";
let result = format(sql, Style::River).unwrap();
assert!(
result.contains("(a IS NULL OR b > 1)"),
"Parentheses around OR were dropped:\n{result}"
);
}

#[test]
fn no_unnecessary_parens() {
let sql = "SELECT 1 FROM t WHERE a = 1 AND b = 2";
let result = format(sql, Style::River).unwrap();
assert!(!result.contains('('), "Unexpected parens added:\n{result}");
}

#[test]
fn preserve_adjacent_parens() {
let sql = "SELECT 1 FROM t WHERE (a = 1) AND (b = 2)";
let result = format(sql, Style::River).unwrap();
assert!(
result.contains("(a = 1)") && result.contains("(b = 2)"),
"Adjacent parens corrupted:\n{result}"
);
}
Loading