fix(sqla): use same template processor in all methods (#22280)

This commit is contained in:
Ville Brofeldt 2022-12-03 06:19:25 +02:00 committed by GitHub
parent 1c20206057
commit 1ad5147016
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 139 additions and 56 deletions

View File

@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from __future__ import annotations
import dataclasses
import json
import logging
@ -222,7 +224,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table: "SqlaTable" = relationship(
table: SqlaTable = relationship(
"SqlaTable",
backref=backref("columns", cascade="all, delete-orphan"),
foreign_keys=[table_id],
@ -301,14 +303,18 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
)
return column_spec.generic_type if column_spec else None
def get_sqla_col(self, label: Optional[str] = None) -> Column:
def get_sqla_col(
self,
label: Optional[str] = None,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
label = label or self.column_name
db_engine_spec = self.db_engine_spec
column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra)
type_ = column_spec.sqla_type if column_spec else None
if self.expression:
tp = self.table.get_template_processor()
expression = tp.process_template(self.expression)
if expression := self.expression:
if template_processor:
expression = template_processor.process_template(expression)
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
@ -324,8 +330,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
start_dttm: Optional[DateTime] = None,
end_dttm: Optional[DateTime] = None,
label: Optional[str] = "__time",
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
col = self.get_sqla_col(label=label)
col = self.get_sqla_col(label=label, template_processor=template_processor)
l = []
if start_dttm:
l.append(col >= self.table.text(self.dttm_sql_literal(start_dttm)))
@ -358,10 +365,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
if not self.expression and not time_grain and not is_epoch:
sqla_col = column(self.column_name, type_=type_)
return self.table.make_sqla_column_compatible(sqla_col, label)
if self.expression:
expression = self.expression
if expression := self.expression:
if template_processor:
expression = template_processor.process_template(self.expression)
expression = template_processor.process_template(expression)
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
@ -458,10 +464,17 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
def __repr__(self) -> str:
return str(self.metric_name)
def get_sqla_col(self, label: Optional[str] = None) -> Column:
def get_sqla_col(
self,
label: Optional[str] = None,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
label = label or self.metric_name
tp = self.table.get_template_processor()
sqla_col: ColumnClause = literal_column(tp.process_template(self.expression))
expression = self.expression
if template_processor:
expression = template_processor.process_template(expression)
sqla_col: ColumnClause = literal_column(expression)
return self.table.make_sqla_column_compatible(sqla_col, label)
@property
@ -650,7 +663,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
datasource_name: str,
schema: Optional[str],
database_name: str,
) -> Optional["SqlaTable"]:
) -> Optional[SqlaTable]:
schema = schema or None
query = (
session.query(cls)
@ -778,10 +791,17 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
except (TypeError, json.JSONDecodeError):
return {}
def get_fetch_values_predicate(self) -> TextClause:
tp = self.get_template_processor()
def get_fetch_values_predicate(
self,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> TextClause:
fetch_values_predicate = self.fetch_values_predicate
if template_processor:
fetch_values_predicate = template_processor.process_template(
fetch_values_predicate
)
try:
return self.text(tp.process_template(self.fetch_values_predicate))
return self.text(fetch_values_predicate)
except TemplateError as ex:
raise QueryObjectValidationError(
_(
@ -799,12 +819,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)
qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct()
qry = (
select([target_col.get_sqla_col(template_processor=tp)])
.select_from(tbl)
.distinct()
)
if limit:
qry = qry.limit(limit)
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
@ -936,7 +960,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
column_name = cast(str, metric_column.get("column_name"))
table_column: Optional[TableColumn] = columns_by_name.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col()
sqla_column = table_column.get_sqla_col(
template_processor=template_processor
)
else:
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
@ -975,7 +1001,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
col_in_metadata = self.get_column(expression)
if col_in_metadata:
sqla_column = col_in_metadata.get_sqla_col()
sqla_column = col_in_metadata.get_sqla_col(
template_processor=template_processor
)
is_dttm = col_in_metadata.is_temporal
else:
sqla_column = literal_column(expression)
@ -1190,7 +1218,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
)
elif isinstance(metric, str) and metric in metrics_by_name:
metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
metrics_exprs.append(
metrics_by_name[metric].get_sqla_col(
template_processor=template_processor
)
)
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=metric)
@ -1229,12 +1261,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
col = metrics_exprs_by_expr.get(str(col), col)
need_groupby = True
elif col in columns_by_name:
col = columns_by_name[col].get_sqla_col()
col = columns_by_name[col].get_sqla_col(
template_processor=template_processor
)
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]
need_groupby = True
elif col in metrics_by_name:
col = metrics_by_name[col].get_sqla_col()
col = metrics_by_name[col].get_sqla_col(
template_processor=template_processor
)
need_groupby = True
if isinstance(col, ColumnElement):
@ -1268,7 +1304,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
# if groupby field equals a selected column
elif selected in columns_by_name:
outer = columns_by_name[selected].get_sqla_col()
outer = columns_by_name[selected].get_sqla_col(
template_processor=template_processor
)
else:
selected = validate_adhoc_subquery(
selected,
@ -1302,7 +1340,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
self.schema,
)
select_exprs.append(
columns_by_name[selected].get_sqla_col()
columns_by_name[selected].get_sqla_col(
template_processor=template_processor
)
if isinstance(selected, str) and selected in columns_by_name
else self.make_sqla_column_compatible(
literal_column(selected), _column_label
@ -1336,11 +1376,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
):
time_filters.append(
columns_by_name[self.main_dttm_col].get_time_filter(
from_dttm,
to_dttm,
start_dttm=from_dttm,
end_dttm=to_dttm,
template_processor=template_processor,
)
)
time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
time_filters.append(
dttm_col.get_time_filter(
start_dttm=from_dttm,
end_dttm=to_dttm,
template_processor=template_processor,
)
)
# Always remove duplicates by column name, as sometimes `metrics_exprs`
# can have the same name as a groupby column (e.g. when users use
@ -1396,7 +1443,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
time_grain=filter_grain, template_processor=template_processor
)
elif col_obj:
sqla_col = col_obj.get_sqla_col()
sqla_col = col_obj.get_sqla_col(
template_processor=template_processor
)
col_type = col_obj.type if col_obj else None
col_spec = db_engine_spec.get_column_spec(
native_type=col_type,
@ -1521,6 +1570,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
start_dttm=_since,
end_dttm=_until,
label=sqla_col.key,
template_processor=template_processor,
)
)
else:
@ -1565,7 +1615,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
qry = qry.where(
self.get_fetch_values_predicate(template_processor=template_processor)
)
if granularity:
qry = qry.where(and_(*(time_filters + where_clause_and)))
else:
@ -1617,8 +1669,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if dttm_col and not db_engine_spec.time_groupby_inline:
inner_time_filter = [
dttm_col.get_time_filter(
inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
start_dttm=inner_from_dttm or from_dttm,
end_dttm=inner_to_dttm or to_dttm,
template_processor=template_processor,
)
]
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
@ -1627,7 +1680,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
ob = inner_main_metric_expr
if series_limit_metric:
ob = self._get_series_orderby(
series_limit_metric, metrics_by_name, columns_by_name
series_limit_metric=series_limit_metric,
metrics_by_name=metrics_by_name,
columns_by_name=columns_by_name,
template_processor=template_processor,
)
direction = desc if order_desc else asc
subq = subq.order_by(direction(ob))
@ -1647,9 +1703,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
orderby = [
(
self._get_series_orderby(
series_limit_metric,
metrics_by_name,
columns_by_name,
series_limit_metric=series_limit_metric,
metrics_by_name=metrics_by_name,
columns_by_name=columns_by_name,
template_processor=template_processor,
),
not order_desc,
)
@ -1709,6 +1766,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
series_limit_metric: Metric,
metrics_by_name: Dict[str, SqlMetric],
columns_by_name: Dict[str, TableColumn],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
@ -1717,7 +1775,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
isinstance(series_limit_metric, str)
and series_limit_metric in metrics_by_name
):
ob = metrics_by_name[series_limit_metric].get_sqla_col()
ob = metrics_by_name[series_limit_metric].get_sqla_col(
template_processor=template_processor
)
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=series_limit_metric)
@ -1930,7 +1990,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
database: Database,
datasource_name: str,
schema: Optional[str] = None,
) -> List["SqlaTable"]:
) -> List[SqlaTable]:
query = (
session.query(cls)
.filter_by(database_id=database.id)
@ -1947,7 +2007,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
database: Database,
permissions: Set[str],
schema_perms: Set[str],
) -> List["SqlaTable"]:
) -> List[SqlaTable]:
# TODO(hughhhh): add unit test
return (
session.query(cls)
@ -1964,7 +2024,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
@classmethod
def get_eager_sqlatable_datasource(
cls, session: Session, datasource_id: int
) -> "SqlaTable":
) -> SqlaTable:
"""Returns SqlaTable with columns and metrics."""
return (
session.query(cls)
@ -1977,7 +2037,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
@classmethod
def get_all_datasources(cls, session: Session) -> List["SqlaTable"]:
def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
qry = session.query(cls)
qry = cls.default_query(qry)
return qry.all()
@ -2038,7 +2098,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def before_update(
mapper: Mapper, # pylint: disable=unused-argument
connection: Connection, # pylint: disable=unused-argument
target: "SqlaTable",
target: SqlaTable,
) -> None:
"""
Check before update if the target table already exists.
@ -2110,7 +2170,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def after_insert(
mapper: Mapper,
connection: Connection,
sqla_table: "SqlaTable",
sqla_table: SqlaTable,
) -> None:
"""
Update dataset permissions after insert
@ -2124,7 +2184,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def after_delete(
mapper: Mapper,
connection: Connection,
sqla_table: "SqlaTable",
sqla_table: SqlaTable,
) -> None:
"""
Update dataset permissions after delete
@ -2135,7 +2195,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def after_update(
mapper: Mapper,
connection: Connection,
sqla_table: "SqlaTable",
sqla_table: SqlaTable,
) -> None:
"""
Update dataset permissions
@ -2170,7 +2230,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return
def write_shadow_dataset(
self: "SqlaTable",
self: SqlaTable,
) -> None:
"""
This method is deprecated

View File

@ -201,22 +201,34 @@ class TestDatabaseModel(SupersetTestCase):
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["user", "expr"],
"columns": [
"user",
"expr",
{
"hasCustomLabel": True,
"label": "adhoc_column",
"sqlExpression": "'{{ 'foo_' + time_grain }}'",
},
],
"metrics": [
{
"hasCustomLabel": True,
"label": "adhoc_metric",
"expressionType": AdhocMetricExpressionType.SQL,
"sqlExpression": "SUM(case when user = '{{ current_username() }}' "
"then 1 else 0 end)",
"label": "SUM(userid)",
}
"sqlExpression": "SUM(case when user = '{{ 'user_' + "
"current_username() }}' then 1 else 0 end)",
},
"count_timegrain",
],
"is_timeseries": False,
"filter": [],
"extras": {"time_grain_sqla": "P1D"},
}
table = SqlaTable(
table_name="test_has_jinja_metric_and_expr",
sql="SELECT '{{ current_username() }}' as user",
sql="SELECT '{{ 'user_' + current_username() }}' as user, "
"'{{ 'xyz_' + time_grain }}' as time_grain",
database=get_example_database(),
)
TableColumn(
@ -226,14 +238,25 @@ class TestDatabaseModel(SupersetTestCase):
type="VARCHAR(100)",
table=table,
)
SqlMetric(
metric_name="count_timegrain",
expression="count('{{ 'bar_' + time_grain }}')",
table=table,
)
db.session.commit()
sqla_query = table.get_sqla_query(**base_query_obj)
query = table.database.compile_sqla_query(sqla_query.sqla_query)
# assert expression
assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
# assert metric
assert "SUM(case when user = 'abc' then 1 else 0 end)" in query
# assert virtual dataset
assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query
# assert dataset calculated column
assert "case when 'abc' = 'abc' then 'yes' else 'no' end AS expr" in query
# assert adhoc column
assert "'foo_P1D'" in query
# assert dataset saved metric
assert "count('bar_P1D')" in query
# assert adhoc metric
assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query
# Cleanup
db.session.delete(table)
db.session.commit()