feat: method for dynamic `allows_alias_in_select` (#25882)

This commit is contained in:
Beto Dealmeida 2023-11-07 14:28:28 -05:00 committed by GitHub
parent 3ee22667a7
commit 80caba3fd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 14 deletions

View File

@ -398,6 +398,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Can the catalog be changed on a per-query basis?
supports_dynamic_catalog = False
@classmethod
def get_allows_alias_in_select(
cls, database: Database # pylint: disable=unused-argument
) -> bool:
"""
Method for dynamic `allows_alias_in_select`.
In Dremio this atribute is version-dependent, so Superset needs to inspect the
database configuration in order to determine it. This method allows engine-specs
to define dynamic values for the attribute.
"""
return cls.allows_alias_in_select
@classmethod
def supports_url(cls, url: URL) -> bool:
"""

View File

@ -14,14 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Optional
from __future__ import annotations
from datetime import datetime
from typing import Any, TYPE_CHECKING
from packaging.version import Version
from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec
if TYPE_CHECKING:
from superset.models.core import Database
# See https://github.com/apache/superset/pull/25657
FIXED_ALIAS_IN_SELECT_VERSION = Version("24.1.0")
class DremioEngineSpec(BaseEngineSpec):
engine = "dremio"
@ -43,10 +54,25 @@ class DremioEngineSpec(BaseEngineSpec):
def epoch_to_dttm(cls) -> str:
return "TO_DATE({col})"
@classmethod
def get_allows_alias_in_select(cls, database: Database) -> bool:
"""
Dremio supports aliases in SELECT statements since version 24.1.0.
If no version is specified in the DB extra, we assume the Dremio version is post
24.1.0. This way, as we move forward people don't have to specify a version when
setting up their databases.
"""
version = database.get_extra().get("version")
if version and Version(version) < FIXED_ALIAS_IN_SELECT_VERSION:
return False
return True
@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):

View File

@ -965,7 +965,7 @@ class Database(
"""
label_expected = label or sqla_col.name
# add quotes to tables
if self.db_engine_spec.allows_alias_in_select:
if self.db_engine_spec.get_allows_alias_in_select(self):
label = self.db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col.key = label_expected

View File

@ -765,7 +765,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise NotImplementedError()
@property
def database(self) -> builtins.type["Database"]:
def database(self) -> "Database":
raise NotImplementedError()
@property
@ -865,7 +865,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
label_expected = label or sqla_col.name
db_engine_spec = self.db_engine_spec
# add quotes to tables
if db_engine_spec.allows_alias_in_select:
if db_engine_spec.get_allows_alias_in_select(self.database):
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col.key = label_expected
@ -900,7 +900,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
self, query_obj: QueryObjectDict, mutate: bool = True
) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
if mutate:
@ -939,7 +939,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
value = value.item()
column_ = columns_by_name[dimension]
db_extra: dict[str, Any] = self.database.get_extra() # type: ignore
db_extra: dict[str, Any] = self.database.get_extra()
if isinstance(column_, dict):
if (
@ -1024,9 +1024,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return df
try:
df = self.database.get_df(
sql, self.schema, mutator=assign_column_label # type: ignore
)
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
except Exception as ex: # pylint: disable=broad-except
df = pd.DataFrame()
status = QueryStatus.FAILED
@ -1361,7 +1359,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if limit:
qry = qry.limit(limit)
with self.database.get_sqla_engine_with_context() as engine: # type: ignore
with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
@ -1958,7 +1956,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
col = col.element
if (
db_engine_spec.allows_alias_in_select
db_engine_spec.get_allows_alias_in_select(self.database)
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):

View File

@ -18,6 +18,7 @@ from datetime import datetime
from typing import Optional
import pytest
from pytest_mock import MockerFixture
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@ -40,3 +41,18 @@ def test_convert_dttm(
from superset.db_engine_specs.dremio import DremioEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_get_allows_alias_in_select(mocker: MockerFixture) -> None:
from superset.db_engine_specs.dremio import DremioEngineSpec
database = mocker.MagicMock()
database.get_extra.return_value = {}
assert DremioEngineSpec.get_allows_alias_in_select(database) is True
database.get_extra.return_value = {"version": "24.1.0"}
assert DremioEngineSpec.get_allows_alias_in_select(database) is True
database.get_extra.return_value = {"version": "24.0.0"}
assert DremioEngineSpec.get_allows_alias_in_select(database) is False