Skip to content

Commit

Permalink
Integrate the f8e8m0 with Constant and Convert (openvinotoolkit#25105)
Browse files Browse the repository at this point in the history
### Details:
 * Add f16, bf16, f32 <-> f8e8m0 conversion in Convert operator
 * Add f8e8m0 support in Constant operator

### Tickets:
 - [*CVS-141563*](https://jira.devtools.intel.com/browse/CVS-141563)
  • Loading branch information
barnasm1 authored and allnes committed Jun 26, 2024
1 parent 201a1ea commit 33a51b0
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 19 deletions.
61 changes: 61 additions & 0 deletions src/bindings/python/tests/test_graph/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,36 @@ def test_float_to_f8e4m3_constant(ov_type, numpy_dtype):
assert np.allclose(result, target, equal_nan=True)


@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [
(Type.f32, np.float32),
(Type.f16, np.float16),
])
def test_float_to_f8e8m0_constant(ov_type, numpy_dtype):
from openvino.runtime import opset12 as opset
import openvino as ov
data = np.array([4.75, 4.5, 5.25, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5,
0.6, 0.7, 0.8, 0.9, 1, -0.0, 1.1, 1.2, 1.3,
1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 448, 512, np.nan], dtype=numpy_dtype)

compressed_const = opset.constant(data, dtype=ov.Type.f8e8m0, name="f8e8m0_constant")
convert = opset.convert(compressed_const, data.dtype)
parameter = opset.parameter(ov.PartialShape([-1]), ov_type)
add_op = opset.add(parameter, convert)
model = ov.Model([add_op], [parameter])

compiled = ov.compile_model(model)
tensor = np.zeros(data.shape, dtype=numpy_dtype)
result = compiled(tensor)[0]

target = [4.0, 4.0, 4.0, 0.0, 0.125, 0.25, 0.25,
0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0,
0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0,
2.0, 2.0, 2.0, 2.0, 512, 512, np.nan]
target = np.array(target, dtype=numpy_dtype)

assert np.allclose(result, target, equal_nan=True)


@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [
(Type.f32, np.float32),
(Type.f16, np.float16),
Expand Down Expand Up @@ -535,6 +565,37 @@ def test_float_to_f8e4m3_convert(ov_type, numpy_dtype):
assert np.allclose(result, target, equal_nan=True)


@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [
(Type.f32, np.float32),
(Type.f16, np.float16),
])
def test_float_to_f8e8m0_convert(ov_type, numpy_dtype):
from openvino.runtime import opset12 as opset
import openvino as ov
data = np.array([4.75, 4.5, 5.25, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5,
0.6, 0.7, 0.8, 0.9, 1, -0.0, 1.1, 1.2, 1.3,
1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 448, 512, np.nan], dtype=numpy_dtype)

compressed_const = opset.constant(data, dtype=ov_type, name="fx_constant")
convert_to_fp8 = opset.convert(compressed_const, Type.f8e8m0)
convert_back = opset.convert(convert_to_fp8, ov_type)
parameter = opset.parameter(ov.PartialShape([-1]), ov_type)
add_op = opset.add(parameter, convert_back)
model = ov.Model([add_op], [parameter])

compiled = ov.compile_model(model)
tensor = np.zeros(data.shape, dtype=numpy_dtype)
result = compiled(tensor)[0]

target = [4.0, 4.0, 4.0, 0.0, 0.125, 0.25, 0.25,
0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0,
0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0,
2.0, 2.0, 2.0, 2.0, 512, 512, np.nan]
target = np.array(target, dtype=numpy_dtype)

assert np.allclose(result, target, equal_nan=True)


@pytest.mark.parametrize(
("src_dtype"),
[
Expand Down
23 changes: 23 additions & 0 deletions src/core/include/openvino/op/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class OPENVINO_API Constant : public Op {
fill_lp_data<Type_t::f4e2m1>(value);
break;
case Type_t::f8e8m0:
fill_data<Type_t::f8e8m0>(value);
break;
case Type_t::undefined:
case Type_t::dynamic:
OPENVINO_THROW("unsupported type");
Expand Down Expand Up @@ -370,6 +372,9 @@ class OPENVINO_API Constant : public Op {
case Type_t::f4e2m1:
cast_lp_vector<Type_t::f4e2m1>(rc, num_elements_to_cast);
break;
case Type_t::f8e8m0:
cast_vector<Type_t::f8e8m0>(rc, num_elements_to_cast);
break;
default:
OPENVINO_THROW("unsupported type");
}
Expand Down Expand Up @@ -729,6 +734,8 @@ class OPENVINO_API Constant : public Op {
write_lp_buffer<Type_t::f4e2m1>(source);
break;
case Type_t::f8e8m0:
write_buffer<Type_t::f8e8m0>(source);
break;
case Type_t::undefined:
case Type_t::dynamic:
OPENVINO_THROW("unsupported type");
Expand Down Expand Up @@ -865,6 +872,7 @@ CONSTANT_FILL_DATA_SPECIALIZATION(u1, long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u1, unsigned long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u1, float8_e4m3)
CONSTANT_FILL_DATA_SPECIALIZATION(u1, float8_e5m2)
CONSTANT_FILL_DATA_SPECIALIZATION(u1, float8_e8m0)
CONSTANT_FILL_DATA_SPECIALIZATION(u1, float16)
CONSTANT_FILL_DATA_SPECIALIZATION(u1, bfloat16)
CONSTANT_FILL_DATA_SPECIALIZATION(u1, float)
Expand All @@ -884,6 +892,7 @@ CONSTANT_FILL_DATA_SPECIALIZATION(u2, long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u2, unsigned long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u2, float8_e4m3)
CONSTANT_FILL_DATA_SPECIALIZATION(u2, float8_e5m2)
CONSTANT_FILL_DATA_SPECIALIZATION(u2, float8_e8m0)
CONSTANT_FILL_DATA_SPECIALIZATION(u2, float16)
CONSTANT_FILL_DATA_SPECIALIZATION(u2, bfloat16)
CONSTANT_FILL_DATA_SPECIALIZATION(u2, float)
Expand All @@ -903,6 +912,7 @@ CONSTANT_FILL_DATA_SPECIALIZATION(u3, long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u3, unsigned long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u3, float8_e4m3)
CONSTANT_FILL_DATA_SPECIALIZATION(u3, float8_e5m2)
CONSTANT_FILL_DATA_SPECIALIZATION(u3, float8_e8m0)
CONSTANT_FILL_DATA_SPECIALIZATION(u3, float16)
CONSTANT_FILL_DATA_SPECIALIZATION(u3, bfloat16)
CONSTANT_FILL_DATA_SPECIALIZATION(u3, float)
Expand All @@ -922,6 +932,7 @@ CONSTANT_FILL_DATA_SPECIALIZATION(u4, long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u4, unsigned long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u4, float8_e4m3)
CONSTANT_FILL_DATA_SPECIALIZATION(u4, float8_e5m2)
CONSTANT_FILL_DATA_SPECIALIZATION(u4, float8_e8m0)
CONSTANT_FILL_DATA_SPECIALIZATION(u4, float16)
CONSTANT_FILL_DATA_SPECIALIZATION(u4, bfloat16)
CONSTANT_FILL_DATA_SPECIALIZATION(u4, float)
Expand All @@ -941,6 +952,7 @@ CONSTANT_FILL_DATA_SPECIALIZATION(u6, long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u6, unsigned long long)
CONSTANT_FILL_DATA_SPECIALIZATION(u6, float8_e4m3)
CONSTANT_FILL_DATA_SPECIALIZATION(u6, float8_e5m2)
CONSTANT_FILL_DATA_SPECIALIZATION(u6, float8_e8m0)
CONSTANT_FILL_DATA_SPECIALIZATION(u6, float16)
CONSTANT_FILL_DATA_SPECIALIZATION(u6, bfloat16)
CONSTANT_FILL_DATA_SPECIALIZATION(u6, float)
Expand All @@ -960,6 +972,7 @@ CONSTANT_FILL_DATA_SPECIALIZATION(i4, long long)
CONSTANT_FILL_DATA_SPECIALIZATION(i4, unsigned long long)
CONSTANT_FILL_DATA_SPECIALIZATION(i4, float8_e4m3)
CONSTANT_FILL_DATA_SPECIALIZATION(i4, float8_e5m2)
CONSTANT_FILL_DATA_SPECIALIZATION(i4, float8_e8m0)
CONSTANT_FILL_DATA_SPECIALIZATION(i4, float16)
CONSTANT_FILL_DATA_SPECIALIZATION(i4, bfloat16)
CONSTANT_FILL_DATA_SPECIALIZATION(i4, float)
Expand All @@ -979,6 +992,7 @@ CONSTANT_FILL_DATA_SPECIALIZATION(nf4, long long)
CONSTANT_FILL_DATA_SPECIALIZATION(nf4, unsigned long long)
CONSTANT_FILL_DATA_SPECIALIZATION(nf4, float8_e4m3)
CONSTANT_FILL_DATA_SPECIALIZATION(nf4, float8_e5m2)
CONSTANT_FILL_DATA_SPECIALIZATION(nf4, float8_e8m0)
CONSTANT_FILL_DATA_SPECIALIZATION(nf4, float16)
CONSTANT_FILL_DATA_SPECIALIZATION(nf4, bfloat16)
CONSTANT_FILL_DATA_SPECIALIZATION(nf4, float)
Expand All @@ -998,6 +1012,7 @@ CONSTANT_FILL_DATA_SPECIALIZATION(f4e2m1, long long)
CONSTANT_FILL_DATA_SPECIALIZATION(f4e2m1, unsigned long long)
CONSTANT_FILL_DATA_SPECIALIZATION(f4e2m1, float8_e4m3)
CONSTANT_FILL_DATA_SPECIALIZATION(f4e2m1, float8_e5m2)
CONSTANT_FILL_DATA_SPECIALIZATION(f4e2m1, float8_e8m0)
CONSTANT_FILL_DATA_SPECIALIZATION(f4e2m1, float16)
CONSTANT_FILL_DATA_SPECIALIZATION(f4e2m1, bfloat16)
CONSTANT_FILL_DATA_SPECIALIZATION(f4e2m1, float)
Expand Down Expand Up @@ -1148,6 +1163,7 @@ CONSTANT_WRITE_BUFFER_SPECIALIZATION(u1, long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u1, unsigned long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u1, float8_e4m3)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u1, float8_e5m2)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u1, float8_e8m0)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u1, float16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u1, bfloat16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u1, float)
Expand All @@ -1167,6 +1183,7 @@ CONSTANT_WRITE_BUFFER_SPECIALIZATION(u2, long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u2, unsigned long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u2, float8_e4m3)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u2, float8_e5m2)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u2, float8_e8m0)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u2, float16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u2, bfloat16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u2, float)
Expand All @@ -1186,6 +1203,7 @@ CONSTANT_WRITE_BUFFER_SPECIALIZATION(u3, long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u3, unsigned long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u3, float8_e4m3)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u3, float8_e5m2)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u3, float8_e8m0)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u3, float16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u3, bfloat16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u3, float)
Expand All @@ -1205,6 +1223,7 @@ CONSTANT_WRITE_BUFFER_SPECIALIZATION(u4, long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u4, unsigned long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u4, float8_e4m3)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u4, float8_e5m2)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u4, float8_e8m0)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u4, float16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u4, bfloat16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u4, float)
Expand All @@ -1224,6 +1243,7 @@ CONSTANT_WRITE_BUFFER_SPECIALIZATION(u6, long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u6, unsigned long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u6, float8_e4m3)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u6, float8_e5m2)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u6, float8_e8m0)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u6, float16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u6, bfloat16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(u6, float)
Expand All @@ -1243,6 +1263,7 @@ CONSTANT_WRITE_BUFFER_SPECIALIZATION(i4, long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(i4, unsigned long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(i4, float8_e4m3)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(i4, float8_e5m2)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(i4, float8_e8m0)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(i4, float16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(i4, bfloat16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(i4, float)
Expand All @@ -1262,6 +1283,7 @@ CONSTANT_WRITE_BUFFER_SPECIALIZATION(nf4, long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(nf4, unsigned long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(nf4, float8_e4m3)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(nf4, float8_e5m2)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(nf4, float8_e8m0)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(nf4, float16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(nf4, bfloat16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(nf4, float)
Expand All @@ -1281,6 +1303,7 @@ CONSTANT_WRITE_BUFFER_SPECIALIZATION(f4e2m1, long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(f4e2m1, unsigned long long)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(f4e2m1, float8_e4m3)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(f4e2m1, float8_e5m2)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(f4e2m1, float8_e8m0)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(f4e2m1, float16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(f4e2m1, bfloat16)
CONSTANT_WRITE_BUFFER_SPECIALIZATION(f4e2m1, float)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ constexpr bool is_floating_point() {
using U = typename std::decay<T>::type;
return std::is_floating_point<U>::value || std::is_same<float16, U>::value || std::is_same<bfloat16, U>::value ||
std::is_same<float8_e4m3, U>::value || std::is_same<float8_e5m2, U>::value ||
std::is_same<float4_e2m1, U>::value;
std::is_same<float4_e2m1, U>::value || std::is_same<float8_e8m0, U>::value;
}
} // namespace ov
22 changes: 20 additions & 2 deletions src/core/src/op/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ std::string Constant::convert_value_to_string(size_t index) const {
f8e4m3,
f8e5m2,
string,
f4e2m1>::apply<ValueToString>(get_element_type(), get_data_ptr(), index);
f4e2m1,
f8e8m0>::apply<ValueToString>(get_element_type(), get_data_ptr(), index);
}

size_t Constant::get_byte_size() const {
Expand Down Expand Up @@ -410,7 +411,8 @@ std::vector<std::string> Constant::get_value_strings() const {
u64,
nf4,
string,
f4e2m1>::apply<ValuesToString>(get_element_type(), get_data_ptr(), shape_size(m_shape), out);
f4e2m1,
f8e8m0>::apply<ValuesToString>(get_element_type(), get_data_ptr(), shape_size(m_shape), out);
return out;
}

Expand Down Expand Up @@ -777,6 +779,7 @@ CONSTANT_FILL_DATA(u1, long long)
CONSTANT_FILL_DATA(u1, unsigned long long)
CONSTANT_FILL_DATA(u1, float8_e4m3)
CONSTANT_FILL_DATA(u1, float8_e5m2)
CONSTANT_FILL_DATA(u1, float8_e8m0)
CONSTANT_FILL_DATA(u1, float16)
CONSTANT_FILL_DATA(u1, bfloat16)
CONSTANT_FILL_DATA(u1, float)
Expand All @@ -796,6 +799,7 @@ CONSTANT_FILL_DATA(u2, long long)
CONSTANT_FILL_DATA(u2, unsigned long long)
CONSTANT_FILL_DATA(u2, float8_e4m3)
CONSTANT_FILL_DATA(u2, float8_e5m2)
CONSTANT_FILL_DATA(u2, float8_e8m0)
CONSTANT_FILL_DATA(u2, float16)
CONSTANT_FILL_DATA(u2, bfloat16)
CONSTANT_FILL_DATA(u2, float)
Expand All @@ -815,6 +819,7 @@ CONSTANT_FILL_DATA(u3, long long)
CONSTANT_FILL_DATA(u3, unsigned long long)
CONSTANT_FILL_DATA(u3, float8_e4m3)
CONSTANT_FILL_DATA(u3, float8_e5m2)
CONSTANT_FILL_DATA(u3, float8_e8m0)
CONSTANT_FILL_DATA(u3, float16)
CONSTANT_FILL_DATA(u3, bfloat16)
CONSTANT_FILL_DATA(u3, float)
Expand All @@ -834,6 +839,7 @@ CONSTANT_FILL_DATA(u4, long long)
CONSTANT_FILL_DATA(u4, unsigned long long)
CONSTANT_FILL_DATA(u4, float8_e4m3)
CONSTANT_FILL_DATA(u4, float8_e5m2)
CONSTANT_FILL_DATA(u4, float8_e8m0)
CONSTANT_FILL_DATA(u4, float16)
CONSTANT_FILL_DATA(u4, bfloat16)
CONSTANT_FILL_DATA(u4, float)
Expand All @@ -853,6 +859,7 @@ CONSTANT_FILL_DATA(u6, long long)
CONSTANT_FILL_DATA(u6, unsigned long long)
CONSTANT_FILL_DATA(u6, float8_e4m3)
CONSTANT_FILL_DATA(u6, float8_e5m2)
CONSTANT_FILL_DATA(u6, float8_e8m0)
CONSTANT_FILL_DATA(u6, float16)
CONSTANT_FILL_DATA(u6, bfloat16)
CONSTANT_FILL_DATA(u6, float)
Expand All @@ -872,6 +879,7 @@ CONSTANT_FILL_DATA(i4, long long)
CONSTANT_FILL_DATA(i4, unsigned long long)
CONSTANT_FILL_DATA(i4, float8_e4m3)
CONSTANT_FILL_DATA(i4, float8_e5m2)
CONSTANT_FILL_DATA(i4, float8_e8m0)
CONSTANT_FILL_DATA(i4, float16)
CONSTANT_FILL_DATA(i4, bfloat16)
CONSTANT_FILL_DATA(i4, float)
Expand All @@ -891,6 +899,7 @@ CONSTANT_FILL_DATA(nf4, long long)
CONSTANT_FILL_DATA(nf4, unsigned long long)
CONSTANT_FILL_DATA(nf4, float8_e4m3)
CONSTANT_FILL_DATA(nf4, float8_e5m2)
CONSTANT_FILL_DATA(nf4, float8_e8m0)
CONSTANT_FILL_DATA(nf4, float16)
CONSTANT_FILL_DATA(nf4, bfloat16)
CONSTANT_FILL_DATA(nf4, float)
Expand All @@ -910,6 +919,7 @@ CONSTANT_FILL_DATA(f4e2m1, long long)
CONSTANT_FILL_DATA(f4e2m1, unsigned long long)
CONSTANT_FILL_DATA(f4e2m1, float8_e4m3)
CONSTANT_FILL_DATA(f4e2m1, float8_e5m2)
CONSTANT_FILL_DATA(f4e2m1, float8_e8m0)
CONSTANT_FILL_DATA(f4e2m1, float16)
CONSTANT_FILL_DATA(f4e2m1, bfloat16)
CONSTANT_FILL_DATA(f4e2m1, float)
Expand Down Expand Up @@ -1065,6 +1075,7 @@ CONSTANT_WRITE_BUFFER(u1, long long)
CONSTANT_WRITE_BUFFER(u1, unsigned long long)
CONSTANT_WRITE_BUFFER(u1, float8_e4m3)
CONSTANT_WRITE_BUFFER(u1, float8_e5m2)
CONSTANT_WRITE_BUFFER(u1, float8_e8m0)
CONSTANT_WRITE_BUFFER(u1, float16)
CONSTANT_WRITE_BUFFER(u1, bfloat16)
CONSTANT_WRITE_BUFFER(u1, float)
Expand All @@ -1084,6 +1095,7 @@ CONSTANT_WRITE_BUFFER(u2, long long)
CONSTANT_WRITE_BUFFER(u2, unsigned long long)
CONSTANT_WRITE_BUFFER(u2, float8_e4m3)
CONSTANT_WRITE_BUFFER(u2, float8_e5m2)
CONSTANT_WRITE_BUFFER(u2, float8_e8m0)
CONSTANT_WRITE_BUFFER(u2, float16)
CONSTANT_WRITE_BUFFER(u2, bfloat16)
CONSTANT_WRITE_BUFFER(u2, float)
Expand All @@ -1103,6 +1115,7 @@ CONSTANT_WRITE_BUFFER(u3, long long)
CONSTANT_WRITE_BUFFER(u3, unsigned long long)
CONSTANT_WRITE_BUFFER(u3, float8_e4m3)
CONSTANT_WRITE_BUFFER(u3, float8_e5m2)
CONSTANT_WRITE_BUFFER(u3, float8_e8m0)
CONSTANT_WRITE_BUFFER(u3, float16)
CONSTANT_WRITE_BUFFER(u3, bfloat16)
CONSTANT_WRITE_BUFFER(u3, float)
Expand All @@ -1122,6 +1135,7 @@ CONSTANT_WRITE_BUFFER(u4, long long)
CONSTANT_WRITE_BUFFER(u4, unsigned long long)
CONSTANT_WRITE_BUFFER(u4, float8_e4m3)
CONSTANT_WRITE_BUFFER(u4, float8_e5m2)
CONSTANT_WRITE_BUFFER(u4, float8_e8m0)
CONSTANT_WRITE_BUFFER(u4, float16)
CONSTANT_WRITE_BUFFER(u4, bfloat16)
CONSTANT_WRITE_BUFFER(u4, float)
Expand All @@ -1141,6 +1155,7 @@ CONSTANT_WRITE_BUFFER(u6, long long)
CONSTANT_WRITE_BUFFER(u6, unsigned long long)
CONSTANT_WRITE_BUFFER(u6, float8_e4m3)
CONSTANT_WRITE_BUFFER(u6, float8_e5m2)
CONSTANT_WRITE_BUFFER(u6, float8_e8m0)
CONSTANT_WRITE_BUFFER(u6, float16)
CONSTANT_WRITE_BUFFER(u6, bfloat16)
CONSTANT_WRITE_BUFFER(u6, float)
Expand All @@ -1160,6 +1175,7 @@ CONSTANT_WRITE_BUFFER(i4, long long)
CONSTANT_WRITE_BUFFER(i4, unsigned long long)
CONSTANT_WRITE_BUFFER(i4, float8_e4m3)
CONSTANT_WRITE_BUFFER(i4, float8_e5m2)
CONSTANT_WRITE_BUFFER(i4, float8_e8m0)
CONSTANT_WRITE_BUFFER(i4, float16)
CONSTANT_WRITE_BUFFER(i4, bfloat16)
CONSTANT_WRITE_BUFFER(i4, float)
Expand All @@ -1179,6 +1195,7 @@ CONSTANT_WRITE_BUFFER(nf4, long long)
CONSTANT_WRITE_BUFFER(nf4, unsigned long long)
CONSTANT_WRITE_BUFFER(nf4, float8_e4m3)
CONSTANT_WRITE_BUFFER(nf4, float8_e5m2)
CONSTANT_WRITE_BUFFER(nf4, float8_e8m0)
CONSTANT_WRITE_BUFFER(nf4, float16)
CONSTANT_WRITE_BUFFER(nf4, bfloat16)
CONSTANT_WRITE_BUFFER(nf4, float)
Expand All @@ -1198,6 +1215,7 @@ CONSTANT_WRITE_BUFFER(f4e2m1, long long)
CONSTANT_WRITE_BUFFER(f4e2m1, unsigned long long)
CONSTANT_WRITE_BUFFER(f4e2m1, float8_e4m3)
CONSTANT_WRITE_BUFFER(f4e2m1, float8_e5m2)
CONSTANT_WRITE_BUFFER(f4e2m1, float8_e8m0)
CONSTANT_WRITE_BUFFER(f4e2m1, float16)
CONSTANT_WRITE_BUFFER(f4e2m1, bfloat16)
CONSTANT_WRITE_BUFFER(f4e2m1, float)
Expand Down
Loading

0 comments on commit 33a51b0

Please sign in to comment.