Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the need for operation type aliasing in codegen #1710

Merged
merged 6 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ references = ["smithy-rs#1647", "smithy-rs#1112"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client"}
author = "Velfi"

[[smithy-rs]]
message = "Removed the need to generate operation output and retry aliases in codegen."
references = ["smithy-rs#976", "smithy-rs#1710"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "Velfi"

[[smithy-rs]]
message = "Added `writable` property to `RustType` and `RuntimeType` that returns them in `Writable` form"
references = ["smithy-rs#1710"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "all" }
author = "Velfi"

[[smithy-rs]]
message = "Smithy IDL v2 mixins are now supported"
references = ["smithy-rs#1680"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private class AwsClientGenerics(private val types: Types) : FluentClientGenerics
override val bounds = writable { }

/** Bounds for generated `send()` functions */
override fun sendBounds(input: Symbol, output: Symbol, error: RuntimeType): Writable = writable { }
override fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: RuntimeType, retryPolicy: Writable): Writable = writable { }

override fun toGenericsGenerator(): GenericsGenerator {
return GenericsGenerator()
Expand All @@ -98,7 +98,7 @@ class AwsFluentClientDecorator : RustCodegenDecorator<ClientCodegenContext> {
AwsPresignedFluentBuilderMethod(runtimeConfig),
AwsFluentClientDocs(codegenContext),
),
retryPolicyType = runtimeConfig.awsHttp().asType().member("retry::AwsErrorRetryPolicy"),
retryPolicy = runtimeConfig.awsHttp().asType().member("retry::AwsErrorRetryPolicy").writable,
).render(rustCrate)
rustCrate.withModule(FluentClientGenerator.customizableOperationModule) { writer ->
renderCustomizableOperationSendMethod(runtimeConfig, generics, writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,46 @@ sealed class RustType {

open val namespace: kotlin.String? = null

/**
* Get a writable for this `RustType`
*
* ```kotlin
* // Declare a RustType
* val t = RustType.Unit.writable
* // Then, invoke the writable directly
* t.invoke(writer)
* // OR template it out
* writer.rustInlineTemplate("#{t:W}", "t" to t)
* ```
*
* When formatted, the converted type will appear as such:
*
* | Type | Formatted |
* | -------------------------------------------------- | ------------------------------------------------------------------- |
* | RustType.Unit | () |
* | RustType.Bool | bool |
* | RustType.Float(32) | f32 |
* | RustType.Float(64) | f64 |
* | RustType.Integer(8) | i8 |
* | RustType.Integer(16) | i16 |
* | RustType.Integer(32) | i32 |
* | RustType.Integer(64) | i64 |
* | RustType.String | std::string::String |
* | RustType.Vec(RustType.String) | std::vec::Vec<std::string::String> |
* | RustType.Slice(RustType.String) | [std::string::String] |
* | RustType.HashMap(RustType.String, RustType.String) | std::collections::HashMap<std::string::String, std::string::String> |
* | RustType.HashSet(RustType.String) | std::vec::Vec<std::string::String> |
* | RustType.Reference("&", RustType.String) | &std::string::String |
* | RustType.Reference("&mut", RustType.String) | &mut std::string::String |
* | RustType.Reference("&'static", RustType.String) | &'static std::string::String |
* | RustType.Option(RustType.String) | std::option::Option<std::string::String> |
* | RustType.Box(RustType.String) | std::boxed::Box<std::string::String> |
* | RustType.Opaque("SoCool", "zelda_is") | zelda_is::SoCool |
* | RustType.Opaque("SoCool") | SoCool |
* | RustType.Dyn(RustType.Opaque("Foo", "foo")) | dyn foo::Foo |
*/
val writable = writable { rustInlineTemplate("#{this}", "this" to this@RustType) }

object Unit : RustType() {
override val name: kotlin.String = "()"
}
Expand Down Expand Up @@ -186,7 +226,13 @@ fun RustType.render(fullyQualified: Boolean = true): String {
is RustType.Slice -> "[${this.member.render(fullyQualified)}]"
is RustType.HashMap -> "${this.name}<${this.key.render(fullyQualified)}, ${this.member.render(fullyQualified)}>"
is RustType.HashSet -> "${this.name}<${this.member.render(fullyQualified)}>"
is RustType.Reference -> "&${this.lifetime?.let { "'$it" } ?: ""} ${this.member.render(fullyQualified)}"
is RustType.Reference -> {
if (this.lifetime == "&") {
"&${this.member.render(fullyQualified)}"
} else {
"&${this.lifetime?.let { "'$it" } ?: ""} ${this.member.render(fullyQualified)}"
}
}
is RustType.Option -> "${this.name}<${this.member.render(fullyQualified)}>"
is RustType.Box -> "${this.name}<${this.member.render(fullyQualified)}>"
is RustType.Dyn -> "${this.name} ${this.member.render(fullyQualified)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,8 @@ class RustWriter private constructor(
* Formatter to enable formatting any [writable] with the #W formatter.
*/
inner class RustWriteableInjector : BiFunction<Any, String, String> {
@Suppress("UNCHECKED_CAST")
override fun apply(t: Any, u: String): String {
val func = t as RustWriter.() -> Unit
val func = t as? RustWriter.() -> Unit ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)")
val innerWriter = RustWriter(filename, namespace, printWarning = false)
func(innerWriter)
innerWriter.dependencies.forEach { addDependency(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package software.amazon.smithy.rust.codegen.rustlang

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.GenericsGenerator
import software.amazon.smithy.rust.codegen.util.PANIC

typealias Writable = RustWriter.() -> Unit

Expand Down Expand Up @@ -56,13 +56,21 @@ fun rustTypeParameters(
val iterator: Iterator<Any> = typeParameters.iterator()
while (iterator.hasNext()) {
when (val typeParameter = iterator.next()) {
is Symbol, is RustType.Unit, is RuntimeType -> rustInlineTemplate("#{it}", "it" to typeParameter)
is Symbol, is RuntimeType, is RustType -> rustInlineTemplate("#{it}", "it" to typeParameter)
is String -> rustInlineTemplate(typeParameter)
is GenericsGenerator -> rustInlineTemplate(
"#{gg:W}",
"gg" to typeParameter.declaration(withAngleBrackets = false),
)
else -> PANIC("Unhandled type '$typeParameter' encountered by rustTypeParameters writer")
else -> {
// Check if it's a writer. If it is, invoke it; Else, throw a codegen error.
val func = typeParameter as? RustWriter.() -> Unit
if (func != null) {
func.invoke(this)
} else {
throw CodegenException("Unhandled type '$typeParameter' encountered by rustTypeParameters writer")
}
}
}

if (iterator.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rustInlineTemplate
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.util.orNull
import java.util.Optional

Expand Down Expand Up @@ -123,6 +125,11 @@ data class RuntimeConfig(
* name, but also ensure that we automatically add any dependencies **as they are used**.
*/
data class RuntimeType(val name: String?, val dependency: RustDependency?, val namespace: String) {
/**
* Get a writable for this `RuntimeType`
*/
val writable = writable { rustInlineTemplate("#{this:T}", "this" to this@RuntimeType) }

/**
* Convert this [RuntimeType] into a [Symbol].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ class PaginatorGenerator private constructor(
service: ServiceShape,
operation: OperationShape,
private val generics: FluentClientGenerics,
retryPolicy: Writable = RustType.Unit.writable,
) {

companion object {
fun paginatorType(
coreCodegenContext: CoreCodegenContext,
generics: FluentClientGenerics,
operationShape: OperationShape,
retryPolicy: Writable = RustType.Unit.writable,
): RuntimeType? {
return if (operationShape.isPaginated(coreCodegenContext.model)) {
PaginatorGenerator(
Expand All @@ -63,6 +64,7 @@ class PaginatorGenerator private constructor(
coreCodegenContext.serviceShape,
operationShape,
generics,
retryPolicy,
).paginatorType()
} else {
null
Expand All @@ -82,7 +84,8 @@ class PaginatorGenerator private constructor(
)

private val inputType = symbolProvider.toSymbol(operation.inputShape(model))
private val outputType = operation.outputShape(model)
private val outputShape = operation.outputShape(model)
private val outputType = symbolProvider.toSymbol(outputShape)
private val errorType = operation.errorSymbol(model, symbolProvider, CodegenTarget.CLIENT)

private fun paginatorType(): RuntimeType = RuntimeType.forInlineFun(
Expand All @@ -95,12 +98,12 @@ class PaginatorGenerator private constructor(
"generics" to generics.decl,
"bounds" to generics.bounds,
"page_size_setter" to pageSizeSetter(),
"send_bounds" to generics.sendBounds(inputType, symbolProvider.toSymbol(outputType), errorType),
"send_bounds" to generics.sendBounds(symbolProvider.toSymbol(operation), outputType, errorType, retryPolicy),

// Operation Types
"operation" to symbolProvider.toSymbol(operation),
"Input" to inputType,
"Output" to symbolProvider.toSymbol(outputType),
"Output" to outputType,
"Error" to errorType,
"Builder" to operation.inputShape(model).builderSymbol(symbolProvider),

Expand All @@ -118,7 +121,7 @@ class PaginatorGenerator private constructor(
/** Generate the paginator struct & impl **/
private fun generate() = writable {
val outputTokenLens = NestedAccessorGenerator(symbolProvider).generateBorrowingAccessor(
outputType,
outputShape,
paginationInfo.outputTokenMemberPath,
)
val inputTokenMember = symbolProvider.toMemberName(paginationInfo.inputTokenMember)
Expand Down Expand Up @@ -173,7 +176,7 @@ class PaginatorGenerator private constructor(
let done = match resp {
Ok(ref resp) => {
let new_token = #{output_token}(resp);
let is_empty = ${nextTokenEmpty("new_token")};
let is_empty = new_token.map(|token| token.is_empty()).unwrap_or(true);
if !is_empty && new_token == input.$inputTokenMember.as_ref() {
let _ = tx.send(Err(#{SdkError}::ConstructionFailure("next token did not change, aborting paginator. This indicates an SDK or AWS service bug.".into()))).await;
return;
Expand Down Expand Up @@ -259,18 +262,14 @@ class PaginatorGenerator private constructor(

""",
"extract_items" to NestedAccessorGenerator(symbolProvider).generateOwnedAccessor(
outputType,
outputShape,
paginationInfo.itemsMemberPath,
),
*codegenScope,
)
}
}

private fun nextTokenEmpty(token: String): String {
return "$token.map(|token|token.is_empty()).unwrap_or(true)"
}

private fun pageSizeSetter() = writable {
paginationInfo.pageSizeMember.orNull()?.also {
val memberName = symbolProvider.toMemberName(it)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asArgumentType
import software.amazon.smithy.rust.codegen.rustlang.asOptional
import software.amazon.smithy.rust.codegen.rustlang.asType
Expand Down Expand Up @@ -66,7 +67,7 @@ class FluentClientGenerator(
client = CargoDependency.SmithyClient(codegenContext.runtimeConfig).asType(),
),
private val customizations: List<FluentClientCustomization> = emptyList(),
private val retryPolicyType: RuntimeType? = null,
private val retryPolicy: Writable = RustType.Unit.writable,
) {
companion object {
fun clientOperationFnName(operationShape: OperationShape, symbolProvider: RustSymbolProvider): String =
Expand Down Expand Up @@ -274,7 +275,6 @@ class FluentClientGenerator(
"client" to clientDep.asType(),
"bounds" to generics.bounds,
) {
val inputType = symbolProvider.toSymbol(operation.inputShape(model))
val outputType = symbolProvider.toSymbol(operation.outputShape(model))
val errorType = operation.errorSymbol(model, symbolProvider, CodegenTarget.CLIENT)

Expand Down Expand Up @@ -322,14 +322,14 @@ class FluentClientGenerator(
"OperationOutput" to outputType,
"SdkError" to runtimeConfig.smithyHttp().member("result::SdkError"),
"SdkSuccess" to runtimeConfig.smithyHttp().member("result::SdkSuccess"),
"send_bounds" to generics.sendBounds(inputType, outputType, errorType),
"send_bounds" to generics.sendBounds(operationSymbol, outputType, errorType, retryPolicy),
"customizable_op_type_params" to rustTypeParameters(
symbolProvider.toSymbol(operation),
retryPolicyType ?: RustType.Unit,
retryPolicy,
generics.toGenericsGenerator(),
),
)
PaginatorGenerator.paginatorType(codegenContext, generics, operation)?.also { paginatorType ->
PaginatorGenerator.paginatorType(codegenContext, generics, operation, retryPolicy)?.also { paginatorType ->
rustTemplate(
"""
/// Create a paginator for this request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ interface FluentClientGenerics {
val bounds: Writable

/** Bounds for generated `send()` functions */
fun sendBounds(input: Symbol, output: Symbol, error: RuntimeType): Writable
fun sendBounds(input: Symbol, output: Symbol, error: RuntimeType, retryPolicy: Writable): Writable

/** Convert this `FluentClientGenerics` into the more general `GenericsGenerator` */
fun toGenericsGenerator(): GenericsGenerator
Expand Down Expand Up @@ -70,21 +70,22 @@ data class FlexibleClientGenerics(
}

/** Bounds for generated `send()` functions */
override fun sendBounds(input: Symbol, output: Symbol, error: RuntimeType): Writable = writable {
override fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: RuntimeType, retryPolicy: Writable): Writable = writable {
rustTemplate(
"""
where
R::Policy: #{client}::bounds::SmithyRetryPolicy<
#{Input}OperationOutputAlias,
#{Output},
#{Error},
#{Input}OperationRetryAlias
#{Operation},
#{OperationOutput},
#{OperationError},
#{RetryPolicy:W}
>
""",
"client" to client,
"Input" to input,
"Output" to output,
"Error" to error,
"Operation" to operation,
"OperationOutput" to operationOutput,
"OperationError" to operationError,
"RetryPolicy" to retryPolicy,
)
}

Expand Down
Loading