fix: handle empty catalog when DB supports them (#29840)

This commit is contained in:
Beto Dealmeida 2024-08-13 10:08:43 -04:00 committed by GitHub
parent 9f5eb899e8
commit 39209c2b40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 100 additions and 148 deletions

View File

@ -52,7 +52,7 @@ GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
def fetch_files_github_api(url: str): # type: ignore
"""Fetches data using GitHub API."""
req = Request(url)
req.add_header("Authorization", f"token {GITHUB_TOKEN}")
req.add_header("Authorization", f"Bearer {GITHUB_TOKEN}")
req.add_header("Accept", "application/vnd.github.v3+json")
print(f"Fetching from {url}")

View File

@ -32,6 +32,10 @@ class Datasource(Schema):
datasource_name = fields.String(
metadata={"description": datasource_name_description},
)
catalog = fields.String(
allow_none=True,
metadata={"description": "Datasource catalog"},
)
schema = fields.String(
metadata={"description": "Datasource schema"},
)

View File

@ -54,23 +54,28 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
def validate(self) -> None:
exceptions: list[ValidationError] = []
database_id = self._properties["database"]
schema = self._properties.get("schema")
catalog = self._properties.get("catalog")
schema = self._properties.get("schema")
table_name = self._properties["table_name"]
sql = self._properties.get("sql")
owner_ids: Optional[list[int]] = self._properties.get("owners")
table = Table(self._properties["table_name"], schema, catalog)
# Validate uniqueness
if not DatasetDAO.validate_uniqueness(database_id, table):
exceptions.append(DatasetExistsValidationError(table))
# Validate/Populate database
database = DatasetDAO.get_database_by_id(database_id)
if not database:
exceptions.append(DatabaseNotFoundValidationError())
self._properties["database"] = database
# Validate uniqueness
if database:
if not catalog:
catalog = self._properties["catalog"] = database.get_default_catalog()
table = Table(table_name, schema, catalog)
if not DatasetDAO.validate_uniqueness(database, table):
exceptions.append(DatasetExistsValidationError(table))
# Validate table exists on dataset if sql is not provided
# This should be validated when the dataset is physical
if (

View File

@ -166,7 +166,7 @@ def import_dataset(
try:
table_exists = dataset.database.has_table(
Table(dataset.table_name, dataset.schema),
Table(dataset.table_name, dataset.schema, dataset.catalog),
)
except Exception: # pylint: disable=broad-except
# MySQL doesn't play nice with GSheets table names

View File

@ -79,10 +79,12 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
def validate(self) -> None:
exceptions: list[ValidationError] = []
owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model:
raise DatasetNotFoundError()
# Check ownership
try:
security_manager.raise_for_ownership(self._model)
@ -91,22 +93,30 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
database_id = self._properties.get("database")
catalog = self._properties.get("catalog")
if not catalog:
catalog = self._properties["catalog"] = (
self._model.database.get_default_catalog()
)
table = Table(
self._properties.get("table_name"), # type: ignore
self._properties.get("schema"),
self._properties.get("catalog"),
catalog,
)
# Validate uniqueness
if not DatasetDAO.validate_update_uniqueness(
self._model.database_id,
self._model.database,
table,
self._model_id,
):
exceptions.append(DatasetExistsValidationError(table))
# Validate/Populate database not allowed to change
if database_id and database_id != self._model:
exceptions.append(DatabaseChangeValidationError())
# Validate/Populate owner
try:
owners = self.compute_owners(
@ -116,6 +126,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
self._properties["owners"] = owners
except ValidationError as ex:
exceptions.append(ex)
# Validate columns
if columns := self._properties.get("columns"):
self._validate_columns(columns, exceptions)

View File

@ -461,9 +461,11 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
)
else:
_columns = [
(
utils.get_column_name(column_)
if utils.is_adhoc_column(column_)
else column_
)
for column_param in COLUMN_FORM_DATA_PARAMS
for column_ in utils.as_list(form_data.get(column_param) or [])
]
@ -1963,7 +1965,7 @@ class SqlaTable(
if self.has_extra_cache_key_calls(query_obj):
sqla_query = self.get_sqla_query(**query_obj)
extra_cache_keys += sqla_query.extra_cache_keys
return extra_cache_keys
return list(set(extra_cache_keys))
@property
def quote_identifier(self) -> Callable[[str], str]:

View File

@ -84,15 +84,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
@staticmethod
def validate_uniqueness(
database_id: int,
database: Database,
table: Table,
dataset_id: int | None = None,
) -> bool:
# The catalog might not be set even if the database supports catalogs, in case
# multi-catalog is disabled.
catalog = table.catalog or database.get_default_catalog()
dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == table.table,
SqlaTable.schema == table.schema,
SqlaTable.catalog == table.catalog,
SqlaTable.database_id == database_id,
SqlaTable.catalog == catalog,
SqlaTable.database_id == database.id,
)
if dataset_id:
@ -103,15 +107,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
@staticmethod
def validate_update_uniqueness(
database_id: int,
database: Database,
table: Table,
dataset_id: int,
) -> bool:
# The catalog might not be set even if the database supports catalogs, in case
# multi-catalog is disabled.
catalog = table.catalog or database.get_default_catalog()
dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == table.table,
SqlaTable.database_id == database_id,
SqlaTable.database_id == database.id,
SqlaTable.schema == table.schema,
SqlaTable.catalog == table.catalog,
SqlaTable.catalog == catalog,
SqlaTable.id != dataset_id,
)
return not db.session.query(dataset_query.exists()).scalar()

View File

@ -1159,7 +1159,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
self.incr_stats("init", self.select_star.__name__)
try:
result = database.select_star(
Table(table_name, schema_name),
Table(table_name, schema_name, database.get_default_catalog()),
latest_partition=True,
)
except NoSuchTableError:

View File

@ -565,7 +565,7 @@ class NoOpTemplateProcessor(BaseTemplateProcessor):
"""
Makes processing a template a noop
"""
return sql
return str(sql)
class PrestoTemplateProcessor(JinjaTemplateProcessor):

View File

@ -491,7 +491,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
g.user.id,
self.db_engine_spec,
)
if hasattr(g, "user") and hasattr(g.user, "id") and oauth2_config
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
else None
)
# If using MySQL or Presto for example, will set url.username

View File

@ -369,6 +369,7 @@ def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) ->
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
if ds:
target.perm = ds.perm
target.catalog_perm = ds.catalog_perm
target.schema_perm = ds.schema_perm

View File

@ -774,6 +774,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable
default_catalog = database.get_default_catalog()
catalog = catalog or default_catalog
if hierarchical and (
self.can_access_database(database)
or (catalog and self.can_access_catalog(database, catalog))
@ -783,7 +786,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# schema_access
accessible_schemas: set[str] = set()
schema_access = self.user_view_menu_names("schema_access")
default_catalog = database.get_default_catalog()
default_schema = database.get_default_schema(default_catalog)
for perm in schema_access:
@ -800,7 +802,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# [database].[catalog].[schema] matches when the catalog is equal to the
# requested catalog or, when no catalog specified, it's equal to the default
# catalog.
elif len(parts) == 3 and parts[1] == (catalog or default_catalog):
elif len(parts) == 3 and parts[1] == catalog:
accessible_schemas.add(parts[2])
# datasource_access
@ -906,16 +908,16 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if self.can_access_database(database):
return datasource_names
catalog = catalog or database.get_default_catalog()
if catalog:
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
if catalog_perm and self.can_access("catalog_access", catalog_perm):
return datasource_names
if schema:
default_catalog = database.get_default_catalog()
schema_perm = self.get_schema_perm(
database.database_name,
catalog or default_catalog,
catalog,
schema,
)
if schema_perm and self.can_access("schema_access", schema_perm):
@ -2183,6 +2185,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
database = query.database
database = cast("Database", database)
default_catalog = database.get_default_catalog()
if self.can_access_database(database):
return
@ -2196,19 +2199,19 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
# inspector to read it.
default_schema = database.get_default_schema_for_query(query)
# Determining the default catalog is much easier, because DB engine
# specs need explicit support for catalogs.
default_catalog = database.get_default_catalog()
tables = {
Table(
table_.table,
table_.schema or default_schema,
table_.catalog or default_catalog,
table_.catalog or query.catalog or default_catalog,
)
for table_ in extract_tables_from_jinja_sql(query.sql, database)
}
elif table:
tables = {table}
# Make sure table has the default catalog, if not specified.
tables = {
Table(table.table, table.schema, table.catalog or default_catalog)
}
denied = set()

View File

@ -284,6 +284,7 @@ class SqlLabRestApi(BaseSupersetApi):
"client_id": client_id,
"row_count": row_count,
"database": query.database.name,
"catalog": query.catalog,
"schema": query.schema,
"sql": query.sql,
"exported_format": "csv",

View File

@ -125,6 +125,8 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
def set_database(self, database: Database) -> None:
self._validate_db(database)
self.database = database
if self.catalog is None:
self.catalog = database.get_default_catalog()
if self.select_as_cta:
schema_name = self._get_ctas_target_schema_name(database)
self.create_table_as_select.target_schema_name = schema_name # type: ignore

View File

@ -239,6 +239,7 @@ class TableSchemaView(BaseSupersetView):
db.session.query(TableSchema).filter(
TableSchema.tab_state_id == table["queryEditorId"],
TableSchema.database_id == table["dbId"],
TableSchema.catalog == table["catalog"],
TableSchema.schema == table["schema"],
TableSchema.table == table["name"],
).delete(synchronize_session=False)
@ -246,6 +247,7 @@ class TableSchemaView(BaseSupersetView):
table_schema = TableSchema(
tab_state_id=table["queryEditorId"],
database_id=table["dbId"],
catalog=table["catalog"],
schema=table["schema"],
table=table["name"],
description=json.dumps(table),

View File

@ -1563,34 +1563,6 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_datasource_access(self):
"""
Database API: Test get select star with datasource access
"""
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
db.session.add(table)
db.session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)
self.login(GAMMA_USERNAME)
main_db = get_main_database()
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()
def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database

View File

@ -36,7 +36,6 @@ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.models.slice import Slice
from superset.sql_parse import Table
from superset.utils import json
from superset.utils.core import backend, get_example_default_schema
from superset.utils.database import get_example_database, get_main_database
@ -676,57 +675,6 @@ class TestDatasetApi(SupersetTestCase):
expected_result = {"message": {"owners": ["Owners are invalid"]}}
assert data == expected_result
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_create_dataset_validate_uniqueness(self):
"""
Dataset API: Test create dataset validate table uniqueness
"""
energy_usage_ds = self.get_energy_usage_dataset()
self.login(ADMIN_USERNAME)
table_data = {
"database": energy_usage_ds.database_id,
"table_name": energy_usage_ds.table_name,
}
if schema := get_example_default_schema():
table_data["schema"] = schema
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
assert rv.status_code == 422
data = json.loads(rv.data.decode("utf-8"))
assert data == {
"message": {
"table": [
f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists"
]
}
}
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_create_dataset_with_sql_validate_uniqueness(self):
"""
Dataset API: Test create dataset with sql
"""
energy_usage_ds = self.get_energy_usage_dataset()
self.login(ADMIN_USERNAME)
table_data = {
"database": energy_usage_ds.database_id,
"table_name": energy_usage_ds.table_name,
"sql": "select * from energy_usage",
}
if schema := get_example_default_schema():
table_data["schema"] = schema
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
assert rv.status_code == 422
data = json.loads(rv.data.decode("utf-8"))
assert data == {
"message": {
"table": [
f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists"
]
}
}
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_create_dataset_with_sql(self):
"""
@ -1455,27 +1403,6 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(dataset)
db.session.commit()
def test_update_dataset_item_uniqueness(self):
"""
Dataset API: Test update dataset uniqueness
"""
dataset = self.insert_default_dataset()
self.login(ADMIN_USERNAME)
ab_user = self.insert_dataset(
"ab_user", [self.get_user("admin").id], get_main_database()
)
table_data = {"table_name": "ab_user"}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, table_data, "put")
data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
expected_response = {"message": {"table": ["Dataset ab_user already exists"]}}
assert data == expected_response
db.session.delete(dataset)
db.session.delete(ab_user)
db.session.commit()
@patch("superset.daos.dataset.DatasetDAO.update")
def test_update_dataset_sqlalchemy_error(self, mock_dao_update):
"""

View File

@ -25,6 +25,7 @@ from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils.core import backend
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
@ -525,11 +526,10 @@ def test_get_catalog_names(app_context: AppContext) -> None:
"""
Test the ``get_catalog_names`` method.
"""
database = get_example_database()
if database.backend != "postgresql":
if backend() != "postgresql":
return
database = get_example_database()
with database.get_inspector() as inspector:
assert PostgresEngineSpec.get_catalog_names(database, inspector) == {
"postgres",

View File

@ -1633,7 +1633,10 @@ class TestSecurityManager(SupersetTestCase):
@patch("superset.security.SupersetSecurityManager.can_access")
def test_raise_for_access_query(self, mock_can_access, mock_is_owner):
query = Mock(
database=get_example_database(), schema="bar", sql="SELECT * FROM foo"
database=get_example_database(),
schema="bar",
sql="SELECT * FROM foo",
catalog=None,
)
mock_can_access.return_value = True

View File

@ -34,7 +34,9 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric
from superset.constants import EMPTY_STRING, NULL_STRING
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException # noqa: F401
from superset.exceptions import (
QueryObjectValidationError,
) # noqa: F401
from superset.models.core import Database
from superset.utils.core import (
AdhocMetricExpressionType,
@ -160,7 +162,7 @@ class TestDatabaseModel(SupersetTestCase):
query_obj = dict(**base_query_obj, extras={})
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
assert extra_cache_keys == [1, "abc", "abc@test.com"]
assert set(extra_cache_keys) == {1, "abc", "abc@test.com"}
# Table with Jinja callable disabled.
table2 = SqlaTable(

View File

@ -255,11 +255,11 @@ def test_dataset_uniqueness(session: Session) -> None:
# but the DAO enforces application logic for uniqueness
assert not DatasetDAO.validate_uniqueness(
database.id,
database,
Table("table", "schema", None),
)
assert DatasetDAO.validate_uniqueness(
database.id,
database,
Table("table", "schema", "some_catalog"),
)

View File

@ -53,7 +53,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
database=database,
table=Table(dataset1.table_name, dataset1.schema),
dataset_id=dataset1.id,
)
@ -62,7 +62,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
database=database,
table=Table(dataset1.table_name, dataset2.schema),
dataset_id=dataset1.id,
)
@ -71,7 +71,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
database=database,
table=Table(dataset1.table_name),
dataset_id=dataset1.id,
)

View File

@ -366,6 +366,7 @@ def test_raise_for_access_query_default_schema(
database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock()
query.catalog = None
query.database = database
query.sql = "SELECT * FROM ab_user"
@ -421,6 +422,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) ->
database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock()
query.catalog = None
query.database = database
query.sql = "SELECT * FROM {% if True %}ab_user{% endif %} WHERE 1=1"
@ -434,7 +436,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) ->
viz=None,
)
get_table_access_error_object.assert_called_with({Table("ab_user", "public")})
get_table_access_error_object.assert_called_with({Table("ab_user", "public", None)})
def test_raise_for_access_chart_for_datasource_permission(
@ -736,6 +738,7 @@ def test_raise_for_access_catalog(
database.get_default_catalog.return_value = "db1"
database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock()
query.catalog = "db1"
query.database = database
query.sql = "SELECT * FROM ab_user"
@ -776,7 +779,8 @@ def test_get_datasources_accessible_by_user_schema_access(
database.database_name = "db1"
database.get_default_catalog.return_value = "catalog2"
can_access = mocker.patch.object(sm, "can_access", return_value=True)
# False for catalog_access, True for schema_access
can_access = mocker.patch.object(sm, "can_access", side_effect=[False, True])
datasource_names = [
DatasourceName("table1", "schema1", "catalog2"),
@ -795,7 +799,12 @@ def test_get_datasources_accessible_by_user_schema_access(
# Even though we passed `catalog=None,` the schema check uses the default catalog
# when building the schema permission, since the DB supports catalog.
can_access.assert_called_with("schema_access", "[db1].[catalog2].[schema1]")
can_access.assert_has_calls(
[
mocker.call("catalog_access", "[db1].[catalog2]"),
mocker.call("schema_access", "[db1].[catalog2].[schema1]"),
]
)
def test_get_catalogs_accessible_by_user_schema_access(