Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sec): remove most instances of possible sql injection #9404

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/check_disallowed_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def generate_dependency_graph(*args):
command = ("pydeps", "--show-deps", *args)
print(f"Running: {' '.join(command)}") # noqa: T201
result = subprocess.check_output(command, text=True)
result = subprocess.check_output(command, text=True) # noqa: S603
return json.loads(result)


Expand Down
2 changes: 1 addition & 1 deletion ci/make_geography_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def main() -> None:

args = parser.parse_args()

response = requests.get(args.input_data_url)
response = requests.get(args.input_data_url, timeout=600)
response.raise_for_status()
input_data = response.json()
db_path = Path(args.output_directory).joinpath("geography.duckdb")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
@st.cache_data
def get_emoji():
resp = requests.get(
"https://raw.githubusercontent.com/omnidan/node-emoji/master/lib/emoji.json"
"https://raw.githubusercontent.com/omnidan/node-emoji/master/lib/emoji.json",
timeout=60,
)
resp.raise_for_status()
emojis = resp.json()
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"https://www.googleapis.com/auth/drive",
]
CLIENT_ID = "546535678771-gvffde27nd83kfl6qbrnletqvkdmsese.apps.googleusercontent.com"
CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt"
CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt" # noqa: S105


def _create_user_agent(application_name: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/tests/unit/udf/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def f(a):
)
f.seek(0)
code = builtins.compile(f.read(), f.name, "exec")
exec(code, d)
exec(code, d) # noqa: S102
f = d["f"]
js = compile(f)
snapshot.assert_match(js, "out.js")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@
elif not isinstance(obj, ir.Table):
obj = ibis.memtable(obj)

query = self._build_insert_query(target=name, source=obj)
query = self._build_insert_from_table(target=name, source=obj)

Check warning on line 426 in ibis/backends/clickhouse/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/clickhouse/__init__.py#L426

Added line #L426 was not covered by tests
external_tables = self._collect_in_memory_tables(obj, {})
external_data = self._normalize_external_tables(external_tables)
return self.con.command(query.sql(self.name), external_data=external_data)
Expand Down
10 changes: 8 additions & 2 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def test_temp_directory(tmp_path):
@pytest.fixture(scope="session")
def pgurl(): # pragma: no cover
pgcon = ibis.postgres.connect(
user="postgres", password="postgres", host="localhost"
user="postgres",
password="postgres", # noqa: S106
host="localhost",
)

df = pd.DataFrame({"x": [1.0, 2.0, 3.0, 1.0], "y": ["a", "b", "c", "a"]})
Expand All @@ -193,7 +195,11 @@ def test_read_postgres(con, pgurl): # pragma: no cover

@pytest.fixture(scope="session")
def mysqlurl(): # pragma: no cover
mysqlcon = ibis.mysql.connect(user="ibis", password="ibis", database="ibis_testing")
mysqlcon = ibis.mysql.connect(
user="ibis",
password="ibis", # noqa: S106
database="ibis_testing",
)

df = pd.DataFrame({"x": [1.0, 2.0, 3.0, 1.0], "y": ["a", "b", "c", "a"]})
s = ibis.schema(dict(x="float64", y="str"))
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,9 +1239,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
).sql(self.name, pretty=True)

data = op.data.to_frame().itertuples(index=False)
specs = ", ".join("?" * len(schema))
table = sg.table(name, quoted=quoted).sql(self.name)
insert_stmt = f"INSERT INTO {table} VALUES ({specs})"
insert_stmt = self._build_insert_template(name, schema=schema)
with self._safe_raw_sql(create_stmt) as cur:
for row in data:
cur.execute(insert_stmt, row)
Expand Down
97 changes: 61 additions & 36 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import datetime
import struct
from contextlib import closing
from functools import partial
from itertools import repeat
from operator import itemgetter
from typing import TYPE_CHECKING, Any

Expand All @@ -25,7 +23,7 @@
from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl
from ibis.backends.mssql.compiler import MSSQLCompiler
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import C
from ibis.backends.sql.compiler import STAR, C

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -287,76 +285,111 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
return cursor

def create_catalog(self, name: str, force: bool = False) -> None:
name = self._quote(name)
expr = (
sg.select(STAR)
.from_(sg.table("databases", db="sys"))
.where(C.name.eq(sge.convert(name)))
)
stmt = sge.Create(
kind="DATABASE", this=sg.to_identifier(name, quoted=self.compiler.quoted)
).sql(self.dialect)
create_stmt = (
f"""\
IF NOT EXISTS (SELECT name FROM sys.databases WHERE name = {name})
IF NOT EXISTS ({expr.sql(self.dialect)})
BEGIN
CREATE DATABASE {name};
{stmt};
END;
GO"""
if force
else f"CREATE DATABASE {name}"
else stmt
)
with self._safe_raw_sql(create_stmt):
pass

def drop_catalog(self, name: str, force: bool = False) -> None:
name = self._quote(name)
if_exists = "IF EXISTS " * force

with self._safe_raw_sql(f"DROP DATABASE {if_exists}{name}"):
with self._safe_raw_sql(
sge.Drop(
kind="DATABASE",
this=sg.to_identifier(name, quoted=self.compiler.quoted),
exists=force,
)
):
pass

def create_database(
self, name: str, catalog: str | None = None, force: bool = False
) -> None:
current_catalog = self.current_catalog
should_switch_catalog = catalog is not None and catalog != current_catalog
quoted = self.compiler.quoted

name = self._quote(name)
expr = (
sg.select(STAR)
.from_(sg.table("schemas", db="sys"))
.where(C.name.eq(sge.convert(name)))
)
stmt = sge.Create(
kind="SCHEMA", this=sg.to_identifier(name, quoted=quoted)
).sql(self.dialect)

create_stmt = (
f"""\
IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = {name})
IF NOT EXISTS ({expr.sql(self.dialect)})
BEGIN
CREATE SCHEMA {name};
{stmt};
END;
GO"""
if force
else f"CREATE SCHEMA {name}"
else stmt
)

with self.begin() as cur:
if should_switch_catalog:
cur.execute(f"USE {self._quote(catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(catalog, quoted=quoted)).sql(
self.dialect
)
)

cur.execute(create_stmt)

if should_switch_catalog:
cur.execute(f"USE {self._quote(current_catalog)}")

def _quote(self, name: str):
return sg.to_identifier(name, quoted=True).sql(self.dialect)
cur.execute(
sge.Use(this=sg.to_identifier(current_catalog, quoted=quoted)).sql(
self.dialect
)
)

def drop_database(
self, name: str, catalog: str | None = None, force: bool = False
) -> None:
current_catalog = self.current_catalog
should_switch_catalog = catalog is not None and catalog != current_catalog

name = self._quote(name)

if_exists = "IF EXISTS " * force
quoted = self.compiler.quoted

with self.begin() as cur:
if should_switch_catalog:
cur.execute(f"USE {self._quote(catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(catalog, quoted=quoted)).sql(
self.dialect
)
)

cur.execute(f"DROP SCHEMA {if_exists}{name}")
cur.execute(
sge.Drop(
kind="SCHEMA",
exists=force,
this=sg.to_identifier(name, quoted=quoted),
).sql(self.dialect)
)

if should_switch_catalog:
cur.execute(f"USE {self._quote(current_catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(current_catalog, quoted=quoted)).sql(
self.dialect
)
)

def list_tables(
self,
Expand Down Expand Up @@ -570,19 +603,11 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:

df = op.data.to_frame()
data = df.itertuples(index=False)
cols = ", ".join(
ident.sql(self.dialect)
for ident in map(
partial(sg.to_identifier, quoted=quoted), schema.keys()
)
)
specs = ", ".join(repeat("?", len(schema)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})"

insert_stmt = self._build_insert_template(name, schema=schema, columns=True)
with self._safe_raw_sql(create_stmt) as cur:
if not df.empty:
cur.executemany(sql, data)
cur.executemany(insert_stmt, data)

def _to_sqlglot(
self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any
Expand Down
31 changes: 18 additions & 13 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import contextlib
import re
import warnings
from functools import cached_property, partial
from itertools import repeat
from functools import cached_property
from operator import itemgetter
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse
Expand All @@ -26,7 +25,7 @@
from ibis.backends.mysql.compiler import MySQLCompiler
from ibis.backends.mysql.datatypes import _type_from_cursor_info
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import TRUE, C
from ibis.backends.sql.compiler import STAR, TRUE, C

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -195,7 +194,16 @@ def list_databases(self, like: str | None = None) -> list[str]:

def _get_schema_using_query(self, query: str) -> sch.Schema:
with self.begin() as cur:
cur.execute(f"SELECT * FROM ({query}) AS tmp LIMIT 0")
cur.execute(
sg.select(STAR)
.from_(
sg.parse_one(query, dialect=self.dialect).subquery(
sg.to_identifier("tmp", quoted=self.compiler.quoted)
)
)
.limit(0)
.sql(self.dialect)
)

return sch.Schema(
{
Expand All @@ -207,10 +215,12 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
) -> sch.Schema:
table = sg.table(name, db=database, catalog=catalog, quoted=True).sql(self.name)
table = sg.table(
name, db=database, catalog=catalog, quoted=self.compiler.quoted
).sql(self.dialect)

with self.begin() as cur:
cur.execute(f"DESCRIBE {table}")
cur.execute(sge.Describe(this=table).sql(self.dialect))
result = cur.fetchall()

type_mapper = self.compiler.type_mapper
Expand Down Expand Up @@ -497,19 +507,14 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
)
create_stmt_sql = create_stmt.sql(self.name)

columns = schema.keys()
df = op.data.to_frame()
# nan can not be used with MySQL
df = df.replace(np.nan, None)

data = df.itertuples(index=False)
cols = ", ".join(
ident.sql(self.name)
for ident in map(partial(sg.to_identifier, quoted=quoted), columns)
sql = self._build_insert_template(
name, schema=schema, columns=True, placeholder="%s"
)
specs = ", ".join(repeat("%s", len(columns)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.name)} ({cols}) VALUES ({specs})"
with self.begin() as cur:
cur.execute(create_stmt_sql)

Expand Down
17 changes: 10 additions & 7 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from operator import itemgetter
from typing import TYPE_CHECKING, Any

import numpy as np
import oracledb
import sqlglot as sg
import sqlglot.expressions as sge
Expand Down Expand Up @@ -501,16 +502,18 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
this=sg.to_identifier(name, quoted=quoted), expressions=column_defs
),
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
).sql(self.name, pretty=True)
).sql(self.name)

data = op.data.to_frame().itertuples(index=False)
specs = ", ".join(f":{i}" for i, _ in enumerate(schema))
table = sg.table(name, quoted=quoted).sql(self.name)
insert_stmt = f"INSERT INTO {table} VALUES ({specs})"
data = op.data.to_frame().replace({np.nan: None})
insert_stmt = self._build_insert_template(
name, schema=schema, placeholder=":{i:d}"
)
with self.begin() as cur:
cur.execute(create_stmt)
for row in data:
cur.execute(insert_stmt, row)
for start, end in util.chunks(len(data), chunk_size=128):
cur.executemany(
insert_stmt, list(data.iloc[start:end].itertuples(index=False))
)

atexit.register(self._clean_up_tmp_table, name)

Expand Down
Loading
Loading