diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 89c640b7a..922c78f21 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines, redefined-outer-name +# pylint: disable=too-many-lines import dataclasses import json import logging @@ -36,7 +36,6 @@ from typing import ( Type, Union, ) -from uuid import uuid4 import dateutil.parser import numpy as np @@ -68,7 +67,6 @@ from sqlalchemy import ( from sqlalchemy.engine.base import Connection from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session -from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.mapper import Mapper from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table @@ -78,18 +76,15 @@ from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager from superset.advanced_data_type.types import AdvancedDataTypeResponse -from superset.columns.models import Column as NewColumn, UNKOWN_TYPE from superset.common.db_query_status import QueryStatus from superset.common.utils.time_range_utils import get_since_until_from_time_range from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.sqla.utils import ( - find_cached_objects_in_session, get_columns_description, get_physical_table_metadata, get_virtual_table_metadata, validate_adhoc_subquery, ) -from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression from superset.exceptions import ( AdvancedDataTypeResponseError, @@ -106,18 +101,8 @@ from superset.jinja_context import ( ) from superset.models.annotations import Annotation from superset.models.core import Database -from superset.models.helpers import ( - AuditMixinNullable, - CertificationMixin, - clone_model, - QueryResult, -) -from superset.sql_parse import ( - extract_table_references, - ParsedQuery, - sanitize_clause, - Table as TableName, -) +from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult +from superset.sql_parse import ParsedQuery, sanitize_clause from superset.superset_typing import ( AdhocColumn, AdhocMetric, @@ -126,7 +111,6 @@ from superset.superset_typing import ( OrderBy, QueryObjectDict, ) -from superset.tables.models import Table as NewTable from superset.utils import core as utils from superset.utils.core import ( GenericDataType, @@ -439,76 +423,6 @@ class TableColumn(Model, BaseColumn, CertificationMixin): return attr_dict - def to_sl_column( - self, known_columns: Optional[Dict[str, NewColumn]] = None - ) -> NewColumn: - """Convert a TableColumn to NewColumn""" - session: Session = inspect(self).session - column = known_columns.get(self.uuid) if known_columns else None - if not column: - column = NewColumn() - - extra_json = self.get_extra_dict() - for attr in { - "verbose_name", - "python_date_format", - }: - value = getattr(self, attr) - if value: - extra_json[attr] = value - - # column id is primary key, so make sure that we check uuid against - # the id as well - if not column.id: - with session.no_autoflush: - saved_column: NewColumn = ( - session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none() - ) - if saved_column is not None: - logger.warning( - "sl_column already exists. Using this row for db update %s", - self, - ) - - # overwrite the existing column instead of creating a new one - column = saved_column - - column.uuid = self.uuid - column.created_on = self.created_on - column.changed_on = self.changed_on - column.created_by = self.created_by - column.changed_by = self.changed_by - column.name = self.column_name - column.type = self.type or UNKOWN_TYPE - column.expression = self.expression or self.table.quote_identifier( - self.column_name - ) - column.description = self.description - column.is_aggregation = False - column.is_dimensional = self.groupby - column.is_filterable = self.filterable - column.is_increase_desired = True - column.is_managed_externally = self.table.is_managed_externally - column.is_partition = False - column.is_physical = not self.expression - column.is_spatial = False - column.is_temporal = self.is_dttm - column.extra_json = json.dumps(extra_json) if extra_json else None - column.external_url = self.table.external_url - - return column - - @staticmethod - def after_delete( # pylint: disable=unused-argument - mapper: Mapper, - connection: Connection, - target: "TableColumn", - ) -> None: - session = inspect(target).session - column = session.query(NewColumn).filter_by(uuid=target.uuid).one_or_none() - if column: - session.delete(column) - class SqlMetric(Model, BaseMetric, CertificationMixin): @@ -574,76 +488,6 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): attr_dict.update(super().data) return attr_dict - def to_sl_column( - self, known_columns: Optional[Dict[str, NewColumn]] = None - ) -> NewColumn: - """Convert a SqlMetric to NewColumn. Find and update existing or - create a new one.""" - session: Session = inspect(self).session - column = known_columns.get(self.uuid) if known_columns else None - if not column: - column = NewColumn() - - extra_json = self.get_extra_dict() - for attr in {"verbose_name", "metric_type", "d3format"}: - value = getattr(self, attr) - if value is not None: - extra_json[attr] = value - is_additive = ( - self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER - ) - - # column id is primary key, so make sure that we check uuid against - # the id as well - if not column.id: - with session.no_autoflush: - saved_column: NewColumn = ( - session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none() - ) - - if saved_column is not None: - logger.warning( - "sl_column already exists. Using this row for db update %s", - self, - ) - - # overwrite the existing column instead of creating a new one - column = saved_column - - column.uuid = self.uuid - column.name = self.metric_name - column.created_on = self.created_on - column.changed_on = self.changed_on - column.created_by = self.created_by - column.changed_by = self.changed_by - column.type = UNKOWN_TYPE - column.expression = self.expression - column.warning_text = self.warning_text - column.description = self.description - column.is_aggregation = True - column.is_additive = is_additive - column.is_filterable = False - column.is_increase_desired = True - column.is_managed_externally = self.table.is_managed_externally - column.is_partition = False - column.is_physical = False - column.is_spatial = False - column.extra_json = json.dumps(extra_json) if extra_json else None - column.external_url = self.table.external_url - - return column - - @staticmethod - def after_delete( # pylint: disable=unused-argument - mapper: Mapper, - connection: Connection, - target: "SqlMetric", - ) -> None: - session = inspect(target).session - column = session.query(NewColumn).filter_by(uuid=target.uuid).one_or_none() - if column: - session.delete(column) - sqlatable_user = Table( "sqlatable_user", @@ -2228,40 +2072,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ): raise Exception(get_dataset_exist_error_msg(target.full_name)) - def get_sl_columns(self) -> List[NewColumn]: - """ - Convert `SqlaTable.columns` and `SqlaTable.metrics` to the new Column model - """ - session: Session = inspect(self).session - - uuids = set() - for column_or_metric in self.columns + self.metrics: - # pre-assign uuid after new columns or metrics are inserted so - # the related `NewColumn` can have a deterministic uuid, too - if not column_or_metric.uuid: - column_or_metric.uuid = uuid4() - else: - uuids.add(column_or_metric.uuid) - - # load existing columns from cached session states first - existing_columns = set( - find_cached_objects_in_session(session, NewColumn, uuids=uuids) - ) - for column in existing_columns: - uuids.remove(column.uuid) - - if uuids: - with session.no_autoflush: - # load those not found from db - existing_columns |= set( - session.query(NewColumn).filter(NewColumn.uuid.in_(uuids)) - ) - - known_columns = {column.uuid: column for column in existing_columns} - return [ - item.to_sl_column(known_columns) for item in self.columns + self.metrics - ] - @staticmethod def update_column( # pylint: disable=unused-argument mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn] @@ -2278,46 +2088,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho # table is updated. This busts the cache key for all charts that use the table. session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id)) - # if table itself has changed, shadow-writing will happen in `after_update` anyway - if target.table not in session.dirty: - dataset: NewDataset = ( - session.query(NewDataset) - .filter_by(uuid=target.table.uuid) - .one_or_none() - ) - # Update shadow dataset and columns - # did we find the dataset? - if not dataset: - # if dataset is not found create a new copy - target.table.write_shadow_dataset() - return - - # update changed_on timestamp - session.execute(update(NewDataset).where(NewDataset.id == dataset.id)) - try: - with session.no_autoflush: - column = session.query(NewColumn).filter_by(uuid=target.uuid).one() - # update `Column` model as well - session.merge(target.to_sl_column({target.uuid: column})) - except NoResultFound: - logger.warning("No column was found for %s", target) - # see if the column is in cache - column = next( - find_cached_objects_in_session( - session, NewColumn, uuids=[target.uuid] - ), - None, - ) - if column: - logger.warning("New column was found in cache: %s", column) - - else: - # to be safe, use a different uuid and create a new column - uuid = uuid4() - target.uuid = uuid - - session.add(target.to_sl_column()) - @staticmethod def after_insert( mapper: Mapper, @@ -2325,19 +2095,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho sqla_table: "SqlaTable", ) -> None: """ - Shadow write the dataset to new models. - - The ``SqlaTable`` model is currently being migrated to two new models, ``Table`` - and ``Dataset``. In the first phase of the migration the new models are populated - whenever ``SqlaTable`` is modified (created, updated, or deleted). - - In the second phase of the migration reads will be done from the new models. - Finally, in the third phase of the migration the old models will be removed. - - For more context: https://github.com/apache/superset/issues/14909 + Update dataset permissions after insert """ security_manager.dataset_after_insert(mapper, connection, sqla_table) - sqla_table.write_shadow_dataset() @staticmethod def after_delete( @@ -2346,24 +2106,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho sqla_table: "SqlaTable", ) -> None: """ - Shadow write the dataset to new models. - - The ``SqlaTable`` model is currently being migrated to two new models, ``Table`` - and ``Dataset``. In the first phase of the migration the new models are populated - whenever ``SqlaTable`` is modified (created, updated, or deleted). - - In the second phase of the migration reads will be done from the new models. - Finally, in the third phase of the migration the old models will be removed. - - For more context: https://github.com/apache/superset/issues/14909 + Update dataset permissions after delete """ security_manager.dataset_after_delete(mapper, connection, sqla_table) - session = inspect(sqla_table).session - dataset = ( - session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none() - ) - if dataset: - session.delete(dataset) @staticmethod def after_update( @@ -2372,240 +2117,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho sqla_table: "SqlaTable", ) -> None: """ - Shadow write the dataset to new models. - - The ``SqlaTable`` model is currently being migrated to two new models, ``Table`` - and ``Dataset``. In the first phase of the migration the new models are populated - whenever ``SqlaTable`` is modified (created, updated, or deleted). - - In the second phase of the migration reads will be done from the new models. - Finally, in the third phase of the migration the old models will be removed. - - For more context: https://github.com/apache/superset/issues/14909 + Update dataset permissions after update """ # set permissions security_manager.dataset_after_update(mapper, connection, sqla_table) - inspector = inspect(sqla_table) - session = inspector.session - - # double-check that ``UPDATE``s are actually pending (this method is called even - # for instances that have no net changes to their column-based attributes) - if not session.is_modified(sqla_table, include_collections=True): - return - - # find the dataset from the known instance list first - # (it could be either from a previous query or newly created) - dataset = next( - find_cached_objects_in_session( - session, NewDataset, uuids=[sqla_table.uuid] - ), - None, - ) - # if not found, pull from database - if not dataset: - dataset = ( - session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none() - ) - if not dataset: - sqla_table.write_shadow_dataset() - return - - # sync column list and delete removed columns - if ( - inspector.attrs.columns.history.has_changes() - or inspector.attrs.metrics.history.has_changes() - ): - # add pending new columns to known columns list, too, so if calling - # `after_update` twice before changes are persisted will not create - # two duplicate columns with the same uuids. - dataset.columns = sqla_table.get_sl_columns() - - # physical dataset - if not sqla_table.sql: - # if the table name changed we should relink the dataset to another table - # (and create one if necessary) - if ( - inspector.attrs.table_name.history.has_changes() - or inspector.attrs.schema.history.has_changes() - or inspector.attrs.database.history.has_changes() - ): - tables = NewTable.bulk_load_or_create( - sqla_table.database, - [TableName(schema=sqla_table.schema, table=sqla_table.table_name)], - sync_columns=False, - default_props=dict( - changed_by=sqla_table.changed_by, - created_by=sqla_table.created_by, - is_managed_externally=sqla_table.is_managed_externally, - external_url=sqla_table.external_url, - ), - ) - if not tables[0].id: - # dataset columns will only be assigned to newly created tables - # existing tables should manage column syncing in another process - physical_columns = [ - clone_model( - column, ignore=["uuid"], keep_relations=["changed_by"] - ) - for column in dataset.columns - if column.is_physical - ] - tables[0].columns = physical_columns - dataset.tables = tables - - # virtual dataset - else: - # mark all columns as virtual (not physical) - for column in dataset.columns: - column.is_physical = False - - # update referenced tables if SQL changed - if sqla_table.sql and inspector.attrs.sql.history.has_changes(): - referenced_tables = extract_table_references( - sqla_table.sql, sqla_table.database.get_dialect().name - ) - dataset.tables = NewTable.bulk_load_or_create( - sqla_table.database, - referenced_tables, - default_schema=sqla_table.schema, - # sync metadata is expensive, we'll do it in another process - # e.g. when users open a Table page - sync_columns=False, - default_props=dict( - changed_by=sqla_table.changed_by, - created_by=sqla_table.created_by, - is_managed_externally=sqla_table.is_managed_externally, - external_url=sqla_table.external_url, - ), - ) - - # update other attributes - dataset.name = sqla_table.table_name - dataset.expression = sqla_table.sql or sqla_table.quote_identifier( - sqla_table.table_name - ) - dataset.is_physical = not sqla_table.sql - - def write_shadow_dataset( - self: "SqlaTable", - ) -> None: - """ - Shadow write the dataset to new models. - - The ``SqlaTable`` model is currently being migrated to two new models, ``Table`` - and ``Dataset``. In the first phase of the migration the new models are populated - whenever ``SqlaTable`` is modified (created, updated, or deleted). - - In the second phase of the migration reads will be done from the new models. - Finally, in the third phase of the migration the old models will be removed. - - For more context: https://github.com/apache/superset/issues/14909 - """ - session = inspect(self).session - # make sure database points to the right instance, in case only - # `table.database_id` is updated and the changes haven't been - # consolidated by SQLA - if self.database_id and ( - not self.database or self.database.id != self.database_id - ): - self.database = session.query(Database).filter_by(id=self.database_id).one() - - # create columns - columns = [] - for item in self.columns + self.metrics: - item.created_by = self.created_by - item.changed_by = self.changed_by - # on `SqlaTable.after_insert`` event, although the table itself - # already has a `uuid`, the associated columns will not. - # Here we pre-assign a uuid so they can still be matched to the new - # Column after creation. - if not item.uuid: - item.uuid = uuid4() - columns.append(item.to_sl_column()) - - # physical dataset - if not self.sql: - # always create separate column entries for Dataset and Table - # so updating a dataset would not update columns in the related table - physical_columns = [ - clone_model( - column, - ignore=["uuid"], - # `created_by` will always be left empty because it'd always - # be created via some sort of automated system. - # But keep `changed_by` in case someone manually changes - # column attributes such as `is_dttm`. - keep_relations=["changed_by"], - ) - for column in columns - if column.is_physical - ] - tables = NewTable.bulk_load_or_create( - self.database, - [TableName(schema=self.schema, table=self.table_name)], - sync_columns=False, - default_props=dict( - created_by=self.created_by, - changed_by=self.changed_by, - is_managed_externally=self.is_managed_externally, - external_url=self.external_url, - ), - ) - tables[0].columns = physical_columns - - # virtual dataset - else: - # mark all columns as virtual (not physical) - for column in columns: - column.is_physical = False - - # find referenced tables - referenced_tables = extract_table_references( - self.sql, self.database.get_dialect().name - ) - tables = NewTable.bulk_load_or_create( - self.database, - referenced_tables, - default_schema=self.schema, - # syncing table columns can be slow so we are not doing it here - sync_columns=False, - default_props=dict( - created_by=self.created_by, - changed_by=self.changed_by, - is_managed_externally=self.is_managed_externally, - external_url=self.external_url, - ), - ) - - # create the new dataset - new_dataset = NewDataset( - uuid=self.uuid, - database_id=self.database_id, - created_on=self.created_on, - created_by=self.created_by, - changed_by=self.changed_by, - changed_on=self.changed_on, - owners=self.owners, - name=self.table_name, - expression=self.sql or self.quote_identifier(self.table_name), - tables=tables, - columns=columns, - is_physical=not self.sql, - is_managed_externally=self.is_managed_externally, - external_url=self.external_url, - ) - session.add(new_dataset) - sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update) +sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update) sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert) sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete) -sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update) sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column) -sa.event.listen(SqlMetric, "after_delete", SqlMetric.after_delete) sa.event.listen(TableColumn, "after_update", SqlaTable.update_column) -sa.event.listen(TableColumn, "after_delete", TableColumn.after_delete) RLSFilterRoles = Table( "rls_filter_roles", diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index d260df361..aed50574c 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union +from flask_appbuilder.models.sqla.interface import SQLAInterface from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import joinedload from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.dao.base import BaseDAO @@ -35,6 +37,26 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods model_cls = SqlaTable base_filter = DatasourceFilter + @classmethod + def find_by_ids(cls, model_ids: Union[List[str], List[int]]) -> List[SqlaTable]: + """ + Find a List of models by a list of ids, if defined applies `base_filter` + """ + id_col = getattr(SqlaTable, cls.id_column_name, None) + if id_col is None: + return [] + + # the joinedload option ensures that the database is + # available in the session later and not lazy loaded + query = ( + db.session.query(SqlaTable) + .options(joinedload(SqlaTable.database)) + .filter(id_col.in_(model_ids)) + ) + data_model = SQLAInterface(SqlaTable, db.session) + query = DatasourceFilter(cls.id_column_name, data_model).apply(query, None) + return query.all() + @staticmethod def get_database_by_id(database_id: int) -> Optional[Database]: try: diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 0175a2c33..33243a801 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -25,6 +25,7 @@ from zipfile import is_zipfile, ZipFile import prison import pytest import yaml +from sqlalchemy.orm import joinedload from sqlalchemy.sql import func from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn @@ -95,6 +96,7 @@ class TestDatasetApi(SupersetTestCase): def get_fixture_datasets(self) -> List[SqlaTable]: return ( db.session.query(SqlaTable) + .options(joinedload(SqlaTable.database)) .filter(SqlaTable.table_name.in_(self.fixture_tables_names)) .all() ) @@ -1973,21 +1975,17 @@ class TestDatasetApi(SupersetTestCase): database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) - shadow_dataset = ( - db.session.query(Dataset).filter_by(uuid=dataset_config["uuid"]).one() - ) + assert database.database_name == "imported_database" assert len(database.tables) == 1 dataset = database.tables[0] assert dataset.table_name == "imported_dataset" assert str(dataset.uuid) == dataset_config["uuid"] - assert str(shadow_dataset.uuid) == dataset_config["uuid"] dataset.owners = [] database.owners = [] db.session.delete(dataset) - db.session.delete(shadow_dataset) db.session.delete(database) db.session.commit() diff --git a/tests/integration_tests/datasets/model_tests.py b/tests/integration_tests/datasets/model_tests.py deleted file mode 100644 index 3bcc4c079..000000000 --- a/tests/integration_tests/datasets/model_tests.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest import mock - -import pytest -from sqlalchemy import inspect -from sqlalchemy.orm.exc import NoResultFound - -from superset.columns.models import Column -from superset.connectors.sqla.models import SqlaTable, TableColumn -from superset.extensions import db -from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.fixtures.datasource import load_dataset_with_columns - - -class SqlaTableModelTest(SupersetTestCase): - @pytest.mark.usefixtures("load_dataset_with_columns") - def test_dual_update_column(self) -> None: - """ - Test that when updating a sqla ``TableColumn`` - That the shadow ``Column`` is also updated - """ - dataset = db.session.query(SqlaTable).filter_by(table_name="students").first() - column = dataset.columns[0] - column_name = column.column_name - column.column_name = "new_column_name" - SqlaTable.update_column(None, None, target=column) - - # refetch - dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one() - assert dataset.columns[0].column_name == "new_column_name" - - # reset - column.column_name = column_name - SqlaTable.update_column(None, None, target=column) - - @pytest.mark.usefixtures("load_dataset_with_columns") - @mock.patch("superset.columns.models.Column") - def test_dual_update_column_not_found(self, column_mock) -> None: - """ - Test that when updating a sqla ``TableColumn`` - That the shadow ``Column`` is also updated - """ - dataset = db.session.query(SqlaTable).filter_by(table_name="students").first() - column = dataset.columns[0] - column_uuid = column.uuid - with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound): - SqlaTable.update_column(None, None, target=column) - - session = inspect(column).session - - session.flush() - - # refetch - dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one() - # it should create a new uuid - assert dataset.columns[0].uuid != column_uuid - - # reset - column.uuid = column_uuid - SqlaTable.update_column(None, None, target=column) - - @pytest.mark.usefixtures("load_dataset_with_columns") - def test_to_sl_column_no_known_columns(self) -> None: - """ - Test that the function returns a new column - """ - dataset = db.session.query(SqlaTable).filter_by(table_name="students").first() - column = dataset.columns[0] - new_column = column.to_sl_column() - - # it should use the same uuid - assert column.uuid == new_column.uuid diff --git a/tests/unit_tests/datasets/test_models.py b/tests/unit_tests/datasets/test_models.py deleted file mode 100644 index 771bb0d0e..000000000 --- a/tests/unit_tests/datasets/test_models.py +++ /dev/null @@ -1,1153 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import json -from typing import Any, Callable, Dict, List, TYPE_CHECKING - -from pytest_mock import MockFixture -from sqlalchemy.orm.session import Session - -from tests.unit_tests.utils.db import get_test_user - -if TYPE_CHECKING: - from superset.connectors.sqla.models import SqlMetric, TableColumn - - -def test_dataset_model(session: Session) -> None: - """ - Test basic attributes of a ``Dataset``. - """ - from superset.columns.models import Column - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - table = Table( - name="my_table", - schema="my_schema", - catalog="my_catalog", - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - columns=[ - Column(name="longitude", expression="longitude"), - Column(name="latitude", expression="latitude"), - ], - ) - session.add(table) - session.flush() - - dataset = Dataset( - database=table.database, - name="positions", - expression=""" -SELECT array_agg(array[longitude,latitude]) AS position -FROM my_catalog.my_schema.my_table -""", - tables=[table], - columns=[ - Column( - name="position", - expression="array_agg(array[longitude,latitude])", - ), - ], - ) - session.add(dataset) - session.flush() - - assert dataset.id == 1 - assert dataset.uuid is not None - - assert dataset.name == "positions" - assert ( - dataset.expression - == """ -SELECT array_agg(array[longitude,latitude]) AS position -FROM my_catalog.my_schema.my_table -""" - ) - - assert [table.name for table in dataset.tables] == ["my_table"] - assert [column.name for column in dataset.columns] == ["position"] - - -def test_cascade_delete_table(session: Session) -> None: - """ - Test that deleting ``Table`` also deletes its columns. - """ - from superset.columns.models import Column - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Table.metadata.create_all(engine) # pylint: disable=no-member - - table = Table( - name="my_table", - schema="my_schema", - catalog="my_catalog", - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - columns=[ - Column(name="longitude", expression="longitude"), - Column(name="latitude", expression="latitude"), - ], - ) - session.add(table) - session.flush() - - columns = session.query(Column).all() - assert len(columns) == 2 - - session.delete(table) - session.flush() - - # test that columns were deleted - columns = session.query(Column).all() - assert len(columns) == 0 - - -def test_cascade_delete_dataset(session: Session) -> None: - """ - Test that deleting ``Dataset`` also deletes its columns. - """ - from superset.columns.models import Column - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - table = Table( - name="my_table", - schema="my_schema", - catalog="my_catalog", - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - columns=[ - Column(name="longitude", expression="longitude"), - Column(name="latitude", expression="latitude"), - ], - ) - session.add(table) - session.flush() - - dataset = Dataset( - name="positions", - expression=""" -SELECT array_agg(array[longitude,latitude]) AS position -FROM my_catalog.my_schema.my_table -""", - database=table.database, - tables=[table], - columns=[ - Column( - name="position", - expression="array_agg(array[longitude,latitude])", - ), - ], - ) - session.add(dataset) - session.flush() - - columns = session.query(Column).all() - assert len(columns) == 3 - - session.delete(dataset) - session.flush() - - # test that dataset columns were deleted (but not table columns) - columns = session.query(Column).all() - assert len(columns) == 2 - - -def test_dataset_attributes(session: Session) -> None: - """ - Test that checks attributes in the dataset. - - If this check fails it means new attributes were added to ``SqlaTable``, and - ``SqlaTable.after_insert`` should be updated to handle them! - """ - from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn - from superset.models.core import Database - - engine = session.get_bind() - SqlaTable.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), - TableColumn(column_name="num_boys", type="INTEGER"), - TableColumn(column_name="revenue", type="INTEGER"), - TableColumn(column_name="expenses", type="INTEGER"), - TableColumn( - column_name="profit", type="INTEGER", expression="revenue-expenses" - ), - ] - metrics = [ - SqlMetric(metric_name="cnt", expression="COUNT(*)"), - ] - - sqla_table = SqlaTable( - table_name="old_dataset", - columns=columns, - metrics=metrics, - main_dttm_col="ds", - default_endpoint="https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - offset=-8, - description="This is the description", - is_featured=1, - cache_timeout=3600, - schema="my_schema", - sql=None, - params=json.dumps( - { - "remote_id": 64, - "database_name": "examples", - "import_time": 1606677834, - } - ), - perm=None, - filter_select_enabled=1, - fetch_values_predicate="foo IN (1, 2)", - is_sqllab_view=0, # no longer used? - template_params=json.dumps({"answer": "42"}), - schema_perm=None, - extra=json.dumps({"warning_markdown": "*WARNING*"}), - ) - - session.add(sqla_table) - session.flush() - - dataset = session.query(SqlaTable).one() - # If this test fails because attributes changed, make sure to update - # ``SqlaTable.after_insert`` accordingly. - assert sorted(dataset.__dict__.keys()) == [ - "_sa_instance_state", - "cache_timeout", - "changed_by_fk", - "changed_on", - "columns", - "created_by_fk", - "created_on", - "database", - "database_id", - "default_endpoint", - "description", - "external_url", - "extra", - "fetch_values_predicate", - "filter_select_enabled", - "id", - "is_featured", - "is_managed_externally", - "is_sqllab_view", - "main_dttm_col", - "metrics", - "offset", - "params", - "perm", - "schema", - "schema_perm", - "sql", - "table_name", - "template_params", - "uuid", - ] - - -def test_create_physical_sqlatable( - app_context: None, - session: Session, - sample_columns: Dict["TableColumn", Dict[str, Any]], - sample_metrics: Dict["SqlMetric", Dict[str, Any]], - columns_default: Dict[str, Any], -) -> None: - """ - Test shadow write when creating a new ``SqlaTable``. - - When a new physical ``SqlaTable`` is created, new models should also be created for - ``Dataset``, ``Table``, and ``Column``. - """ - from superset.columns.models import Column - from superset.columns.schemas import ColumnSchema - from superset.connectors.sqla.models import SqlaTable - from superset.datasets.models import Dataset - from superset.datasets.schemas import DatasetSchema - from superset.models.core import Database - from superset.tables.models import Table - from superset.tables.schemas import TableSchema - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - user1 = get_test_user(1, "abc") - columns = list(sample_columns.keys()) - metrics = list(sample_metrics.keys()) - expected_table_columns = list(sample_columns.values()) - expected_metric_columns = list(sample_metrics.values()) - - sqla_table = SqlaTable( - table_name="old_dataset", - columns=columns, - metrics=metrics, - main_dttm_col="ds", - default_endpoint="https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - offset=-8, - description="This is the description", - is_featured=1, - cache_timeout=3600, - schema="my_schema", - sql=None, - params=json.dumps( - { - "remote_id": 64, - "database_name": "examples", - "import_time": 1606677834, - } - ), - created_by=user1, - changed_by=user1, - owners=[user1], - perm=None, - filter_select_enabled=1, - fetch_values_predicate="foo IN (1, 2)", - is_sqllab_view=0, # no longer used? - template_params=json.dumps({"answer": "42"}), - schema_perm=None, - extra=json.dumps({"warning_markdown": "*WARNING*"}), - ) - session.add(sqla_table) - session.flush() - - # ignore these keys when comparing results - ignored_keys = {"created_on", "changed_on"} - - # check that columns were created - column_schema = ColumnSchema() - actual_columns = [ - {k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys} - for column in session.query(Column).all() - ] - num_physical_columns = len( - [col for col in expected_table_columns if col.get("is_physical") == True] - ) - num_dataset_table_columns = len(columns) - num_dataset_metric_columns = len(metrics) - assert ( - len(actual_columns) - == num_physical_columns + num_dataset_table_columns + num_dataset_metric_columns - ) - - # table columns are created before dataset columns are created - offset = 0 - for i in range(num_physical_columns): - assert actual_columns[i + offset] == { - **columns_default, - **expected_table_columns[i], - "id": i + offset + 1, - # physical columns for table have its own uuid - "uuid": actual_columns[i + offset]["uuid"], - "is_physical": True, - # table columns do not have creators - "created_by": None, - "tables": [1], - } - - offset += num_physical_columns - for i, column in enumerate(sqla_table.columns): - assert actual_columns[i + offset] == { - **columns_default, - **expected_table_columns[i], - "id": i + offset + 1, - # columns for dataset reuses the same uuid of TableColumn - "uuid": str(column.uuid), - "datasets": [1], - } - - offset += num_dataset_table_columns - for i, metric in enumerate(sqla_table.metrics): - assert actual_columns[i + offset] == { - **columns_default, - **expected_metric_columns[i], - "id": i + offset + 1, - "uuid": str(metric.uuid), - "datasets": [1], - } - - # check that table was created - table_schema = TableSchema() - tables = [ - { - k: v - for k, v in table_schema.dump(table).items() - if k not in (ignored_keys | {"uuid"}) - } - for table in session.query(Table).all() - ] - assert len(tables) == 1 - assert tables[0] == { - "id": 1, - "database": 1, - "created_by": 1, - "changed_by": 1, - "datasets": [1], - "columns": [1, 2, 3], - "extra_json": "{}", - "catalog": None, - "schema": "my_schema", - "name": "old_dataset", - "is_managed_externally": False, - "external_url": None, - } - - # check that dataset was created - dataset_schema = DatasetSchema() - datasets = [ - {k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys} - for dataset in session.query(Dataset).all() - ] - assert len(datasets) == 1 - assert datasets[0] == { - "id": 1, - "uuid": str(sqla_table.uuid), - "created_by": 1, - "changed_by": 1, - "owners": [1], - "name": "old_dataset", - "columns": [4, 5, 6, 7, 8, 9], - "is_physical": True, - "database": 1, - "tables": [1], - "extra_json": "{}", - "expression": "old_dataset", - "is_managed_externally": False, - "external_url": None, - } - - -def test_create_virtual_sqlatable( - app_context: None, - mocker: MockFixture, - session: Session, - sample_columns: Dict["TableColumn", Dict[str, Any]], - sample_metrics: Dict["SqlMetric", Dict[str, Any]], - columns_default: Dict[str, Any], -) -> None: - """ - Test shadow write when creating a new ``SqlaTable``. - - When a new virtual ``SqlaTable`` is created, new models should also be created for - ``Dataset`` and ``Column``. - """ - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", return_value=session - ) - - from superset.columns.models import Column - from superset.columns.schemas import ColumnSchema - from superset.connectors.sqla.models import SqlaTable - from superset.datasets.models import Dataset - from superset.datasets.schemas import DatasetSchema - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - user1 = get_test_user(1, "abc") - physical_table_columns: List[Dict[str, Any]] = [ - dict( - name="ds", - is_temporal=True, - type="TIMESTAMP", - advanced_data_type=None, - expression="ds", - is_physical=True, - ), - dict( - name="num_boys", - type="INTEGER", - advanced_data_type=None, - expression="num_boys", - is_physical=True, - ), - dict( - name="revenue", - type="INTEGER", - advanced_data_type=None, - expression="revenue", - is_physical=True, - ), - dict( - name="expenses", - type="INTEGER", - advanced_data_type=None, - expression="expenses", - is_physical=True, - ), - ] - # create a physical ``Table`` that the virtual dataset points to - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - table = Table( - name="some_table", - schema="my_schema", - catalog=None, - database=database, - columns=[ - Column(**props, created_by=user1, changed_by=user1) - for props in physical_table_columns - ], - ) - session.add(table) - session.commit() - - assert session.query(Table).count() == 1 - assert session.query(Dataset).count() == 0 - - # create virtual dataset - columns = list(sample_columns.keys()) - metrics = list(sample_metrics.keys()) - expected_table_columns = list(sample_columns.values()) - expected_metric_columns = list(sample_metrics.values()) - - sqla_table = SqlaTable( - created_by=user1, - changed_by=user1, - owners=[user1], - table_name="old_dataset", - columns=columns, - metrics=metrics, - main_dttm_col="ds", - default_endpoint="https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used - database=database, - offset=-8, - description="This is the description", - is_featured=1, - cache_timeout=3600, - schema="my_schema", - sql=""" -SELECT - ds, - num_boys, - revenue, - expenses, - revenue - expenses AS profit -FROM - some_table""", - params=json.dumps( - { - "remote_id": 64, - "database_name": "examples", - "import_time": 1606677834, - } - ), - perm=None, - filter_select_enabled=1, - fetch_values_predicate="foo IN (1, 2)", - is_sqllab_view=0, # no longer used? - template_params=json.dumps({"answer": "42"}), - schema_perm=None, - extra=json.dumps({"warning_markdown": "*WARNING*"}), - ) - session.add(sqla_table) - session.flush() - - # should not add a new table - assert session.query(Table).count() == 1 - assert session.query(Dataset).count() == 1 - - # ignore these keys when comparing results - ignored_keys = {"created_on", "changed_on"} - column_schema = ColumnSchema() - actual_columns = [ - {k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys} - for column in session.query(Column).all() - ] - num_physical_columns = len(physical_table_columns) - num_dataset_table_columns = len(columns) - num_dataset_metric_columns = len(metrics) - assert ( - len(actual_columns) - == num_physical_columns + num_dataset_table_columns + num_dataset_metric_columns - ) - - for i, column in enumerate(table.columns): - assert actual_columns[i] == { - **columns_default, - **physical_table_columns[i], - "id": i + 1, - "uuid": str(column.uuid), - "tables": [1], - } - - offset = num_physical_columns - for i, column in enumerate(sqla_table.columns): - assert actual_columns[i + offset] == { - **columns_default, - **expected_table_columns[i], - "id": i + offset + 1, - "uuid": str(column.uuid), - "is_physical": False, - "datasets": [1], - } - - offset = num_physical_columns + num_dataset_table_columns - for i, metric in enumerate(sqla_table.metrics): - assert actual_columns[i + offset] == { - **columns_default, - **expected_metric_columns[i], - "id": i + offset + 1, - "uuid": str(metric.uuid), - "datasets": [1], - } - - # check that dataset was created, and has a reference to the table - dataset_schema = DatasetSchema() - datasets = [ - {k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys} - for dataset in session.query(Dataset).all() - ] - assert len(datasets) == 1 - assert datasets[0] == { - "id": 1, - "database": 1, - "uuid": str(sqla_table.uuid), - "name": "old_dataset", - "changed_by": 1, - "created_by": 1, - "owners": [1], - "columns": [5, 6, 7, 8, 9, 10], - "is_physical": False, - "tables": [1], - "extra_json": "{}", - "external_url": None, - "is_managed_externally": False, - "expression": """ -SELECT - ds, - num_boys, - revenue, - expenses, - revenue - expenses AS profit -FROM - some_table""", - } - - -def test_delete_sqlatable(session: Session) -> None: - """ - Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``. - """ - from superset.columns.models import Column - from superset.connectors.sqla.models import SqlaTable, TableColumn - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), - ] - sqla_table = SqlaTable( - table_name="old_dataset", - columns=columns, - metrics=[], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - ) - session.add(sqla_table) - session.flush() - - assert session.query(Dataset).count() == 1 - assert session.query(Table).count() == 1 - assert session.query(Column).count() == 2 - - session.delete(sqla_table) - session.flush() - - # test that dataset and dataset columns are also deleted - # but the physical table and table columns are kept - assert session.query(Dataset).count() == 0 - assert session.query(Table).count() == 1 - assert session.query(Column).count() == 1 - - -def test_update_physical_sqlatable_columns( - mocker: MockFixture, session: Session -) -> None: - """ - Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. - """ - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", return_value=session - ) - - from superset.columns.models import Column - from superset.connectors.sqla.models import SqlaTable, TableColumn - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), - ] - sqla_table = SqlaTable( - table_name="old_dataset", - columns=columns, - metrics=[], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - ) - - session.add(sqla_table) - session.flush() - - assert session.query(Table).count() == 1 - assert session.query(Dataset).count() == 1 - assert session.query(Column).count() == 2 # 1 for table, 1 for dataset - - dataset = session.query(Dataset).one() - assert len(dataset.columns) == 1 - - # add a column to the original ``SqlaTable`` instance - sqla_table.columns.append(TableColumn(column_name="num_boys", type="INTEGER")) - session.flush() - - assert session.query(Column).count() == 3 - dataset = session.query(Dataset).one() - assert len(dataset.columns) == 2 - - # check that both lists have the same uuids - assert [col.uuid for col in sqla_table.columns].sort() == [ - col.uuid for col in dataset.columns - ].sort() - - # delete the column in the original instance - sqla_table.columns = sqla_table.columns[1:] - session.flush() - - # check that the column was added to the dataset and the added columns have - # the correct uuid. - assert session.query(TableColumn).count() == 1 - # the extra Dataset.column is deleted, but Table.column is kept - assert session.query(Column).count() == 2 - - # check that the column was also removed from the dataset - dataset = session.query(Dataset).one() - assert len(dataset.columns) == 1 - - # modify the attribute in a column - sqla_table.columns[0].is_dttm = True - session.flush() - - # check that the dataset column was modified - dataset = session.query(Dataset).one() - assert dataset.columns[0].is_temporal is True - - -def test_update_physical_sqlatable_schema( - mocker: MockFixture, session: Session -) -> None: - """ - Test that updating a ``SqlaTable`` schema also updates the corresponding ``Dataset``. - """ - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", return_value=session - ) - mocker.patch("superset.datasets.dao.db.session", session) - - from superset.columns.models import Column - from superset.connectors.sqla.models import SqlaTable, TableColumn - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), - ] - sqla_table = SqlaTable( - table_name="old_dataset", - schema="old_schema", - columns=columns, - metrics=[], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - ) - session.add(sqla_table) - session.flush() - - dataset = session.query(Dataset).one() - assert dataset.tables[0].schema == "old_schema" - assert dataset.tables[0].id == 1 - - sqla_table.schema = "new_schema" - session.flush() - - new_dataset = session.query(Dataset).one() - assert new_dataset.tables[0].schema == "new_schema" - assert new_dataset.tables[0].id == 2 - - -def test_update_physical_sqlatable_metrics( - mocker: MockFixture, - app_context: None, - session: Session, - get_session: Callable[[], Session], -) -> None: - """ - Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. - - For this test we check that updating the SQL expression in a metric belonging to a - ``SqlaTable`` is reflected in the ``Dataset`` metric. - """ - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", return_value=session - ) - - from superset.columns.models import Column - from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), - ] - metrics = [ - SqlMetric(metric_name="cnt", expression="COUNT(*)"), - ] - sqla_table = SqlaTable( - table_name="old_dataset", - columns=columns, - metrics=metrics, - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - ) - session.add(sqla_table) - session.flush() - - # check that the metric was created - # 1 physical column for table + (1 column + 1 metric for datasets) - assert session.query(Column).count() == 3 - - column = session.query(Column).filter_by(is_physical=False).one() - assert column.expression == "COUNT(*)" - - # change the metric definition - sqla_table.metrics[0].expression = "MAX(ds)" - session.flush() - - assert column.expression == "MAX(ds)" - - # in a new session, update new columns and metrics at the same time - # reload the sqla_table so we can test the case that accessing an not already - # loaded attribute (`sqla_table.metrics`) while there are updates on the instance - # may trigger `after_update` before the attribute is loaded - session = get_session() - sqla_table = session.query(SqlaTable).filter(SqlaTable.id == sqla_table.id).one() - sqla_table.columns.append( - TableColumn( - column_name="another_column", - is_dttm=0, - type="TIMESTAMP", - expression="concat('a', 'b')", - ) - ) - # Here `SqlaTable.after_update` is triggered - # before `sqla_table.metrics` is loaded - sqla_table.metrics.append( - SqlMetric(metric_name="another_metric", expression="COUNT(*)") - ) - # `SqlaTable.after_update` will trigger again at flushing - session.flush() - assert session.query(Column).count() == 5 - - -def test_update_physical_sqlatable_database( - mocker: MockFixture, - app_context: None, - session: Session, - get_session: Callable[[], Session], -) -> None: - """ - Test updating the table on a physical dataset. - - When updating the table on a physical dataset by pointing it somewhere else (change - in database ID, schema, or table name) we should point the ``Dataset`` to an - existing ``Table`` if possible, and create a new one otherwise. - """ - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", return_value=session - ) - mocker.patch("superset.datasets.dao.db.session", session) - - from superset.columns.models import Column - from superset.connectors.sqla.models import SqlaTable, TableColumn - from superset.datasets.models import Dataset, dataset_column_association_table - from superset.models.core import Database - from superset.tables.models import Table, table_column_association_table - from superset.tables.schemas import TableSchema - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="a", type="INTEGER"), - ] - - original_database = Database( - database_name="my_database", sqlalchemy_uri="sqlite://" - ) - sqla_table = SqlaTable( - table_name="original_table", - columns=columns, - metrics=[], - database=original_database, - ) - session.add(sqla_table) - session.flush() - - assert session.query(Table).count() == 1 - assert session.query(Dataset).count() == 1 - assert session.query(Column).count() == 2 # 1 for table, 1 for dataset - - # check that the table was created, and that the created dataset points to it - table = session.query(Table).one() - assert table.id == 1 - assert table.name == "original_table" - assert table.schema is None - assert table.database_id == 1 - - dataset = session.query(Dataset).one() - assert dataset.tables == [table] - - # point ``SqlaTable`` to a different database - new_database = Database( - database_name="my_other_database", sqlalchemy_uri="sqlite://" - ) - session.add(new_database) - session.flush() - sqla_table.database = new_database - sqla_table.table_name = "new_table" - session.flush() - - assert session.query(Dataset).count() == 1 - assert session.query(Table).count() == 2 - # is kept for the old table - # is kept for the updated dataset - # is created for the new table - assert session.query(Column).count() == 3 - - # ignore these keys when comparing results - ignored_keys = {"created_on", "changed_on", "uuid"} - - # check that the old table still exists, and that the dataset points to the newly - # created table, column and dataset - table_schema = TableSchema() - tables = [ - {k: v for k, v in table_schema.dump(table).items() if k not in ignored_keys} - for table in session.query(Table).all() - ] - assert tables[0] == { - "id": 1, - "database": 1, - "columns": [1], - "datasets": [], - "created_by": None, - "changed_by": None, - "extra_json": "{}", - "catalog": None, - "schema": None, - "name": "original_table", - "external_url": None, - "is_managed_externally": False, - } - assert tables[1] == { - "id": 2, - "database": 2, - "datasets": [1], - "columns": [3], - "created_by": None, - "changed_by": None, - "catalog": None, - "schema": None, - "name": "new_table", - "is_managed_externally": False, - "extra_json": "{}", - "external_url": None, - } - - # check that dataset now points to the new table - assert dataset.tables[0].database_id == 2 - # and a new column is created - assert len(dataset.columns) == 1 - assert dataset.columns[0].id == 2 - - # point ``SqlaTable`` back - sqla_table.database = original_database - sqla_table.table_name = "original_table" - session.flush() - - # should not create more table and datasets - assert session.query(Dataset).count() == 1 - assert session.query(Table).count() == 2 - # is deleted for the old table - # is kept for the updated dataset - # is kept for the new table - assert session.query(Column.id).order_by(Column.id).all() == [ - (1,), - (2,), - (3,), - ] - assert session.query(dataset_column_association_table).all() == [(1, 2)] - assert session.query(table_column_association_table).all() == [(1, 1), (2, 3)] - assert session.query(Dataset).filter_by(id=1).one().columns[0].id == 2 - assert session.query(Table).filter_by(id=2).one().columns[0].id == 3 - assert session.query(Table).filter_by(id=1).one().columns[0].id == 1 - - # the dataset points back to the original table - assert dataset.tables[0].database_id == 1 - assert dataset.tables[0].name == "original_table" - - # kept the original column - assert dataset.columns[0].id == 2 - session.commit() - session.close() - - # querying in a new session should still return the same result - session = get_session() - assert session.query(table_column_association_table).all() == [(1, 1), (2, 3)] - - -def test_update_virtual_sqlatable_references( - mocker: MockFixture, session: Session -) -> None: - """ - Test that changing the SQL of a virtual ``SqlaTable`` updates ``Dataset``. - - When the SQL is modified the list of referenced tables should be updated in the new - ``Dataset`` model. - """ - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", return_value=session - ) - - from superset.columns.models import Column - from superset.connectors.sqla.models import SqlaTable, TableColumn - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - table1 = Table( - name="table_a", - schema="my_schema", - catalog=None, - database=database, - columns=[Column(name="a", type="INTEGER")], - ) - table2 = Table( - name="table_b", - schema="my_schema", - catalog=None, - database=database, - columns=[Column(name="b", type="INTEGER")], - ) - session.add(table1) - session.add(table2) - session.commit() - - # create virtual dataset - columns = [TableColumn(column_name="a", type="INTEGER")] - - sqla_table = SqlaTable( - table_name="old_dataset", - columns=columns, - database=database, - schema="my_schema", - sql="SELECT a FROM table_a", - ) - session.add(sqla_table) - session.flush() - - # check that new dataset has table1 - dataset: Dataset = session.query(Dataset).one() - assert dataset.tables == [table1] - - # change SQL - sqla_table.sql = "SELECT a, b FROM table_a JOIN table_b" - session.flush() - - # check that new dataset has both tables - new_dataset: Dataset = session.query(Dataset).one() - assert new_dataset.tables == [table1, table2] - assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b" - - # automatically add new referenced table - sqla_table.sql = "SELECT a, b, c FROM table_a JOIN table_b JOIN table_c" - session.flush() - - new_dataset = session.query(Dataset).one() - assert len(new_dataset.tables) == 3 - assert new_dataset.tables[2].name == "table_c" - - -def test_quote_expressions(session: Session) -> None: - """ - Test that expressions are quoted appropriately in columns and datasets. - """ - from superset.connectors.sqla.models import SqlaTable, TableColumn - from superset.datasets.models import Dataset - from superset.models.core import Database - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="has space", type="INTEGER"), - TableColumn(column_name="no_need", type="INTEGER"), - ] - - sqla_table = SqlaTable( - table_name="old dataset", - columns=columns, - metrics=[], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - ) - session.add(sqla_table) - session.flush() - - dataset = session.query(Dataset).one() - assert dataset.expression == '"old dataset"' - assert dataset.columns[0].expression == '"has space"' - assert dataset.columns[1].expression == "no_need" diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py index cb313b3dd..16334066d 100644 --- a/tests/unit_tests/datasource/dao_tests.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -144,15 +144,13 @@ def test_get_datasource_sl_table(session_with_data: Session) -> None: from superset.datasource.dao import DatasourceDAO from superset.tables.models import Table - # todo(hugh): This will break once we remove the dual write - # update the datsource_id=1 and this will pass again result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.SLTABLE, - datasource_id=2, + datasource_id=1, session=session_with_data, ) - assert result.id == 2 + assert result.id == 1 assert isinstance(result, Table) @@ -160,15 +158,13 @@ def test_get_datasource_sl_dataset(session_with_data: Session) -> None: from superset.datasets.models import Dataset from superset.datasource.dao import DatasourceDAO - # todo(hugh): This will break once we remove the dual write - # update the datsource_id=1 and this will pass again result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.DATASET, - datasource_id=2, + datasource_id=1, session=session_with_data, ) - assert result.id == 2 + assert result.id == 1 assert isinstance(result, Dataset)