fix: handle empty catalog when DB supports them (#29840)

This commit is contained in:
Beto Dealmeida 2024-08-13 10:08:43 -04:00 committed by GitHub
parent 9f5eb899e8
commit 39209c2b40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 100 additions and 148 deletions

View File

@ -52,7 +52,7 @@ GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
def fetch_files_github_api(url: str): # type: ignore def fetch_files_github_api(url: str): # type: ignore
"""Fetches data using GitHub API.""" """Fetches data using GitHub API."""
req = Request(url) req = Request(url)
req.add_header("Authorization", f"token {GITHUB_TOKEN}") req.add_header("Authorization", f"Bearer {GITHUB_TOKEN}")
req.add_header("Accept", "application/vnd.github.v3+json") req.add_header("Accept", "application/vnd.github.v3+json")
print(f"Fetching from {url}") print(f"Fetching from {url}")

View File

@ -32,6 +32,10 @@ class Datasource(Schema):
datasource_name = fields.String( datasource_name = fields.String(
metadata={"description": datasource_name_description}, metadata={"description": datasource_name_description},
) )
catalog = fields.String(
allow_none=True,
metadata={"description": "Datasource catalog"},
)
schema = fields.String( schema = fields.String(
metadata={"description": "Datasource schema"}, metadata={"description": "Datasource schema"},
) )

View File

@ -54,23 +54,28 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []
database_id = self._properties["database"] database_id = self._properties["database"]
schema = self._properties.get("schema")
catalog = self._properties.get("catalog") catalog = self._properties.get("catalog")
schema = self._properties.get("schema")
table_name = self._properties["table_name"]
sql = self._properties.get("sql") sql = self._properties.get("sql")
owner_ids: Optional[list[int]] = self._properties.get("owners") owner_ids: Optional[list[int]] = self._properties.get("owners")
table = Table(self._properties["table_name"], schema, catalog)
# Validate uniqueness
if not DatasetDAO.validate_uniqueness(database_id, table):
exceptions.append(DatasetExistsValidationError(table))
# Validate/Populate database # Validate/Populate database
database = DatasetDAO.get_database_by_id(database_id) database = DatasetDAO.get_database_by_id(database_id)
if not database: if not database:
exceptions.append(DatabaseNotFoundValidationError()) exceptions.append(DatabaseNotFoundValidationError())
self._properties["database"] = database self._properties["database"] = database
# Validate uniqueness
if database:
if not catalog:
catalog = self._properties["catalog"] = database.get_default_catalog()
table = Table(table_name, schema, catalog)
if not DatasetDAO.validate_uniqueness(database, table):
exceptions.append(DatasetExistsValidationError(table))
# Validate table exists on dataset if sql is not provided # Validate table exists on dataset if sql is not provided
# This should be validated when the dataset is physical # This should be validated when the dataset is physical
if ( if (

View File

@ -166,7 +166,7 @@ def import_dataset(
try: try:
table_exists = dataset.database.has_table( table_exists = dataset.database.has_table(
Table(dataset.table_name, dataset.schema), Table(dataset.table_name, dataset.schema, dataset.catalog),
) )
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
# MySQL doesn't play nice with GSheets table names # MySQL doesn't play nice with GSheets table names

View File

@ -79,10 +79,12 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []
owner_ids: Optional[list[int]] = self._properties.get("owners") owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate/populate model exists # Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id) self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model: if not self._model:
raise DatasetNotFoundError() raise DatasetNotFoundError()
# Check ownership # Check ownership
try: try:
security_manager.raise_for_ownership(self._model) security_manager.raise_for_ownership(self._model)
@ -91,22 +93,30 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
database_id = self._properties.get("database") database_id = self._properties.get("database")
catalog = self._properties.get("catalog")
if not catalog:
catalog = self._properties["catalog"] = (
self._model.database.get_default_catalog()
)
table = Table( table = Table(
self._properties.get("table_name"), # type: ignore self._properties.get("table_name"), # type: ignore
self._properties.get("schema"), self._properties.get("schema"),
self._properties.get("catalog"), catalog,
) )
# Validate uniqueness # Validate uniqueness
if not DatasetDAO.validate_update_uniqueness( if not DatasetDAO.validate_update_uniqueness(
self._model.database_id, self._model.database,
table, table,
self._model_id, self._model_id,
): ):
exceptions.append(DatasetExistsValidationError(table)) exceptions.append(DatasetExistsValidationError(table))
# Validate/Populate database not allowed to change # Validate/Populate database not allowed to change
if database_id and database_id != self._model: if database_id and database_id != self._model:
exceptions.append(DatabaseChangeValidationError()) exceptions.append(DatabaseChangeValidationError())
# Validate/Populate owner # Validate/Populate owner
try: try:
owners = self.compute_owners( owners = self.compute_owners(
@ -116,6 +126,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
self._properties["owners"] = owners self._properties["owners"] = owners
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)
# Validate columns # Validate columns
if columns := self._properties.get("columns"): if columns := self._properties.get("columns"):
self._validate_columns(columns, exceptions) self._validate_columns(columns, exceptions)

View File

@ -461,9 +461,11 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
) )
else: else:
_columns = [ _columns = [
utils.get_column_name(column_) (
if utils.is_adhoc_column(column_) utils.get_column_name(column_)
else column_ if utils.is_adhoc_column(column_)
else column_
)
for column_param in COLUMN_FORM_DATA_PARAMS for column_param in COLUMN_FORM_DATA_PARAMS
for column_ in utils.as_list(form_data.get(column_param) or []) for column_ in utils.as_list(form_data.get(column_param) or [])
] ]
@ -1963,7 +1965,7 @@ class SqlaTable(
if self.has_extra_cache_key_calls(query_obj): if self.has_extra_cache_key_calls(query_obj):
sqla_query = self.get_sqla_query(**query_obj) sqla_query = self.get_sqla_query(**query_obj)
extra_cache_keys += sqla_query.extra_cache_keys extra_cache_keys += sqla_query.extra_cache_keys
return extra_cache_keys return list(set(extra_cache_keys))
@property @property
def quote_identifier(self) -> Callable[[str], str]: def quote_identifier(self) -> Callable[[str], str]:

View File

@ -84,15 +84,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
@staticmethod @staticmethod
def validate_uniqueness( def validate_uniqueness(
database_id: int, database: Database,
table: Table, table: Table,
dataset_id: int | None = None, dataset_id: int | None = None,
) -> bool: ) -> bool:
# The catalog might not be set even if the database supports catalogs, in case
# multi-catalog is disabled.
catalog = table.catalog or database.get_default_catalog()
dataset_query = db.session.query(SqlaTable).filter( dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == table.table, SqlaTable.table_name == table.table,
SqlaTable.schema == table.schema, SqlaTable.schema == table.schema,
SqlaTable.catalog == table.catalog, SqlaTable.catalog == catalog,
SqlaTable.database_id == database_id, SqlaTable.database_id == database.id,
) )
if dataset_id: if dataset_id:
@ -103,15 +107,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
@staticmethod @staticmethod
def validate_update_uniqueness( def validate_update_uniqueness(
database_id: int, database: Database,
table: Table, table: Table,
dataset_id: int, dataset_id: int,
) -> bool: ) -> bool:
# The catalog might not be set even if the database supports catalogs, in case
# multi-catalog is disabled.
catalog = table.catalog or database.get_default_catalog()
dataset_query = db.session.query(SqlaTable).filter( dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == table.table, SqlaTable.table_name == table.table,
SqlaTable.database_id == database_id, SqlaTable.database_id == database.id,
SqlaTable.schema == table.schema, SqlaTable.schema == table.schema,
SqlaTable.catalog == table.catalog, SqlaTable.catalog == catalog,
SqlaTable.id != dataset_id, SqlaTable.id != dataset_id,
) )
return not db.session.query(dataset_query.exists()).scalar() return not db.session.query(dataset_query.exists()).scalar()

View File

@ -1159,7 +1159,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
self.incr_stats("init", self.select_star.__name__) self.incr_stats("init", self.select_star.__name__)
try: try:
result = database.select_star( result = database.select_star(
Table(table_name, schema_name), Table(table_name, schema_name, database.get_default_catalog()),
latest_partition=True, latest_partition=True,
) )
except NoSuchTableError: except NoSuchTableError:

View File

@ -565,7 +565,7 @@ class NoOpTemplateProcessor(BaseTemplateProcessor):
""" """
Makes processing a template a noop Makes processing a template a noop
""" """
return sql return str(sql)
class PrestoTemplateProcessor(JinjaTemplateProcessor): class PrestoTemplateProcessor(JinjaTemplateProcessor):

View File

@ -491,7 +491,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
g.user.id, g.user.id,
self.db_engine_spec, self.db_engine_spec,
) )
if hasattr(g, "user") and hasattr(g.user, "id") and oauth2_config if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
else None else None
) )
# If using MySQL or Presto for example, will set url.username # If using MySQL or Presto for example, will set url.username

View File

@ -369,6 +369,7 @@ def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) ->
ds = db.session.query(src_class).filter_by(id=int(id_)).first() ds = db.session.query(src_class).filter_by(id=int(id_)).first()
if ds: if ds:
target.perm = ds.perm target.perm = ds.perm
target.catalog_perm = ds.catalog_perm
target.schema_perm = ds.schema_perm target.schema_perm = ds.schema_perm

View File

@ -774,6 +774,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
default_catalog = database.get_default_catalog()
catalog = catalog or default_catalog
if hierarchical and ( if hierarchical and (
self.can_access_database(database) self.can_access_database(database)
or (catalog and self.can_access_catalog(database, catalog)) or (catalog and self.can_access_catalog(database, catalog))
@ -783,7 +786,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# schema_access # schema_access
accessible_schemas: set[str] = set() accessible_schemas: set[str] = set()
schema_access = self.user_view_menu_names("schema_access") schema_access = self.user_view_menu_names("schema_access")
default_catalog = database.get_default_catalog()
default_schema = database.get_default_schema(default_catalog) default_schema = database.get_default_schema(default_catalog)
for perm in schema_access: for perm in schema_access:
@ -800,7 +802,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# [database].[catalog].[schema] matches when the catalog is equal to the # [database].[catalog].[schema] matches when the catalog is equal to the
# requested catalog or, when no catalog specified, it's equal to the default # requested catalog or, when no catalog specified, it's equal to the default
# catalog. # catalog.
elif len(parts) == 3 and parts[1] == (catalog or default_catalog): elif len(parts) == 3 and parts[1] == catalog:
accessible_schemas.add(parts[2]) accessible_schemas.add(parts[2])
# datasource_access # datasource_access
@ -906,16 +908,16 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if self.can_access_database(database): if self.can_access_database(database):
return datasource_names return datasource_names
catalog = catalog or database.get_default_catalog()
if catalog: if catalog:
catalog_perm = self.get_catalog_perm(database.database_name, catalog) catalog_perm = self.get_catalog_perm(database.database_name, catalog)
if catalog_perm and self.can_access("catalog_access", catalog_perm): if catalog_perm and self.can_access("catalog_access", catalog_perm):
return datasource_names return datasource_names
if schema: if schema:
default_catalog = database.get_default_catalog()
schema_perm = self.get_schema_perm( schema_perm = self.get_schema_perm(
database.database_name, database.database_name,
catalog or default_catalog, catalog,
schema, schema,
) )
if schema_perm and self.can_access("schema_access", schema_perm): if schema_perm and self.can_access("schema_access", schema_perm):
@ -2183,6 +2185,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
database = query.database database = query.database
database = cast("Database", database) database = cast("Database", database)
default_catalog = database.get_default_catalog()
if self.can_access_database(database): if self.can_access_database(database):
return return
@ -2196,19 +2199,19 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
# inspector to read it. # inspector to read it.
default_schema = database.get_default_schema_for_query(query) default_schema = database.get_default_schema_for_query(query)
# Determining the default catalog is much easier, because DB engine
# specs need explicit support for catalogs.
default_catalog = database.get_default_catalog()
tables = { tables = {
Table( Table(
table_.table, table_.table,
table_.schema or default_schema, table_.schema or default_schema,
table_.catalog or default_catalog, table_.catalog or query.catalog or default_catalog,
) )
for table_ in extract_tables_from_jinja_sql(query.sql, database) for table_ in extract_tables_from_jinja_sql(query.sql, database)
} }
elif table: elif table:
tables = {table} # Make sure table has the default catalog, if not specified.
tables = {
Table(table.table, table.schema, table.catalog or default_catalog)
}
denied = set() denied = set()

View File

@ -284,6 +284,7 @@ class SqlLabRestApi(BaseSupersetApi):
"client_id": client_id, "client_id": client_id,
"row_count": row_count, "row_count": row_count,
"database": query.database.name, "database": query.database.name,
"catalog": query.catalog,
"schema": query.schema, "schema": query.schema,
"sql": query.sql, "sql": query.sql,
"exported_format": "csv", "exported_format": "csv",

View File

@ -125,6 +125,8 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
def set_database(self, database: Database) -> None: def set_database(self, database: Database) -> None:
self._validate_db(database) self._validate_db(database)
self.database = database self.database = database
if self.catalog is None:
self.catalog = database.get_default_catalog()
if self.select_as_cta: if self.select_as_cta:
schema_name = self._get_ctas_target_schema_name(database) schema_name = self._get_ctas_target_schema_name(database)
self.create_table_as_select.target_schema_name = schema_name # type: ignore self.create_table_as_select.target_schema_name = schema_name # type: ignore

View File

@ -239,6 +239,7 @@ class TableSchemaView(BaseSupersetView):
db.session.query(TableSchema).filter( db.session.query(TableSchema).filter(
TableSchema.tab_state_id == table["queryEditorId"], TableSchema.tab_state_id == table["queryEditorId"],
TableSchema.database_id == table["dbId"], TableSchema.database_id == table["dbId"],
TableSchema.catalog == table["catalog"],
TableSchema.schema == table["schema"], TableSchema.schema == table["schema"],
TableSchema.table == table["name"], TableSchema.table == table["name"],
).delete(synchronize_session=False) ).delete(synchronize_session=False)
@ -246,6 +247,7 @@ class TableSchemaView(BaseSupersetView):
table_schema = TableSchema( table_schema = TableSchema(
tab_state_id=table["queryEditorId"], tab_state_id=table["queryEditorId"],
database_id=table["dbId"], database_id=table["dbId"],
catalog=table["catalog"],
schema=table["schema"], schema=table["schema"],
table=table["name"], table=table["name"],
description=json.dumps(table), description=json.dumps(table),

View File

@ -1563,34 +1563,6 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri) rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404) self.assertEqual(rv.status_code, 404)
def test_get_select_star_datasource_access(self):
"""
Database API: Test get select star with datasource access
"""
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
db.session.add(table)
db.session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)
self.login(GAMMA_USERNAME)
main_db = get_main_database()
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()
def test_get_select_star_not_found_database(self): def test_get_select_star_not_found_database(self):
""" """
Database API: Test get select star not found database Database API: Test get select star not found database

View File

@ -36,7 +36,6 @@ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.extensions import db, security_manager from superset.extensions import db, security_manager
from superset.models.core import Database from superset.models.core import Database
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.sql_parse import Table
from superset.utils import json from superset.utils import json
from superset.utils.core import backend, get_example_default_schema from superset.utils.core import backend, get_example_default_schema
from superset.utils.database import get_example_database, get_main_database from superset.utils.database import get_example_database, get_main_database
@ -676,57 +675,6 @@ class TestDatasetApi(SupersetTestCase):
expected_result = {"message": {"owners": ["Owners are invalid"]}} expected_result = {"message": {"owners": ["Owners are invalid"]}}
assert data == expected_result assert data == expected_result
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_create_dataset_validate_uniqueness(self):
"""
Dataset API: Test create dataset validate table uniqueness
"""
energy_usage_ds = self.get_energy_usage_dataset()
self.login(ADMIN_USERNAME)
table_data = {
"database": energy_usage_ds.database_id,
"table_name": energy_usage_ds.table_name,
}
if schema := get_example_default_schema():
table_data["schema"] = schema
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
assert rv.status_code == 422
data = json.loads(rv.data.decode("utf-8"))
assert data == {
"message": {
"table": [
f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists"
]
}
}
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_create_dataset_with_sql_validate_uniqueness(self):
"""
Dataset API: Test create dataset with sql
"""
energy_usage_ds = self.get_energy_usage_dataset()
self.login(ADMIN_USERNAME)
table_data = {
"database": energy_usage_ds.database_id,
"table_name": energy_usage_ds.table_name,
"sql": "select * from energy_usage",
}
if schema := get_example_default_schema():
table_data["schema"] = schema
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
assert rv.status_code == 422
data = json.loads(rv.data.decode("utf-8"))
assert data == {
"message": {
"table": [
f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists"
]
}
}
@pytest.mark.usefixtures("load_energy_table_with_slice") @pytest.mark.usefixtures("load_energy_table_with_slice")
def test_create_dataset_with_sql(self): def test_create_dataset_with_sql(self):
""" """
@ -1455,27 +1403,6 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(dataset) db.session.delete(dataset)
db.session.commit() db.session.commit()
def test_update_dataset_item_uniqueness(self):
"""
Dataset API: Test update dataset uniqueness
"""
dataset = self.insert_default_dataset()
self.login(ADMIN_USERNAME)
ab_user = self.insert_dataset(
"ab_user", [self.get_user("admin").id], get_main_database()
)
table_data = {"table_name": "ab_user"}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, table_data, "put")
data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
expected_response = {"message": {"table": ["Dataset ab_user already exists"]}}
assert data == expected_response
db.session.delete(dataset)
db.session.delete(ab_user)
db.session.commit()
@patch("superset.daos.dataset.DatasetDAO.update") @patch("superset.daos.dataset.DatasetDAO.update")
def test_update_dataset_sqlalchemy_error(self, mock_dao_update): def test_update_dataset_sqlalchemy_error(self, mock_dao_update):
""" """

View File

@ -25,6 +25,7 @@ from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.utils.core import backend
from superset.utils.database import get_example_database from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.certificates import ssl_certificate
@ -525,11 +526,10 @@ def test_get_catalog_names(app_context: AppContext) -> None:
""" """
Test the ``get_catalog_names`` method. Test the ``get_catalog_names`` method.
""" """
database = get_example_database() if backend() != "postgresql":
if database.backend != "postgresql":
return return
database = get_example_database()
with database.get_inspector() as inspector: with database.get_inspector() as inspector:
assert PostgresEngineSpec.get_catalog_names(database, inspector) == { assert PostgresEngineSpec.get_catalog_names(database, inspector) == {
"postgres", "postgres",

View File

@ -1633,7 +1633,10 @@ class TestSecurityManager(SupersetTestCase):
@patch("superset.security.SupersetSecurityManager.can_access") @patch("superset.security.SupersetSecurityManager.can_access")
def test_raise_for_access_query(self, mock_can_access, mock_is_owner): def test_raise_for_access_query(self, mock_can_access, mock_is_owner):
query = Mock( query = Mock(
database=get_example_database(), schema="bar", sql="SELECT * FROM foo" database=get_example_database(),
schema="bar",
sql="SELECT * FROM foo",
catalog=None,
) )
mock_can_access.return_value = True mock_can_access.return_value = True

View File

@ -34,7 +34,9 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric
from superset.constants import EMPTY_STRING, NULL_STRING from superset.constants import EMPTY_STRING, NULL_STRING
from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.db_engine_specs.druid import DruidEngineSpec from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException # noqa: F401 from superset.exceptions import (
QueryObjectValidationError,
) # noqa: F401
from superset.models.core import Database from superset.models.core import Database
from superset.utils.core import ( from superset.utils.core import (
AdhocMetricExpressionType, AdhocMetricExpressionType,
@ -160,7 +162,7 @@ class TestDatabaseModel(SupersetTestCase):
query_obj = dict(**base_query_obj, extras={}) query_obj = dict(**base_query_obj, extras={})
extra_cache_keys = table1.get_extra_cache_keys(query_obj) extra_cache_keys = table1.get_extra_cache_keys(query_obj)
self.assertTrue(table1.has_extra_cache_key_calls(query_obj)) self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
assert extra_cache_keys == [1, "abc", "abc@test.com"] assert set(extra_cache_keys) == {1, "abc", "abc@test.com"}
# Table with Jinja callable disabled. # Table with Jinja callable disabled.
table2 = SqlaTable( table2 = SqlaTable(

View File

@ -255,11 +255,11 @@ def test_dataset_uniqueness(session: Session) -> None:
# but the DAO enforces application logic for uniqueness # but the DAO enforces application logic for uniqueness
assert not DatasetDAO.validate_uniqueness( assert not DatasetDAO.validate_uniqueness(
database.id, database,
Table("table", "schema", None), Table("table", "schema", None),
) )
assert DatasetDAO.validate_uniqueness( assert DatasetDAO.validate_uniqueness(
database.id, database,
Table("table", "schema", "some_catalog"), Table("table", "schema", "some_catalog"),
) )

View File

@ -53,7 +53,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert ( assert (
DatasetDAO.validate_update_uniqueness( DatasetDAO.validate_update_uniqueness(
database_id=database.id, database=database,
table=Table(dataset1.table_name, dataset1.schema), table=Table(dataset1.table_name, dataset1.schema),
dataset_id=dataset1.id, dataset_id=dataset1.id,
) )
@ -62,7 +62,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert ( assert (
DatasetDAO.validate_update_uniqueness( DatasetDAO.validate_update_uniqueness(
database_id=database.id, database=database,
table=Table(dataset1.table_name, dataset2.schema), table=Table(dataset1.table_name, dataset2.schema),
dataset_id=dataset1.id, dataset_id=dataset1.id,
) )
@ -71,7 +71,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert ( assert (
DatasetDAO.validate_update_uniqueness( DatasetDAO.validate_update_uniqueness(
database_id=database.id, database=database,
table=Table(dataset1.table_name), table=Table(dataset1.table_name),
dataset_id=dataset1.id, dataset_id=dataset1.id,
) )

View File

@ -366,6 +366,7 @@ def test_raise_for_access_query_default_schema(
database.get_default_catalog.return_value = None database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public" database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock() query = mocker.MagicMock()
query.catalog = None
query.database = database query.database = database
query.sql = "SELECT * FROM ab_user" query.sql = "SELECT * FROM ab_user"
@ -421,6 +422,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) ->
database.get_default_catalog.return_value = None database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public" database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock() query = mocker.MagicMock()
query.catalog = None
query.database = database query.database = database
query.sql = "SELECT * FROM {% if True %}ab_user{% endif %} WHERE 1=1" query.sql = "SELECT * FROM {% if True %}ab_user{% endif %} WHERE 1=1"
@ -434,7 +436,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) ->
viz=None, viz=None,
) )
get_table_access_error_object.assert_called_with({Table("ab_user", "public")}) get_table_access_error_object.assert_called_with({Table("ab_user", "public", None)})
def test_raise_for_access_chart_for_datasource_permission( def test_raise_for_access_chart_for_datasource_permission(
@ -736,6 +738,7 @@ def test_raise_for_access_catalog(
database.get_default_catalog.return_value = "db1" database.get_default_catalog.return_value = "db1"
database.get_default_schema_for_query.return_value = "public" database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock() query = mocker.MagicMock()
query.catalog = "db1"
query.database = database query.database = database
query.sql = "SELECT * FROM ab_user" query.sql = "SELECT * FROM ab_user"
@ -776,7 +779,8 @@ def test_get_datasources_accessible_by_user_schema_access(
database.database_name = "db1" database.database_name = "db1"
database.get_default_catalog.return_value = "catalog2" database.get_default_catalog.return_value = "catalog2"
can_access = mocker.patch.object(sm, "can_access", return_value=True) # False for catalog_access, True for schema_access
can_access = mocker.patch.object(sm, "can_access", side_effect=[False, True])
datasource_names = [ datasource_names = [
DatasourceName("table1", "schema1", "catalog2"), DatasourceName("table1", "schema1", "catalog2"),
@ -795,7 +799,12 @@ def test_get_datasources_accessible_by_user_schema_access(
# Even though we passed `catalog=None,` the schema check uses the default catalog # Even though we passed `catalog=None,` the schema check uses the default catalog
# when building the schema permission, since the DB supports catalog. # when building the schema permission, since the DB supports catalog.
can_access.assert_called_with("schema_access", "[db1].[catalog2].[schema1]") can_access.assert_has_calls(
[
mocker.call("catalog_access", "[db1].[catalog2]"),
mocker.call("schema_access", "[db1].[catalog2].[schema1]"),
]
)
def test_get_catalogs_accessible_by_user_schema_access( def test_get_catalogs_accessible_by_user_schema_access(