Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle HTTP redirects using reqwest Client builder #29

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/download_test.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::error::Error;
use url::Url;
use std::str::FromStr;

use ue_rs::download_and_hash;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();

let url = std::env::args().nth(1).expect("missing URL (second argument)");
let url = Url::from_str(std::env::args().nth(1).expect("missing URL (second argument)").as_str())?;

println!("fetching {}...", url);

Expand Down
189 changes: 108 additions & 81 deletions src/bin/download_sysext.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
use std::error::Error;
use std::borrow::Cow;
use std::path::{Path, PathBuf};
use std::ffi::OsStr;
use std::fs::File;
use std::fs;
use std::io;
use std::io::{Read, Seek, SeekFrom};
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::str::FromStr;

#[macro_use]
extern crate log;

use anyhow::{Context, Result, bail};
use anyhow::{Context, Result, bail, anyhow};
use argh::FromArgs;
use globset::{Glob, GlobSet, GlobSetBuilder};
use hard_xml::XmlRead;
use argh::FromArgs;
use omaha::FileSize;
use reqwest::Client;
use reqwest::redirect::Policy;
use url::Url;

use update_format_crau::delta_update;
use ue_rs::hash_on_disk_sha256;

#[derive(Debug)]
enum PackageStatus {
Expand Down Expand Up @@ -44,51 +49,7 @@ impl<'a> Package<'a> {
// If maxlen is None, a simple read to the end of the file.
// If maxlen is Some, read only until the given length.
fn hash_on_disk(&mut self, path: &Path, maxlen: Option<usize>) -> Result<omaha::Hash<omaha::Sha256>> {
use sha2::{Sha256, Digest};

let file = File::open(path).context({
format!("failed to open path({:?})", path.display())
})?;
let mut hasher = Sha256::new();

let filelen = file.metadata().unwrap().len() as usize;

let mut maxlen_to_read: usize = match maxlen {
Some(len) => {
if filelen < len {
filelen
} else {
len
}
}
None => filelen,
};

const CHUNKLEN: usize = 10485760; // 10M

let mut freader = BufReader::new(file);
let mut chunklen: usize;

freader.seek(SeekFrom::Start(0)).context("failed to seek(0)".to_string())?;
while maxlen_to_read > 0 {
if maxlen_to_read < CHUNKLEN {
chunklen = maxlen_to_read;
} else {
chunklen = CHUNKLEN;
}

let mut databuf = vec![0u8; chunklen];

freader.read_exact(&mut databuf).context(format!("failed to read_exact(chunklen {:?})", chunklen))?;

maxlen_to_read -= chunklen;

hasher.update(&databuf);
}

Ok(omaha::Hash::from_bytes(
hasher.finalize().into()
))
hash_on_disk_sha256(path, maxlen)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -146,7 +107,7 @@ impl<'a> Package<'a> {
let path = into_dir.join(&*self.name);
let mut file = File::create(path.clone()).context(format!("failed to create path ({:?})", path.display()))?;

let res = match ue_rs::download_and_hash(&client, self.url.clone(), &mut file).await {
let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file).await {
Ok(ok) => ok,
Err(err) => {
error!("Downloading failed with error {}", err);
Expand Down Expand Up @@ -188,7 +149,7 @@ impl<'a> Package<'a> {
let sigbytes = delta_update::get_signatures_bytes(upfreader, &header, &mut delta_archive_manifest).context(format!("failed to get_signatures_bytes path ({:?})", from_path.display()))?;

// tmp dir == "/var/tmp/outdir/.tmp"
let tmpdirpathbuf = from_path.parent().unwrap().parent().unwrap().join(".tmp");
let tmpdirpathbuf = from_path.parent().ok_or(anyhow!("unable to get parent dir"))?.parent().ok_or(anyhow!("unable to get parent dir"))?.join(".tmp");
let tmpdir = tmpdirpathbuf.as_path();
let datablobspath = tmpdir.join("ue_data_blobs");

Expand Down Expand Up @@ -281,6 +242,44 @@ fn get_pkgs_to_download<'a>(resp: &'a omaha::Response, glob_set: &GlobSet)
Ok(to_download)
}

// Read data from remote URL into File
async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client) -> Result<Package<'a>>
where
U: reqwest::IntoUrl + From<U> + std::clone::Clone + std::fmt::Debug,
Url: From<U>,
{
let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?;

ue_rs::download_and_hash(client, input_url.clone(), &mut file).await.context(format!("unable to download data(url {:?})", input_url))?;

Ok(Package {
name: Cow::Borrowed(path.file_name().unwrap_or(OsStr::new("fakepackage")).to_str().unwrap_or("fakepackage")),
hash: hash_on_disk_sha256(path, None)?,
size: FileSize::from_bytes(file.metadata().context(format!("failed to get metadata, path ({:?})", path.display()))?.len() as usize),
url: input_url.into(),
status: PackageStatus::Unverified,
})
}

async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client) -> Result<()> {
pkg.check_download(unverified_dir)?;

pkg.download(unverified_dir, client).await.context(format!("unable to download \"{:?}\"", pkg.name))?;

// Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz".
// Verified payload is stored in e.g. "output_dir/oem.raw".
let pkg_unverified = unverified_dir.join(&*pkg.name);
let pkg_verified = output_dir.join(pkg_unverified.with_extension("raw").file_name().unwrap_or_default());

let datablobspath = pkg.verify_signature_on_disk(&pkg_unverified, pubkey_file).context(format!("unable to verify signature \"{}\"", pkg.name))?;

// write extracted data into the final data.
debug!("data blobs written into file {:?}", pkg_verified);
fs::rename(datablobspath, pkg_verified)?;

Ok(())
}

#[derive(FromArgs, Debug)]
/// Parse an update-engine Omaha XML response to extract sysext images, then download and verify
/// their signatures.
Expand All @@ -291,7 +290,11 @@ struct Args {

/// path to the Omaha XML file, or - to read from stdin
#[argh(option, short = 'i')]
input_xml: String,
input_xml: Option<String>,

/// URL to fetch remote update payload
#[argh(option, short = 'u')]
payload_url: Option<String>,
dongsupark marked this conversation as resolved.
Show resolved Hide resolved

/// path to the public key file
#[argh(option, short = 'p')]
Expand Down Expand Up @@ -324,14 +327,6 @@ async fn main() -> Result<(), Box<dyn Error>> {

let glob_set = args.image_match_glob_set()?;

let response_text = match &*args.input_xml {
"-" => io::read_to_string(io::stdin())?,
path => {
let file = File::open(path)?;
io::read_to_string(file)?
}
};

let output_dir = Path::new(&*args.output_dir);
if !output_dir.try_exists()? {
return Err(format!("output directory `{}` does not exist", args.output_dir).into());
Expand All @@ -342,6 +337,58 @@ async fn main() -> Result<(), Box<dyn Error>> {
fs::create_dir_all(&unverified_dir)?;
fs::create_dir_all(&temp_dir)?;

// The default policy of reqwest Client supports max 10 attempts on HTTP redirect.
let client = Client::builder().redirect(Policy::default()).build()?;

// If input_xml exists, simply read it.
// If not, try to read from payload_url.
let res_local = match args.input_xml {
Some(name) => {
if name == "-" {
Some(io::read_to_string(io::stdin())?)
} else {
let file = File::open(name)?;
Some(io::read_to_string(file)?)
}
}
None => None,
};

match (&res_local, args.payload_url) {
(Some(_), Some(_)) => {
return Err("Only one of the options can be given, --input-xml or --payload-url.".into());
}
(Some(res), None) => res,
(None, Some(url)) => {
let u = Url::parse(&url)?;
let fname = u.path_segments().ok_or(anyhow!("failed to get path segments, url ({:?})", u))?.next_back().ok_or(anyhow!("failed to get path segments, url ({:?})", u))?;
let mut pkg_fake: Package;

let temp_payload_path = unverified_dir.join(fname);
pkg_fake = fetch_url_to_file(
&temp_payload_path,
Url::from_str(url.as_str()).context(anyhow!("failed to convert into url ({:?})", url))?,
&client,
)
.await?;
do_download_verify(
&mut pkg_fake,
output_dir,
unverified_dir.as_path(),
args.pubkey_file.as_str(),
&client,
)
.await?;

// verify only a fake package, early exit and skip the rest.
return Ok(());
}
(None, None) => return Err("Either --input-xml or --payload-url must be given.".into()),
};

let response_text = res_local.ok_or(anyhow!("failed to get response text"))?;
debug!("response_text: {:?}", response_text);
dongsupark marked this conversation as resolved.
Show resolved Hide resolved

dongsupark marked this conversation as resolved.
Show resolved Hide resolved
////
// parse response
////
Expand All @@ -355,29 +402,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
////
// download
////
let client = reqwest::Client::new();

for pkg in pkgs_to_dl.iter_mut() {
pkg.check_download(&unverified_dir)?;

match pkg.download(&unverified_dir, &client).await {
Ok(_) => (),
_ => return Err(format!("unable to download \"{}\"", pkg.name).into()),
};

// Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz".
// Verified payload is stored in e.g. "output_dir/oem.raw".
let pkg_unverified = unverified_dir.join(&*pkg.name);
let pkg_verified = output_dir.join(pkg_unverified.with_extension("raw").file_name().unwrap_or_default());

match pkg.verify_signature_on_disk(&pkg_unverified, &args.pubkey_file) {
Ok(datablobspath) => {
// write extracted data into the final data.
fs::rename(datablobspath, pkg_verified.clone())?;
debug!("data blobs written into file {:?}", pkg_verified);
}
_ => return Err(format!("unable to verify signature \"{}\"", pkg.name).into()),
};
do_download_verify(pkg, output_dir, unverified_dir.as_path(), args.pubkey_file.as_str(), &client).await?;
}

// clean up data
Expand Down
61 changes: 54 additions & 7 deletions src/download.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use anyhow::{Context, Result, bail};
use std::io::Write;
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
use std::io;
use log::warn;
use std::fs::File;
use std::path::Path;
use log::info;
use url::Url;

use reqwest::StatusCode;

Expand All @@ -12,10 +15,53 @@ pub struct DownloadResult<W: std::io::Write> {
pub data: W,
}

pub fn hash_on_disk_sha256(path: &Path, maxlen: Option<usize>) -> Result<omaha::Hash<omaha::Sha256>> {
let file = File::open(path).context(format!("failed to open path({:?})", path.display()))?;
let mut hasher = Sha256::new();

let filelen = file.metadata().context(format!("failed to get metadata of {:?}", path.display()))?.len() as usize;

let mut maxlen_to_read: usize = match maxlen {
Some(len) => {
if filelen < len {
filelen
} else {
len
}
}
None => filelen,
};

const CHUNKLEN: usize = 10485760; // 10M

let mut freader = BufReader::new(file);
let mut chunklen: usize;

freader.seek(SeekFrom::Start(0)).context("failed to seek(0)".to_string())?;
while maxlen_to_read > 0 {
if maxlen_to_read < CHUNKLEN {
chunklen = maxlen_to_read;
} else {
chunklen = CHUNKLEN;
}

let mut databuf = vec![0u8; chunklen];

freader.read_exact(&mut databuf).context(format!("failed to read_exact(chunklen {:?})", chunklen))?;

maxlen_to_read -= chunklen;

hasher.update(&databuf);
}

Ok(omaha::Hash::from_bytes(hasher.finalize().into()))
}

pub async fn download_and_hash<U, W>(client: &reqwest::Client, url: U, mut data: W) -> Result<DownloadResult<W>>
where
U: reqwest::IntoUrl + Clone,
W: io::Write,
Url: From<U>,
{
let client_url = url.clone();

Expand All @@ -25,14 +71,15 @@ where
.await
.context(format!("client get and send({:?}) failed", client_url.as_str()))?;

// Redirect was already handled at this point, so there is no need to touch
// response or url again. Simply print info and continue.
if <U as Into<Url>>::into(client_url) != *res.url() {
info!("redirected to URL {:?}", res.url());
}

// Return immediately on download failure on the client side.
let status = res.status();

// TODO: handle redirect with retrying with a new URL or Attempt follow.
if status.is_redirection() {
warn!("redirect with status code {:?}", status);
}

if !status.is_success() {
match status {
StatusCode::FORBIDDEN | StatusCode::NOT_FOUND => {
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
mod download;
pub use download::DownloadResult;
pub use download::download_and_hash;
pub use download::hash_on_disk_sha256;

pub mod request;
Loading
Loading