chore: Change get_table_names/get_view_names return type (#22085)

This commit is contained in:
John Bodley 2022-11-18 12:41:21 -08:00 committed by GitHub
parent e990690dde
commit 7e54b88a51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 76 additions and 62 deletions

View File

@ -1034,7 +1034,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the real table names within the specified schema.
@ -1048,13 +1048,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
try:
tables = inspector.get_table_names(schema)
tables = set(inspector.get_table_names(schema))
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex
if schema and cls.try_remove_schema_from_table_name:
tables = [re.sub(f"^{schema}\\.", "", table) for table in tables]
return sorted(tables)
tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
return tables
@classmethod
def get_view_names( # pylint: disable=unused-argument
@ -1062,7 +1062,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the view names within the specified schema.
@ -1076,13 +1076,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
try:
views = inspector.get_view_names(schema)
views = set(inspector.get_view_names(schema))
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex
if schema and cls.try_remove_schema_from_table_name:
views = [re.sub(f"^{schema}\\.", "", view) for view in views]
return sorted(views)
views = {re.sub(f"^{schema}\\.", "", view) for view in views}
return views
@classmethod
def get_table_comment(

View File

@ -16,7 +16,7 @@
# under the License.
from datetime import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Dict, Optional, Set, TYPE_CHECKING
from sqlalchemy.engine.reflection import Inspector
@ -103,9 +103,7 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
tables = set(super().get_table_names(database, inspector, schema))
views = set(cls.get_view_names(database, inspector, schema))
actual_tables = tables - views
return list(actual_tables)
) -> Set[str]:
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
@ -75,5 +75,5 @@ class DuckDBEngineSpec(BaseEngineSpec):
@classmethod
def get_table_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
) -> List[str]:
return sorted(inspector.get_table_names(schema))
) -> Set[str]:
return set(inspector.get_table_names(schema))

View File

@ -18,7 +18,7 @@ import json
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
@ -228,11 +228,11 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
) -> List[str]:
) -> Set[str]:
"""Need to consider foreign tables for PostgreSQL"""
tables = inspector.get_table_names(schema)
tables.extend(inspector.get_foreign_table_names(schema))
return sorted(tables)
return set(inspector.get_table_names(schema)) | set(
inspector.get_foreign_table_names(schema)
)
@classmethod
def convert_dttm(

View File

@ -26,7 +26,18 @@ from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from textwrap import dedent
from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
cast,
Dict,
List,
Optional,
Pattern,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from urllib import parse
import pandas as pd
@ -396,7 +407,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the real table names within the specified schema.
@ -414,12 +425,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
:returns: The physical table names
"""
return sorted(
list(
set(super().get_table_names(database, inspector, schema))
- set(cls.get_view_names(database, inspector, schema))
)
)
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
@classmethod
def get_view_names(
@ -427,7 +435,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the view names within the specified schema.
@ -468,7 +476,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cursor.execute(sql, params)
results = cursor.fetchall()
return sorted([row[0] for row in results])
return {row[0] for row in results}
@classmethod
def _create_column_info(

View File

@ -16,7 +16,7 @@
# under the License.
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
@ -88,6 +88,6 @@ class SqliteEngineSpec(BaseEngineSpec):
@classmethod
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
) -> Set[str]:
"""Need to disregard the schema for Sqlite"""
return sorted(inspector.get_table_names())
return set(inspector.get_table_names())

View File

@ -543,7 +543,7 @@ class Database(
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[Tuple[str, str]]:
) -> Set[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
@ -553,13 +553,17 @@ class Database(
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
:param force: whether to force refresh the cache
:return: list of tables
:return: The table/schema pairs
"""
try:
tables = self.db_engine_spec.get_table_names(
database=self, inspector=self.inspector, schema=schema
)
return [(table, schema) for table in tables]
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=self.inspector,
schema=schema,
)
}
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
@ -573,7 +577,7 @@ class Database(
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[Tuple[str, str]]:
) -> Set[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
@ -583,13 +587,17 @@ class Database(
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
:param force: whether to force refresh the cache
:return: list of views
:return: set of views
"""
try:
views = self.db_engine_spec.get_view_names(
database=self, inspector=self.inspector, schema=schema
)
return [(view, schema) for view in views]
return {
(view, schema)
for view in self.db_engine_spec.get_view_names(
database=self,
inspector=self.inspector,
schema=schema,
)
}
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

View File

@ -1173,7 +1173,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
tables = security_manager.get_datasources_accessible_by_user(
database=database,
schema=schema_parsed,
datasource_names=[
datasource_names=sorted(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema_parsed,
@ -1181,13 +1181,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
],
),
)
views = security_manager.get_datasources_accessible_by_user(
database=database,
schema=schema_parsed,
datasource_names=[
datasource_names=sorted(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema_parsed,
@ -1195,7 +1195,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
],
),
)
except SupersetException as ex:
return json_error_response(ex.message, ex.status)

View File

@ -767,7 +767,7 @@ class TestDatasetApi(SupersetTestCase):
with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]
patch_get_view_names.return_value = {"test_case_view"}
self.login(username="admin")
table_data = {

View File

@ -229,11 +229,11 @@ class TestDbEngineSpecs(TestDbEngineSpec):
""" Make sure base engine spec removes schema name from table name
ie. when try_remove_schema_from_table_name == True. """
base_result_expected = ["table", "table_2"]
base_result_expected = {"table", "table_2"}
base_result = BaseEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(base_result_expected, base_result)
assert base_result_expected == base_result
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):

View File

@ -45,11 +45,11 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
pg_result_expected = ["schema.table", "table_2", "table_3"]
pg_result_expected = {"schema.table", "table_2", "table_3"}
pg_result = PostgresEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(pg_result_expected, pg_result)
assert pg_result_expected == pg_result
def test_time_exp_literal_no_grain(self):
"""

View File

@ -56,7 +56,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
).strip(),
{"schema": schema},
)
assert result == ["a", "d"]
assert result == {"a", "d"}
def test_get_view_names_without_schema(self):
database = mock.MagicMock()
@ -77,7 +77,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
).strip(),
{},
)
assert result == ["a", "d"]
assert result == {"a", "d"}
def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
@ -670,10 +670,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
mock_get_view_names,
mock_get_table_names,
):
mock_get_view_names.return_value = ["view1", "view2"]
mock_get_table_names.return_value = ["table1", "table2", "view1", "view2"]
mock_get_view_names.return_value = {"view1", "view2"}
mock_get_table_names.return_value = {"table1", "table2", "view1", "view2"}
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == ["table1", "table2"]
assert tables == {"table1", "table2"}
def test_get_full_name(self):
names = [