diff --git a/README.md b/README.md index 8066798..b6f7308 100644 --- a/README.md +++ b/README.md @@ -30,5 +30,5 @@ Use the `file://` protocol to load data into a file instead. ## Limitations -- Supported datatypes: bool, char, int2, int4, int8, float4, float8, timestamp, timestamptz, text, bytea. Cast the columns in your query to `text` or another supported type if your query returns different types +- Supported datatypes: bool, char, int2, int4, int8, float4, float8, timestamp, timestamptz, date, text, bytea. Cast the columns in your query to `text` or another supported type if your query returns different types - Doesn't support appending to tables, only writing new Delta Tables (pass `-o` to overwrite) diff --git a/src/pg_arrow_source.rs b/src/pg_arrow_source.rs index fb76c3a..90c6b6b 100644 --- a/src/pg_arrow_source.rs +++ b/src/pg_arrow_source.rs @@ -23,6 +23,7 @@ pub enum ArrowBuilder { Float32Builder(array::Float32Builder), Float64Builder(array::Float64Builder), TimestampMicrosecondBuilder(array::TimestampMicrosecondBuilder), + DateBuilder(array::Date32Builder), StringBuilder(array::StringBuilder), BinaryBuilder(array::BinaryBuilder), } @@ -30,13 +31,35 @@ use crate::{ArrowBuilder::*, DataLoadingError}; // tokio-postgres provides awkward Rust type conversions for Postgres TIMESTAMP and TIMESTAMPTZ values // It's easier just to handle the raw values ourselves +struct UnixEpochDayOffset(i32); +// Number of days from 1970-01-01 to 2000-01-01 +const J2000_EPOCH_DAYS: i32 = 10957; + +impl FromSql<'_> for UnixEpochDayOffset { + fn from_sql(_ty: &Type, buf: &[u8]) -> Result> { + let byte_array: [u8; 4] = buf.try_into()?; + let offset = i32::from_be_bytes(byte_array) + J2000_EPOCH_DAYS; + Ok(Self(offset)) + } + + fn accepts(ty: &Type) -> bool { + *ty == Type::DATE + } +} +impl From for i32 { + fn from(val: UnixEpochDayOffset) -> Self { + val.0 + } +} + struct UnixEpochMicrosecondOffset(i64); -const J2000_EPOCH_OFFSET: i64 = 946_684_800_000_000; // Number of us from 1970-01-01 to 2000-01-01 +// Number of us from 1970-01-01 (Unix epoch) to 2000-01-01 (Postgres epoch) +const J2000_EPOCH_MICROSECONDS: i64 = J2000_EPOCH_DAYS as i64 * 86400 * 1000000; impl FromSql<'_> for UnixEpochMicrosecondOffset { fn from_sql(_ty: &Type, buf: &[u8]) -> Result> { let byte_array: [u8; 8] = buf.try_into()?; - let offset = i64::from_be_bytes(byte_array) + J2000_EPOCH_OFFSET; + let offset = i64::from_be_bytes(byte_array) + J2000_EPOCH_MICROSECONDS; Ok(Self(offset)) } @@ -69,6 +92,7 @@ impl ArrowBuilder { Some("UTC".into()), )), ), + Type::DATE => DateBuilder(array::Date32Builder::new()), Type::TEXT => StringBuilder(array::StringBuilder::new()), Type::BYTEA => BinaryBuilder(array::BinaryBuilder::new()), _ => panic!("Unsupported type: {}", pg_type), @@ -102,6 +126,10 @@ impl ArrowBuilder { row.get::>(column_idx) .map(UnixEpochMicrosecondOffset::into), ), + DateBuilder(ref mut builder) => builder.append_option( + row.get::>(column_idx) + .map(UnixEpochDayOffset::into), + ), StringBuilder(ref mut builder) => { builder.append_option(row.get::>(column_idx)) } @@ -120,6 +148,7 @@ impl ArrowBuilder { Float32Builder(builder) => Arc::new(builder.finish()), Float64Builder(builder) => Arc::new(builder.finish()), TimestampMicrosecondBuilder(builder) => Arc::new(builder.finish()), + DateBuilder(builder) => Arc::new(builder.finish()), StringBuilder(builder) => Arc::new(builder.finish()), BinaryBuilder(builder) => Arc::new(builder.finish()), } @@ -137,6 +166,7 @@ fn pg_type_to_arrow_type(pg_type: &Type) -> DataType { Type::FLOAT8 => DataType::Float64, Type::TIMESTAMP => DataType::Timestamp(TimeUnit::Microsecond, None), Type::TIMESTAMPTZ => DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + Type::DATE => DataType::Date32, Type::TEXT => DataType::Utf8, Type::BYTEA => DataType::Binary, _ => panic!("Unsupported type: {}. Explicitly cast the relevant columns to text in order to store them as strings.", pg_type), @@ -230,17 +260,17 @@ impl PgArrowSource { mod tests { use postgres::types::{FromSql, Type}; - use super::UnixEpochMicrosecondOffset; + use super::*; #[test] - fn test_just_after_j2000() { + fn test_timestamp_just_after_j2000() { let offset = UnixEpochMicrosecondOffset::from_sql(&Type::TIMESTAMP, &[0, 0, 0, 0, 0, 0, 1, 2]) .unwrap(); assert_eq!(offset.0, 946_684_800_000_000 + 256 + 2); } #[test] - fn test_just_before_j2000() { + fn test_timestamp_just_before_j2000() { let offset = UnixEpochMicrosecondOffset::from_sql( &Type::TIMESTAMP, &[255, 255, 255, 255, 255, 255, 255, 255], @@ -248,4 +278,14 @@ mod tests { .unwrap(); assert_eq!(offset.0, 946_684_800_000_000 - 1); } + #[test] + fn test_date_just_after_j2000() { + let offset = UnixEpochDayOffset::from_sql(&Type::DATE, &[0, 0, 1, 2]).unwrap(); + assert_eq!(offset.0, 10957 + 256 + 2); + } + #[test] + fn test_date_just_before_j2000() { + let offset = UnixEpochDayOffset::from_sql(&Type::DATE, &[255, 255, 255, 255]).unwrap(); + assert_eq!(offset.0, 10957 - 1); + } } diff --git a/tests/basic_integration.rs b/tests/basic_integration.rs index 956b16a..59688cf 100644 --- a/tests/basic_integration.rs +++ b/tests/basic_integration.rs @@ -1,6 +1,6 @@ use arrow::array::{ - Array, BinaryArray, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, + Array, BinaryArray, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, }; use clap::Parser; use futures::{StreamExt, TryStreamExt}; @@ -195,9 +195,21 @@ async fn test_pg_arrow_source() { (elapsed_days * seconds_per_day + 2) * 1000000 ); + // THEN the first 3 date values should be as expected + let cdate_array = rb1 + .column(10) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(cdate_array.is_null(0)); + assert!(!cdate_array.is_null(1)); + assert_eq!(cdate_array.value(1), elapsed_days as i32 + 1); + assert!(!cdate_array.is_null(2)); + assert_eq!(cdate_array.value(2), elapsed_days as i32 + 2); + // THEN the first 3 text values should be as expected let ctext_array = rb1 - .column(10) + .column(11) .as_any() .downcast_ref::() .unwrap(); @@ -209,7 +221,7 @@ async fn test_pg_arrow_source() { // THEN the first 3 bytea values should be as expected let cbytea_array = rb1 - .column(11) + .column(12) .as_any() .downcast_ref::() .unwrap(); diff --git a/tests/postgres-init-scripts/init-pg-data.sql b/tests/postgres-init-scripts/init-pg-data.sql index 43816c8..af69a2e 100755 --- a/tests/postgres-init-scripts/init-pg-data.sql +++ b/tests/postgres-init-scripts/init-pg-data.sql @@ -9,6 +9,7 @@ CREATE TABLE t1( cfloat8 DOUBLE PRECISION, ctimestamp TIMESTAMP, ctimestamptz TIMESTAMPTZ, + cdate DATE, ctext TEXT, cbytea BYTEA ); @@ -25,6 +26,7 @@ INSERT INTO t1( cfloat8, ctimestamp, ctimestamptz, + cdate, ctext, cbytea ) SELECT @@ -37,6 +39,7 @@ INSERT INTO t1( s + 0.5, '2024-01-01'::TIMESTAMP + s * INTERVAL '1 second', '2024-01-01 00:00:00+00'::TIMESTAMPTZ + s * INTERVAL '1 second', + '2024-01-01'::DATE + s, s::TEXT, int4send(s::INT) FROM generate_series(1, 25000) AS s;