diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 97a8ed0480bc..9bc842a12af4 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1449,6 +1449,36 @@ fn from_substrait_type( )?, } } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + let key_field = Arc::new(Field::new( + "key", + from_substrait_type(key_type, dfs_names, name_idx)?, + false, + )); + let value_field = Arc::new(Field::new( + "value", + from_substrait_type(value_type, dfs_names, name_idx)?, + true, + )); + match map.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => { + Ok(DataType::Map(Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), false)) + }, + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + )?, + } + } r#type::Kind::Decimal(d) => match d.type_variation_reference { DECIMAL_128_TYPE_VARIATION_REF => { Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6c00a326291a..302f38606bfb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1629,6 +1629,27 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result match inner.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + let key_type = to_substrait_type( + key_and_value[0].data_type(), + key_and_value[0].is_nullable(), + )?; + let value_type = to_substrait_type( + key_and_value[1].data_type(), + key_and_value[1].is_nullable(), + )?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, DataType::Struct(fields) => { let field_types = fields .iter() @@ -2326,6 +2347,19 @@ mod test { Field::new_list_field(DataType::Int32, true).into(), ))?; + round_trip_type(DataType::Map( + Field::new_struct( + "entries", + [ + Field::new("key", DataType::Utf8, false).into(), + Field::new("value", DataType::Int32, true).into(), + ], + false, + ) + .into(), + false, + ))?; + round_trip_type(DataType::Struct( vec![ Field::new("c0", DataType::Int32, true),