refactor: Ensure Celery leverages the Flask-SQLAlchemy session (#26186)

This commit is contained in:
John Bodley 2024-01-17 17:06:22 +13:00 committed by GitHub
parent aaa4a7b371
commit 7af82ae87d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 932 additions and 348 deletions

688
1 Normal file
View File

@ -0,0 +1,688 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
import unittest
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
import pytest
from flask import g
from sqlalchemy.orm.session import make_transient
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_with_slice,
load_energy_table_data,
)
from tests.integration_tests.test_app import app
from superset.commands.dashboard.importers.v0 import decode_dashboards
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.commands.dashboard.importers.v0 import import_chart, import_dashboard
from superset.commands.dataset.importers.v0 import import_dataset
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import DatasourceType, get_example_default_schema
from superset.utils.database import get_example_database
from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices,
load_world_bank_data,
)
from .base_tests import SupersetTestCase
def delete_imports():
with app.app_context():
# Imported data clean up
session = db.session
for slc in session.query(Slice):
if "remote_id" in slc.params_dict:
session.delete(slc)
for dash in session.query(Dashboard):
if "remote_id" in dash.params_dict:
session.delete(dash)
for table in session.query(SqlaTable):
if "remote_id" in table.params_dict:
session.delete(table)
session.commit()
@pytest.fixture(autouse=True, scope="module")
def clean_imports():
yield
delete_imports()
class TestImportExport(SupersetTestCase):
"""Testing export import functionality for dashboards"""
def create_slice(
self,
name,
ds_id=None,
id=None,
db_name="examples",
table_name="wb_health_population",
schema=None,
):
params = {
"num_period_compare": "10",
"remote_id": id,
"datasource_name": table_name,
"database_name": db_name,
"schema": schema,
# Test for trailing commas
"metrics": ["sum__signup_attempt_email", "sum__signup_attempt_facebook"],
}
if table_name and not ds_id:
table = self.get_table(schema=schema, name=table_name)
if table:
ds_id = table.id
return Slice(
slice_name=name,
datasource_type=DatasourceType.TABLE,
viz_type="bubble",
params=json.dumps(params),
datasource_id=ds_id,
id=id,
)
def create_dashboard(self, title, id=0, slcs=[]):
json_metadata = {"remote_id": id}
return Dashboard(
id=id,
dashboard_title=title,
slices=slcs,
position_json='{"size_y": 2, "size_x": 2}',
slug=f"{title.lower()}_imported",
json_metadata=json.dumps(json_metadata),
published=False,
)
def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]):
params = {"remote_id": id, "database_name": "examples"}
table = SqlaTable(
id=id,
schema=schema,
table_name=name,
params=json.dumps(params),
)
for col_name in cols_names:
table.columns.append(TableColumn(column_name=col_name))
for metric_name in metric_names:
table.metrics.append(SqlMetric(metric_name=metric_name, expression=""))
return table
def get_slice(self, slc_id):
return db.session.query(Slice).filter_by(id=slc_id).first()
def get_slice_by_name(self, name):
return db.session.query(Slice).filter_by(slice_name=name).first()
def get_dash(self, dash_id):
return db.session.query(Dashboard).filter_by(id=dash_id).first()
def assert_dash_equals(
self, expected_dash, actual_dash, check_position=True, check_slugs=True
):
if check_slugs:
self.assertEqual(expected_dash.slug, actual_dash.slug)
self.assertEqual(expected_dash.dashboard_title, actual_dash.dashboard_title)
self.assertEqual(len(expected_dash.slices), len(actual_dash.slices))
expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "")
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
for e_slc, a_slc in zip(expected_slices, actual_slices):
self.assert_slice_equals(e_slc, a_slc)
if check_position:
self.assertEqual(expected_dash.position_json, actual_dash.position_json)
def assert_table_equals(self, expected_ds, actual_ds):
self.assertEqual(expected_ds.table_name, actual_ds.table_name)
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
self.assertEqual(expected_ds.schema, actual_ds.schema)
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
{c.column_name for c in expected_ds.columns},
{c.column_name for c in actual_ds.columns},
)
self.assertEqual(
{m.metric_name for m in expected_ds.metrics},
{m.metric_name for m in actual_ds.metrics},
)
def assert_datasource_equals(self, expected_ds, actual_ds):
self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name)
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
{c.column_name for c in expected_ds.columns},
{c.column_name for c in actual_ds.columns},
)
self.assertEqual(
{m.metric_name for m in expected_ds.metrics},
{m.metric_name for m in actual_ds.metrics},
)
def assert_slice_equals(self, expected_slc, actual_slc):
# to avoid bad slice data (no slice_name)
expected_slc_name = expected_slc.slice_name or ""
actual_slc_name = actual_slc.slice_name or ""
self.assertEqual(expected_slc_name, actual_slc_name)
self.assertEqual(expected_slc.datasource_type, actual_slc.datasource_type)
self.assertEqual(expected_slc.viz_type, actual_slc.viz_type)
exp_params = json.loads(expected_slc.params)
actual_params = json.loads(actual_slc.params)
diff_params_keys = (
"schema",
"database_name",
"datasource_name",
"remote_id",
"import_time",
)
for k in diff_params_keys:
if k in actual_params:
actual_params.pop(k)
if k in exp_params:
exp_params.pop(k)
self.assertEqual(exp_params, actual_params)
def assert_only_exported_slc_fields(self, expected_dash, actual_dash):
"""only exported json has this params
imported/created dashboard has relationships to other models instead
"""
expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "")
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
for e_slc, a_slc in zip(expected_slices, actual_slices):
params = a_slc.params_dict
self.assertEqual(e_slc.datasource.name, params["datasource_name"])
self.assertEqual(e_slc.datasource.schema, params["schema"])
self.assertEqual(e_slc.datasource.database.name, params["database_name"])
@unittest.skip("Schema needs to be updated")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_export_1_dashboard(self):
self.login("admin")
birth_dash = self.get_dash_by_slug("births")
id_ = birth_dash.id
export_dash_url = f"/dashboard/export_dashboards_form?id={id_}&action=go"
resp = self.client.get(export_dash_url)
exported_dashboards = json.loads(
resp.data.decode("utf-8"), object_hook=decode_dashboards
)["dashboards"]
birth_dash = self.get_dash_by_slug("births")
self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0])
self.assert_dash_equals(birth_dash, exported_dashboards[0])
self.assertEqual(
id_,
json.loads(
exported_dashboards[0].json_metadata, object_hook=decode_dashboards
)["remote_id"],
)
exported_tables = json.loads(
resp.data.decode("utf-8"), object_hook=decode_dashboards
)["datasources"]
self.assertEqual(1, len(exported_tables))
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
@unittest.skip("Schema needs to be updated")
@pytest.mark.usefixtures(
"load_world_bank_dashboard_with_slices",
"load_birth_names_dashboard_with_slices",
)
def test_export_2_dashboards(self):
self.login("admin")
birth_dash = self.get_dash_by_slug("births")
world_health_dash = self.get_dash_by_slug("world_health")
export_dash_url = (
"/dashboard/export_dashboards_form?id={}&id={}&action=go".format(
birth_dash.id, world_health_dash.id
)
)
resp = self.client.get(export_dash_url)
resp_data = json.loads(resp.data.decode("utf-8"), object_hook=decode_dashboards)
exported_dashboards = sorted(
resp_data.get("dashboards"), key=lambda d: d.dashboard_title
)
self.assertEqual(2, len(exported_dashboards))
birth_dash = self.get_dash_by_slug("births")
self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0])
self.assert_dash_equals(birth_dash, exported_dashboards[0])
self.assertEqual(
birth_dash.id, json.loads(exported_dashboards[0].json_metadata)["remote_id"]
)
world_health_dash = self.get_dash_by_slug("world_health")
self.assert_only_exported_slc_fields(world_health_dash, exported_dashboards[1])
self.assert_dash_equals(world_health_dash, exported_dashboards[1])
self.assertEqual(
world_health_dash.id,
json.loads(exported_dashboards[1].json_metadata)["remote_id"],
)
exported_tables = sorted(
resp_data.get("datasources"), key=lambda t: t.table_name
)
self.assertEqual(2, len(exported_tables))
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
self.assert_table_equals(
self.get_table(name="wb_health_population"), exported_tables[1]
)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_import_1_slice(self):
expected_slice = self.create_slice(
"Import Me", id=10001, schema=get_example_default_schema()
)
slc_id = import_chart(expected_slice, None, import_time=1989)
slc = self.get_slice(slc_id)
self.assertEqual(slc.datasource.perm, slc.perm)
self.assert_slice_equals(expected_slice, slc)
table_id = self.get_table(name="wb_health_population").id
self.assertEqual(table_id, self.get_slice(slc_id).datasource_id)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_import_2_slices_for_same_table(self):
schema = get_example_default_schema()
table_id = self.get_table(name="wb_health_population").id
slc_1 = self.create_slice(
"Import Me 1", ds_id=table_id, id=10002, schema=schema
)
slc_id_1 = import_chart(slc_1, None)
slc_2 = self.create_slice(
"Import Me 2", ds_id=table_id, id=10003, schema=schema
)
slc_id_2 = import_chart(slc_2, None)
imported_slc_1 = self.get_slice(slc_id_1)
imported_slc_2 = self.get_slice(slc_id_2)
self.assertEqual(table_id, imported_slc_1.datasource_id)
self.assert_slice_equals(slc_1, imported_slc_1)
self.assertEqual(imported_slc_1.datasource.perm, imported_slc_1.perm)
self.assertEqual(table_id, imported_slc_2.datasource_id)
self.assert_slice_equals(slc_2, imported_slc_2)
self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm)
def test_import_slices_override(self):
schema = get_example_default_schema()
slc = self.create_slice("Import Me New", id=10005, schema=schema)
slc_1_id = import_chart(slc, None, import_time=1990)
slc.slice_name = "Import Me New"
imported_slc_1 = self.get_slice(slc_1_id)
slc_2 = self.create_slice("Import Me New", id=10005, schema=schema)
slc_2_id = import_chart(slc_2, imported_slc_1, import_time=1990)
self.assertEqual(slc_1_id, slc_2_id)
imported_slc_2 = self.get_slice(slc_2_id)
self.assert_slice_equals(slc, imported_slc_2)
def test_import_empty_dashboard(self):
empty_dash = self.create_dashboard("empty_dashboard", id=10001)
imported_dash_id = import_dashboard(empty_dash, import_time=1989)
imported_dash = self.get_dash(imported_dash_id)
self.assert_dash_equals(empty_dash, imported_dash, check_position=False)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_import_dashboard_1_slice(self):
slc = self.create_slice(
"health_slc", id=10006, schema=get_example_default_schema()
)
dash_with_1_slice = self.create_dashboard(
"dash_with_1_slice", slcs=[slc], id=10002
)
dash_with_1_slice.position_json = """
{{"DASHBOARD_VERSION_KEY": "v2",
"DASHBOARD_CHART_TYPE-{0}": {{
"type": "CHART",
"id": {0},
"children": [],
"meta": {{
"width": 4,
"height": 50,
"chartId": {0}
}}
}}
}}
""".format(
slc.id
)
imported_dash_id = import_dashboard(dash_with_1_slice, import_time=1990)
imported_dash = self.get_dash(imported_dash_id)
expected_dash = self.create_dashboard("dash_with_1_slice", slcs=[slc], id=10002)
make_transient(expected_dash)
self.assert_dash_equals(
expected_dash, imported_dash, check_position=False, check_slugs=False
)
self.assertEqual(
{
"remote_id": 10002,
"import_time": 1990,
"native_filter_configuration": [],
},
json.loads(imported_dash.json_metadata),
)
expected_position = dash_with_1_slice.position
# new slice id (auto-incremental) assigned on insert
# id from json is used only for updating position with new id
meta = expected_position["DASHBOARD_CHART_TYPE-10006"]["meta"]
meta["chartId"] = imported_dash.slices[0].id
self.assertEqual(expected_position, imported_dash.position)
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_import_dashboard_2_slices(self):
schema = get_example_default_schema()
e_slc = self.create_slice(
"e_slc", id=10007, table_name="energy_usage", schema=schema
)
b_slc = self.create_slice(
"b_slc", id=10008, table_name="birth_names", schema=schema
)
dash_with_2_slices = self.create_dashboard(
"dash_with_2_slices", slcs=[e_slc, b_slc], id=10003
)
dash_with_2_slices.json_metadata = json.dumps(
{
"remote_id": 10003,
"expanded_slices": {
f"{e_slc.id}": True,
f"{b_slc.id}": False,
},
# mocked legacy filter_scope metadata
"filter_scopes": {
str(e_slc.id): {
"region": {"scope": ["ROOT_ID"], "immune": [b_slc.id]}
}
},
}
)
imported_dash_id = import_dashboard(dash_with_2_slices, import_time=1991)
imported_dash = self.get_dash(imported_dash_id)
expected_dash = self.create_dashboard(
"dash_with_2_slices", slcs=[e_slc, b_slc], id=10003
)
make_transient(expected_dash)
self.assert_dash_equals(
imported_dash, expected_dash, check_position=False, check_slugs=False
)
i_e_slc = self.get_slice_by_name("e_slc")
i_b_slc = self.get_slice_by_name("b_slc")
expected_json_metadata = {
"remote_id": 10003,
"import_time": 1991,
"expanded_slices": {
f"{i_e_slc.id}": True,
f"{i_b_slc.id}": False,
},
"native_filter_configuration": [],
}
self.assertEqual(
expected_json_metadata, json.loads(imported_dash.json_metadata)
)
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_import_override_dashboard_2_slices(self):
schema = get_example_default_schema()
e_slc = self.create_slice(
"e_slc", id=10009, table_name="energy_usage", schema=schema
)
b_slc = self.create_slice(
"b_slc", id=10010, table_name="birth_names", schema=schema
)
dash_to_import = self.create_dashboard(
"override_dashboard", slcs=[e_slc, b_slc], id=10004
)
imported_dash_id_1 = import_dashboard(dash_to_import, import_time=1992)
# create new instances of the slices
e_slc = self.create_slice(
"e_slc", id=10009, table_name="energy_usage", schema=schema
)
b_slc = self.create_slice(
"b_slc", id=10010, table_name="birth_names", schema=schema
)
c_slc = self.create_slice(
"c_slc", id=10011, table_name="birth_names", schema=schema
)
dash_to_import_override = self.create_dashboard(
"override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004
)
imported_dash_id_2 = import_dashboard(dash_to_import_override, import_time=1992)
# override doesn't change the id
self.assertEqual(imported_dash_id_1, imported_dash_id_2)
expected_dash = self.create_dashboard(
"override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004
)
make_transient(expected_dash)
imported_dash = self.get_dash(imported_dash_id_2)
self.assert_dash_equals(
expected_dash, imported_dash, check_position=False, check_slugs=False
)
self.assertEqual(
{
"remote_id": 10004,
"import_time": 1992,
"native_filter_configuration": [],
},
json.loads(imported_dash.json_metadata),
)
def test_import_new_dashboard_slice_reset_ownership(self):
admin_user = security_manager.find_user(username="admin")
self.assertTrue(admin_user)
gamma_user = security_manager.find_user(username="gamma")
self.assertTrue(gamma_user)
g.user = gamma_user
dash_with_1_slice = self._create_dashboard_for_import(id_=10200)
# set another user as an owner of importing dashboard
dash_with_1_slice.created_by = admin_user
dash_with_1_slice.changed_by = admin_user
dash_with_1_slice.owners = [admin_user]
imported_dash_id = import_dashboard(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
def test_import_override_dashboard_slice_reset_ownership(self):
admin_user = security_manager.find_user(username="admin")
self.assertTrue(admin_user)
gamma_user = security_manager.find_user(username="gamma")
self.assertTrue(gamma_user)
g.user = gamma_user
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
imported_dash_id = import_dashboard(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
# re-import with another user shouldn't change the permissions
g.user = admin_user
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
imported_dash_id = import_dashboard(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
def _create_dashboard_for_import(self, id_=10100):
slc = self.create_slice(
"health_slc" + str(id_), id=id_ + 1, schema=get_example_default_schema()
)
dash_with_1_slice = self.create_dashboard(
"dash_with_1_slice" + str(id_), slcs=[slc], id=id_ + 2
)
dash_with_1_slice.position_json = """
{{"DASHBOARD_VERSION_KEY": "v2",
"DASHBOARD_CHART_TYPE-{0}": {{
"type": "CHART",
"id": {0},
"children": [],
"meta": {{
"width": 4,
"height": 50,
"chartId": {0}
}}
}}
}}
""".format(
slc.id
)
return dash_with_1_slice
def test_import_table_no_metadata(self):
schema = get_example_default_schema()
db_id = get_example_database().id
table = self.create_table("pure_table", id=10001, schema=schema)
imported_id = import_dataset(table, db_id, import_time=1989)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
def test_import_table_1_col_1_met(self):
schema = get_example_default_schema()
table = self.create_table(
"table_1_col_1_met",
id=10002,
cols_names=["col1"],
metric_names=["metric1"],
schema=schema,
)
db_id = get_example_database().id
imported_id = import_dataset(table, db_id, import_time=1990)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
self.assertEqual(
{
"remote_id": 10002,
"import_time": 1990,
"database_name": "examples",
},
json.loads(imported.params),
)
def test_import_table_2_col_2_met(self):
schema = get_example_default_schema()
table = self.create_table(
"table_2_col_2_met",
id=10003,
cols_names=["c1", "c2"],
metric_names=["m1", "m2"],
schema=schema,
)
db_id = get_example_database().id
imported_id = import_dataset(table, db_id, import_time=1991)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
def test_import_table_override(self):
schema = get_example_default_schema()
table = self.create_table(
"table_override",
id=10003,
cols_names=["col1"],
metric_names=["m1"],
schema=schema,
)
db_id = get_example_database().id
imported_id = import_dataset(table, db_id, import_time=1991)
table_over = self.create_table(
"table_override",
id=10003,
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
schema=schema,
)
imported_over_id = import_dataset(table_over, db_id, import_time=1992)
imported_over = self.get_table_by_id(imported_over_id)
self.assertEqual(imported_id, imported_over.id)
expected_table = self.create_table(
"table_override",
id=10003,
metric_names=["new_metric1", "m1"],
cols_names=["col1", "new_col1", "col2", "col3"],
schema=schema,
)
self.assert_table_equals(expected_table, imported_over)
def test_import_table_override_identical(self):
schema = get_example_default_schema()
table = self.create_table(
"copy_cat",
id=10004,
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
schema=schema,
)
db_id = get_example_database().id
imported_id = import_dataset(table, db_id, import_time=1993)
copy_table = self.create_table(
"copy_cat",
id=10004,
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
schema=schema,
)
imported_id_copy = import_dataset(copy_table, db_id, import_time=1994)
self.assertEqual(imported_id, imported_id_copy)
self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
if __name__ == "__main__":
unittest.main()

View File

@ -22,9 +22,8 @@ from uuid import UUID
import pandas as pd
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy.orm import Session
from superset import app, security_manager
from superset import app, db, security_manager
from superset.commands.base import BaseCommand
from superset.commands.dashboard.permalink.create import CreateDashboardPermalinkCommand
from superset.commands.exceptions import CommandException
@ -68,7 +67,6 @@ from superset.reports.notifications import create_notification
from superset.reports.notifications.base import NotificationContent
from superset.reports.notifications.exceptions import NotificationError
from superset.tasks.utils import get_executor
from superset.utils.celery import session_scope
from superset.utils.core import HeaderDataType, override_user
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.decorators import logs_context
@ -85,12 +83,10 @@ class BaseReportState:
@logs_context()
def __init__(
self,
session: Session,
report_schedule: ReportSchedule,
scheduled_dttm: datetime,
execution_id: UUID,
) -> None:
self._session = session
self._report_schedule = report_schedule
self._scheduled_dttm = scheduled_dttm
self._start_dttm = datetime.utcnow()
@ -123,7 +119,7 @@ class BaseReportState:
self._report_schedule.last_state = state
self._report_schedule.last_eval_dttm = datetime.utcnow()
self._session.commit()
db.session.commit()
def create_log(self, error_message: Optional[str] = None) -> None:
"""
@ -140,8 +136,8 @@ class BaseReportState:
report_schedule=self._report_schedule,
uuid=self._execution_id,
)
self._session.add(log)
self._session.commit()
db.session.add(log)
db.session.commit()
def _get_url(
self,
@ -485,9 +481,7 @@ class BaseReportState:
"""
Checks if an alert is in it's grace period
"""
last_success = ReportScheduleDAO.find_last_success_log(
self._report_schedule, session=self._session
)
last_success = ReportScheduleDAO.find_last_success_log(self._report_schedule)
return (
last_success is not None
and self._report_schedule.grace_period
@ -501,7 +495,7 @@ class BaseReportState:
Checks if an alert/report on error is in it's notification grace period
"""
last_success = ReportScheduleDAO.find_last_error_notification(
self._report_schedule, session=self._session
self._report_schedule
)
if not last_success:
return False
@ -518,7 +512,7 @@ class BaseReportState:
Checks if an alert is in a working timeout
"""
last_working = ReportScheduleDAO.find_last_entered_working_log(
self._report_schedule, session=self._session
self._report_schedule
)
if not last_working:
return False
@ -668,12 +662,10 @@ class ReportScheduleStateMachine: # pylint: disable=too-few-public-methods
def __init__(
self,
session: Session,
task_uuid: UUID,
report_schedule: ReportSchedule,
scheduled_dttm: datetime,
):
self._session = session
self._execution_id = task_uuid
self._report_schedule = report_schedule
self._scheduled_dttm = scheduled_dttm
@ -684,7 +676,6 @@ class ReportScheduleStateMachine: # pylint: disable=too-few-public-methods
self._report_schedule.last_state in state_cls.current_states
):
state_cls(
self._session,
self._report_schedule,
self._scheduled_dttm,
self._execution_id,
@ -708,31 +699,30 @@ class AsyncExecuteReportScheduleCommand(BaseCommand):
self._execution_id = UUID(task_id)
def run(self) -> None:
with session_scope(nullpool=True) as session:
try:
self.validate(session=session)
if not self._model:
raise ReportScheduleExecuteUnexpectedError()
_, username = get_executor(
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
model=self._model,
try:
self.validate()
if not self._model:
raise ReportScheduleExecuteUnexpectedError()
_, username = get_executor(
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
model=self._model,
)
user = security_manager.find_user(username)
with override_user(user):
logger.info(
"Running report schedule %s as user %s",
self._execution_id,
username,
)
user = security_manager.find_user(username)
with override_user(user):
logger.info(
"Running report schedule %s as user %s",
self._execution_id,
username,
)
ReportScheduleStateMachine(
session, self._execution_id, self._model, self._scheduled_dttm
).run()
except CommandException as ex:
raise ex
except Exception as ex:
raise ReportScheduleUnexpectedError(str(ex)) from ex
ReportScheduleStateMachine(
self._execution_id, self._model, self._scheduled_dttm
).run()
except CommandException as ex:
raise ex
except Exception as ex:
raise ReportScheduleUnexpectedError(str(ex)) from ex
def validate(self, session: Session = None) -> None:
def validate(self) -> None:
# Validate/populate model exists
logger.info(
"session is validated: id %s, executionid: %s",
@ -740,7 +730,7 @@ class AsyncExecuteReportScheduleCommand(BaseCommand):
self._execution_id,
)
self._model = (
session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none()
db.session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none()
)
if not self._model:
raise ReportScheduleNotFoundError()

View File

@ -17,12 +17,12 @@
import logging
from datetime import datetime, timedelta
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.report.exceptions import ReportSchedulePruneLogError
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.reports.models import ReportSchedule
from superset.utils.celery import session_scope
logger = logging.getLogger(__name__)
@ -36,28 +36,27 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
self._worker_context = worker_context
def run(self) -> None:
with session_scope(nullpool=True) as session:
self.validate()
prune_errors = []
self.validate()
prune_errors = []
for report_schedule in session.query(ReportSchedule).all():
if report_schedule.log_retention is not None:
from_date = datetime.utcnow() - timedelta(
days=report_schedule.log_retention
for report_schedule in db.session.query(ReportSchedule).all():
if report_schedule.log_retention is not None:
from_date = datetime.utcnow() - timedelta(
days=report_schedule.log_retention
)
try:
row_count = ReportScheduleDAO.bulk_delete_logs(
report_schedule, from_date, commit=False
)
try:
row_count = ReportScheduleDAO.bulk_delete_logs(
report_schedule, from_date, session=session, commit=False
)
logger.info(
"Deleted %s logs for report schedule id: %s",
str(row_count),
str(report_schedule.id),
)
except DAODeleteFailedError as ex:
prune_errors.append(str(ex))
if prune_errors:
raise ReportSchedulePruneLogError(";".join(prune_errors))
logger.info(
"Deleted %s logs for report schedule id: %s",
str(row_count),
str(report_schedule.id),
)
except DAODeleteFailedError as ex:
prune_errors.append(str(ex))
if prune_errors:
raise ReportSchedulePruneLogError(";".join(prune_errors))
def validate(self) -> None:
pass

View File

@ -22,7 +22,6 @@ from datetime import datetime
from typing import Any
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAODeleteFailedError
@ -204,27 +203,25 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
return super().update(item, attributes, commit)
@staticmethod
def find_active(session: Session | None = None) -> list[ReportSchedule]:
def find_active() -> list[ReportSchedule]:
"""
Find all active reports. If session is passed it will be used instead of the
default `db.session`, this is useful when on a celery worker session context
Find all active reports.
"""
session = session or db.session
return (
session.query(ReportSchedule).filter(ReportSchedule.active.is_(True)).all()
db.session.query(ReportSchedule)
.filter(ReportSchedule.active.is_(True))
.all()
)
@staticmethod
def find_last_success_log(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last success execution log for a given report
"""
session = session or db.session
return (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state == ReportState.SUCCESS,
ReportExecutionLog.report_schedule == report_schedule,
@ -236,14 +233,12 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
@staticmethod
def find_last_entered_working_log(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last success execution log for a given report
"""
session = session or db.session
return (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state == ReportState.WORKING,
ReportExecutionLog.report_schedule == report_schedule,
@ -256,14 +251,12 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
@staticmethod
def find_last_error_notification(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last error email sent
"""
session = session or db.session
last_error_email_log = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.error_message
== REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER,
@ -276,7 +269,7 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
return None
# Checks that only errors have occurred since the last email
report_from_last_email = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state.notin_(
[ReportState.ERROR, ReportState.WORKING]
@ -293,13 +286,11 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
def bulk_delete_logs(
model: ReportSchedule,
from_date: datetime,
session: Session | None = None,
commit: bool = True,
) -> int | None:
session = session or db.session
try:
row_count = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.report_schedule == model,
ReportExecutionLog.end_dttm < from_date,
@ -307,8 +298,8 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
.delete(synchronize_session="fetch")
)
if commit:
session.commit()
db.session.commit()
return row_count
except SQLAlchemyError as ex:
session.rollback()
db.session.rollback()
raise DAODeleteFailedError(str(ex)) from ex

View File

@ -50,7 +50,6 @@ from sqlalchemy.engine.interfaces import Compiled, Dialect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import literal_column, quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
@ -1071,7 +1070,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
def handle_cursor(cls, cursor: Any, query: Query) -> None:
"""Handle a live cursor between the execute and fetchall calls
The flow works without this method doing anything, but it allows
@ -1080,9 +1079,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# TODO: Fix circular import error caused by importing sql_lab.Query
@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
) -> None:
def execute_with_cursor(cls, cursor: Any, sql: str, query: Query) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
@ -1095,7 +1092,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
logger.debug("Query %d: Running query: %s", query.id, sql)
cls.execute(cursor, sql, async_=True)
logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query, session)
cls.handle_cursor(cursor, query)
@classmethod
def extract_error_message(cls, ex: Exception) -> str:
@ -1841,7 +1838,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# pylint: disable=unused-argument
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
def prepare_cancel_query(cls, query: Query) -> None:
"""
Some databases may acquire the query cancelation id after the query
cancelation request has been received. For those cases, the db engine spec

View File

@ -34,9 +34,9 @@ from sqlalchemy import Column, text, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from superset import db
from superset.common.db_query_status import QueryStatus
from superset.constants import TimeGrain
from superset.databases.utils import make_url_safe
@ -334,7 +334,7 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def handle_cursor( # pylint: disable=too-many-locals
cls, cursor: Any, query: Query, session: Session
cls, cursor: Any, query: Query
) -> None:
"""Updates progress information"""
# pylint: disable=import-outside-toplevel
@ -353,8 +353,8 @@ class HiveEngineSpec(PrestoEngineSpec):
# Queries don't terminate when user clicks the STOP button on SQL LAB.
# Refresh session so that the `query.status` modified in stop_query in
# views/core.py is reflected here.
session.refresh(query)
query = session.query(type(query)).filter_by(id=query_id).one()
db.session.refresh(query)
query = db.session.query(type(query)).filter_by(id=query_id).one()
if query.status == QueryStatus.STOPPED:
cursor.cancel()
break
@ -396,7 +396,7 @@ class HiveEngineSpec(PrestoEngineSpec):
logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l)
last_log_line = len(log_lines)
if needs_commit:
session.commit()
db.session.commit()
if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"):
logger.warning(
"HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead"

View File

@ -23,8 +23,8 @@ from typing import Any, Optional
from flask import current_app
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import Session
from superset import db
from superset.constants import QUERY_EARLY_CANCEL_KEY, TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.sql_lab import Query
@ -101,7 +101,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
raise cls.get_dbapi_mapped_exception(ex)
@classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
def handle_cursor(cls, cursor: Any, query: Query) -> None:
"""Stop query and updates progress information"""
query_id = query.id
@ -113,8 +113,8 @@ class ImpalaEngineSpec(BaseEngineSpec):
try:
status = cursor.status()
while status in unfinished_states:
session.refresh(query)
query = session.query(Query).filter_by(id=query_id).one()
db.session.refresh(query)
query = db.session.query(Query).filter_by(id=query_id).one()
# if query cancelation was requested prior to the handle_cursor call, but
# the query was still executed
# modified in stop_query in views / core.py is reflected here.
@ -145,7 +145,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
needs_commit = True
if needs_commit:
session.commit()
db.session.commit()
sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get(
cls.engine, 5
)

View File

@ -23,7 +23,6 @@ from typing import Any, Callable, List, NamedTuple, Optional
from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import Session
with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be installed
# Ensure pyocient inherits Superset's logging level
@ -372,13 +371,13 @@ class OcientEngineSpec(BaseEngineSpec):
return "DUMMY_VALUE"
@classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
def handle_cursor(cls, cursor: Any, query: Query) -> None:
with OcientEngineSpec.query_id_mapping_lock:
OcientEngineSpec.query_id_mapping[query.id] = cursor.query_id
# Add the query id to the cursor
setattr(cursor, "superset_query_id", query.id)
return super().handle_cursor(cursor, query, session)
return super().handle_cursor(cursor, query)
@classmethod
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:

View File

@ -39,10 +39,9 @@ from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import Row as ResultRow
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from superset import cache_manager, is_feature_enabled
from superset import cache_manager, db, is_feature_enabled
from superset.common.db_query_status import QueryStatus
from superset.constants import TimeGrain
from superset.databases.utils import make_url_safe
@ -1288,11 +1287,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return None
@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
def handle_cursor(cls, cursor: Cursor, query: Query) -> None:
"""Updates progress information"""
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
session.commit()
db.session.commit()
query_id = query.id
poll_interval = query.database.connect_args.get(
@ -1308,7 +1307,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# Update the object and wait for the kill signal.
stats = polled.get("stats", {})
query = session.query(type(query)).filter_by(id=query_id).one()
query = db.session.query(type(query)).filter_by(id=query_id).one()
if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
cursor.cancel()
break
@ -1332,7 +1331,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
)
if progress > query.progress:
query.progress = progress
session.commit()
db.session.commit()
time.sleep(poll_interval)
logger.info("Query %i: Polling the cursor for progress", query_id)
polled = cursor.poll()

View File

@ -27,8 +27,8 @@ from flask import current_app
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Session
from superset import db
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
@ -155,7 +155,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
return None
@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
def handle_cursor(cls, cursor: Cursor, query: Query) -> None:
"""
Handle a trino client cursor.
@ -172,7 +172,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
session.commit()
db.session.commit()
# if query cancelation was requested prior to the handle_cursor call, but
# the query was still executed, trigger the actual query cancelation now
@ -183,12 +183,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
cancel_query_id=cancel_query_id,
)
super().handle_cursor(cursor=cursor, query=query, session=session)
super().handle_cursor(cursor=cursor, query=query)
@classmethod
def execute_with_cursor(
cls, cursor: Cursor, sql: str, query: Query, session: Session
) -> None:
def execute_with_cursor(cls, cursor: Cursor, sql: str, query: Query) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
@ -225,7 +223,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
time.sleep(0.1)
logger.debug("Query %d: Handling cursor", query_id)
cls.handle_cursor(cursor, query, session)
cls.handle_cursor(cursor, query)
# Block until the query completes; same behaviour as the client itself
logger.debug("Query %d: Waiting for query to complete", query_id)
@ -237,10 +235,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
raise err
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
def prepare_cancel_query(cls, query: Query) -> None:
if QUERY_CANCEL_KEY not in query.extra:
query.set_extra_json_key(QUERY_EARLY_CANCEL_KEY, True)
session.commit()
db.session.commit()
@classmethod
def cancel_query(cls, cursor: Cursor, query: Query, cancel_query_id: str) -> bool:

View File

@ -25,10 +25,8 @@ from typing import Any, cast, Optional, Union
import backoff
import msgpack
import simplejson as json
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from flask_babel import gettext as __
from sqlalchemy.orm import Session
from superset import (
app,
@ -56,7 +54,6 @@ from superset.sql_parse import (
)
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import write_ipc_buffer
from superset.utils.celery import session_scope
from superset.utils.core import (
json_iso_dttm_ser,
override_user,
@ -92,7 +89,6 @@ class SqlLabQueryStoppedException(SqlLabException):
def handle_query_error(
ex: Exception,
query: Query,
session: Session,
payload: Optional[dict[str, Any]] = None,
prefix_message: str = "",
) -> dict[str, Any]:
@ -120,7 +116,7 @@ def handle_query_error(
if errors:
query.set_extra_json_key("errors", errors_payload)
session.commit()
db.session.commit()
payload.update({"status": query.status, "error": msg, "errors": errors_payload})
if troubleshooting_link := config["TROUBLESHOOTING_LINK"]:
payload["link"] = troubleshooting_link
@ -150,22 +146,20 @@ def get_query_giveup_handler(_: Any) -> None:
on_giveup=get_query_giveup_handler,
max_tries=5,
)
def get_query(query_id: int, session: Session) -> Query:
def get_query(query_id: int) -> Query:
"""attempts to get the query and retry if it cannot"""
try:
return session.query(Query).filter_by(id=query_id).one()
return db.session.query(Query).filter_by(id=query_id).one()
except Exception as ex:
raise SqlLabException("Failed at getting query") from ex
@celery_app.task(
name="sql_lab.get_sql_results",
bind=True,
time_limit=SQLLAB_HARD_TIMEOUT,
soft_time_limit=SQLLAB_TIMEOUT,
)
def get_sql_results( # pylint: disable=too-many-arguments
ctask: Task,
query_id: int,
rendered_query: str,
return_results: bool = True,
@ -176,30 +170,27 @@ def get_sql_results( # pylint: disable=too-many-arguments
log_params: Optional[dict[str, Any]] = None,
) -> Optional[dict[str, Any]]:
"""Executes the sql query returns the results."""
with session_scope(not ctask.request.called_directly) as session:
with override_user(security_manager.find_user(username)):
try:
return execute_sql_statements(
query_id,
rendered_query,
return_results,
store_results,
session=session,
start_time=start_time,
expand_data=expand_data,
log_params=log_params,
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id, session)
return handle_query_error(ex, query, session)
with override_user(security_manager.find_user(username)):
try:
return execute_sql_statements(
query_id,
rendered_query,
return_results,
store_results,
start_time=start_time,
expand_data=expand_data,
log_params=log_params,
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id)
return handle_query_error(ex, query)
def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-locals
def execute_sql_statement(
sql_statement: str,
query: Query,
session: Session,
cursor: Any,
log_params: Optional[dict[str, Any]],
apply_ctas: bool = False,
@ -284,9 +275,9 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
security_manager,
log_params,
)
session.commit()
db.session.commit()
with stats_timing("sqllab.query.time_executing_query", stats_logger):
db_engine_spec.execute_with_cursor(cursor, sql, query, session)
db_engine_spec.execute_with_cursor(cursor, sql, query)
with stats_timing("sqllab.query.time_fetching_results", stats_logger):
logger.debug(
@ -319,7 +310,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
except Exception as ex:
# query is stopped in another thread/worker
# stopping raises expected exceptions which we should skip
session.refresh(query)
db.session.refresh(query)
if query.status == QueryStatus.STOPPED:
raise SqlLabQueryStoppedException() from ex
@ -393,7 +384,6 @@ def execute_sql_statements(
rendered_query: str,
return_results: bool,
store_results: bool,
session: Session,
start_time: Optional[float],
expand_data: bool,
log_params: Optional[dict[str, Any]],
@ -403,7 +393,7 @@ def execute_sql_statements(
# only asynchronous queries
stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)
query = get_query(query_id, session)
query = get_query(query_id)
payload: dict[str, Any] = {"query_id": query_id}
database = query.database
db_engine_spec = database.db_engine_spec
@ -432,7 +422,7 @@ def execute_sql_statements(
logger.info("Query %s: Set query to 'running'", str(query_id))
query.status = QueryStatus.RUNNING
query.start_running_time = now_as_float()
session.commit()
db.session.commit()
# Should we create a table or view from the select?
if (
@ -476,11 +466,11 @@ def execute_sql_statements(
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
if cancel_query_id is not None:
query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
session.commit()
db.session.commit()
statement_count = len(statements)
for i, statement in enumerate(statements):
# Check if stopped
session.refresh(query)
db.session.refresh(query)
if query.status == QueryStatus.STOPPED:
payload.update({"status": query.status})
return payload
@ -497,12 +487,11 @@ def execute_sql_statements(
)
logger.info("Query %s: %s", str(query_id), msg)
query.set_extra_json_key("progress", msg)
session.commit()
db.session.commit()
try:
result_set = execute_sql_statement(
statement,
query,
session,
cursor,
log_params,
apply_ctas,
@ -521,9 +510,7 @@ def execute_sql_statements(
if statement_count > 1
else ""
)
payload = handle_query_error(
ex, query, session, payload, prefix_message
)
payload = handle_query_error(ex, query, payload, prefix_message)
return payload
# Commit the connection so CTA queries will create the table and any DML.
@ -593,7 +580,7 @@ def execute_sql_statements(
query.results_key = key
query.status = QueryStatus.SUCCESS
session.commit()
db.session.commit()
if return_results:
# since we're returning results we need to create non-arrow data
@ -634,7 +621,7 @@ def cancel_query(query: Query) -> bool:
return True
# Some databases may need to make preparations for query cancellation
query.database.db_engine_spec.prepare_cancel_query(query, db.session)
query.database.db_engine_spec.prepare_cancel_query(query)
if query.extra.get(QUERY_EARLY_CANCEL_KEY):
# Query has been cancelled prior to being able to set the cancel key.

View File

@ -94,14 +94,7 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
name = "dummy"
def get_payloads(self) -> list[dict[str, int]]:
session = db.create_scoped_session()
try:
charts = session.query(Slice).all()
finally:
session.close()
return [get_payload(chart) for chart in charts]
return [get_payload(chart) for chart in db.session.query(Slice).all()]
class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods
@ -130,28 +123,24 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
self.since = parse_human_datetime(since) if since else None
def get_payloads(self) -> list[dict[str, int]]:
payloads = []
session = db.create_scoped_session()
records = (
db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
.limit(self.top_n)
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = (
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
try:
records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
.limit(self.top_n)
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = (
session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))
finally:
session.close()
return payloads
return [
get_payload(chart, dashboard)
for dashboard in dashboards
for chart in dashboard.slices
]
class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
@ -178,48 +167,44 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
def get_payloads(self) -> list[dict[str, int]]:
payloads = []
session = db.create_scoped_session()
tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
try:
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
# add dashboards that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids),
)
# add dashboards that are tagged
tagged_objects = (
db.session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids),
)
.all()
)
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart))
# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
.all()
)
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = db.session.query(Dashboard).filter(
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart))
finally:
session.close()
# add charts that are tagged
tagged_objects = (
db.session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))
return payloads

View File

@ -21,7 +21,7 @@ it needs to call create_app() in order to initialize things properly
"""
from typing import Any
from celery.signals import worker_process_init
from celery.signals import task_postrun, worker_process_init
# Superset framework imports
from superset import create_app
@ -43,3 +43,27 @@ def reset_db_connection_pool(**kwargs: Any) -> None: # pylint: disable=unused-a
with flask_app.app_context():
# https://docs.sqlalchemy.org/en/14/core/connections.html#engine-disposal
db.engine.dispose()
@task_postrun.connect
def teardown( # pylint: disable=unused-argument
retval: Any,
*args: Any,
**kwargs: Any,
) -> None:
"""
After each Celery task teardown the Flask-SQLAlchemy session.
Note for non eagar requests Flask-SQLAlchemy will perform the teardown.
:param retval: The return value of the task
:see: https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun
:see: https://gist.github.com/twolfson/a1b329e9353f9b575131
"""
if flask_app.config.get("SQLALCHEMY_COMMIT_ON_TEARDOWN"):
if not isinstance(retval, Exception):
db.session.commit()
if not flask_app.config.get("CELERY_ALWAYS_EAGER"):
db.session.remove()

View File

@ -29,7 +29,6 @@ from superset.daos.report import ReportScheduleDAO
from superset.extensions import celery_app
from superset.stats_logger import BaseStatsLogger
from superset.tasks.cron_util import cron_schedule_window
from superset.utils.celery import session_scope
from superset.utils.core import LoggerLevel
from superset.utils.log import get_logger_from_status
@ -46,35 +45,32 @@ def scheduler() -> None:
if not is_feature_enabled("ALERT_REPORTS"):
return
with session_scope(nullpool=True) as session:
active_schedules = ReportScheduleDAO.find_active(session)
triggered_at = (
datetime.fromisoformat(scheduler.request.expires)
- app.config["CELERY_BEAT_SCHEDULER_EXPIRES"]
if scheduler.request.expires
else datetime.utcnow()
)
for active_schedule in active_schedules:
for schedule in cron_schedule_window(
triggered_at, active_schedule.crontab, active_schedule.timezone
active_schedules = ReportScheduleDAO.find_active()
triggered_at = (
datetime.fromisoformat(scheduler.request.expires)
- app.config["CELERY_BEAT_SCHEDULER_EXPIRES"]
if scheduler.request.expires
else datetime.utcnow()
)
for active_schedule in active_schedules:
for schedule in cron_schedule_window(
triggered_at, active_schedule.crontab, active_schedule.timezone
):
logger.info("Scheduling alert %s eta: %s", active_schedule.name, schedule)
async_options = {"eta": schedule}
if (
active_schedule.working_timeout is not None
and app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"]
):
logger.info(
"Scheduling alert %s eta: %s", active_schedule.name, schedule
async_options["time_limit"] = (
active_schedule.working_timeout
+ app.config["ALERT_REPORTS_WORKING_TIME_OUT_LAG"]
)
async_options = {"eta": schedule}
if (
active_schedule.working_timeout is not None
and app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"]
):
async_options["time_limit"] = (
active_schedule.working_timeout
+ app.config["ALERT_REPORTS_WORKING_TIME_OUT_LAG"]
)
async_options["soft_time_limit"] = (
active_schedule.working_timeout
+ app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"]
)
execute.apply_async((active_schedule.id,), **async_options)
async_options["soft_time_limit"] = (
active_schedule.working_timeout
+ app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"]
)
execute.apply_async((active_schedule.id,), **async_options)
@celery_app.task(name="reports.execute", bind=True)

View File

@ -1,59 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import NullPool
from superset import app, db
logger = logging.getLogger(__name__)
# Null pool is used for the celery workers due process forking side effects.
# For more info see: https://github.com/apache/superset/issues/10530
@contextmanager
def session_scope(nullpool: bool) -> Iterator[Session]:
"""Provide a transactional scope around a series of operations."""
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
if "sqlite" in database_uri:
logger.warning(
"SQLite Database support for metadata databases will be removed \
in a future version of Superset."
)
if nullpool:
engine = create_engine(database_uri, poolclass=NullPool)
session_class = sessionmaker()
session_class.configure(bind=engine)
session = session_class()
else:
session = db.session()
session.commit() # HACK
try:
yield session
session.commit()
except SQLAlchemyError as ex:
session.rollback()
logger.exception(ex)
raise
finally:
session.close()

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset import app, db
from superset import app
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
@ -29,7 +29,6 @@ def test_non_async_execute(non_async_example_db: Database, example_query: Query)
"select 1 as foo;",
store_results=False,
return_results=True,
session=db.session,
start_time=now_as_float(),
expand_data=True,
log_params=dict(),

View File

@ -575,16 +575,22 @@ class TestSqlLab(SupersetTestCase):
)
assert data["errors"][0]["error_type"] == "GENERIC_BACKEND_ERROR"
@mock.patch("superset.sql_lab.db")
@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query):
def test_execute_sql_statements(
self,
mock_execute_sql_statement,
mock_get_query,
mock_db,
):
sql = """
-- comment
SET @value = 42;
SELECT @value AS foo;
-- comment
"""
mock_session = mock.MagicMock()
mock_db = mock.MagicMock()
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
@ -599,7 +605,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,
@ -609,7 +614,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SET @value = 42",
mock_query,
mock_session,
mock_cursor,
None,
False,
@ -617,7 +621,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SELECT @value AS foo",
mock_query,
mock_session,
mock_cursor,
None,
False,
@ -637,7 +640,6 @@ class TestSqlLab(SupersetTestCase):
SELECT @value AS foo;
-- comment
"""
mock_session = mock.MagicMock()
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = True
mock_cursor = mock.MagicMock()
@ -653,7 +655,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,
@ -676,10 +677,14 @@ class TestSqlLab(SupersetTestCase):
},
)
@mock.patch("superset.sql_lab.db")
@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
def test_execute_sql_statements_ctas(
self, mock_execute_sql_statement, mock_get_query
self,
mock_execute_sql_statement,
mock_get_query,
mock_db,
):
sql = """
-- comment
@ -687,7 +692,7 @@ class TestSqlLab(SupersetTestCase):
SELECT @value AS foo;
-- comment
"""
mock_session = mock.MagicMock()
mock_db = mock.MagicMock()
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
@ -706,7 +711,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,
@ -716,7 +720,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SET @value = 42",
mock_query,
mock_session,
mock_cursor,
None,
False,
@ -724,7 +727,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SELECT @value AS foo",
mock_query,
mock_session,
mock_cursor,
None,
True, # apply_ctas
@ -740,7 +742,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,
@ -773,7 +774,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
session=mock_session,
start_time=None,
expand_data=False,
log_params=None,

View File

@ -352,9 +352,8 @@ def test_prepare_cancel_query(
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
session_mock = mocker.MagicMock()
query = Query(extra_json=json.dumps(initial_extra))
TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
TrinoEngineSpec.prepare_cancel_query(query=query)
assert query.extra == final_extra
@ -374,14 +373,13 @@ def test_handle_cursor_early_cancel(
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.query_id = query_id
session_mock = mocker.MagicMock()
query = Query()
if cancel_early:
TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
TrinoEngineSpec.prepare_cancel_query(query=query)
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query, session=session_mock)
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
if cancel_early:
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
@ -399,7 +397,6 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
mock_session = mocker.MagicMock()
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id
@ -410,7 +407,6 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
session=mock_session,
)
mock_query.set_extra_json_key.assert_called_once_with(

View File

@ -20,6 +20,7 @@ import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
from superset import db
from superset.utils.core import override_user
@ -41,14 +42,12 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
db_engine_spec.is_select_query.return_value = True
db_engine_spec.fetch_data.return_value = [(42,)]
session = mocker.MagicMock()
cursor = mocker.MagicMock()
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
execute_sql_statement(
sql_statement,
query,
session=session,
cursor=cursor,
log_params={},
apply_ctas=False,
@ -56,7 +55,7 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
db_engine_spec.execute_with_cursor.assert_called_with(
cursor, "SELECT 42 AS answer LIMIT 2", query, session
cursor, "SELECT 42 AS answer LIMIT 2", query
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
@ -83,7 +82,6 @@ def test_execute_sql_statement_with_rls(
db_engine_spec.is_select_query.return_value = True
db_engine_spec.fetch_data.return_value = [(42,)]
session = mocker.MagicMock()
cursor = mocker.MagicMock()
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
mocker.patch(
@ -95,7 +93,6 @@ def test_execute_sql_statement_with_rls(
execute_sql_statement(
sql_statement,
query,
session=session,
cursor=cursor,
log_params={},
apply_ctas=False,
@ -107,7 +104,7 @@ def test_execute_sql_statement_with_rls(
force=True,
)
db_engine_spec.execute_with_cursor.assert_called_with(
cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query, session
cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
@ -162,7 +159,6 @@ def test_sql_lab_insert_rls_as_subquery(
superset_result_set = execute_sql_statement(
sql_statement=query.sql,
query=query,
session=session,
cursor=cursor,
log_params=None,
apply_ctas=False,
@ -198,7 +194,6 @@ def test_sql_lab_insert_rls_as_subquery(
superset_result_set = execute_sql_statement(
sql_statement=query.sql,
query=query,
session=session,
cursor=cursor,
log_params=None,
apply_ctas=False,