Skip to content

Commit

Permalink
Remove the need for operation type aliasing in codegen (#1710)
Browse files Browse the repository at this point in the history
* remove: need for operation type aliasing
rename: FluentClientGenerics.sendBounds params to be more accurate
update: FlexibleClientGenerics.sendBounds impl for readability
update: type of FluentClientGenerator input param `retryPolicyType` to be `Any` with a default of `RustType.Unit`
update: PaginatorGenerator to take retryPolicy as an input
chore: fix some spelling and grammar issues
remove: redundant `nextTokenEmpty` function from PaginatorGenerator

* Update CHANGELOG.next.toml

Co-authored-by: John DiSanti <[email protected]>

* add: `writable` property to RustType that returns the type as a Writable
add: test for RustType writable
add: `writable` property to RuntimeType that returns the type as a Writable
update: FluentClientGenerator to take a writable for retry

* format: run formatter

Co-authored-by: John DiSanti <[email protected]>
  • Loading branch information
Velfi and jdisanti authored Sep 7, 2022
1 parent 96e9f61 commit 4809a5b
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 76 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ references = ["smithy-rs#1647", "smithy-rs#1112"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client"}
author = "Velfi"

[[smithy-rs]]
message = "Removed the need to generate operation output and retry aliases in codegen."
references = ["smithy-rs#976", "smithy-rs#1710"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "Velfi"

[[smithy-rs]]
message = "Added `writable` property to `RustType` and `RuntimeType` that returns them in `Writable` form"
references = ["smithy-rs#1710"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "all" }
author = "Velfi"

[[smithy-rs]]
message = "Smithy IDL v2 mixins are now supported"
references = ["smithy-rs#1680"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private class AwsClientGenerics(private val types: Types) : FluentClientGenerics
override val bounds = writable { }

/** Bounds for generated `send()` functions */
override fun sendBounds(input: Symbol, output: Symbol, error: RuntimeType): Writable = writable { }
override fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: RuntimeType, retryPolicy: Writable): Writable = writable { }

override fun toGenericsGenerator(): GenericsGenerator {
return GenericsGenerator()
Expand All @@ -98,7 +98,7 @@ class AwsFluentClientDecorator : RustCodegenDecorator<ClientCodegenContext> {
AwsPresignedFluentBuilderMethod(runtimeConfig),
AwsFluentClientDocs(codegenContext),
),
retryPolicyType = runtimeConfig.awsHttp().asType().member("retry::AwsErrorRetryPolicy"),
retryPolicy = runtimeConfig.awsHttp().asType().member("retry::AwsErrorRetryPolicy").writable,
).render(rustCrate)
rustCrate.withModule(FluentClientGenerator.customizableOperationModule) { writer ->
renderCustomizableOperationSendMethod(runtimeConfig, generics, writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,46 @@ sealed class RustType {

open val namespace: kotlin.String? = null

/**
* Get a writable for this `RustType`
*
* ```kotlin
* // Declare a RustType
* val t = RustType.Unit.writable
* // Then, invoke the writable directly
* t.invoke(writer)
* // OR template it out
* writer.rustInlineTemplate("#{t:W}", "t" to t)
* ```
*
* When formatted, the converted type will appear as such:
*
* | Type | Formatted |
* | -------------------------------------------------- | ------------------------------------------------------------------- |
* | RustType.Unit | () |
* | RustType.Bool | bool |
* | RustType.Float(32) | f32 |
* | RustType.Float(64) | f64 |
* | RustType.Integer(8) | i8 |
* | RustType.Integer(16) | i16 |
* | RustType.Integer(32) | i32 |
* | RustType.Integer(64) | i64 |
* | RustType.String | std::string::String |
* | RustType.Vec(RustType.String) | std::vec::Vec<std::string::String> |
* | RustType.Slice(RustType.String) | [std::string::String] |
* | RustType.HashMap(RustType.String, RustType.String) | std::collections::HashMap<std::string::String, std::string::String> |
* | RustType.HashSet(RustType.String) | std::vec::Vec<std::string::String> |
* | RustType.Reference("&", RustType.String) | &std::string::String |
* | RustType.Reference("&mut", RustType.String) | &mut std::string::String |
* | RustType.Reference("&'static", RustType.String) | &'static std::string::String |
* | RustType.Option(RustType.String) | std::option::Option<std::string::String> |
* | RustType.Box(RustType.String) | std::boxed::Box<std::string::String> |
* | RustType.Opaque("SoCool", "zelda_is") | zelda_is::SoCool |
* | RustType.Opaque("SoCool") | SoCool |
* | RustType.Dyn(RustType.Opaque("Foo", "foo")) | dyn foo::Foo |
*/
val writable = writable { rustInlineTemplate("#{this}", "this" to this@RustType) }

object Unit : RustType() {
override val name: kotlin.String = "()"
}
Expand Down Expand Up @@ -186,7 +226,13 @@ fun RustType.render(fullyQualified: Boolean = true): String {
is RustType.Slice -> "[${this.member.render(fullyQualified)}]"
is RustType.HashMap -> "${this.name}<${this.key.render(fullyQualified)}, ${this.member.render(fullyQualified)}>"
is RustType.HashSet -> "${this.name}<${this.member.render(fullyQualified)}>"
is RustType.Reference -> "&${this.lifetime?.let { "'$it" } ?: ""} ${this.member.render(fullyQualified)}"
is RustType.Reference -> {
if (this.lifetime == "&") {
"&${this.member.render(fullyQualified)}"
} else {
"&${this.lifetime?.let { "'$it" } ?: ""} ${this.member.render(fullyQualified)}"
}
}
is RustType.Option -> "${this.name}<${this.member.render(fullyQualified)}>"
is RustType.Box -> "${this.name}<${this.member.render(fullyQualified)}>"
is RustType.Dyn -> "${this.name} ${this.member.render(fullyQualified)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,8 @@ class RustWriter private constructor(
* Formatter to enable formatting any [writable] with the #W formatter.
*/
inner class RustWriteableInjector : BiFunction<Any, String, String> {
@Suppress("UNCHECKED_CAST")
override fun apply(t: Any, u: String): String {
val func = t as RustWriter.() -> Unit
val func = t as? RustWriter.() -> Unit ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)")
val innerWriter = RustWriter(filename, namespace, printWarning = false)
func(innerWriter)
innerWriter.dependencies.forEach { addDependency(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package software.amazon.smithy.rust.codegen.rustlang

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.GenericsGenerator
import software.amazon.smithy.rust.codegen.util.PANIC

typealias Writable = RustWriter.() -> Unit

Expand Down Expand Up @@ -56,13 +56,21 @@ fun rustTypeParameters(
val iterator: Iterator<Any> = typeParameters.iterator()
while (iterator.hasNext()) {
when (val typeParameter = iterator.next()) {
is Symbol, is RustType.Unit, is RuntimeType -> rustInlineTemplate("#{it}", "it" to typeParameter)
is Symbol, is RuntimeType, is RustType -> rustInlineTemplate("#{it}", "it" to typeParameter)
is String -> rustInlineTemplate(typeParameter)
is GenericsGenerator -> rustInlineTemplate(
"#{gg:W}",
"gg" to typeParameter.declaration(withAngleBrackets = false),
)
else -> PANIC("Unhandled type '$typeParameter' encountered by rustTypeParameters writer")
else -> {
// Check if it's a writer. If it is, invoke it; Else, throw a codegen error.
val func = typeParameter as? RustWriter.() -> Unit
if (func != null) {
func.invoke(this)
} else {
throw CodegenException("Unhandled type '$typeParameter' encountered by rustTypeParameters writer")
}
}
}

if (iterator.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rustInlineTemplate
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.util.orNull
import java.util.Optional

Expand Down Expand Up @@ -123,6 +125,11 @@ data class RuntimeConfig(
* name, but also ensure that we automatically add any dependencies **as they are used**.
*/
data class RuntimeType(val name: String?, val dependency: RustDependency?, val namespace: String) {
/**
* Get a writable for this `RuntimeType`
*/
val writable = writable { rustInlineTemplate("#{this:T}", "this" to this@RuntimeType) }

/**
* Convert this [RuntimeType] into a [Symbol].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ class PaginatorGenerator private constructor(
service: ServiceShape,
operation: OperationShape,
private val generics: FluentClientGenerics,
retryPolicy: Writable = RustType.Unit.writable,
) {

companion object {
fun paginatorType(
coreCodegenContext: CoreCodegenContext,
generics: FluentClientGenerics,
operationShape: OperationShape,
retryPolicy: Writable = RustType.Unit.writable,
): RuntimeType? {
return if (operationShape.isPaginated(coreCodegenContext.model)) {
PaginatorGenerator(
Expand All @@ -63,6 +64,7 @@ class PaginatorGenerator private constructor(
coreCodegenContext.serviceShape,
operationShape,
generics,
retryPolicy,
).paginatorType()
} else {
null
Expand All @@ -82,7 +84,8 @@ class PaginatorGenerator private constructor(
)

private val inputType = symbolProvider.toSymbol(operation.inputShape(model))
private val outputType = operation.outputShape(model)
private val outputShape = operation.outputShape(model)
private val outputType = symbolProvider.toSymbol(outputShape)
private val errorType = operation.errorSymbol(model, symbolProvider, CodegenTarget.CLIENT)

private fun paginatorType(): RuntimeType = RuntimeType.forInlineFun(
Expand All @@ -95,12 +98,12 @@ class PaginatorGenerator private constructor(
"generics" to generics.decl,
"bounds" to generics.bounds,
"page_size_setter" to pageSizeSetter(),
"send_bounds" to generics.sendBounds(inputType, symbolProvider.toSymbol(outputType), errorType),
"send_bounds" to generics.sendBounds(symbolProvider.toSymbol(operation), outputType, errorType, retryPolicy),

// Operation Types
"operation" to symbolProvider.toSymbol(operation),
"Input" to inputType,
"Output" to symbolProvider.toSymbol(outputType),
"Output" to outputType,
"Error" to errorType,
"Builder" to operation.inputShape(model).builderSymbol(symbolProvider),

Expand All @@ -118,7 +121,7 @@ class PaginatorGenerator private constructor(
/** Generate the paginator struct & impl **/
private fun generate() = writable {
val outputTokenLens = NestedAccessorGenerator(symbolProvider).generateBorrowingAccessor(
outputType,
outputShape,
paginationInfo.outputTokenMemberPath,
)
val inputTokenMember = symbolProvider.toMemberName(paginationInfo.inputTokenMember)
Expand Down Expand Up @@ -173,7 +176,7 @@ class PaginatorGenerator private constructor(
let done = match resp {
Ok(ref resp) => {
let new_token = #{output_token}(resp);
let is_empty = ${nextTokenEmpty("new_token")};
let is_empty = new_token.map(|token| token.is_empty()).unwrap_or(true);
if !is_empty && new_token == input.$inputTokenMember.as_ref() {
let _ = tx.send(Err(#{SdkError}::ConstructionFailure("next token did not change, aborting paginator. This indicates an SDK or AWS service bug.".into()))).await;
return;
Expand Down Expand Up @@ -259,18 +262,14 @@ class PaginatorGenerator private constructor(
""",
"extract_items" to NestedAccessorGenerator(symbolProvider).generateOwnedAccessor(
outputType,
outputShape,
paginationInfo.itemsMemberPath,
),
*codegenScope,
)
}
}

private fun nextTokenEmpty(token: String): String {
return "$token.map(|token|token.is_empty()).unwrap_or(true)"
}

private fun pageSizeSetter() = writable {
paginationInfo.pageSizeMember.orNull()?.also {
val memberName = symbolProvider.toMemberName(it)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asArgumentType
import software.amazon.smithy.rust.codegen.rustlang.asOptional
import software.amazon.smithy.rust.codegen.rustlang.asType
Expand Down Expand Up @@ -66,7 +67,7 @@ class FluentClientGenerator(
client = CargoDependency.SmithyClient(codegenContext.runtimeConfig).asType(),
),
private val customizations: List<FluentClientCustomization> = emptyList(),
private val retryPolicyType: RuntimeType? = null,
private val retryPolicy: Writable = RustType.Unit.writable,
) {
companion object {
fun clientOperationFnName(operationShape: OperationShape, symbolProvider: RustSymbolProvider): String =
Expand Down Expand Up @@ -274,7 +275,6 @@ class FluentClientGenerator(
"client" to clientDep.asType(),
"bounds" to generics.bounds,
) {
val inputType = symbolProvider.toSymbol(operation.inputShape(model))
val outputType = symbolProvider.toSymbol(operation.outputShape(model))
val errorType = operation.errorSymbol(model, symbolProvider, CodegenTarget.CLIENT)

Expand Down Expand Up @@ -322,14 +322,14 @@ class FluentClientGenerator(
"OperationOutput" to outputType,
"SdkError" to runtimeConfig.smithyHttp().member("result::SdkError"),
"SdkSuccess" to runtimeConfig.smithyHttp().member("result::SdkSuccess"),
"send_bounds" to generics.sendBounds(inputType, outputType, errorType),
"send_bounds" to generics.sendBounds(operationSymbol, outputType, errorType, retryPolicy),
"customizable_op_type_params" to rustTypeParameters(
symbolProvider.toSymbol(operation),
retryPolicyType ?: RustType.Unit,
retryPolicy,
generics.toGenericsGenerator(),
),
)
PaginatorGenerator.paginatorType(codegenContext, generics, operation)?.also { paginatorType ->
PaginatorGenerator.paginatorType(codegenContext, generics, operation, retryPolicy)?.also { paginatorType ->
rustTemplate(
"""
/// Create a paginator for this request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ interface FluentClientGenerics {
val bounds: Writable

/** Bounds for generated `send()` functions */
fun sendBounds(input: Symbol, output: Symbol, error: RuntimeType): Writable
fun sendBounds(input: Symbol, output: Symbol, error: RuntimeType, retryPolicy: Writable): Writable

/** Convert this `FluentClientGenerics` into the more general `GenericsGenerator` */
fun toGenericsGenerator(): GenericsGenerator
Expand Down Expand Up @@ -70,21 +70,22 @@ data class FlexibleClientGenerics(
}

/** Bounds for generated `send()` functions */
override fun sendBounds(input: Symbol, output: Symbol, error: RuntimeType): Writable = writable {
override fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: RuntimeType, retryPolicy: Writable): Writable = writable {
rustTemplate(
"""
where
R::Policy: #{client}::bounds::SmithyRetryPolicy<
#{Input}OperationOutputAlias,
#{Output},
#{Error},
#{Input}OperationRetryAlias
#{Operation},
#{OperationOutput},
#{OperationError},
#{RetryPolicy:W}
>
""",
"client" to client,
"Input" to input,
"Output" to output,
"Error" to error,
"Operation" to operation,
"OperationOutput" to operationOutput,
"OperationError" to operationError,
"RetryPolicy" to retryPolicy,
)
}

Expand Down
Loading

0 comments on commit 4809a5b

Please sign in to comment.