diff --git a/postgres-protocol/src/types/mod.rs b/postgres-protocol/src/types/mod.rs index 46000c407..0a93692ff 100644 --- a/postgres-protocol/src/types/mod.rs +++ b/postgres-protocol/src/types/mod.rs @@ -3,6 +3,8 @@ use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt}; use fallible_iterator::FallibleIterator; use std::boxed::Box as StdBox; use std::error::Error; +use std::io::Read; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::str; use crate::{write_nullable, FromUsize, IsNull, Oid}; @@ -16,6 +18,9 @@ const RANGE_UPPER_INCLUSIVE: u8 = 0b0000_0100; const RANGE_LOWER_INCLUSIVE: u8 = 0b0000_0010; const RANGE_EMPTY: u8 = 0b0000_0001; +const PGSQL_AF_INET: u8 = 2; +const PGSQL_AF_INET6: u8 = 3; + /// Serializes a `BOOL` value. #[inline] pub fn bool_to_sql(v: bool, buf: &mut Vec) { @@ -956,3 +961,86 @@ impl<'a> FallibleIterator for PathPoints<'a> { (len, Some(len)) } } + +/// Serializes a Postgres inet. +#[inline] +pub fn inet_to_sql(addr: IpAddr, netmask: u8, buf: &mut Vec) { + let family = match addr { + IpAddr::V4(_) => PGSQL_AF_INET, + IpAddr::V6(_) => PGSQL_AF_INET6, + }; + buf.push(family); + buf.push(netmask); + buf.push(0); // is_cidr + match addr { + IpAddr::V4(addr) => { + buf.push(4); + buf.extend_from_slice(&addr.octets()); + } + IpAddr::V6(addr) => { + buf.push(16); + buf.extend_from_slice(&addr.octets()); + } + } +} + +/// Deserializes a Postgres inet. +#[inline] +pub fn inet_from_sql(mut buf: &[u8]) -> Result> { + let family = buf.read_u8()?; + let netmask = buf.read_u8()?; + buf.read_u8()?; // is_cidr + let len = buf.read_u8()?; + + let addr = match family { + PGSQL_AF_INET => { + if netmask > 32 { + return Err("invalid IPv4 netmask".into()); + } + if len != 4 { + return Err("invalid IPv4 address length".into()); + } + let mut addr = [0; 4]; + buf.read_exact(&mut addr)?; + IpAddr::V4(Ipv4Addr::from(addr)) + } + PGSQL_AF_INET6 => { + if netmask > 128 { + return Err("invalid IPv6 netmask".into()); + } + if len != 16 { + return Err("invalid IPv6 address length".into()); + } + let mut addr = [0; 16]; + buf.read_exact(&mut addr)?; + IpAddr::V6(Ipv6Addr::from(addr)) + } + _ => return Err("invalid IP family".into()), + }; + + if !buf.is_empty() { + return Err("invalid buffer size".into()); + } + + Ok(Inet { addr, netmask }) +} + +/// A Postgres network address. +pub struct Inet { + addr: IpAddr, + netmask: u8, +} + +impl Inet { + /// Returns the IP address. + #[inline] + pub fn addr(&self) -> IpAddr { + self.addr + } + + /// Returns the netmask. + #[inline] + pub fn netmask(&self) -> u8 { + self.netmask + } +} diff --git a/tokio-postgres/src/types/mod.rs b/tokio-postgres/src/types/mod.rs index 145a2530b..42599171e 100644 --- a/tokio-postgres/src/types/mod.rs +++ b/tokio-postgres/src/types/mod.rs @@ -8,6 +8,7 @@ use std::collections::HashMap; use std::error::Error; use std::fmt; use std::hash::BuildHasher; +use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -248,6 +249,7 @@ impl WrongType { /// | `&[u8]`/`Vec` | BYTEA | /// | `HashMap>` | HSTORE | /// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | +/// | `IpAddr` | INET | /// /// In addition, some implementations are provided for types in third party /// crates. These are disabled by default; to opt into one of these @@ -469,6 +471,15 @@ impl<'a> FromSql<'a> for SystemTime { accepts!(TIMESTAMP, TIMESTAMPTZ); } +impl<'a> FromSql<'a> for IpAddr { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result> { + let inet = types::inet_from_sql(raw)?; + Ok(inet.addr()) + } + + accepts!(INET); +} + /// An enum representing the nullability of a Postgres value. pub enum IsNull { /// The value is NULL. @@ -498,6 +509,7 @@ pub enum IsNull { /// | `&[u8]`/Vec` | BYTEA | /// | `HashMap>` | HSTORE | /// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | +/// | `IpAddr` | INET | /// /// In addition, some implementations are provided for types in third party /// crates. These are disabled by default; to opt into one of these @@ -771,6 +783,21 @@ impl ToSql for SystemTime { to_sql_checked!(); } +impl ToSql for IpAddr { + fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { + let netmask = match self { + IpAddr::V4(_) => 32, + IpAddr::V6(_) => 128, + }; + types::inet_to_sql(*self, netmask, w); + Ok(IsNull::No) + } + + accepts!(INET); + + to_sql_checked!(); +} + fn downcast(len: usize) -> Result> { if len > i32::max_value() as usize { Err("value too large to transmit".into()) diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index bf9870043..18858d568 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -4,6 +4,7 @@ use std::error::Error; use std::f32; use std::f64; use std::fmt; +use std::net::IpAddr; use std::result; use std::time::{Duration, UNIX_EPOCH}; use tokio::runtime::current_thread::Runtime; @@ -624,3 +625,33 @@ fn system_time() { ], ); } + +#[test] +fn inet() { + test_type( + "INET", + &[ + (Some("127.0.0.1".parse::().unwrap()), "'127.0.0.1'"), + ( + Some("127.0.0.1".parse::().unwrap()), + "'127.0.0.1/32'", + ), + ( + Some( + "2001:4f8:3:ba:2e0:81ff:fe22:d1f1" + .parse::() + .unwrap(), + ), + "'2001:4f8:3:ba:2e0:81ff:fe22:d1f1'", + ), + ( + Some( + "2001:4f8:3:ba:2e0:81ff:fe22:d1f1" + .parse::() + .unwrap(), + ), + "'2001:4f8:3:ba:2e0:81ff:fe22:d1f1/128'", + ), + ], + ); +}