Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite Java API Table.readJSON to return the output from libcudf read_json directly #17180

Merged
merged 5 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 21 additions & 224 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -1092,224 +1092,6 @@ public static Table readJSON(Schema schema, JSONOptions opts, byte[] buffer) {
return readJSON(schema, opts, buffer, 0, buffer.length);
}

private static class DidViewChange {
ColumnVector changeWasNeeded = null;
boolean noChangeNeeded = false;

public static DidViewChange yes(ColumnVector cv) {
DidViewChange ret = new DidViewChange();
ret.changeWasNeeded = cv;
return ret;
}

public static DidViewChange no() {
DidViewChange ret = new DidViewChange();
ret.noChangeNeeded = true;
return ret;
}
}

private static DidViewChange gatherJSONColumns(Schema schema, TableWithMeta.NestedChildren children,
ColumnView cv) {
// We need to do this recursively to be sure it all matches as expected.
// If we run into problems where the data types don't match, we are not
// going to fix up the data types. We are only going to reorder the columns.
if (schema.getType() == DType.STRUCT) {
if (cv.getType() != DType.STRUCT) {
// The types don't match so just return the input unchanged...
return DidViewChange.no();
} else {
String[] foundNames;
if (children == null) {
foundNames = new String[0];
} else {
foundNames = children.getNames();
}
HashMap<String, Integer> indices = new HashMap<>();
for (int i = 0; i < foundNames.length; i++) {
indices.put(foundNames[i], i);
}
// We might need to rearrange the columns to match what we want.
DType[] types = schema.getChildTypes();
String[] neededNames = schema.getColumnNames();
ColumnView[] columns = new ColumnView[neededNames.length];
try {
boolean somethingChanged = false;
if (columns.length != foundNames.length) {
somethingChanged = true;
}
for (int i = 0; i < columns.length; i++) {
String neededColumnName = neededNames[i];
Integer index = indices.get(neededColumnName);
Schema childSchema = schema.getChild(i);
if (index != null) {
if (childSchema.isStructOrHasStructDescendant()) {
ColumnView child = cv.getChildColumnView(index);
boolean shouldCloseChild = true;
try {
if (index != i) {
somethingChanged = true;
}
DidViewChange childResult = gatherJSONColumns(schema.getChild(i),
children.getChild(index), child);
if (childResult.noChangeNeeded) {
shouldCloseChild = false;
columns[i] = child;
} else {
somethingChanged = true;
columns[i] = childResult.changeWasNeeded;
}
} finally {
if (shouldCloseChild) {
child.close();
}
}
} else {
if (index != i) {
somethingChanged = true;
}
columns[i] = cv.getChildColumnView(index);
}
} else {
somethingChanged = true;
if (types[i] == DType.LIST) {
try (Scalar s = Scalar.listFromNull(childSchema.getChild(0).asHostDataType())) {
columns[i] = ColumnVector.fromScalar(s, (int) cv.getRowCount());
}
} else if (types[i] == DType.STRUCT) {
int numStructChildren = childSchema.getNumChildren();
HostColumnVector.DataType[] structChildren = new HostColumnVector.DataType[numStructChildren];
for (int structChildIndex = 0; structChildIndex < numStructChildren; structChildIndex++) {
structChildren[structChildIndex] = childSchema.getChild(structChildIndex).asHostDataType();
}
try (Scalar s = Scalar.structFromNull(structChildren)) {
columns[i] = ColumnVector.fromScalar(s, (int) cv.getRowCount());
}
} else {
try (Scalar s = Scalar.fromNull(types[i])) {
columns[i] = ColumnVector.fromScalar(s, (int) cv.getRowCount());
}
}
}
}
if (somethingChanged) {
try (ColumnView ret = new ColumnView(cv.type, cv.rows, Optional.of(cv.nullCount),
cv.getValid(), null, columns)) {
return DidViewChange.yes(ret.copyToColumnVector());
}
} else {
return DidViewChange.no();
}
} finally {
for (ColumnView c: columns) {
if (c != null) {
c.close();
}
}
}
}
} else if (schema.getType() == DType.LIST && cv.getType() == DType.LIST) {
if (schema.isStructOrHasStructDescendant()) {
String [] childNames = children.getNames();
if (childNames.length == 2 &&
"offsets".equals(childNames[0]) &&
"element".equals(childNames[1])) {
try (ColumnView child = cv.getChildColumnView(0)){
DidViewChange listResult = gatherJSONColumns(schema.getChild(0),
children.getChild(1), child);
if (listResult.noChangeNeeded) {
return DidViewChange.no();
} else {
try (ColumnView listView = new ColumnView(cv.type, cv.rows,
Optional.of(cv.nullCount), cv.getValid(), cv.getOffsets(),
new ColumnView[]{listResult.changeWasNeeded})) {
return DidViewChange.yes(listView.copyToColumnVector());
} finally {
listResult.changeWasNeeded.close();
}
}
}
}
}
// Nothing to change so just return the input, but we need to inc a ref count to really
// make it work, so for now we are going to turn it into a ColumnVector.
return DidViewChange.no();
} else {
// Nothing to change so just return the input, but we need to inc a ref count to really
// make it work, so for now we are going to turn it into a ColumnVector.
return DidViewChange.no();
}
}

private static Table gatherJSONColumns(Schema schema, TableWithMeta twm, int emptyRowCount) {
String[] neededColumns = schema.getColumnNames();
if (neededColumns == null || neededColumns.length == 0) {
return twm.releaseTable();
} else {
String[] foundNames = twm.getColumnNames();
HashMap<String, Integer> indices = new HashMap<>();
for (int i = 0; i < foundNames.length; i++) {
indices.put(foundNames[i], i);
}
// We might need to rearrange the columns to match what we want.
DType[] types = schema.getChildTypes();
ColumnVector[] columns = new ColumnVector[neededColumns.length];
try (Table tbl = twm.releaseTable()) {
int rowCount = tbl == null ? emptyRowCount : (int)tbl.getRowCount();
if (rowCount < 0) {
throw new IllegalStateException(
"No empty row count provided and the table read has no row count or columns");
}
for (int i = 0; i < columns.length; i++) {
String neededColumnName = neededColumns[i];
Integer index = indices.get(neededColumnName);
if (index != null) {
if (schema.getChild(i).isStructOrHasStructDescendant()) {
DidViewChange gathered = gatherJSONColumns(schema.getChild(i), twm.getChild(index),
tbl.getColumn(index));
if (gathered.noChangeNeeded) {
columns[i] = tbl.getColumn(index).incRefCount();
} else {
columns[i] = gathered.changeWasNeeded;
}
} else {
columns[i] = tbl.getColumn(index).incRefCount();
}
} else {
if (types[i] == DType.LIST) {
Schema listSchema = schema.getChild(i);
Schema elementSchema = listSchema.getChild(0);
try (Scalar s = Scalar.listFromNull(elementSchema.asHostDataType())) {
columns[i] = ColumnVector.fromScalar(s, rowCount);
}
} else if (types[i] == DType.STRUCT) {
Schema structSchema = schema.getChild(i);
int numStructChildren = structSchema.getNumChildren();
DataType[] structChildrenTypes = new DataType[numStructChildren];
for (int j = 0; j < numStructChildren; j++) {
structChildrenTypes[j] = structSchema.getChild(j).asHostDataType();
}
try (Scalar s = Scalar.structFromNull(structChildrenTypes)) {
columns[i] = ColumnVector.fromScalar(s, rowCount);
}
} else {
try (Scalar s = Scalar.fromNull(types[i])) {
columns[i] = ColumnVector.fromScalar(s, rowCount);
}
}
}
}
return new Table(columns);
} finally {
for (ColumnVector c: columns) {
if (c != null) {
c.close();
}
}
}
}
}

/**
* Read a JSON file.
* @param schema the schema of the file. You may use Schema.INFERRED to infer the schema.
Expand Down Expand Up @@ -1340,7 +1122,7 @@ public static Table readJSON(Schema schema, JSONOptions opts, File path) {
opts.experimental(),
opts.getLineDelimiter()))) {

return gatherJSONColumns(schema, twm, -1);
return twm.releaseTable();
}
}

Expand All @@ -1361,6 +1143,10 @@ public static Table readJSON(Schema schema, JSONOptions opts, byte[] buffer, lon

/**
* Read JSON formatted data.
*
* @deprecated This method is deprecated since emptyRowCount is not used. Use the method without
* emptyRowCount instead.
*
* @param schema the schema of the data. You may use Schema.INFERRED to infer the schema.
* @param opts various JSON parsing options.
* @param buffer raw UTF8 formatted bytes.
Expand All @@ -1370,6 +1156,7 @@ public static Table readJSON(Schema schema, JSONOptions opts, byte[] buffer, lon
* @param emptyRowCount the number of rows to return if no columns were read.
* @return the data parsed as a table on the GPU.
*/
@SuppressWarnings("unused")
public static Table readJSON(Schema schema, JSONOptions opts, byte[] buffer, long offset,
long len, HostMemoryAllocator hostMemoryAllocator,
int emptyRowCount) {
Expand All @@ -1381,14 +1168,14 @@ public static Table readJSON(Schema schema, JSONOptions opts, byte[] buffer, lon
assert offset >= 0 && offset < buffer.length;
try (HostMemoryBuffer newBuf = hostMemoryAllocator.allocate(len)) {
newBuf.setBytes(0, buffer, offset, len);
return readJSON(schema, opts, newBuf, 0, len, emptyRowCount);
return readJSON(schema, opts, newBuf, 0, len);
}
}

@SuppressWarnings("unused")
public static Table readJSON(Schema schema, JSONOptions opts, byte[] buffer, long offset,
long len, int emptyRowCount) {
return readJSON(schema, opts, buffer, offset, len, DefaultHostMemoryAllocator.get(),
emptyRowCount);
return readJSON(schema, opts, buffer, offset, len, DefaultHostMemoryAllocator.get());
}

public static Table readJSON(Schema schema, JSONOptions opts, byte[] buffer, long offset,
Expand Down Expand Up @@ -1470,6 +1257,10 @@ public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer b

/**
* Read JSON formatted data.
*
* @deprecated This method is deprecated since emptyRowCount is not used. Use the method without
* emptyRowCount instead.
*
* @param schema the schema of the data. You may use Schema.INFERRED to infer the schema.
* @param opts various JSON parsing options.
* @param buffer raw UTF8 formatted bytes.
Expand All @@ -1478,6 +1269,7 @@ public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer b
* @param emptyRowCount the number of rows to use if no columns were found.
* @return the data parsed as a table on the GPU.
*/
@SuppressWarnings("unused")
public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer buffer,
long offset, long len, int emptyRowCount) {
if (len <= 0) {
Expand Down Expand Up @@ -1508,7 +1300,7 @@ public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer b
cudfPruneSchema,
opts.experimental(),
opts.getLineDelimiter()))) {
return gatherJSONColumns(schema, twm, emptyRowCount);
return twm.releaseTable();
}
}

Expand All @@ -1525,12 +1317,17 @@ public static Table readJSON(Schema schema, JSONOptions opts, DataSource ds) {

/**
* Read JSON formatted data.
*
* @deprecated This method is deprecated since emptyRowCount is not used. Use the method without
* emptyRowCount instead.
*
* @param schema the schema of the data. You may use Schema.INFERRED to infer the schema.
* @param opts various JSON parsing options.
* @param ds the DataSource to read from.
* @param emptyRowCount the number of rows to return if no columns were read.
* @return the data parsed as a table on the GPU.
*/
@SuppressWarnings("unused")
public static Table readJSON(Schema schema, JSONOptions opts, DataSource ds, int emptyRowCount) {
long dsHandle = DataSourceHelper.createWrapperDataSource(ds);
// only prune the schema if one is provided
Expand All @@ -1554,7 +1351,7 @@ public static Table readJSON(Schema schema, JSONOptions opts, DataSource ds, int
opts.experimental(),
opts.getLineDelimiter(),
dsHandle))) {
return gatherJSONColumns(schema, twm, emptyRowCount);
return twm.releaseTable();
} finally {
DataSourceHelper.destroyWrapperDataSource(dsHandle);
}
Expand Down
24 changes: 18 additions & 6 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,21 +1037,23 @@ cudf::io::schema_element read_schema_element(int& index,
if (d_type.id() == cudf::type_id::STRUCT || d_type.id() == cudf::type_id::LIST) {
std::map<std::string, cudf::io::schema_element> child_elems;
int num_children = children[index];
std::vector<std::string> child_names(num_children);
// go to the next entry, so recursion can parse it.
index++;
for (int i = 0; i < num_children; i++) {
auto const name = std::string{names.get(index).get()};
auto name = std::string{names.get(index).get()};
child_elems.insert(
std::pair{name, cudf::jni::read_schema_element(index, children, names, types, scales)});
child_names[i] = std::move(name);
}
return cudf::io::schema_element{d_type, std::move(child_elems)};
return cudf::io::schema_element{d_type, std::move(child_elems), {std::move(child_names)}};
} else {
if (children[index] != 0) {
throw std::invalid_argument("found children for a type that should have none");
}
// go to the next entry before returning...
index++;
return cudf::io::schema_element{d_type, {}};
return cudf::io::schema_element{d_type, {}, std::nullopt};
}
}

Expand Down Expand Up @@ -1886,13 +1888,18 @@ Java_ai_rapids_cudf_Table_readJSONFromDataSource(JNIEnv* env,
}

std::map<std::string, cudf::io::schema_element> data_types;
std::vector<std::string> name_order;
int at = 0;
while (at < n_types.size()) {
auto const name = std::string{n_col_names.get(at).get()};
data_types.insert(std::pair{
name, cudf::jni::read_schema_element(at, n_children, n_col_names, n_types, n_scales)});
name_order.push_back(name);
}
opts.dtypes(data_types);

cudf::io::schema_element structs{
cudf::data_type{cudf::type_id::STRUCT}, std::move(data_types), {std::move(name_order)}};
opts.dtypes(structs);
} else {
// should infer the types
}
Expand Down Expand Up @@ -2001,13 +2008,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSON(JNIEnv* env,
}

std::map<std::string, cudf::io::schema_element> data_types;
std::vector<std::string> name_order;
name_order.reserve(n_types.size());
int at = 0;
while (at < n_types.size()) {
auto const name = std::string{n_col_names.get(at).get()};
auto name = std::string{n_col_names.get(at).get()};
data_types.insert(std::pair{
name, cudf::jni::read_schema_element(at, n_children, n_col_names, n_types, n_scales)});
name_order.emplace_back(std::move(name));
}
opts.dtypes(data_types);
cudf::io::schema_element structs{
cudf::data_type{cudf::type_id::STRUCT}, std::move(data_types), {std::move(name_order)}};
opts.dtypes(structs);
} else {
// should infer the types
}
Expand Down
Loading