fix(sqla): use same template processor in all methods (#22280)
This commit is contained in:
parent
1c20206057
commit
1ad5147016
|
|
@ -15,6 +15,8 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=too-many-lines
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
|
|
@ -222,7 +224,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
__tablename__ = "table_columns"
|
||||
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
|
||||
table_id = Column(Integer, ForeignKey("tables.id"))
|
||||
table: "SqlaTable" = relationship(
|
||||
table: SqlaTable = relationship(
|
||||
"SqlaTable",
|
||||
backref=backref("columns", cascade="all, delete-orphan"),
|
||||
foreign_keys=[table_id],
|
||||
|
|
@ -301,14 +303,18 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
)
|
||||
return column_spec.generic_type if column_spec else None
|
||||
|
||||
def get_sqla_col(self, label: Optional[str] = None) -> Column:
|
||||
def get_sqla_col(
|
||||
self,
|
||||
label: Optional[str] = None,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> Column:
|
||||
label = label or self.column_name
|
||||
db_engine_spec = self.db_engine_spec
|
||||
column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra)
|
||||
type_ = column_spec.sqla_type if column_spec else None
|
||||
if self.expression:
|
||||
tp = self.table.get_template_processor()
|
||||
expression = tp.process_template(self.expression)
|
||||
if expression := self.expression:
|
||||
if template_processor:
|
||||
expression = template_processor.process_template(expression)
|
||||
col = literal_column(expression, type_=type_)
|
||||
else:
|
||||
col = column(self.column_name, type_=type_)
|
||||
|
|
@ -324,8 +330,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
start_dttm: Optional[DateTime] = None,
|
||||
end_dttm: Optional[DateTime] = None,
|
||||
label: Optional[str] = "__time",
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> ColumnElement:
|
||||
col = self.get_sqla_col(label=label)
|
||||
col = self.get_sqla_col(label=label, template_processor=template_processor)
|
||||
l = []
|
||||
if start_dttm:
|
||||
l.append(col >= self.table.text(self.dttm_sql_literal(start_dttm)))
|
||||
|
|
@ -358,10 +365,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
if not self.expression and not time_grain and not is_epoch:
|
||||
sqla_col = column(self.column_name, type_=type_)
|
||||
return self.table.make_sqla_column_compatible(sqla_col, label)
|
||||
if self.expression:
|
||||
expression = self.expression
|
||||
if expression := self.expression:
|
||||
if template_processor:
|
||||
expression = template_processor.process_template(self.expression)
|
||||
expression = template_processor.process_template(expression)
|
||||
col = literal_column(expression, type_=type_)
|
||||
else:
|
||||
col = column(self.column_name, type_=type_)
|
||||
|
|
@ -458,10 +464,17 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
|||
def __repr__(self) -> str:
|
||||
return str(self.metric_name)
|
||||
|
||||
def get_sqla_col(self, label: Optional[str] = None) -> Column:
|
||||
def get_sqla_col(
|
||||
self,
|
||||
label: Optional[str] = None,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> Column:
|
||||
label = label or self.metric_name
|
||||
tp = self.table.get_template_processor()
|
||||
sqla_col: ColumnClause = literal_column(tp.process_template(self.expression))
|
||||
expression = self.expression
|
||||
if template_processor:
|
||||
expression = template_processor.process_template(expression)
|
||||
|
||||
sqla_col: ColumnClause = literal_column(expression)
|
||||
return self.table.make_sqla_column_compatible(sqla_col, label)
|
||||
|
||||
@property
|
||||
|
|
@ -650,7 +663,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
datasource_name: str,
|
||||
schema: Optional[str],
|
||||
database_name: str,
|
||||
) -> Optional["SqlaTable"]:
|
||||
) -> Optional[SqlaTable]:
|
||||
schema = schema or None
|
||||
query = (
|
||||
session.query(cls)
|
||||
|
|
@ -778,10 +791,17 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
except (TypeError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
def get_fetch_values_predicate(self) -> TextClause:
|
||||
tp = self.get_template_processor()
|
||||
def get_fetch_values_predicate(
|
||||
self,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> TextClause:
|
||||
fetch_values_predicate = self.fetch_values_predicate
|
||||
if template_processor:
|
||||
fetch_values_predicate = template_processor.process_template(
|
||||
fetch_values_predicate
|
||||
)
|
||||
try:
|
||||
return self.text(tp.process_template(self.fetch_values_predicate))
|
||||
return self.text(fetch_values_predicate)
|
||||
except TemplateError as ex:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
|
|
@ -799,12 +819,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
tp = self.get_template_processor()
|
||||
tbl, cte = self.get_from_clause(tp)
|
||||
|
||||
qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct()
|
||||
qry = (
|
||||
select([target_col.get_sqla_col(template_processor=tp)])
|
||||
.select_from(tbl)
|
||||
.distinct()
|
||||
)
|
||||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
if self.fetch_values_predicate:
|
||||
qry = qry.where(self.get_fetch_values_predicate())
|
||||
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
|
||||
|
||||
with self.database.get_sqla_engine_with_context() as engine:
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
|
|
@ -936,7 +960,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
column_name = cast(str, metric_column.get("column_name"))
|
||||
table_column: Optional[TableColumn] = columns_by_name.get(column_name)
|
||||
if table_column:
|
||||
sqla_column = table_column.get_sqla_col()
|
||||
sqla_column = table_column.get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
else:
|
||||
sqla_column = column(column_name)
|
||||
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
|
||||
|
|
@ -975,7 +1001,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
)
|
||||
col_in_metadata = self.get_column(expression)
|
||||
if col_in_metadata:
|
||||
sqla_column = col_in_metadata.get_sqla_col()
|
||||
sqla_column = col_in_metadata.get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
is_dttm = col_in_metadata.is_temporal
|
||||
else:
|
||||
sqla_column = literal_column(expression)
|
||||
|
|
@ -1190,7 +1218,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
)
|
||||
)
|
||||
elif isinstance(metric, str) and metric in metrics_by_name:
|
||||
metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
|
||||
metrics_exprs.append(
|
||||
metrics_by_name[metric].get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise QueryObjectValidationError(
|
||||
_("Metric '%(metric)s' does not exist", metric=metric)
|
||||
|
|
@ -1229,12 +1261,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
col = metrics_exprs_by_expr.get(str(col), col)
|
||||
need_groupby = True
|
||||
elif col in columns_by_name:
|
||||
col = columns_by_name[col].get_sqla_col()
|
||||
col = columns_by_name[col].get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
elif col in metrics_exprs_by_label:
|
||||
col = metrics_exprs_by_label[col]
|
||||
need_groupby = True
|
||||
elif col in metrics_by_name:
|
||||
col = metrics_by_name[col].get_sqla_col()
|
||||
col = metrics_by_name[col].get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
need_groupby = True
|
||||
|
||||
if isinstance(col, ColumnElement):
|
||||
|
|
@ -1268,7 +1304,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
)
|
||||
# if groupby field equals a selected column
|
||||
elif selected in columns_by_name:
|
||||
outer = columns_by_name[selected].get_sqla_col()
|
||||
outer = columns_by_name[selected].get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
else:
|
||||
selected = validate_adhoc_subquery(
|
||||
selected,
|
||||
|
|
@ -1302,7 +1340,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
self.schema,
|
||||
)
|
||||
select_exprs.append(
|
||||
columns_by_name[selected].get_sqla_col()
|
||||
columns_by_name[selected].get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
if isinstance(selected, str) and selected in columns_by_name
|
||||
else self.make_sqla_column_compatible(
|
||||
literal_column(selected), _column_label
|
||||
|
|
@ -1336,11 +1376,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
):
|
||||
time_filters.append(
|
||||
columns_by_name[self.main_dttm_col].get_time_filter(
|
||||
from_dttm,
|
||||
to_dttm,
|
||||
start_dttm=from_dttm,
|
||||
end_dttm=to_dttm,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
)
|
||||
time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
|
||||
time_filters.append(
|
||||
dttm_col.get_time_filter(
|
||||
start_dttm=from_dttm,
|
||||
end_dttm=to_dttm,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
)
|
||||
|
||||
# Always remove duplicates by column name, as sometimes `metrics_exprs`
|
||||
# can have the same name as a groupby column (e.g. when users use
|
||||
|
|
@ -1396,7 +1443,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
time_grain=filter_grain, template_processor=template_processor
|
||||
)
|
||||
elif col_obj:
|
||||
sqla_col = col_obj.get_sqla_col()
|
||||
sqla_col = col_obj.get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
col_type = col_obj.type if col_obj else None
|
||||
col_spec = db_engine_spec.get_column_spec(
|
||||
native_type=col_type,
|
||||
|
|
@ -1521,6 +1570,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
start_dttm=_since,
|
||||
end_dttm=_until,
|
||||
label=sqla_col.key,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -1565,7 +1615,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
having_clause_and += [self.text(having)]
|
||||
|
||||
if apply_fetch_values_predicate and self.fetch_values_predicate:
|
||||
qry = qry.where(self.get_fetch_values_predicate())
|
||||
qry = qry.where(
|
||||
self.get_fetch_values_predicate(template_processor=template_processor)
|
||||
)
|
||||
if granularity:
|
||||
qry = qry.where(and_(*(time_filters + where_clause_and)))
|
||||
else:
|
||||
|
|
@ -1617,8 +1669,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
if dttm_col and not db_engine_spec.time_groupby_inline:
|
||||
inner_time_filter = [
|
||||
dttm_col.get_time_filter(
|
||||
inner_from_dttm or from_dttm,
|
||||
inner_to_dttm or to_dttm,
|
||||
start_dttm=inner_from_dttm or from_dttm,
|
||||
end_dttm=inner_to_dttm or to_dttm,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
]
|
||||
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
|
||||
|
|
@ -1627,7 +1680,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
ob = inner_main_metric_expr
|
||||
if series_limit_metric:
|
||||
ob = self._get_series_orderby(
|
||||
series_limit_metric, metrics_by_name, columns_by_name
|
||||
series_limit_metric=series_limit_metric,
|
||||
metrics_by_name=metrics_by_name,
|
||||
columns_by_name=columns_by_name,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
direction = desc if order_desc else asc
|
||||
subq = subq.order_by(direction(ob))
|
||||
|
|
@ -1647,9 +1703,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
orderby = [
|
||||
(
|
||||
self._get_series_orderby(
|
||||
series_limit_metric,
|
||||
metrics_by_name,
|
||||
columns_by_name,
|
||||
series_limit_metric=series_limit_metric,
|
||||
metrics_by_name=metrics_by_name,
|
||||
columns_by_name=columns_by_name,
|
||||
template_processor=template_processor,
|
||||
),
|
||||
not order_desc,
|
||||
)
|
||||
|
|
@ -1709,6 +1766,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
series_limit_metric: Metric,
|
||||
metrics_by_name: Dict[str, SqlMetric],
|
||||
columns_by_name: Dict[str, TableColumn],
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> Column:
|
||||
if utils.is_adhoc_metric(series_limit_metric):
|
||||
assert isinstance(series_limit_metric, dict)
|
||||
|
|
@ -1717,7 +1775,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
isinstance(series_limit_metric, str)
|
||||
and series_limit_metric in metrics_by_name
|
||||
):
|
||||
ob = metrics_by_name[series_limit_metric].get_sqla_col()
|
||||
ob = metrics_by_name[series_limit_metric].get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
else:
|
||||
raise QueryObjectValidationError(
|
||||
_("Metric '%(metric)s' does not exist", metric=series_limit_metric)
|
||||
|
|
@ -1930,7 +1990,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
database: Database,
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
) -> List["SqlaTable"]:
|
||||
) -> List[SqlaTable]:
|
||||
query = (
|
||||
session.query(cls)
|
||||
.filter_by(database_id=database.id)
|
||||
|
|
@ -1947,7 +2007,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
database: Database,
|
||||
permissions: Set[str],
|
||||
schema_perms: Set[str],
|
||||
) -> List["SqlaTable"]:
|
||||
) -> List[SqlaTable]:
|
||||
# TODO(hughhhh): add unit test
|
||||
return (
|
||||
session.query(cls)
|
||||
|
|
@ -1964,7 +2024,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
@classmethod
|
||||
def get_eager_sqlatable_datasource(
|
||||
cls, session: Session, datasource_id: int
|
||||
) -> "SqlaTable":
|
||||
) -> SqlaTable:
|
||||
"""Returns SqlaTable with columns and metrics."""
|
||||
return (
|
||||
session.query(cls)
|
||||
|
|
@ -1977,7 +2037,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_datasources(cls, session: Session) -> List["SqlaTable"]:
|
||||
def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
|
||||
qry = session.query(cls)
|
||||
qry = cls.default_query(qry)
|
||||
return qry.all()
|
||||
|
|
@ -2038,7 +2098,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
def before_update(
|
||||
mapper: Mapper, # pylint: disable=unused-argument
|
||||
connection: Connection, # pylint: disable=unused-argument
|
||||
target: "SqlaTable",
|
||||
target: SqlaTable,
|
||||
) -> None:
|
||||
"""
|
||||
Check before update if the target table already exists.
|
||||
|
|
@ -2110,7 +2170,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
def after_insert(
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
sqla_table: "SqlaTable",
|
||||
sqla_table: SqlaTable,
|
||||
) -> None:
|
||||
"""
|
||||
Update dataset permissions after insert
|
||||
|
|
@ -2124,7 +2184,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
def after_delete(
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
sqla_table: "SqlaTable",
|
||||
sqla_table: SqlaTable,
|
||||
) -> None:
|
||||
"""
|
||||
Update dataset permissions after delete
|
||||
|
|
@ -2135,7 +2195,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
def after_update(
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
sqla_table: "SqlaTable",
|
||||
sqla_table: SqlaTable,
|
||||
) -> None:
|
||||
"""
|
||||
Update dataset permissions
|
||||
|
|
@ -2170,7 +2230,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
return
|
||||
|
||||
def write_shadow_dataset(
|
||||
self: "SqlaTable",
|
||||
self: SqlaTable,
|
||||
) -> None:
|
||||
"""
|
||||
This method is deprecated
|
||||
|
|
|
|||
|
|
@ -201,22 +201,34 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
"granularity": None,
|
||||
"from_dttm": None,
|
||||
"to_dttm": None,
|
||||
"groupby": ["user", "expr"],
|
||||
"columns": [
|
||||
"user",
|
||||
"expr",
|
||||
{
|
||||
"hasCustomLabel": True,
|
||||
"label": "adhoc_column",
|
||||
"sqlExpression": "'{{ 'foo_' + time_grain }}'",
|
||||
},
|
||||
],
|
||||
"metrics": [
|
||||
{
|
||||
"hasCustomLabel": True,
|
||||
"label": "adhoc_metric",
|
||||
"expressionType": AdhocMetricExpressionType.SQL,
|
||||
"sqlExpression": "SUM(case when user = '{{ current_username() }}' "
|
||||
"then 1 else 0 end)",
|
||||
"label": "SUM(userid)",
|
||||
}
|
||||
"sqlExpression": "SUM(case when user = '{{ 'user_' + "
|
||||
"current_username() }}' then 1 else 0 end)",
|
||||
},
|
||||
"count_timegrain",
|
||||
],
|
||||
"is_timeseries": False,
|
||||
"filter": [],
|
||||
"extras": {"time_grain_sqla": "P1D"},
|
||||
}
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="test_has_jinja_metric_and_expr",
|
||||
sql="SELECT '{{ current_username() }}' as user",
|
||||
sql="SELECT '{{ 'user_' + current_username() }}' as user, "
|
||||
"'{{ 'xyz_' + time_grain }}' as time_grain",
|
||||
database=get_example_database(),
|
||||
)
|
||||
TableColumn(
|
||||
|
|
@ -226,14 +238,25 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
type="VARCHAR(100)",
|
||||
table=table,
|
||||
)
|
||||
SqlMetric(
|
||||
metric_name="count_timegrain",
|
||||
expression="count('{{ 'bar_' + time_grain }}')",
|
||||
table=table,
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
sqla_query = table.get_sqla_query(**base_query_obj)
|
||||
query = table.database.compile_sqla_query(sqla_query.sqla_query)
|
||||
# assert expression
|
||||
assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
|
||||
# assert metric
|
||||
assert "SUM(case when user = 'abc' then 1 else 0 end)" in query
|
||||
# assert virtual dataset
|
||||
assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query
|
||||
# assert dataset calculated column
|
||||
assert "case when 'abc' = 'abc' then 'yes' else 'no' end AS expr" in query
|
||||
# assert adhoc column
|
||||
assert "'foo_P1D'" in query
|
||||
# assert dataset saved metric
|
||||
assert "count('bar_P1D')" in query
|
||||
# assert adhoc metric
|
||||
assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query
|
||||
# Cleanup
|
||||
db.session.delete(table)
|
||||
db.session.commit()
|
||||
|
|
|
|||
Loading…
Reference in New Issue