diff --git a/src/lib.rs b/src/lib.rs index e23ada4..ed34492 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ use std::str::FromStr; -use std::sync::{Arc, LazyLock}; +use std::sync::{Arc, LazyLock, Mutex}; use std::time::Duration; use ahash::RandomState; @@ -11,7 +11,7 @@ use pyo3::prelude::*; use pyo3::types::PyBytes; use pythonize::depythonize; use rquest::{ - header::{HeaderMap, HeaderName, HeaderValue, COOKIE}, + header::{HeaderValue, COOKIE}, multipart, redirect::Policy, tls::Impersonate, @@ -23,6 +23,9 @@ use tokio::runtime::{self, Runtime}; mod response; use response::Response; +mod traits; +use traits::HeadersTraits; + mod utils; use utils::load_ca_certs; @@ -37,10 +40,14 @@ static RUNTIME: LazyLock = LazyLock::new(|| { #[pyclass] /// HTTP client that can impersonate web browsers. pub struct Client { - client: Arc, + client: Arc>, + #[pyo3(get, set)] auth: Option<(String, Option)>, + #[pyo3(get, set)] auth_bearer: Option, + #[pyo3(get, set)] params: Option>, + #[pyo3(get)] cookies: Option>, } @@ -133,15 +140,7 @@ impl Client { // Headers if let Some(headers) = headers { - let headers_new = headers - .iter() - .filter_map(|(k, v)| { - HeaderName::from_bytes(k.as_bytes()) - .ok() - .and_then(|name| HeaderValue::from_str(v).ok().map(|value| (name, value))) - }) - .collect::(); - client_builder = client_builder.default_headers(headers_new); + client_builder = client_builder.default_headers(headers.to_headermap()); } // Cookie_store @@ -195,7 +194,7 @@ impl Client { client_builder = client_builder.http2_only(); } - let client = Arc::new(client_builder.build()?); + let client = Arc::new(Mutex::new(client_builder.build()?)); Ok(Client { client, @@ -206,6 +205,28 @@ impl Client { }) } + #[getter] + pub fn get_headers(&self) -> Result> { + let headers = self.client.lock().unwrap().headers_mut().to_indexmap(); + Ok(headers) + } + + #[setter] + pub fn set_headers( + &self, + new_headers: Option>, + ) -> Result<()> { + let mut client = self.client.lock().unwrap(); + let headers = client.headers_mut(); + headers.clear(); + if let Some(new_headers) = new_headers { + for (k, v) in new_headers { + headers.insert_key_value(k, v)? + } + } + Ok(()) + } + /// Constructs an HTTP request with the given method, URL, and optionally sets a timeout, headers, and query parameters. /// Sends the request and returns a `Response` object containing the server's response. /// @@ -274,7 +295,7 @@ impl Client { let future = async move { // Create request builder - let mut request_builder = client.request(method, url); + let mut request_builder = client.lock().unwrap().request(method, url); // Params if let Some(params) = params { @@ -283,15 +304,7 @@ impl Client { // Headers if let Some(headers) = headers { - let headers_new = headers - .iter() - .filter_map(|(k, v)| { - HeaderName::from_bytes(k.as_bytes()).ok().and_then(|name| { - HeaderValue::from_str(v).ok().map(|value| (name, value)) - }) - }) - .collect::(); - request_builder = request_builder.headers(headers_new); + request_builder = request_builder.headers(headers.to_headermap()); } // Cookies @@ -351,11 +364,7 @@ impl Client { .cookies() .map(|cookie| (cookie.name().to_string(), cookie.value().to_string())) .collect(); - let headers: IndexMap = resp - .headers() - .iter() - .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string())) - .collect(); + let headers: IndexMap = resp.headers().to_indexmap(); let status_code = resp.status().as_u16(); let url = resp.url().to_string(); let buf = resp.bytes().await?;