Break up db_engine_specs (#7676)
* Refactor db_engine_specs into package * Rename bigquery class and add epoch funcs * Fix flake8 errors * Dynamically load all engine specs * Fix linting errors and unit tests * Implement Snowflake epoch time funcs * Implement Teradata epoch time func * Fix presto datasource query and remove unused import * Fix broken datasource query * Add mypy ignore for false positive * Add missing license files * Make create_time_grains_tuple public * Fix flake8 quote * Fix incorrect licence header
This commit is contained in:
parent
f3e5805b92
commit
95291facff
|
|
@ -38,7 +38,7 @@ import sqlparse
|
|||
|
||||
from superset import app, db, security_manager
|
||||
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
|
||||
from superset.db_engine_specs import TimestampExpression
|
||||
from superset.db_engine_specs.base import TimestampExpression
|
||||
from superset.jinja_context import get_template_processor
|
||||
from superset.models.annotations import Annotation
|
||||
from superset.models.core import Database
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,49 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
"""Compatibility layer for different database engines
|
||||
|
||||
This modules stores logic specific to different database engines. Things
|
||||
like time-related functions that are similar but not identical, or
|
||||
information as to expose certain features or not and how to expose them.
|
||||
|
||||
For instance, Hive/Presto supports partitions and have a specific API to
|
||||
list partitions. Other databases like Vertica also support partitions but
|
||||
have different API to get to them. Other databases don't support partitions
|
||||
at all. The classes here will use a common interface to specify all this.
|
||||
|
||||
The general idea is to use static classes and an inheritance scheme.
|
||||
"""
|
||||
from importlib import import_module
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
from typing import Dict, Type
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
engines: Dict[str, Type[BaseEngineSpec]] = {}
|
||||
|
||||
for (_, name, _) in pkgutil.iter_modules([Path(__file__).parent]): # type: ignore
|
||||
imported_module = import_module('.' + name, package=__name__)
|
||||
|
||||
for i in dir(imported_module):
|
||||
attribute = getattr(imported_module, i)
|
||||
|
||||
if inspect.isclass(attribute) and issubclass(attribute, BaseEngineSpec) \
|
||||
and attribute.engine != '':
|
||||
engines[attribute.engine] = attribute
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
class AthenaEngineSpec(BaseEngineSpec):
|
||||
engine = 'awsathena'
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': "date_trunc('second', CAST({col} AS TIMESTAMP))",
|
||||
'PT1M': "date_trunc('minute', CAST({col} AS TIMESTAMP))",
|
||||
'PT1H': "date_trunc('hour', CAST({col} AS TIMESTAMP))",
|
||||
'P1D': "date_trunc('day', CAST({col} AS TIMESTAMP))",
|
||||
'P1W': "date_trunc('week', CAST({col} AS TIMESTAMP))",
|
||||
'P1M': "date_trunc('month', CAST({col} AS TIMESTAMP))",
|
||||
'P0.25Y': "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
|
||||
'P1Y': "date_trunc('year', CAST({col} AS TIMESTAMP))",
|
||||
'P1W/1970-01-03T00:00:00Z': "date_add('day', 5, date_trunc('week', \
|
||||
date_add('day', 1, CAST({col} AS TIMESTAMP))))",
|
||||
'1969-12-28T00:00:00Z/P1W': "date_add('day', -1, date_trunc('week', \
|
||||
date_add('day', 1, CAST({col} AS TIMESTAMP))))",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
if tt == 'DATE':
|
||||
return "from_iso8601_date('{}')".format(dttm.isoformat()[:10])
|
||||
if tt == 'TIMESTAMP':
|
||||
return "from_iso8601_timestamp('{}')".format(dttm.isoformat())
|
||||
return ("CAST ('{}' AS TIMESTAMP)"
|
||||
.format(dttm.strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return 'from_unixtime({col})'
|
||||
|
||||
@staticmethod
|
||||
def mutate_label(label):
|
||||
"""
|
||||
Athena only supports lowercase column names and aliases.
|
||||
:param str label: Original label which might include uppercase letters
|
||||
:return: String that is supported by the database
|
||||
"""
|
||||
return label.lower()
|
||||
|
|
@ -0,0 +1,508 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from collections import namedtuple
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from flask import g
|
||||
from flask_babel import lazy_gettext as _
|
||||
import pandas as pd
|
||||
from sqlalchemy import column, DateTime, select
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql import quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
from sqlalchemy.sql.expression import TextAsFrom
|
||||
import sqlparse
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from superset import app, db, sql_parse
|
||||
from superset.utils import core as utils
|
||||
|
||||
Grain = namedtuple('Grain', 'name label function duration')
|
||||
|
||||
config = app.config
|
||||
|
||||
|
||||
QueryStatus = utils.QueryStatus
|
||||
config = app.config
|
||||
|
||||
builtin_time_grains = {
|
||||
None: 'Time Column',
|
||||
'PT1S': 'second',
|
||||
'PT1M': 'minute',
|
||||
'PT5M': '5 minute',
|
||||
'PT10M': '10 minute',
|
||||
'PT15M': '15 minute',
|
||||
'PT0.5H': 'half hour',
|
||||
'PT1H': 'hour',
|
||||
'P1D': 'day',
|
||||
'P1W': 'week',
|
||||
'P1M': 'month',
|
||||
'P0.25Y': 'quarter',
|
||||
'P1Y': 'year',
|
||||
'1969-12-28T00:00:00Z/P1W': 'week_start_sunday',
|
||||
'1969-12-29T00:00:00Z/P1W': 'week_start_monday',
|
||||
'P1W/1970-01-03T00:00:00Z': 'week_ending_saturday',
|
||||
'P1W/1970-01-04T00:00:00Z': 'week_ending_sunday',
|
||||
}
|
||||
|
||||
|
||||
class TimestampExpression(ColumnClause):
|
||||
def __init__(self, expr: str, col: ColumnClause, **kwargs):
|
||||
"""Sqlalchemy class that can be can be used to render native column elements
|
||||
respeting engine-specific quoting rules as part of a string-based expression.
|
||||
|
||||
:param expr: Sql expression with '{col}' denoting the locations where the col
|
||||
object will be rendered.
|
||||
:param col: the target column
|
||||
"""
|
||||
super().__init__(expr, **kwargs)
|
||||
self.col = col
|
||||
|
||||
|
||||
@compiles(TimestampExpression)
|
||||
def compile_timegrain_expression(element: TimestampExpression, compiler, **kw):
|
||||
return element.name.replace('{col}', compiler.process(element.col, **kw))
|
||||
|
||||
|
||||
class LimitMethod(object):
|
||||
"""Enum the ways that limits can be applied"""
|
||||
FETCH_MANY = 'fetch_many'
|
||||
WRAP_SQL = 'wrap_sql'
|
||||
FORCE_LIMIT = 'force_limit'
|
||||
|
||||
|
||||
def create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
|
||||
ret_list = []
|
||||
blacklist = blacklist if blacklist else []
|
||||
for duration, func in time_grain_functions.items():
|
||||
if duration not in blacklist:
|
||||
name = time_grains.get(duration)
|
||||
ret_list.append(Grain(name, _(name), func, duration))
|
||||
return tuple(ret_list)
|
||||
|
||||
|
||||
class BaseEngineSpec(object):
|
||||
"""Abstract class for database engine specific configurations"""
|
||||
|
||||
engine = 'base' # str as defined in sqlalchemy.engine.engine
|
||||
time_grain_functions: Dict[Optional[str], str] = {}
|
||||
time_groupby_inline = False
|
||||
limit_method = LimitMethod.FORCE_LIMIT
|
||||
time_secondary_columns = False
|
||||
inner_joins = True
|
||||
allows_subquery = True
|
||||
supports_column_aliases = True
|
||||
force_column_alias_quotes = False
|
||||
arraysize = 0
|
||||
max_column_name_length = 0
|
||||
try_remove_schema_from_table_name = True
|
||||
|
||||
@classmethod
|
||||
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
|
||||
time_grain: Optional[str]) -> TimestampExpression:
|
||||
"""
|
||||
Construct a TimeExpression to be used in a SQLAlchemy query.
|
||||
|
||||
:param col: Target column for the TimeExpression
|
||||
:param pdf: date format (seconds or milliseconds)
|
||||
:param time_grain: time grain, e.g. P1Y for 1 year
|
||||
:return: TimestampExpression object
|
||||
"""
|
||||
if time_grain:
|
||||
time_expr = cls.time_grain_functions.get(time_grain)
|
||||
if not time_expr:
|
||||
raise NotImplementedError(
|
||||
f'No grain spec for {time_grain} for database {cls.engine}')
|
||||
else:
|
||||
time_expr = '{col}'
|
||||
|
||||
# if epoch, translate to DATE using db specific conf
|
||||
if pdf == 'epoch_s':
|
||||
time_expr = time_expr.replace('{col}', cls.epoch_to_dttm())
|
||||
elif pdf == 'epoch_ms':
|
||||
time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm())
|
||||
|
||||
return TimestampExpression(time_expr, col, type_=DateTime)
|
||||
|
||||
@classmethod
|
||||
def get_time_grains(cls):
|
||||
blacklist = config.get('TIME_GRAIN_BLACKLIST', [])
|
||||
grains = builtin_time_grains.copy()
|
||||
grains.update(config.get('TIME_GRAIN_ADDONS', {}))
|
||||
grain_functions = cls.time_grain_functions.copy()
|
||||
grain_addon_functions = config.get('TIME_GRAIN_ADDON_FUNCTIONS', {})
|
||||
grain_functions.update(grain_addon_functions.get(cls.engine, {}))
|
||||
return create_time_grains_tuple(grains, grain_functions, blacklist)
|
||||
|
||||
@classmethod
|
||||
def make_select_compatible(cls, groupby_exprs, select_exprs):
|
||||
# Some databases will just return the group-by field into the select, but don't
|
||||
# allow the group-by field to be put into the select list.
|
||||
return select_exprs
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit):
|
||||
if cls.arraysize:
|
||||
cursor.arraysize = cls.arraysize
|
||||
if cls.limit_method == LimitMethod.FETCH_MANY:
|
||||
return cursor.fetchmany(limit)
|
||||
return cursor.fetchall()
|
||||
|
||||
@classmethod
|
||||
def expand_data(cls,
|
||||
columns: List[dict],
|
||||
data: List[dict]) -> Tuple[List[dict], List[dict], List[dict]]:
|
||||
return columns, data, []
|
||||
|
||||
@classmethod
|
||||
def alter_new_orm_column(cls, orm_col):
|
||||
"""Allow altering default column attributes when first detected/added
|
||||
|
||||
For instance special column like `__time` for Druid can be
|
||||
set to is_dttm=True. Note that this only gets called when new
|
||||
columns are detected/created"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def epoch_ms_to_dttm(cls):
|
||||
return cls.epoch_to_dttm().replace('{col}', '({col}/1000)')
|
||||
|
||||
@classmethod
|
||||
def get_datatype(cls, type_code):
|
||||
if isinstance(type_code, str) and len(type_code):
|
||||
return type_code.upper()
|
||||
|
||||
@classmethod
|
||||
def extra_table_metadata(cls, database, table_name, schema_name):
|
||||
"""Returns engine-specific table metadata"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def apply_limit_to_sql(cls, sql, limit, database):
|
||||
"""Alters the SQL statement to apply a LIMIT clause"""
|
||||
if cls.limit_method == LimitMethod.WRAP_SQL:
|
||||
sql = sql.strip('\t\n ;')
|
||||
qry = (
|
||||
select('*')
|
||||
.select_from(
|
||||
TextAsFrom(text(sql), ['*']).alias('inner_qry'),
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
return database.compile_sqla_query(qry)
|
||||
elif LimitMethod.FORCE_LIMIT:
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
sql = parsed_query.get_query_with_new_limit(limit)
|
||||
return sql
|
||||
|
||||
@classmethod
|
||||
def get_limit_from_sql(cls, sql):
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
return parsed_query.limit
|
||||
|
||||
@classmethod
|
||||
def get_query_with_new_limit(cls, sql, limit):
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
return parsed_query.get_query_with_new_limit(limit)
|
||||
|
||||
@staticmethod
|
||||
def csv_to_df(**kwargs):
|
||||
kwargs['filepath_or_buffer'] = \
|
||||
config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer']
|
||||
kwargs['encoding'] = 'utf-8'
|
||||
kwargs['iterator'] = True
|
||||
chunks = pd.read_csv(**kwargs)
|
||||
df = pd.DataFrame()
|
||||
df = pd.concat(chunk for chunk in chunks)
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def df_to_db(df, table, **kwargs):
|
||||
df.to_sql(**kwargs)
|
||||
table.user_id = g.user.id
|
||||
table.schema = kwargs['schema']
|
||||
table.fetch_metadata()
|
||||
db.session.add(table)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def create_table_from_csv(form, table):
|
||||
def _allowed_file(filename):
|
||||
# Only allow specific file extensions as specified in the config
|
||||
extension = os.path.splitext(filename)[1]
|
||||
return extension and extension[1:] in config['ALLOWED_EXTENSIONS']
|
||||
|
||||
filename = secure_filename(form.csv_file.data.filename)
|
||||
if not _allowed_file(filename):
|
||||
raise Exception('Invalid file type selected')
|
||||
kwargs = {
|
||||
'filepath_or_buffer': filename,
|
||||
'sep': form.sep.data,
|
||||
'header': form.header.data if form.header.data else 0,
|
||||
'index_col': form.index_col.data,
|
||||
'mangle_dupe_cols': form.mangle_dupe_cols.data,
|
||||
'skipinitialspace': form.skipinitialspace.data,
|
||||
'skiprows': form.skiprows.data,
|
||||
'nrows': form.nrows.data,
|
||||
'skip_blank_lines': form.skip_blank_lines.data,
|
||||
'parse_dates': form.parse_dates.data,
|
||||
'infer_datetime_format': form.infer_datetime_format.data,
|
||||
'chunksize': 10000,
|
||||
}
|
||||
df = BaseEngineSpec.csv_to_df(**kwargs)
|
||||
|
||||
df_to_db_kwargs = {
|
||||
'table': table,
|
||||
'df': df,
|
||||
'name': form.name.data,
|
||||
'con': create_engine(form.con.data.sqlalchemy_uri_decrypted, echo=False),
|
||||
'schema': form.schema.data,
|
||||
'if_exists': form.if_exists.data,
|
||||
'index': form.index.data,
|
||||
'index_label': form.index_label.data,
|
||||
'chunksize': 10000,
|
||||
}
|
||||
|
||||
BaseEngineSpec.df_to_db(**df_to_db_kwargs)
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def get_all_datasource_names(cls, db, datasource_type: str) \
|
||||
-> List[utils.DatasourceName]:
|
||||
"""Returns a list of all tables or views in database.
|
||||
|
||||
:param db: Database instance
|
||||
:param datasource_type: Datasource_type can be 'table' or 'view'
|
||||
:return: List of all datasources in database or schema
|
||||
"""
|
||||
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
|
||||
cache_timeout=db.schema_cache_timeout,
|
||||
force=True)
|
||||
all_datasources: List[utils.DatasourceName] = []
|
||||
for schema in schemas:
|
||||
if datasource_type == 'table':
|
||||
all_datasources += db.get_all_table_names_in_schema(
|
||||
schema=schema, force=True,
|
||||
cache=db.table_cache_enabled,
|
||||
cache_timeout=db.table_cache_timeout)
|
||||
elif datasource_type == 'view':
|
||||
all_datasources += db.get_all_view_names_in_schema(
|
||||
schema=schema, force=True,
|
||||
cache=db.table_cache_enabled,
|
||||
cache_timeout=db.table_cache_timeout)
|
||||
else:
|
||||
raise Exception(f'Unsupported datasource_type: {datasource_type}')
|
||||
return all_datasources
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor, query, session):
|
||||
"""Handle a live cursor between the execute and fetchall calls
|
||||
|
||||
The flow works without this method doing anything, but it allows
|
||||
for handling the cursor and updating progress information in the
|
||||
query object"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def extract_error_message(cls, e):
|
||||
"""Extract error message for queries"""
|
||||
return utils.error_msg_from_exception(e)
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema):
|
||||
"""Based on a URI and selected schema, return a new URI
|
||||
|
||||
The URI here represents the URI as entered when saving the database,
|
||||
``selected_schema`` is the schema currently active presumably in
|
||||
the SQL Lab dropdown. Based on that, for some database engine,
|
||||
we can return a new altered URI that connects straight to the
|
||||
active schema, meaning the users won't have to prefix the object
|
||||
names by the schema name.
|
||||
|
||||
Some databases engines have 2 level of namespacing: database and
|
||||
schema (postgres, oracle, mssql, ...)
|
||||
For those it's probably better to not alter the database
|
||||
component of the URI with the schema name, it won't work.
|
||||
|
||||
Some database drivers like presto accept '{catalog}/{schema}' in
|
||||
the database component of the URL, that can be handled here.
|
||||
"""
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def patch(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_schema_names(cls, inspector):
|
||||
return sorted(inspector.get_schema_names())
|
||||
|
||||
@classmethod
|
||||
def get_table_names(cls, inspector, schema):
|
||||
tables = inspector.get_table_names(schema)
|
||||
if schema and cls.try_remove_schema_from_table_name:
|
||||
tables = [re.sub(f'^{schema}\\.', '', table) for table in tables]
|
||||
return sorted(tables)
|
||||
|
||||
@classmethod
|
||||
def get_view_names(cls, inspector, schema):
|
||||
views = inspector.get_view_names(schema)
|
||||
if schema and cls.try_remove_schema_from_table_name:
|
||||
views = [re.sub(f'^{schema}\\.', '', view) for view in views]
|
||||
return sorted(views)
|
||||
|
||||
@classmethod
|
||||
def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list:
|
||||
return inspector.get_columns(table_name, schema)
|
||||
|
||||
@classmethod
|
||||
def where_latest_partition(
|
||||
cls, table_name, schema, database, qry, columns=None):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols):
|
||||
return [column(c.get('name')) for c in cols]
|
||||
|
||||
@classmethod
|
||||
def select_star(cls, my_db, table_name, engine, schema=None, limit=100,
|
||||
show_cols=False, indent=True, latest_partition=True,
|
||||
cols=None):
|
||||
fields = '*'
|
||||
cols = cols or []
|
||||
if (show_cols or latest_partition) and not cols:
|
||||
cols = my_db.get_columns(table_name, schema)
|
||||
|
||||
if show_cols:
|
||||
fields = cls._get_fields(cols)
|
||||
quote = engine.dialect.identifier_preparer.quote
|
||||
if schema:
|
||||
full_table_name = quote(schema) + '.' + quote(table_name)
|
||||
else:
|
||||
full_table_name = quote(table_name)
|
||||
|
||||
qry = select(fields).select_from(text(full_table_name))
|
||||
|
||||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
if latest_partition:
|
||||
partition_query = cls.where_latest_partition(
|
||||
table_name, schema, my_db, qry, columns=cols)
|
||||
if partition_query != False: # noqa
|
||||
qry = partition_query
|
||||
sql = my_db.compile_sqla_query(qry)
|
||||
if indent:
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
return sql
|
||||
|
||||
@classmethod
|
||||
def modify_url_for_impersonation(cls, url, impersonate_user, username):
|
||||
"""
|
||||
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
|
||||
:param url: SQLAlchemy URL object
|
||||
:param impersonate_user: Bool indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
"""
|
||||
if impersonate_user is not None and username is not None:
|
||||
url.username = username
|
||||
|
||||
@classmethod
|
||||
def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
|
||||
"""
|
||||
Return a configuration dictionary that can be merged with other configs
|
||||
that can set the correct properties for impersonating users
|
||||
:param uri: URI string
|
||||
:param impersonate_user: Bool indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
:return: Dictionary with configs required for impersonation
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, cursor, query, **kwargs):
|
||||
if cls.arraysize:
|
||||
cursor.arraysize = cls.arraysize
|
||||
cursor.execute(query)
|
||||
|
||||
@classmethod
|
||||
def make_label_compatible(cls, label):
|
||||
"""
|
||||
Conditionally mutate and/or quote a sql column/expression label. If
|
||||
force_column_alias_quotes is set to True, return the label as a
|
||||
sqlalchemy.sql.elements.quoted_name object to ensure that the select query
|
||||
and query results have same case. Otherwise return the mutated label as a
|
||||
regular string. If maxmimum supported column name length is exceeded,
|
||||
generate a truncated label by calling truncate_label().
|
||||
"""
|
||||
label_mutated = cls.mutate_label(label)
|
||||
if cls.max_column_name_length and len(label_mutated) > cls.max_column_name_length:
|
||||
label_mutated = cls.truncate_label(label)
|
||||
if cls.force_column_alias_quotes:
|
||||
label_mutated = quoted_name(label_mutated, True)
|
||||
return label_mutated
|
||||
|
||||
@classmethod
|
||||
def get_sqla_column_type(cls, type_):
|
||||
"""
|
||||
Return a sqlalchemy native column type that corresponds to the column type
|
||||
defined in the data source (optional). Needs to be overridden if column requires
|
||||
special handling (see MSSQL for example of NCHAR/NVARCHAR handling).
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def mutate_label(label):
|
||||
"""
|
||||
Most engines support mixed case aliases that can include numbers
|
||||
and special characters, like commas, parentheses etc. For engines that
|
||||
have restrictions on what types of aliases are supported, this method
|
||||
can be overridden to ensure that labels conform to the engine's
|
||||
limitations. Mutated labels should be deterministic (input label A always
|
||||
yields output label X) and unique (input labels A and B don't yield the same
|
||||
output label X).
|
||||
"""
|
||||
return label
|
||||
|
||||
@classmethod
|
||||
def truncate_label(cls, label):
|
||||
"""
|
||||
In the case that a label exceeds the max length supported by the engine,
|
||||
this method is used to construct a deterministic and unique label based on
|
||||
an md5 hash.
|
||||
"""
|
||||
label = hashlib.md5(label.encode('utf-8')).hexdigest()
|
||||
# truncate hash if it exceeds max length
|
||||
if cls.max_column_name_length and len(label) > cls.max_column_name_length:
|
||||
label = label[:cls.max_column_name_length]
|
||||
return label
|
||||
|
||||
@classmethod
|
||||
def column_datatype_to_string(cls, sqla_column_type, dialect):
|
||||
return sqla_column_type.compile(dialect=dialect).upper()
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from sqlalchemy import literal_column
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
class BigQueryEngineSpec(BaseEngineSpec):
|
||||
"""Engine spec for Google's BigQuery
|
||||
|
||||
As contributed by @mxmzdlv on issue #945"""
|
||||
engine = 'bigquery'
|
||||
max_column_name_length = 128
|
||||
|
||||
"""
|
||||
https://www.python.org/dev/peps/pep-0249/#arraysize
|
||||
raw_connections bypass the pybigquery query execution context and deal with
|
||||
raw dbapi connection directly.
|
||||
If this value is not set, the default value is set to 1, as described here,
|
||||
https://googlecloudplatform.github.io/google-cloud-python/latest/_modules/google/cloud/bigquery/dbapi/cursor.html#Cursor
|
||||
|
||||
The default value of 5000 is derived from the pybigquery.
|
||||
https://github.com/mxmzdlv/pybigquery/blob/d214bb089ca0807ca9aaa6ce4d5a01172d40264e/pybigquery/sqlalchemy_bigquery.py#L102
|
||||
"""
|
||||
arraysize = 5000
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': 'TIMESTAMP_TRUNC({col}, SECOND)',
|
||||
'PT1M': 'TIMESTAMP_TRUNC({col}, MINUTE)',
|
||||
'PT1H': 'TIMESTAMP_TRUNC({col}, HOUR)',
|
||||
'P1D': 'TIMESTAMP_TRUNC({col}, DAY)',
|
||||
'P1W': 'TIMESTAMP_TRUNC({col}, WEEK)',
|
||||
'P1M': 'TIMESTAMP_TRUNC({col}, MONTH)',
|
||||
'P0.25Y': 'TIMESTAMP_TRUNC({col}, QUARTER)',
|
||||
'P1Y': 'TIMESTAMP_TRUNC({col}, YEAR)',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
if tt == 'DATE':
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d'))
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit):
|
||||
data = super(BigQueryEngineSpec, cls).fetch_data(cursor, limit)
|
||||
if data and type(data[0]).__name__ == 'Row':
|
||||
data = [r.values() for r in data]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def mutate_label(label):
|
||||
"""
|
||||
BigQuery field_name should start with a letter or underscore and contain only
|
||||
alphanumeric characters. Labels that start with a number are prefixed with an
|
||||
underscore. Any unsupported characters are replaced with underscores and an
|
||||
md5 hash is added to the end of the label to avoid possible collisions.
|
||||
:param str label: the original label which might include unsupported characters
|
||||
:return: String that is supported by the database
|
||||
"""
|
||||
label_hashed = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
|
||||
|
||||
# if label starts with number, add underscore as first character
|
||||
label_mutated = '_' + label if re.match(r'^\d', label) else label
|
||||
|
||||
# replace non-alphanumeric characters with underscores
|
||||
label_mutated = re.sub(r'[^\w]+', '_', label_mutated)
|
||||
if label_mutated != label:
|
||||
# add md5 hash to label to avoid possible collisions
|
||||
label_mutated += label_hashed
|
||||
|
||||
return label_mutated
|
||||
|
||||
@classmethod
|
||||
def truncate_label(cls, label):
|
||||
"""BigQuery requires column names start with either a letter or
|
||||
underscore. To make sure this is always the case, an underscore is prefixed
|
||||
to the truncated label.
|
||||
"""
|
||||
return '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def extra_table_metadata(cls, database, table_name, schema_name):
|
||||
indexes = database.get_indexes(table_name, schema_name)
|
||||
if not indexes:
|
||||
return {}
|
||||
partitions_columns = [
|
||||
index.get('column_names', []) for index in indexes
|
||||
if index.get('name') == 'partition'
|
||||
]
|
||||
cluster_columns = [
|
||||
index.get('column_names', []) for index in indexes
|
||||
if index.get('name') == 'clustering'
|
||||
]
|
||||
return {
|
||||
'partitions': {
|
||||
'cols': partitions_columns,
|
||||
},
|
||||
'clustering': {
|
||||
'cols': cluster_columns,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols):
|
||||
"""
|
||||
BigQuery dialect requires us to not use backtick in the fieldname which are
|
||||
nested.
|
||||
Using literal_column handles that issue.
|
||||
https://docs.sqlalchemy.org/en/latest/core/tutorial.html#using-more-specific-text-with-table-literal-column-and-column
|
||||
Also explicility specifying column names so we don't encounter duplicate
|
||||
column names in the result.
|
||||
"""
|
||||
return [literal_column(c.get('name')).label(c.get('name').replace('.', '__'))
|
||||
for c in cols]
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return 'TIMESTAMP_SECONDS({col})'
|
||||
|
||||
@classmethod
|
||||
def epoch_ms_to_dttm(cls):
|
||||
return 'TIMESTAMP_MILLIS({col})'
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
class ClickHouseEngineSpec(BaseEngineSpec):
|
||||
"""Dialect for ClickHouse analytical DB."""
|
||||
|
||||
engine = 'clickhouse'
|
||||
|
||||
time_secondary_columns = True
|
||||
time_groupby_inline = True
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1M': 'toStartOfMinute(toDateTime({col}))',
|
||||
'PT5M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 300)*300)',
|
||||
'PT10M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 600)*600)',
|
||||
'PT15M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 900)*900)',
|
||||
'PT0.5H': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 1800)*1800)',
|
||||
'PT1H': 'toStartOfHour(toDateTime({col}))',
|
||||
'P1D': 'toStartOfDay(toDateTime({col}))',
|
||||
'P1W': 'toMonday(toDateTime({col}))',
|
||||
'P1M': 'toStartOfMonth(toDateTime({col}))',
|
||||
'P0.25Y': 'toStartOfQuarter(toDateTime({col}))',
|
||||
'P1Y': 'toStartOfYear(toDateTime({col}))',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
if tt == 'DATE':
|
||||
return "toDate('{}')".format(dttm.strftime('%Y-%m-%d'))
|
||||
if tt == 'DATETIME':
|
||||
return "toDateTime('{}')".format(
|
||||
dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||
|
||||
|
||||
class Db2EngineSpec(BaseEngineSpec):
|
||||
engine = 'ibm_db_sa'
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
force_column_alias_quotes = True
|
||||
max_column_name_length = 30
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': 'CAST({col} as TIMESTAMP)'
|
||||
' - MICROSECOND({col}) MICROSECONDS',
|
||||
'PT1M': 'CAST({col} as TIMESTAMP)'
|
||||
' - SECOND({col}) SECONDS'
|
||||
' - MICROSECOND({col}) MICROSECONDS',
|
||||
'PT1H': 'CAST({col} as TIMESTAMP)'
|
||||
' - MINUTE({col}) MINUTES'
|
||||
' - SECOND({col}) SECONDS'
|
||||
' - MICROSECOND({col}) MICROSECONDS ',
|
||||
'P1D': 'CAST({col} as TIMESTAMP)'
|
||||
' - HOUR({col}) HOURS'
|
||||
' - MINUTE({col}) MINUTES'
|
||||
' - SECOND({col}) SECONDS'
|
||||
' - MICROSECOND({col}) MICROSECONDS',
|
||||
'P1W': '{col} - (DAYOFWEEK({col})) DAYS',
|
||||
'P1M': '{col} - (DAY({col})-1) DAYS',
|
||||
'P0.25Y': '{col} - (DAY({col})-1) DAYS'
|
||||
' - (MONTH({col})-1) MONTHS'
|
||||
' + ((QUARTER({col})-1) * 3) MONTHS',
|
||||
'P1Y': '{col} - (DAY({col})-1) DAYS'
|
||||
' - (MONTH({col})-1) MONTHS',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return "(TIMESTAMP('1970-01-01', '00:00:00') + {col} SECONDS)"
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S'))
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from urllib import parse
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
class DrillEngineSpec(BaseEngineSpec):
|
||||
"""Engine spec for Apache Drill"""
|
||||
engine = 'drill'
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': "NEARESTDATE({col}, 'SECOND')",
|
||||
'PT1M': "NEARESTDATE({col}, 'MINUTE')",
|
||||
'PT15M': "NEARESTDATE({col}, 'QUARTER_HOUR')",
|
||||
'PT0.5H': "NEARESTDATE({col}, 'HALF_HOUR')",
|
||||
'PT1H': "NEARESTDATE({col}, 'HOUR')",
|
||||
'P1D': "NEARESTDATE({col}, 'DAY')",
|
||||
'P1W': "NEARESTDATE({col}, 'WEEK_SUNDAY')",
|
||||
'P1M': "NEARESTDATE({col}, 'MONTH')",
|
||||
'P0.25Y': "NEARESTDATE({col}, 'QUARTER')",
|
||||
'P1Y': "NEARESTDATE({col}, 'YEAR')",
|
||||
}
|
||||
|
||||
# Returns a function to convert a Unix timestamp in milliseconds to a date
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return cls.epoch_ms_to_dttm().replace('{col}', '({col}*1000)')
|
||||
|
||||
@classmethod
|
||||
def epoch_ms_to_dttm(cls):
|
||||
return 'TO_DATE({col})'
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
if tt == 'DATE':
|
||||
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
|
||||
elif tt == 'TIMESTAMP':
|
||||
return "CAST('{}' AS TIMESTAMP)".format(
|
||||
dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema):
|
||||
if selected_schema:
|
||||
uri.database = parse.quote(selected_schema, safe='')
|
||||
return uri
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
class DruidEngineSpec(BaseEngineSpec):
|
||||
"""Engine spec for Druid.io"""
|
||||
engine = 'druid'
|
||||
inner_joins = False
|
||||
allows_subquery = False
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': 'FLOOR({col} TO SECOND)',
|
||||
'PT1M': 'FLOOR({col} TO MINUTE)',
|
||||
'PT1H': 'FLOOR({col} TO HOUR)',
|
||||
'P1D': 'FLOOR({col} TO DAY)',
|
||||
'P1W': 'FLOOR({col} TO WEEK)',
|
||||
'P1M': 'FLOOR({col} TO MONTH)',
|
||||
'P0.25Y': 'FLOOR({col} TO QUARTER)',
|
||||
'P1Y': 'FLOOR({col} TO YEAR)',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def alter_new_orm_column(cls, orm_col):
|
||||
if orm_col.column_name == '__time':
|
||||
orm_col.is_dttm = True
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
||||
|
||||
|
||||
class GSheetsEngineSpec(SqliteEngineSpec):
|
||||
"""Engine for Google spreadsheets"""
|
||||
engine = 'gsheets'
|
||||
inner_joins = False
|
||||
allows_subquery = False
|
||||
|
|
@ -0,0 +1,368 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List
|
||||
from urllib import parse
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from superset import app, conf
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.utils import core as utils
|
||||
|
||||
QueryStatus = utils.QueryStatus
|
||||
config = app.config
|
||||
|
||||
tracking_url_trans = conf.get('TRACKING_URL_TRANSFORMER')
|
||||
hive_poll_interval = conf.get('HIVE_POLL_INTERVAL')
|
||||
|
||||
|
||||
class HiveEngineSpec(PrestoEngineSpec):
|
||||
"""Reuses PrestoEngineSpec functionality."""
|
||||
|
||||
engine = 'hive'
|
||||
max_column_name_length = 767
|
||||
|
||||
# Scoping regex at class level to avoid recompiling
|
||||
# 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
|
||||
jobs_stats_r = re.compile(
|
||||
r'.*INFO.*Total jobs = (?P<max_jobs>[0-9]+)')
|
||||
# 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
|
||||
launching_job_r = re.compile(
|
||||
'.*INFO.*Launching Job (?P<job_number>[0-9]+) out of '
|
||||
'(?P<max_jobs>[0-9]+)')
|
||||
# 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
|
||||
# map = 0%, reduce = 0%
|
||||
stage_progress_r = re.compile(
|
||||
r'.*INFO.*Stage-(?P<stage_number>[0-9]+).*'
|
||||
r'map = (?P<map_progress>[0-9]+)%.*'
|
||||
r'reduce = (?P<reduce_progress>[0-9]+)%.*')
|
||||
|
||||
@classmethod
|
||||
def patch(cls):
|
||||
from pyhive import hive # pylint: disable=no-name-in-module
|
||||
from superset.db_engines import hive as patched_hive
|
||||
from TCLIService import (
|
||||
constants as patched_constants,
|
||||
ttypes as patched_ttypes,
|
||||
TCLIService as patched_TCLIService)
|
||||
|
||||
hive.TCLIService = patched_TCLIService
|
||||
hive.constants = patched_constants
|
||||
hive.ttypes = patched_ttypes
|
||||
hive.Cursor.fetch_logs = patched_hive.fetch_logs
|
||||
|
||||
@classmethod
|
||||
def get_all_datasource_names(cls, db, datasource_type: str) \
|
||||
-> List[utils.DatasourceName]:
|
||||
return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit):
|
||||
import pyhive
|
||||
from TCLIService import ttypes
|
||||
state = cursor.poll()
|
||||
if state.operationState == ttypes.TOperationState.ERROR_STATE:
|
||||
raise Exception('Query error', state.errorMessage)
|
||||
try:
|
||||
return super(HiveEngineSpec, cls).fetch_data(cursor, limit)
|
||||
except pyhive.exc.ProgrammingError:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def create_table_from_csv(form, table):
|
||||
"""Uploads a csv file and creates a superset datasource in Hive."""
|
||||
def convert_to_hive_type(col_type):
|
||||
"""maps tableschema's types to hive types"""
|
||||
tableschema_to_hive_types = {
|
||||
'boolean': 'BOOLEAN',
|
||||
'integer': 'INT',
|
||||
'number': 'DOUBLE',
|
||||
'string': 'STRING',
|
||||
}
|
||||
return tableschema_to_hive_types.get(col_type, 'STRING')
|
||||
|
||||
bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET']
|
||||
|
||||
if not bucket_path:
|
||||
logging.info('No upload bucket specified')
|
||||
raise Exception(
|
||||
'No upload bucket specified. You can specify one in the config file.')
|
||||
|
||||
table_name = form.name.data
|
||||
schema_name = form.schema.data
|
||||
|
||||
if config.get('UPLOADED_CSV_HIVE_NAMESPACE'):
|
||||
if '.' in table_name or schema_name:
|
||||
raise Exception(
|
||||
"You can't specify a namespace. "
|
||||
'All tables will be uploaded to the `{}` namespace'.format(
|
||||
config.get('HIVE_NAMESPACE')))
|
||||
full_table_name = '{}.{}'.format(
|
||||
config.get('UPLOADED_CSV_HIVE_NAMESPACE'), table_name)
|
||||
else:
|
||||
if '.' in table_name and schema_name:
|
||||
raise Exception(
|
||||
"You can't specify a namespace both in the name of the table "
|
||||
'and in the schema field. Please remove one')
|
||||
|
||||
full_table_name = '{}.{}'.format(
|
||||
schema_name, table_name) if schema_name else table_name
|
||||
|
||||
filename = form.csv_file.data.filename
|
||||
|
||||
upload_prefix = config['CSV_TO_HIVE_UPLOAD_DIRECTORY']
|
||||
upload_path = config['UPLOAD_FOLDER'] + \
|
||||
secure_filename(filename)
|
||||
|
||||
# Optional dependency
|
||||
from tableschema import Table # pylint: disable=import-error
|
||||
hive_table_schema = Table(upload_path).infer()
|
||||
column_name_and_type = []
|
||||
for column_info in hive_table_schema['fields']:
|
||||
column_name_and_type.append(
|
||||
'`{}` {}'.format(
|
||||
column_info['name'],
|
||||
convert_to_hive_type(column_info['type'])))
|
||||
schema_definition = ', '.join(column_name_and_type)
|
||||
|
||||
# Optional dependency
|
||||
import boto3 # pylint: disable=import-error
|
||||
|
||||
s3 = boto3.client('s3')
|
||||
location = os.path.join('s3a://', bucket_path, upload_prefix, table_name)
|
||||
s3.upload_file(
|
||||
upload_path, bucket_path,
|
||||
os.path.join(upload_prefix, table_name, filename))
|
||||
sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} )
|
||||
ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
|
||||
TEXTFILE LOCATION '{location}'
|
||||
tblproperties ('skip.header.line.count'='1')"""
|
||||
logging.info(form.con.data)
|
||||
engine = create_engine(form.con.data.sqlalchemy_uri_decrypted)
|
||||
engine.execute(sql)
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
if tt == 'DATE':
|
||||
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
|
||||
elif tt == 'TIMESTAMP':
|
||||
return "CAST('{}' AS TIMESTAMP)".format(
|
||||
dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
if selected_schema:
|
||||
uri.database = parse.quote(selected_schema, safe='')
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def extract_error_message(cls, e):
|
||||
msg = str(e)
|
||||
match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
|
||||
if match:
|
||||
msg = match.group(1)
|
||||
return msg
|
||||
|
||||
@classmethod
|
||||
def progress(cls, log_lines):
|
||||
total_jobs = 1 # assuming there's at least 1 job
|
||||
current_job = 1
|
||||
stages = {}
|
||||
for line in log_lines:
|
||||
match = cls.jobs_stats_r.match(line)
|
||||
if match:
|
||||
total_jobs = int(match.groupdict()['max_jobs']) or 1
|
||||
match = cls.launching_job_r.match(line)
|
||||
if match:
|
||||
current_job = int(match.groupdict()['job_number'])
|
||||
total_jobs = int(match.groupdict()['max_jobs']) or 1
|
||||
stages = {}
|
||||
match = cls.stage_progress_r.match(line)
|
||||
if match:
|
||||
stage_number = int(match.groupdict()['stage_number'])
|
||||
map_progress = int(match.groupdict()['map_progress'])
|
||||
reduce_progress = int(match.groupdict()['reduce_progress'])
|
||||
stages[stage_number] = (map_progress + reduce_progress) / 2
|
||||
logging.info(
|
||||
'Progress detail: {}, '
|
||||
'current job {}, '
|
||||
'total jobs: {}'.format(stages, current_job, total_jobs))
|
||||
|
||||
stage_progress = sum(
|
||||
stages.values()) / len(stages.values()) if stages else 0
|
||||
|
||||
progress = (
|
||||
100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
|
||||
)
|
||||
return int(progress)
|
||||
|
||||
@classmethod
|
||||
def get_tracking_url(cls, log_lines):
|
||||
lkp = 'Tracking URL = '
|
||||
for line in log_lines:
|
||||
if lkp in line:
|
||||
return line.split(lkp)[1]
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor, query, session):
|
||||
"""Updates progress information"""
|
||||
from pyhive import hive # pylint: disable=no-name-in-module
|
||||
unfinished_states = (
|
||||
hive.ttypes.TOperationState.INITIALIZED_STATE,
|
||||
hive.ttypes.TOperationState.RUNNING_STATE,
|
||||
)
|
||||
polled = cursor.poll()
|
||||
last_log_line = 0
|
||||
tracking_url = None
|
||||
job_id = None
|
||||
while polled.operationState in unfinished_states:
|
||||
query = session.query(type(query)).filter_by(id=query.id).one()
|
||||
if query.status == QueryStatus.STOPPED:
|
||||
cursor.cancel()
|
||||
break
|
||||
|
||||
log = cursor.fetch_logs() or ''
|
||||
if log:
|
||||
log_lines = log.splitlines()
|
||||
progress = cls.progress(log_lines)
|
||||
logging.info('Progress total: {}'.format(progress))
|
||||
needs_commit = False
|
||||
if progress > query.progress:
|
||||
query.progress = progress
|
||||
needs_commit = True
|
||||
if not tracking_url:
|
||||
tracking_url = cls.get_tracking_url(log_lines)
|
||||
if tracking_url:
|
||||
job_id = tracking_url.split('/')[-2]
|
||||
logging.info(
|
||||
'Found the tracking url: {}'.format(tracking_url))
|
||||
tracking_url = tracking_url_trans(tracking_url)
|
||||
logging.info(
|
||||
'Transformation applied: {}'.format(tracking_url))
|
||||
query.tracking_url = tracking_url
|
||||
logging.info('Job id: {}'.format(job_id))
|
||||
needs_commit = True
|
||||
if job_id and len(log_lines) > last_log_line:
|
||||
# Wait for job id before logging things out
|
||||
# this allows for prefixing all log lines and becoming
|
||||
# searchable in something like Kibana
|
||||
for l in log_lines[last_log_line:]:
|
||||
logging.info('[{}] {}'.format(job_id, l))
|
||||
last_log_line = len(log_lines)
|
||||
if needs_commit:
|
||||
session.commit()
|
||||
time.sleep(hive_poll_interval)
|
||||
polled = cursor.poll()
|
||||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str) -> List[dict]:
|
||||
return inspector.get_columns(table_name, schema)
|
||||
|
||||
@classmethod
|
||||
def where_latest_partition(
|
||||
cls, table_name, schema, database, qry, columns=None):
|
||||
try:
|
||||
col_name, value = cls.latest_partition(
|
||||
table_name, schema, database, show_first=True)
|
||||
except Exception:
|
||||
# table is not partitioned
|
||||
return False
|
||||
if value is not None:
|
||||
for c in columns:
|
||||
if c.get('name') == col_name:
|
||||
return qry.where(Column(col_name) == value)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
|
||||
return BaseEngineSpec._get_fields(cols)
|
||||
|
||||
@classmethod
|
||||
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
|
||||
# TODO(bogdan): implement`
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _latest_partition_from_df(cls, df):
|
||||
"""Hive partitions look like ds={partition name}"""
|
||||
if not df.empty:
|
||||
return df.ix[:, 0].max().split('=')[1]
|
||||
|
||||
@classmethod
|
||||
def _partition_query(
|
||||
cls, table_name, limit=0, order_by=None, filters=None):
|
||||
return f'SHOW PARTITIONS {table_name}'
|
||||
|
||||
@classmethod
|
||||
def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None,
|
||||
limit: int = 100, show_cols: bool = False, indent: bool = True,
|
||||
latest_partition: bool = True, cols: List[dict] = []) -> str:
|
||||
return BaseEngineSpec.select_star(
|
||||
my_db, table_name, engine, schema, limit,
|
||||
show_cols, indent, latest_partition, cols)
|
||||
|
||||
@classmethod
|
||||
def modify_url_for_impersonation(cls, url, impersonate_user, username):
|
||||
"""
|
||||
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
|
||||
:param url: SQLAlchemy URL object
|
||||
:param impersonate_user: Bool indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
"""
|
||||
# Do nothing in the URL object since instead this should modify
|
||||
# the configuraiton dictionary. See get_configuration_for_impersonation
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
|
||||
"""
|
||||
Return a configuration dictionary that can be merged with other configs
|
||||
that can set the correct properties for impersonating users
|
||||
:param uri: URI string
|
||||
:param impersonate_user: Bool indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
:return: Dictionary with configs required for impersonation
|
||||
"""
|
||||
configuration = {}
|
||||
url = make_url(uri)
|
||||
backend_name = url.get_backend_name()
|
||||
|
||||
# Must be Hive connection, enable impersonation, and set param auth=LDAP|KERBEROS
|
||||
if (backend_name == 'hive' and 'auth' in url.query.keys() and
|
||||
impersonate_user is True and username is not None):
|
||||
configuration['hive.server2.proxy.user'] = username
|
||||
return configuration
|
||||
|
||||
@staticmethod
|
||||
def execute(cursor, query, async_=False):
|
||||
kwargs = {'async': async_}
|
||||
cursor.execute(query, **kwargs)
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
class ImpalaEngineSpec(BaseEngineSpec):
|
||||
"""Engine spec for Cloudera's Impala"""
|
||||
|
||||
engine = 'impala'
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1M': "TRUNC({col}, 'MI')",
|
||||
'PT1H': "TRUNC({col}, 'HH')",
|
||||
'P1D': "TRUNC({col}, 'DD')",
|
||||
'P1W': "TRUNC({col}, 'WW')",
|
||||
'P1M': "TRUNC({col}, 'MONTH')",
|
||||
'P0.25Y': "TRUNC({col}, 'Q')",
|
||||
'P1Y': "TRUNC({col}, 'YYYY')",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return 'from_unixtime({col})'
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
if tt == 'DATE':
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d'))
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def get_schema_names(cls, inspector):
|
||||
schemas = [row[0] for row in inspector.engine.execute('SHOW SCHEMAS')
|
||||
if not row[0].startswith('_')]
|
||||
return schemas
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
class KylinEngineSpec(BaseEngineSpec):
|
||||
"""Dialect for Apache Kylin"""
|
||||
|
||||
engine = 'kylin'
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO SECOND) AS TIMESTAMP)',
|
||||
'PT1M': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MINUTE) AS TIMESTAMP)',
|
||||
'PT1H': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO HOUR) AS TIMESTAMP)',
|
||||
'P1D': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO DAY) AS DATE)',
|
||||
'P1W': 'CAST(TIMESTAMPADD(WEEK, WEEK(CAST({col} AS DATE)) - 1, \
|
||||
FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)',
|
||||
'P1M': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MONTH) AS DATE)',
|
||||
'P0.25Y': 'CAST(TIMESTAMPADD(QUARTER, QUARTER(CAST({col} AS DATE)) - 1, \
|
||||
FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)',
|
||||
'P1Y': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO YEAR) AS DATE)',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
if tt == 'DATE':
|
||||
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
|
||||
if tt == 'TIMESTAMP':
|
||||
return "CAST('{}' AS TIMESTAMP)".format(
|
||||
dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
import re
|
||||
|
||||
from sqlalchemy.types import String, UnicodeText
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||
|
||||
|
||||
class MssqlEngineSpec(BaseEngineSpec):
|
||||
engine = 'mssql'
|
||||
epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')"
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
max_column_name_length = 128
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': "DATEADD(second, DATEDIFF(second, '2000-01-01', {col}), '2000-01-01')",
|
||||
'PT1M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}), 0)',
|
||||
'PT5M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 5 * 5, 0)',
|
||||
'PT10M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 10 * 10, 0)',
|
||||
'PT15M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 15 * 15, 0)',
|
||||
'PT0.5H': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 30 * 30, 0)',
|
||||
'PT1H': 'DATEADD(hour, DATEDIFF(hour, 0, {col}), 0)',
|
||||
'P1D': 'DATEADD(day, DATEDIFF(day, 0, {col}), 0)',
|
||||
'P1W': 'DATEADD(week, DATEDIFF(week, 0, {col}), 0)',
|
||||
'P1M': 'DATEADD(month, DATEDIFF(month, 0, {col}), 0)',
|
||||
'P0.25Y': 'DATEADD(quarter, DATEDIFF(quarter, 0, {col}), 0)',
|
||||
'P1Y': 'DATEADD(year, DATEDIFF(year, 0, {col}), 0)',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
return "CONVERT(DATETIME, '{}', 126)".format(dttm.isoformat())
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit):
|
||||
data = super(MssqlEngineSpec, cls).fetch_data(cursor, limit)
|
||||
if data and type(data[0]).__name__ == 'Row':
|
||||
data = [[elem for elem in r] for r in data]
|
||||
return data
|
||||
|
||||
column_types = [
|
||||
(String(), re.compile(r'^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)', re.IGNORECASE)),
|
||||
(UnicodeText(), re.compile(r'^N((VAR){0,1}CHAR|TEXT)', re.IGNORECASE)),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_sqla_column_type(cls, type_):
|
||||
for sqla_type, regex in cls.column_types:
|
||||
if regex.match(type_):
|
||||
return sqla_type
|
||||
return None
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from typing import Dict
|
||||
from urllib import parse
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
class MySQLEngineSpec(BaseEngineSpec):
|
||||
engine = 'mysql'
|
||||
max_column_name_length = 64
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': 'DATE_ADD(DATE({col}), '
|
||||
'INTERVAL (HOUR({col})*60*60 + MINUTE({col})*60'
|
||||
' + SECOND({col})) SECOND)',
|
||||
'PT1M': 'DATE_ADD(DATE({col}), '
|
||||
'INTERVAL (HOUR({col})*60 + MINUTE({col})) MINUTE)',
|
||||
'PT1H': 'DATE_ADD(DATE({col}), '
|
||||
'INTERVAL HOUR({col}) HOUR)',
|
||||
'P1D': 'DATE({col})',
|
||||
'P1W': 'DATE(DATE_SUB({col}, '
|
||||
'INTERVAL DAYOFWEEK({col}) - 1 DAY))',
|
||||
'P1M': 'DATE(DATE_SUB({col}, '
|
||||
'INTERVAL DAYOFMONTH({col}) - 1 DAY))',
|
||||
'P0.25Y': 'MAKEDATE(YEAR({col}), 1) '
|
||||
'+ INTERVAL QUARTER({col}) QUARTER - INTERVAL 1 QUARTER',
|
||||
'P1Y': 'DATE(DATE_SUB({col}, '
|
||||
'INTERVAL DAYOFYEAR({col}) - 1 DAY))',
|
||||
'1969-12-29T00:00:00Z/P1W': 'DATE(DATE_SUB({col}, '
|
||||
'INTERVAL DAYOFWEEK(DATE_SUB({col}, '
|
||||
'INTERVAL 1 DAY)) - 1 DAY))',
|
||||
}
|
||||
|
||||
type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
if target_type.upper() in ('DATETIME', 'DATE'):
|
||||
return "STR_TO_DATE('{}', '%Y-%m-%d %H:%i:%s')".format(
|
||||
dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
if selected_schema:
|
||||
uri.database = parse.quote(selected_schema, safe='')
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def get_datatype(cls, type_code):
|
||||
if not cls.type_code_map:
|
||||
# only import and store if needed at least once
|
||||
import MySQLdb
|
||||
ft = MySQLdb.constants.FIELD_TYPE
|
||||
cls.type_code_map = {
|
||||
getattr(ft, k): k
|
||||
for k in dir(ft)
|
||||
if not k.startswith('_')
|
||||
}
|
||||
datatype = type_code
|
||||
if isinstance(type_code, int):
|
||||
datatype = cls.type_code_map.get(type_code)
|
||||
if datatype and isinstance(datatype, str) and len(datatype):
|
||||
return datatype
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return 'from_unixtime({col})'
|
||||
|
||||
@classmethod
|
||||
def extract_error_message(cls, e):
|
||||
"""Extract error message for queries"""
|
||||
message = str(e)
|
||||
try:
|
||||
if isinstance(e.args, tuple) and len(e.args) > 1:
|
||||
message = e.args[1]
|
||||
except Exception:
|
||||
pass
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def column_datatype_to_string(cls, sqla_column_type, dialect):
|
||||
datatype = super().column_datatype_to_string(sqla_column_type, dialect)
|
||||
# MySQL dialect started returning long overflowing datatype
|
||||
# as in 'VARCHAR(255) COLLATE UTF8MB4_GENERAL_CI'
|
||||
# and we don't need the verbose collation type
|
||||
str_cutoff = ' COLLATE '
|
||||
if str_cutoff in datatype:
|
||||
datatype = datatype.split(str_cutoff)[0]
|
||||
return datatype
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import LimitMethod
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
|
||||
|
||||
class OracleEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'oracle'
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
force_column_alias_quotes = True
|
||||
max_column_name_length = 30
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': 'CAST({col} as DATE)',
|
||||
'PT1M': "TRUNC(CAST({col} as DATE), 'MI')",
|
||||
'PT1H': "TRUNC(CAST({col} as DATE), 'HH')",
|
||||
'P1D': "TRUNC(CAST({col} as DATE), 'DDD')",
|
||||
'P1W': "TRUNC(CAST({col} as DATE), 'WW')",
|
||||
'P1M': "TRUNC(CAST({col} as DATE), 'MONTH')",
|
||||
'P0.25Y': "TRUNC(CAST({col} as DATE), 'Q')",
|
||||
'P1Y': "TRUNC(CAST({col} as DATE), 'YEAR')",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
return (
|
||||
"""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
|
||||
).format(dttm.isoformat())
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
|
||||
|
||||
|
||||
class PinotEngineSpec(BaseEngineSpec):
|
||||
engine = 'pinot'
|
||||
allows_subquery = False
|
||||
inner_joins = False
|
||||
supports_column_aliases = False
|
||||
|
||||
# Pinot does its own conversion below
|
||||
time_grain_functions: Dict[Optional[str], str] = {
|
||||
'PT1S': '1:SECONDS',
|
||||
'PT1M': '1:MINUTES',
|
||||
'PT1H': '1:HOURS',
|
||||
'P1D': '1:DAYS',
|
||||
'P1W': '1:WEEKS',
|
||||
'P1M': '1:MONTHS',
|
||||
'P0.25Y': '3:MONTHS',
|
||||
'P1Y': '1:YEARS',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
|
||||
time_grain: Optional[str]) -> TimestampExpression:
|
||||
is_epoch = pdf in ('epoch_s', 'epoch_ms')
|
||||
if not is_epoch:
|
||||
raise NotImplementedError('Pinot currently only supports epochs')
|
||||
# The DATETIMECONVERT pinot udf is documented at
|
||||
# Per https://github.com/apache/incubator-pinot/wiki/dateTimeConvert-UDF
|
||||
# We are not really converting any time units, just bucketing them.
|
||||
seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS'
|
||||
tf = f'1:{seconds_or_ms}:EPOCH'
|
||||
granularity = cls.time_grain_functions.get(time_grain)
|
||||
if not granularity:
|
||||
raise NotImplementedError('No pinot grain spec for ' + str(time_grain))
|
||||
# In pinot the output is a string since there is no timestamp column like pg
|
||||
time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")'
|
||||
return TimestampExpression(time_expr, col)
|
||||
|
||||
@classmethod
|
||||
def make_select_compatible(cls, groupby_exprs, select_exprs):
|
||||
# Pinot does not want the group by expr's to appear in the select clause
|
||||
select_sans_groupby = []
|
||||
# We want identity and not equality, so doing the filtering manually
|
||||
for s in select_exprs:
|
||||
for gr in groupby_exprs:
|
||||
if s is gr:
|
||||
break
|
||||
else:
|
||||
select_sans_groupby.append(s)
|
||||
return select_sans_groupby
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||
|
||||
|
||||
class PostgresBaseEngineSpec(BaseEngineSpec):
|
||||
""" Abstract class for Postgres 'like' databases """
|
||||
|
||||
engine = ''
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': "DATE_TRUNC('second', {col})",
|
||||
'PT1M': "DATE_TRUNC('minute', {col})",
|
||||
'PT1H': "DATE_TRUNC('hour', {col})",
|
||||
'P1D': "DATE_TRUNC('day', {col})",
|
||||
'P1W': "DATE_TRUNC('week', {col})",
|
||||
'P1M': "DATE_TRUNC('month', {col})",
|
||||
'P0.25Y': "DATE_TRUNC('quarter', {col})",
|
||||
'P1Y': "DATE_TRUNC('year', {col})",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit):
|
||||
if not cursor.description:
|
||||
return []
|
||||
if cls.limit_method == LimitMethod.FETCH_MANY:
|
||||
return cursor.fetchmany(limit)
|
||||
return cursor.fetchall()
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return "(timestamp 'epoch' + {col} * interval '1 second')"
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
|
||||
class PostgresEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'postgresql'
|
||||
max_column_name_length = 63
|
||||
try_remove_schema_from_table_name = False
|
||||
|
||||
@classmethod
|
||||
def get_table_names(cls, inspector, schema):
|
||||
"""Need to consider foreign tables for PostgreSQL"""
|
||||
tables = inspector.get_table_names(schema)
|
||||
tables.extend(inspector.get_foreign_table_names(schema))
|
||||
return sorted(tables)
|
||||
|
|
@ -0,0 +1,967 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
import time
|
||||
from typing import List, Set, Tuple
|
||||
from urllib import parse
|
||||
|
||||
from sqlalchemy import Column, literal_column, types
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.exceptions import SupersetTemplateException
|
||||
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
|
||||
from superset.utils import core as utils
|
||||
|
||||
QueryStatus = utils.QueryStatus
|
||||
|
||||
|
||||
class PrestoEngineSpec(BaseEngineSpec):
|
||||
engine = 'presto'
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': "date_trunc('second', CAST({col} AS TIMESTAMP))",
|
||||
'PT1M': "date_trunc('minute', CAST({col} AS TIMESTAMP))",
|
||||
'PT1H': "date_trunc('hour', CAST({col} AS TIMESTAMP))",
|
||||
'P1D': "date_trunc('day', CAST({col} AS TIMESTAMP))",
|
||||
'P1W': "date_trunc('week', CAST({col} AS TIMESTAMP))",
|
||||
'P1M': "date_trunc('month', CAST({col} AS TIMESTAMP))",
|
||||
'P0.25Y': "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
|
||||
'P1Y': "date_trunc('year', CAST({col} AS TIMESTAMP))",
|
||||
'P1W/1970-01-03T00:00:00Z':
|
||||
"date_add('day', 5, date_trunc('week', date_add('day', 1, "
|
||||
'CAST({col} AS TIMESTAMP))))',
|
||||
'1969-12-28T00:00:00Z/P1W':
|
||||
"date_add('day', -1, date_trunc('week', "
|
||||
"date_add('day', 1, CAST({col} AS TIMESTAMP))))",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_view_names(cls, inspector, schema):
|
||||
"""Returns an empty list
|
||||
|
||||
get_table_names() function returns all table names and view names,
|
||||
and get_view_names() is not implemented in sqlalchemy_presto.py
|
||||
https://github.com/dropbox/PyHive/blob/e25fc8440a0686bbb7a5db5de7cb1a77bdb4167a/pyhive/sqlalchemy_presto.py
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _create_column_info(cls, name: str, data_type: str) -> dict:
|
||||
"""
|
||||
Create column info object
|
||||
:param name: column name
|
||||
:param data_type: column data type
|
||||
:return: column info object
|
||||
"""
|
||||
return {
|
||||
'name': name,
|
||||
'type': f'{data_type}',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_full_name(cls, names: List[Tuple[str, str]]) -> str:
|
||||
"""
|
||||
Get the full column name
|
||||
:param names: list of all individual column names
|
||||
:return: full column name
|
||||
"""
|
||||
return '.'.join(column[0] for column in names if column[0])
|
||||
|
||||
@classmethod
|
||||
def _has_nested_data_types(cls, component_type: str) -> bool:
|
||||
"""
|
||||
Check if string contains a data type. We determine if there is a data type by
|
||||
whitespace or multiple data types by commas
|
||||
:param component_type: data type
|
||||
:return: boolean
|
||||
"""
|
||||
comma_regex = r',(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'
|
||||
white_space_regex = r'\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'
|
||||
return re.search(comma_regex, component_type) is not None \
|
||||
or re.search(white_space_regex, component_type) is not None
|
||||
|
||||
@classmethod
|
||||
def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]:
|
||||
"""
|
||||
Split data type based on given delimiter. Do not split the string if the
|
||||
delimiter is enclosed in quotes
|
||||
:param data_type: data type
|
||||
:param delimiter: string separator (i.e. open parenthesis, closed parenthesis,
|
||||
comma, whitespace)
|
||||
:return: list of strings after breaking it by the delimiter
|
||||
"""
|
||||
return re.split(
|
||||
r'{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'.format(delimiter), data_type)
|
||||
|
||||
@classmethod
|
||||
def _parse_structural_column(cls,
|
||||
parent_column_name: str,
|
||||
parent_data_type: str,
|
||||
result: List[dict]) -> None:
|
||||
"""
|
||||
Parse a row or array column
|
||||
:param result: list tracking the results
|
||||
"""
|
||||
formatted_parent_column_name = parent_column_name
|
||||
# Quote the column name if there is a space
|
||||
if ' ' in parent_column_name:
|
||||
formatted_parent_column_name = f'"{parent_column_name}"'
|
||||
full_data_type = f'{formatted_parent_column_name} {parent_data_type}'
|
||||
original_result_len = len(result)
|
||||
# split on open parenthesis ( to get the structural
|
||||
# data type and its component types
|
||||
data_types = cls._split_data_type(full_data_type, r'\(')
|
||||
stack: List[Tuple[str, str]] = []
|
||||
for data_type in data_types:
|
||||
# split on closed parenthesis ) to track which component
|
||||
# types belong to what structural data type
|
||||
inner_types = cls._split_data_type(data_type, r'\)')
|
||||
for inner_type in inner_types:
|
||||
# We have finished parsing multiple structural data types
|
||||
if not inner_type and len(stack) > 0:
|
||||
stack.pop()
|
||||
elif cls._has_nested_data_types(inner_type):
|
||||
# split on comma , to get individual data types
|
||||
single_fields = cls._split_data_type(inner_type, ',')
|
||||
for single_field in single_fields:
|
||||
single_field = single_field.strip()
|
||||
# If component type starts with a comma, the first single field
|
||||
# will be an empty string. Disregard this empty string.
|
||||
if not single_field:
|
||||
continue
|
||||
# split on whitespace to get field name and data type
|
||||
field_info = cls._split_data_type(single_field, r'\s')
|
||||
# check if there is a structural data type within
|
||||
# overall structural data type
|
||||
if field_info[1] == 'array' or field_info[1] == 'row':
|
||||
stack.append((field_info[0], field_info[1]))
|
||||
full_parent_path = cls._get_full_name(stack)
|
||||
result.append(cls._create_column_info(
|
||||
full_parent_path,
|
||||
presto_type_map[field_info[1]]()))
|
||||
else: # otherwise this field is a basic data type
|
||||
full_parent_path = cls._get_full_name(stack)
|
||||
column_name = '{}.{}'.format(full_parent_path, field_info[0])
|
||||
result.append(cls._create_column_info(
|
||||
column_name, presto_type_map[field_info[1]]()))
|
||||
# If the component type ends with a structural data type, do not pop
|
||||
# the stack. We have run across a structural data type within the
|
||||
# overall structural data type. Otherwise, we have completely parsed
|
||||
# through the entire structural data type and can move on.
|
||||
if not (inner_type.endswith('array') or inner_type.endswith('row')):
|
||||
stack.pop()
|
||||
# We have an array of row objects (i.e. array(row(...)))
|
||||
elif 'array' == inner_type or 'row' == inner_type:
|
||||
# Push a dummy object to represent the structural data type
|
||||
stack.append(('', inner_type))
|
||||
# We have an array of a basic data types(i.e. array(varchar)).
|
||||
elif len(stack) > 0:
|
||||
# Because it is an array of a basic data type. We have finished
|
||||
# parsing the structural data type and can move on.
|
||||
stack.pop()
|
||||
# Unquote the column name if necessary
|
||||
if formatted_parent_column_name != parent_column_name:
|
||||
for index in range(original_result_len, len(result)):
|
||||
result[index]['name'] = result[index]['name'].replace(
|
||||
formatted_parent_column_name, parent_column_name)
|
||||
|
||||
@classmethod
|
||||
def _show_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str) -> List[RowProxy]:
|
||||
"""
|
||||
Show presto column names
|
||||
:param inspector: object that performs database schema inspection
|
||||
:param table_name: table name
|
||||
:param schema: schema name
|
||||
:return: list of column objects
|
||||
"""
|
||||
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
|
||||
full_table = quote(table_name)
|
||||
if schema:
|
||||
full_table = '{}.{}'.format(quote(schema), full_table)
|
||||
columns = inspector.bind.execute('SHOW COLUMNS FROM {}'.format(full_table))
|
||||
return columns
|
||||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str) -> List[dict]:
|
||||
"""
|
||||
Get columns from a Presto data source. This includes handling row and
|
||||
array data types
|
||||
:param inspector: object that performs database schema inspection
|
||||
:param table_name: table name
|
||||
:param schema: schema name
|
||||
:return: a list of results that contain column info
|
||||
(i.e. column name and data type)
|
||||
"""
|
||||
columns = cls._show_columns(inspector, table_name, schema)
|
||||
result: List[dict] = []
|
||||
for column in columns:
|
||||
try:
|
||||
# parse column if it is a row or array
|
||||
if 'array' in column.Type or 'row' in column.Type:
|
||||
structural_column_index = len(result)
|
||||
cls._parse_structural_column(column.Column, column.Type, result)
|
||||
result[structural_column_index]['nullable'] = getattr(
|
||||
column, 'Null', True)
|
||||
result[structural_column_index]['default'] = None
|
||||
continue
|
||||
else: # otherwise column is a basic data type
|
||||
column_type = presto_type_map[column.Type]()
|
||||
except KeyError:
|
||||
logging.info('Did not recognize type {} of column {}'.format(
|
||||
column.Type, column.Column))
|
||||
column_type = types.NullType
|
||||
column_info = cls._create_column_info(column.Column, column_type)
|
||||
column_info['nullable'] = getattr(column, 'Null', True)
|
||||
column_info['default'] = None
|
||||
result.append(column_info)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _is_column_name_quoted(cls, column_name: str) -> bool:
|
||||
"""
|
||||
Check if column name is in quotes
|
||||
:param column_name: column name
|
||||
:return: boolean
|
||||
"""
|
||||
return column_name.startswith('"') and column_name.endswith('"')
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
|
||||
"""
|
||||
Format column clauses where names are in quotes and labels are specified
|
||||
:param cols: columns
|
||||
:return: column clauses
|
||||
"""
|
||||
column_clauses = []
|
||||
# Column names are separated by periods. This regex will find periods in a string
|
||||
# if they are not enclosed in quotes because if a period is enclosed in quotes,
|
||||
# then that period is part of a column name.
|
||||
dot_pattern = r"""\. # split on period
|
||||
(?= # look ahead
|
||||
(?: # create non-capture group
|
||||
[^\"]*\"[^\"]*\" # two quotes
|
||||
)*[^\"]*$) # end regex"""
|
||||
dot_regex = re.compile(dot_pattern, re.VERBOSE)
|
||||
for col in cols:
|
||||
# get individual column names
|
||||
col_names = re.split(dot_regex, col['name'])
|
||||
# quote each column name if it is not already quoted
|
||||
for index, col_name in enumerate(col_names):
|
||||
if not cls._is_column_name_quoted(col_name):
|
||||
col_names[index] = '"{}"'.format(col_name)
|
||||
quoted_col_name = '.'.join(
|
||||
col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"'
|
||||
for col_name in col_names)
|
||||
# create column clause in the format "name"."name" AS "name.name"
|
||||
column_clause = literal_column(quoted_col_name).label(col['name'])
|
||||
column_clauses.append(column_clause)
|
||||
return column_clauses
|
||||
|
||||
@classmethod
|
||||
def _filter_out_array_nested_cols(
|
||||
cls, cols: List[dict]) -> Tuple[List[dict], List[dict]]:
|
||||
"""
|
||||
Filter out columns that correspond to array content. We know which columns to
|
||||
skip because cols is a list provided to us in a specific order where a structural
|
||||
column is positioned right before its content.
|
||||
|
||||
Example: Column Name: ColA, Column Data Type: array(row(nest_obj int))
|
||||
cols = [ ..., ColA, ColA.nest_obj, ... ]
|
||||
|
||||
When we run across an array, check if subsequent column names start with the
|
||||
array name and skip them.
|
||||
:param cols: columns
|
||||
:return: filtered list of columns and list of array columns and its nested fields
|
||||
"""
|
||||
filtered_cols = []
|
||||
array_cols = []
|
||||
curr_array_col_name = None
|
||||
for col in cols:
|
||||
# col corresponds to an array's content and should be skipped
|
||||
if curr_array_col_name and col['name'].startswith(curr_array_col_name):
|
||||
array_cols.append(col)
|
||||
continue
|
||||
# col is an array so we need to check if subsequent
|
||||
# columns correspond to the array's contents
|
||||
elif str(col['type']) == 'ARRAY':
|
||||
curr_array_col_name = col['name']
|
||||
array_cols.append(col)
|
||||
filtered_cols.append(col)
|
||||
else:
|
||||
curr_array_col_name = None
|
||||
filtered_cols.append(col)
|
||||
return filtered_cols, array_cols
|
||||
|
||||
@classmethod
|
||||
def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None,
|
||||
limit: int = 100, show_cols: bool = False, indent: bool = True,
|
||||
latest_partition: bool = True, cols: List[dict] = []) -> str:
|
||||
"""
|
||||
Include selecting properties of row objects. We cannot easily break arrays into
|
||||
rows, so render the whole array in its own row and skip columns that correspond
|
||||
to an array's contents.
|
||||
"""
|
||||
presto_cols = cols
|
||||
if show_cols:
|
||||
dot_regex = r'\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'
|
||||
presto_cols = [
|
||||
col for col in presto_cols if not re.search(dot_regex, col['name'])]
|
||||
return super(PrestoEngineSpec, cls).select_star(
|
||||
my_db, table_name, engine, schema, limit,
|
||||
show_cols, indent, latest_partition, presto_cols,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
database = uri.database
|
||||
if selected_schema and database:
|
||||
selected_schema = parse.quote(selected_schema, safe='')
|
||||
if '/' in database:
|
||||
database = database.split('/')[0] + '/' + selected_schema
|
||||
else:
|
||||
database += '/' + selected_schema
|
||||
uri.database = database
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
if tt == 'DATE':
|
||||
return "from_iso8601_date('{}')".format(dttm.isoformat()[:10])
|
||||
if tt == 'TIMESTAMP':
|
||||
return "from_iso8601_timestamp('{}')".format(dttm.isoformat())
|
||||
if tt == 'BIGINT':
|
||||
return "to_unixtime(from_iso8601_timestamp('{}'))".format(dttm.isoformat())
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return 'from_unixtime({col})'
|
||||
|
||||
@classmethod
|
||||
def get_all_datasource_names(cls, db, datasource_type: str) \
|
||||
-> List[utils.DatasourceName]:
|
||||
datasource_df = db.get_df(
|
||||
'SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S '
|
||||
"ORDER BY concat(table_schema, '.', table_name)".format(
|
||||
datasource_type.upper(),
|
||||
),
|
||||
None)
|
||||
datasource_names: List[utils.DatasourceName] = []
|
||||
for unused, row in datasource_df.iterrows():
|
||||
datasource_names.append(utils.DatasourceName(
|
||||
schema=row['table_schema'], table=row['table_name']))
|
||||
return datasource_names
|
||||
|
||||
@classmethod
|
||||
def _build_column_hierarchy(cls,
|
||||
columns: List[dict],
|
||||
parent_column_types: List[str],
|
||||
column_hierarchy: dict) -> None:
|
||||
"""
|
||||
Build a graph where the root node represents a column whose data type is in
|
||||
parent_column_types. A node's children represent that column's nested fields
|
||||
:param columns: list of columns
|
||||
:param parent_column_types: list of data types that decide what columns can
|
||||
be root nodes
|
||||
:param column_hierarchy: dictionary representing the graph
|
||||
"""
|
||||
if len(columns) == 0:
|
||||
return
|
||||
root = columns.pop(0)
|
||||
root_info = {'type': root['type'], 'children': []}
|
||||
column_hierarchy[root['name']] = root_info
|
||||
while columns:
|
||||
column = columns[0]
|
||||
# If the column name does not start with the root's name,
|
||||
# then this column is not a nested field
|
||||
if not column['name'].startswith(f"{root['name']}."):
|
||||
break
|
||||
# If the column's data type is one of the parent types,
|
||||
# then this column may have nested fields
|
||||
if str(column['type']) in parent_column_types:
|
||||
cls._build_column_hierarchy(columns, parent_column_types,
|
||||
column_hierarchy)
|
||||
root_info['children'].append(column['name'])
|
||||
continue
|
||||
else: # The column is a nested field
|
||||
root_info['children'].append(column['name'])
|
||||
columns.pop(0)
|
||||
|
||||
@classmethod
|
||||
def _create_row_and_array_hierarchy(
|
||||
cls, selected_columns: List[dict]) -> Tuple[dict, dict, List[dict]]:
|
||||
"""
|
||||
Build graphs where the root node represents a row or array and its children
|
||||
are that column's nested fields
|
||||
:param selected_columns: columns selected in a query
|
||||
:return: graph representing a row, graph representing an array, and a list
|
||||
of all the nested fields
|
||||
"""
|
||||
row_column_hierarchy: OrderedDict = OrderedDict()
|
||||
array_column_hierarchy: OrderedDict = OrderedDict()
|
||||
expanded_columns: List[dict] = []
|
||||
for column in selected_columns:
|
||||
if column['type'].startswith('ROW'):
|
||||
parsed_row_columns: List[dict] = []
|
||||
cls._parse_structural_column(column['name'],
|
||||
column['type'].lower(),
|
||||
parsed_row_columns)
|
||||
expanded_columns = expanded_columns + parsed_row_columns[1:]
|
||||
filtered_row_columns, array_columns = cls._filter_out_array_nested_cols(
|
||||
parsed_row_columns)
|
||||
cls._build_column_hierarchy(filtered_row_columns,
|
||||
['ROW'],
|
||||
row_column_hierarchy)
|
||||
cls._build_column_hierarchy(array_columns,
|
||||
['ROW', 'ARRAY'],
|
||||
array_column_hierarchy)
|
||||
elif column['type'].startswith('ARRAY'):
|
||||
parsed_array_columns: List[dict] = []
|
||||
cls._parse_structural_column(column['name'],
|
||||
column['type'].lower(),
|
||||
parsed_array_columns)
|
||||
expanded_columns = expanded_columns + parsed_array_columns[1:]
|
||||
cls._build_column_hierarchy(parsed_array_columns,
|
||||
['ROW', 'ARRAY'],
|
||||
array_column_hierarchy)
|
||||
return row_column_hierarchy, array_column_hierarchy, expanded_columns
|
||||
|
||||
@classmethod
|
||||
def _create_empty_row_of_data(cls, columns: List[dict]) -> dict:
|
||||
"""
|
||||
Create an empty row of data
|
||||
:param columns: list of columns
|
||||
:return: dictionary representing an empty row of data
|
||||
"""
|
||||
return {column['name']: '' for column in columns}
|
||||
|
||||
@classmethod
|
||||
def _expand_row_data(cls, datum: dict, column: str, column_hierarchy: dict) -> None:
|
||||
"""
|
||||
Separate out nested fields and its value in a row of data
|
||||
:param datum: row of data
|
||||
:param column: row column name
|
||||
:param column_hierarchy: dictionary tracking structural columns and its
|
||||
nested fields
|
||||
"""
|
||||
if column in datum:
|
||||
row_data = datum[column]
|
||||
row_children = column_hierarchy[column]['children']
|
||||
if row_data and len(row_data) != len(row_children):
|
||||
raise Exception('The number of data values and number of nested'
|
||||
'fields are not equal')
|
||||
elif row_data:
|
||||
for index, data_value in enumerate(row_data):
|
||||
datum[row_children[index]] = data_value
|
||||
else:
|
||||
for row_child in row_children:
|
||||
datum[row_child] = ''
|
||||
|
||||
@classmethod
|
||||
def _split_array_columns_by_process_state(
|
||||
cls, array_columns: List[str],
|
||||
array_column_hierarchy: dict,
|
||||
datum: dict) -> Tuple[List[str], Set[str]]:
|
||||
"""
|
||||
Take a list of array columns and split them according to whether or not we are
|
||||
ready to process them from a data set
|
||||
:param array_columns: list of array columns
|
||||
:param array_column_hierarchy: graph representing array columns
|
||||
:param datum: row of data
|
||||
:return: list of array columns ready to be processed and set of array columns
|
||||
not ready to be processed
|
||||
"""
|
||||
array_columns_to_process = []
|
||||
unprocessed_array_columns = set()
|
||||
child_array = None
|
||||
for array_column in array_columns:
|
||||
if array_column in datum:
|
||||
array_columns_to_process.append(array_column)
|
||||
elif str(array_column_hierarchy[array_column]['type']) == 'ARRAY':
|
||||
child_array = array_column
|
||||
unprocessed_array_columns.add(child_array)
|
||||
elif child_array and array_column.startswith(child_array):
|
||||
unprocessed_array_columns.add(array_column)
|
||||
return array_columns_to_process, unprocessed_array_columns
|
||||
|
||||
@classmethod
|
||||
def _convert_data_list_to_array_data_dict(
|
||||
cls, data: List[dict], array_columns_to_process: List[str]) -> dict:
|
||||
"""
|
||||
Pull out array data from rows of data into a dictionary where the key represents
|
||||
the index in the data list and the value is the array data values
|
||||
Example:
|
||||
data = [
|
||||
{'ColumnA': [1, 2], 'ColumnB': 3},
|
||||
{'ColumnA': [11, 22], 'ColumnB': 3}
|
||||
]
|
||||
data dictionary = {
|
||||
0: [{'ColumnA': [1, 2]],
|
||||
1: [{'ColumnA': [11, 22]]
|
||||
}
|
||||
:param data: rows of data
|
||||
:param array_columns_to_process: array columns we want to pull out
|
||||
:return: data dictionary
|
||||
"""
|
||||
array_data_dict = {}
|
||||
for data_index, datum in enumerate(data):
|
||||
all_array_datum = {}
|
||||
for array_column in array_columns_to_process:
|
||||
all_array_datum[array_column] = datum[array_column]
|
||||
array_data_dict[data_index] = [all_array_datum]
|
||||
return array_data_dict
|
||||
|
||||
@classmethod
|
||||
def _process_array_data(cls,
|
||||
data: List[dict],
|
||||
all_columns: List[dict],
|
||||
array_column_hierarchy: dict) -> dict:
|
||||
"""
|
||||
Pull out array data that is ready to be processed into a dictionary.
|
||||
The key refers to the index in the original data set. The value is
|
||||
a list of data values. Initially this list will contain just one value,
|
||||
the row of data that corresponds to the index in the original data set.
|
||||
As we process arrays, we will pull out array values into separate rows
|
||||
and append them to the list of data values.
|
||||
Example:
|
||||
Original data set = [
|
||||
{'ColumnA': [1, 2], 'ColumnB': [3]},
|
||||
{'ColumnA': [11, 22], 'ColumnB': [33]}
|
||||
]
|
||||
all_array_data (intially) = {
|
||||
0: [{'ColumnA': [1, 2], 'ColumnB': [3}],
|
||||
1: [{'ColumnA': [11, 22], 'ColumnB': [33]}]
|
||||
}
|
||||
all_array_data (after processing) = {
|
||||
0: [
|
||||
{'ColumnA': 1, 'ColumnB': 3},
|
||||
{'ColumnA': 2, 'ColumnB': ''},
|
||||
],
|
||||
1: [
|
||||
{'ColumnA': 11, 'ColumnB': 33},
|
||||
{'ColumnA': 22, 'ColumnB': ''},
|
||||
],
|
||||
}
|
||||
:param data: rows of data
|
||||
:param all_columns: list of columns
|
||||
:param array_column_hierarchy: graph representing array columns
|
||||
:return: dictionary representing processed array data
|
||||
"""
|
||||
array_columns = list(array_column_hierarchy.keys())
|
||||
# Determine what columns are ready to be processed. This is necessary for
|
||||
# array columns that contain rows with nested arrays. We first process
|
||||
# the outer arrays before processing inner arrays.
|
||||
array_columns_to_process, \
|
||||
unprocessed_array_columns = cls._split_array_columns_by_process_state(
|
||||
array_columns, array_column_hierarchy, data[0])
|
||||
|
||||
# Pull out array data that is ready to be processed into a dictionary.
|
||||
all_array_data = cls._convert_data_list_to_array_data_dict(
|
||||
data, array_columns_to_process)
|
||||
|
||||
for original_data_index, expanded_array_data in all_array_data.items():
|
||||
for array_column in array_columns:
|
||||
if array_column in unprocessed_array_columns:
|
||||
continue
|
||||
# Expand array values that are rows
|
||||
if str(array_column_hierarchy[array_column]['type']) == 'ROW':
|
||||
for array_value in expanded_array_data:
|
||||
cls._expand_row_data(array_value,
|
||||
array_column,
|
||||
array_column_hierarchy)
|
||||
continue
|
||||
array_data = expanded_array_data[0][array_column]
|
||||
array_children = array_column_hierarchy[array_column]
|
||||
# This is an empty array of primitive data type
|
||||
if not array_data and not array_children['children']:
|
||||
continue
|
||||
# Pull out complex array values into its own row of data
|
||||
elif array_data and array_children['children']:
|
||||
for array_index, data_value in enumerate(array_data):
|
||||
if array_index >= len(expanded_array_data):
|
||||
empty_data = cls._create_empty_row_of_data(all_columns)
|
||||
expanded_array_data.append(empty_data)
|
||||
for index, datum_value in enumerate(data_value):
|
||||
array_child = array_children['children'][index]
|
||||
expanded_array_data[array_index][array_child] = datum_value
|
||||
# Pull out primitive array values into its own row of data
|
||||
elif array_data:
|
||||
for array_index, data_value in enumerate(array_data):
|
||||
if array_index >= len(expanded_array_data):
|
||||
empty_data = cls._create_empty_row_of_data(all_columns)
|
||||
expanded_array_data.append(empty_data)
|
||||
expanded_array_data[array_index][array_column] = data_value
|
||||
# This is an empty array with nested fields
|
||||
else:
|
||||
for index, array_child in enumerate(array_children['children']):
|
||||
for array_value in expanded_array_data:
|
||||
array_value[array_child] = ''
|
||||
return all_array_data
|
||||
|
||||
@classmethod
|
||||
def _consolidate_array_data_into_data(cls,
|
||||
data: List[dict],
|
||||
array_data: dict) -> None:
|
||||
"""
|
||||
Consolidate data given a list representing rows of data and a dictionary
|
||||
representing expanded array data
|
||||
Example:
|
||||
Original data set = [
|
||||
{'ColumnA': [1, 2], 'ColumnB': [3]},
|
||||
{'ColumnA': [11, 22], 'ColumnB': [33]}
|
||||
]
|
||||
array_data = {
|
||||
0: [
|
||||
{'ColumnA': 1, 'ColumnB': 3},
|
||||
{'ColumnA': 2, 'ColumnB': ''},
|
||||
],
|
||||
1: [
|
||||
{'ColumnA': 11, 'ColumnB': 33},
|
||||
{'ColumnA': 22, 'ColumnB': ''},
|
||||
],
|
||||
}
|
||||
Final data set = [
|
||||
{'ColumnA': 1, 'ColumnB': 3},
|
||||
{'ColumnA': 2, 'ColumnB': ''},
|
||||
{'ColumnA': 11, 'ColumnB': 33},
|
||||
{'ColumnA': 22, 'ColumnB': ''},
|
||||
]
|
||||
:param data: list representing rows of data
|
||||
:param array_data: dictionary representing expanded array data
|
||||
:return: list where data and array_data are combined
|
||||
"""
|
||||
data_index = 0
|
||||
original_data_index = 0
|
||||
while data_index < len(data):
|
||||
data[data_index].update(array_data[original_data_index][0])
|
||||
array_data[original_data_index].pop(0)
|
||||
data[data_index + 1:data_index + 1] = array_data[original_data_index]
|
||||
data_index = data_index + len(array_data[original_data_index]) + 1
|
||||
original_data_index = original_data_index + 1
|
||||
|
||||
@classmethod
|
||||
def _remove_processed_array_columns(cls,
|
||||
unprocessed_array_columns: Set[str],
|
||||
array_column_hierarchy: dict) -> None:
|
||||
"""
|
||||
Remove keys representing array columns that have already been processed
|
||||
:param unprocessed_array_columns: list of unprocessed array columns
|
||||
:param array_column_hierarchy: graph representing array columns
|
||||
"""
|
||||
array_columns = list(array_column_hierarchy.keys())
|
||||
for array_column in array_columns:
|
||||
if array_column in unprocessed_array_columns:
|
||||
continue
|
||||
else:
|
||||
del array_column_hierarchy[array_column]
|
||||
|
||||
@classmethod
|
||||
def expand_data(cls,
|
||||
columns: List[dict],
|
||||
data: List[dict]) -> Tuple[List[dict], List[dict], List[dict]]:
|
||||
"""
|
||||
We do not immediately display rows and arrays clearly in the data grid. This
|
||||
method separates out nested fields and data values to help clearly display
|
||||
structural columns.
|
||||
|
||||
Example: ColumnA is a row(nested_obj varchar) and ColumnB is an array(int)
|
||||
Original data set = [
|
||||
{'ColumnA': ['a1'], 'ColumnB': [1, 2]},
|
||||
{'ColumnA': ['a2'], 'ColumnB': [3, 4]},
|
||||
]
|
||||
Expanded data set = [
|
||||
{'ColumnA': ['a1'], 'ColumnA.nested_obj': 'a1', 'ColumnB': 1},
|
||||
{'ColumnA': '', 'ColumnA.nested_obj': '', 'ColumnB': 2},
|
||||
{'ColumnA': ['a2'], 'ColumnA.nested_obj': 'a2', 'ColumnB': 3},
|
||||
{'ColumnA': '', 'ColumnA.nested_obj': '', 'ColumnB': 4},
|
||||
]
|
||||
:param columns: columns selected in the query
|
||||
:param data: original data set
|
||||
:return: list of all columns(selected columns and their nested fields),
|
||||
expanded data set, listed of nested fields
|
||||
"""
|
||||
all_columns: List[dict] = []
|
||||
# Get the list of all columns (selected fields and their nested fields)
|
||||
for column in columns:
|
||||
if column['type'].startswith('ARRAY') or column['type'].startswith('ROW'):
|
||||
cls._parse_structural_column(column['name'],
|
||||
column['type'].lower(),
|
||||
all_columns)
|
||||
else:
|
||||
all_columns.append(column)
|
||||
|
||||
# Build graphs where the root node is a row or array and its children are that
|
||||
# column's nested fields
|
||||
row_column_hierarchy,\
|
||||
array_column_hierarchy,\
|
||||
expanded_columns = cls._create_row_and_array_hierarchy(columns)
|
||||
|
||||
# Pull out a row's nested fields and their values into separate columns
|
||||
ordered_row_columns = row_column_hierarchy.keys()
|
||||
for datum in data:
|
||||
for row_column in ordered_row_columns:
|
||||
cls._expand_row_data(datum, row_column, row_column_hierarchy)
|
||||
|
||||
while array_column_hierarchy:
|
||||
array_columns = list(array_column_hierarchy.keys())
|
||||
# Determine what columns are ready to be processed.
|
||||
array_columns_to_process,\
|
||||
unprocessed_array_columns = cls._split_array_columns_by_process_state(
|
||||
array_columns, array_column_hierarchy, data[0])
|
||||
all_array_data = cls._process_array_data(data,
|
||||
all_columns,
|
||||
array_column_hierarchy)
|
||||
# Consolidate the original data set and the expanded array data
|
||||
cls._consolidate_array_data_into_data(data, all_array_data)
|
||||
# Remove processed array columns from the graph
|
||||
cls._remove_processed_array_columns(unprocessed_array_columns,
|
||||
array_column_hierarchy)
|
||||
|
||||
return all_columns, data, expanded_columns
|
||||
|
||||
@classmethod
|
||||
def extra_table_metadata(cls, database, table_name, schema_name):
|
||||
indexes = database.get_indexes(table_name, schema_name)
|
||||
if not indexes:
|
||||
return {}
|
||||
cols = indexes[0].get('column_names', [])
|
||||
full_table_name = table_name
|
||||
if schema_name and '.' not in table_name:
|
||||
full_table_name = '{}.{}'.format(schema_name, table_name)
|
||||
pql = cls._partition_query(full_table_name)
|
||||
col_name, latest_part = cls.latest_partition(
|
||||
table_name, schema_name, database, show_first=True)
|
||||
return {
|
||||
'partitions': {
|
||||
'cols': cols,
|
||||
'latest': {col_name: latest_part},
|
||||
'partitionQuery': pql,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor, query, session):
|
||||
"""Updates progress information"""
|
||||
logging.info('Polling the cursor for progress')
|
||||
polled = cursor.poll()
|
||||
# poll returns dict -- JSON status information or ``None``
|
||||
# if the query is done
|
||||
# https://github.com/dropbox/PyHive/blob/
|
||||
# b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
|
||||
while polled:
|
||||
# Update the object and wait for the kill signal.
|
||||
stats = polled.get('stats', {})
|
||||
|
||||
query = session.query(type(query)).filter_by(id=query.id).one()
|
||||
if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
|
||||
cursor.cancel()
|
||||
break
|
||||
|
||||
if stats:
|
||||
state = stats.get('state')
|
||||
|
||||
# if already finished, then stop polling
|
||||
if state == 'FINISHED':
|
||||
break
|
||||
|
||||
completed_splits = float(stats.get('completedSplits'))
|
||||
total_splits = float(stats.get('totalSplits'))
|
||||
if total_splits and completed_splits:
|
||||
progress = 100 * (completed_splits / total_splits)
|
||||
logging.info(
|
||||
'Query progress: {} / {} '
|
||||
'splits'.format(completed_splits, total_splits))
|
||||
if progress > query.progress:
|
||||
query.progress = progress
|
||||
session.commit()
|
||||
time.sleep(1)
|
||||
logging.info('Polling the cursor for progress')
|
||||
polled = cursor.poll()
|
||||
|
||||
@classmethod
|
||||
def extract_error_message(cls, e):
|
||||
if (
|
||||
hasattr(e, 'orig') and
|
||||
type(e.orig).__name__ == 'DatabaseError' and
|
||||
isinstance(e.orig[0], dict)):
|
||||
error_dict = e.orig[0]
|
||||
return '{} at {}: {}'.format(
|
||||
error_dict.get('errorName'),
|
||||
error_dict.get('errorLocation'),
|
||||
error_dict.get('message'),
|
||||
)
|
||||
if (
|
||||
type(e).__name__ == 'DatabaseError' and
|
||||
hasattr(e, 'args') and
|
||||
len(e.args) > 0
|
||||
):
|
||||
error_dict = e.args[0]
|
||||
return error_dict.get('message')
|
||||
return utils.error_msg_from_exception(e)
|
||||
|
||||
@classmethod
|
||||
def _partition_query(
|
||||
cls, table_name, limit=0, order_by=None, filters=None):
|
||||
"""Returns a partition query
|
||||
|
||||
:param table_name: the name of the table to get partitions from
|
||||
:type table_name: str
|
||||
:param limit: the number of partitions to be returned
|
||||
:type limit: int
|
||||
:param order_by: a list of tuples of field name and a boolean
|
||||
that determines if that field should be sorted in descending
|
||||
order
|
||||
:type order_by: list of (str, bool) tuples
|
||||
:param filters: dict of field name and filter value combinations
|
||||
"""
|
||||
limit_clause = 'LIMIT {}'.format(limit) if limit else ''
|
||||
order_by_clause = ''
|
||||
if order_by:
|
||||
l = [] # noqa: E741
|
||||
for field, desc in order_by:
|
||||
l.append(field + ' DESC' if desc else '')
|
||||
order_by_clause = 'ORDER BY ' + ', '.join(l)
|
||||
|
||||
where_clause = ''
|
||||
if filters:
|
||||
l = [] # noqa: E741
|
||||
for field, value in filters.items():
|
||||
l.append(f"{field} = '{value}'")
|
||||
where_clause = 'WHERE ' + ' AND '.join(l)
|
||||
|
||||
sql = textwrap.dedent(f"""\
|
||||
SELECT * FROM "{table_name}$partitions"
|
||||
|
||||
{where_clause}
|
||||
{order_by_clause}
|
||||
{limit_clause}
|
||||
""")
|
||||
return sql
|
||||
|
||||
@classmethod
|
||||
def where_latest_partition(
|
||||
cls, table_name, schema, database, qry, columns=None):
|
||||
try:
|
||||
col_name, value = cls.latest_partition(
|
||||
table_name, schema, database, show_first=True)
|
||||
except Exception:
|
||||
# table is not partitioned
|
||||
return False
|
||||
if value is not None:
|
||||
for c in columns:
|
||||
if c.get('name') == col_name:
|
||||
return qry.where(Column(col_name) == value)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _latest_partition_from_df(cls, df):
|
||||
if not df.empty:
|
||||
return df.to_records(index=False)[0][0]
|
||||
|
||||
@classmethod
|
||||
def latest_partition(cls, table_name, schema, database, show_first=False):
|
||||
"""Returns col name and the latest (max) partition value for a table
|
||||
|
||||
:param table_name: the name of the table
|
||||
:type table_name: str
|
||||
:param schema: schema / database / namespace
|
||||
:type schema: str
|
||||
:param database: database query will be run against
|
||||
:type database: models.Database
|
||||
:param show_first: displays the value for the first partitioning key
|
||||
if there are many partitioning keys
|
||||
:type show_first: bool
|
||||
|
||||
>>> latest_partition('foo_table')
|
||||
('ds', '2018-01-01')
|
||||
"""
|
||||
indexes = database.get_indexes(table_name, schema)
|
||||
if len(indexes[0]['column_names']) < 1:
|
||||
raise SupersetTemplateException(
|
||||
'The table should have one partitioned field')
|
||||
elif not show_first and len(indexes[0]['column_names']) > 1:
|
||||
raise SupersetTemplateException(
|
||||
'The table should have a single partitioned field '
|
||||
'to use this function. You may want to use '
|
||||
'`presto.latest_sub_partition`')
|
||||
part_field = indexes[0]['column_names'][0]
|
||||
sql = cls._partition_query(table_name, 1, [(part_field, True)])
|
||||
df = database.get_df(sql, schema)
|
||||
return part_field, cls._latest_partition_from_df(df)
|
||||
|
||||
@classmethod
|
||||
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
|
||||
"""Returns the latest (max) partition value for a table
|
||||
|
||||
A filtering criteria should be passed for all fields that are
|
||||
partitioned except for the field to be returned. For example,
|
||||
if a table is partitioned by (``ds``, ``event_type`` and
|
||||
``event_category``) and you want the latest ``ds``, you'll want
|
||||
to provide a filter as keyword arguments for both
|
||||
``event_type`` and ``event_category`` as in
|
||||
``latest_sub_partition('my_table',
|
||||
event_category='page', event_type='click')``
|
||||
|
||||
:param table_name: the name of the table, can be just the table
|
||||
name or a fully qualified table name as ``schema_name.table_name``
|
||||
:type table_name: str
|
||||
:param schema: schema / database / namespace
|
||||
:type schema: str
|
||||
:param database: database query will be run against
|
||||
:type database: models.Database
|
||||
|
||||
:param kwargs: keyword arguments define the filtering criteria
|
||||
on the partition list. There can be many of these.
|
||||
:type kwargs: str
|
||||
>>> latest_sub_partition('sub_partition_table', event_type='click')
|
||||
'2018-01-01'
|
||||
"""
|
||||
indexes = database.get_indexes(table_name, schema)
|
||||
part_fields = indexes[0]['column_names']
|
||||
for k in kwargs.keys():
|
||||
if k not in k in part_fields:
|
||||
msg = 'Field [{k}] is not part of the portioning key'
|
||||
raise SupersetTemplateException(msg)
|
||||
if len(kwargs.keys()) != len(part_fields) - 1:
|
||||
msg = (
|
||||
'A filter needs to be specified for {} out of the '
|
||||
'{} fields.'
|
||||
).format(len(part_fields) - 1, len(part_fields))
|
||||
raise SupersetTemplateException(msg)
|
||||
|
||||
for field in part_fields:
|
||||
if field not in kwargs.keys():
|
||||
field_to_return = field
|
||||
|
||||
sql = cls._partition_query(
|
||||
table_name, 1, [(field_to_return, True)], kwargs)
|
||||
df = database.get_df(sql, schema)
|
||||
if df.empty:
|
||||
return ''
|
||||
return df.to_dict()[field_to_return][0]
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
|
||||
|
||||
class RedshiftEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'redshift'
|
||||
max_column_name_length = 127
|
||||
|
||||
@staticmethod
|
||||
def mutate_label(label):
|
||||
"""
|
||||
Redshift only supports lowercase column names and aliases.
|
||||
:param str label: Original label which might include uppercase letters
|
||||
:return: String that is supported by the database
|
||||
"""
|
||||
return label.lower()
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from urllib import parse
|
||||
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
|
||||
|
||||
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'snowflake'
|
||||
force_column_alias_quotes = True
|
||||
max_column_name_length = 256
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': "DATE_TRUNC('SECOND', {col})",
|
||||
'PT1M': "DATE_TRUNC('MINUTE', {col})",
|
||||
'PT5M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 5) * 5, \
|
||||
DATE_TRUNC('HOUR', {col}))",
|
||||
'PT10M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 10) * 10, \
|
||||
DATE_TRUNC('HOUR', {col}))",
|
||||
'PT15M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 15) * 15, \
|
||||
DATE_TRUNC('HOUR', {col}))",
|
||||
'PT0.5H': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 30) * 30, \
|
||||
DATE_TRUNC('HOUR', {col}))",
|
||||
'PT1H': "DATE_TRUNC('HOUR', {col})",
|
||||
'P1D': "DATE_TRUNC('DAY', {col})",
|
||||
'P1W': "DATE_TRUNC('WEEK', {col})",
|
||||
'P1M': "DATE_TRUNC('MONTH', {col})",
|
||||
'P0.25Y': "DATE_TRUNC('QUARTER', {col})",
|
||||
'P1Y': "DATE_TRUNC('YEAR', {col})",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
database = uri.database
|
||||
if '/' in uri.database:
|
||||
database = uri.database.split('/')[0]
|
||||
if selected_schema:
|
||||
selected_schema = parse.quote(selected_schema, safe='')
|
||||
uri.database = database + '/' + selected_schema
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return "DATEADD(S, {col}, '1970-01-01')"
|
||||
|
||||
@classmethod
|
||||
def epoch_ms_to_dttm(cls):
|
||||
return "DATEADD(MS, {col}, '1970-01-01')"
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from typing import List
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.utils import core as utils
|
||||
|
||||
|
||||
class SqliteEngineSpec(BaseEngineSpec):
|
||||
engine = 'sqlite'
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1H': "DATETIME(STRFTIME('%Y-%m-%dT%H:00:00', {col}))",
|
||||
'P1D': 'DATE({col})',
|
||||
'P1W': "DATE({col}, -strftime('%W', {col}) || ' days')",
|
||||
'P1M': "DATE({col}, -strftime('%d', {col}) || ' days', '+1 day')",
|
||||
'P1Y': "DATETIME(STRFTIME('%Y-01-01T00:00:00', {col}))",
|
||||
'P1W/1970-01-03T00:00:00Z': "DATE({col}, 'weekday 6')",
|
||||
'1969-12-28T00:00:00Z/P1W': "DATE({col}, 'weekday 0', '-7 days')",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return "datetime({col}, 'unixepoch')"
|
||||
|
||||
@classmethod
|
||||
def get_all_datasource_names(cls, db, datasource_type: str) \
|
||||
-> List[utils.DatasourceName]:
|
||||
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
|
||||
cache_timeout=db.schema_cache_timeout,
|
||||
force=True)
|
||||
schema = schemas[0]
|
||||
if datasource_type == 'table':
|
||||
return db.get_all_table_names_in_schema(
|
||||
schema=schema, force=True,
|
||||
cache=db.table_cache_enabled,
|
||||
cache_timeout=db.table_cache_timeout)
|
||||
elif datasource_type == 'view':
|
||||
return db.get_all_view_names_in_schema(
|
||||
schema=schema, force=True,
|
||||
cache=db.table_cache_enabled,
|
||||
cache_timeout=db.table_cache_timeout)
|
||||
else:
|
||||
raise Exception(f'Unsupported datasource_type: {datasource_type}')
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
iso = dttm.isoformat().replace('T', ' ')
|
||||
if '.' not in iso:
|
||||
iso += '.000000'
|
||||
return "'{}'".format(iso)
|
||||
|
||||
@classmethod
|
||||
def get_table_names(cls, inspector, schema):
|
||||
"""Need to disregard the schema for Sqlite"""
|
||||
return sorted(inspector.get_table_names())
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||
|
||||
|
||||
class TeradataEngineSpec(BaseEngineSpec):
|
||||
"""Dialect for Teradata DB."""
|
||||
engine = 'teradata'
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
max_column_name_length = 30 # since 14.10 this is 128
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1M': "TRUNC(CAST({col} as DATE), 'MI')",
|
||||
'PT1H': "TRUNC(CAST({col} as DATE), 'HH')",
|
||||
'P1D': "TRUNC(CAST({col} as DATE), 'DDD')",
|
||||
'P1W': "TRUNC(CAST({col} as DATE), 'WW')",
|
||||
'P1M': "TRUNC(CAST({col} as DATE), 'MONTH')",
|
||||
'P0.25Y': "TRUNC(CAST({col} as DATE), 'Q')",
|
||||
'P1Y': "TRUNC(CAST({col} as DATE), 'YEAR')",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return "CAST(((CAST(DATE '1970-01-01' + ({col} / 86400) AS TIMESTAMP(0) " \
|
||||
"AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' " \
|
||||
'HOUR TO SECOND) AS TIMESTAMP(0))'
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
|
||||
|
||||
class VerticaEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'vertica'
|
||||
|
|
@ -34,8 +34,8 @@ import sqlalchemy as sqla
|
|||
|
||||
from superset import dataframe, db, jinja_context, security_manager, sql_lab
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.db_engine_specs import MssqlEngineSpec
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
from superset.models import core as models
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils import core as utils
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import inspect
|
||||
from unittest import mock
|
||||
|
||||
from sqlalchemy import column, literal_column, select, table
|
||||
|
|
@ -22,12 +21,20 @@ from sqlalchemy.dialects import mssql, oracle, postgresql
|
|||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.types import String, UnicodeText
|
||||
|
||||
from superset import db_engine_specs
|
||||
from superset.db_engine_specs import (
|
||||
BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
|
||||
MySQLEngineSpec, OracleEngineSpec, PinotEngineSpec, PostgresEngineSpec,
|
||||
PrestoEngineSpec,
|
||||
from superset.db_engine_specs import engines
|
||||
from superset.db_engine_specs.base import (
|
||||
BaseEngineSpec,
|
||||
builtin_time_grains,
|
||||
create_time_grains_tuple,
|
||||
)
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.db_engine_specs.oracle import OracleEngineSpec
|
||||
from superset.db_engine_specs.pinot import PinotEngineSpec
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.models.core import Database
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
|
@ -302,24 +309,23 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
'PT1S': '{col}',
|
||||
'PT1M': '{col}',
|
||||
}
|
||||
time_grains = db_engine_specs._create_time_grains_tuple(time_grains,
|
||||
time_grain_functions,
|
||||
blacklist)
|
||||
time_grains = create_time_grains_tuple(time_grains,
|
||||
time_grain_functions,
|
||||
blacklist)
|
||||
self.assertEqual(1, len(time_grains))
|
||||
self.assertEqual('PT1S', time_grains[0].duration)
|
||||
|
||||
def test_engine_time_grain_validity(self):
|
||||
time_grains = set(db_engine_specs.builtin_time_grains.keys())
|
||||
time_grains = set(builtin_time_grains.keys())
|
||||
# loop over all subclasses of BaseEngineSpec
|
||||
for cls_name, cls in inspect.getmembers(db_engine_specs):
|
||||
if inspect.isclass(cls) and issubclass(cls, BaseEngineSpec) \
|
||||
and cls is not BaseEngineSpec:
|
||||
for engine in engines.values():
|
||||
if engine is not BaseEngineSpec:
|
||||
# make sure time grain functions have been defined
|
||||
self.assertGreater(len(cls.time_grain_functions), 0)
|
||||
self.assertGreater(len(engine.time_grain_functions), 0)
|
||||
# make sure that all defined time grains are supported
|
||||
defined_time_grains = {grain.duration for grain in cls.get_time_grains()}
|
||||
intersection = time_grains.intersection(defined_time_grains)
|
||||
self.assertSetEqual(defined_time_grains, intersection, cls_name)
|
||||
defined_grains = {grain.duration for grain in engine.get_time_grains()}
|
||||
intersection = time_grains.intersection(defined_grains)
|
||||
self.assertSetEqual(defined_grains, intersection, engine)
|
||||
|
||||
def test_presto_get_view_names_return_empty_list(self):
|
||||
self.assertEquals([], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY))
|
||||
|
|
@ -691,19 +697,19 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY))
|
||||
|
||||
def test_bigquery_sqla_column_label(self):
|
||||
label = BQEngineSpec.make_label_compatible(column('Col').name)
|
||||
label = BigQueryEngineSpec.make_label_compatible(column('Col').name)
|
||||
label_expected = 'Col'
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BQEngineSpec.make_label_compatible(column('SUM(x)').name)
|
||||
label = BigQueryEngineSpec.make_label_compatible(column('SUM(x)').name)
|
||||
label_expected = 'SUM_x__5f110b965a993675bc4953bb3e03c4a5'
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BQEngineSpec.make_label_compatible(column('SUM[x]').name)
|
||||
label = BigQueryEngineSpec.make_label_compatible(column('SUM[x]').name)
|
||||
label_expected = 'SUM_x__7ebe14a3f9534aeee125449b0bc083a8'
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BQEngineSpec.make_label_compatible(column('12345_col').name)
|
||||
label = BigQueryEngineSpec.make_label_compatible(column('12345_col').name)
|
||||
label_expected = '_12345_col_8d3906e2ea99332eb185f7f8ecb2ffd6'
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
|
|
@ -756,14 +762,14 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
""" Make sure base engine spec removes schema name from table name
|
||||
ie. when try_remove_schema_from_table_name == True. """
|
||||
base_result_expected = ['table', 'table_2']
|
||||
base_result = db_engine_specs.BaseEngineSpec.get_table_names(
|
||||
base_result = BaseEngineSpec.get_table_names(
|
||||
schema='schema', inspector=inspector)
|
||||
self.assertListEqual(base_result_expected, base_result)
|
||||
|
||||
""" Make sure postgres doesn't try to remove schema name from table name
|
||||
ie. when try_remove_schema_from_table_name == False. """
|
||||
pg_result_expected = ['schema.table', 'table_2', 'table_3']
|
||||
pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
|
||||
pg_result = PostgresEngineSpec.get_table_names(
|
||||
schema='schema', inspector=inspector)
|
||||
self.assertListEqual(pg_result_expected, pg_result)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.db_engine_specs import DruidEngineSpec
|
||||
from superset.db_engine_specs.druid import DruidEngineSpec
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue