diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java index aaafe096..b10797b4 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java @@ -29,9 +29,10 @@ import org.apache.doris.spark.exception.ConnectedFailedException; import org.apache.doris.spark.exception.DorisException; import org.apache.doris.spark.exception.DorisInternalException; -import org.apache.doris.spark.util.ErrorMessages; import org.apache.doris.spark.cfg.Settings; import org.apache.doris.spark.serialization.Routing; +import org.apache.doris.spark.util.ErrorMessages; + import org.apache.thrift.TConfiguration; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java index 798ec8cf..c941fdfa 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java @@ -25,6 +25,7 @@ import org.apache.doris.spark.exception.IllegalArgumentException; import org.apache.doris.spark.util.ErrorMessages; import org.apache.doris.spark.util.IOUtils; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index ac920cd0..9ecfa405 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -22,8 +22,6 @@ import org.apache.doris.spark.rest.RestService; import org.apache.doris.spark.rest.models.BackendV2; import org.apache.doris.spark.rest.models.RespContent; -import org.apache.doris.spark.util.DataUtil; -import org.apache.doris.spark.util.ListUtils; import org.apache.doris.spark.util.ResponseUtil; import com.fasterxml.jackson.core.JsonProcessingException; @@ -39,71 +37,72 @@ import org.apache.http.HttpStatus; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.entity.BufferedHttpEntity; -import org.apache.http.entity.StringEntity; +import org.apache.http.entity.InputStreamEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.util.EntityUtils; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.Serializable; import java.nio.charset.StandardCharsets; -import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Calendar; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; /** * DorisStreamLoad **/ public class DorisStreamLoad implements Serializable { - private String FIELD_DELIMITER; - private final String LINE_DELIMITER; - private static final String NULL_VALUE = "\\N"; private static final Logger LOG = LoggerFactory.getLogger(DorisStreamLoad.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final static List DORIS_SUCCESS_STATUS = new ArrayList<>(Arrays.asList("Success", "Publish Timeout")); - private static String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?"; - private static String abortUrlPattern = "http://%s/api/%s/%s/_stream_load_2pc?"; + private static final String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?"; + + private static final String abortUrlPattern = "http://%s/api/%s/%s/_stream_load_2pc?"; - private String user; - private String passwd; private String loadUrlStr; - private String db; - private String tbl; - private String authEncoded; - private String columns; - private String maxFilterRatio; - private Map streamLoadProp; + private final String db; + private final String tbl; + private final String authEncoded; + private final String columns; + private final String maxFilterRatio; + private final Map streamLoadProp; private static final long cacheExpireTimeout = 4 * 60; private final LoadingCache> cache; private final String fileType; - - private boolean readJsonByLine = false; - + private String FIELD_DELIMITER; + private final String LINE_DELIMITER; private boolean streamingPassthrough = false; + private final Integer batchSize; + private boolean enable2PC; public DorisStreamLoad(SparkSettings settings) { String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\."); this.db = dbTable[0]; this.tbl = dbTable[1]; - this.user = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER); - this.passwd = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD); + String user = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER); + String passwd = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD); this.authEncoded = getAuthEncoded(user, passwd); this.columns = settings.getProperty(ConfigurationOptions.DORIS_WRITE_FIELDS); this.maxFilterRatio = settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO); @@ -113,18 +112,15 @@ public DorisStreamLoad(SparkSettings settings) { if ("csv".equals(fileType)) { FIELD_DELIMITER = escapeString(streamLoadProp.getOrDefault("column_separator", "\t")); } else if ("json".equalsIgnoreCase(fileType)) { - readJsonByLine = Boolean.parseBoolean(streamLoadProp.getOrDefault("read_json_by_line", "false")); - boolean stripOuterArray = Boolean.parseBoolean(streamLoadProp.getOrDefault("strip_outer_array", "false")); - if (readJsonByLine && stripOuterArray) { - throw new IllegalArgumentException("Only one of options 'read_json_by_line' and 'strip_outer_array' can be set to true"); - } else if (!readJsonByLine && !stripOuterArray) { - LOG.info("set default json mode: strip_outer_array"); - streamLoadProp.put("strip_outer_array", "true"); - } + streamLoadProp.put("read_json_by_line", "true"); } LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n")); this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH, ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT); + this.batchSize = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, + ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT); + this.enable2PC = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, + ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT); } public String getLoadUrlStr() { @@ -141,9 +137,7 @@ private CloseableHttpClient getHttpClient() { private HttpPut getHttpPut(String label, String loadUrlStr, Boolean enable2PC) { HttpPut httpPut = new HttpPut(loadUrlStr); - httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); - httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); - httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + addCommonHeader(httpPut); httpPut.setHeader("label", label); if (StringUtils.isNotBlank(columns)) { httpPut.setHeader("columns", columns); @@ -155,7 +149,11 @@ private HttpPut getHttpPut(String label, String loadUrlStr, Boolean enable2PC) { httpPut.setHeader("two_phase_commit", "true"); } if (MapUtils.isNotEmpty(streamLoadProp)) { - streamLoadProp.forEach(httpPut::setHeader); + streamLoadProp.forEach((k, v) -> { + if (!"strip_outer_array".equalsIgnoreCase(k)) { + httpPut.setHeader(k, v); + } + }); } return httpPut; } @@ -165,10 +163,10 @@ public static class LoadResponse { public String respMsg; public String respContent; - public LoadResponse(int status, String respMsg, String respContent) { - this.status = status; - this.respMsg = respMsg; - this.respContent = respContent; + public LoadResponse(HttpResponse response) throws IOException { + this.status = response.getStatusLine().getStatusCode(); + this.respMsg = response.getStatusLine().getReasonPhrase(); + this.respContent = EntityUtils.toString(new BufferedHttpEntity(response.getEntity()), StandardCharsets.UTF_8); } @Override @@ -177,95 +175,34 @@ public String toString() { } } - public List loadV2(List> rows, String[] dfColumns, Boolean enable2PC) throws StreamLoadException, JsonProcessingException { - - List loadData = parseLoadData(rows, dfColumns); - List txnIds = new ArrayList<>(loadData.size()); - - try { - for (String data : loadData) { - txnIds.add(load(data, enable2PC)); - } - } catch (StreamLoadException e) { - if (enable2PC && !txnIds.isEmpty()) { - LOG.error("load batch failed, abort previously pre-committed transactions"); - for (Integer txnId : txnIds) { - abort(txnId); - } - } - throw e; - } - - return txnIds; - - } - - public List loadStream(List> rows, String[] dfColumns, Boolean enable2PC) + public int load(Iterator rows, StructType schema) throws StreamLoadException, JsonProcessingException { - List loadData; - - if (this.streamingPassthrough) { - handleStreamPassThrough(); - loadData = passthrough(rows); - } else { - loadData = parseLoadData(rows, dfColumns); - } - - List txnIds = new ArrayList<>(loadData.size()); - - try { - for (String data : loadData) { - txnIds.add(load(data, enable2PC)); - } - } catch (StreamLoadException e) { - if (enable2PC && !txnIds.isEmpty()) { - LOG.error("load batch failed, abort previously pre-committed transactions"); - for (Integer txnId : txnIds) { - abort(txnId); - } - } - throw e; - } - - return txnIds; - - } - - public int load(String value, Boolean enable2PC) throws StreamLoadException { - String label = generateLoadLabel(); - LoadResponse loadResponse; - int responseHttpStatus = -1; try (CloseableHttpClient httpClient = getHttpClient()) { String loadUrlStr = String.format(loadUrlPattern, getBackend(), db, tbl); - LOG.debug("Stream load Request:{} ,Body:{}", loadUrlStr, value); - // only to record the BE node in case of an exception this.loadUrlStr = loadUrlStr; - HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); - httpPut.setEntity(new StringEntity(value, StandardCharsets.UTF_8)); + RecordBatchInputStream recodeBatchInputStream = new RecordBatchInputStream(RecordBatch.newBuilder(rows) + .batchSize(batchSize) + .format(fileType) + .sep(FIELD_DELIMITER) + .delim(LINE_DELIMITER) + .schema(schema).build(), streamingPassthrough); + httpPut.setEntity(new InputStreamEntity(recodeBatchInputStream)); HttpResponse httpResponse = httpClient.execute(httpPut); - responseHttpStatus = httpResponse.getStatusLine().getStatusCode(); - String respMsg = httpResponse.getStatusLine().getReasonPhrase(); - String response = EntityUtils.toString(new BufferedHttpEntity(httpResponse.getEntity()), StandardCharsets.UTF_8); - loadResponse = new LoadResponse(responseHttpStatus, respMsg, response); + loadResponse = new LoadResponse(httpResponse); } catch (IOException e) { - e.printStackTrace(); - String err = "http request exception,load url : " + loadUrlStr + ",failed to execute spark stream load with label: " + label; - LOG.warn(err, e); - loadResponse = new LoadResponse(responseHttpStatus, e.getMessage(), err); + throw new RuntimeException(e); } if (loadResponse.status != HttpStatus.SC_OK) { LOG.info("Stream load Response HTTP Status Error:{}", loadResponse); - // throw new StreamLoadException("stream load error: " + loadResponse.respContent); throw new StreamLoadException("stream load error"); } else { - ObjectMapper obj = new ObjectMapper(); try { - RespContent respContent = obj.readValue(loadResponse.respContent, RespContent.class); + RespContent respContent = MAPPER.readValue(loadResponse.respContent, RespContent.class); if (!DORIS_SUCCESS_STATUS.contains(respContent.getStatus())) { LOG.error("Stream load Response RES STATUS Error:{}", loadResponse); throw new StreamLoadException("stream load error"); @@ -279,6 +216,14 @@ public int load(String value, Boolean enable2PC) throws StreamLoadException { } + public Integer loadStream(Iterator rows, StructType schema) + throws StreamLoadException, JsonProcessingException { + if (this.streamingPassthrough) { + handleStreamPassThrough(); + } + return load(rows, schema); + } + public void commit(int txnId) throws StreamLoadException { try (CloseableHttpClient client = getHttpClient()) { @@ -286,9 +231,7 @@ public void commit(int txnId) throws StreamLoadException { String backend = getBackend(); String abortUrl = String.format(abortUrlPattern, backend, db, tbl); HttpPut httpPut = new HttpPut(abortUrl); - httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); - httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); - httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + addCommonHeader(httpPut); httpPut.setHeader("txn_operation", "commit"); httpPut.setHeader("txn_id", String.valueOf(txnId)); @@ -306,10 +249,9 @@ public void commit(int txnId) throws StreamLoadException { throw new StreamLoadException("stream load error: " + reasonPhrase); } - ObjectMapper mapper = new ObjectMapper(); if (response.getEntity() != null) { String loadResult = EntityUtils.toString(response.getEntity()); - Map res = mapper.readValue(loadResult, new TypeReference>() { + Map res = MAPPER.readValue(loadResult, new TypeReference>() { }); if (res.get("status").equals("Fail") && !ResponseUtil.isCommitted(res.get("msg"))) { throw new StreamLoadException("Commit failed " + loadResult); @@ -331,9 +273,7 @@ public void abort(int txnId) throws StreamLoadException { try (CloseableHttpClient client = getHttpClient()) { String abortUrl = String.format(abortUrlPattern, getBackend(), db, tbl); HttpPut httpPut = new HttpPut(abortUrl); - httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); - httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); - httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + addCommonHeader(httpPut); httpPut.setHeader("txn_operation", "abort"); httpPut.setHeader("txn_id", String.valueOf(txnId)); @@ -344,9 +284,8 @@ public void abort(int txnId) throws StreamLoadException { throw new StreamLoadException("Fail to abort transaction " + txnId + " with url " + abortUrl); } - ObjectMapper mapper = new ObjectMapper(); String loadResult = EntityUtils.toString(response.getEntity()); - Map res = mapper.readValue(loadResult, new TypeReference>() { + Map res = MAPPER.readValue(loadResult, new TypeReference>() { }); if (!"Success".equals(res.get("status"))) { if (ResponseUtil.isCommitted(res.get("msg"))) { @@ -410,52 +349,13 @@ public List load(String key) throws Exception { } - private List parseLoadData(List> rows, String[] dfColumns) throws StreamLoadException, JsonProcessingException { - - List loadDataList; - - switch (fileType.toUpperCase()) { - - case "CSV": - loadDataList = Collections.singletonList( - rows.stream() - .map(row -> row.stream() - .map(DataUtil::handleColumnValue) - .map(Object::toString) - .collect(Collectors.joining(FIELD_DELIMITER)) - ).collect(Collectors.joining(LINE_DELIMITER))); - break; - case "JSON": - List> dataList = new ArrayList<>(); - try { - for (List row : rows) { - Map dataMap = new HashMap<>(); - if (dfColumns.length == row.size()) { - for (int i = 0; i < dfColumns.length; i++) { - dataMap.put(dfColumns[i], DataUtil.handleColumnValue(row.get(i))); - } - } - dataList.add(dataMap); - } - } catch (Exception e) { - throw new StreamLoadException("The number of configured columns does not match the number of data columns."); - } - // splits large collections to normal collection to avoid the "Requested array size exceeds VM limit" exception - loadDataList = ListUtils.getSerializedList(dataList, readJsonByLine ? LINE_DELIMITER : null); - break; - default: - throw new StreamLoadException(String.format("Unsupported file format in stream load: %s.", fileType)); - - } - - return loadDataList; - - } - private String generateLoadLabel() { Calendar calendar = Calendar.getInstance(); - return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s", calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), calendar.get(Calendar.HOUR_OF_DAY), calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND), UUID.randomUUID().toString().replaceAll("-", "")); + return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s", + calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), + calendar.get(Calendar.HOUR_OF_DAY), calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND), + UUID.randomUUID().toString().replaceAll("-", "")); } @@ -478,6 +378,12 @@ private String escapeString(String hexData) { return hexData; } + private void addCommonHeader(HttpRequestBase httpReq) { + httpReq.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); + httpReq.setHeader(HttpHeaders.EXPECT, "100-continue"); + httpReq.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + } + private void handleStreamPassThrough() { if ("json".equalsIgnoreCase(fileType)) { @@ -488,8 +394,4 @@ private void handleStreamPassThrough() { } - private List passthrough(List> values) { - return values.stream().map(list -> list.get(0).toString()).collect(Collectors.toList()); - } - } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java new file mode 100644 index 00000000..779c057d --- /dev/null +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.load; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; + +/** + * Wrapper Object for batch loading + */ +public class RecordBatch { + + private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + /** + * Spark row data iterator + */ + private final Iterator iterator; + + /** + * batch size for single load + */ + private final int batchSize; + + /** + * stream load format + */ + private final String format; + + /** + * column separator, only used when the format is csv + */ + private final String sep; + + /** + * line delimiter + */ + private final byte[] delim; + + /** + * schema of row + */ + private final StructType schema; + + private RecordBatch(Iterator iterator, int batchSize, String format, String sep, byte[] delim, + StructType schema) { + this.iterator = iterator; + this.batchSize = batchSize; + this.format = format; + this.sep = sep; + this.delim = delim; + this.schema = schema; + } + + public Iterator getIterator() { + return iterator; + } + + public int getBatchSize() { + return batchSize; + } + + public String getFormat() { + return format; + } + + public String getSep() { + return sep; + } + + public byte[] getDelim() { + return delim; + } + + public StructType getSchema() { + return schema; + } + public static Builder newBuilder(Iterator iterator) { + return new Builder(iterator); + } + + /** + * RecordBatch Builder + */ + public static class Builder { + + private final Iterator iterator; + + private int batchSize; + + private String format; + + private String sep; + + private byte[] delim; + + private StructType schema; + + public Builder(Iterator iterator) { + this.iterator = iterator; + } + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Builder format(String format) { + this.format = format; + return this; + } + + public Builder sep(String sep) { + this.sep = sep; + return this; + } + + public Builder delim(String delim) { + this.delim = delim.getBytes(DEFAULT_CHARSET); + return this; + } + + public Builder schema(StructType schema) { + this.schema = schema; + return this; + } + + public RecordBatch build() { + return new RecordBatch(iterator, batchSize, format, sep, delim, schema); + } + + } + +} diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java new file mode 100644 index 00000000..9444c1da --- /dev/null +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java @@ -0,0 +1,221 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.load; + +import org.apache.doris.spark.exception.DorisException; +import org.apache.doris.spark.exception.IllegalArgumentException; +import org.apache.doris.spark.exception.ShouldNeverHappenException; +import org.apache.doris.spark.util.DataUtil; + +import com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.spark.sql.catalyst.InternalRow; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; + +/** + * InputStream for batch load + */ +public class RecordBatchInputStream extends InputStream { + + public static final Logger LOG = LoggerFactory.getLogger(RecordBatchInputStream.class); + + private static final int DEFAULT_BUF_SIZE = 4096; + + /** + * Load record batch + */ + private final RecordBatch recordBatch; + + /** + * first line flag + */ + private boolean isFirst = true; + + /** + * record buffer + */ + private ByteBuffer buffer = ByteBuffer.allocate(0); + + /** + * record count has been read + */ + private int readCount = 0; + + /** + * streaming mode pass through data without process + */ + private final boolean passThrough; + + public RecordBatchInputStream(RecordBatch recordBatch, boolean passThrough) { + this.recordBatch = recordBatch; + this.passThrough = passThrough; + } + + @Override + public int read() throws IOException { + try { + if (buffer.remaining() == 0 && endOfBatch()) { + return -1; // End of stream + } + } catch (DorisException e) { + throw new IOException(e); + } + return buffer.get() & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + try { + if (buffer.remaining() == 0 && endOfBatch()) { + return -1; // End of stream + } + } catch (DorisException e) { + throw new IOException(e); + } + int bytesRead = Math.min(len, buffer.remaining()); + buffer.get(b, off, bytesRead); + return bytesRead; + } + + /** + * Check if the current batch read is over. + * If the number of reads is greater than or equal to the batch size or there is no next record, return false, + * otherwise return true. + * + * @return Whether the current batch read is over + * @throws DorisException + */ + public boolean endOfBatch() throws DorisException { + Iterator iterator = recordBatch.getIterator(); + if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) { + return true; + } + readNext(iterator); + return false; + } + + /** + * read next record into buffer + * + * @param iterator row iterator + * @throws DorisException + */ + private void readNext(Iterator iterator) throws DorisException { + if (!iterator.hasNext()) { + throw new ShouldNeverHappenException(); + } + byte[] delim = recordBatch.getDelim(); + byte[] rowBytes = rowToByte(iterator.next()); + if (isFirst) { + ensureCapacity(rowBytes.length); + buffer.put(rowBytes); + buffer.flip(); + isFirst = false; + } else { + ensureCapacity(delim.length + rowBytes.length); + buffer.put(delim); + buffer.put(rowBytes); + buffer.flip(); + } + readCount++; + } + + /** + * Check if the buffer has enough capacity. + * + * @param need required buffer space + */ + private void ensureCapacity(int need) { + + int capacity = buffer.capacity(); + + if (need <= capacity) { + buffer.clear(); + return; + } + + // need to extend + int newCapacity = calculateNewCapacity(capacity, need); + LOG.info("expand buffer, min cap: {}, now cap: {}, new cap: {}", need, capacity, newCapacity); + buffer = ByteBuffer.allocate(newCapacity); + + } + + /** + * Calculate new capacity for buffer expansion. + * + * @param capacity current buffer capacity + * @param minCapacity required min buffer space + * @return new capacity + */ + private int calculateNewCapacity(int capacity, int minCapacity) { + int newCapacity; + if (capacity == 0) { + newCapacity = DEFAULT_BUF_SIZE; + while (newCapacity < minCapacity) { + newCapacity = newCapacity << 1; + } + } else { + newCapacity = capacity << 1; + } + return newCapacity; + } + + /** + * Convert Spark row data to byte array + * + * @param row row data + * @return byte array + * @throws DorisException + */ + private byte[] rowToByte(InternalRow row) throws DorisException { + + byte[] bytes; + + if (passThrough) { + bytes = row.getString(0).getBytes(StandardCharsets.UTF_8); + return bytes; + } + + switch (recordBatch.getFormat().toLowerCase()) { + case "csv": + bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSchema(), recordBatch.getSep()); + break; + case "json": + try { + bytes = DataUtil.rowToJsonBytes(row, recordBatch.getSchema()); + } catch (JsonProcessingException e) { + throw new DorisException("parse row to json bytes failed", e); + } + break; + default: + throw new IllegalArgumentException("format", recordBatch.getFormat()); + } + + return bytes; + + } + + +} diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index faa8ef58..3d66db52 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -17,19 +17,11 @@ package org.apache.doris.spark.serialization; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.charset.StandardCharsets; -import java.sql.Date; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.format.DateTimeFormatter; -import java.util.ArrayList; -import java.util.List; -import java.util.NoSuchElementException; +import org.apache.doris.sdk.thrift.TScanBatchResult; +import org.apache.doris.spark.exception.DorisException; +import org.apache.doris.spark.rest.models.Schema; +import com.google.common.base.Preconditions; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; @@ -47,17 +39,21 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.Types; - -import org.apache.doris.sdk.thrift.TScanBatchResult; -import org.apache.doris.spark.exception.DorisException; -import org.apache.doris.spark.rest.models.Schema; - import org.apache.commons.lang3.ArrayUtils; import org.apache.spark.sql.types.Decimal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.google.common.base.Preconditions; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.sql.Date; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; /** * row batch data container. @@ -128,7 +124,11 @@ public RowBatch(TScanBatchResult nextResult, Schema schema) throws DorisExceptio } public boolean hasNext() { - return offsetInRowBatch < readRowCount; + if (offsetInRowBatch >= readRowCount) { + rowBatch.clear(); + return false; + } + return true; } private void addValueToRow(int rowIndex, Object obj) { diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java index 58774474..aea6ddee 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java @@ -17,35 +17,48 @@ package org.apache.doris.spark.util; -import scala.collection.JavaConversions; -import scala.collection.mutable.WrappedArray; +import org.apache.doris.spark.sql.SchemaUtils; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.Arrays; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; -public class DataUtil { +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; - public static final String NULL_VALUE = "\\N"; +public class DataUtil { - public static Object handleColumnValue(Object value) { + private static final ObjectMapper MAPPER = new ObjectMapper(); - if (value == null) { - return NULL_VALUE; - } + public static final String NULL_VALUE = "\\N"; - if (value instanceof Date || value instanceof Timestamp) { - return value.toString(); + public static byte[] rowToCsvBytes(InternalRow row, StructType schema, String sep) { + StringBuilder builder = new StringBuilder(); + StructField[] fields = schema.fields(); + int n = row.numFields(); + if (n > 0) { + builder.append(SchemaUtils.rowColumnValue(row, 0, fields[0].dataType())); + int i = 1; + while (i < n) { + builder.append(sep); + builder.append(SchemaUtils.rowColumnValue(row, i, fields[i].dataType())); + i++; + } } + return builder.toString().getBytes(StandardCharsets.UTF_8); + } - if (value instanceof WrappedArray) { - - Object[] arr = JavaConversions.seqAsJavaList((WrappedArray) value).toArray(); - return Arrays.toString(arr); + public static byte[] rowToJsonBytes(InternalRow row, StructType schema) + throws JsonProcessingException { + StructField[] fields = schema.fields(); + Map rowMap = new HashMap<>(row.numFields()); + for (int i = 0; i < fields.length; i++) { + rowMap.put(fields[i].name(), SchemaUtils.rowColumnValue(row, i, fields[i].dataType())); } - - return value; - + return MAPPER.writeValueAsBytes(rowMap); } } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java index d8d31b9e..fbfab9a5 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java @@ -34,7 +34,7 @@ public class ListUtils { private static final ObjectMapper MAPPER = new ObjectMapper(); public static List getSerializedList(List> batch, - String lineDelimiter) throws JsonProcessingException { + String lineDelimiter) throws JsonProcessingException { List result = new ArrayList<>(); divideAndSerialize(batch, result, lineDelimiter); return result; diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala index d08bdc0d..9dee5158 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala @@ -18,18 +18,18 @@ package org.apache.doris import scala.language.implicitConversions - import org.apache.doris.spark.rdd.DorisSpark import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD package object spark { - implicit def sparkContextFunctions(sc: SparkContext) = new SparkContextFunctions(sc) + implicit def sparkContextFunctions(sc: SparkContext): SparkContextFunctions = new SparkContextFunctions(sc) class SparkContextFunctions(sc: SparkContext) extends Serializable { def dorisRDD( tableIdentifier: Option[String] = None, query: Option[String] = None, - cfg: Option[Map[String, String]] = None) = + cfg: Option[Map[String, String]] = None): RDD[AnyRef] = DorisSpark.dorisRDD(sc, tableIdentifier, query, cfg) } } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala index 06f5ca30..ec8f887a 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala @@ -27,7 +27,7 @@ private[spark] class ScalaDorisRow(rowOrder: Seq[String]) extends Row { /** No-arg constructor for Kryo serialization. */ def this() = this(null) - def iterator = values.iterator + def iterator: Iterator[Any] = values.iterator override def length: Int = values.length @@ -51,9 +51,9 @@ private[spark] class ScalaDorisRow(rowOrder: Seq[String]) extends Row { override def getByte(i: Int): Byte = getAs[Byte](i) - override def getString(i: Int): String = get(i).toString() + override def getString(i: Int): String = get(i).toString override def copy(): Row = this - override def toSeq = values.toSeq + override def toSeq: Seq[Any] = values } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala index c8aa0349..f5a6a159 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala @@ -18,16 +18,23 @@ package org.apache.doris.spark.sql import org.apache.doris.sdk.thrift.TScanColumnDesc - -import scala.collection.JavaConversions._ +import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD} import org.apache.doris.spark.cfg.Settings import org.apache.doris.spark.exception.DorisException import org.apache.doris.spark.rest.RestService import org.apache.doris.spark.rest.models.{Field, Schema} -import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD} +import org.apache.doris.spark.util.DataUtil +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.slf4j.LoggerFactory +import java.sql.Timestamp +import java.time.{LocalDateTime, ZoneOffset} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable + private[spark] object SchemaUtils { private val logger = LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$")) @@ -137,4 +144,49 @@ private[spark] object SchemaUtils { tscanColumnDescs.foreach(desc => schema.put(new Field(desc.getName, desc.getType.name, "", 0, 0, ""))) schema } + + def rowColumnValue(row: SpecializedGetters, ordinal: Int, dataType: DataType): Any = { + + dataType match { + case NullType => DataUtil.NULL_VALUE + case BooleanType => row.getBoolean(ordinal) + case ByteType => row.getByte(ordinal) + case ShortType => row.getShort(ordinal) + case IntegerType => row.getInt(ordinal) + case LongType => row.getLong(ordinal) + case FloatType => row.getFloat(ordinal) + case DoubleType => row.getDouble(ordinal) + case StringType => row.getUTF8String(ordinal).toString + case TimestampType => + LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, (row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC) + new Timestamp(row.getLong(ordinal) / 1000).toString + case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString + case BinaryType => row.getBinary(ordinal) + case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale) + case at: ArrayType => + val arrayData = row.getArray(ordinal) + var i = 0 + val buffer = mutable.Buffer[Any]() + while (i < arrayData.numElements()) { + if (arrayData.isNullAt(i)) buffer += null else buffer += rowColumnValue(arrayData, i, at.elementType) + i += 1 + } + s"[${buffer.mkString(",")}]" + case mt: MapType => + val mapData = row.getMap(ordinal) + val keys = mapData.keyArray() + val values = mapData.valueArray() + var i = 0 + val map = mutable.Map[Any, Any]() + while (i < keys.numElements()) { + map += rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values, i, mt.valueType) + i += 1 + } + map.toMap.asJava + case st: StructType => row.getStruct(ordinal, st.length) + case _ => throw new DorisException(s"Unsupported spark type: ${dataType.typeName}") + } + + } + } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index e32267ee..b278a385 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -39,8 +39,6 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter]) - val batchSize: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, - ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT) private val maxRetryTimes: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES, ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT) private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE) @@ -55,45 +53,16 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private val dorisStreamLoader: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings) def write(dataFrame: DataFrame): Unit = { + doWrite(dataFrame, dorisStreamLoader.load) + } - val sc = dataFrame.sqlContext.sparkContext - val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") - if (enable2PC) { - sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader)) - } + def writeStream(dataFrame: DataFrame): Unit = { + doWrite(dataFrame, dorisStreamLoader.loadStream) + } - var resultRdd = dataFrame.rdd - val dfColumns = dataFrame.columns - if (Objects.nonNull(sinkTaskPartitionSize)) { - resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) - } - resultRdd - .map(_.toSeq.map(_.asInstanceOf[AnyRef]).toList.asJava) - .foreachPartition(partition => { - partition - .grouped(batchSize) - .foreach(batch => flush(batch, dfColumns)) - }) - - /** - * flush data to Doris and do retry when flush error - * - */ - def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit = { - Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC) - } match { - case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc) - case Failure(e) => - if (enable2PC) handleLoadFailure(preCommittedTxnAcc) - throw new IOException( - s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) - } - } + private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Int): Unit = { - } - def writeStream(dataFrame: DataFrame): Unit = { val sc = dataFrame.sqlContext.sparkContext val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") @@ -103,47 +72,31 @@ class DorisWriter(settings: SparkSettings) extends Serializable { var resultRdd = dataFrame.queryExecution.toRdd val schema = dataFrame.schema - val dfColumns = dataFrame.columns if (Objects.nonNull(sinkTaskPartitionSize)) { resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) } - resultRdd - .foreachPartition(partition => { - partition - .grouped(batchSize) - .foreach(batch => - flush(batch, dfColumns)) - }) - - /** - * flush data to Doris and do retry when flush error - * - */ - def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = { - Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - dorisStreamLoader.loadStream(convertToObjectList(batch, schema), dfColumns, enable2PC) - } match { - case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc) - case Failure(e) => - if (enable2PC) handleLoadFailure(preCommittedTxnAcc) - throw new IOException( - s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) + resultRdd.foreachPartition(iterator => { + while (iterator.hasNext) { + // do load batch with retries + Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { + loadFunc(iterator.asJava, schema) + } match { + case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc) + case Failure(e) => + if (enable2PC) handleLoadFailure(preCommittedTxnAcc) + throw new IOException( + s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) + } } - } - - def convertToObjectList(rows: Seq[InternalRow], schema: StructType): util.List[util.List[Object]] = { - rows.map(row => { - row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava - }).asJava - } + }) } - private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc: CollectionAccumulator[Int]): Unit = { - txnIds.foreach(txnId => acc.add(txnId)) + private def handleLoadSuccess(txnId: Int, acc: CollectionAccumulator[Int]): Unit = { + acc.add(txnId) } - def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = { + private def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = { // if task run failed, acc value will not be returned to driver, // should abort all pre committed transactions inside the task logger.info("load task failed, start aborting previously pre-committed transactions") diff --git a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java b/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java deleted file mode 100644 index 020a241c..00000000 --- a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java +++ /dev/null @@ -1,32 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.spark.util; - -import junit.framework.TestCase; -import org.junit.Assert; -import scala.collection.mutable.WrappedArray; - -import java.sql.Timestamp; - -public class DataUtilTest extends TestCase { - - public void testHandleColumnValue() { - Assert.assertEquals("2023-08-14 18:00:00.0", DataUtil.handleColumnValue(Timestamp.valueOf("2023-08-14 18:00:00"))); - Assert.assertEquals("[1, 2, 3]", DataUtil.handleColumnValue(WrappedArray.make(new Integer[]{1,2,3}))); - } -} \ No newline at end of file diff --git a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala new file mode 100644 index 00000000..e3868cbc --- /dev/null +++ b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.sql + +import org.apache.spark.sql.SparkSession +import org.junit.{Assert, Ignore, Test} + +import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ + +@Ignore +class SchemaUtilsTest { + + @Test + def rowColumnValueTest(): Unit = { + + val spark = SparkSession.builder().master("local").getOrCreate() + + val df = spark.createDataFrame(Seq( + (1, Date.valueOf("2023-09-08"), Timestamp.valueOf("2023-09-08 17:00:00"), Array(1, 2, 3), Map[String, String]("a" -> "1")) + )).toDF("c1", "c2", "c3", "c4", "c5") + + val schema = df.schema + + df.queryExecution.toRdd.foreach(row => { + + val fields = schema.fields + Assert.assertEquals(1, SchemaUtils.rowColumnValue(row, 0, fields(0).dataType)) + Assert.assertEquals("2023-09-08", SchemaUtils.rowColumnValue(row, 1, fields(1).dataType)) + Assert.assertEquals("2023-09-08 17:00:00.0", SchemaUtils.rowColumnValue(row, 2, fields(2).dataType)) + Assert.assertEquals("[1,2,3]", SchemaUtils.rowColumnValue(row, 3, fields(3).dataType)) + println(SchemaUtils.rowColumnValue(row, 4, fields(4).dataType)) + Assert.assertEquals(Map("a" -> "1").asJava, SchemaUtils.rowColumnValue(row, 4, fields(4).dataType)) + + }) + + } + +}