Skip to content

Commit

Permalink
Fixes Python API
Browse files Browse the repository at this point in the history
  • Loading branch information
XComp committed Sep 5, 2024
1 parent 8b9db44 commit d420659
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 28 deletions.
16 changes: 8 additions & 8 deletions flink-python/pyflink/common/restart_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from py4j.java_gateway import get_java_class

from pyflink.java_gateway import get_gateway
from pyflink.util.java_utils import to_j_flink_time, from_j_flink_time
from pyflink.util.java_utils import to_j_duration, from_j_duration

__all__ = ['RestartStrategies', 'RestartStrategyConfiguration']

Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self, restart_attempts=None, delay_between_attempts_interval=None,
self._j_restart_strategy_configuration = \
gateway.jvm.RestartStrategies\
.fixedDelayRestart(
restart_attempts, to_j_flink_time(delay_between_attempts_interval))
restart_attempts, to_j_duration(delay_between_attempts_interval))
super(RestartStrategies.FixedDelayRestartStrategyConfiguration, self)\
.__init__(self._j_restart_strategy_configuration)
else:
Expand All @@ -103,7 +103,7 @@ def get_restart_attempts(self) -> int:
return self._j_restart_strategy_configuration.getRestartAttempts()

def get_delay_between_attempts_interval(self) -> timedelta:
return from_j_flink_time(
return from_j_duration(
self._j_restart_strategy_configuration.getDelayBetweenAttemptsInterval())

class FailureRateRestartStrategyConfiguration(RestartStrategyConfiguration):
Expand All @@ -129,8 +129,8 @@ def __init__(self,
self._j_restart_strategy_configuration = \
gateway.jvm.RestartStrategies\
.FailureRateRestartStrategyConfiguration(max_failure_rate,
to_j_flink_time(failure_interval),
to_j_flink_time(
to_j_duration(failure_interval),
to_j_duration(
delay_between_attempts_interval))
super(RestartStrategies.FailureRateRestartStrategyConfiguration, self)\
.__init__(self._j_restart_strategy_configuration)
Expand All @@ -142,11 +142,11 @@ def get_max_failure_rate(self) -> int:
return self._j_restart_strategy_configuration.getMaxFailureRate()

def get_failure_interval(self) -> timedelta:
return from_j_flink_time(self._j_restart_strategy_configuration.getFailureInterval())
return from_j_duration(self._j_restart_strategy_configuration.getFailureInterval())

def get_delay_between_attempts_interval(self) -> timedelta:
return from_j_flink_time(self._j_restart_strategy_configuration
.getDelayBetweenAttemptsInterval())
return from_j_duration(self._j_restart_strategy_configuration
.getDelayBetweenAttemptsInterval())

class FallbackRestartStrategyConfiguration(RestartStrategyConfiguration):
"""
Expand Down
4 changes: 2 additions & 2 deletions flink-python/pyflink/fn_execution/embedded/java_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

# Java StateTtlConfig
JStateTtlConfig = findClass('org.apache.flink.api.common.state.StateTtlConfig')
JTime = findClass('org.apache.flink.api.common.time.Time')
JDuration = findClass('java.time.Duration')
JUpdateType = findClass('org.apache.flink.api.common.state.StateTtlConfig$UpdateType')
JStateVisibility = findClass('org.apache.flink.api.common.state.StateTtlConfig$StateVisibility')

Expand Down Expand Up @@ -140,7 +140,7 @@ def to_java_typeinfo(type_info: TypeInformation):

def to_java_state_ttl_config(ttl_config: StateTtlConfig):
j_ttl_config_builder = JStateTtlConfig.newBuilder(
JTime.milliseconds(ttl_config.get_ttl().to_milliseconds()))
JDuration.ofMillis(ttl_config.get_ttl().to_milliseconds()))

update_type = ttl_config.get_update_type()
if update_type == StateTtlConfig.UpdateType.Disabled:
Expand Down
6 changes: 3 additions & 3 deletions flink-python/pyflink/table/table_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def set_idle_state_retention_time(self,
least 5 minutes greater than minTime. Set to
0 (zero) to never clean-up the state.
"""
j_time_class = get_gateway().jvm.org.apache.flink.api.common.time.Time
j_min_time = j_time_class.milliseconds(long(round(min_time.total_seconds() * 1000)))
j_max_time = j_time_class.milliseconds(long(round(max_time.total_seconds() * 1000)))
j_duration_class = get_gateway().jvm.java.time.Duration
j_min_time = j_duration_class.ofMillis(long(round(min_time.total_seconds() * 1000)))
j_max_time = j_duration_class.ofMillis(long(round(max_time.total_seconds() * 1000)))
self._j_table_config.setIdleStateRetentionTime(j_min_time, j_max_time)

def set_idle_state_retention(self, duration: datetime.timedelta):
Expand Down
16 changes: 7 additions & 9 deletions flink-python/pyflink/util/java_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,19 @@ def to_jarray(j_type, arr):
return j_arr


def to_j_flink_time(time_delta):
def to_j_duration(time_delta):
gateway = get_gateway()
TimeUnit = gateway.jvm.java.util.concurrent.TimeUnit
Time = gateway.jvm.org.apache.flink.api.common.time.Time
Duration = gateway.jvm.java.time.Duration
if isinstance(time_delta, timedelta):
total_microseconds = round(time_delta.total_seconds() * 1000 * 1000)
return Time.of(total_microseconds, TimeUnit.MICROSECONDS)
total_milliseconds = round(time_delta.total_seconds() * 1000)
else:
# time delta in milliseconds
total_milliseconds = time_delta
return Time.milliseconds(total_milliseconds)

return Duration.ofMillis(total_milliseconds)

def from_j_flink_time(j_flink_time):
total_milliseconds = j_flink_time.toMilliseconds()

def from_j_duration(j_duration):
total_milliseconds = j_duration.toMillis()
return timedelta(milliseconds=total_milliseconds)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
Expand All @@ -38,6 +37,7 @@
import com.google.protobuf.ByteString;
import org.apache.beam.model.pipeline.v1.RunnerApi;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -439,7 +439,7 @@ private static FlinkFnApi.CoderInfoDescriptor createCoderInfoDescriptorProto(
public static StateTtlConfig parseStateTtlConfigFromProto(
FlinkFnApi.StateDescriptor.StateTTLConfig stateTTLConfigProto) {
StateTtlConfig.Builder builder =
StateTtlConfig.newBuilder(Time.milliseconds(stateTTLConfigProto.getTtl()))
StateTtlConfig.newBuilder(Duration.ofMillis(stateTTLConfigProto.getTtl()))
.setUpdateType(
parseUpdateTypeFromProto(stateTTLConfigProto.getUpdateType()))
.setStateVisibility(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
package org.apache.flink.streaming.api.utils;

import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.python.util.ProtoUtils;

import org.junit.jupiter.api.Test;

import java.time.Duration;
import java.util.concurrent.TimeUnit;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -82,7 +80,7 @@ void testParseStateTtlConfigFromProto() {
.build();
FlinkFnApi.StateDescriptor.StateTTLConfig stateTTLConfigProto =
FlinkFnApi.StateDescriptor.StateTTLConfig.newBuilder()
.setTtl(Time.of(1000, TimeUnit.MILLISECONDS).toMilliseconds())
.setTtl(1000)
.setUpdateType(
FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType
.OnCreateAndWrite)
Expand All @@ -99,7 +97,7 @@ void testParseStateTtlConfigFromProto() {
.isEqualTo(StateTtlConfig.UpdateType.OnCreateAndWrite);
assertThat(stateTTLConfig.getStateVisibility())
.isEqualTo(StateTtlConfig.StateVisibility.NeverReturnExpired);
assertThat(stateTTLConfig.getTtl()).isEqualTo(Time.milliseconds(1000));
assertThat(stateTTLConfig.getTimeToLive()).isEqualTo(Duration.ofMillis(1000));
assertThat(stateTTLConfig.getTtlTimeCharacteristic())
.isEqualTo(StateTtlConfig.TtlTimeCharacteristic.ProcessingTime);

Expand Down

0 comments on commit d420659

Please sign in to comment.