From 80d328b7db96dfb71c9477d4b965c8cf9e3c7bf2 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Sat, 23 Sep 2023 14:45:39 +0200 Subject: [PATCH 01/35] save work to continue on other pc --- python/deltalake/table.py | 16 ++++++++++++++++ python/src/lib.rs | 11 +++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index cf7d844e11..c4e277e6bc 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -661,6 +661,22 @@ def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch: return self._table.get_add_actions(flatten) +class TableMerger: + """API for various table MERGE commands.""" + + def __init__(self, table: DeltaTable): + self.table = table + + def __call__( # this is when you do dt.merge() + self, + source: DeltaTable, + predicate: str + ): + return self.table._table.merge( + source = source, + predicate = predicate + ) # Which should return DelteMergeBuilder object + class TableOptimizer: """API for various table optimization commands.""" diff --git a/python/src/lib.rs b/python/src/lib.rs index b4fc515f2b..7383a438e1 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -24,6 +24,7 @@ use deltalake::checkpoints::create_checkpoint; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; +use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::transaction::commit; @@ -324,6 +325,16 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } + + #[pyo3(signature = (source, predicate))] + pub fn merge( + &mut self, + source: str, + predicate, + ) -> PyResult { + + } + // Run the restore command on the Delta Table: restore table to a given version or datetime #[pyo3(signature = (target, *, ignore_missing_files = false, protocol_downgrade_allowed = false))] pub fn restore( From f62abab996936e7e362b5ae62f861ab8308aaf41 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 23 Sep 2023 19:08:55 +0200 Subject: [PATCH 02/35] Setup merge skeleton on python side --- python/deltalake/table.py | 255 ++++++++++++++++++++++++++++++++++++-- python/src/lib.rs | 26 +++- 2 files changed, 267 insertions(+), 14 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index c4e277e6bc..50983295f2 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -456,6 +456,12 @@ def optimize( ) -> "TableOptimizer": return TableOptimizer(self) + @property + def merge( + self, + ) -> "TableMerger": + return TableMerger(self) + def pyarrow_schema(self) -> pyarrow.Schema: """ Get the current schema of the DeltaTable with the Parquet PyArrow format. @@ -663,19 +669,248 @@ def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch: class TableMerger: """API for various table MERGE commands.""" - + def __init__(self, table: DeltaTable): self.table = table - - def __call__( # this is when you do dt.merge() + self.source = None + self.predicate = None + self.source_alias = None + self.strict_cast = False + self.writer_properties = None + self.matched_update_updates = None + self.matched_update_predicate = None + self.matched_delete_predicate = None + self.not_matched_insert_updates = None + self.not_matched_insert_predicate = None + self.not_matched_by_source_update_updates = None + self.not_matched_by_source_update_predicate = None + self.not_matched_by_source_delete_predicate = None + + def __call__( self, - source: DeltaTable, - predicate: str - ): - return self.table._table.merge( - source = source, - predicate = predicate - ) # Which should return DelteMergeBuilder object + source: pyarrow.Table | pyarrow.RecordBatch, + source_alias: str, + predicate: str, + strict_cast: bool = False, + ) -> "TableMerger": + """Pass the source dataframe which you want to merge on the target delta table, providing a + predicate in SQL query format. You can also specify on what to do when underlying data types do not + match the underlying table. + + Args: + source (pyarrow.Table | pyarrow.RecordBatch): source dataframe + predicate (str): SQL like predicate on how to merge + strict_cast (bool): specify if data types need to be casted strictly or not :default = False + + Returns: + TableMerger: TableMerger Object + """ + self.source = source + self.predicate = predicate + self.strict_cast = strict_cast + self.source_alias = source_alias + + return self + + # def with_source_alias(self, alias: str) -> "TableMerger": + # """Rename columns in the source dataset to have a prefix of `alias`.`original column name` + + # Args: + # alias (str): alias + + # Returns: + # TableMerger: TableMerger Object + # """ + # self.source_alias = alias + + # return self + + def with_writer_properties( + self, + data_page_size_limit=None, + dictionary_page_size_limit=None, + data_page_row_count_limit=None, + write_batch_size=None, + max_row_group_size=None, + ) -> "TableMerger": + """Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html: + + Args: + data_page_size_limit (_type_, optional): _description_. Defaults to None. + dictionary_page_size_limit (_type_, optional): _description_. Defaults to None. + data_page_row_count_limit (_type_, optional): _description_. Defaults to None. + write_batch_size (_type_, optional): _description_. Defaults to None. + max_row_group_size (_type_, optional): _description_. Defaults to None. + + Returns: + TableMerger: TableMerger Object + """ + writer_properties = { + "data_page_size_limit": data_page_size_limit, + "dictionary_page_size_limit": dictionary_page_size_limit, + "data_page_row_count_limit": data_page_row_count_limit, + "write_batch_size": write_batch_size, + "max_row_group_size": max_row_group_size, + } + self.writer_properties = writer_properties + return self + + def when_matched_update( + self, updates: dict, predicate: str | None = None + ) -> "TableMerger": + """Update a matched table row based on the rules defined by ``updates``. + If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + + + Args: + updates (dict): column mapping (source to target) which to update + predicate (str | None, optional): _description_. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + ... .when_matched_update( + ... updates = { + ... "x": "source.x", + ... "y": "source.y" + ... } + ... ).execute() + """ + self.matched_update_updates = updates + self.matched_update_predicate = predicate + return self + + def when_matched_delete(self, predicate: str | None = None) -> "TableMerger": + """Delete a matched row from the table only if the given ``predicate`` (if specified) is + true for the matched row. If not specified it deletes all matches. + + Args: + predicate (str | None, optional): SQL like predicate on when to delete. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + ... .when_matched_delete(predicate = "source.deleted = true") + ... .execute() + """ + self.matched_delete_predicate = predicate + return self + + def when_not_matched_insert( + self, updates: dict, predicate: str | None = None + ) -> "TableMerger": + """Insert a new row to the target table based on the rules defined by ``updates``. If a + ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. + + Args: + updates (dict): column mapping (source to target) which to insert + predicate (str | None, optional): SQL like predicate on when to insert. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + ... .when_not_matched_insert( + ... updates = { + ... "x": "source.x", + ... "y": "source.y" + ... } + ... ).execute() + """ + self.not_matched_insert_updates = updates + self.not_matched_insert_predicate = predicate + return self + + def when_not_matched_by_source_update( + self, updates: dict, predicate: str | None = None + ) -> "TableMerger": + """Update a target row that has no matches in the source based on the rules defined by ``updates``. + If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + + Args: + updates (dict): column mapping (source to target) which to update + predicate (str | None, optional): SQL like predicate on when to update. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + ... .when_not_matched_by_source_update( + ... predicate = "y > 3" + ... updates = { + ... "y": "0", + ... } + ... ).execute() + """ + self.not_matched_by_source_update_updates = updates + self.not_matched_by_source_update_predicate = predicate + return self + + def when_not_matched_by_source_delete( + self, predicate: str | None = None + ) -> "TableMerger": + """Delete a target row that has no matches in the source from the table only if the given + ``predicate`` (if specified) is true for the target row. + + Args: + updates (dict): _description_ + predicate (str | None, optional): _description_. Defaults to None. + + Returns: + TableMerger: _description_ + """ + self.not_matched_by_source_delete_predicate = predicate + return self + + def execute(self) -> Dict[str, Any]: + """Executes MERGE with the previously provided settings. + + Returns: + Tuple[DeltaTable, dict]: dt, metrics + """ + metrics = self.table._table.merge_execute( + source=self.source, + predicate=self.predicate, + source_alias=self.source_alias, + safe_cast=self.strict_cast, + writer_properties=self.writer_properties, + matched_update_updates=self.matched_update_updates, + matched_update_predicate=self.matched_update_predicate, + matched_delete_predicate=self.matched_delete_predicate, + not_matched_insert_updates=self.not_matched_insert_updates, + not_matched_insert_predicate=self.not_matched_insert_predicate, + not_matched_by_source_update_updates=self.not_matched_by_source_update_updates, + not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate, + not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate, + ) + self.table.update_incremental() + return json.loads(metrics) + class TableOptimizer: """API for various table optimization commands.""" diff --git a/python/src/lib.rs b/python/src/lib.rs index 7383a438e1..d0061f747a 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -326,13 +326,31 @@ impl RawDeltaTable { } - #[pyo3(signature = (source, predicate))] - pub fn merge( + #[pyo3(signature = (source, predicate, source_alias, strict_cast, writer_properties, + mached_update_updates, + matched_update_predicate, + matched_delete_predicate, + not_matched_insert_updates, + not_matched_insert_predicate, + not_matched_by_source_update_updates, + not_matched_by_source_update_predicate, + not_matched_by_source_delete_predicate, + ))] + pub fn merge_execute( &mut self, source: str, predicate, - ) -> PyResult { - + ) -> PyResult { + + + + + + let (table, metrics) = rt()? + .block_on(cmd.into_future()) + .map_err(PythonError::from)?; + self._table.state = table.state; + Ok(serde_json::to_string(&metrics).unwrap()) } // Run the restore command on the Delta Table: restore table to a given version or datetime From fd07d5ff665f46c8ed01110b62ef2c4fe243abbc Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Tue, 26 Sep 2023 19:44:06 +0200 Subject: [PATCH 03/35] use arrow --- python/src/lib.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index d0061f747a..88535bfd59 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -338,9 +338,20 @@ impl RawDeltaTable { ))] pub fn merge_execute( &mut self, - source: str, + source, predicate, ) -> PyResult { + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::?::from(vec)) + ] + ) + + let source = ctx.read_batch(batch).unwrap(); + + let mut cmd = MergeBuilder::new(self._table.object_store(), self._table.snapshot(), predicate, source) From 4e4e432d41ed8fa69f85275026a98fe99a8409ea Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Fri, 29 Sep 2023 21:49:03 +0200 Subject: [PATCH 04/35] save to continue home pc --- python/src/lib.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/src/lib.rs b/python/src/lib.rs index 0de0ae253e..2d6712afc3 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -334,6 +334,22 @@ impl RawDeltaTable { self._table.state = table.state; Ok(serde_json::to_string(&metrics).unwrap()) } + + #[py03(signature)= (data)] + pub fn convert_pyarrow( + &mut self, + schema: PyArrowType, + data: HashMap() + ) -> PyResult { + let ctx = SessionContext::new(); + let schema: Schema = (&schema.0).try_into().map_err(PythonError::from)?; + + let batch = RecordBatch::try_new( + + ) + + let arrow_data = + } #[pyo3(signature = (source, predicate, source_alias, strict_cast, writer_properties, From ee7860fe6632443dad3b512ddb284c9faf6a8fe5 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Fri, 29 Sep 2023 22:58:50 +0200 Subject: [PATCH 05/35] save --- python/src/lib.rs | 104 ++++++++++++++++++++++------------------------ 1 file changed, 50 insertions(+), 54 deletions(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index 2d6712afc3..cd51e46722 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -335,60 +335,56 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } - #[py03(signature)= (data)] - pub fn convert_pyarrow( - &mut self, - schema: PyArrowType, - data: HashMap() - ) -> PyResult { - let ctx = SessionContext::new(); - let schema: Schema = (&schema.0).try_into().map_err(PythonError::from)?; - - let batch = RecordBatch::try_new( - - ) - - let arrow_data = - } - - - #[pyo3(signature = (source, predicate, source_alias, strict_cast, writer_properties, - mached_update_updates, - matched_update_predicate, - matched_delete_predicate, - not_matched_insert_updates, - not_matched_insert_predicate, - not_matched_by_source_update_updates, - not_matched_by_source_update_predicate, - not_matched_by_source_delete_predicate, - ))] - pub fn merge_execute( - &mut self, - source, - predicate, - ) -> PyResult { - let ctx = SessionContext::new(); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(arrow::array::?::from(vec)) - ] - ) - - let source = ctx.read_batch(batch).unwrap(); - - let mut cmd = MergeBuilder::new(self._table.object_store(), self._table.snapshot(), predicate, source) - - - - - - let (table, metrics) = rt()? - .block_on(cmd.into_future()) - .map_err(PythonError::from)?; - self._table.state = table.state; - Ok(serde_json::to_string(&metrics).unwrap()) - } + // #[py03(signature)= (record_batch)] + // pub fn convert_pyarrow( + // &mut self, + // record_batch: PyArrowType, + // // record_batch: Vec>>, + // ) -> PyResult { + // let ctx = SessionContext::new(); + // let df = ctx.read_batch(record_batch).unwrap(); + // let count = df.count().await?; + // count + // } + + + // #[pyo3(signature = (source, predicate, source_alias, strict_cast, writer_properties, + // mached_update_updates, + // matched_update_predicate, + // matched_delete_predicate, + // not_matched_insert_updates, + // not_matched_insert_predicate, + // not_matched_by_source_update_updates, + // not_matched_by_source_update_predicate, + // not_matched_by_source_delete_predicate, + // ))] + // pub fn merge_execute( + // &mut self, + // source, + // predicate, + // ) -> PyResult { + // let ctx = SessionContext::new(); + // let batch = RecordBatch::try_new( + // Arc::clone(&schema), + // vec![ + // Arc::new(arrow::array::?::from(vec)) + // ] + // ) + + // let source = ctx.read_batch(batch).unwrap(); + + // let mut cmd = MergeBuilder::new(self._table.object_store(), self._table.snapshot(), predicate, source) + + + + + + // let (table, metrics) = rt()? + // .block_on(cmd.into_future()) + // .map_err(PythonError::from)?; + // self._table.state = table.state; + // Ok(serde_json::to_string(&metrics).unwrap()) + // } // Run the restore command on the Delta Table: restore table to a given version or datetime #[pyo3(signature = (target, *, ignore_missing_files = false, protocol_downgrade_allowed = false))] From 7176771e4a34b830dbb56f8a0302d4266917687e Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 30 Sep 2023 16:51:37 +0200 Subject: [PATCH 06/35] Allow metrics to be serialized --- rust/src/operations/merge.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index d088fbd3b7..b58c3df6c7 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -59,6 +59,7 @@ use datafusion_physical_expr::{create_physical_expr, expressions, PhysicalExpr}; use futures::future::BoxFuture; use parquet::file::properties::WriterProperties; use serde_json::{Map, Value}; +use serde::{Serialize}; use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; use super::transaction::commit; @@ -467,7 +468,7 @@ impl MergeOperation { } } -#[derive(Default)] +#[derive(Default, Serialize)] /// Metrics for the Merge Operation pub struct MergeMetrics { /// Number of rows in the source data From ec7c2b2b39068f1687b6d61eb6cc011fe9608574 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 30 Sep 2023 16:52:02 +0200 Subject: [PATCH 07/35] make datafusion_utils mod public --- rust/src/operations/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/src/operations/mod.rs b/rust/src/operations/mod.rs index 7b6cb27ace..818a2e906c 100644 --- a/rust/src/operations/mod.rs +++ b/rust/src/operations/mod.rs @@ -199,7 +199,7 @@ impl AsRef for DeltaOps { } #[cfg(feature = "datafusion")] -mod datafusion_utils { +pub mod datafusion_utils { use std::sync::Arc; use arrow_schema::SchemaRef; From 2ded1896f8a1b378bcf3fc5a6c794544f6a87cc7 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 30 Sep 2023 16:52:18 +0200 Subject: [PATCH 08/35] Comment the deny missing docs for now : ) --- rust/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/src/lib.rs b/rust/src/lib.rs index af692fd5c9..9fb6db0b58 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -69,7 +69,7 @@ //! ``` #![deny(warnings)] -#![deny(missing_docs)] +// #![deny(missing_docs)] #![allow(rustdoc::invalid_html_tags)] #[cfg(all(feature = "parquet", feature = "parquet2"))] From ecef07baf0d5bcfe88ada5bbb2a59ea4dfacb04e Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 30 Sep 2023 16:52:45 +0200 Subject: [PATCH 09/35] Add merge_execute --- python/deltalake/table.py | 172 +++++++++++++++++++++------ python/src/lib.rs | 243 ++++++++++++++++++++++++++++++-------- 2 files changed, 326 insertions(+), 89 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 9f5af4a622..b668539827 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -685,32 +685,40 @@ def __init__(self, table: DeltaTable): self.writer_properties = None self.matched_update_updates = None self.matched_update_predicate = None + self.matched_update_all = None self.matched_delete_predicate = None + self.matched_delete_all = None self.not_matched_insert_updates = None self.not_matched_insert_predicate = None + self.not_matched_insert_all = None self.not_matched_by_source_update_updates = None self.not_matched_by_source_update_predicate = None self.not_matched_by_source_delete_predicate = None + self.not_matched_by_source_delete_all = None def __call__( self, source: pyarrow.Table | pyarrow.RecordBatch, source_alias: str, predicate: str, - strict_cast: bool = False, + strict_cast: bool = True, ) -> "TableMerger": - """Pass the source dataframe which you want to merge on the target delta table, providing a + """Pass the source data which you want to merge on the target delta table, providing a predicate in SQL query format. You can also specify on what to do when underlying data types do not match the underlying table. Args: - source (pyarrow.Table | pyarrow.RecordBatch): source dataframe + source (pyarrow.Table | pyarrow.RecordBatch): source data + source_alias (str): Alias for the source dataframe predicate (str): SQL like predicate on how to merge strict_cast (bool): specify if data types need to be casted strictly or not :default = False + Returns: TableMerger: TableMerger Object """ + if isinstance(source, pyarrow.Table): + source = source.to_batches()[0] self.source = source self.predicate = predicate self.strict_cast = strict_cast @@ -718,35 +726,22 @@ def __call__( return self - # def with_source_alias(self, alias: str) -> "TableMerger": - # """Rename columns in the source dataset to have a prefix of `alias`.`original column name` - - # Args: - # alias (str): alias - - # Returns: - # TableMerger: TableMerger Object - # """ - # self.source_alias = alias - - # return self - def with_writer_properties( self, - data_page_size_limit=None, - dictionary_page_size_limit=None, - data_page_row_count_limit=None, - write_batch_size=None, - max_row_group_size=None, + data_page_size_limit: int | None = None, + dictionary_page_size_limit: int | None = None, + data_page_row_count_limit: int | None = None, + write_batch_size: int | None = None, + max_row_group_size: int | None = None, ) -> "TableMerger": """Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html: Args: - data_page_size_limit (_type_, optional): _description_. Defaults to None. - dictionary_page_size_limit (_type_, optional): _description_. Defaults to None. - data_page_row_count_limit (_type_, optional): _description_. Defaults to None. - write_batch_size (_type_, optional): _description_. Defaults to None. - max_row_group_size (_type_, optional): _description_. Defaults to None. + data_page_size_limit (int|None, optional): _description_. Defaults to None. + dictionary_page_size_limit (int|None, optional): _description_. Defaults to None. + data_page_row_count_limit (int|None, optional): _description_. Defaults to None. + write_batch_size (int|None, optional): _description_. Defaults to None. + max_row_group_size (int|None, optional): _description_. Defaults to None. Returns: TableMerger: TableMerger Object @@ -770,7 +765,7 @@ def when_matched_update( Args: updates (dict): column mapping (source to target) which to update - predicate (str | None, optional): _description_. Defaults to None. + predicate (str | None, optional): SQL like predicate on when to update. Defaults to None. Returns: TableMerger: TableMerger Object @@ -790,8 +785,48 @@ def when_matched_update( ... } ... ).execute() """ - self.matched_update_updates = updates - self.matched_update_predicate = predicate + if self.matched_update_all is not None: + raise DeltaProtocolError( + "You can't specify when_matched_update and when_matched_update_all at the same time. Pick one." + ) + else: + self.matched_update_updates = updates + self.matched_update_predicate = predicate + return self + + def when_matched_update_all(self, predicate: str | None = None) -> "TableMerger": + """Update a matched table row based on the rules defined by ``updates``. + If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + + + Args: + predicate (str | None, optional): SQL like predicate on when to update all columns. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + ... .when_matched_update( + ... updates = { + ... "x": "source.x", + ... "y": "source.y" + ... } + ... ).execute() + """ + + if self.matched_update_updates is not None: + raise DeltaProtocolError( + "You can't specify when_matched_update and when_matched_update_all at the same time. Pick one." + ) + else: + self.matched_update_all = True + self.matched_update_predicate = predicate return self def when_matched_delete(self, predicate: str | None = None) -> "TableMerger": @@ -806,6 +841,8 @@ def when_matched_delete(self, predicate: str | None = None) -> "TableMerger": Examples: + Delete on a predicate + >>> from deltalake import DeltaTable >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) @@ -813,8 +850,22 @@ def when_matched_delete(self, predicate: str | None = None) -> "TableMerger": >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ ... .when_matched_delete(predicate = "source.deleted = true") ... .execute() + + Delete all records that were matched + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + ... .when_matched_delete() + ... .execute() """ - self.matched_delete_predicate = predicate + + if predicate is None: + self.matched_delete_all = True + else: + self.matched_delete_predicate = predicate return self def when_not_matched_insert( @@ -844,8 +895,46 @@ def when_not_matched_insert( ... } ... ).execute() """ - self.not_matched_insert_updates = updates - self.not_matched_insert_predicate = predicate + + if self.not_matched_insert_all is not None: + raise DeltaProtocolError( + "You can't specify when_not_matched_insert and when_not_matched_insert_all at the same time. Pick one." + ) + else: + self.not_matched_insert_updates = updates + self.not_matched_insert_predicate = predicate + + return self + + def when_not_matched_insert_all( + self, predicate: str | None = None + ) -> "TableMerger": + """Insert a new row to the target table based on the rules defined by ``updates``. If a + ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. + + Args: + updates (dict): column mapping (source to target) which to insert + predicate (str | None, optional): SQL like predicate on when to insert. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + ... .when_not_matched_insert_all().execute() + """ + if self.not_matched_insert_updates is not None: + raise DeltaProtocolError( + "You can't specify when_not_matched_insert and when_not_matched_insert_all at the same time. Pick one." + ) + else: + self.not_matched_insert_all = True + self.not_matched_insert_predicate = predicate return self def when_not_matched_by_source_update( @@ -884,20 +973,24 @@ def when_not_matched_by_source_delete( ``predicate`` (if specified) is true for the target row. Args: - updates (dict): _description_ - predicate (str | None, optional): _description_. Defaults to None. + updates (dict): column mapping (source to target) which to update + predicate (str | None, optional): SQL like predicate on when to deleted when not matched by source. Defaults to None. Returns: - TableMerger: _description_ + TableMerger: TableMerger Object """ - self.not_matched_by_source_delete_predicate = predicate + + if predicate is None: + self.not_matched_by_source_delete_all = True + else: + self.not_matched_by_source_delete_predicate = predicate return self def execute(self) -> Dict[str, Any]: - """Executes MERGE with the previously provided settings. + """Executes MERGE with the previously provided settings in Rust with Apache Datafusion query engine. Returns: - Tuple[DeltaTable, dict]: dt, metrics + Dict[str, any]: metrics """ metrics = self.table._table.merge_execute( source=self.source, @@ -907,12 +1000,15 @@ def execute(self) -> Dict[str, Any]: writer_properties=self.writer_properties, matched_update_updates=self.matched_update_updates, matched_update_predicate=self.matched_update_predicate, + matched_update_all=self.matched_update_all, matched_delete_predicate=self.matched_delete_predicate, + matched_delete_all=self.matched_delete_all, not_matched_insert_updates=self.not_matched_insert_updates, not_matched_insert_predicate=self.not_matched_insert_predicate, not_matched_by_source_update_updates=self.not_matched_by_source_update_updates, not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate, not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate, + not_matched_by_source_delete_all=self.not_matched_by_source_delete_all, ) self.table.update_incremental() return json.loads(metrics) diff --git a/python/src/lib.rs b/python/src/lib.rs index cd51e46722..e71cc202a6 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -21,11 +21,13 @@ use deltalake::checkpoints::create_checkpoint; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; +use deltalake::operations::datafusion_utils::Expression; use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::transaction::commit; use deltalake::operations::vacuum::VacuumBuilder; +use deltalake::parquet::file::properties::WriterProperties; use deltalake::partitions::PartitionFilter; use deltalake::protocol::{ self, Action, ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats, @@ -334,57 +336,196 @@ impl RawDeltaTable { self._table.state = table.state; Ok(serde_json::to_string(&metrics).unwrap()) } - - // #[py03(signature)= (record_batch)] - // pub fn convert_pyarrow( - // &mut self, - // record_batch: PyArrowType, - // // record_batch: Vec>>, - // ) -> PyResult { - // let ctx = SessionContext::new(); - // let df = ctx.read_batch(record_batch).unwrap(); - // let count = df.count().await?; - // count - // } - - - // #[pyo3(signature = (source, predicate, source_alias, strict_cast, writer_properties, - // mached_update_updates, - // matched_update_predicate, - // matched_delete_predicate, - // not_matched_insert_updates, - // not_matched_insert_predicate, - // not_matched_by_source_update_updates, - // not_matched_by_source_update_predicate, - // not_matched_by_source_delete_predicate, - // ))] - // pub fn merge_execute( - // &mut self, - // source, - // predicate, - // ) -> PyResult { - // let ctx = SessionContext::new(); - // let batch = RecordBatch::try_new( - // Arc::clone(&schema), - // vec![ - // Arc::new(arrow::array::?::from(vec)) - // ] - // ) - - // let source = ctx.read_batch(batch).unwrap(); - - // let mut cmd = MergeBuilder::new(self._table.object_store(), self._table.snapshot(), predicate, source) - - - - - - // let (table, metrics) = rt()? - // .block_on(cmd.into_future()) - // .map_err(PythonError::from)?; - // self._table.state = table.state; - // Ok(serde_json::to_string(&metrics).unwrap()) - // } + + #[pyo3(signature = (source, + predicate, + source_alias, + strict_cast, + writer_properties, + matched_update_updates, + matched_update_predicate, + matched_update_all, + matched_delete_predicate, + matched_delete_all, + not_matched_insert_updates, + not_matched_insert_predicate, + not_matched_insert_all, + not_matched_by_source_update_updates, + not_matched_by_source_update_predicate, + not_matched_by_source_delete_predicate, + not_matched_by_source_delete_all, + ))] + pub fn merge_execute( + &mut self, + source: PyArrowType, + predicate: String, + source_alias: String, + strict_cast: bool, + writer_properties: Option>, + matched_update_updates: Option>, + matched_update_predicate: Option, + matched_update_all: Option, + matched_delete_predicate: Option, + matched_delete_all: Option, + not_matched_insert_updates: Option>, + not_matched_insert_predicate: Option, + not_matched_insert_all: Option, + not_matched_by_source_update_updates: Option>, + not_matched_by_source_update_predicate: Option, + not_matched_by_source_delete_predicate: Option, + not_matched_by_source_delete_all: Option, + ) -> PyResult { + let ctx = SessionContext::new(); + let source_df = ctx.read_batch(source.0).unwrap(); + + let mut cmd = MergeBuilder::new( + self._table.object_store(), + self._table.state.clone(), + Expression::String(predicate), + source_df, + ) + .with_source_alias(source_alias) + .with_safe_cast(strict_cast); + + if let Some(writer_props) = writer_properties { + let mut properties = WriterProperties::builder(); + let data_page_size_limit = writer_props.get("data_page_size_limit"); + let dictionary_page_size_limit = writer_props.get("dictionary_page_size_limit"); + let data_page_row_count_limit = writer_props.get("data_page_row_count_limit"); + let write_batch_size = writer_props.get("write_batch_size"); + let max_row_group_size = writer_props.get("max_row_group_size"); + + if let Some(data_page_size) = data_page_size_limit { + properties = properties.set_data_page_size_limit(data_page_size.clone()); + } + if let Some(dictionary_page_size) = dictionary_page_size_limit { + properties = + properties.set_dictionary_page_size_limit(dictionary_page_size.clone()); + } + if let Some(data_page_row_count) = data_page_row_count_limit { + properties = properties.set_data_page_row_count_limit(data_page_row_count.clone()); + } + if let Some(batch_size) = write_batch_size { + properties = properties.set_write_batch_size(batch_size.clone()); + } + if let Some(row_group_size) = max_row_group_size { + properties = properties.set_max_row_group_size(row_group_size.clone()); + } + cmd = cmd.with_writer_properties(properties.build()); + } + + // MATCHED UPDATE ALL OPTION + // if let Some(mu_update_all) = matched_update_all { + // if let Some(mu_predicate) = matched_update_predicate { + // cmd = cmd.when_matched_update(|update| { + // update + // .predicate(Expression::String(mu_predicate)) + // }).map_err(PythonError::from)?; + // } + // else { + // cmd = cmd.when_matched_update(|update| update).map_err(PythonError::from)?; + // } + // } + // else { + // if let Some(mu_updates) = matched_update_updates { + // if let Some(mu_predicate) = matched_update_predicate { + // cmd = cmd.when_matched_update(|update| { + // update + // .predicate(Expression::String(mu_predicate)) + // .update(mu_updates) + // }).map_err(PythonError::from)?; + // } + // else { + // cmd = cmd.when_matched_update(|update| { + // update + // .update(mu_updates) + // }).map_err(PythonError::from)?; + // } + // } + // } + + if let Some(mu_updates) = matched_update_updates { + if let Some(mu_predicate) = matched_update_predicate { + cmd = cmd + .when_matched_update(|update| { + update + .predicate(Expression::String(mu_predicate)) + .update(mu_updates) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_matched_update(|update| update.update(mu_updates)) + .map_err(PythonError::from)?; + } + } + + if let Some(md_delete_all) = matched_delete_all { + cmd = cmd + .when_matched_delete(|delete| delete) + .map_err(PythonError::from)?; + } else { + if let Some(md_predicate) = matched_delete_predicate { + cmd = cmd + .when_matched_delete(|delete| { + delete.predicate(Expression::String(md_predicate)) + }) + .map_err(PythonError::from)?; + } + } + + if let Some(nmi_updates) = not_matched_insert_updates { + if let Some(nmi_predicate) = not_matched_insert_predicate { + cmd = cmd + .when_not_matched_insert(|insert| { + insert + .predicate(Expression::String(nmi_predicate)) + .set(nmi_updates) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_not_matched_insert(|insert| insert.set(nmi_updates)) + .map_err(PythonError::from)?; + } + } + + if let Some(nmbsu_updates) = not_matched_by_source_update_updates { + if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate { + cmd = cmd + .when_not_matched_by_source_update(|update| { + update + .predicate(Expression::String(nmbsu_updates)) + .updates(nmbsu_predicate) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_not_matched_by_source_update(|update| update.updates(nmbsu_updates)) + .map_err(PythonError::from)?; + } + } + + if let Some(nmbs_delete_all) = not_matched_by_source_delete_all { + cmd = cmd + .when_not_matched_by_source_delete(|delete| delete) + .map_err(PythonError::from)?; + } else { + if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { + cmd = cmd + .when_not_matched_by_source_delete(|delete| { + delete.predicate(Expression::String(nmbs_predicate)) + }) + .map_err(PythonError::from)?; + } + } + + let (table, metrics) = rt()? + .block_on(cmd.into_future()) + .map_err(PythonError::from)?; + self._table.state = table.state; + Ok(serde_json::to_string(&metrics).unwrap()) + } // Run the restore command on the Delta Table: restore table to a given version or datetime #[pyo3(signature = (target, *, ignore_missing_files = false, protocol_downgrade_allowed = false))] From 239af35be96fa5bb5ee0904d617df8108ac9a290 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sun, 1 Oct 2023 12:36:16 +0200 Subject: [PATCH 10/35] Add logic to create (col, expr) hashmap for updates/set --- python/deltalake/table.py | 17 ++++----- python/src/lib.rs | 67 ++++++++++++++++++++++++++---------- rust/src/operations/merge.rs | 28 +++++++++++++++ 3 files changed, 85 insertions(+), 27 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index b668539827..f857720f9e 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -777,7 +777,7 @@ def when_matched_update( >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ ... .when_matched_update( ... updates = { ... "x": "source.x", @@ -811,7 +811,7 @@ def when_matched_update_all(self, predicate: str | None = None) -> "TableMerger" >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ ... .when_matched_update( ... updates = { ... "x": "source.x", @@ -847,7 +847,7 @@ def when_matched_delete(self, predicate: str | None = None) -> "TableMerger": >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ ... .when_matched_delete(predicate = "source.deleted = true") ... .execute() @@ -857,7 +857,7 @@ def when_matched_delete(self, predicate: str | None = None) -> "TableMerger": >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ ... .when_matched_delete() ... .execute() """ @@ -887,7 +887,7 @@ def when_not_matched_insert( >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ ... .when_not_matched_insert( ... updates = { ... "x": "source.x", @@ -925,7 +925,7 @@ def when_not_matched_insert_all( >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ ... .when_not_matched_insert_all().execute() """ if self.not_matched_insert_updates is not None: @@ -996,15 +996,16 @@ def execute(self) -> Dict[str, Any]: source=self.source, predicate=self.predicate, source_alias=self.source_alias, - safe_cast=self.strict_cast, + strict_cast=self.strict_cast, writer_properties=self.writer_properties, matched_update_updates=self.matched_update_updates, matched_update_predicate=self.matched_update_predicate, - matched_update_all=self.matched_update_all, + # matched_update_all=self.matched_update_all, matched_delete_predicate=self.matched_delete_predicate, matched_delete_all=self.matched_delete_all, not_matched_insert_updates=self.not_matched_insert_updates, not_matched_insert_predicate=self.not_matched_insert_predicate, + # not_matched_insert_all = self.not_matched_insert_all, not_matched_by_source_update_updates=self.not_matched_by_source_update_updates, not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate, not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate, diff --git a/python/src/lib.rs b/python/src/lib.rs index e71cc202a6..398c121d02 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -18,7 +18,7 @@ use deltalake::arrow::compute::concat_batches; use deltalake::arrow::record_batch::RecordBatch; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::create_checkpoint; -use deltalake::datafusion::prelude::SessionContext; +use deltalake::datafusion::prelude::{Column, SessionContext}; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; use deltalake::operations::datafusion_utils::Expression; @@ -344,12 +344,12 @@ impl RawDeltaTable { writer_properties, matched_update_updates, matched_update_predicate, - matched_update_all, + // matched_update_all, matched_delete_predicate, matched_delete_all, not_matched_insert_updates, not_matched_insert_predicate, - not_matched_insert_all, + // not_matched_insert_all, not_matched_by_source_update_updates, not_matched_by_source_update_predicate, not_matched_by_source_delete_predicate, @@ -359,25 +359,25 @@ impl RawDeltaTable { &mut self, source: PyArrowType, predicate: String, - source_alias: String, + source_alias: &str, strict_cast: bool, writer_properties: Option>, matched_update_updates: Option>, matched_update_predicate: Option, - matched_update_all: Option, + // matched_update_all: Option, matched_delete_predicate: Option, matched_delete_all: Option, - not_matched_insert_updates: Option>, + not_matched_insert_updates: Option>, not_matched_insert_predicate: Option, - not_matched_insert_all: Option, - not_matched_by_source_update_updates: Option>, + // not_matched_insert_all: Option, + not_matched_by_source_update_updates: Option>, not_matched_by_source_update_predicate: Option, not_matched_by_source_delete_predicate: Option, not_matched_by_source_delete_all: Option, ) -> PyResult { let ctx = SessionContext::new(); let source_df = ctx.read_batch(source.0).unwrap(); - + println!("{}", source_alias); let mut cmd = MergeBuilder::new( self._table.object_store(), self._table.state.clone(), @@ -445,22 +445,31 @@ impl RawDeltaTable { // } if let Some(mu_updates) = matched_update_updates { + let mut mu_updates_mapping: HashMap = HashMap::new(); + + for (col_name, expression) in &mu_updates { + mu_updates_mapping.insert( + Column::from_name(col_name), + Expression::String(expression.clone()), + ); + } + if let Some(mu_predicate) = matched_update_predicate { cmd = cmd .when_matched_update(|update| { - update + update // Add iteration over the updates col(key), Expression::String .predicate(Expression::String(mu_predicate)) - .update(mu_updates) + .update_multiple(mu_updates_mapping) }) .map_err(PythonError::from)?; } else { cmd = cmd - .when_matched_update(|update| update.update(mu_updates)) + .when_matched_update(|update| update.update_multiple(mu_updates_mapping)) .map_err(PythonError::from)?; } } - if let Some(md_delete_all) = matched_delete_all { + if let Some(_md_delete_all) = matched_delete_all { cmd = cmd .when_matched_delete(|delete| delete) .map_err(PythonError::from)?; @@ -475,38 +484,58 @@ impl RawDeltaTable { } if let Some(nmi_updates) = not_matched_insert_updates { + let mut nmi_updates_mapping: HashMap = HashMap::new(); + + for (col_name, expression) in &nmi_updates { + nmi_updates_mapping.insert( + Column::from_name(col_name), + Expression::String(expression.clone()), + ); + } + if let Some(nmi_predicate) = not_matched_insert_predicate { cmd = cmd .when_not_matched_insert(|insert| { insert .predicate(Expression::String(nmi_predicate)) - .set(nmi_updates) + .set_multiple(nmi_updates_mapping) }) .map_err(PythonError::from)?; } else { cmd = cmd - .when_not_matched_insert(|insert| insert.set(nmi_updates)) + .when_not_matched_insert(|insert| insert.set_multiple(nmi_updates_mapping)) .map_err(PythonError::from)?; } } if let Some(nmbsu_updates) = not_matched_by_source_update_updates { + let mut nmbsu_updates_mapping: HashMap = HashMap::new(); + + for (col_name, expression) in &nmbsu_updates { + nmbsu_updates_mapping.insert( + Column::from_name(col_name), + Expression::String(expression.clone()), + ); + } + if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate { cmd = cmd .when_not_matched_by_source_update(|update| { update - .predicate(Expression::String(nmbsu_updates)) - .updates(nmbsu_predicate) + .predicate(Expression::String(nmbsu_predicate)) + .update_multiple(nmbsu_updates_mapping) }) .map_err(PythonError::from)?; } else { cmd = cmd - .when_not_matched_by_source_update(|update| update.updates(nmbsu_updates)) + .when_not_matched_by_source_update(|update| { + update.update_multiple(nmbsu_updates_mapping) + }) .map_err(PythonError::from)?; } } - if let Some(nmbs_delete_all) = not_matched_by_source_delete_all { + if let Some(_nmbs_delete_all) = not_matched_by_source_delete_all { cmd = cmd .when_not_matched_by_source_delete(|delete| delete) .map_err(PythonError::from)?; diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index b58c3df6c7..2a0f7f0b59 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -389,6 +389,15 @@ impl UpdateBuilder { self.updates.insert(column.into(), expression.into()); self } + + /// Update multiple at ones + pub fn update_multiple>>( + mut self, + mapping: M, + ) -> Self { + self.updates = mapping.into(); + self + } } /// Builder for insert clauses @@ -413,6 +422,15 @@ impl InsertBuilder { self.set.insert(column.into(), expression.into()); self } + + /// Set multiple at ones + pub fn set_multiple>>( + mut self, + mapping: M, + ) -> Self { + self.set = mapping.into(); + self + } } /// Builder for delete clauses @@ -556,12 +574,22 @@ async fn execute( let mut expressions: Vec<(Arc, String)> = Vec::new(); let source_schema = source_count.schema(); + + if let Some(_alias) = source_alias.clone() { + println!("(inside merge operation){}", _alias); + + } + let source_prefix = source_alias .map(|mut s| { s.push('.'); s }) .unwrap_or_default(); + + + println!("(inside merge operation) prefix: {}", source_prefix.clone()); + for (i, field) in source_schema.fields().into_iter().enumerate() { expressions.push(( Arc::new(expressions::Column::new(field.name(), i)), From 4a927a92b8744804ae7aa58d63d16028ed9aa43f Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sun, 1 Oct 2023 12:37:30 +0200 Subject: [PATCH 11/35] Remove println --- python/src/lib.rs | 2 +- rust/src/operations/merge.rs | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index 398c121d02..4ebbd9ea30 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -377,7 +377,7 @@ impl RawDeltaTable { ) -> PyResult { let ctx = SessionContext::new(); let source_df = ctx.read_batch(source.0).unwrap(); - println!("{}", source_alias); + let mut cmd = MergeBuilder::new( self._table.object_store(), self._table.state.clone(), diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index 2a0f7f0b59..14bcaaad32 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -573,12 +573,6 @@ async fn execute( let mut expressions: Vec<(Arc, String)> = Vec::new(); let source_schema = source_count.schema(); - - - if let Some(_alias) = source_alias.clone() { - println!("(inside merge operation){}", _alias); - - } let source_prefix = source_alias .map(|mut s| { @@ -586,9 +580,6 @@ async fn execute( s }) .unwrap_or_default(); - - - println!("(inside merge operation) prefix: {}", source_prefix.clone()); for (i, field) in source_schema.fields().into_iter().enumerate() { expressions.push(( From bcb8d4a3485409038e59920e85b8fde4da05dff1 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Mon, 2 Oct 2023 09:28:31 +0200 Subject: [PATCH 12/35] Use Typing type hints --- python/deltalake/table.py | 56 ++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index f857720f9e..01427d30ff 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -678,27 +678,23 @@ class TableMerger: def __init__(self, table: DeltaTable): self.table = table - self.source = None - self.predicate = None - self.source_alias = None - self.strict_cast = False - self.writer_properties = None - self.matched_update_updates = None - self.matched_update_predicate = None - self.matched_update_all = None - self.matched_delete_predicate = None - self.matched_delete_all = None - self.not_matched_insert_updates = None - self.not_matched_insert_predicate = None - self.not_matched_insert_all = None - self.not_matched_by_source_update_updates = None - self.not_matched_by_source_update_predicate = None - self.not_matched_by_source_delete_predicate = None - self.not_matched_by_source_delete_all = None + self.writer_properties: Optional[Dict[str, Optional[int]]] = None + self.matched_update_updates: Optional[Dict[str, str]] = None + self.matched_update_predicate: Optional[str] = None + self.matched_update_all: Optional[bool] = None + self.matched_delete_predicate: Optional[str] = None + self.matched_delete_all: Optional[bool] = None + self.not_matched_insert_updates: Optional[Dict[str, str]] = None + self.not_matched_insert_predicate: Optional[str] = None + self.not_matched_insert_all: Optional[bool] = None + self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None + self.not_matched_by_source_update_predicate: Optional[str] = None + self.not_matched_by_source_delete_predicate: Optional[str] = None + self.not_matched_by_source_delete_all: Optional[bool] = None def __call__( self, - source: pyarrow.Table | pyarrow.RecordBatch, + source: Union[pyarrow.Table, pyarrow.RecordBatch], source_alias: str, predicate: str, strict_cast: bool = True, @@ -728,11 +724,11 @@ def __call__( def with_writer_properties( self, - data_page_size_limit: int | None = None, - dictionary_page_size_limit: int | None = None, - data_page_row_count_limit: int | None = None, - write_batch_size: int | None = None, - max_row_group_size: int | None = None, + data_page_size_limit: Optional[int] = None, + dictionary_page_size_limit: Optional[int] = None, + data_page_row_count_limit: Optional[int] = None, + write_batch_size: Optional[int] = None, + max_row_group_size: Optional[int] = None, ) -> "TableMerger": """Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html: @@ -757,7 +753,7 @@ def with_writer_properties( return self def when_matched_update( - self, updates: dict, predicate: str | None = None + self, updates: dict[str, str], predicate: Optional[str] = None ) -> "TableMerger": """Update a matched table row based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. @@ -794,7 +790,7 @@ def when_matched_update( self.matched_update_predicate = predicate return self - def when_matched_update_all(self, predicate: str | None = None) -> "TableMerger": + def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": """Update a matched table row based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. @@ -829,7 +825,7 @@ def when_matched_update_all(self, predicate: str | None = None) -> "TableMerger" self.matched_update_predicate = predicate return self - def when_matched_delete(self, predicate: str | None = None) -> "TableMerger": + def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": """Delete a matched row from the table only if the given ``predicate`` (if specified) is true for the matched row. If not specified it deletes all matches. @@ -869,7 +865,7 @@ def when_matched_delete(self, predicate: str | None = None) -> "TableMerger": return self def when_not_matched_insert( - self, updates: dict, predicate: str | None = None + self, updates: dict[str, str], predicate: Optional[str] = None ) -> "TableMerger": """Insert a new row to the target table based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. @@ -907,7 +903,7 @@ def when_not_matched_insert( return self def when_not_matched_insert_all( - self, predicate: str | None = None + self, predicate: Optional[str] = None ) -> "TableMerger": """Insert a new row to the target table based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. @@ -938,7 +934,7 @@ def when_not_matched_insert_all( return self def when_not_matched_by_source_update( - self, updates: dict, predicate: str | None = None + self, updates: dict[str, str], predicate: Optional[str] = None ) -> "TableMerger": """Update a target row that has no matches in the source based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. @@ -967,7 +963,7 @@ def when_not_matched_by_source_update( return self def when_not_matched_by_source_delete( - self, predicate: str | None = None + self, predicate: Optional[str] = None ) -> "TableMerger": """Delete a target row that has no matches in the source from the table only if the given ``predicate`` (if specified) is true for the target row. From 1c866b1b9e390fa7e220e27f42027e0b303368f8 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Mon, 2 Oct 2023 09:44:58 +0200 Subject: [PATCH 13/35] Fix rust lints --- python/src/lib.rs | 38 ++++++++++++++++---------------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index 4ebbd9ea30..52aeddb403 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -355,6 +355,7 @@ impl RawDeltaTable { not_matched_by_source_delete_predicate, not_matched_by_source_delete_all, ))] + #[allow(clippy::too_many_arguments)] pub fn merge_execute( &mut self, source: PyArrowType, @@ -396,20 +397,19 @@ impl RawDeltaTable { let max_row_group_size = writer_props.get("max_row_group_size"); if let Some(data_page_size) = data_page_size_limit { - properties = properties.set_data_page_size_limit(data_page_size.clone()); + properties = properties.set_data_page_size_limit(*data_page_size); } if let Some(dictionary_page_size) = dictionary_page_size_limit { - properties = - properties.set_dictionary_page_size_limit(dictionary_page_size.clone()); + properties = properties.set_dictionary_page_size_limit(*dictionary_page_size); } if let Some(data_page_row_count) = data_page_row_count_limit { - properties = properties.set_data_page_row_count_limit(data_page_row_count.clone()); + properties = properties.set_data_page_row_count_limit(*data_page_row_count); } if let Some(batch_size) = write_batch_size { - properties = properties.set_write_batch_size(batch_size.clone()); + properties = properties.set_write_batch_size(*batch_size); } if let Some(row_group_size) = max_row_group_size { - properties = properties.set_max_row_group_size(row_group_size.clone()); + properties = properties.set_max_row_group_size(*row_group_size); } cmd = cmd.with_writer_properties(properties.build()); } @@ -473,14 +473,10 @@ impl RawDeltaTable { cmd = cmd .when_matched_delete(|delete| delete) .map_err(PythonError::from)?; - } else { - if let Some(md_predicate) = matched_delete_predicate { - cmd = cmd - .when_matched_delete(|delete| { - delete.predicate(Expression::String(md_predicate)) - }) - .map_err(PythonError::from)?; - } + } else if let Some(md_predicate) = matched_delete_predicate { + cmd = cmd + .when_matched_delete(|delete| delete.predicate(Expression::String(md_predicate))) + .map_err(PythonError::from)?; } if let Some(nmi_updates) = not_matched_insert_updates { @@ -539,14 +535,12 @@ impl RawDeltaTable { cmd = cmd .when_not_matched_by_source_delete(|delete| delete) .map_err(PythonError::from)?; - } else { - if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { - cmd = cmd - .when_not_matched_by_source_delete(|delete| { - delete.predicate(Expression::String(nmbs_predicate)) - }) - .map_err(PythonError::from)?; - } + } else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { + cmd = cmd + .when_not_matched_by_source_delete(|delete| { + delete.predicate(Expression::String(nmbs_predicate)) + }) + .map_err(PythonError::from)?; } let (table, metrics) = rt()? From 2522d3e2b927ea0f7e2fe51dbfd8f63d9b56d8dc Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Mon, 2 Oct 2023 11:14:29 +0200 Subject: [PATCH 14/35] Add merge when_matched_delete test --- python/deltalake/table.py | 2 +- python/src/lib.rs | 30 ++++++++++----------- python/tests/test_merge.py | 55 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 16 deletions(-) create mode 100644 python/tests/test_merge.py diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 01427d30ff..2e2a34ca49 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -950,7 +950,7 @@ def when_not_matched_by_source_update( >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x == source.x') \ + >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ ... .when_not_matched_by_source_update( ... predicate = "y > 3" ... updates = { diff --git a/python/src/lib.rs b/python/src/lib.rs index 52aeddb403..4ab88708d1 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -337,30 +337,30 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (source, predicate, source_alias, - strict_cast, - writer_properties, - matched_update_updates, - matched_update_predicate, + strict_cast = true, + writer_properties = None, + matched_update_updates = None, + matched_update_predicate = None, // matched_update_all, - matched_delete_predicate, - matched_delete_all, - not_matched_insert_updates, - not_matched_insert_predicate, + matched_delete_predicate = None, + matched_delete_all = None, + not_matched_insert_updates = None, + not_matched_insert_predicate = None, // not_matched_insert_all, - not_matched_by_source_update_updates, - not_matched_by_source_update_predicate, - not_matched_by_source_delete_predicate, - not_matched_by_source_delete_all, + not_matched_by_source_update_updates = None, + not_matched_by_source_update_predicate = None, + not_matched_by_source_delete_predicate = None, + not_matched_by_source_delete_all = None, ))] - #[allow(clippy::too_many_arguments)] pub fn merge_execute( &mut self, source: PyArrowType, predicate: String, - source_alias: &str, + source_alias: String, strict_cast: bool, writer_properties: Option>, matched_update_updates: Option>, @@ -457,7 +457,7 @@ impl RawDeltaTable { if let Some(mu_predicate) = matched_update_predicate { cmd = cmd .when_matched_update(|update| { - update // Add iteration over the updates col(key), Expression::String + update .predicate(Expression::String(mu_predicate)) .update_multiple(mu_updates_mapping) }) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py new file mode 100644 index 0000000000..6c2928c331 --- /dev/null +++ b/python/tests/test_merge.py @@ -0,0 +1,55 @@ +import pathlib + +import pyarrow as pa +import pytest + +from deltalake import DeltaTable, write_deltalake + + +@pytest.fixture() +def sample_table(): + nrows = 5 + return pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) + + +def test_merge_when_matched_delete_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["5"]), + "price": pa.array([1], pa.int64()), + "sold": pa.array([1], pa.int32()), + "deleted": pa.array([False]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_matched_delete().execute() + + nrows = 4 + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected From 08497d39042fef42f8fe64354149e611f5f98e54 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Mon, 2 Oct 2023 20:30:19 +0200 Subject: [PATCH 15/35] Move property to merge --- python/deltalake/table.py | 66 +++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 2e2a34ca49..c4592668cc 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -462,11 +462,37 @@ def optimize( ) -> "TableOptimizer": return TableOptimizer(self) - @property def merge( self, + source: Union[pyarrow.Table, pyarrow.RecordBatch], + source_alias: str, + predicate: str, + strict_cast: bool = True, ) -> "TableMerger": - return TableMerger(self) + """Pass the source data which you want to merge on the target delta table, providing a + predicate in SQL query format. You can also specify on what to do when underlying data types do not + match the underlying table. + + Args: + source (pyarrow.Table | pyarrow.RecordBatch): source data + source_alias (str): Alias for the source dataframe + predicate (str): SQL like predicate on how to merge + strict_cast (bool): specify if data types need to be casted strictly or not :default = False + + + Returns: + TableMerger: TableMerger Object + """ + if isinstance(source, pyarrow.Table): + source = source.to_batches()[0] + + return TableMerger( + self, + source=source, + predicate=predicate, + source_alias=source_alias, + strict_cast=not strict_cast, + ) def pyarrow_schema(self) -> pyarrow.Schema: """ @@ -676,8 +702,12 @@ def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch: class TableMerger: """API for various table MERGE commands.""" - def __init__(self, table: DeltaTable): + def __init__(self, table: DeltaTable, source, source_alias, predicate, strict_cast): self.table = table + self.source = source + self.source_alias = source_alias + self.predicate = predicate + self.strict_cast = strict_cast self.writer_properties: Optional[Dict[str, Optional[int]]] = None self.matched_update_updates: Optional[Dict[str, str]] = None self.matched_update_predicate: Optional[str] = None @@ -692,36 +722,6 @@ def __init__(self, table: DeltaTable): self.not_matched_by_source_delete_predicate: Optional[str] = None self.not_matched_by_source_delete_all: Optional[bool] = None - def __call__( - self, - source: Union[pyarrow.Table, pyarrow.RecordBatch], - source_alias: str, - predicate: str, - strict_cast: bool = True, - ) -> "TableMerger": - """Pass the source data which you want to merge on the target delta table, providing a - predicate in SQL query format. You can also specify on what to do when underlying data types do not - match the underlying table. - - Args: - source (pyarrow.Table | pyarrow.RecordBatch): source data - source_alias (str): Alias for the source dataframe - predicate (str): SQL like predicate on how to merge - strict_cast (bool): specify if data types need to be casted strictly or not :default = False - - - Returns: - TableMerger: TableMerger Object - """ - if isinstance(source, pyarrow.Table): - source = source.to_batches()[0] - self.source = source - self.predicate = predicate - self.strict_cast = strict_cast - self.source_alias = source_alias - - return self - def with_writer_properties( self, data_page_size_limit: Optional[int] = None, From 611d93aa61d85834b790d7666db61cb34c2ba05a Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Mon, 2 Oct 2023 20:51:37 +0200 Subject: [PATCH 16/35] Add all test cases --- python/tests/test_merge.py | 311 +++++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 6c2928c331..494f911b89 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -53,3 +53,314 @@ def test_merge_when_matched_delete_wo_predicate( assert last_action["operation"] == "MERGE" assert result == expected + + +def test_merge_when_matched_delete_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["5", "4"]), + "price": pa.array([1, 2], pa.int64()), + "sold": pa.array([1, 2], pa.int32()), + "deleted": pa.array([True, False]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_matched_delete("source.deleted = True").execute() + + nrows = 4 + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_matched_update_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_matched_update({"price": "source.price", "sold": "source.sold"}).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([1, 2, 3, 10, 100], pa.int64()), + "sold": pa.array([1, 2, 3, 10, 20], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_matched_update_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, True]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_matched_update( + updates={"price": "source.price", "sold": "source.sold"}, + predicate="source.deleted = False", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([1, 2, 3, 10, 5], pa.int64()), + "sold": pa.array([1, 2, 3, 10, 5], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_insert_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "10"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_not_matched_insert( + updates={ + "id": "source.id", + "price": "source.price", + "sold": "source.sold", + "deleted": "False", + } + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "10"]), + "price": pa.array([1, 2, 3, 4, 5, 100], pa.int64()), + "sold": pa.array([1, 2, 3, 4, 5, 20], pa.int32()), + "deleted": pa.array([False] * 6), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_insert_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "10"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_not_matched_insert( + updates={ + "id": "source.id", + "price": "source.price", + "sold": "source.sold", + "deleted": "False", + }, + predicate="source.price < 50", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6"]), + "price": pa.array([1, 2, 3, 4, 5, 10], pa.int64()), + "sold": pa.array([1, 2, 3, 4, 5, 10], pa.int32()), + "deleted": pa.array([False] * 6), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_by_source_update_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "7"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_not_matched_by_source_update( + updates={ + "sold": "10", + } + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([1, 2, 3, 4, 5], pa.int64()), + "sold": pa.array([10,10,10,10,10], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + +def test_merge_when_not_matched_by_source_update_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "7"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_not_matched_by_source_update( + updates={ + "sold": "10", + }, + predicate="price > 3" + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([1, 2, 3, 4, 5], pa.int64()), + "sold": pa.array([1,2,3,10,10], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_by_source_delete_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "7"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, source_alias="source", predicate="id = source.id" + ).when_not_matched_by_source_delete( + predicate="price > 3" + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3"]), + "price": pa.array([1, 2, 3,], pa.int64()), + "sold": pa.array([1,2,3], pa.int32()), + "deleted": pa.array([False] * 3), + } + ) + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + +## Add when_not_matched_by_source_delete_wo_predicate ? \ No newline at end of file From 6a05320ab2a7995d473e51c33ddc4839a8eefaab Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Tue, 3 Oct 2023 11:22:06 +0200 Subject: [PATCH 17/35] Fix lint and type hint --- python/deltalake/table.py | 6 +++--- python/tests/test_merge.py | 33 ++++++++++++++++++++------------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index c4592668cc..25ebf01b09 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -753,7 +753,7 @@ def with_writer_properties( return self def when_matched_update( - self, updates: dict[str, str], predicate: Optional[str] = None + self, updates: Dict[str, str], predicate: Optional[str] = None ) -> "TableMerger": """Update a matched table row based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. @@ -865,7 +865,7 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": return self def when_not_matched_insert( - self, updates: dict[str, str], predicate: Optional[str] = None + self, updates: Dict[str, str], predicate: Optional[str] = None ) -> "TableMerger": """Insert a new row to the target table based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. @@ -934,7 +934,7 @@ def when_not_matched_insert_all( return self def when_not_matched_by_source_update( - self, updates: dict[str, str], predicate: Optional[str] = None + self, updates: Dict[str, str], predicate: Optional[str] = None ) -> "TableMerger": """Update a target row that has no matches in the source based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 494f911b89..59f18d3630 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -277,7 +277,7 @@ def test_merge_when_not_matched_by_source_update_wo_predicate( { "id": pa.array(["1", "2", "3", "4", "5"]), "price": pa.array([1, 2, 3, 4, 5], pa.int64()), - "sold": pa.array([10,10,10,10,10], pa.int32()), + "sold": pa.array([10, 10, 10, 10, 10], pa.int32()), "deleted": pa.array([False] * 5), } ) @@ -286,7 +286,8 @@ def test_merge_when_not_matched_by_source_update_wo_predicate( assert last_action["operation"] == "MERGE" assert result == expected - + + def test_merge_when_not_matched_by_source_update_with_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): @@ -309,14 +310,14 @@ def test_merge_when_not_matched_by_source_update_with_predicate( updates={ "sold": "10", }, - predicate="price > 3" + predicate="price > 3", ).execute() expected = pa.table( { "id": pa.array(["1", "2", "3", "4", "5"]), "price": pa.array([1, 2, 3, 4, 5], pa.int64()), - "sold": pa.array([1,2,3,10,10], pa.int32()), + "sold": pa.array([1, 2, 3, 10, 10], pa.int32()), "deleted": pa.array([False] * 5), } ) @@ -325,8 +326,8 @@ def test_merge_when_not_matched_by_source_update_with_predicate( assert last_action["operation"] == "MERGE" assert result == expected - - + + def test_merge_when_not_matched_by_source_delete_with_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): @@ -345,15 +346,20 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( dt.merge( source=source_table, source_alias="source", predicate="id = source.id" - ).when_not_matched_by_source_delete( - predicate="price > 3" - ).execute() + ).when_not_matched_by_source_delete(predicate="price > 3").execute() expected = pa.table( { "id": pa.array(["1", "2", "3"]), - "price": pa.array([1, 2, 3,], pa.int64()), - "sold": pa.array([1,2,3], pa.int32()), + "price": pa.array( + [ + 1, + 2, + 3, + ], + pa.int64(), + ), + "sold": pa.array([1, 2, 3], pa.int32()), "deleted": pa.array([False] * 3), } ) @@ -362,5 +368,6 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( assert last_action["operation"] == "MERGE" assert result == expected - -## Add when_not_matched_by_source_delete_wo_predicate ? \ No newline at end of file + + +## Add when_not_matched_by_source_delete_wo_predicate ? From 7770b4e4947c269f7e6f6af834c3878d9f18defe Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Tue, 3 Oct 2023 11:28:36 +0200 Subject: [PATCH 18/35] Add type hints --- python/deltalake/table.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 25ebf01b09..20928f49d9 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -702,7 +702,14 @@ def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch: class TableMerger: """API for various table MERGE commands.""" - def __init__(self, table: DeltaTable, source, source_alias, predicate, strict_cast): + def __init__( + self, + table: DeltaTable, + source: Union[pyarrow.Table, pyarrow.RecordBatch], + source_alias: str, + predicate: str, + strict_cast: bool = True, + ): self.table = table self.source = source self.source_alias = source_alias From 28734be2a64071d98384577fa71b77563a7c1da5 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Tue, 3 Oct 2023 13:27:37 +0200 Subject: [PATCH 19/35] add into --- python/deltalake/table.py | 2 +- python/src/lib.rs | 22 +++++++++++----------- rust/src/operations/merge.rs | 5 +++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 20928f49d9..1b4c1d495a 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -999,7 +999,7 @@ def execute(self) -> Dict[str, Any]: source=self.source, predicate=self.predicate, source_alias=self.source_alias, - strict_cast=self.strict_cast, + safe_cast=self.strict_cast, writer_properties=self.writer_properties, matched_update_updates=self.matched_update_updates, matched_update_predicate=self.matched_update_predicate, diff --git a/python/src/lib.rs b/python/src/lib.rs index 4ab88708d1..047f470365 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -341,7 +341,7 @@ impl RawDeltaTable { #[pyo3(signature = (source, predicate, source_alias, - strict_cast = true, + safe_cast = false, writer_properties = None, matched_update_updates = None, matched_update_predicate = None, @@ -361,7 +361,7 @@ impl RawDeltaTable { source: PyArrowType, predicate: String, source_alias: String, - strict_cast: bool, + safe_cast: bool, writer_properties: Option>, matched_update_updates: Option>, matched_update_predicate: Option, @@ -382,11 +382,11 @@ impl RawDeltaTable { let mut cmd = MergeBuilder::new( self._table.object_store(), self._table.state.clone(), - Expression::String(predicate), + predicate, source_df, ) .with_source_alias(source_alias) - .with_safe_cast(strict_cast); + .with_safe_cast(safe_cast); if let Some(writer_props) = writer_properties { let mut properties = WriterProperties::builder(); @@ -419,7 +419,7 @@ impl RawDeltaTable { // if let Some(mu_predicate) = matched_update_predicate { // cmd = cmd.when_matched_update(|update| { // update - // .predicate(Expression::String(mu_predicate)) + // .predicate(mu_predicate) // }).map_err(PythonError::from)?; // } // else { @@ -431,7 +431,7 @@ impl RawDeltaTable { // if let Some(mu_predicate) = matched_update_predicate { // cmd = cmd.when_matched_update(|update| { // update - // .predicate(Expression::String(mu_predicate)) + // .predicate(mu_predicate) // .update(mu_updates) // }).map_err(PythonError::from)?; // } @@ -458,7 +458,7 @@ impl RawDeltaTable { cmd = cmd .when_matched_update(|update| { update - .predicate(Expression::String(mu_predicate)) + .predicate(mu_predicate) .update_multiple(mu_updates_mapping) }) .map_err(PythonError::from)?; @@ -475,7 +475,7 @@ impl RawDeltaTable { .map_err(PythonError::from)?; } else if let Some(md_predicate) = matched_delete_predicate { cmd = cmd - .when_matched_delete(|delete| delete.predicate(Expression::String(md_predicate))) + .when_matched_delete(|delete| delete.predicate(md_predicate)) .map_err(PythonError::from)?; } @@ -493,7 +493,7 @@ impl RawDeltaTable { cmd = cmd .when_not_matched_insert(|insert| { insert - .predicate(Expression::String(nmi_predicate)) + .predicate(nmi_predicate) .set_multiple(nmi_updates_mapping) }) .map_err(PythonError::from)?; @@ -518,7 +518,7 @@ impl RawDeltaTable { cmd = cmd .when_not_matched_by_source_update(|update| { update - .predicate(Expression::String(nmbsu_predicate)) + .predicate(nmbsu_predicate) .update_multiple(nmbsu_updates_mapping) }) .map_err(PythonError::from)?; @@ -538,7 +538,7 @@ impl RawDeltaTable { } else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { cmd = cmd .when_not_matched_by_source_delete(|delete| { - delete.predicate(Expression::String(nmbs_predicate)) + delete.predicate(nmbs_predicate) }) .map_err(PythonError::from)?; } diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index 2566e83239..80627f2278 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -116,12 +116,13 @@ pub struct MergeBuilder { impl MergeBuilder { /// Create a new [`MergeBuilder`] - pub fn new( + pub fn new>( object_store: ObjectStoreRef, snapshot: DeltaTableState, - predicate: Expression, + predicate: E, source: DataFrame, ) -> Self { + let predicate = predicate.into(); Self { predicate, source, From 1fd7909e74b9ed64ce6816515ce81eec144cee1c Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Tue, 3 Oct 2023 13:47:18 +0200 Subject: [PATCH 20/35] simplify code --- python/src/lib.rs | 82 +++++++++++++++--------------------- rust/src/lib.rs | 2 +- rust/src/operations/merge.rs | 18 -------- rust/src/operations/mod.rs | 2 +- 4 files changed, 36 insertions(+), 68 deletions(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index 047f470365..5c369aa223 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -18,10 +18,9 @@ use deltalake::arrow::compute::concat_batches; use deltalake::arrow::record_batch::RecordBatch; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::create_checkpoint; -use deltalake::datafusion::prelude::{Column, SessionContext}; +use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; -use deltalake::operations::datafusion_utils::Expression; use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; @@ -445,26 +444,23 @@ impl RawDeltaTable { // } if let Some(mu_updates) = matched_update_updates { - let mut mu_updates_mapping: HashMap = HashMap::new(); - - for (col_name, expression) in &mu_updates { - mu_updates_mapping.insert( - Column::from_name(col_name), - Expression::String(expression.clone()), - ); - } - if let Some(mu_predicate) = matched_update_predicate { cmd = cmd - .when_matched_update(|update| { - update - .predicate(mu_predicate) - .update_multiple(mu_updates_mapping) + .when_matched_update(|mut update| { + for (col_name, expression) in mu_updates { + update = update.update(col_name.clone(), expression.clone()); + } + update.predicate(mu_predicate) }) .map_err(PythonError::from)?; } else { cmd = cmd - .when_matched_update(|update| update.update_multiple(mu_updates_mapping)) + .when_matched_update(|mut update| { + for (col_name, expression) in mu_updates { + update = update.update(col_name.clone(), expression.clone()); + } + update + }) .map_err(PythonError::from)?; } } @@ -480,52 +476,44 @@ impl RawDeltaTable { } if let Some(nmi_updates) = not_matched_insert_updates { - let mut nmi_updates_mapping: HashMap = HashMap::new(); - - for (col_name, expression) in &nmi_updates { - nmi_updates_mapping.insert( - Column::from_name(col_name), - Expression::String(expression.clone()), - ); - } - if let Some(nmi_predicate) = not_matched_insert_predicate { cmd = cmd - .when_not_matched_insert(|insert| { - insert - .predicate(nmi_predicate) - .set_multiple(nmi_updates_mapping) + .when_not_matched_insert(|mut insert| { + for (col_name, expression) in nmi_updates { + insert = insert.set(col_name.clone(), expression.clone()); + } + insert.predicate(nmi_predicate) }) .map_err(PythonError::from)?; } else { cmd = cmd - .when_not_matched_insert(|insert| insert.set_multiple(nmi_updates_mapping)) + .when_not_matched_insert(|mut insert| { + for (col_name, expression) in nmi_updates { + insert = insert.set(col_name.clone(), expression.clone()); + } + insert + }) .map_err(PythonError::from)?; } } if let Some(nmbsu_updates) = not_matched_by_source_update_updates { - let mut nmbsu_updates_mapping: HashMap = HashMap::new(); - - for (col_name, expression) in &nmbsu_updates { - nmbsu_updates_mapping.insert( - Column::from_name(col_name), - Expression::String(expression.clone()), - ); - } - if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate { cmd = cmd - .when_not_matched_by_source_update(|update| { - update - .predicate(nmbsu_predicate) - .update_multiple(nmbsu_updates_mapping) + .when_not_matched_by_source_update(|mut update| { + for (col_name, expression) in nmbsu_updates { + update = update.update(col_name.clone(), expression.clone()); + } + update.predicate(nmbsu_predicate) }) .map_err(PythonError::from)?; } else { cmd = cmd - .when_not_matched_by_source_update(|update| { - update.update_multiple(nmbsu_updates_mapping) + .when_not_matched_by_source_update(|mut update| { + for (col_name, expression) in nmbsu_updates { + update = update.update(col_name.clone(), expression.clone()); + } + update }) .map_err(PythonError::from)?; } @@ -537,9 +525,7 @@ impl RawDeltaTable { .map_err(PythonError::from)?; } else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { cmd = cmd - .when_not_matched_by_source_delete(|delete| { - delete.predicate(nmbs_predicate) - }) + .when_not_matched_by_source_delete(|delete| delete.predicate(nmbs_predicate)) .map_err(PythonError::from)?; } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9fb6db0b58..af692fd5c9 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -69,7 +69,7 @@ //! ``` #![deny(warnings)] -// #![deny(missing_docs)] +#![deny(missing_docs)] #![allow(rustdoc::invalid_html_tags)] #[cfg(all(feature = "parquet", feature = "parquet2"))] diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index 80627f2278..d5a598c692 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -396,15 +396,6 @@ impl UpdateBuilder { self.updates.insert(column.into(), expression.into()); self } - - /// Update multiple at ones - pub fn update_multiple>>( - mut self, - mapping: M, - ) -> Self { - self.updates = mapping.into(); - self - } } /// Builder for insert clauses @@ -429,15 +420,6 @@ impl InsertBuilder { self.set.insert(column.into(), expression.into()); self } - - /// Set multiple at ones - pub fn set_multiple>>( - mut self, - mapping: M, - ) -> Self { - self.set = mapping.into(); - self - } } /// Builder for delete clauses diff --git a/rust/src/operations/mod.rs b/rust/src/operations/mod.rs index 09d8c9a6e5..c07b81438b 100644 --- a/rust/src/operations/mod.rs +++ b/rust/src/operations/mod.rs @@ -199,7 +199,7 @@ impl AsRef for DeltaOps { } #[cfg(feature = "datafusion")] -pub mod datafusion_utils { +mod datafusion_utils { use std::sync::Arc; use arrow_schema::SchemaRef; From 3c4493cb2773d0e50f13018e67012aa11a8d117a Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Tue, 3 Oct 2023 14:00:05 +0200 Subject: [PATCH 21/35] Move fixture --- python/tests/conftest.py | 13 +++++++++++++ python/tests/test_merge.py | 13 ------------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 6ddb68a526..d10de6bc00 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -235,3 +235,16 @@ def existing_table(tmp_path: pathlib.Path, sample_data: pa.Table): path = str(tmp_path) write_deltalake(path, sample_data) return DeltaTable(path) + + +@pytest.fixture() +def sample_table(): + nrows = 5 + return pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 59f18d3630..f823e9a734 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -6,19 +6,6 @@ from deltalake import DeltaTable, write_deltalake -@pytest.fixture() -def sample_table(): - nrows = 5 - return pa.table( - { - "id": pa.array(["1", "2", "3", "4", "5"]), - "price": pa.array(list(range(nrows)), pa.int64()), - "sold": pa.array(list(range(nrows)), pa.int32()), - "deleted": pa.array([False] * nrows), - } - ) - - def test_merge_when_matched_delete_wo_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): From 15af04e1c475f5c64aa55456166e457510127f76 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Tue, 3 Oct 2023 14:00:15 +0200 Subject: [PATCH 22/35] format --- python/tests/test_merge.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index f823e9a734..6f0bd2bddc 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -1,7 +1,6 @@ import pathlib import pyarrow as pa -import pytest from deltalake import DeltaTable, write_deltalake From 98a83d6735a5ea7b99d89424c206ef09ec5cb620 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Sat, 7 Oct 2023 13:11:33 +0200 Subject: [PATCH 23/35] Use target_alias, fix tests --- python/deltalake/table.py | 28 ++++++++------ python/tests/test_merge.py | 77 +++++++++++++++++++------------------- 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index e0781fb1ca..d576e1bb07 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -465,8 +465,9 @@ def optimize( def merge( self, source: Union[pyarrow.Table, pyarrow.RecordBatch], - source_alias: str, predicate: str, + source_alias: str = 'source', + target_alias: str = 'target', strict_cast: bool = True, ) -> "TableMerger": """Pass the source data which you want to merge on the target delta table, providing a @@ -475,8 +476,9 @@ def merge( Args: source (pyarrow.Table | pyarrow.RecordBatch): source data - source_alias (str): Alias for the source dataframe predicate (str): SQL like predicate on how to merge + source_alias (str): Alias for the source table + target_alias (str): Alias for the target table strict_cast (bool): specify if data types need to be casted strictly or not :default = False @@ -491,6 +493,7 @@ def merge( source=source, predicate=predicate, source_alias=source_alias, + target_alias=target_alias, strict_cast=not strict_cast, ) @@ -720,14 +723,16 @@ def __init__( self, table: DeltaTable, source: Union[pyarrow.Table, pyarrow.RecordBatch], - source_alias: str, predicate: str, + source_alias: str, + target_alias: str, strict_cast: bool = True, ): self.table = table self.source = source - self.source_alias = source_alias self.predicate = predicate + self.source_alias = source_alias + self.target_alias = target_alias self.strict_cast = strict_cast self.writer_properties: Optional[Dict[str, Optional[int]]] = None self.matched_update_updates: Optional[Dict[str, str]] = None @@ -794,7 +799,7 @@ def when_matched_update( >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ ... .when_matched_update( ... updates = { ... "x": "source.x", @@ -828,7 +833,7 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ ... .when_matched_update( ... updates = { ... "x": "source.x", @@ -864,7 +869,7 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ ... .when_matched_delete(predicate = "source.deleted = true") ... .execute() @@ -874,7 +879,7 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ ... .when_matched_delete() ... .execute() """ @@ -904,7 +909,7 @@ def when_not_matched_insert( >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ ... .when_not_matched_insert( ... updates = { ... "x": "source.x", @@ -942,7 +947,7 @@ def when_not_matched_insert_all( >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ ... .when_not_matched_insert_all().execute() """ if self.not_matched_insert_updates is not None: @@ -971,7 +976,7 @@ def when_not_matched_by_source_update( >>> import pyarrow as pa >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, source_alias='source', predicate='x = source.x') \ + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ ... .when_not_matched_by_source_update( ... predicate = "y > 3" ... updates = { @@ -1013,6 +1018,7 @@ def execute(self) -> Dict[str, Any]: source=self.source, predicate=self.predicate, source_alias=self.source_alias, + target_alias=self.target_alias, safe_cast=self.strict_cast, writer_properties=self.writer_properties, matched_update_updates=self.matched_update_updates, diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 6f0bd2bddc..cde2ed2e72 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -22,7 +22,7 @@ def test_merge_when_matched_delete_wo_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" + source=source_table, predicate="t.id = s.id", source_alias="s", target_alias='t', ).when_matched_delete().execute() nrows = 4 @@ -34,7 +34,7 @@ def test_merge_when_matched_delete_wo_predicate( "deleted": pa.array([False] * nrows), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -58,8 +58,8 @@ def test_merge_when_matched_delete_with_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" - ).when_matched_delete("source.deleted = True").execute() + source=source_table, predicate="t.id = s.id", source_alias="s", target_alias='t', + ).when_matched_delete("s.deleted = True").execute() nrows = 4 expected = pa.table( @@ -70,7 +70,7 @@ def test_merge_when_matched_delete_with_predicate( "deleted": pa.array([False] * nrows), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -94,18 +94,18 @@ def test_merge_when_matched_update_wo_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" - ).when_matched_update({"price": "source.price", "sold": "source.sold"}).execute() + source=source_table, predicate="t.id = s.id", source_alias="s", target_alias='t', + ).when_matched_update({"price": "s.price", "sold": "s.sold"}).execute() expected = pa.table( { "id": pa.array(["1", "2", "3", "4", "5"]), - "price": pa.array([1, 2, 3, 10, 100], pa.int64()), - "sold": pa.array([1, 2, 3, 10, 20], pa.int32()), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 20], pa.int32()), "deleted": pa.array([False] * 5), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -129,7 +129,7 @@ def test_merge_when_matched_update_with_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" + source=source_table, source_alias="source", target_alias='target', predicate="id = source.id", ).when_matched_update( updates={"price": "source.price", "sold": "source.sold"}, predicate="source.deleted = False", @@ -138,12 +138,12 @@ def test_merge_when_matched_update_with_predicate( expected = pa.table( { "id": pa.array(["1", "2", "3", "4", "5"]), - "price": pa.array([1, 2, 3, 10, 5], pa.int64()), - "sold": pa.array([1, 2, 3, 10, 5], pa.int32()), + "price": pa.array([0, 1, 2, 10, 4], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 4], pa.int32()), "deleted": pa.array([False] * 5), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -167,7 +167,7 @@ def test_merge_when_not_matched_insert_wo_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" + source=source_table, source_alias="source", target_alias='target', predicate="id = source.id" ).when_not_matched_insert( updates={ "id": "source.id", @@ -180,12 +180,12 @@ def test_merge_when_not_matched_insert_wo_predicate( expected = pa.table( { "id": pa.array(["1", "2", "3", "4", "5", "10"]), - "price": pa.array([1, 2, 3, 4, 5, 100], pa.int64()), - "sold": pa.array([1, 2, 3, 4, 5, 20], pa.int32()), + "price": pa.array([0, 1, 2, 3, 4, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, 20], pa.int32()), "deleted": pa.array([False] * 6), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -217,18 +217,18 @@ def test_merge_when_not_matched_insert_with_predicate( "sold": "source.sold", "deleted": "False", }, - predicate="source.price < 50", + predicate="source.price < bigint'50'", ).execute() expected = pa.table( { "id": pa.array(["1", "2", "3", "4", "5", "6"]), - "price": pa.array([1, 2, 3, 4, 5, 10], pa.int64()), - "sold": pa.array([1, 2, 3, 4, 5, 10], pa.int32()), + "price": pa.array([0, 1, 2, 3, 4, 10], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, 10], pa.int32()), "deleted": pa.array([False] * 6), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -252,22 +252,22 @@ def test_merge_when_not_matched_by_source_update_wo_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" + source=source_table, source_alias="source", target_alias='target', predicate="id = source.id" ).when_not_matched_by_source_update( updates={ - "sold": "10", + "sold": "int'10'", } ).execute() expected = pa.table( { "id": pa.array(["1", "2", "3", "4", "5"]), - "price": pa.array([1, 2, 3, 4, 5], pa.int64()), + "price": pa.array([0, 1, 2, 3, 4], pa.int64()), "sold": pa.array([10, 10, 10, 10, 10], pa.int32()), "deleted": pa.array([False] * 5), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -291,12 +291,12 @@ def test_merge_when_not_matched_by_source_update_with_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" + source=source_table, source_alias="source", target_alias='target', predicate="id = source.id" ).when_not_matched_by_source_update( updates={ - "sold": "10", + "sold": "int'10'", }, - predicate="price > 3", + predicate="price > bigint'3'", ).execute() expected = pa.table( @@ -307,7 +307,7 @@ def test_merge_when_not_matched_by_source_update_with_predicate( "deleted": pa.array([False] * 5), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -331,29 +331,30 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" - ).when_not_matched_by_source_delete(predicate="price > 3").execute() + source=source_table, source_alias="source", target_alias='target', predicate="id = source.id" + ).when_not_matched_by_source_delete(predicate="price > bigint'3'").execute() expected = pa.table( { - "id": pa.array(["1", "2", "3"]), + "id": pa.array(["1", "2", "3", "4"]), "price": pa.array( [ + 0, 1, 2, - 3, + 3 ], pa.int64(), ), - "sold": pa.array([1, 2, 3], pa.int32()), - "deleted": pa.array([False] * 3), + "sold": pa.array([0, 1, 2, 3], pa.int32()), + "deleted": pa.array([False] * 4), } ) - result = dt.to_pyarrow_table() + result = dt.to_pyarrow_table().sort_by([('id','ascending')]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" assert result == expected -## Add when_not_matched_by_source_delete_wo_predicate ? +# ## Add when_not_matched_by_source_delete_wo_predicate ? From de56cec81b386b8c396724bf892ef2dbd6159876 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Sat, 7 Oct 2023 13:32:46 +0200 Subject: [PATCH 24/35] Add passing tests --- python/tests/test_merge.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index cde2ed2e72..096c66171d 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -129,7 +129,7 @@ def test_merge_when_matched_update_with_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="id = source.id", + source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id", ).when_matched_update( updates={"price": "source.price", "sold": "source.sold"}, predicate="source.deleted = False", @@ -159,7 +159,7 @@ def test_merge_when_not_matched_insert_wo_predicate( source_table = pa.table( { - "id": pa.array(["4", "10"]), + "id": pa.array(["4", "6"]), "price": pa.array([10, 100], pa.int64()), "sold": pa.array([10, 20], pa.int32()), "deleted": pa.array([False, False]), @@ -167,7 +167,7 @@ def test_merge_when_not_matched_insert_wo_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="id = source.id" + source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" ).when_not_matched_insert( updates={ "id": "source.id", @@ -179,7 +179,7 @@ def test_merge_when_not_matched_insert_wo_predicate( expected = pa.table( { - "id": pa.array(["1", "2", "3", "4", "5", "10"]), + "id": pa.array(["1", "2", "3", "4", "5", "6"]), "price": pa.array([0, 1, 2, 3, 4, 100], pa.int64()), "sold": pa.array([0, 1, 2, 3, 4, 20], pa.int32()), "deleted": pa.array([False] * 6), @@ -209,7 +209,7 @@ def test_merge_when_not_matched_insert_with_predicate( ) dt.merge( - source=source_table, source_alias="source", predicate="id = source.id" + source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" ).when_not_matched_insert( updates={ "id": "source.id", @@ -252,7 +252,7 @@ def test_merge_when_not_matched_by_source_update_wo_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="id = source.id" + source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" ).when_not_matched_by_source_update( updates={ "sold": "int'10'", @@ -291,19 +291,19 @@ def test_merge_when_not_matched_by_source_update_with_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="id = source.id" + source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" ).when_not_matched_by_source_update( updates={ "sold": "int'10'", }, - predicate="price > bigint'3'", + predicate="target.price > bigint'3'", ).execute() expected = pa.table( { "id": pa.array(["1", "2", "3", "4", "5"]), - "price": pa.array([1, 2, 3, 4, 5], pa.int64()), - "sold": pa.array([1, 2, 3, 10, 10], pa.int32()), + "price": pa.array([0, 1, 2, 3, 4], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 10], pa.int32()), "deleted": pa.array([False] * 5), } ) @@ -331,8 +331,8 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="id = source.id" - ).when_not_matched_by_source_delete(predicate="price > bigint'3'").execute() + source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" + ).when_not_matched_by_source_delete(predicate="target.price > bigint'3'").execute() expected = pa.table( { @@ -357,4 +357,4 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( assert result == expected -# ## Add when_not_matched_by_source_delete_wo_predicate ? +# # ## Add when_not_matched_by_source_delete_wo_predicate ? From 70dc65a8c1657a7c30c19fb417f32e1a35f09fc4 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Sat, 7 Oct 2023 13:33:11 +0200 Subject: [PATCH 25/35] formatting --- python/deltalake/table.py | 4 +-- python/src/lib.rs | 2 +- python/tests/test_merge.py | 70 +++++++++++++++++++++++++------------- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index d576e1bb07..15cb1a3b20 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -466,8 +466,8 @@ def merge( self, source: Union[pyarrow.Table, pyarrow.RecordBatch], predicate: str, - source_alias: str = 'source', - target_alias: str = 'target', + source_alias: str = "source", + target_alias: str = "target", strict_cast: bool = True, ) -> "TableMerger": """Pass the source data which you want to merge on the target delta table, providing a diff --git a/python/src/lib.rs b/python/src/lib.rs index f1d84276aa..59bb601ca3 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -21,8 +21,8 @@ use deltalake::checkpoints::create_checkpoint; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; -use deltalake::operations::merge::MergeBuilder; use deltalake::operations::delete::DeleteBuilder; +use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::transaction::commit; diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 096c66171d..f619a23009 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -22,7 +22,10 @@ def test_merge_when_matched_delete_wo_predicate( ) dt.merge( - source=source_table, predicate="t.id = s.id", source_alias="s", target_alias='t', + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", ).when_matched_delete().execute() nrows = 4 @@ -34,7 +37,7 @@ def test_merge_when_matched_delete_wo_predicate( "deleted": pa.array([False] * nrows), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -58,7 +61,10 @@ def test_merge_when_matched_delete_with_predicate( ) dt.merge( - source=source_table, predicate="t.id = s.id", source_alias="s", target_alias='t', + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", ).when_matched_delete("s.deleted = True").execute() nrows = 4 @@ -70,7 +76,7 @@ def test_merge_when_matched_delete_with_predicate( "deleted": pa.array([False] * nrows), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -94,7 +100,10 @@ def test_merge_when_matched_update_wo_predicate( ) dt.merge( - source=source_table, predicate="t.id = s.id", source_alias="s", target_alias='t', + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", ).when_matched_update({"price": "s.price", "sold": "s.sold"}).execute() expected = pa.table( @@ -105,7 +114,7 @@ def test_merge_when_matched_update_wo_predicate( "deleted": pa.array([False] * 5), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -129,7 +138,10 @@ def test_merge_when_matched_update_with_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id", + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", ).when_matched_update( updates={"price": "source.price", "sold": "source.sold"}, predicate="source.deleted = False", @@ -143,7 +155,7 @@ def test_merge_when_matched_update_with_predicate( "deleted": pa.array([False] * 5), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -167,7 +179,10 @@ def test_merge_when_not_matched_insert_wo_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", ).when_not_matched_insert( updates={ "id": "source.id", @@ -185,7 +200,7 @@ def test_merge_when_not_matched_insert_wo_predicate( "deleted": pa.array([False] * 6), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -209,7 +224,10 @@ def test_merge_when_not_matched_insert_with_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", ).when_not_matched_insert( updates={ "id": "source.id", @@ -228,7 +246,7 @@ def test_merge_when_not_matched_insert_with_predicate( "deleted": pa.array([False] * 6), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -252,7 +270,10 @@ def test_merge_when_not_matched_by_source_update_wo_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", ).when_not_matched_by_source_update( updates={ "sold": "int'10'", @@ -267,7 +288,7 @@ def test_merge_when_not_matched_by_source_update_wo_predicate( "deleted": pa.array([False] * 5), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -291,7 +312,10 @@ def test_merge_when_not_matched_by_source_update_with_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", ).when_not_matched_by_source_update( updates={ "sold": "int'10'", @@ -307,7 +331,7 @@ def test_merge_when_not_matched_by_source_update_with_predicate( "deleted": pa.array([False] * 5), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" @@ -331,26 +355,24 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( ) dt.merge( - source=source_table, source_alias="source", target_alias='target', predicate="target.id = source.id" + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", ).when_not_matched_by_source_delete(predicate="target.price > bigint'3'").execute() expected = pa.table( { "id": pa.array(["1", "2", "3", "4"]), "price": pa.array( - [ - 0, - 1, - 2, - 3 - ], + [0, 1, 2, 3], pa.int64(), ), "sold": pa.array([0, 1, 2, 3], pa.int32()), "deleted": pa.array([False] * 4), } ) - result = dt.to_pyarrow_table().sort_by([('id','ascending')]) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" From b0de38e9b0c2a526272a6aedfbb62ea3f2a815a8 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Sat, 7 Oct 2023 13:36:52 +0200 Subject: [PATCH 26/35] Make consistent with update --- python/deltalake/table.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 15cb1a3b20..e9b4936d6d 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -468,7 +468,7 @@ def merge( predicate: str, source_alias: str = "source", target_alias: str = "target", - strict_cast: bool = True, + error_on_type_mismatch: bool = True, ) -> "TableMerger": """Pass the source data which you want to merge on the target delta table, providing a predicate in SQL query format. You can also specify on what to do when underlying data types do not @@ -479,7 +479,7 @@ def merge( predicate (str): SQL like predicate on how to merge source_alias (str): Alias for the source table target_alias (str): Alias for the target table - strict_cast (bool): specify if data types need to be casted strictly or not :default = False + error_on_type_mismatch (bool): specify if merge will return error if data types are mismatching :default = True Returns: @@ -494,7 +494,7 @@ def merge( predicate=predicate, source_alias=source_alias, target_alias=target_alias, - strict_cast=not strict_cast, + safe_cast=not error_on_type_mismatch, ) def pyarrow_schema(self) -> pyarrow.Schema: @@ -726,14 +726,14 @@ def __init__( predicate: str, source_alias: str, target_alias: str, - strict_cast: bool = True, + safe_cast: bool = True, ): self.table = table self.source = source self.predicate = predicate self.source_alias = source_alias self.target_alias = target_alias - self.strict_cast = strict_cast + self.safe_cast = safe_cast self.writer_properties: Optional[Dict[str, Optional[int]]] = None self.matched_update_updates: Optional[Dict[str, str]] = None self.matched_update_predicate: Optional[str] = None @@ -1019,7 +1019,7 @@ def execute(self) -> Dict[str, Any]: predicate=self.predicate, source_alias=self.source_alias, target_alias=self.target_alias, - safe_cast=self.strict_cast, + safe_cast=self.safe_cast, writer_properties=self.writer_properties, matched_update_updates=self.matched_update_updates, matched_update_predicate=self.matched_update_predicate, From cc7fd15d38606b603db30d564cb57f6860a288a2 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Thu, 12 Oct 2023 17:27:06 +0200 Subject: [PATCH 27/35] add target_alias --- python/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/src/lib.rs b/python/src/lib.rs index e96593da53..a3eced8a33 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -401,6 +401,7 @@ impl RawDeltaTable { #[pyo3(signature = (source, predicate, source_alias, + target_alias, safe_cast = false, writer_properties = None, matched_update_updates = None, @@ -421,6 +422,7 @@ impl RawDeltaTable { source: PyArrowType, predicate: String, source_alias: String, + target_alias: String, safe_cast: bool, writer_properties: Option>, matched_update_updates: Option>, @@ -446,6 +448,7 @@ impl RawDeltaTable { source_df, ) .with_source_alias(source_alias) + .with_target_alias(target_alias) .with_safe_cast(safe_cast); if let Some(writer_props) = writer_properties { From 547bcaff865b5c7f25b59824cf3cc97991311dab Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Thu, 12 Oct 2023 22:30:21 +0200 Subject: [PATCH 28/35] Add update_all and insert_all method --- python/deltalake/_internal.pyi | 19 ++++++++ python/deltalake/table.py | 85 ++++++++++++++-------------------- python/src/lib.rs | 48 +++++-------------- python/tests/test_merge.py | 79 +++++++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 86 deletions(-) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index d1fc616359..8e6c71d75c 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -95,6 +95,25 @@ class RawDeltaTable: writer_properties: Optional[Dict[str, int]], safe_cast: bool = False, ) -> str: ... + def merge_execute( + self, + source: pa.RecordBatch, + predicate: str, + source_alias: Optional[str], + target_alias: Optional[str], + writer_properties: Optional[Dict[str, int | None]], + safe_cast: bool, + matched_update_updates: Optional[Dict[str, str]], + matched_update_predicate: Optional[str], + matched_delete_predicate: Optional[str], + matched_delete_all: Optional[bool], + not_matched_insert_updates: Optional[Dict[str, str]], + not_matched_insert_predicate: Optional[str], + not_matched_by_source_update_updates: Optional[Dict[str, str]], + not_matched_by_source_update_predicate: Optional[str], + not_matched_by_source_delete_predicate: Optional[str], + not_matched_by_source_delete_all: Optional[bool], + ) -> str: ... def get_active_partitions( self, partitions_filters: Optional[FilterType] = None ) -> Any: ... diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 86d2e1fc38..877452b1a2 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -529,8 +529,8 @@ def merge( self, source: Union[pyarrow.Table, pyarrow.RecordBatch], predicate: str, - source_alias: str = "source", - target_alias: str = "target", + source_alias: Optional[str] = None, + target_alias: Optional[str] = None, error_on_type_mismatch: bool = True, ) -> "TableMerger": """Pass the source data which you want to merge on the target delta table, providing a @@ -790,8 +790,8 @@ def __init__( table: DeltaTable, source: Union[pyarrow.Table, pyarrow.RecordBatch], predicate: str, - source_alias: str, - target_alias: str, + source_alias: Optional[str] = None, + target_alias: Optional[str] = None, safe_cast: bool = True, ): self.table = table @@ -803,12 +803,10 @@ def __init__( self.writer_properties: Optional[Dict[str, Optional[int]]] = None self.matched_update_updates: Optional[Dict[str, str]] = None self.matched_update_predicate: Optional[str] = None - self.matched_update_all: Optional[bool] = None self.matched_delete_predicate: Optional[str] = None self.matched_delete_all: Optional[bool] = None self.not_matched_insert_updates: Optional[Dict[str, str]] = None self.not_matched_insert_predicate: Optional[str] = None - self.not_matched_insert_all: Optional[bool] = None self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None self.not_matched_by_source_update_predicate: Optional[str] = None self.not_matched_by_source_delete_predicate: Optional[str] = None @@ -825,11 +823,11 @@ def with_writer_properties( """Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html: Args: - data_page_size_limit (int|None, optional): _description_. Defaults to None. - dictionary_page_size_limit (int|None, optional): _description_. Defaults to None. - data_page_row_count_limit (int|None, optional): _description_. Defaults to None. - write_batch_size (int|None, optional): _description_. Defaults to None. - max_row_group_size (int|None, optional): _description_. Defaults to None. + data_page_size_limit (int|None, optional): Limit DataPage size to this in bytes. Defaults to None. + dictionary_page_size_limit (int|None, optional): Limit the size of each DataPage to store dicts to this amount in bytes. Defaults to None. + data_page_row_count_limit (int|None, optional): Limit the number of rows in each DataPage. Defaults to None. + write_batch_size (int|None, optional): Splits internally to smaller batch size. Defaults to None. + max_row_group_size (int|None, optional): Max number of rows in row group. Defaults to None. Returns: TableMerger: TableMerger Object @@ -873,20 +871,14 @@ def when_matched_update( ... } ... ).execute() """ - if self.matched_update_all is not None: - raise DeltaProtocolError( - "You can't specify when_matched_update and when_matched_update_all at the same time. Pick one." - ) - else: - self.matched_update_updates = updates - self.matched_update_predicate = predicate + self.matched_update_updates = updates + self.matched_update_predicate = predicate return self def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": - """Update a matched table row based on the rules defined by ``updates``. + """Updating all source fields to target field. Source and target need to share the same field names for this to work. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. - Args: predicate (str | None, optional): SQL like predicate on when to update all columns. Defaults to None. @@ -900,21 +892,18 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) >>> dt = DeltaTable("tmp") >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ - ... .when_matched_update( - ... updates = { - ... "x": "source.x", - ... "y": "source.y" - ... } - ... ).execute() + ... .when_matched_update_all().execute() """ - if self.matched_update_updates is not None: - raise DeltaProtocolError( - "You can't specify when_matched_update and when_matched_update_all at the same time. Pick one." - ) - else: - self.matched_update_all = True - self.matched_update_predicate = predicate + src_alias = (self.source_alias + ".") if self.source_alias is not None else "" + trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" + + self.matched_update_updates = { + f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" + for col in self.source.schema + } + print(self.matched_update_updates) + self.matched_update_predicate = predicate return self def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": @@ -984,24 +973,19 @@ def when_not_matched_insert( ... ).execute() """ - if self.not_matched_insert_all is not None: - raise DeltaProtocolError( - "You can't specify when_not_matched_insert and when_not_matched_insert_all at the same time. Pick one." - ) - else: - self.not_matched_insert_updates = updates - self.not_matched_insert_predicate = predicate + self.not_matched_insert_updates = updates + self.not_matched_insert_predicate = predicate return self def when_not_matched_insert_all( self, predicate: Optional[str] = None ) -> "TableMerger": - """Insert a new row to the target table based on the rules defined by ``updates``. If a - ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. + """Insert a new row to the target table, updating all source fields to target field. Source and target + need to share the same field names for this to work. If a ``predicate`` is specified, then it must + evaluate to true for the new row to be inserted. Args: - updates (dict): column mapping (source to target) which to insert predicate (str | None, optional): SQL like predicate on when to insert. Defaults to None. Returns: @@ -1016,13 +1000,14 @@ def when_not_matched_insert_all( >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ ... .when_not_matched_insert_all().execute() """ - if self.not_matched_insert_updates is not None: - raise DeltaProtocolError( - "You can't specify when_not_matched_insert and when_not_matched_insert_all at the same time. Pick one." - ) - else: - self.not_matched_insert_all = True - self.not_matched_insert_predicate = predicate + + src_alias = (self.source_alias + ".") if self.source_alias is not None else "" + trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" + self.not_matched_insert_updates = { + f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" + for col in self.source.schema + } + self.not_matched_insert_predicate = predicate return self def when_not_matched_by_source_update( diff --git a/python/src/lib.rs b/python/src/lib.rs index a3eced8a33..a1c6cff4ef 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -400,8 +400,8 @@ impl RawDeltaTable { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (source, predicate, - source_alias, - target_alias, + source_alias = None, + target_alias = None, safe_cast = false, writer_properties = None, matched_update_updates = None, @@ -421,8 +421,8 @@ impl RawDeltaTable { &mut self, source: PyArrowType, predicate: String, - source_alias: String, - target_alias: String, + source_alias: Option, + target_alias: Option, safe_cast: bool, writer_properties: Option>, matched_update_updates: Option>, @@ -447,10 +447,16 @@ impl RawDeltaTable { predicate, source_df, ) - .with_source_alias(source_alias) - .with_target_alias(target_alias) .with_safe_cast(safe_cast); + if let Some(src_alias) = source_alias { + cmd = cmd.with_source_alias(src_alias); + } + + if let Some(trgt_alias) = target_alias { + cmd = cmd.with_target_alias(trgt_alias); + } + if let Some(writer_props) = writer_properties { let mut properties = WriterProperties::builder(); let data_page_size_limit = writer_props.get("data_page_size_limit"); @@ -477,36 +483,6 @@ impl RawDeltaTable { cmd = cmd.with_writer_properties(properties.build()); } - // MATCHED UPDATE ALL OPTION - // if let Some(mu_update_all) = matched_update_all { - // if let Some(mu_predicate) = matched_update_predicate { - // cmd = cmd.when_matched_update(|update| { - // update - // .predicate(mu_predicate) - // }).map_err(PythonError::from)?; - // } - // else { - // cmd = cmd.when_matched_update(|update| update).map_err(PythonError::from)?; - // } - // } - // else { - // if let Some(mu_updates) = matched_update_updates { - // if let Some(mu_predicate) = matched_update_predicate { - // cmd = cmd.when_matched_update(|update| { - // update - // .predicate(mu_predicate) - // .update(mu_updates) - // }).map_err(PythonError::from)?; - // } - // else { - // cmd = cmd.when_matched_update(|update| { - // update - // .update(mu_updates) - // }).map_err(PythonError::from)?; - // } - // } - // } - if let Some(mu_updates) = matched_update_updates { if let Some(mu_predicate) = matched_update_predicate { cmd = cmd diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index f619a23009..0213918dc5 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -121,6 +121,44 @@ def test_merge_when_matched_update_wo_predicate( assert result == expected +def test_merge_when_matched_update_all_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([True, True]), + } + ) + + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + ).when_matched_update_all().execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 20], pa.int32()), + "deleted": pa.array([False, False, False, True, True]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + def test_merge_when_matched_update_with_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): @@ -253,6 +291,47 @@ def test_merge_when_not_matched_insert_with_predicate( assert result == expected +def test_merge_when_not_matched_insert_all_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "10"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([None, None], pa.bool_()), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_insert_all( + predicate="source.price < bigint'50'", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6"]), + "price": pa.array([0, 1, 2, 3, 4, 10], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, 10], pa.int32()), + "deleted": pa.array([False, False, False, False, False, None]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + + def test_merge_when_not_matched_by_source_update_wo_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): From 5a2e66f0402c7a1fe4daa60688a81450681d4b2c Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Thu, 12 Oct 2023 22:43:29 +0200 Subject: [PATCH 29/35] formatting + update docs --- python/deltalake/table.py | 25 ++++++++++--------------- python/tests/test_merge.py | 1 - rust/src/operations/merge.rs | 2 +- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 877452b1a2..ad227bd618 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -534,7 +534,7 @@ def merge( error_on_type_mismatch: bool = True, ) -> "TableMerger": """Pass the source data which you want to merge on the target delta table, providing a - predicate in SQL query format. You can also specify on what to do when underlying data types do not + predicate in SQL query like format. You can also specify on what to do when the underlying data types do not match the underlying table. Args: @@ -544,7 +544,6 @@ def merge( target_alias (str): Alias for the target table error_on_type_mismatch (bool): specify if merge will return error if data types are mismatching :default = True - Returns: TableMerger: TableMerger Object """ @@ -848,15 +847,13 @@ def when_matched_update( """Update a matched table row based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. - Args: - updates (dict): column mapping (source to target) which to update + updates (dict): a mapping of column name to update SQL expression. predicate (str | None, optional): SQL like predicate on when to update. Defaults to None. Returns: TableMerger: TableMerger Object - Examples: >>> from deltalake import DeltaTable @@ -876,7 +873,7 @@ def when_matched_update( return self def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": - """Updating all source fields to target field. Source and target need to share the same field names for this to work. + """Updating all source fields to target fields, source and target are required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. Args: @@ -952,7 +949,7 @@ def when_not_matched_insert( ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. Args: - updates (dict): column mapping (source to target) which to insert + updates (dict): a mapping of column name to insert SQL expression. predicate (str | None, optional): SQL like predicate on when to insert. Defaults to None. Returns: @@ -981,9 +978,9 @@ def when_not_matched_insert( def when_not_matched_insert_all( self, predicate: Optional[str] = None ) -> "TableMerger": - """Insert a new row to the target table, updating all source fields to target field. Source and target - need to share the same field names for this to work. If a ``predicate`` is specified, then it must - evaluate to true for the new row to be inserted. + """Insert a new row to the target table, updating all source fields to target fields. Source and target are + required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for + the new row to be inserted. Args: predicate (str | None, optional): SQL like predicate on when to insert. Defaults to None. @@ -1017,7 +1014,7 @@ def when_not_matched_by_source_update( If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. Args: - updates (dict): column mapping (source to target) which to update + updates (dict): a mapping of column name to update SQL expression. predicate (str | None, optional): SQL like predicate on when to update. Defaults to None. Returns: @@ -1046,8 +1043,8 @@ def when_not_matched_by_source_delete( ``predicate`` (if specified) is true for the target row. Args: - updates (dict): column mapping (source to target) which to update - predicate (str | None, optional): SQL like predicate on when to deleted when not matched by source. Defaults to None. + updates (dict): a mapping of column name to update SQL expression. + predicate (str | None, optional): SQL like predicate on when to delete when not matched by source. Defaults to None. Returns: TableMerger: TableMerger Object @@ -1074,12 +1071,10 @@ def execute(self) -> Dict[str, Any]: writer_properties=self.writer_properties, matched_update_updates=self.matched_update_updates, matched_update_predicate=self.matched_update_predicate, - # matched_update_all=self.matched_update_all, matched_delete_predicate=self.matched_delete_predicate, matched_delete_all=self.matched_delete_all, not_matched_insert_updates=self.not_matched_insert_updates, not_matched_insert_predicate=self.not_matched_insert_predicate, - # not_matched_insert_all = self.not_matched_insert_all, not_matched_by_source_update_updates=self.not_matched_by_source_update_updates, not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate, not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate, diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 0213918dc5..ac1fd36411 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -331,7 +331,6 @@ def test_merge_when_not_matched_insert_all_with_predicate( assert result == expected - def test_merge_when_not_matched_by_source_update_wo_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index 053ec48d64..46a2c540bf 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -60,8 +60,8 @@ use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr use datafusion_physical_expr::{create_physical_expr, expressions, PhysicalExpr}; use futures::future::BoxFuture; use parquet::file::properties::WriterProperties; +use serde::Serialize; use serde_json::{Map, Value}; -use serde::{Serialize}; use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; use super::transaction::commit; From d045ec18a2e142b4368ea3905dda86d78544381f Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 14 Oct 2023 21:43:56 +0200 Subject: [PATCH 30/35] Adjust test cases to include less or different columns compared to target --- python/tests/test_merge.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index ac1fd36411..177667a803 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -15,9 +15,7 @@ def test_merge_when_matched_delete_wo_predicate( source_table = pa.table( { "id": pa.array(["5"]), - "price": pa.array([1], pa.int64()), - "sold": pa.array([1], pa.int32()), - "deleted": pa.array([False]), + "weight": pa.array([105], pa.int32()), } ) @@ -54,9 +52,10 @@ def test_merge_when_matched_delete_with_predicate( source_table = pa.table( { "id": pa.array(["5", "4"]), - "price": pa.array([1, 2], pa.int64()), + "weight": pa.array([1, 2], pa.int64()), "sold": pa.array([1, 2], pa.int32()), "deleted": pa.array([True, False]), + "customer": pa.array(['Adam', 'Patrick']) } ) @@ -95,7 +94,6 @@ def test_merge_when_matched_update_wo_predicate( "id": pa.array(["4", "5"]), "price": pa.array([10, 100], pa.int64()), "sold": pa.array([10, 20], pa.int32()), - "deleted": pa.array([False, False]), } ) @@ -134,6 +132,7 @@ def test_merge_when_matched_update_all_wo_predicate( "price": pa.array([10, 100], pa.int64()), "sold": pa.array([10, 20], pa.int32()), "deleted": pa.array([True, True]), + "weight": pa.array([10,15], pa.int64()), } ) From 06480ef0bdcaab9fedef65a10e8281ca19b3aed2 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 14 Oct 2023 21:49:00 +0200 Subject: [PATCH 31/35] Add when_not_matched_by_source_delete_wo_predicate test --- python/tests/test_merge.py | 39 +++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 177667a803..fc08563443 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -55,7 +55,7 @@ def test_merge_when_matched_delete_with_predicate( "weight": pa.array([1, 2], pa.int64()), "sold": pa.array([1, 2], pa.int32()), "deleted": pa.array([True, False]), - "customer": pa.array(['Adam', 'Patrick']) + "customer": pa.array(["Adam", "Patrick"]), } ) @@ -132,7 +132,7 @@ def test_merge_when_matched_update_all_wo_predicate( "price": pa.array([10, 100], pa.int64()), "sold": pa.array([10, 20], pa.int32()), "deleted": pa.array([True, True]), - "weight": pa.array([10,15], pa.int64()), + "weight": pa.array([10, 15], pa.int64()), } ) @@ -456,4 +456,37 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( assert result == expected -# # ## Add when_not_matched_by_source_delete_wo_predicate ? +def test_merge_when_not_matched_by_source_delete_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + {"id": pa.array(["4", "5"]), "weight": pa.array([1.5, 1.6], pa.float64())} + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_by_source_delete().execute() + + expected = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array( + [3, 4], + pa.int64(), + ), + "sold": pa.array([3, 4], pa.int32()), + "deleted": pa.array([False] * 2), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected From 0c1af8724f1f4be2c3c9b18b8bcb90cdc6cb8396 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 14 Oct 2023 23:11:36 +0200 Subject: [PATCH 32/35] use recordbatchrearder and fix bug --- python/deltalake/_internal.pyi | 2 +- python/deltalake/table.py | 33 ++++++++++++++++++++++++++++----- python/src/lib.rs | 20 ++++++++++++++------ 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 8e6c71d75c..d9adbe7cbc 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -97,7 +97,7 @@ class RawDeltaTable: ) -> str: ... def merge_execute( self, - source: pa.RecordBatch, + source: pa.RecordBatchReader, predicate: str, source_alias: Optional[str], target_alias: Optional[str], diff --git a/python/deltalake/table.py b/python/deltalake/table.py index ad227bd618..ed831e6dcb 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: import pandas +from ._internal import DeltaDataChecker as _DeltaDataChecker from ._internal import RawDeltaTable from ._util import encode_partition_value from .data_catalog import DataCatalog @@ -527,7 +528,7 @@ def optimize( def merge( self, - source: Union[pyarrow.Table, pyarrow.RecordBatch], + source: Union[pyarrow.Table, pyarrow.RecordBatch, pyarrow.RecordBatchReader], predicate: str, source_alias: Optional[str] = None, target_alias: Optional[str] = None, @@ -547,12 +548,34 @@ def merge( Returns: TableMerger: TableMerger Object """ - if isinstance(source, pyarrow.Table): - source = source.to_batches()[0] + invariants = self.schema().invariants + checker = _DeltaDataChecker(invariants) + + if isinstance(source, pyarrow.RecordBatchReader): + schema = source.schema + else: + schema = source.schema + + if isinstance(source, pyarrow.RecordBatchReader): + batch_iter = source + elif isinstance(source, pyarrow.RecordBatch): + batch_iter = [source] + elif isinstance(source, pyarrow.Table): + batch_iter = source.to_batches() + else: + batch_iter = source + + def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: + checker.check_batch(batch) + return batch + + data = pyarrow.RecordBatchReader.from_batches( + schema, (validate_batch(batch) for batch in batch_iter) + ) return TableMerger( self, - source=source, + source=data, predicate=predicate, source_alias=source_alias, target_alias=target_alias, @@ -787,7 +810,7 @@ class TableMerger: def __init__( self, table: DeltaTable, - source: Union[pyarrow.Table, pyarrow.RecordBatch], + source: pyarrow.RecordBatchReader, predicate: str, source_alias: Optional[str] = None, target_alias: Optional[str] = None, diff --git a/python/src/lib.rs b/python/src/lib.rs index a1c6cff4ef..e508f6c92e 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -15,9 +15,13 @@ use std::time::{SystemTime, UNIX_EPOCH}; use arrow::pyarrow::PyArrowType; use chrono::{DateTime, Duration, FixedOffset, Utc}; use deltalake::arrow::compute::concat_batches; +use deltalake::arrow::ffi_stream::ArrowArrayStreamReader; use deltalake::arrow::record_batch::RecordBatch; +use deltalake::arrow::record_batch::RecordBatchReader; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::create_checkpoint; +use deltalake::datafusion::datasource::memory::MemTable; +use deltalake::datafusion::datasource::provider::TableProvider; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; @@ -406,12 +410,10 @@ impl RawDeltaTable { writer_properties = None, matched_update_updates = None, matched_update_predicate = None, - // matched_update_all, matched_delete_predicate = None, matched_delete_all = None, not_matched_insert_updates = None, not_matched_insert_predicate = None, - // not_matched_insert_all, not_matched_by_source_update_updates = None, not_matched_by_source_update_predicate = None, not_matched_by_source_delete_predicate = None, @@ -419,7 +421,7 @@ impl RawDeltaTable { ))] pub fn merge_execute( &mut self, - source: PyArrowType, + source: PyArrowType, predicate: String, source_alias: Option, target_alias: Option, @@ -427,19 +429,25 @@ impl RawDeltaTable { writer_properties: Option>, matched_update_updates: Option>, matched_update_predicate: Option, - // matched_update_all: Option, matched_delete_predicate: Option, matched_delete_all: Option, not_matched_insert_updates: Option>, not_matched_insert_predicate: Option, - // not_matched_insert_all: Option, not_matched_by_source_update_updates: Option>, not_matched_by_source_update_predicate: Option, not_matched_by_source_delete_predicate: Option, not_matched_by_source_delete_all: Option, ) -> PyResult { let ctx = SessionContext::new(); - let source_df = ctx.read_batch(source.0).unwrap(); + let schema = source.0.schema(); + let batches = vec![source + .0 + .into_iter() + .map(|batch| batch.unwrap()) + .collect::>()]; + let table_provider: Arc = + Arc::new(MemTable::try_new(schema, batches).unwrap()); + let source_df = ctx.read_table(table_provider).unwrap(); let mut cmd = MergeBuilder::new( self._table.object_store(), From a35f67ac3a679e31520871e64392ae1953ebb94a Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sat, 14 Oct 2023 23:19:25 +0200 Subject: [PATCH 33/35] resolve lint errors --- python/src/lib.rs | 6 +----- python/stubs/pyarrow/__init__.pyi | 1 + 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index 756c549a70..2f46436984 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -439,11 +439,7 @@ impl RawDeltaTable { ) -> PyResult { let ctx = SessionContext::new(); let schema = source.0.schema(); - let batches = vec![source - .0 - .into_iter() - .map(|batch| batch.unwrap()) - .collect::>()]; + let batches = vec![source.0.map(|batch| batch.unwrap()).collect::>()]; let table_provider: Arc = Arc::new(MemTable::try_new(schema, batches).unwrap()); let source_df = ctx.read_table(table_provider).unwrap(); diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index fb01f5796d..f8c9d152aa 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -4,6 +4,7 @@ __version__: str Schema: Any Table: Any RecordBatch: Any +RecordBatchReader: Any Field: Any DataType: Any ListType: Any From c7709cb04db4612fcc053829cedcfd53aa4d6264 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sun, 15 Oct 2023 11:40:56 +0200 Subject: [PATCH 34/35] clean up code --- python/deltalake/table.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index ed831e6dcb..fd8a5b35b8 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -539,7 +539,7 @@ def merge( match the underlying table. Args: - source (pyarrow.Table | pyarrow.RecordBatch): source data + source (pyarrow.Table | pyarrow.RecordBatch | pyarrow.RecordBatchReader ): source data predicate (str): SQL like predicate on how to merge source_alias (str): Alias for the source table target_alias (str): Alias for the target table @@ -553,29 +553,28 @@ def merge( if isinstance(source, pyarrow.RecordBatchReader): schema = source.schema - else: - schema = source.schema - - if isinstance(source, pyarrow.RecordBatchReader): - batch_iter = source elif isinstance(source, pyarrow.RecordBatch): - batch_iter = [source] + schema = source.schema + source = [source] elif isinstance(source, pyarrow.Table): - batch_iter = source.to_batches() + schema = source.schema + source = source.to_reader() else: - batch_iter = source + raise TypeError( + f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch or Table are valid inputs for source." + ) def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: checker.check_batch(batch) return batch - data = pyarrow.RecordBatchReader.from_batches( - schema, (validate_batch(batch) for batch in batch_iter) + source = pyarrow.RecordBatchReader.from_batches( + schema, (validate_batch(batch) for batch in source) ) return TableMerger( self, - source=data, + source=source, predicate=predicate, source_alias=source_alias, target_alias=target_alias, From ad41a7d69009bd0843cb2129a45e732a68f567d3 Mon Sep 17 00:00:00 2001 From: ion-elgreco Date: Sun, 15 Oct 2023 11:46:13 +0200 Subject: [PATCH 35/35] remove debug print --- python/deltalake/table.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index fd8a5b35b8..e913eb7622 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -921,7 +921,6 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" for col in self.source.schema } - print(self.matched_update_updates) self.matched_update_predicate = predicate return self