Skip to content

Commit

Permalink
add support for f16
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Nov 2, 2021
1 parent 2cf3178 commit 4b9ff7c
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 24 deletions.
4 changes: 3 additions & 1 deletion arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ serde_json = { version = "1.0", features = ["preserve_order"] }
indexmap = "1.6"
rand = { version = "0.8", optional = true }
num = "0.4"
half = { version = "1.8", optional = true }
csv_crate = { version = "1.1", optional = true, package="csv" }
regex = "1.3"
lazy_static = "1.4"
Expand All @@ -58,9 +59,10 @@ multiversion = "0.6.1"
bitflags = "1.2.1"

[features]
default = ["csv", "ipc", "test_utils"]
default = ["csv", "ipc", "test_utils", "f16"]
avx512 = []
csv = ["csv_crate"]
f16 = ["half"]
ipc = ["flatbuffers"]
simd = ["packed_simd"]
prettyprint = ["comfy-table"]
Expand Down
4 changes: 4 additions & 0 deletions arrow/src/alloc/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// under the License.

use crate::datatypes::DataType;
#[cfg(feature = "f16")]
use half::f16;

/// A type that Rust's custom allocator knows how to allocate and deallocate.
/// This is implemented for all Arrow's physical types whose in-memory representation
Expand Down Expand Up @@ -67,5 +69,7 @@ create_native!(
i64,
DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _)
);
#[cfg(feature = "f16")]
create_native!(f16, DataType::Float16);
create_native!(f32, DataType::Float32);
create_native!(f64, DataType::Float64);
22 changes: 20 additions & 2 deletions arrow/src/array/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,16 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::UInt16 => Arc::new(UInt16Array::from(data)) as ArrayRef,
DataType::UInt32 => Arc::new(UInt32Array::from(data)) as ArrayRef,
DataType::UInt64 => Arc::new(UInt64Array::from(data)) as ArrayRef,
DataType::Float16 => panic!("Float16 datatype not supported"),
DataType::Float16 => {
#[cfg(feature = "f16")]
{
Arc::new(Float16Array::from(data)) as ArrayRef
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
DataType::Float32 => Arc::new(Float32Array::from(data)) as ArrayRef,
DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef,
DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef,
Expand Down Expand Up @@ -393,7 +402,16 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef {
DataType::UInt8 => new_null_sized_array::<UInt8Type>(data_type, length),
DataType::Int16 => new_null_sized_array::<Int16Type>(data_type, length),
DataType::UInt16 => new_null_sized_array::<UInt16Type>(data_type, length),
DataType::Float16 => unreachable!(),
DataType::Float16 => {
#[cfg(feature = "f16")]
{
new_null_sized_array::<Float16Type>(data_type, length)
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
DataType::Int32 => new_null_sized_array::<Int32Type>(data_type, length),
DataType::UInt32 => new_null_sized_array::<UInt32Type>(data_type, length),
DataType::Float32 => new_null_sized_array::<Float32Type>(data_type, length),
Expand Down
32 changes: 27 additions & 5 deletions arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
//! Contains `ArrayData`, a generic representation of Arrow array data which encapsulates
//! common attributes and operations for Arrow array.

use std::mem;
use std::sync::Arc;

use crate::datatypes::{DataType, IntervalUnit};
use crate::error::Result;
use crate::{bitmap::Bitmap, datatypes::ArrowNativeType};
use crate::{
buffer::{Buffer, MutableBuffer},
util::bit_util,
};
#[cfg(feature = "f16")]
use half::f16;
use std::mem;
use std::sync::Arc;

use super::equal::equal;

Expand Down Expand Up @@ -88,6 +89,19 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<i64>()),
empty_buffer,
],
DataType::Float16 => {
#[cfg(feature = "f16")]
{
[
MutableBuffer::new(capacity * mem::size_of::<f16>()),
empty_buffer,
]
}
#[cfg(not(feature = "f16"))]
{
unimplemented!()
}
}
DataType::Float32 => [
MutableBuffer::new(capacity * mem::size_of::<f32>()),
empty_buffer,
Expand Down Expand Up @@ -177,7 +191,6 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
],
_ => unreachable!(),
},
DataType::Float16 => unreachable!(),
DataType::FixedSizeList(_, _) | DataType::Struct(_) => {
[empty_buffer, MutableBuffer::new(0)]
}
Expand Down Expand Up @@ -543,7 +556,16 @@ impl ArrayData {
DataType::Dictionary(_, data_type) => {
vec![Self::new_empty(data_type)]
}
DataType::Float16 => unreachable!(),
DataType::Float16 => {
#[cfg(feature = "f16")]
{
vec![]
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
};

// Data was constructed correctly above
Expand Down
14 changes: 13 additions & 1 deletion arrow/src/array/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,19 @@ fn equal_values(
),
_ => unreachable!(),
},
DataType::Float16 => unreachable!(),
DataType::Float16 => {
#[cfg(feature = "f16")]
{
use half::f16;
primitive_equal::<f16>(
lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
)
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
DataType::Map(_, _) => {
list_equal::<i32>(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
}
Expand Down
11 changes: 11 additions & 0 deletions arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ pub type UInt64Array = PrimitiveArray<UInt64Type>;
///
/// # Example: Using `collect`
/// ```
/// # #[cfg(feature = "f16")] {
/// # use arrow::array::Float16Array;
/// use half::f16;
/// let arr : Float16Array = [Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))].into_iter().collect();
/// # }
/// ```
#[cfg(feature = "f16")]
pub type Float16Array = PrimitiveArray<Float16Type>;
///
/// # Example: Using `collect`
/// ```
/// # use arrow::array::Float32Array;
/// let arr : Float32Array = [Some(1.0), Some(2.0)].into_iter().collect();
/// ```
Expand Down
46 changes: 37 additions & 9 deletions arrow/src/array/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@
// specific language governing permissions and limitations
// under the License.

use super::{
data::{into_buffers, new_buffers},
ArrayData, ArrayDataBuilder,
};
use crate::array::StringOffsetSizeTrait;
use crate::{
buffer::MutableBuffer,
datatypes::DataType,
error::{ArrowError, Result},
util::bit_util,
};
#[cfg(feature = "f16")]
use half::f16;
use std::mem;

use super::{
data::{into_buffers, new_buffers},
ArrayData, ArrayDataBuilder,
};
use crate::array::StringOffsetSizeTrait;

mod boolean;
mod fixed_binary;
mod list;
Expand Down Expand Up @@ -266,7 +267,16 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
DataType::Struct(_) => structure::build_extend(array),
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => unreachable!(),
DataType::Float16 => {
#[cfg(feature = "f16")]
{
primitive::build_extend::<f16>(array)
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
/*
DataType::FixedSizeList(_, _) => {}
DataType::Union(_) => {}
Expand Down Expand Up @@ -315,7 +325,16 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
},
DataType::Struct(_) => structure::extend_nulls,
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
DataType::Float16 => unreachable!(),
DataType::Float16 => {
#[cfg(feature = "f16")]
{
primitive::extend_nulls::<f16>
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
/*
DataType::FixedSizeList(_, _) => {}
DataType::Union(_) => {}
Expand Down Expand Up @@ -467,7 +486,16 @@ impl<'a> MutableArrayData<'a> {
}
// the dictionary type just appends keys and clones the values.
DataType::Dictionary(_, _) => vec![],
DataType::Float16 => unreachable!(),
DataType::Float16 => {
#[cfg(feature = "f16")]
{
vec![]
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
DataType::Struct(fields) => match capacities {
Capacities::Struct(capacity, Some(ref child_capacities)) => {
array_capacity = capacity;
Expand Down
14 changes: 12 additions & 2 deletions arrow/src/datatypes/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use serde_json::{Number, Value};

use super::DataType;
#[cfg(feature = "f16")]
use half::f16;
use serde_json::{Number, Value};

/// Trait declaring any type that is serializable to JSON. This includes all primitive types (bool, i32, etc.).
pub trait JsonSerializable: 'static {
Expand Down Expand Up @@ -293,6 +294,13 @@ impl ArrowNativeType for u64 {
}
}

#[cfg(feature = "f16")]
impl JsonSerializable for f16 {
fn into_json_value(self) -> Option<Value> {
Number::from_f64(f64::round(f64::from(self) * 1000.0) / 1000.0).map(Value::Number)
}
}

impl JsonSerializable for f32 {
fn into_json_value(self) -> Option<Value> {
Number::from_f64(f64::round(self as f64 * 1000.0) / 1000.0).map(Value::Number)
Expand All @@ -305,6 +313,8 @@ impl JsonSerializable for f64 {
}
}

#[cfg(feature = "f16")]
impl ArrowNativeType for f16 {}
impl ArrowNativeType for f32 {}
impl ArrowNativeType for f64 {}

Expand Down
4 changes: 4 additions & 0 deletions arrow/src/datatypes/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// under the License.

use super::{ArrowPrimitiveType, DataType, IntervalUnit, TimeUnit};
#[cfg(feature = "f16")]
use half::f16;

// BooleanType is special: its bit-width is not the size of the primitive type, and its `index`
// operation assumes bit-packing.
Expand Down Expand Up @@ -46,6 +48,8 @@ make_type!(UInt8Type, u8, DataType::UInt8);
make_type!(UInt16Type, u16, DataType::UInt16);
make_type!(UInt32Type, u32, DataType::UInt32);
make_type!(UInt64Type, u64, DataType::UInt64);
#[cfg(feature = "f16")]
make_type!(Float16Type, f16, DataType::Float16);
make_type!(Float32Type, f32, DataType::Float32);
make_type!(Float64Type, f64, DataType::Float64);
make_type!(
Expand Down
10 changes: 9 additions & 1 deletion arrow/src/json/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,15 @@ impl Decoder {
DataType::UInt32 => self.read_primitive_list_values::<UInt32Type>(rows),
DataType::UInt64 => self.read_primitive_list_values::<UInt64Type>(rows),
DataType::Float16 => {
return Err(ArrowError::JsonError("Float16 not supported".to_string()))
#[cfg(feature = "f16")]
{
unimplemented!("Float16 datatype not supported")
// self.read_primitive_list_values::<Float16Type>(rows)
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
DataType::Float32 => self.read_primitive_list_values::<Float32Type>(rows),
DataType::Float64 => self.read_primitive_list_values::<Float64Type>(rows),
Expand Down
12 changes: 9 additions & 3 deletions arrow/src/util/data_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,15 @@ pub fn create_random_array(
UInt32 => Arc::new(create_primitive_array::<UInt32Type>(size, null_density)),
UInt64 => Arc::new(create_primitive_array::<UInt64Type>(size, null_density)),
Float16 => {
return Err(ArrowError::NotYetImplemented(
"Float16 is not implememted".to_string(),
))
#[cfg(feature = "f16")]
{
// Arc::new(create_primitive_array::<Float16Type>(size, null_density))
unimplemented!("Float16 datatype not supported")
}
#[cfg(not(feature = "f16"))]
{
unimplemented!("Float16 datatype not supported")
}
}
Float32 => Arc::new(create_primitive_array::<Float32Type>(size, null_density)),
Float64 => Arc::new(create_primitive_array::<Float64Type>(size, null_density)),
Expand Down

0 comments on commit 4b9ff7c

Please sign in to comment.