chore(db_engine_specs): Refactor get_index (#23656)

This commit is contained in:
John Bodley 2023-04-13 09:23:16 +12:00 committed by GitHub
parent 976e33330f
commit b35b5a6e05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 149 additions and 24 deletions

View File

@ -43,6 +43,7 @@ import pandas as pd
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask import current_app
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
@ -797,6 +798,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@classmethod
@deprecated(deprecated_in="3.0")
def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
@ -1179,6 +1181,26 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
views = {re.sub(f"^{schema}\\.", "", view) for view in views}
return views
@classmethod
def get_indexes(
cls,
database: Database, # pylint: disable=unused-argument
inspector: Inspector,
table_name: str,
schema: Optional[str],
) -> List[Dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param table_name: The table to inspect
:param schema: The schema to inspect
:returns: The indexes
"""
return inspector.get_indexes(table_name, schema)
@classmethod
def get_table_comment(
cls, inspector: Inspector, table_name: str, schema: Optional[str]

View File

@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional, Pattern, Tuple, Type, TYPE_CHECKIN
import pandas as pd
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
@ -278,6 +279,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return "_" + md5_sha_from_str(label)
@classmethod
@deprecated(deprecated_in="3.0")
def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
@ -296,6 +298,26 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
normalized_idxs.append(ix)
return normalized_idxs
@classmethod
def get_indexes(
cls,
database: "Database",
inspector: Inspector,
table_name: str,
schema: Optional[str],
) -> List[Dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param table_name: The table to inspect
:param schema: The schema to inspect
:returns: The indexes
"""
return cls.normalize_indexes(inspector.get_indexes(table_name, schema))
@classmethod
def extra_table_metadata(
cls, database: "Database", table_name: str, schema_name: Optional[str]

View File

@ -561,10 +561,18 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
)
column_names = indexes[0]["column_names"]
part_fields = [(column_name, True) for column_name in column_names]
sql = cls._partition_query(table_name, database, 1, part_fields)
df = database.get_df(sql, schema)
return column_names, cls._latest_partition_from_df(df)
return column_names, cls._latest_partition_from_df(
df=database.get_df(
sql=cls._partition_query(
table_name,
database,
limit=1,
order_by=[(column_name, True) for column_name in column_names],
),
schema=schema,
)
)
@classmethod
def latest_sub_partition(

View File

@ -847,8 +847,7 @@ class Database(
self, table_name: str, schema: Optional[str] = None
) -> List[Dict[str, Any]]:
with self.get_inspector_with_context() as inspector:
indexes = inspector.get_indexes(table_name, schema)
return self.db_engine_spec.normalize_indexes(indexes)
return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
def get_pk_constraint(
self, table_name: str, schema: Optional[str] = None

View File

@ -521,3 +521,26 @@ def test_validate_parameters_port_closed(is_port_open, is_hostname_valid):
},
)
]
def test_get_indexes():
indexes = [
{
"name": "partition",
"column_names": ["a", "b"],
"unique": False,
},
]
inspector = mock.Mock()
inspector.get_indexes = mock.Mock(return_value=indexes)
assert (
BaseEngineSpec.get_indexes(
database=mock.Mock(),
inspector=inspector,
table_name="bar",
schema="foo",
)
== indexes
)

View File

@ -144,27 +144,78 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
)
self.assertEqual(result, expected_result)
def test_normalize_indexes(self):
"""
DB Eng Specs (bigquery): Test extra table metadata
"""
indexes = [{"name": "partition", "column_names": [None], "unique": False}]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(normalized_idx, [])
def test_get_indexes(self):
database = mock.Mock()
inspector = mock.Mock()
schema = "foo"
table_name = "bar"
indexes = [{"name": "partition", "column_names": ["dttm"], "unique": False}]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(normalized_idx, indexes)
indexes = [
{"name": "partition", "column_names": ["dttm", None], "unique": False}
]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(
normalized_idx,
[{"name": "partition", "column_names": ["dttm"], "unique": False}],
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": [None],
"unique": False,
}
]
)
assert (
BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
)
== []
)
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
)
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm", None],
"unique": False,
}
]
)
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
@mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")
@mock.patch("superset.db_engine_specs.bigquery.service_account")