Skip to content

Commit

Permalink
add support for f16 (#888)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist authored Nov 29, 2021
1 parent 434e3f7 commit f6908bf
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 22 deletions.
1 change: 1 addition & 0 deletions arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ serde_json = { version = "1.0", features = ["preserve_order"] }
indexmap = "1.6"
rand = { version = "0.8", optional = true }
num = "0.4"
half = "1.8"
csv_crate = { version = "1.1", optional = true, package="csv" }
regex = "1.3"
lazy_static = "1.4"
Expand Down
2 changes: 2 additions & 0 deletions arrow/src/alloc/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::datatypes::DataType;
use half::f16;

/// A type that Rust's custom allocator knows how to allocate and deallocate.
/// This is implemented for all Arrow's physical types whose in-memory representation
Expand Down Expand Up @@ -67,5 +68,6 @@ create_native!(
i64,
DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _)
);
create_native!(f16, DataType::Float16);
create_native!(f32, DataType::Float32);
create_native!(f64, DataType::Float64);
4 changes: 2 additions & 2 deletions arrow/src/array/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::UInt16 => Arc::new(UInt16Array::from(data)) as ArrayRef,
DataType::UInt32 => Arc::new(UInt32Array::from(data)) as ArrayRef,
DataType::UInt64 => Arc::new(UInt64Array::from(data)) as ArrayRef,
DataType::Float16 => panic!("Float16 datatype not supported"),
DataType::Float16 => Arc::new(Float16Array::from(data)) as ArrayRef,
DataType::Float32 => Arc::new(Float32Array::from(data)) as ArrayRef,
DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef,
DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef,
Expand Down Expand Up @@ -393,7 +393,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef {
DataType::UInt8 => new_null_sized_array::<UInt8Type>(data_type, length),
DataType::Int16 => new_null_sized_array::<Int16Type>(data_type, length),
DataType::UInt16 => new_null_sized_array::<UInt16Type>(data_type, length),
DataType::Float16 => unreachable!(),
DataType::Float16 => new_null_sized_array::<Float16Type>(data_type, length),
DataType::Int32 => new_null_sized_array::<Int32Type>(data_type, length),
DataType::UInt32 => new_null_sized_array::<UInt32Type>(data_type, length),
DataType::Float32 => new_null_sized_array::<Float32Type>(data_type, length),
Expand Down
17 changes: 10 additions & 7 deletions arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
//! Contains `ArrayData`, a generic representation of Arrow array data which encapsulates
//! common attributes and operations for Arrow array.

use std::convert::TryInto;
use std::mem;
use std::sync::Arc;

use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::{bitmap::Bitmap, datatypes::ArrowNativeType};
use crate::{
buffer::{Buffer, MutableBuffer},
util::bit_util,
};
use half::f16;
use std::convert::TryInto;
use std::mem;
use std::sync::Arc;

use super::equal::equal;

Expand Down Expand Up @@ -89,6 +89,10 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<i64>()),
empty_buffer,
],
DataType::Float16 => [
MutableBuffer::new(capacity * mem::size_of::<f16>()),
empty_buffer,
],
DataType::Float32 => [
MutableBuffer::new(capacity * mem::size_of::<f32>()),
empty_buffer,
Expand Down Expand Up @@ -178,7 +182,6 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
],
_ => unreachable!(),
},
DataType::Float16 => unreachable!(),
DataType::FixedSizeList(_, _) | DataType::Struct(_) => {
[empty_buffer, MutableBuffer::new(0)]
}
Expand Down Expand Up @@ -319,7 +322,7 @@ impl ArrayData {
buffers: Vec<Buffer>,
child_data: Vec<ArrayData>,
) -> Result<Self> {
// Safetly justification: `validate` is (will be) called below
// Safety justification: `validate` is (will be) called below
let new_self = unsafe {
Self::new_unchecked(
data_type,
Expand Down Expand Up @@ -519,6 +522,7 @@ impl ArrayData {
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Date32
Expand Down Expand Up @@ -554,7 +558,6 @@ impl ArrayData {
DataType::Dictionary(_, data_type) => {
vec![Self::new_empty(data_type)]
}
DataType::Float16 => unreachable!(),
};

// Data was constructed correctly above
Expand Down
6 changes: 4 additions & 2 deletions arrow/src/array/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ use super::{
GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray,
StringOffsetSizeTrait, StructArray,
};

use crate::{
buffer::Buffer,
datatypes::{ArrowPrimitiveType, DataType, IntervalUnit},
};
use half::f16;

mod boolean;
mod decimal;
Expand Down Expand Up @@ -251,7 +251,9 @@ fn equal_values(
),
_ => unreachable!(),
},
DataType::Float16 => unreachable!(),
DataType::Float16 => primitive_equal::<f16>(
lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
),
DataType::Map(_, _) => {
list_equal::<i32>(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
}
Expand Down
8 changes: 8 additions & 0 deletions arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ pub type UInt64Array = PrimitiveArray<UInt64Type>;
///
/// # Example: Using `collect`
/// ```
/// # use arrow::array::Float16Array;
/// use half::f16;
/// let arr : Float16Array = [Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))].into_iter().collect();
/// ```
pub type Float16Array = PrimitiveArray<Float16Type>;
///
/// # Example: Using `collect`
/// ```
/// # use arrow::array::Float32Array;
/// let arr : Float32Array = [Some(1.0), Some(2.0)].into_iter().collect();
/// ```
Expand Down
18 changes: 9 additions & 9 deletions arrow/src/array/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
// specific language governing permissions and limitations
// under the License.

use super::{
data::{into_buffers, new_buffers},
ArrayData, ArrayDataBuilder,
};
use crate::array::StringOffsetSizeTrait;
use crate::{
buffer::MutableBuffer,
datatypes::DataType,
error::{ArrowError, Result},
util::bit_util,
};
use half::f16;
use std::mem;

use super::{
data::{into_buffers, new_buffers},
ArrayData, ArrayDataBuilder,
};
use crate::array::StringOffsetSizeTrait;

mod boolean;
mod fixed_binary;
mod list;
Expand Down Expand Up @@ -266,7 +266,7 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
DataType::Struct(_) => structure::build_extend(array),
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => unreachable!(),
DataType::Float16 => primitive::build_extend::<f16>(array),
/*
DataType::FixedSizeList(_, _) => {}
DataType::Union(_) => {}
Expand Down Expand Up @@ -315,7 +315,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
},
DataType::Struct(_) => structure::extend_nulls,
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
DataType::Float16 => unreachable!(),
DataType::Float16 => primitive::extend_nulls::<f16>,
/*
DataType::FixedSizeList(_, _) => {}
DataType::Union(_) => {}
Expand Down Expand Up @@ -429,6 +429,7 @@ impl<'a> MutableArrayData<'a> {
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Date32
Expand Down Expand Up @@ -467,7 +468,6 @@ impl<'a> MutableArrayData<'a> {
}
// the dictionary type just appends keys and clones the values.
DataType::Dictionary(_, _) => vec![],
DataType::Float16 => unreachable!(),
DataType::Struct(fields) => match capacities {
Capacities::Struct(capacity, Some(ref child_capacities)) => {
array_capacity = capacity;
Expand Down
11 changes: 9 additions & 2 deletions arrow/src/datatypes/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use serde_json::{Number, Value};

use super::DataType;
use half::f16;
use serde_json::{Number, Value};

/// Trait declaring any type that is serializable to JSON. This includes all primitive types (bool, i32, etc.).
pub trait JsonSerializable: 'static {
Expand Down Expand Up @@ -293,6 +293,12 @@ impl ArrowNativeType for u64 {
}
}

impl JsonSerializable for f16 {
fn into_json_value(self) -> Option<Value> {
Number::from_f64(f64::round(f64::from(self) * 1000.0) / 1000.0).map(Value::Number)
}
}

impl JsonSerializable for f32 {
fn into_json_value(self) -> Option<Value> {
Number::from_f64(f64::round(self as f64 * 1000.0) / 1000.0).map(Value::Number)
Expand All @@ -305,6 +311,7 @@ impl JsonSerializable for f64 {
}
}

impl ArrowNativeType for f16 {}
impl ArrowNativeType for f32 {}
impl ArrowNativeType for f64 {}

Expand Down
2 changes: 2 additions & 0 deletions arrow/src/datatypes/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use super::{ArrowPrimitiveType, DataType, IntervalUnit, TimeUnit};
use half::f16;

// BooleanType is special: its bit-width is not the size of the primitive type, and its `index`
// operation assumes bit-packing.
Expand Down Expand Up @@ -46,6 +47,7 @@ make_type!(UInt8Type, u8, DataType::UInt8);
make_type!(UInt16Type, u16, DataType::UInt16);
make_type!(UInt32Type, u32, DataType::UInt32);
make_type!(UInt64Type, u64, DataType::UInt64);
make_type!(Float16Type, f16, DataType::Float16);
make_type!(Float32Type, f32, DataType::Float32);
make_type!(Float64Type, f64, DataType::Float64);
make_type!(
Expand Down

0 comments on commit f6908bf

Please sign in to comment.