Skip to content

Commit

Permalink
Fix Crash in FetchMoreBcpData (#1549)
Browse files Browse the repository at this point in the history
We are crashing here when the implicit batching of packets is happening at
the very first phase i.e. reading column metadata phase because the table
itself has huge number of columns. Inherently TDS is in a "miscellaneous"
mem context during FETCH phase and the allocated pointer thus seems to crash
for invalid context during pfree.

To fix this we are appending the new packet's data during the FETCH
phase rather than freeing the message data which we do in the PROCESS phase.

Signed-off-by: Kushaal Shroff <[email protected]>
  • Loading branch information
KushaalShroff authored Jun 16, 2023
1 parent 2ccf842 commit d1312f0
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 43 deletions.
111 changes: 68 additions & 43 deletions contrib/babelfishpg_tds/src/backend/tds/tdsbulkload.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

static StringInfo SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message);
void ProcessBCPRequest(TDSRequest request);
static void FetchMoreBcpData(StringInfo *message, int dataLenToRead);
static void FetchMoreBcpData(StringInfo *message, int dataLenToRead, bool freeMessageData);
static void FetchMoreBcpPlpData(StringInfo *message, int dataLenToRead);
static int ReadBcpPlp(ParameterToken temp, StringInfo *message, TDSRequestBulkLoad request);
uint64_t offset = 0;
Expand Down Expand Up @@ -70,12 +70,26 @@ do \
temp->rowCount, colNum + 1, temp->colMetaData[i].columnTdsType))); \
} while(0)

/* Check if Message has enough data to read, if not then fetch more. */
#define CheckMessageHasEnoughBytesToRead(message, dataLen) \
/*
* Check if Message has enough data to read column metadata.
* If not then fetch more.
*/
#define CheckMessageHasEnoughBytesToReadColMetadata(message, dataLen) \
do \
{ \
if ((*message)->len - offset < dataLen) \
FetchMoreBcpData(message, dataLen, false); \
} while(0)

/*
* Check if Message has enough data to read the rows' data.
* If not then fetch more.
*/
#define CheckMessageHasEnoughBytesToReadRows(message, dataLen) \
do \
{ \
if ((*message)->len - offset < dataLen) \
FetchMoreBcpData(message, dataLen); \
FetchMoreBcpData(message, dataLen, true); \
} while(0)

/* Check if Message has enough data to read, if not then fetch more. */
Expand All @@ -87,7 +101,7 @@ do \
} while(0)

static void
FetchMoreBcpData(StringInfo *message, int dataLenToRead)
FetchMoreBcpData(StringInfo *message, int dataLenToRead, bool freeMessageData)
{
StringInfo temp;
int ret;
Expand All @@ -107,17 +121,30 @@ FetchMoreBcpData(StringInfo *message, int dataLenToRead)
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("Trying to read more data than available in BCP request.")));

temp = makeStringInfo();
appendBinaryStringInfo(temp, (*message)->data + offset, (*message)->len - offset);
/*
* If we are trying to read next packet and freeMessageData is true then
* afford to free it.
* NOTE: We should free the message data only while reading
* rows' data and on the other hand we should not free the
* column-metadata information until we are done with it.
*/
if (freeMessageData)
{
temp = makeStringInfo();
appendBinaryStringInfo(temp, (*message)->data + offset, (*message)->len - offset);

if ((*message)->data)
pfree((*message)->data);
pfree((*message));
if ((*message)->data)
pfree((*message)->data);
pfree((*message));
offset = 0;
}
else
temp = *message;

/*
* Keep fetching for additional packets until we have enough data to read.
*/
while (dataLenToRead > temp->len)
while (dataLenToRead + offset > temp->len)
{
/*
* We should hold the interrupts until we read the next request frame.
Expand All @@ -139,7 +166,6 @@ FetchMoreBcpData(StringInfo *message, int dataLenToRead)
}
}

offset = 0;
(*message) = temp;
}

Expand Down Expand Up @@ -228,7 +254,7 @@ GetBulkLoadRequest(StringInfo message)

for (int currentColumn = 0; currentColumn < colCount; currentColumn++)
{
CheckMessageHasEnoughBytesToRead(&message, COLUMNMETADATA_HEADER_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, COLUMNMETADATA_HEADER_LEN);
/* UserType */
memcpy(&colmetadata[currentColumn].userType, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);
Expand All @@ -249,12 +275,12 @@ GetBulkLoadRequest(StringInfo message)
case TDS_TYPE_MONEYN:
case TDS_TYPE_DATETIMEN:
case TDS_TYPE_UNIQUEIDENTIFIER:
CheckMessageHasEnoughBytesToRead(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
colmetadata[currentColumn].maxLen = message->data[offset++];
break;
case TDS_TYPE_DECIMALN:
case TDS_TYPE_NUMERICN:
CheckMessageHasEnoughBytesToRead(&message, NUMERIC_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, NUMERIC_COLUMNMETADATA_LEN);
colmetadata[currentColumn].maxLen = message->data[offset++];
colmetadata[currentColumn].precision = message->data[offset++];
colmetadata[currentColumn].scale = message->data[offset++];
Expand All @@ -264,7 +290,7 @@ GetBulkLoadRequest(StringInfo message)
case TDS_TYPE_NCHAR:
case TDS_TYPE_NVARCHAR:
{
CheckMessageHasEnoughBytesToRead(&message, STRING_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, STRING_COLUMNMETADATA_LEN);
memcpy(&colmetadata[currentColumn].maxLen, &message->data[offset], sizeof(uint16));
offset += sizeof(uint16);

Expand All @@ -280,53 +306,53 @@ GetBulkLoadRequest(StringInfo message)
{
uint16_t tableLen = 0;

CheckMessageHasEnoughBytesToRead(&message, sizeof(uint32_t));
CheckMessageHasEnoughBytesToReadColMetadata(&message, sizeof(uint32_t));
memcpy(&colmetadata[currentColumn].maxLen, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);

/* Read collation(LICD) and sort-id for TEXT and NTEXT. */
if (colmetadata[currentColumn].columnTdsType == TDS_TYPE_TEXT ||
colmetadata[currentColumn].columnTdsType == TDS_TYPE_NTEXT)
{
CheckMessageHasEnoughBytesToRead(&message, sizeof(uint32_t) + 1);
CheckMessageHasEnoughBytesToReadColMetadata(&message, sizeof(uint32_t) + 1);
memcpy(&collation, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);
colmetadata[currentColumn].sortId = message->data[offset++];
colmetadata[currentColumn].encoding = TdsGetEncoding(collation);
}

CheckMessageHasEnoughBytesToRead(&message, sizeof(uint16_t));
CheckMessageHasEnoughBytesToReadColMetadata(&message, sizeof(uint16_t));
memcpy(&tableLen, &message->data[offset], sizeof(uint16_t));
offset += sizeof(uint16_t);

/* Skip table name for now. */
CheckMessageHasEnoughBytesToRead(&message, tableLen * 2);
CheckMessageHasEnoughBytesToReadColMetadata(&message, tableLen * 2);
offset += tableLen * 2;
}
break;
case TDS_TYPE_XML:
{
CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadColMetadata(&message, 1);
colmetadata[currentColumn].maxLen = message->data[offset++];
}
break;
case TDS_TYPE_DATETIME2:
{
CheckMessageHasEnoughBytesToRead(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
colmetadata[currentColumn].scale = message->data[offset++];
colmetadata[currentColumn].maxLen = 8;
}
break;
case TDS_TYPE_TIME:
{
CheckMessageHasEnoughBytesToRead(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
colmetadata[currentColumn].scale = message->data[offset++];
colmetadata[currentColumn].maxLen = 5;
}
break;
case TDS_TYPE_DATETIMEOFFSET:
{
CheckMessageHasEnoughBytesToRead(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, FIXED_LEN_TYPE_COLUMNMETADATA_LEN);
colmetadata[currentColumn].scale = message->data[offset++];
colmetadata[currentColumn].maxLen = 10;
}
Expand All @@ -336,7 +362,7 @@ GetBulkLoadRequest(StringInfo message)
{
uint16 plp;

CheckMessageHasEnoughBytesToRead(&message, BINARY_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, BINARY_COLUMNMETADATA_LEN);
memcpy(&plp, &message->data[offset], sizeof(uint16));
offset += sizeof(uint16);
colmetadata[currentColumn].maxLen = plp;
Expand All @@ -346,7 +372,7 @@ GetBulkLoadRequest(StringInfo message)
colmetadata[currentColumn].maxLen = 3;
break;
case TDS_TYPE_SQLVARIANT:
CheckMessageHasEnoughBytesToRead(&message, SQL_VARIANT_COLUMNMETADATA_LEN);
CheckMessageHasEnoughBytesToReadColMetadata(&message, SQL_VARIANT_COLUMNMETADATA_LEN);
memcpy(&colmetadata[currentColumn].maxLen, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);
break;
Expand Down Expand Up @@ -442,10 +468,10 @@ GetBulkLoadRequest(StringInfo message)
}

/* Column Name */
CheckMessageHasEnoughBytesToRead(&message, sizeof(uint8_t));
CheckMessageHasEnoughBytesToReadColMetadata(&message, sizeof(uint8_t));
memcpy(&colmetadata[currentColumn].colNameLen, &message->data[offset++], sizeof(uint8_t));

CheckMessageHasEnoughBytesToRead(&message, colmetadata[currentColumn].colNameLen * 2);
CheckMessageHasEnoughBytesToReadColMetadata(&message, colmetadata[currentColumn].colNameLen * 2);
colmetadata[currentColumn].colName = (char *) palloc0(colmetadata[currentColumn].colNameLen * sizeof(char) * 2 + 1);
memcpy(colmetadata[currentColumn].colName, &message->data[offset],
colmetadata[currentColumn].colNameLen * 2);
Expand Down Expand Up @@ -475,7 +501,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
request->rowData = NIL;
request->currentBatchSize = 0;

CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);

/* Loop over each row. */
while ((uint8_t) message->data[offset] == TDS_TOKEN_ROW
Expand Down Expand Up @@ -515,7 +541,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
}
else
{
CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);
len = message->data[offset++];
request->currentBatchSize++;

Expand All @@ -528,7 +554,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
}
CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand Down Expand Up @@ -593,7 +619,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
"Check the source data for invalid values. An example of an invalid value is data of numeric type with scale greater than precision.",
request->rowCount, i + 1)));

CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);

len = message->data[offset++];
request->currentBatchSize++;
Expand All @@ -606,7 +632,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)

CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand Down Expand Up @@ -634,15 +660,15 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
{
if (colmetadata[i].maxLen != 0xffff)
{
CheckMessageHasEnoughBytesToRead(&message, sizeof(short));
CheckMessageHasEnoughBytesToReadRows(&message, sizeof(short));
memcpy(&len, &message->data[offset], sizeof(short));
offset += sizeof(short);
request->currentBatchSize += sizeof(short);
if (len != 0xffff)
{
CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand Down Expand Up @@ -716,7 +742,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
{
uint8 dataTextPtrLen;

CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);

/*
* Ignore the Data Text Ptr since its currently of no
Expand All @@ -731,7 +757,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
continue;
}

CheckMessageHasEnoughBytesToRead(&message, dataTextPtrLen + 8 + sizeof(uint32_t));
CheckMessageHasEnoughBytesToReadRows(&message, dataTextPtrLen + 8 + sizeof(uint32_t));

offset += dataTextPtrLen;
request->currentBatchSize += dataTextPtrLen;
Expand All @@ -751,7 +777,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)

CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand Down Expand Up @@ -813,7 +839,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
break;
case TDS_TYPE_SQLVARIANT:
{
CheckMessageHasEnoughBytesToRead(&message, sizeof(uint32_t));
CheckMessageHasEnoughBytesToReadRows(&message, sizeof(uint32_t));

memcpy(&len, &message->data[offset], sizeof(uint32_t));
offset += sizeof(uint32_t);
Expand All @@ -828,7 +854,7 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)

CheckForInvalidLength(len, request, i);

CheckMessageHasEnoughBytesToRead(&message, len);
CheckMessageHasEnoughBytesToReadRows(&message, len);

/* Build temp Stringinfo. */
temp->data = &message->data[offset];
Expand All @@ -850,15 +876,14 @@ SetBulkLoadRowData(TDSRequestBulkLoad request, StringInfo message)
i++;
}
request->rowData = lappend(request->rowData, rowData);
CheckMessageHasEnoughBytesToRead(&message, 1);
CheckMessageHasEnoughBytesToReadRows(&message, 1);
}

/*
* If row count is less than the default batch size then this is the last
* packet, the next byte should be the done token.
*/
CheckMessageHasEnoughBytesToRead(&message, 1);

CheckMessageHasEnoughBytesToReadRows(&message, 1);
if (request->rowCount < pltsql_plugin_handler_ptr->get_insert_bulk_rows_per_batch()
&& request->currentBatchSize < pltsql_plugin_handler_ptr->get_insert_bulk_kilobytes_per_batch() * 1024
&& (uint8_t) message->data[offset] != TDS_TOKEN_DONE)
Expand Down
Loading

0 comments on commit d1312f0

Please sign in to comment.