Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/common/enum_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3896,19 +3896,20 @@ const StringUtil::EnumStringLiteral *GetQueryNodeTypeValues() {
{ static_cast<uint32_t>(QueryNodeType::BOUND_SUBQUERY_NODE), "BOUND_SUBQUERY_NODE" },
{ static_cast<uint32_t>(QueryNodeType::RECURSIVE_CTE_NODE), "RECURSIVE_CTE_NODE" },
{ static_cast<uint32_t>(QueryNodeType::CTE_NODE), "CTE_NODE" },
{ static_cast<uint32_t>(QueryNodeType::STATEMENT_NODE), "STATEMENT_NODE" }
{ static_cast<uint32_t>(QueryNodeType::STATEMENT_NODE), "STATEMENT_NODE" },
{ static_cast<uint32_t>(QueryNodeType::INSERT_QUERY_NODE), "INSERT_QUERY_NODE" }
};
return values;
}

template<>
const char* EnumUtil::ToChars<QueryNodeType>(QueryNodeType value) {
return StringUtil::EnumToString(GetQueryNodeTypeValues(), 6, "QueryNodeType", static_cast<uint32_t>(value));
return StringUtil::EnumToString(GetQueryNodeTypeValues(), 7, "QueryNodeType", static_cast<uint32_t>(value));
}

template<>
QueryNodeType EnumUtil::FromString<QueryNodeType>(const char *value) {
return static_cast<QueryNodeType>(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 6, "QueryNodeType", value));
return static_cast<QueryNodeType>(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 7, "QueryNodeType", value));
}

const StringUtil::EnumStringLiteral *GetQueryResultMemoryTypeValues() {
Expand Down
3 changes: 2 additions & 1 deletion src/include/duckdb/parser/query_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ enum class QueryNodeType : uint8_t {
BOUND_SUBQUERY_NODE = 3,
RECURSIVE_CTE_NODE = 4,
CTE_NODE = 5,
STATEMENT_NODE = 6
STATEMENT_NODE = 6,
INSERT_QUERY_NODE = 7
};

struct CommonTableExpressionInfo;
Expand Down
68 changes: 68 additions & 0 deletions src/include/duckdb/parser/query_node/insert_query_node.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// duckdb/parser/query_node/insert_query_node.hpp
//
//
//===----------------------------------------------------------------------===//

#pragma once

#include "duckdb/parser/parsed_expression.hpp"
#include "duckdb/parser/query_node.hpp"
#include "duckdb/parser/statement/select_statement.hpp"
#include "duckdb/parser/tableref.hpp"

namespace duckdb {

class OnConflictInfo;
enum class InsertColumnOrder : uint8_t;

//! InsertQueryNode represents an INSERT statement with RETURNING used in a CTE
class InsertQueryNode : public QueryNode {
public:
static constexpr const QueryNodeType TYPE = QueryNodeType::INSERT_QUERY_NODE;

public:
InsertQueryNode();

//! The select statement to insert from
unique_ptr<SelectStatement> select_statement;
//! Column names to insert into
vector<string> columns;

//! Table name to insert to
string table;
//! Schema name to insert to
string schema;
//! The catalog name to insert to
string catalog;

//! The RETURNING clause expressions
vector<unique_ptr<ParsedExpression>> returning_list;

//! ON CONFLICT info
unique_ptr<OnConflictInfo> on_conflict_info;
//! Table reference (for qualified names)
unique_ptr<TableRef> table_ref;

//! Whether or not this is a DEFAULT VALUES insert
bool default_values = false;

//! INSERT BY POSITION or INSERT BY NAME
InsertColumnOrder column_order;

public:
//! Convert the query node to a string
string ToString() const override;

bool Equals(const QueryNode *other) const override;

//! Create a copy of this InsertQueryNode
unique_ptr<QueryNode> Copy() const override;

void Serialize(Serializer &serializer) const override;
static unique_ptr<QueryNode> Deserialize(Deserializer &deserializer);
};

} // namespace duckdb
1 change: 1 addition & 0 deletions src/include/duckdb/parser/query_node/list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
#include "duckdb/parser/query_node/select_node.hpp"
#include "duckdb/parser/query_node/set_operation_node.hpp"
#include "duckdb/parser/query_node/statement_node.hpp"
#include "duckdb/parser/query_node/insert_query_node.hpp"
1 change: 1 addition & 0 deletions src/include/duckdb/parser/tokens.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class SetOperationNode;
class RecursiveCTENode;
class CTENode;
class StatementNode;
class InsertQueryNode;

//===--------------------------------------------------------------------===//
// Expressions
Expand Down
1 change: 1 addition & 0 deletions src/include/duckdb/planner/binder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ class Binder : public enable_shared_from_this<Binder> {
BoundStatement BindNode(RecursiveCTENode &node);
BoundStatement BindNode(QueryNode &node);
BoundStatement BindNode(StatementNode &node);
BoundStatement BindNode(InsertQueryNode &node);

unique_ptr<LogicalOperator> VisitQueryNode(BoundQueryNode &node, unique_ptr<LogicalOperator> root);
unique_ptr<LogicalOperator> CreatePlan(BoundSelectNode &statement);
Expand Down
32 changes: 32 additions & 0 deletions src/parser/parsed_expression_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "duckdb/parser/query_node/recursive_cte_node.hpp"
#include "duckdb/parser/query_node/select_node.hpp"
#include "duckdb/parser/query_node/set_operation_node.hpp"
#include "duckdb/parser/query_node/insert_query_node.hpp"
#include "duckdb/parser/statement/insert_statement.hpp"
#include "duckdb/parser/tableref/list.hpp"

namespace duckdb {
Expand Down Expand Up @@ -298,6 +300,36 @@ void ParsedExpressionIterator::EnumerateQueryNodeChildren(
}
break;
}
case QueryNodeType::INSERT_QUERY_NODE: {
auto &insert_node = node.Cast<InsertQueryNode>();
for (auto &expr : insert_node.returning_list) {
expr_callback(expr);
}
if (insert_node.table_ref) {
EnumerateTableRefChildren(*insert_node.table_ref, expr_callback, ref_callback);
}
if (insert_node.select_statement && insert_node.select_statement->node) {
EnumerateQueryNodeChildren(*insert_node.select_statement->node, expr_callback, ref_callback);
}
// Traverse on_conflict_info expressions
if (insert_node.on_conflict_info) {
if (insert_node.on_conflict_info->condition) {
expr_callback(insert_node.on_conflict_info->condition);
}
if (insert_node.on_conflict_info->set_info) {
for (auto &expr : insert_node.on_conflict_info->set_info->expressions) {
expr_callback(expr);
}
if (insert_node.on_conflict_info->set_info->condition) {
expr_callback(insert_node.on_conflict_info->set_info->condition);
}
}
}
break;
}
case QueryNodeType::STATEMENT_NODE:
case QueryNodeType::CTE_NODE:
case QueryNodeType::BOUND_SUBQUERY_NODE:
default:
throw NotImplementedException("QueryNode type not implemented for traversal");
}
Expand Down
3 changes: 2 additions & 1 deletion src/parser/query_node/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ add_library_unity(
cte_node.cpp
select_node.cpp
set_operation_node.cpp
statement_node.cpp)
statement_node.cpp
insert_query_node.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:duckdb_query_node>
PARENT_SCOPE)
174 changes: 174 additions & 0 deletions src/parser/query_node/insert_query_node.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#include "duckdb/parser/query_node/insert_query_node.hpp"
#include "duckdb/parser/statement/insert_statement.hpp"
#include "duckdb/parser/statement/update_statement.hpp"
#include "duckdb/parser/expression_util.hpp"
#include "duckdb/common/serializer/serializer.hpp"

namespace duckdb {

static bool UpdateSetInfoEquals(const unique_ptr<UpdateSetInfo> &left, const unique_ptr<UpdateSetInfo> &right) {
if (left && right) {
if (left->columns != right->columns) {
return false;
}
if (!ExpressionUtil::ListEquals(left->expressions, right->expressions)) {
return false;
}
if (!ParsedExpression::Equals(left->condition, right->condition)) {
return false;
}
return true;
}
// One is null, the other is not
return left.get() == right.get();
}

static bool OnConflictInfoEquals(const unique_ptr<OnConflictInfo> &left, const unique_ptr<OnConflictInfo> &right) {
if (left && right) {
if (left->action_type != right->action_type) {
return false;
}
if (left->indexed_columns != right->indexed_columns) {
return false;
}
if (!UpdateSetInfoEquals(left->set_info, right->set_info)) {
return false;
}
if (!ParsedExpression::Equals(left->condition, right->condition)) {
return false;
}
return true;
}
// One is null, the other is not
return left.get() == right.get();
}

InsertQueryNode::InsertQueryNode()
: QueryNode(QueryNodeType::INSERT_QUERY_NODE), default_values(false),
column_order(InsertColumnOrder::INSERT_BY_POSITION) {
}

string InsertQueryNode::ToString() const {
string result;
result += "INSERT INTO ";
if (!catalog.empty() && catalog != INVALID_CATALOG) {
result += KeywordHelper::WriteOptionallyQuoted(catalog) + ".";
}
if (!schema.empty() && schema != DEFAULT_SCHEMA) {
result += KeywordHelper::WriteOptionallyQuoted(schema) + ".";
}
result += KeywordHelper::WriteOptionallyQuoted(table);
if (table_ref && !table_ref->alias.empty()) {
result += StringUtil::Format(" AS %s", KeywordHelper::WriteOptionallyQuoted(table_ref->alias));
}
if (column_order == InsertColumnOrder::INSERT_BY_NAME) {
result += " BY NAME";
}
if (!columns.empty()) {
result += " (";
for (idx_t i = 0; i < columns.size(); i++) {
if (i > 0) {
result += ", ";
}
result += KeywordHelper::WriteOptionallyQuoted(columns[i]);
}
result += ")";
}
result += " ";
if (select_statement) {
result += select_statement->ToString();
} else if (default_values) {
result += "DEFAULT VALUES";
}
if (!returning_list.empty()) {
result += " RETURNING ";
for (idx_t i = 0; i < returning_list.size(); i++) {
if (i > 0) {
result += ", ";
}
result += returning_list[i]->ToString();
if (!returning_list[i]->GetAlias().empty()) {
result += StringUtil::Format(" AS %s",
KeywordHelper::WriteOptionallyQuoted(returning_list[i]->GetAlias()));
}
}
}
return result + ResultModifiersToString();
}

bool InsertQueryNode::Equals(const QueryNode *other_p) const {
if (!QueryNode::Equals(other_p)) {
return false;
}
if (this == other_p) {
return true;
}
auto &other = other_p->Cast<InsertQueryNode>();

if (catalog != other.catalog || schema != other.schema || table != other.table) {
return false;
}
if (columns != other.columns) {
return false;
}
if (default_values != other.default_values) {
return false;
}
if (column_order != other.column_order) {
return false;
}
if (!ExpressionUtil::ListEquals(returning_list, other.returning_list)) {
return false;
}
if (!TableRef::Equals(table_ref, other.table_ref)) {
return false;
}
// Compare select_statement
if (select_statement && other.select_statement) {
if (!select_statement->Equals(*other.select_statement)) {
return false;
}
} else if (select_statement || other.select_statement) {
return false;
}
// Compare on_conflict_info
if (!OnConflictInfoEquals(on_conflict_info, other.on_conflict_info)) {
return false;
}
return true;
}

unique_ptr<QueryNode> InsertQueryNode::Copy() const {
auto result = make_uniq<InsertQueryNode>();
result->catalog = catalog;
result->schema = schema;
result->table = table;
result->columns = columns;
result->default_values = default_values;
result->column_order = column_order;
if (select_statement) {
result->select_statement = unique_ptr_cast<SQLStatement, SelectStatement>(select_statement->Copy());
}
for (auto &expr : returning_list) {
result->returning_list.push_back(expr->Copy());
}
if (on_conflict_info) {
result->on_conflict_info = on_conflict_info->Copy();
}
if (table_ref) {
result->table_ref = table_ref->Copy();
}
this->CopyProperties(*result);
return std::move(result);
}

void InsertQueryNode::Serialize(Serializer &serializer) const {
// For now, disallow serialization of DML CTEs (views cannot contain them)
throw NotImplementedException("INSERT in CTE cannot be serialized - views cannot contain data-modifying CTEs");
}

unique_ptr<QueryNode> InsertQueryNode::Deserialize(Deserializer &deserializer) {
throw NotImplementedException("INSERT in CTE cannot be deserialized");
}

} // namespace duckdb
21 changes: 16 additions & 5 deletions src/parser/transform/helpers/transform_cte.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "duckdb/parser/query_node/cte_node.hpp"
#include "duckdb/parser/query_node/recursive_cte_node.hpp"
#include "duckdb/parser/query_node/statement_node.hpp"
#include "duckdb/parser/query_node/insert_query_node.hpp"
#include "duckdb/parser/statement/select_statement.hpp"
#include "duckdb/parser/statement/insert_statement.hpp"
#include "duckdb/parser/statement/update_statement.hpp"
Expand Down Expand Up @@ -115,12 +116,22 @@ void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause,
if (insert->returning_list.empty()) {
throw ParserException("INSERT in a CTE must have a RETURNING clause");
}
// Wrap the DML statement in a SelectStatement via StatementNode
// Create InsertQueryNode with fields from the InsertStatement
auto insert_node = make_uniq<InsertQueryNode>();
insert_node->catalog = insert->catalog;
insert_node->schema = insert->schema;
insert_node->table = insert->table;
insert_node->columns = std::move(insert->columns);
insert_node->default_values = insert->default_values;
insert_node->column_order = insert->column_order;
insert_node->select_statement = std::move(insert->select_statement);
insert_node->returning_list = std::move(insert->returning_list);
insert_node->on_conflict_info = std::move(insert->on_conflict_info);
insert_node->table_ref = std::move(insert->table_ref);
// Copy inner CTEs (for WITH ... INSERT ... syntax)
insert_node->cte_map = insert->cte_map.Copy();
info->query = make_uniq<SelectStatement>();
// Copy inner CTEs before moving the statement (for WITH ... INSERT ... syntax)
auto inner_ctes = insert->cte_map.Copy();
info->query->node = make_uniq<StatementNode>(std::move(insert));
info->query->node->cte_map = std::move(inner_ctes);
info->query->node = std::move(insert_node);
is_dml_cte = true;
break;
}
Expand Down
Loading