diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a67e686ff..fd5942c51 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=too-many-lines +from __future__ import annotations + import dataclasses import json import logging @@ -222,7 +224,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin): __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"),) table_id = Column(Integer, ForeignKey("tables.id")) - table: "SqlaTable" = relationship( + table: SqlaTable = relationship( "SqlaTable", backref=backref("columns", cascade="all, delete-orphan"), foreign_keys=[table_id], @@ -301,14 +303,18 @@ class TableColumn(Model, BaseColumn, CertificationMixin): ) return column_spec.generic_type if column_spec else None - def get_sqla_col(self, label: Optional[str] = None) -> Column: + def get_sqla_col( + self, + label: Optional[str] = None, + template_processor: Optional[BaseTemplateProcessor] = None, + ) -> Column: label = label or self.column_name db_engine_spec = self.db_engine_spec column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra) type_ = column_spec.sqla_type if column_spec else None - if self.expression: - tp = self.table.get_template_processor() - expression = tp.process_template(self.expression) + if expression := self.expression: + if template_processor: + expression = template_processor.process_template(expression) col = literal_column(expression, type_=type_) else: col = column(self.column_name, type_=type_) @@ -324,8 +330,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin): start_dttm: Optional[DateTime] = None, end_dttm: Optional[DateTime] = None, label: Optional[str] = "__time", + template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - col = self.get_sqla_col(label=label) + col = self.get_sqla_col(label=label, template_processor=template_processor) l = [] if start_dttm: l.append(col >= self.table.text(self.dttm_sql_literal(start_dttm))) @@ -358,10 +365,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin): if not self.expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=type_) return self.table.make_sqla_column_compatible(sqla_col, label) - if self.expression: - expression = self.expression + if expression := self.expression: if template_processor: - expression = template_processor.process_template(self.expression) + expression = template_processor.process_template(expression) col = literal_column(expression, type_=type_) else: col = column(self.column_name, type_=type_) @@ -458,10 +464,17 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): def __repr__(self) -> str: return str(self.metric_name) - def get_sqla_col(self, label: Optional[str] = None) -> Column: + def get_sqla_col( + self, + label: Optional[str] = None, + template_processor: Optional[BaseTemplateProcessor] = None, + ) -> Column: label = label or self.metric_name - tp = self.table.get_template_processor() - sqla_col: ColumnClause = literal_column(tp.process_template(self.expression)) + expression = self.expression + if template_processor: + expression = template_processor.process_template(expression) + + sqla_col: ColumnClause = literal_column(expression) return self.table.make_sqla_column_compatible(sqla_col, label) @property @@ -650,7 +663,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho datasource_name: str, schema: Optional[str], database_name: str, - ) -> Optional["SqlaTable"]: + ) -> Optional[SqlaTable]: schema = schema or None query = ( session.query(cls) @@ -778,10 +791,17 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho except (TypeError, json.JSONDecodeError): return {} - def get_fetch_values_predicate(self) -> TextClause: - tp = self.get_template_processor() + def get_fetch_values_predicate( + self, + template_processor: Optional[BaseTemplateProcessor] = None, + ) -> TextClause: + fetch_values_predicate = self.fetch_values_predicate + if template_processor: + fetch_values_predicate = template_processor.process_template( + fetch_values_predicate + ) try: - return self.text(tp.process_template(self.fetch_values_predicate)) + return self.text(fetch_values_predicate) except TemplateError as ex: raise QueryObjectValidationError( _( @@ -799,12 +819,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho tp = self.get_template_processor() tbl, cte = self.get_from_clause(tp) - qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct() + qry = ( + select([target_col.get_sqla_col(template_processor=tp)]) + .select_from(tbl) + .distinct() + ) if limit: qry = qry.limit(limit) if self.fetch_values_predicate: - qry = qry.where(self.get_fetch_values_predicate()) + qry = qry.where(self.get_fetch_values_predicate(template_processor=tp)) with self.database.get_sqla_engine_with_context() as engine: sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) @@ -936,7 +960,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho column_name = cast(str, metric_column.get("column_name")) table_column: Optional[TableColumn] = columns_by_name.get(column_name) if table_column: - sqla_column = table_column.get_sqla_col() + sqla_column = table_column.get_sqla_col( + template_processor=template_processor + ) else: sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) @@ -975,7 +1001,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ) col_in_metadata = self.get_column(expression) if col_in_metadata: - sqla_column = col_in_metadata.get_sqla_col() + sqla_column = col_in_metadata.get_sqla_col( + template_processor=template_processor + ) is_dttm = col_in_metadata.is_temporal else: sqla_column = literal_column(expression) @@ -1190,7 +1218,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ) ) elif isinstance(metric, str) and metric in metrics_by_name: - metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) + metrics_exprs.append( + metrics_by_name[metric].get_sqla_col( + template_processor=template_processor + ) + ) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=metric) @@ -1229,12 +1261,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: - col = columns_by_name[col].get_sqla_col() + col = columns_by_name[col].get_sqla_col( + template_processor=template_processor + ) elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True elif col in metrics_by_name: - col = metrics_by_name[col].get_sqla_col() + col = metrics_by_name[col].get_sqla_col( + template_processor=template_processor + ) need_groupby = True if isinstance(col, ColumnElement): @@ -1268,7 +1304,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ) # if groupby field equals a selected column elif selected in columns_by_name: - outer = columns_by_name[selected].get_sqla_col() + outer = columns_by_name[selected].get_sqla_col( + template_processor=template_processor + ) else: selected = validate_adhoc_subquery( selected, @@ -1302,7 +1340,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho self.schema, ) select_exprs.append( - columns_by_name[selected].get_sqla_col() + columns_by_name[selected].get_sqla_col( + template_processor=template_processor + ) if isinstance(selected, str) and selected in columns_by_name else self.make_sqla_column_compatible( literal_column(selected), _column_label @@ -1336,11 +1376,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ): time_filters.append( columns_by_name[self.main_dttm_col].get_time_filter( - from_dttm, - to_dttm, + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, ) ) - time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) + time_filters.append( + dttm_col.get_time_filter( + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + ) # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use @@ -1396,7 +1443,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho time_grain=filter_grain, template_processor=template_processor ) elif col_obj: - sqla_col = col_obj.get_sqla_col() + sqla_col = col_obj.get_sqla_col( + template_processor=template_processor + ) col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, @@ -1521,6 +1570,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho start_dttm=_since, end_dttm=_until, label=sqla_col.key, + template_processor=template_processor, ) ) else: @@ -1565,7 +1615,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho having_clause_and += [self.text(having)] if apply_fetch_values_predicate and self.fetch_values_predicate: - qry = qry.where(self.get_fetch_values_predicate()) + qry = qry.where( + self.get_fetch_values_predicate(template_processor=template_processor) + ) if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: @@ -1617,8 +1669,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho if dttm_col and not db_engine_spec.time_groupby_inline: inner_time_filter = [ dttm_col.get_time_filter( - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, + start_dttm=inner_from_dttm or from_dttm, + end_dttm=inner_to_dttm or to_dttm, + template_processor=template_processor, ) ] subq = subq.where(and_(*(where_clause_and + inner_time_filter))) @@ -1627,7 +1680,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ob = inner_main_metric_expr if series_limit_metric: ob = self._get_series_orderby( - series_limit_metric, metrics_by_name, columns_by_name + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, ) direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) @@ -1647,9 +1703,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho orderby = [ ( self._get_series_orderby( - series_limit_metric, - metrics_by_name, - columns_by_name, + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, ), not order_desc, ) @@ -1709,6 +1766,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho series_limit_metric: Metric, metrics_by_name: Dict[str, SqlMetric], columns_by_name: Dict[str, TableColumn], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) @@ -1717,7 +1775,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho isinstance(series_limit_metric, str) and series_limit_metric in metrics_by_name ): - ob = metrics_by_name[series_limit_metric].get_sqla_col() + ob = metrics_by_name[series_limit_metric].get_sqla_col( + template_processor=template_processor + ) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=series_limit_metric) @@ -1930,7 +1990,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho database: Database, datasource_name: str, schema: Optional[str] = None, - ) -> List["SqlaTable"]: + ) -> List[SqlaTable]: query = ( session.query(cls) .filter_by(database_id=database.id) @@ -1947,7 +2007,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho database: Database, permissions: Set[str], schema_perms: Set[str], - ) -> List["SqlaTable"]: + ) -> List[SqlaTable]: # TODO(hughhhh): add unit test return ( session.query(cls) @@ -1964,7 +2024,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho @classmethod def get_eager_sqlatable_datasource( cls, session: Session, datasource_id: int - ) -> "SqlaTable": + ) -> SqlaTable: """Returns SqlaTable with columns and metrics.""" return ( session.query(cls) @@ -1977,7 +2037,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ) @classmethod - def get_all_datasources(cls, session: Session) -> List["SqlaTable"]: + def get_all_datasources(cls, session: Session) -> List[SqlaTable]: qry = session.query(cls) qry = cls.default_query(qry) return qry.all() @@ -2038,7 +2098,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def before_update( mapper: Mapper, # pylint: disable=unused-argument connection: Connection, # pylint: disable=unused-argument - target: "SqlaTable", + target: SqlaTable, ) -> None: """ Check before update if the target table already exists. @@ -2110,7 +2170,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def after_insert( mapper: Mapper, connection: Connection, - sqla_table: "SqlaTable", + sqla_table: SqlaTable, ) -> None: """ Update dataset permissions after insert @@ -2124,7 +2184,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def after_delete( mapper: Mapper, connection: Connection, - sqla_table: "SqlaTable", + sqla_table: SqlaTable, ) -> None: """ Update dataset permissions after delete @@ -2135,7 +2195,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def after_update( mapper: Mapper, connection: Connection, - sqla_table: "SqlaTable", + sqla_table: SqlaTable, ) -> None: """ Update dataset permissions @@ -2170,7 +2230,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho return def write_shadow_dataset( - self: "SqlaTable", + self: SqlaTable, ) -> None: """ This method is deprecated diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 3088bdfb0..dfba16179 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -201,22 +201,34 @@ class TestDatabaseModel(SupersetTestCase): "granularity": None, "from_dttm": None, "to_dttm": None, - "groupby": ["user", "expr"], + "columns": [ + "user", + "expr", + { + "hasCustomLabel": True, + "label": "adhoc_column", + "sqlExpression": "'{{ 'foo_' + time_grain }}'", + }, + ], "metrics": [ { + "hasCustomLabel": True, + "label": "adhoc_metric", "expressionType": AdhocMetricExpressionType.SQL, - "sqlExpression": "SUM(case when user = '{{ current_username() }}' " - "then 1 else 0 end)", - "label": "SUM(userid)", - } + "sqlExpression": "SUM(case when user = '{{ 'user_' + " + "current_username() }}' then 1 else 0 end)", + }, + "count_timegrain", ], "is_timeseries": False, "filter": [], + "extras": {"time_grain_sqla": "P1D"}, } table = SqlaTable( table_name="test_has_jinja_metric_and_expr", - sql="SELECT '{{ current_username() }}' as user", + sql="SELECT '{{ 'user_' + current_username() }}' as user, " + "'{{ 'xyz_' + time_grain }}' as time_grain", database=get_example_database(), ) TableColumn( @@ -226,14 +238,25 @@ class TestDatabaseModel(SupersetTestCase): type="VARCHAR(100)", table=table, ) + SqlMetric( + metric_name="count_timegrain", + expression="count('{{ 'bar_' + time_grain }}')", + table=table, + ) db.session.commit() sqla_query = table.get_sqla_query(**base_query_obj) query = table.database.compile_sqla_query(sqla_query.sqla_query) - # assert expression - assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query - # assert metric - assert "SUM(case when user = 'abc' then 1 else 0 end)" in query + # assert virtual dataset + assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query + # assert dataset calculated column + assert "case when 'abc' = 'abc' then 'yes' else 'no' end AS expr" in query + # assert adhoc column + assert "'foo_P1D'" in query + # assert dataset saved metric + assert "count('bar_P1D')" in query + # assert adhoc metric + assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query # Cleanup db.session.delete(table) db.session.commit()