Skip to content

Commit

Permalink
Rewrite the Drill DB-API implementation using ijson. (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnturton authored Jul 8, 2021
1 parent 26da485 commit e67dde1
Show file tree
Hide file tree
Showing 6 changed files with 532 additions and 368 deletions.
18 changes: 10 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
long_description = f.read()

setup(name='sqlalchemy_drill',
version='1.0.0',
version='1.1.0',
description="Apache Drill for SQLAlchemy",
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -49,20 +49,22 @@
],
install_requires=[
"requests",
"numpy",
"pandas",
"ijson",
"sqlalchemy"
],
extras_require={
"jdbc": ["JPype1", "JayDeBeApi"],
"odbc": ["pyodbc"],
},
keywords='SQLAlchemy Apache Drill',
author='John Omernik, Charles Givre, Davide Miceli, Massimo Martiradonna',
author_email='[email protected], [email protected], [email protected], [email protected]',
author='John Omernik, Charles Givre, Davide Miceli, Massimo Martiradonna'
', James Turton',
author_email='[email protected], [email protected], davide.miceli.dap'
'@gmail.com, [email protected], [email protected]',
license='MIT',
url = 'https://github.com/JohnOmernik/sqlalchemy-drill',
download_url = 'https://github.com/JohnOmernik/sqlalchemy-drill/archive/1.0.0.tar.gz',
url='https://github.com/JohnOmernik/sqlalchemy-drill',
download_url='https://github.com/JohnOmernik/sqlalchemy-drill/archive/'
'1.1.0.tar.gz',
packages=find_packages(),
include_package_data=True,
tests_require=['nose >= 0.11'],
Expand All @@ -76,4 +78,4 @@
'drill.odbc = sqlalchemy_drill.odbc:DrillDialect_odbc',
]
}
)
)
2 changes: 1 addition & 1 deletion sqlalchemy_drill/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

__version__ = '1.0.0'
__version__ = '1.1.0'
from sqlalchemy.dialects import registry

registry.register("drill", "sqlalchemy_drill.sadrill", "DrillDialect_sadrill")
Expand Down
64 changes: 41 additions & 23 deletions sqlalchemy_drill/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
'json': types.JSON,
}


class DrillCompiler_sadrill(compiler.SQLCompiler):

def default_from(self):
Expand All @@ -76,13 +77,16 @@ def visit_table(self, table, asfrom=False, **kwargs):
try:
fixed_schema = ""
if table.schema != "":
fixed_schema = ".".join(["`{i}`".format(i=i.replace('`', '')) for i in table.schema.split(".")])
fixed_schema = ".".join(
["`{i}`".format(i=i.replace('`', '')) for i in table.schema.split(".")])
fixed_table = "{fixed_schema}.`{table_name}`".format(
fixed_schema=fixed_schema,table_name=table.name.replace("`", "")
fixed_schema=fixed_schema, table_name=table.name.replace(
"`", "")
)
return fixed_table
except Exception as ex:
logging.error( "Error in DrillCompiler_sadrill.visit_table :: " + str(ex))
logging.error(
"Error in DrillCompiler_sadrill.visit_table :: " + str(ex))

else:
return ""
Expand Down Expand Up @@ -151,7 +155,8 @@ class DrillIdentifierPreparer(compiler.IdentifierPreparer):
)

def __init__(self, dialect):
super(DrillIdentifierPreparer, self).__init__(dialect, initial_quote='`', final_quote='`')
super(DrillIdentifierPreparer, self).__init__(
dialect, initial_quote='`', final_quote='`')

def format_drill_table(self, schema, isFile=True):
formatted_schema = ""
Expand All @@ -171,7 +176,8 @@ def format_drill_table(self, schema, isFile=True):
elif isFile and num_dots == 2:
# Case for file and no workspace
plugin = schema_parts[0]
formatted_schema = plugin + "." + schema_parts[1] + ".`" + schema_parts[2] + "`"
formatted_schema = plugin + "." + \
schema_parts[1] + ".`" + schema_parts[2] + "`"
else:
# Case for non-file plugins or incomplete schema parts
for part in schema_parts:
Expand All @@ -184,7 +190,6 @@ def format_drill_table(self, schema, isFile=True):
return formatted_schema



class DrillDialect(default.DefaultDialect):
name = 'drilldbapi'
driver = 'rest'
Expand Down Expand Up @@ -242,7 +247,8 @@ def create_connect_args(self, url, **kwargs):
if url.password:
qargs['drillpass'] = url.password
except Exception as ex:
logging.error("Error in DrillDialect_sadrill.create_connect_args :: " + str(ex))
logging.error(
"Error in DrillDialect_sadrill.create_connect_args :: " + str(ex))

return [], qargs

Expand Down Expand Up @@ -274,7 +280,8 @@ def get_schema_names(self, connection, **kw):
if row.SCHEMA_NAME != "cp.default" and row.SCHEMA_NAME != "INFORMATION_SCHEMA" and row.SCHEMA_NAME != "dfs.default":
result.append(row.SCHEMA_NAME)
except Exception as ex:
logging.error(("Error in DrillDialect_sadrill.get_schema_names :: ", str(ex)))
logging.error(
("Error in DrillDialect_sadrill.get_schema_names :: ", str(ex)))

return tuple(result)

Expand Down Expand Up @@ -312,7 +319,8 @@ def get_table_names(self, connection, schema=None, **kw):
tables_names.append(myname)

except Exception as ex:
logging.error("Error in DrillDialect_sadrill.get_table_names :: " + str(ex))
logging.error(
"Error in DrillDialect_sadrill.get_table_names :: " + str(ex))

return tuple(tables_names)
else:
Expand All @@ -328,20 +336,23 @@ def get_table_names(self, connection, schema=None, **kw):
tables_names.append(myname)

except Exception as ex:
logging.error("Error in DrillDialect_sadrill.get_table_names :: " + str(ex))
logging.error(
"Error in DrillDialect_sadrill.get_table_names :: " + str(ex))

return tuple(tables_names)

def get_view_names(self, connection, schema=None, **kw):
view_names = []
curs = connection.execute("SELECT `TABLE_NAME` FROM INFORMATION_SCHEMA.views WHERE table_schema='" + schema + "'")
curs = connection.execute(
"SELECT `TABLE_NAME` FROM INFORMATION_SCHEMA.views WHERE table_schema='" + schema + "'")
try:
for row in curs:
myname = row.TABLE_NAME
view_names.append(myname)

except Exception as ex:
logging.error("Error in DrillDialect_sadrill.get_view_names :: " + str(ex))
logging.error(
"Error in DrillDialect_sadrill.get_view_names :: " + str(ex))

return tuple(view_names)

Expand All @@ -350,7 +361,8 @@ def has_table(self, connection, table_name, schema=None):
self.get_columns(connection, table_name, schema)
return True
except exc.NoSuchTableError:
logging.error("Error in DrillDialect_sadrill.has_table :: " + exc.NoSuchTableError)
logging.error(
"Error in DrillDialect_sadrill.has_table :: " + exc.NoSuchTableError)
return False

def _check_unicode_returns(self, connection, additional_tests=None):
Expand Down Expand Up @@ -381,21 +393,26 @@ def get_columns(self, connection, table_name, schema=None, **kw):
if plugin_type == "file" or plugin_type == "mongo":
views = self.get_view_names(connection, schema)


file_name = schema + "." + table_name
quoted_file_name = self.identifier_preparer.format_drill_table(file_name, isFile=True)
quoted_file_name = self.identifier_preparer.format_drill_table(
file_name, isFile=True)

# Since MongoDB uses the ** notation, bypass that and query the data directly.
if plugin_type == "mongo":
print("FILE NAME:", file_name, quoted_file_name)
mongo_quoted_file_name = self.identifier_preparer.format_drill_table(file_name, isFile=False)
q = "SELECT `**` FROM {table_name} LIMIT 1".format(table_name=mongo_quoted_file_name)
mongo_quoted_file_name = self.identifier_preparer.format_drill_table(
file_name, isFile=False)
q = "SELECT `**` FROM {table_name} LIMIT 1".format(
table_name=mongo_quoted_file_name)
elif table_name in views:
logging.debug("View: ", quoted_file_name, table_name, schema)
view_name = "`{schema}`.`{table_name}`".format(schema=schema, table_name=table_name)
q = "SELECT * FROM {file_name} LIMIT 1".format(file_name=view_name)
view_name = "`{schema}`.`{table_name}`".format(
schema=schema, table_name=table_name)
q = "SELECT * FROM {file_name} LIMIT 1".format(
file_name=view_name)
else:
q = "SELECT * FROM {file_name} LIMIT 1".format(file_name=quoted_file_name)
q = "SELECT * FROM {file_name} LIMIT 1".format(
file_name=quoted_file_name)

column_metadata = connection.execute(q).cursor.description

Expand All @@ -419,7 +436,8 @@ def get_columns(self, connection, table_name, schema=None, **kw):
elif "SELECT " in table_name:
q = "SELECT * FROM ({table_name}) LIMIT 1".format(table_name=table_name)
else:
quoted_schema = self.identifier_preparer.format_drill_table(schema + "." + table_name, isFile=False)
quoted_schema = self.identifier_preparer.format_drill_table(
schema + "." + table_name, isFile=False)
q = "DESCRIBE {table_name}".format(table_name=quoted_schema)
logging.debug("QUERY:" + q)
query_results = connection.execute(q)
Expand Down Expand Up @@ -451,6 +469,6 @@ def get_plugin_type(self, connection, plugin=None):
return plugin_type

except Exception as ex:
logging.error("Error in DrillDialect_sadrill.get_plugin_type :: " + str(ex))
logging.error(
"Error in DrillDialect_sadrill.get_plugin_type :: " + str(ex))
return False

Loading

0 comments on commit e67dde1

Please sign in to comment.