From b9c68de9264df0a444fc52d747c90ea2d2231987 Mon Sep 17 00:00:00 2001 From: ibado Date: Sun, 12 Feb 2023 12:34:52 +0100 Subject: [PATCH] Fix #2909 --- .../arrow/optics/plugin/internals/dsl.kt | 160 ++++++++++-------- .../kotlin/arrow/optics/plugin/DSLTests.kt | 15 ++ 2 files changed, 109 insertions(+), 66 deletions(-) diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt index a42decca9e3..35aca64ed08 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt @@ -1,39 +1,50 @@ package arrow.optics.plugin.internals -fun generateLensDsl(ele: ADT, optic: DataClassDsl): Snippet = - Snippet( +import com.google.devtools.ksp.getDeclaredProperties + +fun generateLensDsl(ele: ADT, optic: DataClassDsl): Snippet { + val (className, import) = resolveClassName(ele) + return Snippet( `package` = ele.packageName, name = ele.simpleName, - content = processLensSyntax(ele, optic.foci) + content = processLensSyntax(ele, optic.foci, className), + imports = setOf(import) ) +} -fun generateOptionalDsl(ele: ADT, optic: DataClassDsl): Snippet = - Snippet( +fun generateOptionalDsl(ele: ADT, optic: DataClassDsl): Snippet { + val (className, import) = resolveClassName(ele) + return Snippet( `package` = ele.packageName, name = ele.simpleName, - content = processOptionalSyntax(ele, optic) + content = processOptionalSyntax(ele, optic, className), + imports = setOf(import) ) +} -fun generatePrismDsl(ele: ADT, isoOptic: SealedClassDsl): Snippet = - Snippet( +fun generatePrismDsl(ele: ADT, isoOptic: SealedClassDsl): Snippet { + val (className, import) = resolveClassName(ele) + return Snippet( `package` = ele.packageName, name = ele.simpleName, - content = processPrismSyntax(ele, isoOptic) + content = processPrismSyntax(ele, isoOptic, className), + imports = setOf(import) ) +} -private fun processLensSyntax(ele: ADT, foci: List): String = - if (ele.typeParameters.isEmpty()) { +private fun processLensSyntax(ele: ADT, foci: List, className: String): String { + return if (ele.typeParameters.isEmpty()) { foci.joinToString(separator = "\n") { focus -> - """ - |${ele.visibilityModifierName} inline val $Iso.${focus.lensParamName()}: $Lens inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} - |${ele.visibilityModifierName} inline val $Lens.${focus.lensParamName()}: $Lens inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} - |${ele.visibilityModifierName} inline val $Optional.${focus.lensParamName()}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} - |${ele.visibilityModifierName} inline val $Prism.${focus.lensParamName()}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} - |${ele.visibilityModifierName} inline val $Getter.${focus.lensParamName()}: $Getter inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} - |${ele.visibilityModifierName} inline val $Setter.${focus.lensParamName()}: $Setter inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} - |${ele.visibilityModifierName} inline val $Traversal.${focus.lensParamName()}: $Traversal inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} - |${ele.visibilityModifierName} inline val $Fold.${focus.lensParamName()}: $Fold inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} - |${ele.visibilityModifierName} inline val $Every.${focus.lensParamName()}: $Every inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} + """ + |${ele.visibilityModifierName} inline val $Iso.${focus.lensParamName()}: $Lens inline get() = this + ${className}.${focus.lensParamName()} + |${ele.visibilityModifierName} inline val $Lens.${focus.lensParamName()}: $Lens inline get() = this + ${className}.${focus.lensParamName()} + |${ele.visibilityModifierName} inline val $Optional.${focus.lensParamName()}: $Optional inline get() = this + ${className}.${focus.lensParamName()} + |${ele.visibilityModifierName} inline val $Prism.${focus.lensParamName()}: $Optional inline get() = this + ${className}.${focus.lensParamName()} + |${ele.visibilityModifierName} inline val $Getter.${focus.lensParamName()}: $Getter inline get() = this + ${className}.${focus.lensParamName()} + |${ele.visibilityModifierName} inline val $Setter.${focus.lensParamName()}: $Setter inline get() = this + ${className}.${focus.lensParamName()} + |${ele.visibilityModifierName} inline val $Traversal.${focus.lensParamName()}: $Traversal inline get() = this + ${className}.${focus.lensParamName()} + |${ele.visibilityModifierName} inline val $Fold.${focus.lensParamName()}: $Fold inline get() = this + ${className}.${focus.lensParamName()} + |${ele.visibilityModifierName} inline val $Every.${focus.lensParamName()}: $Every inline get() = this + ${className}.${focus.lensParamName()} |""".trimMargin() } } else { @@ -41,20 +52,21 @@ private fun processLensSyntax(ele: ADT, foci: List): String = val joinedTypeParams = ele.typeParameters.joinToString(separator=",") foci.joinToString(separator = "\n") { focus -> """ - |${ele.visibilityModifierName} inline fun $Iso.${focus.lensParamName()}(): $Lens = this + ${ele.sourceClassName}.${focus.lensParamName()}() - |${ele.visibilityModifierName} inline fun $Lens.${focus.lensParamName()}(): $Lens = this + ${ele.sourceClassName}.${focus.lensParamName()}() - |${ele.visibilityModifierName} inline fun $Optional.${focus.lensParamName()}(): $Optional = this + ${ele.sourceClassName}.${focus.lensParamName()}() - |${ele.visibilityModifierName} inline fun $Prism.${focus.lensParamName()}(): $Optional = this + ${ele.sourceClassName}.${focus.lensParamName()}() - |${ele.visibilityModifierName} inline fun $Getter.${focus.lensParamName()}(): $Getter = this + ${ele.sourceClassName}.${focus.lensParamName()}() - |${ele.visibilityModifierName} inline fun $Setter.${focus.lensParamName()}(): $Setter = this + ${ele.sourceClassName}.${focus.lensParamName()}() - |${ele.visibilityModifierName} inline fun $Traversal.${focus.lensParamName()}(): $Traversal = this + ${ele.sourceClassName}.${focus.lensParamName()}() - |${ele.visibilityModifierName} inline fun $Fold.${focus.lensParamName()}(): $Fold = this + ${ele.sourceClassName}.${focus.lensParamName()}() - |${ele.visibilityModifierName} inline fun $Every.${focus.lensParamName()}(): $Every = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Iso.${focus.lensParamName()}(): $Lens = this + ${className}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Lens.${focus.lensParamName()}(): $Lens = this + ${className}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Optional.${focus.lensParamName()}(): $Optional = this + ${className}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Prism.${focus.lensParamName()}(): $Optional = this + ${className}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Getter.${focus.lensParamName()}(): $Getter = this + ${className}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Setter.${focus.lensParamName()}(): $Setter = this + ${className}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Traversal.${focus.lensParamName()}(): $Traversal = this + ${className}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Fold.${focus.lensParamName()}(): $Fold = this + ${className}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Every.${focus.lensParamName()}(): $Every = this + ${className}.${focus.lensParamName()}() |""".trimMargin() } } +} -private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl): String { +private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl, className: String): String { val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}" val joinedTypeParams = ele.typeParameters.joinToString(separator=",") return optic.foci.filterNot { it is NonNullFocus }.joinToString(separator = "\n") { focus -> @@ -66,42 +78,42 @@ private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl): String { } if (ele.typeParameters.isEmpty()) { """ - |${ele.visibilityModifierName} inline val $Iso.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Lens.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Optional.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Prism.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Setter.${focus.paramName}: $Setter inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Traversal.${focus.paramName}: $Traversal inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Fold.${focus.paramName}: $Fold inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Every.${focus.paramName}: $Every inline get() = this + ${ele.sourceClassName}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Iso.${focus.paramName}: $Optional inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Lens.${focus.paramName}: $Optional inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Optional.${focus.paramName}: $Optional inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Prism.${focus.paramName}: $Optional inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Setter.${focus.paramName}: $Setter inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Traversal.${focus.paramName}: $Traversal inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Fold.${focus.paramName}: $Fold inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Every.${focus.paramName}: $Every inline get() = this + ${className}.${focus.paramName} |""".trimMargin() } else { """ - |${ele.visibilityModifierName} inline fun $Iso.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Lens.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Optional.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Prism.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Setter.${focus.paramName}(): $Setter = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Traversal.${focus.paramName}(): $Traversal = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Fold.${focus.paramName}(): $Fold = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Every.${focus.paramName}(): $Every = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Iso.${focus.paramName}(): $Optional = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Lens.${focus.paramName}(): $Optional = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Optional.${focus.paramName}(): $Optional = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Prism.${focus.paramName}(): $Optional = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Setter.${focus.paramName}(): $Setter = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Traversal.${focus.paramName}(): $Traversal = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Fold.${focus.paramName}(): $Fold = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Every.${focus.paramName}(): $Every = this + ${className}.${focus.paramName}() |""".trimMargin() } } } -private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl): String = - if (ele.typeParameters.isEmpty()) { +private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl, className: String): String { + return if (ele.typeParameters.isEmpty()) { dsl.foci.joinToString(separator = "\n\n") { focus -> - """ - |${ele.visibilityModifierName} inline val $Iso.${focus.paramName}: $Prism inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Lens.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Optional.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Prism.${focus.paramName}: $Prism inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Setter.${focus.paramName}: $Setter inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Traversal.${focus.paramName}: $Traversal inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Fold.${focus.paramName}: $Fold inline get() = this + ${ele.sourceClassName}.${focus.paramName} - |${ele.visibilityModifierName} inline val $Every.${focus.paramName}: $Every inline get() = this + ${ele.sourceClassName}.${focus.paramName} + """ + |${ele.visibilityModifierName} inline val $Iso.${focus.paramName}: $Prism inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Lens.${focus.paramName}: $Optional inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Optional.${focus.paramName}: $Optional inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Prism.${focus.paramName}: $Prism inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Setter.${focus.paramName}: $Setter inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Traversal.${focus.paramName}: $Traversal inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Fold.${focus.paramName}: $Fold inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Every.${focus.paramName}: $Every inline get() = this + ${className}.${focus.paramName} |""".trimMargin() } } else { @@ -111,15 +123,31 @@ private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl): String = focus.refinedArguments.isEmpty() -> "" else -> focus.refinedArguments.joinToString(separator=",") } - """ - |${ele.visibilityModifierName} inline fun $Iso.${focus.paramName}(): $Prism = this + ${ele.sourceClassName}.${focus.paramName}() + """ + |${ele.visibilityModifierName} inline fun $Iso.${focus.paramName}(): $Prism = this + ${className}.${focus.paramName}() |${ele.visibilityModifierName} inline fun $Lens.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Optional.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Prism.${focus.paramName}(): $Prism = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Setter.${focus.paramName}(): $Setter = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Traversal.${focus.paramName}(): $Traversal = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Fold.${focus.paramName}(): $Fold = this + ${ele.sourceClassName}.${focus.paramName}() - |${ele.visibilityModifierName} inline fun $Every.${focus.paramName}(): $Every = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Optional.${focus.paramName}(): $Optional = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Prism.${focus.paramName}(): $Prism = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Setter.${focus.paramName}(): $Setter = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Traversal.${focus.paramName}(): $Traversal = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Fold.${focus.paramName}(): $Fold = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Every.${focus.paramName}(): $Every = this + ${className}.${focus.paramName}() |""".trimMargin() } } +} + +private fun resolveClassName(ele: ADT): Pair = if (hasPackageCollisions(ele)) { + val classNameAlias = ele.sourceClassName.replace(".", "") + val aliasImport = "import ${ele.sourceClassName} as $classNameAlias" + classNameAlias to aliasImport +} else ele.sourceClassName to "" + +private fun hasPackageCollisions(ele: ADT): Boolean = + ele.declaration.getDeclaredProperties().let { properties -> + ele.packageName + .split(".") + .any { p -> + properties.any { it.simpleName.asString() == p } + } + } diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/DSLTests.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/DSLTests.kt index dffa7ddf7b6..9fe6cfd5c8f 100755 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/DSLTests.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/DSLTests.kt @@ -46,5 +46,20 @@ class DSLTests { """.compilationSucceeds() } + @Test + fun `DSL for a data class with property named as a package directive`() { + """ + |package main.program + | + |$imports + | + |@optics + |data class Source(val program: String) { + | companion object + |} + | + """.compilationSucceeds() + } + // Db.content.at(At.map(), One).set(db, None) }