From 17b58037f85dfb9db68167484d6afe5bda8f4f1c Mon Sep 17 00:00:00 2001 From: Bogdan Date: Tue, 9 Aug 2022 09:59:31 -0700 Subject: [PATCH] perf: Implement model specific lookups by id to improve performance (#20974) * Implement model specific lookups by id to improve performance * Address comments e.g. better variable names and test cleanup * commit after cleanup * even better name and test cleanup via rollback Co-authored-by: Bogdan Kyryliuk --- superset/common/query_context_processor.py | 2 + superset/dao/base.py | 7 +- superset/explore/utils.py | 9 ++- tests/integration_tests/datasets/api_tests.py | 1 + .../explore/form_data/api_tests.py | 10 +-- .../explore/permalink/api_tests.py | 2 +- tests/unit_tests/charts/dao/__init__.py | 16 ++++ tests/unit_tests/charts/dao/dao_tests.py | 67 +++++++++++++++++ tests/unit_tests/datasets/dao/__init__.py | 16 ++++ tests/unit_tests/datasets/dao/dao_tests.py | 73 +++++++++++++++++++ 10 files changed, 192 insertions(+), 11 deletions(-) create mode 100644 tests/unit_tests/charts/dao/__init__.py create mode 100644 tests/unit_tests/charts/dao/dao_tests.py create mode 100644 tests/unit_tests/datasets/dao/__init__.py create mode 100644 tests/unit_tests/datasets/dao/dao_tests.py diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 19d78e0b3..2978eeace 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -493,6 +493,8 @@ class QueryContextProcessor: chart = ChartDAO.find_by_id(annotation_layer["value"]) if not chart: raise QueryObjectValidationError(_("The chart does not exist")) + if not chart.datasource: + raise QueryObjectValidationError(_("The chart datasource does not exist")) form_data = chart.form_data.copy() try: viz_obj = get_viz( diff --git a/superset/dao/base.py b/superset/dao/base.py index 0090c4e53..c6890e53a 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -50,14 +50,17 @@ class BaseDAO: @classmethod def find_by_id( - cls, model_id: Union[str, int], session: Session = None + cls, + model_id: Union[str, int], + session: Session = None, + skip_base_filter: bool = False, ) -> Optional[Model]: """ Find a model by id, if defined applies `base_filter` """ session = session or db.session query = session.query(cls.model_cls) - if cls.base_filter: + if cls.base_filter and not skip_base_filter: data_model = SQLAInterface(cls.model_cls, session) query = cls.base_filter( # pylint: disable=not-callable cls.id_column_name, data_model diff --git a/superset/explore/utils.py b/superset/explore/utils.py index a1c329510..01f63f53f 100644 --- a/superset/explore/utils.py +++ b/superset/explore/utils.py @@ -38,7 +38,8 @@ from superset.utils.core import DatasourceType def check_dataset_access(dataset_id: int) -> Optional[bool]: if dataset_id: - dataset = DatasetDAO.find_by_id(dataset_id) + # Access checks below, no need to validate them twice as they can be expensive. + dataset = DatasetDAO.find_by_id(dataset_id, skip_base_filter=True) if dataset: can_access_datasource = security_manager.can_access_datasource(dataset) if can_access_datasource: @@ -49,7 +50,8 @@ def check_dataset_access(dataset_id: int) -> Optional[bool]: def check_query_access(query_id: int) -> Optional[bool]: if query_id: - query = QueryDAO.find_by_id(query_id) + # Access checks below, no need to validate them twice as they can be expensive. + query = QueryDAO.find_by_id(query_id, skip_base_filter=True) if query: security_manager.raise_for_access(query=query) return True @@ -81,7 +83,8 @@ def check_access( check_datasource_access(datasource_id, datasource_type) if not chart_id: return True - chart = ChartDAO.find_by_id(chart_id) + # Access checks below, no need to validate them twice as they can be expensive. + chart = ChartDAO.find_by_id(chart_id, skip_base_filter=True) if chart: can_access_chart = security_manager.is_owner( chart diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index a993f0c0b..950756d81 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -1817,6 +1817,7 @@ class TestDatasetApi(SupersetTestCase): rv = self.client.get(uri) assert rv.status_code == 404 self.logout() + self.login(username="gamma") table = self.get_birth_names_dataset() uri = f"api/v1/dataset/{table.id}/related_objects" diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py index fe8425e28..0e73d0b51 100644 --- a/tests/integration_tests/explore/form_data/api_tests.py +++ b/tests/integration_tests/explore/form_data/api_tests.py @@ -126,7 +126,7 @@ def test_post_access_denied( "form_data": INITIAL_FORM_DATA, } resp = test_client.post("api/v1/explore/form_data", json=payload) - assert resp.status_code == 404 + assert resp.status_code == 403 def test_post_same_key_for_same_context( @@ -337,7 +337,7 @@ def test_put_access_denied(test_client, login_as, chart_id: int, datasource: Sql "form_data": UPDATED_FORM_DATA, } resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) - assert resp.status_code == 404 + assert resp.status_code == 403 def test_put_not_owner(test_client, login_as, chart_id: int, datasource: SqlaTable): @@ -349,7 +349,7 @@ def test_put_not_owner(test_client, login_as, chart_id: int, datasource: SqlaTab "form_data": UPDATED_FORM_DATA, } resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) - assert resp.status_code == 404 + assert resp.status_code == 403 def test_get_key_not_found(test_client, login_as_admin): @@ -367,7 +367,7 @@ def test_get(test_client, login_as_admin): def test_get_access_denied(test_client, login_as): login_as("gamma") resp = test_client.get(f"api/v1/explore/form_data/{KEY}") - assert resp.status_code == 404 + assert resp.status_code == 403 @patch("superset.security.SupersetSecurityManager.can_access_datasource") @@ -387,7 +387,7 @@ def test_delete(test_client, login_as_admin): def test_delete_access_denied(test_client, login_as): login_as("gamma") resp = test_client.delete(f"api/v1/explore/form_data/{KEY}") - assert resp.status_code == 404 + assert resp.status_code == 403 def test_delete_not_owner( diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index a808f0111..22a36f41e 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -84,7 +84,7 @@ def test_post( def test_post_access_denied(test_client, login_as, form_data): login_as("gamma") resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) - assert resp.status_code == 404 + assert resp.status_code == 403 def test_get_missing_chart( diff --git a/tests/unit_tests/charts/dao/__init__.py b/tests/unit_tests/charts/dao/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/tests/unit_tests/charts/dao/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/charts/dao/dao_tests.py b/tests/unit_tests/charts/dao/dao_tests.py new file mode 100644 index 000000000..15310712a --- /dev/null +++ b/tests/unit_tests/charts/dao/dao_tests.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + +from superset.utils.core import DatasourceType + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.models.slice import Slice + + engine = session.get_bind() + Slice.metadata.create_all(engine) # pylint: disable=no-member + + slice_obj = Slice( + id=1, + datasource_id=1, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_perm_table", + slice_name="slice_name", + ) + + session.add(slice_obj) + session.commit() + yield session + session.rollback() + + +def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None: + from superset.charts.dao import ChartDAO + from superset.models.slice import Slice + + result = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True) + + assert result + assert 1 == result.id + assert "slice_name" == result.slice_name + assert isinstance(result, Slice) + + +def test_datasource_find_by_id_skip_base_filter_not_found( + session_with_data: Session, +) -> None: + from superset.charts.dao import ChartDAO + + result = ChartDAO.find_by_id( + 125326326, session=session_with_data, skip_base_filter=True + ) + assert result is None diff --git a/tests/unit_tests/datasets/dao/__init__.py b/tests/unit_tests/datasets/dao/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/tests/unit_tests/datasets/dao/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py new file mode 100644 index 000000000..31aa9f27d --- /dev/null +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=db, + ) + + session.add(db) + session.add(sqla_table) + session.flush() + yield session + session.rollback() + + +def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.datasets.dao import DatasetDAO + + result = DatasetDAO.find_by_id( + 1, + session=session_with_data, + skip_base_filter=True, + ) + + assert result + assert 1 == result.id + assert "my_sqla_table" == result.table_name + assert isinstance(result, SqlaTable) + + +def test_datasource_find_by_id_skip_base_filter_not_found( + session_with_data: Session, +) -> None: + from superset.datasets.dao import DatasetDAO + + result = DatasetDAO.find_by_id( + 125326326, + session=session_with_data, + skip_base_filter=True, + ) + assert result is None