[fix] SQL parsing of table names (#7490)

This commit is contained in:
John Bodley 2019-06-03 11:07:57 -07:00 committed by GitHub
parent 78c1674dc7
commit 45b41aadcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 15 deletions

View File

@ -183,13 +183,11 @@ class SupersetSecurityManager(SecurityManager):
def get_schema_and_table(self, table_in_query, schema):
table_name_pieces = table_in_query.split('.')
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema
table_name = table_name_pieces[0]
return (table_schema, table_name)
if len(table_name_pieces) == 3:
return tuple(table_name_pieces[1:])
elif len(table_name_pieces) == 2:
return tuple(table_name_pieces)
return (schema, table_name_pieces[0])
def datasource_access_by_fullname(
self, database, table_in_query, schema):

View File

@ -16,10 +16,12 @@
# under the License.
# pylint: disable=C,R,W
import logging
from typing import Optional
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import Keyword, Name
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
ON_KEYWORD = 'ON'
@ -75,11 +77,34 @@ class ParsedQuery(object):
return statements
@staticmethod
def __get_full_name(tlist: TokenList):
if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
return '{}.{}'.format(tlist.tokens[0].value,
tlist.tokens[2].value)
return tlist.get_real_name()
def __get_full_name(tlist: TokenList) -> Optional[str]:
"""
Return the full unquoted table name if valid, i.e., conforms to the following
[[cluster.]schema.]table construct.
:param tlist: The SQL tokens
:returns: The valid full table name
"""
# Strip the alias if present.
idx = len(tlist.tokens)
if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)
if ws_idx != -1:
idx = ws_idx
tokens = tlist.tokens[:idx]
if (
len(tokens) in (1, 3, 5) and
all(imt(token, t=[Name, String]) for token in tokens[0::2]) and
all(imt(token, m=(Punctuation, '.')) for token in tokens[1::2])
):
return '.'.join([remove_quotes(token.value) for token in tokens[0::2]])
return None
@staticmethod
def __is_identifier(token: Token):

View File

@ -29,6 +29,12 @@ class SupersetTestCase(unittest.TestCase):
query = 'SELECT * FROM tbname'
self.assertEquals({'tbname'}, self.extract_tables(query))
query = 'SELECT * FROM tbname foo'
self.assertEquals({'tbname'}, self.extract_tables(query))
query = 'SELECT * FROM tbname AS foo'
self.assertEquals({'tbname'}, self.extract_tables(query))
# underscores
query = 'SELECT * FROM tb_name'
self.assertEquals({'tb_name'},
@ -47,11 +53,40 @@ class SupersetTestCase(unittest.TestCase):
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname'))
# Ill-defined schema/table.
self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM "schemaname"."tbname"'))
self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname foo'))
self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname AS foo'))
# cluster
self.assertEquals(
{'clustername.schemaname.tbname'},
self.extract_tables('SELECT * FROM clustername.schemaname.tbname'))
# Ill-defined cluster/schema/table.
self.assertEquals(
set(),
self.extract_tables('SELECT * FROM schemaname.'))
self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername.schemaname.'))
self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername..'))
self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername..tbname'))
# quotes
query = 'SELECT field1, field2 FROM tb_name'
self.assertEquals({'tb_name'}, self.extract_tables(query))