Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix](Nereids)add nereids load function in read fields of GlobalFunctionMgr and Database #23249

Merged
merged 3 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ public void readFields(DataInput in) throws IOException {
dbState = DbState.valueOf(Text.readString(in));
attachDbName = Text.readString(in);

FunctionUtil.readFields(in, name2Function);
FunctionUtil.readFields(in, this.getFullName(), name2Function);

// read encryptKeys
if (Env.getCurrentEnvJournalVersion() >= FeMetaVersion.VERSION_102) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ public String getCandidateHint(String name, List<FunctionBuilder> candidateBuild
.collect(Collectors.joining(", ", "[", "]"));
}


public void addUdf(String dbName, String name, UdfBuilder builder) {
if (dbName == null) {
dbName = GLOBAL_FUNCTION;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ public static void write(DataOutput out, ConcurrentMap<String, ImmutableList<Fun
}
}

public static void readFields(DataInput in, ConcurrentMap<String, ImmutableList<Function>> name2Function)
public static void readFields(DataInput in, String dbName,
ConcurrentMap<String, ImmutableList<Function>> name2Function)
throws IOException {
int numEntries = in.readInt();
for (int i = 0; i < numEntries; ++i) {
Expand All @@ -191,7 +192,11 @@ public static void readFields(DataInput in, ConcurrentMap<String, ImmutableList<
for (int j = 0; j < numFunctions; ++j) {
builder.add(Function.read(in));
}
name2Function.put(name, builder.build());
ImmutableList<Function> functions = builder.build();
name2Function.put(name, functions);
for (Function f : functions) {
translateToNereids(dbName, f);
}
}
}

Expand Down Expand Up @@ -234,7 +239,8 @@ public static boolean translateToNereids(String dbName, Function function) {
JavaUdaf.translateToNereidsFunction(dbName, ((AggregateFunction) function));
}
} catch (Exception e) {
LOG.warn("Nereids create function {}:{} failed", dbName, function.getFunctionName().getFunction(), e);
LOG.warn("Nereids create function {}:{} failed, caused by: {}", dbName == null ? "_global_" : dbName,
function.getFunctionName().getFunction(), e);
}
return true;
}
Expand All @@ -246,7 +252,8 @@ public static boolean dropFromNereids(String dbName, FunctionSearchDesc function
.collect(Collectors.toList());
Env.getCurrentEnv().getFunctionRegistry().dropUdf(dbName, fnName, argTypes);
} catch (Exception e) {
LOG.warn("Nereids drop function {}:{} failed", dbName, function.getName(), e);
LOG.warn("Nereids drop function {}:{} failed, caused by: {}", dbName == null ? "_global_" : dbName,
function.getName(), e);
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public void write(DataOutput out) throws IOException {
@Override
public void readFields(DataInput in) throws IOException {
super.readFields(in);
FunctionUtil.readFields(in, name2Function);
FunctionUtil.readFields(in, null, name2Function);
}

public synchronized void addFunction(Function function, boolean ifNotExists) throws UserException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,15 @@
import org.apache.doris.nereids.util.PlanPatternMatchSupported;
import org.apache.doris.utframe.TestWithFeService;

import avro.shaded.com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;

public class UdfTest extends TestWithFeService implements PlanPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
Expand Down Expand Up @@ -171,4 +177,20 @@ public void testParameterUseMoreThanOneTime() throws Exception {
&& relation.getProjects().get(0).child(0).equals(expected))
);
}

@Test
public void testReadFromStream() throws Exception {
createFunction("create global alias function f8(int) with parameter(n) as hours_add(now(3), n)");
Env.getCurrentEnv().getFunctionRegistry().dropUdf(null, "f8",
ImmutableList.of(IntegerType.INSTANCE));

ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
Env.getCurrentEnv().getGlobalFunctionMgr().write(new DataOutputStream(outputStream));
byte[] buffer = outputStream.toByteArray();
ByteArrayInputStream inputStream = new ByteArrayInputStream(buffer);
Env.getCurrentEnv().getGlobalFunctionMgr().readFields(new DataInputStream(inputStream));

Assertions.assertEquals(1, Env.getCurrentEnv().getFunctionRegistry()
.findUdfBuilder(connectContext.getDatabase(), "f8").size());
}
}
Loading