diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/CodegenVisitor.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/CodegenVisitor.java deleted file mode 100644 index d66695a2..00000000 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/CodegenVisitor.java +++ /dev/null @@ -1,370 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.smithy.python.codegen; - -import static java.lang.String.format; - -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.ServiceLoader; -import java.util.Set; -import java.util.logging.Logger; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import software.amazon.smithy.build.FileManifest; -import software.amazon.smithy.build.PluginContext; -import software.amazon.smithy.codegen.core.CodegenException; -import software.amazon.smithy.codegen.core.SmithyIntegration; -import software.amazon.smithy.codegen.core.Symbol; -import software.amazon.smithy.codegen.core.SymbolProvider; -import software.amazon.smithy.codegen.core.TopologicalIndex; -import software.amazon.smithy.model.Model; -import software.amazon.smithy.model.neighbor.Walker; -import software.amazon.smithy.model.shapes.CollectionShape; -import software.amazon.smithy.model.shapes.ListShape; -import software.amazon.smithy.model.shapes.MapShape; -import software.amazon.smithy.model.shapes.ServiceShape; -import software.amazon.smithy.model.shapes.SetShape; -import software.amazon.smithy.model.shapes.Shape; -import software.amazon.smithy.model.shapes.ShapeId; -import software.amazon.smithy.model.shapes.ShapeVisitor; -import software.amazon.smithy.model.shapes.StringShape; -import software.amazon.smithy.model.shapes.StructureShape; -import software.amazon.smithy.model.shapes.UnionShape; -import software.amazon.smithy.model.traits.EnumTrait; -import software.amazon.smithy.model.transform.ModelTransformer; -import software.amazon.smithy.python.codegen.integration.ProtocolGenerator; -import software.amazon.smithy.python.codegen.integration.PythonIntegration; -import software.amazon.smithy.utils.CodeInterceptor; -import software.amazon.smithy.utils.CodeSection; - -/** - * Orchestrates Python client generation. - */ -final class CodegenVisitor extends ShapeVisitor.Default { - - private static final Logger LOGGER = Logger.getLogger(CodegenVisitor.class.getName()); - - private final PythonSettings settings; - private final Model model; - private final Model modelWithoutTraitShapes; - private final ServiceShape service; - private final FileManifest fileManifest; - private final SymbolProvider symbolProvider; - private final PythonDelegator writers; - private Set recursiveShapes; - private final List integrations; - private final GenerationContext generationContext; - private final ProtocolGenerator protocolGenerator; - private final ApplicationProtocol applicationProtocol; - - CodegenVisitor(PluginContext context) { - // Load all integrations. - ClassLoader loader = context.getPluginClassLoader().orElse(getClass().getClassLoader()); - LOGGER.info("Attempting to discover PythonIntegrations from the classpath..."); - List loadedIntegrations = new ArrayList<>(); - ServiceLoader.load(PythonIntegration.class, loader) - .forEach(integration -> { - LOGGER.info(() -> "Adding PythonIntegration: " + integration.getClass().getName()); - loadedIntegrations.add(integration); - }); - integrations = Collections.unmodifiableList(SmithyIntegration.sort(loadedIntegrations)); - - // Allow integrations to modify the model before generation - PythonSettings pythonSettings = PythonSettings.from(context.getSettings()); - ModelTransformer transformer = ModelTransformer.create(); - Model modifiedModel = transformer.createDedicatedInputAndOutput(context.getModel(), "Input", "Output"); - for (PythonIntegration integration : integrations) { - modifiedModel = integration.preprocessModel(modifiedModel, pythonSettings); - } - - settings = pythonSettings; - model = modifiedModel; - modelWithoutTraitShapes = transformer.getModelWithoutTraitShapes(model); - service = settings.getService(model); - fileManifest = context.getFileManifest(); - LOGGER.info(() -> "Generating Python client for service " + service.getId()); - - // Decorate the symbol provider using integrations. - SymbolProvider resolvedProvider = PythonCodegenPlugin.createSymbolProvider(model, settings); - for (PythonIntegration integration : integrations) { - resolvedProvider = integration.decorateSymbolProvider(model, settings, resolvedProvider); - } - symbolProvider = SymbolProvider.cache(resolvedProvider); - - // Resolve the nullable protocol generator and application protocol. - protocolGenerator = resolveProtocolGenerator(integrations, service, settings); - applicationProtocol = protocolGenerator == null - ? ApplicationProtocol.createDefaultHttpApplicationProtocol() - : protocolGenerator.getApplicationProtocol(); - - // Finalize the generation context - generationContext = GenerationContext.builder() - .model(model) - .settings(settings) - .symbolProvider(symbolProvider) - .fileManifest(fileManifest) - .build(); - - // Gather all registered interceptors from integrations - List> interceptors = new ArrayList<>(); - for (PythonIntegration integration : integrations) { - interceptors.addAll(integration.interceptors(generationContext)); - } - - writers = new PythonDelegator(fileManifest, symbolProvider, settings); - writers.setInterceptors(interceptors); - } - - private ProtocolGenerator resolveProtocolGenerator( - List integrations, - ServiceShape service, - PythonSettings settings - ) { - // Collect all the supported protocol generators. - Map generators = new HashMap<>(); - for (PythonIntegration integration : integrations) { - for (ProtocolGenerator generator : integration.getProtocolGenerators()) { - generators.put(generator.getProtocol(), generator); - } - } - - ShapeId protocolName; - try { - protocolName = settings.resolveServiceProtocol(model, service, generators.keySet()); - } catch (CodegenException e) { - LOGGER.warning("Unable to find a protocol generator for " + service.getId() + ": " + e.getMessage()); - protocolName = null; - } - - return protocolName != null ? generators.get(protocolName) : null; - } - - void execute() { - // Generate models that are connected to the service being generated. - LOGGER.fine("Walking shapes from " + service.getId() + " to find shapes to generate"); - Collection shapeSet = new Walker(modelWithoutTraitShapes).walkShapes(service); - Model prunedModel = Model.builder().addShapes(shapeSet).build(); - - generateDefaultTimestamp(prunedModel); - generateServiceErrors(); - - // Sort shapes in a reverse topological order so that we can reduce the - // number of necessary forward references. - var topologicalIndex = TopologicalIndex.of(prunedModel); - recursiveShapes = topologicalIndex.getRecursiveShapes(); - for (Shape shape : topologicalIndex.getOrderedShapes()) { - shape.accept(this); - } - for (Shape shape : topologicalIndex.getRecursiveShapes()) { - shape.accept(this); - } - - SetupGenerator.generateSetup(settings, writers); - - LOGGER.fine("Flushing python writers"); - writers.flushWriters(); - generateInits(); - - // Allows integrations to interact with the generated output files - // in the file manifest. - for (PythonIntegration integration : integrations) { - integration.customize(generationContext); - } - - postProcess(); - } - - private void generateServiceErrors() { - var serviceError = CodegenUtils.getServiceError(settings); - writers.useFileWriter(serviceError.getDefinitionFile(), serviceError.getNamespace(), writer -> { - // TODO: subclass a shared error - writer.openBlock("class $L(Exception):", "", serviceError.getName(), () -> { - writer.writeDocs("Base error for all errors in the service."); - writer.write("pass"); - }); - }); - - var apiError = CodegenUtils.getApiError(settings); - writers.useFileWriter(apiError.getDefinitionFile(), apiError.getNamespace(), writer -> { - writer.addStdlibImport("typing", "Generic"); - writer.addStdlibImport("typing", "TypeVar"); - writer.write("T = TypeVar('T')"); - writer.openBlock("class $L($T, Generic[T]):", "", apiError.getName(), serviceError, () -> { - writer.writeDocs("Base error for all api errors in the service."); - writer.write("code: T"); - writer.openBlock("def __init__(self, message: str):", "", () -> { - writer.write("super().__init__(message)"); - writer.write("self.message = message"); - }); - }); - - var unknownApiError = CodegenUtils.getUnknownApiError(settings); - writer.addStdlibImport("typing", "Literal"); - writer.openBlock("class $L($T[Literal['Unknown']]):", "", unknownApiError.getName(), apiError, () -> { - writer.writeDocs("Error representing any unknown api errors"); - writer.write("code: Literal['Unknown'] = 'Unknown'"); - }); - }); - - - } - - /** - * Creates __init__.py files where not already present. - */ - private void generateInits() { - var directories = fileManifest.getFiles().stream() - .filter(path -> !path.getParent().equals(fileManifest.getBaseDir())) - .collect(Collectors.groupingBy(Path::getParent, Collectors.toSet())); - for (var entry : directories.entrySet()) { - var initPath = entry.getKey().resolve("__init__.py"); - if (!entry.getValue().contains(initPath)) { - fileManifest.writeFile(initPath, "# Code generated by smithy-python-codegen DO NOT EDIT.\n"); - } - } - } - - private void postProcess() { - Pattern versionPattern = Pattern.compile("Python \\d\\.(?\\d+)\\.(?\\d+)"); - - String output; - try { - LOGGER.info("Attempting to discover python version"); - output = CodegenUtils.runCommand("python3 --version", fileManifest.getBaseDir()).strip(); - } catch (CodegenException e) { - LOGGER.warning("Unable to find python on the path. Skipping formatting and type checking."); - return; - } - var matcher = versionPattern.matcher(output); - if (!matcher.find()) { - LOGGER.warning("Unable to parse python version string. Skipping formatting and type checking."); - } - int minorVersion = Integer.parseInt(matcher.group("minor")); - if (minorVersion < 9) { - LOGGER.warning(format(""" - Found incompatible python version 3.%s.%s, expected 3.9.0 or greater. \ - Skipping formatting and type checking.""", - matcher.group("minor"), matcher.group("patch"))); - return; - } - LOGGER.info("Verifying python files"); - for (var file : fileManifest.getFiles()) { - var fileName = file.getFileName(); - if (fileName == null || !fileName.endsWith(".py")) { - continue; - } - CodegenUtils.runCommand("python3 " + file, fileManifest.getBaseDir()); - } - formatCode(); - runMypy(); - } - - private void formatCode() { - try { - CodegenUtils.runCommand("python3 -m black -h", fileManifest.getBaseDir()); - } catch (CodegenException e) { - LOGGER.warning("Unable to find the python package black. Skipping formatting."); - return; - } - LOGGER.info("Running code formatter on generated code"); - CodegenUtils.runCommand("python3 -m black . --exclude \"\"", fileManifest.getBaseDir()); - } - - private void runMypy() { - try { - CodegenUtils.runCommand("python3 -m mypy -h", fileManifest.getBaseDir()); - } catch (CodegenException e) { - LOGGER.warning("Unable to find the python package mypy. Skipping type checking."); - return; - } - LOGGER.info("Running mypy on generated code"); - CodegenUtils.runCommand("python3 -m mypy .", fileManifest.getBaseDir()); - } - - private void generateDefaultTimestamp(Model model) { - var timestamp = CodegenUtils.getDefaultTimestamp(settings); - if (!model.getTimestampShapes().isEmpty()) { - writers.useFileWriter(timestamp.getDefinitionFile(), timestamp.getNamespace(), writer -> { - writer.addStdlibImport("datetime", "datetime"); - writer.write("$L = datetime(1970, 1, 1)", timestamp.getName()); - }); - } - } - - @Override - protected Void getDefault(Shape shape) { - return null; - } - - @Override - public Void stringShape(StringShape shape) { - if (shape.hasTrait(EnumTrait.class)) { - writers.useShapeWriter(shape, writer -> new EnumGenerator(model, symbolProvider, writer, shape).run()); - } - return null; - } - - @Override - public Void structureShape(StructureShape shape) { - writers.useShapeWriter(shape, writer -> new StructureGenerator( - model, settings, symbolProvider, writer, shape, recursiveShapes).run()); - return null; - } - - @Override - public Void unionShape(UnionShape shape) { - writers.useShapeWriter(shape, writer -> new UnionGenerator( - model, symbolProvider, writer, shape, recursiveShapes).run()); - return null; - } - - @Override - public Void listShape(ListShape shape) { - return collectionShape(shape); - } - - @Override - public Void setShape(SetShape shape) { - return collectionShape(shape); - } - - private Void collectionShape(CollectionShape shape) { - var optionalAsDictSymbol = symbolProvider.toSymbol(shape).getProperty("asDict", Symbol.class); - optionalAsDictSymbol.ifPresent(asDictSymbol -> { - writers.useFileWriter(asDictSymbol.getDefinitionFile(), asDictSymbol.getNamespace(), writer -> { - new CollectionGenerator(model, symbolProvider, writer, shape).run(); - }); - }); - return null; - } - - @Override - public Void mapShape(MapShape shape) { - var optionalAsDictSymbol = symbolProvider.toSymbol(shape).getProperty("asDict", Symbol.class); - optionalAsDictSymbol.ifPresent(asDictSymbol -> { - writers.useFileWriter(asDictSymbol.getDefinitionFile(), asDictSymbol.getNamespace(), writer -> { - new MapGenerator(model, symbolProvider, writer, shape).run(); - }); - }); - return null; - } -} diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonCodegen.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonCodegen.java new file mode 100644 index 00000000..06fca539 --- /dev/null +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonCodegen.java @@ -0,0 +1,283 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.python.codegen; + +import static java.lang.String.format; + +import java.nio.file.Path; +import java.util.logging.Logger; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import software.amazon.smithy.build.FileManifest; +import software.amazon.smithy.codegen.core.CodegenException; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.codegen.core.TopologicalIndex; +import software.amazon.smithy.codegen.core.WriterDelegator; +import software.amazon.smithy.codegen.core.directed.CreateContextDirective; +import software.amazon.smithy.codegen.core.directed.CreateSymbolProviderDirective; +import software.amazon.smithy.codegen.core.directed.CustomizeDirective; +import software.amazon.smithy.codegen.core.directed.DirectedCodegen; +import software.amazon.smithy.codegen.core.directed.GenerateEnumDirective; +import software.amazon.smithy.codegen.core.directed.GenerateErrorDirective; +import software.amazon.smithy.codegen.core.directed.GenerateServiceDirective; +import software.amazon.smithy.codegen.core.directed.GenerateStructureDirective; +import software.amazon.smithy.codegen.core.directed.GenerateUnionDirective; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.CollectionShape; +import software.amazon.smithy.model.shapes.ListShape; +import software.amazon.smithy.model.shapes.MapShape; +import software.amazon.smithy.model.shapes.SetShape; +import software.amazon.smithy.utils.SmithyUnstableApi; + +@SmithyUnstableApi +final class DirectedPythonCodegen implements DirectedCodegen { + + private static final Logger LOGGER = Logger.getLogger(DirectedPythonCodegen.class.getName()); + + @Override + public SymbolProvider createSymbolProvider(CreateSymbolProviderDirective directive) { + return new SymbolVisitor(directive.model(), directive.settings()); + } + + @Override + public GenerationContext createContext(CreateContextDirective directive) { + return GenerationContext.builder() + .model(directive.model()) + .settings(directive.settings()) + .symbolProvider(directive.symbolProvider()) + .fileManifest(directive.fileManifest()) + .writerDelegator(new PythonDelegator( + directive.fileManifest(), directive.symbolProvider(), directive.settings())) + .build(); + } + + @Override + public void generateService(GenerateServiceDirective directive) { + generateDefaultTimestamp(directive.model(), directive.settings(), directive.context().writerDelegator()); + generateServiceErrors(directive.settings(), directive.context().writerDelegator()); + } + + private void generateDefaultTimestamp(Model model, PythonSettings settings, WriterDelegator writers) { + var timestamp = CodegenUtils.getDefaultTimestamp(settings); + if (!model.getTimestampShapes().isEmpty()) { + writers.useFileWriter(timestamp.getDefinitionFile(), timestamp.getNamespace(), writer -> { + writer.addStdlibImport("datetime", "datetime"); + writer.write("$L = datetime(1970, 1, 1)", timestamp.getName()); + }); + } + } + + private void generateServiceErrors(PythonSettings settings, WriterDelegator writers) { + var serviceError = CodegenUtils.getServiceError(settings); + writers.useFileWriter(serviceError.getDefinitionFile(), serviceError.getNamespace(), writer -> { + // TODO: subclass a shared error + writer.openBlock("class $L(Exception):", "", serviceError.getName(), () -> { + writer.writeDocs("Base error for all errors in the service."); + writer.write("pass"); + }); + }); + + var apiError = CodegenUtils.getApiError(settings); + writers.useFileWriter(apiError.getDefinitionFile(), apiError.getNamespace(), writer -> { + writer.addStdlibImport("typing", "Generic"); + writer.addStdlibImport("typing", "TypeVar"); + writer.write("T = TypeVar('T')"); + writer.openBlock("class $L($T, Generic[T]):", "", apiError.getName(), serviceError, () -> { + writer.writeDocs("Base error for all api errors in the service."); + writer.write("code: T"); + writer.openBlock("def __init__(self, message: str):", "", () -> { + writer.write("super().__init__(message)"); + writer.write("self.message = message"); + }); + }); + + var unknownApiError = CodegenUtils.getUnknownApiError(settings); + writer.addStdlibImport("typing", "Literal"); + writer.openBlock("class $L($T[Literal['Unknown']]):", "", unknownApiError.getName(), apiError, () -> { + writer.writeDocs("Error representing any unknown api errors"); + writer.write("code: Literal['Unknown'] = 'Unknown'"); + }); + }); + } + + @Override + public void generateStructure(GenerateStructureDirective directive) { + directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> { + StructureGenerator generator = new StructureGenerator( + directive.model(), + directive.settings(), + directive.symbolProvider(), + writer, + directive.shape(), + TopologicalIndex.of(directive.model()).getRecursiveShapes() + ); + generator.run(); + }); + } + + @Override + public void generateError(GenerateErrorDirective directive) { + directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> { + StructureGenerator generator = new StructureGenerator( + directive.model(), + directive.settings(), + directive.symbolProvider(), + writer, + directive.shape(), + TopologicalIndex.of(directive.model()).getRecursiveShapes() + ); + generator.run(); + }); + } + + @Override + public void generateUnion(GenerateUnionDirective directive) { + directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> { + UnionGenerator generator = new UnionGenerator( + directive.model(), + directive.symbolProvider(), + writer, + directive.shape(), + TopologicalIndex.of(directive.model()).getRecursiveShapes() + ); + generator.run(); + }); + } + + @Override + public void generateEnumShape(GenerateEnumDirective directive) { + directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> { + EnumGenerator generator = new EnumGenerator( + directive.model(), + directive.symbolProvider(), + writer, + directive.shape().asStringShape().get() + ); + generator.run(); + }); + } + + @Override + public void customizeBeforeIntegrations(CustomizeDirective directive) { + generateDictHelpers(directive.context()); + generateInits(directive.fileManifest()); + } + + private void generateDictHelpers(GenerationContext context) { + context.model().shapes(ListShape.class).forEach(shape -> generateCollectionDictHelpers(context, shape)); + context.model().shapes(SetShape.class).forEach(shape -> generateCollectionDictHelpers(context, shape)); + context.model().shapes(MapShape.class).forEach(shape -> generateMapDictHelpers(context, shape)); + } + + private Void generateCollectionDictHelpers(GenerationContext context, CollectionShape shape) { + SymbolProvider symbolProvider = context.symbolProvider(); + WriterDelegator writers = context.writerDelegator(); + var optionalAsDictSymbol = symbolProvider.toSymbol(shape).getProperty("asDict", Symbol.class); + optionalAsDictSymbol.ifPresent(asDictSymbol -> { + writers.useFileWriter(asDictSymbol.getDefinitionFile(), asDictSymbol.getNamespace(), writer -> { + new CollectionGenerator(context.model(), symbolProvider, writer, shape).run(); + }); + }); + return null; + } + + public Void generateMapDictHelpers(GenerationContext context, MapShape shape) { + SymbolProvider symbolProvider = context.symbolProvider(); + WriterDelegator writers = context.writerDelegator(); + var optionalAsDictSymbol = symbolProvider.toSymbol(shape).getProperty("asDict", Symbol.class); + optionalAsDictSymbol.ifPresent(asDictSymbol -> { + writers.useFileWriter(asDictSymbol.getDefinitionFile(), asDictSymbol.getNamespace(), writer -> { + new MapGenerator(context.model(), symbolProvider, writer, shape).run(); + }); + }); + return null; + } + + /** + * Creates __init__.py files where not already present. + */ + private void generateInits(FileManifest fileManifest) { + var directories = fileManifest.getFiles().stream() + .filter(path -> !path.getParent().equals(fileManifest.getBaseDir())) + .collect(Collectors.groupingBy(Path::getParent, Collectors.toSet())); + for (var entry : directories.entrySet()) { + var initPath = entry.getKey().resolve("__init__.py"); + if (!entry.getValue().contains(initPath)) { + fileManifest.writeFile(initPath, "# Code generated by smithy-python-codegen DO NOT EDIT.\n"); + } + } + } + + @Override + public void customizeAfterIntegrations(CustomizeDirective directive) { + Pattern versionPattern = Pattern.compile("Python \\d\\.(?\\d+)\\.(?\\d+)"); + FileManifest fileManifest = directive.fileManifest(); + + String output; + try { + LOGGER.info("Attempting to discover python version"); + output = CodegenUtils.runCommand("python3 --version", fileManifest.getBaseDir()).strip(); + } catch (CodegenException e) { + LOGGER.warning("Unable to find python on the path. Skipping formatting and type checking."); + return; + } + var matcher = versionPattern.matcher(output); + if (!matcher.find()) { + LOGGER.warning("Unable to parse python version string. Skipping formatting and type checking."); + } + int minorVersion = Integer.parseInt(matcher.group("minor")); + if (minorVersion < 9) { + LOGGER.warning(format(""" + Found incompatible python version 3.%s.%s, expected 3.9.0 or greater. \ + Skipping formatting and type checking.""", + matcher.group("minor"), matcher.group("patch"))); + return; + } + LOGGER.info("Verifying python files"); + for (var file : fileManifest.getFiles()) { + var fileName = file.getFileName(); + if (fileName == null || !fileName.endsWith(".py")) { + continue; + } + CodegenUtils.runCommand("python3 " + file, fileManifest.getBaseDir()); + } + formatCode(fileManifest); + runMypy(fileManifest); + } + + private void formatCode(FileManifest fileManifest) { + try { + CodegenUtils.runCommand("python3 -m black -h", fileManifest.getBaseDir()); + } catch (CodegenException e) { + LOGGER.warning("Unable to find the python package black. Skipping formatting."); + return; + } + LOGGER.info("Running code formatter on generated code"); + CodegenUtils.runCommand("python3 -m black . --exclude \"\"", fileManifest.getBaseDir()); + } + + private void runMypy(FileManifest fileManifest) { + try { + CodegenUtils.runCommand("python3 -m mypy -h", fileManifest.getBaseDir()); + } catch (CodegenException e) { + LOGGER.warning("Unable to find the python package mypy. Skipping type checking."); + return; + } + LOGGER.info("Running mypy on generated code"); + CodegenUtils.runCommand("python3 -m mypy .", fileManifest.getBaseDir()); + } +} diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/GenerationContext.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/GenerationContext.java index 55474db2..027976a1 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/GenerationContext.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/GenerationContext.java @@ -66,7 +66,7 @@ public FileManifest fileManifest() { } @Override - public WriterDelegator writerDelegator() { + public WriterDelegator writerDelegator() { return delegator; } @@ -83,7 +83,8 @@ public SmithyBuilder toBuilder() { .model(model) .settings(settings) .symbolProvider(symbolProvider) - .fileManifest(fileManifest); + .fileManifest(fileManifest) + .writerDelegator(delegator); } /** diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/MapGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/MapGenerator.java index e3e3f7fa..cb827d88 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/MapGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/MapGenerator.java @@ -82,7 +82,7 @@ private void writeFromDict() { writer.addStdlibImport("typing", "Dict"); writer.openBlock("def $L(given: Dict[str, Any]) -> $T:", "", fromDictSymbol.getName(), symbol, () -> { if (target.isUnionShape() || target.isStructureShape()) { - writer.write("return {k: $T.from_dict(v) for k, v in given.items()}"); + writer.write("return {k: $T.from_dict(v) for k, v in given.items()}", targetSymbol); } else if (target.isMapShape() || target instanceof CollectionShape) { var targetFromDictSymbol = targetSymbol.expectProperty("fromDict", Symbol.class); writer.write("return {k: $T(v) for k, v in given.items()}", targetFromDictSymbol); diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/PythonCodegenPlugin.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/PythonCodegenPlugin.java index d020af55..45c504c2 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/PythonCodegenPlugin.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/PythonCodegenPlugin.java @@ -17,8 +17,8 @@ import software.amazon.smithy.build.PluginContext; import software.amazon.smithy.build.SmithyBuildPlugin; -import software.amazon.smithy.codegen.core.SymbolProvider; -import software.amazon.smithy.model.Model; +import software.amazon.smithy.codegen.core.directed.CodegenDirector; +import software.amazon.smithy.python.codegen.integration.PythonIntegration; /** * Plugin to trigger Python code generation. @@ -31,16 +31,18 @@ public String getName() { @Override public void execute(PluginContext context) { - new CodegenVisitor(context).execute(); - } + CodegenDirector runnner + = new CodegenDirector<>(); - /** - * Creates a Python symbol provider. - * @param model The model to generate symbols for. - * @param settings The settings for the plugin. - * @return Returns the created provider. - */ - public static SymbolProvider createSymbolProvider(Model model, PythonSettings settings) { - return new SymbolVisitor(model, settings); + PythonSettings settings = PythonSettings.from(context.getSettings()); + runnner.settings(settings); + runnner.directedCodegen(new DirectedPythonCodegen()); + runnner.fileManifest(context.getFileManifest()); + runnner.service(settings.getService()); + runnner.model(context.getModel()); + runnner.integrationClass(PythonIntegration.class); + runnner.performDefaultCodegenTransforms(); + runnner.createDedicatedInputsAndOutputs(); + runnner.run(); } } diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SymbolVisitor.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SymbolVisitor.java index 92d3898d..bccb528f 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SymbolVisitor.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SymbolVisitor.java @@ -182,8 +182,7 @@ private Symbol createCollectionSymbol(CollectionShape shape) { .addReference(createStdlibReference("List", "typing")) .addReference(reference); - var target = model.expectShape(shape.getMember().getTarget()); - if (!(target instanceof SimpleShape)) { + if (needsDictHelpers(shape)) { builder.putProperty("asDict", createAsDictFunctionSymbol(shape)) .putProperty("fromDict", createFromDictFunctionSymbol(shape)); } @@ -197,14 +196,41 @@ public Symbol mapShape(MapShape shape) { .addReference(createStdlibReference("Dict", "typing")) .addReference(reference); - var target = model.expectShape(shape.getValue().getTarget()); - if (!(target instanceof SimpleShape)) { + if (needsDictHelpers(shape)) { builder.putProperty("asDict", createAsDictFunctionSymbol(shape)) .putProperty("fromDict", createFromDictFunctionSymbol(shape)); } return builder.build(); } + private boolean needsDictHelpers(MapShape shape) { + Shape target = model.expectShape(shape.getValue().getTarget()); + return targetRequiresDictHelpers(target); + } + + private boolean needsDictHelpers(CollectionShape shape) { + Shape target = model.expectShape(shape.getMember().getTarget()); + return targetRequiresDictHelpers(target); + } + + /** + * Maps and collections are already dict compatible, so if a given map or + * collection only ever transitively reference dict compatible shapes, + * they don't need these dict helpers. + */ + private boolean targetRequiresDictHelpers(Shape target) { + if (target instanceof SimpleShape) { + return false; + } + if (target instanceof CollectionShape) { + return needsDictHelpers((CollectionShape) target); + } + if (target.isMapShape()) { + return needsDictHelpers((MapShape) target); + } + return true; + } + private Symbol createAsDictFunctionSymbol(Shape shape) { return Symbol.builder() .name(String.format("_%s_as_dict", CaseUtils.toSnakeCase(shape.getId().getName())))