diff --git a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java index 9e23d2c0412c3..9f10739486238 100644 --- a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java +++ b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java @@ -29,15 +29,13 @@ public class InstrumentationServiceImpl implements InstrumentationService { @Override - public Instrumenter newInstrumenter(Map checkMethods) { - return InstrumenterImpl.create(checkMethods); + public Instrumenter newInstrumenter(Class clazz, Map methods) { + return InstrumenterImpl.create(clazz, methods); } @Override - public Map lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, - IOException { + public Map lookupMethods(Class checkerClass) throws IOException { var methodsToInstrument = new HashMap(); - var checkerClass = Class.forName(entitlementCheckerClassName); var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass); ClassReader reader = new ClassReader(classFileInfo.bytecodes()); ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) { diff --git a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java index 57e30c01c5c28..00efab829b2bb 100644 --- a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java +++ b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java @@ -58,30 +58,14 @@ public class InstrumenterImpl implements Instrumenter { this.checkMethods = checkMethods; } - static String getCheckerClassName() { - int javaVersion = Runtime.version().feature(); - final String classNamePrefix; - if (javaVersion >= 23) { - classNamePrefix = "Java23"; - } else { - classNamePrefix = ""; - } - return "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker"; - } - - public static InstrumenterImpl create(Map checkMethods) { - String checkerClass = getCheckerClassName(); - String handleClass = checkerClass + "Handle"; - String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass)); + public static InstrumenterImpl create(Class checkerClass, Map checkMethods) { + Type checkerClassType = Type.getType(checkerClass); + String handleClass = checkerClassType.getInternalName() + "Handle"; + String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(checkerClassType); return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods); } - public ClassFileInfo instrumentClassFile(Class clazz) throws IOException { - ClassFileInfo initial = getClassFileInfo(clazz); - return new ClassFileInfo(initial.fileName(), instrumentClass(Type.getInternalName(clazz), initial.bytecodes())); - } - - public static ClassFileInfo getClassFileInfo(Class clazz) throws IOException { + static ClassFileInfo getClassFileInfo(Class clazz) throws IOException { String internalName = Type.getInternalName(clazz); String fileName = "/" + internalName + ".class"; byte[] originalBytecodes; @@ -306,5 +290,5 @@ protected void pushEntitlementChecker(MethodVisitor mv) { mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", getCheckerClassMethodDescriptor, false); } - public record ClassFileInfo(String fileName, byte[] bytecodes) {} + record ClassFileInfo(String fileName, byte[] bytecodes) {} } diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java index 9ccb72637d463..e3285cec8f883 100644 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java +++ b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java @@ -51,8 +51,8 @@ interface TestCheckerCtors { void check$org_example_TestTargetClass$(Class clazz, int x, String y); } - public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException { - Map checkMethods = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName()); + public void testInstrumentationTargetLookup() throws IOException { + Map checkMethods = instrumentationService.lookupMethods(TestChecker.class); assertThat(checkMethods, aMapWithSize(3)); assertThat( @@ -116,8 +116,8 @@ public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundE ); } - public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException { - Map checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName()); + public void testInstrumentationTargetLookupWithOverloads() throws IOException { + Map checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class); assertThat(checkMethods, aMapWithSize(2)); assertThat( @@ -148,8 +148,8 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException, C ); } - public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException { - Map checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName()); + public void testInstrumentationTargetLookupWithCtors() throws IOException { + Map checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class); assertThat(checkMethods, aMapWithSize(2)); assertThat( diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java index 75102b0bf260d..e9af1d152dd35 100644 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java +++ b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java @@ -12,31 +12,64 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.entitlement.instrumentation.CheckMethod; import org.elasticsearch.entitlement.instrumentation.MethodKey; +import org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.ClassFileInfo; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.test.ESTestCase; +import org.junit.Before; import org.objectweb.asm.Type; +import java.io.IOException; +import java.lang.reflect.AccessFlag; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; import java.lang.reflect.InvocationTargetException; -import java.util.List; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text; import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo; -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod; -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod; -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.Matchers.equalTo; /** - * This tests {@link InstrumenterImpl} with some ad-hoc instrumented method and checker methods, to allow us to check - * some ad-hoc test cases (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.) + * This tests {@link InstrumenterImpl} can instrument various method signatures + * (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.) */ @ESTestCase.WithoutSecurityManager public class InstrumenterTests extends ESTestCase { private static final Logger logger = LogManager.getLogger(InstrumenterTests.class); + static class TestLoader extends ClassLoader { + final byte[] testClassBytes; + final Class testClass; + + TestLoader(String testClassName, byte[] testClassBytes) { + super(InstrumenterTests.class.getClassLoader()); + this.testClassBytes = testClassBytes; + this.testClass = defineClass(testClassName, testClassBytes, 0, testClassBytes.length); + } + + Method getSameMethod(Method method) { + try { + return testClass.getMethod(method.getName(), method.getParameterTypes()); + } catch (NoSuchMethodException e) { + throw new AssertionError(e); + } + } + + Constructor getSameConstructor(Constructor ctor) { + try { + return testClass.getConstructor(ctor.getParameterTypes()); + } catch (NoSuchMethodException e) { + throw new AssertionError(e); + } + } + } + /** * Contains all the virtual methods from {@link TestClassToInstrument}, * allowing this test to call them on the dynamically loaded instrumented class. @@ -80,13 +113,15 @@ public static void anotherStaticMethod(int arg) {} public interface MockEntitlementChecker { void checkSomeStaticMethod(Class clazz, int arg); - void checkSomeStaticMethod(Class clazz, int arg, String anotherArg); + void checkSomeStaticMethodOverload(Class clazz, int arg, String anotherArg); + + void checkAnotherStaticMethod(Class clazz, int arg); void checkSomeInstanceMethod(Class clazz, Testable that, int arg, String anotherArg); void checkCtor(Class clazz); - void checkCtor(Class clazz, int arg); + void checkCtorOverload(Class clazz, int arg); } public static class TestEntitlementCheckerHolder { @@ -105,6 +140,7 @@ public static class TestEntitlementChecker implements MockEntitlementChecker { volatile boolean isActive; int checkSomeStaticMethodIntCallCount = 0; + int checkAnotherStaticMethodIntCallCount = 0; int checkSomeStaticMethodIntStringCallCount = 0; int checkSomeInstanceMethodCallCount = 0; @@ -120,28 +156,33 @@ private void throwIfActive() { @Override public void checkSomeStaticMethod(Class callerClass, int arg) { checkSomeStaticMethodIntCallCount++; - assertSame(TestMethodUtils.class, callerClass); + assertSame(InstrumenterTests.class, callerClass); assertEquals(123, arg); throwIfActive(); } @Override - public void checkSomeStaticMethod(Class callerClass, int arg, String anotherArg) { + public void checkSomeStaticMethodOverload(Class callerClass, int arg, String anotherArg) { checkSomeStaticMethodIntStringCallCount++; - assertSame(TestMethodUtils.class, callerClass); + assertSame(InstrumenterTests.class, callerClass); assertEquals(123, arg); assertEquals("abc", anotherArg); throwIfActive(); } + @Override + public void checkAnotherStaticMethod(Class callerClass, int arg) { + checkAnotherStaticMethodIntCallCount++; + assertSame(InstrumenterTests.class, callerClass); + assertEquals(123, arg); + throwIfActive(); + } + @Override public void checkSomeInstanceMethod(Class callerClass, Testable that, int arg, String anotherArg) { checkSomeInstanceMethodCallCount++; assertSame(InstrumenterTests.class, callerClass); - assertThat( - that.getClass().getName(), - startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$TestClassToInstrument") - ); + assertThat(that.getClass().getName(), equalTo(TestClassToInstrument.class.getName())); assertEquals(123, arg); assertEquals("def", anotherArg); throwIfActive(); @@ -155,7 +196,7 @@ public void checkCtor(Class callerClass) { } @Override - public void checkCtor(Class callerClass, int arg) { + public void checkCtorOverload(Class callerClass, int arg) { checkCtorIntCallCount++; assertSame(InstrumenterTests.class, callerClass); assertEquals(123, arg); @@ -163,206 +204,83 @@ public void checkCtor(Class callerClass, int arg) { } } - public void testClassIsInstrumented() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class); - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), - checkMethod - ); - - var instrumenter = createInstrumenter(checkMethods); - - byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); - - if (logger.isTraceEnabled()) { - logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); - } - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW", - newBytecode - ); + @Before + public void resetInstance() { + TestEntitlementCheckerHolder.checkerInstance = new TestEntitlementChecker(); + } - TestEntitlementCheckerHolder.checkerInstance.isActive = false; + public void testStaticMethod() throws Exception { + Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class); + TestLoader loader = instrumentTestClass(createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod))); // Before checking is active, nothing should throw - callStaticMethod(newClass, "someStaticMethod", 123); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - + assertStaticMethod(loader, targetMethod, 123); // After checking is activated, everything should throw - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); + assertStaticMethodThrows(loader, targetMethod, 123); } - public void testClassIsNotInstrumentedTwice() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class); - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), - checkMethod - ); - - var instrumenter = createInstrumenter(checkMethods); - - InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument); - var internalClassName = Type.getInternalName(classToInstrument); - - byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes()); - byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode); + public void testNotInstrumentedTwice() throws Exception { + Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class); + var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod)); - logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode))); + var loader1 = instrumentTestClass(instrumenter); + byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(TestClassToInstrument.class.getName(), loader1.testClassBytes); logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode))); + var loader2 = new TestLoader(TestClassToInstrument.class.getName(), instrumentedTwiceBytecode); - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW_NEW", - instrumentedTwiceBytecode - ); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0; - - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); + assertStaticMethodThrows(loader2, targetMethod, 123); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount); } - public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class); - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), - checkMethod, - methodKeyForTarget(classToInstrument.getMethod("anotherStaticMethod", int.class)), - checkMethod - ); - - var instrumenter = createInstrumenter(checkMethods); - - InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument); - var internalClassName = Type.getInternalName(classToInstrument); + public void testMultipleMethods() throws Exception { + Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class); + Method targetMethod2 = TestClassToInstrument.class.getMethod("anotherStaticMethod", int.class); - byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes()); - byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode); + var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod1, "checkAnotherStaticMethod", targetMethod2)); + var loader = instrumentTestClass(instrumenter); - logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode))); - logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode))); - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW_NEW", - instrumentedTwiceBytecode - ); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0; - - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); + assertStaticMethodThrows(loader, targetMethod1, 123); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount); - - assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherStaticMethod", 123)); - assertEquals(2, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount); + assertStaticMethodThrows(loader, targetMethod2, 123); + assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkAnotherStaticMethodIntCallCount); } - public void testInstrumenterWorksWithOverloads() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), - getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class), - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)), - getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class) + public void testStaticMethodOverload() throws Exception { + Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class); + Method targetMethod2 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class, String.class); + var instrumenter = createInstrumenter( + Map.of("checkSomeStaticMethod", targetMethod1, "checkSomeStaticMethodOverload", targetMethod2) ); + var loader = instrumentTestClass(instrumenter); - var instrumenter = createInstrumenter(checkMethods); - - byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); - - if (logger.isTraceEnabled()) { - logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); - } - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW", - newBytecode - ); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0; - TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount = 0; - - // After checking is activated, everything should throw - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc")); - + assertStaticMethodThrows(loader, targetMethod1, 123); + assertStaticMethodThrows(loader, targetMethod2, 123, "abc"); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount); } - public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)), - getCheckMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class) - ); - - var instrumenter = createInstrumenter(checkMethods); - - byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); - - if (logger.isTraceEnabled()) { - logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); - } - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW", - newBytecode - ); + public void testInstanceMethodOverload() throws Exception { + Method targetMethod = TestClassToInstrument.class.getMethod("someMethod", int.class, String.class); + var instrumenter = createInstrumenter(Map.of("checkSomeInstanceMethod", targetMethod)); + var loader = instrumentTestClass(instrumenter); TestEntitlementCheckerHolder.checkerInstance.isActive = true; - TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount = 0; - - Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance()); + Testable testTargetClass = (Testable) (loader.testClass.getConstructor().newInstance()); // This overload is not instrumented, so it will not throw testTargetClass.someMethod(123); - assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def")); + expectThrows(TestException.class, () -> testTargetClass.someMethod(123, "def")); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount); } - public void testInstrumenterWorksWithConstructors() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - Map checkMethods = Map.of( - new MethodKey(classToInstrument.getName().replace('.', '/'), "", List.of()), - getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class), - new MethodKey(classToInstrument.getName().replace('.', '/'), "", List.of("I")), - getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class) - ); - - var instrumenter = createInstrumenter(checkMethods); - - byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); - - if (logger.isTraceEnabled()) { - logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); - } - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW", - newBytecode - ); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - - var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance()); - assertThat(ex.getCause(), instanceOf(TestException.class)); - var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123)); - assertThat(ex2.getCause(), instanceOf(TestException.class)); + public void testConstructors() throws Exception { + Constructor ctor1 = TestClassToInstrument.class.getConstructor(); + Constructor ctor2 = TestClassToInstrument.class.getConstructor(int.class); + var loader = instrumentTestClass(createInstrumenter(Map.of("checkCtor", ctor1, "checkCtorOverload", ctor2))); + assertCtorThrows(loader, ctor1); + assertCtorThrows(loader, ctor2, 123); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorCallCount); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorIntCallCount); } @@ -373,11 +291,107 @@ public void testInstrumenterWorksWithConstructors() throws Exception { * MethodKey and instrumentationMethod with slightly different signatures (using the common interface * Testable) which is not what would happen when it's run by the agent. */ - private InstrumenterImpl createInstrumenter(Map checkMethods) { + private static InstrumenterImpl createInstrumenter(Map methods) throws NoSuchMethodException { + Map checkMethods = new HashMap<>(); + for (var entry : methods.entrySet()) { + checkMethods.put(getMethodKey(entry.getValue()), getCheckMethod(entry.getKey(), entry.getValue())); + } String checkerClass = Type.getInternalName(InstrumenterTests.MockEntitlementChecker.class); String handleClass = Type.getInternalName(InstrumenterTests.TestEntitlementCheckerHolder.class); String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass)); - return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "_NEW", checkMethods); + return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods); + } + + private static TestLoader instrumentTestClass(InstrumenterImpl instrumenter) throws IOException { + var clazz = TestClassToInstrument.class; + ClassFileInfo initial = getClassFileInfo(clazz); + byte[] newBytecode = instrumenter.instrumentClass(Type.getInternalName(clazz), initial.bytecodes()); + if (logger.isTraceEnabled()) { + logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); + } + return new TestLoader(clazz.getName(), newBytecode); + } + + private static MethodKey getMethodKey(Executable method) { + logger.info("method key: {}", method.getName()); + String methodName = method instanceof Constructor ? "" : method.getName(); + return new MethodKey( + Type.getInternalName(method.getDeclaringClass()), + methodName, + Stream.of(method.getParameterTypes()).map(Type::getType).map(Type::getInternalName).toList() + ); + } + + private static CheckMethod getCheckMethod(String methodName, Executable targetMethod) throws NoSuchMethodException { + Set flags = targetMethod.accessFlags(); + boolean isInstance = flags.contains(AccessFlag.STATIC) == false && targetMethod instanceof Method; + int extraArgs = 1; // caller class + if (isInstance) { + ++extraArgs; + } + Class[] targetParameterTypes = targetMethod.getParameterTypes(); + Class[] checkParameterTypes = new Class[targetParameterTypes.length + extraArgs]; + checkParameterTypes[0] = Class.class; + if (isInstance) { + checkParameterTypes[1] = Testable.class; + } + System.arraycopy(targetParameterTypes, 0, checkParameterTypes, extraArgs, targetParameterTypes.length); + var checkMethod = MockEntitlementChecker.class.getMethod(methodName, checkParameterTypes); + return new CheckMethod( + Type.getInternalName(MockEntitlementChecker.class), + checkMethod.getName(), + Arrays.stream(Type.getArgumentTypes(checkMethod)).map(Type::getDescriptor).toList() + ); + } + + private static void unwrapInvocationException(InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof TestException n) { + // Sometimes we're expecting this one! + throw n; + } else { + throw new AssertionError(cause); + } + } + + /** + * Calling a static method of a dynamically loaded class is significantly more cumbersome + * than calling a virtual method. + */ + static void callStaticMethod(Method method, Object... args) { + try { + method.invoke(null, args); + } catch (InvocationTargetException e) { + unwrapInvocationException(e); + } catch (IllegalAccessException e) { + throw new AssertionError(e); + } + } + + private void assertStaticMethodThrows(TestLoader loader, Method method, Object... args) { + Method testMethod = loader.getSameMethod(method); + TestEntitlementCheckerHolder.checkerInstance.isActive = true; + expectThrows(TestException.class, () -> callStaticMethod(testMethod, args)); + } + + private void assertStaticMethod(TestLoader loader, Method method, Object... args) { + Method testMethod = loader.getSameMethod(method); + TestEntitlementCheckerHolder.checkerInstance.isActive = false; + callStaticMethod(testMethod, args); + } + + private void assertCtorThrows(TestLoader loader, Constructor ctor, Object... args) { + Constructor testCtor = loader.getSameConstructor(ctor); + TestEntitlementCheckerHolder.checkerInstance.isActive = true; + expectThrows(TestException.class, () -> { + try { + testCtor.newInstance(args); + } catch (InvocationTargetException e) { + unwrapInvocationException(e); + } catch (IllegalAccessException | InstantiationException e) { + throw new AssertionError(e); + } + }); } } diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestLoader.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestLoader.java deleted file mode 100644 index 9eb8e9328ecba..0000000000000 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestLoader.java +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.instrumentation.impl; - -class TestLoader extends ClassLoader { - TestLoader(ClassLoader parent) { - super(parent); - } - - public Class defineClassFromBytes(String name, byte[] bytes) { - return defineClass(name, bytes, 0, bytes.length); - } -} diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestMethodUtils.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestMethodUtils.java deleted file mode 100644 index de7822fea926e..0000000000000 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestMethodUtils.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.instrumentation.impl; - -import org.elasticsearch.entitlement.instrumentation.CheckMethod; -import org.elasticsearch.entitlement.instrumentation.MethodKey; -import org.objectweb.asm.Type; - -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Stream; - -class TestMethodUtils { - - /** - * @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline - */ - static MethodKey methodKeyForTarget(Method targetMethod) { - Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod)); - return new MethodKey( - Type.getInternalName(targetMethod.getDeclaringClass()), - targetMethod.getName(), - Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList() - ); - } - - static MethodKey methodKeyForConstructor(Class classToInstrument, List params) { - return new MethodKey(classToInstrument.getName().replace('.', '/'), "", params); - } - - static CheckMethod getCheckMethod(Class clazz, String methodName, Class... parameterTypes) throws NoSuchMethodException { - var method = clazz.getMethod(methodName, parameterTypes); - return new CheckMethod( - Type.getInternalName(clazz), - method.getName(), - Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList() - ); - } - - /** - * Calling a static method of a dynamically loaded class is significantly more cumbersome - * than calling a virtual method. - */ - static void callStaticMethod(Class c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException { - try { - c.getMethod(methodName, int.class).invoke(null, arg); - } catch (InvocationTargetException e) { - Throwable cause = e.getCause(); - if (cause instanceof TestException n) { - // Sometimes we're expecting this one! - throw n; - } else { - throw new AssertionError(cause); - } - } - } - - static void callStaticMethod(Class c, String methodName, int arg1, String arg2) throws NoSuchMethodException, - IllegalAccessException { - try { - c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2); - } catch (InvocationTargetException e) { - Throwable cause = e.getCause(); - if (cause instanceof TestException n) { - // Sometimes we're expecting this one! - throw n; - } else { - throw new AssertionError(cause); - } - } - } -} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java index 2956efa8eec31..9118f67cdc145 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java @@ -14,6 +14,7 @@ import org.elasticsearch.entitlement.bridge.EntitlementChecker; import org.elasticsearch.entitlement.instrumentation.CheckMethod; import org.elasticsearch.entitlement.instrumentation.InstrumentationService; +import org.elasticsearch.entitlement.instrumentation.Instrumenter; import org.elasticsearch.entitlement.instrumentation.MethodKey; import org.elasticsearch.entitlement.instrumentation.Transformer; import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker; @@ -64,13 +65,12 @@ public static EntitlementChecker checker() { public static void initialize(Instrumentation inst) throws Exception { manager = initChecker(); - Map checkMethods = INSTRUMENTER_FACTORY.lookupMethodsToInstrument( - "org.elasticsearch.entitlement.bridge.EntitlementChecker" - ); + Map checkMethods = INSTRUMENTER_FACTORY.lookupMethods(EntitlementChecker.class); var classesToTransform = checkMethods.keySet().stream().map(MethodKey::className).collect(Collectors.toSet()); - inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter(checkMethods), classesToTransform), true); + Instrumenter instrumenter = INSTRUMENTER_FACTORY.newInstrumenter(EntitlementChecker.class, checkMethods); + inst.addTransformer(new Transformer(instrumenter, classesToTransform), true); // TODO: should we limit this array somehow? var classesToRetransform = classesToTransform.stream().map(EntitlementInitialization::internalNameToClass).toArray(Class[]::new); inst.retransformClasses(classesToRetransform); diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java index d0331d756d2b2..66d8ad9488cfa 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java @@ -16,7 +16,7 @@ * The SPI service entry point for instrumentation. */ public interface InstrumentationService { - Instrumenter newInstrumenter(Map checkMethods); + Instrumenter newInstrumenter(Class clazz, Map methods); - Map lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException; + Map lookupMethods(Class clazz) throws IOException; }