fix: handle empty catalog when DB supports them (#29840)
This commit is contained in:
parent
9f5eb899e8
commit
39209c2b40
|
|
@ -52,7 +52,7 @@ GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
|
|||
def fetch_files_github_api(url: str): # type: ignore
|
||||
"""Fetches data using GitHub API."""
|
||||
req = Request(url)
|
||||
req.add_header("Authorization", f"token {GITHUB_TOKEN}")
|
||||
req.add_header("Authorization", f"Bearer {GITHUB_TOKEN}")
|
||||
req.add_header("Accept", "application/vnd.github.v3+json")
|
||||
|
||||
print(f"Fetching from {url}")
|
||||
|
|
|
|||
|
|
@ -32,6 +32,10 @@ class Datasource(Schema):
|
|||
datasource_name = fields.String(
|
||||
metadata={"description": datasource_name_description},
|
||||
)
|
||||
catalog = fields.String(
|
||||
allow_none=True,
|
||||
metadata={"description": "Datasource catalog"},
|
||||
)
|
||||
schema = fields.String(
|
||||
metadata={"description": "Datasource schema"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -54,23 +54,28 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
|
|||
def validate(self) -> None:
|
||||
exceptions: list[ValidationError] = []
|
||||
database_id = self._properties["database"]
|
||||
schema = self._properties.get("schema")
|
||||
catalog = self._properties.get("catalog")
|
||||
schema = self._properties.get("schema")
|
||||
table_name = self._properties["table_name"]
|
||||
sql = self._properties.get("sql")
|
||||
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||
|
||||
table = Table(self._properties["table_name"], schema, catalog)
|
||||
|
||||
# Validate uniqueness
|
||||
if not DatasetDAO.validate_uniqueness(database_id, table):
|
||||
exceptions.append(DatasetExistsValidationError(table))
|
||||
|
||||
# Validate/Populate database
|
||||
database = DatasetDAO.get_database_by_id(database_id)
|
||||
if not database:
|
||||
exceptions.append(DatabaseNotFoundValidationError())
|
||||
self._properties["database"] = database
|
||||
|
||||
# Validate uniqueness
|
||||
if database:
|
||||
if not catalog:
|
||||
catalog = self._properties["catalog"] = database.get_default_catalog()
|
||||
|
||||
table = Table(table_name, schema, catalog)
|
||||
|
||||
if not DatasetDAO.validate_uniqueness(database, table):
|
||||
exceptions.append(DatasetExistsValidationError(table))
|
||||
|
||||
# Validate table exists on dataset if sql is not provided
|
||||
# This should be validated when the dataset is physical
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ def import_dataset(
|
|||
|
||||
try:
|
||||
table_exists = dataset.database.has_table(
|
||||
Table(dataset.table_name, dataset.schema),
|
||||
Table(dataset.table_name, dataset.schema, dataset.catalog),
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# MySQL doesn't play nice with GSheets table names
|
||||
|
|
|
|||
|
|
@ -79,10 +79,12 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
|||
def validate(self) -> None:
|
||||
exceptions: list[ValidationError] = []
|
||||
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||
|
||||
# Validate/populate model exists
|
||||
self._model = DatasetDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatasetNotFoundError()
|
||||
|
||||
# Check ownership
|
||||
try:
|
||||
security_manager.raise_for_ownership(self._model)
|
||||
|
|
@ -91,22 +93,30 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
|||
|
||||
database_id = self._properties.get("database")
|
||||
|
||||
catalog = self._properties.get("catalog")
|
||||
if not catalog:
|
||||
catalog = self._properties["catalog"] = (
|
||||
self._model.database.get_default_catalog()
|
||||
)
|
||||
|
||||
table = Table(
|
||||
self._properties.get("table_name"), # type: ignore
|
||||
self._properties.get("schema"),
|
||||
self._properties.get("catalog"),
|
||||
catalog,
|
||||
)
|
||||
|
||||
# Validate uniqueness
|
||||
if not DatasetDAO.validate_update_uniqueness(
|
||||
self._model.database_id,
|
||||
self._model.database,
|
||||
table,
|
||||
self._model_id,
|
||||
):
|
||||
exceptions.append(DatasetExistsValidationError(table))
|
||||
|
||||
# Validate/Populate database not allowed to change
|
||||
if database_id and database_id != self._model:
|
||||
exceptions.append(DatabaseChangeValidationError())
|
||||
|
||||
# Validate/Populate owner
|
||||
try:
|
||||
owners = self.compute_owners(
|
||||
|
|
@ -116,6 +126,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
|||
self._properties["owners"] = owners
|
||||
except ValidationError as ex:
|
||||
exceptions.append(ex)
|
||||
|
||||
# Validate columns
|
||||
if columns := self._properties.get("columns"):
|
||||
self._validate_columns(columns, exceptions)
|
||||
|
|
|
|||
|
|
@ -461,9 +461,11 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
|
|||
)
|
||||
else:
|
||||
_columns = [
|
||||
utils.get_column_name(column_)
|
||||
if utils.is_adhoc_column(column_)
|
||||
else column_
|
||||
(
|
||||
utils.get_column_name(column_)
|
||||
if utils.is_adhoc_column(column_)
|
||||
else column_
|
||||
)
|
||||
for column_param in COLUMN_FORM_DATA_PARAMS
|
||||
for column_ in utils.as_list(form_data.get(column_param) or [])
|
||||
]
|
||||
|
|
@ -1963,7 +1965,7 @@ class SqlaTable(
|
|||
if self.has_extra_cache_key_calls(query_obj):
|
||||
sqla_query = self.get_sqla_query(**query_obj)
|
||||
extra_cache_keys += sqla_query.extra_cache_keys
|
||||
return extra_cache_keys
|
||||
return list(set(extra_cache_keys))
|
||||
|
||||
@property
|
||||
def quote_identifier(self) -> Callable[[str], str]:
|
||||
|
|
|
|||
|
|
@ -84,15 +84,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
|
|||
|
||||
@staticmethod
|
||||
def validate_uniqueness(
|
||||
database_id: int,
|
||||
database: Database,
|
||||
table: Table,
|
||||
dataset_id: int | None = None,
|
||||
) -> bool:
|
||||
# The catalog might not be set even if the database supports catalogs, in case
|
||||
# multi-catalog is disabled.
|
||||
catalog = table.catalog or database.get_default_catalog()
|
||||
|
||||
dataset_query = db.session.query(SqlaTable).filter(
|
||||
SqlaTable.table_name == table.table,
|
||||
SqlaTable.schema == table.schema,
|
||||
SqlaTable.catalog == table.catalog,
|
||||
SqlaTable.database_id == database_id,
|
||||
SqlaTable.catalog == catalog,
|
||||
SqlaTable.database_id == database.id,
|
||||
)
|
||||
|
||||
if dataset_id:
|
||||
|
|
@ -103,15 +107,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
|
|||
|
||||
@staticmethod
|
||||
def validate_update_uniqueness(
|
||||
database_id: int,
|
||||
database: Database,
|
||||
table: Table,
|
||||
dataset_id: int,
|
||||
) -> bool:
|
||||
# The catalog might not be set even if the database supports catalogs, in case
|
||||
# multi-catalog is disabled.
|
||||
catalog = table.catalog or database.get_default_catalog()
|
||||
|
||||
dataset_query = db.session.query(SqlaTable).filter(
|
||||
SqlaTable.table_name == table.table,
|
||||
SqlaTable.database_id == database_id,
|
||||
SqlaTable.database_id == database.id,
|
||||
SqlaTable.schema == table.schema,
|
||||
SqlaTable.catalog == table.catalog,
|
||||
SqlaTable.catalog == catalog,
|
||||
SqlaTable.id != dataset_id,
|
||||
)
|
||||
return not db.session.query(dataset_query.exists()).scalar()
|
||||
|
|
|
|||
|
|
@ -1159,7 +1159,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
self.incr_stats("init", self.select_star.__name__)
|
||||
try:
|
||||
result = database.select_star(
|
||||
Table(table_name, schema_name),
|
||||
Table(table_name, schema_name, database.get_default_catalog()),
|
||||
latest_partition=True,
|
||||
)
|
||||
except NoSuchTableError:
|
||||
|
|
|
|||
|
|
@ -565,7 +565,7 @@ class NoOpTemplateProcessor(BaseTemplateProcessor):
|
|||
"""
|
||||
Makes processing a template a noop
|
||||
"""
|
||||
return sql
|
||||
return str(sql)
|
||||
|
||||
|
||||
class PrestoTemplateProcessor(JinjaTemplateProcessor):
|
||||
|
|
|
|||
|
|
@ -491,7 +491,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
g.user.id,
|
||||
self.db_engine_spec,
|
||||
)
|
||||
if hasattr(g, "user") and hasattr(g.user, "id") and oauth2_config
|
||||
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
|
||||
else None
|
||||
)
|
||||
# If using MySQL or Presto for example, will set url.username
|
||||
|
|
|
|||
|
|
@ -369,6 +369,7 @@ def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) ->
|
|||
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
|
||||
if ds:
|
||||
target.perm = ds.perm
|
||||
target.catalog_perm = ds.catalog_perm
|
||||
target.schema_perm = ds.schema_perm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -774,6 +774,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
# pylint: disable=import-outside-toplevel
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
default_catalog = database.get_default_catalog()
|
||||
catalog = catalog or default_catalog
|
||||
|
||||
if hierarchical and (
|
||||
self.can_access_database(database)
|
||||
or (catalog and self.can_access_catalog(database, catalog))
|
||||
|
|
@ -783,7 +786,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
# schema_access
|
||||
accessible_schemas: set[str] = set()
|
||||
schema_access = self.user_view_menu_names("schema_access")
|
||||
default_catalog = database.get_default_catalog()
|
||||
default_schema = database.get_default_schema(default_catalog)
|
||||
|
||||
for perm in schema_access:
|
||||
|
|
@ -800,7 +802,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
# [database].[catalog].[schema] matches when the catalog is equal to the
|
||||
# requested catalog or, when no catalog specified, it's equal to the default
|
||||
# catalog.
|
||||
elif len(parts) == 3 and parts[1] == (catalog or default_catalog):
|
||||
elif len(parts) == 3 and parts[1] == catalog:
|
||||
accessible_schemas.add(parts[2])
|
||||
|
||||
# datasource_access
|
||||
|
|
@ -906,16 +908,16 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
if self.can_access_database(database):
|
||||
return datasource_names
|
||||
|
||||
catalog = catalog or database.get_default_catalog()
|
||||
if catalog:
|
||||
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
|
||||
if catalog_perm and self.can_access("catalog_access", catalog_perm):
|
||||
return datasource_names
|
||||
|
||||
if schema:
|
||||
default_catalog = database.get_default_catalog()
|
||||
schema_perm = self.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog or default_catalog,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
if schema_perm and self.can_access("schema_access", schema_perm):
|
||||
|
|
@ -2183,6 +2185,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
database = query.database
|
||||
|
||||
database = cast("Database", database)
|
||||
default_catalog = database.get_default_catalog()
|
||||
|
||||
if self.can_access_database(database):
|
||||
return
|
||||
|
|
@ -2196,19 +2199,19 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
|
||||
# inspector to read it.
|
||||
default_schema = database.get_default_schema_for_query(query)
|
||||
# Determining the default catalog is much easier, because DB engine
|
||||
# specs need explicit support for catalogs.
|
||||
default_catalog = database.get_default_catalog()
|
||||
tables = {
|
||||
Table(
|
||||
table_.table,
|
||||
table_.schema or default_schema,
|
||||
table_.catalog or default_catalog,
|
||||
table_.catalog or query.catalog or default_catalog,
|
||||
)
|
||||
for table_ in extract_tables_from_jinja_sql(query.sql, database)
|
||||
}
|
||||
elif table:
|
||||
tables = {table}
|
||||
# Make sure table has the default catalog, if not specified.
|
||||
tables = {
|
||||
Table(table.table, table.schema, table.catalog or default_catalog)
|
||||
}
|
||||
|
||||
denied = set()
|
||||
|
||||
|
|
|
|||
|
|
@ -284,6 +284,7 @@ class SqlLabRestApi(BaseSupersetApi):
|
|||
"client_id": client_id,
|
||||
"row_count": row_count,
|
||||
"database": query.database.name,
|
||||
"catalog": query.catalog,
|
||||
"schema": query.schema,
|
||||
"sql": query.sql,
|
||||
"exported_format": "csv",
|
||||
|
|
|
|||
|
|
@ -125,6 +125,8 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
|
|||
def set_database(self, database: Database) -> None:
|
||||
self._validate_db(database)
|
||||
self.database = database
|
||||
if self.catalog is None:
|
||||
self.catalog = database.get_default_catalog()
|
||||
if self.select_as_cta:
|
||||
schema_name = self._get_ctas_target_schema_name(database)
|
||||
self.create_table_as_select.target_schema_name = schema_name # type: ignore
|
||||
|
|
|
|||
|
|
@ -239,6 +239,7 @@ class TableSchemaView(BaseSupersetView):
|
|||
db.session.query(TableSchema).filter(
|
||||
TableSchema.tab_state_id == table["queryEditorId"],
|
||||
TableSchema.database_id == table["dbId"],
|
||||
TableSchema.catalog == table["catalog"],
|
||||
TableSchema.schema == table["schema"],
|
||||
TableSchema.table == table["name"],
|
||||
).delete(synchronize_session=False)
|
||||
|
|
@ -246,6 +247,7 @@ class TableSchemaView(BaseSupersetView):
|
|||
table_schema = TableSchema(
|
||||
tab_state_id=table["queryEditorId"],
|
||||
database_id=table["dbId"],
|
||||
catalog=table["catalog"],
|
||||
schema=table["schema"],
|
||||
table=table["name"],
|
||||
description=json.dumps(table),
|
||||
|
|
|
|||
|
|
@ -1563,34 +1563,6 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_select_star_datasource_access(self):
|
||||
"""
|
||||
Database API: Test get select star with datasource access
|
||||
"""
|
||||
table = SqlaTable(
|
||||
schema="main", table_name="ab_permission", database=get_main_database()
|
||||
)
|
||||
db.session.add(table)
|
||||
db.session.commit()
|
||||
|
||||
tmp_table_perm = security_manager.find_permission_view_menu(
|
||||
"datasource_access", table.get_perm()
|
||||
)
|
||||
gamma_role = security_manager.find_role("Gamma")
|
||||
security_manager.add_permission_role(gamma_role, tmp_table_perm)
|
||||
|
||||
self.login(GAMMA_USERNAME)
|
||||
main_db = get_main_database()
|
||||
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
# rollback changes
|
||||
security_manager.del_permission_role(gamma_role, tmp_table_perm)
|
||||
db.session.delete(table)
|
||||
db.session.delete(main_db)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_select_star_not_found_database(self):
|
||||
"""
|
||||
Database API: Test get select star not found database
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
|||
from superset.extensions import db, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.models.slice import Slice
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils import json
|
||||
from superset.utils.core import backend, get_example_default_schema
|
||||
from superset.utils.database import get_example_database, get_main_database
|
||||
|
|
@ -676,57 +675,6 @@ class TestDatasetApi(SupersetTestCase):
|
|||
expected_result = {"message": {"owners": ["Owners are invalid"]}}
|
||||
assert data == expected_result
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_create_dataset_validate_uniqueness(self):
|
||||
"""
|
||||
Dataset API: Test create dataset validate table uniqueness
|
||||
"""
|
||||
|
||||
energy_usage_ds = self.get_energy_usage_dataset()
|
||||
self.login(ADMIN_USERNAME)
|
||||
table_data = {
|
||||
"database": energy_usage_ds.database_id,
|
||||
"table_name": energy_usage_ds.table_name,
|
||||
}
|
||||
if schema := get_example_default_schema():
|
||||
table_data["schema"] = schema
|
||||
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
|
||||
assert rv.status_code == 422
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data == {
|
||||
"message": {
|
||||
"table": [
|
||||
f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_create_dataset_with_sql_validate_uniqueness(self):
|
||||
"""
|
||||
Dataset API: Test create dataset with sql
|
||||
"""
|
||||
|
||||
energy_usage_ds = self.get_energy_usage_dataset()
|
||||
self.login(ADMIN_USERNAME)
|
||||
table_data = {
|
||||
"database": energy_usage_ds.database_id,
|
||||
"table_name": energy_usage_ds.table_name,
|
||||
"sql": "select * from energy_usage",
|
||||
}
|
||||
if schema := get_example_default_schema():
|
||||
table_data["schema"] = schema
|
||||
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
|
||||
assert rv.status_code == 422
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data == {
|
||||
"message": {
|
||||
"table": [
|
||||
f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_create_dataset_with_sql(self):
|
||||
"""
|
||||
|
|
@ -1455,27 +1403,6 @@ class TestDatasetApi(SupersetTestCase):
|
|||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
||||
def test_update_dataset_item_uniqueness(self):
|
||||
"""
|
||||
Dataset API: Test update dataset uniqueness
|
||||
"""
|
||||
|
||||
dataset = self.insert_default_dataset()
|
||||
self.login(ADMIN_USERNAME)
|
||||
ab_user = self.insert_dataset(
|
||||
"ab_user", [self.get_user("admin").id], get_main_database()
|
||||
)
|
||||
table_data = {"table_name": "ab_user"}
|
||||
uri = f"api/v1/dataset/{dataset.id}"
|
||||
rv = self.put_assert_metric(uri, table_data, "put")
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert rv.status_code == 422
|
||||
expected_response = {"message": {"table": ["Dataset ab_user already exists"]}}
|
||||
assert data == expected_response
|
||||
db.session.delete(dataset)
|
||||
db.session.delete(ab_user)
|
||||
db.session.commit()
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.update")
|
||||
def test_update_dataset_sqlalchemy_error(self, mock_dao_update):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from superset.db_engine_specs import load_engine_specs
|
|||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils.core import backend
|
||||
from superset.utils.database import get_example_database
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
from tests.integration_tests.fixtures.certificates import ssl_certificate
|
||||
|
|
@ -525,11 +526,10 @@ def test_get_catalog_names(app_context: AppContext) -> None:
|
|||
"""
|
||||
Test the ``get_catalog_names`` method.
|
||||
"""
|
||||
database = get_example_database()
|
||||
|
||||
if database.backend != "postgresql":
|
||||
if backend() != "postgresql":
|
||||
return
|
||||
|
||||
database = get_example_database()
|
||||
with database.get_inspector() as inspector:
|
||||
assert PostgresEngineSpec.get_catalog_names(database, inspector) == {
|
||||
"postgres",
|
||||
|
|
|
|||
|
|
@ -1633,7 +1633,10 @@ class TestSecurityManager(SupersetTestCase):
|
|||
@patch("superset.security.SupersetSecurityManager.can_access")
|
||||
def test_raise_for_access_query(self, mock_can_access, mock_is_owner):
|
||||
query = Mock(
|
||||
database=get_example_database(), schema="bar", sql="SELECT * FROM foo"
|
||||
database=get_example_database(),
|
||||
schema="bar",
|
||||
sql="SELECT * FROM foo",
|
||||
catalog=None,
|
||||
)
|
||||
|
||||
mock_can_access.return_value = True
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric
|
|||
from superset.constants import EMPTY_STRING, NULL_STRING
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
|
||||
from superset.db_engine_specs.druid import DruidEngineSpec
|
||||
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException # noqa: F401
|
||||
from superset.exceptions import (
|
||||
QueryObjectValidationError,
|
||||
) # noqa: F401
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import (
|
||||
AdhocMetricExpressionType,
|
||||
|
|
@ -160,7 +162,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
query_obj = dict(**base_query_obj, extras={})
|
||||
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
|
||||
assert extra_cache_keys == [1, "abc", "abc@test.com"]
|
||||
assert set(extra_cache_keys) == {1, "abc", "abc@test.com"}
|
||||
|
||||
# Table with Jinja callable disabled.
|
||||
table2 = SqlaTable(
|
||||
|
|
|
|||
|
|
@ -255,11 +255,11 @@ def test_dataset_uniqueness(session: Session) -> None:
|
|||
|
||||
# but the DAO enforces application logic for uniqueness
|
||||
assert not DatasetDAO.validate_uniqueness(
|
||||
database.id,
|
||||
database,
|
||||
Table("table", "schema", None),
|
||||
)
|
||||
|
||||
assert DatasetDAO.validate_uniqueness(
|
||||
database.id,
|
||||
database,
|
||||
Table("table", "schema", "some_catalog"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
|
|||
|
||||
assert (
|
||||
DatasetDAO.validate_update_uniqueness(
|
||||
database_id=database.id,
|
||||
database=database,
|
||||
table=Table(dataset1.table_name, dataset1.schema),
|
||||
dataset_id=dataset1.id,
|
||||
)
|
||||
|
|
@ -62,7 +62,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
|
|||
|
||||
assert (
|
||||
DatasetDAO.validate_update_uniqueness(
|
||||
database_id=database.id,
|
||||
database=database,
|
||||
table=Table(dataset1.table_name, dataset2.schema),
|
||||
dataset_id=dataset1.id,
|
||||
)
|
||||
|
|
@ -71,7 +71,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
|
|||
|
||||
assert (
|
||||
DatasetDAO.validate_update_uniqueness(
|
||||
database_id=database.id,
|
||||
database=database,
|
||||
table=Table(dataset1.table_name),
|
||||
dataset_id=dataset1.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -366,6 +366,7 @@ def test_raise_for_access_query_default_schema(
|
|||
database.get_default_catalog.return_value = None
|
||||
database.get_default_schema_for_query.return_value = "public"
|
||||
query = mocker.MagicMock()
|
||||
query.catalog = None
|
||||
query.database = database
|
||||
query.sql = "SELECT * FROM ab_user"
|
||||
|
||||
|
|
@ -421,6 +422,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) ->
|
|||
database.get_default_catalog.return_value = None
|
||||
database.get_default_schema_for_query.return_value = "public"
|
||||
query = mocker.MagicMock()
|
||||
query.catalog = None
|
||||
query.database = database
|
||||
query.sql = "SELECT * FROM {% if True %}ab_user{% endif %} WHERE 1=1"
|
||||
|
||||
|
|
@ -434,7 +436,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) ->
|
|||
viz=None,
|
||||
)
|
||||
|
||||
get_table_access_error_object.assert_called_with({Table("ab_user", "public")})
|
||||
get_table_access_error_object.assert_called_with({Table("ab_user", "public", None)})
|
||||
|
||||
|
||||
def test_raise_for_access_chart_for_datasource_permission(
|
||||
|
|
@ -736,6 +738,7 @@ def test_raise_for_access_catalog(
|
|||
database.get_default_catalog.return_value = "db1"
|
||||
database.get_default_schema_for_query.return_value = "public"
|
||||
query = mocker.MagicMock()
|
||||
query.catalog = "db1"
|
||||
query.database = database
|
||||
query.sql = "SELECT * FROM ab_user"
|
||||
|
||||
|
|
@ -776,7 +779,8 @@ def test_get_datasources_accessible_by_user_schema_access(
|
|||
database.database_name = "db1"
|
||||
database.get_default_catalog.return_value = "catalog2"
|
||||
|
||||
can_access = mocker.patch.object(sm, "can_access", return_value=True)
|
||||
# False for catalog_access, True for schema_access
|
||||
can_access = mocker.patch.object(sm, "can_access", side_effect=[False, True])
|
||||
|
||||
datasource_names = [
|
||||
DatasourceName("table1", "schema1", "catalog2"),
|
||||
|
|
@ -795,7 +799,12 @@ def test_get_datasources_accessible_by_user_schema_access(
|
|||
|
||||
# Even though we passed `catalog=None,` the schema check uses the default catalog
|
||||
# when building the schema permission, since the DB supports catalog.
|
||||
can_access.assert_called_with("schema_access", "[db1].[catalog2].[schema1]")
|
||||
can_access.assert_has_calls(
|
||||
[
|
||||
mocker.call("catalog_access", "[db1].[catalog2]"),
|
||||
mocker.call("schema_access", "[db1].[catalog2].[schema1]"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_get_catalogs_accessible_by_user_schema_access(
|
||||
|
|
|
|||
Loading…
Reference in New Issue