refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (#26200)
This commit is contained in:
parent
80a6e25a98
commit
df79522160
|
|
@ -142,8 +142,6 @@ def main(
|
|||
filepath: str, limit: int = 1000, force: bool = False, no_auto_cleanup: bool = False
|
||||
) -> None:
|
||||
auto_cleanup = not no_auto_cleanup
|
||||
session = db.session()
|
||||
|
||||
print(f"Importing migration script: {filepath}")
|
||||
module = import_migration_script(Path(filepath))
|
||||
|
||||
|
|
@ -174,10 +172,9 @@ def main(
|
|||
models = find_models(module)
|
||||
model_rows: dict[type[Model], int] = {}
|
||||
for model in models:
|
||||
rows = session.query(model).count()
|
||||
rows = db.session.query(model).count()
|
||||
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
|
||||
model_rows[model] = rows
|
||||
session.close()
|
||||
|
||||
print("Benchmarking migration")
|
||||
results: dict[str, float] = {}
|
||||
|
|
@ -199,16 +196,16 @@ def main(
|
|||
print(f"- Adding {missing} entities to the {model.__name__} model")
|
||||
bar = ChargingBar("Processing", max=missing)
|
||||
try:
|
||||
for entity in add_sample_rows(session, model, missing):
|
||||
for entity in add_sample_rows(model, missing):
|
||||
entities.append(entity)
|
||||
bar.next()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
db.session.rollback()
|
||||
raise
|
||||
bar.finish()
|
||||
model_rows[model] = min_entities
|
||||
session.add_all(entities)
|
||||
session.commit()
|
||||
db.session.add_all(entities)
|
||||
db.session.commit()
|
||||
|
||||
if auto_cleanup:
|
||||
new_models[model].extend(entities)
|
||||
|
|
@ -227,10 +224,10 @@ def main(
|
|||
print("Cleaning up DB")
|
||||
# delete in reverse order of creation to handle relationships
|
||||
for model, entities in list(new_models.items())[::-1]:
|
||||
session.query(model).filter(
|
||||
db.session.query(model).filter(
|
||||
model.id.in_(entity.id for entity in entities)
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
if current_revision != revision and not force:
|
||||
click.confirm(f"\nRevert DB to {revision}?", abort=True)
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ class CacheRestApi(BaseSupersetModelRestApi):
|
|||
datasource_uids = set(datasources.get("datasource_uids", []))
|
||||
for ds in datasources.get("datasources", []):
|
||||
ds_obj = SqlaTable.get_datasource_by_name(
|
||||
session=db.session,
|
||||
datasource_name=ds.get("datasource_name"),
|
||||
schema=ds.get("schema"),
|
||||
database_name=ds.get("database_name"),
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from datetime import datetime
|
|||
from typing import Any, Optional
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy.orm import make_transient, Session
|
||||
from sqlalchemy.orm import make_transient
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
|
|
@ -55,7 +55,6 @@ def import_chart(
|
|||
:returns: The resulting id for the imported slice
|
||||
:rtype: int
|
||||
"""
|
||||
session = db.session
|
||||
make_transient(slc_to_import)
|
||||
slc_to_import.dashboards = []
|
||||
slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
|
||||
|
|
@ -64,7 +63,6 @@ def import_chart(
|
|||
slc_to_import.reset_ownership()
|
||||
params = slc_to_import.params_dict
|
||||
datasource = SqlaTable.get_datasource_by_name(
|
||||
session=session,
|
||||
datasource_name=params["datasource_name"],
|
||||
database_name=params["database_name"],
|
||||
schema=params["schema"],
|
||||
|
|
@ -72,11 +70,11 @@ def import_chart(
|
|||
slc_to_import.datasource_id = datasource.id # type: ignore
|
||||
if slc_to_override:
|
||||
slc_to_override.override(slc_to_import)
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
return slc_to_override.id
|
||||
session.add(slc_to_import)
|
||||
db.session.add(slc_to_import)
|
||||
logger.info("Final slice: %s", str(slc_to_import.to_json()))
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
return slc_to_import.id
|
||||
|
||||
|
||||
|
|
@ -156,7 +154,6 @@ def import_dashboard(
|
|||
dashboard.json_metadata = json.dumps(json_metadata)
|
||||
|
||||
logger.info("Started import of the dashboard: %s", dashboard_to_import.to_json())
|
||||
session = db.session
|
||||
logger.info("Dashboard has %d slices", len(dashboard_to_import.slices))
|
||||
# copy slices object as Slice.import_slice will mutate the slice
|
||||
# and will remove the existing dashboard - slice association
|
||||
|
|
@ -173,7 +170,7 @@ def import_dashboard(
|
|||
i_params_dict = dashboard_to_import.params_dict
|
||||
remote_id_slice_map = {
|
||||
slc.params_dict["remote_id"]: slc
|
||||
for slc in session.query(Slice).all()
|
||||
for slc in db.session.query(Slice).all()
|
||||
if "remote_id" in slc.params_dict
|
||||
}
|
||||
for slc in slices:
|
||||
|
|
@ -224,7 +221,7 @@ def import_dashboard(
|
|||
|
||||
# override the dashboard
|
||||
existing_dashboard = None
|
||||
for dash in session.query(Dashboard).all():
|
||||
for dash in db.session.query(Dashboard).all():
|
||||
if (
|
||||
"remote_id" in dash.params_dict
|
||||
and dash.params_dict["remote_id"] == dashboard_to_import.id
|
||||
|
|
@ -253,18 +250,20 @@ def import_dashboard(
|
|||
alter_native_filters(dashboard_to_import)
|
||||
|
||||
new_slices = (
|
||||
session.query(Slice).filter(Slice.id.in_(old_to_new_slc_id_dict.values())).all()
|
||||
db.session.query(Slice)
|
||||
.filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
|
||||
.all()
|
||||
)
|
||||
|
||||
if existing_dashboard:
|
||||
existing_dashboard.override(dashboard_to_import)
|
||||
existing_dashboard.slices = new_slices
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
return existing_dashboard.id
|
||||
|
||||
dashboard_to_import.slices = new_slices
|
||||
session.add(dashboard_to_import)
|
||||
session.flush()
|
||||
db.session.add(dashboard_to_import)
|
||||
db.session.flush()
|
||||
return dashboard_to_import.id # type: ignore
|
||||
|
||||
|
||||
|
|
@ -291,7 +290,6 @@ def decode_dashboards(o: dict[str, Any]) -> Any:
|
|||
|
||||
|
||||
def import_dashboards(
|
||||
session: Session,
|
||||
content: str,
|
||||
database_id: Optional[int] = None,
|
||||
import_time: Optional[int] = None,
|
||||
|
|
@ -308,10 +306,10 @@ def import_dashboards(
|
|||
params = json.loads(table.params)
|
||||
dataset_id_mapping[params["remote_id"]] = new_dataset_id
|
||||
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
for dashboard in data["dashboards"]:
|
||||
import_dashboard(dashboard, dataset_id_mapping, import_time=import_time)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class ImportDashboardsCommand(BaseCommand):
|
||||
|
|
@ -334,7 +332,7 @@ class ImportDashboardsCommand(BaseCommand):
|
|||
|
||||
for file_name, content in self.contents.items():
|
||||
logger.info("Importing dashboard from file %s", file_name)
|
||||
import_dashboards(db.session, content, self.database_id)
|
||||
import_dashboards(content, self.database_id)
|
||||
|
||||
def validate(self) -> None:
|
||||
# ensure all files are JSON
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ from flask import request
|
|||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.explore.form_data.get import GetFormDataCommand
|
||||
from superset.commands.explore.form_data.parameters import (
|
||||
|
|
@ -114,7 +113,7 @@ class GetExploreCommand(BaseCommand, ABC):
|
|||
if self._datasource_id is not None:
|
||||
with contextlib.suppress(DatasourceNotFound):
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session, cast(str, self._datasource_type), self._datasource_id
|
||||
cast(str, self._datasource_type), self._datasource_id
|
||||
)
|
||||
datasource_name = datasource.name if datasource else _("[Missing Dataset]")
|
||||
viz_type = form_data.get("viz_type")
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ from superset.commands.exceptions import (
|
|||
)
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
from superset.daos.exceptions import DatasourceNotFound
|
||||
from superset.extensions import db
|
||||
from superset.utils.core import DatasourceType, get_user_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -80,7 +79,7 @@ def populate_roles(role_ids: list[int] | None = None) -> list[Role]:
|
|||
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
|
||||
try:
|
||||
return DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), datasource_id
|
||||
DatasourceType(datasource_type), datasource_id
|
||||
)
|
||||
except DatasourceNotFound as ex:
|
||||
raise DatasourceNotFoundValidationError() from ex
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from superset import app, db
|
||||
from superset import app
|
||||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.common.query_object import QueryObject
|
||||
|
|
@ -35,7 +35,7 @@ config = app.config
|
|||
|
||||
|
||||
def create_query_object_factory() -> QueryObjectFactory:
|
||||
return QueryObjectFactory(config, DatasourceDAO(), db.session)
|
||||
return QueryObjectFactory(config, DatasourceDAO())
|
||||
|
||||
|
||||
class QueryContextFactory: # pylint: disable=too-few-public-methods
|
||||
|
|
@ -95,7 +95,6 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
|||
|
||||
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
|
||||
return DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=DatasourceType(datasource["type"]),
|
||||
datasource_id=int(datasource["id"]),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,8 +33,6 @@ from superset.utils.core import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
|
||||
|
|
@ -42,17 +40,14 @@ if TYPE_CHECKING:
|
|||
class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
||||
_config: dict[str, Any]
|
||||
_datasource_dao: DatasourceDAO
|
||||
_session_maker: sessionmaker
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_configurations: dict[str, Any],
|
||||
_datasource_dao: DatasourceDAO,
|
||||
session_maker: sessionmaker,
|
||||
):
|
||||
self._config = app_configurations
|
||||
self._datasource_dao = _datasource_dao
|
||||
self._session_maker = session_maker
|
||||
|
||||
def create( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
|
|
@ -91,7 +86,6 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
|||
return self._datasource_dao.get_datasource(
|
||||
datasource_type=DatasourceType(datasource["type"]),
|
||||
datasource_id=int(datasource["id"]),
|
||||
session=self._session_maker(),
|
||||
)
|
||||
|
||||
def _process_extras(
|
||||
|
|
|
|||
|
|
@ -699,7 +699,7 @@ class BaseDatasource(
|
|||
|
||||
@classmethod
|
||||
def get_datasource_by_name(
|
||||
cls, session: Session, datasource_name: str, schema: str, database_name: str
|
||||
cls, datasource_name: str, schema: str, database_name: str
|
||||
) -> BaseDatasource | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -1238,14 +1238,13 @@ class SqlaTable(
|
|||
@classmethod
|
||||
def get_datasource_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_name: str,
|
||||
schema: str | None,
|
||||
database_name: str,
|
||||
) -> SqlaTable | None:
|
||||
schema = schema or None
|
||||
query = (
|
||||
session.query(cls)
|
||||
db.session.query(cls)
|
||||
.join(Database)
|
||||
.filter(cls.table_name == datasource_name)
|
||||
.filter(Database.database_name == database_name)
|
||||
|
|
@ -1939,12 +1938,10 @@ class SqlaTable(
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_eager_sqlatable_datasource(
|
||||
cls, session: Session, datasource_id: int
|
||||
) -> SqlaTable:
|
||||
def get_eager_sqlatable_datasource(cls, datasource_id: int) -> SqlaTable:
|
||||
"""Returns SqlaTable with columns and metrics."""
|
||||
return (
|
||||
session.query(cls)
|
||||
db.session.query(cls)
|
||||
.options(
|
||||
sa.orm.subqueryload(cls.columns),
|
||||
sa.orm.subqueryload(cls.metrics),
|
||||
|
|
@ -2037,8 +2034,7 @@ class SqlaTable(
|
|||
:param connection: Unused.
|
||||
:param target: The metric or column that was updated.
|
||||
"""
|
||||
inspector = inspect(target)
|
||||
session = inspector.session
|
||||
session = inspect(target).session
|
||||
|
||||
# Forces an update to the table's changed_on value when a metric or column on the
|
||||
# table is updated. This busts the cache key for all charts that use the table.
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class DashboardDAO(BaseDAO[Dashboard]):
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_dash_metadata( # pylint: disable=too-many-locals
|
||||
def set_dash_metadata(
|
||||
dashboard: Dashboard,
|
||||
data: dict[Any, Any],
|
||||
old_to_new_slice_ids: dict[int, int] | None = None,
|
||||
|
|
@ -187,8 +187,9 @@ class DashboardDAO(BaseDAO[Dashboard]):
|
|||
if isinstance(value, dict)
|
||||
]
|
||||
|
||||
session = db.session()
|
||||
current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
|
||||
current_slices = (
|
||||
db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
|
||||
)
|
||||
|
||||
dashboard.slices = current_slices
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@
|
|||
import logging
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
|
||||
|
|
@ -45,7 +44,6 @@ class DatasourceDAO(BaseDAO[Datasource]):
|
|||
@classmethod
|
||||
def get_datasource(
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_type: Union[DatasourceType, str],
|
||||
datasource_id: int,
|
||||
) -> Datasource:
|
||||
|
|
@ -53,7 +51,7 @@ class DatasourceDAO(BaseDAO[Datasource]):
|
|||
raise DatasourceTypeNotSupportedError()
|
||||
|
||||
datasource = (
|
||||
session.query(cls.sources[datasource_type])
|
||||
db.session.query(cls.sources[datasource_type])
|
||||
.filter_by(id=datasource_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import logging
|
|||
|
||||
from flask_appbuilder.api import expose, protect, safe
|
||||
|
||||
from superset import app, db, event_logger
|
||||
from superset import app, event_logger
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
|
@ -100,7 +100,7 @@ class DatasourceRestApi(BaseSupersetApi):
|
|||
"""
|
||||
try:
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), datasource_id
|
||||
DatasourceType(datasource_type), datasource_id
|
||||
)
|
||||
datasource.raise_for_access()
|
||||
except ValueError:
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from sqlalchemy import (
|
|||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
|
||||
from sqlalchemy.orm import relationship, subqueryload
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.sql import join, select
|
||||
from sqlalchemy.sql.elements import BinaryExpression
|
||||
|
|
@ -62,38 +62,33 @@ config = app.config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -> None:
|
||||
def copy_dashboard(_mapper: Mapper, _connection: Connection, target: Dashboard) -> None:
|
||||
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
|
||||
if dashboard_id is None:
|
||||
return
|
||||
|
||||
session_class = sessionmaker(autoflush=False)
|
||||
session = session_class(bind=connection)
|
||||
session = sqla.inspect(target).session
|
||||
new_user = session.query(User).filter_by(id=target.id).first()
|
||||
|
||||
try:
|
||||
new_user = session.query(User).filter_by(id=target.id).first()
|
||||
# copy template dashboard to user
|
||||
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
|
||||
dashboard = Dashboard(
|
||||
dashboard_title=template.dashboard_title,
|
||||
position_json=template.position_json,
|
||||
description=template.description,
|
||||
css=template.css,
|
||||
json_metadata=template.json_metadata,
|
||||
slices=template.slices,
|
||||
owners=[new_user],
|
||||
)
|
||||
session.add(dashboard)
|
||||
|
||||
# copy template dashboard to user
|
||||
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
|
||||
dashboard = Dashboard(
|
||||
dashboard_title=template.dashboard_title,
|
||||
position_json=template.position_json,
|
||||
description=template.description,
|
||||
css=template.css,
|
||||
json_metadata=template.json_metadata,
|
||||
slices=template.slices,
|
||||
owners=[new_user],
|
||||
)
|
||||
session.add(dashboard)
|
||||
|
||||
# set dashboard as the welcome dashboard
|
||||
extra_attributes = UserAttribute(
|
||||
user_id=target.id, welcome_dashboard_id=dashboard.id
|
||||
)
|
||||
session.add(extra_attributes)
|
||||
session.commit()
|
||||
finally:
|
||||
session.close()
|
||||
# set dashboard as the welcome dashboard
|
||||
extra_attributes = UserAttribute(
|
||||
user_id=target.id, welcome_dashboard_id=dashboard.id
|
||||
)
|
||||
session.add(extra_attributes)
|
||||
session.commit()
|
||||
|
||||
|
||||
sqla.event.listen(User, "after_insert", copy_dashboard)
|
||||
|
|
@ -397,7 +392,7 @@ class Dashboard(AuditMixinNullable, ImportExportMixin, Model):
|
|||
if id_ is None:
|
||||
continue
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session, utils.DatasourceType.TABLE, id_
|
||||
utils.DatasourceType.TABLE, id_
|
||||
)
|
||||
datasource_ids.add((datasource.id, datasource.type))
|
||||
|
||||
|
|
@ -406,9 +401,7 @@ class Dashboard(AuditMixinNullable, ImportExportMixin, Model):
|
|||
|
||||
eager_datasources = []
|
||||
for datasource_id, _ in datasource_ids:
|
||||
eager_datasource = SqlaTable.get_eager_sqlatable_datasource(
|
||||
db.session, datasource_id
|
||||
)
|
||||
eager_datasource = SqlaTable.get_eager_sqlatable_datasource(datasource_id)
|
||||
copied_datasource = eager_datasource.copy()
|
||||
copied_datasource.alter_params(
|
||||
remote_id=eager_datasource.id,
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ from flask_login import AnonymousUserMixin, LoginManager
|
|||
from jwt.api_jwt import _jwt_global_obj
|
||||
from sqlalchemy import and_, inspect, or_
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import eagerload, Session
|
||||
from sqlalchemy.orm import eagerload
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import Query as SqlaQuery
|
||||
|
||||
|
|
@ -545,8 +545,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
)
|
||||
|
||||
# group all datasources by database
|
||||
session = self.get_session
|
||||
all_datasources = SqlaTable.get_all_datasources(session)
|
||||
all_datasources = SqlaTable.get_all_datasources(self.get_session)
|
||||
datasources_by_database: dict["Database", set["SqlaTable"]] = defaultdict(set)
|
||||
for datasource in all_datasources:
|
||||
datasources_by_database[datasource.database].add(datasource)
|
||||
|
|
@ -2017,17 +2016,14 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
self.get_dashboard_access_error_object(dashboard)
|
||||
)
|
||||
|
||||
def get_user_by_username(
|
||||
self, username: str, session: Session = None
|
||||
) -> Optional[User]:
|
||||
def get_user_by_username(self, username: str) -> Optional[User]:
|
||||
"""
|
||||
Retrieves a user by it's username case sensitive. Optional session parameter
|
||||
utility method normally useful for celery tasks where the session
|
||||
need to be scoped
|
||||
"""
|
||||
session = session or self.get_session
|
||||
return (
|
||||
session.query(self.user_model)
|
||||
self.get_session.query(self.user_model)
|
||||
.filter(self.user_model.username == username)
|
||||
.one_or_none()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -79,6 +79,5 @@ def remove_database(database: Database) -> None:
|
|||
# pylint: disable=import-outside-toplevel
|
||||
from superset import db
|
||||
|
||||
session = db.session
|
||||
session.delete(database)
|
||||
session.commit()
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from contextlib import contextmanager
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, cast, Literal, TYPE_CHECKING
|
||||
|
||||
from flask import current_app, g, request
|
||||
from flask import g, request
|
||||
from flask_appbuilder.const import API_URI_RIS_KEY
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
|
@ -139,6 +139,7 @@ class AbstractEventLogger(ABC):
|
|||
**payload_override: dict[str, Any] | None,
|
||||
) -> None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import db
|
||||
from superset.views.core import get_form_data
|
||||
|
||||
referrer = request.referrer[:1000] if request and request.referrer else None
|
||||
|
|
@ -152,8 +153,7 @@ class AbstractEventLogger(ABC):
|
|||
# need to add them back before logging to capture user_id
|
||||
if user_id is None:
|
||||
try:
|
||||
session = current_app.appbuilder.get_session
|
||||
session.add(g.user)
|
||||
db.session.add(g.user)
|
||||
user_id = get_user_id()
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logging.warning(ex)
|
||||
|
|
@ -332,6 +332,7 @@ class DBEventLogger(AbstractEventLogger):
|
|||
**kwargs: Any,
|
||||
) -> None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import db
|
||||
from superset.models.core import Log
|
||||
|
||||
records = kwargs.get("records", [])
|
||||
|
|
@ -353,9 +354,8 @@ class DBEventLogger(AbstractEventLogger):
|
|||
)
|
||||
logs.append(log)
|
||||
try:
|
||||
sesh = current_app.appbuilder.get_session
|
||||
sesh.bulk_save_objects(logs)
|
||||
sesh.commit()
|
||||
db.session.bulk_save_objects(logs)
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as ex:
|
||||
logging.error("DBEventLogger failed to log event(s)")
|
||||
logging.exception(ex)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ import sqlalchemy_utils
|
|||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, inspect, MetaData, Table
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.sql.visitors import VisitableType
|
||||
|
||||
|
|
@ -231,12 +230,10 @@ def generate_column_data(column: ColumnInfo, num_rows: int) -> list[Any]:
|
|||
return [gen() for _ in range(num_rows)]
|
||||
|
||||
|
||||
def add_sample_rows(
|
||||
session: Session, model: type[Model], count: int
|
||||
) -> Iterator[Model]:
|
||||
def add_sample_rows(model: type[Model], count: int) -> Iterator[Model]:
|
||||
"""
|
||||
Add entities of a given model.
|
||||
:param Session session: an SQLAlchemy session
|
||||
|
||||
:param Model model: a Superset/FAB model
|
||||
:param int count: how many entities to generate and insert
|
||||
"""
|
||||
|
|
@ -244,7 +241,7 @@ def add_sample_rows(
|
|||
|
||||
# select samples to copy relationship values
|
||||
relationships = inspector.relationships.items()
|
||||
samples = session.query(model).limit(count).all() if relationships else []
|
||||
samples = db.session.query(model).limit(count).all() if relationships else []
|
||||
|
||||
max_primary_key: Optional[int] = None
|
||||
for i in range(count):
|
||||
|
|
@ -255,7 +252,7 @@ def add_sample_rows(
|
|||
if column.primary_key:
|
||||
if max_primary_key is None:
|
||||
max_primary_key = (
|
||||
session.query(func.max(getattr(model, column.name))).scalar()
|
||||
db.session.query(func.max(getattr(model, column.name))).scalar()
|
||||
or 0
|
||||
)
|
||||
max_primary_key += 1
|
||||
|
|
|
|||
|
|
@ -510,7 +510,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
if datasource_id is not None:
|
||||
with contextlib.suppress(DatasetNotFoundError):
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session,
|
||||
DatasourceType("table"),
|
||||
datasource_id,
|
||||
)
|
||||
|
|
@ -751,7 +750,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
In terms of the `extra_filters` these can be obtained from records in the JSON
|
||||
encoded `logs.json` column associated with the `explore_json` action.
|
||||
"""
|
||||
session = db.session()
|
||||
slice_id = request.args.get("slice_id")
|
||||
dashboard_id = request.args.get("dashboard_id")
|
||||
table_name = request.args.get("table_name")
|
||||
|
|
@ -768,14 +766,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
status=400,
|
||||
)
|
||||
if slice_id:
|
||||
slices = session.query(Slice).filter_by(id=slice_id).all()
|
||||
slices = db.session.query(Slice).filter_by(id=slice_id).all()
|
||||
if not slices:
|
||||
return json_error_response(
|
||||
__("Chart %(id)s not found", id=slice_id), status=404
|
||||
)
|
||||
elif table_name and db_name:
|
||||
table = (
|
||||
session.query(SqlaTable)
|
||||
db.session.query(SqlaTable)
|
||||
.join(Database)
|
||||
.filter(
|
||||
Database.database_name == db_name
|
||||
|
|
@ -792,7 +790,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
status=404,
|
||||
)
|
||||
slices = (
|
||||
session.query(Slice)
|
||||
db.session.query(Slice)
|
||||
.filter_by(datasource_id=table.id, datasource_type=table.type)
|
||||
.all()
|
||||
)
|
||||
|
|
@ -919,7 +917,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
datasource_id, datasource_type = request.args["datasourceKey"].split("__")
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), int(datasource_id)
|
||||
DatasourceType(datasource_type), int(datasource_id)
|
||||
)
|
||||
# Check if datasource exists
|
||||
if not datasource:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
from typing import Any, Optional
|
||||
|
||||
from superset import app, db
|
||||
from superset import app
|
||||
from superset.commands.dataset.exceptions import DatasetSamplesFailedError
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
|
|
@ -52,7 +52,6 @@ def get_samples( # pylint: disable=too-many-arguments
|
|||
payload: Optional[SamplesPayloadSchema] = None,
|
||||
) -> dict[str, Any]:
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=datasource_type,
|
||||
datasource_id=datasource_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class Datasource(BaseSupersetView):
|
|||
datasource_type = datasource_dict.get("type")
|
||||
database_id = datasource_dict["database"].get("id")
|
||||
orm_datasource = DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), datasource_id
|
||||
DatasourceType(datasource_type), datasource_id
|
||||
)
|
||||
orm_datasource.database_id = database_id
|
||||
|
||||
|
|
@ -126,7 +126,7 @@ class Datasource(BaseSupersetView):
|
|||
@deprecated(new_target="/api/v1/dataset/<int:pk>")
|
||||
def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), datasource_id
|
||||
DatasourceType(datasource_type), datasource_id
|
||||
)
|
||||
return self.json_response(sanitize_datasource_data(datasource.data))
|
||||
|
||||
|
|
@ -139,7 +139,6 @@ class Datasource(BaseSupersetView):
|
|||
) -> FlaskResponse:
|
||||
"""Gets column info from the source system"""
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session,
|
||||
DatasourceType(datasource_type),
|
||||
datasource_id,
|
||||
)
|
||||
|
|
@ -164,7 +163,6 @@ class Datasource(BaseSupersetView):
|
|||
return json_error_response(str(err), status=400)
|
||||
|
||||
datasource = SqlaTable.get_datasource_by_name(
|
||||
session=db.session,
|
||||
database_name=params["database_name"],
|
||||
schema=params["schema_name"],
|
||||
datasource_name=params["table_name"],
|
||||
|
|
|
|||
|
|
@ -129,7 +129,6 @@ def get_viz(
|
|||
) -> BaseViz:
|
||||
viz_type = form_data.get("viz_type", "table")
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session,
|
||||
DatasourceType(datasource_type),
|
||||
datasource_id,
|
||||
)
|
||||
|
|
@ -312,8 +311,7 @@ CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"]
|
|||
def get_dashboard_extra_filters(
|
||||
slice_id: int, dashboard_id: int
|
||||
) -> list[dict[str, Any]]:
|
||||
session = db.session()
|
||||
dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
|
||||
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
|
||||
|
||||
# is chart in this dashboard?
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -474,7 +474,7 @@ class TestDatasource(SupersetTestCase):
|
|||
|
||||
pytest.raises(
|
||||
DatasourceNotFound,
|
||||
lambda: DatasourceDAO.get_datasource(db.session, "table", 9999999),
|
||||
lambda: DatasourceDAO.get_datasource("table", 9999999),
|
||||
)
|
||||
|
||||
self.login(username="admin")
|
||||
|
|
@ -486,7 +486,7 @@ class TestDatasource(SupersetTestCase):
|
|||
|
||||
pytest.raises(
|
||||
DatasourceTypeNotSupportedError,
|
||||
lambda: DatasourceDAO.get_datasource(db.session, "druid", 9999999),
|
||||
lambda: DatasourceDAO.get_datasource("druid", 9999999),
|
||||
)
|
||||
|
||||
self.login(username="admin")
|
||||
|
|
|
|||
|
|
@ -145,7 +145,6 @@ class TestQueryContext(SupersetTestCase):
|
|||
|
||||
# make temporary change and revert it to refresh the changed_on property
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=DatasourceType(payload["datasource"]["type"]),
|
||||
datasource_id=payload["datasource"]["id"],
|
||||
)
|
||||
|
|
@ -169,7 +168,6 @@ class TestQueryContext(SupersetTestCase):
|
|||
|
||||
# make temporary change and revert it to refresh the changed_on property
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=DatasourceType(payload["datasource"]["type"]),
|
||||
datasource_id=payload["datasource"]["id"],
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -38,11 +38,6 @@ def app_config() -> dict[str, Any]:
|
|||
return create_app_config().copy()
|
||||
|
||||
|
||||
@fixture
|
||||
def session_factory() -> Mock:
|
||||
return Mock()
|
||||
|
||||
|
||||
@fixture
|
||||
def connector_registry() -> Mock:
|
||||
return Mock(spec=["get_datasource"])
|
||||
|
|
@ -58,12 +53,12 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
|
|||
|
||||
@fixture
|
||||
def query_object_factory(
|
||||
app_config: dict[str, Any], connector_registry: Mock, session_factory: Mock
|
||||
app_config: dict[str, Any], connector_registry: Mock
|
||||
) -> QueryObjectFactory:
|
||||
import superset.common.query_object_factory as mod
|
||||
|
||||
mod.apply_max_row_limit = apply_max_row_limit
|
||||
return QueryObjectFactory(app_config, connector_registry, session_factory)
|
||||
return QueryObjectFactory(app_config, connector_registry)
|
||||
|
||||
|
||||
@fixture
|
||||
|
|
|
|||
|
|
@ -172,7 +172,6 @@ def dummy_query_object(request, app_context):
|
|||
"ROW_LIMIT": 100,
|
||||
},
|
||||
_datasource_dao=unittest.mock.Mock(),
|
||||
session_maker=unittest.mock.Mock(),
|
||||
).create(parent_result_type=result_type, **query_object)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -106,7 +106,6 @@ def test_get_datasource_sqlatable(session_with_data: Session) -> None:
|
|||
result = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
)
|
||||
|
||||
assert 1 == result.id
|
||||
|
|
@ -119,7 +118,7 @@ def test_get_datasource_query(session_with_data: Session) -> None:
|
|||
from superset.models.sql_lab import Query
|
||||
|
||||
result = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType.QUERY, datasource_id=1, session=session_with_data
|
||||
datasource_type=DatasourceType.QUERY, datasource_id=1
|
||||
)
|
||||
|
||||
assert result.id == 1
|
||||
|
|
@ -133,7 +132,6 @@ def test_get_datasource_saved_query(session_with_data: Session) -> None:
|
|||
result = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType.SAVEDQUERY,
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
)
|
||||
|
||||
assert result.id == 1
|
||||
|
|
@ -147,7 +145,6 @@ def test_get_datasource_sl_table(session_with_data: Session) -> None:
|
|||
result = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType.SLTABLE,
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
)
|
||||
|
||||
assert result.id == 1
|
||||
|
|
@ -161,7 +158,6 @@ def test_get_datasource_sl_dataset(session_with_data: Session) -> None:
|
|||
result = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType.DATASET,
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
)
|
||||
|
||||
assert result.id == 1
|
||||
|
|
@ -178,7 +174,6 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None:
|
|||
DatasourceDAO.get_datasource(
|
||||
datasource_type="table",
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
),
|
||||
SqlaTable,
|
||||
)
|
||||
|
|
@ -187,7 +182,6 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None:
|
|||
DatasourceDAO.get_datasource(
|
||||
datasource_type="sl_table",
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
),
|
||||
Table,
|
||||
)
|
||||
|
|
@ -208,5 +202,4 @@ def test_not_found_datasource(session_with_data: Session) -> None:
|
|||
DatasourceDAO.get_datasource(
|
||||
datasource_type="table",
|
||||
datasource_id=500000,
|
||||
session=session_with_data,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue