Add support for period character in table names (#7453)

* Move schema name handling in table names from frontend to backend

* Rename all_schema_names to get_all_schema_names

* Fix js errors

* Fix additional js linting errors

* Refactor datasource getters and fix linting errors

* Update js unit tests

* Add python unit test for get_table_names method

* Add python unit test for get_table_names method

* Fix js linting error
This commit is contained in:
Ville Brofeldt 2019-05-26 06:13:16 +03:00 committed by GitHub
parent 47ba2ad394
commit f7d3413a50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 147 additions and 136 deletions

View File

@ -208,19 +208,20 @@ describe('TableSelector', () => {
it('test 1', () => {
wrapper.instance().changeTable({
value: 'birth_names',
value: { schema: 'main', table: 'birth_names' },
label: 'birth_names',
});
expect(wrapper.state().tableName).toBe('birth_names');
});
it('test 2', () => {
it('should call onTableChange with schema from table object', () => {
wrapper.setProps({ schema: null });
wrapper.instance().changeTable({
value: 'main.my_table',
label: 'my_table',
value: { schema: 'other_schema', table: 'my_table' },
label: 'other_schema.my_table',
});
expect(mockedProps.onTableChange.getCall(0).args[0]).toBe('my_table');
expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('main');
expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('other_schema');
});
});

View File

@ -329,15 +329,15 @@ export const databases = {
export const tables = {
options: [
{
value: 'birth_names',
value: { schema: 'main', table: 'birth_names' },
label: 'birth_names',
},
{
value: 'energy_usage',
value: { schema: 'main', table: 'energy_usage' },
label: 'energy_usage',
},
{
value: 'wb_health_population',
value: { schema: 'main', table: 'wb_health_population' },
label: 'wb_health_population',
},
],

View File

@ -83,17 +83,10 @@ export default class SqlEditorLeftBar extends React.PureComponent {
this.setState({ tableName: '' });
return;
}
const namePieces = tableOpt.value.split('.');
let tableName = namePieces[0];
let schemaName = this.props.queryEditor.schema;
if (namePieces.length === 1) {
this.setState({ tableName });
} else {
schemaName = namePieces[0];
tableName = namePieces[1];
this.setState({ tableName });
this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
}
const schemaName = tableOpt.value.schema;
const tableName = tableOpt.value.table;
this.setState({ tableName });
this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
this.props.actions.addTable(this.props.queryEditor, tableName, schemaName);
}

View File

@ -170,13 +170,8 @@ export default class TableSelector extends React.PureComponent {
this.setState({ tableName: '' });
return;
}
const namePieces = tableOpt.value.split('.');
let tableName = namePieces[0];
let schemaName = this.props.schema;
if (namePieces.length > 1) {
schemaName = namePieces[0];
tableName = namePieces[1];
}
const schemaName = tableOpt.value.schema;
const tableName = tableOpt.value.table;
if (this.props.tableNameSticky) {
this.setState({ tableName }, this.onChange);
}

View File

@ -288,9 +288,9 @@ def update_datasources_cache():
if database.allow_multi_schema_metadata_fetch:
print('Fetching {} datasources ...'.format(database.name))
try:
database.all_table_names_in_database(
database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60)
database.all_view_names_in_database(
database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60)
except Exception as e:
print('{}'.format(str(e)))

View File

@ -122,6 +122,7 @@ class BaseEngineSpec(object):
force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
try_remove_schema_from_table_name = True
@classmethod
def get_time_expr(cls, expr, pdf, time_grain, grain):
@ -279,33 +280,32 @@ class BaseEngineSpec(object):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
@classmethod
def fetch_result_sets(cls, db, datasource_type):
"""Returns a list of tables [schema1.table1, schema2.table2, ...]
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
"""Returns a list of all tables or views in database.
Datasource_type can be 'table' or 'view'.
Empty schema corresponds to the list of full names of the all
tables or views: <schema>.<result_set_name>.
:param db: Database instance
:param datasource_type: Datasource_type can be 'table' or 'view'
:return: List of all datasources in database or schema
"""
schemas = db.all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
all_result_sets = []
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
all_datasources: List[utils.DatasourceName] = []
for schema in schemas:
if datasource_type == 'table':
all_datasource_names = db.all_table_names_in_schema(
all_datasources += db.get_all_table_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
elif datasource_type == 'view':
all_datasource_names = db.all_view_names_in_schema(
all_datasources += db.get_all_view_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
else:
raise Exception(f'Unsupported datasource_type: {datasource_type}')
all_result_sets += [
'{}.{}'.format(schema, t) for t in all_datasource_names]
return all_result_sets
return all_datasources
@classmethod
def handle_cursor(cls, cursor, query, session):
@ -352,11 +352,17 @@ class BaseEngineSpec(object):
@classmethod
def get_table_names(cls, inspector, schema):
return sorted(inspector.get_table_names(schema))
tables = inspector.get_table_names(schema)
if schema and cls.try_remove_schema_from_table_name:
tables = [re.sub(f'^{schema}\\.', '', table) for table in tables]
return sorted(tables)
@classmethod
def get_view_names(cls, inspector, schema):
return sorted(inspector.get_view_names(schema))
views = inspector.get_view_names(schema)
if schema and cls.try_remove_schema_from_table_name:
views = [re.sub(f'^{schema}\\.', '', view) for view in views]
return sorted(views)
@classmethod
def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list:
@ -528,6 +534,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
class PostgresEngineSpec(PostgresBaseEngineSpec):
engine = 'postgresql'
max_column_name_length = 63
try_remove_schema_from_table_name = False
@classmethod
def get_table_names(cls, inspector, schema):
@ -685,29 +692,25 @@ class SqliteEngineSpec(BaseEngineSpec):
return "datetime({col}, 'unixepoch')"
@classmethod
def fetch_result_sets(cls, db, datasource_type):
schemas = db.all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
all_result_sets = []
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
schema = schemas[0]
if datasource_type == 'table':
all_datasource_names = db.all_table_names_in_schema(
return db.get_all_table_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
elif datasource_type == 'view':
all_datasource_names = db.all_view_names_in_schema(
return db.get_all_view_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
else:
raise Exception(f'Unsupported datasource_type: {datasource_type}')
all_result_sets += [
'{}.{}'.format(schema, t) for t in all_datasource_names]
return all_result_sets
@classmethod
def convert_dttm(cls, target_type, dttm):
iso = dttm.isoformat().replace('T', ' ')
@ -1107,24 +1110,19 @@ class PrestoEngineSpec(BaseEngineSpec):
return 'from_unixtime({col})'
@classmethod
def fetch_result_sets(cls, db, datasource_type):
"""Returns a list of tables [schema1.table1, schema2.table2, ...]
Datasource_type can be 'table' or 'view'.
Empty schema corresponds to the list of full names of the all
tables or views: <schema>.<result_set_name>.
"""
result_set_df = db.get_df(
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
datasource_df = db.get_df(
"""SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S
ORDER BY concat(table_schema, '.', table_name)""".format(
datasource_type.upper(),
),
None)
result_sets = []
for unused, row in result_set_df.iterrows():
result_sets.append('{}.{}'.format(
row['table_schema'], row['table_name']))
return result_sets
datasource_names: List[utils.DatasourceName] = []
for unused, row in datasource_df.iterrows():
datasource_names.append(utils.DatasourceName(
schema=row['table_schema'], table=row['table_name']))
return datasource_names
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
@ -1385,9 +1383,9 @@ class HiveEngineSpec(PrestoEngineSpec):
hive.Cursor.fetch_logs = patched_hive.fetch_logs
@classmethod
def fetch_result_sets(cls, db, datasource_type):
return BaseEngineSpec.fetch_result_sets(
db, datasource_type)
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
@classmethod
def fetch_data(cls, cursor, limit):

View File

@ -23,6 +23,7 @@ import functools
import json
import logging
import textwrap
from typing import List
from flask import escape, g, Markup, request
from flask_appbuilder import Model
@ -65,6 +66,7 @@ metadata = Model.metadata # pylint: disable=no-member
PASSWORD_MASK = 'X' * 10
def set_related_perm(mapper, connection, target): # noqa
src_class = target.cls_model
id_ = target.datasource_id
@ -184,7 +186,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
description=self.description,
cache_timeout=self.cache_timeout)
@datasource.getter
@datasource.getter # type: ignore
@utils.memoized
def get_datasource(self):
return (
@ -210,7 +212,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
datasource = self.datasource
return datasource.url if datasource else None
@property
@property # type: ignore
@utils.memoized
def viz(self):
d = json.loads(self.params)
@ -930,100 +932,87 @@ class Database(Model, AuditMixinNullable, ImportMixin):
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{}:schema:None:table_list',
attribute_in_key='id')
def all_table_names_in_database(self, cache=False,
cache_timeout=None, force=False):
def get_all_table_names_in_database(self, cache: bool = False,
cache_timeout: bool = None,
force=False) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
return []
return self.db_engine_spec.fetch_result_sets(self, 'table')
return self.db_engine_spec.get_all_datasource_names(self, 'table')
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{}:schema:None:view_list',
attribute_in_key='id')
def all_view_names_in_database(self, cache=False,
cache_timeout=None, force=False):
def get_all_view_names_in_database(self, cache: bool = False,
cache_timeout: bool = None,
force: bool = False) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
return []
return self.db_engine_spec.fetch_result_sets(self, 'view')
return self.db_engine_spec.get_all_datasource_names(self, 'view')
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{{}}:schema:{}:table_list'.format(
kwargs.get('schema')),
attribute_in_key='id')
def all_table_names_in_schema(self, schema, cache=False,
cache_timeout=None, force=False):
def get_all_table_names_in_schema(self, schema: str, cache: bool = False,
cache_timeout: int = None, force: bool = False):
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param schema: schema name
:type schema: str
:param cache: whether cache is enabled for the function
:type cache: bool
:param cache_timeout: timeout in seconds for the cache
:type cache_timeout: int
:param force: whether to force refresh the cache
:type force: bool
:return: table list
:rtype: list
:return: list of tables
"""
tables = []
try:
tables = self.db_engine_spec.get_table_names(
inspector=self.inspector, schema=schema)
return [utils.DatasourceName(table=table, schema=schema) for table in tables]
except Exception as e:
logging.exception(e)
return tables
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{{}}:schema:{}:view_list'.format(
kwargs.get('schema')),
attribute_in_key='id')
def all_view_names_in_schema(self, schema, cache=False,
cache_timeout=None, force=False):
def get_all_view_names_in_schema(self, schema: str, cache: bool = False,
cache_timeout: int = None, force: bool = False):
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param schema: schema name
:type schema: str
:param cache: whether cache is enabled for the function
:type cache: bool
:param cache_timeout: timeout in seconds for the cache
:type cache_timeout: int
:param force: whether to force refresh the cache
:type force: bool
:return: view list
:rtype: list
:return: list of views
"""
views = []
try:
views = self.db_engine_spec.get_view_names(
inspector=self.inspector, schema=schema)
return [utils.DatasourceName(table=view, schema=schema) for view in views]
except Exception as e:
logging.exception(e)
return views
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{}:schema_list',
attribute_in_key='id')
def all_schema_names(self, cache=False, cache_timeout=None, force=False):
def get_all_schema_names(self, cache: bool = False, cache_timeout: int = None,
force: bool = False) -> List[str]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param cache: whether cache is enabled for the function
:type cache: bool
:param cache_timeout: timeout in seconds for the cache
:type cache_timeout: int
:param force: whether to force refresh the cache
:type force: bool
:return: schema list
:rtype: list
"""
return self.db_engine_spec.get_schema_names(self.inspector)
@ -1232,7 +1221,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
def datasource(self):
return self.get_datasource
@datasource.getter
@datasource.getter # type: ignore
@utils.memoized
def get_datasource(self):
# pylint: disable=no-member

View File

@ -17,6 +17,7 @@
# pylint: disable=C,R,W
"""A set of constants and methods to manage permissions and security"""
import logging
from typing import List
from flask import g
from flask_appbuilder.security.sqla import models as ab_models
@ -26,6 +27,7 @@ from sqlalchemy import or_
from superset import sql_parse
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import SupersetSecurityException
from superset.utils.core import DatasourceName
class SupersetSecurityManager(SecurityManager):
@ -240,7 +242,9 @@ class SupersetSecurityManager(SecurityManager):
subset.add(t.schema)
return sorted(list(subset))
def accessible_by_user(self, database, datasource_names, schema=None):
def get_datasources_accessible_by_user(
self, database, datasource_names: List[DatasourceName],
schema: str = None) -> List[DatasourceName]:
from superset import db
if self.database_access(database) or self.all_datasource_access():
return datasource_names

View File

@ -32,7 +32,7 @@ import signal
import smtplib
import sys
from time import struct_time
from typing import List, Optional, Tuple
from typing import List, NamedTuple, Optional, Tuple
from urllib.parse import unquote_plus
import uuid
import zlib
@ -1100,3 +1100,8 @@ def MediumText() -> Variant:
def shortid() -> str:
return '{}'.format(uuid.uuid4())[-12:]
class DatasourceName(NamedTuple):
table: str
schema: str

View File

@ -22,7 +22,7 @@ import logging
import os
import re
import traceback
from typing import List # noqa: F401
from typing import Dict, List # noqa: F401
from urllib import parse
from flask import (
@ -311,7 +311,7 @@ class DatabaseView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa
db.set_sqlalchemy_uri(db.sqlalchemy_uri)
security_manager.add_permission_view_menu('database_access', db.perm)
# adding a new database we always want to force refresh schema list
for schema in db.all_schema_names():
for schema in db.get_all_schema_names():
security_manager.add_permission_view_menu(
'schema_access', security_manager.get_schema_perm(db, schema))
@ -1545,7 +1545,7 @@ class Superset(BaseSupersetView):
.first()
)
if database:
schemas = database.all_schema_names(
schemas = database.get_all_schema_names(
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout,
force=force_refresh)
@ -1570,50 +1570,57 @@ class Superset(BaseSupersetView):
database = db.session.query(models.Database).filter_by(id=db_id).one()
if schema:
table_names = database.all_table_names_in_schema(
tables = database.get_all_table_names_in_schema(
schema=schema, force=force_refresh,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout)
view_names = database.all_view_names_in_schema(
cache_timeout=database.table_cache_timeout) or []
views = database.get_all_view_names_in_schema(
schema=schema, force=force_refresh,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout)
cache_timeout=database.table_cache_timeout) or []
else:
table_names = database.all_table_names_in_database(
tables = database.get_all_table_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60)
view_names = database.all_view_names_in_database(
views = database.get_all_view_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60)
table_names = security_manager.accessible_by_user(database, table_names, schema)
view_names = security_manager.accessible_by_user(database, view_names, schema)
tables = security_manager.get_datasources_accessible_by_user(
database, tables, schema)
views = security_manager.get_datasources_accessible_by_user(
database, views, schema)
def get_datasource_label(ds_name: utils.DatasourceName) -> str:
return ds_name.table if schema else f'{ds_name.schema}.{ds_name.table}'
if substr:
table_names = [tn for tn in table_names if substr in tn]
view_names = [vn for vn in view_names if substr in vn]
tables = [tn for tn in tables if substr in get_datasource_label(tn)]
views = [vn for vn in views if substr in get_datasource_label(vn)]
if not schema and database.default_schemas:
def get_schema(tbl_or_view_name):
return tbl_or_view_name.split('.')[0] if '.' in tbl_or_view_name else None
user_schema = g.user.email.split('@')[0]
valid_schemas = set(database.default_schemas + [user_schema])
table_names = [tn for tn in table_names if get_schema(tn) in valid_schemas]
view_names = [vn for vn in view_names if get_schema(vn) in valid_schemas]
tables = [tn for tn in tables if tn.schema in valid_schemas]
views = [vn for vn in views if vn.schema in valid_schemas]
max_items = config.get('MAX_TABLE_NAMES') or len(table_names)
total_items = len(table_names) + len(view_names)
max_tables = len(table_names)
max_views = len(view_names)
max_items = config.get('MAX_TABLE_NAMES') or len(tables)
total_items = len(tables) + len(views)
max_tables = len(tables)
max_views = len(views)
if total_items and substr:
max_tables = max_items * len(table_names) // total_items
max_views = max_items * len(view_names) // total_items
max_tables = max_items * len(tables) // total_items
max_views = max_items * len(views) // total_items
table_options = [{'value': tn, 'label': tn}
for tn in table_names[:max_tables]]
table_options.extend([{'value': vn, 'label': '[view] {}'.format(vn)}
for vn in view_names[:max_views]])
def get_datasource_value(ds_name: utils.DatasourceName) -> Dict[str, str]:
return {'schema': ds_name.schema, 'table': ds_name.table}
table_options = [{'value': get_datasource_value(tn),
'label': get_datasource_label(tn)}
for tn in tables[:max_tables]]
table_options.extend([{'value': get_datasource_value(vn),
'label': f'[view] {get_datasource_label(vn)}'}
for vn in views[:max_views]])
payload = {
'tableLength': len(table_names) + len(view_names),
'tableLength': len(tables) + len(views),
'options': table_options,
}
return json_success(json.dumps(payload))

View File

@ -464,3 +464,22 @@ class DbEngineSpecsTestCase(SupersetTestCase):
query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa
self.assertEqual(query, query_expected)
def test_get_table_names(self):
inspector = mock.Mock()
inspector.get_table_names = mock.Mock(return_value=['schema.table', 'table_2'])
inspector.get_foreign_table_names = mock.Mock(return_value=['table_3'])
""" Make sure base engine spec removes schema name from table name
ie. when try_remove_schema_from_table_name == True. """
base_result_expected = ['table', 'table_2']
base_result = db_engine_specs.BaseEngineSpec.get_table_names(
schema='schema', inspector=inspector)
self.assertListEqual(base_result_expected, base_result)
""" Make sure postgres doesn't try to remove schema name from table name
ie. when try_remove_schema_from_table_name == False. """
pg_result_expected = ['schema.table', 'table_2', 'table_3']
pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
schema='schema', inspector=inspector)
self.assertListEqual(pg_result_expected, pg_result)