fix: Ensure table uniqueness on update (#15909)

* fix: Ensure table uniqueness on update

* Update models.py

* Update slice.py

* Update datasource_tests.py

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2021-08-02 12:45:55 -07:00 committed by GitHub
parent 76a13dfc9a
commit c0615c55df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 344 additions and 274 deletions

View File

@ -23,6 +23,9 @@ This file documents any backwards-incompatible changes in Superset and
assists people when migrating to a new version. assists people when migrating to a new version.
## Next ## Next
- [15909](https://github.com/apache/incubator-superset/pull/15909): a change which
drops a uniqueness criterion (which may or may not have existed) to the tables table. This constraint was obsolete as it is handled by the ORM due to differences in how MySQL, PostgreSQL, etc. handle uniqueness for NULL values.
- [15927](https://github.com/apache/superset/pull/15927): Upgrades Celery to 5.x. Per the [upgrading](https://docs.celeryproject.org/en/stable/history/whatsnew-5.0.html#upgrading-from-celery-4-x) instructions Celery 5.0 introduces a new CLI implementation which is not completely backwards compatible. Please ensure global options are positioned before the sub-command. - [15927](https://github.com/apache/superset/pull/15927): Upgrades Celery to 5.x. Per the [upgrading](https://docs.celeryproject.org/en/stable/history/whatsnew-5.0.html#upgrading-from-celery-4-x) instructions Celery 5.0 introduces a new CLI implementation which is not completely backwards compatible. Please ensure global options are positioned before the sub-command.
- [13772](https://github.com/apache/superset/pull/13772): Row level security (RLS) is now enabled by default. To activate the feature, please run `superset init` to expose the RLS menus to Admin users. - [13772](https://github.com/apache/superset/pull/13772): Row level security (RLS) is now enabled by default. To activate the feature, please run `superset init` to expose the RLS menus to Admin users.

View File

@ -483,7 +483,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
owner_class = security_manager.user_model owner_class = security_manager.user_model
__tablename__ = "tables" __tablename__ = "tables"
__table_args__ = (UniqueConstraint("database_id", "table_name"),)
# Note this uniqueness constraint is not part of the physical schema, i.e., it does
# not exist in the migrations, but is required by `import_from_dict` to ensure the
# correct filters are applied in order to identify uniqueness.
#
# The reason it does not physically exist is MySQL, PostgreSQL, etc. have a
# different interpretation of uniqueness when it comes to NULL which is problematic
# given the schema is optional.
__table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),)
table_name = Column(String(250), nullable=False) table_name = Column(String(250), nullable=False)
main_dttm_col = Column(String(250)) main_dttm_col = Column(String(250))
@ -1604,6 +1612,47 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
extra_cache_keys += sqla_query.extra_cache_keys extra_cache_keys += sqla_query.extra_cache_keys
return extra_cache_keys return extra_cache_keys
@staticmethod
def before_update(
mapper: Mapper, # pylint: disable=unused-argument
connection: Connection, # pylint: disable=unused-argument
target: "SqlaTable",
) -> None:
"""
Check whether before update if the target table already exists.
Note this listener is called when any fields are being updated and thus it is
necessary to first check whether the reference table is being updated.
Note this logic is temporary, given uniqueness is handled via the dataset DAO,
but is necessary until both the legacy datasource editor and datasource/save
endpoints are deprecated.
:param mapper: The table mapper
:param connection: The DB-API connection
:param target: The mapped instance being persisted
:raises Exception: If the target table is not unique
"""
from superset.datasets.commands.exceptions import get_dataset_exist_error_msg
from superset.datasets.dao import DatasetDAO
# Check whether the relevant attributes have changed.
state = db.inspect(target) # pylint: disable=no-member
for attr in ["database_id", "schema", "table_name"]:
history = state.get_history(attr, True)
if history.has_changes():
break
else:
return None
if not DatasetDAO.validate_uniqueness(
target.database_id, target.schema, target.table_name
):
raise Exception(get_dataset_exist_error_msg(target.full_name))
def update_table( def update_table(
_mapper: Mapper, _connection: Connection, obj: Union[SqlMetric, TableColumn] _mapper: Mapper, _connection: Connection, obj: Union[SqlMetric, TableColumn]
@ -1621,6 +1670,7 @@ def update_table(
sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm) sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
sa.event.listen(SqlaTable, "after_update", security_manager.set_perm) sa.event.listen(SqlaTable, "after_update", security_manager.set_perm)
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlMetric, "after_update", update_table) sa.event.listen(SqlMetric, "after_update", update_table)
sa.event.listen(TableColumn, "after_update", update_table) sa.event.listen(TableColumn, "after_update", update_table)

View File

@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""drop tables constraint
Revision ID: 31b2a1039d4a
Revises: ae1ed299413b
Create Date: 2021-07-27 08:25:20.755453
"""
from alembic import op
from sqlalchemy import engine
from sqlalchemy.exc import OperationalError, ProgrammingError
from superset.utils.core import generic_find_uq_constraint_name
# revision identifiers, used by Alembic.
revision = "31b2a1039d4a"
down_revision = "ae1ed299413b"
conv = {"uq": "uq_%(table_name)s_%(column_0_name)s"}
def upgrade():
bind = op.get_bind()
insp = engine.reflection.Inspector.from_engine(bind)
# Drop the uniqueness constraint if it exists.
constraint = generic_find_uq_constraint_name("tables", {"table_name"}, insp)
if constraint:
with op.batch_alter_table("tables", naming_convention=conv) as batch_op:
batch_op.drop_constraint(constraint, type_="unique")
def downgrade():
# One cannot simply re-add the uniqueness constraint as it may not have previously
# existed.
pass

View File

@ -53,10 +53,9 @@ slice_user = Table(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Slice( class Slice( # pylint: disable=too-many-instance-attributes,too-many-public-methods
Model, AuditMixinNullable, ImportExportMixin Model, AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods, too-many-instance-attributes ):
"""A slice is essentially a report or a view on data""" """A slice is essentially a report or a view on data"""
__tablename__ = "slices" __tablename__ = "slices"

View File

@ -161,7 +161,7 @@ class TestRequestAccess(SupersetTestCase):
updated_override_me = security_manager.find_role("override_me") updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions)) self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table_by_name("birth_names") birth_names = self.get_table(name="birth_names")
self.assertEqual( self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name birth_names.perm, updated_override_me.permissions[0].view_menu.name
) )
@ -190,7 +190,7 @@ class TestRequestAccess(SupersetTestCase):
"datasource_access", updated_role.permissions[1].permission.name "datasource_access", updated_role.permissions[1].permission.name
) )
birth_names = self.get_table_by_name("birth_names") birth_names = self.get_table(name="birth_names")
self.assertEqual(birth_names.perm, perms[2].view_menu.name) self.assertEqual(birth_names.perm, perms[2].view_menu.name)
self.assertEqual( self.assertEqual(
"datasource_access", updated_role.permissions[2].permission.name "datasource_access", updated_role.permissions[2].permission.name
@ -204,7 +204,7 @@ class TestRequestAccess(SupersetTestCase):
override_me = security_manager.find_role("override_me") override_me = security_manager.find_role("override_me")
override_me.permissions.append( override_me.permissions.append(
security_manager.find_permission_view_menu( security_manager.find_permission_view_menu(
view_menu_name=self.get_table_by_name("energy_usage").perm, view_menu_name=self.get_table(name="energy_usage").perm,
permission_name="datasource_access", permission_name="datasource_access",
) )
) )
@ -218,7 +218,7 @@ class TestRequestAccess(SupersetTestCase):
self.assertEqual(201, response.status_code) self.assertEqual(201, response.status_code)
updated_override_me = security_manager.find_role("override_me") updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions)) self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table_by_name("birth_names") birth_names = self.get_table(name="birth_names")
self.assertEqual( self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name birth_names.perm, updated_override_me.permissions[0].view_menu.name
) )

View File

@ -99,10 +99,6 @@ def post_assert_metric(
return rv return rv
def get_table_by_name(name: str) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(table_name=name).one()
@pytest.fixture @pytest.fixture
def logged_in_admin(): def logged_in_admin():
"""Fixture with app context and logged in admin user.""" """Fixture with app context and logged in admin user."""
@ -132,12 +128,7 @@ class SupersetTestCase(TestCase):
@staticmethod @staticmethod
def get_birth_names_dataset() -> SqlaTable: def get_birth_names_dataset() -> SqlaTable:
example_db = get_example_database() return SupersetTestCase.get_table(name="birth_names")
return (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)
@staticmethod @staticmethod
def create_user_with_roles( def create_user_with_roles(
@ -254,13 +245,31 @@ class SupersetTestCase(TestCase):
return slc return slc
@staticmethod @staticmethod
def get_table_by_name(name: str) -> SqlaTable: def get_table(
return get_table_by_name(name) name: str, database_id: Optional[int] = None, schema: Optional[str] = None
) -> SqlaTable:
return (
db.session.query(SqlaTable)
.filter_by(
database_id=database_id
or SupersetTestCase.get_database_by_name("examples").id,
schema=schema,
table_name=name,
)
.one()
)
@staticmethod @staticmethod
def get_database_by_id(db_id: int) -> Database: def get_database_by_id(db_id: int) -> Database:
return db.session.query(Database).filter_by(id=db_id).one() return db.session.query(Database).filter_by(id=db_id).one()
@staticmethod
def get_database_by_name(database_name: str = "main") -> Database:
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")
@staticmethod @staticmethod
def get_druid_ds_by_name(name: str) -> DruidDatasource: def get_druid_ds_by_name(name: str) -> DruidDatasource:
return db.session.query(DruidDatasource).filter_by(datasource_name=name).first() return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
@ -340,12 +349,6 @@ class SupersetTestCase(TestCase):
): ):
security_manager.del_permission_role(public_role, perm) security_manager.del_permission_role(public_role, perm)
def _get_database_by_name(self, database_name="main"):
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")
def run_sql( def run_sql(
self, self,
sql, sql,
@ -364,7 +367,7 @@ class SupersetTestCase(TestCase):
if user_name: if user_name:
self.logout() self.logout()
self.login(username=(user_name or "admin")) self.login(username=(user_name or "admin"))
dbid = self._get_database_by_name(database_name).id dbid = SupersetTestCase.get_database_by_name(database_name).id
json_payload = { json_payload = {
"database_id": dbid, "database_id": dbid,
"sql": sql, "sql": sql,
@ -448,7 +451,7 @@ class SupersetTestCase(TestCase):
if user_name: if user_name:
self.logout() self.logout()
self.login(username=(user_name if user_name else "admin")) self.login(username=(user_name if user_name else "admin"))
dbid = self._get_database_by_name(database_name).id dbid = SupersetTestCase.get_database_by_name(database_name).id
resp = self.get_json_resp( resp = self.get_json_resp(
"/superset/validate_sql_json/", "/superset/validate_sql_json/",
raise_on_error=False, raise_on_error=False,

View File

@ -545,7 +545,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
""" """
admin = self.get_user("admin") admin = self.get_user("admin")
gamma = self.get_user("gamma") gamma = self.get_user("gamma")
birth_names_table_id = SupersetTestCase.get_table_by_name("birth_names").id birth_names_table_id = SupersetTestCase.get_table(name="birth_names").id
chart_id = self.insert_chart( chart_id = self.insert_chart(
"title", [admin.id], birth_names_table_id, admin "title", [admin.id], birth_names_table_id, admin
).id ).id

View File

@ -221,7 +221,7 @@ def test_import_csv_explore_database(setup_csv_upload, create_csv_files):
f'CSV file "{CSV_FILENAME1}" uploaded to table "{CSV_UPLOAD_TABLE_W_EXPLORE}"' f'CSV file "{CSV_FILENAME1}" uploaded to table "{CSV_UPLOAD_TABLE_W_EXPLORE}"'
in resp in resp
) )
table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE_W_EXPLORE) table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE_W_EXPLORE)
assert table.database_id == utils.get_example_database().id assert table.database_id == utils.get_example_database().id
@ -267,7 +267,7 @@ def test_import_csv(setup_csv_upload, create_csv_files):
) )
assert success_msg_f2 in resp assert success_msg_f2 in resp
table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE) table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE)
# make sure the new column name is reflected in the table metadata # make sure the new column name is reflected in the table metadata
assert "d" in table.column_names assert "d" in table.column_names

View File

@ -35,6 +35,7 @@ def create_table_for_dashboard(
dtype: Dict[str, Any], dtype: Dict[str, Any],
table_description: str = "", table_description: str = "",
fetch_values_predicate: Optional[str] = None, fetch_values_predicate: Optional[str] = None,
schema: Optional[str] = None,
) -> SqlaTable: ) -> SqlaTable:
df.to_sql( df.to_sql(
table_name, table_name,
@ -44,14 +45,17 @@ def create_table_for_dashboard(
dtype=dtype, dtype=dtype,
index=False, index=False,
method="multi", method="multi",
schema=schema,
) )
table_source = ConnectorRegistry.sources["table"] table_source = ConnectorRegistry.sources["table"]
table = ( table = (
db.session.query(table_source).filter_by(table_name=table_name).one_or_none() db.session.query(table_source)
.filter_by(database_id=database.id, schema=schema, table_name=table_name)
.one_or_none()
) )
if not table: if not table:
table = table_source(table_name=table_name) table = table_source(schema=schema, table_name=table_name)
if fetch_values_predicate: if fetch_values_predicate:
table.fetch_values_predicate = fetch_values_predicate table.fetch_values_predicate = fetch_values_predicate
table.database = database table.database = database

View File

@ -63,10 +63,10 @@ class TestDatasetApi(SupersetTestCase):
@staticmethod @staticmethod
def insert_dataset( def insert_dataset(
table_name: str, table_name: str,
schema: str,
owners: List[int], owners: List[int],
database: Database, database: Database,
sql: Optional[str] = None, sql: Optional[str] = None,
schema: Optional[str] = None,
) -> SqlaTable: ) -> SqlaTable:
obj_owners = list() obj_owners = list()
for owner in owners: for owner in owners:
@ -86,7 +86,7 @@ class TestDatasetApi(SupersetTestCase):
def insert_default_dataset(self): def insert_default_dataset(self):
return self.insert_dataset( return self.insert_dataset(
"ab_permission", "", [self.get_user("admin").id], get_main_database() "ab_permission", [self.get_user("admin").id], get_main_database()
) )
def get_fixture_datasets(self) -> List[SqlaTable]: def get_fixture_datasets(self) -> List[SqlaTable]:
@ -105,11 +105,7 @@ class TestDatasetApi(SupersetTestCase):
for table_name in self.fixture_virtual_table_names: for table_name in self.fixture_virtual_table_names:
datasets.append( datasets.append(
self.insert_dataset( self.insert_dataset(
table_name, table_name, [admin.id], main_db, "SELECT * from ab_view_menu;",
"",
[admin.id],
main_db,
"SELECT * from ab_view_menu;",
) )
) )
yield datasets yield datasets
@ -126,9 +122,7 @@ class TestDatasetApi(SupersetTestCase):
admin = self.get_user("admin") admin = self.get_user("admin")
main_db = get_main_database() main_db = get_main_database()
for tables_name in self.fixture_tables_names: for tables_name in self.fixture_tables_names:
datasets.append( datasets.append(self.insert_dataset(tables_name, [admin.id], main_db))
self.insert_dataset(tables_name, "", [admin.id], main_db)
)
yield datasets yield datasets
# rollback changes # rollback changes
@ -270,11 +264,13 @@ class TestDatasetApi(SupersetTestCase):
datasets = [] datasets = []
if example_db.backend == "postgresql": if example_db.backend == "postgresql":
datasets.append( datasets.append(
self.insert_dataset("ab_permission", "public", [], get_main_database()) self.insert_dataset(
"ab_permission", [], get_main_database(), schema="public"
)
) )
datasets.append( datasets.append(
self.insert_dataset( self.insert_dataset(
"columns", "information_schema", [], get_main_database() "columns", [], get_main_database(), schema="information_schema",
) )
) )
schema_values = [ schema_values = [
@ -921,7 +917,7 @@ class TestDatasetApi(SupersetTestCase):
dataset = self.insert_default_dataset() dataset = self.insert_default_dataset()
self.login(username="admin") self.login(username="admin")
ab_user = self.insert_dataset( ab_user = self.insert_dataset(
"ab_user", "", [self.get_user("admin").id], get_main_database() "ab_user", [self.get_user("admin").id], get_main_database()
) )
table_data = {"table_name": "ab_user"} table_data = {"table_name": "ab_user"}
uri = f"api/v1/dataset/{dataset.id}" uri = f"api/v1/dataset/{dataset.id}"

View File

@ -17,7 +17,6 @@
"""Unit tests for Superset""" """Unit tests for Superset"""
import json import json
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy
from unittest import mock from unittest import mock
import pytest import pytest
@ -28,12 +27,11 @@ from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetGenericDBErrorException from superset.exceptions import SupersetGenericDBErrorException
from superset.models.core import Database from superset.models.core import Database
from superset.utils.core import get_example_database from superset.utils.core import get_example_database
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import ( from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, load_birth_names_dashboard_with_slices,
) )
from tests.integration_tests.fixtures.datasource import get_datasource_post
from .base_tests import db_insert_temp_object, SupersetTestCase
from .fixtures.datasource import datasource_post
@contextmanager @contextmanager
@ -54,20 +52,15 @@ def create_test_table_context(database: Database):
class TestDatasource(SupersetTestCase): class TestDatasource(SupersetTestCase):
def setUp(self): def setUp(self):
self.original_attrs = {} db.session.begin(subtransactions=True)
self.datasource = None
def tearDown(self): def tearDown(self):
if self.datasource: db.session.rollback()
for key, value in self.original_attrs.items():
setattr(self.datasource, key, value)
db.session.commit()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_external_metadata_for_physical_table(self): def test_external_metadata_for_physical_table(self):
self.login(username="admin") self.login(username="admin")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
url = f"/datasource/external_metadata/table/{tbl.id}/" url = f"/datasource/external_metadata/table/{tbl.id}/"
resp = self.get_json_resp(url) resp = self.get_json_resp(url)
col_names = {o.get("name") for o in resp} col_names = {o.get("name") for o in resp}
@ -86,7 +79,7 @@ class TestDatasource(SupersetTestCase):
session.add(table) session.add(table)
session.commit() session.commit()
table = self.get_table_by_name("dummy_sql_table") table = self.get_table(name="dummy_sql_table")
url = f"/datasource/external_metadata/table/{table.id}/" url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url) resp = self.get_json_resp(url)
assert {o.get("name") for o in resp} == {"intcol", "strcol"} assert {o.get("name") for o in resp} == {"intcol", "strcol"}
@ -96,7 +89,7 @@ class TestDatasource(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_external_metadata_by_name_for_physical_table(self): def test_external_metadata_by_name_for_physical_table(self):
self.login(username="admin") self.login(username="admin")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
# empty schema need to be represented by undefined # empty schema need to be represented by undefined
url = ( url = (
f"/datasource/external_metadata_by_name/table/" f"/datasource/external_metadata_by_name/table/"
@ -119,7 +112,7 @@ class TestDatasource(SupersetTestCase):
session.add(table) session.add(table)
session.commit() session.commit()
table = self.get_table_by_name("dummy_sql_table") table = self.get_table(name="dummy_sql_table")
# empty schema need to be represented by undefined # empty schema need to be represented by undefined
url = ( url = (
f"/datasource/external_metadata_by_name/table/" f"/datasource/external_metadata_by_name/table/"
@ -160,7 +153,7 @@ class TestDatasource(SupersetTestCase):
session.add(table) session.add(table)
session.commit() session.commit()
table = self.get_table_by_name("dummy_sql_table_with_template_params") table = self.get_table(name="dummy_sql_table_with_template_params")
url = f"/datasource/external_metadata/table/{table.id}/" url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url) resp = self.get_json_resp(url)
assert {o.get("name") for o in resp} == {"intcol"} assert {o.get("name") for o in resp} == {"intcol"}
@ -196,7 +189,7 @@ class TestDatasource(SupersetTestCase):
@mock.patch("superset.connectors.sqla.models.SqlaTable.external_metadata") @mock.patch("superset.connectors.sqla.models.SqlaTable.external_metadata")
def test_external_metadata_error_return_400(self, mock_get_datasource): def test_external_metadata_error_return_400(self, mock_get_datasource):
self.login(username="admin") self.login(username="admin")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
url = f"/datasource/external_metadata/table/{tbl.id}/" url = f"/datasource/external_metadata/table/{tbl.id}/"
mock_get_datasource.side_effect = SupersetGenericDBErrorException("oops") mock_get_datasource.side_effect = SupersetGenericDBErrorException("oops")
@ -221,13 +214,9 @@ class TestDatasource(SupersetTestCase):
def test_save(self): def test_save(self):
self.login(username="admin") self.login(username="admin")
tbl_id = self.get_table_by_name("birth_names").id tbl_id = self.get_table(name="birth_names").id
self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)
for key in self.datasource.export_fields:
self.original_attrs[key] = getattr(self.datasource, key)
datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id datasource_post["id"] = tbl_id
data = dict(data=json.dumps(datasource_post)) data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data) resp = self.get_json_resp("/datasource/save/", data)
@ -241,25 +230,21 @@ class TestDatasource(SupersetTestCase):
else: else:
self.assertEqual(resp[k], datasource_post[k]) self.assertEqual(resp[k], datasource_post[k])
def save_datasource_from_dict(self, datasource_dict): def save_datasource_from_dict(self, datasource_post):
data = dict(data=json.dumps(datasource_post)) data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data) resp = self.get_json_resp("/datasource/save/", data)
return resp return resp
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_change_database(self): def test_change_database(self):
self.login(username="admin") self.login(username="admin")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
tbl_id = tbl.id tbl_id = tbl.id
db_id = tbl.database_id db_id = tbl.database_id
datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id datasource_post["id"] = tbl_id
self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)
for key in self.datasource.export_fields:
self.original_attrs[key] = getattr(self.datasource, key)
new_db = self.create_fake_db() new_db = self.create_fake_db()
datasource_post["database"]["id"] = new_db.id datasource_post["database"]["id"] = new_db.id
resp = self.save_datasource_from_dict(datasource_post) resp = self.save_datasource_from_dict(datasource_post)
self.assertEqual(resp["database"]["id"], new_db.id) self.assertEqual(resp["database"]["id"], new_db.id)
@ -272,15 +257,11 @@ class TestDatasource(SupersetTestCase):
def test_save_duplicate_key(self): def test_save_duplicate_key(self):
self.login(username="admin") self.login(username="admin")
tbl_id = self.get_table_by_name("birth_names").id tbl_id = self.get_table(name="birth_names").id
self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)
for key in self.datasource.export_fields: datasource_post = get_datasource_post()
self.original_attrs[key] = getattr(self.datasource, key) datasource_post["id"] = tbl_id
datasource_post["columns"].extend(
datasource_post_copy = deepcopy(datasource_post)
datasource_post_copy["id"] = tbl_id
datasource_post_copy["columns"].extend(
[ [
{ {
"column_name": "<new column>", "column_name": "<new column>",
@ -298,18 +279,15 @@ class TestDatasource(SupersetTestCase):
}, },
] ]
) )
data = dict(data=json.dumps(datasource_post_copy)) data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False) resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False)
self.assertIn("Duplicate column name(s): <new column>", resp["error"]) self.assertIn("Duplicate column name(s): <new column>", resp["error"])
def test_get_datasource(self): def test_get_datasource(self):
self.login(username="admin") self.login(username="admin")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
self.datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)
for key in self.datasource.export_fields:
self.original_attrs[key] = getattr(self.datasource, key)
datasource_post = get_datasource_post()
datasource_post["id"] = tbl.id datasource_post["id"] = tbl.id
data = dict(data=json.dumps(datasource_post)) data = dict(data=json.dumps(datasource_post))
self.get_json_resp("/datasource/save/", data) self.get_json_resp("/datasource/save/", data)
@ -337,7 +315,7 @@ class TestDatasource(SupersetTestCase):
app.config["DATASET_HEALTH_CHECK"] = my_check app.config["DATASET_HEALTH_CHECK"] = my_check
self.login(username="admin") self.login(username="admin")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session) datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)
assert datasource.health_check_message == "Warning message!" assert datasource.health_check_message == "Warning message!"
app.config["DATASET_HEALTH_CHECK"] = None app.config["DATASET_HEALTH_CHECK"] = None

View File

@ -66,7 +66,7 @@ class TestDictImportExport(SupersetTestCase):
cls.delete_imports() cls.delete_imports()
def create_table( def create_table(
self, name, schema="", id=0, cols_names=[], cols_uuids=None, metric_names=[] self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[]
): ):
database_name = "main" database_name = "main"
name = "{0}{1}".format(NAME_PREFIX, name) name = "{0}{1}".format(NAME_PREFIX, name)
@ -128,9 +128,6 @@ class TestDictImportExport(SupersetTestCase):
def get_datasource(self, datasource_id): def get_datasource(self, datasource_id):
return db.session.query(DruidDatasource).filter_by(id=datasource_id).first() return db.session.query(DruidDatasource).filter_by(id=datasource_id).first()
def get_table_by_name(self, name):
return db.session.query(SqlaTable).filter_by(table_name=name).first()
def yaml_compare(self, obj_1, obj_2): def yaml_compare(self, obj_1, obj_2):
obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False) obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False)
obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False) obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False)

View File

@ -15,138 +15,142 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Fixtures for test_datasource.py""" """Fixtures for test_datasource.py"""
datasource_post = { from typing import Any, Dict
"id": None,
"column_formats": {"ratio": ".2%"},
"database": {"id": 1}, def get_datasource_post() -> Dict[str, Any]:
"description": "Adding a DESCRip", return {
"default_endpoint": "", "id": None,
"filter_select_enabled": True, "column_formats": {"ratio": ".2%"},
"name": "birth_names", "database": {"id": 1},
"table_name": "birth_names", "description": "Adding a DESCRip",
"datasource_name": "birth_names", "default_endpoint": "",
"type": "table", "filter_select_enabled": True,
"schema": "", "name": "birth_names",
"offset": 66, "table_name": "birth_names",
"cache_timeout": 55, "datasource_name": "birth_names",
"sql": "", "type": "table",
"columns": [ "schema": None,
{ "offset": 66,
"id": 504, "cache_timeout": 55,
"column_name": "ds", "sql": "",
"verbose_name": "", "columns": [
"description": None, {
"expression": "", "id": 504,
"filterable": True, "column_name": "ds",
"groupby": True, "verbose_name": "",
"is_dttm": True, "description": None,
"type": "DATETIME", "expression": "",
}, "filterable": True,
{ "groupby": True,
"id": 505, "is_dttm": True,
"column_name": "gender", "type": "DATETIME",
"verbose_name": None, },
"description": None, {
"expression": "", "id": 505,
"filterable": True, "column_name": "gender",
"groupby": True, "verbose_name": None,
"is_dttm": False, "description": None,
"type": "VARCHAR(16)", "expression": "",
}, "filterable": True,
{ "groupby": True,
"id": 506, "is_dttm": False,
"column_name": "name", "type": "VARCHAR(16)",
"verbose_name": None, },
"description": None, {
"expression": None, "id": 506,
"filterable": True, "column_name": "name",
"groupby": True, "verbose_name": None,
"is_dttm": None, "description": None,
"type": "VARCHAR(255)", "expression": None,
}, "filterable": True,
{ "groupby": True,
"id": 508, "is_dttm": None,
"column_name": "state", "type": "VARCHAR(255)",
"verbose_name": None, },
"description": None, {
"expression": None, "id": 508,
"filterable": True, "column_name": "state",
"groupby": True, "verbose_name": None,
"is_dttm": None, "description": None,
"type": "VARCHAR(10)", "expression": None,
}, "filterable": True,
{ "groupby": True,
"id": 509, "is_dttm": None,
"column_name": "num_boys", "type": "VARCHAR(10)",
"verbose_name": None, },
"description": None, {
"expression": None, "id": 509,
"filterable": True, "column_name": "num_boys",
"groupby": True, "verbose_name": None,
"is_dttm": None, "description": None,
"type": "BIGINT(20)", "expression": None,
}, "filterable": True,
{ "groupby": True,
"id": 510, "is_dttm": None,
"column_name": "num_girls", "type": "BIGINT(20)",
"verbose_name": None, },
"description": None, {
"expression": "", "id": 510,
"filterable": False, "column_name": "num_girls",
"groupby": False, "verbose_name": None,
"is_dttm": False, "description": None,
"type": "BIGINT(20)", "expression": "",
}, "filterable": False,
{ "groupby": False,
"id": 532, "is_dttm": False,
"column_name": "num", "type": "BIGINT(20)",
"verbose_name": None, },
"description": None, {
"expression": None, "id": 532,
"filterable": True, "column_name": "num",
"groupby": True, "verbose_name": None,
"is_dttm": None, "description": None,
"type": "BIGINT(20)", "expression": None,
}, "filterable": True,
{ "groupby": True,
"id": 522, "is_dttm": None,
"column_name": "num_california", "type": "BIGINT(20)",
"verbose_name": None, },
"description": None, {
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", "id": 522,
"filterable": False, "column_name": "num_california",
"groupby": False, "verbose_name": None,
"is_dttm": False, "description": None,
"type": "NUMBER", "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
}, "filterable": False,
], "groupby": False,
"metrics": [ "is_dttm": False,
{ "type": "NUMBER",
"id": 824, },
"metric_name": "sum__num", ],
"verbose_name": "Babies", "metrics": [
"description": "", {
"expression": "SUM(num)", "id": 824,
"warning_text": "", "metric_name": "sum__num",
"d3format": "", "verbose_name": "Babies",
}, "description": "",
{ "expression": "SUM(num)",
"id": 836, "warning_text": "",
"metric_name": "count", "d3format": "",
"verbose_name": "", },
"description": None, {
"expression": "count(1)", "id": 836,
"warning_text": None, "metric_name": "count",
"d3format": None, "verbose_name": "",
}, "description": None,
{ "expression": "count(1)",
"id": 843, "warning_text": None,
"metric_name": "ratio", "d3format": None,
"verbose_name": "Ratio Boys/Girls", },
"description": "This represents the ratio of boys/girls", {
"expression": "sum(num_boys) / sum(num_girls)", "id": 843,
"warning_text": "no warning", "metric_name": "ratio",
"d3format": ".2%", "verbose_name": "Ratio Boys/Girls",
}, "description": "This represents the ratio of boys/girls",
], "expression": "sum(num_boys) / sum(num_girls)",
} "warning_text": "no warning",
"d3format": ".2%",
},
],
}

View File

@ -18,7 +18,7 @@ import copy
from typing import Any, Dict, List from typing import Any, Dict, List
from superset.utils.core import AnnotationType, DTTM_ALIAS, TimeRangeEndpoint from superset.utils.core import AnnotationType, DTTM_ALIAS, TimeRangeEndpoint
from tests.integration_tests.base_tests import get_table_by_name from tests.integration_tests.base_tests import SupersetTestCase
query_birth_names = { query_birth_names = {
"extras": { "extras": {
@ -245,7 +245,7 @@ def get_query_context(
:return: Request payload :return: Request payload
""" """
table_name = query_name.split(":")[0] table_name = query_name.split(":")[0]
table = get_table_by_name(table_name) table = SupersetTestCase.get_table(name=table_name)
return { return {
"datasource": {"id": table.id, "type": table.type}, "datasource": {"id": table.id, "type": table.type},
"queries": [ "queries": [

View File

@ -89,19 +89,20 @@ class TestImportExport(SupersetTestCase):
id=None, id=None,
db_name="examples", db_name="examples",
table_name="wb_health_population", table_name="wb_health_population",
schema=None,
): ):
params = { params = {
"num_period_compare": "10", "num_period_compare": "10",
"remote_id": id, "remote_id": id,
"datasource_name": table_name, "datasource_name": table_name,
"database_name": db_name, "database_name": db_name,
"schema": "", "schema": schema,
# Test for trailing commas # Test for trailing commas
"metrics": ["sum__signup_attempt_email", "sum__signup_attempt_facebook"], "metrics": ["sum__signup_attempt_email", "sum__signup_attempt_facebook"],
} }
if table_name and not ds_id: if table_name and not ds_id:
table = self.get_table_by_name(table_name) table = self.get_table(schema=schema, name=table_name)
if table: if table:
ds_id = table.id ds_id = table.id
@ -167,9 +168,6 @@ class TestImportExport(SupersetTestCase):
def get_datasource(self, datasource_id): def get_datasource(self, datasource_id):
return db.session.query(DruidDatasource).filter_by(id=datasource_id).first() return db.session.query(DruidDatasource).filter_by(id=datasource_id).first()
def get_table_by_name(self, name):
return db.session.query(SqlaTable).filter_by(table_name=name).first()
def assert_dash_equals( def assert_dash_equals(
self, expected_dash, actual_dash, check_position=True, check_slugs=True self, expected_dash, actual_dash, check_position=True, check_slugs=True
): ):
@ -273,9 +271,7 @@ class TestImportExport(SupersetTestCase):
resp.data.decode("utf-8"), object_hook=decode_dashboards resp.data.decode("utf-8"), object_hook=decode_dashboards
)["datasources"] )["datasources"]
self.assertEqual(1, len(exported_tables)) self.assertEqual(1, len(exported_tables))
self.assert_table_equals( self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
self.get_table_by_name("birth_names"), exported_tables[0]
)
@pytest.mark.usefixtures( @pytest.mark.usefixtures(
"load_world_bank_dashboard_with_slices", "load_world_bank_dashboard_with_slices",
@ -314,11 +310,9 @@ class TestImportExport(SupersetTestCase):
resp_data.get("datasources"), key=lambda t: t.table_name resp_data.get("datasources"), key=lambda t: t.table_name
) )
self.assertEqual(2, len(exported_tables)) self.assertEqual(2, len(exported_tables))
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
self.assert_table_equals( self.assert_table_equals(
self.get_table_by_name("birth_names"), exported_tables[0] self.get_table(name="wb_health_population"), exported_tables[1]
)
self.assert_table_equals(
self.get_table_by_name("wb_health_population"), exported_tables[1]
) )
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
@ -329,12 +323,12 @@ class TestImportExport(SupersetTestCase):
self.assertEqual(slc.datasource.perm, slc.perm) self.assertEqual(slc.datasource.perm, slc.perm)
self.assert_slice_equals(expected_slice, slc) self.assert_slice_equals(expected_slice, slc)
table_id = self.get_table_by_name("wb_health_population").id table_id = self.get_table(name="wb_health_population").id
self.assertEqual(table_id, self.get_slice(slc_id).datasource_id) self.assertEqual(table_id, self.get_slice(slc_id).datasource_id)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_import_2_slices_for_same_table(self): def test_import_2_slices_for_same_table(self):
table_id = self.get_table_by_name("wb_health_population").id table_id = self.get_table(name="wb_health_population").id
# table_id != 666, import func will have to find the table # table_id != 666, import func will have to find the table
slc_1 = self.create_slice("Import Me 1", ds_id=666, id=10002) slc_1 = self.create_slice("Import Me 1", ds_id=666, id=10002)
slc_id_1 = import_chart(slc_1, None) slc_id_1 = import_chart(slc_1, None)
@ -351,13 +345,6 @@ class TestImportExport(SupersetTestCase):
self.assert_slice_equals(slc_2, imported_slc_2) self.assert_slice_equals(slc_2, imported_slc_2)
self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm) self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm)
def test_import_slices_for_non_existent_table(self):
with self.assertRaises(AttributeError):
import_chart(
self.create_slice("Import Me 3", id=10004, table_name="non_existent"),
None,
)
def test_import_slices_override(self): def test_import_slices_override(self):
slc = self.create_slice("Import Me New", id=10005) slc = self.create_slice("Import Me New", id=10005)
slc_1_id = import_chart(slc, None, import_time=1990) slc_1_id = import_chart(slc, None, import_time=1990)

View File

@ -339,7 +339,7 @@ class TestDatabaseModel(SupersetTestCase):
class TestSqlaTableModel(SupersetTestCase): class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_timestamp_expression(self): def test_get_timestamp_expression(self):
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
ds_col = tbl.get_column("ds") ds_col = tbl.get_column("ds")
sqla_literal = ds_col.get_timestamp_expression(None) sqla_literal = ds_col.get_timestamp_expression(None)
self.assertEqual(str(sqla_literal.compile()), "ds") self.assertEqual(str(sqla_literal.compile()), "ds")
@ -359,7 +359,7 @@ class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_timestamp_expression_epoch(self): def test_get_timestamp_expression_epoch(self):
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
ds_col = tbl.get_column("ds") ds_col = tbl.get_column("ds")
ds_col.expression = None ds_col.expression = None
@ -384,7 +384,7 @@ class TestSqlaTableModel(SupersetTestCase):
ds_col.expression = prev_ds_expr ds_col.expression = prev_ds_expr
def query_with_expr_helper(self, is_timeseries, inner_join=True): def query_with_expr_helper(self, is_timeseries, inner_join=True):
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
ds_col = tbl.get_column("ds") ds_col = tbl.get_column("ds")
ds_col.expression = None ds_col.expression = None
ds_col.python_date_format = None ds_col.python_date_format = None
@ -447,7 +447,7 @@ class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_sql_mutator(self): def test_sql_mutator(self):
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
query_obj = dict( query_obj = dict(
groupby=[], groupby=[],
metrics=None, metrics=None,
@ -472,7 +472,7 @@ class TestSqlaTableModel(SupersetTestCase):
app.config["SQL_QUERY_MUTATOR"] = None app.config["SQL_QUERY_MUTATOR"] = None
def test_query_with_non_existent_metrics(self): def test_query_with_non_existent_metrics(self):
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
query_obj = dict( query_obj = dict(
groupby=[], groupby=[],
@ -493,7 +493,7 @@ class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices(self): def test_data_for_slices(self):
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
slc = ( slc = (
metadata_db.session.query(Slice) metadata_db.session.query(Slice)
.filter_by( .filter_by(

View File

@ -92,7 +92,7 @@ class TestQueryContext(SupersetTestCase):
def test_cache(self): def test_cache(self):
table_name = "birth_names" table_name = "birth_names"
table = self.get_table_by_name(table_name) table = self.get_table(name=table_name)
payload = get_query_context(table.name, table.id) payload = get_query_context(table.name, table.id)
payload["force"] = True payload["force"] = True

View File

@ -1151,7 +1151,7 @@ class TestRowLevelSecurity(SupersetTestCase):
@pytest.mark.usefixtures("load_energy_table_with_slice") @pytest.mark.usefixtures("load_energy_table_with_slice")
def test_rls_filter_alters_energy_query(self): def test_rls_filter_alters_energy_query(self):
g.user = self.get_user(username="alpha") g.user = self.get_user(username="alpha")
tbl = self.get_table_by_name("energy_usage") tbl = self.get_table(name="energy_usage")
sql = tbl.get_query_str(self.query_obj) sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert tbl.get_extra_cache_keys(self.query_obj) == [1]
assert "value > 1" in sql assert "value > 1" in sql
@ -1161,7 +1161,7 @@ class TestRowLevelSecurity(SupersetTestCase):
g.user = self.get_user( g.user = self.get_user(
username="admin" username="admin"
) # self.login() doesn't actually set the user ) # self.login() doesn't actually set the user
tbl = self.get_table_by_name("energy_usage") tbl = self.get_table(name="energy_usage")
sql = tbl.get_query_str(self.query_obj) sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == [] assert tbl.get_extra_cache_keys(self.query_obj) == []
assert "value > 1" not in sql assert "value > 1" not in sql
@ -1171,7 +1171,7 @@ class TestRowLevelSecurity(SupersetTestCase):
g.user = self.get_user( g.user = self.get_user(
username="alpha" username="alpha"
) # self.login() doesn't actually set the user ) # self.login() doesn't actually set the user
tbl = self.get_table_by_name("unicode_test") tbl = self.get_table(name="unicode_test")
sql = tbl.get_query_str(self.query_obj) sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert tbl.get_extra_cache_keys(self.query_obj) == [1]
assert "value > 1" in sql assert "value > 1" in sql
@ -1179,7 +1179,7 @@ class TestRowLevelSecurity(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_alters_gamma_birth_names_query(self): def test_rls_filter_alters_gamma_birth_names_query(self):
g.user = self.get_user(username="gamma") g.user = self.get_user(username="gamma")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj) sql = tbl.get_query_str(self.query_obj)
# establish that the filters are grouped together correctly with # establish that the filters are grouped together correctly with
@ -1192,7 +1192,7 @@ class TestRowLevelSecurity(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_alters_no_role_user_birth_names_query(self): def test_rls_filter_alters_no_role_user_birth_names_query(self):
g.user = self.get_user(username="NoRlsRoleUser") g.user = self.get_user(username="NoRlsRoleUser")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj) sql = tbl.get_query_str(self.query_obj)
# gamma's filters should not be present query # gamma's filters should not be present query
@ -1205,7 +1205,7 @@ class TestRowLevelSecurity(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_doesnt_alter_admin_birth_names_query(self): def test_rls_filter_doesnt_alter_admin_birth_names_query(self):
g.user = self.get_user(username="admin") g.user = self.get_user(username="admin")
tbl = self.get_table_by_name("birth_names") tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj) sql = tbl.get_query_str(self.query_obj)
# no filters are applied for admin user # no filters are applied for admin user

View File

@ -241,7 +241,7 @@ class TestDatabaseModel(SupersetTestCase):
FilterTestCase(FilterOperator.IN, ["1", "2"], "IN (1, 2)"), FilterTestCase(FilterOperator.IN, ["1", "2"], "IN (1, 2)"),
FilterTestCase(FilterOperator.NOT_IN, ["1", "2"], "NOT IN (1, 2)"), FilterTestCase(FilterOperator.NOT_IN, ["1", "2"], "NOT IN (1, 2)"),
) )
table = self.get_table_by_name("birth_names") table = self.get_table(name="birth_names")
for filter_ in filters: for filter_ in filters:
query_obj = { query_obj = {
"granularity": None, "granularity": None,

View File

@ -42,11 +42,6 @@ from tests.integration_tests.fixtures.query_context import get_query_context
from tests.integration_tests.test_app import app from tests.integration_tests.test_app import app
def get_table_by_name(name: str) -> SqlaTable:
with app.app_context():
return db.session.query(SqlaTable).filter_by(table_name=name).one()
class TestAsyncQueries(SupersetTestCase): class TestAsyncQueries(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job") @mock.patch.object(async_query_manager, "update_job")
@ -127,7 +122,7 @@ class TestAsyncQueries(SupersetTestCase):
@mock.patch.object(async_query_manager, "update_job") @mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job): def test_load_explore_json_into_cache(self, mock_update_job):
async_query_manager.init_app(app) async_query_manager.init_app(app)
table = get_table_by_name("birth_names") table = self.get_table(name="birth_names")
user = security_manager.find_user("gamma") user = security_manager.find_user("gamma")
form_data = { form_data = {
"datasource": f"{table.id}__table", "datasource": f"{table.id}__table",