chore: remove duplicate code in `SqlaTable` (#28752)
This commit is contained in:
parent
020c79970f
commit
643ee17544
|
|
@ -21,7 +21,6 @@ import builtins
|
|||
import dataclasses
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
|
|
@ -70,7 +69,7 @@ from sqlalchemy.sql.elements import ColumnClause, TextClause
|
|||
from sqlalchemy.sql.expression import Label, TextAsFrom
|
||||
from sqlalchemy.sql.selectable import Alias, TableClause
|
||||
|
||||
from superset import app, db, is_feature_enabled, security_manager
|
||||
from superset import app, db, security_manager
|
||||
from superset.commands.dataset.exceptions import DatasetNotFoundError
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.connectors.sqla.utils import (
|
||||
|
|
@ -1603,48 +1602,6 @@ class SqlaTable(
|
|||
if is_alias_used_in_orderby(col):
|
||||
col.name = f"{col.name}__"
|
||||
|
||||
def get_sqla_row_level_filters(
|
||||
self,
|
||||
template_processor: BaseTemplateProcessor,
|
||||
) -> list[TextClause]:
|
||||
"""
|
||||
Return the appropriate row level security filters for this table and the
|
||||
current user. A custom username can be passed when the user is not present in the
|
||||
Flask global namespace.
|
||||
|
||||
:param template_processor: The template processor to apply to the filters.
|
||||
:returns: A list of SQL clauses to be ANDed together.
|
||||
"""
|
||||
all_filters: list[TextClause] = []
|
||||
filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
|
||||
try:
|
||||
for filter_ in security_manager.get_rls_filters(self):
|
||||
clause = self.text(
|
||||
f"({template_processor.process_template(filter_.clause)})"
|
||||
)
|
||||
if filter_.group_key:
|
||||
filter_groups[filter_.group_key].append(clause)
|
||||
else:
|
||||
all_filters.append(clause)
|
||||
|
||||
if is_feature_enabled("EMBEDDED_SUPERSET"):
|
||||
for rule in security_manager.get_guest_rls_filters(self):
|
||||
clause = self.text(
|
||||
f"({template_processor.process_template(rule['clause'])})"
|
||||
)
|
||||
all_filters.append(clause)
|
||||
|
||||
grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
|
||||
all_filters.extend(grouped_filters)
|
||||
return all_filters
|
||||
except TemplateError as ex:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Error in jinja expression in RLS filters: %(msg)s",
|
||||
msg=ex.message,
|
||||
)
|
||||
) from ex
|
||||
|
||||
def text(self, clause: str) -> TextClause:
|
||||
return self.db_engine_spec.get_text_clause(clause)
|
||||
|
||||
|
|
|
|||
|
|
@ -805,7 +805,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
|
||||
def get_sqla_row_level_filters(
|
||||
self,
|
||||
template_processor: BaseTemplateProcessor,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> list[TextClause]:
|
||||
"""
|
||||
Return the appropriate row level security filters for this table and the
|
||||
|
|
@ -815,6 +815,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
:param template_processor: The template processor to apply to the filters.
|
||||
:returns: A list of SQL clauses to be ANDed together.
|
||||
"""
|
||||
template_processor = template_processor or self.get_template_processor()
|
||||
|
||||
all_filters: list[TextClause] = []
|
||||
filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1269,10 +1269,8 @@ def get_rls_for_table(
|
|||
if not dataset:
|
||||
return None
|
||||
|
||||
template_processor = dataset.get_template_processor()
|
||||
predicate = " AND ".join(
|
||||
str(filter_)
|
||||
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
|
||||
str(filter_) for filter_ in dataset.get_sqla_row_level_filters()
|
||||
)
|
||||
if not predicate:
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue