fix(presto/trino): Ensure get_table_names only returns real tables (#21794)

This commit is contained in:
John Bodley 2022-11-09 14:30:49 -08:00 committed by GitHub
parent 53ed8f2d5a
commit 9f7bd1e63f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 125 additions and 116 deletions

View File

@ -31,6 +31,7 @@ assists people when migrating to a new version.
- [21002](https://github.com/apache/superset/pull/21002): Support Python 3.10 and bump pandas 1.4 and pyarrow 6.
- [21163](https://github.com/apache/superset/pull/21163): When `GENERIC_CHART_AXES` feature flags set to `True`, the Time Grain control will move below the X-Axis control.
- [21284](https://github.com/apache/superset/pull/21284): The non-functional `MAX_TABLE_NAMES` config key has been removed.
- [21794](https://github.com/apache/superset/pull/21794): Deprecates the undocumented `PRESTO_SPLIT_VIEWS_FROM_TABLES` feature flag. Now for Presto, like other engines, only physical tables are treated as tables.
### Breaking Changes

View File

@ -108,8 +108,6 @@ geopy==2.2.0
# via apache-superset
graphlib-backport==1.0.3
# via apache-superset
greenlet==1.1.2
# via sqlalchemy
gunicorn==20.1.0
# via apache-superset
hashids==1.3.1

View File

@ -12,6 +12,8 @@
# -r requirements/docker.in
gevent==21.8.0
# via -r requirements/docker.in
greenlet==1.1.3.post0
# via gevent
psycopg2-binary==2.9.1
# via apache-superset
zope-event==4.5.0

View File

@ -130,7 +130,7 @@ rsa==4.7.2
# via google-auth
statsd==3.3.0
# via -r requirements/testing.in
trino==0.315.0
trino==0.319.0
# via apache-superset
typing-inspect==0.7.1
# via libcst

View File

@ -160,7 +160,7 @@ setup(
"pinot": ["pinotdb>=0.3.3, <0.4"],
"postgres": ["psycopg2-binary==2.9.1"],
"presto": ["pyhive[presto]>=0.6.5"],
"trino": ["trino>=0.313.0"],
"trino": ["trino>=0.319.0"],
"prophet": ["prophet>=1.0.1, <1.1", "pystan<3.0"],
"redshift": ["sqlalchemy-redshift>=0.8.1, < 0.9"],
"rockset": ["rockset>=0.8.10, <0.9"],

View File

@ -19,7 +19,7 @@ import json
import logging
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Optional
from typing import Any, cast, Dict, List, Optional
from zipfile import is_zipfile, ZipFile
from flask import request, Response, send_file
@ -611,7 +611,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
self.incr_stats("init", self.table_metadata.__name__)
parsed_schema = parse_js_uri_path_item(schema_name, eval_undefined=True)
table_name = parse_js_uri_path_item(table_name) # type: ignore
table_name = cast(str, parse_js_uri_path_item(table_name))
payload = database.db_engine_spec.extra_table_metadata(
database, table_name, parsed_schema
)

View File

@ -1018,13 +1018,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
schema: Optional[str],
) -> List[str]:
"""
Get all tables from schema
Get all the real table names within the specified schema.
:param database: The database to get info
:param inspector: SqlAlchemy inspector
:param schema: Schema to inspect. If omitted, uses default schema for database
:return: All tables in schema
Per the SQLAlchemy definition if the schema is omitted the databases default
schema is used, however some dialects infer the request as schema agnostic.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param schema: The schema to inspect
:returns: The physical table names
"""
try:
tables = inspector.get_table_names(schema)
except Exception as ex:
@ -1042,13 +1046,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
schema: Optional[str],
) -> List[str]:
"""
Get all views from schema
Get all the view names within the specified schema.
:param database: The database to get info
:param inspector: SqlAlchemy inspector
:param schema: Schema name. If omitted, uses default schema for database
:return: All views in schema
Per the SQLAlchemy definition if the schema is omitted the databases default
schema is used, however some dialects infer the request as schema agnostic.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param schema: The schema to inspect
:returns: The view names
"""
try:
views = inspector.get_view_names(schema)
except Exception as ex:

View File

@ -19,13 +19,13 @@ from __future__ import annotations
import logging
import re
import textwrap
import time
from abc import ABCMeta
from collections import defaultdict, deque
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from textwrap import dedent
from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union
from urllib import parse
@ -392,46 +392,84 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_table_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
cls,
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
tables = super().get_table_names(database, inspector, schema)
if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
return tables
"""
Get all the real table names within the specified schema.
views = set(cls.get_view_names(database, inspector, schema))
actual_tables = set(tables) - views
return list(actual_tables)
Per the SQLAlchemy definition if the schema is omitted the databases default
schema is used, however some dialects infer the request as schema agnostic.
Note that PyHive's Hive and Presto SQLAlchemy dialects do not adhere to the
specification where the `get_table_names` method returns both real tables and
views. Futhermore the dialects wrongfully infer the request as schema agnostic
when the schema is omitted.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param schema: The schema to inspect
:returns: The physical table names
"""
return sorted(
list(
set(super().get_table_names(database, inspector, schema))
- set(cls.get_view_names(database, inspector, schema))
)
)
@classmethod
def get_view_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
cls,
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
"""Returns an empty list
get_table_names() function returns all table names and view names,
and get_view_names() is not implemented in sqlalchemy_presto.py
https://github.com/dropbox/PyHive/blob/e25fc8440a0686bbb7a5db5de7cb1a77bdb4167a/pyhive/sqlalchemy_presto.py
"""
if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
return []
Get all the view names within the specified schema.
Per the SQLAlchemy definition if the schema is omitted the databases default
schema is used, however some dialects infer the request as schema agnostic.
Note that PyHive's Hive and Presto SQLAlchemy dialects do not implement the
`get_view_names` method. To ensure consistency with the `get_table_names` method
the request is deemed schema agnostic when the schema is omitted.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param schema: The schema to inspect
:returns: The view names
"""
if schema:
sql = (
"SELECT table_name FROM information_schema.views "
"WHERE table_schema=%(schema)s"
)
sql = dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_schema = %(schema)s
AND table_type = 'VIEW'
"""
).strip()
params = {"schema": schema}
else:
sql = "SELECT table_name FROM information_schema.views"
sql = dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_type = 'VIEW'
"""
).strip()
params = {}
engine = cls.get_engine(database, schema=schema)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
return [row[0] for row in results]
return sorted([row[0] for row in results])
@classmethod
def _create_column_info(
@ -1087,7 +1125,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
else f"SHOW PARTITIONS FROM {table_name}"
)
sql = textwrap.dedent(
sql = dedent(
f"""\
{partition_select_clause}
{where_clause}

View File

@ -150,10 +150,6 @@ def test_hive_error_msg():
)
def test_hive_get_view_names_return_empty_list(): # pylint: disable=invalid-name
assert HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) == []
def test_convert_dttm():
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from collections import namedtuple
from textwrap import dedent
from unittest import mock, skipUnless
import pandas as pd
@ -33,52 +34,50 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_datatype_presto(self):
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
def test_presto_get_view_names_return_empty_list(
self,
): # pylint: disable=invalid-name
self.assertEqual(
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
)
@mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
def test_get_view_names(self, mock_is_feature_enabled):
mock_is_feature_enabled.return_value = True
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
mock_execute.assert_called_once_with(
"SELECT table_name FROM information_schema.views", {}
)
assert result == ["a", "d"]
@mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
def test_get_view_names_with_schema(self, mock_is_feature_enabled):
mock_is_feature_enabled.return_value = True
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
schema = "schema"
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
mock_execute.assert_called_once_with(
"SELECT table_name FROM information_schema.views "
"WHERE table_schema=%(schema)s",
dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_schema = %(schema)s
AND table_type = 'VIEW'
"""
).strip(),
{"schema": schema},
)
assert result == ["a", "d"]
def test_get_view_names_without_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
mock_execute.assert_called_once_with(
dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_type = 'VIEW'
"""
).strip(),
{},
)
assert result == ["a", "d"]
def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
@ -663,50 +662,17 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None
@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_no_split_views_from_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
def test_get_table_names(
self,
mock_get_view_names,
mock_get_table_names,
):
mock_get_view_names.return_value = ["view1", "view2"]
table_names = ["table1", "table2", "view1", "view2"]
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = False
mock_get_table_names.return_value = ["table1", "table2", "view1", "view2"]
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == table_names
@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_split_views_from_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
):
mock_get_view_names.return_value = ["view1", "view2"]
table_names = ["table1", "table2", "view1", "view2"]
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = True
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert sorted(tables) == sorted(table_names)
@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_split_views_from_tables_no_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
):
mock_get_view_names.return_value = []
table_names = []
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = True
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == []
assert tables == ["table1", "table2"]
def test_get_full_name(self):
names = [