diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index afdab1669..c9006bc28 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -29,6 +29,7 @@ import boto3 from flask import g from flask_babel import lazy_gettext as _ import pandas +import sqlalchemy as sqla from sqlalchemy import select from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url @@ -231,14 +232,15 @@ class BaseEngineSpec(object): @classmethod def select_star(cls, my_db, table_name, schema=None, limit=100, - show_cols=False, indent=True, latest_partition=True): + show_cols=False, indent=True, latest_partition=True, + cols=None): fields = '*' - cols = [] - if show_cols or latest_partition: - cols = my_db.get_table(table_name, schema=schema).columns + cols = cols or [] + if (show_cols or latest_partition) and not cols: + cols = my_db.get_columns(table_name, schema) if show_cols: - fields = [my_db.get_quoter()(c.name) for c in cols] + fields = [sqla.column(c.get('name')) for c in cols] full_table_name = table_name if schema: full_table_name = schema + '.' + table_name diff --git a/superset/models/core.py b/superset/models/core.py index b477f02e3..cd7cc44b9 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -680,7 +680,8 @@ class Database(Model, AuditMixinNullable, ImportMixin): DB_CONNECTION_MUTATOR = config.get('DB_CONNECTION_MUTATOR') if DB_CONNECTION_MUTATOR: - url, params = DB_CONNECTION_MUTATOR(url, params, user_name, sm) + url, params = DB_CONNECTION_MUTATOR( + url, params, effective_username, sm) return create_engine(url, **params) def get_reserved_words(self): @@ -713,11 +714,11 @@ class Database(Model, AuditMixinNullable, ImportMixin): def select_star( self, table_name, schema=None, limit=100, show_cols=False, - indent=True, latest_partition=True): + indent=True, latest_partition=True, cols=None): """Generates a ``select *`` statement in the proper dialect""" return self.db_engine_spec.select_star( self, table_name, schema=schema, limit=limit, show_cols=show_cols, - indent=indent, latest_partition=latest_partition) + indent=indent, latest_partition=latest_partition, cols=cols) def wrap_sql_limit(self, sql, limit=1000): qry = ( diff --git a/superset/views/core.py b/superset/views/core.py index 1d8ddf949..723e8ccfb 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2212,11 +2212,12 @@ class Superset(BaseSupersetView): def table(self, database_id, table_name, schema): schema = utils.js_string_to_python(schema) mydb = db.session.query(models.Database).filter_by(id=database_id).one() - cols = [] + payload_columns = [] indexes = [] - t = mydb.get_columns(table_name, schema) + primary_key = [] + foreign_keys = [] try: - t = mydb.get_columns(table_name, schema) + columns = mydb.get_columns(table_name, schema) indexes = mydb.get_indexes(table_name, schema) primary_key = mydb.get_pk_constraint(table_name, schema) foreign_keys = mydb.get_foreign_keys(table_name, schema) @@ -2235,13 +2236,13 @@ class Superset(BaseSupersetView): idx['type'] = 'index' keys += indexes - for col in t: + for col in columns: dtype = '' try: dtype = '{}'.format(col['type']) except Exception: pass - cols.append({ + payload_columns.append({ 'name': col['name'], 'type': dtype.split('(')[0] if '(' in dtype else dtype, 'longType': dtype, @@ -2252,9 +2253,10 @@ class Superset(BaseSupersetView): }) tbl = { 'name': table_name, - 'columns': cols, + 'columns': payload_columns, 'selectStar': mydb.select_star( - table_name, schema=schema, show_cols=True, indent=True), + table_name, schema=schema, show_cols=True, indent=True, + cols=columns, latest_partition=False), 'primaryKey': primary_key, 'foreignKeys': foreign_keys, 'indexes': keys, diff --git a/tests/model_tests.py b/tests/model_tests.py index 0b4a16bd4..19367cfff 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -4,14 +4,16 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -import unittest +import textwrap from sqlalchemy.engine.url import make_url +from tests.base_tests import SupersetTestCase +from superset import db from superset.models.core import Database -class DatabaseModelTestCase(unittest.TestCase): +class DatabaseModelTestCase(SupersetTestCase): def test_database_schema_presto(self): sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default' @@ -73,3 +75,25 @@ class DatabaseModelTestCase(unittest.TestCase): model.impersonate_user = False user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username self.assertNotEquals(example_user, user_name) + + def test_select_star(self): + main_db = self.get_main_database(db.session) + table_name = 'bart_lines' + sql = main_db.select_star( + table_name, show_cols=False, latest_partition=False) + expected = textwrap.dedent("""\ + SELECT * + FROM {table_name} + LIMIT 100""".format(**locals())) + assert sql.startswith(expected) + + sql = main_db.select_star( + table_name, show_cols=True, latest_partition=False) + expected = textwrap.dedent("""\ + SELECT color, + name, + path_json, + polyline + FROM bart_lines + LIMIT 100""".format(**locals())) + assert sql.startswith(expected)