From d5d62b4d94acf05d4335122efa9e36b07955eb2d Mon Sep 17 00:00:00 2001 From: Michael Lumish Date: Fri, 7 Jun 2024 10:11:06 -0700 Subject: [PATCH] grpc-js: Avoid buffering significantly more than max_receive_message_size per received message (1.9.x) --- packages/grpc-js/src/compression-filter.ts | 67 ++++++++++---- packages/grpc-js/src/internal-channel.ts | 2 - .../grpc-js/src/max-message-size-filter.ts | 88 ------------------- packages/grpc-js/src/server-call.ts | 87 +++++++++++------- packages/grpc-js/src/stream-decoder.ts | 5 ++ packages/grpc-js/src/subchannel-call.ts | 14 ++- packages/grpc-js/src/transport.ts | 7 +- .../grpc-js/test/fixtures/test_service.proto | 1 + packages/grpc-js/test/test-server-errors.ts | 49 ++++++++++- 9 files changed, 175 insertions(+), 145 deletions(-) delete mode 100644 packages/grpc-js/src/max-message-size-filter.ts diff --git a/packages/grpc-js/src/compression-filter.ts b/packages/grpc-js/src/compression-filter.ts index 136311ad5..f1600b36d 100644 --- a/packages/grpc-js/src/compression-filter.ts +++ b/packages/grpc-js/src/compression-filter.ts @@ -21,7 +21,7 @@ import { WriteObject, WriteFlags } from './call-interface'; import { Channel } from './channel'; import { ChannelOptions } from './channel-options'; import { CompressionAlgorithms } from './compression-algorithms'; -import { LogVerbosity } from './constants'; +import { DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH, LogVerbosity, Status } from './constants'; import { BaseFilter, Filter, FilterFactory } from './filter'; import * as logging from './logging'; import { Metadata, MetadataValue } from './metadata'; @@ -98,6 +98,10 @@ class IdentityHandler extends CompressionHandler { } class DeflateHandler extends CompressionHandler { + constructor(private maxRecvMessageLength: number) { + super(); + } + compressMessage(message: Buffer) { return new Promise((resolve, reject) => { zlib.deflate(message, (err, output) => { @@ -112,18 +116,34 @@ class DeflateHandler extends CompressionHandler { decompressMessage(message: Buffer) { return new Promise((resolve, reject) => { - zlib.inflate(message, (err, output) => { - if (err) { - reject(err); - } else { - resolve(output); + let totalLength = 0; + const messageParts: Buffer[] = []; + const decompresser = zlib.createInflate(); + decompresser.on('data', (chunk: Buffer) => { + messageParts.push(chunk); + totalLength += chunk.byteLength; + if (this.maxRecvMessageLength !== -1 && totalLength > this.maxRecvMessageLength) { + decompresser.destroy(); + reject({ + code: Status.RESOURCE_EXHAUSTED, + details: `Received message that decompresses to a size larger than ${this.maxRecvMessageLength}` + }); } }); + decompresser.on('end', () => { + resolve(Buffer.concat(messageParts)); + }); + decompresser.write(message); + decompresser.end(); }); } } class GzipHandler extends CompressionHandler { + constructor(private maxRecvMessageLength: number) { + super(); + } + compressMessage(message: Buffer) { return new Promise((resolve, reject) => { zlib.gzip(message, (err, output) => { @@ -138,13 +158,25 @@ class GzipHandler extends CompressionHandler { decompressMessage(message: Buffer) { return new Promise((resolve, reject) => { - zlib.unzip(message, (err, output) => { - if (err) { - reject(err); - } else { - resolve(output); + let totalLength = 0; + const messageParts: Buffer[] = []; + const decompresser = zlib.createGunzip(); + decompresser.on('data', (chunk: Buffer) => { + messageParts.push(chunk); + totalLength += chunk.byteLength; + if (this.maxRecvMessageLength !== -1 && totalLength > this.maxRecvMessageLength) { + decompresser.destroy(); + reject({ + code: Status.RESOURCE_EXHAUSTED, + details: `Received message that decompresses to a size larger than ${this.maxRecvMessageLength}` + }); } }); + decompresser.on('end', () => { + resolve(Buffer.concat(messageParts)); + }); + decompresser.write(message); + decompresser.end(); }); } } @@ -169,14 +201,14 @@ class UnknownHandler extends CompressionHandler { } } -function getCompressionHandler(compressionName: string): CompressionHandler { +function getCompressionHandler(compressionName: string, maxReceiveMessageSize: number): CompressionHandler { switch (compressionName) { case 'identity': return new IdentityHandler(); case 'deflate': - return new DeflateHandler(); + return new DeflateHandler(maxReceiveMessageSize); case 'gzip': - return new GzipHandler(); + return new GzipHandler(maxReceiveMessageSize); default: return new UnknownHandler(compressionName); } @@ -186,6 +218,7 @@ export class CompressionFilter extends BaseFilter implements Filter { private sendCompression: CompressionHandler = new IdentityHandler(); private receiveCompression: CompressionHandler = new IdentityHandler(); private currentCompressionAlgorithm: CompressionAlgorithm = 'identity'; + private maxReceiveMessageLength: number; constructor( channelOptions: ChannelOptions, @@ -195,6 +228,7 @@ export class CompressionFilter extends BaseFilter implements Filter { const compressionAlgorithmKey = channelOptions['grpc.default_compression_algorithm']; + this.maxReceiveMessageLength = channelOptions['grpc.max_receive_message_length'] ?? DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH if (compressionAlgorithmKey !== undefined) { if (isCompressionAlgorithmKey(compressionAlgorithmKey)) { const clientSelectedEncoding = CompressionAlgorithms[ @@ -215,7 +249,8 @@ export class CompressionFilter extends BaseFilter implements Filter { ) { this.currentCompressionAlgorithm = clientSelectedEncoding; this.sendCompression = getCompressionHandler( - this.currentCompressionAlgorithm + this.currentCompressionAlgorithm, + -1 ); } } else { @@ -247,7 +282,7 @@ export class CompressionFilter extends BaseFilter implements Filter { if (receiveEncoding.length > 0) { const encoding: MetadataValue = receiveEncoding[0]; if (typeof encoding === 'string') { - this.receiveCompression = getCompressionHandler(encoding); + this.receiveCompression = getCompressionHandler(encoding, this.maxReceiveMessageLength); } } metadata.remove('grpc-encoding'); diff --git a/packages/grpc-js/src/internal-channel.ts b/packages/grpc-js/src/internal-channel.ts index 6a65b712f..5c8c91c41 100644 --- a/packages/grpc-js/src/internal-channel.ts +++ b/packages/grpc-js/src/internal-channel.ts @@ -33,7 +33,6 @@ import { } from './resolver'; import { trace } from './logging'; import { SubchannelAddress } from './subchannel-address'; -import { MaxMessageSizeFilterFactory } from './max-message-size-filter'; import { mapProxyName } from './http_proxy'; import { GrpcUri, parseUri, uriToString } from './uri-parser'; import { ServerSurfaceCall } from './server-call'; @@ -393,7 +392,6 @@ export class InternalChannel { } ); this.filterStackFactory = new FilterStackFactory([ - new MaxMessageSizeFilterFactory(this.options), new CompressionFilterFactory(this, this.options), ]); this.trace( diff --git a/packages/grpc-js/src/max-message-size-filter.ts b/packages/grpc-js/src/max-message-size-filter.ts deleted file mode 100644 index b6df374b2..000000000 --- a/packages/grpc-js/src/max-message-size-filter.ts +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -import { BaseFilter, Filter, FilterFactory } from './filter'; -import { WriteObject } from './call-interface'; -import { - Status, - DEFAULT_MAX_SEND_MESSAGE_LENGTH, - DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH, -} from './constants'; -import { ChannelOptions } from './channel-options'; -import { Metadata } from './metadata'; - -export class MaxMessageSizeFilter extends BaseFilter implements Filter { - private maxSendMessageSize: number = DEFAULT_MAX_SEND_MESSAGE_LENGTH; - private maxReceiveMessageSize: number = DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH; - constructor(options: ChannelOptions) { - super(); - if ('grpc.max_send_message_length' in options) { - this.maxSendMessageSize = options['grpc.max_send_message_length']!; - } - if ('grpc.max_receive_message_length' in options) { - this.maxReceiveMessageSize = options['grpc.max_receive_message_length']!; - } - } - - async sendMessage(message: Promise): Promise { - /* A configured size of -1 means that there is no limit, so skip the check - * entirely */ - if (this.maxSendMessageSize === -1) { - return message; - } else { - const concreteMessage = await message; - if (concreteMessage.message.length > this.maxSendMessageSize) { - throw { - code: Status.RESOURCE_EXHAUSTED, - details: `Sent message larger than max (${concreteMessage.message.length} vs. ${this.maxSendMessageSize})`, - metadata: new Metadata(), - }; - } else { - return concreteMessage; - } - } - } - - async receiveMessage(message: Promise): Promise { - /* A configured size of -1 means that there is no limit, so skip the check - * entirely */ - if (this.maxReceiveMessageSize === -1) { - return message; - } else { - const concreteMessage = await message; - if (concreteMessage.length > this.maxReceiveMessageSize) { - throw { - code: Status.RESOURCE_EXHAUSTED, - details: `Received message larger than max (${concreteMessage.length} vs. ${this.maxReceiveMessageSize})`, - metadata: new Metadata(), - }; - } else { - return concreteMessage; - } - } - } -} - -export class MaxMessageSizeFilterFactory - implements FilterFactory -{ - constructor(private readonly options: ChannelOptions) {} - - createFilter(): MaxMessageSizeFilter { - return new MaxMessageSizeFilter(this.options); - } -} diff --git a/packages/grpc-js/src/server-call.ts b/packages/grpc-js/src/server-call.ts index 107c2e3ef..3032c6e4e 100644 --- a/packages/grpc-js/src/server-call.ts +++ b/packages/grpc-js/src/server-call.ts @@ -19,7 +19,6 @@ import { EventEmitter } from 'events'; import * as http2 from 'http2'; import { Duplex, Readable, Writable } from 'stream'; import * as zlib from 'zlib'; -import { promisify } from 'util'; import { Status, @@ -38,8 +37,6 @@ import { Deadline } from './deadline'; import { getErrorCode, getErrorMessage } from './error'; const TRACER_NAME = 'server_call'; -const unzip = promisify(zlib.unzip); -const inflate = promisify(zlib.inflate); function trace(text: string): void { logging.trace(LogVerbosity.DEBUG, TRACER_NAME, text); @@ -480,19 +477,42 @@ export class Http2ServerCallStream< private getDecompressedMessage( message: Buffer, encoding: string - ): Buffer | Promise { - if (encoding === 'deflate') { - return inflate(message.subarray(5)); - } else if (encoding === 'gzip') { - return unzip(message.subarray(5)); - } else if (encoding === 'identity') { - return message.subarray(5); + ): Buffer | Promise { const messageContents = message.subarray(5); + if (encoding === 'identity') { + return messageContents; + } else if (encoding === 'deflate' || encoding === 'gzip') { + let decompresser: zlib.Gunzip | zlib.Deflate; + if (encoding === 'deflate') { + decompresser = zlib.createInflate(); + } else { + decompresser = zlib.createGunzip(); + } + return new Promise((resolve, reject) => { + let totalLength = 0 + const messageParts: Buffer[] = []; + decompresser.on('data', (chunk: Buffer) => { + messageParts.push(chunk); + totalLength += chunk.byteLength; + if (this.maxReceiveMessageSize !== -1 && totalLength > this.maxReceiveMessageSize) { + decompresser.destroy(); + reject({ + code: Status.RESOURCE_EXHAUSTED, + details: `Received message that decompresses to a size larger than ${this.maxReceiveMessageSize}` + }); + } + }); + decompresser.on('end', () => { + resolve(Buffer.concat(messageParts)); + }); + decompresser.write(messageContents); + decompresser.end(); + }); + } else { + return Promise.reject({ + code: Status.UNIMPLEMENTED, + details: `Received message compressed with unsupported encoding "${encoding}"`, + }); } - - return Promise.reject({ - code: Status.UNIMPLEMENTED, - details: `Received message compressed with unsupported encoding "${encoding}"`, - }); } sendMetadata(customMetadata?: Metadata) { @@ -816,7 +836,7 @@ export class Http2ServerCallStream< | ServerDuplexStream, encoding: string ) { - const decoder = new StreamDecoder(); + const decoder = new StreamDecoder(this.maxReceiveMessageSize); let readsDone = false; @@ -832,29 +852,34 @@ export class Http2ServerCallStream< }; this.stream.on('data', async (data: Buffer) => { - const messages = decoder.write(data); + let messages: Buffer[]; + try { + messages = decoder.write(data); + } catch (e) { + this.sendError({ + code: Status.RESOURCE_EXHAUSTED, + details: (e as Error).message + }); + return; + } pendingMessageProcessing = true; this.stream.pause(); for (const message of messages) { - if ( - this.maxReceiveMessageSize !== -1 && - message.length > this.maxReceiveMessageSize - ) { - this.sendError({ - code: Status.RESOURCE_EXHAUSTED, - details: `Received message larger than max (${message.length} vs. ${this.maxReceiveMessageSize})`, - }); - return; - } this.emit('receiveMessage'); const compressed = message.readUInt8(0) === 1; const compressedMessageEncoding = compressed ? encoding : 'identity'; - const decompressedMessage = await this.getDecompressedMessage( - message, - compressedMessageEncoding - ); + let decompressedMessage: Buffer; + try { + decompressedMessage = await this.getDecompressedMessage( + message, + compressedMessageEncoding + ); + } catch (e) { + this.sendError(e as Partial); + return; + } // Encountered an error with decompression; it'll already have been propogated back // Just return early diff --git a/packages/grpc-js/src/stream-decoder.ts b/packages/grpc-js/src/stream-decoder.ts index 671ad41ae..ea669d14c 100644 --- a/packages/grpc-js/src/stream-decoder.ts +++ b/packages/grpc-js/src/stream-decoder.ts @@ -30,6 +30,8 @@ export class StreamDecoder { private readPartialMessage: Buffer[] = []; private readMessageRemaining = 0; + constructor(private maxReadMessageLength: number) {} + write(data: Buffer): Buffer[] { let readHead = 0; let toRead: number; @@ -60,6 +62,9 @@ export class StreamDecoder { // readSizeRemaining >=0 here if (this.readSizeRemaining === 0) { this.readMessageSize = this.readPartialSize.readUInt32BE(0); + if (this.maxReadMessageLength !== -1 && this.readMessageSize > this.maxReadMessageLength) { + throw new Error(`Received message larger than max (${this.readMessageSize} vs ${this.maxReadMessageLength})`); + } this.readMessageRemaining = this.readMessageSize; if (this.readMessageRemaining > 0) { this.readState = ReadState.READING_MESSAGE; diff --git a/packages/grpc-js/src/subchannel-call.ts b/packages/grpc-js/src/subchannel-call.ts index 3b9b6152f..b9f3191cc 100644 --- a/packages/grpc-js/src/subchannel-call.ts +++ b/packages/grpc-js/src/subchannel-call.ts @@ -18,7 +18,7 @@ import * as http2 from 'http2'; import * as os from 'os'; -import { Status } from './constants'; +import { DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH, Status } from './constants'; import { Metadata } from './metadata'; import { StreamDecoder } from './stream-decoder'; import * as logging from './logging'; @@ -82,7 +82,7 @@ export interface SubchannelCallInterceptingListener } export class Http2SubchannelCall implements SubchannelCall { - private decoder = new StreamDecoder(); + private decoder: StreamDecoder; private isReadFilterPending = false; private isPushPending = false; @@ -112,6 +112,8 @@ export class Http2SubchannelCall implements SubchannelCall { private readonly transport: Transport, private readonly callId: number ) { + const maxReceiveMessageLength = transport.getOptions()['grpc.max_receive_message_length'] ?? DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH; + this.decoder = new StreamDecoder(maxReceiveMessageLength); http2Stream.on('response', (headers, flags) => { let headersString = ''; for (const header of Object.keys(headers)) { @@ -169,7 +171,13 @@ export class Http2SubchannelCall implements SubchannelCall { return; } this.trace('receive HTTP/2 data frame of length ' + data.length); - const messages = this.decoder.write(data); + let messages: Buffer[]; + try { + messages = this.decoder.write(data); + } catch (e) { + this.cancelWithStatus(Status.RESOURCE_EXHAUSTED, (e as Error).message); + return; + } for (const message of messages) { this.trace('parsed message of length ' + message.length); diff --git a/packages/grpc-js/src/transport.ts b/packages/grpc-js/src/transport.ts index 39ca69383..fe9a81352 100644 --- a/packages/grpc-js/src/transport.ts +++ b/packages/grpc-js/src/transport.ts @@ -83,6 +83,7 @@ export interface TransportDisconnectListener { export interface Transport { getChannelzRef(): SocketRef; getPeerName(): string; + getOptions(): ChannelOptions; createCall( metadata: Metadata, host: string, @@ -146,7 +147,7 @@ class Http2Transport implements Transport { constructor( private session: http2.ClientHttp2Session, subchannelAddress: SubchannelAddress, - options: ChannelOptions, + private options: ChannelOptions, /** * Name of the remote server, if it is not the same as the subchannel * address, i.e. if connecting through an HTTP CONNECT proxy. @@ -601,6 +602,10 @@ class Http2Transport implements Transport { return this.subchannelAddressString; } + getOptions() { + return this.options; + } + shutdown() { this.session.close(); unregisterChannelzRef(this.channelzRef); diff --git a/packages/grpc-js/test/fixtures/test_service.proto b/packages/grpc-js/test/fixtures/test_service.proto index 64ce0d378..2a7a303f3 100644 --- a/packages/grpc-js/test/fixtures/test_service.proto +++ b/packages/grpc-js/test/fixtures/test_service.proto @@ -21,6 +21,7 @@ message Request { bool error = 1; string message = 2; int32 errorAfter = 3; + int32 responseLength = 4; } message Response { diff --git a/packages/grpc-js/test/test-server-errors.ts b/packages/grpc-js/test/test-server-errors.ts index 24ccfeef3..243e10918 100644 --- a/packages/grpc-js/test/test-server-errors.ts +++ b/packages/grpc-js/test/test-server-errors.ts @@ -33,6 +33,7 @@ import { } from '../src/server-call'; import { loadProtoFile } from './common'; +import { CompressionAlgorithms } from '../src/compression-algorithms'; const protoFile = join(__dirname, 'fixtures', 'test_service.proto'); const testServiceDef = loadProtoFile(protoFile); @@ -310,7 +311,7 @@ describe('Other conditions', () => { trailerMetadata ); } else { - cb(null, { count: 1 }, trailerMetadata); + cb(null, { count: 1, message: 'a'.repeat(req.responseLength) }, trailerMetadata); } }, @@ -320,6 +321,7 @@ describe('Other conditions', () => { ) { let count = 0; let errored = false; + let responseLength = 0; stream.on('data', (data: any) => { if (data.error) { @@ -327,13 +329,14 @@ describe('Other conditions', () => { errored = true; cb(new Error(message) as ServiceError, null, trailerMetadata); } else { + responseLength += data.responseLength; count++; } }); stream.on('end', () => { if (!errored) { - cb(null, { count }, trailerMetadata); + cb(null, { count, message: 'a'.repeat(responseLength) }, trailerMetadata); } }); }, @@ -349,7 +352,7 @@ describe('Other conditions', () => { }); } else { for (let i = 1; i <= 5; i++) { - stream.write({ count: i }); + stream.write({ count: i, message: 'a'.repeat(req.responseLength) }); if (req.errorAfter && req.errorAfter === i) { stream.emit('error', { code: grpc.status.UNKNOWN, @@ -376,7 +379,7 @@ describe('Other conditions', () => { err.metadata.add('count', '' + count); stream.emit('error', err); } else { - stream.write({ count }); + stream.write({ count, message: 'a'.repeat(data.responseLength) }); count++; } }); @@ -740,6 +743,44 @@ describe('Other conditions', () => { }); }); }); + + describe('Max message size', () => { + const largeMessage = 'a'.repeat(10_000_000); + it('Should be enforced on the server', done => { + client.unary({ message: largeMessage }, (error?: ServiceError) => { + assert(error); + assert.strictEqual(error.code, grpc.status.RESOURCE_EXHAUSTED); + done(); + }); + }); + it('Should be enforced on the client', done => { + client.unary({ responseLength: 10_000_000 }, (error?: ServiceError) => { + assert(error); + assert.strictEqual(error.code, grpc.status.RESOURCE_EXHAUSTED); + done(); + }); + }); + describe('Compressed messages', () => { + it('Should be enforced with gzip', done => { + const compressingClient = new testServiceClient(`localhost:${port}`, clientInsecureCreds, {'grpc.default_compression_algorithm': CompressionAlgorithms.gzip}); + compressingClient.unary({ message: largeMessage }, (error?: ServiceError) => { + assert(error); + assert.strictEqual(error.code, grpc.status.RESOURCE_EXHAUSTED); + assert.match(error.details, /Received message that decompresses to a size larger/); + done(); + }); + }); + it('Should be enforced with deflate', done => { + const compressingClient = new testServiceClient(`localhost:${port}`, clientInsecureCreds, {'grpc.default_compression_algorithm': CompressionAlgorithms.deflate}); + compressingClient.unary({ message: largeMessage }, (error?: ServiceError) => { + assert(error); + assert.strictEqual(error.code, grpc.status.RESOURCE_EXHAUSTED); + assert.match(error.details, /Received message that decompresses to a size larger/); + done(); + }); + }); + }); + }); }); function identity(arg: any): any {