diff --git a/Cargo.toml b/Cargo.toml index 39b6d4edd..0e827bc10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ connection-string = "0.1.10" percent-encoding = "2" tracing-core = "0.1" async-trait = "0.1" +enumflags2 = "0.7" thiserror = "1.0" once_cell = "1.3" num_cpus = "1.12" diff --git a/db/test.db b/db/test.db index 939858e8e..69516cea2 100644 Binary files a/db/test.db and b/db/test.db differ diff --git a/src/ast.rs b/src/ast.rs index 03f9bc234..6d8d033c9 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -5,6 +5,7 @@ //! actual query building is in the [visitor](../visitor/index.html) module. //! //! For prelude, all important imports are in `quaint::ast::*`. +mod castable; mod column; mod compare; mod conditions; @@ -29,6 +30,7 @@ mod union; mod update; mod values; +pub use castable::*; pub use column::{Column, DefaultValue, TypeDataLength, TypeFamily}; pub use compare::{Comparable, Compare, JsonCompare, JsonType}; pub use conditions::ConditionTree; diff --git a/src/ast/castable.rs b/src/ast/castable.rs new file mode 100644 index 000000000..5c8d65cec --- /dev/null +++ b/src/ast/castable.rs @@ -0,0 +1,295 @@ +use enumflags2::{bitflags, BitFlags}; +use std::borrow::Cow; + +#[bitflags] +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq)] +enum CastDatabase { + Postgres = 1 << 0, + Mysql = 1 << 1, + Mssql = 1 << 2, +} + +/// A typecast for an expression. +/// +/// By default, casting is performed on all databases. To restrict this +/// behavior, use the corresponding methods +/// [on_postgres](struct.CastType.html#method.on_postgres), +/// [on_mysql](struct.CastType.html#method.on_mysql) or +/// [on_sql_server](struct.CastType.html#method.on_sql_server). +/// +/// Always a no-op on SQLite. +#[derive(Debug, Clone, PartialEq)] +pub struct CastType<'a> { + kind: CastKind<'a>, + on_databases: BitFlags, +} + +impl<'a> CastType<'a> { + /// A 16-bit integer. + /// + /// - PostgreSQL: `int2` + /// - MySQL: `signed` + /// - SQL Server: `smallint` + pub fn int2() -> Self { + Self { + kind: CastKind::Int2, + on_databases: BitFlags::all(), + } + } + + /// A 32-bit integer (int) + /// + /// - PostgreSQL: `int4` + /// - MySQL: `signed` + /// - SQL Server: `int` + pub fn int4() -> Self { + Self { + kind: CastKind::Int4, + on_databases: BitFlags::all(), + } + } + + /// A 64-bit integer (bigint) + /// + /// - PostgreSQL: `int8` + /// - MySQL: `signed` + /// - SQL Server: `bigint` + pub fn int8() -> Self { + Self { + kind: CastKind::Int8, + on_databases: BitFlags::all(), + } + } + + /// A 32-bit floating point number + /// + /// - PostgreSQL: `float4` + /// - MySQL: `decimal` + /// - SQL Server: `real` + pub fn float4() -> Self { + Self { + kind: CastKind::Float4, + on_databases: BitFlags::all(), + } + } + + /// A 64-bit floating point number + /// + /// - PostgreSQL: `float8` + /// - MySQL: `decimal` + /// - SQL Server: `float` + pub fn float8() -> Self { + Self { + kind: CastKind::Float8, + on_databases: BitFlags::all(), + } + } + + /// An arbitrary-precision numeric type + /// + /// - PostgreSQL: `numeric` + /// - MySQL: `decimal` + /// - SQL Server: `numeric` + pub fn decimal() -> Self { + Self { + kind: CastKind::Decimal, + on_databases: BitFlags::all(), + } + } + + /// True or false (or a bit) + /// + /// - PostgreSQL: `boolean` + /// - MySQL: `unsigned` + /// - SQL Server: `bit` + pub fn boolean() -> Self { + Self { + kind: CastKind::Boolean, + on_databases: BitFlags::all(), + } + } + + /// A unique identifier + /// + /// - PostgreSQL: `uuid` + /// - MySQL: `char` + /// - SQL Server: `uniqueidentifier` + pub fn uuid() -> Self { + Self { + kind: CastKind::Uuid, + on_databases: BitFlags::all(), + } + } + + /// Json data + /// + /// - PostgreSQL: `json` + /// - MySQL: `nchar` + /// - SQL Server: `nvarchar` + pub fn json() -> Self { + Self { + kind: CastKind::Json, + on_databases: BitFlags::all(), + } + } + + /// Jsonb data + /// + /// - PostgreSQL: `jsonb` + /// - MySQL: `nchar` + /// - SQL Server: `nvarchar` + pub fn jsonb() -> Self { + Self { + kind: CastKind::Jsonb, + on_databases: BitFlags::all(), + } + } + + /// Date value + /// + /// - PostgreSQL: `date` + /// - MySQL: `date` + /// - SQL Server: `date` + pub fn date() -> Self { + Self { + kind: CastKind::Date, + on_databases: BitFlags::all(), + } + } + + /// Time value + /// + /// - PostgreSQL: `time` + /// - MySQL: `time` + /// - SQL Server: `time` + pub fn time() -> Self { + Self { + kind: CastKind::Time, + on_databases: BitFlags::all(), + } + } + + /// Datetime value + /// + /// - PostgreSQL: `datetime` + /// - MySQL: `datetime` + /// - SQL Server: `datetime2` + pub fn datetime() -> Self { + Self { + kind: CastKind::DateTime, + on_databases: BitFlags::all(), + } + } + + /// Byte blob + /// + /// - PostgreSQL: `bytea` + /// - MySQL: `binary` + /// - SQL Server: `bytes` + pub fn bytes() -> Self { + Self { + kind: CastKind::Bytes, + on_databases: BitFlags::all(), + } + } + + /// Textual data + /// + /// - PostgreSQL: `text` + /// - MySQL: `nchar` + /// - SQL Server: `nvarchar` + pub fn text() -> Self { + Self { + kind: CastKind::Text, + on_databases: BitFlags::all(), + } + } + + /// Creates a new custom cast type. + pub fn custom(r#type: impl Into>) -> Self { + Self { + kind: CastKind::Custom(r#type.into()), + on_databases: BitFlags::all(), + } + } + + /// Perform the given cast on PostgreSQL. + pub fn on_postgres(mut self) -> Self { + self.maybe_clear_databases(); + self.on_databases.insert(CastDatabase::Postgres); + + self + } + + /// Perform the given cast on MySQL. + pub fn on_mysql(mut self) -> Self { + self.maybe_clear_databases(); + self.on_databases.insert(CastDatabase::Mysql); + + self + } + + /// Perform the given cast on SQL Server. + pub fn on_sql_server(mut self) -> Self { + self.maybe_clear_databases(); + self.on_databases.insert(CastDatabase::Mssql); + + self + } + + #[cfg(feature = "postgresql")] + pub(crate) fn postgres_enabled(&self) -> bool { + self.on_databases.contains(CastDatabase::Postgres) + } + + #[cfg(feature = "mysql")] + pub(crate) fn mysql_enabled(&self) -> bool { + self.on_databases.contains(CastDatabase::Mysql) + } + + #[cfg(feature = "mssql")] + pub(crate) fn mssql_enabled(&self) -> bool { + self.on_databases.contains(CastDatabase::Mssql) + } + + #[cfg(any(feature = "mssql", feature = "mysql", feature = "mssql"))] + pub(crate) fn kind(&self) -> &CastKind<'a> { + &self.kind + } + + fn maybe_clear_databases(&mut self) { + if self.on_databases.is_all() { + self.on_databases.remove(BitFlags::all()); + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum CastKind<'a> { + Int2, + Int4, + Int8, + Float4, + Float8, + Decimal, + Boolean, + Uuid, + Json, + Jsonb, + Date, + Time, + DateTime, + Bytes, + Text, + Custom(Cow<'a, str>), +} + +/// An item that can be cast to a different type. +pub trait Castable<'a, T> +where + T: Sized, +{ + /// Map the result of the underlying item into a different type. + fn cast_as(self, r#type: CastType<'a>) -> T; +} diff --git a/src/ast/column.rs b/src/ast/column.rs index 87342bd56..56008be41 100644 --- a/src/ast/column.rs +++ b/src/ast/column.rs @@ -5,22 +5,36 @@ use crate::{ }; use std::borrow::Cow; +/// The maximum length of the column. #[derive(Debug, Clone, Copy)] pub enum TypeDataLength { + /// Number of either bytes or characters. Constant(u16), + /// Stored outside of the row in the heap, usually either two or four + /// gigabytes. Maximum, } +/// The type family of the column. #[derive(Debug, Clone, Copy)] pub enum TypeFamily { + /// Textual data with an optional length. Text(Option), + /// Integers. Int, + /// Floating point values, 32-bit. Float, + /// Floating point values, 64-bit. Double, + /// Trues and falses. Boolean, + /// Unique identifiers. Uuid, + /// Date, time and datetime. DateTime, + /// Numerics with an arbitrary scale and precision. Decimal(Option<(u8, u8)>), + /// Blobs with an optional length. Bytes(Option), } @@ -104,6 +118,7 @@ impl<'a> From> for Expression<'a> { Expression { kind: ExpressionKind::Column(Box::new(col)), alias: None, + cast: None, } } } diff --git a/src/ast/compare.rs b/src/ast/compare.rs index e71aba39e..62158334b 100644 --- a/src/ast/compare.rs +++ b/src/ast/compare.rs @@ -230,6 +230,7 @@ impl<'a> From> for Expression<'a> { Expression { kind: ExpressionKind::Compare(cmp), alias: None, + cast: None, } } } diff --git a/src/ast/conditions.rs b/src/ast/conditions.rs index 00a6f110f..014a4bb94 100644 --- a/src/ast/conditions.rs +++ b/src/ast/conditions.rs @@ -129,6 +129,7 @@ impl<'a> From> for Expression<'a> { Expression { kind: ExpressionKind::ConditionTree(ct), alias: None, + cast: None, } } } @@ -138,6 +139,7 @@ impl<'a> From> for ConditionTree<'a> { let exp = Expression { kind: ExpressionKind::Value(Box::new(sel.into())), alias: None, + cast: None, }; ConditionTree::single(exp) diff --git a/src/ast/expression.rs b/src/ast/expression.rs index 4ea640ab9..57f207fa1 100644 --- a/src/ast/expression.rs +++ b/src/ast/expression.rs @@ -4,12 +4,15 @@ use crate::ast::*; use query::SelectQuery; use std::borrow::Cow; +use super::castable::{CastType, Castable}; + /// An expression that can be positioned in a query. Can be a single value or a /// statement that is evaluated into a value. #[derive(Debug, Clone, PartialEq)] pub struct Expression<'a> { pub(crate) kind: ExpressionKind<'a>, pub(crate) alias: Option>, + pub(crate) cast: Option>, } impl<'a> Expression<'a> { @@ -28,6 +31,7 @@ impl<'a> Expression<'a> { Self { kind: ExpressionKind::Row(row), alias: None, + cast: None, } } @@ -40,6 +44,7 @@ impl<'a> Expression<'a> { Self { kind: ExpressionKind::Selection(selection), alias: None, + cast: None, } } @@ -114,6 +119,7 @@ impl<'a> Expression<'a> { let expr = Expression { kind: ExpressionKind::Selection(selection), alias: self.alias, + cast: self.cast, }; (expr, ctes) @@ -124,6 +130,7 @@ impl<'a> Expression<'a> { let expr = Expression { kind: ExpressionKind::Compare(compare), alias: self.alias, + cast: self.cast, }; (expr, Vec::new()) @@ -133,6 +140,7 @@ impl<'a> Expression<'a> { let expr = Expression { kind: ExpressionKind::Compare(comp), alias: self.alias, + cast: self.cast, }; (expr, ctes) @@ -144,6 +152,7 @@ impl<'a> Expression<'a> { let expr = Expression { kind: ExpressionKind::ConditionTree(tree), alias: self.alias, + cast: self.cast, }; (expr, ctes) @@ -153,6 +162,17 @@ impl<'a> Expression<'a> { } } +impl<'a, E> Castable<'a, Expression<'a>> for E +where + E: Into>, +{ + fn cast_as(self, r#type: castable::CastType<'a>) -> Expression<'a> { + let mut exp = self.into(); + exp.cast = Some(r#type); + exp + } +} + /// An expression we can compare and use in database queries. #[derive(Debug, Clone, PartialEq)] pub enum ExpressionKind<'a> { @@ -199,6 +219,7 @@ pub fn asterisk() -> Expression<'static> { Expression { kind: ExpressionKind::Asterisk(None), alias: None, + cast: None, } } @@ -207,6 +228,7 @@ pub fn default_value() -> Expression<'static> { Expression { kind: ExpressionKind::Default, alias: None, + cast: None, } } @@ -217,6 +239,7 @@ impl<'a> From> for Expression<'a> { Expression { kind: ExpressionKind::Function(Box::new(f)), alias: None, + cast: None, } } } @@ -226,6 +249,7 @@ impl<'a> From> for Expression<'a> { Expression { kind: ExpressionKind::RawValue(r), alias: None, + cast: None, } } } @@ -235,6 +259,7 @@ impl<'a> From> for Expression<'a> { Expression { kind: ExpressionKind::Values(Box::new(p)), alias: None, + cast: None, } } } @@ -244,6 +269,7 @@ impl<'a> From> for Expression<'a> { Expression { kind: ExpressionKind::Op(Box::new(p)), alias: None, + cast: None, } } } @@ -256,6 +282,7 @@ where Expression { kind: ExpressionKind::Parameterized(p.into()), alias: None, + cast: None, } } } @@ -272,7 +299,11 @@ where impl<'a> From> for Expression<'a> { fn from(kind: ExpressionKind<'a>) -> Self { - Self { kind, alias: None } + Self { + kind, + alias: None, + cast: None, + } } } diff --git a/src/ast/select.rs b/src/ast/select.rs index 7bdb5bc80..73631941d 100644 --- a/src/ast/select.rs +++ b/src/ast/select.rs @@ -21,6 +21,7 @@ impl<'a> From> for Expression<'a> { Expression { kind: ExpressionKind::Selection(SelectQuery::Select(Box::new(sel))), alias: None, + cast: None, } } } diff --git a/src/ast/table.rs b/src/ast/table.rs index 0a1a46e75..f7c1e36c4 100644 --- a/src/ast/table.rs +++ b/src/ast/table.rs @@ -54,6 +54,7 @@ impl<'a> Table<'a> { Expression { kind: ExpressionKind::Asterisk(Some(Box::new(self))), alias: None, + cast: None, } } diff --git a/src/ast/values.rs b/src/ast/values.rs index 945664b2e..473e506ad 100644 --- a/src/ast/values.rs +++ b/src/ast/values.rs @@ -2,7 +2,7 @@ use crate::ast::*; use crate::error::{Error, ErrorKind}; #[cfg(feature = "bigdecimal")] -use bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive}; +use bigdecimal::{BigDecimal, FromPrimitive}; #[cfg(feature = "chrono")] use chrono::{DateTime, NaiveDate, NaiveTime, Utc}; #[cfg(feature = "json")] @@ -180,7 +180,7 @@ impl<'a> From> for serde_json::Value { v.map(|v| serde_json::Value::Array(v.into_iter().map(serde_json::Value::from).collect())) } #[cfg(feature = "bigdecimal")] - Value::Numeric(d) => d.map(|d| serde_json::to_value(d.to_f64().unwrap()).unwrap()), + Value::Numeric(d) => d.map(|d| serde_json::Value::String(d.to_string())), #[cfg(feature = "json")] Value::Json(v) => v, #[cfg(feature = "uuid")] diff --git a/src/connector/postgres/conversion.rs b/src/connector/postgres/conversion.rs index e8da0e564..068268bd9 100644 --- a/src/connector/postgres/conversion.rs +++ b/src/connector/postgres/conversion.rs @@ -511,36 +511,28 @@ impl<'a> ToSql for Value<'a> { (Value::Integer(integer), &PostgresType::OID) => integer.map(|integer| (integer as u32).to_sql(ty, out)), (Value::Integer(integer), _) => integer.map(|integer| (integer as i64).to_sql(ty, out)), (Value::Float(float), &PostgresType::FLOAT8) => float.map(|float| (float as f64).to_sql(ty, out)), - #[cfg(feature = "bigdecimal")] - (Value::Float(float), &PostgresType::NUMERIC) => float - .map(|float| BigDecimal::from_f32(float).unwrap()) - .map(DecimalWrapper) - .map(|dw| dw.to_sql(ty, out)), + (Value::Float(_), &PostgresType::NUMERIC) => { + let kind = ErrorKind::conversion(format!( + "Writing a float to a {} column is unstable.", + PostgresType::NUMERIC + )); + return Err(Error::builder(kind).build().into()); + } (Value::Float(float), _) => float.map(|float| float.to_sql(ty, out)), (Value::Double(double), &PostgresType::FLOAT4) => double.map(|double| (double as f32).to_sql(ty, out)), - #[cfg(feature = "bigdecimal")] - (Value::Double(double), &PostgresType::NUMERIC) => double - .map(|double| BigDecimal::from_f64(double).unwrap()) - .map(DecimalWrapper) - .map(|dw| dw.to_sql(ty, out)), + (Value::Double(_), &PostgresType::NUMERIC) => { + let kind = ErrorKind::conversion(format!( + "Writing a double to a {} column is unstable.", + PostgresType::NUMERIC + )); + return Err(Error::builder(kind).build().into()); + } (Value::Double(double), _) => double.map(|double| double.to_sql(ty, out)), - #[cfg(feature = "bigdecimal")] - (Value::Numeric(decimal), &PostgresType::FLOAT4) => decimal.as_ref().map(|decimal| { - let f = decimal.to_string().parse::().expect("decimal to f32 conversion"); - f.to_sql(ty, out) - }), - #[cfg(feature = "bigdecimal")] - (Value::Numeric(decimal), &PostgresType::FLOAT8) => decimal.as_ref().map(|decimal| { - let f = decimal.to_string().parse::().expect("decimal to f64 conversion"); - f.to_sql(ty, out) - }), - #[cfg(feature = "bigdecimal")] (Value::Array(values), &PostgresType::FLOAT4_ARRAY) => values.as_ref().map(|values| { let mut floats = Vec::with_capacity(values.len()); for value in values.iter() { let float = match value { - Value::Numeric(n) => n.as_ref().and_then(|n| n.to_string().parse::().ok()), Value::Float(f) => *f, Value::Double(d) => d.map(|d| d as f32), v => { @@ -558,13 +550,11 @@ impl<'a> ToSql for Value<'a> { floats.to_sql(ty, out) }), - #[cfg(feature = "bigdecimal")] (Value::Array(values), &PostgresType::FLOAT8_ARRAY) => values.as_ref().map(|values| { let mut floats = Vec::with_capacity(values.len()); for value in values.iter() { let float = match value { - Value::Numeric(n) => n.as_ref().and_then(|n| n.to_string().parse::().ok()), Value::Float(f) => f.map(|f| f as f64), Value::Double(d) => *d, v => { @@ -598,9 +588,10 @@ impl<'a> ToSql for Value<'a> { .as_ref() .map(|decimal| DecimalWrapper(decimal.clone()).to_sql(ty, out)), #[cfg(feature = "bigdecimal")] - (Value::Numeric(float), _) => float - .as_ref() - .map(|float| DecimalWrapper(float.clone()).to_sql(ty, out)), + (Value::Numeric(_), typ) => { + let kind = ErrorKind::conversion(format!("Writing a decimal to a {} column is unstable.", typ)); + return Err(Error::builder(kind).build().into()); + } #[cfg(feature = "uuid")] (Value::Text(string), &PostgresType::UUID) => string.as_ref().map(|string| { let parsed_uuid: Uuid = string.parse()?; diff --git a/src/connector/sqlite/conversion.rs b/src/connector/sqlite/conversion.rs index 922d0ea5c..cd3375ab8 100644 --- a/src/connector/sqlite/conversion.rs +++ b/src/connector/sqlite/conversion.rs @@ -166,12 +166,6 @@ impl<'a> GetRow for SqliteRow<'a> { } _ => Value::integer(i), }, - #[cfg(feature = "bigdecimal")] - ValueRef::Real(f) if column.is_real() => { - use bigdecimal::{BigDecimal, FromPrimitive}; - - Value::numeric(BigDecimal::from_f64(f).unwrap()) - } ValueRef::Real(f) => Value::double(f), #[cfg(feature = "chrono")] ValueRef::Text(bytes) if column.is_datetime() => { diff --git a/src/macros.rs b/src/macros.rs index 6289fe0ba..44a8dd5ec 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -152,6 +152,7 @@ macro_rules! expression { Expression { kind: ExpressionKind::$paramkind(that), alias: None, + cast: None, } } } diff --git a/src/tests/query.rs b/src/tests/query.rs index 42f6f6e55..97ec28b12 100644 --- a/src/tests/query.rs +++ b/src/tests/query.rs @@ -1898,7 +1898,6 @@ async fn ints_read_write_to_numeric(api: &mut dyn TestApi) -> crate::Result<()> let table = api.create_table("id int, value numeric(12,2)").await?; let insert = Insert::multi_into(&table, &["id", "value"]) - .values(vec![Value::integer(1), Value::double(1234.5)]) .values(vec![Value::integer(2), Value::integer(1234)]) .values(vec![Value::integer(3), Value::integer(12345)]); @@ -1909,8 +1908,7 @@ async fn ints_read_write_to_numeric(api: &mut dyn TestApi) -> crate::Result<()> for (i, row) in rows.into_iter().enumerate() { match i { - 0 => assert_eq!(Value::numeric(BigDecimal::from_str("1234.5").unwrap()), row["value"]), - 1 => assert_eq!(Value::numeric(BigDecimal::from_str("1234.0").unwrap()), row["value"]), + 0 => assert_eq!(Value::numeric(BigDecimal::from_str("1234.0").unwrap()), row["value"]), _ => assert_eq!(Value::numeric(BigDecimal::from_str("12345.0").unwrap()), row["value"]), } } @@ -1918,32 +1916,6 @@ async fn ints_read_write_to_numeric(api: &mut dyn TestApi) -> crate::Result<()> Ok(()) } -#[cfg(feature = "bigdecimal")] -#[test_each_connector(tags("postgresql"))] -async fn bigdecimal_read_write_to_floating(api: &mut dyn TestApi) -> crate::Result<()> { - use bigdecimal::BigDecimal; - use std::str::FromStr; - - let table = api.create_table("id int, a float4, b float8").await?; - let val = BigDecimal::from_str("0.1").unwrap(); - - let insert = Insert::multi_into(&table, &["id", "a", "b"]).values(vec![ - Value::integer(1), - Value::numeric(val.clone()), - Value::numeric(val.clone()), - ]); - - api.conn().execute(insert.into()).await?; - - let select = Select::from_table(&table); - let row = api.conn().select(select).await?.into_single()?; - - assert_eq!(Value::float(0.1), row["a"]); - assert_eq!(Value::double(0.1), row["b"]); - - Ok(()) -} - #[test_each_connector] async fn coalesce_fun(api: &mut dyn TestApi) -> crate::Result<()> { let exprs: Vec = vec![Value::Text(None).into(), Value::text("Individual").into()]; @@ -2351,3 +2323,221 @@ async fn json_array_not_ends_into_fun(api: &mut dyn TestApi) -> crate::Result<() Ok(()) } + +#[cfg(feature = "mysql")] +#[test_each_connector(tags("mysql"))] +async fn mysql_type_casts(api: &mut dyn TestApi) -> crate::Result<()> { + #[cfg(feature = "bigdecimal")] + use std::str::FromStr; + + let casts = vec![ + (Value::integer(1), CastType::int2(), Value::integer(1)), + (Value::integer(1), CastType::int4(), Value::integer(1)), + (Value::integer(1), CastType::int8(), Value::integer(1)), + // You cannot cast to boolean :P + (Value::boolean(true), CastType::boolean(), Value::integer(1)), + (Value::text("asdf"), CastType::text(), Value::text("asdf")), + ( + Value::bytes(b"DEADBEEF".to_vec()), + CastType::bytes(), + Value::bytes(b"DEADBEEF".to_vec()), + ), + #[cfg(feature = "uuid")] + ( + Value::uuid(uuid::Uuid::parse_str("936DA01F9ABD4d9d80C702AF85C822A8").unwrap()), + CastType::uuid(), + Value::text("936da01f-9abd-4d9d-80c7-02af85c822a8"), + ), + #[cfg(feature = "bigdecimal")] + ( + Value::float(1.0), + CastType::float4(), + Value::numeric(bigdecimal::BigDecimal::from_str("1.0").unwrap()), + ), + #[cfg(feature = "bigdecimal")] + ( + Value::float(1.0), + CastType::float8(), + Value::numeric(bigdecimal::BigDecimal::from_str("1.0").unwrap()), + ), + #[cfg(feature = "bigdecimal")] + ( + Value::numeric(bigdecimal::BigDecimal::from_str("1.0").unwrap()), + CastType::decimal(), + Value::numeric(bigdecimal::BigDecimal::from_str("1.0").unwrap()), + ), + #[cfg(feature = "chrono")] + ( + Value::time(chrono::NaiveTime::from_hms(16, 20, 0)), + CastType::time(), + Value::time(chrono::NaiveTime::from_hms(16, 20, 0)), + ), + #[cfg(feature = "chrono")] + ( + Value::datetime(chrono::DateTime::from_utc( + chrono::NaiveDate::from_ymd(2015, 3, 14).and_hms(16, 20, 0), + chrono::Utc, + )), + CastType::datetime(), + Value::datetime(chrono::DateTime::from_utc( + chrono::NaiveDate::from_ymd(2015, 3, 14).and_hms(16, 20, 0), + chrono::Utc, + )), + ), + ]; + + for (input, cast, output) in casts.into_iter() { + let select = Select::default().value(input.cast_as(cast)); + let result = api.conn().select(select).await?.into_single()?.into_single()?; + + assert_eq!(output, result); + } + + Ok(()) +} + +#[cfg(feature = "postgresql")] +#[test_each_connector(tags("postgresql"))] +async fn postgres_type_casts(api: &mut dyn TestApi) -> crate::Result<()> { + #[cfg(feature = "bigdecimal")] + use std::str::FromStr; + + let casts = vec![ + (Value::integer(1), CastType::int2(), Value::integer(1)), + (Value::integer(1), CastType::int4(), Value::integer(1)), + (Value::integer(1), CastType::int8(), Value::integer(1)), + (Value::float(1.0), CastType::float4(), Value::float(1.0)), + (Value::double(1.0), CastType::float8(), Value::double(1.0)), + (Value::boolean(true), CastType::boolean(), Value::boolean(true)), + (Value::text("asdf"), CastType::text(), Value::text("asdf")), + ( + Value::bytes(b"DEADBEEF".to_vec()), + CastType::bytes(), + Value::bytes(b"DEADBEEF".to_vec()), + ), + #[cfg(feature = "uuid")] + ( + Value::uuid(uuid::Uuid::parse_str("936DA01F9ABD4d9d80C702AF85C822A8").unwrap()), + CastType::uuid(), + Value::uuid(uuid::Uuid::parse_str("936DA01F9ABD4d9d80C702AF85C822A8").unwrap()), + ), + #[cfg(feature = "bigdecimal")] + ( + Value::numeric(bigdecimal::BigDecimal::from_str("1.0").unwrap()), + CastType::decimal(), + Value::numeric(bigdecimal::BigDecimal::from_str("1.0").unwrap()), + ), + #[cfg(feature = "json")] + ( + Value::json(serde_json::json!({"a": "b"})), + CastType::json(), + Value::json(serde_json::json!({"a": "b"})), + ), + #[cfg(feature = "json")] + ( + Value::json(serde_json::json!({"a": "b"})), + CastType::jsonb(), + Value::json(serde_json::json!({"a": "b"})), + ), + #[cfg(feature = "chrono")] + ( + Value::date(chrono::NaiveDate::from_ymd(2015, 3, 14)), + CastType::date(), + Value::date(chrono::NaiveDate::from_ymd(2015, 3, 14)), + ), + #[cfg(feature = "chrono")] + ( + Value::time(chrono::NaiveTime::from_hms(16, 20, 0)), + CastType::time(), + Value::time(chrono::NaiveTime::from_hms(16, 20, 0)), + ), + #[cfg(feature = "chrono")] + ( + Value::datetime(chrono::DateTime::from_utc( + chrono::NaiveDate::from_ymd(2015, 3, 14).and_hms(16, 20, 0), + chrono::Utc, + )), + CastType::datetime(), + Value::datetime(chrono::DateTime::from_utc( + chrono::NaiveDate::from_ymd(2015, 3, 14).and_hms(16, 20, 0), + chrono::Utc, + )), + ), + ]; + + for (input, cast, output) in casts.into_iter() { + let select = Select::default().value(input.cast_as(cast)); + let result = api.conn().select(select).await?.into_single()?.into_single()?; + + assert_eq!(output, result); + } + + Ok(()) +} + +#[cfg(feature = "mssql")] +#[test_each_connector(tags("mssql"))] +async fn mssql_type_casts(api: &mut dyn TestApi) -> crate::Result<()> { + #[cfg(feature = "bigdecimal")] + use std::str::FromStr; + + let casts = vec![ + (Value::integer(1), CastType::int2(), Value::integer(1)), + (Value::integer(1), CastType::int4(), Value::integer(1)), + (Value::integer(1), CastType::int8(), Value::integer(1)), + (Value::float(1.0), CastType::float4(), Value::float(1.0)), + (Value::double(1.0), CastType::float8(), Value::double(1.0)), + (Value::boolean(true), CastType::boolean(), Value::boolean(true)), + (Value::text("asdf"), CastType::text(), Value::text("asdf")), + ( + Value::bytes(b"DEADBEEF".to_vec()), + CastType::bytes(), + Value::bytes(b"DEADBEEF".to_vec()), + ), + #[cfg(feature = "uuid")] + ( + Value::uuid(uuid::Uuid::parse_str("936DA01F9ABD4d9d80C702AF85C822A8").unwrap()), + CastType::uuid(), + Value::uuid(uuid::Uuid::parse_str("936DA01F9ABD4d9d80C702AF85C822A8").unwrap()), + ), + #[cfg(feature = "bigdecimal")] + ( + Value::numeric(bigdecimal::BigDecimal::from_str("1.0").unwrap()), + CastType::decimal(), + Value::numeric(bigdecimal::BigDecimal::from_str("1.0").unwrap()), + ), + #[cfg(feature = "chrono")] + ( + Value::date(chrono::NaiveDate::from_ymd(2015, 3, 14)), + CastType::date(), + Value::date(chrono::NaiveDate::from_ymd(2015, 3, 14)), + ), + #[cfg(feature = "chrono")] + ( + Value::time(chrono::NaiveTime::from_hms(16, 20, 0)), + CastType::time(), + Value::time(chrono::NaiveTime::from_hms(16, 20, 0)), + ), + #[cfg(feature = "chrono")] + ( + Value::datetime(chrono::DateTime::from_utc( + chrono::NaiveDate::from_ymd(2015, 3, 14).and_hms(16, 20, 0), + chrono::Utc, + )), + CastType::datetime(), + Value::datetime(chrono::DateTime::from_utc( + chrono::NaiveDate::from_ymd(2015, 3, 14).and_hms(16, 20, 0), + chrono::Utc, + )), + ), + ]; + + for (input, cast, output) in casts.into_iter() { + let select = Select::default().value(input.cast_as(cast)); + let result = api.conn().select(select).await?.into_single()?.into_single()?; + + assert_eq!(output, result); + } + + Ok(()) +} diff --git a/src/tests/types/postgres/bigdecimal.rs b/src/tests/types/postgres/bigdecimal.rs index ca99857fd..2cbc1c9e0 100644 --- a/src/tests/types/postgres/bigdecimal.rs +++ b/src/tests/types/postgres/bigdecimal.rs @@ -186,23 +186,3 @@ test_type!(money_array( Value::Array(None), Value::array(vec![BigDecimal::from_str("1.12")?, BigDecimal::from_str("1.12")?]) )); - -test_type!(float4( - postgresql, - "float4", - (Value::Numeric(None), Value::Float(None)), - ( - Value::numeric(BigDecimal::from_str("1.123456")?), - Value::float(1.123456) - ) -)); - -test_type!(float8( - postgresql, - "float8", - (Value::Numeric(None), Value::Double(None)), - ( - Value::numeric(BigDecimal::from_str("1.123456")?), - Value::double(1.123456) - ) -)); diff --git a/src/tests/types/sqlite.rs b/src/tests/types/sqlite.rs index b26ae1753..48b598cd0 100644 --- a/src/tests/types/sqlite.rs +++ b/src/tests/types/sqlite.rs @@ -3,7 +3,7 @@ use crate::tests::test_api::sqlite_test_api; use crate::tests::test_api::TestApi; #[cfg(feature = "chrono")] use crate::{ast::*, connector::Queryable}; -#[cfg(feature = "bigdecimal")] +#[cfg(feature = "chrono")] use std::str::FromStr; test_type!(integer( @@ -20,36 +20,6 @@ test_type!(integer( Value::integer(i64::MAX) )); -#[cfg(feature = "bigdecimal")] -test_type!(real( - sqlite, - "REAL", - Value::Numeric(None), - Value::numeric(bigdecimal::BigDecimal::from_str("1.12345").unwrap()) -)); - -#[cfg(feature = "bigdecimal")] -test_type!(float_decimal( - sqlite, - "FLOAT", - (Value::Numeric(None), Value::Float(None)), - ( - Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), - Value::double(3.14) - ) -)); - -#[cfg(feature = "bigdecimal")] -test_type!(double_decimal( - sqlite, - "DOUBLE", - (Value::Numeric(None), Value::Double(None)), - ( - Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), - Value::double(3.14) - ) -)); - test_type!(text(sqlite, "TEXT", Value::Text(None), Value::text("foobar huhuu"))); test_type!(blob( diff --git a/src/visitor.rs b/src/visitor.rs index af3012d38..f42f03e97 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -117,6 +117,9 @@ pub trait Visitor<'a> { #[cfg(all(feature = "json", any(feature = "postgresql", feature = "mysql")))] fn visit_json_type_equals(&mut self, left: Expression<'a>, json_type: JsonType) -> Result; + /// Visit an expression with a type cast. + fn visit_cast_expression(&mut self, value: Expression<'a>, cast: CastType<'a>) -> Result; + /// A visit to a value we parameterize fn visit_parameterized(&mut self, value: Value<'a>) -> Result { self.add_parameter(value); @@ -441,38 +444,42 @@ pub trait Visitor<'a> { } /// A visit to a value used in an expression - fn visit_expression(&mut self, value: Expression<'a>) -> Result { - match value.kind { - ExpressionKind::Value(value) => self.visit_expression(*value)?, - ExpressionKind::ConditionTree(tree) => self.visit_conditions(tree)?, - ExpressionKind::Compare(compare) => self.visit_compare(compare)?, - ExpressionKind::Parameterized(val) => self.visit_parameterized(val)?, - ExpressionKind::RawValue(val) => self.visit_raw_value(val.0)?, - ExpressionKind::Column(column) => self.visit_column(*column)?, - ExpressionKind::Row(row) => self.visit_row(row)?, - ExpressionKind::Selection(selection) => { - self.surround_with("(", ")", |ref mut s| s.visit_selection(selection))? - } - ExpressionKind::Function(function) => self.visit_function(*function)?, - ExpressionKind::Op(op) => self.visit_operation(*op)?, - ExpressionKind::Values(values) => self.visit_values(*values)?, - ExpressionKind::Asterisk(table) => match table { - Some(table) => { - self.visit_table(*table, false)?; - self.write(".*")? + fn visit_expression(&mut self, mut value: Expression<'a>) -> Result { + match value.cast.take() { + None => { + match value.kind { + ExpressionKind::Value(value) => self.visit_expression(*value)?, + ExpressionKind::ConditionTree(tree) => self.visit_conditions(tree)?, + ExpressionKind::Compare(compare) => self.visit_compare(compare)?, + ExpressionKind::Parameterized(val) => self.visit_parameterized(val)?, + ExpressionKind::RawValue(val) => self.visit_raw_value(val.0)?, + ExpressionKind::Column(column) => self.visit_column(*column)?, + ExpressionKind::Row(row) => self.visit_row(row)?, + ExpressionKind::Selection(selection) => { + self.surround_with("(", ")", |ref mut s| s.visit_selection(selection))? + } + ExpressionKind::Function(function) => self.visit_function(*function)?, + ExpressionKind::Op(op) => self.visit_operation(*op)?, + ExpressionKind::Values(values) => self.visit_values(*values)?, + ExpressionKind::Asterisk(table) => match table { + Some(table) => { + self.visit_table(*table, false)?; + self.write(".*")? + } + None => self.write("*")?, + }, + ExpressionKind::Default => self.write("DEFAULT")?, } - None => self.write("*")?, - }, - ExpressionKind::Default => self.write("DEFAULT")?, - } - if let Some(alias) = value.alias { - self.write(" AS ")?; + if let Some(alias) = value.alias { + self.write(" AS ")?; + self.delimited_identifiers(&[&*alias])?; + }; - self.delimited_identifiers(&[&*alias])?; - }; - - Ok(()) + Ok(()) + } + Some(cast) => self.visit_cast_expression(value, cast), + } } fn visit_multiple_tuple_comparison(&mut self, left: Row<'a>, right: Values<'a>, negate: bool) -> Result { diff --git a/src/visitor/mssql.rs b/src/visitor/mssql.rs index 34a5089b0..4ffc9226f 100644 --- a/src/visitor/mssql.rs +++ b/src/visitor/mssql.rs @@ -1,15 +1,10 @@ use super::Visitor; +use crate::error::{Error, ErrorKind}; +use crate::prelude::Aliasable; +use crate::prelude::Query; #[cfg(all(feature = "json", any(feature = "postgresql", feature = "mysql")))] use crate::prelude::{JsonExtract, JsonType}; -use crate::{ - ast::{ - Column, Comparable, Expression, ExpressionKind, Insert, IntoRaw, Join, JoinData, Joinable, Merge, OnConflict, - Order, Ordering, Row, Table, TypeDataLength, TypeFamily, Values, - }, - error::{Error, ErrorKind}, - prelude::{Aliasable, Average, Query}, - visitor, Value, -}; +use crate::{ast::*, prelude::Average, visitor, Value}; use std::{convert::TryFrom, fmt::Write, iter}; static GENERATED_KEYS: &str = "@generated_keys"; @@ -343,6 +338,45 @@ impl<'a> Visitor<'a> for Mssql<'a> { } } + fn visit_cast_expression(&mut self, mut value: Expression<'a>, cast: CastType<'a>) -> visitor::Result { + if cast.mssql_enabled() { + let alias = value.alias.take(); + + self.surround_with("CAST(", ")", |this| { + this.visit_expression(value)?; + this.write(" AS ")?; + + match cast.kind() { + CastKind::Int2 => this.write("smallint"), + CastKind::Int4 => this.write("int"), + CastKind::Int8 => this.write("bigint"), + CastKind::Float4 => this.write("real"), + CastKind::Float8 => this.write("float"), + CastKind::Decimal => this.write("numeric"), + CastKind::Boolean => this.write("bit"), + CastKind::Uuid => this.write("uniqueidentifier"), + CastKind::Json => this.write("nvarchar"), + CastKind::Jsonb => this.write("nvarchar"), + CastKind::Date => this.write("date"), + CastKind::Time => this.write("time"), + CastKind::DateTime => this.write("datetime2"), + CastKind::Bytes => this.write("varbinary"), + CastKind::Text => this.write("nvarchar"), + CastKind::Custom(r#type) => this.write(r#type), + } + })?; + + if let Some(alias) = alias { + self.write(" AS ")?; + self.delimited_identifiers(&[&alias])?; + } + + Ok(()) + } else { + self.visit_expression(value) + } + } + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> visitor::Result { let add_ordering = |this: &mut Self| { if !this.order_by_set { @@ -1741,4 +1775,12 @@ mod tests { sql ); } + + #[test] + fn type_casts_smoke() { + let select = Select::default().value(1.cast_as(CastType::int2()).alias("val")); + let (sql, _) = Mssql::build(select).unwrap(); + + assert_eq!("SELECT CAST(@P1 AS smallint) AS [val]", sql); + } } diff --git a/src/visitor/mysql.rs b/src/visitor/mysql.rs index c8323458a..ac563e5e5 100644 --- a/src/visitor/mysql.rs +++ b/src/visitor/mysql.rs @@ -189,6 +189,45 @@ impl<'a> Visitor<'a> for Mysql<'a> { self.parameters.push(value); } + fn visit_cast_expression(&mut self, mut value: Expression<'a>, cast: CastType<'a>) -> visitor::Result { + if cast.mysql_enabled() { + let alias = value.alias.take(); + + self.surround_with("CAST(", ")", |this| { + this.visit_expression(value)?; + this.write(" AS ")?; + + match cast.kind() { + CastKind::Int2 => this.write("signed"), + CastKind::Int4 => this.write("signed"), + CastKind::Int8 => this.write("signed"), + CastKind::Float4 => this.write("decimal"), + CastKind::Float8 => this.write("decimal"), + CastKind::Decimal => this.write("decimal"), + CastKind::Boolean => this.write("unsigned"), + CastKind::Uuid => this.write("char"), + CastKind::Json => this.write("nchar"), + CastKind::Jsonb => this.write("nchar"), + CastKind::Date => this.write("date"), + CastKind::Time => this.write("time"), + CastKind::DateTime => this.write("datetime"), + CastKind::Bytes => this.write("binary"), + CastKind::Text => this.write("nchar"), + CastKind::Custom(r#type) => this.write(r#type), + } + })?; + + if let Some(alias) = alias { + self.write(" AS ")?; + self.delimited_identifiers(&[&alias])?; + } + + Ok(()) + } else { + self.visit_expression(value) + } + } + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> visitor::Result { match (limit, offset) { (Some(limit), Some(offset)) => { @@ -692,4 +731,12 @@ mod tests { sql ); } + + #[test] + fn type_casts_smoke() { + let select = Select::default().value(1.cast_as(CastType::int2()).alias("val")); + let (sql, _) = Mysql::build(select).unwrap(); + + assert_eq!("SELECT CAST(? AS signed) AS `val`", sql); + } } diff --git a/src/visitor/postgres.rs b/src/visitor/postgres.rs index 0c533cb0b..4753476db 100644 --- a/src/visitor/postgres.rs +++ b/src/visitor/postgres.rs @@ -48,6 +48,43 @@ impl<'a> Visitor<'a> for Postgres<'a> { self.write(self.parameters.len()) } + fn visit_cast_expression(&mut self, mut value: Expression<'a>, cast: CastType<'a>) -> visitor::Result { + if cast.postgres_enabled() { + let alias = value.alias.take(); + + self.surround_with("(", ")", |this| this.visit_expression(value))?; + self.write("::")?; + + match cast.kind() { + CastKind::Int2 => self.write("int2")?, + CastKind::Int4 => self.write("int4")?, + CastKind::Int8 => self.write("int8")?, + CastKind::Float4 => self.write("float4")?, + CastKind::Float8 => self.write("float8")?, + CastKind::Decimal => self.write("numeric")?, + CastKind::Boolean => self.write("boolean")?, + CastKind::Uuid => self.write("uuid")?, + CastKind::Json => self.write("json")?, + CastKind::Jsonb => self.write("jsonb")?, + CastKind::Date => self.write("date")?, + CastKind::Time => self.write("time")?, + CastKind::DateTime => self.write("timestamp")?, + CastKind::Bytes => self.write("bytea")?, + CastKind::Text => self.write("text")?, + CastKind::Custom(r#type) => self.write(r#type)?, + } + + if let Some(alias) = alias { + self.write(" AS ")?; + self.delimited_identifiers(&[&alias])?; + } + + Ok(()) + } else { + self.visit_expression(value) + } + } + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> visitor::Result { match (limit, offset) { (Some(limit), Some(offset)) => { @@ -743,4 +780,12 @@ mod tests { assert_eq!("SELECT \"User\".*, \"Toto\".* FROM \"User\" LEFT JOIN \"Post\" AS \"p\" ON \"p\".\"userId\" = \"User\".\"id\", \"Toto\"", sql); } + + #[test] + fn type_casts_smoke() { + let select = Select::default().value(1.cast_as(CastType::int2()).alias("val")); + let (sql, _) = Postgres::build(select).unwrap(); + + assert_eq!("SELECT ($1)::int2 AS \"val\"", sql); + } } diff --git a/src/visitor/sqlite.rs b/src/visitor/sqlite.rs index bf1cd7d75..8c96bb2c7 100644 --- a/src/visitor/sqlite.rs +++ b/src/visitor/sqlite.rs @@ -188,6 +188,10 @@ impl<'a> Visitor<'a> for Sqlite<'a> { self.parameters.push(value); } + fn visit_cast_expression(&mut self, value: Expression<'a>, _cast: CastType<'a>) -> visitor::Result { + self.visit_expression(value) + } + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> visitor::Result { match (limit, offset) { (Some(limit), Some(offset)) => {