From 3bcbdec6d48ad476672a8d7250dab60c8370a463 Mon Sep 17 00:00:00 2001 From: Morgen Peschke Date: Wed, 26 Jun 2024 09:51:10 -0700 Subject: [PATCH] Always handle the MDC --- .../slf4j/internal/Slf4jLoggerInternal.scala | 36 +++-- .../internal/Slf4jLoggerInternalSuite.scala | 124 ++++++++++-------- 2 files changed, 84 insertions(+), 76 deletions(-) diff --git a/slf4j/src/main/scala/org/typelevel/log4cats/slf4j/internal/Slf4jLoggerInternal.scala b/slf4j/src/main/scala/org/typelevel/log4cats/slf4j/internal/Slf4jLoggerInternal.scala index ab54f05e..c97e2be5 100644 --- a/slf4j/src/main/scala/org/typelevel/log4cats/slf4j/internal/Slf4jLoggerInternal.scala +++ b/slf4j/src/main/scala/org/typelevel/log4cats/slf4j/internal/Slf4jLoggerInternal.scala @@ -34,6 +34,12 @@ private[slf4j] object Slf4jLoggerInternal { def apply(t: Throwable)(msg: => String): F[Unit] } + // Need this to make sure MDC is correctly cleared before logging + private[this] def noContextLog[F[_]](isEnabled: F[Boolean], logging: () => Unit)(implicit + F: Sync[F] + ): F[Unit] = + contextLog[F](isEnabled, Map.empty, logging) + private[this] def contextLog[F[_]]( isEnabled: F[Boolean], ctx: Map[String, String], @@ -85,55 +91,45 @@ private[slf4j] object Slf4jLoggerInternal { override def isErrorEnabled: F[Boolean] = F.delay(logger.isErrorEnabled) override def trace(t: Throwable)(msg: => String): F[Unit] = - isTraceEnabled - .ifM(F.suspend(sync)(logger.trace(msg, t)), F.unit) + noContextLog(isTraceEnabled, () => logger.trace(msg, t)) override def trace(msg: => String): F[Unit] = - isTraceEnabled - .ifM(F.suspend(sync)(logger.trace(msg)), F.unit) + noContextLog(isTraceEnabled, () => logger.trace(msg)) override def trace(ctx: Map[String, String])(msg: => String): F[Unit] = contextLog(isTraceEnabled, ctx, () => logger.trace(msg)) override def trace(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] = contextLog(isTraceEnabled, ctx, () => logger.trace(msg, t)) override def debug(t: Throwable)(msg: => String): F[Unit] = - isDebugEnabled - .ifM(F.suspend(sync)(logger.debug(msg, t)), F.unit) + noContextLog(isDebugEnabled, () => logger.debug(msg, t)) override def debug(msg: => String): F[Unit] = - isDebugEnabled - .ifM(F.suspend(sync)(logger.debug(msg)), F.unit) + noContextLog(isDebugEnabled, () => logger.debug(msg)) override def debug(ctx: Map[String, String])(msg: => String): F[Unit] = contextLog(isDebugEnabled, ctx, () => logger.debug(msg)) override def debug(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] = contextLog(isDebugEnabled, ctx, () => logger.debug(msg, t)) override def info(t: Throwable)(msg: => String): F[Unit] = - isInfoEnabled - .ifM(F.suspend(sync)(logger.info(msg, t)), F.unit) + noContextLog(isInfoEnabled, () => logger.info(msg, t)) override def info(msg: => String): F[Unit] = - isInfoEnabled - .ifM(F.suspend(sync)(logger.info(msg)), F.unit) + noContextLog(isInfoEnabled, () => logger.info(msg)) override def info(ctx: Map[String, String])(msg: => String): F[Unit] = contextLog(isInfoEnabled, ctx, () => logger.info(msg)) override def info(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] = contextLog(isInfoEnabled, ctx, () => logger.info(msg, t)) override def warn(t: Throwable)(msg: => String): F[Unit] = - isWarnEnabled - .ifM(F.suspend(sync)(logger.warn(msg, t)), F.unit) + noContextLog(isWarnEnabled, () => logger.warn(msg, t)) override def warn(msg: => String): F[Unit] = - isWarnEnabled - .ifM(F.suspend(sync)(logger.warn(msg)), F.unit) + noContextLog(isWarnEnabled, () => logger.warn(msg)) override def warn(ctx: Map[String, String])(msg: => String): F[Unit] = contextLog(isWarnEnabled, ctx, () => logger.warn(msg)) override def warn(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] = contextLog(isWarnEnabled, ctx, () => logger.warn(msg, t)) override def error(t: Throwable)(msg: => String): F[Unit] = - isErrorEnabled - .ifM(F.suspend(sync)(logger.error(msg, t)), F.unit) + noContextLog(isErrorEnabled, () => logger.error(msg, t)) override def error(msg: => String): F[Unit] = - isErrorEnabled - .ifM(F.suspend(sync)(logger.error(msg)), F.unit) + noContextLog(isErrorEnabled, () => logger.error(msg)) override def error(ctx: Map[String, String])(msg: => String): F[Unit] = contextLog(isErrorEnabled, ctx, () => logger.error(msg)) override def error(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] = diff --git a/slf4j/src/test/scala/org/typelevel/log4cats/slf4j/internal/Slf4jLoggerInternalSuite.scala b/slf4j/src/test/scala/org/typelevel/log4cats/slf4j/internal/Slf4jLoggerInternalSuite.scala index b4ed249c..5528edd6 100644 --- a/slf4j/src/test/scala/org/typelevel/log4cats/slf4j/internal/Slf4jLoggerInternalSuite.scala +++ b/slf4j/src/test/scala/org/typelevel/log4cats/slf4j/internal/Slf4jLoggerInternalSuite.scala @@ -60,12 +60,12 @@ class Slf4jLoggerInternalSuite extends CatsEffectSuite { } private def testLoggerFixture( - traceEnabled: Boolean = true, - debugEnabled: Boolean = true, - infoEnabled: Boolean = true, - warnEnabled: Boolean = true, - errorEnabled: Boolean = true - ): SyncIO[FunFixture[JTestLogger]] = + traceEnabled: Boolean = true, + debugEnabled: Boolean = true, + infoEnabled: Boolean = true, + warnEnabled: Boolean = true, + errorEnabled: Boolean = true + ): SyncIO[FunFixture[JTestLogger]] = ResourceFunFixture( Resource.eval( IO( @@ -128,7 +128,8 @@ class Slf4jLoggerInternalSuite extends CatsEffectSuite { test("Slf4jLoggerInternal resets after exceptions") { prepareMDC >> - Slf4jLogger.getLogger[IO] + Slf4jLogger + .getLogger[IO] .info(Map("foo" -> "bar"))(die()) .interceptMessage[IllegalStateException]("dead") >> validateMDC @@ -151,9 +152,13 @@ class Slf4jLoggerInternalSuite extends CatsEffectSuite { testLoggerFixture().test("Slf4jLoggerInternal correctly sets the MDC") { testLogger => prepareMDC >> Slf4jLogger.getLoggerFromSlf4j[IO](testLogger).info(Map("foo" -> "bar"))("A log went here") >> - IO(testLogger.logs()).map(toDeferredLogs).assertEquals(List( - DeferredLogMessage.info(Map("foo" -> "bar"), none, () => "A log went here") - )) >> + IO(testLogger.logs()) + .map(toDeferredLogs) + .assertEquals( + List( + DeferredLogMessage.info(Map("foo" -> "bar"), none, () => "A log went here") + ) + ) >> validateMDC } @@ -162,10 +167,12 @@ class Slf4jLoggerInternalSuite extends CatsEffectSuite { ) { testLogger => prepareMDC >> Slf4jLogger.getLoggerFromSlf4j[IO](testLogger).info(Map("bar" -> "baz"))("A log went here") >> - IO(testLogger.logs()).map(toDeferredLogs).assertEquals( - List(DeferredLogMessage.info(Map("bar" -> "baz"), none, () => "A log went here")), - clue("Context should not include foo->yellow") - ) >> + IO(testLogger.logs()) + .map(toDeferredLogs) + .assertEquals( + List(DeferredLogMessage.info(Map("bar" -> "baz"), none, () => "A log went here")), + clue("Context should not include foo->yellow") + ) >> validateMDC } @@ -246,7 +253,8 @@ class Slf4jLoggerInternalSuite extends CatsEffectSuite { warnEnabled = false, errorEnabled = false ).test("Slf4jLoggerInternal.withModifiedString is still lazy") { testLogger => - val slf4jLogger = Slf4jLogger.getLoggerFromSlf4j[IO](testLogger).withModifiedString(_.toUpperCase) + val slf4jLogger = + Slf4jLogger.getLoggerFromSlf4j[IO](testLogger).withModifiedString(_.toUpperCase) val ctx = tag("lazy") // If these are lazy the way they need to be, the message won't be evaluated until // after the log level has been checked @@ -331,51 +339,55 @@ class Slf4jLoggerInternalSuite extends CatsEffectSuite { validateMDC } - testLoggerFixture().test("Slf4jLoggerInternal gets the dispatching right (msg + error)") { testLogger => - val slf4jLogger = Slf4jLogger.getLoggerFromSlf4j[IO](testLogger) - prepareMDC >> - slf4jLogger.trace(throwable)("trace").assert >> - slf4jLogger.debug(throwable)("debug").assert >> - slf4jLogger.info(throwable)("info").assert >> - slf4jLogger.warn(throwable)("warn").assert >> - slf4jLogger.error(throwable)("error").assert >> - IO(testLogger.logs()) - .map(toDeferredLogs) - .assertEquals( - List( - DeferredLogMessage.trace(Map.empty, throwable.some, () => "trace"), - DeferredLogMessage.debug(Map.empty, throwable.some, () => "debug"), - DeferredLogMessage.info(Map.empty, throwable.some, () => "info"), - DeferredLogMessage.warn(Map.empty, throwable.some, () => "warn"), - DeferredLogMessage.error(Map.empty, throwable.some, () => "error") - ) - ) >> - validateMDC + testLoggerFixture().test("Slf4jLoggerInternal gets the dispatching right (msg + error)") { + testLogger => + val slf4jLogger = Slf4jLogger.getLoggerFromSlf4j[IO](testLogger) + prepareMDC >> + slf4jLogger.trace(throwable)("trace").assert >> + slf4jLogger.debug(throwable)("debug").assert >> + slf4jLogger.info(throwable)("info").assert >> + slf4jLogger.warn(throwable)("warn").assert >> + slf4jLogger.error(throwable)("error").assert >> + IO(testLogger.logs()) + .map(toDeferredLogs) + .assertEquals( + List( + DeferredLogMessage.trace(Map.empty, throwable.some, () => "trace"), + DeferredLogMessage.debug(Map.empty, throwable.some, () => "debug"), + DeferredLogMessage.info(Map.empty, throwable.some, () => "info"), + DeferredLogMessage.warn(Map.empty, throwable.some, () => "warn"), + DeferredLogMessage.error(Map.empty, throwable.some, () => "error") + ) + ) >> + validateMDC } - testLoggerFixture().test("Slf4jLoggerInternal gets the dispatching right (msg + context)") { testLogger => - val slf4jLogger = Slf4jLogger.getLoggerFromSlf4j[IO](testLogger) - prepareMDC >> - slf4jLogger.trace(tag("trace"))("trace").assert >> - slf4jLogger.debug(tag("debug"))("debug").assert >> - slf4jLogger.info(tag("info"))("info").assert >> - slf4jLogger.warn(tag("warn"))("warn").assert >> - slf4jLogger.error(tag("error"))("error").assert >> - IO(testLogger.logs()) - .map(toDeferredLogs) - .assertEquals( - List( - DeferredLogMessage.trace(tag("trace"), none, () => "trace"), - DeferredLogMessage.debug(tag("debug"), none, () => "debug"), - DeferredLogMessage.info(tag("info"), none, () => "info"), - DeferredLogMessage.warn(tag("warn"), none, () => "warn"), - DeferredLogMessage.error(tag("error"), none, () => "error") - ) - ) >> - validateMDC + testLoggerFixture().test("Slf4jLoggerInternal gets the dispatching right (msg + context)") { + testLogger => + val slf4jLogger = Slf4jLogger.getLoggerFromSlf4j[IO](testLogger) + prepareMDC >> + slf4jLogger.trace(tag("trace"))("trace").assert >> + slf4jLogger.debug(tag("debug"))("debug").assert >> + slf4jLogger.info(tag("info"))("info").assert >> + slf4jLogger.warn(tag("warn"))("warn").assert >> + slf4jLogger.error(tag("error"))("error").assert >> + IO(testLogger.logs()) + .map(toDeferredLogs) + .assertEquals( + List( + DeferredLogMessage.trace(tag("trace"), none, () => "trace"), + DeferredLogMessage.debug(tag("debug"), none, () => "debug"), + DeferredLogMessage.info(tag("info"), none, () => "info"), + DeferredLogMessage.warn(tag("warn"), none, () => "warn"), + DeferredLogMessage.error(tag("error"), none, () => "error") + ) + ) >> + validateMDC } - testLoggerFixture().test("Slf4jLoggerInternal gets the dispatching right (msg + context + error") { testLogger => + testLoggerFixture().test( + "Slf4jLoggerInternal gets the dispatching right (msg + context + error" + ) { testLogger => val slf4jLogger = Slf4jLogger.getLoggerFromSlf4j[IO](testLogger) prepareMDC >> slf4jLogger.trace(tag("trace"), throwable)("trace").assert >>