Skip to content

Commit

Permalink
refactor(virtio-net): migrate header to virtio-spec
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Kröning <[email protected]>
  • Loading branch information
mkroening committed May 16, 2024
1 parent bc0b245 commit dd7923b
Showing 1 changed file with 22 additions and 55 deletions.
77 changes: 22 additions & 55 deletions src/drivers/net/virtio_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use align_address::Align;
use pci_types::InterruptLine;
use smoltcp::phy::{Checksum, ChecksumCapabilities};
use smoltcp::wire::{EthernetFrame, Ipv4Packet, Ipv6Packet, ETHERNET_HEADER_LEN};
use virtio_spec::net::{HdrF, HdrGso};
use zerocopy::AsBytes;
use virtio_spec::net::{Hdr, HdrF};

use self::constants::{Status, MAX_NUM_VQ};
use self::error::VirtioNetError;
Expand Down Expand Up @@ -43,37 +42,6 @@ pub(crate) struct NetDevCfg {
pub features: virtio_spec::net::F,
}

#[derive(AsBytes, Debug)]
#[repr(C)]
pub struct VirtioNetHdr {
flags: HdrF,
gso_type: HdrGso,
/// Ethernet + IP + tcp/udp hdrs
hdr_len: u16,
/// Bytes to append to hdr_len per frame
gso_size: u16,
/// Position to start checksumming from
csum_start: u16,
/// Offset after that to place checksum
csum_offset: u16,
/// Number of buffers this Packet consists of
num_buffers: u16,
}

impl Default for VirtioNetHdr {
fn default() -> Self {
Self {
flags: HdrF::empty(),
gso_type: HdrGso::empty(),
hdr_len: 0,
gso_size: 0,
csum_start: 0,
csum_offset: 0,
num_buffers: 0,
}
}
}

pub struct CtrlQueue(Option<Rc<dyn Virtq>>);

impl CtrlQueue {
Expand Down Expand Up @@ -182,10 +150,10 @@ impl RxQueues {
let num_buff: u16 = vq.size().into();

let rx_size = if dev_cfg.features.contains(virtio_spec::net::F::MRG_RXBUF) {
(1514 + mem::size_of::<VirtioNetHdr>())
(1514 + mem::size_of::<Hdr>())
.align_up(core::mem::size_of::<crossbeam_utils::CachePadded<u8>>())
} else {
dev_cfg.raw.get_mtu() as usize + mem::size_of::<VirtioNetHdr>()
dev_cfg.raw.get_mtu() as usize + mem::size_of::<Hdr>()
};

// See Virtio specification v1.1 - 5.1.6.3.1
Expand Down Expand Up @@ -330,7 +298,7 @@ impl TxQueues {
// Header and data are added as ONE output descriptor to the transmitvq.
// Hence we are interpreting this, as the fact, that send packets must be inside a single descriptor.
// As usize is currently safe as the minimal usize is defined as 16bit in rust.
let buff_def = Bytes::new(mem::size_of::<VirtioNetHdr>() + 65550).unwrap();
let buff_def = Bytes::new(mem::size_of::<Hdr>() + 65550).unwrap();
let spec = BuffSpec::Single(buff_def);

let num_buff: u16 = vq.size().into();
Expand All @@ -340,7 +308,7 @@ impl TxQueues {
vq.clone()
.prep_buffer(Some(spec.clone()), None)
.unwrap()
.write_seq(Some(&VirtioNetHdr::default()), None::<&VirtioNetHdr>)
.write_seq(Some(&Hdr::default()), None::<&Hdr>)
.unwrap(),
)
}
Expand All @@ -350,8 +318,7 @@ impl TxQueues {
// Hence we are interpreting this, as the fact, that send packets must be inside a single descriptor.
// As usize is currently safe as the minimal usize is defined as 16bit in rust.
let buff_def =
Bytes::new(mem::size_of::<VirtioNetHdr>() + dev_cfg.raw.get_mtu() as usize)
.unwrap();
Bytes::new(mem::size_of::<Hdr>() + dev_cfg.raw.get_mtu() as usize).unwrap();
let spec = BuffSpec::Single(buff_def);

let num_buff: u16 = vq.size().into();
Expand All @@ -361,7 +328,7 @@ impl TxQueues {
vq.clone()
.prep_buffer(Some(spec.clone()), None)
.unwrap()
.write_seq(Some(&VirtioNetHdr::default()), None::<&VirtioNetHdr>)
.write_seq(Some(&Hdr::default()), None::<&Hdr>)
.unwrap(),
)
}
Expand Down Expand Up @@ -480,21 +447,19 @@ impl NetworkDriver for VirtioNetDriver {
where
F: FnOnce(&mut [u8]) -> R,
{
if let Some((mut buff_tkn, _vq_index)) = self
.send_vqs
.get_tkn(len + core::mem::size_of::<VirtioNetHdr>())
if let Some((mut buff_tkn, _vq_index)) =
self.send_vqs.get_tkn(len + core::mem::size_of::<Hdr>())
{
let (send_ptrs, _) = buff_tkn.raw_ptrs();
// Currently we have single Buffers in the TxQueue of size: MTU + ETHERNET_HEADER_LEN + VIRTIO_NET_HDR
// see TxQueue.add()
let (buff_ptr, _) = send_ptrs.unwrap()[0];

// Do not show smoltcp the memory region for VirtioNetHdr.
let header = unsafe { &mut *(buff_ptr as *mut VirtioNetHdr) };
// Do not show smoltcp the memory region for Hdr.
let header = unsafe { &mut *(buff_ptr as *mut Hdr) };
*header = Default::default();
let buff_ptr = unsafe {
buff_ptr.offset(isize::try_from(core::mem::size_of::<VirtioNetHdr>()).unwrap())
};
let buff_ptr =
unsafe { buff_ptr.offset(isize::try_from(core::mem::size_of::<Hdr>()).unwrap()) };

let buf_slice: &'static mut [u8] =
unsafe { core::slice::from_raw_parts_mut(buff_ptr, len) };
Expand Down Expand Up @@ -524,12 +489,14 @@ impl NetworkDriver for VirtioNetDriver {
protocol = None;
}
}
header.csum_start = u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len;
header.csum_start =
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len).into();
header.csum_offset = match protocol {
Some(smoltcp::wire::IpProtocol::Tcp) => 16,
Some(smoltcp::wire::IpProtocol::Udp) => 6,
_ => 0,
};
}
.into();
}

buff_tkn
Expand Down Expand Up @@ -560,7 +527,7 @@ impl NetworkDriver for VirtioNetDriver {
if recv_data.len() == 1 {
let mut vec_data: Vec<u8> = Vec::with_capacity(self.mtu.into());
let num_buffers = {
const HEADER_SIZE: usize = mem::size_of::<VirtioNetHdr>();
const HEADER_SIZE: usize = mem::size_of::<Hdr>();
let packet = recv_data.pop().unwrap();

// drop packets with invalid packet size
Expand All @@ -574,14 +541,14 @@ impl NetworkDriver for VirtioNetDriver {
}

let header = unsafe {
core::mem::transmute::<[u8; HEADER_SIZE], VirtioNetHdr>(
core::mem::transmute::<[u8; HEADER_SIZE], Hdr>(
packet[..HEADER_SIZE].try_into().unwrap(),
)
};
trace!("Header: {:?}", header);
let num_buffers = header.num_buffers;

vec_data.extend_from_slice(&packet[mem::size_of::<VirtioNetHdr>()..]);
vec_data.extend_from_slice(&packet[mem::size_of::<Hdr>()..]);
transfer
.reset()
.provide()
Expand All @@ -590,7 +557,7 @@ impl NetworkDriver for VirtioNetDriver {
num_buffers
};

for _ in 1..num_buffers {
for _ in 1..num_buffers.get() {
let transfer =
match RxQueues::post_processing(self.recv_vqs.get_next().unwrap()) {
Ok(trf) => trf,
Expand All @@ -615,7 +582,7 @@ impl NetworkDriver for VirtioNetDriver {
error!("Empty transfer, or with wrong buffer layout. Reusing and returning error to user-space network driver...");
transfer
.reset()
.write_seq(None::<&VirtioNetHdr>, Some(&VirtioNetHdr::default()))
.write_seq(None::<&Hdr>, Some(&Hdr::default()))
.unwrap()
.provide()
.dispatch_await(self.recv_vqs.poll_sender.clone(), false);
Expand Down

0 comments on commit dd7923b

Please sign in to comment.