[fix] SQL parsing of table names (#7490)
This commit is contained in:
parent
78c1674dc7
commit
45b41aadcc
|
|
@ -183,13 +183,11 @@ class SupersetSecurityManager(SecurityManager):
|
|||
|
||||
def get_schema_and_table(self, table_in_query, schema):
|
||||
table_name_pieces = table_in_query.split('.')
|
||||
if len(table_name_pieces) == 2:
|
||||
table_schema = table_name_pieces[0]
|
||||
table_name = table_name_pieces[1]
|
||||
else:
|
||||
table_schema = schema
|
||||
table_name = table_name_pieces[0]
|
||||
return (table_schema, table_name)
|
||||
if len(table_name_pieces) == 3:
|
||||
return tuple(table_name_pieces[1:])
|
||||
elif len(table_name_pieces) == 2:
|
||||
return tuple(table_name_pieces)
|
||||
return (schema, table_name_pieces[0])
|
||||
|
||||
def datasource_access_by_fullname(
|
||||
self, database, table_in_query, schema):
|
||||
|
|
|
|||
|
|
@ -16,10 +16,12 @@
|
|||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
|
||||
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
|
||||
from sqlparse.utils import imt
|
||||
|
||||
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
|
||||
ON_KEYWORD = 'ON'
|
||||
|
|
@ -75,11 +77,34 @@ class ParsedQuery(object):
|
|||
return statements
|
||||
|
||||
@staticmethod
|
||||
def __get_full_name(tlist: TokenList):
|
||||
if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
|
||||
return '{}.{}'.format(tlist.tokens[0].value,
|
||||
tlist.tokens[2].value)
|
||||
return tlist.get_real_name()
|
||||
def __get_full_name(tlist: TokenList) -> Optional[str]:
|
||||
"""
|
||||
Return the full unquoted table name if valid, i.e., conforms to the following
|
||||
[[cluster.]schema.]table construct.
|
||||
|
||||
:param tlist: The SQL tokens
|
||||
:returns: The valid full table name
|
||||
"""
|
||||
|
||||
# Strip the alias if present.
|
||||
idx = len(tlist.tokens)
|
||||
|
||||
if tlist.has_alias():
|
||||
ws_idx, _ = tlist.token_next_by(t=Whitespace)
|
||||
|
||||
if ws_idx != -1:
|
||||
idx = ws_idx
|
||||
|
||||
tokens = tlist.tokens[:idx]
|
||||
|
||||
if (
|
||||
len(tokens) in (1, 3, 5) and
|
||||
all(imt(token, t=[Name, String]) for token in tokens[0::2]) and
|
||||
all(imt(token, m=(Punctuation, '.')) for token in tokens[1::2])
|
||||
):
|
||||
return '.'.join([remove_quotes(token.value) for token in tokens[0::2]])
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __is_identifier(token: Token):
|
||||
|
|
|
|||
|
|
@ -29,6 +29,12 @@ class SupersetTestCase(unittest.TestCase):
|
|||
query = 'SELECT * FROM tbname'
|
||||
self.assertEquals({'tbname'}, self.extract_tables(query))
|
||||
|
||||
query = 'SELECT * FROM tbname foo'
|
||||
self.assertEquals({'tbname'}, self.extract_tables(query))
|
||||
|
||||
query = 'SELECT * FROM tbname AS foo'
|
||||
self.assertEquals({'tbname'}, self.extract_tables(query))
|
||||
|
||||
# underscores
|
||||
query = 'SELECT * FROM tb_name'
|
||||
self.assertEquals({'tb_name'},
|
||||
|
|
@ -47,11 +53,40 @@ class SupersetTestCase(unittest.TestCase):
|
|||
{'schemaname.tbname'},
|
||||
self.extract_tables('SELECT * FROM schemaname.tbname'))
|
||||
|
||||
# Ill-defined schema/table.
|
||||
self.assertEquals(
|
||||
{'schemaname.tbname'},
|
||||
self.extract_tables('SELECT * FROM "schemaname"."tbname"'))
|
||||
|
||||
self.assertEquals(
|
||||
{'schemaname.tbname'},
|
||||
self.extract_tables('SELECT * FROM schemaname.tbname foo'))
|
||||
|
||||
self.assertEquals(
|
||||
{'schemaname.tbname'},
|
||||
self.extract_tables('SELECT * FROM schemaname.tbname AS foo'))
|
||||
|
||||
# cluster
|
||||
self.assertEquals(
|
||||
{'clustername.schemaname.tbname'},
|
||||
self.extract_tables('SELECT * FROM clustername.schemaname.tbname'))
|
||||
|
||||
# Ill-defined cluster/schema/table.
|
||||
self.assertEquals(
|
||||
set(),
|
||||
self.extract_tables('SELECT * FROM schemaname.'))
|
||||
|
||||
self.assertEquals(
|
||||
set(),
|
||||
self.extract_tables('SELECT * FROM clustername.schemaname.'))
|
||||
|
||||
self.assertEquals(
|
||||
set(),
|
||||
self.extract_tables('SELECT * FROM clustername..'))
|
||||
|
||||
self.assertEquals(
|
||||
set(),
|
||||
self.extract_tables('SELECT * FROM clustername..tbname'))
|
||||
|
||||
# quotes
|
||||
query = 'SELECT field1, field2 FROM tb_name'
|
||||
self.assertEquals({'tb_name'}, self.extract_tables(query))
|
||||
|
|
|
|||
Loading…
Reference in New Issue