Skip to content

Commit

Permalink
KTNB-794: Initialize all bean variables on Spring kernel startup
Browse files Browse the repository at this point in the history
  • Loading branch information
ileasile committed Oct 18, 2024
1 parent 099c0a4 commit aeb7679
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.jetbrains.kotlinx.jupyter.api

/**
* Evaluates code. Returns rendered result or null in case of the error.
*/
fun interface CodeEvaluator {
fun eval(code: Code): Any?
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ interface KernelRunMode {

val streamSubstitutionType: StreamSubstitutionType

fun initializeSession(notebook: Notebook) = Unit
fun initializeSession(
notebook: Notebook,
evaluator: CodeEvaluator,
) = Unit
}

abstract class AbstractKernelRunMode(override val name: String) : KernelRunMode {
Expand All @@ -50,8 +53,4 @@ object EmbeddedKernelRunMode : AbstractKernelRunMode("Embedded") {
override val isRunInsideIntellijProcess: Boolean get() = false
override val streamSubstitutionType: StreamSubstitutionType
get() = StreamSubstitutionType.BLOCKING

override fun initializeSession(notebook: Notebook) {
notebook.sessionOptions.serializeScriptData = true
}
}
5 changes: 5 additions & 0 deletions jupyter-lib/spring-starter/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.publisher.composeOfTaskOutputs

plugins {
kotlin("libs.publisher")
kotlin("jupyter.api")
kotlin("jvm")
kotlin("kapt")
alias(libs.plugins.spring.boot)
Expand Down Expand Up @@ -82,6 +83,10 @@ val springKernelJar =
},
)

tasks.processJupyterApiResources {
libraryProducers = listOf("org.jetbrains.kotlinx.jupyter.spring.starter.SpringJupyterIntegration")
}

kotlinPublications {
publication {
publicationName.set("spring-starter")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,26 @@ package org.jetbrains.kotlinx.jupyter.spring.starter
import jupyter.kotlin.ScriptTemplateWithDisplayHelpers
import jupyter.kotlin.USE
import org.jetbrains.kotlinx.jupyter.api.VariableDeclaration
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration
import java.util.Locale
import kotlin.reflect.KClass
import kotlin.reflect.KVisibility
import kotlin.reflect.full.starProjectedType

@Suppress("unused")
fun ScriptTemplateWithDisplayHelpers.declareAllBeans() {
USE {
declareAllBeansInLibrary()
}
}

@Suppress("unused")
fun JupyterIntegration.Builder.declareAllBeansInLibrary() {
declareBeansByNames(springContext.beanDefinitionNames.toList())
}

@Suppress("unused")
fun ScriptTemplateWithDisplayHelpers.declareBeansByClasses(beanClasses: Iterable<KClass<*>>) {
fun JupyterIntegration.Builder.declareBeansByClasses(beanClasses: Iterable<KClass<*>>) {
val beanInstances =
buildMap {
beanClasses.forEachSafe { beanClass ->
Expand All @@ -29,14 +37,14 @@ fun ScriptTemplateWithDisplayHelpers.declareBeansByClasses(beanClasses: Iterable
declareBeanInstances(beanInstances)
}

fun ScriptTemplateWithDisplayHelpers.declareBeansByNames(beanNames: Iterable<String>) {
fun JupyterIntegration.Builder.declareBeansByNames(beanNames: Iterable<String>) {
val beanInstances =
buildMap {
beanNames.forEachSafe { beanName ->
val varName = beanName.substringAfterLast(".")
if (varName.contains("$")) return@forEachSafe

val beanClass = springContext.getType(beanName).kotlin
val beanClass = springContext.getType(beanName)?.kotlin ?: return@forEachSafe

val qualifiedName = beanClass.qualifiedName ?: return@forEachSafe
if (qualifiedName.contains("$") || qualifiedName.startsWith("com.sun.")) return@forEachSafe
Expand All @@ -51,11 +59,9 @@ fun ScriptTemplateWithDisplayHelpers.declareBeansByNames(beanNames: Iterable<Str
declareBeanInstances(beanInstances)
}

private fun ScriptTemplateWithDisplayHelpers.declareBeanInstances(beanInstances: Map<String, VariableDeclaration>) {
USE {
onLoaded {
declare(beanInstances.values)
}
private fun JupyterIntegration.Builder.declareBeanInstances(beanInstances: Map<String, VariableDeclaration>) {
onLoaded {
declare(beanInstances.values)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.jetbrains.kotlinx.jupyter.spring.starter

import org.jetbrains.kotlinx.jupyter.api.EmbeddedKernelRunMode
import org.jetbrains.kotlinx.jupyter.api.JupyterClientType
import org.jetbrains.kotlinx.jupyter.config.DefaultKernelLoggerFactory
import org.jetbrains.kotlinx.jupyter.libraries.DefaultResolutionInfoProviderFactory
Expand Down Expand Up @@ -43,7 +42,7 @@ class KotlinJupyterKernelService(
thread {
startKernel(
DefaultKernelLoggerFactory,
EmbeddedKernelRunMode,
SpringProcessKernelRunMode,
kernelConfig,
DefaultResolutionInfoProviderFactory,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.jetbrains.kotlinx.jupyter.spring.starter

import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration

@Suppress("unused")
class SpringJupyterIntegration : JupyterIntegration() {
override fun Builder.onLoaded() {
declareAllBeansInLibrary()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.jetbrains.kotlinx.jupyter.spring.starter

import org.jetbrains.kotlinx.jupyter.api.CodeEvaluator
import org.jetbrains.kotlinx.jupyter.api.EmbeddedKernelRunMode
import org.jetbrains.kotlinx.jupyter.api.KernelRunMode
import org.jetbrains.kotlinx.jupyter.api.Notebook

object SpringProcessKernelRunMode : KernelRunMode by EmbeddedKernelRunMode {
override fun initializeSession(
notebook: Notebook,
evaluator: CodeEvaluator,
) {
notebook.sessionOptions.serializeScriptData = true
evaluator.eval("1")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ class ReplForJupyterImpl(
classpathProvider,
).also {
notebook.sharedReplContext = it
kernelRunMode.initializeSession(notebook)
commHandlers.requireUniqueTargets()
commHandlers.forEach { handler -> installCommHandler(handler) }
}
Expand All @@ -435,6 +434,10 @@ class ReplForJupyterImpl(

private val executor: CellExecutor = CellExecutorImpl(sharedContext)

init {
kernelRunMode.initializeSession(notebook, ::eval)
}

private fun onAnnotationsHandler(context: ScriptConfigurationRefinementContext): ResultWithDiagnostics<ScriptCompilationConfiguration> {
return if (evalContextEnabled) {
fileAnnotationsProcessor.process(context, hostProvider.host!!)
Expand All @@ -443,6 +446,12 @@ class ReplForJupyterImpl(
}
}

private fun eval(code: Code): Any? {
val requestData = EvalRequestData(code)
val result = evalEx(requestData)
return (result as? EvalResultEx.Success)?.renderedValue
}

override fun evalEx(evalData: EvalRequestData): EvalResultEx {
return withEvalContext {
evalExImpl(evalData)
Expand Down

0 comments on commit aeb7679

Please sign in to comment.