diff --git a/app/src/main/java/org/astraea/app/common/Utils.java b/app/src/main/java/org/astraea/app/common/Utils.java index c8f0fe3b90..bc770546ac 100644 --- a/app/src/main/java/org/astraea/app/common/Utils.java +++ b/app/src/main/java/org/astraea/app/common/Utils.java @@ -17,6 +17,7 @@ package org.astraea.app.common; import java.io.File; +import java.lang.reflect.Constructor; import java.net.InetAddress; import java.net.ServerSocket; import java.nio.file.Files; @@ -34,6 +35,8 @@ import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.astraea.app.cost.CostFunction; +import org.astraea.app.partitioner.Configuration; public final class Utils { @@ -253,5 +256,23 @@ public static CompletableFuture> sequence(Collection futures.stream().map(CompletableFuture::join).collect(Collectors.toList())); } + public static CostFunction constructCostFunction( + Class costClass, Configuration configuration) { + for (Constructor constructor : costClass.getConstructors()) { + Class[] types = constructor.getParameterTypes(); + if (types.length == 1 && types[0].isAssignableFrom(Configuration.class)) { + return (CostFunction) Utils.packException(() -> constructor.newInstance(configuration)); + } + } + for (Constructor constructor : costClass.getConstructors()) { + Class[] types = constructor.getParameterTypes(); + if (types.length == 0) { + return (CostFunction) Utils.packException(() -> constructor.newInstance()); + } + } + throw new IllegalArgumentException( + "No suitable constructor found for class " + costClass.getName()); + } + private Utils() {} } diff --git a/app/src/test/java/org/astraea/app/common/UtilsTest.java b/app/src/test/java/org/astraea/app/common/UtilsTest.java index 9ff5b8863f..47db461dc4 100644 --- a/app/src/test/java/org/astraea/app/common/UtilsTest.java +++ b/app/src/test/java/org/astraea/app/common/UtilsTest.java @@ -24,8 +24,12 @@ import java.util.concurrent.ExecutionException; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.astraea.app.cost.CostFunction; +import org.astraea.app.partitioner.Configuration; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; public class UtilsTest { @@ -82,4 +86,41 @@ void testNonEmpty() { Assertions.assertThrows(IllegalArgumentException.class, () -> Utils.requireNonEmpty("")); Assertions.assertThrows(NullPointerException.class, () -> Utils.requireNonEmpty(null)); } + + public static class TestConfigCostFunction implements CostFunction { + public TestConfigCostFunction(Configuration configuration) {} + } + + public static class TestCostFunction implements CostFunction { + public TestCostFunction() {} + } + + public static class TestBadCostFunction implements CostFunction { + public TestBadCostFunction(int value) {} + } + + @ParameterizedTest + @ValueSource(classes = {TestCostFunction.class, TestConfigCostFunction.class}) + void constructCostFunction(Class aClass) { + // arrange + var config = Configuration.of(Map.of()); + + // act + var costFunction = Utils.constructCostFunction(aClass, config); + + // assert + Assertions.assertInstanceOf(CostFunction.class, costFunction); + Assertions.assertInstanceOf(aClass, costFunction); + } + + @Test + void constructCostFunctionException() { + // arrange + var aClass = TestBadCostFunction.class; + var config = Configuration.of(Map.of()); + + // act, assert + Assertions.assertThrows( + IllegalArgumentException.class, () -> Utils.constructCostFunction(aClass, config)); + } }