Skip to content

Commit

Permalink
Fix endToMarker call with inline lambda return
Browse files Browse the repository at this point in the history
endToMarker call was not generated correctly when trace markers were disabled

Fixes: [346808602](https://issuetracker.google.com/346808602)
Relnote: Fixes `endToMarker` generation in early return from inline lambdas that caused start/end imbalance.
  • Loading branch information
ShikaSD authored and Space Cloud committed Jun 16, 2024
1 parent 173d6e6 commit f64fc3a
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,25 @@ class ControlFlowTransformTestsNoSource(
}
"""
)

@Test
fun returnFromIfInlineNoinline() = verifyGoldenComposeIrTransform(
extra = """
import androidx.compose.runtime.*
@Composable fun OuterComposableFunction(content: @Composable () -> Unit) { content() }
""",
source = """
import androidx.compose.runtime.*
import androidx.compose.foundation.layout.*
@Composable
fun Label(test: Boolean) {
OuterComposableFunction {
Column {
if (test) return@OuterComposableFunction
}
}
}
"""
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package androidx.compose.compiler.plugins.kotlin

import androidx.compose.compiler.plugins.kotlin.facade.SourceFile
import org.jetbrains.kotlin.config.CompilerConfiguration
import org.junit.runner.RunWith
import org.junit.runner.Runner
import org.junit.runners.BlockJUnit4ClassRunner
Expand All @@ -30,19 +31,18 @@ class RuntimeTests {

private fun createRuntimeRunners(cls: Class<*>): List<Runner> {
AbstractCompilerTest.setSystemProperties()
val k1Classes = RuntimeTestCompiler(useFir = false).run {
val classes = compileRuntimeClasses()
disposeTestRootDisposable()
classes
}
val k2Classes = RuntimeTestCompiler(useFir = true).run {
val classes = compileRuntimeClasses()
disposeTestRootDisposable()
classes
}
val compilers = listOf(
RuntimeTestCompiler(useFir = false, sourceInformation = false),
RuntimeTestCompiler(useFir = false, sourceInformation = true),
RuntimeTestCompiler(useFir = true, sourceInformation = false),
RuntimeTestCompiler(useFir = true, sourceInformation = true)
)

return k1Classes.map { FirVariantRunner(it, type = "[k1]") } +
k2Classes.map { FirVariantRunner(it, type = "[k2]") }
return compilers.flatMap { compiler ->
val classes = compiler.compileRuntimeClasses()
compiler.disposeTestRootDisposable()
classes.map { FirVariantRunner(it, compiler.description) }
}
}

private class FirVariantRunner(private val cls: Class<*>, val type: String) : BlockJUnit4ClassRunner(cls) {
Expand All @@ -54,7 +54,17 @@ private val runtimeTestSourceRoot = File(RUNTIME_TEST_ROOT)
private val runtimeTestFiles = runtimeTestSourceRoot.walk().toSet()

@Ignore
private class RuntimeTestCompiler(useFir: Boolean) : AbstractCodegenTest(useFir) {
private class RuntimeTestCompiler(
useFir: Boolean,
private val sourceInformation: Boolean
) : AbstractCodegenTest(useFir) {
val description: String = "[k${if (useFir) "1" else "2"}][source=$sourceInformation]"

override fun CompilerConfiguration.updateConfiguration() {
put(ComposeConfiguration.SOURCE_INFORMATION_ENABLED_KEY, sourceInformation)
put(ComposeConfiguration.TRACE_MARKERS_ENABLED_KEY, sourceInformation)
}

fun compileRuntimeClasses() =
compileRuntimeTestClasses(
runtimeTestSourceRoot,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//
// Source
// ------------------------------------------

import androidx.compose.runtime.*
import androidx.compose.foundation.layout.*

@Composable
fun Label(test: Boolean) {
OuterComposableFunction {
Column {
if (test) return@OuterComposableFunction
}
}
}

//
// Transformed IR
// ------------------------------------------

@Composable
@ComposableTarget(applier = "androidx.compose.ui.UiComposable")
fun Label(test: Boolean, %composer: Composer?, %changed: Int) {
%composer = %composer.startRestartGroup(<>)
val %dirty = %changed
if (%changed and 0b1110 == 0) {
%dirty = %dirty or if (%composer.changed(test)) 0b0100 else 0b0010
}
if (%dirty and 0b1011 != 0b0010 || !%composer.skipping) {
OuterComposableFunction(rememberComposableLambda(<>, true, { %composer: Composer?, %changed: Int ->
val tmp0_marker = %composer.currentMarker
if (%changed and 0b1011 != 0b0010 || !%composer.skipping) {
Column(null, null, null, { %composer: Composer?, %changed: Int ->
%composer.startReplaceGroup(<>)
if (test) {
%composer.endToMarker(tmp0_marker)
return@rememberComposableLambda
}
%composer.endReplaceGroup()
}, %composer, 0, 0b0111)
} else {
%composer.skipToGroupEnd()
}
}, %composer, 0b00110110), %composer, 0b0110)
} else {
%composer.skipToGroupEnd()
}
%composer.endRestartGroup()?.updateScope { %composer: Composer?, %force: Int ->
Label(test, %composer, updateChangedFlags(%changed or 0b0001))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//
// Source
// ------------------------------------------

import androidx.compose.runtime.*
import androidx.compose.foundation.layout.*

@Composable
fun Label(test: Boolean) {
OuterComposableFunction {
Column {
if (test) return@OuterComposableFunction
}
}
}

//
// Transformed IR
// ------------------------------------------

@Composable
@ComposableTarget(applier = "androidx.compose.ui.UiComposable")
fun Label(test: Boolean, %composer: Composer?, %changed: Int) {
%composer = %composer.startRestartGroup(<>)
val %dirty = %changed
if (%changed and 0b1110 == 0) {
%dirty = %dirty or if (%composer.changed(test)) 0b0100 else 0b0010
}
if (%dirty and 0b1011 != 0b0010 || !%composer.skipping) {
OuterComposableFunction(rememberComposableLambda(<>, true, { %composer: Composer?, %changed: Int ->
val tmp0_marker = %composer.currentMarker
if (%changed and 0b1011 != 0b0010 || !%composer.skipping) {
Column(null, null, null, { %composer: Composer?, %changed: Int ->
%composer.startReplaceGroup(<>)
if (test) {
%composer.endToMarker(tmp0_marker)
return@rememberComposableLambda
}
%composer.endReplaceGroup()
}, %composer, 0, 0b0111)
} else {
%composer.skipToGroupEnd()
}
}, %composer, 0b00110110), %composer, 0b0110)
} else {
%composer.skipToGroupEnd()
}
%composer.endRestartGroup()?.updateScope { %composer: Composer?, %force: Int ->
Label(test, %composer, updateChangedFlags(%changed or 0b0001))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,11 @@

package androidx.compose.compiler.test

import androidx.compose.runtime.Composable
import androidx.compose.runtime.NonRestartableComposable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.*
import androidx.compose.runtime.mock.InlineLinear
import androidx.compose.runtime.mock.Text
import androidx.compose.runtime.mock.compositionTest
import androidx.compose.runtime.mock.validate
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberUpdatedState
import androidx.compose.runtime.setValue
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.StandardTestDispatcher
Expand Down Expand Up @@ -117,6 +112,21 @@ class CompositionTests {
state = false
advance()
}

@Test
fun returnFromIfInlineNoinline() = compositionTest {
var state by mutableStateOf(true)
compose {
OuterComposable {
InlineLinear {
if (state) return@OuterComposable
}
}
}

state = false
advance()
}
}

@Composable
Expand Down Expand Up @@ -150,3 +160,6 @@ fun DefaultValueClass(
) {
println(data)
}

@Composable
fun OuterComposable(content: @Composable () -> Unit) = content()
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,10 @@

package androidx.compose.compiler.plugins.kotlin.lower

import androidx.compose.compiler.plugins.kotlin.ComposeCallableIds
import androidx.compose.compiler.plugins.kotlin.ComposeFqNames
import androidx.compose.compiler.plugins.kotlin.ComposeNames
import androidx.compose.compiler.plugins.kotlin.FeatureFlag
import androidx.compose.compiler.plugins.kotlin.FeatureFlags
import androidx.compose.compiler.plugins.kotlin.FunctionMetrics
import androidx.compose.compiler.plugins.kotlin.ModuleMetrics
import androidx.compose.compiler.plugins.kotlin.analysis.ComposeWritableSlices
import androidx.compose.compiler.plugins.kotlin.analysis.Stability
import androidx.compose.compiler.plugins.kotlin.analysis.StabilityInferencer
import androidx.compose.compiler.plugins.kotlin.analysis.isUncertain
import androidx.compose.compiler.plugins.kotlin.analysis.knownStable
import androidx.compose.compiler.plugins.kotlin.analysis.knownUnstable
import androidx.compose.compiler.plugins.kotlin.irTrace
import androidx.compose.compiler.plugins.kotlin.*
import androidx.compose.compiler.plugins.kotlin.analysis.*
import androidx.compose.compiler.plugins.kotlin.lower.ComposerParamTransformer.ComposeDefaultValueStubOrigin
import androidx.compose.compiler.plugins.kotlin.lower.decoys.DecoyFqNames
import kotlin.math.abs
import kotlin.math.absoluteValue
import kotlin.math.ceil
import kotlin.math.min
import org.jetbrains.kotlin.backend.common.FileLoweringPass
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.lower.DeclarationIrBuilder
Expand All @@ -55,93 +39,16 @@ import org.jetbrains.kotlin.ir.builders.irBlockBody
import org.jetbrains.kotlin.ir.builders.irCall
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irReturn
import org.jetbrains.kotlin.ir.declarations.IrAnonymousInitializer
import org.jetbrains.kotlin.ir.declarations.IrAttributeContainer
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrDeclarationBase
import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin
import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
import org.jetbrains.kotlin.ir.declarations.IrField
import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.declarations.IrLocalDelegatedProperty
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.declarations.IrPackageFragment
import org.jetbrains.kotlin.ir.declarations.IrProperty
import org.jetbrains.kotlin.ir.declarations.IrScript
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.declarations.IrTypeAlias
import org.jetbrains.kotlin.ir.declarations.IrTypeParameter
import org.jetbrains.kotlin.ir.declarations.IrValueDeclaration
import org.jetbrains.kotlin.ir.declarations.IrValueParameter
import org.jetbrains.kotlin.ir.declarations.IrVariable
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.declarations.impl.IrVariableImpl
import org.jetbrains.kotlin.ir.declarations.name
import org.jetbrains.kotlin.ir.expressions.IrBlock
import org.jetbrains.kotlin.ir.expressions.IrBody
import org.jetbrains.kotlin.ir.expressions.IrBreakContinue
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrComposite
import org.jetbrains.kotlin.ir.expressions.IrConst
import org.jetbrains.kotlin.ir.expressions.IrConstKind
import org.jetbrains.kotlin.ir.expressions.IrContainerExpression
import org.jetbrains.kotlin.ir.expressions.IrContinue
import org.jetbrains.kotlin.ir.expressions.IrDoWhileLoop
import org.jetbrains.kotlin.ir.expressions.IrElseBranch
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionAccessExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionExpression
import org.jetbrains.kotlin.ir.expressions.IrGetValue
import org.jetbrains.kotlin.ir.expressions.IrLoop
import org.jetbrains.kotlin.ir.expressions.IrReturn
import org.jetbrains.kotlin.ir.expressions.IrSpreadElement
import org.jetbrains.kotlin.ir.expressions.IrStatementContainer
import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
import org.jetbrains.kotlin.ir.expressions.IrVararg
import org.jetbrains.kotlin.ir.expressions.IrWhen
import org.jetbrains.kotlin.ir.expressions.IrWhileLoop
import org.jetbrains.kotlin.ir.expressions.impl.IrBlockImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrBranchImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrCompositeImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrElseBranchImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetValueImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrReturnImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrSpreadElementImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrVarargImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrWhenImpl
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.*
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.symbols.IrReturnTargetSymbol
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.symbols.impl.IrVariableSymbolImpl
import org.jetbrains.kotlin.ir.types.IrSimpleType
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.IrTypeArgument
import org.jetbrains.kotlin.ir.types.classOrNull
import org.jetbrains.kotlin.ir.types.classifierOrNull
import org.jetbrains.kotlin.ir.types.defaultType
import org.jetbrains.kotlin.ir.types.getClass
import org.jetbrains.kotlin.ir.types.isClassWithFqName
import org.jetbrains.kotlin.ir.types.isMarkedNullable
import org.jetbrains.kotlin.ir.types.isNothing
import org.jetbrains.kotlin.ir.types.isUnit
import org.jetbrains.kotlin.ir.types.makeNullable
import org.jetbrains.kotlin.ir.util.DeepCopySymbolRemapper
import org.jetbrains.kotlin.ir.util.defaultType
import org.jetbrains.kotlin.ir.util.file
import org.jetbrains.kotlin.ir.util.fqNameWhenAvailable
import org.jetbrains.kotlin.ir.util.functions
import org.jetbrains.kotlin.ir.util.getPropertyGetter
import org.jetbrains.kotlin.ir.util.hasAnnotation
import org.jetbrains.kotlin.ir.util.isLocal
import org.jetbrains.kotlin.ir.util.isOverridableOrOverrides
import org.jetbrains.kotlin.ir.util.isVararg
import org.jetbrains.kotlin.ir.util.kotlinFqName
import org.jetbrains.kotlin.ir.util.parentClassOrNull
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
import org.jetbrains.kotlin.ir.util.properties
import org.jetbrains.kotlin.ir.util.statements
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.name.FqNameUnsafe
Expand All @@ -151,6 +58,10 @@ import org.jetbrains.kotlin.platform.isJs
import org.jetbrains.kotlin.platform.jvm.isJvm
import org.jetbrains.kotlin.util.OperatorNameConventions
import org.jetbrains.kotlin.utils.IDEAPluginsCompatibilityAPI
import kotlin.math.abs
import kotlin.math.absoluteValue
import kotlin.math.ceil
import kotlin.math.min

/**
* An enum of the different "states" a parameter of a composable function can have relating to
Expand Down Expand Up @@ -2640,16 +2551,12 @@ class ComposableFunctionBodyTransformer(
} else {
val functionScope = scope
val targetScope = currentScope as? Scope.BlockScope ?: functionScope
val marker = irGet(functionScope.allocateMarker())
extraEndLocation(irEndToMarker(marker, targetScope))
if (functionScope.isInlinedLambda) {
val marker = irGet(functionScope.allocateMarker())
extraEndLocation(irEndToMarker(marker, targetScope))
scope.hasInlineEarlyReturn = true
} else {
val marker = functionScope.allocateMarker()
functionScope.markReturn {
extraEndLocation(irEndToMarker(irGet(marker), targetScope))
extraEndLocation(it)
}
functionScope.markReturn(extraEndLocation)
}
}
break@loop
Expand Down

0 comments on commit f64fc3a

Please sign in to comment.