Skip to content

Commit

Permalink
Rework on now function implementation (#113)
Browse files Browse the repository at this point in the history
Signed-off-by: Yury-Fridlyand <[email protected]>
  • Loading branch information
Yury-Fridlyand committed Sep 15, 2022
1 parent de3dac1 commit 425a998
Show file tree
Hide file tree
Showing 22 changed files with 360 additions and 289 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ public class QueryContext {
*/
private static final String REQUEST_ID_KEY = "request_id";

/**
* Timestamp when SQL plugin started to process current request.
*/
private static final String REQUEST_PROCESSING_STARTED = "request_processing_started";

/**
* Generates a random UUID and adds to the {@link ThreadContext} as the request id.
* <p>
Expand Down Expand Up @@ -56,22 +51,6 @@ public static String getRequestId() {
return id;
}

public static void recordProcessingStarted() {
ThreadContext.put(REQUEST_PROCESSING_STARTED, LocalDateTime.now().toString());
}

/**
* Get recorded previously time indicating when processing started for the current query.
* @return A LocalDateTime object
*/
public static LocalDateTime getProcessingStartedTime() {
if (ThreadContext.containsKey(REQUEST_PROCESSING_STARTED)) {
return LocalDateTime.parse(ThreadContext.get(REQUEST_PROCESSING_STARTED));
}
// This shouldn't happen outside of unit tests
return LocalDateTime.now();
}

/**
* Wraps a given instance of {@link Runnable} into a new one which gets all the
* entries from current ThreadContext map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
package org.opensearch.sql.analysis;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.Getter;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.NamedExpression;

/**
Expand All @@ -23,13 +26,26 @@ public class AnalysisContext {
@Getter
private final List<NamedExpression> namedParseExpressions;

/**
* Storage for values of functions which return a constant value.
* We are storing the values there to use it in sequential calls to those functions.
* For example, `now` function should the same value during processing a query.
*/
@Getter
private final Map<String, Expression> constantFunctionValues;

public AnalysisContext() {
this(new TypeEnvironment(null));
}

/**
* Class CTOR.
* @param environment Env to set to a new instance.
*/
public AnalysisContext(TypeEnvironment environment) {
this.environment = environment;
this.namedParseExpressions = new ArrayList<>();
this.constantFunctionValues = new HashMap<>();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.ConstantFunction;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -169,6 +170,19 @@ public Expression visitRelevanceFieldList(RelevanceFieldList node, AnalysisConte
ImmutableMap.copyOf(node.getFieldList())));
}

@Override
public Expression visitConstantFunction(ConstantFunction node, AnalysisContext context) {
var valueName = node.getFuncName();
if (context.getConstantFunctionValues().containsKey(valueName)) {
return context.getConstantFunctionValues().get(valueName);
}

var value = visitFunction(node, context);
value = DSL.literal(value.valueOf(null));
context.getConstantFunctionValues().put(valueName, value);
return value;
}

@Override
public Expression visitFunction(Function node, AnalysisContext context) {
FunctionName functionName = FunctionName.of(node.getFuncName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.ConstantFunction;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -116,6 +117,10 @@ public T visitRelevanceFieldList(RelevanceFieldList node, C context) {
return visitChildren(node, context);
}

public T visitConstantFunction(ConstantFunction node, C context) {
return visitChildren(node, context);
}

public T visitUnresolvedAttribute(UnresolvedAttribute node, C context) {
return visitChildren(node, context);
}
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.ConstantFunction;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
Expand Down Expand Up @@ -234,6 +235,10 @@ public static Function function(String funcName, UnresolvedExpression... funcArg
return new Function(funcName, Arrays.asList(funcArgs));
}

public static Function constantFunction(String funcName, UnresolvedExpression... funcArgs) {
return new ConstantFunction(funcName, Arrays.asList(funcArgs));
}

/**
* CASE
* WHEN search_condition THEN result_expr
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.expression;

import java.util.List;
import lombok.EqualsAndHashCode;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node that holds a function which should be replaced by its constant[1] value.
* [1] Constant at execution time.
*/
@EqualsAndHashCode(callSuper = false)
public class ConstantFunction extends Function {

public ConstantFunction(String funcName, List<UnresolvedExpression> funcArgs) {
super(funcName, funcArgs);
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitConstantFunction(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
import java.time.format.TextStyle;
import java.util.Locale;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.common.utils.QueryContext;
import org.opensearch.sql.data.model.ExprDateValue;
import org.opensearch.sql.data.model.ExprDatetimeValue;
import org.opensearch.sql.data.model.ExprIntegerValue;
Expand Down Expand Up @@ -105,15 +103,12 @@ public void register(BuiltinFunctionRepository repository) {

/**
* NOW() returns a constant time that indicates the time at which the statement began to execute.
* `fsp` argument support is removed until refactoring to avoid bug where `now()`, `now(x)` and
* `now(y) return different values.
*/
private LocalDateTime now(@Nullable Integer fsp) {
return formatLocalDateTime(QueryContext::getProcessingStartedTime, fsp);
}

private FunctionResolver now(FunctionName functionName) {
return define(functionName,
impl(() -> new ExprDatetimeValue(now((Integer)null)), DATETIME),
impl((v) -> new ExprDatetimeValue(now(v.integerValue())), DATETIME, INTEGER)
impl(() -> new ExprDatetimeValue(formatNow(null)), DATETIME)
);
}

Expand All @@ -136,21 +131,19 @@ private FunctionResolver localtime() {
/**
* SYSDATE() returns the time at which it executes.
*/
private LocalDateTime sysDate(@Nullable Integer fsp) {
return formatLocalDateTime(LocalDateTime::now, fsp);
}

private FunctionResolver sysdate() {
return define(BuiltinFunctionName.SYSDATE.getName(),
impl(() -> new ExprDatetimeValue(sysDate(null)), DATETIME),
impl((v) -> new ExprDatetimeValue(sysDate(v.integerValue())), DATETIME, INTEGER)
impl(() -> new ExprDatetimeValue(formatNow(null)), DATETIME),
impl((v) -> new ExprDatetimeValue(formatNow(v.integerValue())), DATETIME, INTEGER)
);
}

/**
* Synonym for @see `now`.
*/
private FunctionResolver curtime(FunctionName functionName) {
return define(functionName,
impl(() -> new ExprTimeValue(sysDate(null).toLocalTime()), TIME),
impl((v) -> new ExprTimeValue(sysDate(v.integerValue()).toLocalTime()), TIME, INTEGER)
impl(() -> new ExprTimeValue(formatNow(null).toLocalTime()), TIME)
);
}

Expand All @@ -164,7 +157,7 @@ private FunctionResolver current_time() {

private FunctionResolver curdate(FunctionName functionName) {
return define(functionName,
impl(() -> new ExprDateValue(sysDate(null).toLocalDate()), DATE)
impl(() -> new ExprDateValue(formatNow(null).toLocalDate()), DATE)
);
}

Expand Down Expand Up @@ -832,17 +825,15 @@ private ExprValue exprYear(ExprValue date) {
}

/**
* Prepare LocalDateTime value.
* @param supplier A function which returns LocalDateTime to format.
* Prepare LocalDateTime value. Truncate fractional second part according to the argument.
* @param fsp argument is given to specify a fractional seconds precision from 0 to 6,
* the return value includes a fractional seconds part of that many digits.
* @return LocalDateTime object.
*/
private LocalDateTime formatLocalDateTime(Supplier<LocalDateTime> supplier,
@Nullable Integer fsp) {
var res = supplier.get();
private LocalDateTime formatNow(@Nullable Integer fsp) {
var res = LocalDateTime.now();
if (fsp == null) {
return res;
fsp = 0;
}
var defaultPrecision = 9; // There are 10^9 nanoseconds in one second
if (fsp < 0 || fsp > 6) { // Check that the argument is in the allowed range [0, 6]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.opensearch.sql.ast.dsl.AstDSL.field;
import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral;
import static org.opensearch.sql.ast.dsl.AstDSL.function;
Expand All @@ -27,14 +29,10 @@
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.dsl.AstDSL;
Expand All @@ -53,6 +51,7 @@
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.config.ExpressionConfig;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -552,47 +551,45 @@ public void match_phrase_prefix_all_params() {
);
}

private static Stream<Arguments> functionNames() {
var dsl = new DSL(new ExpressionConfig().functionRepository());
return Stream.of(
Arguments.of((Function<Expression[], FunctionExpression>)dsl::now,
"now", true),
Arguments.of((Function<Expression[], FunctionExpression>)dsl::current_timestamp,
"current_timestamp", true),
Arguments.of((Function<Expression[], FunctionExpression>)dsl::localtimestamp,
"localtimestamp", true),
Arguments.of((Function<Expression[], FunctionExpression>)dsl::localtime,
"localtime", true),
Arguments.of((Function<Expression[], FunctionExpression>)dsl::sysdate,
"sysdate", true),
Arguments.of((Function<Expression[], FunctionExpression>)dsl::curtime,
"curtime", true),
Arguments.of((Function<Expression[], FunctionExpression>)dsl::current_time,
"current_time", true),
Arguments.of((Function<Expression[], FunctionExpression>)dsl::curdate,
"curdate", false),
Arguments.of((Function<Expression[], FunctionExpression>)dsl::current_date,
"current_date", false));
}

@ParameterizedTest(name = "{1}")
@MethodSource("functionNames")
public void now_like_functions(Function<Expression[], FunctionExpression> function,
String name,
Boolean hasFsp) {
assertAnalyzeEqual(
function.apply(new Expression[]{}),
AstDSL.function(name));

if (hasFsp) {
assertAnalyzeEqual(
function.apply(new Expression[]{DSL.ref("integer_value", INTEGER)}),
AstDSL.function(name, field("integer_value")));

assertAnalyzeEqual(
function.apply(new Expression[]{DSL.literal(3)}),
AstDSL.function(name, intLiteral(3)));
}
@Test
public void constant_function_is_calculated_on_analyze() {
// Actually, we can call any function as ConstantFunction to be calculated on analyze stage
assertTrue(analyze(AstDSL.constantFunction("now")) instanceof LiteralExpression);
assertTrue(analyze(AstDSL.constantFunction("localtime")) instanceof LiteralExpression);
}

@Test
public void function_isnt_calculated_on_analyze() {
assertTrue(analyze(function("now")) instanceof FunctionExpression);
assertTrue(analyze(AstDSL.function("localtime")) instanceof FunctionExpression);
}

@Test
public void constant_function_returns_constant_cached_value() {
var values = List.of(analyze(AstDSL.constantFunction("now")),
analyze(AstDSL.constantFunction("now")), analyze(AstDSL.constantFunction("now")));
assertTrue(values.stream().allMatch(v ->
v.valueOf(null) == analyze(AstDSL.constantFunction("now")).valueOf(null)));
}

@Test
public void function_returns_non_constant_value() {
// Even a function returns the same values - they are calculated on each call
// `sysdate()` which returns `LocalDateTime.now()` shouldn't be cached and should return always
// different values
var values = List.of(analyze(function("sysdate")), analyze(function("sysdate")),
analyze(function("sysdate")), analyze(function("sysdate")));
var referenceValue = analyze(function("sysdate")).valueOf(null);
assertTrue(values.stream().noneMatch(v -> v.valueOf(null) == referenceValue));
}

@Test
public void now_as_a_function_not_cached() {
// // We can call `now()` as a function, in that case nothing should be cached
var values = List.of(analyze(function("now")), analyze(function("now")),
analyze(function("now")), analyze(function("now")));
var referenceValue = analyze(function("now")).valueOf(null);
assertTrue(values.stream().noneMatch(v -> v.valueOf(null) == referenceValue));
}

@Test
Expand Down
Loading

0 comments on commit 425a998

Please sign in to comment.