From 15e7613b778fa6faa0c1e6688eeca5e505b1aeec Mon Sep 17 00:00:00 2001 From: jruaux Date: Wed, 20 Nov 2024 21:38:08 -0800 Subject: [PATCH] refactor: Added file reader and writer registries --- .../riot/core/AbstractCallableCommand.java | 18 +- .../redis/riot/core/AbstractJobCommand.java | 55 +++-- .../riot/core/RiotExecutionException.java | 18 ++ .../core/RiotInitializationException.java | 18 ++ .../riot/file/AbstractFactoryRegistry.java | 184 --------------- .../riot/file/AbstractReaderFactory.java | 83 +++++++ .../com/redis/riot/file/AbstractRegistry.java | 43 ++++ .../riot/file/AbstractWriterFactory.java | 58 +++++ .../riot/file/DelimitedReaderFactory.java | 27 +++ .../riot/file/DelimitedWriterFactory.java | 30 +++ .../java/com/redis/riot/file/FileOptions.java | 40 ++-- .../redis/riot/file/FileReaderRegistry.java | 222 +++--------------- .../com/redis/riot/file/FileReaderResult.java | 37 +++ .../java/com/redis/riot/file/FileUtils.java | 2 +- .../redis/riot/file/FileWriterRegistry.java | 181 ++++---------- .../com/redis/riot/file/FileWriterResult.java | 36 +++ .../riot/file/FixedWidthReaderFactory.java | 27 +++ .../riot/file/FormattedWriterFactory.java | 23 ++ .../redis/riot/file/GoogleStorageOptions.java | 59 ++++- .../file/GoogleStorageProtocolResolver.java | 17 +- .../riot/file/JsonLinesReaderFactory.java | 30 +++ .../riot/file/JsonLinesWriterFactory.java | 18 ++ .../redis/riot/file/JsonReaderFactory.java | 27 +++ .../redis/riot/file/JsonWriterFactory.java | 29 +++ .../com/redis/riot/file/ReaderFactory.java | 10 + .../file/{Factory.java => ResourceMap.java} | 4 +- .../com/redis/riot/file/ResourceTypeMap.java | 64 ----- ...rceLoader.java => RiotResourceLoader.java} | 77 +++--- .../com/redis/riot/file/RiotResourceMap.java | 74 ++++++ .../java/com/redis/riot/file/S3Options.java | 27 ++- .../redis/riot/file/S3ProtocolResolver.java | 45 ++-- .../riot/file/StdInProtocolResolver.java | 10 +- .../com/redis/riot/file/WriterFactory.java | 9 + .../com/redis/riot/file/XmlReaderFactory.java | 26 ++ .../com/redis/riot/file/XmlWriterFactory.java | 28 +++ .../java/com/redis/riot/file/ReaderTests.java | 21 +- .../com/redis/riot/AbstractExportCommand.java | 13 +- .../com/redis/riot/AbstractFileExport.java | 48 ++-- .../com/redis/riot/AbstractFileImport.java | 64 +++-- .../com/redis/riot/AbstractImportCommand.java | 27 ++- .../com/redis/riot/AbstractRedisCommand.java | 13 +- .../AbstractRedisTargetExportCommand.java | 36 +-- .../main/java/com/redis/riot/FileArgs.java | 13 +- .../com/redis/riot/GoogleStorageArgs.java | 1 - .../main/java/com/redis/riot/RedisArgs.java | 111 ++++++++- .../java/com/redis/riot/RedisClientArgs.java | 33 +++ .../java/com/redis/riot/RedisContext.java | 29 ++- .../src/main/java/com/redis/riot/S3Args.java | 18 +- .../java/com/redis/riot/SourceRedisArgs.java | 104 +++++++- .../src/main/java/com/redis/riot/SslArgs.java | 120 ---------- .../java/com/redis/riot/TargetRedisArgs.java | 99 +++++++- .../java/com/redis/riot/StackRiotTests.java | 4 +- .../riot/src/test/resources/file-import-gcs | 2 +- .../riot/src/test/resources/file-import-s3 | 2 +- 54 files changed, 1469 insertions(+), 945 deletions(-) create mode 100644 core/riot-core/src/main/java/com/redis/riot/core/RiotExecutionException.java create mode 100644 core/riot-core/src/main/java/com/redis/riot/core/RiotInitializationException.java delete mode 100644 core/riot-file/src/main/java/com/redis/riot/file/AbstractFactoryRegistry.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/AbstractReaderFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/AbstractRegistry.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/AbstractWriterFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/DelimitedReaderFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/DelimitedWriterFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/FileReaderResult.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/FileWriterResult.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/FixedWidthReaderFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/FormattedWriterFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/JsonLinesReaderFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/JsonLinesWriterFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/JsonReaderFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/JsonWriterFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/ReaderFactory.java rename core/riot-file/src/main/java/com/redis/riot/file/{Factory.java => ResourceMap.java} (50%) delete mode 100644 core/riot-file/src/main/java/com/redis/riot/file/ResourceTypeMap.java rename core/riot-file/src/main/java/com/redis/riot/file/{DefaultResourceLoader.java => RiotResourceLoader.java} (59%) create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/RiotResourceMap.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/WriterFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/XmlReaderFactory.java create mode 100644 core/riot-file/src/main/java/com/redis/riot/file/XmlWriterFactory.java delete mode 100644 plugins/riot/src/main/java/com/redis/riot/SslArgs.java diff --git a/core/riot-core/src/main/java/com/redis/riot/core/AbstractCallableCommand.java b/core/riot-core/src/main/java/com/redis/riot/core/AbstractCallableCommand.java index b719c1369..6bca756c5 100644 --- a/core/riot-core/src/main/java/com/redis/riot/core/AbstractCallableCommand.java +++ b/core/riot-core/src/main/java/com/redis/riot/core/AbstractCallableCommand.java @@ -18,14 +18,26 @@ public abstract class AbstractCallableCommand extends BaseCommand implements Cal @Override public Integer call() throws Exception { + initialize(); + try { + execute(); + } finally { + teardown(); + } + return 0; + } + + protected void initialize() throws RiotInitializationException { if (log == null) { log = LoggerFactory.getLogger(getClass()); } - execute(); - return 0; } - protected abstract void execute() throws Exception; + protected abstract void execute() throws RiotExecutionException; + + protected void teardown() { + // do nothing + } public Logger getLog() { return log; diff --git a/core/riot-core/src/main/java/com/redis/riot/core/AbstractJobCommand.java b/core/riot-core/src/main/java/com/redis/riot/core/AbstractJobCommand.java index 87d2025c8..1572f5978 100644 --- a/core/riot-core/src/main/java/com/redis/riot/core/AbstractJobCommand.java +++ b/core/riot-core/src/main/java/com/redis/riot/core/AbstractJobCommand.java @@ -65,7 +65,32 @@ public abstract class AbstractJobCommand extends AbstractCallableCommand { private PlatformTransactionManager transactionManager; private JobLauncher jobLauncher; - private TaskExecutorJobLauncher taskExecutorJobLauncher() throws Exception { + @Override + protected void initialize() throws RiotInitializationException { + super.initialize(); + if (jobName == null) { + jobName = jobName(); + } + if (jobRepository == null) { + try { + jobRepository = JobUtils.jobRepositoryFactoryBean(jobRepositoryName).getObject(); + } catch (Exception e) { + throw new RiotInitializationException("Could not create job repository", e); + } + } + if (transactionManager == null) { + transactionManager = JobUtils.resourcelessTransactionManager(); + } + if (jobLauncher == null) { + try { + jobLauncher = jobLauncher(); + } catch (Exception e) { + throw new RiotInitializationException("Could not create job launcher", e); + } + } + } + + private JobLauncher jobLauncher() throws Exception { TaskExecutorJobLauncher launcher = new TaskExecutorJobLauncher(); launcher.setJobRepository(jobRepository); launcher.setTaskExecutor(new SyncTaskExecutor()); @@ -82,26 +107,20 @@ private JobBuilder jobBuilder() { } @Override - protected void execute() throws Exception { - if (jobName == null) { - jobName = jobName(); - } - if (jobRepository == null) { - jobRepository = JobUtils.jobRepositoryFactoryBean(jobRepositoryName).getObject(); - } - if (transactionManager == null) { - transactionManager = JobUtils.resourcelessTransactionManager(); - } - if (jobLauncher == null) { - jobLauncher = taskExecutorJobLauncher(); + protected void execute() throws RiotExecutionException { + Job job = job(); + JobExecution jobExecution; + try { + jobExecution = jobLauncher.run(job, new JobParameters()); + } catch (JobExecutionException e) { + throw new RiotExecutionException("Could not run job " + job.getName(), e); } - JobExecution jobExecution = jobLauncher.run(job(), new JobParameters()); if (JobUtils.isFailed(jobExecution.getExitStatus())) { for (StepExecution stepExecution : jobExecution.getStepExecutions()) { ExitStatus stepExitStatus = stepExecution.getExitStatus(); if (JobUtils.isFailed(stepExitStatus)) { if (CollectionUtils.isEmpty(stepExecution.getFailureExceptions())) { - throw new JobExecutionException(stepExitStatus.getExitDescription()); + throw new RiotExecutionException(stepExitStatus.getExitDescription()); } throw wrapException(stepExecution.getFailureExceptions()); } @@ -117,11 +136,11 @@ private String jobName() { return commandSpec.name(); } - private JobExecutionException wrapException(List throwables) { + private RiotExecutionException wrapException(List throwables) { if (throwables.isEmpty()) { - return new JobExecutionException("Job failed"); + return new RiotExecutionException("Job failed"); } - return new JobExecutionException("Job failed", throwables.get(0)); + return new RiotExecutionException("Job failed", throwables.get(0)); } protected Job job(Step... steps) { diff --git a/core/riot-core/src/main/java/com/redis/riot/core/RiotExecutionException.java b/core/riot-core/src/main/java/com/redis/riot/core/RiotExecutionException.java new file mode 100644 index 000000000..2c5bd8126 --- /dev/null +++ b/core/riot-core/src/main/java/com/redis/riot/core/RiotExecutionException.java @@ -0,0 +1,18 @@ +package com.redis.riot.core; + +@SuppressWarnings("serial") +public class RiotExecutionException extends Exception { + + public RiotExecutionException(String message, Throwable cause) { + super(message, cause); + } + + public RiotExecutionException(String message) { + super(message); + } + + public RiotExecutionException(Throwable cause) { + super(cause); + } + +} diff --git a/core/riot-core/src/main/java/com/redis/riot/core/RiotInitializationException.java b/core/riot-core/src/main/java/com/redis/riot/core/RiotInitializationException.java new file mode 100644 index 000000000..4cd8656dd --- /dev/null +++ b/core/riot-core/src/main/java/com/redis/riot/core/RiotInitializationException.java @@ -0,0 +1,18 @@ +package com.redis.riot.core; + +@SuppressWarnings("serial") +public class RiotInitializationException extends Exception { + + public RiotInitializationException(String message, Throwable cause) { + super(message, cause); + } + + public RiotInitializationException(String message) { + super(message); + } + + public RiotInitializationException(Throwable cause) { + super(cause); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/AbstractFactoryRegistry.java b/core/riot-file/src/main/java/com/redis/riot/file/AbstractFactoryRegistry.java deleted file mode 100644 index 3a068b1b1..000000000 --- a/core/riot-file/src/main/java/com/redis/riot/file/AbstractFactoryRegistry.java +++ /dev/null @@ -1,184 +0,0 @@ -package com.redis.riot.file; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.util.Arrays; -import java.util.Base64; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; - -import org.springframework.batch.item.file.transform.DelimitedLineTokenizer; -import org.springframework.core.io.ProtocolResolver; -import org.springframework.core.io.Resource; -import org.springframework.core.io.ResourceLoader; -import org.springframework.util.MimeType; - -import com.google.auth.oauth2.GoogleCredentials; -import com.google.cloud.ServiceOptions; -import com.google.cloud.spring.autoconfigure.storage.GcpStorageAutoConfiguration; -import com.google.cloud.spring.core.GcpScope; -import com.google.cloud.spring.core.UserAgentHeaderProvider; -import com.google.cloud.storage.Storage; -import com.google.cloud.storage.StorageOptions; - -import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.S3ClientBuilder; - -public abstract class AbstractFactoryRegistry { - - public static final String DELIMITER_PIPE = "|"; - private final Map delimiterMap = defaultDelimiterMap(); - private ResourceTypeMap resourceTypeMap = ResourceTypeMap.defaultResourceTypeMap(); - private final Map> factories = new HashMap<>(); - - public void registerDelimiter(MimeType type, String delimiter) { - delimiterMap.put(type, delimiter); - } - - private static Map defaultDelimiterMap() { - Map map = new HashMap<>(); - map.put(FileUtils.CSV, DelimitedLineTokenizer.DELIMITER_COMMA); - map.put(FileUtils.PSV, DELIMITER_PIPE); - map.put(FileUtils.TSV, DelimitedLineTokenizer.DELIMITER_TAB); - return map; - } - - public ResourceTypeMap getResourceTypeMap() { - return resourceTypeMap; - } - - public void setResourceTypeMap(ResourceTypeMap map) { - this.resourceTypeMap = map; - } - - public void register(MimeType type, Factory factory) { - factories.put(type, factory); - } - - protected String delimiter(Resource resource, O options) { - if (options.getDelimiter() == null) { - return delimiterMap.get(getType(resource.getFilename(), options)); - } - return options.getDelimiter(); - } - - private Resource normalize(Resource resource, O options) throws IOException { - if (options.isGzipped() || FileUtils.isGzip(resource.getFilename())) { - return gzip(resource); - } - return resource; - } - - protected abstract Resource gzip(Resource resource) throws IOException; - - public T get(String location, O options) throws IOException { - Resource resource = resource(location, options); - return get(resource, options); - } - - public Resource resource(String location, O options) { - return resourceLoader(options).getResource(location); - } - - protected ResourceLoader resourceLoader(O options) { - DefaultResourceLoader loader = new DefaultResourceLoader(); - protocolResolvers(options).forEach(loader::addProtocolResolver); - return loader; - } - - protected Collection protocolResolvers(O options) { - S3ProtocolResolver s3ProtocolResolver = new S3ProtocolResolver(); - s3ProtocolResolver.setClientSupplier(() -> s3Client(options.getS3Options())); - GoogleStorageProtocolResolver googleStorageProtocolResolver = new GoogleStorageProtocolResolver(); - googleStorageProtocolResolver.setStorageSupplier(() -> googleStorage(options.getGoogleStorageOptions())); - return Arrays.asList(s3ProtocolResolver, googleStorageProtocolResolver); - } - - private S3Client s3Client(S3Options options) { - S3ClientBuilder clientBuilder = S3Client.builder(); - if (options.getRegion() != null) { - clientBuilder.region(options.getRegion()); - } - if (options.getEndpoint() != null) { - clientBuilder.endpointOverride(options.getEndpoint()); - } - clientBuilder.credentialsProvider(credentialsProvider(options)); - return clientBuilder.build(); - } - - private AwsCredentialsProvider credentialsProvider(S3Options options) { - if (options.getAccessKey() == null && options.getSecretKey() == null) { - return AnonymousCredentialsProvider.create(); - } - return StaticCredentialsProvider - .create(AwsBasicCredentials.create(options.getAccessKey(), options.getSecretKey())); - } - - private Storage googleStorage(GoogleStorageOptions options) { - StorageOptions.Builder builder = StorageOptions.newBuilder(); - builder.setProjectId(ServiceOptions.getDefaultProjectId()); - builder.setHeaderProvider(new UserAgentHeaderProvider(GcpStorageAutoConfiguration.class)); - if (options.getKeyFile() != null) { - InputStream inputStream; - try { - inputStream = Files.newInputStream(options.getKeyFile()); - } catch (IOException e) { - throw new RuntimeIOException("Could not read key file", e); - } - builder.setCredentials(credentials(inputStream, options)); - } - if (options.getEncodedKey() != null) { - byte[] bytes = Base64.getDecoder().decode(options.getEncodedKey()); - builder.setCredentials(credentials(new ByteArrayInputStream(bytes), options)); - } - if (options.getProjectId() != null) { - builder.setProjectId(options.getProjectId()); - } - return builder.build().getService(); - } - - private GoogleCredentials credentials(InputStream inputStream, GoogleStorageOptions options) { - GoogleCredentials credentials; - try { - credentials = GoogleCredentials.fromStream(inputStream); - } catch (IOException e) { - throw new RuntimeIOException("Could not create Google credentials", e); - } - credentials.createScoped(googleStorageScope().getUrl()); - return credentials; - } - - protected abstract GcpScope googleStorageScope(); - - public T get(Resource resource, O options) throws IOException { - MimeType type = getType(resource.getFilename(), options); - Factory factory = factories.get(type); - if (factory == null) { - return null; - } - return factory.create(normalize(resource, options), options); - } - - public MimeType getType(String filename) { - return resourceTypeMap.getContentType(FileUtils.normalize(filename)); - } - - public MimeType getType(String filename, O options) { - return getType(filename, options.getType()); - } - - public MimeType getType(String filename, MimeType type) { - if (type == null) { - return getType(filename); - } - return type; - } - -} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/AbstractReaderFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/AbstractReaderFactory.java new file mode 100644 index 000000000..77522c4eb --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/AbstractReaderFactory.java @@ -0,0 +1,83 @@ +package com.redis.riot.file; + +import java.util.Map; + +import org.springframework.batch.item.file.FlatFileItemReader; +import org.springframework.batch.item.file.builder.FlatFileItemReaderBuilder; +import org.springframework.batch.item.file.separator.DefaultRecordSeparatorPolicy; +import org.springframework.batch.item.file.separator.RecordSeparatorPolicy; +import org.springframework.batch.item.file.transform.AbstractLineTokenizer; +import org.springframework.core.io.Resource; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.module.SimpleModule; + +public abstract class AbstractReaderFactory implements ReaderFactory { + + protected int[] includedFields(ReadOptions options) { + return options.getIncludedFields().stream().mapToInt(Integer::intValue).toArray(); + } + + protected FlatFileItemReader> flatFileReader(Resource resource, ReadOptions options, + AbstractLineTokenizer tokenizer) { + if (ObjectUtils.isEmpty(options.getFields())) { + Assert.isTrue(options.isHeader(), + String.format("Could not create reader for file '%s': no header or field names specified", + resource.getFilename())); + } else { + tokenizer.setNames(options.getFields().toArray(new String[0])); + } + FlatFileItemReaderBuilder> builder = flatFileReader(options); + builder.resource(resource); + builder.fieldSetMapper(new MapFieldSetMapper()); + builder.lineTokenizer(tokenizer); + builder.skippedLinesCallback(new HeaderCallbackHandler(tokenizer, headerIndex(options))); + return builder.build(); + } + + protected FlatFileItemReaderBuilder flatFileReader(ReadOptions options) { + FlatFileItemReaderBuilder builder = new FlatFileItemReaderBuilder<>(); + if (options.getMaxItemCount() > 0) { + builder.maxItemCount(options.getMaxItemCount()); + } + builder.encoding(options.getEncoding()); + builder.recordSeparatorPolicy(recordSeparatorPolicy(options)); + builder.linesToSkip(linesToSkip(options)); + builder.saveState(false); + return builder; + } + + private RecordSeparatorPolicy recordSeparatorPolicy(ReadOptions options) { + String quoteCharacter = String.valueOf(options.getQuoteCharacter()); + return new DefaultRecordSeparatorPolicy(quoteCharacter, options.getContinuationString()); + } + + private int headerIndex(ReadOptions options) { + if (options.getHeaderLine() != null) { + return options.getHeaderLine(); + } + return linesToSkip(options) - 1; + } + + private int linesToSkip(ReadOptions options) { + if (options.getLinesToSkip() != null) { + return options.getLinesToSkip(); + } + if (options.isHeader()) { + return 1; + } + return 0; + } + + protected T objectMapper(T objectMapper, ReadOptions options) { + objectMapper.configure(DeserializationFeature.USE_LONG_FOR_INTS, true); + SimpleModule module = new SimpleModule(); + options.getDeserializers().forEach(module::addDeserializer); + objectMapper.registerModule(module); + return objectMapper; + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/AbstractRegistry.java b/core/riot-file/src/main/java/com/redis/riot/file/AbstractRegistry.java new file mode 100644 index 000000000..c4cc68c74 --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/AbstractRegistry.java @@ -0,0 +1,43 @@ +package com.redis.riot.file; + +import java.util.HashSet; +import java.util.Set; + +import org.springframework.core.io.ProtocolResolver; +import org.springframework.core.io.Resource; +import org.springframework.util.MimeType; + +public class AbstractRegistry { + + private ResourceMap resourceMap = new RiotResourceMap(); + private Set protocolResolvers = new HashSet<>(); + + protected MimeType type(Resource resource, FileOptions options) { + if (options.getContentType() == null) { + return MimeType.valueOf(resourceMap.getContentTypeFor(resource)); + } + return options.getContentType(); + } + + public void addProtocolResolver(ProtocolResolver protocolResolver) { + protocolResolvers.add(protocolResolver); + } + + protected Resource resource(String location, FileOptions options) { + RiotResourceLoader resourceLoader = new RiotResourceLoader(); + protocolResolvers.forEach(resourceLoader::addProtocolResolver); + resourceLoader.getS3ProtocolResolver().setClientSupplier(options.getS3Options()::client); + resourceLoader.getGoogleStorageProtocolResolver() + .setStorageSupplier(options.getGoogleStorageOptions()::storage); + return resourceLoader.getResource(location); + } + + public ResourceMap getResourceMap() { + return resourceMap; + } + + public void setResourceMap(ResourceMap resourceMap) { + this.resourceMap = resourceMap; + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/AbstractWriterFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/AbstractWriterFactory.java new file mode 100644 index 000000000..c123b74be --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/AbstractWriterFactory.java @@ -0,0 +1,58 @@ +package com.redis.riot.file; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.batch.item.file.transform.LineAggregator; +import org.springframework.core.io.WritableResource; +import org.springframework.util.CollectionUtils; + +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.redis.riot.resource.FlatFileItemWriter; +import com.redis.riot.resource.FlatFileItemWriterBuilder; + +public abstract class AbstractWriterFactory implements WriterFactory { + + protected T objectMapper(T objectMapper) { + objectMapper.setSerializationInclusion(Include.NON_NULL); + objectMapper.setSerializationInclusion(Include.NON_DEFAULT); + return objectMapper; + } + + protected FlatFileItemWriterBuilder flatFileWriter(WritableResource resource, WriteOptions options) { + FlatFileItemWriterBuilder builder = new FlatFileItemWriterBuilder<>(); + builder.name(resource.getFilename()); + builder.resource(resource); + builder.append(options.isAppend()); + builder.encoding(options.getEncoding()); + builder.forceSync(options.isForceSync()); + builder.lineSeparator(options.getLineSeparator()); + builder.saveState(false); + builder.shouldDeleteIfEmpty(options.isShouldDeleteIfEmpty()); + builder.shouldDeleteIfExists(options.isShouldDeleteIfExists()); + builder.transactional(options.isTransactional()); + return builder; + } + + protected FlatFileItemWriter> flatFileWriter(WriteOptions options, + FlatFileItemWriterBuilder> writer, LineAggregator> aggregator) { + writer.lineAggregator(aggregator); + if (options.isHeader()) { + Map headerRecord = options.getHeaderSupplier().get(); + if (!CollectionUtils.isEmpty(headerRecord)) { + List fields = new ArrayList<>(headerRecord.keySet()); + Collections.sort(fields); + Map fieldMap = new LinkedHashMap<>(); + fields.forEach(f -> fieldMap.put(f, f)); + String headerLine = aggregator.aggregate(fieldMap); + writer.headerCallback(w -> w.write(headerLine)); + } + } + return writer.build(); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/DelimitedReaderFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/DelimitedReaderFactory.java new file mode 100644 index 000000000..d57c7eb6e --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/DelimitedReaderFactory.java @@ -0,0 +1,27 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemReader; +import org.springframework.batch.item.file.transform.DelimitedLineTokenizer; +import org.springframework.core.io.Resource; +import org.springframework.util.ObjectUtils; + +public class DelimitedReaderFactory extends AbstractReaderFactory { + + private final String delimiter; + + public DelimitedReaderFactory(String delimiter) { + this.delimiter = delimiter; + } + + @Override + public ItemReader create(Resource resource, ReadOptions options) { + DelimitedLineTokenizer tokenizer = new DelimitedLineTokenizer(); + tokenizer.setDelimiter(options.getDelimiter() == null ? delimiter : options.getDelimiter()); + tokenizer.setQuoteCharacter(options.getQuoteCharacter()); + if (!ObjectUtils.isEmpty(options.getIncludedFields())) { + tokenizer.setIncludedFields(includedFields(options)); + } + return flatFileReader(resource, options, tokenizer); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/DelimitedWriterFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/DelimitedWriterFactory.java new file mode 100644 index 000000000..a03c1f9b9 --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/DelimitedWriterFactory.java @@ -0,0 +1,30 @@ +package com.redis.riot.file; + +import java.util.Map; + +import org.springframework.batch.item.ItemWriter; +import org.springframework.batch.item.file.transform.PassThroughFieldExtractor; +import org.springframework.core.io.WritableResource; + +import com.redis.riot.resource.FlatFileItemWriterBuilder; +import com.redis.riot.resource.FlatFileItemWriterBuilder.DelimitedBuilder; + +public class DelimitedWriterFactory extends AbstractWriterFactory { + + private final String delimiter; + + public DelimitedWriterFactory(String delimiter) { + this.delimiter = delimiter; + } + + @Override + public ItemWriter create(WritableResource resource, WriteOptions options) { + FlatFileItemWriterBuilder> writer = flatFileWriter(resource, options); + DelimitedBuilder> delimitedBuilder = writer.delimited(); + delimitedBuilder.delimiter(options.getDelimiter() == null ? delimiter : options.getDelimiter()); + delimitedBuilder.fieldExtractor(new PassThroughFieldExtractor<>()); + delimitedBuilder.quoteCharacter(String.valueOf(options.getQuoteCharacter())); + return flatFileWriter(options, writer, delimitedBuilder.build()); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/FileOptions.java b/core/riot-file/src/main/java/com/redis/riot/file/FileOptions.java index 1c755b602..3be10c115 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/FileOptions.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/FileOptions.java @@ -2,6 +2,7 @@ import java.nio.charset.StandardCharsets; +import org.springframework.batch.item.file.transform.DelimitedLineTokenizer; import org.springframework.util.MimeType; import lombok.ToString; @@ -9,40 +10,51 @@ @ToString public class FileOptions { + public static final String DELIMITER_PIPE = "|"; + public static final String DELIMITER_COMMA = DelimitedLineTokenizer.DELIMITER_COMMA; + public static final String DELIMITER_TAB = DelimitedLineTokenizer.DELIMITER_TAB; public static final String DEFAULT_ENCODING = StandardCharsets.UTF_8.name(); public static final char DEFAULT_QUOTE_CHARACTER = '"'; - private MimeType type; + private boolean gzip; private S3Options s3Options = new S3Options(); private GoogleStorageOptions googleStorageOptions = new GoogleStorageOptions(); - private boolean gzipped; + private MimeType contentType; private String encoding = DEFAULT_ENCODING; private boolean header; private String delimiter; private char quoteCharacter = DEFAULT_QUOTE_CHARACTER; - public S3Options getS3Options() { - return s3Options; + public boolean isGzip() { + return gzip; } - public void setS3Options(S3Options s3Options) { - this.s3Options = s3Options; + public void setGzip(boolean gzip) { + this.gzip = gzip; } public GoogleStorageOptions getGoogleStorageOptions() { return googleStorageOptions; } + public S3Options getS3Options() { + return s3Options; + } + + public void setS3Options(S3Options s3Options) { + this.s3Options = s3Options; + } + public void setGoogleStorageOptions(GoogleStorageOptions googleStorageOptions) { this.googleStorageOptions = googleStorageOptions; } - public MimeType getType() { - return type; + public MimeType getContentType() { + return contentType; } - public void setType(MimeType type) { - this.type = type; + public void setContentType(MimeType type) { + this.contentType = type; } public String getDelimiter() { @@ -77,12 +89,4 @@ public void setEncoding(String encoding) { this.encoding = encoding; } - public boolean isGzipped() { - return gzipped; - } - - public void setGzipped(boolean gzipped) { - this.gzipped = gzipped; - } - } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/FileReaderRegistry.java b/core/riot-file/src/main/java/com/redis/riot/file/FileReaderRegistry.java index d39f4b316..07e335c44 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/FileReaderRegistry.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/FileReaderRegistry.java @@ -1,210 +1,58 @@ package com.redis.riot.file; import java.io.IOException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; +import java.util.HashMap; import java.util.Map; import java.util.zip.GZIPInputStream; -import org.springframework.batch.item.ItemReader; -import org.springframework.batch.item.file.FlatFileItemReader; -import org.springframework.batch.item.file.builder.FlatFileItemReaderBuilder; -import org.springframework.batch.item.file.mapping.JsonLineMapper; -import org.springframework.batch.item.file.separator.DefaultRecordSeparatorPolicy; -import org.springframework.batch.item.file.separator.RecordSeparatorPolicy; -import org.springframework.batch.item.file.transform.AbstractLineTokenizer; -import org.springframework.batch.item.file.transform.DelimitedLineTokenizer; -import org.springframework.batch.item.file.transform.FixedLengthTokenizer; -import org.springframework.batch.item.file.transform.Range; -import org.springframework.batch.item.file.transform.RangeArrayPropertyEditor; -import org.springframework.batch.item.json.JacksonJsonObjectReader; -import org.springframework.batch.item.json.JsonItemReader; -import org.springframework.batch.item.json.builder.JsonItemReaderBuilder; -import org.springframework.core.io.ProtocolResolver; import org.springframework.core.io.Resource; -import org.springframework.util.Assert; -import org.springframework.util.ObjectUtils; +import org.springframework.util.MimeType; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.module.SimpleModule; -import com.fasterxml.jackson.dataformat.xml.XmlMapper; -import com.google.cloud.spring.core.GcpScope; -import com.redis.riot.file.xml.XmlItemReader; -import com.redis.riot.file.xml.XmlItemReaderBuilder; -import com.redis.riot.file.xml.XmlObjectReader; +public class FileReaderRegistry extends AbstractRegistry { -public class FileReaderRegistry extends AbstractFactoryRegistry, ReadOptions> { + private final Map factories = new HashMap<>(); - private ProtocolResolver stdInProtocolResolver = new StdInProtocolResolver(); - - public ProtocolResolver getStdInProtocolResolver() { - return stdInProtocolResolver; - } - - public void setStdInProtocolResolver(ProtocolResolver stdInProtocolResolver) { - this.stdInProtocolResolver = stdInProtocolResolver; - } - - @Override - protected Collection protocolResolvers(ReadOptions options) { - List resolvers = new ArrayList<>(super.protocolResolvers(options)); - resolvers.add(stdInProtocolResolver); - return resolvers; - } - - @Override - protected GcpScope googleStorageScope() { - return GcpScope.STORAGE_READ_ONLY; - } - - @Override - protected Resource gzip(Resource resource) throws IOException { - InputStream inputStream = resource.getInputStream(); - GZIPInputStream gzipInputStream = new GZIPInputStream(inputStream); - return new NamedInputStreamResource(gzipInputStream, resource.getFilename(), resource.getDescription()); - } - - public FlatFileItemReader> delimited(Resource resource, ReadOptions options) { - return flatFileReader(resource, delimitedLineTokenizer(delimiter(resource, options), options), options); - } - - public FlatFileItemReader> fixedWidth(Resource resource, ReadOptions options) { - FixedLengthTokenizer tokenizer = new FixedLengthTokenizer(); - RangeArrayPropertyEditor editor = new RangeArrayPropertyEditor(); - List columnRanges = options.getColumnRanges(); - Assert.notEmpty(columnRanges, "Column ranges are required"); - editor.setAsText(String.join(",", columnRanges)); - Range[] ranges = (Range[]) editor.getValue(); - Assert.notEmpty(ranges, "Invalid ranges specified: " + columnRanges); - tokenizer.setColumns(ranges); - return flatFileReader(resource, tokenizer, options); - } - - public JsonItemReader json(Resource resource, ReadOptions options) { - JsonItemReaderBuilder builder = new JsonItemReaderBuilder<>(); - builder.name(resource.getFilename() + "-json-file-reader"); - builder.resource(resource); - builder.saveState(false); - JacksonJsonObjectReader objectReader = new JacksonJsonObjectReader<>(options.getItemType()); - objectReader.setMapper(objectMapper(new ObjectMapper(), options)); - builder.jsonObjectReader(objectReader); - if (options.getMaxItemCount() > 0) { - builder.maxItemCount(options.getMaxItemCount()); - } - return builder.build(); - } - - public FlatFileItemReader jsonLines(Resource resource, ReadOptions options) { - if (Map.class.isAssignableFrom(options.getItemType())) { - FlatFileItemReaderBuilder> reader = flatFileReader(resource, options); - reader.lineMapper(new JsonLineMapper()); - reader.fieldSetMapper(new MapFieldSetMapper()); - return reader.build(); - } - FlatFileItemReaderBuilder reader = flatFileReader(resource, options); - ObjectMapper objectMapper = objectMapper(new ObjectMapper(), options); - reader.lineMapper(new ObjectMapperLineMapper<>(objectMapper, options.getItemType())); - return reader.build(); - } - - public XmlItemReader xml(Resource resource, ReadOptions options) { - XmlItemReaderBuilder builder = new XmlItemReaderBuilder<>(); - builder.name(resource.getFilename() + "-xml-file-reader"); - builder.resource(resource); - XmlObjectReader objectReader = new XmlObjectReader<>(options.getItemType()); - objectReader.setMapper(objectMapper(new XmlMapper(), options)); - builder.xmlObjectReader(objectReader); - if (options.getMaxItemCount() > 0) { - builder.maxItemCount(options.getMaxItemCount()); - } - return builder.build(); + public void register(MimeType type, ReaderFactory factory) { + factories.put(type, factory); } - private DelimitedLineTokenizer delimitedLineTokenizer(String delimiter, ReadOptions options) { - DelimitedLineTokenizer tokenizer = new DelimitedLineTokenizer(); - tokenizer.setDelimiter(delimiter); - tokenizer.setQuoteCharacter(options.getQuoteCharacter()); - if (!ObjectUtils.isEmpty(options.getIncludedFields())) { - tokenizer.setIncludedFields(includedFields(options)); + public FileReaderResult find(String location, ReadOptions options) { + Resource resource = resource(location, options); + MimeType type = type(resource, options); + FileReaderResult reader = new FileReaderResult(); + reader.setResource(resource); + reader.setType(type); + ReaderFactory factory = factories.get(type); + if (factory != null) { + reader.setReader(factory.create(resource, options)); } - return tokenizer; - } - - private int[] includedFields(ReadOptions options) { - return options.getIncludedFields().stream().mapToInt(Integer::intValue).toArray(); + return reader; } - private FlatFileItemReader> flatFileReader(Resource resource, AbstractLineTokenizer tokenizer, - ReadOptions options) { - if (ObjectUtils.isEmpty(options.getFields())) { - Assert.isTrue(options.isHeader(), - String.format("Could not create reader for file '%s': no header or field names specified", - resource.getFilename())); - } else { - tokenizer.setNames(options.getFields().toArray(new String[0])); - } - FlatFileItemReaderBuilder> builder = flatFileReader(resource, options); - builder.fieldSetMapper(new MapFieldSetMapper()); - builder.lineTokenizer(tokenizer); - builder.skippedLinesCallback(new HeaderCallbackHandler(tokenizer, headerIndex(options))); - return builder.build(); - } - - private FlatFileItemReaderBuilder flatFileReader(Resource resource, ReadOptions options) { - FlatFileItemReaderBuilder builder = new FlatFileItemReaderBuilder<>(); - builder.resource(resource); - if (options.getMaxItemCount() > 0) { - builder.maxItemCount(options.getMaxItemCount()); - } - builder.encoding(options.getEncoding()); - builder.recordSeparatorPolicy(recordSeparatorPolicy(options)); - builder.linesToSkip(linesToSkip(options)); - builder.saveState(false); - return builder; - } - - private RecordSeparatorPolicy recordSeparatorPolicy(ReadOptions options) { - String quoteCharacter = String.valueOf(options.getQuoteCharacter()); - return new DefaultRecordSeparatorPolicy(quoteCharacter, options.getContinuationString()); - } - - private int headerIndex(ReadOptions options) { - if (options.getHeaderLine() != null) { - return options.getHeaderLine(); - } - return linesToSkip(options) - 1; - } - - private int linesToSkip(ReadOptions options) { - if (options.getLinesToSkip() != null) { - return options.getLinesToSkip(); - } - if (options.isHeader()) { - return 1; + @Override + protected Resource resource(String location, FileOptions options) { + Resource resource = super.resource(location, options); + if (options.isGzip() || FileUtils.isGzip(resource.getFilename())) { + GZIPInputStream gzipInputStream; + try { + gzipInputStream = new GZIPInputStream(resource.getInputStream()); + } catch (IOException e) { + throw new RuntimeIOException("Could not create GZip input stream", e); + } + return new NamedInputStreamResource(gzipInputStream, resource.getFilename(), resource.getDescription()); } - return 0; - } - - private T objectMapper(T objectMapper, ReadOptions options) { - objectMapper.configure(DeserializationFeature.USE_LONG_FOR_INTS, true); - SimpleModule module = new SimpleModule(); - options.getDeserializers().forEach(module::addDeserializer); - objectMapper.registerModule(module); - return objectMapper; + return resource; } public static FileReaderRegistry defaultReaderRegistry() { FileReaderRegistry registry = new FileReaderRegistry(); - registry.register(FileUtils.JSON, registry::json); - registry.register(FileUtils.JSON_LINES, registry::jsonLines); - registry.register(FileUtils.XML, registry::xml); - registry.register(FileUtils.CSV, registry::delimited); - registry.register(FileUtils.PSV, registry::delimited); - registry.register(FileUtils.TSV, registry::delimited); - registry.register(FileUtils.TEXT, registry::fixedWidth); + registry.register(FileUtils.JSON, new JsonReaderFactory()); + registry.register(FileUtils.JSON_LINES, new JsonLinesReaderFactory()); + registry.register(FileUtils.XML, new XmlReaderFactory()); + registry.register(FileUtils.CSV, new DelimitedReaderFactory(FileOptions.DELIMITER_COMMA)); + registry.register(FileUtils.PSV, new DelimitedReaderFactory(FileOptions.DELIMITER_PIPE)); + registry.register(FileUtils.TSV, new DelimitedReaderFactory(FileOptions.DELIMITER_TAB)); + registry.register(FileUtils.TEXT, new FixedWidthReaderFactory()); return registry; } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/FileReaderResult.java b/core/riot-file/src/main/java/com/redis/riot/file/FileReaderResult.java new file mode 100644 index 000000000..45fdcfbea --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/FileReaderResult.java @@ -0,0 +1,37 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemReader; +import org.springframework.core.io.Resource; +import org.springframework.util.MimeType; + +public class FileReaderResult { + + private Resource resource; + private MimeType type; + private ItemReader reader; + + public Resource getResource() { + return resource; + } + + public void setResource(Resource resource) { + this.resource = resource; + } + + public MimeType getType() { + return type; + } + + public void setType(MimeType mimeType) { + this.type = mimeType; + } + + public ItemReader getReader() { + return reader; + } + + public void setReader(ItemReader itemReader) { + this.reader = itemReader; + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/FileUtils.java b/core/riot-file/src/main/java/com/redis/riot/file/FileUtils.java index 59552387a..d9f1a2b47 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/FileUtils.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/FileUtils.java @@ -22,7 +22,7 @@ public static boolean isGzip(String filename) { return filename.endsWith(GZ_SUFFIX); } - public static String normalize(String filename) { + public static String stripGzipSuffix(String filename) { if (isGzip(filename)) { return filename.substring(0, filename.length() - GZ_SUFFIX.length()); } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/FileWriterRegistry.java b/core/riot-file/src/main/java/com/redis/riot/file/FileWriterRegistry.java index 7ab430391..ce748f781 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/FileWriterRegistry.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/FileWriterRegistry.java @@ -1,168 +1,71 @@ package com.redis.riot.file; import java.io.IOException; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; +import java.util.HashMap; import java.util.Map; import java.util.zip.GZIPOutputStream; -import org.springframework.batch.item.ItemWriter; -import org.springframework.batch.item.file.transform.LineAggregator; -import org.springframework.batch.item.file.transform.PassThroughFieldExtractor; -import org.springframework.batch.item.json.JacksonJsonObjectMarshaller; -import org.springframework.core.io.ProtocolResolver; import org.springframework.core.io.Resource; import org.springframework.core.io.WritableResource; -import org.springframework.util.CollectionUtils; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.dataformat.xml.XmlMapper; -import com.google.cloud.spring.core.GcpScope; -import com.redis.riot.file.xml.XmlResourceItemWriterBuilder; -import com.redis.riot.resource.FlatFileItemWriter; -import com.redis.riot.resource.FlatFileItemWriterBuilder; -import com.redis.riot.resource.FlatFileItemWriterBuilder.DelimitedBuilder; -import com.redis.riot.resource.FlatFileItemWriterBuilder.FormattedBuilder; -import com.redis.riot.resource.JsonFileItemWriterBuilder; +public class FileWriterRegistry extends AbstractRegistry { -public class FileWriterRegistry extends AbstractFactoryRegistry, WriteOptions> { + private final Map factories = new HashMap<>(); + private ResourceMap resourceMap = new RiotResourceMap(); - private ProtocolResolver stdOutProtocolResolver = new StdOutProtocolResolver(); - - public ProtocolResolver getStdOutProtocolResolver() { - return stdOutProtocolResolver; + public ResourceMap getResourceMap() { + return resourceMap; } - public void setStdOutProtocolResolver(ProtocolResolver resolver) { - this.stdOutProtocolResolver = resolver; + public void setResourceMap(ResourceMap resourceMap) { + this.resourceMap = resourceMap; } - @Override - protected GcpScope googleStorageScope() { - return GcpScope.STORAGE_READ_WRITE; + public void register(MimeType type, WriterFactory factory) { + factories.put(type, factory); } - @Override - protected Collection protocolResolvers(WriteOptions options) { - List resolvers = new ArrayList<>(super.protocolResolvers(options)); - resolvers.add(stdOutProtocolResolver); - return resolvers; + public FileWriterResult find(String location, WriteOptions options) { + WritableResource resource = resource(location, options); + MimeType type = type(resource, options); + FileWriterResult result = new FileWriterResult(); + result.setResource(resource); + result.setType(type); + WriterFactory factory = factories.get(type); + if (factory != null) { + result.setWriter(factory.create(resource, options)); + } + return result; } @Override - protected Resource gzip(Resource resource) throws IOException { - OutputStream outputStream = ((WritableResource) resource).getOutputStream(); - GZIPOutputStream gzipOutputStream = new GZIPOutputStream(outputStream); - return new OutputStreamResource(gzipOutputStream, resource.getFilename(), resource.getDescription()); - } - - private FlatFileItemWriter> delimited(Resource resource, WriteOptions options) { - FlatFileItemWriterBuilder> writer = flatFileWriter(resource, options); - DelimitedBuilder> delimitedBuilder = writer.delimited(); - delimitedBuilder.delimiter(delimiter(resource, options)); - delimitedBuilder.fieldExtractor(new PassThroughFieldExtractor<>()); - delimitedBuilder.quoteCharacter(String.valueOf(options.getQuoteCharacter())); - return flatFileWriter(writer, delimitedBuilder.build(), options); - } - - private FlatFileItemWriter> formatted(Resource resource, WriteOptions options) { - FlatFileItemWriterBuilder> writer = flatFileWriter(resource, options); - FormattedBuilder> formattedBuilder = writer.formatted(); - formattedBuilder.format(options.getFormatterString()); - formattedBuilder.fieldExtractor(new PassThroughFieldExtractor<>()); - return flatFileWriter(writer, formattedBuilder.build(), options); - } - - private FlatFileItemWriter jsonLines(Resource resource, WriteOptions options) { - FlatFileItemWriterBuilder builder = flatFileWriter(resource, options); - builder.lineAggregator(new JsonLineAggregator<>(new ObjectMapper())); - return builder.build(); - } - - private ItemWriter json(Resource resource, WriteOptions options) { - JsonFileItemWriterBuilder writer = new JsonFileItemWriterBuilder<>(); - writer.name(resource.getFilename()); - writer.resource((WritableResource) resource); - writer.append(options.isAppend()); - writer.encoding(options.getEncoding()); - writer.forceSync(options.isForceSync()); - writer.lineSeparator(options.getLineSeparator()); - writer.saveState(false); - writer.shouldDeleteIfEmpty(options.isShouldDeleteIfEmpty()); - writer.shouldDeleteIfExists(options.isShouldDeleteIfExists()); - writer.transactional(options.isTransactional()); - writer.jsonObjectMarshaller(new JacksonJsonObjectMarshaller<>(objectMapper(new ObjectMapper()))); - return writer.build(); - } - - private T objectMapper(T objectMapper) { - objectMapper.setSerializationInclusion(Include.NON_NULL); - objectMapper.setSerializationInclusion(Include.NON_DEFAULT); - return objectMapper; - } - - private ItemWriter xml(Resource resource, WriteOptions options) { - XmlResourceItemWriterBuilder writer = new XmlResourceItemWriterBuilder<>(); - writer.name(resource.getFilename()); - writer.append(options.isAppend()); - writer.encoding(options.getEncoding()); - writer.lineSeparator(options.getLineSeparator()); - writer.rootName(options.getRootName()); - writer.resource((WritableResource) resource); - writer.saveState(false); - XmlMapper mapper = objectMapper(new XmlMapper()); - mapper.setConfig(mapper.getSerializationConfig().withRootName(options.getElementName())); - writer.xmlObjectMarshaller(new JacksonJsonObjectMarshaller<>(mapper)); - return writer.build(); - } - - private FlatFileItemWriterBuilder flatFileWriter(Resource resource, WriteOptions options) { - FlatFileItemWriterBuilder builder = new FlatFileItemWriterBuilder<>(); - builder.name(resource.getFilename()); - builder.resource((WritableResource) resource); - builder.append(options.isAppend()); - builder.encoding(options.getEncoding()); - builder.forceSync(options.isForceSync()); - builder.lineSeparator(options.getLineSeparator()); - builder.saveState(false); - builder.shouldDeleteIfEmpty(options.isShouldDeleteIfEmpty()); - builder.shouldDeleteIfExists(options.isShouldDeleteIfExists()); - builder.transactional(options.isTransactional()); - return builder; - } - - private FlatFileItemWriter> flatFileWriter( - FlatFileItemWriterBuilder> writer, LineAggregator> aggregator, - WriteOptions options) { - writer.lineAggregator(aggregator); - if (options.isHeader()) { - Map headerRecord = options.getHeaderSupplier().get(); - if (!CollectionUtils.isEmpty(headerRecord)) { - List fields = new ArrayList<>(headerRecord.keySet()); - Collections.sort(fields); - Map fieldMap = new LinkedHashMap<>(); - fields.forEach(f -> fieldMap.put(f, f)); - String headerLine = aggregator.aggregate(fieldMap); - writer.headerCallback(w -> w.write(headerLine)); + protected WritableResource resource(String location, FileOptions options) { + Resource resource = super.resource(location, options); + Assert.isInstanceOf(WritableResource.class, resource, "Resource is not writable"); + WritableResource writableResource = (WritableResource) resource; + if (options.isGzip() || FileUtils.isGzip(resource.getFilename())) { + GZIPOutputStream gzipOutputStream; + try { + gzipOutputStream = new GZIPOutputStream(writableResource.getOutputStream()); + } catch (IOException e) { + throw new RuntimeIOException("Could not create GZip output stream", e); } + return new OutputStreamResource(gzipOutputStream, resource.getFilename(), resource.getDescription()); } - return writer.build(); + return writableResource; } public static FileWriterRegistry defaultWriterRegistry() { FileWriterRegistry registry = new FileWriterRegistry(); - registry.register(FileUtils.JSON, registry::json); - registry.register(FileUtils.JSON_LINES, registry::jsonLines); - registry.register(FileUtils.XML, registry::xml); - registry.register(FileUtils.CSV, registry::delimited); - registry.register(FileUtils.PSV, registry::delimited); - registry.register(FileUtils.TSV, registry::delimited); - registry.register(FileUtils.TEXT, registry::formatted); + registry.register(FileUtils.JSON, new JsonWriterFactory()); + registry.register(FileUtils.JSON_LINES, new JsonLinesWriterFactory()); + registry.register(FileUtils.XML, new XmlWriterFactory()); + registry.register(FileUtils.CSV, new DelimitedWriterFactory(FileOptions.DELIMITER_COMMA)); + registry.register(FileUtils.PSV, new DelimitedWriterFactory(FileOptions.DELIMITER_PIPE)); + registry.register(FileUtils.TSV, new DelimitedWriterFactory(FileOptions.DELIMITER_TAB)); + registry.register(FileUtils.TEXT, new FormattedWriterFactory()); return registry; } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/FileWriterResult.java b/core/riot-file/src/main/java/com/redis/riot/file/FileWriterResult.java new file mode 100644 index 000000000..2da481d91 --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/FileWriterResult.java @@ -0,0 +1,36 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemWriter; +import org.springframework.core.io.Resource; +import org.springframework.util.MimeType; + +public class FileWriterResult { + + private Resource resource; + private MimeType type; + private ItemWriter writer; + + public Resource getResource() { + return resource; + } + + public void setResource(Resource resource) { + this.resource = resource; + } + + public MimeType getType() { + return type; + } + + public void setType(MimeType mimeType) { + this.type = mimeType; + } + + public ItemWriter getWriter() { + return writer; + } + + public void setWriter(ItemWriter writer) { + this.writer = writer; + } +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/FixedWidthReaderFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/FixedWidthReaderFactory.java new file mode 100644 index 000000000..c448fb149 --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/FixedWidthReaderFactory.java @@ -0,0 +1,27 @@ +package com.redis.riot.file; + +import java.util.List; + +import org.springframework.batch.item.ItemReader; +import org.springframework.batch.item.file.transform.FixedLengthTokenizer; +import org.springframework.batch.item.file.transform.Range; +import org.springframework.batch.item.file.transform.RangeArrayPropertyEditor; +import org.springframework.core.io.Resource; +import org.springframework.util.Assert; + +public class FixedWidthReaderFactory extends AbstractReaderFactory { + + @Override + public ItemReader create(Resource resource, ReadOptions options) { + FixedLengthTokenizer tokenizer = new FixedLengthTokenizer(); + RangeArrayPropertyEditor editor = new RangeArrayPropertyEditor(); + List columnRanges = options.getColumnRanges(); + Assert.notEmpty(columnRanges, "Column ranges are required"); + editor.setAsText(String.join(",", columnRanges)); + Range[] ranges = (Range[]) editor.getValue(); + Assert.notEmpty(ranges, "Invalid ranges specified: " + columnRanges); + tokenizer.setColumns(ranges); + return flatFileReader(resource, options, tokenizer); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/FormattedWriterFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/FormattedWriterFactory.java new file mode 100644 index 000000000..55bdaf08d --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/FormattedWriterFactory.java @@ -0,0 +1,23 @@ +package com.redis.riot.file; + +import java.util.Map; + +import org.springframework.batch.item.ItemWriter; +import org.springframework.batch.item.file.transform.PassThroughFieldExtractor; +import org.springframework.core.io.WritableResource; + +import com.redis.riot.resource.FlatFileItemWriterBuilder; +import com.redis.riot.resource.FlatFileItemWriterBuilder.FormattedBuilder; + +public class FormattedWriterFactory extends AbstractWriterFactory { + + @Override + public ItemWriter create(WritableResource resource, WriteOptions options) { + FlatFileItemWriterBuilder> writer = flatFileWriter(resource, options); + FormattedBuilder> formattedBuilder = writer.formatted(); + formattedBuilder.format(options.getFormatterString()); + formattedBuilder.fieldExtractor(new PassThroughFieldExtractor<>()); + return flatFileWriter(options, writer, formattedBuilder.build()); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/GoogleStorageOptions.java b/core/riot-file/src/main/java/com/redis/riot/file/GoogleStorageOptions.java index 035c63079..3a7fb7c3b 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/GoogleStorageOptions.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/GoogleStorageOptions.java @@ -1,15 +1,62 @@ package com.redis.riot.file; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; import java.nio.file.Path; +import java.util.Base64; -import lombok.ToString; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.ServiceOptions; +import com.google.cloud.spring.autoconfigure.storage.GcpStorageAutoConfiguration; +import com.google.cloud.spring.core.GcpScope; +import com.google.cloud.spring.core.UserAgentHeaderProvider; +import com.google.cloud.storage.Storage; +import com.google.cloud.storage.StorageOptions; -@ToString(exclude = "encodedKey") public class GoogleStorageOptions { + public static final GcpScope DEFAULT_SCOPE = GcpScope.STORAGE_READ_ONLY; + private Path keyFile; private String projectId; private String encodedKey; + private GcpScope scope = DEFAULT_SCOPE; + + public Storage storage() { + StorageOptions.Builder builder = StorageOptions.newBuilder(); + builder.setProjectId(ServiceOptions.getDefaultProjectId()); + builder.setHeaderProvider(new UserAgentHeaderProvider(GcpStorageAutoConfiguration.class)); + if (keyFile != null) { + InputStream inputStream; + try { + inputStream = Files.newInputStream(keyFile); + } catch (IOException e) { + throw new RuntimeIOException("Could not read key file", e); + } + builder.setCredentials(credentials(inputStream)); + } + if (encodedKey != null) { + byte[] bytes = Base64.getDecoder().decode(encodedKey); + builder.setCredentials(credentials(new ByteArrayInputStream(bytes))); + } + if (projectId != null) { + builder.setProjectId(projectId); + } + return builder.build().getService(); + } + + private GoogleCredentials credentials(InputStream inputStream) { + GoogleCredentials credentials; + try { + credentials = GoogleCredentials.fromStream(inputStream); + } catch (IOException e) { + throw new RuntimeIOException("Could not create Google credentials", e); + } + credentials.createScoped(scope.getUrl()); + return credentials; + } public Path getKeyFile() { return keyFile; @@ -35,4 +82,12 @@ public void setEncodedKey(String encodedKey) { this.encodedKey = encodedKey; } + public GcpScope getScope() { + return scope; + } + + public void setScope(GcpScope scope) { + this.scope = scope; + } + } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/GoogleStorageProtocolResolver.java b/core/riot-file/src/main/java/com/redis/riot/file/GoogleStorageProtocolResolver.java index 020f88452..e23444530 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/GoogleStorageProtocolResolver.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/GoogleStorageProtocolResolver.java @@ -14,14 +14,6 @@ public class GoogleStorageProtocolResolver implements ProtocolResolver { private Supplier storageSupplier; private Storage storage; - public void setStorageSupplier(Supplier supplier) { - this.storageSupplier = supplier; - } - - public void setStorage(Storage storage) { - this.storage = storage; - } - @Override public Resource resolve(String location, ResourceLoader resourceLoader) { if (location.startsWith(com.google.cloud.spring.storage.GoogleStorageProtocolResolver.PROTOCOL)) { @@ -36,4 +28,13 @@ private Storage storage() { } return storage; } + + public void setStorage(Storage storage) { + this.storage = storage; + } + + public void setStorageSupplier(Supplier storageSupplier) { + this.storageSupplier = storageSupplier; + } + } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/JsonLinesReaderFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/JsonLinesReaderFactory.java new file mode 100644 index 000000000..ac5e279a9 --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/JsonLinesReaderFactory.java @@ -0,0 +1,30 @@ +package com.redis.riot.file; + +import java.util.Map; + +import org.springframework.batch.item.ItemReader; +import org.springframework.batch.item.file.builder.FlatFileItemReaderBuilder; +import org.springframework.batch.item.file.mapping.JsonLineMapper; +import org.springframework.core.io.Resource; + +import com.fasterxml.jackson.databind.ObjectMapper; + +public class JsonLinesReaderFactory extends AbstractReaderFactory { + + @Override + public ItemReader create(Resource resource, ReadOptions options) { + if (Map.class.isAssignableFrom(options.getItemType())) { + FlatFileItemReaderBuilder> reader = flatFileReader(options); + reader.resource(resource); + reader.lineMapper(new JsonLineMapper()); + reader.fieldSetMapper(new MapFieldSetMapper()); + return reader.build(); + } + FlatFileItemReaderBuilder reader = flatFileReader(options); + reader.resource(resource); + ObjectMapper objectMapper = objectMapper(new ObjectMapper(), options); + reader.lineMapper(new ObjectMapperLineMapper<>(objectMapper, options.getItemType())); + return reader.build(); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/JsonLinesWriterFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/JsonLinesWriterFactory.java new file mode 100644 index 000000000..ab5109584 --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/JsonLinesWriterFactory.java @@ -0,0 +1,18 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemWriter; +import org.springframework.core.io.WritableResource; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.redis.riot.resource.FlatFileItemWriterBuilder; + +public class JsonLinesWriterFactory extends AbstractWriterFactory { + + @Override + public ItemWriter create(WritableResource resource, WriteOptions options) { + FlatFileItemWriterBuilder builder = flatFileWriter(resource, options); + builder.lineAggregator(new JsonLineAggregator<>(new ObjectMapper())); + return builder.build(); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/JsonReaderFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/JsonReaderFactory.java new file mode 100644 index 000000000..08e0457af --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/JsonReaderFactory.java @@ -0,0 +1,27 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemReader; +import org.springframework.batch.item.json.JacksonJsonObjectReader; +import org.springframework.batch.item.json.builder.JsonItemReaderBuilder; +import org.springframework.core.io.Resource; + +import com.fasterxml.jackson.databind.ObjectMapper; + +public class JsonReaderFactory extends AbstractReaderFactory { + + @Override + public ItemReader create(Resource resource, ReadOptions options) { + JsonItemReaderBuilder builder = new JsonItemReaderBuilder<>(); + builder.name(resource.getFilename() + "-json-file-reader"); + builder.resource(resource); + builder.saveState(false); + JacksonJsonObjectReader objectReader = new JacksonJsonObjectReader<>(options.getItemType()); + objectReader.setMapper(objectMapper(new ObjectMapper(), options)); + builder.jsonObjectReader(objectReader); + if (options.getMaxItemCount() > 0) { + builder.maxItemCount(options.getMaxItemCount()); + } + return builder.build(); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/JsonWriterFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/JsonWriterFactory.java new file mode 100644 index 000000000..e9123d40b --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/JsonWriterFactory.java @@ -0,0 +1,29 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemWriter; +import org.springframework.batch.item.json.JacksonJsonObjectMarshaller; +import org.springframework.core.io.WritableResource; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.redis.riot.resource.JsonFileItemWriterBuilder; + +public class JsonWriterFactory extends AbstractWriterFactory { + + @Override + public ItemWriter create(WritableResource resource, WriteOptions options) { + JsonFileItemWriterBuilder writer = new JsonFileItemWriterBuilder<>(); + writer.name(resource.getFilename()); + writer.resource((WritableResource) resource); + writer.append(options.isAppend()); + writer.encoding(options.getEncoding()); + writer.forceSync(options.isForceSync()); + writer.lineSeparator(options.getLineSeparator()); + writer.saveState(false); + writer.shouldDeleteIfEmpty(options.isShouldDeleteIfEmpty()); + writer.shouldDeleteIfExists(options.isShouldDeleteIfExists()); + writer.transactional(options.isTransactional()); + writer.jsonObjectMarshaller(new JacksonJsonObjectMarshaller<>(objectMapper(new ObjectMapper()))); + return writer.build(); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/ReaderFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/ReaderFactory.java new file mode 100644 index 000000000..202b1407b --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/ReaderFactory.java @@ -0,0 +1,10 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemReader; +import org.springframework.core.io.Resource; + +public interface ReaderFactory { + + ItemReader create(Resource resource, ReadOptions options); + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/Factory.java b/core/riot-file/src/main/java/com/redis/riot/file/ResourceMap.java similarity index 50% rename from core/riot-file/src/main/java/com/redis/riot/file/Factory.java rename to core/riot-file/src/main/java/com/redis/riot/file/ResourceMap.java index 1de784f51..6254a9cf7 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/Factory.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/ResourceMap.java @@ -2,8 +2,8 @@ import org.springframework.core.io.Resource; -public interface Factory { +public interface ResourceMap { - T create(Resource resource, O options); + String getContentTypeFor(Resource resource); } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/ResourceTypeMap.java b/core/riot-file/src/main/java/com/redis/riot/file/ResourceTypeMap.java deleted file mode 100644 index 706ec15dc..000000000 --- a/core/riot-file/src/main/java/com/redis/riot/file/ResourceTypeMap.java +++ /dev/null @@ -1,64 +0,0 @@ -package com.redis.riot.file; - -import java.io.IOException; -import java.net.FileNameMap; -import java.net.URLConnection; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Iterator; -import java.util.LinkedHashSet; -import java.util.Set; - -import org.springframework.util.Assert; -import org.springframework.util.MimeType; - -public class ResourceTypeMap { - - private final Set fileNameMaps = new LinkedHashSet<>(); - - public void addFileNameMap(FileNameMap map) { - fileNameMaps.add(map); - } - - public MimeType getContentType(String filename) { - String normalizedFilename = FileUtils.normalize(filename); - try { - String type = Files.probeContentType(Path.of(normalizedFilename)); - if (type == null) { - type = URLConnection.guessContentTypeFromName(normalizedFilename); - Iterator maps = fileNameMaps.iterator(); - while (type == null && maps.hasNext()) { - type = maps.next().getContentTypeFor(normalizedFilename); - } - } - Assert.notNull(type, () -> "Could not determine type of " + normalizedFilename); - return MimeType.valueOf(type); - } catch (IOException e) { - throw new RuntimeIOException("Could not determine type of " + normalizedFilename, e); - } - } - - public static ResourceTypeMap defaultResourceTypeMap() { - ResourceTypeMap map = new ResourceTypeMap(); - map.addFileNameMap(new JsonLinesFileNameMap()); - return map; - } - - private static class JsonLinesFileNameMap implements FileNameMap { - - public static final String JSONL_SUFFIX = ".jsonl"; - - @Override - public String getContentTypeFor(String fileName) { - if (fileName == null) { - return null; - } - if (fileName.endsWith(JSONL_SUFFIX)) { - return FileUtils.JSON_LINES.toString(); - } - return null; - } - - } - -} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/DefaultResourceLoader.java b/core/riot-file/src/main/java/com/redis/riot/file/RiotResourceLoader.java similarity index 59% rename from core/riot-file/src/main/java/com/redis/riot/file/DefaultResourceLoader.java rename to core/riot-file/src/main/java/com/redis/riot/file/RiotResourceLoader.java index 865f95fe8..abc1bcc38 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/DefaultResourceLoader.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/RiotResourceLoader.java @@ -2,44 +2,27 @@ import java.net.MalformedURLException; import java.net.URL; +import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashSet; -import java.util.Map; +import java.util.List; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.FileUrlResource; import org.springframework.core.io.ProtocolResolver; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; -import org.springframework.core.io.UrlResource; -import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ResourceUtils; -/** - * Implementation of the {@link ResourceLoader} interface. - * - * - *

- * Will return a {@link UrlResource} if the location value is a URL, and a - * {@link ClassPathResource} if it is a non-URL path or a "classpath:" - * pseudo-URL. - * - */ -public class DefaultResourceLoader implements ResourceLoader { +public class RiotResourceLoader implements ResourceLoader { - private final Set protocolResolvers = new LinkedHashSet<>(4); - private final Map, Map> resourceCaches = new ConcurrentHashMap<>(4); + private GoogleStorageProtocolResolver googleStorageProtocolResolver = new GoogleStorageProtocolResolver(); + private S3ProtocolResolver s3ProtocolResolver = new S3ProtocolResolver(); - @Override - @Nullable - public ClassLoader getClassLoader() { - return ClassUtils.getDefaultClassLoader(); - } + private Set protocolResolvers = new LinkedHashSet<>(4); /** * Register the given resolver with this resource loader, allowing for @@ -55,6 +38,22 @@ public void addProtocolResolver(ProtocolResolver resolver) { this.protocolResolvers.add(resolver); } + public GoogleStorageProtocolResolver getGoogleStorageProtocolResolver() { + return googleStorageProtocolResolver; + } + + public void setGoogleStorageProtocolResolver(GoogleStorageProtocolResolver resolver) { + this.googleStorageProtocolResolver = resolver; + } + + public S3ProtocolResolver getS3ProtocolResolver() { + return s3ProtocolResolver; + } + + public void setS3ProtocolResolver(S3ProtocolResolver resolver) { + this.s3ProtocolResolver = resolver; + } + /** * Return the collection of currently registered protocol resolvers, allowing * for introspection as well as modification. @@ -65,31 +64,15 @@ public Collection getProtocolResolvers() { return this.protocolResolvers; } - /** - * Obtain a cache for the given value type, keyed by {@link Resource}. - * - * @param valueType the value type, e.g. an ASM {@code MetadataReader} - * @return the cache {@link Map}, shared at the {@code ResourceLoader} level - */ - @SuppressWarnings("unchecked") - public Map getResourceCache(Class valueType) { - return (Map) this.resourceCaches.computeIfAbsent(valueType, key -> new ConcurrentHashMap<>()); - } - - /** - * Clear all resource caches in this resource loader. - * - * @since 5.0 - * @see #getResourceCache - */ - public void clearResourceCaches() { - this.resourceCaches.clear(); + @Override + public ClassLoader getClassLoader() { + return ClassUtils.getDefaultClassLoader(); } @Override public Resource getResource(String location) { Assert.notNull(location, "Location must not be null"); - for (ProtocolResolver protocolResolver : getProtocolResolvers()) { + for (ProtocolResolver protocolResolver : allProtocolResolvers()) { Resource resource = protocolResolver.resolve(location, this); if (resource != null) { return resource; @@ -105,4 +88,12 @@ public Resource getResource(String location) { } } + private Iterable allProtocolResolvers() { + List resolvers = new ArrayList<>(); + resolvers.add(s3ProtocolResolver); + resolvers.add(googleStorageProtocolResolver); + resolvers.addAll(protocolResolvers); + return resolvers; + } + } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/RiotResourceMap.java b/core/riot-file/src/main/java/com/redis/riot/file/RiotResourceMap.java new file mode 100644 index 000000000..ee8fae3bf --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/RiotResourceMap.java @@ -0,0 +1,74 @@ +package com.redis.riot.file; + +import java.io.IOException; +import java.net.FileNameMap; +import java.net.URLConnection; +import java.nio.file.Files; +import java.util.LinkedHashSet; +import java.util.Set; + +import org.springframework.core.io.Resource; + +public class RiotResourceMap implements ResourceMap { + + private final Set fileNameMaps = defaultFileNameMaps(); + + public void addFileNameMap(FileNameMap map) { + fileNameMaps.add(map); + } + + public static Set defaultFileNameMaps() { + Set maps = new LinkedHashSet<>(); + maps.add(new JsonLinesFileNameMap()); + return maps; + } + + @Override + public String getContentTypeFor(Resource resource) { + String type = null; + if (resource.isFile()) { + try { + type = Files.probeContentType(resource.getFile().toPath()); + } catch (IOException e) { + // ignore + } + } + if (type == null) { + return getContentTypeFor(resource.getFilename()); + } + return type; + } + + public String getContentTypeFor(String filename) { + String normalizedFilename = FileUtils.stripGzipSuffix(filename); + String type = URLConnection.guessContentTypeFromName(normalizedFilename); + if (type != null) { + return type; + } + for (FileNameMap nameMap : fileNameMaps) { + String mapType = nameMap.getContentTypeFor(normalizedFilename); + if (mapType != null) { + return mapType; + } + } + throw new IllegalArgumentException("Could not determine type of " + filename); + } + + private static class JsonLinesFileNameMap implements FileNameMap { + + public static final String JSONL_SUFFIX = ".jsonl"; + + @Override + public String getContentTypeFor(String fileName) { + if (fileName == null) { + return null; + } + if (fileName.endsWith(JSONL_SUFFIX)) { + return FileUtils.JSON_LINES.toString(); + } + return null; + } + + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/S3Options.java b/core/riot-file/src/main/java/com/redis/riot/file/S3Options.java index c0a5edf5c..72d789542 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/S3Options.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/S3Options.java @@ -2,10 +2,14 @@ import java.net.URI; -import lombok.ToString; +import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; -@ToString public class S3Options { private String accessKey; @@ -13,6 +17,25 @@ public class S3Options { private Region region; private URI endpoint; + public S3Client client() { + S3ClientBuilder clientBuilder = S3Client.builder(); + if (region != null) { + clientBuilder.region(region); + } + if (endpoint != null) { + clientBuilder.endpointOverride(endpoint); + } + clientBuilder.credentialsProvider(credentialsProvider()); + return clientBuilder.build(); + } + + private AwsCredentialsProvider credentialsProvider() { + if (accessKey == null && secretKey == null) { + return AnonymousCredentialsProvider.create(); + } + return StaticCredentialsProvider.create(AwsBasicCredentials.create(accessKey, secretKey)); + } + public String getAccessKey() { return accessKey; } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/S3ProtocolResolver.java b/core/riot-file/src/main/java/com/redis/riot/file/S3ProtocolResolver.java index 3102e2513..e95c2dae7 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/S3ProtocolResolver.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/S3ProtocolResolver.java @@ -5,39 +5,37 @@ import org.springframework.core.io.ProtocolResolver; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; +import org.springframework.util.ClassUtils; import io.awspring.cloud.s3.InMemoryBufferingS3OutputStreamProvider; import io.awspring.cloud.s3.Location; import io.awspring.cloud.s3.PropertiesS3ObjectContentTypeResolver; +import io.awspring.cloud.s3.S3ObjectContentTypeResolver; import io.awspring.cloud.s3.S3OutputStreamProvider; import io.awspring.cloud.s3.S3Resource; import software.amazon.awssdk.services.s3.S3Client; -public class S3ProtocolResolver implements ProtocolResolver { +public class S3ProtocolResolver implements ProtocolResolver, ResourceLoader { private Supplier clientSupplier; - private S3Client client; private S3OutputStreamProvider outputStreamProvider; - public void setClientSupplier(Supplier client) { - this.clientSupplier = client; - } - - public void setClient(S3Client client) { - this.client = client; + @Override + public Resource resolve(String location, ResourceLoader resourceLoader) { + if (isS3(location)) { + return getResource(location); + } + return null; } - public void setOutputStreamProvider(S3OutputStreamProvider outputStreamProvider) { - this.outputStreamProvider = outputStreamProvider; + private boolean isS3(String location) { + return location.startsWith(Location.S3_PROTOCOL_PREFIX); } @Override - public Resource resolve(String location, ResourceLoader resourceLoader) { - if (location.startsWith(Location.S3_PROTOCOL_PREFIX)) { - return new S3Resource(location, client(), outputStreamProvider()); - } - return null; + public Resource getResource(String location) { + return new S3Resource(location, client(), outputStreamProvider()); } private S3Client client() { @@ -47,12 +45,25 @@ private S3Client client() { return client; } - private S3OutputStreamProvider outputStreamProvider() { + public S3OutputStreamProvider outputStreamProvider() { if (outputStreamProvider == null) { - PropertiesS3ObjectContentTypeResolver contentTypeResolver = new PropertiesS3ObjectContentTypeResolver(); + S3ObjectContentTypeResolver contentTypeResolver = new PropertiesS3ObjectContentTypeResolver(); outputStreamProvider = new InMemoryBufferingS3OutputStreamProvider(client(), contentTypeResolver); } return outputStreamProvider; } + @Override + public ClassLoader getClassLoader() { + return ClassUtils.getDefaultClassLoader(); + } + + public void setClient(S3Client client) { + this.client = client; + } + + public void setClientSupplier(Supplier clientSupplier) { + this.clientSupplier = clientSupplier; + } + } diff --git a/core/riot-file/src/main/java/com/redis/riot/file/StdInProtocolResolver.java b/core/riot-file/src/main/java/com/redis/riot/file/StdInProtocolResolver.java index 7daf0544a..641707cb3 100644 --- a/core/riot-file/src/main/java/com/redis/riot/file/StdInProtocolResolver.java +++ b/core/riot-file/src/main/java/com/redis/riot/file/StdInProtocolResolver.java @@ -8,7 +8,15 @@ public class StdInProtocolResolver implements ProtocolResolver { public static final String DEFAULT_FILENAME = SystemInResource.FILENAME; - private String filename = DEFAULT_FILENAME; + private String filename; + + public StdInProtocolResolver() { + setFilename(DEFAULT_FILENAME); + } + + public String getFilename() { + return filename; + } public void setFilename(String filename) { this.filename = filename; diff --git a/core/riot-file/src/main/java/com/redis/riot/file/WriterFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/WriterFactory.java new file mode 100644 index 000000000..996d8e3ac --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/WriterFactory.java @@ -0,0 +1,9 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemWriter; +import org.springframework.core.io.WritableResource; + +public interface WriterFactory { + + ItemWriter create(WritableResource resource, WriteOptions options); +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/XmlReaderFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/XmlReaderFactory.java new file mode 100644 index 000000000..1fca51ced --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/XmlReaderFactory.java @@ -0,0 +1,26 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemReader; +import org.springframework.core.io.Resource; + +import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import com.redis.riot.file.xml.XmlItemReaderBuilder; +import com.redis.riot.file.xml.XmlObjectReader; + +public class XmlReaderFactory extends AbstractReaderFactory { + + @Override + public ItemReader create(Resource resource, ReadOptions options) { + XmlItemReaderBuilder builder = new XmlItemReaderBuilder<>(); + builder.name(resource.getFilename() + "-xml-file-reader"); + builder.resource(resource); + XmlObjectReader objectReader = new XmlObjectReader<>(options.getItemType()); + objectReader.setMapper(objectMapper(new XmlMapper(), options)); + builder.xmlObjectReader(objectReader); + if (options.getMaxItemCount() > 0) { + builder.maxItemCount(options.getMaxItemCount()); + } + return builder.build(); + } + +} diff --git a/core/riot-file/src/main/java/com/redis/riot/file/XmlWriterFactory.java b/core/riot-file/src/main/java/com/redis/riot/file/XmlWriterFactory.java new file mode 100644 index 000000000..152362f74 --- /dev/null +++ b/core/riot-file/src/main/java/com/redis/riot/file/XmlWriterFactory.java @@ -0,0 +1,28 @@ +package com.redis.riot.file; + +import org.springframework.batch.item.ItemWriter; +import org.springframework.batch.item.json.JacksonJsonObjectMarshaller; +import org.springframework.core.io.WritableResource; + +import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import com.redis.riot.file.xml.XmlResourceItemWriterBuilder; + +public class XmlWriterFactory extends AbstractWriterFactory { + + @Override + public ItemWriter create(WritableResource resource, WriteOptions options) { + XmlResourceItemWriterBuilder writer = new XmlResourceItemWriterBuilder<>(); + writer.name(resource.getFilename()); + writer.append(options.isAppend()); + writer.encoding(options.getEncoding()); + writer.lineSeparator(options.getLineSeparator()); + writer.rootName(options.getRootName()); + writer.resource(resource); + writer.saveState(false); + XmlMapper mapper = objectMapper(new XmlMapper()); + mapper.setConfig(mapper.getSerializationConfig().withRootName(options.getElementName())); + writer.xmlObjectMarshaller(new JacksonJsonObjectMarshaller<>(mapper)); + return writer.build(); + } + +} diff --git a/core/riot-file/src/test/java/com/redis/riot/file/ReaderTests.java b/core/riot-file/src/test/java/com/redis/riot/file/ReaderTests.java index df8ad1a0f..ca27d5a7a 100644 --- a/core/riot-file/src/test/java/com/redis/riot/file/ReaderTests.java +++ b/core/riot-file/src/test/java/com/redis/riot/file/ReaderTests.java @@ -32,7 +32,8 @@ public class ReaderTests { public static final String JSONL_URL = BUCKET_URL + JSONL_FILE; public static final String CSV_FILE = "beers.csv"; public static final String CSV_URL = BUCKET_URL + CSV_FILE; - public static final String JSON_S3_URL = "s3://riot-bucket-jrx/beers.json"; + public static final String S3_BUCKET_URL = "s3://riot-bucket-jrx"; + public static final String JSON_S3_URL = S3_BUCKET_URL + "/beers.json"; public static final String JSON_GOOGLE_STORAGE_URL = "gs://riot-bucket-jrx/beers.json"; public static final String JSON_GZ_URL = "http://storage.googleapis.com/jrx/beers.json.gz"; @@ -40,12 +41,12 @@ public class ReaderTests { @Test void readJsonUrl() throws Exception { - assertRead(JSON_URL, new ReadOptions(), JsonItemReader.class, 216); + assertRead(JSON_URL, JsonItemReader.class, 216); } @Test void readJsonGzUrl() throws Exception { - assertRead(JSON_GZ_URL, new ReadOptions(), JsonItemReader.class, 216); + assertRead(JSON_GZ_URL, JsonItemReader.class, 216); } @Test @@ -57,7 +58,7 @@ void readJsonS3Url() throws Exception { @Test void readJsonGoogleStorageUrl() throws Exception { - assertRead(JSON_GOOGLE_STORAGE_URL, new ReadOptions(), JsonItemReader.class, 4432); + assertRead(JSON_GOOGLE_STORAGE_URL, JsonItemReader.class, 4432); } @Test @@ -66,7 +67,7 @@ void readJsonFile() throws Exception { try (FileOutputStream outputStream = new FileOutputStream(file.toFile())) { StreamUtils.copy(urlInputStream(JSON_URL), outputStream); } - assertRead(file.toFile().getAbsolutePath(), new ReadOptions(), JsonItemReader.class, 216); + assertRead(file.toFile().getAbsolutePath(), JsonItemReader.class, 216); } private InputStream urlInputStream(String url) throws MalformedURLException, IOException, URISyntaxException { @@ -75,7 +76,7 @@ private InputStream urlInputStream(String url) throws MalformedURLException, IOE @Test void readJsonLinesUrl() throws Exception { - assertRead(JSONL_URL, new ReadOptions(), FlatFileItemReader.class, 6); + assertRead(JSONL_URL, FlatFileItemReader.class, 6); } @Test @@ -87,15 +88,19 @@ void readCsvUrl() throws Exception { @Test void readStdIn() throws Exception { - DefaultResourceLoader resourceLoader = new DefaultResourceLoader(); + RiotResourceLoader resourceLoader = new RiotResourceLoader(); resourceLoader.addProtocolResolver(new StdInProtocolResolver()); Resource resource = resourceLoader.getResource(SystemInResource.FILENAME); Assertions.assertInstanceOf(SystemInResource.class, resource); } + private void assertRead(String location, Class expectedType, int expectedCount) throws Exception { + assertRead(location, new ReadOptions(), expectedType, expectedCount); + } + private void assertRead(String location, ReadOptions options, Class expectedType, int expectedCount) throws Exception { - ItemReader reader = registry.get(location, options); + ItemReader reader = registry.find(location, options).getReader(); Assertions.assertNotNull(reader); List items = readAll(reader); Assertions.assertEquals(expectedCount, items.size()); diff --git a/plugins/riot/src/main/java/com/redis/riot/AbstractExportCommand.java b/plugins/riot/src/main/java/com/redis/riot/AbstractExportCommand.java index 7f3c2f425..696cc85a4 100644 --- a/plugins/riot/src/main/java/com/redis/riot/AbstractExportCommand.java +++ b/plugins/riot/src/main/java/com/redis/riot/AbstractExportCommand.java @@ -11,6 +11,7 @@ import com.redis.lettucemod.RedisModulesUtils; import com.redis.lettucemod.api.StatefulRedisModulesConnection; import com.redis.riot.core.AbstractJobCommand; +import com.redis.riot.core.RiotInitializationException; import com.redis.riot.core.Step; import com.redis.spring.batch.item.redis.RedisItemReader; import com.redis.spring.batch.item.redis.RedisItemReader.ReaderMode; @@ -34,14 +35,18 @@ public abstract class AbstractExportCommand extends AbstractJobCommand { private RedisContext sourceRedisContext; @Override - protected void execute() throws Exception { + protected void initialize() throws RiotInitializationException { + super.initialize(); sourceRedisContext = sourceRedisContext(); sourceRedisContext.afterPropertiesSet(); - try { - super.execute(); - } finally { + } + + @Override + protected void teardown() { + if (sourceRedisContext != null) { sourceRedisContext.close(); } + super.teardown(); } protected void configure(StandardEvaluationContext context) { diff --git a/plugins/riot/src/main/java/com/redis/riot/AbstractFileExport.java b/plugins/riot/src/main/java/com/redis/riot/AbstractFileExport.java index 76d5f1040..fd85703d2 100644 --- a/plugins/riot/src/main/java/com/redis/riot/AbstractFileExport.java +++ b/plugins/riot/src/main/java/com/redis/riot/AbstractFileExport.java @@ -1,6 +1,5 @@ package com.redis.riot; -import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -8,16 +7,16 @@ import java.util.Set; import org.springframework.batch.core.Job; -import org.springframework.batch.core.step.builder.StepBuilderException; import org.springframework.batch.item.ExecutionContext; import org.springframework.batch.item.ItemProcessor; import org.springframework.batch.item.ItemStreamException; -import org.springframework.batch.item.ItemWriter; import org.springframework.util.MimeType; +import com.redis.riot.core.RiotInitializationException; import com.redis.riot.core.Step; import com.redis.riot.file.FileUtils; import com.redis.riot.file.FileWriterRegistry; +import com.redis.riot.file.FileWriterResult; import com.redis.riot.file.StdOutProtocolResolver; import com.redis.riot.file.WriteOptions; import com.redis.spring.batch.item.redis.RedisItemReader; @@ -31,8 +30,6 @@ @Command(name = "file-export", description = "Export Redis data to files.") public abstract class AbstractFileExport extends AbstractRedisExportCommand { - private FileWriterRegistry writerRegistry = FileWriterRegistry.defaultWriterRegistry(); - private Set flatFileTypes = new HashSet<>( Arrays.asList(FileUtils.CSV, FileUtils.PSV, FileUtils.TSV, FileUtils.TEXT)); @@ -45,6 +42,23 @@ public abstract class AbstractFileExport extends AbstractRedisExportCommand { @Option(names = "--content-type", description = "Type of exported content: ${COMPLETION-CANDIDATES}.", paramLabel = "") private ContentType contentType = ContentType.STRUCT; + private FileWriterRegistry writerRegistry; + private WriteOptions writeOptions; + + @Override + protected void initialize() throws RiotInitializationException { + super.initialize(); + writerRegistry = FileWriterRegistry.defaultWriterRegistry(); + writeOptions = writeOptions(); + } + + private WriteOptions writeOptions() { + WriteOptions writeOptions = fileWriterArgs.fileWriterOptions(); + writeOptions.setContentType(getFileType()); + writeOptions.setHeaderSupplier(this::headerRecord); + return writeOptions; + } + @Override protected Job job() { return job(step()); @@ -58,16 +72,8 @@ public void setFlatFileTypes(MimeType... types) { @SuppressWarnings("unchecked") private Step step() { - WriteOptions writerOptions = fileWriterArgs.fileWriterOptions(); - writerOptions.setType(getFileType()); - writerOptions.setHeaderSupplier(this::headerRecord); - ItemWriter writer; - try { - writer = writerRegistry.get(file, writerOptions); - } catch (IOException e) { - throw new StepBuilderException(e); - } - return step(writer).processor(processor()); + FileWriterResult writer = writerRegistry.find(file, writeOptions); + return step(writer.getWriter()).processor(processor(writer.getType())); } @Override @@ -75,24 +81,18 @@ protected boolean shouldShowProgress() { return super.shouldShowProgress() && file != null; } - private ContentType contentType() { - MimeType type = writerRegistry.getType(file, getFileType()); - return isFlatFile(type) ? ContentType.MAP : contentType; - } - private boolean isFlatFile(MimeType type) { return flatFileTypes.contains(type); } @SuppressWarnings("rawtypes") - private ItemProcessor processor() { - if (contentType() == ContentType.MAP) { + private ItemProcessor processor(MimeType type) { + if (isFlatFile(type) || contentType == ContentType.MAP) { return mapProcessor(); } return null; } - @SuppressWarnings("unchecked") private Map headerRecord() { RedisItemReader reader = RedisItemReader.struct(); configureSourceRedisReader(reader); @@ -103,7 +103,7 @@ private Map headerRecord() { if (keyValue == null) { return Collections.emptyMap(); } - return ((ItemProcessor, Map>) processor()).process(keyValue); + return ((ItemProcessor, Map>) mapProcessor()).process(keyValue); } catch (Exception e) { throw new ItemStreamException("Could not read header record", e); } diff --git a/plugins/riot/src/main/java/com/redis/riot/AbstractFileImport.java b/plugins/riot/src/main/java/com/redis/riot/AbstractFileImport.java index 90acaa68e..799ac579f 100644 --- a/plugins/riot/src/main/java/com/redis/riot/AbstractFileImport.java +++ b/plugins/riot/src/main/java/com/redis/riot/AbstractFileImport.java @@ -1,6 +1,5 @@ package com.redis.riot; -import java.io.IOException; import java.text.ParseException; import java.util.ArrayList; import java.util.Arrays; @@ -12,19 +11,20 @@ import java.util.stream.Collectors; import org.springframework.batch.core.Job; -import org.springframework.batch.core.step.builder.StepBuilderException; import org.springframework.batch.item.ItemProcessor; -import org.springframework.batch.item.ItemReader; import org.springframework.batch.item.function.FunctionItemProcessor; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; +import com.redis.riot.core.RiotInitializationException; import com.redis.riot.core.RiotUtils; import com.redis.riot.core.Step; import com.redis.riot.core.processor.RegexNamedGroupFunction; -import com.redis.riot.file.ReadOptions; import com.redis.riot.file.FileReaderRegistry; +import com.redis.riot.file.FileReaderResult; +import com.redis.riot.file.ReadOptions; +import com.redis.riot.file.StdInProtocolResolver; import com.redis.riot.function.MapToFieldFunction; import com.redis.spring.batch.item.redis.RedisItemWriter; import com.redis.spring.batch.item.redis.common.KeyValue; @@ -37,7 +37,7 @@ @Command(name = "file-import", description = "Import data from files.") public abstract class AbstractFileImport extends AbstractRedisImportCommand { - private FileReaderRegistry readerRegistry = FileReaderRegistry.defaultReaderRegistry(); + public static final String STDIN_FILENAME = "-"; @Parameters(arity = "1..*", description = "Files or URLs to import. Use '-' to read from stdin.", paramLabel = "FILE") private List files; @@ -48,23 +48,37 @@ public abstract class AbstractFileImport extends AbstractRedisImportCommand { @Option(arity = "1..*", names = "--regex", description = "Regular expressions used to extract values from fields in the form field1=\"regex\" field2=\"regex\"...", paramLabel = "") private Map regexes = new LinkedHashMap<>(); + private FileReaderRegistry readerRegistry; + private ReadOptions readOptions; + @Override - protected Job job() { + protected void initialize() throws RiotInitializationException { + super.initialize(); Assert.notEmpty(files, "No file specified"); - ReadOptions options = readOptions(); - return job(files.stream().map(f -> step(f, options)).collect(Collectors.toList())); + readerRegistry = readerRegistry(); + readOptions = readOptions(); } - private Step step(String location, ReadOptions options) { - ItemReader reader; - try { - reader = readerRegistry.get(location, options); - } catch (IOException e) { - throw new StepBuilderException(e); - } + private FileReaderRegistry readerRegistry() { + FileReaderRegistry registry = FileReaderRegistry.defaultReaderRegistry(); + StdInProtocolResolver stdInProtocolResolver = new StdInProtocolResolver(); + stdInProtocolResolver.setFilename(STDIN_FILENAME); + registry.addProtocolResolver(stdInProtocolResolver); + return registry; + } + + @Override + protected Job job() { + return job(files.stream().map(this::step).collect(Collectors.toList())); + } + + private Step step(String location) { + FileReaderResult reader = readerRegistry.find(location, readOptions); + Assert.notNull(reader.getReader(), + () -> String.format("No reader found for type %s and file %s", reader.getType(), location)); RedisItemWriter writer = writer(); configureTargetRedisWriter(writer); - Step step = new Step<>(reader, writer); + Step step = new Step<>(reader.getReader(), writer); step.name(location); if (hasOperations()) { step.processor(RiotUtils.processor(processor(), regexProcessor())); @@ -73,24 +87,30 @@ protected Job job() { step.skip(org.springframework.batch.item.ParseException.class); step.noRetry(ParseException.class); step.noRetry(org.springframework.batch.item.ParseException.class); - step.taskName(String.format("Importing %s", location)); + step.taskName(String.format("Importing %s", reader.getResource().getFilename())); return step; } private ReadOptions readOptions() { ReadOptions options = fileReaderArgs.readOptions(); - options.setType(getFileType()); + options.setContentType(getFileType()); options.setItemType(itemType()); options.addDeserializer(KeyValue.class, new KeyValueDeserializer()); return options; } private Class itemType() { - return hasOperations() ? Map.class : KeyValue.class; + if (hasOperations()) { + return Map.class; + } + return KeyValue.class; } private RedisItemWriter writer() { - return hasOperations() ? operationWriter() : RedisItemWriter.struct(); + if (hasOperations()) { + return operationWriter(); + } + return RedisItemWriter.struct(); } protected abstract MimeType getFileType(); @@ -138,6 +158,10 @@ public void setRegexes(Map regexes) { this.regexes = regexes; } + public FileReaderRegistry getReaderRegistry() { + return readerRegistry; + } + public void setReaderRegistry(FileReaderRegistry registry) { this.readerRegistry = registry; } diff --git a/plugins/riot/src/main/java/com/redis/riot/AbstractImportCommand.java b/plugins/riot/src/main/java/com/redis/riot/AbstractImportCommand.java index 938e8d3a6..921b19d96 100644 --- a/plugins/riot/src/main/java/com/redis/riot/AbstractImportCommand.java +++ b/plugins/riot/src/main/java/com/redis/riot/AbstractImportCommand.java @@ -16,6 +16,7 @@ import com.redis.riot.core.AbstractJobCommand; import com.redis.riot.core.QuietMapAccessor; +import com.redis.riot.core.RiotInitializationException; import com.redis.riot.core.RiotUtils; import com.redis.riot.core.Step; import com.redis.riot.core.processor.PredicateOperator; @@ -65,6 +66,21 @@ public abstract class AbstractImportCommand extends AbstractJobCommand { */ private List importOperationCommands = new ArrayList<>(); + @Override + protected void initialize() throws RiotInitializationException { + super.initialize(); + targetRedisContext = targetRedisContext(); + targetRedisContext.afterPropertiesSet(); + } + + @Override + protected void teardown() { + if (targetRedisContext != null) { + targetRedisContext.close(); + } + super.teardown(); + } + protected List, Object>> operations() { return importOperationCommands.stream().map(OperationCommand::operation).collect(Collectors.toList()); } @@ -83,17 +99,6 @@ protected Step, Map> step(ItemReader, Map> processor() { log.info("Creating SpEL evaluation context with {}", evaluationContextArgs); StandardEvaluationContext evaluationContext = evaluationContextArgs.evaluationContext(); diff --git a/plugins/riot/src/main/java/com/redis/riot/AbstractRedisCommand.java b/plugins/riot/src/main/java/com/redis/riot/AbstractRedisCommand.java index b71686dd1..2b063729e 100644 --- a/plugins/riot/src/main/java/com/redis/riot/AbstractRedisCommand.java +++ b/plugins/riot/src/main/java/com/redis/riot/AbstractRedisCommand.java @@ -2,6 +2,7 @@ import com.redis.lettucemod.api.sync.RedisModulesCommands; import com.redis.riot.core.AbstractJobCommand; +import com.redis.riot.core.RiotInitializationException; import com.redis.spring.batch.item.redis.RedisItemWriter; import picocli.CommandLine.ArgGroup; @@ -14,14 +15,18 @@ public abstract class AbstractRedisCommand extends AbstractJobCommand { private RedisContext redisContext; @Override - protected void execute() throws Exception { + protected void initialize() throws RiotInitializationException { + super.initialize(); redisContext = RedisContext.of(redisArgs); redisContext.afterPropertiesSet(); - try { - super.execute(); - } finally { + } + + @Override + protected void teardown() { + if (redisContext != null) { redisContext.close(); } + super.teardown(); } protected RedisModulesCommands commands() { diff --git a/plugins/riot/src/main/java/com/redis/riot/AbstractRedisTargetExportCommand.java b/plugins/riot/src/main/java/com/redis/riot/AbstractRedisTargetExportCommand.java index 9345125b3..6e3673483 100644 --- a/plugins/riot/src/main/java/com/redis/riot/AbstractRedisTargetExportCommand.java +++ b/plugins/riot/src/main/java/com/redis/riot/AbstractRedisTargetExportCommand.java @@ -2,6 +2,7 @@ import org.springframework.expression.spel.support.StandardEvaluationContext; +import com.redis.riot.core.RiotInitializationException; import com.redis.spring.batch.item.redis.RedisItemReader; import com.redis.spring.batch.item.redis.RedisItemWriter; @@ -14,9 +15,6 @@ public abstract class AbstractRedisTargetExportCommand extends AbstractExportCom public static final int DEFAULT_TARGET_POOL_SIZE = RedisItemReader.DEFAULT_POOL_SIZE; private static final String VAR_TARGET = "target"; - @ArgGroup(exclusive = false, heading = "TLS options%n") - private SslArgs sslArgs = new SslArgs(); - @Parameters(arity = "1", index = "0", description = "Source server URI or endpoint in the form host:port.", paramLabel = "SOURCE") private RedisURI sourceRedisUri; @@ -32,29 +30,29 @@ public abstract class AbstractRedisTargetExportCommand extends AbstractExportCom private RedisContext targetRedisContext; @Override - protected void execute() throws Exception { + protected void initialize() throws RiotInitializationException { + super.initialize(); targetRedisContext = targetRedisContext(); targetRedisContext.afterPropertiesSet(); - try { - super.execute(); - } finally { + } + + @Override + protected void teardown() { + if (targetRedisContext != null) { targetRedisContext.close(); } + super.teardown(); } @Override protected RedisContext sourceRedisContext() { - log.info("Creating source Redis context with {} {} {}", sourceRedisUri, sourceRedisArgs, sslArgs); - RedisContext context = RedisContext.of(sourceRedisUri, sourceRedisArgs); - context.sslOptions(sslArgs.sslOptions()); - return context; + log.info("Creating source Redis context with {} {} {}", sourceRedisUri, sourceRedisArgs); + return RedisContext.of(sourceRedisUri, sourceRedisArgs); } protected RedisContext targetRedisContext() { - log.info("Creating target Redis context with {} {} {}", targetRedisUri, targetRedisArgs, sslArgs); - RedisContext context = RedisContext.of(targetRedisUri, targetRedisArgs); - context.sslOptions(sslArgs.sslOptions()); - return context; + log.info("Creating target Redis context with {} {} {}", targetRedisUri, targetRedisArgs); + return RedisContext.of(targetRedisUri, targetRedisArgs); } @Override @@ -104,12 +102,4 @@ public void setTargetRedisArgs(TargetRedisArgs targetRedisArgs) { this.targetRedisArgs = targetRedisArgs; } - public SslArgs getSslArgs() { - return sslArgs; - } - - public void setSslArgs(SslArgs sslArgs) { - this.sslArgs = sslArgs; - } - } diff --git a/plugins/riot/src/main/java/com/redis/riot/FileArgs.java b/plugins/riot/src/main/java/com/redis/riot/FileArgs.java index e52567c05..2d7b923a6 100644 --- a/plugins/riot/src/main/java/com/redis/riot/FileArgs.java +++ b/plugins/riot/src/main/java/com/redis/riot/FileArgs.java @@ -12,6 +12,12 @@ public class FileArgs { @Option(names = { "-z", "--gzip" }, description = "File is gzip compressed.") private boolean gzipped; + @ArgGroup(exclusive = false) + private S3Args s3Args = new S3Args(); + + @ArgGroup(exclusive = false) + private GoogleStorageArgs googleStorageArgs = new GoogleStorageArgs(); + @Option(names = "--delimiter", description = "Delimiter character.", paramLabel = "") private String delimiter; @@ -24,12 +30,6 @@ public class FileArgs { @Option(names = "--quote", description = "Escape character for CSV files (default: ${DEFAULT-VALUE}).", paramLabel = "") private char quoteCharacter = FileOptions.DEFAULT_QUOTE_CHARACTER; - @ArgGroup(exclusive = false) - private S3Args s3Args = new S3Args(); - - @ArgGroup(exclusive = false) - private GoogleStorageArgs googleStorageArgs = new GoogleStorageArgs(); - public S3Args getS3Args() { return s3Args; } @@ -87,7 +87,6 @@ public void setHeader(boolean header) { } public void apply(FileOptions options) { - options.setGzipped(gzipped); options.setDelimiter(delimiter); options.setEncoding(encoding); options.setHeader(header); diff --git a/plugins/riot/src/main/java/com/redis/riot/GoogleStorageArgs.java b/plugins/riot/src/main/java/com/redis/riot/GoogleStorageArgs.java index edecf2dca..0b7d040a1 100644 --- a/plugins/riot/src/main/java/com/redis/riot/GoogleStorageArgs.java +++ b/plugins/riot/src/main/java/com/redis/riot/GoogleStorageArgs.java @@ -50,5 +50,4 @@ public GoogleStorageOptions googleStorageOptions() { options.setProjectId(projectId); return options; } - } diff --git a/plugins/riot/src/main/java/com/redis/riot/RedisArgs.java b/plugins/riot/src/main/java/com/redis/riot/RedisArgs.java index 6c0c97aee..7f17b3d6b 100644 --- a/plugins/riot/src/main/java/com/redis/riot/RedisArgs.java +++ b/plugins/riot/src/main/java/com/redis/riot/RedisArgs.java @@ -1,14 +1,14 @@ package com.redis.riot; +import java.io.File; import java.time.Duration; import io.lettuce.core.RedisURI; import io.lettuce.core.protocol.ProtocolVersion; import lombok.ToString; -import picocli.CommandLine.ArgGroup; import picocli.CommandLine.Option; -@ToString(exclude = "password") +@ToString(exclude = { "password", "keystorePassword", "truststorePassword", "keyPassword" }) public class RedisArgs implements RedisClientArgs { @Option(names = { "-u", "--uri" }, description = "Redis server URI.", paramLabel = "") @@ -54,15 +54,108 @@ public class RedisArgs implements RedisClientArgs { @Option(names = "--resp", description = "Redis protocol version used to connect to Redis: ${COMPLETION-CANDIDATES} (default: ${DEFAULT-VALUE}).", paramLabel = "") private ProtocolVersion protocolVersion = DEFAULT_PROTOCOL_VERSION; - @ArgGroup(exclusive = false) - private SslArgs sslArgs = new SslArgs(); - @Option(names = "--pool", description = "Max number of Redis connections (default: ${DEFAULT-VALUE}).", paramLabel = "") private int poolSize = DEFAULT_POOL_SIZE; @Option(names = "--read-from", description = "Which Redis cluster nodes to read from: ${COMPLETION-CANDIDATES} (default: ${DEFAULT-VALUE}).", paramLabel = "") private ReadFrom readFrom = DEFAULT_READ_FROM; + @Option(names = "--keystore", description = "Path to keystore.", paramLabel = "", hidden = true) + private File keystore; + + @Option(names = "--keystore-pass", arity = "0..1", interactive = true, description = "Keystore password.", paramLabel = "", hidden = true) + private char[] keystorePassword; + + @Option(names = "--trust", description = "Path to truststore.", paramLabel = "", hidden = true) + private File truststore; + + @Option(names = "--trust-pass", arity = "0..1", interactive = true, description = "Truststore password.", paramLabel = "", hidden = true) + private char[] truststorePassword; + + @Option(names = "--cert", description = "Client certificate to authenticate with (X.509 PEM).", paramLabel = "") + private File keyCert; + + @Option(names = "--key", description = "Private key file to authenticate with (PKCS#8 PEM).", paramLabel = "") + private File key; + + @Option(names = "--key-pass", arity = "0..1", interactive = true, description = "Private key password.", paramLabel = "") + private char[] keyPassword; + + @Option(names = "--cacert", description = "CA Certificate file to verify with (X.509).", paramLabel = "") + private File trustedCerts; + + @Override + public File getKeystore() { + return keystore; + } + + public void setKeystore(File keystore) { + this.keystore = keystore; + } + + @Override + public char[] getKeystorePassword() { + return keystorePassword; + } + + public void setKeystorePassword(char[] keystorePassword) { + this.keystorePassword = keystorePassword; + } + + @Override + public File getTruststore() { + return truststore; + } + + public void setTruststore(File truststore) { + this.truststore = truststore; + } + + @Override + public char[] getTruststorePassword() { + return truststorePassword; + } + + public void setTruststorePassword(char[] truststorePassword) { + this.truststorePassword = truststorePassword; + } + + @Override + public File getKeyCert() { + return keyCert; + } + + public void setKeyCert(File keyCert) { + this.keyCert = keyCert; + } + + @Override + public File getKey() { + return key; + } + + public void setKey(File key) { + this.key = key; + } + + @Override + public char[] getKeyPassword() { + return keyPassword; + } + + public void setKeyPassword(char[] keyPassword) { + this.keyPassword = keyPassword; + } + + @Override + public File getTrustedCerts() { + return trustedCerts; + } + + public void setTrustedCerts(File trustedCerts) { + this.trustedCerts = trustedCerts; + } + @Override public boolean isCluster() { return cluster; @@ -81,14 +174,6 @@ public void setProtocolVersion(ProtocolVersion version) { this.protocolVersion = version; } - public SslArgs getSslArgs() { - return sslArgs; - } - - public void setSslArgs(SslArgs sslArgs) { - this.sslArgs = sslArgs; - } - public RedisURI getUri() { return uri; } diff --git a/plugins/riot/src/main/java/com/redis/riot/RedisClientArgs.java b/plugins/riot/src/main/java/com/redis/riot/RedisClientArgs.java index 1f2fc5837..291c3459e 100644 --- a/plugins/riot/src/main/java/com/redis/riot/RedisClientArgs.java +++ b/plugins/riot/src/main/java/com/redis/riot/RedisClientArgs.java @@ -1,5 +1,6 @@ package com.redis.riot; +import java.io.File; import java.time.Duration; import com.redis.lettucemod.RedisURIBuilder; @@ -73,4 +74,36 @@ default ReadFrom getReadFrom() { return DEFAULT_READ_FROM; } + default File getKeystore() { + return null; + } + + default char[] getKeystorePassword() { + return null; + } + + default File getTruststore() { + return null; + } + + default char[] getTruststorePassword() { + return null; + } + + default File getKeyCert() { + return null; + } + + default File getKey() { + return null; + } + + default char[] getKeyPassword() { + return null; + } + + default File getTrustedCerts() { + return null; + } + } diff --git a/plugins/riot/src/main/java/com/redis/riot/RedisContext.java b/plugins/riot/src/main/java/com/redis/riot/RedisContext.java index 546f43174..1a563ce86 100644 --- a/plugins/riot/src/main/java/com/redis/riot/RedisContext.java +++ b/plugins/riot/src/main/java/com/redis/riot/RedisContext.java @@ -14,6 +14,7 @@ import io.lettuce.core.ReadFrom; import io.lettuce.core.RedisURI; import io.lettuce.core.SslOptions; +import io.lettuce.core.SslOptions.Resource; import io.lettuce.core.SslVerifyMode; import io.lettuce.core.cluster.ClusterClientOptions; import io.lettuce.core.protocol.ProtocolVersion; @@ -99,9 +100,31 @@ public static RedisContext of(RedisURI uri, RedisClientArgs args) { context.protocolVersion(args.getProtocolVersion()); context.readFrom(args.getReadFrom().getReadFrom()); context.uri(uriBuilder(args).uri(uri).build()); + context.sslOptions(sslOptions(args)); return context; } + public static RedisContext of(RedisArgs args) { + return of(args.getUri(), args); + } + + private static SslOptions sslOptions(RedisClientArgs args) { + SslOptions.Builder ssl = SslOptions.builder(); + if (args.getKey() != null) { + ssl.keyManager(args.getKeyCert(), args.getKey(), args.getKeyPassword()); + } + if (args.getKeystore() != null) { + ssl.keystore(args.getKeystore(), args.getKeystorePassword()); + } + if (args.getTruststore() != null) { + ssl.truststore(Resource.from(args.getTruststore()), args.getTruststorePassword()); + } + if (args.getTrustedCerts() != null) { + ssl.trustManager(args.getTrustedCerts()); + } + return ssl.build(); + } + public AbstractRedisClient getClient() { return client; } @@ -173,10 +196,4 @@ public RedisContext clientResources(ClientResources clientResources) { return this; } - public static RedisContext of(RedisArgs args) { - RedisContext context = of(args.getUri(), args); - context.sslOptions(args.getSslArgs().sslOptions()); - return context; - } - } diff --git a/plugins/riot/src/main/java/com/redis/riot/S3Args.java b/plugins/riot/src/main/java/com/redis/riot/S3Args.java index d6ae44e24..711b4bdd0 100644 --- a/plugins/riot/src/main/java/com/redis/riot/S3Args.java +++ b/plugins/riot/src/main/java/com/redis/riot/S3Args.java @@ -23,15 +23,6 @@ public class S3Args { @Option(names = "--s3-endpoint", description = "Service endpoint with which the AWS client should communicate (e.g. https://sns.us-west-1.amazonaws.com).", paramLabel = "") private URI endpoint; - public S3Options s3Options() { - S3Options options = new S3Options(); - options.setAccessKey(accessKey); - options.setSecretKey(secretKey); - options.setEndpoint(endpoint); - options.setRegion(region); - return options; - } - public Region getRegion() { return region; } @@ -64,4 +55,13 @@ public void setSecretKey(String secretKey) { this.secretKey = secretKey; } + public S3Options s3Options() { + S3Options options = new S3Options(); + options.setAccessKey(accessKey); + options.setSecretKey(secretKey); + options.setEndpoint(endpoint); + options.setRegion(region); + return options; + } + } diff --git a/plugins/riot/src/main/java/com/redis/riot/SourceRedisArgs.java b/plugins/riot/src/main/java/com/redis/riot/SourceRedisArgs.java index 232ffa151..7607c39f7 100644 --- a/plugins/riot/src/main/java/com/redis/riot/SourceRedisArgs.java +++ b/plugins/riot/src/main/java/com/redis/riot/SourceRedisArgs.java @@ -1,12 +1,13 @@ package com.redis.riot; +import java.io.File; import java.time.Duration; import io.lettuce.core.protocol.ProtocolVersion; import lombok.ToString; import picocli.CommandLine.Option; -@ToString(exclude = "password") +@ToString(exclude = { "password", "keystorePassword", "truststorePassword", "keyPassword" }) public class SourceRedisArgs implements RedisClientArgs { @Option(names = "--source-user", description = "Source ACL style 'AUTH username pass'. Needs password.", paramLabel = "") @@ -39,6 +40,30 @@ public class SourceRedisArgs implements RedisClientArgs { @Option(names = "--source-read-from", description = "Which source Redis cluster nodes to read from: ${COMPLETION-CANDIDATES} (default: ${DEFAULT-VALUE}).", paramLabel = "") private ReadFrom readFrom = DEFAULT_READ_FROM; + @Option(names = "--source-keystore", description = "Path to keystore.", paramLabel = "", hidden = true) + private File keystore; + + @Option(names = "--source-keystore-pass", arity = "0..1", interactive = true, description = "Keystore password.", paramLabel = "", hidden = true) + private char[] keystorePassword; + + @Option(names = "--source-trust", description = "Path to truststore.", paramLabel = "", hidden = true) + private File truststore; + + @Option(names = "--source-trust-pass", arity = "0..1", interactive = true, description = "Truststore password.", paramLabel = "", hidden = true) + private char[] truststorePassword; + + @Option(names = "--source-cert", description = "Client certificate to authenticate with (X.509 PEM).", paramLabel = "") + private File keyCert; + + @Option(names = "--source-key", description = "Private key file to authenticate with (PKCS#8 PEM).", paramLabel = "") + private File key; + + @Option(names = "--source-key-pass", arity = "0..1", interactive = true, description = "Private key password.", paramLabel = "") + private char[] keyPassword; + + @Option(names = "--source-cacert", description = "CA Certificate file to verify with (X.509).", paramLabel = "") + private File trustedCerts; + @Override public String getUsername() { return username; @@ -66,6 +91,7 @@ public void setInsecure(boolean insecure) { this.insecure = insecure; } + @Override public boolean isCluster() { return cluster; } @@ -74,6 +100,7 @@ public void setCluster(boolean cluster) { this.cluster = cluster; } + @Override public ProtocolVersion getProtocolVersion() { return protocolVersion; } @@ -100,6 +127,7 @@ public void setTimeout(Duration timeout) { this.timeout = timeout.toSeconds(); } + @Override public boolean isTls() { return tls; } @@ -108,6 +136,7 @@ public void setTls(boolean tls) { this.tls = tls; } + @Override public String getClientName() { return clientName; } @@ -116,6 +145,7 @@ public void setClientName(String clientName) { this.clientName = clientName; } + @Override public ReadFrom getReadFrom() { return readFrom; } @@ -124,4 +154,76 @@ public void setReadFrom(ReadFrom readFrom) { this.readFrom = readFrom; } + @Override + public File getKeystore() { + return keystore; + } + + public void setKeystore(File keystore) { + this.keystore = keystore; + } + + @Override + public char[] getKeystorePassword() { + return keystorePassword; + } + + public void setKeystorePassword(char[] keystorePassword) { + this.keystorePassword = keystorePassword; + } + + @Override + public File getTruststore() { + return truststore; + } + + public void setTruststore(File truststore) { + this.truststore = truststore; + } + + @Override + public char[] getTruststorePassword() { + return truststorePassword; + } + + public void setTruststorePassword(char[] truststorePassword) { + this.truststorePassword = truststorePassword; + } + + @Override + public File getKeyCert() { + return keyCert; + } + + public void setKeyCert(File keyCert) { + this.keyCert = keyCert; + } + + @Override + public File getKey() { + return key; + } + + public void setKey(File key) { + this.key = key; + } + + @Override + public char[] getKeyPassword() { + return keyPassword; + } + + public void setKeyPassword(char[] keyPassword) { + this.keyPassword = keyPassword; + } + + @Override + public File getTrustedCerts() { + return trustedCerts; + } + + public void setTrustedCerts(File trustedCerts) { + this.trustedCerts = trustedCerts; + } + } diff --git a/plugins/riot/src/main/java/com/redis/riot/SslArgs.java b/plugins/riot/src/main/java/com/redis/riot/SslArgs.java deleted file mode 100644 index 791304fc3..000000000 --- a/plugins/riot/src/main/java/com/redis/riot/SslArgs.java +++ /dev/null @@ -1,120 +0,0 @@ -package com.redis.riot; - -import java.io.File; - -import io.lettuce.core.SslOptions; -import io.lettuce.core.SslOptions.Builder; -import io.lettuce.core.SslOptions.Resource; -import lombok.ToString; -import picocli.CommandLine.Option; - -@ToString(exclude = { "keystorePassword", "truststorePassword", "keyPassword" }) - -public class SslArgs { - - @Option(names = "--keystore", description = "Path to keystore.", paramLabel = "", hidden = true) - private File keystore; - - @Option(names = "--keystore-pass", arity = "0..1", interactive = true, description = "Keystore password.", paramLabel = "", hidden = true) - private char[] keystorePassword; - - @Option(names = "--trust", description = "Path to truststore.", paramLabel = "", hidden = true) - private File truststore; - - @Option(names = "--trust-pass", arity = "0..1", interactive = true, description = "Truststore password.", paramLabel = "", hidden = true) - private char[] truststorePassword; - - @Option(names = "--cert", description = "Client certificate to authenticate with (X.509 PEM).", paramLabel = "") - private File keyCert; - - @Option(names = "--key", description = "Private key file to authenticate with (PKCS#8 PEM).", paramLabel = "") - private File key; - - @Option(names = "--key-pass", arity = "0..1", interactive = true, description = "Private key password.", paramLabel = "") - private char[] keyPassword; - - @Option(names = "--cacert", description = "CA Certificate file to verify with (X.509).", paramLabel = "") - private File trustedCerts; - - public SslOptions sslOptions() { - Builder ssl = SslOptions.builder(); - if (key != null) { - ssl.keyManager(keyCert, key, keyPassword); - } - if (keystore != null) { - ssl.keystore(keystore, keystorePassword); - } - if (truststore != null) { - ssl.truststore(Resource.from(truststore), truststorePassword); - } - if (trustedCerts != null) { - ssl.trustManager(trustedCerts); - } - return ssl.build(); - } - - public File getKeystore() { - return keystore; - } - - public void setKeystore(File keystore) { - this.keystore = keystore; - } - - public char[] getKeystorePassword() { - return keystorePassword; - } - - public void setKeystorePassword(char[] keystorePassword) { - this.keystorePassword = keystorePassword; - } - - public File getTruststore() { - return truststore; - } - - public void setTruststore(File truststore) { - this.truststore = truststore; - } - - public char[] getTruststorePassword() { - return truststorePassword; - } - - public void setTruststorePassword(char[] truststorePassword) { - this.truststorePassword = truststorePassword; - } - - public File getKeyCert() { - return keyCert; - } - - public void setKeyCert(File keyCert) { - this.keyCert = keyCert; - } - - public File getKey() { - return key; - } - - public void setKey(File key) { - this.key = key; - } - - public char[] getKeyPassword() { - return keyPassword; - } - - public void setKeyPassword(char[] keyPassword) { - this.keyPassword = keyPassword; - } - - public File getTrustedCerts() { - return trustedCerts; - } - - public void setTrustedCerts(File trustedCerts) { - this.trustedCerts = trustedCerts; - } - -} diff --git a/plugins/riot/src/main/java/com/redis/riot/TargetRedisArgs.java b/plugins/riot/src/main/java/com/redis/riot/TargetRedisArgs.java index 5d96b591e..c0cb9919c 100644 --- a/plugins/riot/src/main/java/com/redis/riot/TargetRedisArgs.java +++ b/plugins/riot/src/main/java/com/redis/riot/TargetRedisArgs.java @@ -1,12 +1,13 @@ package com.redis.riot; +import java.io.File; import java.time.Duration; import io.lettuce.core.protocol.ProtocolVersion; import lombok.ToString; import picocli.CommandLine.Option; -@ToString(exclude = "password") +@ToString(exclude = { "password", "keystorePassword", "truststorePassword", "keyPassword" }) public class TargetRedisArgs implements RedisClientArgs { @Option(names = "--target-user", description = "Target ACL style 'AUTH username pass'. Needs password.", paramLabel = "") @@ -39,6 +40,30 @@ public class TargetRedisArgs implements RedisClientArgs { @Option(names = "--target-read-from", description = "Which target Redis cluster nodes to read from: ${COMPLETION-CANDIDATES} (default: ${DEFAULT-VALUE}).", paramLabel = "") private ReadFrom readFrom = DEFAULT_READ_FROM; + @Option(names = "--target-keystore", description = "Path to keystore.", paramLabel = "", hidden = true) + private File keystore; + + @Option(names = "--target-keystore-pass", arity = "0..1", interactive = true, description = "Keystore password.", paramLabel = "", hidden = true) + private char[] keystorePassword; + + @Option(names = "--target-trust", description = "Path to truststore.", paramLabel = "", hidden = true) + private File truststore; + + @Option(names = "--target-trust-pass", arity = "0..1", interactive = true, description = "Truststore password.", paramLabel = "", hidden = true) + private char[] truststorePassword; + + @Option(names = "--target-cert", description = "Client certificate to authenticate with (X.509 PEM).", paramLabel = "") + private File keyCert; + + @Option(names = "--target-key", description = "Private key file to authenticate with (PKCS#8 PEM).", paramLabel = "") + private File key; + + @Option(names = "--target-key-pass", arity = "0..1", interactive = true, description = "Private key password.", paramLabel = "") + private char[] keyPassword; + + @Option(names = "--target-cacert", description = "CA Certificate file to verify with (X.509).", paramLabel = "") + private File trustedCerts; + @Override public String getUsername() { return username; @@ -129,4 +154,76 @@ public void setReadFrom(ReadFrom readFrom) { this.readFrom = readFrom; } + @Override + public File getKeystore() { + return keystore; + } + + public void setKeystore(File keystore) { + this.keystore = keystore; + } + + @Override + public char[] getKeystorePassword() { + return keystorePassword; + } + + public void setKeystorePassword(char[] keystorePassword) { + this.keystorePassword = keystorePassword; + } + + @Override + public File getTruststore() { + return truststore; + } + + public void setTruststore(File truststore) { + this.truststore = truststore; + } + + @Override + public char[] getTruststorePassword() { + return truststorePassword; + } + + public void setTruststorePassword(char[] truststorePassword) { + this.truststorePassword = truststorePassword; + } + + @Override + public File getKeyCert() { + return keyCert; + } + + public void setKeyCert(File keyCert) { + this.keyCert = keyCert; + } + + @Override + public File getKey() { + return key; + } + + public void setKey(File key) { + this.key = key; + } + + @Override + public char[] getKeyPassword() { + return keyPassword; + } + + public void setKeyPassword(char[] keyPassword) { + this.keyPassword = keyPassword; + } + + @Override + public File getTrustedCerts() { + return trustedCerts; + } + + public void setTrustedCerts(File trustedCerts) { + this.trustedCerts = trustedCerts; + } + } diff --git a/plugins/riot/src/test/java/com/redis/riot/StackRiotTests.java b/plugins/riot/src/test/java/com/redis/riot/StackRiotTests.java index 9943a11ec..4c50cb99c 100644 --- a/plugins/riot/src/test/java/com/redis/riot/StackRiotTests.java +++ b/plugins/riot/src/test/java/com/redis/riot/StackRiotTests.java @@ -383,14 +383,14 @@ void fileImportBad(TestInfo info) throws Exception { @Test void fileImportGCS(TestInfo info) throws Exception { - testImport(info, "file-import-gcs", "beer:*", 4432); + testImport(info, "file-import-gcs", "beer:*", 216); Map beer1 = redisCommands.hgetall("beer:1"); Assertions.assertEquals("Hocus Pocus", name(beer1)); } @Test void fileImportS3(TestInfo info) throws Exception { - testImport(info, "file-import-s3", "beer:*", 4432); + testImport(info, "file-import-s3", "beer:*", 216); Map beer1 = redisCommands.hgetall("beer:1"); Assertions.assertEquals("Hocus Pocus", name(beer1)); } diff --git a/plugins/riot/src/test/resources/file-import-gcs b/plugins/riot/src/test/resources/file-import-gcs index 191279d8f..0a4fb91c8 100755 --- a/plugins/riot/src/test/resources/file-import-gcs +++ b/plugins/riot/src/test/resources/file-import-gcs @@ -1 +1 @@ -riot file-import gs://riot-bucket-jrx/beers.json hset --keyspace beer --key id \ No newline at end of file +riot file-import gs://riotx/beers.json hset --keyspace beer --key id \ No newline at end of file diff --git a/plugins/riot/src/test/resources/file-import-s3 b/plugins/riot/src/test/resources/file-import-s3 index f4db643a3..b402d6bd1 100755 --- a/plugins/riot/src/test/resources/file-import-s3 +++ b/plugins/riot/src/test/resources/file-import-s3 @@ -1 +1 @@ -riot file-import s3://riot-bucket-jrx/beers.json --s3-region us-west-1 hset --keyspace beer --key id \ No newline at end of file +riot file-import s3://riotx/beers.json --s3-region us-west-1 hset --keyspace beer --key id \ No newline at end of file