Skip to content

Commit

Permalink
chore(datastore): remove schema checking when query or scan (#2729)
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui authored Sep 11, 2023
1 parent 4e15002 commit d0fb07a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import ai.starwhale.mlops.datastore.impl.MemoryTableImpl;
import ai.starwhale.mlops.datastore.impl.RecordEncoder;
import ai.starwhale.mlops.datastore.type.BaseValue;
import ai.starwhale.mlops.datastore.type.ListValue;
import ai.starwhale.mlops.datastore.type.MapValue;
import ai.starwhale.mlops.datastore.type.ObjectValue;
import ai.starwhale.mlops.exception.SwProcessException;
import ai.starwhale.mlops.exception.SwProcessException.ErrorType;
import ai.starwhale.mlops.exception.SwValidationException;
Expand Down Expand Up @@ -245,12 +242,7 @@ public RecordList query(DataStoreQueryRequest req) {
}
var records = results.stream()
.filter(r -> !r.isDeleted())
.map(r -> {
if (columnSchemaMap != null) {
checkRecord(columnSchemaMap, r.getValues());
}
return RecordEncoder.encodeRecord(r.getValues(), req.isRawResult(), req.isEncodeWithType());
})
.map(r -> RecordEncoder.encodeRecord(r.getValues(), req.isRawResult(), req.isEncodeWithType()))
.collect(Collectors.toList());
return new RecordList(columnSchemaMap,
table.getColumnStatistics(columns).entrySet().stream()
Expand Down Expand Up @@ -420,9 +412,6 @@ record = null;
if (record == null) {
record = new HashMap<>();
}
if (columnSchemaMap != null && r.record.getValues() != null) {
checkRecord(columnSchemaMap, r.record.getValues());
}
record.putAll(
RecordEncoder.encodeRecord(r.record.getValues(),
req.isRawResult(),
Expand Down Expand Up @@ -644,46 +633,4 @@ public void terminate() {
}
}
}

private static void checkRecord(Map<String, ColumnSchema> columnSchemaMap, Map<String, BaseValue> record) {
record.forEach((k, v) -> checkValue(columnSchemaMap.get(k), v));
}

private static void checkValue(ColumnSchema schema, BaseValue value) {
if (value == null) {
return;
}
if (value.getColumnType() != schema.getType()) {
throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE,
"mixed column type. try to query/scan with encodeWithType=true");
}
switch (value.getColumnType()) {
case LIST:
case TUPLE:
((ListValue) value).forEach(e -> checkValue(schema.getElementSchema(), e));
break;
case MAP:
((MapValue) value).forEach((k, v) -> {
checkValue(schema.getKeySchema(), k);
checkValue(schema.getValueSchema(), v);
});
break;
case OBJECT:
if (!((ObjectValue) value).getPythonType().equals(schema.getPythonType())) {
throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE,
"mixed column type. try to query/scan with encodeWithType=true");
}
((ObjectValue) value).forEach((k, v) -> {
var attrSchema = schema.getAttributesSchema().get(k);
if (attrSchema == null) {
throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE,
"mixed column type. try to query/scan with encodeWithType=true");
}
checkValue(attrSchema, v);
});
break;
default:
break;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
Expand Down Expand Up @@ -275,10 +276,27 @@ public void testUpdate() throws InterruptedException {
}));
}
};
assertThrows(SwValidationException.class, () -> this.controller.scanTable(req));
req.setEncodeWithType(true);
req.setEncodeWithType(false);
var resp = this.controller.scanTable(req);
assertThat("t1", resp.getStatusCode().is2xxSuccessful(), is(true));
assertThat("t1", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), notNullValue());
assertThat("t1",
Objects.requireNonNull(resp.getBody()).getData().getRecords(),
is(List.of(Map.of("k", "00000000", "b", "1"), Map.of("k", "00000001", "b", "00000002"))));
assertThat("test",
Objects.requireNonNull(resp.getBody()).getData().getColumnHints(),
is(Map.of("k", ColumnHintsDesc.builder()
.typeHints(List.of("INT32"))
.columnValueHints(List.of("0", "1", "4"))
.build(),
"b", ColumnHintsDesc.builder()
.typeHints(List.of("INT32", "STRING"))
.columnValueHints(List.of("1", "2"))
.build())));

req.setEncodeWithType(true);
resp = this.controller.scanTable(req);
assertThat("t1", resp.getStatusCode().is2xxSuccessful(), is(true));
assertThat("t1",
Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(),
nullValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,12 @@ public void testQuery() {
.operator(Operator.EQUAL)
.operands(List.of(new TableQueryFilter.Column("k"), new Constant(ColumnType.STRING, "0")))
.build());
assertThrows(SwValidationException.class, () -> this.dataStore.query(req.build()));

recordList = this.dataStore.query(req.build());
assertThat("typed", recordList.getColumnSchemaMap(), notNullValue());
assertThat("typed", recordList.getRecords(), is(List.of(Map.of("k", "00000000"), Map.of("k", "0"))));


recordList = this.dataStore.query(req.encodeWithType(true).build());
assertThat("mixed", recordList.getColumnSchemaMap(), nullValue());
assertThat("test",
Expand Down Expand Up @@ -388,7 +393,7 @@ public Object apply(String str, Boolean rawResult) {
}

var encodeString = new EncodeString();
var testParams = new boolean[]{true, false, true, false, true, false};
var testParams = new boolean[] {true, false, true, false, true, false};
for (boolean rawResult : testParams) {
var recordList = this.dataStore.query(DataStoreQueryRequest.builder()
.tableName("t1")
Expand Down Expand Up @@ -1374,11 +1379,28 @@ public void testAllTypes() throws Exception {
this.dataStore.update("t",
new TableSchemaDesc("key", List.of(ColumnSchemaDesc.builder().name("key").type("INT32").build())),
List.of(Map.of("key", "1")));
assertThrows(SwValidationException.class, () -> this.dataStore.scan(DataStoreScanRequest.builder()
.tables(List.of(DataStoreScanRequest.TableInfo.builder()
.tableName("t")
.build()))
.build()));
result = this.dataStore.scan(DataStoreScanRequest.builder()
.tables(List.of(DataStoreScanRequest.TableInfo.builder().tableName("t").build())).build());
assertThat(result.getColumnSchemaMap().entrySet().stream()
.collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())),
is(new HashMap<>() {{
put("key", ColumnType.INT32);
put("a", ColumnType.BOOL);
put("b", ColumnType.INT8);
put("c", ColumnType.INT16);
put("d", ColumnType.INT32);
put("e", ColumnType.INT64);
put("f", ColumnType.FLOAT32);
put("g", ColumnType.FLOAT64);
put("h", ColumnType.BYTES);
put("i", ColumnType.UNKNOWN);
put("j", ColumnType.LIST);
put("k", ColumnType.OBJECT);
put("l", ColumnType.TUPLE);
put("m", ColumnType.MAP);
put("complex", ColumnType.OBJECT);
}}
));
result = this.dataStore.scan(DataStoreScanRequest.builder()
.tables(List.of(DataStoreScanRequest.TableInfo.builder()
.tableName("t")
Expand Down

0 comments on commit d0fb07a

Please sign in to comment.