Fixing confusion when selecting schema across engines (#2572)

This commit is contained in:
Maxime Beauchemin 2017-04-10 15:36:58 -07:00 committed by GitHub
parent 40b3d3b3ef
commit ac84fc2b65
4 changed files with 85 additions and 55 deletions

View File

@ -116,6 +116,27 @@ class BaseEngineSpec(object):
"""Extract error message for queries"""
return utils.error_msg_from_exception(e)
@classmethod
def adjust_database_uri(cls, uri, selected_schema):
"""Based on a URI and selected schema, return a new URI
The URI here represents the URI as entered when saving the database,
``selected_schema`` is the schema currently active presumably in
the SQL Lab dropdown. Based on that, for some database engine,
we can return a new altered URI that connects straight to the
active schema, meaning the users won't have to prefix the object
names by the schema name.
Some databases engines have 2 level of namespacing: database and
schema (postgres, oracle, mssql, ...)
For those it's probably better to not alter the database
component of the URI with the schema name, it won't work.
Some database drivers like presto accept "{catalog}/{schema}" in
the database component of the URL, that can be handled here.
"""
return uri
@classmethod
def sql_preprocessor(cls, sql):
"""If the SQL needs to be altered prior to running it
@ -290,6 +311,12 @@ class MySQLEngineSpec(BaseEngineSpec):
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
if selected_schema:
uri.database = selected_schema
return uri
@classmethod
def epoch_to_dttm(cls):
return "from_unixtime({col})"
@ -328,6 +355,17 @@ class PrestoEngineSpec(BaseEngineSpec):
from superset.db_engines import presto as patched_presto
presto.Cursor.cancel = patched_presto.cancel
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database
if selected_schema:
if '/' in database:
database = database.split('/')[0] + '/' + selected_schema
else:
database += '/' + selected_schema
uri.database = database
return uri
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()

View File

@ -560,26 +560,10 @@ class Database(Model, AuditMixinNullable):
def get_sqla_engine(self, schema=None):
extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted)
uri = make_url(self.sqlalchemy_uri_decrypted)
params = extra.get('engine_params', {})
url.database = self.get_database_for_various_backend(url, schema)
return create_engine(url, **params)
def get_database_for_various_backend(self, uri, default_database=None):
database = uri.database
if self.backend == 'presto' and default_database:
if '/' in database:
database = database.split('/')[0] + '/' + default_database
else:
database += '/' + default_database
# Postgres and Redshift use the concept of schema as a logical entity
# on top of the database, so the database should not be changed
# even if passed default_database
elif self.backend in ('redshift', 'postgresql', 'sqlite'):
pass
elif default_database:
database = default_database
return database
uri = self.db_engine_spec.adjust_database_uri(uri, schema)
return create_engine(uri, **params)
def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
@ -662,9 +646,8 @@ class Database(Model, AuditMixinNullable):
@property
def db_engine_spec(self):
engine_name = self.get_sqla_engine().name or 'base'
return db_engine_specs.engines.get(
engine_name, db_engine_specs.BaseEngineSpec)
self.backend, db_engine_specs.BaseEngineSpec)
def grains(self):
"""Defines time granularity database-specific expressions.

View File

@ -169,8 +169,8 @@ def generate_download_headers(extension):
class DatabaseView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.Database)
list_columns = [
'verbose_name', 'backend', 'allow_run_sync', 'allow_run_async',
'allow_dml', 'creator', 'changed_on_', 'database_name']
'database_name', 'backend', 'allow_run_sync', 'allow_run_async',
'allow_dml', 'creator', 'modified']
add_columns = [
'database_name', 'sqlalchemy_uri', 'cache_timeout', 'extra',
'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
@ -1351,6 +1351,7 @@ class Superset(BaseSupersetView):
engine.connect()
return json.dumps(engine.table_names(), indent=4)
except Exception as e:
logging.exception(e)
return json_error_response((
"Connection failed!\n\n"
"The error message returned was:\n{}").format(e))

View File

@ -6,43 +6,51 @@ from superset.models.core import Database
class DatabaseModelTestCase(unittest.TestCase):
def test_database_for_various_backend(self):
def test_database_schema_presto(self):
sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'hive/default'
db = model.get_database_for_various_backend(url, 'raw_data')
assert db == 'hive/raw_data'
sqlalchemy_uri = 'redshift+psycopg2://superset:XXXXXXXXXX@redshift.airbnb.io:5439/prod'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'prod'
db = model.get_database_for_various_backend(url, 'test')
assert db == 'prod'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive/default', db)
sqlalchemy_uri = 'postgresql+psycopg2://superset:XXXXXXXXXX@postgres.airbnb.io:5439/prod'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'prod'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'prod'
db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)
sqlalchemy_uri = 'hive://hive@hive.airbnb.io:10000/raw_data'
sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'raw_data'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'adhoc'
sqlalchemy_uri = 'mysql://superset:XXXXXXXXXX@mysql.airbnb.io/superset'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive', db)
db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)
def test_database_schema_postgres(self):
sqlalchemy_uri = 'postgresql+psycopg2://postgres.airbnb.io:5439/prod'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'superset'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'adhoc'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('prod', db)
db = make_url(model.get_sqla_engine(schema='foo').url).database
self.assertEquals('prod', db)
def test_database_schema_hive(self):
sqlalchemy_uri = 'hive://hive@hive.airbnb.io:10000/hive/default'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive/default', db)
db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)
def test_database_schema_mysql(self):
sqlalchemy_uri = 'mysql://root@localhost/superset'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('superset', db)
db = make_url(model.get_sqla_engine(schema='staging').url).database
self.assertEquals('staging', db)