diff --git a/lib/antlr/visitors/exportVisitor.ts b/lib/antlr/visitors/exportVisitor.ts index fb266b5..ecb891d 100644 --- a/lib/antlr/visitors/exportVisitor.ts +++ b/lib/antlr/visitors/exportVisitor.ts @@ -8,6 +8,7 @@ import { ContractDefinitionContext, EnumDefinitionContext, ErrorDefinitionContext, + FunctionDefinitionContext, InheritanceSpecifierListContext, InterfaceDefinitionContext, LibraryDefinitionContext, @@ -372,4 +373,29 @@ class ExportVisitor implements SolidityParserListener { typeName: typeName.text, }); } + + enterFunctionDefinition(ctx: FunctionDefinitionContext): void { + if (!(ctx.parent instanceof SourceUnitContext)) { + return; + } + + if (!ctx.stop) { + return; + } + + const start = ctx.start.startIndex; + const end = ctx.stop.stopIndex; + const name = ctx.identifier(); + + if (!name?.stop) { + return; + } + + this.#onVisit({ + type: ExportType.function, + start, + end, + name: name.text, + }); + } } diff --git a/lib/antlr/visitors/types.ts b/lib/antlr/visitors/types.ts index e97c3a3..8809e4a 100644 --- a/lib/antlr/visitors/types.ts +++ b/lib/antlr/visitors/types.ts @@ -18,7 +18,8 @@ export interface ImportVisitNamedImport { export type ExportVisitResult = | ExportVisitResultContractLike - | ExportVisitResultConstant; + | ExportVisitResultConstant + | ExportVisitResultFunction; export interface ExportVisitResultContractLike extends RangeVisitResult { abstract: boolean; @@ -35,4 +36,9 @@ export interface ExportVisitResultConstant extends RangeVisitResult { name: string; } +export interface ExportVisitResultFunction extends RangeVisitResult { + type: ExportType.function; + name: string; +} + export type VisitCallback = (v: T) => void; diff --git a/lib/exportsAnalyzer.ts b/lib/exportsAnalyzer.ts index d28363d..b67a5bc 100644 --- a/lib/exportsAnalyzer.ts +++ b/lib/exportsAnalyzer.ts @@ -1,13 +1,17 @@ import Debug from 'debug'; import { SolidityExportVisitor } from './antlr/visitors/exportVisitor'; -import { ExportVisitResultConstant } from './antlr/visitors/types'; +import { + ExportVisitResultConstant, + ExportVisitResultFunction, +} from './antlr/visitors/types'; import { ContractLikeExportType, ExportType } from './types'; const error = Debug('sol-merger:error'); export type ExportsAnalyzerResult = | ExportsAnalyzerResultContractLike - | ExportsAnalyzerResultConstant; + | ExportsAnalyzerResultConstant + | ExportsAnalyzerResultFunction; export interface ExportsAnalyzerResultContractLike { abstract: boolean; @@ -24,6 +28,12 @@ export interface ExportsAnalyzerResultConstant { typeName: string; } +export interface ExportsAnalyzerResultFunction { + type: ExportType.function; + name: string; + body: string; +} + export class ExportsAnalyzer { constructor(private contents: string) {} @@ -47,6 +57,12 @@ export class ExportsAnalyzer { results.push(constantExport); return; } + + if (e.type === ExportType.function) { + const functionExport = this.analyzeExportFunction(e); + results.push(functionExport); + return; + } results.push({ abstract: e.abstract, type: e.type, @@ -74,4 +90,14 @@ export class ExportsAnalyzer { typeName: e.typeName, }; } + + private analyzeExportFunction( + e: ExportVisitResultFunction, + ): ExportsAnalyzerResultFunction { + return { + body: this.contents.substring(e.start, e.end + 1), + name: e.name, + type: ExportType.function, + }; + } } diff --git a/lib/fileAnalyzer.ts b/lib/fileAnalyzer.ts index d74e33b..02f52a9 100644 --- a/lib/fileAnalyzer.ts +++ b/lib/fileAnalyzer.ts @@ -24,6 +24,10 @@ export class FileAnalyzer { return `${e.typeName} ${e.type} ${e.name}${e.body}`; } + if (e.type === ExportType.function) { + return e.body; + } + let is = e.is; if (is) { globalRenames.forEach((i) => { diff --git a/lib/types.ts b/lib/types.ts index 063a6c5..39904f2 100644 --- a/lib/types.ts +++ b/lib/types.ts @@ -10,11 +10,12 @@ export enum ExportType { error = 'error', constant = 'constant', function = 'function', + userDefinedValueType = 'userDefinedValueType', } export type ContractLikeExportType = Exclude< ExportType, - ExportType.constant | ExportType.function + ExportType.constant | ExportType.function | ExportType.userDefinedValueType >; export interface ExportPluginProcessor { diff --git a/test/compiled/ContractWithTopLevelFunction.sol b/test/compiled/ContractWithTopLevelFunction.sol new file mode 100644 index 0000000..a10b743 --- /dev/null +++ b/test/compiled/ContractWithTopLevelFunction.sol @@ -0,0 +1,19 @@ +pragma solidity >=0.7.1 <0.9.0; + + +// SPDX-License-Identifier: GPL-3.0 +function sum(uint[] memory _arr) pure returns (uint s) { + for (uint i = 0; i < _arr.length; i++) + s += _arr[i]; +} + +contract ArrayExample { + bool found; + function f(uint[] memory _arr) public { + // This calls the free function internally. + // The compiler will add its code to the contract. + uint s = sum(_arr); + require(s >= 10); + found = true; + } +} diff --git a/test/contracts/ContractWithTopLevelFunction.sol b/test/contracts/ContractWithTopLevelFunction.sol new file mode 100644 index 0000000..07b523f --- /dev/null +++ b/test/contracts/ContractWithTopLevelFunction.sol @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity >=0.7.1 <0.9.0; + +function sum(uint[] memory _arr) pure returns (uint s) { + for (uint i = 0; i < _arr.length; i++) + s += _arr[i]; +} + +contract ArrayExample { + bool found; + function f(uint[] memory _arr) public { + // This calls the free function internally. + // The compiler will add its code to the contract. + uint s = sum(_arr); + require(s >= 10); + found = true; + } +} diff --git a/test/exportsAnalyzer.spec.ts b/test/exportsAnalyzer.spec.ts index 099f382..c42da63 100644 --- a/test/exportsAnalyzer.spec.ts +++ b/test/exportsAnalyzer.spec.ts @@ -145,5 +145,27 @@ describe('ExportsAnalyzer', () => { assert.deepEqual(exports, []); }); + + it('should analyze function export', () => { + const exportsAnalyzer = new ExportsAnalyzer(` + function sum(uint[] memory _arr) pure returns (uint s) { + for (uint i = 0; i < _arr.length; i++) + s += _arr[i]; + } + `); + + const exports = exportsAnalyzer.analyzeExports(); + + assert.deepEqual(exports, [ + { + name: 'sum', + type: ExportType.function, + body: `function sum(uint[] memory _arr) pure returns (uint s) { + for (uint i = 0; i < _arr.length; i++) + s += _arr[i]; + }`, + }, + ]); + }); }); }); diff --git a/test/index.spec.ts b/test/index.spec.ts index aec0f38..dc28eb2 100644 --- a/test/index.spec.ts +++ b/test/index.spec.ts @@ -124,4 +124,8 @@ describe('Solidity Merger', () => { it('should compile file with constants at root level (0.8 support)', async () => { await testFile('ContractWithConstants'); }); + + it('should compile file with functions at root level (0.8 support)', async () => { + await testFile('ContractWithTopLevelFunction'); + }); });