diff --git a/src/descriptor.rs b/src/descriptor.rs index cb40fda..8bc37ca 100644 --- a/src/descriptor.rs +++ b/src/descriptor.rs @@ -43,6 +43,104 @@ pub struct DescriptorWriter<'a> { write_iads: bool, } +/// Write arbitrary sequences of bytes to descriptors with tools for delayed length writing +/// If there is an error it is delayed until the end of writing when finish is called +pub struct ByteWriter<'a, 'b> { + parent_writer: &'a mut DescriptorWriter<'b>, + result: Result<()>, +} + +impl<'a, 'b> ByteWriter<'a, 'b> { + fn new(parent: &'a mut DescriptorWriter<'b>) -> Self { + ByteWriter { + parent_writer: parent, + result: Ok(()), + } + } + + /// Write a single byte + pub fn byte(&mut self, b: u8) -> &mut Self { + if self.result.is_ok() { + let mut parent = &mut self.parent_writer; + if (parent.position + 1) >= parent.buf.len() { + self.result = Err(UsbError::BufferOverflow); + return self; // Delayed Error + } + + parent.buf[parent.position] = b; + parent.position += 1; + } + self + } + + /// Write a slice of bytes + pub fn arr(&mut self, bs: &[u8]) -> &mut Self { + if self.result.is_ok() { + let mut parent = &mut self.parent_writer; + let length = bs.len(); + if (parent.position + length) >= parent.buf.len() { + self.result = Err(UsbError::BufferOverflow); + return self; // Delayed Error + } + + let start = parent.position; + parent.buf[start..start + length].copy_from_slice(bs); + parent.position += length; + } + self + } + + /// current position in the parent buffer + pub fn position(&self) -> usize { + self.parent_writer.position + } + + /// Rserve space, do some writes, produce a value, then write something in that space using that value + pub fn delayed_write(&mut self, size: usize, w: W, delayed: D) -> &mut Self + where + W: FnOnce(&mut ByteWriter<'_, '_>) -> V, + D: FnOnce(&mut [u8], V), + { + if self.result.is_ok() { + if (self.parent_writer.position + size) >= self.parent_writer.buf.len() { + self.result = Err(UsbError::BufferOverflow); + return self; // Delayed Error + } + + let start = self.parent_writer.position; + self.parent_writer.position += size; + + let v = w(self); + delayed(&mut self.parent_writer.buf[start..start + size], v); + } + + self + } + + /// Store the current position and advance by 1, write some bytes + /// then when leaving the functon write the length of those bytes, including the length byte + pub fn delayed_length(&mut self, w: W) -> &mut Self + where + W: FnOnce(&mut ByteWriter<'_, '_>), + { + let start = self.position(); + self.delayed_write( + 1, + |bw| { + w(bw); + (bw.position() - start) as u8 + }, + |buf, length| buf[0] = length, + ); + self + } + + #[must_use] + fn finish(self) -> Result<()> { + self.result + } +} + impl DescriptorWriter<'_> { pub(crate) fn new(buf: &mut [u8]) -> DescriptorWriter<'_> { DescriptorWriter { @@ -79,6 +177,20 @@ impl DescriptorWriter<'_> { Ok(()) } + /// Write an arbiutrary sequence of bytes without needing to know the length at compile time + pub fn writer(&mut self, w: W) -> Result<()> + where + W: FnOnce(&mut ByteWriter<'_, '_>), + { + if self.position > self.buf.len() { + return Err(UsbError::BufferOverflow); + } + + let mut bb = ByteWriter::new(self); + w(&mut bb); + bb.finish() + } + pub(crate) fn device(&mut self, config: &device::Config) -> Result<()> { self.write( descriptor_type::DEVICE,