Skip to content

Commit

Permalink
fix: add separator option to cgi api
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamGold committed Jan 25, 2021
1 parent 93730b7 commit 8e1e361
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
23 changes: 14 additions & 9 deletions Lib/cgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def closelog():
# 0 ==> unlimited input
maxlen = 0

def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0):
def parse(fp=None, environ=os.environ, keep_blank_values=0,
strict_parsing=0, separator='&'):
"""Parse a query in the environment or from a file (default stdin)
Arguments, all optional:
Expand All @@ -134,6 +135,9 @@ def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0):
strict_parsing: flag indicating what to do with parsing errors.
If false (the default), errors are silently ignored.
If true, errors raise a ValueError exception.
separator: str. The symbol to use for separating the query arguments.
Defaults to &.
"""
if fp is None:
fp = sys.stdin
Expand All @@ -154,7 +158,7 @@ def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0):
if environ['REQUEST_METHOD'] == 'POST':
ctype, pdict = parse_header(environ['CONTENT_TYPE'])
if ctype == 'multipart/form-data':
return parse_multipart(fp, pdict)
return parse_multipart(fp, pdict, separator=separator)
elif ctype == 'application/x-www-form-urlencoded':
clength = int(environ['CONTENT_LENGTH'])
if maxlen and clength > maxlen:
Expand All @@ -178,10 +182,10 @@ def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0):
qs = ""
environ['QUERY_STRING'] = qs # XXX Shouldn't, really
return urllib.parse.parse_qs(qs, keep_blank_values, strict_parsing,
encoding=encoding)
encoding=encoding, separator=separator)


def parse_multipart(fp, pdict, encoding="utf-8", errors="replace"):
def parse_multipart(fp, pdict, encoding="utf-8", errors="replace", separator='&'):
"""Parse multipart input.
Arguments:
Expand All @@ -205,7 +209,7 @@ def parse_multipart(fp, pdict, encoding="utf-8", errors="replace"):
except KeyError:
pass
fs = FieldStorage(fp, headers=headers, encoding=encoding, errors=errors,
environ={'REQUEST_METHOD': 'POST'})
environ={'REQUEST_METHOD': 'POST'}, separator=separator)
return {k: fs.getlist(k) for k in fs}

def _parseparam(s):
Expand Down Expand Up @@ -315,7 +319,7 @@ class FieldStorage:
def __init__(self, fp=None, headers=None, outerboundary=b'',
environ=os.environ, keep_blank_values=0, strict_parsing=0,
limit=None, encoding='utf-8', errors='replace',
max_num_fields=None):
max_num_fields=None, separator='&'):
"""Constructor. Read multipart/* until last part.
Arguments, all optional:
Expand Down Expand Up @@ -363,6 +367,7 @@ def __init__(self, fp=None, headers=None, outerboundary=b'',
self.keep_blank_values = keep_blank_values
self.strict_parsing = strict_parsing
self.max_num_fields = max_num_fields
self.separator = separator
if 'REQUEST_METHOD' in environ:
method = environ['REQUEST_METHOD'].upper()
self.qs_on_post = None
Expand Down Expand Up @@ -589,7 +594,7 @@ def read_urlencoded(self):
query = urllib.parse.parse_qsl(
qs, self.keep_blank_values, self.strict_parsing,
encoding=self.encoding, errors=self.errors,
max_num_fields=self.max_num_fields)
max_num_fields=self.max_num_fields, separator=self.separator)
self.list = [MiniFieldStorage(key, value) for key, value in query]
self.skip_lines()

Expand All @@ -605,7 +610,7 @@ def read_multi(self, environ, keep_blank_values, strict_parsing):
query = urllib.parse.parse_qsl(
self.qs_on_post, self.keep_blank_values, self.strict_parsing,
encoding=self.encoding, errors=self.errors,
max_num_fields=self.max_num_fields)
max_num_fields=self.max_num_fields, separator=self.separator)
self.list.extend(MiniFieldStorage(key, value) for key, value in query)

klass = self.FieldStorageClass or self.__class__
Expand Down Expand Up @@ -649,7 +654,7 @@ def read_multi(self, environ, keep_blank_values, strict_parsing):
else self.limit - self.bytes_read
part = klass(self.fp, headers, ib, environ, keep_blank_values,
strict_parsing, limit,
self.encoding, self.errors, max_num_fields)
self.encoding, self.errors, max_num_fields, self.separator)

if max_num_fields is not None:
max_num_fields -= 1
Expand Down
27 changes: 19 additions & 8 deletions Lib/test/test_urlparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
(b"&a=b", [(b'a', b'b')]),
(b"a=a+b&b=b+c", [(b'a', b'a b'), (b'b', b'b c')]),
(b"a=1&a=2", [(b'a', b'1'), (b'a', b'2')]),
(";a=b", [(';a', 'b')]),
("a=a+b;b=b+c", [('a', 'a b;b=b c')]),
(b";a=b", [(b';a', b'b')]),
(b"a=a+b;b=b+c", [(b'a', b'a b;b=b c')]),
]

# Each parse_qs testcase is a two-tuple that contains
Expand All @@ -58,6 +62,10 @@
(b"&a=b", {b'a': [b'b']}),
(b"a=a+b&b=b+c", {b'a': [b'a b'], b'b': [b'b c']}),
(b"a=1&a=2", {b'a': [b'1', b'2']}),
(";a=b", {';a': ['b']}),
("a=a+b;b=b+c", {'a': ['a b;b=b c']}),
(b";a=b", {b';a': [b'b']}),
(b"a=a+b;b=b+c", {b'a':[ b'a b;b=b c']}),
]

class UrlParseTestCase(unittest.TestCase):
Expand Down Expand Up @@ -869,7 +877,7 @@ def test_parse_qsl_max_num_fields(self):
urllib.parse.parse_qs('&'.join(['a=a']*10), max_num_fields=10)

def test_parse_qs_separator(self):
semicolon_cases = [
parse_qs_semicolon_cases = [
(";", {}),
(";;", {}),
(";a=b", {'a': ['b']}),
Expand All @@ -881,13 +889,14 @@ def test_parse_qs_separator(self):
(b"a=a+b;b=b+c", {b'a': [b'a b'], b'b': [b'b c']}),
(b"a=1;a=2", {b'a': [b'1', b'2']}),
]
for orig, expect in semicolon_cases:
result = urllib.parse.parse_qs(orig, separator=';')
self.assertEqual(result, expect, "Error parsing %r" % orig)
for orig, expect in parse_qs_semicolon_cases:
with self.subTest(f"Original: {orig!r}, Expected: {expect!r}"):
result = urllib.parse.parse_qs(orig, separator=';')
self.assertEqual(result, expect, "Error parsing %r" % orig)


def test_parse_qsl_separator(self):
semicolon_cases = [
parse_qsl_semicolon_cases = [
(";", []),
(";;", []),
(";a=b", [('a', 'b')]),
Expand All @@ -899,9 +908,11 @@ def test_parse_qsl_separator(self):
(b"a=a+b;b=b+c", [(b'a', b'a b'), (b'b', b'b c')]),
(b"a=1;a=2", [(b'a', b'1'), (b'a', b'2')]),
]
for orig, expect in semicolon_cases:
result = urllib.parse.parse_qsl(orig, separator=';')
self.assertEqual(result, expect, "Error parsing %r" % orig)
for orig, expect in parse_qsl_semicolon_cases:
with self.subTest(f"Original: {orig!r}, Expected: {expect!r}"):
result = urllib.parse.parse_qsl(orig, separator=';')
self.assertEqual(result, expect, "Error parsing %r" % orig)


def test_urlencode_sequences(self):
# Other tests incidentally urlencode things; test non-covered cases:
Expand Down
4 changes: 4 additions & 0 deletions Lib/urllib/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,10 @@ def parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
"""
qs, _coerce_result = _coerce_args(qs)

if not separator or (not isinstance(separator, str)
and not isinstance(separator, bytes)):
raise ValueError("Separator must be of type string or bytes.")

# If max_num_fields is defined then check that the number of fields
# is less than max_num_fields. This prevents a memory exhaustion DOS
# attack via post bodies with many fields.
Expand Down

0 comments on commit 8e1e361

Please sign in to comment.