Skip to content

Commit

Permalink
If a classname matches two imports, then try to resolve the type to f…
Browse files Browse the repository at this point in the history
…ind the correct one.

Fixes #154
  • Loading branch information
vRallev committed Nov 11, 2020
1 parent 42f8441 commit 83dbd5a
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,31 @@ internal fun PsiElement.requireFqName(
// First look in the imports for the reference name. If the class is imported, then we know the
// fully qualified name.
importPaths
.filter { it.alias == null }
.firstOrNull {
it.fqName.shortName().asString() == classReference
.filter { it.alias == null && it.fqName.shortName().asString() == classReference }
.also { matchingImportPaths ->
when {
matchingImportPaths.size == 1 ->
return matchingImportPaths[0].fqName
matchingImportPaths.size > 1 ->
return matchingImportPaths.first { importPath ->
module.resolveClassByFqName(importPath.fqName, FROM_BACKEND) != null
}.fqName
}
}
?.let { return it.fqName }

importPaths
.filter { it.alias == null }
.firstOrNull {
it.fqName.shortName().asString() == classReferenceOuter
.filter { it.alias == null && it.fqName.shortName().asString() == classReferenceOuter }
.also { matchingImportPaths ->
when {
matchingImportPaths.size == 1 ->
return FqName("${matchingImportPaths[0].fqName.parent()}.$classReference")
matchingImportPaths.size > 1 ->
return matchingImportPaths.first { importPath ->
val fqName = FqName("${importPath.fqName.parent()}.$classReference")
module.resolveClassByFqName(fqName, FROM_BACKEND) != null
}.fqName
}
}
?.let { return FqName("${it.fqName.parent()}.$classReference") }

// If there is no import, then try to resolve the class with the same package as this file.
module.findClassOrTypeAlias(containingKtFile.packageFqName, classReference)
Expand Down Expand Up @@ -289,9 +302,9 @@ internal fun PsiElement.requireFqName(

// Check if it's a named import.
containingKtFile.importDirectives
.firstOrNull { classReference == it.importPath?.importedName?.asString() }
?.importedFqName
?.let { return it }
.firstOrNull { classReference == it.importPath?.importedName?.asString() }
?.importedFqName
?.let { return it }

// Everything else isn't supported.
throw AnvilCompilationException(
Expand Down Expand Up @@ -390,7 +403,7 @@ fun KtUserType.isTypeParameter(): Boolean {

fun KtUserType.findExtendsBound(): List<FqName> {
return parents.filterIsInstance<KtClassOrObject>()
.first()
.typeParameters
.mapNotNull { it.fqName }
.first()
.typeParameters
.mapNotNull { it.fqName }
}
17 changes: 13 additions & 4 deletions compiler/src/test/java/com/squareup/anvil/compiler/TestUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import java.util.Locale.US
import kotlin.reflect.KClass

internal fun compile(
source: String,
vararg sources: String,
enableDaggerAnnotationProcessor: Boolean = false,
generateDaggerFactories: Boolean = false,
block: Result.() -> Unit = { }
Expand Down Expand Up @@ -50,10 +50,19 @@ internal fun compile(
)
)

val name = "${workingDir.absolutePath}/sources/src/main/java/com/squareup/test/Source.kt"
check(File(name).parentFile.mkdirs())
this.sources = sources.map { content ->
val packageDir = content.lines()
.first { it.trim().startsWith("package ") }
.substringAfter("package ")
.replace('.', '/')

sources = listOf(SourceFile.kotlin(name, contents = source, trimIndent = true))
val name = "${workingDir.absolutePath}/sources/src/main/java/$packageDir/Source.kt"
with(File(name).parentFile) {
check(exists() || mkdirs())
}

SourceFile.kotlin(name, contents = content, trimIndent = true)
}
}
.compile()
.also(block)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ class ComponentDetectorCheckTest {
}

private fun compile(
source: String,
vararg sources: String,
block: Result.() -> Unit = { }
): Result = com.squareup.anvil.compiler.compile(
source = source,
sources = *sources,
generateDaggerFactories = true,
block = block
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1304,10 +1304,10 @@ public final class InjectClass_Factory<T extends CharSequence> implements Factor
}

private fun compile(
source: String,
vararg sources: String,
block: Result.() -> Unit = { }
): Result = com.squareup.anvil.compiler.compile(
source = source,
sources = *sources,
enableDaggerAnnotationProcessor = useDagger,
generateDaggerFactories = !useDagger,
block = block
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,10 @@ public final class InjectClass_MembersInjector<T, U, V> implements MembersInject
}

private fun compile(
source: String,
vararg sources: String,
block: Result.() -> Unit = { }
): Result = com.squareup.anvil.compiler.compile(
source = source,
sources = *sources,
enableDaggerAnnotationProcessor = useDagger,
generateDaggerFactories = !useDagger,
block = block
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2399,11 +2399,52 @@ public final class DaggerComponentInterface implements ComponentInterface {
}
}

@Test
fun `a factory class is generated for an uppercase factory function`() {
compile(
"""
package com.squareup.test.a
import com.squareup.test.b.User
fun User(): User = User(42)
""",
"""
package com.squareup.test.b
data class User(val age: Int)
""",
"""
package com.squareup.test
import com.squareup.test.a.User
import com.squareup.test.b.User
import dagger.Module
import dagger.Provides
@Module
object DaggerModule1 {
@Provides fun user(): User = User()
}
"""
) {
val factoryClass = daggerModule1.moduleFactoryClass("user")

val constructor = factoryClass.declaredConstructors.single()
assertThat(constructor.parameterTypes.toList()).isEmpty()

val staticMethods = factoryClass.declaredMethods.filter { it.isStatic }

val userProvider = staticMethods.single { it.name == "user" }
assertThat(userProvider.invoke(null)).isNotNull()
}
}

private fun compile(
source: String,
vararg sources: String,
block: Result.() -> Unit = { }
): Result = com.squareup.anvil.compiler.compile(
source,
sources = *sources,
enableDaggerAnnotationProcessor = useDagger,
generateDaggerFactories = !useDagger,
block = block
Expand Down

0 comments on commit 83dbd5a

Please sign in to comment.