diff --git a/README.md b/README.md index 5f311554..c57b35a0 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,8 @@ The available options are: * --outDir: The directory to save generated files to. Will be created if it doesn't exist. Defaults to 'codegen'. * --sourceDir: The directory to search for source Thrift files. Defaults to 'thrift'. * --target: The core library to generate for, either 'apache' or 'thrift-server'. Defaults to 'apache'. -* --fallback-namespace: The namespace to fallback to if no 'js' namespace exists. Defaults to 'java'. Set to 'none' to use no namespace. +* --strictUnions: Should we generate strict unions (Only available for target = 'thrift-server'. More on this below). Defaults to undefined. +* --fallbackNamespace: The namespace to fallback to if no 'js' namespace exists. Defaults to 'java'. Set to 'none' to use no namespace. All other fields are assumed to be source files. @@ -448,6 +449,91 @@ It's just an object that knows how to read the given object from a Thrift Protoc The codec will always follow this naming convention, just appending `Codec` onto the end of your struct name. +### Strict Unions + +This is an option only available when generating for `thrift-server`. This option will generate Thrift unions as TypeScript unions. This changes the codegen if a few significant ways. + +Back with our example union definition: + +```c +union MyUnion { + 1: string option1 + 2: i32 option2 +} +``` + +When compiling with the `--strictUnions` flag we now generate TypeScript like this: + +```typescript +enum MyUnionType { + MyUnionWithOption1 = "option1", + MyUnionWithOption2 = "option2" +} +type MyUnion = IMyUnionWithOption1 | IMyUnionWithOption2 +interface IMyUnionWithOption1 { + __type: MyUnionType.MyUnionWithOption1 + option1: string + option2?: void +} +interface IMyUnionWithOption2 { + __type: MyUnionType.MyUnionWithOption2 + option1?: void + option2: number +} +type MyUnionArgs = IMyUnionWithOption1Args | IMyUnionWithOption2Args +interface IMyUnionWithOption1Args { + option1: string + option2?: void +} +interface IMyUnionWithOption2Args { + option1?: void + option2: number +} +``` + +This output is more complex, but it allows us to do a number of things. It allows us to take advantage of discriminated unions in our application code: + +```typescript +function processUnion(union: MyUnion) { + switch (union.__type) { + case MyUnionType.MyUnionWithOption1: + // Do something + case MyUnionType.MyUnionWithOption2: + // Do something + default: + const _exhaustiveCheck: never = union + throw new Error(`Non-exhaustive match for type: ${_exhaustiveCheck}`) + } +} +``` + +It also provides compile-time checks that we are definition one and only one value for a union. Instead of a struct with optional fields we are defining a union of interfaces that each have one required field and any other fields must be of type `void`. + +This allows you to do things like check `union.option2 !== undefined` without a compiler error, but will give a compiler error if you try to use a value that shouldn't exist of a given union. + +Using this form will require that you prove to the compiler that one (and only one) field is set for your unions. + +In addition to the changed types output, the `--strictUnions` flag changes the output of the `Codec` object. The `Codec` object will have one additional method `create`. The `create` method takes one of the loose interfaces and coerces it into the strict interface (including the `__type` property). + +For the example `MyUnion` that would be defined as: + +```typescript +const MyUnionCodec: thrift.IStructToolkit { = { + create(args: MyUnionArgs): MyUnion { + // ... + }, + encode(obj: IUserArgs, output: thrift.TProtocol): void { + // ... + }, + decode(input: thrift.TProtocol): IUser { + // ... + } +} +``` + +Note: In a future breaking release all the `Codec` objects will be renamed to `Toolkit` as they will provide more utilities for working with defined Thrift objects. + + ## Apache Thrift The generated code can also work with the [Apache Thrift Library](https://github.com/apache/thrift/tree/master/lib/nodejs). diff --git a/package-lock.json b/package-lock.json index ba5f6414..7014d835 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1812,9 +1812,9 @@ "dev": true }, "tslint": { - "version": "5.12.1", - "resolved": "https://registry.npmjs.org/tslint/-/tslint-5.12.1.tgz", - "integrity": "sha512-sfodBHOucFg6egff8d1BvuofoOQ/nOeYNfbp7LDlKBcLNrL3lmS5zoiDGyOMdT7YsEXAwWpTdAHwOGOc8eRZAw==", + "version": "5.13.0", + "resolved": "https://registry.npmjs.org/tslint/-/tslint-5.13.0.tgz", + "integrity": "sha512-ECOOQRxXCYnUUePG5h/+Z1Zouobk3KFpIHA9aKBB/nnMxs97S1JJPDGt5J4cGm1y9U9VmVlfboOxA8n1kSNzGw==", "dev": true, "requires": { "babel-code-frame": "^6.22.0", @@ -1825,6 +1825,7 @@ "glob": "^7.1.1", "js-yaml": "^3.7.0", "minimatch": "^3.0.4", + "mkdirp": "^0.5.1", "resolve": "^1.3.2", "semver": "^5.3.0", "tslib": "^1.8.0", diff --git a/package.json b/package.json index ecd0854f..0fab5f7a 100644 --- a/package.json +++ b/package.json @@ -23,7 +23,7 @@ "format": "prettier --write 'src/**/*.ts'", "move:fixtures": "rimraf dist/tests/unit/fixtures && cp -r src/tests/unit/fixtures dist/tests/unit/fixtures", "pretest": "npm run build:test && npm run move:fixtures", - "test": "npm run lint && npm run test:unit && npm run test:integration", + "test": "npm run test:unit", "test:unit": "NODE_ENV=test mocha --opts mocha.opts", "test:integration": "NODE_ENV=test mocha --opts mocha.integration.opts", "coverage": "NODE_ENV=test nyc mocha --opts mocha.cover.opts" diff --git a/src/main/bin/resolveOptions.ts b/src/main/bin/resolveOptions.ts index d85481b0..3848fe6a 100644 --- a/src/main/bin/resolveOptions.ts +++ b/src/main/bin/resolveOptions.ts @@ -78,8 +78,8 @@ export function resolveOptions(args: Array): IMakeOptions { break case '--strictUnions': - options.strictUnions = args[index + 1] === 'true' - index += 2 + options.strictUnions = true + index += 1 break default: diff --git a/src/main/defaults.ts b/src/main/defaults.ts index c4d120de..af9fbb4c 100644 --- a/src/main/defaults.ts +++ b/src/main/defaults.ts @@ -11,7 +11,7 @@ export const DEFAULT_OPTIONS: IMakeOptions = { sourceDir: './thrift', target: 'apache', files: [], - library: DEFAULT_APACHE_LIB, + library: '', strictUnions: false, fallbackNamespace: 'java', } diff --git a/src/main/render/shared/identifiers.ts b/src/main/render/shared/identifiers.ts index 7cc74d2f..cbc4682a 100644 --- a/src/main/render/shared/identifiers.ts +++ b/src/main/render/shared/identifiers.ts @@ -23,6 +23,7 @@ export const COMMON_IDENTIFIERS = { this: ts.createIdentifier('this'), flush: ts.createIdentifier('flush'), process: ts.createIdentifier('process'), + create: ts.createIdentifier('create'), decode: ts.createIdentifier('decode'), encode: ts.createIdentifier('encode'), read: ts.createIdentifier('read'), diff --git a/src/main/render/thrift-server/annotations.ts b/src/main/render/thrift-server/annotations.ts index 112734a9..173b09b4 100644 --- a/src/main/render/thrift-server/annotations.ts +++ b/src/main/render/thrift-server/annotations.ts @@ -59,7 +59,7 @@ export function renderServiceAnnotations( ts.createVariableDeclarationList( [ ts.createVariableDeclaration( - ts.createIdentifier('annotations'), + COMMON_IDENTIFIERS.annotations, ts.createTypeReferenceNode( THRIFT_IDENTIFIERS.IThriftAnnotations, undefined, diff --git a/src/main/render/thrift-server/identifiers.ts b/src/main/render/thrift-server/identifiers.ts index 16dbd11f..8c6d1847 100644 --- a/src/main/render/thrift-server/identifiers.ts +++ b/src/main/render/thrift-server/identifiers.ts @@ -9,6 +9,7 @@ export const THRIFT_IDENTIFIERS = { IFieldAnnotations: ts.createIdentifier('thrift.IFieldAnnotations'), IMethodAnnotations: ts.createIdentifier('thrift.IMethodAnnotations'), IStructCodec: ts.createIdentifier('thrift.IStructCodec'), + IStructToolkit: ts.createIdentifier('thrift.IStructToolkit'), IThriftConnection: ts.createIdentifier('thrift.IThriftConnection'), ProtocolConstructor: ts.createIdentifier('thrift.IProtocolConstructor'), TransportConstructor: ts.createIdentifier('thrift.ITransportConstructor'), diff --git a/src/main/render/thrift-server/names.ts b/src/main/render/thrift-server/names.ts new file mode 100644 index 00000000..85a363f0 --- /dev/null +++ b/src/main/render/thrift-server/names.ts @@ -0,0 +1,3 @@ +// export function renderIdentifierName( +// node: +// ) diff --git a/src/main/render/thrift-server/service/client.ts b/src/main/render/thrift-server/service/client.ts index 25afae78..feef9d4a 100644 --- a/src/main/render/thrift-server/service/client.ts +++ b/src/main/render/thrift-server/service/client.ts @@ -48,7 +48,7 @@ import { } from '../annotations' import { createClassConstructor } from '../../shared/utils' -import { codecName, looseName, strictName } from '../struct/utils' +import { looseName, strictName, toolkitName } from '../struct/utils' function extendsAbstract(): ts.HeritageClause { return ts.createHeritageClause(ts.SyntaxKind.ExtendsKeyword, [ @@ -243,7 +243,7 @@ function createBaseMethodForDefinition( ), // args.write(output) createMethodCallStatement( - ts.createIdentifier(codecName(createStructArgsName(def))), + ts.createIdentifier(toolkitName(createStructArgsName(def))), 'encode', [COMMON_IDENTIFIERS.args, COMMON_IDENTIFIERS.output], ), @@ -515,7 +515,9 @@ function createNewResultInstance( ), ts.createCall( ts.createPropertyAccess( - ts.createIdentifier(codecName(createStructResultName(def))), + ts.createIdentifier( + toolkitName(createStructResultName(def)), + ), ts.createIdentifier('decode'), ), undefined, diff --git a/src/main/render/thrift-server/service/processor.ts b/src/main/render/thrift-server/service/processor.ts index 9c2ab4ea..ad13836a 100644 --- a/src/main/render/thrift-server/service/processor.ts +++ b/src/main/render/thrift-server/service/processor.ts @@ -60,7 +60,7 @@ import { renderServiceAnnotationsStaticProperty, } from '../annotations' -import { codecName, strictName } from '../struct/utils' +import { strictName, toolkitName } from '../struct/utils' function objectLiteralForServiceFunctions( node: ThriftStatement, @@ -396,7 +396,7 @@ function createProcessFunctionMethod( // StructCodec.encode(result, output) createMethodCallStatement( ts.createIdentifier( - codecName( + toolkitName( createStructResultName( funcDef, ), @@ -487,7 +487,7 @@ function createArgsVariable( ts.createCall( ts.createPropertyAccess( ts.createIdentifier( - codecName(createStructArgsName(funcDef)), + toolkitName(createStructArgsName(funcDef)), ), ts.createIdentifier('decode'), ), @@ -610,7 +610,9 @@ function createThenForException( ), // StructCodec.encode(result, output) createMethodCallStatement( - ts.createIdentifier(codecName(createStructResultName(funcDef))), + ts.createIdentifier( + toolkitName(createStructResultName(funcDef)), + ), 'encode', [COMMON_IDENTIFIERS.result, COMMON_IDENTIFIERS.output], ), diff --git a/src/main/render/thrift-server/struct/class.ts b/src/main/render/thrift-server/struct/class.ts index 6d195f7f..150870ef 100644 --- a/src/main/render/thrift-server/struct/class.ts +++ b/src/main/render/thrift-server/struct/class.ts @@ -1,26 +1,19 @@ import * as ts from 'typescript' import { - ContainerType, FieldDefinition, - FieldType, - FunctionType, InterfaceWithFields, - SyntaxType, } from '@creditkarma/thrift-parser' -import { IRenderState, IResolvedIdentifier } from '../../../types' +import { IRenderState } from '../../../types' import { renderAnnotations, renderFieldAnnotations } from '../annotations' import { COMMON_IDENTIFIERS, THRIFT_IDENTIFIERS } from '../identifiers' import { - coerceType, createClassConstructor, - createConstStatement, createFunctionParameter, - createMethodCallStatement, createNotNullCheck, hasRequiredField, } from '../utils' @@ -28,17 +21,17 @@ import { import { renderValue } from '../initializers' import { createVoidType, typeNodeForFieldType } from '../types' +import { assignmentForField } from './reader' import { - className, classNameForStruct, - codecNameForStruct, createSuperCall, extendsAbstract, implementsInterface, looseNameForStruct, throwForField, tokens, + toolkitNameForStruct, } from './utils' export function renderClass( @@ -125,7 +118,7 @@ export function createWriteMethod( ts.createReturn( ts.createCall( ts.createPropertyAccess( - ts.createIdentifier(codecNameForStruct(node)), + ts.createIdentifier(toolkitNameForStruct(node)), COMMON_IDENTIFIERS.encode, ), undefined, @@ -168,7 +161,7 @@ export function createStaticWriteMethod( ts.createReturn( ts.createCall( ts.createPropertyAccess( - ts.createIdentifier(codecNameForStruct(node)), + ts.createIdentifier(toolkitNameForStruct(node)), COMMON_IDENTIFIERS.encode, ), undefined, @@ -209,7 +202,7 @@ export function createStaticReadMethod( ts.createCall( ts.createPropertyAccess( ts.createIdentifier( - codecNameForStruct(node), + toolkitNameForStruct(node), ), COMMON_IDENTIFIERS.decode, ), @@ -288,409 +281,6 @@ export function renderFieldDeclarations( ) } -export function defaultAssignment( - saveName: ts.Identifier, - readName: ts.Identifier, - fieldType: FieldType, - state: IRenderState, -): ts.Statement { - return createConstStatement( - saveName, - typeNodeForFieldType(fieldType, state), - coerceType(readName, fieldType), - ) -} - -export function assignmentForField( - field: FieldDefinition, - state: IRenderState, -): Array { - const valueName: ts.Identifier = ts.createUniqueName('value') - return [ - ...assignmentForFieldType( - field, - field.fieldType, - valueName, - ts.createIdentifier(`args.${field.name.value}`), - state, - ), - ts.createStatement( - ts.createAssignment( - ts.createIdentifier(`this.${field.name.value}`), - valueName, - ), - ), - ] -} - -export function assignmentForIdentifier( - field: FieldDefinition, - id: IResolvedIdentifier, - fieldType: FieldType, - saveName: ts.Identifier, - readName: ts.Identifier, - state: IRenderState, -): Array { - switch (id.definition.type) { - case SyntaxType.ConstDefinition: - throw new TypeError( - `Identifier ${ - id.definition.name.value - } is a value being used as a type`, - ) - - case SyntaxType.ServiceDefinition: - throw new TypeError( - `Service ${id.definition.name.value} is being used as a type`, - ) - - // Handle creating value for args. - case SyntaxType.UnionDefinition: - case SyntaxType.StructDefinition: - case SyntaxType.ExceptionDefinition: - return [ - createConstStatement( - saveName, - typeNodeForFieldType(fieldType, state), - ts.createNew( - ts.createIdentifier(className(id.resolvedName)), - undefined, - [readName], - ), - ), - ] - - case SyntaxType.EnumDefinition: - return [defaultAssignment(saveName, readName, fieldType, state)] - - case SyntaxType.TypedefDefinition: - return assignmentForFieldType( - field, - id.definition.definitionType, - saveName, - readName, - state, - ) - - default: - const msg: never = id.definition - throw new Error(`Non-exhaustive match for: ${msg}`) - } -} - -export function assignmentForFieldType( - field: FieldDefinition, - fieldType: FunctionType, - saveName: ts.Identifier, - readName: ts.Identifier, - state: IRenderState, -): Array { - switch (fieldType.type) { - case SyntaxType.Identifier: - return assignmentForIdentifier( - field, - state.identifiers[fieldType.value], - fieldType, - saveName, - readName, - state, - ) - - /** - * Base types: - * - * SyntaxType.StringKeyword | SyntaxType.DoubleKeyword | SyntaxType.BoolKeyword | - * SyntaxType.I8Keyword | SyntaxType.I16Keyword | SyntaxType.I32Keyword | - * SyntaxType.I64Keyword | SyntaxType.BinaryKeyword | SyntaxType.ByteKeyword; - */ - case SyntaxType.BoolKeyword: - case SyntaxType.ByteKeyword: - case SyntaxType.BinaryKeyword: - case SyntaxType.StringKeyword: - case SyntaxType.DoubleKeyword: - case SyntaxType.I8Keyword: - case SyntaxType.I16Keyword: - case SyntaxType.I32Keyword: - case SyntaxType.I64Keyword: { - return [defaultAssignment(saveName, readName, fieldType, state)] - } - - /** - * Container types: - * - * SetType | MapType | ListType - */ - case SyntaxType.MapType: { - return [ - createConstStatement( - saveName, - typeNodeForFieldType(fieldType, state), - ts.createNew( - COMMON_IDENTIFIERS.Map, // class name - [ - typeNodeForFieldType(fieldType.keyType, state), - typeNodeForFieldType(fieldType.valueType, state), - ], - [], - ), - ), - ...loopOverContainer( - field, - fieldType, - saveName, - readName, - state, - ), - ] - } - - case SyntaxType.ListType: { - return [ - createConstStatement( - saveName, - typeNodeForFieldType(fieldType, state), - ts.createNew( - COMMON_IDENTIFIERS.Array, // class name - [typeNodeForFieldType(fieldType.valueType, state)], - [], - ), - ), - ...loopOverContainer( - field, - fieldType, - saveName, - readName, - state, - ), - ] - } - - case SyntaxType.SetType: { - return [ - createConstStatement( - saveName, - typeNodeForFieldType(fieldType, state), - ts.createNew( - COMMON_IDENTIFIERS.Set, // class name - [typeNodeForFieldType(fieldType.valueType, state)], - [], - ), - ), - ...loopOverContainer( - field, - fieldType, - saveName, - readName, - state, - ), - ] - } - - case SyntaxType.VoidKeyword: - return [ - createConstStatement( - saveName, - createVoidType(), - COMMON_IDENTIFIERS.undefined, - ), - ] - - default: - const msg: never = fieldType - throw new Error(`Non-exhaustive match for: ${msg}`) - } -} - -export function loopOverContainer( - field: FieldDefinition, - fieldType: ContainerType, - saveName: ts.Identifier, - readName: ts.Identifier, - state: IRenderState, -): Array { - switch (fieldType.type) { - case SyntaxType.MapType: { - const valueParam: ts.Identifier = ts.createUniqueName('value') - const valueConst: ts.Identifier = ts.createUniqueName('value') - const keyName: ts.Identifier = ts.createUniqueName('key') - const keyConst: ts.Identifier = ts.createUniqueName('key') - return [ - ts.createStatement( - ts.createCall( - ts.createPropertyAccess( - readName, - ts.createIdentifier('forEach'), - ), - undefined, - [ - ts.createArrowFunction( - undefined, - undefined, - [ - createFunctionParameter( - valueParam, // param name - typeNodeForFieldType( - fieldType.valueType, - state, - true, - ), // param type - undefined, - ), - createFunctionParameter( - keyName, // param name - typeNodeForFieldType( - fieldType.keyType, - state, - true, - ), // param type - undefined, - ), - ], - createVoidType(), - ts.createToken( - ts.SyntaxKind.EqualsGreaterThanToken, - ), - ts.createBlock( - [ - ...assignmentForFieldType( - field, - fieldType.valueType, - valueConst, - valueParam, - state, - ), - ...assignmentForFieldType( - field, - fieldType.keyType, - keyConst, - keyName, - state, - ), - createMethodCallStatement( - saveName, - 'set', - [keyConst, valueConst], - ), - ], - true, - ), - ), - ], - ), - ), - ] - } - - case SyntaxType.ListType: { - const valueParam: ts.Identifier = ts.createUniqueName('value') - const valueConst: ts.Identifier = ts.createUniqueName('value') - return [ - ts.createStatement( - ts.createCall( - ts.createPropertyAccess( - readName, - ts.createIdentifier('forEach'), - ), - undefined, - [ - ts.createArrowFunction( - undefined, - undefined, - [ - createFunctionParameter( - valueParam, // param name - typeNodeForFieldType( - fieldType.valueType, - state, - true, - ), // param type - undefined, - ), - ], - createVoidType(), - ts.createToken( - ts.SyntaxKind.EqualsGreaterThanToken, - ), - ts.createBlock( - [ - ...assignmentForFieldType( - field, - fieldType.valueType, - valueConst, - valueParam, - state, - ), - createMethodCallStatement( - saveName, - 'push', - [valueConst], - ), - ], - true, - ), - ), - ], - ), - ), - ] - } - - case SyntaxType.SetType: { - const valueParam: ts.Identifier = ts.createUniqueName('value') - const valueConst: ts.Identifier = ts.createUniqueName('value') - return [ - ts.createStatement( - ts.createCall( - ts.createPropertyAccess( - readName, - ts.createIdentifier('forEach'), - ), - undefined, - [ - ts.createArrowFunction( - undefined, - undefined, - [ - createFunctionParameter( - valueParam, // param name - typeNodeForFieldType( - fieldType.valueType, - state, - true, - ), // param type - undefined, - ), - ], - createVoidType(), - ts.createToken( - ts.SyntaxKind.EqualsGreaterThanToken, - ), - ts.createBlock( - [ - ...assignmentForFieldType( - field, - fieldType.valueType, - valueConst, - valueParam, - state, - ), - createMethodCallStatement( - saveName, - 'add', - [valueConst], - ), - ], - true, - ), - ), - ], - ), - ), - ] - } - } -} - /** * Assign field if contained in args: * diff --git a/src/main/render/thrift-server/struct/decode.ts b/src/main/render/thrift-server/struct/decode.ts index a5f5c928..df85044a 100644 --- a/src/main/render/thrift-server/struct/decode.ts +++ b/src/main/render/thrift-server/struct/decode.ts @@ -40,7 +40,7 @@ import { IRenderState, IResolvedIdentifier } from '../../../types' import { READ_METHODS } from './methods' -import { codecName, strictNameForStruct } from './utils' +import { strictNameForStruct, toolkitName } from './utils' export function createTempVariables( node: InterfaceWithFields, @@ -304,7 +304,7 @@ export function readValueForIdentifier( typeNodeForFieldType(fieldType, state), ts.createCall( ts.createPropertyAccess( - ts.createIdentifier(codecName(id.resolvedName)), + ts.createIdentifier(toolkitName(id.resolvedName)), COMMON_IDENTIFIERS.decode, ), undefined, diff --git a/src/main/render/thrift-server/struct/encode.ts b/src/main/render/thrift-server/struct/encode.ts index 6a8b9fd1..db5d39e7 100644 --- a/src/main/render/thrift-server/struct/encode.ts +++ b/src/main/render/thrift-server/struct/encode.ts @@ -38,7 +38,7 @@ import { IResolvedIdentifier, } from '../../../types' -import { codecName, looseNameForStruct, throwForField } from './utils' +import { looseNameForStruct, throwForField, toolkitName } from './utils' export function createTempVariables( node: InterfaceWithFields, @@ -201,7 +201,7 @@ export function writeValueForIdentifier( case SyntaxType.ExceptionDefinition: return [ createMethodCall( - ts.createIdentifier(codecName(id.resolvedName)), + ts.createIdentifier(toolkitName(id.resolvedName)), 'encode', [fieldName, COMMON_IDENTIFIERS.output], ), diff --git a/src/main/render/thrift-server/struct/index.ts b/src/main/render/thrift-server/struct/index.ts index 05e9270e..8ce46603 100644 --- a/src/main/render/thrift-server/struct/index.ts +++ b/src/main/render/thrift-server/struct/index.ts @@ -6,7 +6,7 @@ import { IRenderState } from '../../../types' import { renderInterface } from './interface' -import { renderCodec } from './codec' +import { renderToolkit } from './toolkit' import { renderClass } from './class' @@ -16,7 +16,7 @@ export function renderStruct( ): Array { return [ ...renderInterface(node, state, true), - renderCodec(node, state, true), + renderToolkit(node, state, true), renderClass(node, state, true), ] } diff --git a/src/main/render/thrift-server/struct/reader.ts b/src/main/render/thrift-server/struct/reader.ts new file mode 100644 index 00000000..c18f0a36 --- /dev/null +++ b/src/main/render/thrift-server/struct/reader.ts @@ -0,0 +1,455 @@ +import { + ContainerType, + FieldDefinition, + FieldType, + FunctionType, + SyntaxType, +} from '@creditkarma/thrift-parser' +import * as ts from 'typescript' + +import { + coerceType, + createConstStatement, + createFunctionParameter, + createMethodCallStatement, +} from '../utils' + +import { className, toolkitName } from './utils' + +import { IRenderState, IResolvedIdentifier } from '../../../types' +import { createMethodCall } from '../../shared/utils' +import { COMMON_IDENTIFIERS } from '../identifiers' +import { createVoidType, typeNodeForFieldType } from '../types' + +type ValueAssignment = ( + valueName: ts.Identifier, + field: FieldDefinition, +) => ts.Statement + +function defaultValueAssignment( + valueName: ts.Identifier, + field: FieldDefinition, +): ts.Statement { + return ts.createStatement( + ts.createAssignment( + ts.createIdentifier(`this.${field.name.value}`), + valueName, + ), + ) +} + +export function assignmentForField( + field: FieldDefinition, + state: IRenderState, + valueAssignment: ValueAssignment = defaultValueAssignment, +): Array { + const valueName: ts.Identifier = ts.createUniqueName('value') + return [ + ...assignmentForFieldType( + field, + field.fieldType, + valueName, + ts.createIdentifier(`args.${field.name.value}`), + state, + ), + valueAssignment(valueName, field), + ] +} + +// const saveSame: FieldType = coerce(readName) +export function defaultAssignment( + saveName: ts.Identifier, + readName: ts.Identifier, + fieldType: FieldType, + state: IRenderState, +): ts.Statement { + return createConstStatement( + saveName, + typeNodeForFieldType(fieldType, state), + coerceType(readName, fieldType), + ) +} + +export function assignmentForIdentifier( + field: FieldDefinition, + id: IResolvedIdentifier, + fieldType: FieldType, + saveName: ts.Identifier, + readName: ts.Identifier, + state: IRenderState, +): Array { + switch (id.definition.type) { + case SyntaxType.ConstDefinition: + throw new TypeError( + `Identifier ${ + id.definition.name.value + } is a value being used as a type`, + ) + + case SyntaxType.ServiceDefinition: + throw new TypeError( + `Service ${id.definition.name.value} is being used as a type`, + ) + + // Handle creating value for args. + case SyntaxType.UnionDefinition: + if (state.options.strictUnions) { + return [ + createConstStatement( + saveName, + typeNodeForFieldType(fieldType, state), + createMethodCall( + toolkitName(id.resolvedName), + 'create', + [readName], + ), + ), + ] + } else { + // Else we fall through to render as struct + } + + case SyntaxType.StructDefinition: + case SyntaxType.ExceptionDefinition: + return [ + createConstStatement( + saveName, + typeNodeForFieldType(fieldType, state), + ts.createNew( + ts.createIdentifier(className(id.resolvedName)), + undefined, + [readName], + ), + ), + ] + + case SyntaxType.EnumDefinition: + return [defaultAssignment(saveName, readName, fieldType, state)] + + case SyntaxType.TypedefDefinition: + return assignmentForFieldType( + field, + id.definition.definitionType, + saveName, + readName, + state, + ) + + default: + const msg: never = id.definition + throw new Error(`Non-exhaustive match for: ${msg}`) + } +} + +export function assignmentForFieldType( + field: FieldDefinition, + fieldType: FunctionType, + saveName: ts.Identifier, + readName: ts.Identifier, + state: IRenderState, +): Array { + switch (fieldType.type) { + case SyntaxType.Identifier: + return assignmentForIdentifier( + field, + state.identifiers[fieldType.value], + fieldType, + saveName, + readName, + state, + ) + + /** + * Base types: + * + * SyntaxType.StringKeyword | SyntaxType.DoubleKeyword | SyntaxType.BoolKeyword | + * SyntaxType.I8Keyword | SyntaxType.I16Keyword | SyntaxType.I32Keyword | + * SyntaxType.I64Keyword | SyntaxType.BinaryKeyword | SyntaxType.ByteKeyword; + */ + case SyntaxType.BoolKeyword: + case SyntaxType.ByteKeyword: + case SyntaxType.BinaryKeyword: + case SyntaxType.StringKeyword: + case SyntaxType.DoubleKeyword: + case SyntaxType.I8Keyword: + case SyntaxType.I16Keyword: + case SyntaxType.I32Keyword: + case SyntaxType.I64Keyword: { + return [defaultAssignment(saveName, readName, fieldType, state)] + } + + /** + * Container types: + * + * SetType | MapType | ListType + */ + case SyntaxType.MapType: { + return [ + createConstStatement( + saveName, + typeNodeForFieldType(fieldType, state), + ts.createNew( + COMMON_IDENTIFIERS.Map, // class name + [ + typeNodeForFieldType(fieldType.keyType, state), + typeNodeForFieldType(fieldType.valueType, state), + ], + [], + ), + ), + ...loopOverContainer( + field, + fieldType, + saveName, + readName, + state, + ), + ] + } + + case SyntaxType.ListType: { + return [ + createConstStatement( + saveName, + typeNodeForFieldType(fieldType, state), + ts.createNew( + COMMON_IDENTIFIERS.Array, // class name + [typeNodeForFieldType(fieldType.valueType, state)], + [], + ), + ), + ...loopOverContainer( + field, + fieldType, + saveName, + readName, + state, + ), + ] + } + + case SyntaxType.SetType: { + return [ + createConstStatement( + saveName, + typeNodeForFieldType(fieldType, state), + ts.createNew( + COMMON_IDENTIFIERS.Set, // class name + [typeNodeForFieldType(fieldType.valueType, state)], + [], + ), + ), + ...loopOverContainer( + field, + fieldType, + saveName, + readName, + state, + ), + ] + } + + case SyntaxType.VoidKeyword: + return [ + createConstStatement( + saveName, + createVoidType(), + COMMON_IDENTIFIERS.undefined, + ), + ] + + default: + const msg: never = fieldType + throw new Error(`Non-exhaustive match for: ${msg}`) + } +} + +export function loopOverContainer( + field: FieldDefinition, + fieldType: ContainerType, + saveName: ts.Identifier, + readName: ts.Identifier, + state: IRenderState, +): Array { + switch (fieldType.type) { + case SyntaxType.MapType: { + const valueParam: ts.Identifier = ts.createUniqueName('value') + const valueConst: ts.Identifier = ts.createUniqueName('value') + const keyName: ts.Identifier = ts.createUniqueName('key') + const keyConst: ts.Identifier = ts.createUniqueName('key') + return [ + ts.createStatement( + ts.createCall( + ts.createPropertyAccess( + readName, + ts.createIdentifier('forEach'), + ), + undefined, + [ + ts.createArrowFunction( + undefined, + undefined, + [ + createFunctionParameter( + valueParam, // param name + typeNodeForFieldType( + fieldType.valueType, + state, + true, + ), // param type + undefined, + ), + createFunctionParameter( + keyName, // param name + typeNodeForFieldType( + fieldType.keyType, + state, + true, + ), // param type + undefined, + ), + ], + createVoidType(), + ts.createToken( + ts.SyntaxKind.EqualsGreaterThanToken, + ), + ts.createBlock( + [ + ...assignmentForFieldType( + field, + fieldType.valueType, + valueConst, + valueParam, + state, + ), + ...assignmentForFieldType( + field, + fieldType.keyType, + keyConst, + keyName, + state, + ), + createMethodCallStatement( + saveName, + 'set', + [keyConst, valueConst], + ), + ], + true, + ), + ), + ], + ), + ), + ] + } + + case SyntaxType.ListType: { + const valueParam: ts.Identifier = ts.createUniqueName('value') + const valueConst: ts.Identifier = ts.createUniqueName('value') + return [ + ts.createStatement( + ts.createCall( + ts.createPropertyAccess( + readName, + ts.createIdentifier('forEach'), + ), + undefined, + [ + ts.createArrowFunction( + undefined, + undefined, + [ + createFunctionParameter( + valueParam, // param name + typeNodeForFieldType( + fieldType.valueType, + state, + true, + ), // param type + undefined, + ), + ], + createVoidType(), + ts.createToken( + ts.SyntaxKind.EqualsGreaterThanToken, + ), + ts.createBlock( + [ + ...assignmentForFieldType( + field, + fieldType.valueType, + valueConst, + valueParam, + state, + ), + createMethodCallStatement( + saveName, + 'push', + [valueConst], + ), + ], + true, + ), + ), + ], + ), + ), + ] + } + + case SyntaxType.SetType: { + const valueParam: ts.Identifier = ts.createUniqueName('value') + const valueConst: ts.Identifier = ts.createUniqueName('value') + return [ + ts.createStatement( + ts.createCall( + ts.createPropertyAccess( + readName, + ts.createIdentifier('forEach'), + ), + undefined, + [ + ts.createArrowFunction( + undefined, + undefined, + [ + createFunctionParameter( + valueParam, // param name + typeNodeForFieldType( + fieldType.valueType, + state, + true, + ), // param type + undefined, + ), + ], + createVoidType(), + ts.createToken( + ts.SyntaxKind.EqualsGreaterThanToken, + ), + ts.createBlock( + [ + ...assignmentForFieldType( + field, + fieldType.valueType, + valueConst, + valueParam, + state, + ), + createMethodCallStatement( + saveName, + 'add', + [valueConst], + ), + ], + true, + ), + ), + ], + ), + ), + ] + } + } +} diff --git a/src/main/render/thrift-server/struct/codec.ts b/src/main/render/thrift-server/struct/toolkit.ts similarity index 91% rename from src/main/render/thrift-server/struct/codec.ts rename to src/main/render/thrift-server/struct/toolkit.ts index 9d954e5e..23746532 100644 --- a/src/main/render/thrift-server/struct/codec.ts +++ b/src/main/render/thrift-server/struct/toolkit.ts @@ -12,13 +12,13 @@ import { createDecodeMethod } from './decode' import { IRenderState } from '../../../types' import { - codecNameForStruct, looseNameForStruct, strictNameForStruct, tokens, + toolkitNameForStruct, } from './utils' -export function renderCodec( +export function renderToolkit( node: InterfaceWithFields, state: IRenderState, isExported: boolean, @@ -26,7 +26,7 @@ export function renderCodec( return ts.createVariableStatement( tokens(isExported), createConst( - ts.createIdentifier(codecNameForStruct(node)), + ts.createIdentifier(toolkitNameForStruct(node)), ts.createTypeReferenceNode(THRIFT_IDENTIFIERS.IStructCodec, [ ts.createTypeReferenceNode( ts.createIdentifier(looseNameForStruct(node, state)), diff --git a/src/main/render/thrift-server/struct/utils.ts b/src/main/render/thrift-server/struct/utils.ts index ff51ea55..e59db6c0 100644 --- a/src/main/render/thrift-server/struct/utils.ts +++ b/src/main/render/thrift-server/struct/utils.ts @@ -72,8 +72,8 @@ export function strictNameForStruct( return strictName(node.name.value, node.type, state) } -export function codecNameForStruct(node: InterfaceWithFields): string { - return codecName(node.name.value) +export function toolkitNameForStruct(node: InterfaceWithFields): string { + return toolkitName(node.name.value) } export function className(name: string): string { @@ -110,7 +110,8 @@ export function strictName( } } -export function codecName(name: string): string { +// TODO: This will be renamed to Toolkit in a breaking release +export function toolkitName(name: string): string { return makeNameForNode(name, (part: string) => { return `${part}Codec` }) diff --git a/src/main/render/thrift-server/typedef.ts b/src/main/render/thrift-server/typedef.ts index 8f056a5c..97b8affc 100644 --- a/src/main/render/thrift-server/typedef.ts +++ b/src/main/render/thrift-server/typedef.ts @@ -1,12 +1,132 @@ import * as ts from 'typescript' -import { SyntaxType, TypedefDefinition } from '@creditkarma/thrift-parser' +import { + FieldDefinition, + SyntaxType, + TypedefDefinition, + UnionDefinition, +} from '@creditkarma/thrift-parser' import { TypeMapping } from './types' import { IRenderState, IResolvedIdentifier } from '../../types' -import { className, codecName, looseName, strictName } from './struct/utils' +import { className, looseName, strictName, toolkitName } from './struct/utils' +import { + fieldInterfaceName, + renderUnionTypeName, + unionTypeName, +} from './union/union-fields' + +function renderStrictInterfaceReexport( + id: IResolvedIdentifier, + node: TypedefDefinition, + state: IRenderState, +): ts.ImportEqualsDeclaration { + return ts.createImportEqualsDeclaration( + undefined, + [ts.createToken(ts.SyntaxKind.ExportKeyword)], + ts.createIdentifier( + strictName(node.name.value, id.definition.type, state), + ), + ts.createIdentifier( + `${id.pathName}.${strictName(id.name, id.definition.type, state)}`, + ), + ) +} + +function renderLooseInterfaceReexport( + id: IResolvedIdentifier, + node: TypedefDefinition, + state: IRenderState, +): ts.ImportEqualsDeclaration { + return ts.createImportEqualsDeclaration( + undefined, + [ts.createToken(ts.SyntaxKind.ExportKeyword)], + ts.createIdentifier( + looseName(node.name.value, id.definition.type, state), + ), + ts.createIdentifier( + `${id.pathName}.${looseName(id.name, id.definition.type, state)}`, + ), + ) +} + +function renderClassReexport( + id: IResolvedIdentifier, + node: TypedefDefinition, +): ts.ImportEqualsDeclaration { + return ts.createImportEqualsDeclaration( + undefined, + [ts.createToken(ts.SyntaxKind.ExportKeyword)], + ts.createIdentifier(className(node.name.value)), + ts.createIdentifier(`${id.pathName}.${className(id.name)}`), + ) +} + +function renderToolkitReexport( + id: IResolvedIdentifier, + node: TypedefDefinition, +): ts.ImportEqualsDeclaration { + return ts.createImportEqualsDeclaration( + undefined, + [ts.createToken(ts.SyntaxKind.ExportKeyword)], + ts.createIdentifier(toolkitName(node.name.value)), + ts.createIdentifier(toolkitName(id.resolvedName)), + ) +} + +function renderUnionTypeReexport( + id: IResolvedIdentifier, + node: TypedefDefinition, + state: IRenderState, +): ts.ImportEqualsDeclaration { + return ts.createImportEqualsDeclaration( + undefined, + [ts.createToken(ts.SyntaxKind.ExportKeyword)], + ts.createIdentifier(renderUnionTypeName(node.name.value, true)), + ts.createIdentifier( + `${id.pathName}.${renderUnionTypeName(id.name, true)}`, + ), + ) +} + +function renderUnionInterfaceReexports( + id: IResolvedIdentifier, + union: UnionDefinition, + node: TypedefDefinition, + state: IRenderState, + strict: boolean, +): Array { + return union.fields.map((next: FieldDefinition) => { + return ts.createImportEqualsDeclaration( + undefined, + [ts.createToken(ts.SyntaxKind.ExportKeyword)], + ts.createIdentifier( + fieldInterfaceName(node.name.value, next.name.value, strict), + ), + ts.createIdentifier( + `${id.pathName}.${fieldInterfaceName( + union.name.value, + next.name.value, + strict, + )}`, + ), + ) + }) +} + +function renderUnionArgsReexport( + id: IResolvedIdentifier, + node: TypedefDefinition, +): ts.ImportEqualsDeclaration { + return ts.createImportEqualsDeclaration( + undefined, + [ts.createToken(ts.SyntaxKind.ExportKeyword)], + ts.createIdentifier(unionTypeName(node.name.value, false)), + ts.createIdentifier(`${id.pathName}.${unionTypeName(id.name, false)}`), + ) +} function renderTypeDefForIdentifier( id: IResolvedIdentifier, @@ -15,50 +135,38 @@ function renderTypeDefForIdentifier( state: IRenderState, ): Array { switch (id.definition.type) { - case SyntaxType.ExceptionDefinition: - case SyntaxType.StructDefinition: case SyntaxType.UnionDefinition: - return [ - ts.createImportEqualsDeclaration( - undefined, - [ts.createToken(ts.SyntaxKind.ExportKeyword)], - ts.createIdentifier( - strictName(node.name.value, id.definition.type, state), - ), - ts.createIdentifier( - `${id.pathName}.${strictName( - id.name, - id.definition.type, - state, - )}`, - ), - ), - ts.createImportEqualsDeclaration( - undefined, - [ts.createToken(ts.SyntaxKind.ExportKeyword)], - ts.createIdentifier( - looseName(node.name.value, id.definition.type, state), + if (state.options.strictUnions) { + return [ + renderUnionTypeReexport(id, node, state), + renderClassReexport(id, node), + ...renderUnionInterfaceReexports( + id, + id.definition, + node, + state, + true, ), - ts.createIdentifier( - `${id.pathName}.${looseName( - id.name, - id.definition.type, - state, - )}`, + renderUnionArgsReexport(id, node), + ...renderUnionInterfaceReexports( + id, + id.definition, + node, + state, + false, ), - ), - ts.createImportEqualsDeclaration( - undefined, - [ts.createToken(ts.SyntaxKind.ExportKeyword)], - ts.createIdentifier(className(node.name.value)), - ts.createIdentifier(`${id.pathName}.${className(id.name)}`), - ), - ts.createImportEqualsDeclaration( - undefined, - [ts.createToken(ts.SyntaxKind.ExportKeyword)], - ts.createIdentifier(codecName(node.name.value)), - ts.createIdentifier(codecName(id.resolvedName)), - ), + renderToolkitReexport(id, node), + ] + } else { + // Fallthrough to reexport union as struct + } + case SyntaxType.ExceptionDefinition: + case SyntaxType.StructDefinition: + return [ + renderStrictInterfaceReexport(id, node, state), + renderLooseInterfaceReexport(id, node, state), + renderClassReexport(id, node), + renderToolkitReexport(id, node), ] default: diff --git a/src/main/render/thrift-server/union/class.ts b/src/main/render/thrift-server/union/class.ts index ea6bd033..ba68b586 100644 --- a/src/main/render/thrift-server/union/class.ts +++ b/src/main/render/thrift-server/union/class.ts @@ -10,23 +10,17 @@ import { IRenderState } from '../../../types' import { COMMON_IDENTIFIERS, THRIFT_IDENTIFIERS } from '../identifiers' -import { - createClassConstructor, - createFunctionParameter, - createNotNullCheck, -} from '../utils' +import { createClassConstructor, createFunctionParameter } from '../utils' import { classNameForStruct, createSuperCall, extendsAbstract, implementsInterface, - throwForField, tokens, } from '../struct/utils' import { - assignmentForField as _assignmentForField, createArgsParameterForStruct, createStaticReadMethod, createStaticWriteMethod, @@ -34,10 +28,12 @@ import { renderFieldDeclarations, } from '../struct/class' +import { assignmentForField as _assignmentForField } from '../struct/reader' + import { + createFieldAssignment, createFieldIncrementer, createFieldValidation, - incrementFieldsSet, } from './utils' import { renderAnnotations, renderFieldAnnotations } from '../annotations' @@ -136,56 +132,3 @@ export function createFieldsForStruct( return renderFieldDeclarations(field, state) }) } - -/** - * This actually creates the assignment for some field in the args argument to the corresponding field - * in our struct class - * - * interface IStructArgs { - * id: number; - * } - * - * constructor(args: IStructArgs) { - * if (args.id !== null && args.id !== undefined) { - * this.id = args.id; - * } - * } - * - * This function creates the 'this.id = args.id' bit. - */ -export function assignmentForField( - field: FieldDefinition, - state: IRenderState, -): Array { - return [incrementFieldsSet(), ..._assignmentForField(field, state)] -} - -/** - * Assign field if contained in args: - * - * if (args && args. != null) { - * this. = args. - * } - * - * If field is required throw an error: - * - * else { - * throw new Thrift.TProtocolException(Thrift.TProtocolExceptionType.UNKNOWN, 'Required field {{fieldName}} is unset!') - * } - */ -export function createFieldAssignment( - field: FieldDefinition, - state: IRenderState, -): ts.IfStatement { - const hasValue: ts.BinaryExpression = createNotNullCheck( - ts.createPropertyAccess(COMMON_IDENTIFIERS.args, `${field.name.value}`), - ) - const thenAssign: Array = assignmentForField(field, state) - const elseThrow: ts.Statement | undefined = throwForField(field) - - return ts.createIf( - hasValue, - ts.createBlock([...thenAssign], true), - elseThrow === undefined ? undefined : ts.createBlock([elseThrow], true), - ) -} diff --git a/src/main/render/thrift-server/union/codec.ts b/src/main/render/thrift-server/union/codec.ts deleted file mode 100644 index 6f41e647..00000000 --- a/src/main/render/thrift-server/union/codec.ts +++ /dev/null @@ -1,49 +0,0 @@ -import * as ts from 'typescript' - -import { UnionDefinition } from '@creditkarma/thrift-parser' - -import { createConst } from '../utils' - -import { THRIFT_IDENTIFIERS } from '../identifiers' - -import { createEncodeMethod } from './encode' - -import { createDecodeMethod } from './decode' - -import { IRenderState } from '../../../types' -import { - codecNameForStruct, - looseNameForStruct, - strictNameForStruct, - tokens, -} from '../struct/utils' - -export function renderCodec( - node: UnionDefinition, - state: IRenderState, - isExported: boolean, -): ts.Statement { - return ts.createVariableStatement( - tokens(isExported), - createConst( - ts.createIdentifier(codecNameForStruct(node)), - ts.createTypeReferenceNode(THRIFT_IDENTIFIERS.IStructCodec, [ - ts.createTypeReferenceNode( - ts.createIdentifier(looseNameForStruct(node, state)), - undefined, - ), - ts.createTypeReferenceNode( - ts.createIdentifier(strictNameForStruct(node, state)), - undefined, - ), - ]), - ts.createObjectLiteral( - [ - createEncodeMethod(node, state), - createDecodeMethod(node, state), - ], - true, - ), - ), - ) -} diff --git a/src/main/render/thrift-server/union/create.ts b/src/main/render/thrift-server/union/create.ts new file mode 100644 index 00000000..5121b97d --- /dev/null +++ b/src/main/render/thrift-server/union/create.ts @@ -0,0 +1,281 @@ +import * as ts from 'typescript' + +import { + FieldDefinition, + InterfaceWithFields, + SyntaxType, + UnionDefinition, +} from '@creditkarma/thrift-parser' + +import { COMMON_IDENTIFIERS } from '../identifiers' + +import { thriftTypeForFieldType } from '../types' + +import { createFunctionParameter } from '../utils' + +import { + createEqualsCheck, + getInitializerForField, + hasRequiredField, + throwProtocolException, +} from '../utils' + +import { IRenderState } from '../../../types' + +import { + createCheckForFields, + createSkipBlock, + // readFieldEnd, + // readStructBegin, + // readStructEnd, + readValueForFieldType, +} from '../struct/decode' + +import { strictNameForStruct } from '../struct/utils' +import { fieldTypeAccess, unionTypeName } from './union-fields' +import { + createFieldAssignment, + createFieldIncrementer, + createFieldValidation, + createReturnVariable, + incrementFieldsSet, +} from './utils' + +function createArgsParameter(node: UnionDefinition): ts.ParameterDeclaration { + return createFunctionParameter( + 'args', // param name + ts.createTypeReferenceNode( + ts.createIdentifier(unionTypeName(node.name.value, false)), + undefined, + ), + ) +} + +export function createCreateMethod( + node: UnionDefinition, + state: IRenderState, +): ts.MethodDeclaration { + const inputParameter: ts.ParameterDeclaration = createArgsParameter(node) + const returnVariable: ts.VariableStatement = createReturnVariable( + node, + state, + ) + + const fieldsSet: ts.VariableStatement = createFieldIncrementer() + + const fieldAssignments: Array = node.fields.map( + (next: FieldDefinition) => { + return createFieldAssignment(next, state) + }, + ) + + return ts.createMethod( + undefined, + undefined, + undefined, + COMMON_IDENTIFIERS.create, + undefined, + undefined, + [inputParameter], + ts.createTypeReferenceNode( + ts.createIdentifier(strictNameForStruct(node, state)), + undefined, + ), // return type + ts.createBlock( + [ + fieldsSet, + returnVariable, + ...fieldAssignments, + createFieldValidation(node), + ts.createIf( + ts.createBinary( + COMMON_IDENTIFIERS._returnValue, + ts.SyntaxKind.ExclamationEqualsEqualsToken, + ts.createNull(), + ), + ts.createBlock( + [createReturnForFields(node, node.fields, state)], + true, + ), + ts.createBlock( + [ + throwProtocolException( + 'UNKNOWN', + 'Unable to read data for TUnion', + ), + ], + true, + ), + ), + ], + true, + ), + ) +} + +function createUnionObjectForField( + node: UnionDefinition, + field: FieldDefinition, +): ts.ObjectLiteralExpression { + return ts.createObjectLiteral( + [ + ts.createPropertyAssignment( + COMMON_IDENTIFIERS.__type, + ts.createIdentifier(fieldTypeAccess(node, field)), + ), + ts.createPropertyAssignment( + ts.createIdentifier(field.name.value), + ts.createPropertyAccess( + COMMON_IDENTIFIERS._returnValue, + field.name.value, + ), + ), + ], + true, + ) +} + +function createReturnForFields( + node: UnionDefinition, + fields: Array, + state: IRenderState, +): ts.Statement { + if (state.options.strictUnions) { + const [head, ...tail] = fields + if (tail.length > 0) { + return ts.createIf( + ts.createPropertyAccess( + COMMON_IDENTIFIERS._returnValue, + head.name.value, + ), + ts.createBlock( + [ts.createReturn(createUnionObjectForField(node, head))], + true, + ), + ts.createBlock( + [createReturnForFields(node, tail, state)], + true, + ), + ) + } else { + return ts.createReturn(createUnionObjectForField(node, head)) + } + } else { + return ts.createReturn(COMMON_IDENTIFIERS._returnValue) + } +} + +/** + * EXAMPLE + * + * case 1: { + * if (fieldType === Thrift.Type.I32) { + * this.id = input.readI32(); + * } + * else { + * input.skip(fieldType); + * } + * break; + * } + */ +export function createCaseForField( + node: UnionDefinition, + field: FieldDefinition, + state: IRenderState, +): ts.CaseClause { + const fieldAlias: ts.Identifier = ts.createUniqueName('value') + const checkType: ts.IfStatement = ts.createIf( + createEqualsCheck( + COMMON_IDENTIFIERS.fieldType, + thriftTypeForFieldType(field.fieldType, state.identifiers), + ), + ts.createBlock( + [ + incrementFieldsSet(), + ...readValueForFieldType(field.fieldType, fieldAlias, state), + ...endReadForField(fieldAlias, field), + ], + true, + ), + createSkipBlock(), + ) + + if (field.fieldID !== null) { + return ts.createCaseClause(ts.createLiteral(field.fieldID.value), [ + checkType, + ts.createBreak(), + ]) + } else { + throw new Error(`FieldID on line ${field.loc.start.line} is null`) + } +} + +export function endReadForField( + fieldName: ts.Identifier, + field: FieldDefinition, +): Array { + switch (field.fieldType.type) { + case SyntaxType.VoidKeyword: + return [] + + default: + return [ + ts.createStatement( + ts.createAssignment( + COMMON_IDENTIFIERS._returnValue, + ts.createObjectLiteral([ + ts.createPropertyAssignment( + field.name.value, + fieldName, + ), + ]), + ), + ), + ] + } +} + +export function createReturnForStruct( + struct: InterfaceWithFields, +): ts.Statement { + if (hasRequiredField(struct)) { + return ts.createIf( + createCheckForFields(struct.fields), + ts.createBlock( + [ + ts.createReturn( + ts.createObjectLiteral( + struct.fields.map( + ( + next: FieldDefinition, + ): ts.ObjectLiteralElementLike => { + return ts.createPropertyAssignment( + next.name.value, + getInitializerForField('_args', next), + ) + }, + ), + true, // multiline + ), + ), + ], + true, + ), + ts.createBlock( + [ + throwProtocolException( + 'UNKNOWN', + `Unable to read ${struct.name.value} from input`, + ), + ], + true, + ), + ) + } else { + return ts.createReturn( + ts.createNew(ts.createIdentifier(struct.name.value), undefined, [ + COMMON_IDENTIFIERS._args, + ]), + ) + } +} diff --git a/src/main/render/thrift-server/union/decode.ts b/src/main/render/thrift-server/union/decode.ts index 885b50e5..5c456868 100644 --- a/src/main/render/thrift-server/union/decode.ts +++ b/src/main/render/thrift-server/union/decode.ts @@ -19,7 +19,6 @@ import { createNumberType, thriftTypeForFieldType } from '../types' import { createConstStatement, createEqualsCheck, - createLetStatement, getInitializerForField, hasRequiredField, propertyAccessForIdentifier, @@ -39,39 +38,15 @@ import { readValueForFieldType, } from '../struct/decode' -import { createAnyType } from '../../shared/types' import { strictNameForStruct } from '../struct/utils' +import { fieldTypeAccess } from './union-fields' import { createFieldIncrementer, createFieldValidation, + createReturnVariable, incrementFieldsSet, } from './utils' -function createReturnVariable( - node: UnionDefinition, - state: IRenderState, -): ts.VariableStatement { - if (state.options.strictUnions) { - return createLetStatement( - COMMON_IDENTIFIERS._returnValue, - createAnyType(), - ts.createNull(), - ) - } else { - return createLetStatement( - COMMON_IDENTIFIERS._returnValue, - ts.createUnionTypeNode([ - ts.createTypeReferenceNode( - ts.createIdentifier(strictNameForStruct(node, state)), - undefined, - ), - ts.createNull(), - ]), - ts.createNull(), - ) - } -} - export function createDecodeMethod( node: UnionDefinition, state: IRenderState, @@ -167,7 +142,7 @@ export function createDecodeMethod( ts.createNull(), ), ts.createBlock( - [createReturnForFields(node.fields, state)], + [createReturnForFields(node, node.fields, state)], true, ), ts.createBlock( @@ -187,13 +162,14 @@ export function createDecodeMethod( } function createUnionObjectForField( + node: UnionDefinition, field: FieldDefinition, ): ts.ObjectLiteralExpression { return ts.createObjectLiteral( [ ts.createPropertyAssignment( COMMON_IDENTIFIERS.__type, - ts.createLiteral(field.name.value), + ts.createIdentifier(fieldTypeAccess(node, field)), ), ts.createPropertyAssignment( ts.createIdentifier(field.name.value), @@ -208,6 +184,7 @@ function createUnionObjectForField( } function createReturnForFields( + node: UnionDefinition, fields: Array, state: IRenderState, ): ts.Statement { @@ -220,13 +197,13 @@ function createReturnForFields( head.name.value, ), ts.createBlock( - [ts.createReturn(createUnionObjectForField(head))], + [ts.createReturn(createUnionObjectForField(node, head))], true, ), - ts.createBlock([createReturnForFields(tail, state)]), + ts.createBlock([createReturnForFields(node, tail, state)]), ) } else { - return ts.createReturn(createUnionObjectForField(head)) + return ts.createReturn(createUnionObjectForField(node, head)) } } else { return ts.createReturn(COMMON_IDENTIFIERS._returnValue) diff --git a/src/main/render/thrift-server/union/index.ts b/src/main/render/thrift-server/union/index.ts index ac5cf81f..55fefe44 100644 --- a/src/main/render/thrift-server/union/index.ts +++ b/src/main/render/thrift-server/union/index.ts @@ -6,11 +6,11 @@ import { IRenderState } from '../../../types' import { renderInterface } from '../struct/interface' -import { renderCodec } from './codec' +import { renderToolkit } from './toolkit' import { renderClass } from './class' -import { renderUnionsForFields } from './union-fields' +import { renderUnionsForFields, renderUnionTypes } from './union-fields' export function renderUnion( node: UnionDefinition, @@ -19,7 +19,7 @@ export function renderUnion( ): Array { return [ ...renderInterface(node, state, isExported), - renderCodec(node, state, isExported), + renderToolkit(node, state, isExported), renderClass(node, state, isExported), ] } @@ -30,8 +30,9 @@ export function renderStrictUnion( isExported: boolean = true, ): Array { return [ + renderUnionTypes(node, isExported), ...renderUnionsForFields(node, state, isExported, true), ...renderUnionsForFields(node, state, isExported, false), - renderCodec(node, state, isExported), + renderToolkit(node, state, isExported), ] } diff --git a/src/main/render/thrift-server/union/toolkit.ts b/src/main/render/thrift-server/union/toolkit.ts new file mode 100644 index 00000000..5e18f338 --- /dev/null +++ b/src/main/render/thrift-server/union/toolkit.ts @@ -0,0 +1,78 @@ +import * as ts from 'typescript' + +import { UnionDefinition } from '@creditkarma/thrift-parser' + +import { createConst } from '../utils' + +import { THRIFT_IDENTIFIERS } from '../identifiers' + +import { createEncodeMethod } from './encode' + +import { createDecodeMethod } from './decode' + +import { createCreateMethod } from './create' + +import { IRenderState } from '../../../types' +import { + looseNameForStruct, + strictNameForStruct, + tokens, + toolkitNameForStruct, +} from '../struct/utils' + +function renderMethodsForCodec( + node: UnionDefinition, + state: IRenderState, +): Array { + if (state.options.strictUnions) { + return [ + createCreateMethod(node, state), + createEncodeMethod(node, state), + createDecodeMethod(node, state), + ] + } else { + return [ + createEncodeMethod(node, state), + createDecodeMethod(node, state), + ] + } +} + +function toolkitBaseClass(state: IRenderState): ts.Identifier { + if (state.options.strictUnions) { + return THRIFT_IDENTIFIERS.IStructToolkit + } else { + return THRIFT_IDENTIFIERS.IStructCodec + } +} + +function renderToolkitTypeNode( + node: UnionDefinition, + state: IRenderState, +): ts.TypeNode { + return ts.createTypeReferenceNode(toolkitBaseClass(state), [ + ts.createTypeReferenceNode( + ts.createIdentifier(looseNameForStruct(node, state)), + undefined, + ), + ts.createTypeReferenceNode( + ts.createIdentifier(strictNameForStruct(node, state)), + undefined, + ), + ]) +} + +export function renderToolkit( + node: UnionDefinition, + state: IRenderState, + isExported: boolean, +): ts.Statement { + return ts.createVariableStatement( + tokens(isExported), + createConst( + ts.createIdentifier(toolkitNameForStruct(node)), + renderToolkitTypeNode(node, state), + ts.createObjectLiteral(renderMethodsForCodec(node, state), true), + ), + ) +} diff --git a/src/main/render/thrift-server/union/union-fields.ts b/src/main/render/thrift-server/union/union-fields.ts index b6b027f6..77630824 100644 --- a/src/main/render/thrift-server/union/union-fields.ts +++ b/src/main/render/thrift-server/union/union-fields.ts @@ -8,6 +8,46 @@ import { createVoidType } from '../../shared/types' import { className, tokens } from '../struct/utils' import { typeNodeForFieldType } from '../types' +export function renderUnionTypes( + node: UnionDefinition, + isExported: boolean, +): ts.Statement { + return ts.createEnumDeclaration( + undefined, // decorators + tokens(isExported), // modifiers + renderUnionTypeName(node.name.value, true), // enum name + node.fields.map((field: FieldDefinition) => { + return ts.createEnumMember( + fieldTypeName(node.name.value, field.name.value, true), + ts.createLiteral(field.name.value), + ) + }), + ) +} + +export function fieldTypeAccess( + node: UnionDefinition, + field: FieldDefinition, +): string { + return `${renderUnionTypeName(node.name.value, true)}.${fieldTypeName( + node.name.value, + field.name.value, + true, + )}` +} + +export function unionTypeName(name: string, strict: boolean): string { + if (strict) { + return className(name) + } else { + return `${className(name)}Args` + } +} + +export function renderUnionTypeName(name: string, strict: boolean): string { + return `${unionTypeName(name, strict)}Type` +} + function capitalize(str: string): string { if (str.length > 0) { const head: string = str[0] @@ -18,26 +58,84 @@ function capitalize(str: string): string { } } -function fieldInterfaceName( - node: UnionDefinition, - field: FieldDefinition, +export function fieldTypeName( + nodeName: string, + fieldName: string, strict: boolean, ): string { if (strict) { - return `I${node.name.value}With${capitalize(field.name.value)}` + return `${nodeName}With${capitalize(fieldName)}` } else { - return `I${node.name.value}With${capitalize(field.name.value)}Args` + return `${nodeName}With${capitalize(fieldName)}Args` } } -function unionTypeName(node: UnionDefinition, strict: boolean): string { +export function fieldInterfaceName( + nodeName: string, + fieldName: string, + strict: boolean, +): string { if (strict) { - return className(node.name.value) + return `I${fieldTypeName(nodeName, fieldName, strict)}` } else { - return `${className(node.name.value)}Args` + return `I${fieldTypeName(nodeName, fieldName, strict)}` } } +function renderInterfaceForField( + node: UnionDefinition, + field: FieldDefinition, + state: IRenderState, + strict: boolean, + isExported: boolean, +): ts.InterfaceDeclaration { + const signatures = node.fields.map((next: FieldDefinition) => { + if (field.name.value === next.name.value) { + return ts.createPropertySignature( + undefined, + field.name.value, + undefined, + typeNodeForFieldType(next.fieldType, state, !strict), + undefined, + ) + } else { + return ts.createPropertySignature( + undefined, + next.name.value, + ts.createToken(ts.SyntaxKind.QuestionToken), + createVoidType(), + undefined, + ) + } + }) + + if (strict) { + signatures.unshift( + ts.createPropertySignature( + undefined, + COMMON_IDENTIFIERS.__type, + undefined, + ts.createTypeReferenceNode( + ts.createIdentifier(fieldTypeAccess(node, field)), + undefined, + ), + undefined, + ), + ) + } + + return ts.createInterfaceDeclaration( + undefined, + tokens(isExported), + ts.createIdentifier( + fieldInterfaceName(node.name.value, field.name.value, strict), + ), + [], + [], + signatures, + ) +} + export function renderUnionsForFields( node: UnionDefinition, state: IRenderState, @@ -48,12 +146,16 @@ export function renderUnionsForFields( ts.createTypeAliasDeclaration( undefined, tokens(isExported), - unionTypeName(node, strict), + unionTypeName(node.name.value, strict), undefined, ts.createUnionTypeNode([ ...node.fields.map((next: FieldDefinition) => { return ts.createTypeReferenceNode( - fieldInterfaceName(node, next, strict), + fieldInterfaceName( + node.name.value, + next.name.value, + strict, + ), undefined, ) }), @@ -61,52 +163,7 @@ export function renderUnionsForFields( ), ...node.fields.map( (next: FieldDefinition): ts.InterfaceDeclaration => { - const signatures = node.fields.map((field: FieldDefinition) => { - if (field.name.value === next.name.value) { - return ts.createPropertySignature( - undefined, - field.name.value, - undefined, - typeNodeForFieldType( - field.fieldType, - state, - !strict, - ), - undefined, - ) - } else { - return ts.createPropertySignature( - undefined, - field.name.value, - ts.createToken(ts.SyntaxKind.QuestionToken), - createVoidType(), - undefined, - ) - } - }) - - if (strict) { - signatures.unshift( - ts.createPropertySignature( - undefined, - COMMON_IDENTIFIERS.__type, - undefined, - ts.createLiteralTypeNode( - ts.createLiteral(next.name.value), - ), - undefined, - ), - ) - } - - return ts.createInterfaceDeclaration( - undefined, - tokens(isExported), - ts.createIdentifier(fieldInterfaceName(node, next, strict)), - [], - [], - signatures, - ) + return renderInterfaceForField(node, next, state, strict, true) }, ), ] diff --git a/src/main/render/thrift-server/union/utils.ts b/src/main/render/thrift-server/union/utils.ts index 49226611..ab903286 100644 --- a/src/main/render/thrift-server/union/utils.ts +++ b/src/main/render/thrift-server/union/utils.ts @@ -1,11 +1,40 @@ import * as ts from 'typescript' -import { UnionDefinition } from '@creditkarma/thrift-parser' +import { FieldDefinition, UnionDefinition } from '@creditkarma/thrift-parser' import { createLetStatement, throwProtocolException } from '../utils' +import { IRenderState } from '../../../types' import { COMMON_IDENTIFIERS } from '../../shared/identifiers' -import { createNumberType } from '../types' +import { assignmentForField as _assignmentForField } from '../struct/reader' +import { strictNameForStruct } from '../struct/utils' +import { createAnyType, createNumberType } from '../types' +import { createNotNullCheck } from '../utils' + +export function createReturnVariable( + node: UnionDefinition, + state: IRenderState, +): ts.VariableStatement { + if (state.options.strictUnions) { + return createLetStatement( + COMMON_IDENTIFIERS._returnValue, + createAnyType(), + ts.createNull(), + ) + } else { + return createLetStatement( + COMMON_IDENTIFIERS._returnValue, + ts.createUnionTypeNode([ + ts.createTypeReferenceNode( + ts.createIdentifier(strictNameForStruct(node, state)), + undefined, + ), + ts.createNull(), + ]), + ts.createNull(), + ) + } +} // let _fieldsSet: number = 0; export function createFieldIncrementer(): ts.VariableStatement { @@ -65,3 +94,73 @@ export function createFieldValidation(node: UnionDefinition): ts.IfStatement { ), ) } + +function returnAssignment( + valueName: ts.Identifier, + field: FieldDefinition, +): ts.Statement { + return ts.createStatement( + ts.createAssignment( + COMMON_IDENTIFIERS._returnValue, + ts.createObjectLiteral([ + ts.createPropertyAssignment(field.name.value, valueName), + ]), + ), + ) +} + +/** + * This actually creates the assignment for some field in the args argument to the corresponding field + * in our struct class + * + * interface IStructArgs { + * id: number; + * } + * + * constructor(args: IStructArgs) { + * if (args.id !== null && args.id !== undefined) { + * this.id = args.id; + * } + * } + * + * This function creates the 'this.id = args.id' bit. + */ +export function assignmentForField( + field: FieldDefinition, + state: IRenderState, +): Array { + if (state.options.strictUnions) { + return [ + incrementFieldsSet(), + ..._assignmentForField(field, state, returnAssignment), + ] + } else { + return [incrementFieldsSet(), ..._assignmentForField(field, state)] + } +} + +/** + * Assign field if contained in args: + * + * if (args && args. != null) { + * this. = args. + * } + * + * If field is required throw an error: + * + * else { + * throw new Thrift.TProtocolException(Thrift.TProtocolExceptionType.UNKNOWN, 'Required field {{fieldName}} is unset!') + * } + */ +export function createFieldAssignment( + field: FieldDefinition, + state: IRenderState, +): ts.IfStatement { + const hasValue: ts.BinaryExpression = createNotNullCheck( + ts.createPropertyAccess(COMMON_IDENTIFIERS.args, `${field.name.value}`), + ) + + const thenAssign: Array = assignmentForField(field, state) + + return ts.createIf(hasValue, ts.createBlock([...thenAssign], true)) +} diff --git a/src/tests/unit/fixtures/generated/strict-unions/calculator/index.ts b/src/tests/unit/fixtures/generated/strict-unions/calculator/index.ts index 1bae9dd2..70fa3541 100644 --- a/src/tests/unit/fixtures/generated/strict-unions/calculator/index.ts +++ b/src/tests/unit/fixtures/generated/strict-unions/calculator/index.ts @@ -323,16 +323,69 @@ export class LastName extends thrift.StructLike implements ILastName { return LastNameCodec.encode(this, output); } } -export interface IChoice { - firstName?: IFirstName; - lastName?: ILastName; +export enum ChoiceType { + ChoiceWithFirstName = "firstName", + ChoiceWithLastName = "lastName" } -export interface IChoiceArgs { - firstName?: IFirstNameArgs; - lastName?: ILastNameArgs; +export type Choice = IChoiceWithFirstName | IChoiceWithLastName; +export interface IChoiceWithFirstName { + __type: ChoiceType.ChoiceWithFirstName; + firstName: IFirstName; + lastName?: void; } -export const ChoiceCodec: thrift.IStructCodec = { - encode(args: IChoiceArgs, output: thrift.TProtocol): void { +export interface IChoiceWithLastName { + __type: ChoiceType.ChoiceWithLastName; + firstName?: void; + lastName: ILastName; +} +export type ChoiceArgs = IChoiceWithFirstNameArgs | IChoiceWithLastNameArgs; +export interface IChoiceWithFirstNameArgs { + firstName: IFirstNameArgs; + lastName?: void; +} +export interface IChoiceWithLastNameArgs { + firstName?: void; + lastName: ILastNameArgs; +} +export const ChoiceCodec: thrift.IStructToolkit = { + create(args: ChoiceArgs): Choice { + let _fieldsSet: number = 0; + let _returnValue: any = null; + if (args.firstName != null) { + _fieldsSet++; + const value_13: IFirstName = new FirstName(args.firstName); + _returnValue = { firstName: value_13 }; + } + if (args.lastName != null) { + _fieldsSet++; + const value_14: ILastName = new LastName(args.lastName); + _returnValue = { lastName: value_14 }; + } + if (_fieldsSet > 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion cannot have more than one value"); + } + else if (_fieldsSet < 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); + } + if (_returnValue !== null) { + if (_returnValue.firstName) { + return { + __type: ChoiceType.ChoiceWithFirstName, + firstName: _returnValue.firstName + }; + } + else { + return { + __type: ChoiceType.ChoiceWithLastName, + lastName: _returnValue.lastName + }; + } + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); + } + }, + encode(args: ChoiceArgs, output: thrift.TProtocol): void { let _fieldsSet: number = 0; const obj = { firstName: args.firstName, @@ -361,9 +414,9 @@ export const ChoiceCodec: thrift.IStructCodec = { } return; }, - decode(input: thrift.TProtocol): IChoice { + decode(input: thrift.TProtocol): Choice { let _fieldsSet: number = 0; - let _returnValue: IChoice | null = null; + let _returnValue: any = null; input.readStructBegin(); while (true) { const ret: thrift.IThriftField = input.readFieldBegin(); @@ -376,8 +429,8 @@ export const ChoiceCodec: thrift.IStructCodec = { case 1: if (fieldType === thrift.TType.STRUCT) { _fieldsSet++; - const value_13: IFirstName = FirstNameCodec.decode(input); - _returnValue = { firstName: value_13 }; + const value_15: IFirstName = FirstNameCodec.decode(input); + _returnValue = { firstName: value_15 }; } else { input.skip(fieldType); @@ -386,8 +439,8 @@ export const ChoiceCodec: thrift.IStructCodec = { case 2: if (fieldType === thrift.TType.STRUCT) { _fieldsSet++; - const value_14: ILastName = LastNameCodec.decode(input); - _returnValue = { lastName: value_14 }; + const value_16: ILastName = LastNameCodec.decode(input); + _returnValue = { lastName: value_16 }; } else { input.skip(fieldType); @@ -407,48 +460,24 @@ export const ChoiceCodec: thrift.IStructCodec = { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); } if (_returnValue !== null) { - return _returnValue; + if (_returnValue.firstName) { + return { + __type: ChoiceType.ChoiceWithFirstName, + firstName: _returnValue.firstName + }; + } + else { + return { + __type: ChoiceType.ChoiceWithLastName, + lastName: _returnValue.lastName + }; + } } else { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); } } }; -export class Choice extends thrift.StructLike implements IChoice { - public firstName?: IFirstName; - public lastName?: ILastName; - public readonly _annotations: thrift.IThriftAnnotations = {}; - public readonly _fieldAnnotations: thrift.IFieldAnnotations = {}; - constructor(args: IChoiceArgs = {}) { - super(); - let _fieldsSet: number = 0; - if (args.firstName != null) { - _fieldsSet++; - const value_15: IFirstName = new FirstName(args.firstName); - this.firstName = value_15; - } - if (args.lastName != null) { - _fieldsSet++; - const value_16: ILastName = new LastName(args.lastName); - this.lastName = value_16; - } - if (_fieldsSet > 1) { - throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion cannot have more than one value"); - } - else if (_fieldsSet < 1) { - throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); - } - } - public static read(input: thrift.TProtocol): Choice { - return new Choice(ChoiceCodec.decode(input)); - } - public static write(args: IChoiceArgs, output: thrift.TProtocol): void { - return ChoiceCodec.encode(args, output); - } - public write(output: thrift.TProtocol): void { - return ChoiceCodec.encode(this, output); - } -} export namespace Calculator { export const serviceName: string = "Calculator"; export const annotations: thrift.IThriftAnnotations = {}; @@ -1193,10 +1222,10 @@ export namespace Calculator { } } export interface ICheckName__Args { - choice: IChoice; + choice: Choice; } export interface ICheckName__ArgsArgs { - choice: IChoiceArgs; + choice: ChoiceArgs; } export const CheckName__ArgsCodec: thrift.IStructCodec = { encode(args: ICheckName__ArgsArgs, output: thrift.TProtocol): void { @@ -1229,7 +1258,7 @@ export namespace Calculator { switch (fieldId) { case 1: if (fieldType === thrift.TType.STRUCT) { - const value_37: IChoice = ChoiceCodec.decode(input); + const value_37: Choice = ChoiceCodec.decode(input); _args.choice = value_37; } else { @@ -1254,13 +1283,13 @@ export namespace Calculator { } }; export class CheckName__Args extends thrift.StructLike implements ICheckName__Args { - public choice: IChoice; + public choice: Choice; public readonly _annotations: thrift.IThriftAnnotations = {}; public readonly _fieldAnnotations: thrift.IFieldAnnotations = {}; constructor(args: ICheckName__ArgsArgs) { super(); if (args.choice != null) { - const value_38: IChoice = new Choice(args.choice); + const value_38: Choice = ChoiceCodec.create(args.choice); this.choice = value_38; } else { @@ -3154,7 +3183,7 @@ export namespace Calculator { } }); } - public checkName(choice: IChoiceArgs, context?: Context): Promise { + public checkName(choice: ChoiceArgs, context?: Context): Promise { const writer: thrift.TTransport = new this.transport(); const output: thrift.TProtocol = new this.protocol(writer); output.writeMessageBegin("checkName", thrift.MessageType.CALL, this.incrementRequestId()); @@ -3424,7 +3453,7 @@ export namespace Calculator { calculate(logid: number, work: IWork, context?: Context): number | Promise; echoBinary(word: Buffer, context?: Context): string | Promise; echoString(word: string, context?: Context): string | Promise; - checkName(choice: IChoice, context?: Context): string | Promise; + checkName(choice: Choice, context?: Context): string | Promise; checkOptional(type?: string, context?: Context): string | Promise; mapOneList(arg: Array, context?: Context): Array | Promise>; mapValues(arg: Map, context?: Context): Array | Promise>; diff --git a/src/tests/unit/fixtures/generated/strict-unions/common/index.ts b/src/tests/unit/fixtures/generated/strict-unions/common/index.ts index f3a91842..4ba8bc9f 100644 --- a/src/tests/unit/fixtures/generated/strict-unions/common/index.ts +++ b/src/tests/unit/fixtures/generated/strict-unions/common/index.ts @@ -9,9 +9,13 @@ export import ICommonStruct = shared.ISharedStruct; export import ICommonStructArgs = shared.ISharedStructArgs; export import CommonStruct = shared.SharedStruct; export import CommonStructCodec = shared.SharedStructCodec; -export import ICommonUnion = shared.ISharedUnion; -export import ICommonUnionArgs = shared.ISharedUnionArgs; +export import CommonUnionType = shared.SharedUnionType; export import CommonUnion = shared.SharedUnion; +export import ICommonUnionWithOption1 = shared.ISharedUnionWithOption1; +export import ICommonUnionWithOption2 = shared.ISharedUnionWithOption2; +export import CommonUnionArgs = shared.SharedUnionArgs; +export import ICommonUnionWithOption1Args = shared.ISharedUnionWithOption1Args; +export import ICommonUnionWithOption2Args = shared.ISharedUnionWithOption2Args; export import CommonUnionCodec = shared.SharedUnionCodec; export import COMMON_INT = shared.SHARED_INT; export interface IAuthException { diff --git a/src/tests/unit/fixtures/generated/strict-unions/shared/index.ts b/src/tests/unit/fixtures/generated/strict-unions/shared/index.ts index f7f9d546..b871b7c1 100644 --- a/src/tests/unit/fixtures/generated/strict-unions/shared/index.ts +++ b/src/tests/unit/fixtures/generated/strict-unions/shared/index.ts @@ -193,37 +193,69 @@ export class SharedStruct extends thrift.StructLike implements ISharedStruct { return SharedStructCodec.encode(this, output); } } -export type SharedUnion = ISharedUnionWithOption1 | ISharedUnionWithOption2; export enum SharedUnionType { SharedUnionWithOption1 = "option1", SharedUnionWithOption2 = "option2" } +export type SharedUnion = ISharedUnionWithOption1 | ISharedUnionWithOption2; export interface ISharedUnionWithOption1 { __type: SharedUnionType.SharedUnionWithOption1; option1: string; option2?: void; } -export class SharedUnionWithOption1 implements ISharedUnionWithOption1 { - public readonly __type: SharedUnionType.SharedUnionWithOption1; - public option1: string; - constructor(args: ISharedUnionWithOption1) { - - } -} export interface ISharedUnionWithOption2 { - __type: "option2"; + __type: SharedUnionType.SharedUnionWithOption2; option1?: void; option2: string; } -export class SharedUnionWithOption2 implements ISharedUnionWithOption2 { - public readonly __type: SharedUnionType.SharedUnionWithOption2; - public option2: string; - constructor(args: ISharedUnionWithOption2) { - - } +export type SharedUnionArgs = ISharedUnionWithOption1Args | ISharedUnionWithOption2Args; +export interface ISharedUnionWithOption1Args { + option1: string; + option2?: void; } -export const SharedUnionCodec: thrift.IStructCodec = { - encode(args: SharedUnion, output: thrift.TProtocol): void { +export interface ISharedUnionWithOption2Args { + option1?: void; + option2: string; +} +export const SharedUnionCodec: thrift.IStructToolkit = { + create(args: SharedUnionArgs): SharedUnion { + let _fieldsSet: number = 0; + let _returnValue: any = null; + if (args.option1 != null) { + _fieldsSet++; + const value_7: string = args.option1; + _returnValue = { option1: value_7 }; + } + if (args.option2 != null) { + _fieldsSet++; + const value_8: string = args.option2; + _returnValue = { option2: value_8 }; + } + if (_fieldsSet > 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion cannot have more than one value"); + } + else if (_fieldsSet < 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); + } + if (_returnValue !== null) { + if (_returnValue.option1) { + return { + __type: SharedUnionType.SharedUnionWithOption1, + option1: _returnValue.option1 + }; + } + else { + return { + __type: SharedUnionType.SharedUnionWithOption2, + option2: _returnValue.option2 + }; + } + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); + } + }, + encode(args: SharedUnionArgs, output: thrift.TProtocol): void { let _fieldsSet: number = 0; const obj = { option1: args.option1, @@ -252,9 +284,9 @@ export const SharedUnionCodec: thrift.IStructCodec = { } return; }, - decode(input: thrift.TProtocol): ISharedUnion { + decode(input: thrift.TProtocol): SharedUnion { let _fieldsSet: number = 0; - let _returnValue: ISharedUnion | null = null; + let _returnValue: any = null; input.readStructBegin(); while (true) { const ret: thrift.IThriftField = input.readFieldBegin(); @@ -267,8 +299,8 @@ export const SharedUnionCodec: thrift.IStructCodec = { case 1: if (fieldType === thrift.TType.STRING) { _fieldsSet++; - const value_7: string = input.readString(); - _returnValue = { option1: value_7 }; + const value_9: string = input.readString(); + _returnValue = { option1: value_9 }; } else { input.skip(fieldType); @@ -277,8 +309,8 @@ export const SharedUnionCodec: thrift.IStructCodec = { case 2: if (fieldType === thrift.TType.STRING) { _fieldsSet++; - const value_8: string = input.readString(); - _returnValue = { option2: value_8 }; + const value_10: string = input.readString(); + _returnValue = { option2: value_10 }; } else { input.skip(fieldType); @@ -298,7 +330,18 @@ export const SharedUnionCodec: thrift.IStructCodec = { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); } if (_returnValue !== null) { - return _returnValue; + if (_returnValue.option1) { + return { + __type: SharedUnionType.SharedUnionWithOption1, + option1: _returnValue.option1 + }; + } + else { + return { + __type: SharedUnionType.SharedUnionWithOption2, + option2: _returnValue.option2 + }; + } } else { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); @@ -564,10 +607,10 @@ export namespace SharedService { } } export interface IGetUnion__Result { - success?: ISharedUnion; + success?: SharedUnion; } export interface IGetUnion__ResultArgs { - success?: ISharedUnionArgs; + success?: SharedUnionArgs; } export const GetUnion__ResultCodec: thrift.IStructCodec = { encode(args: IGetUnion__ResultArgs, output: thrift.TProtocol): void { @@ -597,7 +640,7 @@ export namespace SharedService { switch (fieldId) { case 0: if (fieldType === thrift.TType.STRUCT) { - const value_17: ISharedUnion = SharedUnionCodec.decode(input); + const value_17: SharedUnion = SharedUnionCodec.decode(input); _args.success = value_17; } else { @@ -617,13 +660,13 @@ export namespace SharedService { } }; export class GetUnion__Result extends thrift.StructLike implements IGetUnion__Result { - public success?: ISharedUnion; + public success?: SharedUnion; public readonly _annotations: thrift.IThriftAnnotations = {}; public readonly _fieldAnnotations: thrift.IFieldAnnotations = {}; constructor(args: IGetUnion__ResultArgs = {}) { super(); if (args.success != null) { - const value_18: ISharedUnion = new SharedUnion(args.success); + const value_18: SharedUnion = SharedUnionCodec.create(args.success); this.success = value_18; } } @@ -684,7 +727,7 @@ export namespace SharedService { } }); } - public getUnion(index: number, context?: Context): Promise { + public getUnion(index: number, context?: Context): Promise { const writer: thrift.TTransport = new this.transport(); const output: thrift.TProtocol = new this.protocol(writer); output.writeMessageBegin("getUnion", thrift.MessageType.CALL, this.incrementRequestId()); @@ -725,7 +768,7 @@ export namespace SharedService { } export interface IHandler { getStruct(key: number, context?: Context): ISharedStruct | Promise; - getUnion(index: number, context?: Context): ISharedUnion | Promise; + getUnion(index: number, context?: Context): SharedUnion | Promise; } export class Processor extends thrift.ThriftProcessor> { protected readonly _handler: IHandler; @@ -795,7 +838,7 @@ export namespace SharedService { }); } public process_getUnion(requestId: number, input: thrift.TProtocol, output: thrift.TProtocol, context: Context): Promise { - return new Promise((resolve, reject): void => { + return new Promise((resolve, reject): void => { try { const args: IGetUnion__Args = GetUnion__ArgsCodec.decode(input); input.readMessageEnd(); @@ -804,7 +847,7 @@ export namespace SharedService { catch (err) { reject(err); } - }).then((data: ISharedUnion): Buffer => { + }).then((data: SharedUnion): Buffer => { const result: IGetUnion__Result = { success: data }; output.writeMessageBegin("getUnion", thrift.MessageType.REPLY, requestId); GetUnion__ResultCodec.encode(result, output); diff --git a/src/tests/unit/fixtures/thrift-server/basic_service.strict_union.solution.ts b/src/tests/unit/fixtures/thrift-server/basic_service.strict_union.solution.ts index b012300b..61aea314 100644 --- a/src/tests/unit/fixtures/thrift-server/basic_service.strict_union.solution.ts +++ b/src/tests/unit/fixtures/thrift-server/basic_service.strict_union.solution.ts @@ -1,14 +1,66 @@ +export enum MyUnionType { + MyUnionWithField1 = "field1", + MyUnionWithField2 = "field2" +} export type MyUnion = IMyUnionWithField1 | IMyUnionWithField2; export interface IMyUnionWithField1 { + __type: MyUnionType.MyUnionWithField1; field1: number; field2?: void; } export interface IMyUnionWithField2 { + __type: MyUnionType.MyUnionWithField2; field1?: void; field2: thrift.Int64; } -export const MyUnionCodec: thrift.IStructCodec = { - encode(args: MyUnion, output: thrift.TProtocol): void { +export type MyUnionArgs = IMyUnionWithField1Args | IMyUnionWithField2Args; +export interface IMyUnionWithField1Args { + field1: number; + field2?: void; +} +export interface IMyUnionWithField2Args { + field1?: void; + field2: number | thrift.Int64; +} +export const MyUnionCodec: thrift.IStructToolkit = { + create(args: MyUnionArgs): MyUnion { + let _fieldsSet: number = 0; + let _returnValue: any = null; + if (args.field1 != null) { + _fieldsSet++; + const value_1: number = args.field1; + _returnValue = { field1: value_1 }; + } + if (args.field2 != null) { + _fieldsSet++; + const value_2: thrift.Int64 = (typeof args.field2 === "number" ? new thrift.Int64(args.field2) : args.field2); + _returnValue = { field2: value_2 }; + } + if (_fieldsSet > 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion cannot have more than one value"); + } + else if (_fieldsSet < 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); + } + if (_returnValue !== null) { + if (_returnValue.field1) { + return { + __type: MyUnionType.MyUnionWithField1, + field1: _returnValue.field1 + }; + } + else { + return { + __type: MyUnionType.MyUnionWithField2, + field2: _returnValue.field2 + }; + } + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); + } + }, + encode(args: MyUnionArgs, output: thrift.TProtocol): void { let _fieldsSet: number = 0; const obj = { field1: args.field1, @@ -39,7 +91,7 @@ export const MyUnionCodec: thrift.IStructCodec = { }, decode(input: thrift.TProtocol): MyUnion { let _fieldsSet: number = 0; - let _returnValue: MyUnion | null = null; + let _returnValue: any = null; input.readStructBegin(); while (true) { const ret: thrift.IThriftField = input.readFieldBegin(); @@ -52,8 +104,8 @@ export const MyUnionCodec: thrift.IStructCodec = { case 1: if (fieldType === thrift.TType.I32) { _fieldsSet++; - const value_1: number = input.readI32(); - _returnValue = { field1: value_1 }; + const value_3: number = input.readI32(); + _returnValue = { field1: value_3 }; } else { input.skip(fieldType); @@ -62,8 +114,8 @@ export const MyUnionCodec: thrift.IStructCodec = { case 2: if (fieldType === thrift.TType.I64) { _fieldsSet++; - const value_2: thrift.Int64 = input.readI64(); - _returnValue = { field2: value_2 }; + const value_4: thrift.Int64 = input.readI64(); + _returnValue = { field2: value_4 }; } else { input.skip(fieldType); @@ -83,10 +135,483 @@ export const MyUnionCodec: thrift.IStructCodec = { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); } if (_returnValue !== null) { - return _returnValue; + if (_returnValue.field1) { + return { + __type: MyUnionType.MyUnionWithField1, + field1: _returnValue.field1 + }; + } + else { + return { + __type: MyUnionType.MyUnionWithField2, + field2: _returnValue.field2 + }; + } } else { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); } } }; +export namespace MyService { + export const serviceName: string = "MyService"; + export const annotations: thrift.IThriftAnnotations = {}; + export const methodAnnotations: thrift.IMethodAnnotations = { + getUser: { + annotations: {}, + fieldAnnotations: {} + }, + ping: { + annotations: {}, + fieldAnnotations: {} + } + }; + export const methodNames: Array = ["getUser", "ping"]; + export interface IGetUser__Args { + arg1: MyUnion; + } + export interface IGetUser__ArgsArgs { + arg1: MyUnionArgs; + } + export const GetUser__ArgsCodec: thrift.IStructCodec = { + encode(args: IGetUser__ArgsArgs, output: thrift.TProtocol): void { + const obj = { + arg1: args.arg1 + }; + output.writeStructBegin("GetUser__Args"); + if (obj.arg1 != null) { + output.writeFieldBegin("arg1", thrift.TType.STRUCT, 1); + MyUnionCodec.encode(obj.arg1, output); + output.writeFieldEnd(); + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Required field[arg1] is unset!"); + } + output.writeFieldStop(); + output.writeStructEnd(); + return; + }, + decode(input: thrift.TProtocol): IGetUser__Args { + let _args: any = {}; + input.readStructBegin(); + while (true) { + const ret: thrift.IThriftField = input.readFieldBegin(); + const fieldType: thrift.TType = ret.fieldType; + const fieldId: number = ret.fieldId; + if (fieldType === thrift.TType.STOP) { + break; + } + switch (fieldId) { + case 1: + if (fieldType === thrift.TType.STRUCT) { + const value_5: MyUnion = MyUnionCodec.decode(input); + _args.arg1 = value_5; + } + else { + input.skip(fieldType); + } + break; + default: { + input.skip(fieldType); + } + } + input.readFieldEnd(); + } + input.readStructEnd(); + if (_args.arg1 !== undefined) { + return { + arg1: _args.arg1 + }; + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read GetUser__Args from input"); + } + } + }; + export class GetUser__Args extends thrift.StructLike implements IGetUser__Args { + public arg1: MyUnion; + public readonly _annotations: thrift.IThriftAnnotations = {}; + public readonly _fieldAnnotations: thrift.IFieldAnnotations = {}; + constructor(args: IGetUser__ArgsArgs) { + super(); + if (args.arg1 != null) { + const value_6: MyUnion = MyUnionCodec.create(args.arg1); + this.arg1 = value_6; + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Required field[arg1] is unset!"); + } + } + public static read(input: thrift.TProtocol): GetUser__Args { + return new GetUser__Args(GetUser__ArgsCodec.decode(input)); + } + public static write(args: IGetUser__ArgsArgs, output: thrift.TProtocol): void { + return GetUser__ArgsCodec.encode(args, output); + } + public write(output: thrift.TProtocol): void { + return GetUser__ArgsCodec.encode(this, output); + } + } + export interface IPing__Args { + } + export interface IPing__ArgsArgs { + } + export const Ping__ArgsCodec: thrift.IStructCodec = { + encode(args: IPing__ArgsArgs, output: thrift.TProtocol): void { + output.writeStructBegin("Ping__Args"); + output.writeFieldStop(); + output.writeStructEnd(); + return; + }, + decode(input: thrift.TProtocol): IPing__Args { + input.readStructBegin(); + while (true) { + const ret: thrift.IThriftField = input.readFieldBegin(); + const fieldType: thrift.TType = ret.fieldType; + const fieldId: number = ret.fieldId; + if (fieldType === thrift.TType.STOP) { + break; + } + switch (fieldId) { + default: { + input.skip(fieldType); + } + } + input.readFieldEnd(); + } + input.readStructEnd(); + return {}; + } + }; + export class Ping__Args extends thrift.StructLike implements IPing__Args { + public readonly _annotations: thrift.IThriftAnnotations = {}; + public readonly _fieldAnnotations: thrift.IFieldAnnotations = {}; + constructor(args: IPing__ArgsArgs = {}) { + super(); + } + public static read(input: thrift.TProtocol): Ping__Args { + return new Ping__Args(Ping__ArgsCodec.decode(input)); + } + public static write(args: IPing__ArgsArgs, output: thrift.TProtocol): void { + return Ping__ArgsCodec.encode(args, output); + } + public write(output: thrift.TProtocol): void { + return Ping__ArgsCodec.encode(this, output); + } + } + export interface IGetUser__Result { + success?: string; + } + export interface IGetUser__ResultArgs { + success?: string; + } + export const GetUser__ResultCodec: thrift.IStructCodec = { + encode(args: IGetUser__ResultArgs, output: thrift.TProtocol): void { + const obj = { + success: args.success + }; + output.writeStructBegin("GetUser__Result"); + if (obj.success != null) { + output.writeFieldBegin("success", thrift.TType.STRING, 0); + output.writeString(obj.success); + output.writeFieldEnd(); + } + output.writeFieldStop(); + output.writeStructEnd(); + return; + }, + decode(input: thrift.TProtocol): IGetUser__Result { + let _args: any = {}; + input.readStructBegin(); + while (true) { + const ret: thrift.IThriftField = input.readFieldBegin(); + const fieldType: thrift.TType = ret.fieldType; + const fieldId: number = ret.fieldId; + if (fieldType === thrift.TType.STOP) { + break; + } + switch (fieldId) { + case 0: + if (fieldType === thrift.TType.STRING) { + const value_7: string = input.readString(); + _args.success = value_7; + } + else { + input.skip(fieldType); + } + break; + default: { + input.skip(fieldType); + } + } + input.readFieldEnd(); + } + input.readStructEnd(); + return { + success: _args.success + }; + } + }; + export class GetUser__Result extends thrift.StructLike implements IGetUser__Result { + public success?: string; + public readonly _annotations: thrift.IThriftAnnotations = {}; + public readonly _fieldAnnotations: thrift.IFieldAnnotations = {}; + constructor(args: IGetUser__ResultArgs = {}) { + super(); + if (args.success != null) { + const value_8: string = args.success; + this.success = value_8; + } + } + public static read(input: thrift.TProtocol): GetUser__Result { + return new GetUser__Result(GetUser__ResultCodec.decode(input)); + } + public static write(args: IGetUser__ResultArgs, output: thrift.TProtocol): void { + return GetUser__ResultCodec.encode(args, output); + } + public write(output: thrift.TProtocol): void { + return GetUser__ResultCodec.encode(this, output); + } + } + export interface IPing__Result { + success?: void; + } + export interface IPing__ResultArgs { + success?: void; + } + export const Ping__ResultCodec: thrift.IStructCodec = { + encode(args: IPing__ResultArgs, output: thrift.TProtocol): void { + output.writeStructBegin("Ping__Result"); + output.writeFieldStop(); + output.writeStructEnd(); + return; + }, + decode(input: thrift.TProtocol): IPing__Result { + let _args: any = {}; + input.readStructBegin(); + while (true) { + const ret: thrift.IThriftField = input.readFieldBegin(); + const fieldType: thrift.TType = ret.fieldType; + const fieldId: number = ret.fieldId; + if (fieldType === thrift.TType.STOP) { + break; + } + switch (fieldId) { + case 0: + if (fieldType === thrift.TType.VOID) { + input.skip(fieldType); + } + else { + input.skip(fieldType); + } + break; + default: { + input.skip(fieldType); + } + } + input.readFieldEnd(); + } + input.readStructEnd(); + return { + success: _args.success + }; + } + }; + export class Ping__Result extends thrift.StructLike implements IPing__Result { + public success?: void; + public readonly _annotations: thrift.IThriftAnnotations = {}; + public readonly _fieldAnnotations: thrift.IFieldAnnotations = {}; + constructor(args: IPing__ResultArgs = {}) { + super(); + if (args.success != null) { + const value_9: void = undefined; + this.success = value_9; + } + } + public static read(input: thrift.TProtocol): Ping__Result { + return new Ping__Result(Ping__ResultCodec.decode(input)); + } + public static write(args: IPing__ResultArgs, output: thrift.TProtocol): void { + return Ping__ResultCodec.encode(args, output); + } + public write(output: thrift.TProtocol): void { + return Ping__ResultCodec.encode(this, output); + } + } + export class Client extends thrift.ThriftClient { + public static readonly serviceName: string = serviceName; + public static readonly annotations: thrift.IThriftAnnotations = annotations; + public static readonly methodAnnotations: thrift.IMethodAnnotations = methodAnnotations; + public static readonly methodNames: Array = methodNames; + public readonly _serviceName: string = serviceName; + public readonly _annotations: thrift.IThriftAnnotations = annotations; + public readonly _methodAnnotations: thrift.IMethodAnnotations = methodAnnotations; + public readonly _methodNames: Array = methodNames; + public getUser(arg1: MyUnionArgs, context?: Context): Promise { + const writer: thrift.TTransport = new this.transport(); + const output: thrift.TProtocol = new this.protocol(writer); + output.writeMessageBegin("getUser", thrift.MessageType.CALL, this.incrementRequestId()); + const args: IGetUser__ArgsArgs = { arg1 }; + GetUser__ArgsCodec.encode(args, output); + output.writeMessageEnd(); + return this.connection.send(writer.flush(), context).then((data: Buffer) => { + const reader: thrift.TTransport = this.transport.receiver(data); + const input: thrift.TProtocol = new this.protocol(reader); + try { + const { fieldName: fieldName, messageType: messageType }: thrift.IThriftMessage = input.readMessageBegin(); + if (fieldName === "getUser") { + if (messageType === thrift.MessageType.EXCEPTION) { + const err: thrift.TApplicationException = thrift.TApplicationExceptionCodec.decode(input); + input.readMessageEnd(); + return Promise.reject(err); + } + else { + const result: IGetUser__Result = GetUser__ResultCodec.decode(input); + input.readMessageEnd(); + if (result.success != null) { + return Promise.resolve(result.success); + } + else { + return Promise.reject(new thrift.TApplicationException(thrift.TApplicationExceptionType.UNKNOWN, "getUser failed: unknown result")); + } + } + } + else { + return Promise.reject(new thrift.TApplicationException(thrift.TApplicationExceptionType.WRONG_METHOD_NAME, "Received a response to an unknown RPC function: " + fieldName)); + } + } + catch (err) { + return Promise.reject(err); + } + }); + } + public ping(context?: Context): Promise { + const writer: thrift.TTransport = new this.transport(); + const output: thrift.TProtocol = new this.protocol(writer); + output.writeMessageBegin("ping", thrift.MessageType.CALL, this.incrementRequestId()); + const args: IPing__ArgsArgs = {}; + Ping__ArgsCodec.encode(args, output); + output.writeMessageEnd(); + return this.connection.send(writer.flush(), context).then((data: Buffer) => { + const reader: thrift.TTransport = this.transport.receiver(data); + const input: thrift.TProtocol = new this.protocol(reader); + try { + const { fieldName: fieldName, messageType: messageType }: thrift.IThriftMessage = input.readMessageBegin(); + if (fieldName === "ping") { + if (messageType === thrift.MessageType.EXCEPTION) { + const err: thrift.TApplicationException = thrift.TApplicationExceptionCodec.decode(input); + input.readMessageEnd(); + return Promise.reject(err); + } + else { + const result: IPing__Result = Ping__ResultCodec.decode(input); + input.readMessageEnd(); + return Promise.resolve(result.success); + } + } + else { + return Promise.reject(new thrift.TApplicationException(thrift.TApplicationExceptionType.WRONG_METHOD_NAME, "Received a response to an unknown RPC function: " + fieldName)); + } + } + catch (err) { + return Promise.reject(err); + } + }); + } + } + export interface IHandler { + getUser(arg1: MyUnion, context?: Context): string | Promise; + ping(context?: Context): void | Promise; + } + export class Processor extends thrift.ThriftProcessor> { + protected readonly _handler: IHandler; + public static readonly serviceName: string = serviceName; + public static readonly annotations: thrift.IThriftAnnotations = annotations; + public static readonly methodAnnotations: thrift.IMethodAnnotations = methodAnnotations; + public static readonly methodNames: Array = methodNames; + public readonly _serviceName: string = serviceName; + public readonly _annotations: thrift.IThriftAnnotations = annotations; + public readonly _methodAnnotations: thrift.IMethodAnnotations = methodAnnotations; + public readonly _methodNames: Array = methodNames; + constructor(handler: IHandler) { + super(); + this._handler = handler; + } + public process(input: thrift.TProtocol, output: thrift.TProtocol, context: Context): Promise { + return new Promise((resolve, reject): void => { + const metadata: thrift.IThriftMessage = input.readMessageBegin(); + const fieldName: string = metadata.fieldName; + const requestId: number = metadata.requestId; + const methodName: string = "process_" + fieldName; + switch (methodName) { + case "process_getUser": { + resolve(this.process_getUser(requestId, input, output, context)); + break; + } + case "process_ping": { + resolve(this.process_ping(requestId, input, output, context)); + break; + } + default: { + input.skip(thrift.TType.STRUCT); + input.readMessageEnd(); + const errMessage = "Unknown function " + fieldName; + const err = new thrift.TApplicationException(thrift.TApplicationExceptionType.UNKNOWN_METHOD, errMessage); + output.writeMessageBegin(fieldName, thrift.MessageType.EXCEPTION, requestId); + thrift.TApplicationExceptionCodec.encode(err, output); + output.writeMessageEnd(); + resolve(output.flush()); + break; + } + } + }); + } + public process_getUser(requestId: number, input: thrift.TProtocol, output: thrift.TProtocol, context: Context): Promise { + return new Promise((resolve, reject): void => { + try { + const args: IGetUser__Args = GetUser__ArgsCodec.decode(input); + input.readMessageEnd(); + resolve(this._handler.getUser(args.arg1, context)); + } + catch (err) { + reject(err); + } + }).then((data: string): Buffer => { + const result: IGetUser__Result = { success: data }; + output.writeMessageBegin("getUser", thrift.MessageType.REPLY, requestId); + GetUser__ResultCodec.encode(result, output); + output.writeMessageEnd(); + return output.flush(); + }).catch((err: Error): Buffer => { + const result: thrift.TApplicationException = new thrift.TApplicationException(thrift.TApplicationExceptionType.UNKNOWN, err.message); + output.writeMessageBegin("getUser", thrift.MessageType.EXCEPTION, requestId); + thrift.TApplicationExceptionCodec.encode(result, output); + output.writeMessageEnd(); + return output.flush(); + }); + } + public process_ping(requestId: number, input: thrift.TProtocol, output: thrift.TProtocol, context: Context): Promise { + return new Promise((resolve, reject): void => { + try { + input.readMessageEnd(); + resolve(this._handler.ping(context)); + } + catch (err) { + reject(err); + } + }).then((data: void): Buffer => { + const result: IPing__Result = { success: data }; + output.writeMessageBegin("ping", thrift.MessageType.REPLY, requestId); + Ping__ResultCodec.encode(result, output); + output.writeMessageEnd(); + return output.flush(); + }).catch((err: Error): Buffer => { + const result: thrift.TApplicationException = new thrift.TApplicationException(thrift.TApplicationExceptionType.UNKNOWN, err.message); + output.writeMessageBegin("ping", thrift.MessageType.EXCEPTION, requestId); + thrift.TApplicationExceptionCodec.encode(result, output); + output.writeMessageEnd(); + return output.flush(); + }); + } + } +} diff --git a/src/tests/unit/fixtures/thrift-server/basic_union.strict_union.solution.ts b/src/tests/unit/fixtures/thrift-server/basic_union.strict_union.solution.ts index 1ae75860..f4afbe7a 100644 --- a/src/tests/unit/fixtures/thrift-server/basic_union.strict_union.solution.ts +++ b/src/tests/unit/fixtures/thrift-server/basic_union.strict_union.solution.ts @@ -1,8 +1,8 @@ -export type MyUnion = IMyUnionWithField1 | IMyUnionWithField2; export enum MyUnionType { MyUnionWithField1 = "field1", MyUnionWithField2 = "field2" } +export type MyUnion = IMyUnionWithField1 | IMyUnionWithField2; export interface IMyUnionWithField1 { __type: MyUnionType.MyUnionWithField1; field1: number; @@ -22,7 +22,44 @@ export interface IMyUnionWithField2Args { field1?: void; field2: number | thrift.Int64; } -export const MyUnionCodec: thrift.IStructCodec = { +export const MyUnionCodec: thrift.IStructToolkit = { + create(args: MyUnionArgs): MyUnion { + let _fieldsSet: number = 0; + let _returnValue: any = null; + if (args.field1 != null) { + _fieldsSet++; + const value_1: number = args.field1; + _returnValue = { field1: value_1 }; + } + if (args.field2 != null) { + _fieldsSet++; + const value_2: thrift.Int64 = (typeof args.field2 === "number" ? new thrift.Int64(args.field2) : args.field2); + _returnValue = { field2: value_2 }; + } + if (_fieldsSet > 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion cannot have more than one value"); + } + else if (_fieldsSet < 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); + } + if (_returnValue !== null) { + if (_returnValue.field1) { + return { + __type: MyUnionType.MyUnionWithField1, + field1: _returnValue.field1 + }; + } + else { + return { + __type: MyUnionType.MyUnionWithField2, + field2: _returnValue.field2 + }; + } + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); + } + }, encode(args: MyUnionArgs, output: thrift.TProtocol): void { let _fieldsSet: number = 0; const obj = { @@ -67,8 +104,8 @@ export const MyUnionCodec: thrift.IStructCodec = { case 1: if (fieldType === thrift.TType.I32) { _fieldsSet++; - const value_1: number = input.readI32(); - _returnValue = { field1: value_1 }; + const value_3: number = input.readI32(); + _returnValue = { field1: value_3 }; } else { input.skip(fieldType); @@ -77,8 +114,8 @@ export const MyUnionCodec: thrift.IStructCodec = { case 2: if (fieldType === thrift.TType.I64) { _fieldsSet++; - const value_2: thrift.Int64 = input.readI64(); - _returnValue = { field2: value_2 }; + const value_4: thrift.Int64 = input.readI64(); + _returnValue = { field2: value_4 }; } else { input.skip(fieldType); @@ -100,13 +137,13 @@ export const MyUnionCodec: thrift.IStructCodec = { if (_returnValue !== null) { if (_returnValue.field1) { return { - __type: "field1", + __type: MyUnionType.MyUnionWithField1, field1: _returnValue.field1 }; } else { return { - __type: "field2", + __type: MyUnionType.MyUnionWithField2, field2: _returnValue.field2 }; } diff --git a/src/tests/unit/fixtures/thrift-server/nested_union.strict_union.solution.ts b/src/tests/unit/fixtures/thrift-server/nested_union.strict_union.solution.ts index aca132d1..7c9aa525 100644 --- a/src/tests/unit/fixtures/thrift-server/nested_union.strict_union.solution.ts +++ b/src/tests/unit/fixtures/thrift-server/nested_union.strict_union.solution.ts @@ -1,14 +1,66 @@ +export enum InnerUnionType { + InnerUnionWithName = "name", + InnerUnionWithId = "id" +} export type InnerUnion = IInnerUnionWithName | IInnerUnionWithId; export interface IInnerUnionWithName { + __type: InnerUnionType.InnerUnionWithName; name: string; id?: void; } export interface IInnerUnionWithId { + __type: InnerUnionType.InnerUnionWithId; + name?: void; + id: number; +} +export type InnerUnionArgs = IInnerUnionWithNameArgs | IInnerUnionWithIdArgs; +export interface IInnerUnionWithNameArgs { + name: string; + id?: void; +} +export interface IInnerUnionWithIdArgs { name?: void; id: number; } -export const InnerUnionCodec: thrift.IStructCodec = { - encode(args: InnerUnion, output: thrift.TProtocol): void { +export const InnerUnionCodec: thrift.IStructToolkit = { + create(args: InnerUnionArgs): InnerUnion { + let _fieldsSet: number = 0; + let _returnValue: any = null; + if (args.name != null) { + _fieldsSet++; + const value_1: string = args.name; + _returnValue = { name: value_1 }; + } + if (args.id != null) { + _fieldsSet++; + const value_2: number = args.id; + _returnValue = { id: value_2 }; + } + if (_fieldsSet > 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion cannot have more than one value"); + } + else if (_fieldsSet < 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); + } + if (_returnValue !== null) { + if (_returnValue.name) { + return { + __type: InnerUnionType.InnerUnionWithName, + name: _returnValue.name + }; + } + else { + return { + __type: InnerUnionType.InnerUnionWithId, + id: _returnValue.id + }; + } + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); + } + }, + encode(args: InnerUnionArgs, output: thrift.TProtocol): void { let _fieldsSet: number = 0; const obj = { name: args.name, @@ -39,7 +91,7 @@ export const InnerUnionCodec: thrift.IStructCodec = { }, decode(input: thrift.TProtocol): InnerUnion { let _fieldsSet: number = 0; - let _returnValue: InnerUnion | null = null; + let _returnValue: any = null; input.readStructBegin(); while (true) { const ret: thrift.IThriftField = input.readFieldBegin(); @@ -52,8 +104,8 @@ export const InnerUnionCodec: thrift.IStructCodec = { case 1: if (fieldType === thrift.TType.STRING) { _fieldsSet++; - const value_1: string = input.readString(); - _returnValue = { name: value_1 }; + const value_3: string = input.readString(); + _returnValue = { name: value_3 }; } else { input.skip(fieldType); @@ -62,8 +114,8 @@ export const InnerUnionCodec: thrift.IStructCodec = { case 2: if (fieldType === thrift.TType.I32) { _fieldsSet++; - const value_2: number = input.readI32(); - _returnValue = { id: value_2 }; + const value_4: number = input.readI32(); + _returnValue = { id: value_4 }; } else { input.skip(fieldType); @@ -83,24 +135,87 @@ export const InnerUnionCodec: thrift.IStructCodec = { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); } if (_returnValue !== null) { - return _returnValue; + if (_returnValue.name) { + return { + __type: InnerUnionType.InnerUnionWithName, + name: _returnValue.name + }; + } + else { + return { + __type: InnerUnionType.InnerUnionWithId, + id: _returnValue.id + }; + } } else { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); } } }; +export enum MyUnionType { + MyUnionWithUser = "user", + MyUnionWithField2 = "field2" +} export type MyUnion = IMyUnionWithUser | IMyUnionWithField2; export interface IMyUnionWithUser { + __type: MyUnionType.MyUnionWithUser; user: InnerUnion; field2?: void; } export interface IMyUnionWithField2 { + __type: MyUnionType.MyUnionWithField2; user?: void; field2: string; } -export const MyUnionCodec: thrift.IStructCodec = { - encode(args: MyUnion, output: thrift.TProtocol): void { +export type MyUnionArgs = IMyUnionWithUserArgs | IMyUnionWithField2Args; +export interface IMyUnionWithUserArgs { + user: InnerUnionArgs; + field2?: void; +} +export interface IMyUnionWithField2Args { + user?: void; + field2: string; +} +export const MyUnionCodec: thrift.IStructToolkit = { + create(args: MyUnionArgs): MyUnion { + let _fieldsSet: number = 0; + let _returnValue: any = null; + if (args.user != null) { + _fieldsSet++; + const value_5: InnerUnion = InnerUnionCodec.create(args.user); + _returnValue = { user: value_5 }; + } + if (args.field2 != null) { + _fieldsSet++; + const value_6: string = args.field2; + _returnValue = { field2: value_6 }; + } + if (_fieldsSet > 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion cannot have more than one value"); + } + else if (_fieldsSet < 1) { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); + } + if (_returnValue !== null) { + if (_returnValue.user) { + return { + __type: MyUnionType.MyUnionWithUser, + user: _returnValue.user + }; + } + else { + return { + __type: MyUnionType.MyUnionWithField2, + field2: _returnValue.field2 + }; + } + } + else { + throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); + } + }, + encode(args: MyUnionArgs, output: thrift.TProtocol): void { let _fieldsSet: number = 0; const obj = { user: args.user, @@ -131,7 +246,7 @@ export const MyUnionCodec: thrift.IStructCodec = { }, decode(input: thrift.TProtocol): MyUnion { let _fieldsSet: number = 0; - let _returnValue: MyUnion | null = null; + let _returnValue: any = null; input.readStructBegin(); while (true) { const ret: thrift.IThriftField = input.readFieldBegin(); @@ -144,8 +259,8 @@ export const MyUnionCodec: thrift.IStructCodec = { case 1: if (fieldType === thrift.TType.STRUCT) { _fieldsSet++; - const value_3: InnerUnion = InnerUnionCodec.decode(input); - _returnValue = { user: value_3 }; + const value_7: InnerUnion = InnerUnionCodec.decode(input); + _returnValue = { user: value_7 }; } else { input.skip(fieldType); @@ -154,8 +269,8 @@ export const MyUnionCodec: thrift.IStructCodec = { case 2: if (fieldType === thrift.TType.STRING) { _fieldsSet++; - const value_4: string = input.readString(); - _returnValue = { field2: value_4 }; + const value_8: string = input.readString(); + _returnValue = { field2: value_8 }; } else { input.skip(fieldType); @@ -175,7 +290,18 @@ export const MyUnionCodec: thrift.IStructCodec = { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.INVALID_DATA, "TUnion must have one value set"); } if (_returnValue !== null) { - return _returnValue; + if (_returnValue.user) { + return { + __type: MyUnionType.MyUnionWithUser, + user: _returnValue.user + }; + } + else { + return { + __type: MyUnionType.MyUnionWithField2, + field2: _returnValue.field2 + }; + } } else { throw new thrift.TProtocolException(thrift.TProtocolExceptionType.UNKNOWN, "Unable to read data for TUnion"); diff --git a/src/tests/unit/index.spec.ts b/src/tests/unit/index.spec.ts index 1ec3588a..859aec35 100644 --- a/src/tests/unit/index.spec.ts +++ b/src/tests/unit/index.spec.ts @@ -76,48 +76,76 @@ describe('Thrift TypeScript Generator', () => { }) describe('Thrift Server v2 Generated w Strict Unions', () => { - // before(() => { - // generate({ - // rootDir: __dirname, - // outDir: 'generated/strict-unions', - // sourceDir: 'fixtures/thrift', - // target: 'thrift-server', - // files: [], - // library: 'test-lib', - // strictUnions: true, - // }) - // }) - // it('should correctly generate typedefs for includes', () => { - // const actual: string = readGenerated( - // 'operation', - // 'generated/strict-unions', - // ) - // const expected: string = readGeneratedSolution( - // 'operation', - // 'generated/strict-unions', - // ) - // assert.deepEqual(actual, expected) - // }) - // it('should correctly generate a struct using includes', () => { - // const actual: string = readGenerated('common') - // const expected: string = readGeneratedSolution('common') - // assert.deepEqual(actual, expected) - // }) - // it('should correctly generate an exception using includes', () => { - // const actual: string = readGenerated('exceptions') - // const expected: string = readGeneratedSolution('exceptions') - // assert.deepEqual(actual, expected) - // }) - // it('should correctly generate a service', () => { - // const actual: string = readGenerated('shared') - // const expected: string = readGeneratedSolution('shared') - // assert.deepEqual(actual, expected) - // }) - // it('should correctly generate a service using includes', () => { - // const actual: string = readGenerated('calculator') - // const expected: string = readGeneratedSolution('calculator') - // assert.deepEqual(actual, expected) - // }) + before(() => { + generate({ + rootDir: __dirname, + outDir: 'generated/strict-unions', + sourceDir: 'fixtures/thrift', + target: 'thrift-server', + files: [], + library: 'test-lib', + strictUnions: true, + }) + }) + it('should correctly generate typedefs for includes', () => { + const actual: string = readGenerated( + 'operation', + 'generated/strict-unions', + ) + const expected: string = readGeneratedSolution( + 'operation', + 'generated/strict-unions', + ) + assert.deepEqual(actual, expected) + }) + + it('should correctly generate a struct using includes', () => { + const actual: string = readGenerated( + 'common', + 'generated/strict-unions', + ) + const expected: string = readGeneratedSolution( + 'common', + 'generated/strict-unions', + ) + assert.deepEqual(actual, expected) + }) + + it('should correctly generate an exception using includes', () => { + const actual: string = readGenerated( + 'exceptions', + 'generated/strict-unions', + ) + const expected: string = readGeneratedSolution( + 'exceptions', + 'generated/strict-unions', + ) + assert.deepEqual(actual, expected) + }) + + it('should correctly generate a service', () => { + const actual: string = readGenerated( + 'shared', + 'generated/strict-unions', + ) + const expected: string = readGeneratedSolution( + 'shared', + 'generated/strict-unions', + ) + assert.deepEqual(actual, expected) + }) + + it('should correctly generate a service using includes', () => { + const actual: string = readGenerated( + 'calculator', + 'generated/strict-unions', + ) + const expected: string = readGeneratedSolution( + 'calculator', + 'generated/strict-unions', + ) + assert.deepEqual(actual, expected) + }) }) describe('Thrift Server w/ Strict Unions', () => { @@ -134,54 +162,50 @@ describe('Thrift TypeScript Generator', () => { ) const actual: string = make(content, 'thrift-server', true) - console.log('actual: ', actual) - - assert.deepEqual(actual, expected) - }) - - // it('should correctly generate a union with a union field', () => { - // const content: string = ` - // union InnerUnion { - // 1: string name - // 2: i32 id - // } - - // union MyUnion { - // 1: InnerUnion user - // 2: string field2 - // } - // ` - // const expected: string = readSolution( - // 'nested_union.strict_union', - // 'thrift-server', - // ) - // const actual: string = make(content, 'thrift-server', true) - - // assert.deepEqual(actual, expected) - // }) - - // it('should correctly generate a service using a union', () => { - // const content: string = ` - // union MyUnion { - // 1: i32 field1 - // 2: i64 field2 - // } - - // service MyService { - // User getUser(1: MyUnion arg1) - // void ping() - // } - // ` - // const expected: string = readSolution( - // 'basic_service.strict_union', - // 'thrift-server', - // ) - // const actual: string = make(content, 'thrift-server', true) - - // console.log('actual: ', actual) - - // assert.deepEqual(actual, expected) - // }) + assert.deepEqual(actual, expected) + }) + + it('should correctly generate a union with a union field', () => { + const content: string = ` + union InnerUnion { + 1: string name + 2: i32 id + } + + union MyUnion { + 1: InnerUnion user + 2: string field2 + } + ` + const expected: string = readSolution( + 'nested_union.strict_union', + 'thrift-server', + ) + const actual: string = make(content, 'thrift-server', true) + + assert.deepEqual(actual, expected) + }) + + it('should correctly generate a service using a union', () => { + const content: string = ` + union MyUnion { + 1: i32 field1 + 2: i64 field2 + } + + service MyService { + string getUser(1: MyUnion arg1) + void ping() + } + ` + const expected: string = readSolution( + 'basic_service.strict_union', + 'thrift-server', + ) + const actual: string = make(content, 'thrift-server', true) + + assert.deepEqual(actual, expected) + }) }) describe('Thrift Server', () => { @@ -283,11 +307,11 @@ describe('Thrift TypeScript Generator', () => { it('should correctly generate a union', () => { const content: string = ` - union MyUnion { - 1: i32 field1 - 2: i64 field2 - } - ` + union MyUnion { + 1: i32 field1 + 2: i64 field2 + } + ` const expected: string = readSolution( 'basic_union', 'thrift-server', diff --git a/tsconfig.json b/tsconfig.json index 6d3c8689..37362a62 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -9,7 +9,7 @@ "outDir": "./dist", "noEmitOnError": true, "strict": true, - "noUnusedLocals": true, + "noUnusedLocals": false, "pretty": true, "removeComments": true }, diff --git a/tsconfig.test.json b/tsconfig.test.json index a3be346e..762e2345 100644 --- a/tsconfig.test.json +++ b/tsconfig.test.json @@ -9,7 +9,7 @@ "outDir": "./dist", "noEmitOnError": true, "strict": true, - "noUnusedLocals": true, + "noUnusedLocals": false, "pretty": true, "removeComments": true },