diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 24ae3776a217b..e2a97acb5e2a7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -786,7 +786,7 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None, once=None): + def trigger(self, processingTime=None, once=None, continuous=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -802,23 +802,38 @@ def trigger(self, processingTime=None, once=None): >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') >>> # trigger the query for just once batch of data >>> writer = sdf.writeStream.trigger(once=True) + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.writeStream.trigger(continuous='5 seconds') """ + params = [processingTime, once, continuous] + + if params.count(None) == 3: + raise ValueError('No trigger provided') + elif params.count(None) < 2: + raise ValueError('Multiple triggers not allowed.') + jTrigger = None if processingTime is not None: - if once is not None: - raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) interval = processingTime.strip() jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( interval) + elif once is not None: if once is not True: raise ValueError('Value for once must be True. Got: %s' % once) jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: - raise ValueError('No trigger provided') + if type(continuous) != str or len(continuous.strip()) == 0: + raise ValueError('Value for continuous must be a non empty string. Got: %s' % + continuous) + interval = continuous.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous( + interval) + self._jwrite = self._jwrite.trigger(jTrigger) return self diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f84aa3d68b808..25483594f2725 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1538,6 +1538,12 @@ def test_stream_trigger(self): except ValueError: pass + # Should not take multiple args + try: + df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') + except ValueError: + pass + # Should take only keyword args try: df.writeStream.trigger('5 seconds')