diff --git a/src/common/enum_util.cpp b/src/common/enum_util.cpp index 9e04cf9130e2..9aca1288a8c3 100644 --- a/src/common/enum_util.cpp +++ b/src/common/enum_util.cpp @@ -3896,19 +3896,20 @@ const StringUtil::EnumStringLiteral *GetQueryNodeTypeValues() { { static_cast(QueryNodeType::BOUND_SUBQUERY_NODE), "BOUND_SUBQUERY_NODE" }, { static_cast(QueryNodeType::RECURSIVE_CTE_NODE), "RECURSIVE_CTE_NODE" }, { static_cast(QueryNodeType::CTE_NODE), "CTE_NODE" }, - { static_cast(QueryNodeType::STATEMENT_NODE), "STATEMENT_NODE" } + { static_cast(QueryNodeType::STATEMENT_NODE), "STATEMENT_NODE" }, + { static_cast(QueryNodeType::INSERT_QUERY_NODE), "INSERT_QUERY_NODE" } }; return values; } template<> const char* EnumUtil::ToChars(QueryNodeType value) { - return StringUtil::EnumToString(GetQueryNodeTypeValues(), 6, "QueryNodeType", static_cast(value)); + return StringUtil::EnumToString(GetQueryNodeTypeValues(), 7, "QueryNodeType", static_cast(value)); } template<> QueryNodeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 6, "QueryNodeType", value)); + return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 7, "QueryNodeType", value)); } const StringUtil::EnumStringLiteral *GetQueryResultMemoryTypeValues() { diff --git a/src/include/duckdb/parser/query_node.hpp b/src/include/duckdb/parser/query_node.hpp index 4981e8afcecb..0fec98d76922 100644 --- a/src/include/duckdb/parser/query_node.hpp +++ b/src/include/duckdb/parser/query_node.hpp @@ -25,7 +25,8 @@ enum class QueryNodeType : uint8_t { BOUND_SUBQUERY_NODE = 3, RECURSIVE_CTE_NODE = 4, CTE_NODE = 5, - STATEMENT_NODE = 6 + STATEMENT_NODE = 6, + INSERT_QUERY_NODE = 7 }; struct CommonTableExpressionInfo; diff --git a/src/include/duckdb/parser/query_node/insert_query_node.hpp b/src/include/duckdb/parser/query_node/insert_query_node.hpp new file mode 100644 index 000000000000..360979b45f73 --- /dev/null +++ b/src/include/duckdb/parser/query_node/insert_query_node.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/insert_query_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/tableref.hpp" + +namespace duckdb { + +class OnConflictInfo; +enum class InsertColumnOrder : uint8_t; + +//! InsertQueryNode represents an INSERT statement with RETURNING used in a CTE +class InsertQueryNode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::INSERT_QUERY_NODE; + +public: + InsertQueryNode(); + + //! The select statement to insert from + unique_ptr select_statement; + //! Column names to insert into + vector columns; + + //! Table name to insert to + string table; + //! Schema name to insert to + string schema; + //! The catalog name to insert to + string catalog; + + //! The RETURNING clause expressions + vector> returning_list; + + //! ON CONFLICT info + unique_ptr on_conflict_info; + //! Table reference (for qualified names) + unique_ptr table_ref; + + //! Whether or not this is a DEFAULT VALUES insert + bool default_values = false; + + //! INSERT BY POSITION or INSERT BY NAME + InsertColumnOrder column_order; + +public: + //! Convert the query node to a string + string ToString() const override; + + bool Equals(const QueryNode *other) const override; + + //! Create a copy of this InsertQueryNode + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/include/duckdb/parser/query_node/list.hpp b/src/include/duckdb/parser/query_node/list.hpp index 3a2894cc4c06..8476e9ab02db 100644 --- a/src/include/duckdb/parser/query_node/list.hpp +++ b/src/include/duckdb/parser/query_node/list.hpp @@ -3,3 +3,4 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/set_operation_node.hpp" #include "duckdb/parser/query_node/statement_node.hpp" +#include "duckdb/parser/query_node/insert_query_node.hpp" diff --git a/src/include/duckdb/parser/tokens.hpp b/src/include/duckdb/parser/tokens.hpp index d5646739cc3e..ef184dcb436b 100644 --- a/src/include/duckdb/parser/tokens.hpp +++ b/src/include/duckdb/parser/tokens.hpp @@ -54,6 +54,7 @@ class SetOperationNode; class RecursiveCTENode; class CTENode; class StatementNode; +class InsertQueryNode; //===--------------------------------------------------------------------===// // Expressions diff --git a/src/include/duckdb/planner/binder.hpp b/src/include/duckdb/planner/binder.hpp index 523ba484ab00..f4ee340cf427 100644 --- a/src/include/duckdb/planner/binder.hpp +++ b/src/include/duckdb/planner/binder.hpp @@ -411,6 +411,7 @@ class Binder : public enable_shared_from_this { BoundStatement BindNode(RecursiveCTENode &node); BoundStatement BindNode(QueryNode &node); BoundStatement BindNode(StatementNode &node); + BoundStatement BindNode(InsertQueryNode &node); unique_ptr VisitQueryNode(BoundQueryNode &node, unique_ptr root); unique_ptr CreatePlan(BoundSelectNode &statement); diff --git a/src/parser/parsed_expression_iterator.cpp b/src/parser/parsed_expression_iterator.cpp index f5746f9f7044..d0a1874bea0f 100644 --- a/src/parser/parsed_expression_iterator.cpp +++ b/src/parser/parsed_expression_iterator.cpp @@ -6,6 +6,8 @@ #include "duckdb/parser/query_node/recursive_cte_node.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/insert_query_node.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" #include "duckdb/parser/tableref/list.hpp" namespace duckdb { @@ -298,6 +300,36 @@ void ParsedExpressionIterator::EnumerateQueryNodeChildren( } break; } + case QueryNodeType::INSERT_QUERY_NODE: { + auto &insert_node = node.Cast(); + for (auto &expr : insert_node.returning_list) { + expr_callback(expr); + } + if (insert_node.table_ref) { + EnumerateTableRefChildren(*insert_node.table_ref, expr_callback, ref_callback); + } + if (insert_node.select_statement && insert_node.select_statement->node) { + EnumerateQueryNodeChildren(*insert_node.select_statement->node, expr_callback, ref_callback); + } + // Traverse on_conflict_info expressions + if (insert_node.on_conflict_info) { + if (insert_node.on_conflict_info->condition) { + expr_callback(insert_node.on_conflict_info->condition); + } + if (insert_node.on_conflict_info->set_info) { + for (auto &expr : insert_node.on_conflict_info->set_info->expressions) { + expr_callback(expr); + } + if (insert_node.on_conflict_info->set_info->condition) { + expr_callback(insert_node.on_conflict_info->set_info->condition); + } + } + } + break; + } + case QueryNodeType::STATEMENT_NODE: + case QueryNodeType::CTE_NODE: + case QueryNodeType::BOUND_SUBQUERY_NODE: default: throw NotImplementedException("QueryNode type not implemented for traversal"); } diff --git a/src/parser/query_node/CMakeLists.txt b/src/parser/query_node/CMakeLists.txt index 45b6a23de62c..3f97e144c789 100644 --- a/src/parser/query_node/CMakeLists.txt +++ b/src/parser/query_node/CMakeLists.txt @@ -5,7 +5,8 @@ add_library_unity( cte_node.cpp select_node.cpp set_operation_node.cpp - statement_node.cpp) + statement_node.cpp + insert_query_node.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ PARENT_SCOPE) diff --git a/src/parser/query_node/insert_query_node.cpp b/src/parser/query_node/insert_query_node.cpp new file mode 100644 index 000000000000..b7e0c6940439 --- /dev/null +++ b/src/parser/query_node/insert_query_node.cpp @@ -0,0 +1,174 @@ +#include "duckdb/parser/query_node/insert_query_node.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/expression_util.hpp" +#include "duckdb/common/serializer/serializer.hpp" + +namespace duckdb { + +static bool UpdateSetInfoEquals(const unique_ptr &left, const unique_ptr &right) { + if (left && right) { + if (left->columns != right->columns) { + return false; + } + if (!ExpressionUtil::ListEquals(left->expressions, right->expressions)) { + return false; + } + if (!ParsedExpression::Equals(left->condition, right->condition)) { + return false; + } + return true; + } + // One is null, the other is not + return left.get() == right.get(); +} + +static bool OnConflictInfoEquals(const unique_ptr &left, const unique_ptr &right) { + if (left && right) { + if (left->action_type != right->action_type) { + return false; + } + if (left->indexed_columns != right->indexed_columns) { + return false; + } + if (!UpdateSetInfoEquals(left->set_info, right->set_info)) { + return false; + } + if (!ParsedExpression::Equals(left->condition, right->condition)) { + return false; + } + return true; + } + // One is null, the other is not + return left.get() == right.get(); +} + +InsertQueryNode::InsertQueryNode() + : QueryNode(QueryNodeType::INSERT_QUERY_NODE), default_values(false), + column_order(InsertColumnOrder::INSERT_BY_POSITION) { +} + +string InsertQueryNode::ToString() const { + string result; + result += "INSERT INTO "; + if (!catalog.empty() && catalog != INVALID_CATALOG) { + result += KeywordHelper::WriteOptionallyQuoted(catalog) + "."; + } + if (!schema.empty() && schema != DEFAULT_SCHEMA) { + result += KeywordHelper::WriteOptionallyQuoted(schema) + "."; + } + result += KeywordHelper::WriteOptionallyQuoted(table); + if (table_ref && !table_ref->alias.empty()) { + result += StringUtil::Format(" AS %s", KeywordHelper::WriteOptionallyQuoted(table_ref->alias)); + } + if (column_order == InsertColumnOrder::INSERT_BY_NAME) { + result += " BY NAME"; + } + if (!columns.empty()) { + result += " ("; + for (idx_t i = 0; i < columns.size(); i++) { + if (i > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(columns[i]); + } + result += ")"; + } + result += " "; + if (select_statement) { + result += select_statement->ToString(); + } else if (default_values) { + result += "DEFAULT VALUES"; + } + if (!returning_list.empty()) { + result += " RETURNING "; + for (idx_t i = 0; i < returning_list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += returning_list[i]->ToString(); + if (!returning_list[i]->GetAlias().empty()) { + result += StringUtil::Format(" AS %s", + KeywordHelper::WriteOptionallyQuoted(returning_list[i]->GetAlias())); + } + } + } + return result + ResultModifiersToString(); +} + +bool InsertQueryNode::Equals(const QueryNode *other_p) const { + if (!QueryNode::Equals(other_p)) { + return false; + } + if (this == other_p) { + return true; + } + auto &other = other_p->Cast(); + + if (catalog != other.catalog || schema != other.schema || table != other.table) { + return false; + } + if (columns != other.columns) { + return false; + } + if (default_values != other.default_values) { + return false; + } + if (column_order != other.column_order) { + return false; + } + if (!ExpressionUtil::ListEquals(returning_list, other.returning_list)) { + return false; + } + if (!TableRef::Equals(table_ref, other.table_ref)) { + return false; + } + // Compare select_statement + if (select_statement && other.select_statement) { + if (!select_statement->Equals(*other.select_statement)) { + return false; + } + } else if (select_statement || other.select_statement) { + return false; + } + // Compare on_conflict_info + if (!OnConflictInfoEquals(on_conflict_info, other.on_conflict_info)) { + return false; + } + return true; +} + +unique_ptr InsertQueryNode::Copy() const { + auto result = make_uniq(); + result->catalog = catalog; + result->schema = schema; + result->table = table; + result->columns = columns; + result->default_values = default_values; + result->column_order = column_order; + if (select_statement) { + result->select_statement = unique_ptr_cast(select_statement->Copy()); + } + for (auto &expr : returning_list) { + result->returning_list.push_back(expr->Copy()); + } + if (on_conflict_info) { + result->on_conflict_info = on_conflict_info->Copy(); + } + if (table_ref) { + result->table_ref = table_ref->Copy(); + } + this->CopyProperties(*result); + return std::move(result); +} + +void InsertQueryNode::Serialize(Serializer &serializer) const { + // For now, disallow serialization of DML CTEs (views cannot contain them) + throw NotImplementedException("INSERT in CTE cannot be serialized - views cannot contain data-modifying CTEs"); +} + +unique_ptr InsertQueryNode::Deserialize(Deserializer &deserializer) { + throw NotImplementedException("INSERT in CTE cannot be deserialized"); +} + +} // namespace duckdb diff --git a/src/parser/transform/helpers/transform_cte.cpp b/src/parser/transform/helpers/transform_cte.cpp index 309891bad2e2..73817a2b7752 100644 --- a/src/parser/transform/helpers/transform_cte.cpp +++ b/src/parser/transform/helpers/transform_cte.cpp @@ -4,6 +4,7 @@ #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/parser/query_node/recursive_cte_node.hpp" #include "duckdb/parser/query_node/statement_node.hpp" +#include "duckdb/parser/query_node/insert_query_node.hpp" #include "duckdb/parser/statement/select_statement.hpp" #include "duckdb/parser/statement/insert_statement.hpp" #include "duckdb/parser/statement/update_statement.hpp" @@ -115,12 +116,22 @@ void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, if (insert->returning_list.empty()) { throw ParserException("INSERT in a CTE must have a RETURNING clause"); } - // Wrap the DML statement in a SelectStatement via StatementNode + // Create InsertQueryNode with fields from the InsertStatement + auto insert_node = make_uniq(); + insert_node->catalog = insert->catalog; + insert_node->schema = insert->schema; + insert_node->table = insert->table; + insert_node->columns = std::move(insert->columns); + insert_node->default_values = insert->default_values; + insert_node->column_order = insert->column_order; + insert_node->select_statement = std::move(insert->select_statement); + insert_node->returning_list = std::move(insert->returning_list); + insert_node->on_conflict_info = std::move(insert->on_conflict_info); + insert_node->table_ref = std::move(insert->table_ref); + // Copy inner CTEs (for WITH ... INSERT ... syntax) + insert_node->cte_map = insert->cte_map.Copy(); info->query = make_uniq(); - // Copy inner CTEs before moving the statement (for WITH ... INSERT ... syntax) - auto inner_ctes = insert->cte_map.Copy(); - info->query->node = make_uniq(std::move(insert)); - info->query->node->cte_map = std::move(inner_ctes); + info->query->node = std::move(insert_node); is_dml_cte = true; break; } diff --git a/src/parser/transform/statement/transform_create_view.cpp b/src/parser/transform/statement/transform_create_view.cpp index 992e410504f9..e29d495dc826 100644 --- a/src/parser/transform/statement/transform_create_view.cpp +++ b/src/parser/transform/statement/transform_create_view.cpp @@ -11,9 +11,10 @@ static bool ContainsWritableCTEInNode(const QueryNode &node); static bool ContainsWritableCTEInCTEMap(const CommonTableExpressionMap &cte_map) { for (auto &entry : cte_map.map) { auto &cte_info = entry.second; - // Check if this CTE's query node is a StatementNode (writable CTE) + // Check if this CTE's query node is a DML query node (writable CTE) if (cte_info->query && cte_info->query->node) { - if (cte_info->query->node->type == QueryNodeType::STATEMENT_NODE) { + if (cte_info->query->node->type == QueryNodeType::STATEMENT_NODE || + cte_info->query->node->type == QueryNodeType::INSERT_QUERY_NODE) { return true; } // Recursively check CTEs within this CTE's query diff --git a/src/planner/binder/query_node/CMakeLists.txt b/src/planner/binder/query_node/CMakeLists.txt index 709d23ce1749..935fcf0f03af 100644 --- a/src/planner/binder/query_node/CMakeLists.txt +++ b/src/planner/binder/query_node/CMakeLists.txt @@ -6,6 +6,7 @@ add_library_unity( bind_recursive_cte_node.cpp bind_cte_node.cpp bind_statement_node.cpp + bind_dml_query_node.cpp bind_table_macro_node.cpp plan_query_node.cpp plan_select_node.cpp diff --git a/src/planner/binder/query_node/bind_cte_node.cpp b/src/planner/binder/query_node/bind_cte_node.cpp index 573cc77f31de..0fb5851ed069 100644 --- a/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/planner/binder/query_node/bind_cte_node.cpp @@ -42,6 +42,9 @@ BoundStatement Binder::BindNode(QueryNode &node) { case QueryNodeType::STATEMENT_NODE: result = current_binder.get().BindNode(node.Cast()); break; + case QueryNodeType::INSERT_QUERY_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; default: throw InternalException("Unsupported query node type"); } @@ -116,9 +119,11 @@ BoundCTEData Binder::PrepareCTE(const string &ctename, CommonTableExpressionInfo result.materialized = statement.materialized; result.setop_index = GenerateTableIndex(); - // Check if this CTE contains a DML statement (StatementNode wraps INSERT/UPDATE/DELETE) + // Check if this CTE contains a DML statement (INSERT/UPDATE/DELETE with RETURNING) // DML CTEs must be bound even if unreferenced to ensure side effects occur - result.is_dml_cte = statement.query->node->type == QueryNodeType::STATEMENT_NODE; + auto node_type = statement.query->node->type; + result.is_dml_cte = node_type == QueryNodeType::INSERT_QUERY_NODE || + node_type == QueryNodeType::STATEMENT_NODE; // Legacy support // instead of eagerly binding the CTE here we add the CTE bind state to the list of CTE bindings // the CTE is bound lazily - when referenced for the first time we perform the binding diff --git a/src/planner/binder/query_node/bind_dml_query_node.cpp b/src/planner/binder/query_node/bind_dml_query_node.cpp new file mode 100644 index 000000000000..5db040f1e828 --- /dev/null +++ b/src/planner/binder/query_node/bind_dml_query_node.cpp @@ -0,0 +1,32 @@ +#include "duckdb/parser/query_node/insert_query_node.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +BoundStatement Binder::BindNode(InsertQueryNode &node) { + // Convert InsertQueryNode back to InsertStatement for binding + InsertStatement insert; + insert.catalog = node.catalog; + insert.schema = node.schema; + insert.table = node.table; + insert.columns = node.columns; + insert.default_values = node.default_values; + insert.column_order = node.column_order; + if (node.select_statement) { + insert.select_statement = unique_ptr_cast(node.select_statement->Copy()); + } + for (auto &expr : node.returning_list) { + insert.returning_list.push_back(expr->Copy()); + } + if (node.on_conflict_info) { + insert.on_conflict_info = node.on_conflict_info->Copy(); + } + if (node.table_ref) { + insert.table_ref = node.table_ref->Copy(); + } + // Note: We don't copy node.cte_map because the CTEs are handled at the outer level + return Bind(insert); +} + +} // namespace duckdb diff --git a/test/api/CMakeLists.txt b/test/api/CMakeLists.txt index 9423611dba1a..0c01591a68b9 100644 --- a/test/api/CMakeLists.txt +++ b/test/api/CMakeLists.txt @@ -15,6 +15,7 @@ set(TEST_API_OBJECTS test_custom_allocator.cpp test_extension_setting_autoload.cpp test_instance_cache.cpp + test_insert_query_node.cpp test_storage_extension_alias.cpp test_results.cpp test_reset.cpp diff --git a/test/api/test_insert_query_node.cpp b/test/api/test_insert_query_node.cpp new file mode 100644 index 000000000000..309cbfeaa3ca --- /dev/null +++ b/test/api/test_insert_query_node.cpp @@ -0,0 +1,176 @@ +#include "catch.hpp" +#include "test_helpers.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/insert_query_node.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" + +using namespace duckdb; +using namespace std; + +TEST_CASE("Test InsertQueryNode::Equals compares on_conflict_info", "[api]") { + DuckDB db(nullptr); + Connection con(db); + + // Create a table with a primary key for ON CONFLICT to work + REQUIRE_NO_FAIL(con.Query("CREATE TABLE test(id INT PRIMARY KEY, val INT)")); + + // Parse INSERT with ON CONFLICT DO UPDATE - will be transformed into InsertQueryNode + auto sql_with_conflict = "WITH cte AS (INSERT INTO test VALUES (1, 10) ON CONFLICT (id) DO UPDATE SET val = 20 RETURNING *) SELECT * FROM cte"; + auto stmts_with = con.ExtractStatements(sql_with_conflict); + REQUIRE(stmts_with.size() == 1); + + auto &select_stmt_with = stmts_with[0]->Cast(); + // The top-level node is a SELECT node, CTEs are stored in the cte_map + auto &cte_map_with = select_stmt_with.node->cte_map; + REQUIRE(cte_map_with.map.size() == 1); + auto &cte_info_with = cte_map_with.map.begin()->second; + REQUIRE(cte_info_with->query->node->type == QueryNodeType::INSERT_QUERY_NODE); + auto &insert_node_with = cte_info_with->query->node->Cast(); + REQUIRE(insert_node_with.on_conflict_info != nullptr); + + // Parse INSERT without ON CONFLICT + auto sql_without_conflict = "WITH cte AS (INSERT INTO test VALUES (2, 20) RETURNING *) SELECT * FROM cte"; + auto stmts_without = con.ExtractStatements(sql_without_conflict); + REQUIRE(stmts_without.size() == 1); + + auto &select_stmt_without = stmts_without[0]->Cast(); + auto &cte_map_without = select_stmt_without.node->cte_map; + REQUIRE(cte_map_without.map.size() == 1); + auto &cte_info_without = cte_map_without.map.begin()->second; + REQUIRE(cte_info_without->query->node->type == QueryNodeType::INSERT_QUERY_NODE); + auto &insert_node_without = cte_info_without->query->node->Cast(); + REQUIRE(insert_node_without.on_conflict_info == nullptr); + + // These two InsertQueryNode instances differ in on_conflict_info + // Equals() should return false + REQUIRE(!insert_node_with.Equals(&insert_node_without)); +} + +TEST_CASE("Test InsertQueryNode::Equals compares different on_conflict_info actions", "[api]") { + DuckDB db(nullptr); + Connection con(db); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE test(id INT PRIMARY KEY, val INT)")); + + // Parse INSERT with ON CONFLICT DO UPDATE + auto sql_do_update = "WITH cte AS (INSERT INTO test VALUES (1, 10) ON CONFLICT (id) DO UPDATE SET val = 20 RETURNING *) SELECT * FROM cte"; + auto stmts_update = con.ExtractStatements(sql_do_update); + auto &select_stmt_update = stmts_update[0]->Cast(); + auto &cte_info_update = select_stmt_update.node->cte_map.map.begin()->second; + auto &insert_node_update = cte_info_update->query->node->Cast(); + + // Parse INSERT with ON CONFLICT DO NOTHING + auto sql_do_nothing = "WITH cte AS (INSERT INTO test VALUES (1, 10) ON CONFLICT (id) DO NOTHING RETURNING *) SELECT * FROM cte"; + auto stmts_nothing = con.ExtractStatements(sql_do_nothing); + auto &select_stmt_nothing = stmts_nothing[0]->Cast(); + auto &cte_info_nothing = select_stmt_nothing.node->cte_map.map.begin()->second; + auto &insert_node_nothing = cte_info_nothing->query->node->Cast(); + + // Both have on_conflict_info, but with different actions + REQUIRE(insert_node_update.on_conflict_info != nullptr); + REQUIRE(insert_node_nothing.on_conflict_info != nullptr); + + // They should not be equal + REQUIRE(!insert_node_update.Equals(&insert_node_nothing)); +} + +TEST_CASE("Test expression iterator traverses on_conflict_info expressions", "[api]") { + DuckDB db(nullptr); + Connection con(db); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE test(id INT PRIMARY KEY, val INT)")); + + // Parse INSERT with ON CONFLICT DO UPDATE SET that contains expressions + // The expression "val + 100" should be traversed + auto sql = "WITH cte AS (INSERT INTO test VALUES (1, 10) ON CONFLICT (id) DO UPDATE SET val = val + 100 RETURNING *) SELECT * FROM cte"; + auto stmts = con.ExtractStatements(sql); + auto &select_stmt = stmts[0]->Cast(); + auto &cte_info = select_stmt.node->cte_map.map.begin()->second; + auto &insert_node = cte_info->query->node->Cast(); + + REQUIRE(insert_node.on_conflict_info != nullptr); + REQUIRE(insert_node.on_conflict_info->set_info != nullptr); + REQUIRE(!insert_node.on_conflict_info->set_info->expressions.empty()); + + // Count all expressions traversed by the iterator + idx_t expression_count = 0; + bool found_addition_expr = false; + + ParsedExpressionIterator::EnumerateQueryNodeChildren( + *cte_info->query->node, + [&](duckdb::unique_ptr &expr) { + expression_count++; + // Check if we found the addition expression from ON CONFLICT SET + if (expr->ToString().find("+") != string::npos) { + found_addition_expr = true; + } + }, + [](TableRef &) {}); + + // The expression iterator should have found the "val + 100" expression + // This test will fail until the bug is fixed (on_conflict_info expressions are not traversed) + REQUIRE(found_addition_expr); +} + +TEST_CASE("Test expression iterator traverses on_conflict_info condition", "[api]") { + DuckDB db(nullptr); + Connection con(db); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE test(id INT PRIMARY KEY, val INT)")); + + // Parse INSERT with ON CONFLICT DO UPDATE with a WHERE condition + auto sql = "WITH cte AS (INSERT INTO test VALUES (1, 10) ON CONFLICT (id) DO UPDATE SET val = 20 WHERE val < 100 RETURNING *) SELECT * FROM cte"; + auto stmts = con.ExtractStatements(sql); + auto &select_stmt = stmts[0]->Cast(); + auto &cte_info = select_stmt.node->cte_map.map.begin()->second; + auto &insert_node = cte_info->query->node->Cast(); + + REQUIRE(insert_node.on_conflict_info != nullptr); + REQUIRE(insert_node.on_conflict_info->set_info != nullptr); + REQUIRE(insert_node.on_conflict_info->set_info->condition != nullptr); + + // Count expressions and check for the WHERE condition + bool found_comparison_expr = false; + + ParsedExpressionIterator::EnumerateQueryNodeChildren( + *cte_info->query->node, + [&](duckdb::unique_ptr &expr) { + // Check if we found the comparison expression from WHERE clause + if (expr->ToString().find("<") != string::npos) { + found_comparison_expr = true; + } + }, + [](TableRef &) {}); + + // The expression iterator should have found the "val < 100" expression + // This test will fail until the bug is fixed + REQUIRE(found_comparison_expr); +} + +TEST_CASE("Test InsertQueryNode::Copy preserves on_conflict_info", "[api]") { + DuckDB db(nullptr); + Connection con(db); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE test(id INT PRIMARY KEY, val INT)")); + + auto sql = "WITH cte AS (INSERT INTO test VALUES (1, 10) ON CONFLICT (id) DO UPDATE SET val = 20 RETURNING *) SELECT * FROM cte"; + auto stmts = con.ExtractStatements(sql); + auto &select_stmt = stmts[0]->Cast(); + auto &cte_info = select_stmt.node->cte_map.map.begin()->second; + auto &insert_node = cte_info->query->node->Cast(); + + REQUIRE(insert_node.on_conflict_info != nullptr); + + // Copy the node + auto copied = insert_node.Copy(); + auto &copied_insert = copied->Cast(); + + // Verify the copy has on_conflict_info + REQUIRE(copied_insert.on_conflict_info != nullptr); + + // Verify they are equal (after fixing Equals, this should pass) + REQUIRE(insert_node.Equals(copied.get())); +}