diff --git a/superset/security.py b/superset/security.py index 89eab5d53..3a92a7c21 100644 --- a/superset/security.py +++ b/superset/security.py @@ -183,13 +183,11 @@ class SupersetSecurityManager(SecurityManager): def get_schema_and_table(self, table_in_query, schema): table_name_pieces = table_in_query.split('.') - if len(table_name_pieces) == 2: - table_schema = table_name_pieces[0] - table_name = table_name_pieces[1] - else: - table_schema = schema - table_name = table_name_pieces[0] - return (table_schema, table_name) + if len(table_name_pieces) == 3: + return tuple(table_name_pieces[1:]) + elif len(table_name_pieces) == 2: + return tuple(table_name_pieces) + return (schema, table_name_pieces[0]) def datasource_access_by_fullname( self, database, table_in_query, schema): diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 63ae05e66..e2c25ef50 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -16,10 +16,12 @@ # under the License. # pylint: disable=C,R,W import logging +from typing import Optional import sqlparse -from sqlparse.sql import Identifier, IdentifierList, Token, TokenList -from sqlparse.tokens import Keyword, Name +from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList +from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace +from sqlparse.utils import imt RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'} ON_KEYWORD = 'ON' @@ -75,11 +77,34 @@ class ParsedQuery(object): return statements @staticmethod - def __get_full_name(tlist: TokenList): - if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.': - return '{}.{}'.format(tlist.tokens[0].value, - tlist.tokens[2].value) - return tlist.get_real_name() + def __get_full_name(tlist: TokenList) -> Optional[str]: + """ + Return the full unquoted table name if valid, i.e., conforms to the following + [[cluster.]schema.]table construct. + + :param tlist: The SQL tokens + :returns: The valid full table name + """ + + # Strip the alias if present. + idx = len(tlist.tokens) + + if tlist.has_alias(): + ws_idx, _ = tlist.token_next_by(t=Whitespace) + + if ws_idx != -1: + idx = ws_idx + + tokens = tlist.tokens[:idx] + + if ( + len(tokens) in (1, 3, 5) and + all(imt(token, t=[Name, String]) for token in tokens[0::2]) and + all(imt(token, m=(Punctuation, '.')) for token in tokens[1::2]) + ): + return '.'.join([remove_quotes(token.value) for token in tokens[0::2]]) + + return None @staticmethod def __is_identifier(token: Token): diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py index fb3dc7894..bf3cffb33 100644 --- a/tests/sql_parse_tests.py +++ b/tests/sql_parse_tests.py @@ -29,6 +29,12 @@ class SupersetTestCase(unittest.TestCase): query = 'SELECT * FROM tbname' self.assertEquals({'tbname'}, self.extract_tables(query)) + query = 'SELECT * FROM tbname foo' + self.assertEquals({'tbname'}, self.extract_tables(query)) + + query = 'SELECT * FROM tbname AS foo' + self.assertEquals({'tbname'}, self.extract_tables(query)) + # underscores query = 'SELECT * FROM tb_name' self.assertEquals({'tb_name'}, @@ -47,11 +53,40 @@ class SupersetTestCase(unittest.TestCase): {'schemaname.tbname'}, self.extract_tables('SELECT * FROM schemaname.tbname')) - # Ill-defined schema/table. + self.assertEquals( + {'schemaname.tbname'}, + self.extract_tables('SELECT * FROM "schemaname"."tbname"')) + + self.assertEquals( + {'schemaname.tbname'}, + self.extract_tables('SELECT * FROM schemaname.tbname foo')) + + self.assertEquals( + {'schemaname.tbname'}, + self.extract_tables('SELECT * FROM schemaname.tbname AS foo')) + + # cluster + self.assertEquals( + {'clustername.schemaname.tbname'}, + self.extract_tables('SELECT * FROM clustername.schemaname.tbname')) + + # Ill-defined cluster/schema/table. self.assertEquals( set(), self.extract_tables('SELECT * FROM schemaname.')) + self.assertEquals( + set(), + self.extract_tables('SELECT * FROM clustername.schemaname.')) + + self.assertEquals( + set(), + self.extract_tables('SELECT * FROM clustername..')) + + self.assertEquals( + set(), + self.extract_tables('SELECT * FROM clustername..tbname')) + # quotes query = 'SELECT field1, field2 FROM tb_name' self.assertEquals({'tb_name'}, self.extract_tables(query))