Skip to content
This repository has been archived by the owner on Sep 4, 2024. It is now read-only.

Reuse socket #72

Merged
merged 1 commit into from
Nov 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 75 additions & 41 deletions src/simple_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

#[cfg(feature = "proxy")]
use socks::Socks5Stream;
use std::io::{BufRead, BufReader, Write};
#[cfg(not(feature = "proxy"))]
use std::io::{BufRead, BufReader, Read, Write};
use std::net::TcpStream;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::{error, fmt, io, net, thread};

Expand Down Expand Up @@ -38,6 +38,7 @@ pub struct SimpleHttpTransport {
proxy_addr: net::SocketAddr,
#[cfg(feature = "proxy")]
proxy_auth: Option<(String, String)>,
sock: Arc<Mutex<Option<TcpStream>>>,
}

impl Default for SimpleHttpTransport {
Expand All @@ -57,6 +58,7 @@ impl Default for SimpleHttpTransport {
),
#[cfg(feature = "proxy")]
proxy_auth: None,
sock: Arc::new(Mutex::new(None)),
}
}
}
Expand All @@ -73,29 +75,58 @@ impl SimpleHttpTransport {
}

fn request<R>(&self, req: impl serde::Serialize) -> Result<R, Error>
where
R: for<'a> serde::de::Deserialize<'a>,
{
// `try_request` should not panic, so the mutex shouldn't be poisoned
// and unwrapping should be safe
let mut sock = self.sock.lock().expect("poisoned mutex");
match self.try_request(req, &mut sock) {
Ok(response) => Ok(response),
Err(err) => {
*sock = None;
Err(err)
}
}
}

fn try_request<R>(
&self,
req: impl serde::Serialize,
sock: &mut Option<TcpStream>,
) -> Result<R, Error>
where
R: for<'a> serde::de::Deserialize<'a>,
{
// Open connection
let request_deadline = Instant::now() + self.timeout;
#[cfg(feature = "proxy")]
let mut sock = if let Some((username, password)) = &self.proxy_auth {
Socks5Stream::connect_with_password(
self.proxy_addr,
self.addr,
username.as_str(),
password.as_str(),
)?
.into_inner()
} else {
Socks5Stream::connect(self.proxy_addr, self.addr)?.into_inner()
};

#[cfg(not(feature = "proxy"))]
let mut sock = TcpStream::connect_timeout(&self.addr, self.timeout)?;
if sock.is_none() {
*sock = Some({
#[cfg(feature = "proxy")]
{
if let Some((username, password)) = &self.proxy_auth {
Socks5Stream::connect_with_password(
self.proxy_addr,
self.addr,
username.as_str(),
password.as_str(),
)?
.into_inner()
} else {
Socks5Stream::connect(self.proxy_addr, self.addr)?.into_inner()
}
}

sock.set_read_timeout(Some(self.timeout))?;
sock.set_write_timeout(Some(self.timeout))?;
#[cfg(not(feature = "proxy"))]
{
let stream = TcpStream::connect_timeout(&self.addr, self.timeout)?;
stream.set_read_timeout(Some(self.timeout))?;
stream.set_write_timeout(Some(self.timeout))?;
stream
}
})
};
let sock = sock.as_mut().unwrap();

// Serialize the body first so we can set the Content-Length header.
let body = serde_json::to_vec(&req)?;
Expand All @@ -105,7 +136,6 @@ impl SimpleHttpTransport {
sock.write_all(self.path.as_bytes())?;
sock.write_all(b" HTTP/1.1\r\n")?;
// Write headers
sock.write_all(b"Connection: Close\r\n")?;
sock.write_all(b"Content-Type: application/json\r\n")?;
sock.write_all(b"Content-Length: ")?;
sock.write_all(body.len().to_string().as_bytes())?;
Expand Down Expand Up @@ -133,18 +163,39 @@ impl SimpleHttpTransport {
Err(_) => return Err(Error::HttpParseError),
};

// Skip response header fields
while get_line(&mut reader, request_deadline)? != "\r\n" {}
// Parse response header fields
let mut content_length = None;
loop {
let line = get_line(&mut reader, request_deadline)?;

if line == "\r\n" {
break;
}

const CONTENT_LENGTH: &str = "content-length: ";
if line.to_lowercase().starts_with(CONTENT_LENGTH) {
content_length = Some(
line[CONTENT_LENGTH.len()..]
.trim()
.parse::<usize>()
.map_err(|_| Error::HttpParseError)?,
);
}
}

if response_code == 401 {
// There is no body in a 401 response, so don't try to read it
return Err(Error::HttpErrorCode(response_code));
}

let content_length = content_length.ok_or(Error::HttpParseError)?;

let mut buffer = vec![0; content_length];
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely need to put a limit on this, but it's fine for now :). We'll do a followup PR to audit this because I think get_line may also be subject to memory exhaustion attacks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that sounds prudent


// Even if it's != 200, we parse the response as we may get a JSONRPC error instead
// of the less meaningful HTTP error code.
let resp_body = get_lines(&mut reader)?;
match serde_json::from_str(&resp_body) {
reader.read_exact(&mut buffer)?;
match serde_json::from_slice(&buffer) {
Ok(s) => Ok(s),
Err(e) => {
if response_code != 200 {
Expand Down Expand Up @@ -261,23 +312,6 @@ fn get_line<R: BufRead>(reader: &mut R, deadline: Instant) -> Result<String, Err
Err(Error::Timeout)
}

/// Read all lines from a buffered reader.
fn get_lines<R: BufRead>(reader: &mut R) -> Result<String, Error> {
let mut body: String = String::new();

for line in reader.lines() {
match line {
Ok(l) => body.push_str(&l),
// io error occurred, abort
Err(e) => return Err(Error::SocketError(e)),
}
}
// remove whitespace
body.retain(|c| !c.is_whitespace());

Ok(body)
}

/// Do some very basic manual URL parsing because the uri/url crates
/// all have unicode-normalization as a dependency and that's broken.
fn check_url(url: &str) -> Result<(SocketAddr, String), Error> {
Expand Down