Skip to content

Commit

Permalink
Code review comments
Browse files Browse the repository at this point in the history
Corrected formatting
Added case for decimal
used inbuilt function for getting procedure name
Added syntax error if identity_into function is called directly

Task: BABEL-539

Signed-off-by: Deepakshi Mittal <[email protected]>
  • Loading branch information
deepakshi-mittal committed Jul 30, 2023
1 parent 5293789 commit e0774d6
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 103 deletions.
2 changes: 1 addition & 1 deletion contrib/babelfishpg_tsql/runtime/functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ identity_into(PG_FUNCTION_ARGS)
int64 result;
Assert(tsql_select_into_seq_oid != InvalidOid);
result = nextval_internal(tsql_select_into_seq_oid, false);
return result;
PG_RETURN_INT64((int64) result);
}

/*
Expand Down
1 change: 1 addition & 0 deletions contrib/babelfishpg_tsql/sql/sys_functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3393,6 +3393,7 @@ GRANT EXECUTE ON FUNCTION sys.host_id() TO PUBLIC;

CREATE OR REPLACE FUNCTION sys.identity_into(IN typename INT, IN seed INT, IN increment INT)
RETURNS int AS 'babelfishpg_tsql' LANGUAGE C STABLE;
GRANT EXECUTE ON FUNCTION sys.identity_into(INT, INT, INT) TO PUBLIC;

CREATE OR REPLACE FUNCTION sys.degrees(IN arg1 BIGINT)
RETURNS bigint AS 'babelfishpg_tsql','bigint_degrees' LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ LANGUAGE C IMMUTABLE STRICT;

CREATE OR REPLACE FUNCTION sys.identity_into(IN typename INT, IN seed INT, IN increment INT)
RETURNS int AS 'babelfishpg_tsql' LANGUAGE C STABLE;
GRANT EXECUTE ON FUNCTION sys.identity_into(INT, INT, INT) TO PUBLIC;

CREATE OR REPLACE VIEW sys.sql_expression_dependencies
AS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,14 @@ TsqlFunctionIdentityInto(TypeName *typename, Node *seed, Node *increment, int lo
case INT4OID:
case INT8OID:
case NUMERICOID:
args = list_make3((Node *)makeIntConst((int)type_oid, -1), seed, increment);
args = list_make3((Node *)makeIntConst((int)type_oid, location), seed, increment);
result = (Node *) makeFuncCall(TsqlSystemFuncName("identity_into"), args, COERCE_EXPLICIT_CALL, location);
break;
default:
ereport(ERROR,
(errcode(ERRCODE_SYNTAX_ERROR),
errmsg("identity column type must be smallint, integer, or bigint")));
errmsg("identity column type must be smallint, integer, bigint, or numeric")));
break;

}

Expand Down
182 changes: 96 additions & 86 deletions contrib/babelfishpg_tsql/src/pl_handler.c
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ Datum sp_prepare(PG_FUNCTION_ARGS);
Datum sp_unprepare(PG_FUNCTION_ARGS);
static List *transformReturningList(ParseState *pstate, List *returningList);
static List *transformSelectIntoStmt(CreateTableAsStmt *stmt, const char *queryString);
static char *parse_type_argument(int type_oid);
static char *get_oid_type_string(int type_oid);
extern char *construct_unique_index_name(char *index_name, char *relation_name);
extern int CurrentLineNumber;
static non_tsql_proc_entry_hook_type prev_non_tsql_proc_entry_hook = NULL;
Expand Down Expand Up @@ -5026,9 +5026,16 @@ pltsql_revert_guc(int nest_level)

}

static char *parse_type_argument(int type_oid){
static char *get_oid_type_string(int type_oid){
char *type_string = NULL;
switch(type_oid){
if ((*common_utility_plugin_ptr->is_tsql_decimal_datatype) (type_oid))
{
type_string = "decimal";
return type_string;
}

switch(type_oid)
{
case INT2OID:
type_string = "pg_catalog.int2";
break;
Expand All @@ -5042,18 +5049,20 @@ static char *parse_type_argument(int type_oid){
type_string = "pg_catalog.numeric";
break;
default:
type_string = "decimal";
break;
ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Identity column type must be smallint, integer, bigint, or numeric")));
break;
}
return type_string;
}

static List * transformSelectIntoStmt(CreateTableAsStmt *stmt, const char *queryString){
List *result;
ListCell *elements;
static List *transformSelectIntoStmt(CreateTableAsStmt *stmt, const char *queryString)
{
List *result;
ListCell *elements;
CreateSeqStmt *seqstmt;
AlterSeqStmt *altseqstmt;
List *attnamelist;
List *attnamelist;
IntoClause *into;
Node *n;

Expand All @@ -5065,62 +5074,69 @@ static List * transformSelectIntoStmt(CreateTableAsStmt *stmt, const char *query

if (n && n->type == T_Query)
{
Query *q = (Query *) n;
Query *q = (Query *)n;
bool seen_identity = false;
foreach(elements, q->targetList)
{
TargetEntry *tle = (TargetEntry *) lfirst(elements);
FuncExpr *funcexpr;
if (tle->expr && IsA(tle->expr, FuncExpr) && strcasecmp(get_func_name(((FuncExpr *) (tle->expr))->funcid), "identity_into") ==0 ){

Oid snamespaceid;
foreach (elements, q->targetList)
{
TargetEntry *tle = (TargetEntry *)lfirst(elements);

if (tle->expr && IsA(tle->expr, FuncExpr) && strcasecmp(get_func_name(((FuncExpr *)(tle->expr))->funcid), "identity_into") == 0)
{
FuncExpr *funcexpr;
Oid snamespaceid;
char *snamespace;
char *sname;
List *seqoptions = NIL;
ListCell *arg;

int type_oid;
char *type= NULL;
char *type = NULL;
TypeName *ofTypename;
int64 seed_value;
int arg_num;

if(seen_identity){
if (seen_identity)
{
ereport(ERROR,
(errcode(ERRCODE_SYNTAX_ERROR), errmsg("Attempting to add multiple identity columns to table \"%s\" using the SELECT INTO statement.", into->rel->relname )));
(errcode(ERRCODE_SYNTAX_ERROR), errmsg("Attempting to add multiple identity columns to table \"%s\" using the SELECT INTO statement.", into->rel->relname)));
}

if(tle->resname == NULL){
if (tle->resname == NULL)
{
ereport(ERROR,
(errcode(ERRCODE_SYNTAX_ERROR), errmsg("Incorrect syntax near the keyword 'INTO'")));
(errcode(ERRCODE_SYNTAX_ERROR), errmsg("Incorrect syntax near the keyword 'INTO'")));
}

funcexpr = (FuncExpr *)tle->expr;
arg_num = 0;
foreach(arg, funcexpr->args)
foreach (arg, funcexpr->args)
{
Node *fargNode = (Node *) lfirst(arg);
Node *fargNode = (Node *)lfirst(arg);
Const *con;
arg_num++;
switch (arg_num){
switch (arg_num)
{
case 1:
con = (Const *) fargNode;
type_oid = (int) con->constvalue;
type = parse_type_argument(type_oid);
con = (Const *)fargNode;
type_oid = (int)con->constvalue;
type = get_oid_type_string(type_oid);
ofTypename = typeStringToTypeName(type);
seqoptions = lappend(seqoptions, makeDefElem("as", (Node *) ofTypename, -1));
seqoptions = lappend(seqoptions, makeDefElem("as", (Node *)ofTypename, -1));
break;
case 2:
con = (Const *) fargNode;
seqoptions = lappend(seqoptions, makeDefElem("start", (Node *)makeInteger(con->constvalue), -1));
seed_value = (int64) con->constvalue;
con = (Const *)fargNode;
seqoptions = lappend(seqoptions, makeDefElem("start", (Node *)makeInteger((int64)con->constvalue), -1));
seed_value = (int64)con->constvalue;
break;
case 3:
con = (Const *) fargNode;
seqoptions = lappend(seqoptions, makeDefElem("increment", (Node *)makeInteger(con->constvalue), -1));
if ((int) con->constvalue > 0){
con = (Const *)fargNode;
seqoptions = lappend(seqoptions, makeDefElem("increment", (Node *)makeInteger((int64)con->constvalue), -1));
if ((int)con->constvalue > 0)
{
seqoptions = lappend(seqoptions, makeDefElem("minvalue", (Node *)makeInteger(seed_value), -1));
}else{
}
else
{
seqoptions = lappend(seqoptions, makeDefElem("maxvalue", (Node *)makeInteger(seed_value), -1));
}
break;
Expand All @@ -5134,7 +5150,7 @@ static List * transformSelectIntoStmt(CreateTableAsStmt *stmt, const char *query

snamespaceid = RangeVarGetCreationNamespace(into->rel);
snamespace = get_namespace_name(snamespaceid);
sname = ChooseRelationName(into->rel->relname, tle->resname,"seq",snamespaceid, false);
sname = ChooseRelationName(into->rel->relname, tle->resname, "seq", snamespaceid, false);
seqstmt = makeNode(CreateSeqStmt);
seqstmt->for_identity = true;
seqstmt->sequence = makeRangeVar(snamespace, sname, -1);
Expand All @@ -5143,74 +5159,68 @@ static List * transformSelectIntoStmt(CreateTableAsStmt *stmt, const char *query

altseqstmt = makeNode(AlterSeqStmt);
altseqstmt->sequence = makeRangeVar(snamespace, sname, -1);
attnamelist = list_make3(makeString(snamespace), makeString(into->rel->relname),makeString(tle->resname));
altseqstmt->options = list_make1(makeDefElem("owned_by", (Node *) attnamelist, -1));
attnamelist = list_make3(makeString(snamespace), makeString(into->rel->relname), makeString(tle->resname));
altseqstmt->options = list_make1(makeDefElem("owned_by", (Node *)attnamelist, -1));
altseqstmt->for_identity = true;

}
}
}

if(seqstmt){
result = lappend(result, seqstmt);
if (seqstmt)
{
result = lappend(result, seqstmt);
}
result = lappend(result, stmt);

if(altseqstmt){
result = lappend(result, altseqstmt);
if (altseqstmt)
{
result = lappend(result, altseqstmt);
}

return result;
return result;
}

void pltsql_bbfSelectIntoUtility(ParseState *pstate, PlannedStmt *pstmt, const char *queryString, QueryEnvironment *queryEnv,
ParamListInfo params, QueryCompletion *qc)
{

void
pltsql_bbfSelectIntoUtility(ParseState *pstate, PlannedStmt *pstmt, const char *queryString, QueryEnvironment *queryEnv,
ParamListInfo params, QueryCompletion *qc){

Node *parsetree = pstmt->utilityStmt;
Node *parsetree = pstmt->utilityStmt;
ObjectAddress address;
ObjectAddress secondaryObject = InvalidObjectAddress;
List *stmts;
stmts = transformSelectIntoStmt((CreateTableAsStmt *) parsetree, queryString);
while (stmts != NIL)
stmts = transformSelectIntoStmt((CreateTableAsStmt *)parsetree, queryString);
while (stmts != NIL)
{
Node *stmt = (Node *)linitial(stmts);
stmts = list_delete_first(stmts);
if (IsA(stmt, CreateTableAsStmt))
{
Node *stmt = (Node *) linitial(stmts);
stmts = list_delete_first(stmts);
if (IsA(stmt, CreateTableAsStmt))
{
address = ExecCreateTableAs(pstate, (CreateTableAsStmt *) parsetree,params, queryEnv, qc);
EventTriggerCollectSimpleCommand(address,secondaryObject,stmt);
}
else if(IsA(stmt, CreateSeqStmt)) {
address = DefineSequence(pstate, (CreateSeqStmt *) stmt);
Assert(address.objectId != InvalidOid);
tsql_select_into_seq_oid = address.objectId;
EventTriggerCollectSimpleCommand(address,secondaryObject,stmt);
}
else{

PlannedStmt *wrapper;
wrapper = makeNode(PlannedStmt);
wrapper->commandType = CMD_UTILITY;
wrapper->canSetTag = false;
wrapper->utilityStmt = stmt;
wrapper->stmt_location = pstmt->stmt_location;
wrapper->stmt_len = pstmt->stmt_len;

ProcessUtility(wrapper,
queryString,
false,
PROCESS_UTILITY_SUBCOMMAND,
params,
NULL,
None_Receiver,
NULL);
}
address = ExecCreateTableAs(pstate, (CreateTableAsStmt *)parsetree, params, queryEnv, qc);
EventTriggerCollectSimpleCommand(address, secondaryObject, stmt);
}
else if (IsA(stmt, CreateSeqStmt))
{
address = DefineSequence(pstate, (CreateSeqStmt *)stmt);
Assert(address.objectId != InvalidOid);
tsql_select_into_seq_oid = address.objectId;
EventTriggerCollectSimpleCommand(address, secondaryObject, stmt);
}
else
{
PlannedStmt *wrapper;
wrapper = makeNode(PlannedStmt);
wrapper->commandType = CMD_UTILITY;
wrapper->canSetTag = false;
wrapper->utilityStmt = stmt;
wrapper->stmt_location = pstmt->stmt_location;
wrapper->stmt_len = pstmt->stmt_len;

ProcessUtility(wrapper, queryString, false, PROCESS_UTILITY_SUBCOMMAND, params, NULL, None_Receiver, NULL);
}
if (stmts != NIL)
CommandCounterIncrement();
}
}
}

void
set_current_query_is_create_tbl_check_constraint(Node *expr)
{
Expand Down
26 changes: 15 additions & 11 deletions contrib/babelfishpg_tsql/src/tsqlIface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1847,17 +1847,6 @@ class tsqlBuilder : public tsqlCommonMutator
}
}

void enterFunction_call(TSqlParser::Function_callContext *ctx) override
{
std::string func_name = ::getFullText(ctx);
size_t Lbracket_index = func_name.find('(');
func_name = func_name.substr(0, Lbracket_index);
if (pg_strcasecmp(func_name.c_str(), "identity") == 0 || pg_strcasecmp(func_name.c_str(), "identity_into") == 0
|| pg_strcasecmp(func_name.c_str(), "sys.identity_into") == 0) {
has_identity_function = true;
}
}

//////////////////////////////////////////////////////////////////////////////
// function/procedure call analysis
//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1966,6 +1955,21 @@ class tsqlBuilder : public tsqlCommonMutator
}

}

if (ctx->func_proc_name_server_database_schema()->procedure)
{
std::string proc_name = stripQuoteFromId(ctx->func_proc_name_server_database_schema()->procedure);
if (pg_strcasecmp(proc_name.c_str(), "identity") == 0)
{
has_identity_function = true;
}

if (pg_strcasecmp(proc_name.c_str(), "identity_into") == 0)
{
throw PGErrorWrapperException(ERROR, ERRCODE_FEATURE_NOT_SUPPORTED,
format_errmsg("function %s does not exist", proc_name.c_str()), getLineAndPos(ctx));
}
}
}
}

Expand Down
Loading

0 comments on commit e0774d6

Please sign in to comment.