Skip to content

Commit

Permalink
Add missing iterators
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Saveau <[email protected]>
  • Loading branch information
SUPERCILEX authored and TheDan64 committed Jan 14, 2024
1 parent b545f7c commit 4909eaa
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 88 deletions.
22 changes: 22 additions & 0 deletions src/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ impl<'ctx> BasicBlock<'ctx> {
unsafe { Some(InstructionValue::new(value)) }
}

/// Get an instruction iterator
pub fn get_instructions(self) -> InstructionIter<'ctx> {
InstructionIter(self.get_first_instruction())
}

/// Removes this `BasicBlock` from its parent `FunctionValue`.
/// It returns `Err(())` when it has no parent to remove from.
///
Expand Down Expand Up @@ -597,3 +602,20 @@ impl fmt::Debug for BasicBlock<'_> {
.finish()
}
}

/// Iterate over all `InstructionValue`s in a basic block.
#[derive(Debug)]
pub struct InstructionIter<'ctx>(Option<InstructionValue<'ctx>>);

impl<'ctx> Iterator for InstructionIter<'ctx> {
type Item = InstructionValue<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
if let Some(instr) = self.0 {
self.0 = instr.get_next_instruction();
Some(instr)
} else {
None
}
}
}
78 changes: 14 additions & 64 deletions src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1573,96 +1573,46 @@ pub enum FlagBehavior {

/// Iterate over all `FunctionValue`s in an llvm module
#[derive(Debug)]
pub struct FunctionIterator<'ctx>(FunctionIteratorInner<'ctx>);

/// Inner type so the variants are not publicly visible
#[derive(Debug)]
enum FunctionIteratorInner<'ctx> {
Empty,
Start(FunctionValue<'ctx>),
Previous(FunctionValue<'ctx>),
}
pub struct FunctionIterator<'ctx>(Option<FunctionValue<'ctx>>);

impl<'ctx> FunctionIterator<'ctx> {
fn from_module(module: &Module<'ctx>) -> Self {
use FunctionIteratorInner::*;

match module.get_first_function() {
None => Self(Empty),
Some(first) => Self(Start(first)),
}
Self(module.get_first_function())
}
}

impl<'ctx> Iterator for FunctionIterator<'ctx> {
type Item = FunctionValue<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
use FunctionIteratorInner::*;

match self.0 {
Empty => None,
Start(first) => {
self.0 = Previous(first);

Some(first)
},
Previous(prev) => match prev.get_next_function() {
Some(current) => {
self.0 = Previous(current);

Some(current)
},
None => None,
},
if let Some(func) = self.0 {
self.0 = func.get_next_function();
Some(func)
} else {
None
}
}
}

/// Iterate over all `GlobalValue`s in an llvm module
#[derive(Debug)]
pub struct GlobalIterator<'ctx>(GlobalIteratorInner<'ctx>);

/// Inner type so the variants are not publicly visible
#[derive(Debug)]
enum GlobalIteratorInner<'ctx> {
Empty,
Start(GlobalValue<'ctx>),
Previous(GlobalValue<'ctx>),
}
pub struct GlobalIterator<'ctx>(Option<GlobalValue<'ctx>>);

impl<'ctx> GlobalIterator<'ctx> {
fn from_module(module: &Module<'ctx>) -> Self {
use GlobalIteratorInner::*;

match module.get_first_global() {
None => Self(Empty),
Some(first) => Self(Start(first)),
}
Self(module.get_first_global())
}
}

impl<'ctx> Iterator for GlobalIterator<'ctx> {
type Item = GlobalValue<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
use GlobalIteratorInner::*;

match self.0 {
Empty => None,
Start(first) => {
self.0 = Previous(first);

Some(first)
},
Previous(prev) => match prev.get_next_global() {
Some(current) => {
self.0 = Previous(current);

Some(current)
},
None => None,
},
if let Some(global) = self.0 {
self.0 = global.get_next_global();
Some(global)
} else {
None
}
}
}
21 changes: 21 additions & 0 deletions src/values/fn_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ impl<'ctx> FunctionValue<'ctx> {
unsafe { LLVMCountBasicBlocks(self.as_value_ref()) }
}

pub fn get_basic_block_iter(self) -> BasicBlockIter<'ctx> {
BasicBlockIter(self.get_first_basic_block())
}

pub fn get_basic_blocks(self) -> Vec<BasicBlock<'ctx>> {
let count = self.count_basic_blocks();
let mut raw_vec: Vec<LLVMBasicBlockRef> = Vec::with_capacity(count as usize);
Expand Down Expand Up @@ -552,6 +556,23 @@ impl fmt::Debug for FunctionValue<'_> {
}
}

/// Iterate over all `BasicBlock`s in a function.
#[derive(Debug)]
pub struct BasicBlockIter<'ctx>(Option<BasicBlock<'ctx>>);

impl<'ctx> Iterator for BasicBlockIter<'ctx> {
type Item = BasicBlock<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
if let Some(bb) = self.0 {
self.0 = bb.get_next_basic_block();
Some(bb)
} else {
None
}
}
}

#[derive(Debug)]
pub struct ParamValueIter<'ctx> {
param_iter_value: LLVMValueRef,
Expand Down
50 changes: 26 additions & 24 deletions tests/all/test_basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,33 @@ fn test_basic_block_ordering() {
let basic_block2 = context.insert_basic_block_after(basic_block, "block2");
let basic_block3 = context.prepend_basic_block(basic_block4, "block3");

let basic_blocks = function.get_basic_blocks();

assert_eq!(basic_blocks.len(), 4);
assert_eq!(basic_blocks[0], basic_block);
assert_eq!(basic_blocks[1], basic_block2);
assert_eq!(basic_blocks[2], basic_block3);
assert_eq!(basic_blocks[3], basic_block4);
for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] {
assert_eq!(basic_blocks.len(), 4);
assert_eq!(basic_blocks[0], basic_block);
assert_eq!(basic_blocks[1], basic_block2);
assert_eq!(basic_blocks[2], basic_block3);
assert_eq!(basic_blocks[3], basic_block4);
}

assert!(basic_block3.move_before(basic_block2).is_ok());
assert!(basic_block.move_after(basic_block4).is_ok());

let basic_block5 = context.prepend_basic_block(basic_block, "block5");
let basic_blocks = function.get_basic_blocks();

assert_eq!(basic_blocks.len(), 5);
assert_eq!(basic_blocks[0], basic_block3);
assert_eq!(basic_blocks[1], basic_block2);
assert_eq!(basic_blocks[2], basic_block4);
assert_eq!(basic_blocks[3], basic_block5);
assert_eq!(basic_blocks[4], basic_block);

assert_ne!(basic_blocks[0], basic_block);
assert_ne!(basic_blocks[1], basic_block3);
assert_ne!(basic_blocks[2], basic_block2);
assert_ne!(basic_blocks[3], basic_block4);
assert_ne!(basic_blocks[4], basic_block5);
for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] {
assert_eq!(basic_blocks.len(), 5);
assert_eq!(basic_blocks[0], basic_block3);
assert_eq!(basic_blocks[1], basic_block2);
assert_eq!(basic_blocks[2], basic_block4);
assert_eq!(basic_blocks[3], basic_block5);
assert_eq!(basic_blocks[4], basic_block);

assert_ne!(basic_blocks[0], basic_block);
assert_ne!(basic_blocks[1], basic_block3);
assert_ne!(basic_blocks[2], basic_block2);
assert_ne!(basic_blocks[3], basic_block4);
assert_ne!(basic_blocks[4], basic_block5);
}

context.append_basic_block(function, "block6");

Expand Down Expand Up @@ -89,6 +90,7 @@ fn test_get_basic_blocks() {

assert!(function.get_last_basic_block().is_none());
assert_eq!(function.get_basic_blocks().len(), 0);
assert_eq!(function.get_basic_block_iter().count(), 0);

let basic_block = context.append_basic_block(function, "entry");

Expand All @@ -98,10 +100,10 @@ fn test_get_basic_blocks() {

assert_eq!(last_basic_block, basic_block);

let basic_blocks = function.get_basic_blocks();

assert_eq!(basic_blocks.len(), 1);
assert_eq!(basic_blocks[0], basic_block);
for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] {
assert_eq!(basic_blocks.len(), 1);
assert_eq!(basic_blocks[0], basic_block);
}
}

#[test]
Expand Down

0 comments on commit 4909eaa

Please sign in to comment.