Implement table name extraction. (#1598)
* Implement table name extraction tests. * Address comments. * Fix tests and reimplement the token processing. * Exclude aliases. * Clean up print statements and code. * Reverse select test. * Fix failing test. * Test JOINs * refactore as a class * Check for permissions in SQL Lab. * Implement permissions check for the datasources in sql_lab * Address comments.
This commit is contained in:
parent
fcb870728d
commit
dc98c6739f
|
|
@ -665,6 +665,7 @@ class Database(Model, AuditMixinNullable):
|
|||
"""An ORM object that stores Database related information"""
|
||||
|
||||
__tablename__ = 'dbs'
|
||||
type = "table"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
database_name = Column(String(250), unique=True)
|
||||
|
|
@ -1524,6 +1525,7 @@ class DruidCluster(Model, AuditMixinNullable):
|
|||
"""ORM object referencing the Druid clusters"""
|
||||
|
||||
__tablename__ = 'clusters'
|
||||
type = "druid"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
cluster_name = Column(String(250), unique=True)
|
||||
|
|
|
|||
|
|
@ -40,6 +40,27 @@ class SourceRegistry(object):
|
|||
d.name == datasource_name and schema == schema]
|
||||
return db_ds[0]
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_name(
|
||||
cls, session, database, datasource_name, schema=None):
|
||||
datasource_class = SourceRegistry.sources[database.type]
|
||||
if database.type == 'table':
|
||||
query = (
|
||||
session.query(datasource_class)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter_by(table_name=datasource_name))
|
||||
if schema:
|
||||
query = query.filter_by(schema=schema)
|
||||
return query.all()
|
||||
if database.type == 'druid':
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.filter_by(cluster_name=database.id)
|
||||
.filter_by(datasource_name=datasource_name)
|
||||
.all()
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_eager_datasource(cls, session, datasource_type, datasource_id):
|
||||
"""Returns datasource with columns and metrics."""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from sqlalchemy.pool import NullPool
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from superset import (
|
||||
app, db, models, utils, dataframe, results_backend)
|
||||
app, db, models, utils, dataframe, results_backend, sql_parse, sm)
|
||||
from superset.db_engine_specs import LimitMethod
|
||||
from superset.jinja_context import get_template_processor
|
||||
QueryStatus = models.QueryStatus
|
||||
|
|
@ -19,16 +19,12 @@ QueryStatus = models.QueryStatus
|
|||
celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
|
||||
|
||||
|
||||
def is_query_select(sql):
|
||||
return sql.upper().startswith('SELECT')
|
||||
|
||||
|
||||
def create_table_as(sql, table_name, schema=None, override=False):
|
||||
"""Reformats the query into the create table as query.
|
||||
|
||||
Works only for the single select SQL statements, in all other cases
|
||||
the sql query is not modified.
|
||||
:param sql: string, sql query that will be executed
|
||||
:param superset_query: string, sql query that will be executed
|
||||
:param table_name: string, will contain the results of the query execution
|
||||
:param override, boolean, table table_name will be dropped if true
|
||||
:return: string, create table as query
|
||||
|
|
@ -41,12 +37,9 @@ def create_table_as(sql, table_name, schema=None, override=False):
|
|||
if schema:
|
||||
table_name = schema + '.' + table_name
|
||||
exec_sql = ''
|
||||
if is_query_select(sql):
|
||||
if override:
|
||||
exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
|
||||
exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
|
||||
else:
|
||||
raise Exception("Could not generate CREATE TABLE statement")
|
||||
if override:
|
||||
exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
|
||||
exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
|
||||
return exec_sql.format(**locals())
|
||||
|
||||
|
||||
|
|
@ -76,12 +69,12 @@ def get_sql_results(self, query_id, return_results=True, store_results=False):
|
|||
raise Exception(query.error_message)
|
||||
|
||||
# Limit enforced only for retrieving the data, not for the CTA queries.
|
||||
is_select = is_query_select(executed_sql);
|
||||
if not is_select and not database.allow_dml:
|
||||
superset_query = sql_parse.SupersetQuery(executed_sql)
|
||||
if not superset_query.is_select() and not database.allow_dml:
|
||||
handle_error(
|
||||
"Only `SELECT` statements are allowed against this database")
|
||||
if query.select_as_cta:
|
||||
if not is_select:
|
||||
if not superset_query.is_select():
|
||||
handle_error(
|
||||
"Only `SELECT` statements can be used with the CREATE TABLE "
|
||||
"feature.")
|
||||
|
|
@ -94,7 +87,7 @@ def get_sql_results(self, query_id, return_results=True, store_results=False):
|
|||
executed_sql, query.tmp_table_name, database.force_ctas_schema)
|
||||
query.select_as_cta_used = True
|
||||
elif (
|
||||
query.limit and is_select and
|
||||
query.limit and superset_query.is_select() and
|
||||
db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
|
||||
executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
|
||||
query.limit_used = True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
import sqlparse
|
||||
from sqlparse.sql import IdentifierList, Identifier
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
|
||||
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'}
|
||||
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
|
||||
|
||||
|
||||
# TODO: some sql_lab logic here.
|
||||
class SupersetQuery(object):
|
||||
def __init__(self, sql_statement):
|
||||
self._tokens = []
|
||||
self.sql = sql_statement
|
||||
self._table_names = set()
|
||||
self._alias_names = set()
|
||||
# TODO: multistatement support
|
||||
for statement in sqlparse.parse(self.sql):
|
||||
self.__extract_from_token(statement)
|
||||
self._table_names = self._table_names - self._alias_names
|
||||
|
||||
@property
|
||||
def tables(self):
|
||||
return self._table_names
|
||||
|
||||
# TODO: use sqlparse for this check.
|
||||
def is_select(self):
|
||||
return self.sql.upper().startswith('SELECT')
|
||||
|
||||
@staticmethod
|
||||
def __precedes_table_name(token_value):
|
||||
for keyword in PRECEDES_TABLE_NAME:
|
||||
if keyword in token_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __get_full_name(identifier):
|
||||
if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
|
||||
return "{}.{}".format(identifier.tokens[0].value,
|
||||
identifier.tokens[2].value)
|
||||
return identifier.get_real_name()
|
||||
|
||||
@staticmethod
|
||||
def __is_result_operation(keyword):
|
||||
for operation in RESULT_OPERATIONS:
|
||||
if operation in keyword.upper():
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __is_identifier(token):
|
||||
return (
|
||||
isinstance(token, IdentifierList) or isinstance(token, Identifier))
|
||||
|
||||
def __process_identifier(self, identifier):
|
||||
# exclude subselects
|
||||
if '(' not in '{}'.format(identifier):
|
||||
self._table_names.add(SupersetQuery.__get_full_name(identifier))
|
||||
return
|
||||
|
||||
# store aliases
|
||||
if hasattr(identifier, 'get_alias'):
|
||||
self._alias_names.add(identifier.get_alias())
|
||||
if hasattr(identifier, 'tokens'):
|
||||
# some aliases are not parsed properly
|
||||
if identifier.tokens[0].ttype == Name:
|
||||
self._alias_names.add(identifier.tokens[0].value)
|
||||
self.__extract_from_token(identifier)
|
||||
|
||||
def __extract_from_token(self, token):
|
||||
if not hasattr(token, 'tokens'):
|
||||
return
|
||||
|
||||
table_name_preceding_token = False
|
||||
|
||||
for item in token.tokens:
|
||||
if item.is_group and not self.__is_identifier(item):
|
||||
self.__extract_from_token(item)
|
||||
|
||||
if item.ttype in Keyword:
|
||||
if SupersetQuery.__precedes_table_name(item.value.upper()):
|
||||
table_name_preceding_token = True
|
||||
continue
|
||||
|
||||
if not table_name_preceding_token:
|
||||
continue
|
||||
|
||||
if item.ttype in Keyword:
|
||||
if SupersetQuery.__is_result_operation(item.value):
|
||||
table_name_preceding_token = False
|
||||
continue
|
||||
# FROM clause is over
|
||||
break
|
||||
|
||||
if isinstance(item, Identifier):
|
||||
self.__process_identifier(item)
|
||||
|
||||
if isinstance(item, IdentifierList):
|
||||
for token in item.tokens:
|
||||
if SupersetQuery.__is_identifier(token):
|
||||
self.__process_identifier(token)
|
||||
|
|
@ -36,7 +36,7 @@ from wtforms.validators import ValidationError
|
|||
import superset
|
||||
from superset import (
|
||||
appbuilder, cache, db, models, viz, utils, app,
|
||||
sm, sql_lab, results_backend, security,
|
||||
sm, sql_lab, sql_parse, results_backend, security,
|
||||
)
|
||||
from superset.source_registry import SourceRegistry
|
||||
from superset.models import DatasourceAccessRequest as DAR
|
||||
|
|
@ -74,6 +74,18 @@ class BaseSupersetView(BaseView):
|
|||
self.can_access("datasource_access", datasource.perm)
|
||||
)
|
||||
|
||||
def datasource_access_by_name(
|
||||
self, database, datasource_name, schema=None):
|
||||
if (self.database_access(database) or
|
||||
self.all_datasource_access()):
|
||||
return True
|
||||
datasources = SourceRegistry.query_datasources_by_name(
|
||||
db.session, database, datasource_name, schema=schema)
|
||||
for datasource in datasources:
|
||||
if self.can_access("datasource_access", datasource.perm):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ListWidgetWithCheckboxes(ListWidget):
|
||||
"""An alternative to list view that renders Boolean fields as checkboxes
|
||||
|
|
@ -2303,27 +2315,45 @@ class Superset(BaseSupersetView):
|
|||
@log_this
|
||||
def sql_json(self):
|
||||
"""Runs arbitrary sql and returns and json"""
|
||||
def table_accessible(database, full_table_name, schema_name=None):
|
||||
table_name_pieces = full_table_name.split(".")
|
||||
if len(table_name_pieces) == 2:
|
||||
table_schema = table_name_pieces[0]
|
||||
table_name = table_name_pieces[1]
|
||||
else:
|
||||
table_schema = schema_name
|
||||
table_name = table_name_pieces[0]
|
||||
return self.datasource_access_by_name(
|
||||
database, table_name, schema=table_schema)
|
||||
|
||||
async = request.form.get('runAsync') == 'true'
|
||||
sql = request.form.get('sql')
|
||||
database_id = request.form.get('database_id')
|
||||
|
||||
session = db.session()
|
||||
mydb = session.query(models.Database).filter_by(id=database_id).first()
|
||||
mydb = session.query(models.Database).filter_by(id=database_id).one()
|
||||
|
||||
if not mydb:
|
||||
json_error_response(
|
||||
'Database with id {} is missing.'.format(database_id))
|
||||
|
||||
if not self.database_access(mydb):
|
||||
superset_query = sql_parse.SupersetQuery(sql)
|
||||
schema = request.form.get('schema')
|
||||
schema = schema if schema else None
|
||||
|
||||
rejected_tables = [
|
||||
t for t in superset_query.tables if not
|
||||
table_accessible(mydb, t, schema_name=schema)]
|
||||
if rejected_tables:
|
||||
json_error_response(
|
||||
get_database_access_error_msg(mydb.database_name))
|
||||
get_datasource_access_error_msg('{}'.format(rejected_tables)))
|
||||
session.commit()
|
||||
|
||||
query = models.Query(
|
||||
database_id=int(database_id),
|
||||
limit=int(app.config.get('SQL_MAX_ROW', None)),
|
||||
sql=sql,
|
||||
schema=request.form.get('schema'),
|
||||
schema=schema,
|
||||
select_as_cta=request.form.get('select_as_cta') == 'true',
|
||||
start_time=utils.now_as_float(),
|
||||
tab_name=request.form.get('tab'),
|
||||
|
|
@ -2341,7 +2371,8 @@ class Superset(BaseSupersetView):
|
|||
if async:
|
||||
# Ignore the celery future object and the request may time out.
|
||||
sql_lab.get_sql_results.delay(
|
||||
query_id, return_results=False, store_results=not query.select_as_cta)
|
||||
query_id, return_results=False,
|
||||
store_results=not query.select_as_cta)
|
||||
return Response(
|
||||
json.dumps({'query': query.to_dict()},
|
||||
default=utils.json_int_dttm_ser,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,295 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
|
||||
from superset import sql_parse
|
||||
|
||||
|
||||
class SupersetTestCase(unittest.TestCase):
|
||||
|
||||
def extract_tables(self, query):
|
||||
sq = sql_parse.SupersetQuery(query)
|
||||
return sq.tables
|
||||
|
||||
def test_simple_select(self):
|
||||
query = "SELECT * FROM tbname"
|
||||
self.assertEquals({"tbname"}, self.extract_tables(query))
|
||||
|
||||
# underscores
|
||||
query = "SELECT * FROM tb_name"
|
||||
self.assertEquals({"tb_name"},
|
||||
self.extract_tables(query))
|
||||
|
||||
# quotes
|
||||
query = 'SELECT * FROM "tbname"'
|
||||
self.assertEquals({"tbname"}, self.extract_tables(query))
|
||||
|
||||
# schema
|
||||
self.assertEquals(
|
||||
{"schemaname.tbname"},
|
||||
self.extract_tables("SELECT * FROM schemaname.tbname"))
|
||||
|
||||
# quotes
|
||||
query = "SELECT field1, field2 FROM tb_name"
|
||||
self.assertEquals({"tb_name"}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT t1.f1, t2.f2 FROM t1, t2"
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
def test_select_named_table(self):
|
||||
query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
|
||||
self.assertEquals(
|
||||
{"left_table"}, self.extract_tables(query))
|
||||
|
||||
def test_reverse_select(self):
|
||||
query = "FROM t1 SELECT field"
|
||||
self.assertEquals({"t1"}, self.extract_tables(query))
|
||||
|
||||
def test_subselect(self):
|
||||
query = """
|
||||
SELECT sub.*
|
||||
FROM (
|
||||
SELECT *
|
||||
FROM s1.t1
|
||||
WHERE day_of_week = 'Friday'
|
||||
) sub, s2.t2
|
||||
WHERE sub.resolution = 'NONE'
|
||||
"""
|
||||
self.assertEquals({"s1.t1", "s2.t2"},
|
||||
self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT sub.*
|
||||
FROM (
|
||||
SELECT *
|
||||
FROM s1.t1
|
||||
WHERE day_of_week = 'Friday'
|
||||
) sub
|
||||
WHERE sub.resolution = 'NONE'
|
||||
"""
|
||||
self.assertEquals({"s1.t1"}, self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT * FROM t1
|
||||
WHERE s11 > ANY
|
||||
(SELECT COUNT(*) /* no hint */ FROM t2
|
||||
WHERE NOT EXISTS
|
||||
(SELECT * FROM t3
|
||||
WHERE ROW(5*t2.s1,77)=
|
||||
(SELECT 50,11*s1 FROM t4)));
|
||||
"""
|
||||
self.assertEquals({"t1", "t2", "t3", "t4"},
|
||||
self.extract_tables(query))
|
||||
|
||||
def test_select_in_expression(self):
|
||||
query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
def test_union(self):
|
||||
query = "SELECT * FROM t1 UNION SELECT * FROM t2"
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
def test_select_from_values(self):
|
||||
query = "SELECT * FROM VALUES (13, 42)"
|
||||
self.assertFalse(self.extract_tables(query))
|
||||
|
||||
def test_select_array(self):
|
||||
query = """
|
||||
SELECT ARRAY[1, 2, 3] AS my_array
|
||||
FROM t1 LIMIT 10
|
||||
"""
|
||||
self.assertEquals({"t1"}, self.extract_tables(query))
|
||||
|
||||
def test_select_if(self):
|
||||
query = """
|
||||
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
||||
FROM t1 LIMIT 10
|
||||
"""
|
||||
self.assertEquals({"t1"}, self.extract_tables(query))
|
||||
|
||||
# SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
|
||||
def test_show_tables(self):
|
||||
query = 'SHOW TABLES FROM s1 like "%order%"'
|
||||
# TODO: figure out what should code do here
|
||||
self.assertEquals({"s1"}, self.extract_tables(query))
|
||||
|
||||
# SHOW COLUMNS (FROM | IN) qualifiedName
|
||||
def test_show_columns(self):
|
||||
query = "SHOW COLUMNS FROM t1"
|
||||
self.assertEquals({"t1"}, self.extract_tables(query))
|
||||
|
||||
def test_where_subquery(self):
|
||||
query = """
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
||||
"""
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey IN (SELECT regionkey FROM t2)
|
||||
"""
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
|
||||
"""
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
# DESCRIBE | DESC qualifiedName
|
||||
def test_describe(self):
|
||||
self.assertEquals({"t1"}, self.extract_tables("DESCRIBE t1"))
|
||||
self.assertEquals({"t1"}, self.extract_tables("DESC t1"))
|
||||
|
||||
# SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
|
||||
# (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
|
||||
def test_show_partitions(self):
|
||||
query = """
|
||||
SHOW PARTITIONS FROM orders
|
||||
WHERE ds >= '2013-01-01' ORDER BY ds DESC;
|
||||
"""
|
||||
self.assertEquals({"orders"}, self.extract_tables(query))
|
||||
|
||||
def test_join(self):
|
||||
query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
# subquery + join
|
||||
query = """
|
||||
SELECT a.date, b.name FROM
|
||||
left_table a
|
||||
JOIN (
|
||||
SELECT
|
||||
CAST((b.year) as VARCHAR) date,
|
||||
name
|
||||
FROM right_table
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
self.assertEquals({"left_table", "right_table"},
|
||||
self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT a.date, b.name FROM
|
||||
left_table a
|
||||
LEFT INNER JOIN (
|
||||
SELECT
|
||||
CAST((b.year) as VARCHAR) date,
|
||||
name
|
||||
FROM right_table
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
self.assertEquals({"left_table", "right_table"},
|
||||
self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT a.date, b.name FROM
|
||||
left_table a
|
||||
RIGHT OUTER JOIN (
|
||||
SELECT
|
||||
CAST((b.year) as VARCHAR) date,
|
||||
name
|
||||
FROM right_table
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
self.assertEquals({"left_table", "right_table"},
|
||||
self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT a.date, b.name FROM
|
||||
left_table a
|
||||
FULL OUTER JOIN (
|
||||
SELECT
|
||||
CAST((b.year) as VARCHAR) date,
|
||||
name
|
||||
FROM right_table
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
self.assertEquals({"left_table", "right_table"},
|
||||
self.extract_tables(query))
|
||||
|
||||
# TODO: add SEMI join support, SQL Parse does not handle it.
|
||||
# query = """
|
||||
# SELECT a.date, b.name FROM
|
||||
# left_table a
|
||||
# LEFT SEMI JOIN (
|
||||
# SELECT
|
||||
# CAST((b.year) as VARCHAR) date,
|
||||
# name
|
||||
# FROM right_table
|
||||
# ) b
|
||||
# ON a.date = b.date
|
||||
# """
|
||||
# self.assertEquals({"left_table", "right_table"},
|
||||
# sql_parse.extract_tables(query))
|
||||
|
||||
def test_combinations(self):
|
||||
query = """
|
||||
SELECT * FROM t1
|
||||
WHERE s11 > ANY
|
||||
(SELECT * FROM t1 UNION ALL SELECT * FROM (
|
||||
SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join
|
||||
WHERE NOT EXISTS
|
||||
(SELECT * FROM t3
|
||||
WHERE ROW(5*t3.s1,77)=
|
||||
(SELECT 50,11*s1 FROM t4)));
|
||||
"""
|
||||
self.assertEquals({"t1", "t3", "t4", "t6"},
|
||||
self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
|
||||
AS S1) AS S2) AS S3;
|
||||
"""
|
||||
self.assertEquals({"EmployeeS"}, self.extract_tables(query))
|
||||
|
||||
def test_with(self):
|
||||
query = """
|
||||
WITH
|
||||
x AS (SELECT a FROM t1),
|
||||
y AS (SELECT a AS b FROM t2),
|
||||
z AS (SELECT b AS c FROM t3)
|
||||
SELECT c FROM z;
|
||||
"""
|
||||
self.assertEquals({"t1", "t2", "t3"},
|
||||
self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
WITH
|
||||
x AS (SELECT a FROM t1),
|
||||
y AS (SELECT a AS b FROM x),
|
||||
z AS (SELECT b AS c FROM y)
|
||||
SELECT c FROM z;
|
||||
"""
|
||||
self.assertEquals({"t1"}, self.extract_tables(query))
|
||||
|
||||
def test_reusing_aliases(self):
|
||||
query = """
|
||||
with q1 as ( select key from q2 where key = '5'),
|
||||
q2 as ( select key from src where key = '5')
|
||||
select * from (select key from q1) a;
|
||||
"""
|
||||
self.assertEquals({"src"}, self.extract_tables(query))
|
||||
|
||||
def multistatement(self):
|
||||
query = "SELECT * FROM t1; SELECT * FROM t2"
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT * FROM t1; SELECT * FROM t2;"
|
||||
self.assertEquals({"t1", "t2"}, self.extract_tables(query))
|
||||
Loading…
Reference in New Issue