Skip to content

Commit

Permalink
add 'TypeStrategy' to types (apache#11888)
Browse files Browse the repository at this point in the history
* add TypeStrategy - value comparators and binary serialization for any TypeSignature

(cherry picked from commit e583033)
  • Loading branch information
clintropolis authored and sachinsagare committed Oct 31, 2022
1 parent 736f3e0 commit ccf3b33
Show file tree
Hide file tree
Showing 32 changed files with 1,885 additions and 1,452 deletions.
17 changes: 11 additions & 6 deletions core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.math.expr.vector.VectorProcessors;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.TypeStrategy;

import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Objects;

Expand Down Expand Up @@ -434,14 +434,19 @@ public String stringify()
if (value == null) {
return StringUtils.format("complex_decode_base64('%s', %s)", outputType.getComplexTypeName(), NULL_LITERAL);
}
ObjectByteStrategy strategy = Types.getStrategy(outputType.getComplexTypeName());
if (strategy == null) {
throw new IAE("Cannot stringify type[%s]", outputType.asTypeString());
TypeStrategy strategy = outputType.getStrategy();
byte[] bytes = new byte[strategy.estimateSizeBytes(value)];
ByteBuffer wrappedBytes = ByteBuffer.wrap(bytes);
int remaining = strategy.write(wrappedBytes, 0, value, bytes.length);
if (remaining < 0) {
bytes = new byte[bytes.length - remaining];
wrappedBytes = ByteBuffer.wrap(bytes);
strategy.write(wrappedBytes, 0, value, bytes.length);
}
return StringUtils.format(
"complex_decode_base64('%s', '%s')",
outputType.getComplexTypeName(),
StringUtils.encodeBase64String(strategy.toBytes(value))
StringUtils.encodeBase64String(bytes)
);
}

Expand Down
95 changes: 32 additions & 63 deletions core/src/main/java/org/apache/druid/math/expr/ExprEval.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.NullableTypeStrategy;
import org.apache.druid.segment.column.TypeStrategies;
import org.apache.druid.segment.column.TypeStrategy;

import javax.annotation.Nullable;
import java.nio.ByteBuffer;
Expand All @@ -50,36 +51,17 @@ public static ExprEval deserialize(ByteBuffer buffer, int offset, ExpressionType
{
switch (type.getType()) {
case LONG:
if (Types.isNullableNull(buffer, offset)) {
if (TypeStrategies.isNullableNull(buffer, offset)) {
return ofLong(null);
}
return of(Types.readNullableLong(buffer, offset));
return of(TypeStrategies.readNotNullNullableLong(buffer, offset));
case DOUBLE:
if (Types.isNullableNull(buffer, offset)) {
if (TypeStrategies.isNullableNull(buffer, offset)) {
return ofDouble(null);
}
return of(Types.readNullableDouble(buffer, offset));
case STRING:
if (Types.isNullableNull(buffer, offset)) {
return of(null);
}
final byte[] stringBytes = Types.readNullableVariableBlob(buffer, offset);
return of(StringUtils.fromUtf8(stringBytes));
case ARRAY:
switch (type.getElementType().getType()) {
case LONG:
return ofLongArray(Types.readNullableLongArray(buffer, offset));
case DOUBLE:
return ofDoubleArray(Types.readNullableDoubleArray(buffer, offset));
case STRING:
return ofStringArray(Types.readNullableStringArray(buffer, offset));
default:
throw new UOE("Cannot deserialize expression array of type %s", type);
}
case COMPLEX:
return ofComplex(type, Types.readNullableComplexType(buffer, offset, type));
return of(TypeStrategies.readNotNullNullableDouble(buffer, offset));
default:
throw new UOE("Cannot deserialize expression type %s", type);
return ofType(type, type.getNullableStrategy().read(buffer, offset));
}
}

Expand All @@ -90,55 +72,39 @@ public static ExprEval deserialize(ByteBuffer buffer, int offset, ExpressionType
*
* This should be refactored to be consolidated with some of the standard type handling of aggregators probably
*/
public static void serialize(ByteBuffer buffer, int position, ExprEval<?> eval, int maxSizeBytes)
public static void serialize(ByteBuffer buffer, int position, ExpressionType type, ExprEval<?> eval, int maxSizeBytes)
{
int offset = position;
switch (eval.type().getType()) {
switch (type.getType()) {
case LONG:
if (eval.isNumericNull()) {
Types.writeNull(buffer, offset);
TypeStrategies.writeNull(buffer, offset);
} else {
Types.writeNullableLong(buffer, offset, eval.asLong());
TypeStrategies.writeNotNullNullableLong(buffer, offset, eval.asLong());
}
break;
case DOUBLE:
if (eval.isNumericNull()) {
Types.writeNull(buffer, offset);
TypeStrategies.writeNull(buffer, offset);
} else {
Types.writeNullableDouble(buffer, offset, eval.asDouble());
TypeStrategies.writeNotNullNullableDouble(buffer, offset, eval.asDouble());
}
break;
case STRING:
final byte[] stringBytes = StringUtils.toUtf8Nullable(eval.asString());
if (stringBytes != null) {
Types.writeNullableVariableBlob(buffer, offset, stringBytes, eval.type(), maxSizeBytes);
} else {
Types.writeNull(buffer, offset);
default:
final NullableTypeStrategy strategy = type.getNullableStrategy();
// if the types don't match, cast it so things don't get weird
if (type.equals(eval.type())) {
eval = eval.castTo(type);
}
break;
case ARRAY:
switch (eval.type().getElementType().getType()) {
case LONG:
Long[] longs = eval.asLongArray();
Types.writeNullableLongArray(buffer, offset, longs, maxSizeBytes);
break;
case DOUBLE:
Double[] doubles = eval.asDoubleArray();
Types.writeNullableDoubleArray(buffer, offset, doubles, maxSizeBytes);
break;
case STRING:
String[] strings = eval.asStringArray();
Types.writeNullableStringArray(buffer, offset, strings, maxSizeBytes);
break;
default:
throw new UOE("Cannot serialize expression array type %s", eval.type());
int written = strategy.write(buffer, offset, eval.value(), maxSizeBytes);
if (written < 0) {
throw new ISE(
"Unable to serialize [%s], max size bytes is [%s], but need at least [%s] bytes to write entire value",
type.asTypeString(),
maxSizeBytes,
maxSizeBytes - written
);
}
break;
case COMPLEX:
Types.writeNullableComplexType(buffer, offset, eval.type(), eval.value(), maxSizeBytes);
break;
default:
throw new UOE("Cannot serialize expression type %s", eval.type());
}
}

Expand Down Expand Up @@ -453,10 +419,10 @@ public static ExprEval ofType(@Nullable ExpressionType type, @Nullable Object va
}

if (bytes != null) {
ObjectByteStrategy<?> strategy = Types.getStrategy(type.getComplexTypeName());
TypeStrategy<?> strategy = type.getStrategy();
assert strategy != null;
ByteBuffer bb = ByteBuffer.wrap(bytes);
return ofComplex(type, strategy.fromByteBuffer(bb, bytes.length));
return ofComplex(type, strategy.read(bb));
}

return ofComplex(type, value);
Expand Down Expand Up @@ -1208,6 +1174,9 @@ public ExprEval castTo(ExpressionType castTo)
}
return ExprEval.ofType(castTo, null);
}
if (type().equals(castTo)) {
return this;
}
switch (castTo.getType()) {
case STRING:
if (value.length == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public ExpressionType(
@JsonProperty("elementType") @Nullable ExpressionType elementType
)
{
super(exprType, complexTypeName, elementType);
super(ExpressionTypeFactory.getInstance(), exprType, complexTypeName, elementType);
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

import com.google.common.collect.Interner;
import com.google.common.collect.Interners;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.TypeFactory;
import org.apache.druid.segment.column.TypeStrategies;
import org.apache.druid.segment.column.TypeStrategy;

import javax.annotation.Nullable;

Expand Down Expand Up @@ -85,4 +89,34 @@ public ExpressionType ofComplex(@Nullable String complexTypeName)
{
return INTERNER.intern(new ExpressionType(ExprType.COMPLEX, complexTypeName, null));
}

@Override
public <T> TypeStrategy<T> getTypeStrategy(ExpressionType expressionType)
{
final TypeStrategy strategy;
switch (expressionType.getType()) {
case LONG:
strategy = TypeStrategies.LONG;
break;
case DOUBLE:
strategy = TypeStrategies.DOUBLE;
break;
case STRING:
strategy = TypeStrategies.STRING;
break;
case ARRAY:
strategy = new TypeStrategies.ArrayTypeStrategy(expressionType);
break;
case COMPLEX:
TypeStrategy<?> complexStrategy = TypeStrategies.getComplex(expressionType.getComplexTypeName());
if (complexStrategy == null) {
throw new IAE("Cannot find strategy for type [%s]", expressionType.asTypeString());
}
strategy = complexStrategy;
break;
default:
throw new ISE("Unsupported column type[%s]", expressionType.getType());
}
return strategy;
}
}
33 changes: 17 additions & 16 deletions core/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
import org.apache.druid.math.expr.vector.VectorMathProcessors;
import org.apache.druid.math.expr.vector.VectorProcessors;
import org.apache.druid.math.expr.vector.VectorStringProcessors;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.TypeStrategy;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.joda.time.format.DateTimeFormat;
Expand All @@ -59,7 +58,7 @@
/**
* Base interface describing the mechanism used to evaluate a {@link FunctionExpr}. All {@link Function} implementations
* are immutable.
* <p>
*
* Do NOT remove "unused" members in this class. They are used by generated Antlr
*/
@SuppressWarnings("unused")
Expand Down Expand Up @@ -1977,9 +1976,9 @@ protected ExprEval eval(ExprEval x, ExprEval y)
public Set<Expr> getScalarInputs(List<Expr> args)
{
if (args.get(1).isLiteral()) {
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1)
.getLiteralValue()
.toString()));
ExpressionType castTo = ExpressionType.fromString(
StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())
);
switch (castTo.getType()) {
case ARRAY:
return Collections.emptySet();
Expand All @@ -1995,9 +1994,9 @@ public Set<Expr> getScalarInputs(List<Expr> args)
public Set<Expr> getArrayInputs(List<Expr> args)
{
if (args.get(1).isLiteral()) {
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1)
.getLiteralValue()
.toString()));
ExpressionType castTo = ExpressionType.fromString(
StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())
);
switch (castTo.getType()) {
case LONG:
case DOUBLE:
Expand Down Expand Up @@ -3679,14 +3678,16 @@ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
name()
);
}
ExpressionType complexType = ExpressionTypeFactory.getInstance()
.ofComplex((String) args.get(0).getLiteralValue());
ObjectByteStrategy strategy = Types.getStrategy(complexType.getComplexTypeName());
if (strategy == null) {
ExpressionType type = ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
TypeStrategy strategy;
try {
strategy = type.getStrategy();
}
catch (IAE illegal) {
throw new IAE(
"Function[%s] first argument must be a valid complex type name, unknown complex type [%s]",
name(),
complexType.asTypeString()
type.asTypeString()
);
}
ExprEval base64String = args.get(1).eval(bindings);
Expand All @@ -3697,11 +3698,11 @@ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
);
}
if (base64String.value() == null) {
return ExprEval.ofComplex(complexType, null);
return ExprEval.ofComplex(type, null);
}

final byte[] base64 = StringUtils.decodeBase64String(base64String.asString());
return ExprEval.ofComplex(complexType, strategy.fromByteBuffer(ByteBuffer.wrap(base64), base64.length));
return ExprEval.ofComplex(type, strategy.read(ByteBuffer.wrap(base64)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package org.apache.druid.segment.column;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;

import javax.annotation.Nullable;
import java.util.Objects;
Expand All @@ -34,7 +36,11 @@ public abstract class BaseTypeSignature<Type extends TypeDescriptor> implements
@Nullable
protected final TypeSignature<Type> elementType;

private final Supplier<TypeStrategy> typeStrategy;
private final Supplier<NullableTypeStrategy> nullableTypeStrategy;

public BaseTypeSignature(
TypeFactory typeFactory,
Type type,
@Nullable String complexTypeName,
@Nullable TypeSignature<Type> elementType
Expand All @@ -43,6 +49,8 @@ public BaseTypeSignature(
this.type = type;
this.complexTypeName = complexTypeName;
this.elementType = elementType;
this.typeStrategy = Suppliers.memoize(() -> typeFactory.getTypeStrategy(this));
this.nullableTypeStrategy = Suppliers.memoize(() -> new NullableTypeStrategy<>(typeStrategy.get()));
}

@Override
Expand All @@ -68,6 +76,18 @@ public TypeSignature<Type> getElementType()
return elementType;
}

@Override
public <T> TypeStrategy<T> getStrategy()
{
return typeStrategy.get();
}

@Override
public <T> NullableTypeStrategy<T> getNullableStrategy()
{
return nullableTypeStrategy.get();
}

@Override
public boolean equals(Object o)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public ColumnType(
@JsonProperty("elementType") @Nullable ColumnType elementType
)
{
super(type, complexTypeName, elementType);
super(ColumnTypeFactory.getInstance(), type, complexTypeName, elementType);
}

@Nullable
Expand Down
Loading

0 comments on commit ccf3b33

Please sign in to comment.