refactor: Ensure Celery leverages the Flask-SQLAlchemy session (#26186)
This commit is contained in:
parent
aaa4a7b371
commit
7af82ae87d
|
|
@ -0,0 +1,688 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
load_birth_names_data,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from flask import g
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
|
||||
from tests.integration_tests.fixtures.energy_dashboard import (
|
||||
load_energy_table_with_slice,
|
||||
load_energy_table_data,
|
||||
)
|
||||
from tests.integration_tests.test_app import app
|
||||
from superset.commands.dashboard.importers.v0 import decode_dashboards
|
||||
from superset import db, security_manager
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.commands.dashboard.importers.v0 import import_chart, import_dashboard
|
||||
from superset.commands.dataset.importers.v0 import import_dataset
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType, get_example_default_schema
|
||||
from superset.utils.database import get_example_database
|
||||
|
||||
from tests.integration_tests.fixtures.world_bank_dashboard import (
|
||||
load_world_bank_dashboard_with_slices,
|
||||
load_world_bank_data,
|
||||
)
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
def delete_imports():
|
||||
with app.app_context():
|
||||
# Imported data clean up
|
||||
session = db.session
|
||||
for slc in session.query(Slice):
|
||||
if "remote_id" in slc.params_dict:
|
||||
session.delete(slc)
|
||||
for dash in session.query(Dashboard):
|
||||
if "remote_id" in dash.params_dict:
|
||||
session.delete(dash)
|
||||
for table in session.query(SqlaTable):
|
||||
if "remote_id" in table.params_dict:
|
||||
session.delete(table)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def clean_imports():
|
||||
yield
|
||||
delete_imports()
|
||||
|
||||
|
||||
class TestImportExport(SupersetTestCase):
|
||||
"""Testing export import functionality for dashboards"""
|
||||
|
||||
def create_slice(
|
||||
self,
|
||||
name,
|
||||
ds_id=None,
|
||||
id=None,
|
||||
db_name="examples",
|
||||
table_name="wb_health_population",
|
||||
schema=None,
|
||||
):
|
||||
params = {
|
||||
"num_period_compare": "10",
|
||||
"remote_id": id,
|
||||
"datasource_name": table_name,
|
||||
"database_name": db_name,
|
||||
"schema": schema,
|
||||
# Test for trailing commas
|
||||
"metrics": ["sum__signup_attempt_email", "sum__signup_attempt_facebook"],
|
||||
}
|
||||
|
||||
if table_name and not ds_id:
|
||||
table = self.get_table(schema=schema, name=table_name)
|
||||
if table:
|
||||
ds_id = table.id
|
||||
|
||||
return Slice(
|
||||
slice_name=name,
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
viz_type="bubble",
|
||||
params=json.dumps(params),
|
||||
datasource_id=ds_id,
|
||||
id=id,
|
||||
)
|
||||
|
||||
def create_dashboard(self, title, id=0, slcs=[]):
|
||||
json_metadata = {"remote_id": id}
|
||||
return Dashboard(
|
||||
id=id,
|
||||
dashboard_title=title,
|
||||
slices=slcs,
|
||||
position_json='{"size_y": 2, "size_x": 2}',
|
||||
slug=f"{title.lower()}_imported",
|
||||
json_metadata=json.dumps(json_metadata),
|
||||
published=False,
|
||||
)
|
||||
|
||||
def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]):
|
||||
params = {"remote_id": id, "database_name": "examples"}
|
||||
table = SqlaTable(
|
||||
id=id,
|
||||
schema=schema,
|
||||
table_name=name,
|
||||
params=json.dumps(params),
|
||||
)
|
||||
for col_name in cols_names:
|
||||
table.columns.append(TableColumn(column_name=col_name))
|
||||
for metric_name in metric_names:
|
||||
table.metrics.append(SqlMetric(metric_name=metric_name, expression=""))
|
||||
return table
|
||||
|
||||
def get_slice(self, slc_id):
|
||||
return db.session.query(Slice).filter_by(id=slc_id).first()
|
||||
|
||||
def get_slice_by_name(self, name):
|
||||
return db.session.query(Slice).filter_by(slice_name=name).first()
|
||||
|
||||
def get_dash(self, dash_id):
|
||||
return db.session.query(Dashboard).filter_by(id=dash_id).first()
|
||||
|
||||
def assert_dash_equals(
|
||||
self, expected_dash, actual_dash, check_position=True, check_slugs=True
|
||||
):
|
||||
if check_slugs:
|
||||
self.assertEqual(expected_dash.slug, actual_dash.slug)
|
||||
self.assertEqual(expected_dash.dashboard_title, actual_dash.dashboard_title)
|
||||
self.assertEqual(len(expected_dash.slices), len(actual_dash.slices))
|
||||
expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "")
|
||||
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
|
||||
for e_slc, a_slc in zip(expected_slices, actual_slices):
|
||||
self.assert_slice_equals(e_slc, a_slc)
|
||||
if check_position:
|
||||
self.assertEqual(expected_dash.position_json, actual_dash.position_json)
|
||||
|
||||
def assert_table_equals(self, expected_ds, actual_ds):
|
||||
self.assertEqual(expected_ds.table_name, actual_ds.table_name)
|
||||
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
|
||||
self.assertEqual(expected_ds.schema, actual_ds.schema)
|
||||
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
|
||||
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
|
||||
self.assertEqual(
|
||||
{c.column_name for c in expected_ds.columns},
|
||||
{c.column_name for c in actual_ds.columns},
|
||||
)
|
||||
self.assertEqual(
|
||||
{m.metric_name for m in expected_ds.metrics},
|
||||
{m.metric_name for m in actual_ds.metrics},
|
||||
)
|
||||
|
||||
def assert_datasource_equals(self, expected_ds, actual_ds):
|
||||
self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name)
|
||||
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
|
||||
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
|
||||
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
|
||||
self.assertEqual(
|
||||
{c.column_name for c in expected_ds.columns},
|
||||
{c.column_name for c in actual_ds.columns},
|
||||
)
|
||||
self.assertEqual(
|
||||
{m.metric_name for m in expected_ds.metrics},
|
||||
{m.metric_name for m in actual_ds.metrics},
|
||||
)
|
||||
|
||||
def assert_slice_equals(self, expected_slc, actual_slc):
|
||||
# to avoid bad slice data (no slice_name)
|
||||
expected_slc_name = expected_slc.slice_name or ""
|
||||
actual_slc_name = actual_slc.slice_name or ""
|
||||
self.assertEqual(expected_slc_name, actual_slc_name)
|
||||
self.assertEqual(expected_slc.datasource_type, actual_slc.datasource_type)
|
||||
self.assertEqual(expected_slc.viz_type, actual_slc.viz_type)
|
||||
exp_params = json.loads(expected_slc.params)
|
||||
actual_params = json.loads(actual_slc.params)
|
||||
diff_params_keys = (
|
||||
"schema",
|
||||
"database_name",
|
||||
"datasource_name",
|
||||
"remote_id",
|
||||
"import_time",
|
||||
)
|
||||
for k in diff_params_keys:
|
||||
if k in actual_params:
|
||||
actual_params.pop(k)
|
||||
if k in exp_params:
|
||||
exp_params.pop(k)
|
||||
self.assertEqual(exp_params, actual_params)
|
||||
|
||||
def assert_only_exported_slc_fields(self, expected_dash, actual_dash):
|
||||
"""only exported json has this params
|
||||
imported/created dashboard has relationships to other models instead
|
||||
"""
|
||||
expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "")
|
||||
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
|
||||
for e_slc, a_slc in zip(expected_slices, actual_slices):
|
||||
params = a_slc.params_dict
|
||||
self.assertEqual(e_slc.datasource.name, params["datasource_name"])
|
||||
self.assertEqual(e_slc.datasource.schema, params["schema"])
|
||||
self.assertEqual(e_slc.datasource.database.name, params["database_name"])
|
||||
|
||||
@unittest.skip("Schema needs to be updated")
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_export_1_dashboard(self):
|
||||
self.login("admin")
|
||||
birth_dash = self.get_dash_by_slug("births")
|
||||
id_ = birth_dash.id
|
||||
export_dash_url = f"/dashboard/export_dashboards_form?id={id_}&action=go"
|
||||
resp = self.client.get(export_dash_url)
|
||||
exported_dashboards = json.loads(
|
||||
resp.data.decode("utf-8"), object_hook=decode_dashboards
|
||||
)["dashboards"]
|
||||
|
||||
birth_dash = self.get_dash_by_slug("births")
|
||||
self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0])
|
||||
self.assert_dash_equals(birth_dash, exported_dashboards[0])
|
||||
self.assertEqual(
|
||||
id_,
|
||||
json.loads(
|
||||
exported_dashboards[0].json_metadata, object_hook=decode_dashboards
|
||||
)["remote_id"],
|
||||
)
|
||||
|
||||
exported_tables = json.loads(
|
||||
resp.data.decode("utf-8"), object_hook=decode_dashboards
|
||||
)["datasources"]
|
||||
self.assertEqual(1, len(exported_tables))
|
||||
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
|
||||
|
||||
@unittest.skip("Schema needs to be updated")
|
||||
@pytest.mark.usefixtures(
|
||||
"load_world_bank_dashboard_with_slices",
|
||||
"load_birth_names_dashboard_with_slices",
|
||||
)
|
||||
def test_export_2_dashboards(self):
|
||||
self.login("admin")
|
||||
birth_dash = self.get_dash_by_slug("births")
|
||||
world_health_dash = self.get_dash_by_slug("world_health")
|
||||
export_dash_url = (
|
||||
"/dashboard/export_dashboards_form?id={}&id={}&action=go".format(
|
||||
birth_dash.id, world_health_dash.id
|
||||
)
|
||||
)
|
||||
resp = self.client.get(export_dash_url)
|
||||
resp_data = json.loads(resp.data.decode("utf-8"), object_hook=decode_dashboards)
|
||||
exported_dashboards = sorted(
|
||||
resp_data.get("dashboards"), key=lambda d: d.dashboard_title
|
||||
)
|
||||
self.assertEqual(2, len(exported_dashboards))
|
||||
|
||||
birth_dash = self.get_dash_by_slug("births")
|
||||
self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0])
|
||||
self.assert_dash_equals(birth_dash, exported_dashboards[0])
|
||||
self.assertEqual(
|
||||
birth_dash.id, json.loads(exported_dashboards[0].json_metadata)["remote_id"]
|
||||
)
|
||||
|
||||
world_health_dash = self.get_dash_by_slug("world_health")
|
||||
self.assert_only_exported_slc_fields(world_health_dash, exported_dashboards[1])
|
||||
self.assert_dash_equals(world_health_dash, exported_dashboards[1])
|
||||
self.assertEqual(
|
||||
world_health_dash.id,
|
||||
json.loads(exported_dashboards[1].json_metadata)["remote_id"],
|
||||
)
|
||||
|
||||
exported_tables = sorted(
|
||||
resp_data.get("datasources"), key=lambda t: t.table_name
|
||||
)
|
||||
self.assertEqual(2, len(exported_tables))
|
||||
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
|
||||
self.assert_table_equals(
|
||||
self.get_table(name="wb_health_population"), exported_tables[1]
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_import_1_slice(self):
|
||||
expected_slice = self.create_slice(
|
||||
"Import Me", id=10001, schema=get_example_default_schema()
|
||||
)
|
||||
slc_id = import_chart(expected_slice, None, import_time=1989)
|
||||
slc = self.get_slice(slc_id)
|
||||
self.assertEqual(slc.datasource.perm, slc.perm)
|
||||
self.assert_slice_equals(expected_slice, slc)
|
||||
|
||||
table_id = self.get_table(name="wb_health_population").id
|
||||
self.assertEqual(table_id, self.get_slice(slc_id).datasource_id)
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_import_2_slices_for_same_table(self):
|
||||
schema = get_example_default_schema()
|
||||
table_id = self.get_table(name="wb_health_population").id
|
||||
slc_1 = self.create_slice(
|
||||
"Import Me 1", ds_id=table_id, id=10002, schema=schema
|
||||
)
|
||||
slc_id_1 = import_chart(slc_1, None)
|
||||
slc_2 = self.create_slice(
|
||||
"Import Me 2", ds_id=table_id, id=10003, schema=schema
|
||||
)
|
||||
slc_id_2 = import_chart(slc_2, None)
|
||||
|
||||
imported_slc_1 = self.get_slice(slc_id_1)
|
||||
imported_slc_2 = self.get_slice(slc_id_2)
|
||||
self.assertEqual(table_id, imported_slc_1.datasource_id)
|
||||
self.assert_slice_equals(slc_1, imported_slc_1)
|
||||
self.assertEqual(imported_slc_1.datasource.perm, imported_slc_1.perm)
|
||||
|
||||
self.assertEqual(table_id, imported_slc_2.datasource_id)
|
||||
self.assert_slice_equals(slc_2, imported_slc_2)
|
||||
self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm)
|
||||
|
||||
def test_import_slices_override(self):
|
||||
schema = get_example_default_schema()
|
||||
slc = self.create_slice("Import Me New", id=10005, schema=schema)
|
||||
slc_1_id = import_chart(slc, None, import_time=1990)
|
||||
slc.slice_name = "Import Me New"
|
||||
imported_slc_1 = self.get_slice(slc_1_id)
|
||||
slc_2 = self.create_slice("Import Me New", id=10005, schema=schema)
|
||||
slc_2_id = import_chart(slc_2, imported_slc_1, import_time=1990)
|
||||
self.assertEqual(slc_1_id, slc_2_id)
|
||||
imported_slc_2 = self.get_slice(slc_2_id)
|
||||
self.assert_slice_equals(slc, imported_slc_2)
|
||||
|
||||
def test_import_empty_dashboard(self):
|
||||
empty_dash = self.create_dashboard("empty_dashboard", id=10001)
|
||||
imported_dash_id = import_dashboard(empty_dash, import_time=1989)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assert_dash_equals(empty_dash, imported_dash, check_position=False)
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_import_dashboard_1_slice(self):
|
||||
slc = self.create_slice(
|
||||
"health_slc", id=10006, schema=get_example_default_schema()
|
||||
)
|
||||
dash_with_1_slice = self.create_dashboard(
|
||||
"dash_with_1_slice", slcs=[slc], id=10002
|
||||
)
|
||||
dash_with_1_slice.position_json = """
|
||||
{{"DASHBOARD_VERSION_KEY": "v2",
|
||||
"DASHBOARD_CHART_TYPE-{0}": {{
|
||||
"type": "CHART",
|
||||
"id": {0},
|
||||
"children": [],
|
||||
"meta": {{
|
||||
"width": 4,
|
||||
"height": 50,
|
||||
"chartId": {0}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
""".format(
|
||||
slc.id
|
||||
)
|
||||
imported_dash_id = import_dashboard(dash_with_1_slice, import_time=1990)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
|
||||
expected_dash = self.create_dashboard("dash_with_1_slice", slcs=[slc], id=10002)
|
||||
make_transient(expected_dash)
|
||||
self.assert_dash_equals(
|
||||
expected_dash, imported_dash, check_position=False, check_slugs=False
|
||||
)
|
||||
self.assertEqual(
|
||||
{
|
||||
"remote_id": 10002,
|
||||
"import_time": 1990,
|
||||
"native_filter_configuration": [],
|
||||
},
|
||||
json.loads(imported_dash.json_metadata),
|
||||
)
|
||||
|
||||
expected_position = dash_with_1_slice.position
|
||||
# new slice id (auto-incremental) assigned on insert
|
||||
# id from json is used only for updating position with new id
|
||||
meta = expected_position["DASHBOARD_CHART_TYPE-10006"]["meta"]
|
||||
meta["chartId"] = imported_dash.slices[0].id
|
||||
self.assertEqual(expected_position, imported_dash.position)
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_import_dashboard_2_slices(self):
|
||||
schema = get_example_default_schema()
|
||||
e_slc = self.create_slice(
|
||||
"e_slc", id=10007, table_name="energy_usage", schema=schema
|
||||
)
|
||||
b_slc = self.create_slice(
|
||||
"b_slc", id=10008, table_name="birth_names", schema=schema
|
||||
)
|
||||
dash_with_2_slices = self.create_dashboard(
|
||||
"dash_with_2_slices", slcs=[e_slc, b_slc], id=10003
|
||||
)
|
||||
dash_with_2_slices.json_metadata = json.dumps(
|
||||
{
|
||||
"remote_id": 10003,
|
||||
"expanded_slices": {
|
||||
f"{e_slc.id}": True,
|
||||
f"{b_slc.id}": False,
|
||||
},
|
||||
# mocked legacy filter_scope metadata
|
||||
"filter_scopes": {
|
||||
str(e_slc.id): {
|
||||
"region": {"scope": ["ROOT_ID"], "immune": [b_slc.id]}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
imported_dash_id = import_dashboard(dash_with_2_slices, import_time=1991)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
|
||||
expected_dash = self.create_dashboard(
|
||||
"dash_with_2_slices", slcs=[e_slc, b_slc], id=10003
|
||||
)
|
||||
make_transient(expected_dash)
|
||||
self.assert_dash_equals(
|
||||
imported_dash, expected_dash, check_position=False, check_slugs=False
|
||||
)
|
||||
i_e_slc = self.get_slice_by_name("e_slc")
|
||||
i_b_slc = self.get_slice_by_name("b_slc")
|
||||
expected_json_metadata = {
|
||||
"remote_id": 10003,
|
||||
"import_time": 1991,
|
||||
"expanded_slices": {
|
||||
f"{i_e_slc.id}": True,
|
||||
f"{i_b_slc.id}": False,
|
||||
},
|
||||
"native_filter_configuration": [],
|
||||
}
|
||||
self.assertEqual(
|
||||
expected_json_metadata, json.loads(imported_dash.json_metadata)
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_import_override_dashboard_2_slices(self):
|
||||
schema = get_example_default_schema()
|
||||
e_slc = self.create_slice(
|
||||
"e_slc", id=10009, table_name="energy_usage", schema=schema
|
||||
)
|
||||
b_slc = self.create_slice(
|
||||
"b_slc", id=10010, table_name="birth_names", schema=schema
|
||||
)
|
||||
dash_to_import = self.create_dashboard(
|
||||
"override_dashboard", slcs=[e_slc, b_slc], id=10004
|
||||
)
|
||||
imported_dash_id_1 = import_dashboard(dash_to_import, import_time=1992)
|
||||
|
||||
# create new instances of the slices
|
||||
e_slc = self.create_slice(
|
||||
"e_slc", id=10009, table_name="energy_usage", schema=schema
|
||||
)
|
||||
b_slc = self.create_slice(
|
||||
"b_slc", id=10010, table_name="birth_names", schema=schema
|
||||
)
|
||||
c_slc = self.create_slice(
|
||||
"c_slc", id=10011, table_name="birth_names", schema=schema
|
||||
)
|
||||
dash_to_import_override = self.create_dashboard(
|
||||
"override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004
|
||||
)
|
||||
imported_dash_id_2 = import_dashboard(dash_to_import_override, import_time=1992)
|
||||
|
||||
# override doesn't change the id
|
||||
self.assertEqual(imported_dash_id_1, imported_dash_id_2)
|
||||
expected_dash = self.create_dashboard(
|
||||
"override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004
|
||||
)
|
||||
make_transient(expected_dash)
|
||||
imported_dash = self.get_dash(imported_dash_id_2)
|
||||
self.assert_dash_equals(
|
||||
expected_dash, imported_dash, check_position=False, check_slugs=False
|
||||
)
|
||||
self.assertEqual(
|
||||
{
|
||||
"remote_id": 10004,
|
||||
"import_time": 1992,
|
||||
"native_filter_configuration": [],
|
||||
},
|
||||
json.loads(imported_dash.json_metadata),
|
||||
)
|
||||
|
||||
def test_import_new_dashboard_slice_reset_ownership(self):
|
||||
admin_user = security_manager.find_user(username="admin")
|
||||
self.assertTrue(admin_user)
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
self.assertTrue(gamma_user)
|
||||
g.user = gamma_user
|
||||
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10200)
|
||||
# set another user as an owner of importing dashboard
|
||||
dash_with_1_slice.created_by = admin_user
|
||||
dash_with_1_slice.changed_by = admin_user
|
||||
dash_with_1_slice.owners = [admin_user]
|
||||
|
||||
imported_dash_id = import_dashboard(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
|
||||
def test_import_override_dashboard_slice_reset_ownership(self):
|
||||
admin_user = security_manager.find_user(username="admin")
|
||||
self.assertTrue(admin_user)
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
self.assertTrue(gamma_user)
|
||||
g.user = gamma_user
|
||||
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
|
||||
|
||||
imported_dash_id = import_dashboard(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
|
||||
# re-import with another user shouldn't change the permissions
|
||||
g.user = admin_user
|
||||
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
|
||||
|
||||
imported_dash_id = import_dashboard(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
|
||||
def _create_dashboard_for_import(self, id_=10100):
|
||||
slc = self.create_slice(
|
||||
"health_slc" + str(id_), id=id_ + 1, schema=get_example_default_schema()
|
||||
)
|
||||
dash_with_1_slice = self.create_dashboard(
|
||||
"dash_with_1_slice" + str(id_), slcs=[slc], id=id_ + 2
|
||||
)
|
||||
dash_with_1_slice.position_json = """
|
||||
{{"DASHBOARD_VERSION_KEY": "v2",
|
||||
"DASHBOARD_CHART_TYPE-{0}": {{
|
||||
"type": "CHART",
|
||||
"id": {0},
|
||||
"children": [],
|
||||
"meta": {{
|
||||
"width": 4,
|
||||
"height": 50,
|
||||
"chartId": {0}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
""".format(
|
||||
slc.id
|
||||
)
|
||||
return dash_with_1_slice
|
||||
|
||||
def test_import_table_no_metadata(self):
|
||||
schema = get_example_default_schema()
|
||||
db_id = get_example_database().id
|
||||
table = self.create_table("pure_table", id=10001, schema=schema)
|
||||
imported_id = import_dataset(table, db_id, import_time=1989)
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
|
||||
def test_import_table_1_col_1_met(self):
|
||||
schema = get_example_default_schema()
|
||||
table = self.create_table(
|
||||
"table_1_col_1_met",
|
||||
id=10002,
|
||||
cols_names=["col1"],
|
||||
metric_names=["metric1"],
|
||||
schema=schema,
|
||||
)
|
||||
db_id = get_example_database().id
|
||||
imported_id = import_dataset(table, db_id, import_time=1990)
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
self.assertEqual(
|
||||
{
|
||||
"remote_id": 10002,
|
||||
"import_time": 1990,
|
||||
"database_name": "examples",
|
||||
},
|
||||
json.loads(imported.params),
|
||||
)
|
||||
|
||||
def test_import_table_2_col_2_met(self):
|
||||
schema = get_example_default_schema()
|
||||
table = self.create_table(
|
||||
"table_2_col_2_met",
|
||||
id=10003,
|
||||
cols_names=["c1", "c2"],
|
||||
metric_names=["m1", "m2"],
|
||||
schema=schema,
|
||||
)
|
||||
db_id = get_example_database().id
|
||||
imported_id = import_dataset(table, db_id, import_time=1991)
|
||||
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
|
||||
def test_import_table_override(self):
|
||||
schema = get_example_default_schema()
|
||||
table = self.create_table(
|
||||
"table_override",
|
||||
id=10003,
|
||||
cols_names=["col1"],
|
||||
metric_names=["m1"],
|
||||
schema=schema,
|
||||
)
|
||||
db_id = get_example_database().id
|
||||
imported_id = import_dataset(table, db_id, import_time=1991)
|
||||
|
||||
table_over = self.create_table(
|
||||
"table_override",
|
||||
id=10003,
|
||||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
schema=schema,
|
||||
)
|
||||
imported_over_id = import_dataset(table_over, db_id, import_time=1992)
|
||||
|
||||
imported_over = self.get_table_by_id(imported_over_id)
|
||||
self.assertEqual(imported_id, imported_over.id)
|
||||
expected_table = self.create_table(
|
||||
"table_override",
|
||||
id=10003,
|
||||
metric_names=["new_metric1", "m1"],
|
||||
cols_names=["col1", "new_col1", "col2", "col3"],
|
||||
schema=schema,
|
||||
)
|
||||
self.assert_table_equals(expected_table, imported_over)
|
||||
|
||||
def test_import_table_override_identical(self):
|
||||
schema = get_example_default_schema()
|
||||
table = self.create_table(
|
||||
"copy_cat",
|
||||
id=10004,
|
||||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
schema=schema,
|
||||
)
|
||||
db_id = get_example_database().id
|
||||
imported_id = import_dataset(table, db_id, import_time=1993)
|
||||
|
||||
copy_table = self.create_table(
|
||||
"copy_cat",
|
||||
id=10004,
|
||||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
schema=schema,
|
||||
)
|
||||
imported_id_copy = import_dataset(copy_table, db_id, import_time=1994)
|
||||
|
||||
self.assertEqual(imported_id, imported_id_copy)
|
||||
self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -22,9 +22,8 @@ from uuid import UUID
|
|||
|
||||
import pandas as pd
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import app, security_manager
|
||||
from superset import app, db, security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.dashboard.permalink.create import CreateDashboardPermalinkCommand
|
||||
from superset.commands.exceptions import CommandException
|
||||
|
|
@ -68,7 +67,6 @@ from superset.reports.notifications import create_notification
|
|||
from superset.reports.notifications.base import NotificationContent
|
||||
from superset.reports.notifications.exceptions import NotificationError
|
||||
from superset.tasks.utils import get_executor
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.core import HeaderDataType, override_user
|
||||
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
|
||||
from superset.utils.decorators import logs_context
|
||||
|
|
@ -85,12 +83,10 @@ class BaseReportState:
|
|||
@logs_context()
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
report_schedule: ReportSchedule,
|
||||
scheduled_dttm: datetime,
|
||||
execution_id: UUID,
|
||||
) -> None:
|
||||
self._session = session
|
||||
self._report_schedule = report_schedule
|
||||
self._scheduled_dttm = scheduled_dttm
|
||||
self._start_dttm = datetime.utcnow()
|
||||
|
|
@ -123,7 +119,7 @@ class BaseReportState:
|
|||
|
||||
self._report_schedule.last_state = state
|
||||
self._report_schedule.last_eval_dttm = datetime.utcnow()
|
||||
self._session.commit()
|
||||
db.session.commit()
|
||||
|
||||
def create_log(self, error_message: Optional[str] = None) -> None:
|
||||
"""
|
||||
|
|
@ -140,8 +136,8 @@ class BaseReportState:
|
|||
report_schedule=self._report_schedule,
|
||||
uuid=self._execution_id,
|
||||
)
|
||||
self._session.add(log)
|
||||
self._session.commit()
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
|
||||
def _get_url(
|
||||
self,
|
||||
|
|
@ -485,9 +481,7 @@ class BaseReportState:
|
|||
"""
|
||||
Checks if an alert is in it's grace period
|
||||
"""
|
||||
last_success = ReportScheduleDAO.find_last_success_log(
|
||||
self._report_schedule, session=self._session
|
||||
)
|
||||
last_success = ReportScheduleDAO.find_last_success_log(self._report_schedule)
|
||||
return (
|
||||
last_success is not None
|
||||
and self._report_schedule.grace_period
|
||||
|
|
@ -501,7 +495,7 @@ class BaseReportState:
|
|||
Checks if an alert/report on error is in it's notification grace period
|
||||
"""
|
||||
last_success = ReportScheduleDAO.find_last_error_notification(
|
||||
self._report_schedule, session=self._session
|
||||
self._report_schedule
|
||||
)
|
||||
if not last_success:
|
||||
return False
|
||||
|
|
@ -518,7 +512,7 @@ class BaseReportState:
|
|||
Checks if an alert is in a working timeout
|
||||
"""
|
||||
last_working = ReportScheduleDAO.find_last_entered_working_log(
|
||||
self._report_schedule, session=self._session
|
||||
self._report_schedule
|
||||
)
|
||||
if not last_working:
|
||||
return False
|
||||
|
|
@ -668,12 +662,10 @@ class ReportScheduleStateMachine: # pylint: disable=too-few-public-methods
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
task_uuid: UUID,
|
||||
report_schedule: ReportSchedule,
|
||||
scheduled_dttm: datetime,
|
||||
):
|
||||
self._session = session
|
||||
self._execution_id = task_uuid
|
||||
self._report_schedule = report_schedule
|
||||
self._scheduled_dttm = scheduled_dttm
|
||||
|
|
@ -684,7 +676,6 @@ class ReportScheduleStateMachine: # pylint: disable=too-few-public-methods
|
|||
self._report_schedule.last_state in state_cls.current_states
|
||||
):
|
||||
state_cls(
|
||||
self._session,
|
||||
self._report_schedule,
|
||||
self._scheduled_dttm,
|
||||
self._execution_id,
|
||||
|
|
@ -708,31 +699,30 @@ class AsyncExecuteReportScheduleCommand(BaseCommand):
|
|||
self._execution_id = UUID(task_id)
|
||||
|
||||
def run(self) -> None:
|
||||
with session_scope(nullpool=True) as session:
|
||||
try:
|
||||
self.validate(session=session)
|
||||
if not self._model:
|
||||
raise ReportScheduleExecuteUnexpectedError()
|
||||
_, username = get_executor(
|
||||
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
|
||||
model=self._model,
|
||||
try:
|
||||
self.validate()
|
||||
if not self._model:
|
||||
raise ReportScheduleExecuteUnexpectedError()
|
||||
_, username = get_executor(
|
||||
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
|
||||
model=self._model,
|
||||
)
|
||||
user = security_manager.find_user(username)
|
||||
with override_user(user):
|
||||
logger.info(
|
||||
"Running report schedule %s as user %s",
|
||||
self._execution_id,
|
||||
username,
|
||||
)
|
||||
user = security_manager.find_user(username)
|
||||
with override_user(user):
|
||||
logger.info(
|
||||
"Running report schedule %s as user %s",
|
||||
self._execution_id,
|
||||
username,
|
||||
)
|
||||
ReportScheduleStateMachine(
|
||||
session, self._execution_id, self._model, self._scheduled_dttm
|
||||
).run()
|
||||
except CommandException as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise ReportScheduleUnexpectedError(str(ex)) from ex
|
||||
ReportScheduleStateMachine(
|
||||
self._execution_id, self._model, self._scheduled_dttm
|
||||
).run()
|
||||
except CommandException as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise ReportScheduleUnexpectedError(str(ex)) from ex
|
||||
|
||||
def validate(self, session: Session = None) -> None:
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
logger.info(
|
||||
"session is validated: id %s, executionid: %s",
|
||||
|
|
@ -740,7 +730,7 @@ class AsyncExecuteReportScheduleCommand(BaseCommand):
|
|||
self._execution_id,
|
||||
)
|
||||
self._model = (
|
||||
session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none()
|
||||
db.session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none()
|
||||
)
|
||||
if not self._model:
|
||||
raise ReportScheduleNotFoundError()
|
||||
|
|
|
|||
|
|
@ -17,12 +17,12 @@
|
|||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.report.exceptions import ReportSchedulePruneLogError
|
||||
from superset.daos.exceptions import DAODeleteFailedError
|
||||
from superset.daos.report import ReportScheduleDAO
|
||||
from superset.reports.models import ReportSchedule
|
||||
from superset.utils.celery import session_scope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,28 +36,27 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
|
|||
self._worker_context = worker_context
|
||||
|
||||
def run(self) -> None:
|
||||
with session_scope(nullpool=True) as session:
|
||||
self.validate()
|
||||
prune_errors = []
|
||||
self.validate()
|
||||
prune_errors = []
|
||||
|
||||
for report_schedule in session.query(ReportSchedule).all():
|
||||
if report_schedule.log_retention is not None:
|
||||
from_date = datetime.utcnow() - timedelta(
|
||||
days=report_schedule.log_retention
|
||||
for report_schedule in db.session.query(ReportSchedule).all():
|
||||
if report_schedule.log_retention is not None:
|
||||
from_date = datetime.utcnow() - timedelta(
|
||||
days=report_schedule.log_retention
|
||||
)
|
||||
try:
|
||||
row_count = ReportScheduleDAO.bulk_delete_logs(
|
||||
report_schedule, from_date, commit=False
|
||||
)
|
||||
try:
|
||||
row_count = ReportScheduleDAO.bulk_delete_logs(
|
||||
report_schedule, from_date, session=session, commit=False
|
||||
)
|
||||
logger.info(
|
||||
"Deleted %s logs for report schedule id: %s",
|
||||
str(row_count),
|
||||
str(report_schedule.id),
|
||||
)
|
||||
except DAODeleteFailedError as ex:
|
||||
prune_errors.append(str(ex))
|
||||
if prune_errors:
|
||||
raise ReportSchedulePruneLogError(";".join(prune_errors))
|
||||
logger.info(
|
||||
"Deleted %s logs for report schedule id: %s",
|
||||
str(row_count),
|
||||
str(report_schedule.id),
|
||||
)
|
||||
except DAODeleteFailedError as ex:
|
||||
prune_errors.append(str(ex))
|
||||
if prune_errors:
|
||||
raise ReportSchedulePruneLogError(";".join(prune_errors))
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from datetime import datetime
|
|||
from typing import Any
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.daos.exceptions import DAODeleteFailedError
|
||||
|
|
@ -204,27 +203,25 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
|
|||
return super().update(item, attributes, commit)
|
||||
|
||||
@staticmethod
|
||||
def find_active(session: Session | None = None) -> list[ReportSchedule]:
|
||||
def find_active() -> list[ReportSchedule]:
|
||||
"""
|
||||
Find all active reports. If session is passed it will be used instead of the
|
||||
default `db.session`, this is useful when on a celery worker session context
|
||||
Find all active reports.
|
||||
"""
|
||||
session = session or db.session
|
||||
return (
|
||||
session.query(ReportSchedule).filter(ReportSchedule.active.is_(True)).all()
|
||||
db.session.query(ReportSchedule)
|
||||
.filter(ReportSchedule.active.is_(True))
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def find_last_success_log(
|
||||
report_schedule: ReportSchedule,
|
||||
session: Session | None = None,
|
||||
) -> ReportExecutionLog | None:
|
||||
"""
|
||||
Finds last success execution log for a given report
|
||||
"""
|
||||
session = session or db.session
|
||||
return (
|
||||
session.query(ReportExecutionLog)
|
||||
db.session.query(ReportExecutionLog)
|
||||
.filter(
|
||||
ReportExecutionLog.state == ReportState.SUCCESS,
|
||||
ReportExecutionLog.report_schedule == report_schedule,
|
||||
|
|
@ -236,14 +233,12 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
|
|||
@staticmethod
|
||||
def find_last_entered_working_log(
|
||||
report_schedule: ReportSchedule,
|
||||
session: Session | None = None,
|
||||
) -> ReportExecutionLog | None:
|
||||
"""
|
||||
Finds last success execution log for a given report
|
||||
"""
|
||||
session = session or db.session
|
||||
return (
|
||||
session.query(ReportExecutionLog)
|
||||
db.session.query(ReportExecutionLog)
|
||||
.filter(
|
||||
ReportExecutionLog.state == ReportState.WORKING,
|
||||
ReportExecutionLog.report_schedule == report_schedule,
|
||||
|
|
@ -256,14 +251,12 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
|
|||
@staticmethod
|
||||
def find_last_error_notification(
|
||||
report_schedule: ReportSchedule,
|
||||
session: Session | None = None,
|
||||
) -> ReportExecutionLog | None:
|
||||
"""
|
||||
Finds last error email sent
|
||||
"""
|
||||
session = session or db.session
|
||||
last_error_email_log = (
|
||||
session.query(ReportExecutionLog)
|
||||
db.session.query(ReportExecutionLog)
|
||||
.filter(
|
||||
ReportExecutionLog.error_message
|
||||
== REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER,
|
||||
|
|
@ -276,7 +269,7 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
|
|||
return None
|
||||
# Checks that only errors have occurred since the last email
|
||||
report_from_last_email = (
|
||||
session.query(ReportExecutionLog)
|
||||
db.session.query(ReportExecutionLog)
|
||||
.filter(
|
||||
ReportExecutionLog.state.notin_(
|
||||
[ReportState.ERROR, ReportState.WORKING]
|
||||
|
|
@ -293,13 +286,11 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
|
|||
def bulk_delete_logs(
|
||||
model: ReportSchedule,
|
||||
from_date: datetime,
|
||||
session: Session | None = None,
|
||||
commit: bool = True,
|
||||
) -> int | None:
|
||||
session = session or db.session
|
||||
try:
|
||||
row_count = (
|
||||
session.query(ReportExecutionLog)
|
||||
db.session.query(ReportExecutionLog)
|
||||
.filter(
|
||||
ReportExecutionLog.report_schedule == model,
|
||||
ReportExecutionLog.end_dttm < from_date,
|
||||
|
|
@ -307,8 +298,8 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
|
|||
.delete(synchronize_session="fetch")
|
||||
)
|
||||
if commit:
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
return row_count
|
||||
except SQLAlchemyError as ex:
|
||||
session.rollback()
|
||||
db.session.rollback()
|
||||
raise DAODeleteFailedError(str(ex)) from ex
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ from sqlalchemy.engine.interfaces import Compiled, Dialect
|
|||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import literal_column, quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
|
@ -1071,7 +1070,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
|
||||
def handle_cursor(cls, cursor: Any, query: Query) -> None:
|
||||
"""Handle a live cursor between the execute and fetchall calls
|
||||
|
||||
The flow works without this method doing anything, but it allows
|
||||
|
|
@ -1080,9 +1079,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
# TODO: Fix circular import error caused by importing sql_lab.Query
|
||||
|
||||
@classmethod
|
||||
def execute_with_cursor(
|
||||
cls, cursor: Any, sql: str, query: Query, session: Session
|
||||
) -> None:
|
||||
def execute_with_cursor(cls, cursor: Any, sql: str, query: Query) -> None:
|
||||
"""
|
||||
Trigger execution of a query and handle the resulting cursor.
|
||||
|
||||
|
|
@ -1095,7 +1092,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
logger.debug("Query %d: Running query: %s", query.id, sql)
|
||||
cls.execute(cursor, sql, async_=True)
|
||||
logger.debug("Query %d: Handling cursor", query.id)
|
||||
cls.handle_cursor(cursor, query, session)
|
||||
cls.handle_cursor(cursor, query)
|
||||
|
||||
@classmethod
|
||||
def extract_error_message(cls, ex: Exception) -> str:
|
||||
|
|
@ -1841,7 +1838,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
# pylint: disable=unused-argument
|
||||
@classmethod
|
||||
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
|
||||
def prepare_cancel_query(cls, query: Query) -> None:
|
||||
"""
|
||||
Some databases may acquire the query cancelation id after the query
|
||||
cancelation request has been received. For those cases, the db engine spec
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ from sqlalchemy import Column, text, types
|
|||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
|
||||
from superset import db
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.constants import TimeGrain
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
|
@ -334,7 +334,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def handle_cursor( # pylint: disable=too-many-locals
|
||||
cls, cursor: Any, query: Query, session: Session
|
||||
cls, cursor: Any, query: Query
|
||||
) -> None:
|
||||
"""Updates progress information"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
|
@ -353,8 +353,8 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
# Queries don't terminate when user clicks the STOP button on SQL LAB.
|
||||
# Refresh session so that the `query.status` modified in stop_query in
|
||||
# views/core.py is reflected here.
|
||||
session.refresh(query)
|
||||
query = session.query(type(query)).filter_by(id=query_id).one()
|
||||
db.session.refresh(query)
|
||||
query = db.session.query(type(query)).filter_by(id=query_id).one()
|
||||
if query.status == QueryStatus.STOPPED:
|
||||
cursor.cancel()
|
||||
break
|
||||
|
|
@ -396,7 +396,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l)
|
||||
last_log_line = len(log_lines)
|
||||
if needs_commit:
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"):
|
||||
logger.warning(
|
||||
"HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead"
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ from typing import Any, Optional
|
|||
from flask import current_app
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db
|
||||
from superset.constants import QUERY_EARLY_CANCEL_KEY, TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.models.sql_lab import Query
|
||||
|
|
@ -101,7 +101,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
|
|||
raise cls.get_dbapi_mapped_exception(ex)
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
|
||||
def handle_cursor(cls, cursor: Any, query: Query) -> None:
|
||||
"""Stop query and updates progress information"""
|
||||
|
||||
query_id = query.id
|
||||
|
|
@ -113,8 +113,8 @@ class ImpalaEngineSpec(BaseEngineSpec):
|
|||
try:
|
||||
status = cursor.status()
|
||||
while status in unfinished_states:
|
||||
session.refresh(query)
|
||||
query = session.query(Query).filter_by(id=query_id).one()
|
||||
db.session.refresh(query)
|
||||
query = db.session.query(Query).filter_by(id=query_id).one()
|
||||
# if query cancelation was requested prior to the handle_cursor call, but
|
||||
# the query was still executed
|
||||
# modified in stop_query in views / core.py is reflected here.
|
||||
|
|
@ -145,7 +145,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
|
|||
needs_commit = True
|
||||
|
||||
if needs_commit:
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get(
|
||||
cls.engine, 5
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from typing import Any, Callable, List, NamedTuple, Optional
|
|||
|
||||
from flask_babel import gettext as __
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be installed
|
||||
# Ensure pyocient inherits Superset's logging level
|
||||
|
|
@ -372,13 +371,13 @@ class OcientEngineSpec(BaseEngineSpec):
|
|||
return "DUMMY_VALUE"
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
|
||||
def handle_cursor(cls, cursor: Any, query: Query) -> None:
|
||||
with OcientEngineSpec.query_id_mapping_lock:
|
||||
OcientEngineSpec.query_id_mapping[query.id] = cursor.query_id
|
||||
|
||||
# Add the query id to the cursor
|
||||
setattr(cursor, "superset_query_id", query.id)
|
||||
return super().handle_cursor(cursor, query, session)
|
||||
return super().handle_cursor(cursor, query)
|
||||
|
||||
@classmethod
|
||||
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -39,10 +39,9 @@ from sqlalchemy.engine.base import Engine
|
|||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.result import Row as ResultRow
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
|
||||
from superset import cache_manager, is_feature_enabled
|
||||
from superset import cache_manager, db, is_feature_enabled
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.constants import TimeGrain
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
|
@ -1288,11 +1287,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
|
||||
def handle_cursor(cls, cursor: Cursor, query: Query) -> None:
|
||||
"""Updates progress information"""
|
||||
if tracking_url := cls.get_tracking_url(cursor):
|
||||
query.tracking_url = tracking_url
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
query_id = query.id
|
||||
poll_interval = query.database.connect_args.get(
|
||||
|
|
@ -1308,7 +1307,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
# Update the object and wait for the kill signal.
|
||||
stats = polled.get("stats", {})
|
||||
|
||||
query = session.query(type(query)).filter_by(id=query_id).one()
|
||||
query = db.session.query(type(query)).filter_by(id=query_id).one()
|
||||
if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
|
||||
cursor.cancel()
|
||||
break
|
||||
|
|
@ -1332,7 +1331,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
)
|
||||
if progress > query.progress:
|
||||
query.progress = progress
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
time.sleep(poll_interval)
|
||||
logger.info("Query %i: Polling the cursor for progress", query_id)
|
||||
polled = cursor.poll()
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ from flask import current_app
|
|||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db
|
||||
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
|
@ -155,7 +155,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
|
||||
def handle_cursor(cls, cursor: Cursor, query: Query) -> None:
|
||||
"""
|
||||
Handle a trino client cursor.
|
||||
|
||||
|
|
@ -172,7 +172,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
if tracking_url := cls.get_tracking_url(cursor):
|
||||
query.tracking_url = tracking_url
|
||||
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
# if query cancelation was requested prior to the handle_cursor call, but
|
||||
# the query was still executed, trigger the actual query cancelation now
|
||||
|
|
@ -183,12 +183,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
cancel_query_id=cancel_query_id,
|
||||
)
|
||||
|
||||
super().handle_cursor(cursor=cursor, query=query, session=session)
|
||||
super().handle_cursor(cursor=cursor, query=query)
|
||||
|
||||
@classmethod
|
||||
def execute_with_cursor(
|
||||
cls, cursor: Cursor, sql: str, query: Query, session: Session
|
||||
) -> None:
|
||||
def execute_with_cursor(cls, cursor: Cursor, sql: str, query: Query) -> None:
|
||||
"""
|
||||
Trigger execution of a query and handle the resulting cursor.
|
||||
|
||||
|
|
@ -225,7 +223,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
time.sleep(0.1)
|
||||
|
||||
logger.debug("Query %d: Handling cursor", query_id)
|
||||
cls.handle_cursor(cursor, query, session)
|
||||
cls.handle_cursor(cursor, query)
|
||||
|
||||
# Block until the query completes; same behaviour as the client itself
|
||||
logger.debug("Query %d: Waiting for query to complete", query_id)
|
||||
|
|
@ -237,10 +235,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
raise err
|
||||
|
||||
@classmethod
|
||||
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
|
||||
def prepare_cancel_query(cls, query: Query) -> None:
|
||||
if QUERY_CANCEL_KEY not in query.extra:
|
||||
query.set_extra_json_key(QUERY_EARLY_CANCEL_KEY, True)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def cancel_query(cls, cursor: Cursor, query: Query, cancel_query_id: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -25,10 +25,8 @@ from typing import Any, cast, Optional, Union
|
|||
import backoff
|
||||
import msgpack
|
||||
import simplejson as json
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from flask_babel import gettext as __
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import (
|
||||
app,
|
||||
|
|
@ -56,7 +54,6 @@ from superset.sql_parse import (
|
|||
)
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
from superset.sqllab.utils import write_ipc_buffer
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.core import (
|
||||
json_iso_dttm_ser,
|
||||
override_user,
|
||||
|
|
@ -92,7 +89,6 @@ class SqlLabQueryStoppedException(SqlLabException):
|
|||
def handle_query_error(
|
||||
ex: Exception,
|
||||
query: Query,
|
||||
session: Session,
|
||||
payload: Optional[dict[str, Any]] = None,
|
||||
prefix_message: str = "",
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -120,7 +116,7 @@ def handle_query_error(
|
|||
if errors:
|
||||
query.set_extra_json_key("errors", errors_payload)
|
||||
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
payload.update({"status": query.status, "error": msg, "errors": errors_payload})
|
||||
if troubleshooting_link := config["TROUBLESHOOTING_LINK"]:
|
||||
payload["link"] = troubleshooting_link
|
||||
|
|
@ -150,22 +146,20 @@ def get_query_giveup_handler(_: Any) -> None:
|
|||
on_giveup=get_query_giveup_handler,
|
||||
max_tries=5,
|
||||
)
|
||||
def get_query(query_id: int, session: Session) -> Query:
|
||||
def get_query(query_id: int) -> Query:
|
||||
"""attempts to get the query and retry if it cannot"""
|
||||
try:
|
||||
return session.query(Query).filter_by(id=query_id).one()
|
||||
return db.session.query(Query).filter_by(id=query_id).one()
|
||||
except Exception as ex:
|
||||
raise SqlLabException("Failed at getting query") from ex
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="sql_lab.get_sql_results",
|
||||
bind=True,
|
||||
time_limit=SQLLAB_HARD_TIMEOUT,
|
||||
soft_time_limit=SQLLAB_TIMEOUT,
|
||||
)
|
||||
def get_sql_results( # pylint: disable=too-many-arguments
|
||||
ctask: Task,
|
||||
query_id: int,
|
||||
rendered_query: str,
|
||||
return_results: bool = True,
|
||||
|
|
@ -176,30 +170,27 @@ def get_sql_results( # pylint: disable=too-many-arguments
|
|||
log_params: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Executes the sql query returns the results."""
|
||||
with session_scope(not ctask.request.called_directly) as session:
|
||||
with override_user(security_manager.find_user(username)):
|
||||
try:
|
||||
return execute_sql_statements(
|
||||
query_id,
|
||||
rendered_query,
|
||||
return_results,
|
||||
store_results,
|
||||
session=session,
|
||||
start_time=start_time,
|
||||
expand_data=expand_data,
|
||||
log_params=log_params,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.debug("Query %d: %s", query_id, ex)
|
||||
stats_logger.incr("error_sqllab_unhandled")
|
||||
query = get_query(query_id, session)
|
||||
return handle_query_error(ex, query, session)
|
||||
with override_user(security_manager.find_user(username)):
|
||||
try:
|
||||
return execute_sql_statements(
|
||||
query_id,
|
||||
rendered_query,
|
||||
return_results,
|
||||
store_results,
|
||||
start_time=start_time,
|
||||
expand_data=expand_data,
|
||||
log_params=log_params,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.debug("Query %d: %s", query_id, ex)
|
||||
stats_logger.incr("error_sqllab_unhandled")
|
||||
query = get_query(query_id)
|
||||
return handle_query_error(ex, query)
|
||||
|
||||
|
||||
def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-locals
|
||||
def execute_sql_statement(
|
||||
sql_statement: str,
|
||||
query: Query,
|
||||
session: Session,
|
||||
cursor: Any,
|
||||
log_params: Optional[dict[str, Any]],
|
||||
apply_ctas: bool = False,
|
||||
|
|
@ -284,9 +275,9 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
|
|||
security_manager,
|
||||
log_params,
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
with stats_timing("sqllab.query.time_executing_query", stats_logger):
|
||||
db_engine_spec.execute_with_cursor(cursor, sql, query, session)
|
||||
db_engine_spec.execute_with_cursor(cursor, sql, query)
|
||||
|
||||
with stats_timing("sqllab.query.time_fetching_results", stats_logger):
|
||||
logger.debug(
|
||||
|
|
@ -319,7 +310,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
|
|||
except Exception as ex:
|
||||
# query is stopped in another thread/worker
|
||||
# stopping raises expected exceptions which we should skip
|
||||
session.refresh(query)
|
||||
db.session.refresh(query)
|
||||
if query.status == QueryStatus.STOPPED:
|
||||
raise SqlLabQueryStoppedException() from ex
|
||||
|
||||
|
|
@ -393,7 +384,6 @@ def execute_sql_statements(
|
|||
rendered_query: str,
|
||||
return_results: bool,
|
||||
store_results: bool,
|
||||
session: Session,
|
||||
start_time: Optional[float],
|
||||
expand_data: bool,
|
||||
log_params: Optional[dict[str, Any]],
|
||||
|
|
@ -403,7 +393,7 @@ def execute_sql_statements(
|
|||
# only asynchronous queries
|
||||
stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)
|
||||
|
||||
query = get_query(query_id, session)
|
||||
query = get_query(query_id)
|
||||
payload: dict[str, Any] = {"query_id": query_id}
|
||||
database = query.database
|
||||
db_engine_spec = database.db_engine_spec
|
||||
|
|
@ -432,7 +422,7 @@ def execute_sql_statements(
|
|||
logger.info("Query %s: Set query to 'running'", str(query_id))
|
||||
query.status = QueryStatus.RUNNING
|
||||
query.start_running_time = now_as_float()
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
# Should we create a table or view from the select?
|
||||
if (
|
||||
|
|
@ -476,11 +466,11 @@ def execute_sql_statements(
|
|||
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
|
||||
if cancel_query_id is not None:
|
||||
query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
statement_count = len(statements)
|
||||
for i, statement in enumerate(statements):
|
||||
# Check if stopped
|
||||
session.refresh(query)
|
||||
db.session.refresh(query)
|
||||
if query.status == QueryStatus.STOPPED:
|
||||
payload.update({"status": query.status})
|
||||
return payload
|
||||
|
|
@ -497,12 +487,11 @@ def execute_sql_statements(
|
|||
)
|
||||
logger.info("Query %s: %s", str(query_id), msg)
|
||||
query.set_extra_json_key("progress", msg)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
try:
|
||||
result_set = execute_sql_statement(
|
||||
statement,
|
||||
query,
|
||||
session,
|
||||
cursor,
|
||||
log_params,
|
||||
apply_ctas,
|
||||
|
|
@ -521,9 +510,7 @@ def execute_sql_statements(
|
|||
if statement_count > 1
|
||||
else ""
|
||||
)
|
||||
payload = handle_query_error(
|
||||
ex, query, session, payload, prefix_message
|
||||
)
|
||||
payload = handle_query_error(ex, query, payload, prefix_message)
|
||||
return payload
|
||||
|
||||
# Commit the connection so CTA queries will create the table and any DML.
|
||||
|
|
@ -593,7 +580,7 @@ def execute_sql_statements(
|
|||
query.results_key = key
|
||||
|
||||
query.status = QueryStatus.SUCCESS
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
if return_results:
|
||||
# since we're returning results we need to create non-arrow data
|
||||
|
|
@ -634,7 +621,7 @@ def cancel_query(query: Query) -> bool:
|
|||
return True
|
||||
|
||||
# Some databases may need to make preparations for query cancellation
|
||||
query.database.db_engine_spec.prepare_cancel_query(query, db.session)
|
||||
query.database.db_engine_spec.prepare_cancel_query(query)
|
||||
|
||||
if query.extra.get(QUERY_EARLY_CANCEL_KEY):
|
||||
# Query has been cancelled prior to being able to set the cancel key.
|
||||
|
|
|
|||
|
|
@ -94,14 +94,7 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
|
|||
name = "dummy"
|
||||
|
||||
def get_payloads(self) -> list[dict[str, int]]:
|
||||
session = db.create_scoped_session()
|
||||
|
||||
try:
|
||||
charts = session.query(Slice).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
return [get_payload(chart) for chart in charts]
|
||||
return [get_payload(chart) for chart in db.session.query(Slice).all()]
|
||||
|
||||
|
||||
class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods
|
||||
|
|
@ -130,28 +123,24 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
|
|||
self.since = parse_human_datetime(since) if since else None
|
||||
|
||||
def get_payloads(self) -> list[dict[str, int]]:
|
||||
payloads = []
|
||||
session = db.create_scoped_session()
|
||||
records = (
|
||||
db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
|
||||
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
|
||||
.group_by(Log.dashboard_id)
|
||||
.order_by(func.count(Log.dashboard_id).desc())
|
||||
.limit(self.top_n)
|
||||
.all()
|
||||
)
|
||||
dash_ids = [record.dashboard_id for record in records]
|
||||
dashboards = (
|
||||
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
|
||||
)
|
||||
|
||||
try:
|
||||
records = (
|
||||
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
|
||||
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
|
||||
.group_by(Log.dashboard_id)
|
||||
.order_by(func.count(Log.dashboard_id).desc())
|
||||
.limit(self.top_n)
|
||||
.all()
|
||||
)
|
||||
dash_ids = [record.dashboard_id for record in records]
|
||||
dashboards = (
|
||||
session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
|
||||
)
|
||||
for dashboard in dashboards:
|
||||
for chart in dashboard.slices:
|
||||
payloads.append(get_payload(chart, dashboard))
|
||||
finally:
|
||||
session.close()
|
||||
return payloads
|
||||
return [
|
||||
get_payload(chart, dashboard)
|
||||
for dashboard in dashboards
|
||||
for chart in dashboard.slices
|
||||
]
|
||||
|
||||
|
||||
class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
|
||||
|
|
@ -178,48 +167,44 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
|
|||
|
||||
def get_payloads(self) -> list[dict[str, int]]:
|
||||
payloads = []
|
||||
session = db.create_scoped_session()
|
||||
tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
try:
|
||||
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
# add dashboards that are tagged
|
||||
tagged_objects = (
|
||||
session.query(TaggedObject)
|
||||
.filter(
|
||||
and_(
|
||||
TaggedObject.object_type == "dashboard",
|
||||
TaggedObject.tag_id.in_(tag_ids),
|
||||
)
|
||||
# add dashboards that are tagged
|
||||
tagged_objects = (
|
||||
db.session.query(TaggedObject)
|
||||
.filter(
|
||||
and_(
|
||||
TaggedObject.object_type == "dashboard",
|
||||
TaggedObject.tag_id.in_(tag_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
|
||||
tagged_dashboards = session.query(Dashboard).filter(
|
||||
Dashboard.id.in_(dash_ids)
|
||||
)
|
||||
for dashboard in tagged_dashboards:
|
||||
for chart in dashboard.slices:
|
||||
payloads.append(get_payload(chart))
|
||||
|
||||
# add charts that are tagged
|
||||
tagged_objects = (
|
||||
session.query(TaggedObject)
|
||||
.filter(
|
||||
and_(
|
||||
TaggedObject.object_type == "chart",
|
||||
TaggedObject.tag_id.in_(tag_ids),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
|
||||
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
|
||||
for chart in tagged_charts:
|
||||
.all()
|
||||
)
|
||||
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
|
||||
tagged_dashboards = db.session.query(Dashboard).filter(
|
||||
Dashboard.id.in_(dash_ids)
|
||||
)
|
||||
for dashboard in tagged_dashboards:
|
||||
for chart in dashboard.slices:
|
||||
payloads.append(get_payload(chart))
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# add charts that are tagged
|
||||
tagged_objects = (
|
||||
db.session.query(TaggedObject)
|
||||
.filter(
|
||||
and_(
|
||||
TaggedObject.object_type == "chart",
|
||||
TaggedObject.tag_id.in_(tag_ids),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
|
||||
tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
|
||||
for chart in tagged_charts:
|
||||
payloads.append(get_payload(chart))
|
||||
|
||||
return payloads
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ it needs to call create_app() in order to initialize things properly
|
|||
"""
|
||||
from typing import Any
|
||||
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import task_postrun, worker_process_init
|
||||
|
||||
# Superset framework imports
|
||||
from superset import create_app
|
||||
|
|
@ -43,3 +43,27 @@ def reset_db_connection_pool(**kwargs: Any) -> None: # pylint: disable=unused-a
|
|||
with flask_app.app_context():
|
||||
# https://docs.sqlalchemy.org/en/14/core/connections.html#engine-disposal
|
||||
db.engine.dispose()
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def teardown( # pylint: disable=unused-argument
|
||||
retval: Any,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
After each Celery task teardown the Flask-SQLAlchemy session.
|
||||
|
||||
Note for non eagar requests Flask-SQLAlchemy will perform the teardown.
|
||||
|
||||
:param retval: The return value of the task
|
||||
:see: https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun
|
||||
:see: https://gist.github.com/twolfson/a1b329e9353f9b575131
|
||||
"""
|
||||
|
||||
if flask_app.config.get("SQLALCHEMY_COMMIT_ON_TEARDOWN"):
|
||||
if not isinstance(retval, Exception):
|
||||
db.session.commit()
|
||||
|
||||
if not flask_app.config.get("CELERY_ALWAYS_EAGER"):
|
||||
db.session.remove()
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ from superset.daos.report import ReportScheduleDAO
|
|||
from superset.extensions import celery_app
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.tasks.cron_util import cron_schedule_window
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.core import LoggerLevel
|
||||
from superset.utils.log import get_logger_from_status
|
||||
|
||||
|
|
@ -46,35 +45,32 @@ def scheduler() -> None:
|
|||
|
||||
if not is_feature_enabled("ALERT_REPORTS"):
|
||||
return
|
||||
with session_scope(nullpool=True) as session:
|
||||
active_schedules = ReportScheduleDAO.find_active(session)
|
||||
triggered_at = (
|
||||
datetime.fromisoformat(scheduler.request.expires)
|
||||
- app.config["CELERY_BEAT_SCHEDULER_EXPIRES"]
|
||||
if scheduler.request.expires
|
||||
else datetime.utcnow()
|
||||
)
|
||||
for active_schedule in active_schedules:
|
||||
for schedule in cron_schedule_window(
|
||||
triggered_at, active_schedule.crontab, active_schedule.timezone
|
||||
active_schedules = ReportScheduleDAO.find_active()
|
||||
triggered_at = (
|
||||
datetime.fromisoformat(scheduler.request.expires)
|
||||
- app.config["CELERY_BEAT_SCHEDULER_EXPIRES"]
|
||||
if scheduler.request.expires
|
||||
else datetime.utcnow()
|
||||
)
|
||||
for active_schedule in active_schedules:
|
||||
for schedule in cron_schedule_window(
|
||||
triggered_at, active_schedule.crontab, active_schedule.timezone
|
||||
):
|
||||
logger.info("Scheduling alert %s eta: %s", active_schedule.name, schedule)
|
||||
async_options = {"eta": schedule}
|
||||
if (
|
||||
active_schedule.working_timeout is not None
|
||||
and app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"]
|
||||
):
|
||||
logger.info(
|
||||
"Scheduling alert %s eta: %s", active_schedule.name, schedule
|
||||
async_options["time_limit"] = (
|
||||
active_schedule.working_timeout
|
||||
+ app.config["ALERT_REPORTS_WORKING_TIME_OUT_LAG"]
|
||||
)
|
||||
async_options = {"eta": schedule}
|
||||
if (
|
||||
active_schedule.working_timeout is not None
|
||||
and app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"]
|
||||
):
|
||||
async_options["time_limit"] = (
|
||||
active_schedule.working_timeout
|
||||
+ app.config["ALERT_REPORTS_WORKING_TIME_OUT_LAG"]
|
||||
)
|
||||
async_options["soft_time_limit"] = (
|
||||
active_schedule.working_timeout
|
||||
+ app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"]
|
||||
)
|
||||
execute.apply_async((active_schedule.id,), **async_options)
|
||||
async_options["soft_time_limit"] = (
|
||||
active_schedule.working_timeout
|
||||
+ app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"]
|
||||
)
|
||||
execute.apply_async((active_schedule.id,), **async_options)
|
||||
|
||||
|
||||
@celery_app.task(name="reports.execute", bind=True)
|
||||
|
|
|
|||
|
|
@ -1,59 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from superset import app, db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Null pool is used for the celery workers due process forking side effects.
|
||||
# For more info see: https://github.com/apache/superset/issues/10530
|
||||
@contextmanager
|
||||
def session_scope(nullpool: bool) -> Iterator[Session]:
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
|
||||
if "sqlite" in database_uri:
|
||||
logger.warning(
|
||||
"SQLite Database support for metadata databases will be removed \
|
||||
in a future version of Superset."
|
||||
)
|
||||
if nullpool:
|
||||
engine = create_engine(database_uri, poolclass=NullPool)
|
||||
session_class = sessionmaker()
|
||||
session_class.configure(bind=engine)
|
||||
session = session_class()
|
||||
else:
|
||||
session = db.session()
|
||||
session.commit() # HACK
|
||||
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except SQLAlchemyError as ex:
|
||||
session.rollback()
|
||||
logger.exception(ex)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset import app, db
|
||||
from superset import app
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
|
@ -29,7 +29,6 @@ def test_non_async_execute(non_async_example_db: Database, example_query: Query)
|
|||
"select 1 as foo;",
|
||||
store_results=False,
|
||||
return_results=True,
|
||||
session=db.session,
|
||||
start_time=now_as_float(),
|
||||
expand_data=True,
|
||||
log_params=dict(),
|
||||
|
|
|
|||
|
|
@ -575,16 +575,22 @@ class TestSqlLab(SupersetTestCase):
|
|||
)
|
||||
assert data["errors"][0]["error_type"] == "GENERIC_BACKEND_ERROR"
|
||||
|
||||
@mock.patch("superset.sql_lab.db")
|
||||
@mock.patch("superset.sql_lab.get_query")
|
||||
@mock.patch("superset.sql_lab.execute_sql_statement")
|
||||
def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query):
|
||||
def test_execute_sql_statements(
|
||||
self,
|
||||
mock_execute_sql_statement,
|
||||
mock_get_query,
|
||||
mock_db,
|
||||
):
|
||||
sql = """
|
||||
-- comment
|
||||
SET @value = 42;
|
||||
SELECT @value AS foo;
|
||||
-- comment
|
||||
"""
|
||||
mock_session = mock.MagicMock()
|
||||
mock_db = mock.MagicMock()
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
|
|
@ -599,7 +605,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
|
|
@ -609,7 +614,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock.call(
|
||||
"SET @value = 42",
|
||||
mock_query,
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
|
|
@ -617,7 +621,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock.call(
|
||||
"SELECT @value AS foo",
|
||||
mock_query,
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
|
|
@ -637,7 +640,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
SELECT @value AS foo;
|
||||
-- comment
|
||||
"""
|
||||
mock_session = mock.MagicMock()
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = True
|
||||
mock_cursor = mock.MagicMock()
|
||||
|
|
@ -653,7 +655,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
|
|
@ -676,10 +677,14 @@ class TestSqlLab(SupersetTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
@mock.patch("superset.sql_lab.db")
|
||||
@mock.patch("superset.sql_lab.get_query")
|
||||
@mock.patch("superset.sql_lab.execute_sql_statement")
|
||||
def test_execute_sql_statements_ctas(
|
||||
self, mock_execute_sql_statement, mock_get_query
|
||||
self,
|
||||
mock_execute_sql_statement,
|
||||
mock_get_query,
|
||||
mock_db,
|
||||
):
|
||||
sql = """
|
||||
-- comment
|
||||
|
|
@ -687,7 +692,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
SELECT @value AS foo;
|
||||
-- comment
|
||||
"""
|
||||
mock_session = mock.MagicMock()
|
||||
mock_db = mock.MagicMock()
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
|
|
@ -706,7 +711,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
|
|
@ -716,7 +720,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock.call(
|
||||
"SET @value = 42",
|
||||
mock_query,
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
|
|
@ -724,7 +727,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock.call(
|
||||
"SELECT @value AS foo",
|
||||
mock_query,
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
True, # apply_ctas
|
||||
|
|
@ -740,7 +742,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
|
|
@ -773,7 +774,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
|
|
|
|||
|
|
@ -352,9 +352,8 @@ def test_prepare_cancel_query(
|
|||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
session_mock = mocker.MagicMock()
|
||||
query = Query(extra_json=json.dumps(initial_extra))
|
||||
TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
|
||||
TrinoEngineSpec.prepare_cancel_query(query=query)
|
||||
assert query.extra == final_extra
|
||||
|
||||
|
||||
|
|
@ -374,14 +373,13 @@ def test_handle_cursor_early_cancel(
|
|||
|
||||
cursor_mock = engine_mock.return_value.__enter__.return_value
|
||||
cursor_mock.query_id = query_id
|
||||
session_mock = mocker.MagicMock()
|
||||
|
||||
query = Query()
|
||||
|
||||
if cancel_early:
|
||||
TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
|
||||
TrinoEngineSpec.prepare_cancel_query(query=query)
|
||||
|
||||
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query, session=session_mock)
|
||||
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
|
||||
|
||||
if cancel_early:
|
||||
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
|
||||
|
|
@ -399,7 +397,6 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
|
|||
mock_cursor.query_id = None
|
||||
|
||||
mock_query = mocker.MagicMock()
|
||||
mock_session = mocker.MagicMock()
|
||||
|
||||
def _mock_execute(*args, **kwargs):
|
||||
mock_cursor.query_id = query_id
|
||||
|
|
@ -410,7 +407,6 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
|
|||
cursor=mock_cursor,
|
||||
sql="SELECT 1 FROM foo",
|
||||
query=mock_query,
|
||||
session=mock_session,
|
||||
)
|
||||
|
||||
mock_query.set_extra_json_key.assert_called_once_with(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import sqlparse
|
|||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.utils.core import override_user
|
||||
|
||||
|
||||
|
|
@ -41,14 +42,12 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
|
|||
db_engine_spec.is_select_query.return_value = True
|
||||
db_engine_spec.fetch_data.return_value = [(42,)]
|
||||
|
||||
session = mocker.MagicMock()
|
||||
cursor = mocker.MagicMock()
|
||||
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
|
||||
|
||||
execute_sql_statement(
|
||||
sql_statement,
|
||||
query,
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params={},
|
||||
apply_ctas=False,
|
||||
|
|
@ -56,7 +55,7 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
|
|||
|
||||
database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
|
||||
db_engine_spec.execute_with_cursor.assert_called_with(
|
||||
cursor, "SELECT 42 AS answer LIMIT 2", query, session
|
||||
cursor, "SELECT 42 AS answer LIMIT 2", query
|
||||
)
|
||||
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
||||
|
||||
|
|
@ -83,7 +82,6 @@ def test_execute_sql_statement_with_rls(
|
|||
db_engine_spec.is_select_query.return_value = True
|
||||
db_engine_spec.fetch_data.return_value = [(42,)]
|
||||
|
||||
session = mocker.MagicMock()
|
||||
cursor = mocker.MagicMock()
|
||||
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
|
||||
mocker.patch(
|
||||
|
|
@ -95,7 +93,6 @@ def test_execute_sql_statement_with_rls(
|
|||
execute_sql_statement(
|
||||
sql_statement,
|
||||
query,
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params={},
|
||||
apply_ctas=False,
|
||||
|
|
@ -107,7 +104,7 @@ def test_execute_sql_statement_with_rls(
|
|||
force=True,
|
||||
)
|
||||
db_engine_spec.execute_with_cursor.assert_called_with(
|
||||
cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query, session
|
||||
cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query
|
||||
)
|
||||
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
||||
|
||||
|
|
@ -162,7 +159,6 @@ def test_sql_lab_insert_rls_as_subquery(
|
|||
superset_result_set = execute_sql_statement(
|
||||
sql_statement=query.sql,
|
||||
query=query,
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params=None,
|
||||
apply_ctas=False,
|
||||
|
|
@ -198,7 +194,6 @@ def test_sql_lab_insert_rls_as_subquery(
|
|||
superset_result_set = execute_sql_statement(
|
||||
sql_statement=query.sql,
|
||||
query=query,
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params=None,
|
||||
apply_ctas=False,
|
||||
|
|
|
|||
Loading…
Reference in New Issue