diff --git a/connectors/riot-faker/riot-faker.gradle b/connectors/riot-faker/riot-faker.gradle index 7f32d6b20..d63577bec 100644 --- a/connectors/riot-faker/riot-faker.gradle +++ b/connectors/riot-faker/riot-faker.gradle @@ -18,7 +18,7 @@ dependencies { implementation project(':riot-core') implementation 'org.springframework.batch:spring-batch-core' - implementation group: 'net.datafaker', name: 'datafaker', version: datafakerVersion + api group: 'net.datafaker', name: 'datafaker', version: datafakerVersion } compileJava { diff --git a/connectors/riot-faker/src/main/java/com/redis/riot/faker/FakerItemReader.java b/connectors/riot-faker/src/main/java/com/redis/riot/faker/FakerItemReader.java index 8fcf400a7..19855daad 100644 --- a/connectors/riot-faker/src/main/java/com/redis/riot/faker/FakerItemReader.java +++ b/connectors/riot-faker/src/main/java/com/redis/riot/faker/FakerItemReader.java @@ -1,20 +1,18 @@ package com.redis.riot.faker; +import java.util.AbstractMap; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Locale; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.Map.Entry; +import java.util.stream.Collectors; import org.springframework.batch.item.ItemReader; import org.springframework.batch.item.support.AbstractItemCountingItemStreamItemReader; -import org.springframework.expression.spel.support.DataBindingMethodResolver; -import org.springframework.expression.spel.support.StandardEvaluationContext; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; -import com.redis.riot.core.Expression; - import net.datafaker.Faker; /** @@ -26,9 +24,11 @@ public class FakerItemReader extends AbstractItemCountingItemStreamItemReader fields = new LinkedHashMap<>(); + private Map expressions = new LinkedHashMap<>(); private Locale locale = DEFAULT_LOCALE; - private StandardEvaluationContext evaluationContext; + + private Faker faker; + private Map fields; public FakerItemReader() { setName(ClassUtils.getShortName(getClass())); @@ -38,66 +38,37 @@ public void setLocale(Locale locale) { this.locale = locale; } - public void setFields(Map fields) { - this.fields = fields; + public void setExpressions(Map fields) { + this.expressions = fields; } @Override protected synchronized void doOpen() throws Exception { - if (evaluationContext == null) { - Assert.notEmpty(fields, "No field specified"); - evaluationContext = new StandardEvaluationContext(); - evaluationContext.addPropertyAccessor(new ReflectivePropertyAccessor()); - evaluationContext.addMethodResolver(DataBindingMethodResolver.forInstanceMethodInvocation()); - evaluationContext.setRootObject(new AugmentedFaker(locale)); + Assert.notEmpty(expressions, "No field specified"); + if (fields == null) { + fields = expressions.entrySet().stream().map(this::normalizeField) + .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); + } + faker = new Faker(locale); + } + + private Entry normalizeField(Entry field) { + if (field.getValue().startsWith("#{")) { + return field; } + return new AbstractMap.SimpleEntry<>(field.getKey(), "#{" + field.getValue() + "}"); } @Override protected Map doRead() throws Exception { Map map = new HashMap<>(); - fields.forEach((k, v) -> map.put(k, v.getValue(evaluationContext))); + fields.forEach((k, v) -> map.put(k, faker.expression(v))); return map; } @Override protected synchronized void doClose() { - evaluationContext = null; - } - - public class AugmentedFaker extends Faker { - - private final AtomicInteger threadCount = new AtomicInteger(); - private final ThreadLocal threadId = ThreadLocal.withInitial(threadCount::incrementAndGet); - - public AugmentedFaker(Locale locale) { - super(locale); - } - - public void setThread(int id) { - threadId.set(id); - } - - public void removeThread() { - threadId.remove(); - } - - public int getIndex() { - return index(); - } - - public int index() { - return getCurrentItemCount(); - } - - public int getThread() { - return thread(); - } - - public int thread() { - return threadId.get(); - } - + faker = null; } } diff --git a/connectors/riot-faker/src/test/java/com/redis/riot/faker/FakerReaderTests.java b/connectors/riot-faker/src/test/java/com/redis/riot/faker/FakerReaderTests.java index 4c016455f..9e1d14791 100644 --- a/connectors/riot-faker/src/test/java/com/redis/riot/faker/FakerReaderTests.java +++ b/connectors/riot-faker/src/test/java/com/redis/riot/faker/FakerReaderTests.java @@ -10,8 +10,6 @@ import org.springframework.batch.item.ExecutionContext; import org.springframework.batch.item.ItemReader; -import com.redis.riot.core.Expression; - class FakerReaderTests { public static List readAll(ItemReader reader) throws Exception { @@ -27,19 +25,17 @@ public static List readAll(ItemReader reader) throws Exception { void fakerReader() throws Exception { int count = 100; FakerItemReader reader = new FakerItemReader(); - Map fields = new LinkedHashMap<>(); - fields.put("index", Expression.parse("index")); - fields.put("firstName", Expression.parse("name.firstName")); - fields.put("lastName", Expression.parse("name.lastName")); - fields.put("thread", Expression.parse("thread")); - reader.setFields(fields); + Map fields = new LinkedHashMap<>(); + fields.put("firstName", "Name.first_name"); + fields.put("lastName", "Name.last_name"); + reader.setExpressions(fields); reader.setMaxItemCount(count); reader.open(new ExecutionContext()); List> items = readAll(reader); reader.close(); Assertions.assertEquals(count, items.size()); - Assertions.assertEquals(1, items.get(0).get("index")); - Assertions.assertEquals(1, (Integer) items.get(0).get("thread")); + Assertions.assertTrue(items.get(0).containsKey("firstName")); + Assertions.assertTrue(items.get(0).containsKey("lastName")); } }