diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index f178db458..de9f4d1fd 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -18,6 +18,7 @@ from collections import namedtuple, OrderedDict from datetime import datetime import logging +from typing import Optional, Union from flask import escape, Markup from flask_appbuilder import Model @@ -32,11 +33,12 @@ from sqlalchemy.exc import CompileError from sqlalchemy.orm import backref, relationship from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, literal_column, table, text -from sqlalchemy.sql.expression import TextAsFrom +from sqlalchemy.sql.expression import Label, TextAsFrom import sqlparse from superset import app, db, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric +from superset.db_engine_specs import TimestampExpression from superset.jinja_context import get_template_processor from superset.models.annotations import Annotation from superset.models.core import Database @@ -140,8 +142,14 @@ class TableColumn(Model, BaseColumn): l.append(col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc))) return and_(*l) - def get_timestamp_expression(self, time_grain): - """Getting the time component of the query""" + def get_timestamp_expression(self, time_grain: Optional[str]) \ + -> Union[TimestampExpression, Label]: + """ + Return a SQLAlchemy Core element representation of self to be used in a query. + + :param time_grain: Optional time grain, e.g. P1Y + :return: A TimeExpression object wrapped in a Label if supported by db + """ label = utils.DTTM_ALIAS db = self.table.database @@ -150,16 +158,12 @@ class TableColumn(Model, BaseColumn): if not self.expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=DateTime) return self.table.make_sqla_column_compatible(sqla_col, label) - grain = None - if time_grain: - grain = db.grains_dict().get(time_grain) - if not grain: - raise NotImplementedError( - f'No grain spec for {time_grain} for database {db.database_name}') - col = db.db_engine_spec.get_timestamp_column(self.expression, self.column_name) - expr = db.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) - sqla_col = literal_column(expr, type_=DateTime) - return self.table.make_sqla_column_compatible(sqla_col, label) + if self.expression: + col = literal_column(self.expression) + else: + col = column(self.column_name) + time_expr = db.db_engine_spec.get_timestamp_expr(col, pdf, time_grain) + return self.table.make_sqla_column_compatible(time_expr, label) @classmethod def import_obj(cls, i_column): diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 04efef78b..b6103c347 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -36,19 +36,20 @@ import os import re import textwrap import time -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple from urllib import parse from flask import g from flask_babel import lazy_gettext as _ import pandas import sqlalchemy as sqla -from sqlalchemy import Column, select, types +from sqlalchemy import Column, DateTime, select, types from sqlalchemy.engine import create_engine from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.url import make_url +from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause from sqlalchemy.sql.expression import TextAsFrom @@ -90,6 +91,24 @@ builtin_time_grains = { } +class TimestampExpression(ColumnClause): + def __init__(self, expr: str, col: ColumnClause, **kwargs): + """Sqlalchemy class that can be can be used to render native column elements + respeting engine-specific quoting rules as part of a string-based expression. + + :param expr: Sql expression with '{col}' denoting the locations where the col + object will be rendered. + :param col: the target column + """ + super().__init__(expr, **kwargs) + self.col = col + + +@compiles(TimestampExpression) +def compile_timegrain_expression(element: TimestampExpression, compiler, **kw): + return element.name.replace('{col}', compiler.process(element.col, **kw)) + + def _create_time_grains_tuple(time_grains, time_grain_functions, blacklist): ret_list = [] blacklist = blacklist if blacklist else [] @@ -112,7 +131,7 @@ class BaseEngineSpec(object): """Abstract class for database engine specific configurations""" engine = 'base' # str as defined in sqlalchemy.engine.engine - time_grain_functions: dict = {} + time_grain_functions: Dict[Optional[str], str] = {} time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False @@ -125,16 +144,31 @@ class BaseEngineSpec(object): try_remove_schema_from_table_name = True @classmethod - def get_time_expr(cls, expr, pdf, time_grain, grain): + def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str], + time_grain: Optional[str]) -> TimestampExpression: + """ + Construct a TimeExpression to be used in a SQLAlchemy query. + + :param col: Target column for the TimeExpression + :param pdf: date format (seconds or milliseconds) + :param time_grain: time grain, e.g. P1Y for 1 year + :return: TimestampExpression object + """ + if time_grain: + time_expr = cls.time_grain_functions.get(time_grain) + if not time_expr: + raise NotImplementedError( + f'No grain spec for {time_grain} for database {cls.engine}') + else: + time_expr = '{col}' + # if epoch, translate to DATE using db specific conf if pdf == 'epoch_s': - expr = cls.epoch_to_dttm().format(col=expr) + time_expr = time_expr.replace('{col}', cls.epoch_to_dttm()) elif pdf == 'epoch_ms': - expr = cls.epoch_ms_to_dttm().format(col=expr) + time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm()) - if grain: - expr = grain.function.format(col=expr) - return expr + return TimestampExpression(time_expr, col, type_=DateTime) @classmethod def get_time_grains(cls): @@ -489,13 +523,6 @@ class BaseEngineSpec(object): label = label[:cls.max_column_name_length] return label - @staticmethod - def get_timestamp_column(expression, column_name): - """Return the expression if defined, otherwise return column_name. Some - engines require forcing quotes around column name, in which case this method - can be overridden.""" - return expression or column_name - class PostgresBaseEngineSpec(BaseEngineSpec): """ Abstract class for Postgres 'like' databases """ @@ -543,16 +570,6 @@ class PostgresEngineSpec(PostgresBaseEngineSpec): tables.extend(inspector.get_foreign_table_names(schema)) return sorted(tables) - @staticmethod - def get_timestamp_column(expression, column_name): - """Postgres is unable to identify mixed case column names unless they - are quoted.""" - if expression: - return expression - elif column_name.lower() != column_name: - return f'"{column_name}"' - return column_name - class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' @@ -794,7 +811,7 @@ class MySQLEngineSpec(BaseEngineSpec): 'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))', } - type_code_map: dict = {} # loaded from get_datatype only if needed + type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed @classmethod def convert_dttm(cls, target_type, dttm): @@ -1812,20 +1829,21 @@ class PinotEngineSpec(BaseEngineSpec): inner_joins = False supports_column_aliases = False - _time_grain_to_datetimeconvert = { + # Pinot does its own conversion below + time_grain_functions: Dict[Optional[str], str] = { 'PT1S': '1:SECONDS', 'PT1M': '1:MINUTES', 'PT1H': '1:HOURS', 'P1D': '1:DAYS', - 'P1Y': '1:YEARS', + 'P1W': '1:WEEKS', 'P1M': '1:MONTHS', + 'P0.25Y': '3:MONTHS', + 'P1Y': '1:YEARS', } - # Pinot does its own conversion below - time_grain_functions = {k: None for k in _time_grain_to_datetimeconvert.keys()} - @classmethod - def get_time_expr(cls, expr, pdf, time_grain, grain): + def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str], + time_grain: Optional[str]) -> TimestampExpression: is_epoch = pdf in ('epoch_s', 'epoch_ms') if not is_epoch: raise NotImplementedError('Pinot currently only supports epochs') @@ -1834,11 +1852,12 @@ class PinotEngineSpec(BaseEngineSpec): # We are not really converting any time units, just bucketing them. seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS' tf = f'1:{seconds_or_ms}:EPOCH' - granularity = cls._time_grain_to_datetimeconvert.get(time_grain) + granularity = cls.time_grain_functions.get(time_grain) if not granularity: raise NotImplementedError('No pinot grain spec for ' + str(time_grain)) # In pinot the output is a string since there is no timestamp column like pg - return f'DATETIMECONVERT({expr}, "{tf}", "{tf}", "{granularity}")' + time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")' + return TimestampExpression(time_expr, col) @classmethod def make_select_compatible(cls, groupby_exprs, select_exprs): diff --git a/superset/models/core.py b/superset/models/core.py index 047a3ddb1..b379af7ca 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -1029,21 +1029,13 @@ class Database(Model, AuditMixinNullable, ImportMixin): """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain - form a datetime (maybe the source grain is arbitrary timestamps, daily + from a datetime (maybe the source grain is arbitrary timestamps, daily or 5 minutes increments) to another, "truncated" datetime. Since each database has slightly different but similar datetime functions, this allows a mapping between database engines and actual functions. """ return self.db_engine_spec.get_time_grains() - def grains_dict(self): - """Allowing to lookup grain by either label or duration - - For backward compatibility""" - d = {grain.duration: grain for grain in self.grains()} - d.update({grain.label: grain for grain in self.grains()}) - return d - def get_extra(self): extra = {} if self.extra: diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 0372366a2..43f89c143 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -17,15 +17,16 @@ import inspect from unittest import mock -from sqlalchemy import column, select, table -from sqlalchemy.dialects.mssql import pymssql +from sqlalchemy import column, literal_column, select, table +from sqlalchemy.dialects import mssql, oracle, postgresql from sqlalchemy.engine.result import RowProxy from sqlalchemy.types import String, UnicodeText from superset import db_engine_specs from superset.db_engine_specs import ( BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec, - MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec, + MySQLEngineSpec, OracleEngineSpec, PinotEngineSpec, PostgresEngineSpec, + PrestoEngineSpec, ) from superset.models.core import Database from .base_tests import SupersetTestCase @@ -451,7 +452,7 @@ class DbEngineSpecsTestCase(SupersetTestCase): assert_type('NTEXT', UnicodeText) def test_mssql_where_clause_n_prefix(self): - dialect = pymssql.dialect() + dialect = mssql.dialect() spec = MssqlEngineSpec str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)')) unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT')) @@ -462,7 +463,9 @@ class DbEngineSpecsTestCase(SupersetTestCase): where(unicode_col == 'abc') query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True})) - query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa + query_expected = 'SELECT col, unicode_col \n' \ + 'FROM tbl \n' \ + "WHERE col = 'abc' AND unicode_col = N'abc'" self.assertEqual(query, query_expected) def test_get_table_names(self): @@ -483,3 +486,51 @@ class DbEngineSpecsTestCase(SupersetTestCase): pg_result = db_engine_specs.PostgresEngineSpec.get_table_names( schema='schema', inspector=inspector) self.assertListEqual(pg_result_expected, pg_result) + + def test_pg_time_expression_literal_no_grain(self): + col = literal_column('COALESCE(a, b)') + expr = PostgresEngineSpec.get_timestamp_expr(col, None, None) + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, 'COALESCE(a, b)') + + def test_pg_time_expression_literal_1y_grain(self): + col = literal_column('COALESCE(a, b)') + expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y') + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))") + + def test_pg_time_expression_lower_column_no_grain(self): + col = column('lower_case') + expr = PostgresEngineSpec.get_timestamp_expr(col, None, None) + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, 'lower_case') + + def test_pg_time_expression_lower_case_column_sec_1y_grain(self): + col = column('lower_case') + expr = PostgresEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1Y') + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))") # noqa + + def test_pg_time_expression_mixed_case_column_1y_grain(self): + col = column('MixedCase') + expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y') + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")") + + def test_mssql_time_expression_mixed_case_column_1y_grain(self): + col = column('MixedCase') + expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y') + result = str(expr.compile(dialect=mssql.dialect())) + self.assertEqual(result, 'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)') + + def test_oracle_time_expression_reserved_keyword_1m_grain(self): + col = column('decimal') + expr = OracleEngineSpec.get_timestamp_expr(col, None, 'P1M') + result = str(expr.compile(dialect=oracle.dialect())) + self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')") + + def test_pinot_time_expression_sec_1m_grain(self): + col = column('tstamp') + expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M') + result = str(expr.compile()) + self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")') # noqa diff --git a/tests/model_tests.py b/tests/model_tests.py index 0fe03de93..53e53cc5f 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -109,47 +109,6 @@ class DatabaseModelTestCase(SupersetTestCase): LIMIT 100""") assert sql.startswith(expected) - def test_grains_dict(self): - uri = 'mysql://root@localhost' - database = Database(sqlalchemy_uri=uri) - d = database.grains_dict() - self.assertEquals(d.get('day').function, 'DATE({col})') - self.assertEquals(d.get('P1D').function, 'DATE({col})') - self.assertEquals(d.get('Time Column').function, '{col}') - - def test_postgres_expression_time_grain(self): - uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset' - database = Database(sqlalchemy_uri=uri) - pdf, time_grain = '', 'P1D' - expression, column_name = 'COALESCE(lowercase_col, "MixedCaseCol")', '' - grain = database.grains_dict().get(time_grain) - col = database.db_engine_spec.get_timestamp_column(expression, column_name) - grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) - grain_expr_expected = grain.function.replace('{col}', expression) - self.assertEqual(grain_expr, grain_expr_expected) - - def test_postgres_lowercase_col_time_grain(self): - uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset' - database = Database(sqlalchemy_uri=uri) - pdf, time_grain = '', 'P1D' - expression, column_name = '', 'lowercase_col' - grain = database.grains_dict().get(time_grain) - col = database.db_engine_spec.get_timestamp_column(expression, column_name) - grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) - grain_expr_expected = grain.function.replace('{col}', column_name) - self.assertEqual(grain_expr, grain_expr_expected) - - def test_postgres_mixedcase_col_time_grain(self): - uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset' - database = Database(sqlalchemy_uri=uri) - pdf, time_grain = '', 'P1D' - expression, column_name = '', 'MixedCaseCol' - grain = database.grains_dict().get(time_grain) - col = database.db_engine_spec.get_timestamp_column(expression, column_name) - grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) - grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"') - self.assertEqual(grain_expr, grain_expr_expected) - def test_single_statement(self): main_db = get_main_database(db.session) @@ -217,24 +176,6 @@ class SqlaTableModelTestCase(SupersetTestCase): self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))') ds_col.expression = prev_ds_expr - def test_get_timestamp_expression_backward(self): - tbl = self.get_table_by_name('birth_names') - ds_col = tbl.get_column('ds') - - ds_col.expression = None - ds_col.python_date_format = None - sqla_literal = ds_col.get_timestamp_expression('day') - compiled = '{}'.format(sqla_literal.compile()) - if tbl.database.backend == 'mysql': - self.assertEquals(compiled, 'DATE(ds)') - - ds_col.expression = None - ds_col.python_date_format = None - sqla_literal = ds_col.get_timestamp_expression('Time Column') - compiled = '{}'.format(sqla_literal.compile()) - if tbl.database.backend == 'mysql': - self.assertEquals(compiled, 'ds') - def query_with_expr_helper(self, is_timeseries, inner_join=True): tbl = self.get_table_by_name('birth_names') ds_col = tbl.get_column('ds')