chore: Add explicit bidirectional performant relationships for SQLA model (#22413)

This commit is contained in:
John Bodley 2023-01-21 10:17:56 +13:00 committed by GitHub
parent 858c6e19a0
commit 92cdb8c282
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 44 additions and 24 deletions

View File

@ -68,7 +68,14 @@ from sqlalchemy import (
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm import (
backref,
Mapped,
Query,
relationship,
RelationshipProperty,
Session,
)
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
@ -224,10 +231,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table: SqlaTable = relationship(
table: Mapped["SqlaTable"] = relationship(
"SqlaTable",
backref=backref("columns", cascade="all, delete-orphan"),
foreign_keys=[table_id],
back_populates="columns",
lazy="joined", # Eager loading for efficient parent referencing with selectin.
)
is_dttm = Column(Boolean, default=False)
expression = Column(MediumText())
@ -439,10 +446,10 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
__tablename__ = "sql_metrics"
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table = relationship(
table: Mapped["SqlaTable"] = relationship(
"SqlaTable",
backref=backref("metrics", cascade="all, delete-orphan"),
foreign_keys=[table_id],
back_populates="metrics",
lazy="joined", # Eager loading for efficient parent referencing with selectin.
)
expression = Column(MediumText(), nullable=False)
extra = Column(Text)
@ -535,13 +542,23 @@ def _process_sql_expression(
class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""
"""An ORM object for SqlAlchemy table references."""
type = "table"
query_language = "sql"
is_rls_supported = True
columns: List[TableColumn] = []
metrics: List[SqlMetric] = []
columns: Mapped[List[TableColumn]] = relationship(
TableColumn,
back_populates="table",
cascade="all, delete-orphan",
lazy="selectin", # Only non-eager loading that works with bidirectional joined.
)
metrics: Mapped[List[SqlMetric]] = relationship(
SqlMetric,
back_populates="table",
cascade="all, delete-orphan",
lazy="selectin", # Only non-eager loading that works with bidirectional joined.
)
metric_class = SqlMetric
column_class = TableColumn
owner_class = security_manager.user_model

View File

@ -242,6 +242,9 @@ class DatasetRestApi(BaseSupersetModelRestApi):
DatasetDuplicateSchema,
)
list_outer_default_load = True
show_outer_default_load = True
@expose("/", methods=["POST"])
@protect()
@safe

View File

@ -144,8 +144,8 @@ def _add_table_metrics(datasource: SqlaTable) -> None:
metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))
for col in columns:
if col.column_name == "ds":
col.is_dttm = True
if col.column_name == "ds": # type: ignore
col.is_dttm = True # type: ignore
break
datasource.columns = columns

View File

@ -423,11 +423,6 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
remote_id=eager_datasource.id,
database_name=eager_datasource.database.name,
)
datasource_class = copied_datasource.__class__
for field_name in datasource_class.export_children:
field_val = getattr(eager_datasource, field_name).copy()
# set children without creating ORM relations
copied_datasource.__dict__[field_name] = field_val
eager_datasources.append(copied_datasource)
return json.dumps(

View File

@ -327,7 +327,10 @@ class ImportExportMixin:
# Recursively create children
if recursive:
for child in cls.export_children:
child_class = cls.__mapper__.relationships[child].argument.class_
argument = cls.__mapper__.relationships[child].argument
child_class = (
argument.class_ if hasattr(argument, "class_") else argument
)
added = []
for c_obj in new_children.get(child, []):
added.append(

View File

@ -1315,9 +1315,10 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
chart.owners = []
dataset.owners = []
database.owners = []
db.session.delete(chart)
db.session.commit()
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()
@ -1387,9 +1388,10 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
chart.owners = []
dataset.owners = []
database.owners = []
db.session.delete(chart)
db.session.commit()
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()

View File

@ -1987,8 +1987,8 @@ class TestDatabaseApi(SupersetTestCase):
assert str(dataset.uuid) == dataset_config["uuid"]
dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()
@ -2058,8 +2058,8 @@ class TestDatabaseApi(SupersetTestCase):
)
dataset = database.tables[0]
dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()

View File

@ -1988,8 +1988,8 @@ class TestDatasetApi(SupersetTestCase):
assert str(dataset.uuid) == dataset_config["uuid"]
dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()
@ -2090,8 +2090,8 @@ class TestDatasetApi(SupersetTestCase):
dataset = database.tables[0]
dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()