diff --git a/ingestion/src/main/java/feast/ingestion/transform/WriteToStore.java b/ingestion/src/main/java/feast/ingestion/transform/WriteToStore.java index 2e3a0a5dde..6c63c0d7a9 100644 --- a/ingestion/src/main/java/feast/ingestion/transform/WriteToStore.java +++ b/ingestion/src/main/java/feast/ingestion/transform/WriteToStore.java @@ -89,13 +89,20 @@ public PDone expand(PCollection input) { switch (storeType) { case REDIS: RedisConfig redisConfig = getStore().getRedisConfig(); - input + PCollection redisWriteResult = input .apply( "FeatureRowToRedisMutation", ParDo.of(new FeatureRowToRedisMutationDoFn(getFeatureSetSpecs()))) .apply( "WriteRedisMutationToRedis", - RedisCustomIO.write(redisConfig.getHost(), redisConfig.getPort())); + RedisCustomIO.write(redisConfig)); + if (options.getDeadLetterTableSpec() != null) { + redisWriteResult.apply( + WriteFailedElementToBigQuery.newBuilder() + .setTableSpec(options.getDeadLetterTableSpec()) + .setJsonSchema(ResourceUtil.getDeadletterTableSchemaJson()) + .build()); + } break; case BIGQUERY: BigQueryConfig bigqueryConfig = getStore().getBigqueryConfig(); diff --git a/ingestion/src/main/java/feast/retry/BackOffExecutor.java b/ingestion/src/main/java/feast/retry/BackOffExecutor.java new file mode 100644 index 0000000000..7e38a3cf70 --- /dev/null +++ b/ingestion/src/main/java/feast/retry/BackOffExecutor.java @@ -0,0 +1,38 @@ +package feast.retry; + +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.BackOffUtils; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; +import org.joda.time.Duration; + +import java.io.IOException; +import java.io.Serializable; + +public class BackOffExecutor implements Serializable { + + private static FluentBackoff backoff; + + public BackOffExecutor(Integer maxRetries, Duration initialBackOff) { + backoff = FluentBackoff.DEFAULT + .withMaxRetries(maxRetries) + .withInitialBackoff(initialBackOff); + } + + public void execute(Retriable retriable) throws Exception { + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backOff = backoff.backoff(); + while(true) { + try { + retriable.execute(); + break; + } catch (Exception e) { + if(retriable.isExceptionRetriable(e) && BackOffUtils.next(sleeper, backOff)) { + retriable.cleanUpAfterFailure(); + } else { + throw e; + } + } + } + } +} diff --git a/ingestion/src/main/java/feast/retry/Retriable.java b/ingestion/src/main/java/feast/retry/Retriable.java new file mode 100644 index 0000000000..8fd76fedbb --- /dev/null +++ b/ingestion/src/main/java/feast/retry/Retriable.java @@ -0,0 +1,7 @@ +package feast.retry; + +public interface Retriable { + void execute(); + Boolean isExceptionRetriable(Exception e); + void cleanUpAfterFailure(); +} diff --git a/ingestion/src/main/java/feast/store/serving/redis/RedisCustomIO.java b/ingestion/src/main/java/feast/store/serving/redis/RedisCustomIO.java index 2e5d4c9452..20afc43d76 100644 --- a/ingestion/src/main/java/feast/store/serving/redis/RedisCustomIO.java +++ b/ingestion/src/main/java/feast/store/serving/redis/RedisCustomIO.java @@ -16,19 +16,31 @@ */ package feast.store.serving.redis; +import feast.core.StoreProto; +import feast.ingestion.values.FailedElement; +import feast.retry.BackOffExecutor; +import feast.retry.Retriable; import org.apache.avro.reflect.Nullable; import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.coders.DefaultCoder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PDone; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.joda.time.Duration; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import redis.clients.jedis.Jedis; import redis.clients.jedis.Pipeline; import redis.clients.jedis.Response; +import redis.clients.jedis.exceptions.JedisConnectionException; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; public class RedisCustomIO { @@ -39,8 +51,8 @@ public class RedisCustomIO { private RedisCustomIO() {} - public static Write write(String host, int port) { - return new Write(host, port); + public static Write write(StoreProto.Store.RedisConfig redisConfig) { + return new Write(redisConfig); } public enum Method { @@ -152,12 +164,12 @@ public void setScore(@Nullable Long score) { } /** ServingStoreWrite data to a Redis server. */ - public static class Write extends PTransform, PDone> { + public static class Write extends PTransform, PCollection> { private WriteDoFn dofn; - private Write(String host, int port) { - this.dofn = new WriteDoFn(host, port); + private Write(StoreProto.Store.RedisConfig redisConfig) { + this.dofn = new WriteDoFn(redisConfig); } public Write withBatchSize(int batchSize) { @@ -171,24 +183,28 @@ public Write withTimeout(int timeout) { } @Override - public PDone expand(PCollection input) { - input.apply(ParDo.of(dofn)); - return PDone.in(input.getPipeline()); + public PCollection expand(PCollection input) { + return input.apply(ParDo.of(dofn)); } - public static class WriteDoFn extends DoFn { + public static class WriteDoFn extends DoFn { private final String host; - private int port; + private final int port; + private final BackOffExecutor backOffExecutor; + private final List mutations = new ArrayList<>(); + private Jedis jedis; private Pipeline pipeline; - private int batchCount; private int batchSize = DEFAULT_BATCH_SIZE; private int timeout = DEFAULT_TIMEOUT; - WriteDoFn(String host, int port) { - this.host = host; - this.port = port; + WriteDoFn(StoreProto.Store.RedisConfig redisConfig) { + this.host = redisConfig.getHost(); + this.port = redisConfig.getPort(); + long backoffMs = redisConfig.getInitialBackoffMs() > 0 ? redisConfig.getInitialBackoffMs() : 1; + this.backOffExecutor = new BackOffExecutor(redisConfig.getMaxRetries(), + Duration.millis(backoffMs)); } public WriteDoFn withBatchSize(int batchSize) { @@ -212,24 +228,69 @@ public void setup() { @StartBundle public void startBundle() { + mutations.clear(); pipeline = jedis.pipelined(); - pipeline.multi(); - batchCount = 0; + } + + private void executeBatch() throws Exception { + backOffExecutor.execute(new Retriable() { + @Override + public void execute() { + pipeline.multi(); + mutations.forEach(mutation -> { + writeRecord(mutation); + if (mutation.getExpiryMillis() != null && mutation.getExpiryMillis() > 0) { + pipeline.pexpire(mutation.getKey(), mutation.getExpiryMillis()); + } + }); + pipeline.exec(); + pipeline.sync(); + mutations.clear(); + } + + @Override + public Boolean isExceptionRetriable(Exception e) { + return e instanceof JedisConnectionException; + } + + @Override + public void cleanUpAfterFailure() { + try { + pipeline.close(); + } catch (IOException e) { + log.error(String.format("Error while closing pipeline: %s", e.getMessage())); + } + jedis = new Jedis(host, port, timeout); + pipeline = jedis.pipelined(); + } + }); + } + + private FailedElement toFailedElement(RedisMutation mutation, Exception exception, String jobName) { + return FailedElement.newBuilder() + .setJobName(jobName) + .setTransformName("RedisCustomIO") + .setPayload(mutation.getValue().toString()) + .setErrorMessage(exception.getMessage()) + .setStackTrace(ExceptionUtils.getStackTrace(exception)) + .build(); } @ProcessElement public void processElement(ProcessContext context) { RedisMutation mutation = context.element(); - writeRecord(mutation); - if (mutation.getExpiryMillis() != null && mutation.getExpiryMillis() > 0) { - pipeline.pexpire(mutation.getKey(), mutation.getExpiryMillis()); - } - batchCount++; - if (batchCount >= batchSize) { - pipeline.exec(); - pipeline.sync(); - pipeline.multi(); - batchCount = 0; + mutations.add(mutation); + if (mutations.size() >= batchSize) { + try { + executeBatch(); + } catch (Exception e) { + mutations.forEach(failedMutation -> { + FailedElement failedElement = toFailedElement( + failedMutation, e, context.getPipelineOptions().getJobName()); + context.output(failedElement); + }); + mutations.clear(); + } } } @@ -254,12 +315,19 @@ private Response writeRecord(RedisMutation mutation) { } @FinishBundle - public void finishBundle() { - if (pipeline.isInMulti()) { - pipeline.exec(); - pipeline.sync(); + public void finishBundle(FinishBundleContext context) throws IOException, InterruptedException { + if(mutations.size() > 0) { + try { + executeBatch(); + } catch (Exception e) { + mutations.forEach(failedMutation -> { + FailedElement failedElement = toFailedElement( + failedMutation, e, context.getPipelineOptions().getJobName()); + context.output(failedElement, Instant.now(), GlobalWindow.INSTANCE); + }); + mutations.clear(); + } } - batchCount = 0; } @Teardown diff --git a/ingestion/src/test/java/feast/store/serving/redis/RedisCustomIOTest.java b/ingestion/src/test/java/feast/store/serving/redis/RedisCustomIOTest.java index a35e63386d..94167059b4 100644 --- a/ingestion/src/test/java/feast/store/serving/redis/RedisCustomIOTest.java +++ b/ingestion/src/test/java/feast/store/serving/redis/RedisCustomIOTest.java @@ -16,52 +16,66 @@ */ package feast.store.serving.redis; -import static feast.test.TestUtil.field; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; - +import feast.core.StoreProto; import feast.storage.RedisProto.RedisKey; import feast.store.serving.redis.RedisCustomIO.Method; import feast.store.serving.redis.RedisCustomIO.RedisMutation; import feast.types.FeatureRowProto.FeatureRow; import feast.types.ValueProto.ValueType.Enum; -import java.io.IOException; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.stream.Collectors; +import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; -import org.junit.AfterClass; -import org.junit.BeforeClass; +import org.apache.beam.sdk.values.PCollection; +import org.junit.After; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import redis.clients.jedis.Jedis; import redis.embedded.Redis; import redis.embedded.RedisServer; -public class RedisCustomIOTest { +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static feast.test.TestUtil.field; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; - @Rule public transient TestPipeline p = TestPipeline.create(); +public class RedisCustomIOTest { + @Rule + public transient TestPipeline p = TestPipeline.create(); + private static String REDIS_HOST = "localhost"; private static int REDIS_PORT = 51234; - private static Redis redis; - private static Jedis jedis; + private Redis redis; + private Jedis jedis; + - @BeforeClass - public static void setUp() throws IOException { + @Before + public void setUp() throws IOException { redis = new RedisServer(REDIS_PORT); redis.start(); - jedis = new Jedis("localhost", REDIS_PORT); + jedis = new Jedis(REDIS_HOST, REDIS_PORT); } - @AfterClass - public static void teardown() { + @After + public void teardown() { redis.stop(); } @Test public void shouldWriteToRedis() { + StoreProto.Store.RedisConfig redisConfig = StoreProto.Store.RedisConfig.newBuilder() + .setHost(REDIS_HOST) + .setPort(REDIS_PORT) + .build(); HashMap kvs = new LinkedHashMap<>(); kvs.put( RedisKey.newBuilder() @@ -96,7 +110,8 @@ public void shouldWriteToRedis() { null)) .collect(Collectors.toList()); - p.apply(Create.of(featureRowWrites)).apply(RedisCustomIO.write("localhost", REDIS_PORT)); + p.apply(Create.of(featureRowWrites)) + .apply(RedisCustomIO.write(redisConfig)); p.run(); kvs.forEach( @@ -105,4 +120,74 @@ public void shouldWriteToRedis() { assertThat(actual, equalTo(value.toByteArray())); }); } + + @Test(timeout = 10000) + public void shouldRetryFailConnection() throws InterruptedException { + StoreProto.Store.RedisConfig redisConfig = StoreProto.Store.RedisConfig.newBuilder() + .setHost(REDIS_HOST) + .setPort(REDIS_PORT) + .setMaxRetries(4) + .setInitialBackoffMs(2000) + .build(); + HashMap kvs = new LinkedHashMap<>(); + kvs.put(RedisKey.newBuilder().setFeatureSet("fs:1") + .addEntities(field("entity", 1, Enum.INT64)).build(), + FeatureRow.newBuilder().setFeatureSet("fs:1") + .addFields(field("entity", 1, Enum.INT64)) + .addFields(field("feature", "one", Enum.STRING)).build()); + + List featureRowWrites = kvs.entrySet().stream() + .map(kv -> new RedisMutation(Method.SET, kv.getKey().toByteArray(), + kv.getValue().toByteArray(), + null, null) + ) + .collect(Collectors.toList()); + + PCollection failedElementCount = p.apply(Create.of(featureRowWrites)) + .apply(RedisCustomIO.write(redisConfig)) + .apply(Count.globally()); + + redis.stop(); + final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1); + ScheduledFuture scheduledRedisRestart = redisRestartExecutor.schedule(() -> { + redis.start(); + }, 3, TimeUnit.SECONDS); + + PAssert.that(failedElementCount).containsInAnyOrder(0L); + p.run(); + scheduledRedisRestart.cancel(true); + + kvs.forEach((key, value) -> { + byte[] actual = jedis.get(key.toByteArray()); + assertThat(actual, equalTo(value.toByteArray())); + }); + } + + @Test + public void shouldProduceFailedElementIfRetryExceeded() { + StoreProto.Store.RedisConfig redisConfig = StoreProto.Store.RedisConfig.newBuilder() + .setHost(REDIS_HOST) + .setPort(REDIS_PORT) + .build(); + HashMap kvs = new LinkedHashMap<>(); + kvs.put(RedisKey.newBuilder().setFeatureSet("fs:1") + .addEntities(field("entity", 1, Enum.INT64)).build(), + FeatureRow.newBuilder().setFeatureSet("fs:1") + .addFields(field("entity", 1, Enum.INT64)) + .addFields(field("feature", "one", Enum.STRING)).build()); + + List featureRowWrites = kvs.entrySet().stream() + .map(kv -> new RedisMutation(Method.SET, kv.getKey().toByteArray(), + kv.getValue().toByteArray(), + null, null) + ).collect(Collectors.toList()); + + PCollection failedElementCount = p.apply(Create.of(featureRowWrites)) + .apply(RedisCustomIO.write(redisConfig)) + .apply(Count.globally()); + + redis.stop(); + PAssert.that(failedElementCount).containsInAnyOrder(1L); + p.run(); + } } diff --git a/protos/feast/core/Store.proto b/protos/feast/core/Store.proto index e1b8c581a3..a6a01f5d82 100644 --- a/protos/feast/core/Store.proto +++ b/protos/feast/core/Store.proto @@ -110,6 +110,11 @@ message Store { message RedisConfig { string host = 1; int32 port = 2; + // Optional. The number of milliseconds to wait before retrying failed Redis connection. + // By default, Feast uses exponential backoff policy and "initial_backoff_ms" sets the initial wait duration. + int32 initial_backoff_ms = 3; + // Optional. Maximum total number of retries for connecting to Redis. Default to zero retries. + int32 max_retries = 4; } message BigQueryConfig {