Skip to content

Commit

Permalink
fix(sec): remove most instances of possible sql injection (#9404)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Jun 21, 2024
1 parent 37687e2 commit a555774
Show file tree
Hide file tree
Showing 26 changed files with 364 additions and 171 deletions.
2 changes: 1 addition & 1 deletion ci/check_disallowed_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def generate_dependency_graph(*args):
command = ("pydeps", "--show-deps", *args)
print(f"Running: {' '.join(command)}") # noqa: T201
result = subprocess.check_output(command, text=True)
result = subprocess.check_output(command, text=True) # noqa: S603
return json.loads(result)


Expand Down
2 changes: 1 addition & 1 deletion ci/make_geography_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def main() -> None:

args = parser.parse_args()

response = requests.get(args.input_data_url)
response = requests.get(args.input_data_url, timeout=600)
response.raise_for_status()
input_data = response.json()
db_path = Path(args.output_directory).joinpath("geography.duckdb")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
@st.cache_data
def get_emoji():
resp = requests.get(
"https://raw.githubusercontent.com/omnidan/node-emoji/master/lib/emoji.json"
"https://raw.githubusercontent.com/omnidan/node-emoji/master/lib/emoji.json",
timeout=60,
)
resp.raise_for_status()
emojis = resp.json()
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"https://www.googleapis.com/auth/drive",
]
CLIENT_ID = "546535678771-gvffde27nd83kfl6qbrnletqvkdmsese.apps.googleusercontent.com"
CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt"
CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt" # noqa: S105


def _create_user_agent(application_name: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/tests/unit/udf/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def f(a):
)
f.seek(0)
code = builtins.compile(f.read(), f.name, "exec")
exec(code, d)
exec(code, d) # noqa: S102
f = d["f"]
js = compile(f)
snapshot.assert_match(js, "out.js")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def insert(
elif not isinstance(obj, ir.Table):
obj = ibis.memtable(obj)

query = self._build_insert_query(target=name, source=obj)
query = self._build_insert_from_table(target=name, source=obj)
external_tables = self._collect_in_memory_tables(obj, {})
external_data = self._normalize_external_tables(external_tables)
return self.con.command(query.sql(self.name), external_data=external_data)
Expand Down
10 changes: 8 additions & 2 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def test_temp_directory(tmp_path):
@pytest.fixture(scope="session")
def pgurl(): # pragma: no cover
pgcon = ibis.postgres.connect(
user="postgres", password="postgres", host="localhost"
user="postgres",
password="postgres", # noqa: S106
host="localhost",
)

df = pd.DataFrame({"x": [1.0, 2.0, 3.0, 1.0], "y": ["a", "b", "c", "a"]})
Expand All @@ -193,7 +195,11 @@ def test_read_postgres(con, pgurl): # pragma: no cover

@pytest.fixture(scope="session")
def mysqlurl(): # pragma: no cover
mysqlcon = ibis.mysql.connect(user="ibis", password="ibis", database="ibis_testing")
mysqlcon = ibis.mysql.connect(
user="ibis",
password="ibis", # noqa: S106
database="ibis_testing",
)

df = pd.DataFrame({"x": [1.0, 2.0, 3.0, 1.0], "y": ["a", "b", "c", "a"]})
s = ibis.schema(dict(x="float64", y="str"))
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,9 +1239,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
).sql(self.name, pretty=True)

data = op.data.to_frame().itertuples(index=False)
specs = ", ".join("?" * len(schema))
table = sg.table(name, quoted=quoted).sql(self.name)
insert_stmt = f"INSERT INTO {table} VALUES ({specs})"
insert_stmt = self._build_insert_template(name, schema=schema)
with self._safe_raw_sql(create_stmt) as cur:
for row in data:
cur.execute(insert_stmt, row)
Expand Down
97 changes: 61 additions & 36 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import datetime
import struct
from contextlib import closing
from functools import partial
from itertools import repeat
from operator import itemgetter
from typing import TYPE_CHECKING, Any

Expand All @@ -25,7 +23,7 @@
from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl
from ibis.backends.mssql.compiler import MSSQLCompiler
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import C
from ibis.backends.sql.compiler import STAR, C

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -287,76 +285,111 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
return cursor

def create_catalog(self, name: str, force: bool = False) -> None:
name = self._quote(name)
expr = (
sg.select(STAR)
.from_(sg.table("databases", db="sys"))
.where(C.name.eq(sge.convert(name)))
)
stmt = sge.Create(
kind="DATABASE", this=sg.to_identifier(name, quoted=self.compiler.quoted)
).sql(self.dialect)
create_stmt = (
f"""\
IF NOT EXISTS (SELECT name FROM sys.databases WHERE name = {name})
IF NOT EXISTS ({expr.sql(self.dialect)})
BEGIN
CREATE DATABASE {name};
{stmt};
END;
GO"""
if force
else f"CREATE DATABASE {name}"
else stmt
)
with self._safe_raw_sql(create_stmt):
pass

def drop_catalog(self, name: str, force: bool = False) -> None:
name = self._quote(name)
if_exists = "IF EXISTS " * force

with self._safe_raw_sql(f"DROP DATABASE {if_exists}{name}"):
with self._safe_raw_sql(
sge.Drop(
kind="DATABASE",
this=sg.to_identifier(name, quoted=self.compiler.quoted),
exists=force,
)
):
pass

def create_database(
self, name: str, catalog: str | None = None, force: bool = False
) -> None:
current_catalog = self.current_catalog
should_switch_catalog = catalog is not None and catalog != current_catalog
quoted = self.compiler.quoted

name = self._quote(name)
expr = (
sg.select(STAR)
.from_(sg.table("schemas", db="sys"))
.where(C.name.eq(sge.convert(name)))
)
stmt = sge.Create(
kind="SCHEMA", this=sg.to_identifier(name, quoted=quoted)
).sql(self.dialect)

create_stmt = (
f"""\
IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = {name})
IF NOT EXISTS ({expr.sql(self.dialect)})
BEGIN
CREATE SCHEMA {name};
{stmt};
END;
GO"""
if force
else f"CREATE SCHEMA {name}"
else stmt
)

with self.begin() as cur:
if should_switch_catalog:
cur.execute(f"USE {self._quote(catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(catalog, quoted=quoted)).sql(
self.dialect
)
)

cur.execute(create_stmt)

if should_switch_catalog:
cur.execute(f"USE {self._quote(current_catalog)}")

def _quote(self, name: str):
return sg.to_identifier(name, quoted=True).sql(self.dialect)
cur.execute(
sge.Use(this=sg.to_identifier(current_catalog, quoted=quoted)).sql(
self.dialect
)
)

def drop_database(
self, name: str, catalog: str | None = None, force: bool = False
) -> None:
current_catalog = self.current_catalog
should_switch_catalog = catalog is not None and catalog != current_catalog

name = self._quote(name)

if_exists = "IF EXISTS " * force
quoted = self.compiler.quoted

with self.begin() as cur:
if should_switch_catalog:
cur.execute(f"USE {self._quote(catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(catalog, quoted=quoted)).sql(
self.dialect
)
)

cur.execute(f"DROP SCHEMA {if_exists}{name}")
cur.execute(
sge.Drop(
kind="SCHEMA",
exists=force,
this=sg.to_identifier(name, quoted=quoted),
).sql(self.dialect)
)

if should_switch_catalog:
cur.execute(f"USE {self._quote(current_catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(current_catalog, quoted=quoted)).sql(
self.dialect
)
)

def list_tables(
self,
Expand Down Expand Up @@ -570,19 +603,11 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:

df = op.data.to_frame()
data = df.itertuples(index=False)
cols = ", ".join(
ident.sql(self.dialect)
for ident in map(
partial(sg.to_identifier, quoted=quoted), schema.keys()
)
)
specs = ", ".join(repeat("?", len(schema)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})"

insert_stmt = self._build_insert_template(name, schema=schema, columns=True)
with self._safe_raw_sql(create_stmt) as cur:
if not df.empty:
cur.executemany(sql, data)
cur.executemany(insert_stmt, data)

def _to_sqlglot(
self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any
Expand Down
31 changes: 18 additions & 13 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import contextlib
import re
import warnings
from functools import cached_property, partial
from itertools import repeat
from functools import cached_property
from operator import itemgetter
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse
Expand All @@ -26,7 +25,7 @@
from ibis.backends.mysql.compiler import MySQLCompiler
from ibis.backends.mysql.datatypes import _type_from_cursor_info
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import TRUE, C
from ibis.backends.sql.compiler import STAR, TRUE, C

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -195,7 +194,16 @@ def list_databases(self, like: str | None = None) -> list[str]:

def _get_schema_using_query(self, query: str) -> sch.Schema:
with self.begin() as cur:
cur.execute(f"SELECT * FROM ({query}) AS tmp LIMIT 0")
cur.execute(
sg.select(STAR)
.from_(
sg.parse_one(query, dialect=self.dialect).subquery(
sg.to_identifier("tmp", quoted=self.compiler.quoted)
)
)
.limit(0)
.sql(self.dialect)
)

return sch.Schema(
{
Expand All @@ -207,10 +215,12 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
) -> sch.Schema:
table = sg.table(name, db=database, catalog=catalog, quoted=True).sql(self.name)
table = sg.table(
name, db=database, catalog=catalog, quoted=self.compiler.quoted
).sql(self.dialect)

with self.begin() as cur:
cur.execute(f"DESCRIBE {table}")
cur.execute(sge.Describe(this=table).sql(self.dialect))
result = cur.fetchall()

type_mapper = self.compiler.type_mapper
Expand Down Expand Up @@ -497,19 +507,14 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
)
create_stmt_sql = create_stmt.sql(self.name)

columns = schema.keys()
df = op.data.to_frame()
# nan can not be used with MySQL
df = df.replace(np.nan, None)

data = df.itertuples(index=False)
cols = ", ".join(
ident.sql(self.name)
for ident in map(partial(sg.to_identifier, quoted=quoted), columns)
sql = self._build_insert_template(
name, schema=schema, columns=True, placeholder="%s"
)
specs = ", ".join(repeat("%s", len(columns)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.name)} ({cols}) VALUES ({specs})"
with self.begin() as cur:
cur.execute(create_stmt_sql)

Expand Down
17 changes: 10 additions & 7 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from operator import itemgetter
from typing import TYPE_CHECKING, Any

import numpy as np
import oracledb
import sqlglot as sg
import sqlglot.expressions as sge
Expand Down Expand Up @@ -501,16 +502,18 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
this=sg.to_identifier(name, quoted=quoted), expressions=column_defs
),
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
).sql(self.name, pretty=True)
).sql(self.name)

data = op.data.to_frame().itertuples(index=False)
specs = ", ".join(f":{i}" for i, _ in enumerate(schema))
table = sg.table(name, quoted=quoted).sql(self.name)
insert_stmt = f"INSERT INTO {table} VALUES ({specs})"
data = op.data.to_frame().replace({np.nan: None})
insert_stmt = self._build_insert_template(
name, schema=schema, placeholder=":{i:d}"
)
with self.begin() as cur:
cur.execute(create_stmt)
for row in data:
cur.execute(insert_stmt, row)
for start, end in util.chunks(len(data), chunk_size=128):
cur.executemany(
insert_stmt, list(data.iloc[start:end].itertuples(index=False))
)

atexit.register(self._clean_up_tmp_table, name)

Expand Down
Loading

0 comments on commit a555774

Please sign in to comment.