diff --git a/cprover_bindings/src/goto_program/expr.rs b/cprover_bindings/src/goto_program/expr.rs index e8feeb386a01..1b87ed8737c5 100644 --- a/cprover_bindings/src/goto_program/expr.rs +++ b/cprover_bindings/src/goto_program/expr.rs @@ -1328,7 +1328,7 @@ impl Expr { /// `self == 0` pub fn is_zero(self) -> Self { - assert!(self.typ.is_numeric()); + assert!(self.typ.is_numeric() || self.typ.is_pointer()); let typ = self.typ.clone(); self.eq(typ.zero()) } diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/operand.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/operand.rs index 723fa4a36606..8c700d2d07ce 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/operand.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/operand.rs @@ -608,12 +608,16 @@ impl<'tcx> GotocCtx<'tcx> { /// fetch the niche value (as both left and right value) pub fn codegen_get_niche(&self, v: Expr, offset: Size, niche_ty: Type) -> Expr { - v // t: T - .address_of() // &t: T* - .cast_to(Type::unsigned_int(8).to_pointer()) // (u8 *)&t: u8 * - .plus(Expr::int_constant(offset.bytes_usize(), Type::size_t())) // ((u8 *)&t) + offset: u8 * - .cast_to(niche_ty.to_pointer()) // (N *)(((u8 *)&t) + offset): N * - .dereference() // *(N *)(((u8 *)&t) + offset): N + if offset == Size::ZERO { + v.reinterpret_cast(niche_ty) + } else { + v // t: T + .address_of() // &t: T* + .cast_to(Type::unsigned_int(8).to_pointer()) // (u8 *)&t: u8 * + .plus(Expr::int_constant(offset.bytes(), Type::size_t())) // ((u8 *)&t) + offset: u8 * + .cast_to(niche_ty.to_pointer()) // (N *)(((u8 *)&t) + offset): N * + .dereference() // *(N *)(((u8 *)&t) + offset): N + } } /// Ensure that the given instance is in the symbol table, returning the symbol. diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/rvalue.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/rvalue.rs index 4d03e34acd4c..9c32072b84aa 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/rvalue.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/rvalue.rs @@ -14,7 +14,7 @@ use rustc_middle::mir::{AggregateKind, BinOp, CastKind, NullOp, Operand, Place, use rustc_middle::ty::adjustment::PointerCast; use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::{self, Instance, IntTy, Ty, TyCtxt, UintTy, VtblEntry}; -use rustc_target::abi::{FieldsShape, Primitive, TagEncoding, Variants}; +use rustc_target::abi::{FieldsShape, TagEncoding, Variants}; use tracing::{debug, warn}; impl<'tcx> GotocCtx<'tcx> { @@ -503,12 +503,14 @@ impl<'tcx> GotocCtx<'tcx> { .map_or(index.as_u32() as u128, |discr| discr.val); Expr::int_constant(discr_val, self.codegen_ty(res_ty)) } - Variants::Multiple { tag, tag_encoding, .. } => match tag_encoding { + Variants::Multiple { tag_encoding, .. } => match tag_encoding { TagEncoding::Direct => { self.codegen_discriminant_field(e, ty).cast_to(self.codegen_ty(res_ty)) } TagEncoding::Niche { dataful_variant, niche_variants, niche_start } => { - // This code follows the logic in the cranelift codegen backend: + // This code follows the logic in the ssa codegen backend: + // https://github.com/rust-lang/rust/blob/fee75fbe11b1fad5d93c723234178b2a329a3c03/compiler/rustc_codegen_ssa/src/mir/place.rs#L247 + // See also the cranelift backend: // https://github.com/rust-lang/rust/blob/05d22212e89588e7c443cc6b9bc0e4e02fdfbc8d/compiler/rustc_codegen_cranelift/src/discriminant.rs#L116 let offset = match &layout.fields { FieldsShape::Arbitrary { offsets, .. } => offsets[0], @@ -523,36 +525,35 @@ impl<'tcx> GotocCtx<'tcx> { // https://github.com/rust-lang/rust/blob/fee75fbe11b1fad5d93c723234178b2a329a3c03/compiler/rustc_codegen_ssa/src/mir/place.rs#L247 // // Note: niche_variants can only represent values that fit in a u32. + let result_type = self.codegen_ty(res_ty); let discr_mir_ty = self.codegen_enum_discr_typ(ty); let discr_type = self.codegen_ty(discr_mir_ty); - let niche_val = self.codegen_get_niche(e, offset, discr_type.clone()); + let niche_val = self.codegen_get_niche(e, offset, discr_type); let relative_discr = wrapping_sub(&niche_val, u64::try_from(*niche_start).unwrap()); let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32(); - let is_niche = if tag.primitive() == Primitive::Pointer { - tracing::trace!(?tag, "Primitive::Pointer"); - discr_type.null().eq(relative_discr.clone()) + let is_niche = if relative_max == 0 { + relative_discr.clone().is_zero() } else { - tracing::trace!(?tag, "Not Primitive::Pointer"); relative_discr .clone() .le(Expr::int_constant(relative_max, relative_discr.typ().clone())) }; let niche_discr = { let relative_discr = if relative_max == 0 { - self.codegen_ty(res_ty).zero() + result_type.zero() } else { - relative_discr.cast_to(self.codegen_ty(res_ty)) + relative_discr.cast_to(result_type.clone()) }; relative_discr.plus(Expr::int_constant( niche_variants.start().as_u32(), - self.codegen_ty(res_ty), + result_type.clone(), )) }; is_niche.ternary( niche_discr, - Expr::int_constant(dataful_variant.as_u32(), self.codegen_ty(res_ty)), + Expr::int_constant(dataful_variant.as_u32(), result_type), ) } },