Skip to content

Commit

Permalink
[FP16] Implement unary operations. (#6867)
Browse files Browse the repository at this point in the history
  • Loading branch information
brendandahl authored Aug 27, 2024
1 parent 459bc07 commit 6c2d0e2
Show file tree
Hide file tree
Showing 17 changed files with 493 additions and 70 deletions.
7 changes: 7 additions & 0 deletions scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,9 @@
("i64x2.extmul_high_i32x4_s", "makeBinary(BinaryOp::ExtMulHighSVecI64x2)"),
("i64x2.extmul_low_i32x4_u", "makeBinary(BinaryOp::ExtMulLowUVecI64x2)"),
("i64x2.extmul_high_i32x4_u", "makeBinary(BinaryOp::ExtMulHighUVecI64x2)"),
("f16x8.abs", "makeUnary(UnaryOp::AbsVecF16x8)"),
("f16x8.neg", "makeUnary(UnaryOp::NegVecF16x8)"),
("f16x8.sqrt", "makeUnary(UnaryOp::SqrtVecF16x8)"),
("f16x8.add", "makeBinary(BinaryOp::AddVecF16x8)"),
("f16x8.sub", "makeBinary(BinaryOp::SubVecF16x8)"),
("f16x8.mul", "makeBinary(BinaryOp::MulVecF16x8)"),
Expand All @@ -460,6 +463,10 @@
("f16x8.max", "makeBinary(BinaryOp::MaxVecF16x8)"),
("f16x8.pmin", "makeBinary(BinaryOp::PMinVecF16x8)"),
("f16x8.pmax", "makeBinary(BinaryOp::PMaxVecF16x8)"),
("f16x8.ceil", "makeUnary(UnaryOp::CeilVecF16x8)"),
("f16x8.floor", "makeUnary(UnaryOp::FloorVecF16x8)"),
("f16x8.trunc", "makeUnary(UnaryOp::TruncVecF16x8)"),
("f16x8.nearest", "makeUnary(UnaryOp::NearestVecF16x8)"),
("f32x4.abs", "makeUnary(UnaryOp::AbsVecF32x4)"),
("f32x4.neg", "makeUnary(UnaryOp::NegVecF32x4)"),
("f32x4.sqrt", "makeUnary(UnaryOp::SqrtVecF32x4)"),
Expand Down
68 changes: 60 additions & 8 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,26 @@ switch (buf[0]) {
switch (buf[1]) {
case '1': {
switch (buf[6]) {
case 'a':
if (op == "f16x8.add"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::AddVecF16x8));
case 'a': {
switch (buf[7]) {
case 'b':
if (op == "f16x8.abs"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::AbsVecF16x8));
return Ok{};
}
goto parse_error;
case 'd':
if (op == "f16x8.add"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::AddVecF16x8));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
}
case 'c':
if (op == "f16x8.ceil"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::CeilVecF16x8));
return Ok{};
}
goto parse_error;
Expand All @@ -338,6 +355,12 @@ switch (buf[0]) {
default: goto parse_error;
}
}
case 'f':
if (op == "f16x8.floor"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::FloorVecF16x8));
return Ok{};
}
goto parse_error;
case 'g': {
switch (buf[7]) {
case 'e':
Expand Down Expand Up @@ -395,12 +418,29 @@ switch (buf[0]) {
default: goto parse_error;
}
}
case 'n':
if (op == "f16x8.ne"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::NeVecF16x8));
return Ok{};
case 'n': {
switch (buf[8]) {
case '\0':
if (op == "f16x8.ne"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::NeVecF16x8));
return Ok{};
}
goto parse_error;
case 'a':
if (op == "f16x8.nearest"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::NearestVecF16x8));
return Ok{};
}
goto parse_error;
case 'g':
if (op == "f16x8.neg"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::NegVecF16x8));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
goto parse_error;
}
case 'p': {
switch (buf[8]) {
case 'a':
Expand Down Expand Up @@ -432,6 +472,12 @@ switch (buf[0]) {
return Ok{};
}
goto parse_error;
case 'q':
if (op == "f16x8.sqrt"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::SqrtVecF16x8));
return Ok{};
}
goto parse_error;
case 'u':
if (op == "f16x8.sub"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::SubVecF16x8));
Expand All @@ -441,6 +487,12 @@ switch (buf[0]) {
default: goto parse_error;
}
}
case 't':
if (op == "f16x8.trunc"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::TruncVecF16x8));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/ir/child-typer.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,13 @@ template<typename Subtype> struct ChildTyper : OverriddenVisitor<Subtype> {
case NegVecI16x8:
case NegVecI32x4:
case NegVecI64x2:
case AbsVecF16x8:
case NegVecF16x8:
case SqrtVecF16x8:
case CeilVecF16x8:
case FloorVecF16x8:
case TruncVecF16x8:
case NearestVecF16x8:
case AbsVecF32x4:
case NegVecF32x4:
case SqrtVecF32x4:
Expand Down
7 changes: 7 additions & 0 deletions src/ir/cost.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
case NegVecI64x2:
case AllTrueVecI64x2:
case BitmaskVecI64x2:
case AbsVecF16x8:
case NegVecF16x8:
case SqrtVecF16x8:
case CeilVecF16x8:
case FloorVecF16x8:
case TruncVecF16x8:
case NearestVecF16x8:
case AbsVecF32x4:
case NegVecF32x4:
case SqrtVecF32x4:
Expand Down
7 changes: 7 additions & 0 deletions src/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ class Literal {
Literal extMulHighSI64x2(const Literal& other) const;
Literal extMulLowUI64x2(const Literal& other) const;
Literal extMulHighUI64x2(const Literal& other) const;
Literal absF16x8() const;
Literal negF16x8() const;
Literal sqrtF16x8() const;
Literal addF16x8(const Literal& other) const;
Literal subF16x8(const Literal& other) const;
Literal mulF16x8(const Literal& other) const;
Expand All @@ -626,6 +629,10 @@ class Literal {
Literal maxF16x8(const Literal& other) const;
Literal pminF16x8(const Literal& other) const;
Literal pmaxF16x8(const Literal& other) const;
Literal ceilF16x8() const;
Literal floorF16x8() const;
Literal truncF16x8() const;
Literal nearestF16x8() const;
Literal absF32x4() const;
Literal negF32x4() const;
Literal sqrtF32x4() const;
Expand Down
21 changes: 21 additions & 0 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,27 @@ struct PrintExpressionContents
case BitmaskVecI64x2:
o << "i64x2.bitmask";
break;
case AbsVecF16x8:
o << "f16x8.abs";
break;
case NegVecF16x8:
o << "f16x8.neg";
break;
case SqrtVecF16x8:
o << "f16x8.sqrt";
break;
case CeilVecF16x8:
o << "f16x8.ceil";
break;
case FloorVecF16x8:
o << "f16x8.floor";
break;
case TruncVecF16x8:
o << "f16x8.trunc";
break;
case NearestVecF16x8:
o << "f16x8.nearest";
break;
case AbsVecF32x4:
o << "f32x4.abs";
break;
Expand Down
61 changes: 36 additions & 25 deletions src/tools/fuzzing/fuzzing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3109,31 +3109,42 @@ Expression* TranslateToFuzzReader::makeUnary(Type type) {
case 3:
return buildUnary({SplatVecF64x2, make(Type::f64)});
case 4:
return buildUnary({pick(NotVec128,
// TODO: add additional SIMD instructions
NegVecI8x16,
NegVecI16x8,
NegVecI32x4,
NegVecI64x2,
AbsVecF32x4,
NegVecF32x4,
SqrtVecF32x4,
AbsVecF64x2,
NegVecF64x2,
SqrtVecF64x2,
TruncSatSVecF32x4ToVecI32x4,
TruncSatUVecF32x4ToVecI32x4,
ConvertSVecI32x4ToVecF32x4,
ConvertUVecI32x4ToVecF32x4,
ExtendLowSVecI8x16ToVecI16x8,
ExtendHighSVecI8x16ToVecI16x8,
ExtendLowUVecI8x16ToVecI16x8,
ExtendHighUVecI8x16ToVecI16x8,
ExtendLowSVecI16x8ToVecI32x4,
ExtendHighSVecI16x8ToVecI32x4,
ExtendLowUVecI16x8ToVecI32x4,
ExtendHighUVecI16x8ToVecI32x4),
make(Type::v128)});
return buildUnary(
{pick(FeatureOptions<UnaryOp>()
.add(FeatureSet::SIMD,
NotVec128,
// TODO: add additional SIMD instructions
NegVecI8x16,
NegVecI16x8,
NegVecI32x4,
NegVecI64x2,
AbsVecF32x4,
NegVecF32x4,
SqrtVecF32x4,
AbsVecF64x2,
NegVecF64x2,
SqrtVecF64x2,
TruncSatSVecF32x4ToVecI32x4,
TruncSatUVecF32x4ToVecI32x4,
ConvertSVecI32x4ToVecF32x4,
ConvertUVecI32x4ToVecF32x4,
ExtendLowSVecI8x16ToVecI16x8,
ExtendHighSVecI8x16ToVecI16x8,
ExtendLowUVecI8x16ToVecI16x8,
ExtendHighUVecI8x16ToVecI16x8,
ExtendLowSVecI16x8ToVecI32x4,
ExtendHighSVecI16x8ToVecI32x4,
ExtendLowUVecI16x8ToVecI32x4,
ExtendHighUVecI16x8ToVecI32x4)
.add(FeatureSet::FP16,
AbsVecF16x8,
NegVecF16x8,
SqrtVecF16x8,
CeilVecF16x8,
FloorVecF16x8,
TruncVecF16x8,
NearestVecF16x8)),
make(Type::v128)});
}
WASM_UNREACHABLE("invalid value");
}
Expand Down
7 changes: 7 additions & 0 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,13 @@ enum ASTNodes {
F16x8Splat = 0x120,
F16x8ExtractLane = 0x121,
F16x8ReplaceLane = 0x122,
F16x8Abs = 0x130,
F16x8Neg = 0x131,
F16x8Sqrt = 0x132,
F16x8Ceil = 0x133,
F16x8Floor = 0x134,
F16x8Trunc = 0x135,
F16x8Nearest = 0x136,
F16x8Eq = 0x137,
F16x8Ne = 0x138,
F16x8Lt = 0x139,
Expand Down
14 changes: 14 additions & 0 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,20 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
return value.allTrueI64x2();
case BitmaskVecI64x2:
return value.bitmaskI64x2();
case AbsVecF16x8:
return value.absF16x8();
case NegVecF16x8:
return value.negF16x8();
case SqrtVecF16x8:
return value.sqrtF16x8();
case CeilVecF16x8:
return value.ceilF16x8();
case FloorVecF16x8:
return value.floorF16x8();
case TruncVecF16x8:
return value.truncF16x8();
case NearestVecF16x8:
return value.nearestF16x8();
case AbsVecF32x4:
return value.absF32x4();
case NegVecF32x4:
Expand Down
7 changes: 7 additions & 0 deletions src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ enum UnaryOp {
NegVecI64x2,
AllTrueVecI64x2,
BitmaskVecI64x2,
AbsVecF16x8,
NegVecF16x8,
SqrtVecF16x8,
CeilVecF16x8,
FloorVecF16x8,
TruncVecF16x8,
NearestVecF16x8,
AbsVecF32x4,
NegVecF32x4,
SqrtVecF32x4,
Expand Down
36 changes: 29 additions & 7 deletions src/wasm/literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1842,13 +1842,19 @@ Literal Literal::replaceLaneF64x2(const Literal& other, uint8_t index) const {
return replace<2, &Literal::getLanesF64x2>(*this, other, index);
}

static Literal passThrough(const Literal& literal) { return literal; }
static Literal toFP16(const Literal& literal) {
return literal.convertF32ToF16();
}

template<int Lanes,
LaneArray<Lanes> (Literal::*IntoLanes)() const,
Literal (Literal::*UnaryOp)(void) const>
Literal (Literal::*UnaryOp)(void) const,
Literal (*Convert)(const Literal&) = passThrough>
static Literal unary(const Literal& val) {
LaneArray<Lanes> lanes = (val.*IntoLanes)();
for (size_t i = 0; i < Lanes; ++i) {
lanes[i] = (lanes[i].*UnaryOp)();
lanes[i] = Convert((lanes[i].*UnaryOp)());
}
return Literal(lanes);
}
Expand Down Expand Up @@ -1885,6 +1891,27 @@ Literal Literal::negI32x4() const {
Literal Literal::negI64x2() const {
return unary<2, &Literal::getLanesI64x2, &Literal::neg>(*this);
}
Literal Literal::absF16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::abs, &toFP16>(*this);
}
Literal Literal::negF16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::neg, &toFP16>(*this);
}
Literal Literal::sqrtF16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::sqrt, &toFP16>(*this);
}
Literal Literal::ceilF16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::ceil, &toFP16>(*this);
}
Literal Literal::floorF16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::floor, &toFP16>(*this);
}
Literal Literal::truncF16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::trunc, &toFP16>(*this);
}
Literal Literal::nearestF16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::nearbyint, &toFP16>(*this);
}
Literal Literal::absF32x4() const {
return unary<4, &Literal::getLanesF32x4, &Literal::abs>(*this);
}
Expand Down Expand Up @@ -2271,11 +2298,6 @@ Literal Literal::geF64x2(const Literal& other) const {
other);
}

static Literal passThrough(const Literal& literal) { return literal; }
static Literal toFP16(const Literal& literal) {
return literal.convertF32ToF16();
}

template<int Lanes,
LaneArray<Lanes> (Literal::*IntoLanes)() const,
Literal (Literal::*BinaryOp)(const Literal&) const,
Expand Down
Loading

0 comments on commit 6c2d0e2

Please sign in to comment.