Fixing confusion when selecting schema across engines (#2572)
This commit is contained in:
parent
40b3d3b3ef
commit
ac84fc2b65
|
|
@ -116,6 +116,27 @@ class BaseEngineSpec(object):
|
|||
"""Extract error message for queries"""
|
||||
return utils.error_msg_from_exception(e)
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema):
|
||||
"""Based on a URI and selected schema, return a new URI
|
||||
|
||||
The URI here represents the URI as entered when saving the database,
|
||||
``selected_schema`` is the schema currently active presumably in
|
||||
the SQL Lab dropdown. Based on that, for some database engine,
|
||||
we can return a new altered URI that connects straight to the
|
||||
active schema, meaning the users won't have to prefix the object
|
||||
names by the schema name.
|
||||
|
||||
Some databases engines have 2 level of namespacing: database and
|
||||
schema (postgres, oracle, mssql, ...)
|
||||
For those it's probably better to not alter the database
|
||||
component of the URI with the schema name, it won't work.
|
||||
|
||||
Some database drivers like presto accept "{catalog}/{schema}" in
|
||||
the database component of the URL, that can be handled here.
|
||||
"""
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def sql_preprocessor(cls, sql):
|
||||
"""If the SQL needs to be altered prior to running it
|
||||
|
|
@ -290,6 +311,12 @@ class MySQLEngineSpec(BaseEngineSpec):
|
|||
dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
if selected_schema:
|
||||
uri.database = selected_schema
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
return "from_unixtime({col})"
|
||||
|
|
@ -328,6 +355,17 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
from superset.db_engines import presto as patched_presto
|
||||
presto.Cursor.cancel = patched_presto.cancel
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
database = uri.database
|
||||
if selected_schema:
|
||||
if '/' in database:
|
||||
database = database.split('/')[0] + '/' + selected_schema
|
||||
else:
|
||||
database += '/' + selected_schema
|
||||
uri.database = database
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
tt = target_type.upper()
|
||||
|
|
|
|||
|
|
@ -560,26 +560,10 @@ class Database(Model, AuditMixinNullable):
|
|||
|
||||
def get_sqla_engine(self, schema=None):
|
||||
extra = self.get_extra()
|
||||
url = make_url(self.sqlalchemy_uri_decrypted)
|
||||
uri = make_url(self.sqlalchemy_uri_decrypted)
|
||||
params = extra.get('engine_params', {})
|
||||
url.database = self.get_database_for_various_backend(url, schema)
|
||||
return create_engine(url, **params)
|
||||
|
||||
def get_database_for_various_backend(self, uri, default_database=None):
|
||||
database = uri.database
|
||||
if self.backend == 'presto' and default_database:
|
||||
if '/' in database:
|
||||
database = database.split('/')[0] + '/' + default_database
|
||||
else:
|
||||
database += '/' + default_database
|
||||
# Postgres and Redshift use the concept of schema as a logical entity
|
||||
# on top of the database, so the database should not be changed
|
||||
# even if passed default_database
|
||||
elif self.backend in ('redshift', 'postgresql', 'sqlite'):
|
||||
pass
|
||||
elif default_database:
|
||||
database = default_database
|
||||
return database
|
||||
uri = self.db_engine_spec.adjust_database_uri(uri, schema)
|
||||
return create_engine(uri, **params)
|
||||
|
||||
def get_reserved_words(self):
|
||||
return self.get_sqla_engine().dialect.preparer.reserved_words
|
||||
|
|
@ -662,9 +646,8 @@ class Database(Model, AuditMixinNullable):
|
|||
|
||||
@property
|
||||
def db_engine_spec(self):
|
||||
engine_name = self.get_sqla_engine().name or 'base'
|
||||
return db_engine_specs.engines.get(
|
||||
engine_name, db_engine_specs.BaseEngineSpec)
|
||||
self.backend, db_engine_specs.BaseEngineSpec)
|
||||
|
||||
def grains(self):
|
||||
"""Defines time granularity database-specific expressions.
|
||||
|
|
|
|||
|
|
@ -169,8 +169,8 @@ def generate_download_headers(extension):
|
|||
class DatabaseView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.Database)
|
||||
list_columns = [
|
||||
'verbose_name', 'backend', 'allow_run_sync', 'allow_run_async',
|
||||
'allow_dml', 'creator', 'changed_on_', 'database_name']
|
||||
'database_name', 'backend', 'allow_run_sync', 'allow_run_async',
|
||||
'allow_dml', 'creator', 'modified']
|
||||
add_columns = [
|
||||
'database_name', 'sqlalchemy_uri', 'cache_timeout', 'extra',
|
||||
'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
|
||||
|
|
@ -1351,6 +1351,7 @@ class Superset(BaseSupersetView):
|
|||
engine.connect()
|
||||
return json.dumps(engine.table_names(), indent=4)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return json_error_response((
|
||||
"Connection failed!\n\n"
|
||||
"The error message returned was:\n{}").format(e))
|
||||
|
|
|
|||
|
|
@ -6,43 +6,51 @@ from superset.models.core import Database
|
|||
|
||||
|
||||
class DatabaseModelTestCase(unittest.TestCase):
|
||||
def test_database_for_various_backend(self):
|
||||
|
||||
def test_database_schema_presto(self):
|
||||
sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default'
|
||||
model = Database(sqlalchemy_uri=sqlalchemy_uri)
|
||||
url = make_url(model.sqlalchemy_uri)
|
||||
db = model.get_database_for_various_backend(url, None)
|
||||
assert db == 'hive/default'
|
||||
db = model.get_database_for_various_backend(url, 'raw_data')
|
||||
assert db == 'hive/raw_data'
|
||||
|
||||
sqlalchemy_uri = 'redshift+psycopg2://superset:XXXXXXXXXX@redshift.airbnb.io:5439/prod'
|
||||
model = Database(sqlalchemy_uri=sqlalchemy_uri)
|
||||
url = make_url(model.sqlalchemy_uri)
|
||||
db = model.get_database_for_various_backend(url, None)
|
||||
assert db == 'prod'
|
||||
db = model.get_database_for_various_backend(url, 'test')
|
||||
assert db == 'prod'
|
||||
db = make_url(model.get_sqla_engine().url).database
|
||||
self.assertEquals('hive/default', db)
|
||||
|
||||
sqlalchemy_uri = 'postgresql+psycopg2://superset:XXXXXXXXXX@postgres.airbnb.io:5439/prod'
|
||||
model = Database(sqlalchemy_uri=sqlalchemy_uri)
|
||||
url = make_url(model.sqlalchemy_uri)
|
||||
db = model.get_database_for_various_backend(url, None)
|
||||
assert db == 'prod'
|
||||
db = model.get_database_for_various_backend(url, 'adhoc')
|
||||
assert db == 'prod'
|
||||
db = make_url(model.get_sqla_engine(schema='core_db').url).database
|
||||
self.assertEquals('hive/core_db', db)
|
||||
|
||||
sqlalchemy_uri = 'hive://hive@hive.airbnb.io:10000/raw_data'
|
||||
sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive'
|
||||
model = Database(sqlalchemy_uri=sqlalchemy_uri)
|
||||
url = make_url(model.sqlalchemy_uri)
|
||||
db = model.get_database_for_various_backend(url, None)
|
||||
assert db == 'raw_data'
|
||||
db = model.get_database_for_various_backend(url, 'adhoc')
|
||||
assert db == 'adhoc'
|
||||
|
||||
sqlalchemy_uri = 'mysql://superset:XXXXXXXXXX@mysql.airbnb.io/superset'
|
||||
db = make_url(model.get_sqla_engine().url).database
|
||||
self.assertEquals('hive', db)
|
||||
|
||||
db = make_url(model.get_sqla_engine(schema='core_db').url).database
|
||||
self.assertEquals('hive/core_db', db)
|
||||
|
||||
def test_database_schema_postgres(self):
|
||||
sqlalchemy_uri = 'postgresql+psycopg2://postgres.airbnb.io:5439/prod'
|
||||
model = Database(sqlalchemy_uri=sqlalchemy_uri)
|
||||
url = make_url(model.sqlalchemy_uri)
|
||||
db = model.get_database_for_various_backend(url, None)
|
||||
assert db == 'superset'
|
||||
db = model.get_database_for_various_backend(url, 'adhoc')
|
||||
assert db == 'adhoc'
|
||||
|
||||
db = make_url(model.get_sqla_engine().url).database
|
||||
self.assertEquals('prod', db)
|
||||
|
||||
db = make_url(model.get_sqla_engine(schema='foo').url).database
|
||||
self.assertEquals('prod', db)
|
||||
|
||||
def test_database_schema_hive(self):
|
||||
sqlalchemy_uri = 'hive://hive@hive.airbnb.io:10000/hive/default'
|
||||
model = Database(sqlalchemy_uri=sqlalchemy_uri)
|
||||
db = make_url(model.get_sqla_engine().url).database
|
||||
self.assertEquals('hive/default', db)
|
||||
|
||||
db = make_url(model.get_sqla_engine(schema='core_db').url).database
|
||||
self.assertEquals('hive/core_db', db)
|
||||
|
||||
def test_database_schema_mysql(self):
|
||||
sqlalchemy_uri = 'mysql://root@localhost/superset'
|
||||
model = Database(sqlalchemy_uri=sqlalchemy_uri)
|
||||
|
||||
db = make_url(model.get_sqla_engine().url).database
|
||||
self.assertEquals('superset', db)
|
||||
|
||||
db = make_url(model.get_sqla_engine(schema='staging').url).database
|
||||
self.assertEquals('staging', db)
|
||||
|
|
|
|||
Loading…
Reference in New Issue