Implement table name extraction. (#1598)

* Implement table name extraction tests.

* Address comments.

* Fix tests and reimplement the token processing.

* Exclude aliases.

* Clean up print statements and code.

* Reverse select test.

* Fix failing test.

* Test JOINs

* refactore as a class

* Check for permissions in SQL Lab.

* Implement permissions check for the datasources in sql_lab

* Address comments.
This commit is contained in:
Bogdan 2016-11-29 15:43:36 -05:00 committed by GitHub
parent fcb870728d
commit dc98c6739f
6 changed files with 465 additions and 22 deletions

View File

@ -665,6 +665,7 @@ class Database(Model, AuditMixinNullable):
"""An ORM object that stores Database related information"""
__tablename__ = 'dbs'
type = "table"
id = Column(Integer, primary_key=True)
database_name = Column(String(250), unique=True)
@ -1524,6 +1525,7 @@ class DruidCluster(Model, AuditMixinNullable):
"""ORM object referencing the Druid clusters"""
__tablename__ = 'clusters'
type = "druid"
id = Column(Integer, primary_key=True)
cluster_name = Column(String(250), unique=True)

View File

@ -40,6 +40,27 @@ class SourceRegistry(object):
d.name == datasource_name and schema == schema]
return db_ds[0]
@classmethod
def query_datasources_by_name(
cls, session, database, datasource_name, schema=None):
datasource_class = SourceRegistry.sources[database.type]
if database.type == 'table':
query = (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter_by(table_name=datasource_name))
if schema:
query = query.filter_by(schema=schema)
return query.all()
if database.type == 'druid':
return (
session.query(datasource_class)
.filter_by(cluster_name=database.id)
.filter_by(datasource_name=datasource_name)
.all()
)
return None
@classmethod
def get_eager_datasource(cls, session, datasource_type, datasource_id):
"""Returns datasource with columns and metrics."""

View File

@ -11,7 +11,7 @@ from sqlalchemy.pool import NullPool
from sqlalchemy.orm import sessionmaker
from superset import (
app, db, models, utils, dataframe, results_backend)
app, db, models, utils, dataframe, results_backend, sql_parse, sm)
from superset.db_engine_specs import LimitMethod
from superset.jinja_context import get_template_processor
QueryStatus = models.QueryStatus
@ -19,16 +19,12 @@ QueryStatus = models.QueryStatus
celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
def is_query_select(sql):
return sql.upper().startswith('SELECT')
def create_table_as(sql, table_name, schema=None, override=False):
"""Reformats the query into the create table as query.
Works only for the single select SQL statements, in all other cases
the sql query is not modified.
:param sql: string, sql query that will be executed
:param superset_query: string, sql query that will be executed
:param table_name: string, will contain the results of the query execution
:param override, boolean, table table_name will be dropped if true
:return: string, create table as query
@ -41,12 +37,9 @@ def create_table_as(sql, table_name, schema=None, override=False):
if schema:
table_name = schema + '.' + table_name
exec_sql = ''
if is_query_select(sql):
if override:
exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
else:
raise Exception("Could not generate CREATE TABLE statement")
if override:
exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
return exec_sql.format(**locals())
@ -76,12 +69,12 @@ def get_sql_results(self, query_id, return_results=True, store_results=False):
raise Exception(query.error_message)
# Limit enforced only for retrieving the data, not for the CTA queries.
is_select = is_query_select(executed_sql);
if not is_select and not database.allow_dml:
superset_query = sql_parse.SupersetQuery(executed_sql)
if not superset_query.is_select() and not database.allow_dml:
handle_error(
"Only `SELECT` statements are allowed against this database")
if query.select_as_cta:
if not is_select:
if not superset_query.is_select():
handle_error(
"Only `SELECT` statements can be used with the CREATE TABLE "
"feature.")
@ -94,7 +87,7 @@ def get_sql_results(self, query_id, return_results=True, store_results=False):
executed_sql, query.tmp_table_name, database.force_ctas_schema)
query.select_as_cta_used = True
elif (
query.limit and is_select and
query.limit and superset_query.is_select() and
db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
query.limit_used = True

101
superset/sql_parse.py Normal file
View File

@ -0,0 +1,101 @@
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, Name
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'}
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
# TODO: some sql_lab logic here.
class SupersetQuery(object):
def __init__(self, sql_statement):
self._tokens = []
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
# TODO: multistatement support
for statement in sqlparse.parse(self.sql):
self.__extract_from_token(statement)
self._table_names = self._table_names - self._alias_names
@property
def tables(self):
return self._table_names
# TODO: use sqlparse for this check.
def is_select(self):
return self.sql.upper().startswith('SELECT')
@staticmethod
def __precedes_table_name(token_value):
for keyword in PRECEDES_TABLE_NAME:
if keyword in token_value:
return True
return False
@staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
return "{}.{}".format(identifier.tokens[0].value,
identifier.tokens[2].value)
return identifier.get_real_name()
@staticmethod
def __is_result_operation(keyword):
for operation in RESULT_OPERATIONS:
if operation in keyword.upper():
return True
return False
@staticmethod
def __is_identifier(token):
return (
isinstance(token, IdentifierList) or isinstance(token, Identifier))
def __process_identifier(self, identifier):
# exclude subselects
if '(' not in '{}'.format(identifier):
self._table_names.add(SupersetQuery.__get_full_name(identifier))
return
# store aliases
if hasattr(identifier, 'get_alias'):
self._alias_names.add(identifier.get_alias())
if hasattr(identifier, 'tokens'):
# some aliases are not parsed properly
if identifier.tokens[0].ttype == Name:
self._alias_names.add(identifier.tokens[0].value)
self.__extract_from_token(identifier)
def __extract_from_token(self, token):
if not hasattr(token, 'tokens'):
return
table_name_preceding_token = False
for item in token.tokens:
if item.is_group and not self.__is_identifier(item):
self.__extract_from_token(item)
if item.ttype in Keyword:
if SupersetQuery.__precedes_table_name(item.value.upper()):
table_name_preceding_token = True
continue
if not table_name_preceding_token:
continue
if item.ttype in Keyword:
if SupersetQuery.__is_result_operation(item.value):
table_name_preceding_token = False
continue
# FROM clause is over
break
if isinstance(item, Identifier):
self.__process_identifier(item)
if isinstance(item, IdentifierList):
for token in item.tokens:
if SupersetQuery.__is_identifier(token):
self.__process_identifier(token)

View File

@ -36,7 +36,7 @@ from wtforms.validators import ValidationError
import superset
from superset import (
appbuilder, cache, db, models, viz, utils, app,
sm, sql_lab, results_backend, security,
sm, sql_lab, sql_parse, results_backend, security,
)
from superset.source_registry import SourceRegistry
from superset.models import DatasourceAccessRequest as DAR
@ -74,6 +74,18 @@ class BaseSupersetView(BaseView):
self.can_access("datasource_access", datasource.perm)
)
def datasource_access_by_name(
self, database, datasource_name, schema=None):
if (self.database_access(database) or
self.all_datasource_access()):
return True
datasources = SourceRegistry.query_datasources_by_name(
db.session, database, datasource_name, schema=schema)
for datasource in datasources:
if self.can_access("datasource_access", datasource.perm):
return True
return False
class ListWidgetWithCheckboxes(ListWidget):
"""An alternative to list view that renders Boolean fields as checkboxes
@ -2303,27 +2315,45 @@ class Superset(BaseSupersetView):
@log_this
def sql_json(self):
"""Runs arbitrary sql and returns and json"""
def table_accessible(database, full_table_name, schema_name=None):
table_name_pieces = full_table_name.split(".")
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema_name
table_name = table_name_pieces[0]
return self.datasource_access_by_name(
database, table_name, schema=table_schema)
async = request.form.get('runAsync') == 'true'
sql = request.form.get('sql')
database_id = request.form.get('database_id')
session = db.session()
mydb = session.query(models.Database).filter_by(id=database_id).first()
mydb = session.query(models.Database).filter_by(id=database_id).one()
if not mydb:
json_error_response(
'Database with id {} is missing.'.format(database_id))
if not self.database_access(mydb):
superset_query = sql_parse.SupersetQuery(sql)
schema = request.form.get('schema')
schema = schema if schema else None
rejected_tables = [
t for t in superset_query.tables if not
table_accessible(mydb, t, schema_name=schema)]
if rejected_tables:
json_error_response(
get_database_access_error_msg(mydb.database_name))
get_datasource_access_error_msg('{}'.format(rejected_tables)))
session.commit()
query = models.Query(
database_id=int(database_id),
limit=int(app.config.get('SQL_MAX_ROW', None)),
sql=sql,
schema=request.form.get('schema'),
schema=schema,
select_as_cta=request.form.get('select_as_cta') == 'true',
start_time=utils.now_as_float(),
tab_name=request.form.get('tab'),
@ -2341,7 +2371,8 @@ class Superset(BaseSupersetView):
if async:
# Ignore the celery future object and the request may time out.
sql_lab.get_sql_results.delay(
query_id, return_results=False, store_results=not query.select_as_cta)
query_id, return_results=False,
store_results=not query.select_as_cta)
return Response(
json.dumps({'query': query.to_dict()},
default=utils.json_int_dttm_ser,

295
tests/sql_parse_tests.py Normal file
View File

@ -0,0 +1,295 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
from superset import sql_parse
class SupersetTestCase(unittest.TestCase):
def extract_tables(self, query):
sq = sql_parse.SupersetQuery(query)
return sq.tables
def test_simple_select(self):
query = "SELECT * FROM tbname"
self.assertEquals({"tbname"}, self.extract_tables(query))
# underscores
query = "SELECT * FROM tb_name"
self.assertEquals({"tb_name"},
self.extract_tables(query))
# quotes
query = 'SELECT * FROM "tbname"'
self.assertEquals({"tbname"}, self.extract_tables(query))
# schema
self.assertEquals(
{"schemaname.tbname"},
self.extract_tables("SELECT * FROM schemaname.tbname"))
# quotes
query = "SELECT field1, field2 FROM tb_name"
self.assertEquals({"tb_name"}, self.extract_tables(query))
query = "SELECT t1.f1, t2.f2 FROM t1, t2"
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
def test_select_named_table(self):
query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
self.assertEquals(
{"left_table"}, self.extract_tables(query))
def test_reverse_select(self):
query = "FROM t1 SELECT field"
self.assertEquals({"t1"}, self.extract_tables(query))
def test_subselect(self):
query = """
SELECT sub.*
FROM (
SELECT *
FROM s1.t1
WHERE day_of_week = 'Friday'
) sub, s2.t2
WHERE sub.resolution = 'NONE'
"""
self.assertEquals({"s1.t1", "s2.t2"},
self.extract_tables(query))
query = """
SELECT sub.*
FROM (
SELECT *
FROM s1.t1
WHERE day_of_week = 'Friday'
) sub
WHERE sub.resolution = 'NONE'
"""
self.assertEquals({"s1.t1"}, self.extract_tables(query))
query = """
SELECT * FROM t1
WHERE s11 > ANY
(SELECT COUNT(*) /* no hint */ FROM t2
WHERE NOT EXISTS
(SELECT * FROM t3
WHERE ROW(5*t2.s1,77)=
(SELECT 50,11*s1 FROM t4)));
"""
self.assertEquals({"t1", "t2", "t3", "t4"},
self.extract_tables(query))
def test_select_in_expression(self):
query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
def test_union(self):
query = "SELECT * FROM t1 UNION SELECT * FROM t2"
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
def test_select_from_values(self):
query = "SELECT * FROM VALUES (13, 42)"
self.assertFalse(self.extract_tables(query))
def test_select_array(self):
query = """
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
"""
self.assertEquals({"t1"}, self.extract_tables(query))
def test_select_if(self):
query = """
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
"""
self.assertEquals({"t1"}, self.extract_tables(query))
# SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
def test_show_tables(self):
query = 'SHOW TABLES FROM s1 like "%order%"'
# TODO: figure out what should code do here
self.assertEquals({"s1"}, self.extract_tables(query))
# SHOW COLUMNS (FROM | IN) qualifiedName
def test_show_columns(self):
query = "SHOW COLUMNS FROM t1"
self.assertEquals({"t1"}, self.extract_tables(query))
def test_where_subquery(self):
query = """
SELECT name
FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2)
"""
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
query = """
SELECT name
FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2)
"""
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
query = """
SELECT name
FROM t1
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
"""
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
# DESCRIBE | DESC qualifiedName
def test_describe(self):
self.assertEquals({"t1"}, self.extract_tables("DESCRIBE t1"))
self.assertEquals({"t1"}, self.extract_tables("DESC t1"))
# SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
# (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
def test_show_partitions(self):
query = """
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC;
"""
self.assertEquals({"orders"}, self.extract_tables(query))
def test_join(self):
query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
# subquery + join
query = """
SELECT a.date, b.name FROM
left_table a
JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
self.assertEquals({"left_table", "right_table"},
self.extract_tables(query))
query = """
SELECT a.date, b.name FROM
left_table a
LEFT INNER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
self.assertEquals({"left_table", "right_table"},
self.extract_tables(query))
query = """
SELECT a.date, b.name FROM
left_table a
RIGHT OUTER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
self.assertEquals({"left_table", "right_table"},
self.extract_tables(query))
query = """
SELECT a.date, b.name FROM
left_table a
FULL OUTER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
self.assertEquals({"left_table", "right_table"},
self.extract_tables(query))
# TODO: add SEMI join support, SQL Parse does not handle it.
# query = """
# SELECT a.date, b.name FROM
# left_table a
# LEFT SEMI JOIN (
# SELECT
# CAST((b.year) as VARCHAR) date,
# name
# FROM right_table
# ) b
# ON a.date = b.date
# """
# self.assertEquals({"left_table", "right_table"},
# sql_parse.extract_tables(query))
def test_combinations(self):
query = """
SELECT * FROM t1
WHERE s11 > ANY
(SELECT * FROM t1 UNION ALL SELECT * FROM (
SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join
WHERE NOT EXISTS
(SELECT * FROM t3
WHERE ROW(5*t3.s1,77)=
(SELECT 50,11*s1 FROM t4)));
"""
self.assertEquals({"t1", "t3", "t4", "t6"},
self.extract_tables(query))
query = """
SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
AS S1) AS S2) AS S3;
"""
self.assertEquals({"EmployeeS"}, self.extract_tables(query))
def test_with(self):
query = """
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM t2),
z AS (SELECT b AS c FROM t3)
SELECT c FROM z;
"""
self.assertEquals({"t1", "t2", "t3"},
self.extract_tables(query))
query = """
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM x),
z AS (SELECT b AS c FROM y)
SELECT c FROM z;
"""
self.assertEquals({"t1"}, self.extract_tables(query))
def test_reusing_aliases(self):
query = """
with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from q1) a;
"""
self.assertEquals({"src"}, self.extract_tables(query))
def multistatement(self):
query = "SELECT * FROM t1; SELECT * FROM t2"
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
query = "SELECT * FROM t1; SELECT * FROM t2;"
self.assertEquals({"t1", "t2"}, self.extract_tables(query))