Skip to content

Commit

Permalink
Enable the crate reorganization for generic clients
Browse files Browse the repository at this point in the history
  • Loading branch information
jdisanti committed Mar 11, 2023
1 parent b3b1182 commit f9c74d1
Show file tree
Hide file tree
Showing 26 changed files with 60 additions and 45 deletions.
2 changes: 1 addition & 1 deletion buildSrc/src/main/kotlin/CodegenTestCommon.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private fun generateSmithyBuild(projectDir: String, pluginName: String, tests: L
"relativePath": "$projectDir/rust-runtime"
},
"codegen": {
"enableNewCrateOrganizationScheme": false
"enableNewCrateOrganizationScheme": true
},
"service": "${it.service}",
"module": "${it.module}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDocs
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.EscapeFor
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
Expand Down Expand Up @@ -149,7 +150,7 @@ class ClientModuleDocProvider(
by calling the `customize()` method on the builder returned from a client
operation call. For example, this can be used to add an additional HTTP header:
```no_run
```ignore
## async fn wrapper() -> Result<(), $moduleUseName::Error> {
## let client: $moduleUseName::Client = unimplemented!();
use #{http}::header::{HeaderName, HeaderValue};
Expand Down Expand Up @@ -217,7 +218,7 @@ object ClientModuleProvider : ModuleProvider {
val operationShape = shape.findOperation(context.model)
val contextName = operationShape.contextName(context.serviceShape)
val operationModuleName =
RustReservedWords.escapeIfNeeded(contextName.toSnakeCase())
RustReservedWords.escapeIfNeeded(contextName.toSnakeCase(), EscapeFor.ModuleName)
return RustModule.public(
operationModuleName,
parent = ClientRustModule.Operation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ data class ClientCodegenConfig(
private const val defaultIncludeFluentClient = true
private const val defaultAddMessageToErrors = true
private val defaultEventStreamAllowList: Set<String> = emptySet()
private const val defaultEnableNewCrateOrganizationScheme = false
private const val defaultEnableNewCrateOrganizationScheme = true

fun fromCodegenConfigAndNode(coreCodegenConfig: CoreCodegenConfig, node: Optional<ObjectNode>) =
if (node.isPresent) {
Expand All @@ -109,7 +109,7 @@ data class ClientCodegenConfig(
renameExceptions = node.get().getBooleanMemberOrDefault("renameErrors", defaultRenameExceptions),
includeFluentClient = node.get().getBooleanMemberOrDefault("includeFluentClient", defaultIncludeFluentClient),
addMessageToErrors = node.get().getBooleanMemberOrDefault("addMessageToErrors", defaultAddMessageToErrors),
enableNewCrateOrganizationScheme = node.get().getBooleanMemberOrDefault("enableNewCrateOrganizationScheme", false),
enableNewCrateOrganizationScheme = node.get().getBooleanMemberOrDefault("enableNewCrateOrganizationScheme", defaultEnableNewCrateOrganizationScheme),
)
} else {
ClientCodegenConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ErrorGenerator(
override fun section(section: BuilderSection): Writable = writable {
when (section) {
is BuilderSection.AdditionalFields -> {
rust("meta: Option<#T>,", errorMetadata(runtimeConfig))
rust("meta: std::option::Option<#T>,", errorMetadata(runtimeConfig))
}

is BuilderSection.AdditionalMethods -> {
Expand All @@ -102,7 +102,7 @@ class ErrorGenerator(
}
/// Sets error metadata
pub fn set_meta(&mut self, meta: Option<#{error_metadata}>) -> &mut Self {
pub fn set_meta(&mut self, meta: std::option::Option<#{error_metadata}>) -> &mut Self {
self.meta = meta;
self
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class OperationErrorGenerator(
"""
fn create_unhandled_error(
source: Box<dyn std::error::Error + Send + Sync + 'static>,
meta: Option<#T>
meta: std::option::Option<#T>
) -> Self
""",
errorMetadata,
Expand Down Expand Up @@ -152,11 +152,11 @@ class OperationErrorGenerator(
"impl #T for ${errorSymbol.name}",
RuntimeType.provideErrorKind(symbolProvider.config.runtimeConfig),
) {
rustBlock("fn code(&self) -> Option<&str>") {
rustBlock("fn code(&self) -> std::option::Option<&str>") {
rust("#T::code(self)", RuntimeType.provideErrorMetadataTrait(runtimeConfig))
}

rustBlock("fn retryable_error_kind(&self) -> Option<#T>", retryErrorKindT) {
rustBlock("fn retryable_error_kind(&self) -> std::option::Option<#T>", retryErrorKindT) {
val retryableVariants = errors.filter { it.hasTrait<RetryableTrait>() }
if (retryableVariants.isEmpty()) {
rust("None")
Expand Down Expand Up @@ -216,7 +216,7 @@ class OperationErrorGenerator(
}

writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) {
rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) {
rustBlock("fn source(&self) -> std::option::Option<&(dyn #T + 'static)>", RuntimeType.StdError) {
delegateToVariants(errors) {
writable {
rust("Some(_inner)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class RequestBindingGenerator(
val combinedArgs = listOf(formatString, *args.toTypedArray())
writer.addImport(RuntimeType.stdFmt.resolve("Write").toSymbol(), null)
writer.rustBlockTemplate(
"fn uri_base(_input: &#{Input}, output: &mut String) -> Result<(), #{BuildError}>",
"fn uri_base(_input: &#{Input}, output: &mut String) -> std::result::Result<(), #{BuildError}>",
*codegenScope,
) {
httpTrait.uri.labels.map { label ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ open class MakeOperationGenerator(
Attribute.AllowClippyLetAndReturn.render(implBlockWriter)
// Allows builders that don’t consume the input borrow
Attribute.AllowClippyNeedlessBorrow.render(implBlockWriter)

implBlockWriter.rustBlockTemplate(
"$fnType $functionName($self, _config: &#{config}::Config) -> $returnType",
*codegenScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ internal class ApiKeyAuthDecoratorTest {
let conf = $moduleName::Config::builder()
.api_key(AuthApiKey::new(api_key_value))
.build();
let operation = $moduleName::input::SomeOperationInput::builder()
let operation = $moduleName::operation::some_operation::SomeOperationInput::builder()
.build()
.expect("input is valid")
.make_operation(&conf)
Expand Down Expand Up @@ -87,7 +87,7 @@ internal class ApiKeyAuthDecoratorTest {
let conf = $moduleName::Config::builder()
.api_key(AuthApiKey::new(api_key_value))
.build();
let operation = $moduleName::input::SomeOperationInput::builder()
let operation = $moduleName::operation::some_operation::SomeOperationInput::builder()
.build()
.expect("input is valid")
.make_operation(&conf)
Expand Down Expand Up @@ -149,7 +149,7 @@ internal class ApiKeyAuthDecoratorTest {
let conf = $moduleName::Config::builder()
.api_key(AuthApiKey::new(api_key_value))
.build();
let operation = $moduleName::input::SomeOperationInput::builder()
let operation = $moduleName::operation::some_operation::SomeOperationInput::builder()
.build()
.expect("input is valid")
.make_operation(&conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ internal class HttpVersionListGeneratorTest {
"""
async fn test_http_version_list_defaults() {
let conf = $moduleName::Config::builder().build();
let op = $moduleName::input::SayHelloInput::builder()
let op = $moduleName::operation::say_hello::SayHelloInput::builder()
.greeting("hello")
.build().expect("valid operation")
.make_operation(&conf).await.expect("hello is a valid prefix");
Expand Down Expand Up @@ -113,7 +113,7 @@ internal class HttpVersionListGeneratorTest {
"""
async fn test_http_version_list_defaults() {
let conf = $moduleName::Config::builder().build();
let op = $moduleName::input::SayHelloInput::builder()
let op = $moduleName::operation::say_hello::SayHelloInput::builder()
.greeting("hello")
.build().expect("valid operation")
.make_operation(&conf).await.expect("hello is a valid prefix");
Expand Down Expand Up @@ -181,7 +181,7 @@ internal class HttpVersionListGeneratorTest {
"""
async fn test_http_version_list_defaults() {
let conf = $moduleName::Config::builder().build();
let op = $moduleName::input::SayHelloInput::builder()
let op = $moduleName::operation::say_hello::SayHelloInput::builder()
.build().expect("valid operation")
.make_operation(&conf).await.unwrap();
let properties = op.properties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class EndpointsDecoratorTest {
"""
async fn endpoint_params_are_set() {
let conf = $moduleName::Config::builder().a_string_param("hello").a_bool_param(false).build();
let operation = $moduleName::input::TestOperationInput::builder()
let operation = $moduleName::operation::test_operation::TestOperationInput::builder()
.bucket("bucket-name").build().expect("input is valid")
.make_operation(&conf).await.expect("valid operation");
use $moduleName::endpoint::{Params};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class ClientEnumGeneratorTest {
"""
assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue);
assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_);
assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned())));
assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::primitives::UnknownVariantValue("SomethingNew".to_owned())));
""".trimIndent(),
)
}
Expand Down Expand Up @@ -150,7 +150,7 @@ class ClientEnumGeneratorTest {
assert_eq!(instance.as_str(), "t2.micro");
assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano);
// round trip unknown variants:
assert_eq!(InstanceType::from("other"), InstanceType::Unknown(crate::types::UnknownVariantValue("other".to_owned())));
assert_eq!(InstanceType::from("other"), InstanceType::Unknown(crate::primitives::UnknownVariantValue("other".to_owned())));
assert_eq!(InstanceType::from("other").as_str(), "other");
""",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ internal class EndpointTraitBindingsTest {
"""
async fn test_endpoint_prefix() {
let conf = $moduleName::Config::builder().build();
$moduleName::input::SayHelloInput::builder()
$moduleName::operation::say_hello::SayHelloInput::builder()
.greeting("hey there!").build().expect("input is valid")
.make_operation(&conf).await.expect_err("no spaces or exclamation points in ep prefixes");
let op = $moduleName::input::SayHelloInput::builder()
let op = $moduleName::operation::say_hello::SayHelloInput::builder()
.greeting("hello")
.build().expect("valid operation")
.make_operation(&conf).await.expect("hello is a valid prefix");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ internal class PaginatorGeneratorTest {
clientIntegrationTest(model) { clientCodegenContext, rustCrate ->
rustCrate.integrationTest("paginators_generated") {
Attribute.AllowUnusedImports.render(this)
rust("use ${clientCodegenContext.moduleUseName()}::paginator::PaginatedListPaginator;")
rust("use ${clientCodegenContext.moduleUseName()}::operation::paginated_list::paginator::PaginatedListPaginator;")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ErrorGeneratorTest {
@Test
fun `generate error structure and builder`() {
clientIntegrationTest(model) { _, rustCrate ->
rustCrate.withFile("src/error.rs") {
rustCrate.withFile("src/types/error.rs") {
rust(
"""
##[test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class OperationErrorGeneratorTest {
unitTest(
name = "generates_combined_error_enums",
test = """
use crate::operation::greeting::GreetingError;
let error = GreetingError::InvalidGreeting(
InvalidGreeting::builder()
.message("an error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class ProtocolTestGeneratorTest {
private fun testService(
httpRequestBuilder: String,
body: String = "${correctBody.dq()}.to_string()",
correctResponse: String = """Ok(crate::output::SayHelloOutput::builder().value("hey there!").build())""",
correctResponse: String = """Ok(crate::operation::say_hello::SayHelloOutput::builder().value("hey there!").build())""",
): Path {
val codegenDecorator = object : ClientCodegenDecorator {
override val name: String = "mock"
Expand Down Expand Up @@ -256,7 +256,7 @@ class ProtocolTestGeneratorTest {
.header("X-Greeting", "Hi")
.method("POST")
""",
correctResponse = "Ok(crate::output::SayHelloOutput::builder().build())",
correctResponse = "Ok(crate::operation::say_hello::SayHelloOutput::builder().build())",
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class AwsQueryCompatibleTest {
}"##,
)
.unwrap();
let some_operation = $moduleName::operation::SomeOperation::new();
let some_operation = $moduleName::operation::some_operation::SomeOperation::new();
let error = some_operation
.parse(&response.map(bytes::Bytes::from))
.err()
Expand Down Expand Up @@ -136,7 +136,7 @@ class AwsQueryCompatibleTest {
}"##,
)
.unwrap();
let some_operation = $moduleName::operation::SomeOperation::new();
let some_operation = $moduleName::operation::some_operation::SomeOperation::new();
let error = some_operation
.parse(&response.map(bytes::Bytes::from))
.err()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class ClientEventStreamMarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
clientIntegrationTest(testCase.model) { _, rustCrate ->
clientIntegrationTest(testCase.model) { codegenContext, rustCrate ->
rustCrate.testModule {
writeMarshallTestCases(testCase, optionalBuilderInputs = false)
writeMarshallTestCases(codegenContext, testCase, optionalBuilderInputs = false)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ class ClientEventStreamUnmarshallerGeneratorTest {
clientIntegrationTest(
testCase.model,
IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true),
) { _, rustCrate ->
) { codegenContext, rustCrate ->
val generator = "crate::event_stream_serde::TestStreamUnmarshaller"

rustCrate.testModule {
rust("##![allow(unused_imports, dead_code)]")
writeUnmarshallTestCases(testCase, optionalBuilderInputs = false)
writeUnmarshallTestCases(codegenContext, testCase, optionalBuilderInputs = false)

unitTest(
"unknown_message",
Expand All @@ -52,7 +52,7 @@ class ClientEventStreamUnmarshallerGeneratorTest {
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap()) {
TestStreamError::Unhandled(err) => {
let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err));
let message = format!("{}", crate::error::DisplayErrorContext(&err));
let expected = "message: \"unmodeled error\"";
assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.lookup

object EventStreamMarshallTestCases {
fun RustWriter.writeMarshallTestCases(
codegenContext: CodegenContext,
testCase: EventStreamTestModels.TestCase,
optionalBuilderInputs: Boolean,
) {
Expand All @@ -29,12 +32,13 @@ object EventStreamMarshallTestCases {
vararg ctx: Pair<String, Any>,
): Writable = conditionalBuilderInput(input, conditional = optionalBuilderInputs, ctx = ctx)

val typesModule = codegenContext.symbolProvider.moduleForShape(codegenContext.model.lookup("test#TestStruct"))
rustTemplate(
"""
use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage};
use std::collections::HashMap;
use aws_smithy_types::{Blob, DateTime};
use crate::model::*;
use ${typesModule.fullyQualifiedPath()}::*;
use #{validate_body};
use #{MediaType};
Expand Down
Loading

0 comments on commit f9c74d1

Please sign in to comment.