diff --git a/Lib/cgi.py b/Lib/cgi.py index 6018c3608697af..6c72507c2087de 100755 --- a/Lib/cgi.py +++ b/Lib/cgi.py @@ -115,7 +115,8 @@ def closelog(): # 0 ==> unlimited input maxlen = 0 -def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0): +def parse(fp=None, environ=os.environ, keep_blank_values=0, + strict_parsing=0, separator='&'): """Parse a query in the environment or from a file (default stdin) Arguments, all optional: @@ -134,6 +135,9 @@ def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0): strict_parsing: flag indicating what to do with parsing errors. If false (the default), errors are silently ignored. If true, errors raise a ValueError exception. + + separator: str. The symbol to use for separating the query arguments. + Defaults to &. """ if fp is None: fp = sys.stdin @@ -154,7 +158,7 @@ def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0): if environ['REQUEST_METHOD'] == 'POST': ctype, pdict = parse_header(environ['CONTENT_TYPE']) if ctype == 'multipart/form-data': - return parse_multipart(fp, pdict) + return parse_multipart(fp, pdict, separator=separator) elif ctype == 'application/x-www-form-urlencoded': clength = int(environ['CONTENT_LENGTH']) if maxlen and clength > maxlen: @@ -178,10 +182,10 @@ def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0): qs = "" environ['QUERY_STRING'] = qs # XXX Shouldn't, really return urllib.parse.parse_qs(qs, keep_blank_values, strict_parsing, - encoding=encoding) + encoding=encoding, separator=separator) -def parse_multipart(fp, pdict, encoding="utf-8", errors="replace"): +def parse_multipart(fp, pdict, encoding="utf-8", errors="replace", separator='&'): """Parse multipart input. Arguments: @@ -205,7 +209,7 @@ def parse_multipart(fp, pdict, encoding="utf-8", errors="replace"): except KeyError: pass fs = FieldStorage(fp, headers=headers, encoding=encoding, errors=errors, - environ={'REQUEST_METHOD': 'POST'}) + environ={'REQUEST_METHOD': 'POST'}, separator=separator) return {k: fs.getlist(k) for k in fs} def _parseparam(s): @@ -315,7 +319,7 @@ class FieldStorage: def __init__(self, fp=None, headers=None, outerboundary=b'', environ=os.environ, keep_blank_values=0, strict_parsing=0, limit=None, encoding='utf-8', errors='replace', - max_num_fields=None): + max_num_fields=None, separator='&'): """Constructor. Read multipart/* until last part. Arguments, all optional: @@ -363,6 +367,7 @@ def __init__(self, fp=None, headers=None, outerboundary=b'', self.keep_blank_values = keep_blank_values self.strict_parsing = strict_parsing self.max_num_fields = max_num_fields + self.separator = separator if 'REQUEST_METHOD' in environ: method = environ['REQUEST_METHOD'].upper() self.qs_on_post = None @@ -589,7 +594,7 @@ def read_urlencoded(self): query = urllib.parse.parse_qsl( qs, self.keep_blank_values, self.strict_parsing, encoding=self.encoding, errors=self.errors, - max_num_fields=self.max_num_fields) + max_num_fields=self.max_num_fields, separator=self.separator) self.list = [MiniFieldStorage(key, value) for key, value in query] self.skip_lines() @@ -605,7 +610,7 @@ def read_multi(self, environ, keep_blank_values, strict_parsing): query = urllib.parse.parse_qsl( self.qs_on_post, self.keep_blank_values, self.strict_parsing, encoding=self.encoding, errors=self.errors, - max_num_fields=self.max_num_fields) + max_num_fields=self.max_num_fields, separator=self.separator) self.list.extend(MiniFieldStorage(key, value) for key, value in query) klass = self.FieldStorageClass or self.__class__ @@ -649,7 +654,7 @@ def read_multi(self, environ, keep_blank_values, strict_parsing): else self.limit - self.bytes_read part = klass(self.fp, headers, ib, environ, keep_blank_values, strict_parsing, limit, - self.encoding, self.errors, max_num_fields) + self.encoding, self.errors, max_num_fields, self.separator) if max_num_fields is not None: max_num_fields -= 1 diff --git a/Lib/test/test_urlparse.py b/Lib/test/test_urlparse.py index 28cc3e1ba5fee4..3b1c360625b5a6 100644 --- a/Lib/test/test_urlparse.py +++ b/Lib/test/test_urlparse.py @@ -32,6 +32,10 @@ (b"&a=b", [(b'a', b'b')]), (b"a=a+b&b=b+c", [(b'a', b'a b'), (b'b', b'b c')]), (b"a=1&a=2", [(b'a', b'1'), (b'a', b'2')]), + (";a=b", [(';a', 'b')]), + ("a=a+b;b=b+c", [('a', 'a b;b=b c')]), + (b";a=b", [(b';a', b'b')]), + (b"a=a+b;b=b+c", [(b'a', b'a b;b=b c')]), ] # Each parse_qs testcase is a two-tuple that contains @@ -58,6 +62,10 @@ (b"&a=b", {b'a': [b'b']}), (b"a=a+b&b=b+c", {b'a': [b'a b'], b'b': [b'b c']}), (b"a=1&a=2", {b'a': [b'1', b'2']}), + (";a=b", {';a': ['b']}), + ("a=a+b;b=b+c", {'a': ['a b;b=b c']}), + (b";a=b", {b';a': [b'b']}), + (b"a=a+b;b=b+c", {b'a':[ b'a b;b=b c']}), ] class UrlParseTestCase(unittest.TestCase): @@ -869,7 +877,7 @@ def test_parse_qsl_max_num_fields(self): urllib.parse.parse_qs('&'.join(['a=a']*10), max_num_fields=10) def test_parse_qs_separator(self): - semicolon_cases = [ + parse_qs_semicolon_cases = [ (";", {}), (";;", {}), (";a=b", {'a': ['b']}), @@ -881,13 +889,14 @@ def test_parse_qs_separator(self): (b"a=a+b;b=b+c", {b'a': [b'a b'], b'b': [b'b c']}), (b"a=1;a=2", {b'a': [b'1', b'2']}), ] - for orig, expect in semicolon_cases: - result = urllib.parse.parse_qs(orig, separator=';') - self.assertEqual(result, expect, "Error parsing %r" % orig) + for orig, expect in parse_qs_semicolon_cases: + with self.subTest(f"Original: {orig!r}, Expected: {expect!r}"): + result = urllib.parse.parse_qs(orig, separator=';') + self.assertEqual(result, expect, "Error parsing %r" % orig) def test_parse_qsl_separator(self): - semicolon_cases = [ + parse_qsl_semicolon_cases = [ (";", []), (";;", []), (";a=b", [('a', 'b')]), @@ -899,9 +908,11 @@ def test_parse_qsl_separator(self): (b"a=a+b;b=b+c", [(b'a', b'a b'), (b'b', b'b c')]), (b"a=1;a=2", [(b'a', b'1'), (b'a', b'2')]), ] - for orig, expect in semicolon_cases: - result = urllib.parse.parse_qsl(orig, separator=';') - self.assertEqual(result, expect, "Error parsing %r" % orig) + for orig, expect in parse_qsl_semicolon_cases: + with self.subTest(f"Original: {orig!r}, Expected: {expect!r}"): + result = urllib.parse.parse_qsl(orig, separator=';') + self.assertEqual(result, expect, "Error parsing %r" % orig) + def test_urlencode_sequences(self): # Other tests incidentally urlencode things; test non-covered cases: diff --git a/Lib/urllib/parse.py b/Lib/urllib/parse.py index e07db8368616f2..5bd067895bfa3d 100644 --- a/Lib/urllib/parse.py +++ b/Lib/urllib/parse.py @@ -734,6 +734,10 @@ def parse_qsl(qs, keep_blank_values=False, strict_parsing=False, """ qs, _coerce_result = _coerce_args(qs) + if not separator or (not isinstance(separator, str) + and not isinstance(separator, bytes)): + raise ValueError("Separator must be of type string or bytes.") + # If max_num_fields is defined then check that the number of fields # is less than max_num_fields. This prevents a memory exhaustion DOS # attack via post bodies with many fields.