Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated get_connection deprecated method to avoid warning #5

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 48 additions & 51 deletions requests_to_curl/requests_to_curl.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,50 @@
# coding: utf-8
import sys
from copy import deepcopy

from copy import deepcopy
import requests

if sys.version_info.major >= 3:
# Use `shlex.quote` for Python 3, and `pipes.quote` for Python 2
try:
from shlex import quote
else:
except ImportError:
from pipes import quote


HEADER_BLOCKLIST = {
"Content-Length",
}


def parse(
request_or_response, compressed=False, verify=True, return_it=False, print_it=True
):
def parse(request_or_response, compressed=False, verify=True, return_it=False, print_it=True):
"""
Args:
request_or_response: requests.models.Request|requests.models.Response
return_it: False=return None. True=return the string
print_it: False=Do Nothing. True=Print parsed string to stdout.
request_or_response: requests.models.Request | requests.models.Response
return_it: bool - If True, return the generated curl string; otherwise, return None.
print_it: bool - If True, print the generated curl string to stdout.

Returns:
str or None: The generated curl command if return_it is True; otherwise, None.

Print:
curl command
Prints:
The generated curl command if print_it is True.
"""

def _build_url(connection_pool, request):
scheme = connection_pool.scheme if connection_pool.scheme in ("http", "https") else "http"
host = "[{}]".format(connection_pool.host) if ":" in connection_pool.host else connection_pool.host
return "{}://{}:{}{}".format(scheme, host, connection_pool.port, request.path_url)

if isinstance(request_or_response, requests.models.Response):
request = deepcopy(request_or_response.request)
connection_pool = request_or_response.connection.get_connection(request.url)
http_scheme = connection_pool.scheme
if http_scheme not in ("http", "https"):
http_scheme = "http"
host = connection_pool.host
is_ipv6 = ":" in host
if is_ipv6:
host = "[{}]".format(host)
request.url = "{scheme}://{host}:{port}{path_url}".format(
scheme=http_scheme,
host=host,
port=connection_pool.port,
path_url=request.path_url,
)
elif isinstance(
request_or_response, (requests.models.Request, requests.models.PreparedRequest)
):
connection_pool = request_or_response.connection.get_connection_with_tls_context(request, verify)
request.url = _build_url(connection_pool, request)
elif isinstance(request_or_response, (requests.models.Request, requests.models.PreparedRequest)):
request = deepcopy(request_or_response)
else:
raise Exception(
"`parse` needs a request or response, not {}".format(
type(request_or_response)
)
)
raise TypeError("`parse` needs a request or response, not {}".format(type(request_or_response).__name__))

curl_string = _parse_request(request=request, compressed=compressed, verify=verify)

if print_it:
print(curl_string)
if return_it:
Expand All @@ -62,24 +53,30 @@ def parse(

def _parse_request(request, compressed=False, verify=True):
parts = [("curl", None), ("-X", request.method)]
for k, v in sorted(request.headers.items()):
if k in HEADER_BLOCKLIST:
continue
parts += [("-H", "{0}: {1}".format(k, v))]

# Add headers, skipping blocklisted ones
headers = [
("-H", "{}: {}".format(k, v))
for k, v in sorted(request.headers.items())
if k not in HEADER_BLOCKLIST
]
parts.extend(headers)

# Add body if present
if request.body:
body = request.body
if isinstance(body, bytes):
body = body.decode("utf-8")
parts += [("-d", body)]
body = request.body.decode("utf-8") if isinstance(request.body, bytes) else request.body
parts.append(("-d", body))

# Add optional flags
if compressed:
parts += [("--compressed", None)]
parts.append(("--compressed", None))
if not verify:
parts += [("--insecure", None)]
parts += [(None, request.url)]
flat_parts = []
for k, v in parts:
if k:
flat_parts.append(quote(k))
if v:
flat_parts.append(quote(v))
parts.append(("--insecure", None))

# Add the request URL
parts.append((None, request.url))

# Flatten parts and return the command string
flat_parts = [quote(k) for k, v in parts if k] + [quote(v) for k, v in parts if v]

return " ".join(flat_parts)
Loading