diff --git a/changelog.md b/changelog.md index 95a69107..50b7d806 100644 --- a/changelog.md +++ b/changelog.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- [Add compiler plugin validation to validate spread-field config initialization](https://github.com/ballerina-platform/ballerina-standard-library/issues/4594) ### Changed diff --git a/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/oracledb/compiler/CompilerPluginTest.java b/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/oracledb/compiler/CompilerPluginTest.java index 5f9c12fb..4377d8ab 100644 --- a/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/oracledb/compiler/CompilerPluginTest.java +++ b/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/oracledb/compiler/CompilerPluginTest.java @@ -146,4 +146,22 @@ public void testOptionsWithVariables() { Assert.assertEquals(availableErrors, 0); } + @Test + public void negativeTestConnectionPoolWithSpreadField() { + Package currentPackage = loadPackage("sample7"); + PackageCompilation compilation = currentPackage.getCompilation(); + DiagnosticResult diagnosticResult = compilation.diagnosticResult(); + List diagnosticErrorStream = diagnosticResult.diagnostics().stream() + .filter(r -> r.diagnosticInfo().severity().equals(DiagnosticSeverity.ERROR)) + .collect(Collectors.toList()); + long availableErrors = diagnosticErrorStream.size(); + Assert.assertEquals(availableErrors, 3); + Assert.assertEquals(diagnosticErrorStream.get(0).diagnosticInfo().messageFormat(), + "invalid value: expected value is greater than one"); + Assert.assertEquals(diagnosticErrorStream.get(1).diagnosticInfo().messageFormat(), + "invalid value: expected value is greater than or equal to 30"); + Assert.assertEquals(diagnosticErrorStream.get(2).diagnosticInfo().messageFormat(), + "invalid value: expected value is greater than zero"); + } + } diff --git a/compiler-plugin-tests/src/test/resources/diagnostics/sample7/Ballerina.toml b/compiler-plugin-tests/src/test/resources/diagnostics/sample7/Ballerina.toml new file mode 100644 index 00000000..e7f7c149 --- /dev/null +++ b/compiler-plugin-tests/src/test/resources/diagnostics/sample7/Ballerina.toml @@ -0,0 +1,4 @@ +[package] +org = "oracledb_test" +name = "sample7" +version = "0.1.0" diff --git a/compiler-plugin-tests/src/test/resources/diagnostics/sample7/main.bal b/compiler-plugin-tests/src/test/resources/diagnostics/sample7/main.bal new file mode 100644 index 00000000..5ec97846 --- /dev/null +++ b/compiler-plugin-tests/src/test/resources/diagnostics/sample7/main.bal @@ -0,0 +1,59 @@ +// Copyright (c) 2023, WSO2 LLC. (http://www.wso2.com). All Rights Reserved. +// +// This software is the property of WSO2 LLC. and its suppliers, if any. +// Dissemination of any information or reproduction of any material contained +// herein in any form is strictly forbidden, unless permitted by WSO2 expressly. +// You may not alter or remove any copyright or other notice from copies of this content. + +import ballerinax/oracledb; +import ballerinax/oracledb.driver as _; + +# sql:ConnectionPool parameter record with default optimized values +# +# + maxOpenConnections - The maximum open connections +# + maxConnectionLifeTime - The maximum lifetime of a connection +# + minIdleConnections - The minimum idle time of a connection +type SqlConnectionPoolConfig record {| + int maxOpenConnections = -10; + decimal maxConnectionLifeTime = -180; + int minIdleConnections = -5; +|}; + +# mysql:Options parameter record with default optimized values +# +# + connectTimeout - Timeout to be used when establishing a connection +type MysqlOptionsConfig record {| + decimal connectTimeout = 10; +|}; + +# [Configurable] Allocation MySQL Database +# +# + hostname - database hostname +# + username - database username +# + password - database password +# + database - database name +# + port - database port +# + connectionPool - sql:ConnectionPool configurations, type: SqlConnectionPoolConfig +# + mysqlOptions - mysql:Options configurations, type: MysqlOptionsConfig +type AllocationDatabase record {| + string hostname; + string username; + string password; + string database; + int port = 3306; + SqlConnectionPoolConfig connectionPool; + MysqlOptionsConfig mysqlOptions; +|}; + +configurable AllocationDatabase allocationDatabase = ?; + +final oracledb:Client allocationDbClient = check new ( + host = allocationDatabase.hostname, + user = allocationDatabase.username, + password = allocationDatabase.password, + port = allocationDatabase.port, + database = allocationDatabase.database, + connectionPool = { + ...allocationDatabase.connectionPool + } +); diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/Utils.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/Utils.java index 1b516737..9e7b7cd3 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/Utils.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/Utils.java @@ -18,24 +18,38 @@ package io.ballerina.stdlib.oracledb.compiler; import io.ballerina.compiler.api.symbols.ModuleSymbol; +import io.ballerina.compiler.api.symbols.Symbol; +import io.ballerina.compiler.api.symbols.SymbolKind; import io.ballerina.compiler.api.symbols.TypeDescKind; import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol; import io.ballerina.compiler.api.symbols.TypeSymbol; import io.ballerina.compiler.api.symbols.UnionTypeSymbol; import io.ballerina.compiler.syntax.tree.BasicLiteralNode; +import io.ballerina.compiler.syntax.tree.ChildNodeEntry; import io.ballerina.compiler.syntax.tree.ExpressionNode; import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode; import io.ballerina.compiler.syntax.tree.MappingFieldNode; +import io.ballerina.compiler.syntax.tree.ModulePartNode; import io.ballerina.compiler.syntax.tree.Node; -import io.ballerina.compiler.syntax.tree.SeparatedNodeList; +import io.ballerina.compiler.syntax.tree.NodeList; +import io.ballerina.compiler.syntax.tree.NonTerminalNode; +import io.ballerina.compiler.syntax.tree.RecordFieldNode; +import io.ballerina.compiler.syntax.tree.RecordFieldWithDefaultValueNode; +import io.ballerina.compiler.syntax.tree.RecordTypeDescriptorNode; +import io.ballerina.compiler.syntax.tree.SimpleNameReferenceNode; import io.ballerina.compiler.syntax.tree.SpecificFieldNode; +import io.ballerina.compiler.syntax.tree.SpreadFieldNode; +import io.ballerina.compiler.syntax.tree.TypeDefinitionNode; +import io.ballerina.compiler.syntax.tree.TypedBindingPatternNode; import io.ballerina.compiler.syntax.tree.UnaryExpressionNode; import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext; import io.ballerina.tools.diagnostics.Diagnostic; import io.ballerina.tools.diagnostics.DiagnosticFactory; import io.ballerina.tools.diagnostics.DiagnosticInfo; import io.ballerina.tools.diagnostics.DiagnosticSeverity; +import io.ballerina.tools.diagnostics.Location; +import java.util.List; import java.util.Optional; import static io.ballerina.stdlib.oracledb.compiler.Constants.UNNECESSARY_CHARS_REGEX; @@ -98,31 +112,45 @@ public static boolean isOracleDBObject(TypeReferenceTypeSymbol typeReference, St } } - public static void validateOptions(SyntaxNodeAnalysisContext ctx, MappingConstructorExpressionNode options) { - SeparatedNodeList fields = options.fields(); - for (MappingFieldNode field : fields) { - String name = ((SpecificFieldNode) field).fieldName().toString() - .trim().replaceAll(UNNECESSARY_CHARS_REGEX, ""); - ExpressionNode valueNode = ((SpecificFieldNode) field).valueExpr().get(); - switch (name) { - case Constants.Options.CONNECT_TIMEOUT: - case Constants.Options.LOGIN_TIMEOUT: - case Constants.Options.SOCKET_TIMEOUT: - float timeoutVal = Float.parseFloat(getTerminalNodeValue(valueNode, "0")); - if (timeoutVal < 0) { - DiagnosticInfo diagnosticInfo = new DiagnosticInfo(ORACLEDB_101.getCode(), - ORACLEDB_101.getMessage(), ORACLEDB_101.getSeverity()); - ctx.reportDiagnostic( - DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location())); + public static void validateOptionConfig(SyntaxNodeAnalysisContext ctx, MappingConstructorExpressionNode options) { + for (MappingFieldNode field: options.fields()) { + if (field instanceof SpecificFieldNode) { + SpecificFieldNode specificFieldNode = ((SpecificFieldNode) field); + validateOptions(ctx, specificFieldNode.fieldName().toString().trim(). + replaceAll(UNNECESSARY_CHARS_REGEX, ""), specificFieldNode.valueExpr().get()); + } else if (field instanceof SpreadFieldNode) { + NodeList recordFields = Utils.getSpreadFieldType(ctx, ((SpreadFieldNode) field)); + for (Node recordField : recordFields) { + if (recordField instanceof RecordFieldWithDefaultValueNode) { + RecordFieldWithDefaultValueNode fieldWithDefaultValueNode = + (RecordFieldWithDefaultValueNode) recordField; + validateOptions(ctx, fieldWithDefaultValueNode.fieldName().toString(). + trim().replaceAll(UNNECESSARY_CHARS_REGEX, ""), + fieldWithDefaultValueNode.expression()); } - break; - default: - // Can ignore all the other fields - continue; + } } } } + public static void validateOptions(SyntaxNodeAnalysisContext ctx, String name, ExpressionNode valueNode) { + switch (name) { + case Constants.Options.CONNECT_TIMEOUT: + case Constants.Options.LOGIN_TIMEOUT: + case Constants.Options.SOCKET_TIMEOUT: + float timeoutVal = Float.parseFloat(getTerminalNodeValue(valueNode, "0")); + if (timeoutVal < 0) { + DiagnosticInfo diagnosticInfo = new DiagnosticInfo(ORACLEDB_101.getCode(), + ORACLEDB_101.getMessage(), ORACLEDB_101.getSeverity()); + ctx.reportDiagnostic( + DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location())); + } + break; + default: + // Can ignore all the other fields + } + } + public static String getTerminalNodeValue(Node valueNode, String defaultValue) { String value = defaultValue; if (valueNode instanceof BasicLiteralNode) { @@ -157,4 +185,117 @@ public static DiagnosticInfo addDiagnosticsForInvalidTypes(String objectName, Ty return null; } } + + public static NodeList getSpreadFieldType(SyntaxNodeAnalysisContext ctx, SpreadFieldNode spreadFieldNode) { + List symbols = ctx.semanticModel().moduleSymbols(); + Object[] entries = spreadFieldNode.valueExpr().childEntries().toArray(); + ModulePartNode modulePartNode = ctx.syntaxTree().rootNode(); + ChildNodeEntry type = Utils.getVariableType(symbols, entries, modulePartNode); + RecordTypeDescriptorNode typeDescriptor = Utils.getFirstSpreadFieldRecordTypeDescriptorNode(symbols, + type, modulePartNode); + typeDescriptor = Utils.getEndSpreadFieldRecordType(symbols, entries, modulePartNode, + typeDescriptor); + return typeDescriptor.fields(); + } + + public static ChildNodeEntry getVariableType(List symbols, Object[] entries, + ModulePartNode modulePartNode) { + for (Symbol symbol : symbols) { + if (!symbol.kind().equals(SymbolKind.VARIABLE)) { + continue; + } + Optional symbolName = symbol.getName(); + Optional childNodeEntry = ((ChildNodeEntry) entries[0]).node(); + if (symbolName.isPresent() && childNodeEntry.isPresent() && + symbolName.get().equals(childNodeEntry.get().toString())) { + Optional location = symbol.getLocation(); + if (location.isPresent()) { + Location loc = location.get(); + NonTerminalNode node = modulePartNode.findNode(loc.textRange()); + if (node instanceof TypedBindingPatternNode) { + TypedBindingPatternNode typedBindingPatternNode = (TypedBindingPatternNode) node; + return (ChildNodeEntry) typedBindingPatternNode.childEntries().toArray()[0]; + } + } + } + } + return null; + } + + public static RecordTypeDescriptorNode getFirstSpreadFieldRecordTypeDescriptorNode(List symbols, + ChildNodeEntry type, + ModulePartNode modulePartNode) { + if (type != null && type.node().isPresent()) { + for (Symbol symbol : symbols) { + if (!symbol.kind().equals(SymbolKind.TYPE_DEFINITION)) { + continue; + } + if (symbol.getName().isPresent() && + symbol.getName().get().equals(type.node().get().toString().trim())) { + Optional loc = symbol.getLocation(); + if (loc.isPresent()) { + Location location = loc.get(); + Node node = modulePartNode.findNode(location.textRange()); + if (node instanceof TypeDefinitionNode) { + TypeDefinitionNode typeDefinitionNode = (TypeDefinitionNode) node; + return (RecordTypeDescriptorNode) typeDefinitionNode.typeDescriptor(); + } + } + } + } + } + return null; + } + + public static RecordTypeDescriptorNode getEndSpreadFieldRecordType(List symbols, Object[] entries, + ModulePartNode modulePartNode, + RecordTypeDescriptorNode typeDescriptor) { + if (typeDescriptor != null) { + for (int i = 1; i < entries.length; i++) { + String childNodeEntry = ((ChildNodeEntry) entries[i]).node().get().toString(); + NodeList recordFields = typeDescriptor.fields(); + if (childNodeEntry.equals(".")) { + continue; + } + for (Node recordField : recordFields) { + String fieldName; + Node fieldType; + if (recordField instanceof RecordFieldWithDefaultValueNode) { + RecordFieldWithDefaultValueNode fieldWithDefaultValueNode = + (RecordFieldWithDefaultValueNode) recordField; + fieldName = fieldWithDefaultValueNode.fieldName().text().trim(); + fieldType = fieldWithDefaultValueNode.typeName(); + } else { + RecordFieldNode fieldNode = (RecordFieldNode) recordField; + fieldName = fieldNode.fieldName().text().trim(); + fieldType = fieldNode.typeName(); + } + if (fieldName.equals(childNodeEntry.trim())) { + if (fieldType instanceof SimpleNameReferenceNode) { + SimpleNameReferenceNode nameReferenceNode = (SimpleNameReferenceNode) fieldType; + for (Symbol symbol : symbols) { + if (!symbol.kind().equals(SymbolKind.TYPE_DEFINITION)) { + continue; + } + if (symbol.getName().isPresent() && + symbol.getName().get().equals(nameReferenceNode.name().text().trim())) { + Optional loc = symbol.getLocation(); + if (loc.isPresent()) { + Location location = loc.get(); + Node node = modulePartNode.findNode(location.textRange()); + if (node instanceof TypeDefinitionNode) { + TypeDefinitionNode typeDefinitionNode = (TypeDefinitionNode) node; + typeDescriptor = (RecordTypeDescriptorNode) typeDefinitionNode. + typeDescriptor(); + } + } + } + } + } + } + } + } + } + return typeDescriptor; + } } diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/analyzer/InitializerParamAnalyzer.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/analyzer/InitializerParamAnalyzer.java index 56983e55..b85093be 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/analyzer/InitializerParamAnalyzer.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/analyzer/InitializerParamAnalyzer.java @@ -24,9 +24,13 @@ import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode; import io.ballerina.compiler.syntax.tree.MappingFieldNode; import io.ballerina.compiler.syntax.tree.NamedArgumentNode; +import io.ballerina.compiler.syntax.tree.Node; +import io.ballerina.compiler.syntax.tree.NodeList; import io.ballerina.compiler.syntax.tree.PositionalArgumentNode; +import io.ballerina.compiler.syntax.tree.RecordFieldWithDefaultValueNode; import io.ballerina.compiler.syntax.tree.SeparatedNodeList; import io.ballerina.compiler.syntax.tree.SpecificFieldNode; +import io.ballerina.compiler.syntax.tree.SpreadFieldNode; import io.ballerina.projects.plugins.AnalysisTask; import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext; import io.ballerina.stdlib.oracledb.compiler.Constants; @@ -44,7 +48,6 @@ import static io.ballerina.stdlib.oracledb.compiler.OracleDBDiagnosticsCode.SQL_102; import static io.ballerina.stdlib.oracledb.compiler.OracleDBDiagnosticsCode.SQL_103; import static io.ballerina.stdlib.oracledb.compiler.Utils.getTerminalNodeValue; -import static io.ballerina.stdlib.oracledb.compiler.Utils.validateOptions; /** * Validate fields of sql:Connection pool fields. @@ -95,55 +98,66 @@ public void perform(SyntaxNodeAnalysisContext ctx) { } if (options instanceof MappingConstructorExpressionNode) { - validateOptions(ctx, (MappingConstructorExpressionNode) options); + Utils.validateOptionConfig(ctx, (MappingConstructorExpressionNode) options); } if (connectionPool instanceof MappingConstructorExpressionNode) { - validateConnectionPool(ctx, (MappingConstructorExpressionNode) connectionPool); + SeparatedNodeList fields = ((MappingConstructorExpressionNode) connectionPool).fields(); + for (MappingFieldNode field: fields) { + if (field instanceof SpecificFieldNode) { + SpecificFieldNode specificFieldNode = ((SpecificFieldNode) field); + validateConnectionPool(ctx, specificFieldNode.fieldName().toString().trim(). + replaceAll(UNNECESSARY_CHARS_REGEX, ""), specificFieldNode.valueExpr().get()); + } else if (field instanceof SpreadFieldNode) { + NodeList recordFields = Utils.getSpreadFieldType(ctx, (SpreadFieldNode) field); + for (Node recordField : recordFields) { + if (recordField instanceof RecordFieldWithDefaultValueNode) { + RecordFieldWithDefaultValueNode fieldWithDefaultValueNode = + (RecordFieldWithDefaultValueNode) recordField; + validateConnectionPool(ctx, fieldWithDefaultValueNode.fieldName().toString(). + trim().replaceAll(UNNECESSARY_CHARS_REGEX, ""), + fieldWithDefaultValueNode.expression()); + } + } + } + } } } - private void validateConnectionPool(SyntaxNodeAnalysisContext ctx, MappingConstructorExpressionNode pool) { - SeparatedNodeList fields = pool.fields(); - for (MappingFieldNode field : fields) { - String name = ((SpecificFieldNode) field).fieldName().toString() - .trim().replaceAll(UNNECESSARY_CHARS_REGEX, ""); - ExpressionNode valueNode = ((SpecificFieldNode) field).valueExpr().get(); - switch (name) { - case Constants.ConnectionPool.MAX_OPEN_CONNECTIONS: - int maxOpenConnections = Integer.parseInt(getTerminalNodeValue(valueNode, "1")); - if (maxOpenConnections < 1) { - DiagnosticInfo diagnosticInfo = new DiagnosticInfo(SQL_101.getCode(), SQL_101.getMessage(), - SQL_101.getSeverity()); + private void validateConnectionPool(SyntaxNodeAnalysisContext ctx, String name, ExpressionNode valueNode) { + switch (name) { + case Constants.ConnectionPool.MAX_OPEN_CONNECTIONS: + int maxOpenConnections = Integer.parseInt(getTerminalNodeValue(valueNode, "1")); + if (maxOpenConnections < 1) { + DiagnosticInfo diagnosticInfo = new DiagnosticInfo(SQL_101.getCode(), SQL_101.getMessage(), + SQL_101.getSeverity()); - ctx.reportDiagnostic( - DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location())); + ctx.reportDiagnostic( + DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location())); - } - break; - case Constants.ConnectionPool.MIN_IDLE_CONNECTIONS: - int minIdleConnection = Integer.parseInt(getTerminalNodeValue(valueNode, "0")); - if (minIdleConnection < 0) { - DiagnosticInfo diagnosticInfo = new DiagnosticInfo(SQL_102.getCode(), SQL_102.getMessage(), - SQL_102.getSeverity()); - ctx.reportDiagnostic( - DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location())); + } + break; + case Constants.ConnectionPool.MIN_IDLE_CONNECTIONS: + int minIdleConnection = Integer.parseInt(getTerminalNodeValue(valueNode, "0")); + if (minIdleConnection < 0) { + DiagnosticInfo diagnosticInfo = new DiagnosticInfo(SQL_102.getCode(), SQL_102.getMessage(), + SQL_102.getSeverity()); + ctx.reportDiagnostic( + DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location())); - } - break; - case Constants.ConnectionPool.MAX_CONNECTION_LIFE_TIME: - float maxConnectionTime = Float.parseFloat(getTerminalNodeValue(valueNode, "30")); - if (maxConnectionTime < 30) { - DiagnosticInfo diagnosticInfo = new DiagnosticInfo(SQL_103.getCode(), SQL_103.getMessage(), - SQL_103.getSeverity()); - ctx.reportDiagnostic( - DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location())); + } + break; + case Constants.ConnectionPool.MAX_CONNECTION_LIFE_TIME: + float maxConnectionTime = Float.parseFloat(getTerminalNodeValue(valueNode, "30")); + if (maxConnectionTime < 30) { + DiagnosticInfo diagnosticInfo = new DiagnosticInfo(SQL_103.getCode(), SQL_103.getMessage(), + SQL_103.getSeverity()); + ctx.reportDiagnostic( + DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location())); - } - break; - default: - // Can ignore all the other fields - continue; - } + } + break; + default: + // Can ignore all the other fields } } } diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/analyzer/RecordAnalyzer.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/analyzer/RecordAnalyzer.java index 27958b99..4c855f35 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/analyzer/RecordAnalyzer.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/oracledb/compiler/analyzer/RecordAnalyzer.java @@ -37,7 +37,7 @@ import static io.ballerina.stdlib.oracledb.compiler.Constants.BALLERINAX; import static io.ballerina.stdlib.oracledb.compiler.Constants.ORACLEDB; -import static io.ballerina.stdlib.oracledb.compiler.Utils.validateOptions; +import static io.ballerina.stdlib.oracledb.compiler.Utils.validateOptionConfig; /** * Analyser for validation oracledb:Options. @@ -58,7 +58,7 @@ public void perform(SyntaxNodeAnalysisContext ctx) { if (recordNode.isEmpty()) { return; } - validateOptions(ctx, recordNode.get()); + validateOptionConfig(ctx, recordNode.get()); } } }