fix: handle empty catalog when DB supports them (#29840)
This commit is contained in:
parent
9f5eb899e8
commit
39209c2b40
|
|
@ -52,7 +52,7 @@ GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
|
||||||
def fetch_files_github_api(url: str): # type: ignore
|
def fetch_files_github_api(url: str): # type: ignore
|
||||||
"""Fetches data using GitHub API."""
|
"""Fetches data using GitHub API."""
|
||||||
req = Request(url)
|
req = Request(url)
|
||||||
req.add_header("Authorization", f"token {GITHUB_TOKEN}")
|
req.add_header("Authorization", f"Bearer {GITHUB_TOKEN}")
|
||||||
req.add_header("Accept", "application/vnd.github.v3+json")
|
req.add_header("Accept", "application/vnd.github.v3+json")
|
||||||
|
|
||||||
print(f"Fetching from {url}")
|
print(f"Fetching from {url}")
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,10 @@ class Datasource(Schema):
|
||||||
datasource_name = fields.String(
|
datasource_name = fields.String(
|
||||||
metadata={"description": datasource_name_description},
|
metadata={"description": datasource_name_description},
|
||||||
)
|
)
|
||||||
|
catalog = fields.String(
|
||||||
|
allow_none=True,
|
||||||
|
metadata={"description": "Datasource catalog"},
|
||||||
|
)
|
||||||
schema = fields.String(
|
schema = fields.String(
|
||||||
metadata={"description": "Datasource schema"},
|
metadata={"description": "Datasource schema"},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -54,23 +54,28 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: list[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
database_id = self._properties["database"]
|
database_id = self._properties["database"]
|
||||||
schema = self._properties.get("schema")
|
|
||||||
catalog = self._properties.get("catalog")
|
catalog = self._properties.get("catalog")
|
||||||
|
schema = self._properties.get("schema")
|
||||||
|
table_name = self._properties["table_name"]
|
||||||
sql = self._properties.get("sql")
|
sql = self._properties.get("sql")
|
||||||
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||||
|
|
||||||
table = Table(self._properties["table_name"], schema, catalog)
|
|
||||||
|
|
||||||
# Validate uniqueness
|
|
||||||
if not DatasetDAO.validate_uniqueness(database_id, table):
|
|
||||||
exceptions.append(DatasetExistsValidationError(table))
|
|
||||||
|
|
||||||
# Validate/Populate database
|
# Validate/Populate database
|
||||||
database = DatasetDAO.get_database_by_id(database_id)
|
database = DatasetDAO.get_database_by_id(database_id)
|
||||||
if not database:
|
if not database:
|
||||||
exceptions.append(DatabaseNotFoundValidationError())
|
exceptions.append(DatabaseNotFoundValidationError())
|
||||||
self._properties["database"] = database
|
self._properties["database"] = database
|
||||||
|
|
||||||
|
# Validate uniqueness
|
||||||
|
if database:
|
||||||
|
if not catalog:
|
||||||
|
catalog = self._properties["catalog"] = database.get_default_catalog()
|
||||||
|
|
||||||
|
table = Table(table_name, schema, catalog)
|
||||||
|
|
||||||
|
if not DatasetDAO.validate_uniqueness(database, table):
|
||||||
|
exceptions.append(DatasetExistsValidationError(table))
|
||||||
|
|
||||||
# Validate table exists on dataset if sql is not provided
|
# Validate table exists on dataset if sql is not provided
|
||||||
# This should be validated when the dataset is physical
|
# This should be validated when the dataset is physical
|
||||||
if (
|
if (
|
||||||
|
|
|
||||||
|
|
@ -166,7 +166,7 @@ def import_dataset(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
table_exists = dataset.database.has_table(
|
table_exists = dataset.database.has_table(
|
||||||
Table(dataset.table_name, dataset.schema),
|
Table(dataset.table_name, dataset.schema, dataset.catalog),
|
||||||
)
|
)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
# MySQL doesn't play nice with GSheets table names
|
# MySQL doesn't play nice with GSheets table names
|
||||||
|
|
|
||||||
|
|
@ -79,10 +79,12 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: list[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||||
|
|
||||||
# Validate/populate model exists
|
# Validate/populate model exists
|
||||||
self._model = DatasetDAO.find_by_id(self._model_id)
|
self._model = DatasetDAO.find_by_id(self._model_id)
|
||||||
if not self._model:
|
if not self._model:
|
||||||
raise DatasetNotFoundError()
|
raise DatasetNotFoundError()
|
||||||
|
|
||||||
# Check ownership
|
# Check ownership
|
||||||
try:
|
try:
|
||||||
security_manager.raise_for_ownership(self._model)
|
security_manager.raise_for_ownership(self._model)
|
||||||
|
|
@ -91,22 +93,30 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
||||||
|
|
||||||
database_id = self._properties.get("database")
|
database_id = self._properties.get("database")
|
||||||
|
|
||||||
|
catalog = self._properties.get("catalog")
|
||||||
|
if not catalog:
|
||||||
|
catalog = self._properties["catalog"] = (
|
||||||
|
self._model.database.get_default_catalog()
|
||||||
|
)
|
||||||
|
|
||||||
table = Table(
|
table = Table(
|
||||||
self._properties.get("table_name"), # type: ignore
|
self._properties.get("table_name"), # type: ignore
|
||||||
self._properties.get("schema"),
|
self._properties.get("schema"),
|
||||||
self._properties.get("catalog"),
|
catalog,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate uniqueness
|
# Validate uniqueness
|
||||||
if not DatasetDAO.validate_update_uniqueness(
|
if not DatasetDAO.validate_update_uniqueness(
|
||||||
self._model.database_id,
|
self._model.database,
|
||||||
table,
|
table,
|
||||||
self._model_id,
|
self._model_id,
|
||||||
):
|
):
|
||||||
exceptions.append(DatasetExistsValidationError(table))
|
exceptions.append(DatasetExistsValidationError(table))
|
||||||
|
|
||||||
# Validate/Populate database not allowed to change
|
# Validate/Populate database not allowed to change
|
||||||
if database_id and database_id != self._model:
|
if database_id and database_id != self._model:
|
||||||
exceptions.append(DatabaseChangeValidationError())
|
exceptions.append(DatabaseChangeValidationError())
|
||||||
|
|
||||||
# Validate/Populate owner
|
# Validate/Populate owner
|
||||||
try:
|
try:
|
||||||
owners = self.compute_owners(
|
owners = self.compute_owners(
|
||||||
|
|
@ -116,6 +126,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
||||||
self._properties["owners"] = owners
|
self._properties["owners"] = owners
|
||||||
except ValidationError as ex:
|
except ValidationError as ex:
|
||||||
exceptions.append(ex)
|
exceptions.append(ex)
|
||||||
|
|
||||||
# Validate columns
|
# Validate columns
|
||||||
if columns := self._properties.get("columns"):
|
if columns := self._properties.get("columns"):
|
||||||
self._validate_columns(columns, exceptions)
|
self._validate_columns(columns, exceptions)
|
||||||
|
|
|
||||||
|
|
@ -461,9 +461,11 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_columns = [
|
_columns = [
|
||||||
utils.get_column_name(column_)
|
(
|
||||||
if utils.is_adhoc_column(column_)
|
utils.get_column_name(column_)
|
||||||
else column_
|
if utils.is_adhoc_column(column_)
|
||||||
|
else column_
|
||||||
|
)
|
||||||
for column_param in COLUMN_FORM_DATA_PARAMS
|
for column_param in COLUMN_FORM_DATA_PARAMS
|
||||||
for column_ in utils.as_list(form_data.get(column_param) or [])
|
for column_ in utils.as_list(form_data.get(column_param) or [])
|
||||||
]
|
]
|
||||||
|
|
@ -1963,7 +1965,7 @@ class SqlaTable(
|
||||||
if self.has_extra_cache_key_calls(query_obj):
|
if self.has_extra_cache_key_calls(query_obj):
|
||||||
sqla_query = self.get_sqla_query(**query_obj)
|
sqla_query = self.get_sqla_query(**query_obj)
|
||||||
extra_cache_keys += sqla_query.extra_cache_keys
|
extra_cache_keys += sqla_query.extra_cache_keys
|
||||||
return extra_cache_keys
|
return list(set(extra_cache_keys))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def quote_identifier(self) -> Callable[[str], str]:
|
def quote_identifier(self) -> Callable[[str], str]:
|
||||||
|
|
|
||||||
|
|
@ -84,15 +84,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_uniqueness(
|
def validate_uniqueness(
|
||||||
database_id: int,
|
database: Database,
|
||||||
table: Table,
|
table: Table,
|
||||||
dataset_id: int | None = None,
|
dataset_id: int | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
# The catalog might not be set even if the database supports catalogs, in case
|
||||||
|
# multi-catalog is disabled.
|
||||||
|
catalog = table.catalog or database.get_default_catalog()
|
||||||
|
|
||||||
dataset_query = db.session.query(SqlaTable).filter(
|
dataset_query = db.session.query(SqlaTable).filter(
|
||||||
SqlaTable.table_name == table.table,
|
SqlaTable.table_name == table.table,
|
||||||
SqlaTable.schema == table.schema,
|
SqlaTable.schema == table.schema,
|
||||||
SqlaTable.catalog == table.catalog,
|
SqlaTable.catalog == catalog,
|
||||||
SqlaTable.database_id == database_id,
|
SqlaTable.database_id == database.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
|
|
@ -103,15 +107,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_update_uniqueness(
|
def validate_update_uniqueness(
|
||||||
database_id: int,
|
database: Database,
|
||||||
table: Table,
|
table: Table,
|
||||||
dataset_id: int,
|
dataset_id: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
# The catalog might not be set even if the database supports catalogs, in case
|
||||||
|
# multi-catalog is disabled.
|
||||||
|
catalog = table.catalog or database.get_default_catalog()
|
||||||
|
|
||||||
dataset_query = db.session.query(SqlaTable).filter(
|
dataset_query = db.session.query(SqlaTable).filter(
|
||||||
SqlaTable.table_name == table.table,
|
SqlaTable.table_name == table.table,
|
||||||
SqlaTable.database_id == database_id,
|
SqlaTable.database_id == database.id,
|
||||||
SqlaTable.schema == table.schema,
|
SqlaTable.schema == table.schema,
|
||||||
SqlaTable.catalog == table.catalog,
|
SqlaTable.catalog == catalog,
|
||||||
SqlaTable.id != dataset_id,
|
SqlaTable.id != dataset_id,
|
||||||
)
|
)
|
||||||
return not db.session.query(dataset_query.exists()).scalar()
|
return not db.session.query(dataset_query.exists()).scalar()
|
||||||
|
|
|
||||||
|
|
@ -1159,7 +1159,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
self.incr_stats("init", self.select_star.__name__)
|
self.incr_stats("init", self.select_star.__name__)
|
||||||
try:
|
try:
|
||||||
result = database.select_star(
|
result = database.select_star(
|
||||||
Table(table_name, schema_name),
|
Table(table_name, schema_name, database.get_default_catalog()),
|
||||||
latest_partition=True,
|
latest_partition=True,
|
||||||
)
|
)
|
||||||
except NoSuchTableError:
|
except NoSuchTableError:
|
||||||
|
|
|
||||||
|
|
@ -565,7 +565,7 @@ class NoOpTemplateProcessor(BaseTemplateProcessor):
|
||||||
"""
|
"""
|
||||||
Makes processing a template a noop
|
Makes processing a template a noop
|
||||||
"""
|
"""
|
||||||
return sql
|
return str(sql)
|
||||||
|
|
||||||
|
|
||||||
class PrestoTemplateProcessor(JinjaTemplateProcessor):
|
class PrestoTemplateProcessor(JinjaTemplateProcessor):
|
||||||
|
|
|
||||||
|
|
@ -491,7 +491,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||||
g.user.id,
|
g.user.id,
|
||||||
self.db_engine_spec,
|
self.db_engine_spec,
|
||||||
)
|
)
|
||||||
if hasattr(g, "user") and hasattr(g.user, "id") and oauth2_config
|
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
# If using MySQL or Presto for example, will set url.username
|
# If using MySQL or Presto for example, will set url.username
|
||||||
|
|
|
||||||
|
|
@ -369,6 +369,7 @@ def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) ->
|
||||||
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
|
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
|
||||||
if ds:
|
if ds:
|
||||||
target.perm = ds.perm
|
target.perm = ds.perm
|
||||||
|
target.catalog_perm = ds.catalog_perm
|
||||||
target.schema_perm = ds.schema_perm
|
target.schema_perm = ds.schema_perm
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -774,6 +774,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
|
||||||
|
default_catalog = database.get_default_catalog()
|
||||||
|
catalog = catalog or default_catalog
|
||||||
|
|
||||||
if hierarchical and (
|
if hierarchical and (
|
||||||
self.can_access_database(database)
|
self.can_access_database(database)
|
||||||
or (catalog and self.can_access_catalog(database, catalog))
|
or (catalog and self.can_access_catalog(database, catalog))
|
||||||
|
|
@ -783,7 +786,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
# schema_access
|
# schema_access
|
||||||
accessible_schemas: set[str] = set()
|
accessible_schemas: set[str] = set()
|
||||||
schema_access = self.user_view_menu_names("schema_access")
|
schema_access = self.user_view_menu_names("schema_access")
|
||||||
default_catalog = database.get_default_catalog()
|
|
||||||
default_schema = database.get_default_schema(default_catalog)
|
default_schema = database.get_default_schema(default_catalog)
|
||||||
|
|
||||||
for perm in schema_access:
|
for perm in schema_access:
|
||||||
|
|
@ -800,7 +802,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
# [database].[catalog].[schema] matches when the catalog is equal to the
|
# [database].[catalog].[schema] matches when the catalog is equal to the
|
||||||
# requested catalog or, when no catalog specified, it's equal to the default
|
# requested catalog or, when no catalog specified, it's equal to the default
|
||||||
# catalog.
|
# catalog.
|
||||||
elif len(parts) == 3 and parts[1] == (catalog or default_catalog):
|
elif len(parts) == 3 and parts[1] == catalog:
|
||||||
accessible_schemas.add(parts[2])
|
accessible_schemas.add(parts[2])
|
||||||
|
|
||||||
# datasource_access
|
# datasource_access
|
||||||
|
|
@ -906,16 +908,16 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
if self.can_access_database(database):
|
if self.can_access_database(database):
|
||||||
return datasource_names
|
return datasource_names
|
||||||
|
|
||||||
|
catalog = catalog or database.get_default_catalog()
|
||||||
if catalog:
|
if catalog:
|
||||||
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
|
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
|
||||||
if catalog_perm and self.can_access("catalog_access", catalog_perm):
|
if catalog_perm and self.can_access("catalog_access", catalog_perm):
|
||||||
return datasource_names
|
return datasource_names
|
||||||
|
|
||||||
if schema:
|
if schema:
|
||||||
default_catalog = database.get_default_catalog()
|
|
||||||
schema_perm = self.get_schema_perm(
|
schema_perm = self.get_schema_perm(
|
||||||
database.database_name,
|
database.database_name,
|
||||||
catalog or default_catalog,
|
catalog,
|
||||||
schema,
|
schema,
|
||||||
)
|
)
|
||||||
if schema_perm and self.can_access("schema_access", schema_perm):
|
if schema_perm and self.can_access("schema_access", schema_perm):
|
||||||
|
|
@ -2183,6 +2185,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
database = query.database
|
database = query.database
|
||||||
|
|
||||||
database = cast("Database", database)
|
database = cast("Database", database)
|
||||||
|
default_catalog = database.get_default_catalog()
|
||||||
|
|
||||||
if self.can_access_database(database):
|
if self.can_access_database(database):
|
||||||
return
|
return
|
||||||
|
|
@ -2196,19 +2199,19 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
|
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
|
||||||
# inspector to read it.
|
# inspector to read it.
|
||||||
default_schema = database.get_default_schema_for_query(query)
|
default_schema = database.get_default_schema_for_query(query)
|
||||||
# Determining the default catalog is much easier, because DB engine
|
|
||||||
# specs need explicit support for catalogs.
|
|
||||||
default_catalog = database.get_default_catalog()
|
|
||||||
tables = {
|
tables = {
|
||||||
Table(
|
Table(
|
||||||
table_.table,
|
table_.table,
|
||||||
table_.schema or default_schema,
|
table_.schema or default_schema,
|
||||||
table_.catalog or default_catalog,
|
table_.catalog or query.catalog or default_catalog,
|
||||||
)
|
)
|
||||||
for table_ in extract_tables_from_jinja_sql(query.sql, database)
|
for table_ in extract_tables_from_jinja_sql(query.sql, database)
|
||||||
}
|
}
|
||||||
elif table:
|
elif table:
|
||||||
tables = {table}
|
# Make sure table has the default catalog, if not specified.
|
||||||
|
tables = {
|
||||||
|
Table(table.table, table.schema, table.catalog or default_catalog)
|
||||||
|
}
|
||||||
|
|
||||||
denied = set()
|
denied = set()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -284,6 +284,7 @@ class SqlLabRestApi(BaseSupersetApi):
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"row_count": row_count,
|
"row_count": row_count,
|
||||||
"database": query.database.name,
|
"database": query.database.name,
|
||||||
|
"catalog": query.catalog,
|
||||||
"schema": query.schema,
|
"schema": query.schema,
|
||||||
"sql": query.sql,
|
"sql": query.sql,
|
||||||
"exported_format": "csv",
|
"exported_format": "csv",
|
||||||
|
|
|
||||||
|
|
@ -125,6 +125,8 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
|
||||||
def set_database(self, database: Database) -> None:
|
def set_database(self, database: Database) -> None:
|
||||||
self._validate_db(database)
|
self._validate_db(database)
|
||||||
self.database = database
|
self.database = database
|
||||||
|
if self.catalog is None:
|
||||||
|
self.catalog = database.get_default_catalog()
|
||||||
if self.select_as_cta:
|
if self.select_as_cta:
|
||||||
schema_name = self._get_ctas_target_schema_name(database)
|
schema_name = self._get_ctas_target_schema_name(database)
|
||||||
self.create_table_as_select.target_schema_name = schema_name # type: ignore
|
self.create_table_as_select.target_schema_name = schema_name # type: ignore
|
||||||
|
|
|
||||||
|
|
@ -239,6 +239,7 @@ class TableSchemaView(BaseSupersetView):
|
||||||
db.session.query(TableSchema).filter(
|
db.session.query(TableSchema).filter(
|
||||||
TableSchema.tab_state_id == table["queryEditorId"],
|
TableSchema.tab_state_id == table["queryEditorId"],
|
||||||
TableSchema.database_id == table["dbId"],
|
TableSchema.database_id == table["dbId"],
|
||||||
|
TableSchema.catalog == table["catalog"],
|
||||||
TableSchema.schema == table["schema"],
|
TableSchema.schema == table["schema"],
|
||||||
TableSchema.table == table["name"],
|
TableSchema.table == table["name"],
|
||||||
).delete(synchronize_session=False)
|
).delete(synchronize_session=False)
|
||||||
|
|
@ -246,6 +247,7 @@ class TableSchemaView(BaseSupersetView):
|
||||||
table_schema = TableSchema(
|
table_schema = TableSchema(
|
||||||
tab_state_id=table["queryEditorId"],
|
tab_state_id=table["queryEditorId"],
|
||||||
database_id=table["dbId"],
|
database_id=table["dbId"],
|
||||||
|
catalog=table["catalog"],
|
||||||
schema=table["schema"],
|
schema=table["schema"],
|
||||||
table=table["name"],
|
table=table["name"],
|
||||||
description=json.dumps(table),
|
description=json.dumps(table),
|
||||||
|
|
|
||||||
|
|
@ -1563,34 +1563,6 @@ class TestDatabaseApi(SupersetTestCase):
|
||||||
rv = self.client.get(uri)
|
rv = self.client.get(uri)
|
||||||
self.assertEqual(rv.status_code, 404)
|
self.assertEqual(rv.status_code, 404)
|
||||||
|
|
||||||
def test_get_select_star_datasource_access(self):
|
|
||||||
"""
|
|
||||||
Database API: Test get select star with datasource access
|
|
||||||
"""
|
|
||||||
table = SqlaTable(
|
|
||||||
schema="main", table_name="ab_permission", database=get_main_database()
|
|
||||||
)
|
|
||||||
db.session.add(table)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
tmp_table_perm = security_manager.find_permission_view_menu(
|
|
||||||
"datasource_access", table.get_perm()
|
|
||||||
)
|
|
||||||
gamma_role = security_manager.find_role("Gamma")
|
|
||||||
security_manager.add_permission_role(gamma_role, tmp_table_perm)
|
|
||||||
|
|
||||||
self.login(GAMMA_USERNAME)
|
|
||||||
main_db = get_main_database()
|
|
||||||
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
|
|
||||||
rv = self.client.get(uri)
|
|
||||||
self.assertEqual(rv.status_code, 200)
|
|
||||||
|
|
||||||
# rollback changes
|
|
||||||
security_manager.del_permission_role(gamma_role, tmp_table_perm)
|
|
||||||
db.session.delete(table)
|
|
||||||
db.session.delete(main_db)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def test_get_select_star_not_found_database(self):
|
def test_get_select_star_not_found_database(self):
|
||||||
"""
|
"""
|
||||||
Database API: Test get select star not found database
|
Database API: Test get select star not found database
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,6 @@ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||||
from superset.extensions import db, security_manager
|
from superset.extensions import db, security_manager
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.models.slice import Slice
|
from superset.models.slice import Slice
|
||||||
from superset.sql_parse import Table
|
|
||||||
from superset.utils import json
|
from superset.utils import json
|
||||||
from superset.utils.core import backend, get_example_default_schema
|
from superset.utils.core import backend, get_example_default_schema
|
||||||
from superset.utils.database import get_example_database, get_main_database
|
from superset.utils.database import get_example_database, get_main_database
|
||||||
|
|
@ -676,57 +675,6 @@ class TestDatasetApi(SupersetTestCase):
|
||||||
expected_result = {"message": {"owners": ["Owners are invalid"]}}
|
expected_result = {"message": {"owners": ["Owners are invalid"]}}
|
||||||
assert data == expected_result
|
assert data == expected_result
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
|
||||||
def test_create_dataset_validate_uniqueness(self):
|
|
||||||
"""
|
|
||||||
Dataset API: Test create dataset validate table uniqueness
|
|
||||||
"""
|
|
||||||
|
|
||||||
energy_usage_ds = self.get_energy_usage_dataset()
|
|
||||||
self.login(ADMIN_USERNAME)
|
|
||||||
table_data = {
|
|
||||||
"database": energy_usage_ds.database_id,
|
|
||||||
"table_name": energy_usage_ds.table_name,
|
|
||||||
}
|
|
||||||
if schema := get_example_default_schema():
|
|
||||||
table_data["schema"] = schema
|
|
||||||
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
|
|
||||||
assert rv.status_code == 422
|
|
||||||
data = json.loads(rv.data.decode("utf-8"))
|
|
||||||
assert data == {
|
|
||||||
"message": {
|
|
||||||
"table": [
|
|
||||||
f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
|
||||||
def test_create_dataset_with_sql_validate_uniqueness(self):
|
|
||||||
"""
|
|
||||||
Dataset API: Test create dataset with sql
|
|
||||||
"""
|
|
||||||
|
|
||||||
energy_usage_ds = self.get_energy_usage_dataset()
|
|
||||||
self.login(ADMIN_USERNAME)
|
|
||||||
table_data = {
|
|
||||||
"database": energy_usage_ds.database_id,
|
|
||||||
"table_name": energy_usage_ds.table_name,
|
|
||||||
"sql": "select * from energy_usage",
|
|
||||||
}
|
|
||||||
if schema := get_example_default_schema():
|
|
||||||
table_data["schema"] = schema
|
|
||||||
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
|
|
||||||
assert rv.status_code == 422
|
|
||||||
data = json.loads(rv.data.decode("utf-8"))
|
|
||||||
assert data == {
|
|
||||||
"message": {
|
|
||||||
"table": [
|
|
||||||
f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||||
def test_create_dataset_with_sql(self):
|
def test_create_dataset_with_sql(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1455,27 +1403,6 @@ class TestDatasetApi(SupersetTestCase):
|
||||||
db.session.delete(dataset)
|
db.session.delete(dataset)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def test_update_dataset_item_uniqueness(self):
|
|
||||||
"""
|
|
||||||
Dataset API: Test update dataset uniqueness
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataset = self.insert_default_dataset()
|
|
||||||
self.login(ADMIN_USERNAME)
|
|
||||||
ab_user = self.insert_dataset(
|
|
||||||
"ab_user", [self.get_user("admin").id], get_main_database()
|
|
||||||
)
|
|
||||||
table_data = {"table_name": "ab_user"}
|
|
||||||
uri = f"api/v1/dataset/{dataset.id}"
|
|
||||||
rv = self.put_assert_metric(uri, table_data, "put")
|
|
||||||
data = json.loads(rv.data.decode("utf-8"))
|
|
||||||
assert rv.status_code == 422
|
|
||||||
expected_response = {"message": {"table": ["Dataset ab_user already exists"]}}
|
|
||||||
assert data == expected_response
|
|
||||||
db.session.delete(dataset)
|
|
||||||
db.session.delete(ab_user)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
@patch("superset.daos.dataset.DatasetDAO.update")
|
@patch("superset.daos.dataset.DatasetDAO.update")
|
||||||
def test_update_dataset_sqlalchemy_error(self, mock_dao_update):
|
def test_update_dataset_sqlalchemy_error(self, mock_dao_update):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from superset.db_engine_specs import load_engine_specs
|
||||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
|
from superset.utils.core import backend
|
||||||
from superset.utils.database import get_example_database
|
from superset.utils.database import get_example_database
|
||||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||||
from tests.integration_tests.fixtures.certificates import ssl_certificate
|
from tests.integration_tests.fixtures.certificates import ssl_certificate
|
||||||
|
|
@ -525,11 +526,10 @@ def test_get_catalog_names(app_context: AppContext) -> None:
|
||||||
"""
|
"""
|
||||||
Test the ``get_catalog_names`` method.
|
Test the ``get_catalog_names`` method.
|
||||||
"""
|
"""
|
||||||
database = get_example_database()
|
if backend() != "postgresql":
|
||||||
|
|
||||||
if database.backend != "postgresql":
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
database = get_example_database()
|
||||||
with database.get_inspector() as inspector:
|
with database.get_inspector() as inspector:
|
||||||
assert PostgresEngineSpec.get_catalog_names(database, inspector) == {
|
assert PostgresEngineSpec.get_catalog_names(database, inspector) == {
|
||||||
"postgres",
|
"postgres",
|
||||||
|
|
|
||||||
|
|
@ -1633,7 +1633,10 @@ class TestSecurityManager(SupersetTestCase):
|
||||||
@patch("superset.security.SupersetSecurityManager.can_access")
|
@patch("superset.security.SupersetSecurityManager.can_access")
|
||||||
def test_raise_for_access_query(self, mock_can_access, mock_is_owner):
|
def test_raise_for_access_query(self, mock_can_access, mock_is_owner):
|
||||||
query = Mock(
|
query = Mock(
|
||||||
database=get_example_database(), schema="bar", sql="SELECT * FROM foo"
|
database=get_example_database(),
|
||||||
|
schema="bar",
|
||||||
|
sql="SELECT * FROM foo",
|
||||||
|
catalog=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_can_access.return_value = True
|
mock_can_access.return_value = True
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,9 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric
|
||||||
from superset.constants import EMPTY_STRING, NULL_STRING
|
from superset.constants import EMPTY_STRING, NULL_STRING
|
||||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
|
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
|
||||||
from superset.db_engine_specs.druid import DruidEngineSpec
|
from superset.db_engine_specs.druid import DruidEngineSpec
|
||||||
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException # noqa: F401
|
from superset.exceptions import (
|
||||||
|
QueryObjectValidationError,
|
||||||
|
) # noqa: F401
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.utils.core import (
|
from superset.utils.core import (
|
||||||
AdhocMetricExpressionType,
|
AdhocMetricExpressionType,
|
||||||
|
|
@ -160,7 +162,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||||
query_obj = dict(**base_query_obj, extras={})
|
query_obj = dict(**base_query_obj, extras={})
|
||||||
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
|
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
|
||||||
self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
|
self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
|
||||||
assert extra_cache_keys == [1, "abc", "abc@test.com"]
|
assert set(extra_cache_keys) == {1, "abc", "abc@test.com"}
|
||||||
|
|
||||||
# Table with Jinja callable disabled.
|
# Table with Jinja callable disabled.
|
||||||
table2 = SqlaTable(
|
table2 = SqlaTable(
|
||||||
|
|
|
||||||
|
|
@ -255,11 +255,11 @@ def test_dataset_uniqueness(session: Session) -> None:
|
||||||
|
|
||||||
# but the DAO enforces application logic for uniqueness
|
# but the DAO enforces application logic for uniqueness
|
||||||
assert not DatasetDAO.validate_uniqueness(
|
assert not DatasetDAO.validate_uniqueness(
|
||||||
database.id,
|
database,
|
||||||
Table("table", "schema", None),
|
Table("table", "schema", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert DatasetDAO.validate_uniqueness(
|
assert DatasetDAO.validate_uniqueness(
|
||||||
database.id,
|
database,
|
||||||
Table("table", "schema", "some_catalog"),
|
Table("table", "schema", "some_catalog"),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
DatasetDAO.validate_update_uniqueness(
|
DatasetDAO.validate_update_uniqueness(
|
||||||
database_id=database.id,
|
database=database,
|
||||||
table=Table(dataset1.table_name, dataset1.schema),
|
table=Table(dataset1.table_name, dataset1.schema),
|
||||||
dataset_id=dataset1.id,
|
dataset_id=dataset1.id,
|
||||||
)
|
)
|
||||||
|
|
@ -62,7 +62,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
DatasetDAO.validate_update_uniqueness(
|
DatasetDAO.validate_update_uniqueness(
|
||||||
database_id=database.id,
|
database=database,
|
||||||
table=Table(dataset1.table_name, dataset2.schema),
|
table=Table(dataset1.table_name, dataset2.schema),
|
||||||
dataset_id=dataset1.id,
|
dataset_id=dataset1.id,
|
||||||
)
|
)
|
||||||
|
|
@ -71,7 +71,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
DatasetDAO.validate_update_uniqueness(
|
DatasetDAO.validate_update_uniqueness(
|
||||||
database_id=database.id,
|
database=database,
|
||||||
table=Table(dataset1.table_name),
|
table=Table(dataset1.table_name),
|
||||||
dataset_id=dataset1.id,
|
dataset_id=dataset1.id,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -366,6 +366,7 @@ def test_raise_for_access_query_default_schema(
|
||||||
database.get_default_catalog.return_value = None
|
database.get_default_catalog.return_value = None
|
||||||
database.get_default_schema_for_query.return_value = "public"
|
database.get_default_schema_for_query.return_value = "public"
|
||||||
query = mocker.MagicMock()
|
query = mocker.MagicMock()
|
||||||
|
query.catalog = None
|
||||||
query.database = database
|
query.database = database
|
||||||
query.sql = "SELECT * FROM ab_user"
|
query.sql = "SELECT * FROM ab_user"
|
||||||
|
|
||||||
|
|
@ -421,6 +422,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) ->
|
||||||
database.get_default_catalog.return_value = None
|
database.get_default_catalog.return_value = None
|
||||||
database.get_default_schema_for_query.return_value = "public"
|
database.get_default_schema_for_query.return_value = "public"
|
||||||
query = mocker.MagicMock()
|
query = mocker.MagicMock()
|
||||||
|
query.catalog = None
|
||||||
query.database = database
|
query.database = database
|
||||||
query.sql = "SELECT * FROM {% if True %}ab_user{% endif %} WHERE 1=1"
|
query.sql = "SELECT * FROM {% if True %}ab_user{% endif %} WHERE 1=1"
|
||||||
|
|
||||||
|
|
@ -434,7 +436,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) ->
|
||||||
viz=None,
|
viz=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
get_table_access_error_object.assert_called_with({Table("ab_user", "public")})
|
get_table_access_error_object.assert_called_with({Table("ab_user", "public", None)})
|
||||||
|
|
||||||
|
|
||||||
def test_raise_for_access_chart_for_datasource_permission(
|
def test_raise_for_access_chart_for_datasource_permission(
|
||||||
|
|
@ -736,6 +738,7 @@ def test_raise_for_access_catalog(
|
||||||
database.get_default_catalog.return_value = "db1"
|
database.get_default_catalog.return_value = "db1"
|
||||||
database.get_default_schema_for_query.return_value = "public"
|
database.get_default_schema_for_query.return_value = "public"
|
||||||
query = mocker.MagicMock()
|
query = mocker.MagicMock()
|
||||||
|
query.catalog = "db1"
|
||||||
query.database = database
|
query.database = database
|
||||||
query.sql = "SELECT * FROM ab_user"
|
query.sql = "SELECT * FROM ab_user"
|
||||||
|
|
||||||
|
|
@ -776,7 +779,8 @@ def test_get_datasources_accessible_by_user_schema_access(
|
||||||
database.database_name = "db1"
|
database.database_name = "db1"
|
||||||
database.get_default_catalog.return_value = "catalog2"
|
database.get_default_catalog.return_value = "catalog2"
|
||||||
|
|
||||||
can_access = mocker.patch.object(sm, "can_access", return_value=True)
|
# False for catalog_access, True for schema_access
|
||||||
|
can_access = mocker.patch.object(sm, "can_access", side_effect=[False, True])
|
||||||
|
|
||||||
datasource_names = [
|
datasource_names = [
|
||||||
DatasourceName("table1", "schema1", "catalog2"),
|
DatasourceName("table1", "schema1", "catalog2"),
|
||||||
|
|
@ -795,7 +799,12 @@ def test_get_datasources_accessible_by_user_schema_access(
|
||||||
|
|
||||||
# Even though we passed `catalog=None,` the schema check uses the default catalog
|
# Even though we passed `catalog=None,` the schema check uses the default catalog
|
||||||
# when building the schema permission, since the DB supports catalog.
|
# when building the schema permission, since the DB supports catalog.
|
||||||
can_access.assert_called_with("schema_access", "[db1].[catalog2].[schema1]")
|
can_access.assert_has_calls(
|
||||||
|
[
|
||||||
|
mocker.call("catalog_access", "[db1].[catalog2]"),
|
||||||
|
mocker.call("schema_access", "[db1].[catalog2].[schema1]"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_catalogs_accessible_by_user_schema_access(
|
def test_get_catalogs_accessible_by_user_schema_access(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue