Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

String generation from complex regex in integration tests #8594

Merged
merged 12 commits into from
Jul 17, 2023
59 changes: 31 additions & 28 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,13 @@ def modify():
_MAX_CHOICES = 1 << 64
class StringGen(DataGen):
"""Generate strings that match a pattern"""
def __init__(self, pattern="(.|\n){1,30}", flags=0, charset=sre_yield.CHARSET, nullable=True):
def __init__(self, pattern=None, flags=0, charset=sre_yield.CHARSET, nullable=True):
super().__init__(StringType(), nullable=nullable)
self.base_strs = sre_yield.AllStrings(pattern, flags=flags, charset=charset, max_count=_MAX_CHOICES)
# save pattern and charset for cache repr
charsetrepr = '[' + ','.join(charset) + ']' if charset != sre_yield.CHARSET else 'sre_yield.CHARSET'
self.stringrepr = pattern + ',' + str(flags) + ',' + charsetrepr
self.stringrepr = str(pattern) + ',' + str(flags) + ',' + charsetrepr
self.pattern = pattern
self.flags = flags
self.charset = charset

def _cache_repr(self):
return super()._cache_repr() + '(' + self.stringrepr + ')'
Expand All @@ -192,19 +193,19 @@ def with_special_pattern(self, pattern, flags=0, charset=sre_yield.CHARSET, weig
instead of a hard coded string value.
"""
strs = sre_yield.AllStrings(pattern, flags=flags, charset=charset, max_count=_MAX_CHOICES)
try:
length = int(len(strs))
except OverflowError:
length = _MAX_CHOICES
return self.with_special_case(lambda rand : strs[rand.randrange(0, length)], weight=weight)
length = strs.__len__()
return self.with_special_case(lambda rand : strs[rand.randint(0, length-1)], weight=weight)

def start(self, rand):
strs = self.base_strs
try:
length = int(len(strs))
except OverflowError:
length = _MAX_CHOICES
self._start(rand, lambda : strs[rand.randrange(0, length)])
if self.pattern == None and self.charset == sre_yield.CHARSET:
def gen_default_str():
# use "(.|\n){1,30}" as default pattern
return ''.join(rand.choice(sre_yield.CHARSET + ['\n']) for _ in range(30))
self._start(rand, gen_default_str)
else:
strs = sre_yield.AllStrings(self.pattern, flags=self.flags, charset=self.charset, max_count=_MAX_CHOICES)
length = strs.__len__()
self._start(rand, lambda : strs[rand.randint(0, length-1)])

BYTE_MIN = -(1 << 7)
BYTE_MAX = (1 << 7) - 1
Expand Down Expand Up @@ -269,23 +270,25 @@ def __init__(self, precision=None, scale=None, nullable=True, special_cases=None
super().__init__(DecimalType(precision, scale), nullable=nullable, special_cases=special_cases)
self.scale = scale
self.precision = precision
negative_pattern = "-" if avoid_positive_values else "-?"
self.pattern = negative_pattern + "[0-9]{1,"+ str(precision) + "}e" + str(-scale)
self.base_strs = sre_yield.AllStrings(self.pattern, flags=0, charset=sre_yield.CHARSET, max_count=_MAX_CHOICES)
self.avoid_positive_values = avoid_positive_values

def __repr__(self):
return super().__repr__() + '(' + str(self.precision) + ',' + str(self.scale) + ')'

def _cache_repr(self):
return super()._cache_repr() + '(' + self.pattern + ')'
return super()._cache_repr() + '(' + str(self.precision) + ',' + str(self.scale) + ',' + str(self.avoid_positive_values) + ')'

def start(self, rand):
strs = self.base_strs
try:
length = int(strs.length)
except OverflowError:
length = _MAX_CHOICES
self._start(rand, lambda : Decimal(strs[rand.randrange(0, length)]))
def random_decimal(rand):
if self.avoid_positive_values:
sign = "-"
else:
sign = rand.choice(["-", ""])
int_part = "".join([rand.choice("0123456789") for _ in range(self.precision)])
result = f"{sign}{int_part}e{str(-self.scale)}"
return Decimal(result)

self._start(rand, lambda : random_decimal(rand))

LONG_MIN = -(1 << 63)
LONG_MAX = (1 << 63) - 1
Expand Down Expand Up @@ -317,7 +320,7 @@ def next_val(self):
return self._current_val

def _cache_repr(self):
return super()._cache_repr() + '(' + str(self._current_val) + ')'
return super()._cache_repr()

def start(self, rand):
self._current_val = 0
Expand All @@ -337,7 +340,7 @@ def __repr__(self):
return super().__repr__() + '(' + str(self._child) + ')'

def _cache_repr(self):
return super()._cache_repr() + '(' + self._child._cache_repr() + ',' + str(self._length) + str(self._index) + ')'
return super()._cache_repr() + '(' + self._child._cache_repr() + ',' + str(self._length) + ')'

def _loop_values(self):
ret = self._vals[self._index]
Expand Down Expand Up @@ -944,7 +947,7 @@ def get_null_lit_string(spark_type):

def _convert_to_sql(spark_type, data):
if isinstance(data, str):
d = "'" + data.replace("'", "\\'") + "'"
d = "'" + data.replace("\\", "\\\\").replace("\'", "\\\'") + "'"
elif isinstance(data, datetime):
d = "'" + data.strftime('%Y-%m-%d T%H:%M:%S.%f').zfill(26) + "'"
elif isinstance(data, date):
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/src/main/python/get_json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def mk_json_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

@pytest.mark.parametrize('json_str_pattern', [r'\{"store": \{"fruit": \[\{"weight":\d,"type":"[a-z]{1,9}"\}\], ' \
r'"bicycle":\{"price":\d\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
r'"bicycle":\{"price":[1-9]\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
r'"email":"[a-z]{1,5}\@[a-z]{3,10}\.com","owner":"[a-z]{3,8}"\}',
r'\{"a": "[a-z]{1,3}"\}'], ids=idfn)
def test_get_json_object(json_str_pattern):
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/src/main/python/json_tuple_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def mk_json_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

json_str_patterns = [r'\{"store": \{"fruit": \[\{"weight":\d,"type":"[a-z]{1,9}"\}\], ' \
r'"bicycle":\{"price":\d\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
r'"bicycle":\{"price":[1-9]\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
r'"email":"[a-z]{1,5}\@[a-z]{3,10}\.com","owner":"[a-z]{3,8}"\}',
r'\{"a": "[a-z]{1,3}", "b\$":"[b-z]{1,3}"\}']

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def test_regexp_extract_all_idx_negative():

@allow_non_gpu('ProjectExec', 'RegExpExtractAll')
def test_regexp_extract_all_idx_out_of_bounds():
gen = mk_str_gen('[abcd]{0,3}')
gen = StringGen('.{0,10}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change being made? The original regexp was made to match closely with the regexp_extract_all pattern below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply.

The test is failed after this PR because the error message "Regex group count is 2, but the specified group index is 3" will only raised when data matches the pattern.

In the previous code, what triggers this error message is actually the '.{0,10}' as a special case in mk_str_gen. So the test was PASSED because of good luck.

Now I change the pattern to '[a-d]{1,2}.{0,1}[0-9]{1,2}' to make sure they can match the pattern below.

assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "([a-d]+).*([0-9])", 3)'
Expand Down