Skip to content

Commit

Permalink
feat: Add conversions itobool, ifrombool (#101)
Browse files Browse the repository at this point in the history
Closes #22

---------

Co-authored-by: Craig Roy <[email protected]>
Co-authored-by: Craig Roy <[email protected]>
Co-authored-by: Mark Koch <[email protected]>
  • Loading branch information
4 people authored Sep 16, 2024
1 parent ede3f51 commit 99ad943
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 3 deletions.
142 changes: 139 additions & 3 deletions src/custom/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use anyhow::{anyhow, Result};

use hugr::{
extension::{
prelude::{sum_with_error, ConstError},
prelude::{sum_with_error, ConstError, BOOL_T},
simple_op::MakeExtensionOp,
ExtensionId,
},
Expand All @@ -13,18 +13,19 @@ use hugr::{
conversions::{self, ConvertOpDef, ConvertOpType},
int_types::INT_TYPES,
},
types::{CustomType, TypeArg},
types::{CustomType, TypeArg, TypeEnum},
HugrView,
};

use inkwell::{types::BasicTypeEnum, values::BasicValue, FloatPredicate};
use inkwell::{types::BasicTypeEnum, values::BasicValue, FloatPredicate, IntPredicate};

use crate::{
emit::{
func::EmitFuncContext,
ops::{emit_custom_unary_op, emit_value},
EmitOp, EmitOpArgs,
},
sum::LLVMSumValue,
types::TypingSession,
};

Expand Down Expand Up @@ -186,6 +187,52 @@ impl<'c, H: HugrView> EmitOp<'c, ExtensionOp, H> for ConversionsEmitter<'c, '_,
ConvertOpDef::itousize | ConvertOpDef::ifromusize => {
emit_custom_unary_op(self.0, args, |_, arg, _| Ok(vec![arg]))
}
ConvertOpDef::itobool | ConvertOpDef::ifrombool => {
assert!(conversion_op.type_args().is_empty()); // Always 1-bit int <-> bool
let i0_ty = self
.0
.typing_session()
.llvm_type(&INT_TYPES[0])?
.into_int_type();
let sum_ty =
self.0
.typing_session()
.llvm_sum_type(match BOOL_T.as_type_enum() {
TypeEnum::Sum(st) => st.clone(),
_ => panic!("Hugr prelude BOOL_T not a Sum"),
})?;

emit_custom_unary_op(self.0, args, |ctx, arg, _| {
let res = if conversion_op.def() == &ConvertOpDef::itobool {
let is1 = ctx.builder().build_int_compare(
IntPredicate::EQ,
arg.into_int_value(),
i0_ty.const_int(1, false),
"eq1",
)?;
let sum_f = sum_ty.build_tag(ctx.builder(), 0, vec![])?;
let sum_t = sum_ty.build_tag(ctx.builder(), 1, vec![])?;
ctx.builder().build_select(is1, sum_t, sum_f, "")?
} else {
let tag_ty = sum_ty.get_tag_type();
let tag =
LLVMSumValue::try_new(arg, sum_ty)?.build_get_tag(ctx.builder())?;
let is_true = ctx.builder().build_int_compare(
IntPredicate::EQ,
tag,
tag_ty.const_int(1, false),
"",
)?;
ctx.builder().build_select(
is_true,
i0_ty.const_int(1, false),
i0_ty.const_int(0, false),
"",
)?
};
Ok(vec![res])
})
}
_ => Err(anyhow!(
"Conversion op not implemented: {:?}",
args.node().as_ref()
Expand Down Expand Up @@ -237,6 +284,7 @@ mod test {
use crate::emit::test::{SimpleHugrConfig, DFGW};
use crate::test::{exec_ctx, llvm_ctx, TestContext};
use hugr::builder::SubContainer;
use hugr::std_extensions::arithmetic::int_types::ConstInt;
use hugr::{
builder::{Dataflow, DataflowSubContainer},
extension::prelude::{ConstUsize, PRELUDE_REGISTRY, USIZE_T},
Expand Down Expand Up @@ -308,6 +356,40 @@ mod test {
check_emission!(op_name, hugr, llvm_ctx);
}

#[rstest]
#[case("itobool", true)]
#[case("ifrombool", false)]
fn test_intbool_emit(
mut llvm_ctx: TestContext,
#[case] op_name: &str,
#[case] input_int: bool,
) {
let mut tys = [INT_TYPES[0].clone(), BOOL_T];
if !input_int {
tys.reverse()
};
let [in_t, out_t] = tys;
llvm_ctx.add_extensions(add_int_extensions);
llvm_ctx.add_extensions(add_float_extensions);
llvm_ctx.add_extensions(add_conversions_extension);
let hugr = SimpleHugrConfig::new()
.with_ins(vec![in_t])
.with_outs(vec![out_t])
.with_extensions(CONVERT_OPS_REGISTRY.to_owned())
.finish(|mut hugr_builder| {
let [in1] = hugr_builder.input_wires_arr();
let ext_op = EXTENSION
.instantiate_extension_op(op_name, [], &CONVERT_OPS_REGISTRY)
.unwrap();
let [out1] = hugr_builder
.add_dataflow_op(ext_op, [in1])
.unwrap()
.outputs_arr();
hugr_builder.finish_with_outputs([out1]).unwrap()
});
check_emission!(op_name, hugr, llvm_ctx);
}

#[rstest]
fn my_test_exec(mut exec_ctx: TestContext) {
let hugr = SimpleHugrConfig::new()
Expand Down Expand Up @@ -430,4 +512,58 @@ mod test {
let hugr = roundtrip_hugr(val);
assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
}

#[rstest]
fn itobool_cond(mut exec_ctx: TestContext, #[values(0, 1)] i: u64) {
use hugr::type_row;

let hugr = SimpleHugrConfig::new()
.with_outs(vec![USIZE_T])
.with_extensions(CONVERT_OPS_REGISTRY.to_owned())
.finish(|mut builder| {
let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap());
let ext_op = EXTENSION
.instantiate_extension_op("itobool", [], &CONVERT_OPS_REGISTRY)
.unwrap();
let [b] = builder.add_dataflow_op(ext_op, [i]).unwrap().outputs_arr();
let mut cond = builder
.conditional_builder(([type_row![], type_row![]], b), [], type_row![USIZE_T])
.unwrap();
let mut case_false = cond.case_builder(0).unwrap();
let false_result = case_false.add_load_value(ConstUsize::new(1));
case_false.finish_with_outputs([false_result]).unwrap();
let mut case_true = cond.case_builder(1).unwrap();
let true_result = case_true.add_load_value(ConstUsize::new(6));
case_true.finish_with_outputs([true_result]).unwrap();
let res = cond.finish_sub_container().unwrap();
builder.finish_with_outputs(res.outputs()).unwrap()
});
exec_ctx.add_extensions(add_conversions_extension);
exec_ctx.add_extensions(add_default_prelude_extensions);
exec_ctx.add_extensions(add_int_extensions);
assert_eq!(i * 5 + 1, exec_ctx.exec_hugr_u64(hugr, "main"));
}

#[rstest]
fn itobool_roundtrip(mut exec_ctx: TestContext, #[values(0, 1)] i: u64) {
let hugr = SimpleHugrConfig::new()
.with_outs(vec![INT_TYPES[0].clone()])
.with_extensions(CONVERT_OPS_REGISTRY.to_owned())
.finish(|mut builder| {
let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap());
let i2b = EXTENSION
.instantiate_extension_op("itobool", [], &CONVERT_OPS_REGISTRY)
.unwrap();
let [b] = builder.add_dataflow_op(i2b, [i]).unwrap().outputs_arr();
let b2i = EXTENSION
.instantiate_extension_op("ifrombool", [], &CONVERT_OPS_REGISTRY)
.unwrap();
let [i] = builder.add_dataflow_op(b2i, [b]).unwrap().outputs_arr();
builder.finish_with_outputs([i]).unwrap()
});
exec_ctx.add_extensions(add_conversions_extension);
exec_ctx.add_extensions(add_default_prelude_extensions);
exec_ctx.add_extensions(add_int_extensions);
assert_eq!(i, exec_ctx.exec_hugr_u64(hugr, "main"));
}
}
17 changes: 17 additions & 0 deletions src/custom/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
source: src/custom/conversions.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1({ i32, {}, {} } %0) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%1 = extractvalue { i32, {}, {} } %0, 0
%2 = icmp eq i32 %1, 1
%3 = select i1 %2, i8 1, i8 0
ret i8 %3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
source: src/custom/conversions.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1({ i32, {}, {} } %0) {
alloca_block:
%"0" = alloca i8, align 1
%"2_0" = alloca { i32, {}, {} }, align 8
%"4_0" = alloca i8, align 1
br label %entry_block

entry_block: ; preds = %alloca_block
store { i32, {}, {} } %0, { i32, {}, {} }* %"2_0", align 4
%"2_01" = load { i32, {}, {} }, { i32, {}, {} }* %"2_0", align 4
%1 = extractvalue { i32, {}, {} } %"2_01", 0
%2 = icmp eq i32 %1, 1
%3 = select i1 %2, i8 1, i8 0
store i8 %3, i8* %"4_0", align 1
%"4_02" = load i8, i8* %"4_0", align 1
store i8 %"4_02", i8* %"0", align 1
%"03" = load i8, i8* %"0", align 1
ret i8 %"03"
}
16 changes: 16 additions & 0 deletions src/custom/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
source: src/custom/conversions.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define { i32, {}, {} } @_hl.main.1(i8 %0) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%eq1 = icmp eq i8 %0, 1
%1 = select i1 %eq1, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
ret { i32, {}, {} } %1
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
---
source: src/custom/conversions.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define { i32, {}, {} } @_hl.main.1(i8 %0) {
alloca_block:
%"0" = alloca { i32, {}, {} }, align 8
%"2_0" = alloca i8, align 1
%"4_0" = alloca { i32, {}, {} }, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store i8 %0, i8* %"2_0", align 1
%"2_01" = load i8, i8* %"2_0", align 1
%eq1 = icmp eq i8 %"2_01", 1
%1 = select i1 %eq1, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
store { i32, {}, {} } %1, { i32, {}, {} }* %"4_0", align 4
%"4_02" = load { i32, {}, {} }, { i32, {}, {} }* %"4_0", align 4
store { i32, {}, {} } %"4_02", { i32, {}, {} }* %"0", align 4
%"03" = load { i32, {}, {} }, { i32, {}, {} }* %"0", align 4
ret { i32, {}, {} } %"03"
}

0 comments on commit 99ad943

Please sign in to comment.