diff --git a/extensions/netty-loom-adaptor/deployment/src/main/java/io/quarkus/netty/loom/adaptor/NettyLoomAdaptorProcessor.java b/extensions/netty-loom-adaptor/deployment/src/main/java/io/quarkus/netty/loom/adaptor/NettyLoomAdaptorProcessor.java index bb0c9168ec2443..8a4aca522db0b4 100644 --- a/extensions/netty-loom-adaptor/deployment/src/main/java/io/quarkus/netty/loom/adaptor/NettyLoomAdaptorProcessor.java +++ b/extensions/netty-loom-adaptor/deployment/src/main/java/io/quarkus/netty/loom/adaptor/NettyLoomAdaptorProcessor.java @@ -41,6 +41,9 @@ import org.jboss.jandex.DotName; import org.jboss.logging.Logger; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; import io.quarkus.builder.item.EmptyBuildItem; import io.quarkus.deployment.annotations.BuildProducer; @@ -54,10 +57,6 @@ import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem; import io.smallrye.common.annotation.RunOnVirtualThread; -import org.objectweb.asm.ClassVisitor; -import org.objectweb.asm.Label; -import org.objectweb.asm.MethodVisitor; - public class NettyLoomAdaptorProcessor { static Logger LOG = Logger.getLogger(NettyLoomAdaptorProcessor.class); @@ -66,6 +65,76 @@ public FeatureBuildItem feature() { return new FeatureBuildItem("netty-Loom-adaptor"); } + /** + * This extension is designed to stop using Netty's {@link io.netty.buffer.PooledByteBufAllocator.PoolThreadLocalCache + * PoolThreadLocalCache}, extending {@link io.netty.util.concurrent.FastThreadLocal FastThreadLocal} in the + * {@link io.netty.buffer.PooledByteBufAllocator#newDirectBuffer(int, int)} newDirectBuffer(int,int)} method and to replace + * them with a {@link java.util.concurrent.ConcurrentHashMap ConcurrentHashMap} using the carrier thread's name as a key. + * + * we want to instrument the source ({@link io.netty.buffer.PooledByteBufAllocator#newDirectBuffer(int, int)} + * newDirectBuffer(int,int)} to get: + * + * protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + * boolean isVirtual = false; + * PoolThreadCache cache=null; + * if(canUseVirtual){ + * try { + * isVirtual = (boolean) isVirtualMethod.invoke(Thread.currentThread()); + * } catch (IllegalAccessException | InvocationTargetException e) { + * System.err.println(e); + * } + * if(isVirtual){ + * cache = createCache(initialCapacity, maxCapacity); + * } + * } + * if(cache == null){ + * cache = threadCache.get(); + * } + * PoolArena directArena = cache.directArena; + * + * final ByteBuf buf; + * if (directArena != null) { + * buf = directArena.allocate(cache, initialCapacity, maxCapacity); + * } else { + * buf = PlatformDependent.hasUnsafe() ? + * UnsafeByteBufUtil.newUnsafeDirectByteBuf(this, initialCapacity, maxCapacity) : + * new UnpooledDirectByteBuf(this, initialCapacity, maxCapacity); + * } + * + * return toLeakAwareBuffer(buf); + * } + * + * private PoolThreadCache createCache(int initialCapacity, int maxCapacity){ + * PoolThreadCache cache; + * Thread currentCarrierThread; + * try { + * currentCarrierThread = (Thread) getCurrentCarrierMethod.invoke(null); + * } catch (InvocationTargetException | IllegalAccessException e) { + * System.out.println(e); + * return null; + * } + * if(threadCaches.containsKey(currentCarrierThread)){ + * return threadCaches.get(currentCarrierThread); + * }else{ + * PoolArena heapArena = leastUsedArena(heapArenas); + * PoolArena directArena = leastUsedArena(directArenas); + * + * cache = new PoolThreadCache( + * heapArena, directArena, smallCacheSize, normalCacheSize, + * DEFAULT_MAX_CACHED_BUFFER_CAPACITY, DEFAULT_CACHE_TRIM_INTERVAL); + * threadCaches.put(currentCarrierThread, cache); + * if (DEFAULT_CACHE_TRIM_INTERVAL_MILLIS > 0) { + * EventExecutor executor = ThreadExecutorMap.currentExecutor(); + * if (executor != null) { + * executor.scheduleAtFixedRate(trimTask, DEFAULT_CACHE_TRIM_INTERVAL_MILLIS, + * DEFAULT_CACHE_TRIM_INTERVAL_MILLIS, TimeUnit.MILLISECONDS); + * } + * } + * } + * return cache; + * } + * + */ @Produce(EmptyBuildItem.class) @Consume(MinNettyAllocatorMaxOrderBuildItem.class) @BuildStep @@ -103,6 +172,8 @@ public MethodVisitor visitMethod( if (cv != null) { MethodVisitor mv = cv.visitMethod(access, name, descriptor, signature, exceptions); if (name.equals("")) { + // we need to augment the method to assigned the different static fields we added to the + // {@link io.netty.buffer.PooledByteBufAllocator PooledByteBufAllocator} class mv = new MethodVisitor(Gizmo.ASM_API_VERSION, mv) { @Override public void visitInsn(int opcode) { @@ -114,6 +185,7 @@ public void visitInsn(int opcode) { Label LthreadCaches = new Label(); Label lcanUseVirtual = new Label(); + // set canUseVirtual to true mv.visitLabel(lcanUseVirtual); mv.visitInsn(ICONST_1); mv.visitMethodInsn(INVOKESTATIC, "java/lang/Boolean", "valueOf", @@ -121,6 +193,8 @@ public void visitInsn(int opcode) { mv.visitFieldInsn(PUTSTATIC, "io/netty/buffer/PooledByteBufAllocator", "canUseVirtual", "Ljava/lang/Boolean;"); + // fetch the currentCarrierThread method and put it inside the getCurrentCarrierMethod field + // to avoid having to fetch it everytime we need to invoke it mv.visitLabel(L0); mv.visitLdcInsn("java.lang.Thread"); mv.visitMethodInsn(INVOKESTATIC, "java/lang/Class", "forName", @@ -136,8 +210,12 @@ public void visitInsn(int opcode) { mv.visitFieldInsn(GETSTATIC, "io/netty/buffer/PooledByteBufAllocator", "getCurrentCarrierMethod", "Ljava/lang/reflect/Method;"); mv.visitInsn(ICONST_1); + // make it accessible mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/reflect/Method", "setAccessible", "(Z)V", false); + + // fetch the isVirtual method and put it inside the isVirtualMethod field to avoid + // having to fetch it everytime we need to invoke it mv.visitLdcInsn("java.lang.Thread"); mv.visitMethodInsn(INVOKESTATIC, "java/lang/Class", "forName", "(Ljava/lang/String;)Ljava/lang/Class;", false); @@ -153,6 +231,9 @@ public void visitInsn(int opcode) { mv.visitLabel(L1); mv.visitJumpInsn(GOTO, LthreadCaches); + // catch block of reflective calls to fetch isVirtual and currentCarrierThread. + // we set the canUseVirtual field to false if we enter the catch block: + // if these methods can't be found the jdk is not quarkus-loom compliant mv.visitLabel(L2); mv.visitVarInsn(ASTORE, 0); mv.visitInsn(ICONST_0); @@ -161,6 +242,7 @@ public void visitInsn(int opcode) { mv.visitFieldInsn(PUTSTATIC, "io/netty/buffer/PooledByteBufAllocator", "canUseVirtual", "Ljava/lang/Boolean;"); + // create the static concurrentHashMap that will be populated mv.visitLabel(LthreadCaches); mv.visitTypeInsn(NEW, "java/util/concurrent/ConcurrentHashMap"); mv.visitInsn(DUP); @@ -180,6 +262,7 @@ public void visitInsn(int opcode) { return mv; } if (name.equals("newDirectBuffer")) { + // this is the actual method we want to modify mv = new CurrentThreadMethodAdaptor(Gizmo.ASM_API_VERSION, mv); mv.visitMaxs(4, 4); return mv; @@ -189,6 +272,12 @@ public void visitInsn(int opcode) { return null; } + /** + * this method contains logic that was previously in + * {@link io.netty.buffer.PooledByteBufAllocator#newDirectBuffer(int, int)} newDirectBuffer(int, int) + * The FastThreadLocals are used to store thread cache, they are hence created with an initial value that needs a + * {@link io.netty.buffer.PoolArena}, this is + */ public void createLeastUsedArenaMethod() { var L0 = new Label(); var L1 = new Label(); @@ -201,7 +290,7 @@ public void createLeastUsedArenaMethod() { var L8 = new Label(); var L9 = new Label(); var L10 = new Label(); - var mv = cv.visitMethod(2, "leastUsedArena", + var mv = cv.visitMethod(ACC_PRIVATE, "leastUsedArena", "([Lio/netty/buffer/PoolArena;)Lio/netty/buffer/PoolArena;", null, null); mv.visitLabel(L0); mv.visitVarInsn(ALOAD, 1); @@ -274,6 +363,12 @@ public void createLeastUsedArenaMethod() { mv.visitMaxs(2, 5); } + /** + * this method contains logic that was previously in + * {@link io.netty.buffer.PooledByteBufAllocator#newDirectBuffer(int, int)} newDirectBuffer(int, int) + * it was a method of {@link io.netty.buffer.PooledByteBufAllocator.PoolThreadLocalCache PoolThreadLocalCache}, + * we need to reimplement it outside of this subclass that we don't use anymore + */ public void createCacheMethod() { Label L0 = new Label(); Label L1 = new Label(); @@ -650,6 +745,7 @@ public void visitCode() { mv.visitVarInsn(ALOAD, 6); mv.visitMethodInsn(INVOKESTATIC, "io/netty/buffer/PooledByteBufAllocator", "toLeakAwareBuffer", "(Lio/netty/buffer/ByteBuf;)Lio/netty/buffer/ByteBuf;", false); + mv.visitInsn(ARETURN); mv.visitLabel(LEnd);