Skip to content

Commit

Permalink
Reset db context at the time TDS resets the connection (babelfish-for…
Browse files Browse the repository at this point in the history
…-postgresql#2971)

T-SQL Behaviour suggests that if we connect to database db1 and if during the session we have changed the database context to db2 then at the time of reset connection, the server must reset the connection to db1. Earlier we were not resetting the database context to that of the database used to login, in the above example db1, this lead to clients being handed a stale connection.
To Fix this we reset the database context to that from the loginInfo which was maintained at time of login. Changes were also made to avoid sending the environment change token for the implicit "USE DB" being run at time of reset.

Issues Resolved
BABEL-5256

Signed off by: Kushaal Shroff <[email protected]>
  • Loading branch information
KushaalShroff committed Sep 26, 2024
1 parent 800ea29 commit 50b72a9
Show file tree
Hide file tree
Showing 15 changed files with 308 additions and 86 deletions.
1 change: 1 addition & 0 deletions contrib/babelfishpg_tds/src/backend/tds/tds_srv.c
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ pe_tds_init(void)
pltsql_plugin_handler_ptr->invalidate_stat_view = &invalidate_stat_table;
pltsql_plugin_handler_ptr->get_host_name = &get_tds_host_name;
pltsql_plugin_handler_ptr->set_reset_tds_connection_flag = &SetResetTDSConnectionFlag;
pltsql_plugin_handler_ptr->get_reset_tds_connection_flag = &GetResetTDSConnectionFlag;

invalidate_stat_table_hook = invalidate_stat_table;
guc_newval_hook = TdsSetGucStatVariable;
Expand Down
198 changes: 125 additions & 73 deletions contrib/babelfishpg_tds/src/backend/tds/tdslogin.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

#include "src/include/tds_debug.h"
#include "src/include/tds_int.h"
#include "src/include/tds_protocol.h"
#include "src/include/tds_request.h"
#include "src/include/tds_response.h"
#include "src/include/guc.h"
Expand Down Expand Up @@ -2008,6 +2009,128 @@ TdsProcessLogin(Port *port, bool loadedSsl)
return rc;
}

/*
* TdsSetDbContext:
* Used to Set the Database Context during login
* and during reset connection.
* Note: We should not optimize the scenario during
* reset connection to reset to the same database
* which might be in use since the USE db command
* will reset other configurations which might
* have changed.
*/
void
TdsSetDbContext()
{
char *dbname = NULL;
char *useDbCommand = NULL;
char *user = NULL;
MemoryContext oldContext = CurrentMemoryContext;

PG_TRY();
{
if (loginInfo->database != NULL && loginInfo->database[0] != '\0')
{
Oid db_id;

/*
* Before preparing the query, first check whether we got a valid
* database name and it exists. Otherwise, there'll be risk of
* SQL injection.
*/
StartTransactionCommand();
db_id = pltsql_plugin_handler_ptr->pltsql_get_database_oid(loginInfo->database);
CommitTransactionCommand();
MemoryContextSwitchTo(oldContext);

if (!OidIsValid(db_id))
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("database \"%s\" does not exist", loginInfo->database)));

/*
* Any delimitated/quoted db name identifier requested in login
* must be already handled before this point.
*/
useDbCommand = psprintf("USE [%s]", loginInfo->database);
dbname = pstrdup(loginInfo->database);
}
else
{
char *temp = NULL;

StartTransactionCommand();
temp = pltsql_plugin_handler_ptr->pltsql_get_login_default_db(loginInfo->username);
MemoryContextSwitchTo(oldContext);

if (temp == NULL)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("could not find default database for user \"%s\"", loginInfo->username)));

useDbCommand = psprintf("USE [%s]", temp);
dbname = pstrdup(temp);
CommitTransactionCommand();
MemoryContextSwitchTo(oldContext);
}

StartTransactionCommand();
/*
* Check if user has privileges to access current database.
*/
user = pltsql_plugin_handler_ptr->pltsql_get_user_for_database(dbname);
if (!user)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("Cannot open database \"%s\" requested by the login. The login failed", dbname)));

/*
* loginInfo has a database name provided, so we execute a "USE
* [<db_name>]" through pltsql inline handler.
*/
ExecuteSQLBatch(useDbCommand);
CommitTransactionCommand();
}
PG_CATCH();
{
/*
* If this is during reset phase and we encounter an error
* with mapped user or db not found then we should terminate
* the connection.
*/
if (resetTdsConnectionFlag)
{
/* Before terminating the connection, send the response to the client. */
EmitErrorReport();
FlushErrorState();

/*
* Client driver terminates the connection with a
* dual error token and with error 596. Otherwise
* it sends the next requests before realising the
* session was terminated.
*/
TdsSendError(596, 1, ERROR,
"Cannot continue the execution because the session is in the kill state.", 1);

TdsSendDone(TDS_TOKEN_DONE, TDS_DONE_ERROR, 0, 0);
TdsFlush();

/* Terminate the connection. */
ereport(FATAL,
(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
errmsg("Reset Connection Failed")));
}
/* Else rethrow the error. */
PG_RE_THROW();
}
PG_END_TRY();
if (useDbCommand)
pfree(useDbCommand);
if (dbname)
pfree(dbname);
}

/*
* TdsSendLoginAck - Send a login acknowledgement to the client
*
Expand All @@ -2017,16 +2140,13 @@ void
TdsSendLoginAck(Port *port)
{
uint16_t temp16;
char *dbname = NULL;
int prognameLen = pg_mbstrlen(default_server_name);
LoginRequest request;
StringInfoData buf;
uint8 temp8;
uint32_t collationInfo;
char collationBytesNew[5];
char *useDbCommand = NULL;
char *user = NULL;
MemoryContext oldContext;
Oid roleid = InvalidOid;
uint32_t tdsVersion = pg_hton32(loginInfo->tdsVersion);
char srvVersionBytes[4];

Expand Down Expand Up @@ -2138,75 +2258,7 @@ TdsSendLoginAck(Port *port)
errmsg("\"%s\" is not a Babelfish user", port->user_name)));
}

oldContext = CurrentMemoryContext;

if (request->database != NULL && request->database[0] != '\0')
{
Oid db_id;

/*
* Before preparing the query, first check whether we got a valid
* database name and it exists. Otherwise, there'll be risk of
* SQL injection.
*/
StartTransactionCommand();
db_id = pltsql_plugin_handler_ptr->pltsql_get_database_oid(request->database);
CommitTransactionCommand();
MemoryContextSwitchTo(oldContext);

if (!OidIsValid(db_id))
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("database \"%s\" does not exist", request->database)));

/*
* Any delimitated/quoted db name identifier requested in login
* must be already handled before this point.
*/
useDbCommand = psprintf("USE [%s]", request->database);
dbname = pstrdup(request->database);
}
else
{
char *temp = NULL;

StartTransactionCommand();
temp = pltsql_plugin_handler_ptr->pltsql_get_login_default_db(port->user_name);
MemoryContextSwitchTo(oldContext);

if (temp == NULL)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("could not find default database for user \"%s\"", port->user_name)));

useDbCommand = psprintf("USE [%s]", temp);
dbname = pstrdup(temp);
CommitTransactionCommand();
MemoryContextSwitchTo(oldContext);
}

/*
* Check if user has privileges to access current database
*/
StartTransactionCommand();
user = pltsql_plugin_handler_ptr->pltsql_get_user_for_database(dbname);
if (!user)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("Cannot open database \"%s\" requested by the login. The login failed", dbname)));
CommitTransactionCommand();
if (dbname)
pfree(dbname);

/*
* Request has a database name provided, so we execute a "USE
* [<db_name>]" through pgtsql inline handler
*/
StartTransactionCommand();
ExecuteSQLBatch(useDbCommand);
CommitTransactionCommand();
if (useDbCommand)
pfree(useDbCommand);
TdsSetDbContext();

/*
* Set the GUC for language, it will take care of changing the GUC,
Expand Down
32 changes: 28 additions & 4 deletions contrib/babelfishpg_tds/src/backend/tds/tdsprotocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ typedef ResetConnectionData *ResetConnection;
TdsRequestCtrlData *TdsRequestCtrl = NULL;

ResetConnection resetCon = NULL;
static bool resetTdsConnectionFlag = false;
bool resetTdsConnectionFlag = false;

/* Local functions */
static void ResetTDSConnection(void);
Expand Down Expand Up @@ -155,16 +155,15 @@ ResetTDSConnection(void)
TdsResetCache();
TdsResponseReset();
TdsResetBcpOffset();
/* Retore previous isolation level when not called by sys.sp_reset_connection */
/* Retore previous isolation level when not called by sys.sp_reset_connection. */
if (!resetTdsConnectionFlag)
{
SetConfigOption("default_transaction_isolation", isolationOld,
PGC_BACKEND, PGC_S_CLIENT);
}

tvp_lookup_list = NIL;

/* send an environement change token is its not called via sys.sp_reset_connection procedure */
/* Send an environement change token is its not called via sys.sp_reset_connection procedure. */
if (!resetTdsConnectionFlag)
{
TdsSendEnvChange(TDS_ENVID_RESETCON, NULL, NULL);
Expand All @@ -179,6 +178,11 @@ void SetResetTDSConnectionFlag()
resetTdsConnectionFlag = true;
}

bool GetResetTDSConnectionFlag()
{
return resetTdsConnectionFlag;
}

/*
* GetTDSRequest - Fetch and parse a TDS packet and generate a TDS request that
* can be processed later.
Expand Down Expand Up @@ -290,7 +294,16 @@ GetTDSRequest(bool *resetProtocol)
resetCon->messageType = messageType;
resetCon->status = (status & ~TDS_PACKET_HEADER_STATUS_RESETCON);

/*
* Set resetTdsConnectionFlag to true so that we avoid
* sending any env change token for the USE DB command
* which will get executed.
*/
resetTdsConnectionFlag = true;
TdsSetDbContext();
resetTdsConnectionFlag = false;
ResetTDSConnection();

TdsErrorContext->err_text = "Fetching TDS Request";
*resetProtocol = true;
return NULL;
Expand Down Expand Up @@ -678,6 +691,17 @@ TdsSocketBackend(void)
case TDS_REQUEST_PHASE_FLUSH:
{
TdsErrorContext->phase = "TDS_REQUEST_PHASE_FLUSH";

if (resetTdsConnectionFlag)
{
/*
* We must set the Db Context before resetting TDS state,
* becasue we need the existing TDS state to flush any errors
* along with the reset.
*/
TdsSetDbContext();
}

/* Send the response now */
TdsFlush();

Expand Down
1 change: 1 addition & 0 deletions contrib/babelfishpg_tds/src/include/tds_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ extern int TdsProcessLogin(Port *port, bool LoadSsl);
extern void TdsSendLoginAck(Port *port);
extern uint32_t GetClientTDSVersion(void);
extern char *get_tds_login_domainname(void);
extern void TdsSetDbContext(void);

/* Functions in backend/tds/tdsprotocol.c */
extern int TdsSocketBackend(void);
Expand Down
3 changes: 3 additions & 0 deletions contrib/babelfishpg_tds/src/include/tds_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,8 @@ typedef struct

extern TdsRequestCtrlData *TdsRequestCtrl;
extern void SetResetTDSConnectionFlag(void);
extern bool GetResetTDSConnectionFlag(void);

extern bool resetTdsConnectionFlag;

#endif /* TDS_PROTOCOL_H */
21 changes: 13 additions & 8 deletions contrib/babelfishpg_tsql/src/pl_exec-2.c
Original file line number Diff line number Diff line change
Expand Up @@ -2725,14 +2725,19 @@ exec_stmt_usedb(PLtsql_execstate *estate, PLtsql_stmt_usedb *stmt)
top_es_entry = top_es_entry->next;
}

snprintf(message, sizeof(message), "Changed database context to '%s'.", stmt->db_name);
/* send env change token to user */
if (*pltsql_protocol_plugin_ptr && (*pltsql_protocol_plugin_ptr)->send_env_change)
((*pltsql_protocol_plugin_ptr)->send_env_change) (1, stmt->db_name, old_db_name);
/* send message to user */
if (*pltsql_protocol_plugin_ptr && (*pltsql_protocol_plugin_ptr)->send_info)
((*pltsql_protocol_plugin_ptr)->send_info) (0, 1, 0, message, 0);

/*
* In case of reset-connection we do not need to send the environment change token.
*/
if (!((*pltsql_protocol_plugin_ptr) && (*pltsql_protocol_plugin_ptr)->get_reset_tds_connection_flag()))
{
snprintf(message, sizeof(message), "Changed database context to '%s'.", stmt->db_name);
/* send env change token to user */
if (*pltsql_protocol_plugin_ptr && (*pltsql_protocol_plugin_ptr)->send_env_change)
((*pltsql_protocol_plugin_ptr)->send_env_change) (1, stmt->db_name, old_db_name);
/* send message to user */
if (*pltsql_protocol_plugin_ptr && (*pltsql_protocol_plugin_ptr)->send_info)
((*pltsql_protocol_plugin_ptr)->send_info) (0, 1, 0, message, 0);
}
return PLTSQL_RC_OK;
}

Expand Down
2 changes: 2 additions & 0 deletions contrib/babelfishpg_tsql/src/pltsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,8 @@ typedef struct PLtsql_protocol_plugin

void (*set_reset_tds_connection_flag) ();

bool (*get_reset_tds_connection_flag) ();

/* Session level GUCs */
bool quoted_identifier;
bool arithabort;
Expand Down
1 change: 0 additions & 1 deletion contrib/babelfishpg_tsql/src/session.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ void
reset_session_properties(void)
{
reset_cached_batch();
set_session_properties(get_cur_db_name());
}

void
Expand Down
Loading

0 comments on commit 50b72a9

Please sign in to comment.