Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Crash in FetchMoreBcpData #1549

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 68 additions & 43 deletions contrib/babelfishpg_tds/src/backend/tds/tdsbulkload.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

static StringInfo SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message);
void ProcessBCPRequest(TDSRequest request);
static void FetchMoreBcpData(StringInfo *message, int dataLenToRead);
static void FetchMoreBcpData(StringInfo *message, int dataLenToRead, bool freeMessageData);
static void FetchMoreBcpPlpData(StringInfo *message, int dataLenToRead);
static int ReadBcpPlp(ParameterToken temp, StringInfo *message, TDSRequestBulkLoad request);
uint64_t offset = 0;
Expand Down Expand Up @@ -70,12 +70,26 @@ do \
temp->rowCount, colNum + 1, temp->colMetaData[i].columnTdsType))); \
} while(0)

/* Check if Message has enough data to read, if not then fetch more. */
#define CheckMessageHasEnoughBytesToRead(message, dataLen) \
/*
* Check if Message has enough data to read column metadata.
* If not then fetch more.
*/
#define CheckMessageHasEnoughBytesToReadColMetadata(message, dataLen) \
do \
{ \
if ((*message)->len - offset < dataLen) \
FetchMoreBcpData(message, dataLen, false); \
} while(0)

/*
* Check if Message has enough data to read the rows' data.
* If not then fetch more.
*/
#define CheckMessageHasEnoughBytesToReadRows(message, dataLen) \
do \
{ \
if ((*message)->len - offset < dataLen) \
FetchMoreBcpData(message, dataLen); \
FetchMoreBcpData(message, dataLen, true); \
} while(0)

/* Check if Message has enough data to read, if not then fetch more. */
Expand All @@ -87,7 +101,7 @@ do \
} while(0)

static void
FetchMoreBcpData(StringInfo *message, int dataLenToRead)
FetchMoreBcpData(StringInfo *message, int dataLenToRead, bool freeMessageData)
{
StringInfo temp;
int ret;
Expand All @@ -107,17 +121,30 @@ FetchMoreBcpData(StringInfo *message, int dataLenToRead)
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("Trying to read more data than available in BCP request.")));

temp = makeStringInfo();
appendBinaryStringInfo(temp, (*message)->data + offset, (*message)->len - offset);
/*
* If we are trying to read next packet and freeMessageData is true then
* afford to free it.
* NOTE: We should free the message data only while reading
* rows' data and on the other hand we should not free the
* column-metadata information until we are done with it.
*/
if (freeMessageData)
{
temp = makeStringInfo();
appendBinaryStringInfo(temp, (*message)->data + offset, (*message)->len - offset);

if ((*message)->data)
pfree((*message)->data);
pfree((*message));
if ((*message)->data)
pfree((*message)->data);
pfree((*message));
offset = 0;
}
else
temp = *message;

/*
* Keep fetching for additional packets until we have enough data to read.
*/
while (dataLenToRead > temp->len)
while (dataLenToRead + offset > temp->len)
{
/*
* We should hold the interrupts until we read the next request frame.
Expand All @@ -139,7 +166,6 @@ FetchMoreBcpData(StringInfo *message, int dataLenToRead)
}
}

offset = 0;
(*message) = temp;
}

Expand Down Expand Up @@ -228,7 +254,7 @@ GetBulkLoadRequest(StringInfo message)

for (int currentColumn = 0; currentColumn < colCount; currentColumn++)
{
CheckMessageHasEnoughBytesToRead(&message, COLUMNMETADATA_HEADER_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, COLUMNMETADATA_HEADER_LEN);
/* UserType */
memcpy(&colmetadata[currentColumn].userType, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);
Expand All @@ -249,12 +275,12 @@ GetBulkLoadRequest(StringInfo message)
case TDS_TYPE_MONEYN:
case TDS_TYPE_DATETIMEN:
case TDS_TYPE_UNIQUEIDENTIFIER:
CheckMessageHasEnoughBytesToRead(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
colmetadata[currentColumn].maxLen = message->data[offset++];
break;
case TDS_TYPE_DECIMALN:
case TDS_TYPE_NUMERICN:
CheckMessageHasEnoughBytesToRead(&message, NUMERIC_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, NUMERIC_COLUMNMETADATA_LEN);
colmetadata[currentColumn].maxLen = message->data[offset++];
colmetadata[currentColumn].precision = message->data[offset++];
colmetadata[currentColumn].scale = message->data[offset++];
Expand All @@ -264,7 +290,7 @@ GetBulkLoadRequest(StringInfo message)
case TDS_TYPE_NCHAR:
case TDS_TYPE_NVARCHAR:
{
CheckMessageHasEnoughBytesToRead(&message, STRING_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, STRING_COLUMNMETADATA_LEN);
memcpy(&colmetadata[currentColumn].maxLen, &message->data[offset], sizeof(uint16));
offset += sizeof(uint16);

Expand All @@ -280,53 +306,53 @@ GetBulkLoadRequest(StringInfo message)
{
uint16_t tableLen = 0;

CheckMessageHasEnoughBytesToRead(&message, sizeof(uint32_t));
CheckMessageHasEnoughBytesToReadColMetadata(&message, sizeof(uint32_t));
memcpy(&colmetadata[currentColumn].maxLen, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);

/* Read collation(LICD) and sort-id for TEXT and NTEXT. */
if (colmetadata[currentColumn].columnTdsType == TDS_TYPE_TEXT ||
colmetadata[currentColumn].columnTdsType == TDS_TYPE_NTEXT)
{
CheckMessageHasEnoughBytesToRead(&message, sizeof(uint32_t) + 1);
CheckMessageHasEnoughBytesToReadColMetadata(&message, sizeof(uint32_t) + 1);
memcpy(&collation, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);
colmetadata[currentColumn].sortId = message->data[offset++];
colmetadata[currentColumn].encoding = TdsGetEncoding(collation);
}

CheckMessageHasEnoughBytesToRead(&message, sizeof(uint16_t));
CheckMessageHasEnoughBytesToReadColMetadata(&message, sizeof(uint16_t));
memcpy(&tableLen, &message->data[offset], sizeof(uint16_t));
offset += sizeof(uint16_t);

/* Skip table name for now. */
CheckMessageHasEnoughBytesToRead(&message, tableLen * 2);
CheckMessageHasEnoughBytesToReadColMetadata(&message, tableLen * 2);
offset += tableLen * 2;
}
break;
case TDS_TYPE_XML:
{
CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadColMetadata(&message, 1);
colmetadata[currentColumn].maxLen = message->data[offset++];
}
break;
case TDS_TYPE_DATETIME2:
{
CheckMessageHasEnoughBytesToRead(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
colmetadata[currentColumn].scale = message->data[offset++];
colmetadata[currentColumn].maxLen = 8;
}
break;
case TDS_TYPE_TIME:
{
CheckMessageHasEnoughBytesToRead(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
colmetadata[currentColumn].scale = message->data[offset++];
colmetadata[currentColumn].maxLen = 5;
}
break;
case TDS_TYPE_DATETIMEOFFSET:
{
CheckMessageHasEnoughBytesToRead(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
colmetadata[currentColumn].scale = message->data[offset++];
colmetadata[currentColumn].maxLen = 10;
}
Expand All @@ -336,7 +362,7 @@ GetBulkLoadRequest(StringInfo message)
{
uint16 plp;

CheckMessageHasEnoughBytesToRead(&message, BINARY_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, BINARY_COLUMNMETADATA_LEN);
memcpy(&plp, &message->data[offset], sizeof(uint16));
offset += sizeof(uint16);
colmetadata[currentColumn].maxLen = plp;
Expand All @@ -346,7 +372,7 @@ GetBulkLoadRequest(StringInfo message)
colmetadata[currentColumn].maxLen = 3;
break;
case TDS_TYPE_SQLVARIANT:
CheckMessageHasEnoughBytesToRead(&message, SQL_VARIANT_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, SQL_VARIANT_COLUMNMETADATA_LEN);
memcpy(&colmetadata[currentColumn].maxLen, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);
break;
Expand Down Expand Up @@ -442,10 +468,10 @@ GetBulkLoadRequest(StringInfo message)
}

/* Column Name */
CheckMessageHasEnoughBytesToRead(&message, sizeof(uint8_t));
CheckMessageHasEnoughBytesToReadColMetadata(&message, sizeof(uint8_t));
memcpy(&colmetadata[currentColumn].colNameLen, &message->data[offset++], sizeof(uint8_t));

CheckMessageHasEnoughBytesToRead(&message, colmetadata[currentColumn].colNameLen * 2);
CheckMessageHasEnoughBytesToReadColMetadata(&message, colmetadata[currentColumn].colNameLen * 2);
colmetadata[currentColumn].colName = (char *) palloc0(colmetadata[currentColumn].colNameLen * sizeof(char) * 2 + 1);
memcpy(colmetadata[currentColumn].colName, &message->data[offset],
colmetadata[currentColumn].colNameLen * 2);
Expand Down Expand Up @@ -475,7 +501,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
request->rowData = NIL;
request->currentBatchSize = 0;

CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);

/* Loop over each row. */
while ((uint8_t) message->data[offset] == TDS_TOKEN_ROW
Expand Down Expand Up @@ -515,7 +541,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
}
else
{
CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);
len = message->data[offset++];
request->currentBatchSize++;

Expand All @@ -528,7 +554,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
}
CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand Down Expand Up @@ -593,7 +619,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
"Check the source data for invalid values. An example of an invalid value is data of numeric type with scale greater than precision.",
request->rowCount, i + 1)));

CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);

len = message->data[offset++];
request->currentBatchSize++;
Expand All @@ -606,7 +632,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)

CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand Down Expand Up @@ -634,15 +660,15 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
{
if (colmetadata[i].maxLen != 0xffff)
{
CheckMessageHasEnoughBytesToRead(&message, sizeof(short));
CheckMessageHasEnoughBytesToReadRows(&message, sizeof(short));
memcpy(&len, &message->data[offset], sizeof(short));
offset += sizeof(short);
request->currentBatchSize += sizeof(short);
if (len != 0xffff)
{
CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand Down Expand Up @@ -716,7 +742,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
{
uint8 dataTextPtrLen;

CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);

/*
* Ignore the Data Text Ptr since its currently of no
Expand All @@ -731,7 +757,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
continue;
}

CheckMessageHasEnoughBytesToRead(&message, dataTextPtrLen + 8 + sizeof(uint32_t));
CheckMessageHasEnoughBytesToReadRows(&message, dataTextPtrLen + 8 + sizeof(uint32_t));

offset += dataTextPtrLen;
request->currentBatchSize += dataTextPtrLen;
Expand All @@ -751,7 +777,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)

CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand Down Expand Up @@ -813,7 +839,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
break;
case TDS_TYPE_SQLVARIANT:
{
CheckMessageHasEnoughBytesToRead(&message, sizeof(uint32_t));
CheckMessageHasEnoughBytesToReadRows(&message, sizeof(uint32_t));

memcpy(&len, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);
Expand All @@ -828,7 +854,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)

CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand All @@ -850,15 +876,14 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
i++;
}
request->rowData = lappend(request->rowData, rowData);
CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);
}

/*
* If row count is less than the default batch size then this is the last
* packet, the next byte should be the done token.
*/
CheckMessageHasEnoughBytesToRead(&message, 1);

CheckMessageHasEnoughBytesToReadRows(&message, 1);
if (request->rowCount < pltsql_plugin_handler_ptr->get_insert_bulk_rows_per_batch()
&& request->currentBatchSize < pltsql_plugin_handler_ptr->get_insert_bulk_kilobytes_per_batch() * 1024
&& (uint8_t) message->data[offset] != TDS_TOKEN_DONE)
Expand Down
Loading