Add support for period character in table names (#7453)
* Move schema name handling in table names from frontend to backend * Rename all_schema_names to get_all_schema_names * Fix js errors * Fix additional js linting errors * Refactor datasource getters and fix linting errors * Update js unit tests * Add python unit test for get_table_names method * Add python unit test for get_table_names method * Fix js linting error
This commit is contained in:
parent
47ba2ad394
commit
f7d3413a50
|
|
@ -208,19 +208,20 @@ describe('TableSelector', () => {
|
|||
|
||||
it('test 1', () => {
|
||||
wrapper.instance().changeTable({
|
||||
value: 'birth_names',
|
||||
value: { schema: 'main', table: 'birth_names' },
|
||||
label: 'birth_names',
|
||||
});
|
||||
expect(wrapper.state().tableName).toBe('birth_names');
|
||||
});
|
||||
|
||||
it('test 2', () => {
|
||||
it('should call onTableChange with schema from table object', () => {
|
||||
wrapper.setProps({ schema: null });
|
||||
wrapper.instance().changeTable({
|
||||
value: 'main.my_table',
|
||||
label: 'my_table',
|
||||
value: { schema: 'other_schema', table: 'my_table' },
|
||||
label: 'other_schema.my_table',
|
||||
});
|
||||
expect(mockedProps.onTableChange.getCall(0).args[0]).toBe('my_table');
|
||||
expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('main');
|
||||
expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('other_schema');
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -329,15 +329,15 @@ export const databases = {
|
|||
export const tables = {
|
||||
options: [
|
||||
{
|
||||
value: 'birth_names',
|
||||
value: { schema: 'main', table: 'birth_names' },
|
||||
label: 'birth_names',
|
||||
},
|
||||
{
|
||||
value: 'energy_usage',
|
||||
value: { schema: 'main', table: 'energy_usage' },
|
||||
label: 'energy_usage',
|
||||
},
|
||||
{
|
||||
value: 'wb_health_population',
|
||||
value: { schema: 'main', table: 'wb_health_population' },
|
||||
label: 'wb_health_population',
|
||||
},
|
||||
],
|
||||
|
|
|
|||
|
|
@ -83,17 +83,10 @@ export default class SqlEditorLeftBar extends React.PureComponent {
|
|||
this.setState({ tableName: '' });
|
||||
return;
|
||||
}
|
||||
const namePieces = tableOpt.value.split('.');
|
||||
let tableName = namePieces[0];
|
||||
let schemaName = this.props.queryEditor.schema;
|
||||
if (namePieces.length === 1) {
|
||||
this.setState({ tableName });
|
||||
} else {
|
||||
schemaName = namePieces[0];
|
||||
tableName = namePieces[1];
|
||||
this.setState({ tableName });
|
||||
this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
|
||||
}
|
||||
const schemaName = tableOpt.value.schema;
|
||||
const tableName = tableOpt.value.table;
|
||||
this.setState({ tableName });
|
||||
this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
|
||||
this.props.actions.addTable(this.props.queryEditor, tableName, schemaName);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -170,13 +170,8 @@ export default class TableSelector extends React.PureComponent {
|
|||
this.setState({ tableName: '' });
|
||||
return;
|
||||
}
|
||||
const namePieces = tableOpt.value.split('.');
|
||||
let tableName = namePieces[0];
|
||||
let schemaName = this.props.schema;
|
||||
if (namePieces.length > 1) {
|
||||
schemaName = namePieces[0];
|
||||
tableName = namePieces[1];
|
||||
}
|
||||
const schemaName = tableOpt.value.schema;
|
||||
const tableName = tableOpt.value.table;
|
||||
if (this.props.tableNameSticky) {
|
||||
this.setState({ tableName }, this.onChange);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -288,9 +288,9 @@ def update_datasources_cache():
|
|||
if database.allow_multi_schema_metadata_fetch:
|
||||
print('Fetching {} datasources ...'.format(database.name))
|
||||
try:
|
||||
database.all_table_names_in_database(
|
||||
database.get_all_table_names_in_database(
|
||||
force=True, cache=True, cache_timeout=24 * 60 * 60)
|
||||
database.all_view_names_in_database(
|
||||
database.get_all_view_names_in_database(
|
||||
force=True, cache=True, cache_timeout=24 * 60 * 60)
|
||||
except Exception as e:
|
||||
print('{}'.format(str(e)))
|
||||
|
|
|
|||
|
|
@ -122,6 +122,7 @@ class BaseEngineSpec(object):
|
|||
force_column_alias_quotes = False
|
||||
arraysize = 0
|
||||
max_column_name_length = 0
|
||||
try_remove_schema_from_table_name = True
|
||||
|
||||
@classmethod
|
||||
def get_time_expr(cls, expr, pdf, time_grain, grain):
|
||||
|
|
@ -279,33 +280,32 @@ class BaseEngineSpec(object):
|
|||
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
@classmethod
|
||||
def fetch_result_sets(cls, db, datasource_type):
|
||||
"""Returns a list of tables [schema1.table1, schema2.table2, ...]
|
||||
def get_all_datasource_names(cls, db, datasource_type: str) \
|
||||
-> List[utils.DatasourceName]:
|
||||
"""Returns a list of all tables or views in database.
|
||||
|
||||
Datasource_type can be 'table' or 'view'.
|
||||
Empty schema corresponds to the list of full names of the all
|
||||
tables or views: <schema>.<result_set_name>.
|
||||
:param db: Database instance
|
||||
:param datasource_type: Datasource_type can be 'table' or 'view'
|
||||
:return: List of all datasources in database or schema
|
||||
"""
|
||||
schemas = db.all_schema_names(cache=db.schema_cache_enabled,
|
||||
cache_timeout=db.schema_cache_timeout,
|
||||
force=True)
|
||||
all_result_sets = []
|
||||
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
|
||||
cache_timeout=db.schema_cache_timeout,
|
||||
force=True)
|
||||
all_datasources: List[utils.DatasourceName] = []
|
||||
for schema in schemas:
|
||||
if datasource_type == 'table':
|
||||
all_datasource_names = db.all_table_names_in_schema(
|
||||
all_datasources += db.get_all_table_names_in_schema(
|
||||
schema=schema, force=True,
|
||||
cache=db.table_cache_enabled,
|
||||
cache_timeout=db.table_cache_timeout)
|
||||
elif datasource_type == 'view':
|
||||
all_datasource_names = db.all_view_names_in_schema(
|
||||
all_datasources += db.get_all_view_names_in_schema(
|
||||
schema=schema, force=True,
|
||||
cache=db.table_cache_enabled,
|
||||
cache_timeout=db.table_cache_timeout)
|
||||
else:
|
||||
raise Exception(f'Unsupported datasource_type: {datasource_type}')
|
||||
all_result_sets += [
|
||||
'{}.{}'.format(schema, t) for t in all_datasource_names]
|
||||
return all_result_sets
|
||||
return all_datasources
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor, query, session):
|
||||
|
|
@ -352,11 +352,17 @@ class BaseEngineSpec(object):
|
|||
|
||||
@classmethod
|
||||
def get_table_names(cls, inspector, schema):
|
||||
return sorted(inspector.get_table_names(schema))
|
||||
tables = inspector.get_table_names(schema)
|
||||
if schema and cls.try_remove_schema_from_table_name:
|
||||
tables = [re.sub(f'^{schema}\\.', '', table) for table in tables]
|
||||
return sorted(tables)
|
||||
|
||||
@classmethod
|
||||
def get_view_names(cls, inspector, schema):
|
||||
return sorted(inspector.get_view_names(schema))
|
||||
views = inspector.get_view_names(schema)
|
||||
if schema and cls.try_remove_schema_from_table_name:
|
||||
views = [re.sub(f'^{schema}\\.', '', view) for view in views]
|
||||
return sorted(views)
|
||||
|
||||
@classmethod
|
||||
def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list:
|
||||
|
|
@ -528,6 +534,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
class PostgresEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'postgresql'
|
||||
max_column_name_length = 63
|
||||
try_remove_schema_from_table_name = False
|
||||
|
||||
@classmethod
|
||||
def get_table_names(cls, inspector, schema):
|
||||
|
|
@ -685,29 +692,25 @@ class SqliteEngineSpec(BaseEngineSpec):
|
|||
return "datetime({col}, 'unixepoch')"
|
||||
|
||||
@classmethod
|
||||
def fetch_result_sets(cls, db, datasource_type):
|
||||
schemas = db.all_schema_names(cache=db.schema_cache_enabled,
|
||||
cache_timeout=db.schema_cache_timeout,
|
||||
force=True)
|
||||
all_result_sets = []
|
||||
def get_all_datasource_names(cls, db, datasource_type: str) \
|
||||
-> List[utils.DatasourceName]:
|
||||
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
|
||||
cache_timeout=db.schema_cache_timeout,
|
||||
force=True)
|
||||
schema = schemas[0]
|
||||
if datasource_type == 'table':
|
||||
all_datasource_names = db.all_table_names_in_schema(
|
||||
return db.get_all_table_names_in_schema(
|
||||
schema=schema, force=True,
|
||||
cache=db.table_cache_enabled,
|
||||
cache_timeout=db.table_cache_timeout)
|
||||
elif datasource_type == 'view':
|
||||
all_datasource_names = db.all_view_names_in_schema(
|
||||
return db.get_all_view_names_in_schema(
|
||||
schema=schema, force=True,
|
||||
cache=db.table_cache_enabled,
|
||||
cache_timeout=db.table_cache_timeout)
|
||||
else:
|
||||
raise Exception(f'Unsupported datasource_type: {datasource_type}')
|
||||
|
||||
all_result_sets += [
|
||||
'{}.{}'.format(schema, t) for t in all_datasource_names]
|
||||
return all_result_sets
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
iso = dttm.isoformat().replace('T', ' ')
|
||||
|
|
@ -1107,24 +1110,19 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
return 'from_unixtime({col})'
|
||||
|
||||
@classmethod
|
||||
def fetch_result_sets(cls, db, datasource_type):
|
||||
"""Returns a list of tables [schema1.table1, schema2.table2, ...]
|
||||
|
||||
Datasource_type can be 'table' or 'view'.
|
||||
Empty schema corresponds to the list of full names of the all
|
||||
tables or views: <schema>.<result_set_name>.
|
||||
"""
|
||||
result_set_df = db.get_df(
|
||||
def get_all_datasource_names(cls, db, datasource_type: str) \
|
||||
-> List[utils.DatasourceName]:
|
||||
datasource_df = db.get_df(
|
||||
"""SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S
|
||||
ORDER BY concat(table_schema, '.', table_name)""".format(
|
||||
datasource_type.upper(),
|
||||
),
|
||||
None)
|
||||
result_sets = []
|
||||
for unused, row in result_set_df.iterrows():
|
||||
result_sets.append('{}.{}'.format(
|
||||
row['table_schema'], row['table_name']))
|
||||
return result_sets
|
||||
datasource_names: List[utils.DatasourceName] = []
|
||||
for unused, row in datasource_df.iterrows():
|
||||
datasource_names.append(utils.DatasourceName(
|
||||
schema=row['table_schema'], table=row['table_name']))
|
||||
return datasource_names
|
||||
|
||||
@classmethod
|
||||
def extra_table_metadata(cls, database, table_name, schema_name):
|
||||
|
|
@ -1385,9 +1383,9 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
hive.Cursor.fetch_logs = patched_hive.fetch_logs
|
||||
|
||||
@classmethod
|
||||
def fetch_result_sets(cls, db, datasource_type):
|
||||
return BaseEngineSpec.fetch_result_sets(
|
||||
db, datasource_type)
|
||||
def get_all_datasource_names(cls, db, datasource_type: str) \
|
||||
-> List[utils.DatasourceName]:
|
||||
return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit):
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import functools
|
|||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from flask import escape, g, Markup, request
|
||||
from flask_appbuilder import Model
|
||||
|
|
@ -65,6 +66,7 @@ metadata = Model.metadata # pylint: disable=no-member
|
|||
|
||||
PASSWORD_MASK = 'X' * 10
|
||||
|
||||
|
||||
def set_related_perm(mapper, connection, target): # noqa
|
||||
src_class = target.cls_model
|
||||
id_ = target.datasource_id
|
||||
|
|
@ -184,7 +186,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
description=self.description,
|
||||
cache_timeout=self.cache_timeout)
|
||||
|
||||
@datasource.getter
|
||||
@datasource.getter # type: ignore
|
||||
@utils.memoized
|
||||
def get_datasource(self):
|
||||
return (
|
||||
|
|
@ -210,7 +212,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
|
|||
datasource = self.datasource
|
||||
return datasource.url if datasource else None
|
||||
|
||||
@property
|
||||
@property # type: ignore
|
||||
@utils.memoized
|
||||
def viz(self):
|
||||
d = json.loads(self.params)
|
||||
|
|
@ -930,100 +932,87 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: 'db:{}:schema:None:table_list',
|
||||
attribute_in_key='id')
|
||||
def all_table_names_in_database(self, cache=False,
|
||||
cache_timeout=None, force=False):
|
||||
def get_all_table_names_in_database(self, cache: bool = False,
|
||||
cache_timeout: bool = None,
|
||||
force=False) -> List[utils.DatasourceName]:
|
||||
"""Parameters need to be passed as keyword arguments."""
|
||||
if not self.allow_multi_schema_metadata_fetch:
|
||||
return []
|
||||
return self.db_engine_spec.fetch_result_sets(self, 'table')
|
||||
return self.db_engine_spec.get_all_datasource_names(self, 'table')
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: 'db:{}:schema:None:view_list',
|
||||
attribute_in_key='id')
|
||||
def all_view_names_in_database(self, cache=False,
|
||||
cache_timeout=None, force=False):
|
||||
def get_all_view_names_in_database(self, cache: bool = False,
|
||||
cache_timeout: bool = None,
|
||||
force: bool = False) -> List[utils.DatasourceName]:
|
||||
"""Parameters need to be passed as keyword arguments."""
|
||||
if not self.allow_multi_schema_metadata_fetch:
|
||||
return []
|
||||
return self.db_engine_spec.fetch_result_sets(self, 'view')
|
||||
return self.db_engine_spec.get_all_datasource_names(self, 'view')
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: 'db:{{}}:schema:{}:table_list'.format(
|
||||
kwargs.get('schema')),
|
||||
attribute_in_key='id')
|
||||
def all_table_names_in_schema(self, schema, cache=False,
|
||||
cache_timeout=None, force=False):
|
||||
def get_all_table_names_in_schema(self, schema: str, cache: bool = False,
|
||||
cache_timeout: int = None, force: bool = False):
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
|
||||
For unused parameters, they are referenced in
|
||||
cache_util.memoized_func decorator.
|
||||
|
||||
:param schema: schema name
|
||||
:type schema: str
|
||||
:param cache: whether cache is enabled for the function
|
||||
:type cache: bool
|
||||
:param cache_timeout: timeout in seconds for the cache
|
||||
:type cache_timeout: int
|
||||
:param force: whether to force refresh the cache
|
||||
:type force: bool
|
||||
:return: table list
|
||||
:rtype: list
|
||||
:return: list of tables
|
||||
"""
|
||||
tables = []
|
||||
try:
|
||||
tables = self.db_engine_spec.get_table_names(
|
||||
inspector=self.inspector, schema=schema)
|
||||
return [utils.DatasourceName(table=table, schema=schema) for table in tables]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return tables
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: 'db:{{}}:schema:{}:view_list'.format(
|
||||
kwargs.get('schema')),
|
||||
attribute_in_key='id')
|
||||
def all_view_names_in_schema(self, schema, cache=False,
|
||||
cache_timeout=None, force=False):
|
||||
def get_all_view_names_in_schema(self, schema: str, cache: bool = False,
|
||||
cache_timeout: int = None, force: bool = False):
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
|
||||
For unused parameters, they are referenced in
|
||||
cache_util.memoized_func decorator.
|
||||
|
||||
:param schema: schema name
|
||||
:type schema: str
|
||||
:param cache: whether cache is enabled for the function
|
||||
:type cache: bool
|
||||
:param cache_timeout: timeout in seconds for the cache
|
||||
:type cache_timeout: int
|
||||
:param force: whether to force refresh the cache
|
||||
:type force: bool
|
||||
:return: view list
|
||||
:rtype: list
|
||||
:return: list of views
|
||||
"""
|
||||
views = []
|
||||
try:
|
||||
views = self.db_engine_spec.get_view_names(
|
||||
inspector=self.inspector, schema=schema)
|
||||
return [utils.DatasourceName(table=view, schema=schema) for view in views]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return views
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: 'db:{}:schema_list',
|
||||
attribute_in_key='id')
|
||||
def all_schema_names(self, cache=False, cache_timeout=None, force=False):
|
||||
def get_all_schema_names(self, cache: bool = False, cache_timeout: int = None,
|
||||
force: bool = False) -> List[str]:
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
|
||||
For unused parameters, they are referenced in
|
||||
cache_util.memoized_func decorator.
|
||||
|
||||
:param cache: whether cache is enabled for the function
|
||||
:type cache: bool
|
||||
:param cache_timeout: timeout in seconds for the cache
|
||||
:type cache_timeout: int
|
||||
:param force: whether to force refresh the cache
|
||||
:type force: bool
|
||||
:return: schema list
|
||||
:rtype: list
|
||||
"""
|
||||
return self.db_engine_spec.get_schema_names(self.inspector)
|
||||
|
||||
|
|
@ -1232,7 +1221,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
|
|||
def datasource(self):
|
||||
return self.get_datasource
|
||||
|
||||
@datasource.getter
|
||||
@datasource.getter # type: ignore
|
||||
@utils.memoized
|
||||
def get_datasource(self):
|
||||
# pylint: disable=no-member
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# pylint: disable=C,R,W
|
||||
"""A set of constants and methods to manage permissions and security"""
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from flask import g
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
|
|
@ -26,6 +27,7 @@ from sqlalchemy import or_
|
|||
from superset import sql_parse
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.utils.core import DatasourceName
|
||||
|
||||
|
||||
class SupersetSecurityManager(SecurityManager):
|
||||
|
|
@ -240,7 +242,9 @@ class SupersetSecurityManager(SecurityManager):
|
|||
subset.add(t.schema)
|
||||
return sorted(list(subset))
|
||||
|
||||
def accessible_by_user(self, database, datasource_names, schema=None):
|
||||
def get_datasources_accessible_by_user(
|
||||
self, database, datasource_names: List[DatasourceName],
|
||||
schema: str = None) -> List[DatasourceName]:
|
||||
from superset import db
|
||||
if self.database_access(database) or self.all_datasource_access():
|
||||
return datasource_names
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import signal
|
|||
import smtplib
|
||||
import sys
|
||||
from time import struct_time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
from urllib.parse import unquote_plus
|
||||
import uuid
|
||||
import zlib
|
||||
|
|
@ -1100,3 +1100,8 @@ def MediumText() -> Variant:
|
|||
|
||||
def shortid() -> str:
|
||||
return '{}'.format(uuid.uuid4())[-12:]
|
||||
|
||||
|
||||
class DatasourceName(NamedTuple):
|
||||
table: str
|
||||
schema: str
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import List # noqa: F401
|
||||
from typing import Dict, List # noqa: F401
|
||||
from urllib import parse
|
||||
|
||||
from flask import (
|
||||
|
|
@ -311,7 +311,7 @@ class DatabaseView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa
|
|||
db.set_sqlalchemy_uri(db.sqlalchemy_uri)
|
||||
security_manager.add_permission_view_menu('database_access', db.perm)
|
||||
# adding a new database we always want to force refresh schema list
|
||||
for schema in db.all_schema_names():
|
||||
for schema in db.get_all_schema_names():
|
||||
security_manager.add_permission_view_menu(
|
||||
'schema_access', security_manager.get_schema_perm(db, schema))
|
||||
|
||||
|
|
@ -1545,7 +1545,7 @@ class Superset(BaseSupersetView):
|
|||
.first()
|
||||
)
|
||||
if database:
|
||||
schemas = database.all_schema_names(
|
||||
schemas = database.get_all_schema_names(
|
||||
cache=database.schema_cache_enabled,
|
||||
cache_timeout=database.schema_cache_timeout,
|
||||
force=force_refresh)
|
||||
|
|
@ -1570,50 +1570,57 @@ class Superset(BaseSupersetView):
|
|||
database = db.session.query(models.Database).filter_by(id=db_id).one()
|
||||
|
||||
if schema:
|
||||
table_names = database.all_table_names_in_schema(
|
||||
tables = database.get_all_table_names_in_schema(
|
||||
schema=schema, force=force_refresh,
|
||||
cache=database.table_cache_enabled,
|
||||
cache_timeout=database.table_cache_timeout)
|
||||
view_names = database.all_view_names_in_schema(
|
||||
cache_timeout=database.table_cache_timeout) or []
|
||||
views = database.get_all_view_names_in_schema(
|
||||
schema=schema, force=force_refresh,
|
||||
cache=database.table_cache_enabled,
|
||||
cache_timeout=database.table_cache_timeout)
|
||||
cache_timeout=database.table_cache_timeout) or []
|
||||
else:
|
||||
table_names = database.all_table_names_in_database(
|
||||
tables = database.get_all_table_names_in_database(
|
||||
cache=True, force=False, cache_timeout=24 * 60 * 60)
|
||||
view_names = database.all_view_names_in_database(
|
||||
views = database.get_all_view_names_in_database(
|
||||
cache=True, force=False, cache_timeout=24 * 60 * 60)
|
||||
table_names = security_manager.accessible_by_user(database, table_names, schema)
|
||||
view_names = security_manager.accessible_by_user(database, view_names, schema)
|
||||
tables = security_manager.get_datasources_accessible_by_user(
|
||||
database, tables, schema)
|
||||
views = security_manager.get_datasources_accessible_by_user(
|
||||
database, views, schema)
|
||||
|
||||
def get_datasource_label(ds_name: utils.DatasourceName) -> str:
|
||||
return ds_name.table if schema else f'{ds_name.schema}.{ds_name.table}'
|
||||
|
||||
if substr:
|
||||
table_names = [tn for tn in table_names if substr in tn]
|
||||
view_names = [vn for vn in view_names if substr in vn]
|
||||
tables = [tn for tn in tables if substr in get_datasource_label(tn)]
|
||||
views = [vn for vn in views if substr in get_datasource_label(vn)]
|
||||
|
||||
if not schema and database.default_schemas:
|
||||
def get_schema(tbl_or_view_name):
|
||||
return tbl_or_view_name.split('.')[0] if '.' in tbl_or_view_name else None
|
||||
|
||||
user_schema = g.user.email.split('@')[0]
|
||||
valid_schemas = set(database.default_schemas + [user_schema])
|
||||
|
||||
table_names = [tn for tn in table_names if get_schema(tn) in valid_schemas]
|
||||
view_names = [vn for vn in view_names if get_schema(vn) in valid_schemas]
|
||||
tables = [tn for tn in tables if tn.schema in valid_schemas]
|
||||
views = [vn for vn in views if vn.schema in valid_schemas]
|
||||
|
||||
max_items = config.get('MAX_TABLE_NAMES') or len(table_names)
|
||||
total_items = len(table_names) + len(view_names)
|
||||
max_tables = len(table_names)
|
||||
max_views = len(view_names)
|
||||
max_items = config.get('MAX_TABLE_NAMES') or len(tables)
|
||||
total_items = len(tables) + len(views)
|
||||
max_tables = len(tables)
|
||||
max_views = len(views)
|
||||
if total_items and substr:
|
||||
max_tables = max_items * len(table_names) // total_items
|
||||
max_views = max_items * len(view_names) // total_items
|
||||
max_tables = max_items * len(tables) // total_items
|
||||
max_views = max_items * len(views) // total_items
|
||||
|
||||
table_options = [{'value': tn, 'label': tn}
|
||||
for tn in table_names[:max_tables]]
|
||||
table_options.extend([{'value': vn, 'label': '[view] {}'.format(vn)}
|
||||
for vn in view_names[:max_views]])
|
||||
def get_datasource_value(ds_name: utils.DatasourceName) -> Dict[str, str]:
|
||||
return {'schema': ds_name.schema, 'table': ds_name.table}
|
||||
|
||||
table_options = [{'value': get_datasource_value(tn),
|
||||
'label': get_datasource_label(tn)}
|
||||
for tn in tables[:max_tables]]
|
||||
table_options.extend([{'value': get_datasource_value(vn),
|
||||
'label': f'[view] {get_datasource_label(vn)}'}
|
||||
for vn in views[:max_views]])
|
||||
payload = {
|
||||
'tableLength': len(table_names) + len(view_names),
|
||||
'tableLength': len(tables) + len(views),
|
||||
'options': table_options,
|
||||
}
|
||||
return json_success(json.dumps(payload))
|
||||
|
|
|
|||
|
|
@ -464,3 +464,22 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
|
||||
query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa
|
||||
self.assertEqual(query, query_expected)
|
||||
|
||||
def test_get_table_names(self):
|
||||
inspector = mock.Mock()
|
||||
inspector.get_table_names = mock.Mock(return_value=['schema.table', 'table_2'])
|
||||
inspector.get_foreign_table_names = mock.Mock(return_value=['table_3'])
|
||||
|
||||
""" Make sure base engine spec removes schema name from table name
|
||||
ie. when try_remove_schema_from_table_name == True. """
|
||||
base_result_expected = ['table', 'table_2']
|
||||
base_result = db_engine_specs.BaseEngineSpec.get_table_names(
|
||||
schema='schema', inspector=inspector)
|
||||
self.assertListEqual(base_result_expected, base_result)
|
||||
|
||||
""" Make sure postgres doesn't try to remove schema name from table name
|
||||
ie. when try_remove_schema_from_table_name == False. """
|
||||
pg_result_expected = ['schema.table', 'table_2', 'table_3']
|
||||
pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
|
||||
schema='schema', inspector=inspector)
|
||||
self.assertListEqual(pg_result_expected, pg_result)
|
||||
|
|
|
|||
Loading…
Reference in New Issue