Skip to content

Commit

Permalink
Remove verify from bytecode now that is is on attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK committed Jan 31, 2023
1 parent c5b7a11 commit 2162f56
Showing 1 changed file with 0 additions and 37 deletions.
37 changes: 0 additions & 37 deletions stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -963,18 +963,6 @@ void VhloBytecodeInterface::write(TypeExtensionsV1Attr attr,
// Forked Attributes
//===----------------------------------------------------------------------===//

namespace {
template <typename AttrOrType>
void verifyFromVhlo(AttrOrType val) {
if (val.getDialect().getNamespace() !=
vhlo::VhloDialect::getDialectNamespace()) {
LLVM_DEBUG(llvm::dbgs() << "Not vhlo: " << val << '\n');
llvm::report_fatal_error(
"All types and attributes must be VHLO for bytecode.");
}
}
} // namespace

//===----------------------------------------------------------------------===//
// ArrayV1Attr

Expand All @@ -983,14 +971,11 @@ ArrayV1Attr VhloBytecodeInterface::readArrayV1Attr(
LOG_READ_CALL;
SmallVector<Attribute> elements;
if (failed(reader.readAttributes(elements))) return ArrayV1Attr();

llvm::for_each(elements, verifyFromVhlo<Attribute>);
return ArrayV1Attr::get(getContext(), elements);
}

void VhloBytecodeInterface::write(ArrayV1Attr attr,
DialectBytecodeWriter &writer) const {
llvm::for_each(attr.getValue(), verifyFromVhlo<Attribute>);
writer.writeVarInt(vhlo_encoding::kArrayAttr);
writer.writeAttributes(attr.getValue());
}
Expand All @@ -1006,13 +991,11 @@ VhloBytecodeInterface::readDenseIntOrFPElementsV1Attr(
ArrayRef<char> blob;
if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
return DenseIntOrFPElementsV1Attr();
verifyFromVhlo(type);
return DenseIntOrFPElementsV1Attr::get(getContext(), type, blob);
}

void VhloBytecodeInterface::write(DenseIntOrFPElementsV1Attr attr,
DialectBytecodeWriter &writer) const {
verifyFromVhlo(attr.getType());
writer.writeVarInt(vhlo_encoding::kDenseIntOrFPElementsAttr);
writer.writeType(attr.getType());
writer.writeOwnedBlob(attr.getRawData());
Expand All @@ -1037,7 +1020,6 @@ FloatV1Attr VhloBytecodeInterface::readFloatV1Attr(
LOG_READ_CALL;
Type type;
if (failed(reader.readType(type))) return FloatV1Attr();
verifyFromVhlo(type);

FailureOr<APFloat> value =
reader.readAPFloatWithKnownSemantics(getFloatSemantics(type));
Expand All @@ -1048,7 +1030,6 @@ FloatV1Attr VhloBytecodeInterface::readFloatV1Attr(

void VhloBytecodeInterface::write(FloatV1Attr attr,
DialectBytecodeWriter &writer) const {
verifyFromVhlo(attr.getType());
writer.writeVarInt(vhlo_encoding::kFloatAttr);
writer.writeType(attr.getType());
writer.writeAPFloatWithKnownSemantics(attr.getValue());
Expand All @@ -1062,13 +1043,11 @@ FlatSymbolRefV1Attr VhloBytecodeInterface::readFlatSymbolRefV1Attr(
LOG_READ_CALL;
Attribute rootReference;
if (failed(reader.readAttribute(rootReference))) return FlatSymbolRefV1Attr();
verifyFromVhlo(rootReference);
return FlatSymbolRefV1Attr::get(getContext(), rootReference);
}

void VhloBytecodeInterface::write(FlatSymbolRefV1Attr attr,
DialectBytecodeWriter &writer) const {
verifyFromVhlo(attr.getRootReference());
writer.writeVarInt(vhlo_encoding::kFlatSymbolRefAttr);
writer.writeAttribute(attr.getRootReference());
}
Expand All @@ -1093,7 +1072,6 @@ IntegerV1Attr VhloBytecodeInterface::readIntegerV1Attr(
LOG_READ_CALL;
Type type;
if (failed(reader.readType(type))) return IntegerV1Attr();
verifyFromVhlo(type);

// Extract the value storage width from the type.
unsigned bitWidth;
Expand All @@ -1110,7 +1088,6 @@ IntegerV1Attr VhloBytecodeInterface::readIntegerV1Attr(

void VhloBytecodeInterface::write(IntegerV1Attr attr,
DialectBytecodeWriter &writer) const {
verifyFromVhlo(attr.getType());
writer.writeVarInt(vhlo_encoding::kIntegerAttr);
writer.writeType(attr.getType());
writer.writeAPIntWithKnownWidth(attr.getValue());
Expand Down Expand Up @@ -1311,13 +1288,11 @@ ComplexV1Type VhloBytecodeInterface::readComplexType(
LOG_READ_CALL;
Type elementType;
if (failed(reader.readType(elementType))) return ComplexV1Type();
verifyFromVhlo(elementType);
return ComplexV1Type::get(getContext(), elementType);
}

void VhloBytecodeInterface::write(ComplexV1Type type,
DialectBytecodeWriter &writer) const {
verifyFromVhlo(type.getElementType());
writer.writeVarInt(vhlo_encoding::kComplexType);
writer.writeType(type.getElementType());
}
Expand All @@ -1331,23 +1306,19 @@ RankedTensorV1Type VhloBytecodeInterface::readRankedTensorType(
Attribute encoding;
if (hasEncoding && failed(reader.readAttribute(encoding)))
return RankedTensorV1Type();
if (hasEncoding) verifyFromVhlo(encoding);

SmallVector<int64_t> shape;
Type elementType;
if (failed(reader.readSignedVarInts(shape)) ||
failed(reader.readType(elementType)))
return RankedTensorV1Type();
verifyFromVhlo(elementType);

return RankedTensorV1Type::get(getContext(), shape, elementType, encoding);
}

void VhloBytecodeInterface::write(RankedTensorV1Type type,
DialectBytecodeWriter &writer) const {
verifyFromVhlo(type.getElementType());
if (Attribute encoding = type.getEncoding()) {
verifyFromVhlo(encoding);
writer.writeVarInt(vhlo_encoding::kRankedTensorTypeWithEncoding);
writer.writeAttribute(encoding);
} else {
Expand All @@ -1366,13 +1337,11 @@ TupleV1Type VhloBytecodeInterface::readTupleType(
SmallVector<Type> elements;
if (failed(reader.readTypes(elements))) return TupleV1Type();

llvm::for_each(elements, verifyFromVhlo<Type>);
return TupleV1Type::get(getContext(), elements);
}

void VhloBytecodeInterface::write(TupleV1Type type,
DialectBytecodeWriter &writer) const {
llvm::for_each(type.getTypes(), verifyFromVhlo<Type>);
writer.writeVarInt(vhlo_encoding::kTupleType);
writer.writeTypes(type.getTypes());
}
Expand All @@ -1398,17 +1367,13 @@ UniformQuantizedV1Type VhloBytecodeInterface::readUniformQuantizedType(
return reader.emitError("invalid UniformQuantizedType"),
UniformQuantizedV1Type();

verifyFromVhlo(storageType);
verifyFromVhlo(expressedType);
return UniformQuantizedV1Type::get(getContext(), flags, storageType,
expressedType, scale.value(), zeroPoint,
storageTypeMin, storageTypeMax);
}

void VhloBytecodeInterface::write(UniformQuantizedV1Type type,
DialectBytecodeWriter &writer) const {
verifyFromVhlo(type.getStorageType());
verifyFromVhlo(type.getExpressedType());
writer.writeVarInt(vhlo_encoding::kUniformQuantizedType);
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
Expand All @@ -1428,13 +1393,11 @@ UnrankedTensorV1Type VhloBytecodeInterface::readUnrankedTensorType(
Type elementType;
if (failed(reader.readType(elementType))) return UnrankedTensorV1Type();

verifyFromVhlo(elementType);
return UnrankedTensorV1Type::get(getContext(), elementType);
}

void VhloBytecodeInterface::write(UnrankedTensorV1Type type,
DialectBytecodeWriter &writer) const {
verifyFromVhlo(type.getElementType());
writer.writeVarInt(vhlo_encoding::kUnrankedTensorType);
writer.writeType(type.getElementType());
}
Expand Down

0 comments on commit 2162f56

Please sign in to comment.