Skip to content

Commit

Permalink
fix(sec): remove most instances of possible sql injection
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jun 18, 2024
1 parent bdc1b3f commit 7191e8e
Show file tree
Hide file tree
Showing 19 changed files with 294 additions and 155 deletions.
2 changes: 1 addition & 1 deletion ci/check_disallowed_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def generate_dependency_graph(*args):
command = ("pydeps", "--show-deps", *args)
print(f"Running: {' '.join(command)}") # noqa: T201
result = subprocess.check_output(command, text=True)
result = subprocess.check_output(command, text=True) # noqa: S603
return json.loads(result)


Expand Down
2 changes: 1 addition & 1 deletion ci/make_geography_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def main() -> None:

args = parser.parse_args()

response = requests.get(args.input_data_url)
response = requests.get(args.input_data_url, timeout=600)
response.raise_for_status()
input_data = response.json()
db_path = Path(args.output_directory).joinpath("geography.duckdb")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
@st.cache_data
def get_emoji():
resp = requests.get(
"https://raw.githubusercontent.com/omnidan/node-emoji/master/lib/emoji.json"
"https://raw.githubusercontent.com/omnidan/node-emoji/master/lib/emoji.json",
timeout=60,
)
resp.raise_for_status()
emojis = resp.json()
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"https://www.googleapis.com/auth/drive",
]
CLIENT_ID = "546535678771-gvffde27nd83kfl6qbrnletqvkdmsese.apps.googleusercontent.com"
CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt"
CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt" # noqa: S105


def _create_user_agent(application_name: str) -> str:
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,9 +1239,9 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
).sql(self.name, pretty=True)

data = op.data.to_frame().itertuples(index=False)
specs = ", ".join("?" * len(schema))
table = sg.table(name, quoted=quoted).sql(self.name)
insert_stmt = f"INSERT INTO {table} VALUES ({specs})"
insert_stmt = self._make_insert_stmt(
table=sg.table(name, quoted=quoted), schema=schema
)
with self._safe_raw_sql(create_stmt) as cur:
for row in data:
cur.execute(insert_stmt, row)
Expand Down
99 changes: 63 additions & 36 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import datetime
import struct
from contextlib import closing
from functools import partial
from itertools import repeat
from operator import itemgetter
from typing import TYPE_CHECKING, Any

Expand All @@ -25,7 +23,7 @@
from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl
from ibis.backends.mssql.compiler import MSSQLCompiler
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import C
from ibis.backends.sql.compiler import STAR, C

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -287,76 +285,111 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
return cursor

def create_catalog(self, name: str, force: bool = False) -> None:
name = self._quote(name)
expr = (
sg.select(STAR)
.from_(sg.table("databases", db="sys"))
.where(C.name.eq(sge.convert(name)))
)
stmt = sge.Create(
kind="DATABASE", this=sg.to_identifier(name, quoted=self.compiler.quoted)
).sql(self.dialect)
create_stmt = (
f"""\
IF NOT EXISTS (SELECT name FROM sys.databases WHERE name = {name})
IF NOT EXISTS ({expr.sql(self.dialect)})
BEGIN
CREATE DATABASE {name};
{stmt};
END;
GO"""
if force
else f"CREATE DATABASE {name}"
else stmt
)
with self._safe_raw_sql(create_stmt):
pass

def drop_catalog(self, name: str, force: bool = False) -> None:
name = self._quote(name)
if_exists = "IF EXISTS " * force

with self._safe_raw_sql(f"DROP DATABASE {if_exists}{name}"):
with self._safe_raw_sql(
sge.Drop(
kind="DATABASE",
this=sg.to_identifier(name, quoted=self.compiler.quoted),
exists=force,
)
):
pass

def create_database(
self, name: str, catalog: str | None = None, force: bool = False
) -> None:
current_catalog = self.current_catalog
should_switch_catalog = catalog is not None and catalog != current_catalog
quoted = self.compiler.quoted

name = self._quote(name)
expr = (
sg.select(STAR)
.from_(sg.table("schemas", db="sys"))
.where(C.name.eq(sge.convert(name)))
)
stmt = sge.Create(
kind="SCHEMA", this=sg.to_identifier(name, quoted=quoted)
).sql(self.dialect)

create_stmt = (
f"""\
IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = {name})
IF NOT EXISTS ({expr.sql(self.dialect)})
BEGIN
CREATE SCHEMA {name};
{stmt};
END;
GO"""
if force
else f"CREATE SCHEMA {name}"
else stmt
)

with self.begin() as cur:
if should_switch_catalog:
cur.execute(f"USE {self._quote(catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(catalog, quoted=quoted)).sql(
self.dialect
)
)

cur.execute(create_stmt)

if should_switch_catalog:
cur.execute(f"USE {self._quote(current_catalog)}")

def _quote(self, name: str):
return sg.to_identifier(name, quoted=True).sql(self.dialect)
cur.execute(
sge.Use(this=sg.to_identifier(current_catalog, quoted=quoted)).sql(
self.dialect
)
)

def drop_database(
self, name: str, catalog: str | None = None, force: bool = False
) -> None:
current_catalog = self.current_catalog
should_switch_catalog = catalog is not None and catalog != current_catalog

name = self._quote(name)

if_exists = "IF EXISTS " * force
quoted = self.compiler.quoted

with self.begin() as cur:
if should_switch_catalog:
cur.execute(f"USE {self._quote(catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(catalog, quoted=quoted)).sql(
self.dialect
)
)

cur.execute(f"DROP SCHEMA {if_exists}{name}")
cur.execute(
sge.Drop(
kind="SCHEMA",
exists=force,
this=sg.to_identifier(name, quoted=quoted),
).sql(self.dialect)
)

if should_switch_catalog:
cur.execute(f"USE {self._quote(current_catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(current_catalog, quoted=quoted)).sql(
self.dialect
)
)

def list_tables(
self,
Expand Down Expand Up @@ -570,19 +603,13 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:

df = op.data.to_frame()
data = df.itertuples(index=False)
cols = ", ".join(
ident.sql(self.dialect)
for ident in map(
partial(sg.to_identifier, quoted=quoted), schema.keys()
)
)
specs = ", ".join(repeat("?", len(schema)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})"

insert_stmt = self._make_insert_stmt(
table=sg.table(name, quoted=quoted), schema=schema, columns=True
)
with self._safe_raw_sql(create_stmt) as cur:
if not df.empty:
cur.executemany(sql, data)
cur.executemany(insert_stmt, data)

def _to_sqlglot(
self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any
Expand Down
34 changes: 21 additions & 13 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import contextlib
import re
import warnings
from functools import cached_property, partial
from itertools import repeat
from functools import cached_property
from operator import itemgetter
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse
Expand All @@ -26,7 +25,7 @@
from ibis.backends.mysql.compiler import MySQLCompiler
from ibis.backends.mysql.datatypes import _type_from_cursor_info
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import TRUE, C
from ibis.backends.sql.compiler import STAR, TRUE, C

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -191,7 +190,16 @@ def list_databases(self, like: str | None = None) -> list[str]:

def _get_schema_using_query(self, query: str) -> sch.Schema:
with self.begin() as cur:
cur.execute(f"SELECT * FROM ({query}) AS tmp LIMIT 0")
cur.execute(
sg.select(STAR)
.from_(
sg.parse_one(query, dialect=self.dialect).subquery(
sg.to_identifier("tmp", quoted=self.compiler.quoted)
)
)
.limit(0)
.sql(self.dialect)
)

return sch.Schema(
{
Expand All @@ -203,10 +211,12 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
) -> sch.Schema:
table = sg.table(name, db=database, catalog=catalog, quoted=True).sql(self.name)
table = sg.table(
name, db=database, catalog=catalog, quoted=self.compiler.quoted
).sql(self.dialect)

with self.begin() as cur:
cur.execute(f"DESCRIBE {table}")
cur.execute(sge.Describe(this=table).sql(self.dialect))
result = cur.fetchall()

type_mapper = self.compiler.type_mapper
Expand Down Expand Up @@ -493,19 +503,17 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
)
create_stmt_sql = create_stmt.sql(self.name)

columns = schema.keys()
df = op.data.to_frame()
# nan can not be used with MySQL
df = df.replace(np.nan, None)

data = df.itertuples(index=False)
cols = ", ".join(
ident.sql(self.name)
for ident in map(partial(sg.to_identifier, quoted=quoted), columns)
sql = self._make_insert_stmt(
table=sg.table(name, quoted=quoted),
schema=schema,
columns=True,
fmt="%s",
)
specs = ", ".join(repeat("%s", len(columns)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.name)} ({cols}) VALUES ({specs})"
with self.begin() as cur:
cur.execute(create_stmt_sql)

Expand Down
13 changes: 7 additions & 6 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from operator import itemgetter
from typing import TYPE_CHECKING, Any

import numpy as np
import oracledb
import sqlglot as sg
import sqlglot.expressions as sge
Expand Down Expand Up @@ -501,15 +502,15 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
this=sg.to_identifier(name, quoted=quoted), expressions=column_defs
),
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
).sql(self.name, pretty=True)
).sql(self.name)

data = op.data.to_frame().itertuples(index=False)
specs = ", ".join(f":{i}" for i, _ in enumerate(schema))
table = sg.table(name, quoted=quoted).sql(self.name)
insert_stmt = f"INSERT INTO {table} VALUES ({specs})"
data = op.data.to_frame().replace({np.nan: None})
insert_stmt = self._make_insert_stmt(
table=sg.table(name, quoted=quoted), schema=schema, fmt=":{i:d}"
)
with self.begin() as cur:
cur.execute(create_stmt)
for row in data:
for row in data.itertuples(index=False):
cur.execute(insert_stmt, row)

atexit.register(self._clean_up_tmp_table, name)
Expand Down
14 changes: 6 additions & 8 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import inspect
import textwrap
from functools import partial
from itertools import repeat, takewhile
from itertools import takewhile
from operator import itemgetter
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse
Expand Down Expand Up @@ -148,7 +148,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
)
create_stmt_sql = create_stmt.sql(self.dialect)

columns = schema.keys()
df = op.data.to_frame()
# nan gets compiled into 'NaN'::float which throws errors in non-float columns
# In order to hold NaN values, pandas automatically converts integer columns
Expand All @@ -161,13 +160,12 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
df[col] = df[col].replace(np.nan, None)

data = df.itertuples(index=False)
cols = ", ".join(
ident.sql(self.dialect)
for ident in map(partial(sg.to_identifier, quoted=quoted), columns)
sql = self._make_insert_stmt(
table=sg.table(name, quoted=quoted),
schema=schema,
columns=True,
fmt="%s",
)
specs = ", ".join(repeat("%s", len(columns)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})"

with self.begin() as cur:
cur.execute(create_stmt_sql)
Expand Down
Loading

0 comments on commit 7191e8e

Please sign in to comment.