chore(db_engine_specs): Refactor get_index (#23656)
This commit is contained in:
parent
976e33330f
commit
b35b5a6e05
|
|
@ -43,6 +43,7 @@ import pandas as pd
|
|||
import sqlparse
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from deprecation import deprecated
|
||||
from flask import current_app
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
|
|
@ -797,6 +798,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
@deprecated(deprecated_in="3.0")
|
||||
def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Normalizes indexes for more consistency across db engines
|
||||
|
|
@ -1179,6 +1181,26 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
views = {re.sub(f"^{schema}\\.", "", view) for view in views}
|
||||
return views
|
||||
|
||||
@classmethod
|
||||
def get_indexes(
|
||||
cls,
|
||||
database: Database, # pylint: disable=unused-argument
|
||||
inspector: Inspector,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get the indexes associated with the specified schema/table.
|
||||
|
||||
:param database: The database to inspect
|
||||
:param inspector: The SQLAlchemy inspector
|
||||
:param table_name: The table to inspect
|
||||
:param schema: The schema to inspect
|
||||
:returns: The indexes
|
||||
"""
|
||||
|
||||
return inspector.get_indexes(table_name, schema)
|
||||
|
||||
@classmethod
|
||||
def get_table_comment(
|
||||
cls, inspector: Inspector, table_name: str, schema: Optional[str]
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional, Pattern, Tuple, Type, TYPE_CHECKIN
|
|||
import pandas as pd
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from deprecation import deprecated
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
|
@ -278,6 +279,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
return "_" + md5_sha_from_str(label)
|
||||
|
||||
@classmethod
|
||||
@deprecated(deprecated_in="3.0")
|
||||
def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Normalizes indexes for more consistency across db engines
|
||||
|
|
@ -296,6 +298,26 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
normalized_idxs.append(ix)
|
||||
return normalized_idxs
|
||||
|
||||
@classmethod
|
||||
def get_indexes(
|
||||
cls,
|
||||
database: "Database",
|
||||
inspector: Inspector,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get the indexes associated with the specified schema/table.
|
||||
|
||||
:param database: The database to inspect
|
||||
:param inspector: The SQLAlchemy inspector
|
||||
:param table_name: The table to inspect
|
||||
:param schema: The schema to inspect
|
||||
:returns: The indexes
|
||||
"""
|
||||
|
||||
return cls.normalize_indexes(inspector.get_indexes(table_name, schema))
|
||||
|
||||
@classmethod
|
||||
def extra_table_metadata(
|
||||
cls, database: "Database", table_name: str, schema_name: Optional[str]
|
||||
|
|
|
|||
|
|
@ -561,10 +561,18 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||
)
|
||||
|
||||
column_names = indexes[0]["column_names"]
|
||||
part_fields = [(column_name, True) for column_name in column_names]
|
||||
sql = cls._partition_query(table_name, database, 1, part_fields)
|
||||
df = database.get_df(sql, schema)
|
||||
return column_names, cls._latest_partition_from_df(df)
|
||||
|
||||
return column_names, cls._latest_partition_from_df(
|
||||
df=database.get_df(
|
||||
sql=cls._partition_query(
|
||||
table_name,
|
||||
database,
|
||||
limit=1,
|
||||
order_by=[(column_name, True) for column_name in column_names],
|
||||
),
|
||||
schema=schema,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def latest_sub_partition(
|
||||
|
|
|
|||
|
|
@ -847,8 +847,7 @@ class Database(
|
|||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
indexes = inspector.get_indexes(table_name, schema)
|
||||
return self.db_engine_spec.normalize_indexes(indexes)
|
||||
return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
|
||||
|
||||
def get_pk_constraint(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -521,3 +521,26 @@ def test_validate_parameters_port_closed(is_port_open, is_hostname_valid):
|
|||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_get_indexes():
|
||||
indexes = [
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["a", "b"],
|
||||
"unique": False,
|
||||
},
|
||||
]
|
||||
|
||||
inspector = mock.Mock()
|
||||
inspector.get_indexes = mock.Mock(return_value=indexes)
|
||||
|
||||
assert (
|
||||
BaseEngineSpec.get_indexes(
|
||||
database=mock.Mock(),
|
||||
inspector=inspector,
|
||||
table_name="bar",
|
||||
schema="foo",
|
||||
)
|
||||
== indexes
|
||||
)
|
||||
|
|
|
|||
|
|
@ -144,27 +144,78 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
|||
)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_normalize_indexes(self):
|
||||
"""
|
||||
DB Eng Specs (bigquery): Test extra table metadata
|
||||
"""
|
||||
indexes = [{"name": "partition", "column_names": [None], "unique": False}]
|
||||
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
|
||||
self.assertEqual(normalized_idx, [])
|
||||
def test_get_indexes(self):
|
||||
database = mock.Mock()
|
||||
inspector = mock.Mock()
|
||||
schema = "foo"
|
||||
table_name = "bar"
|
||||
|
||||
indexes = [{"name": "partition", "column_names": ["dttm"], "unique": False}]
|
||||
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
|
||||
self.assertEqual(normalized_idx, indexes)
|
||||
|
||||
indexes = [
|
||||
{"name": "partition", "column_names": ["dttm", None], "unique": False}
|
||||
]
|
||||
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
|
||||
self.assertEqual(
|
||||
normalized_idx,
|
||||
[{"name": "partition", "column_names": ["dttm"], "unique": False}],
|
||||
inspector.get_indexes = mock.Mock(
|
||||
return_value=[
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": [None],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert (
|
||||
BigQueryEngineSpec.get_indexes(
|
||||
database,
|
||||
inspector,
|
||||
table_name,
|
||||
schema,
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
inspector.get_indexes = mock.Mock(
|
||||
return_value=[
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["dttm"],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert BigQueryEngineSpec.get_indexes(
|
||||
database,
|
||||
inspector,
|
||||
table_name,
|
||||
schema,
|
||||
) == [
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["dttm"],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
|
||||
inspector.get_indexes = mock.Mock(
|
||||
return_value=[
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["dttm", None],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert BigQueryEngineSpec.get_indexes(
|
||||
database,
|
||||
inspector,
|
||||
table_name,
|
||||
schema,
|
||||
) == [
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["dttm"],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
|
||||
@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
|
||||
@mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")
|
||||
@mock.patch("superset.db_engine_specs.bigquery.service_account")
|
||||
|
|
|
|||
Loading…
Reference in New Issue