feat: method for dynamic `allows_alias_in_select` (#25882)
This commit is contained in:
parent
3ee22667a7
commit
80caba3fd1
|
|
@ -398,6 +398,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
# Can the catalog be changed on a per-query basis?
|
||||
supports_dynamic_catalog = False
|
||||
|
||||
@classmethod
|
||||
def get_allows_alias_in_select(
|
||||
cls, database: Database # pylint: disable=unused-argument
|
||||
) -> bool:
|
||||
"""
|
||||
Method for dynamic `allows_alias_in_select`.
|
||||
|
||||
In Dremio this atribute is version-dependent, so Superset needs to inspect the
|
||||
database configuration in order to determine it. This method allows engine-specs
|
||||
to define dynamic values for the attribute.
|
||||
"""
|
||||
return cls.allows_alias_in_select
|
||||
|
||||
@classmethod
|
||||
def supports_url(cls, url: URL) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -14,14 +14,25 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from packaging.version import Version
|
||||
from sqlalchemy import types
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
# See https://github.com/apache/superset/pull/25657
|
||||
FIXED_ALIAS_IN_SELECT_VERSION = Version("24.1.0")
|
||||
|
||||
|
||||
class DremioEngineSpec(BaseEngineSpec):
|
||||
engine = "dremio"
|
||||
|
|
@ -43,10 +54,25 @@ class DremioEngineSpec(BaseEngineSpec):
|
|||
def epoch_to_dttm(cls) -> str:
|
||||
return "TO_DATE({col})"
|
||||
|
||||
@classmethod
|
||||
def get_allows_alias_in_select(cls, database: Database) -> bool:
|
||||
"""
|
||||
Dremio supports aliases in SELECT statements since version 24.1.0.
|
||||
|
||||
If no version is specified in the DB extra, we assume the Dremio version is post
|
||||
24.1.0. This way, as we move forward people don't have to specify a version when
|
||||
setting up their databases.
|
||||
"""
|
||||
version = database.get_extra().get("version")
|
||||
if version and Version(version) < FIXED_ALIAS_IN_SELECT_VERSION:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(
|
||||
cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
|
||||
) -> str | None:
|
||||
sqla_type = cls.get_sqla_column_type(target_type)
|
||||
|
||||
if isinstance(sqla_type, types.Date):
|
||||
|
|
|
|||
|
|
@ -965,7 +965,7 @@ class Database(
|
|||
"""
|
||||
label_expected = label or sqla_col.name
|
||||
# add quotes to tables
|
||||
if self.db_engine_spec.allows_alias_in_select:
|
||||
if self.db_engine_spec.get_allows_alias_in_select(self):
|
||||
label = self.db_engine_spec.make_label_compatible(label_expected)
|
||||
sqla_col = sqla_col.label(label)
|
||||
sqla_col.key = label_expected
|
||||
|
|
|
|||
|
|
@ -765,7 +765,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def database(self) -> builtins.type["Database"]:
|
||||
def database(self) -> "Database":
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
|
|
@ -865,7 +865,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
label_expected = label or sqla_col.name
|
||||
db_engine_spec = self.db_engine_spec
|
||||
# add quotes to tables
|
||||
if db_engine_spec.allows_alias_in_select:
|
||||
if db_engine_spec.get_allows_alias_in_select(self.database):
|
||||
label = db_engine_spec.make_label_compatible(label_expected)
|
||||
sqla_col = sqla_col.label(label)
|
||||
sqla_col.key = label_expected
|
||||
|
|
@ -900,7 +900,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
self, query_obj: QueryObjectDict, mutate: bool = True
|
||||
) -> QueryStringExtended:
|
||||
sqlaq = self.get_sqla_query(**query_obj)
|
||||
sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore
|
||||
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
|
||||
sql = self._apply_cte(sql, sqlaq.cte)
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
if mutate:
|
||||
|
|
@ -939,7 +939,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
value = value.item()
|
||||
|
||||
column_ = columns_by_name[dimension]
|
||||
db_extra: dict[str, Any] = self.database.get_extra() # type: ignore
|
||||
db_extra: dict[str, Any] = self.database.get_extra()
|
||||
|
||||
if isinstance(column_, dict):
|
||||
if (
|
||||
|
|
@ -1024,9 +1024,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
return df
|
||||
|
||||
try:
|
||||
df = self.database.get_df(
|
||||
sql, self.schema, mutator=assign_column_label # type: ignore
|
||||
)
|
||||
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
df = pd.DataFrame()
|
||||
status = QueryStatus.FAILED
|
||||
|
|
@ -1361,7 +1359,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
with self.database.get_sqla_engine_with_context() as engine: # type: ignore
|
||||
with self.database.get_sqla_engine_with_context() as engine:
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
sql = self._apply_cte(sql, cte)
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
|
|
@ -1958,7 +1956,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
col = col.element
|
||||
|
||||
if (
|
||||
db_engine_spec.allows_alias_in_select
|
||||
db_engine_spec.get_allows_alias_in_select(self.database)
|
||||
and db_engine_spec.allows_hidden_cc_in_orderby
|
||||
and col.name in [select_col.name for select_col in select_exprs]
|
||||
):
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from datetime import datetime
|
|||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
||||
from tests.unit_tests.fixtures.common import dttm
|
||||
|
|
@ -40,3 +41,18 @@ def test_convert_dttm(
|
|||
from superset.db_engine_specs.dremio import DremioEngineSpec as spec
|
||||
|
||||
assert_convert_dttm(spec, target_type, expected_result, dttm)
|
||||
|
||||
|
||||
def test_get_allows_alias_in_select(mocker: MockerFixture) -> None:
|
||||
from superset.db_engine_specs.dremio import DremioEngineSpec
|
||||
|
||||
database = mocker.MagicMock()
|
||||
|
||||
database.get_extra.return_value = {}
|
||||
assert DremioEngineSpec.get_allows_alias_in_select(database) is True
|
||||
|
||||
database.get_extra.return_value = {"version": "24.1.0"}
|
||||
assert DremioEngineSpec.get_allows_alias_in_select(database) is True
|
||||
|
||||
database.get_extra.return_value = {"version": "24.0.0"}
|
||||
assert DremioEngineSpec.get_allows_alias_in_select(database) is False
|
||||
|
|
|
|||
Loading…
Reference in New Issue