Skip to content

Commit

Permalink
Fix generate_non_canonical_map_case, fix MapArray equality (#1476)
Browse files Browse the repository at this point in the history
* Revamp list_equal for map type

* Canonicalize schema

* Add nullability and metadata
  • Loading branch information
viirya authored Mar 27, 2022
1 parent 6bf3b3a commit c5442cf
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 24 deletions.
102 changes: 80 additions & 22 deletions arrow/src/array/equal/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use crate::datatypes::DataType;
use crate::{
array::ArrayData,
array::{data::count_nulls, OffsetSizeTrait},
buffer::Buffer,
util::bit_util::get_bit,
};

use super::{equal_range, utils::child_logical_null_buffer};
use super::{
equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls,
};

fn lengths_equal<T: OffsetSizeTrait>(lhs: &[T], rhs: &[T]) -> bool {
// invariant from `base_equal`
Expand Down Expand Up @@ -58,22 +61,47 @@ fn offset_value_equal<T: OffsetSizeTrait>(
lhs_pos: usize,
rhs_pos: usize,
len: usize,
data_type: &DataType,
) -> bool {
let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap();
let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap();
let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos];
let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos];

lhs_len == rhs_len
&& equal_range(
lhs_values,
rhs_values,
lhs_nulls,
rhs_nulls,
lhs_start,
rhs_start,
lhs_len.to_usize().unwrap(),
)
lhs_len == rhs_len && {
match data_type {
DataType::Map(_, _) => {
// Don't use `equal_range` which calls `utils::base_equal` that checks
// struct fields, but we don't enforce struct field names.
equal_nulls(
lhs_values,
rhs_values,
lhs_nulls,
rhs_nulls,
lhs_start,
rhs_start,
lhs_len.to_usize().unwrap(),
) && equal_values(
lhs_values,
rhs_values,
lhs_nulls,
rhs_nulls,
lhs_start,
rhs_start,
lhs_len.to_usize().unwrap(),
)
}
_ => equal_range(
lhs_values,
rhs_values,
lhs_nulls,
rhs_nulls,
lhs_start,
rhs_start,
lhs_len.to_usize().unwrap(),
),
}
}
}

pub(super) fn list_equal<T: OffsetSizeTrait>(
Expand Down Expand Up @@ -131,17 +159,46 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
lengths_equal(
&lhs_offsets[lhs_start..lhs_start + len],
&rhs_offsets[rhs_start..rhs_start + len],
) && equal_range(
lhs_values,
rhs_values,
child_lhs_nulls.as_ref(),
child_rhs_nulls.as_ref(),
lhs_offsets[lhs_start].to_usize().unwrap(),
rhs_offsets[rhs_start].to_usize().unwrap(),
(lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
.to_usize()
.unwrap(),
)
) && {
match lhs.data_type() {
DataType::Map(_, _) => {
// Don't use `equal_range` which calls `utils::base_equal` that checks
// struct fields, but we don't enforce struct field names.
equal_nulls(
lhs_values,
rhs_values,
child_lhs_nulls.as_ref(),
child_rhs_nulls.as_ref(),
lhs_offsets[lhs_start].to_usize().unwrap(),
rhs_offsets[rhs_start].to_usize().unwrap(),
(lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
.to_usize()
.unwrap(),
) && equal_values(
lhs_values,
rhs_values,
child_lhs_nulls.as_ref(),
child_rhs_nulls.as_ref(),
lhs_offsets[lhs_start].to_usize().unwrap(),
rhs_offsets[rhs_start].to_usize().unwrap(),
(lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
.to_usize()
.unwrap(),
)
}
_ => equal_range(
lhs_values,
rhs_values,
child_lhs_nulls.as_ref(),
child_rhs_nulls.as_ref(),
lhs_offsets[lhs_start].to_usize().unwrap(),
rhs_offsets[rhs_start].to_usize().unwrap(),
(lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
.to_usize()
.unwrap(),
),
}
}
} else {
// get a ref of the parent null buffer bytes, to use in testing for nullness
let lhs_null_bytes = lhs_nulls.unwrap().as_slice();
Expand All @@ -166,6 +223,7 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
lhs_pos,
rhs_pos,
1,
lhs.data_type(),
)
})
}
Expand Down
30 changes: 29 additions & 1 deletion arrow/src/array/equal/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,35 @@ pub(super) fn equal_nulls(

#[inline]
pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()
let equal_type = match (lhs.data_type(), rhs.data_type()) {
(DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => {
let field_equal = match (l_field.data_type(), r_field.data_type()) {
(DataType::Struct(l_fields), DataType::Struct(r_fields))
if l_fields.len() == 2 && r_fields.len() == 2 =>
{
let l_key_field = l_fields.get(0).unwrap();
let r_key_field = r_fields.get(0).unwrap();
let l_value_field = l_fields.get(1).unwrap();
let r_value_field = r_fields.get(1).unwrap();

// We don't enforce the equality of field names
let data_type_equal = l_key_field.data_type()
== r_key_field.data_type()
&& l_value_field.data_type() == r_value_field.data_type();
let nullability_equal = l_key_field.is_nullable()
== r_key_field.is_nullable()
&& l_value_field.is_nullable() == r_value_field.is_nullable();
let metadata_equal = l_key_field.metadata() == r_key_field.metadata()
&& l_value_field.metadata() == r_value_field.metadata();
data_type_equal && nullability_equal && metadata_equal
}
_ => panic!("Map type should have 2 fields Struct in its field"),
};
field_equal && l_sorted == r_sorted
}
(l_data_type, r_data_type) => l_data_type == r_data_type,
};
equal_type && lhs.len() == rhs.len()
}

// whether the two memory regions are equal
Expand Down
45 changes: 44 additions & 1 deletion integration-testing/src/bin/arrow-json-integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::Schema;
use arrow::datatypes::{DataType, Field};
use arrow::error::{ArrowError, Result};
use arrow::ipc::reader::FileReader;
use arrow::ipc::writer::FileWriter;
Expand Down Expand Up @@ -107,6 +109,47 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()>
Ok(())
}

fn canonicalize_schema(schema: &Schema) -> Schema {
let fields = schema
.fields()
.iter()
.map(|field| match field.data_type() {
DataType::Map(child_field, sorted) => match child_field.data_type() {
DataType::Struct(fields) if fields.len() == 2 => {
let first_field = fields.get(0).unwrap();
let key_field = Field::new(
"key",
first_field.data_type().clone(),
first_field.is_nullable(),
);
let second_field = fields.get(1).unwrap();
let value_field = Field::new(
"value",
second_field.data_type().clone(),
second_field.is_nullable(),
);

let struct_type = DataType::Struct(vec![key_field, value_field]);
let child_field =
Field::new("entries", struct_type, child_field.is_nullable());

Field::new(
field.name().as_str(),
DataType::Map(Box::new(child_field), *sorted),
field.is_nullable(),
)
}
_ => panic!(
"The child field of Map type should be Struct type with 2 fields."
),
},
_ => field.clone(),
})
.collect::<Vec<_>>();

Schema::new(fields).with_metadata(schema.metadata().clone())
}

fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> {
if verbose {
eprintln!("Validating {} and {}", arrow_name, json_name);
Expand All @@ -121,7 +164,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> {
let arrow_schema = arrow_reader.schema().as_ref().to_owned();

// compare schemas
if json_file.schema != arrow_schema {
if canonicalize_schema(&json_file.schema) != canonicalize_schema(&arrow_schema) {
return Err(ArrowError::ComputeError(format!(
"Schemas do not match. JSON: {:?}. Arrow: {:?}",
json_file.schema, arrow_schema
Expand Down

0 comments on commit c5442cf

Please sign in to comment.