feat(row-level-security): add base filter type and filter grouping (#10946)
* feat(row-level-security): add filter type and group key * simplify tests and add custom list widget * address comments * use enum value to ensure case sensitive value is used
This commit is contained in:
parent
3be8bdad9a
commit
448a41a4e7
|
|
@ -261,8 +261,8 @@ class SupersetAppInitializer:
|
|||
if self.config["ENABLE_ROW_LEVEL_SECURITY"]:
|
||||
appbuilder.add_view(
|
||||
RowLevelSecurityFiltersModelView,
|
||||
"Row Level Security Filters",
|
||||
label=__("Row level security filters"),
|
||||
"Row Level Security",
|
||||
label=__("Row level security"),
|
||||
category="Security",
|
||||
category_label=__("Security"),
|
||||
icon="fa-lock",
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import json
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from collections import defaultdict, OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union
|
||||
|
|
@ -35,6 +35,7 @@ from sqlalchemy import (
|
|||
Column,
|
||||
DateTime,
|
||||
desc,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
or_,
|
||||
|
|
@ -92,8 +93,8 @@ class MetadataResult:
|
|||
|
||||
|
||||
class AnnotationDatasource(BaseDatasource):
|
||||
""" Dummy object so we can query annotations using 'Viz' objects just like
|
||||
regular datasources.
|
||||
"""Dummy object so we can query annotations using 'Viz' objects just like
|
||||
regular datasources.
|
||||
"""
|
||||
|
||||
cache_timeout = 0
|
||||
|
|
@ -798,11 +799,14 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
:returns: A list of SQL clauses to be ANDed together.
|
||||
:rtype: List[str]
|
||||
"""
|
||||
filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list)
|
||||
try:
|
||||
return [
|
||||
text("({})".format(template_processor.process_template(f.clause)))
|
||||
for f in security_manager.get_rls_filters(self)
|
||||
]
|
||||
for filter_ in security_manager.get_rls_filters(self):
|
||||
clause = text(
|
||||
f"({template_processor.process_template(filter_.clause)})"
|
||||
)
|
||||
filters_grouped[filter_.group_key or filter_.id].append(clause)
|
||||
return [or_(*clauses) for clauses in filters_grouped.values()]
|
||||
except TemplateError as ex:
|
||||
raise QueryObjectValidationError(
|
||||
_("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,)
|
||||
|
|
@ -1371,9 +1375,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
) -> int:
|
||||
"""Imports the datasource from the object to the database.
|
||||
|
||||
Metrics and columns and datasource will be overrided if exists.
|
||||
This function can be used to import/export dashboards between multiple
|
||||
superset instances. Audit metadata isn't copies over.
|
||||
Metrics and columns and datasource will be overrided if exists.
|
||||
This function can be used to import/export dashboards between multiple
|
||||
superset instances. Audit metadata isn't copies over.
|
||||
"""
|
||||
|
||||
def lookup_sqlatable(table_: "SqlaTable") -> "SqlaTable":
|
||||
|
|
@ -1506,6 +1510,10 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
|
|||
|
||||
__tablename__ = "row_level_security_filters"
|
||||
id = Column(Integer, primary_key=True)
|
||||
filter_type = Column(
|
||||
Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType])
|
||||
)
|
||||
group_key = Column(String(255), nullable=True)
|
||||
roles = relationship(
|
||||
security_manager.role_model,
|
||||
secondary=RLSFilterRoles,
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@
|
|||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Union
|
||||
from typing import Any, cast, Dict, List, Union
|
||||
|
||||
from flask import flash, Markup, redirect
|
||||
from flask import current_app, flash, Markup, redirect
|
||||
from flask_appbuilder import CompactCRUDMixin, expose
|
||||
from flask_appbuilder.actions import action
|
||||
from flask_appbuilder.fieldwidgets import Select2Widget
|
||||
|
|
@ -41,6 +41,7 @@ from superset.views.base import (
|
|||
DatasourceFilter,
|
||||
DeleteMixin,
|
||||
ListWidgetWithCheckboxes,
|
||||
SupersetListWidget,
|
||||
SupersetModelView,
|
||||
validate_sqlatable,
|
||||
YamlExportMixin,
|
||||
|
|
@ -241,30 +242,73 @@ class SqlMetricInlineView( # pylint: disable=too-many-ancestors
|
|||
edit_form_extra_fields = add_form_extra_fields
|
||||
|
||||
|
||||
class RowLevelSecurityListWidget(
|
||||
SupersetListWidget
|
||||
): # pylint: disable=too-few-public-methods
|
||||
template = "superset/models/rls/list.html"
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
kwargs["appbuilder"] = current_app.appbuilder
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors
|
||||
SupersetModelView, DeleteMixin
|
||||
):
|
||||
datamodel = SQLAInterface(models.RowLevelSecurityFilter)
|
||||
|
||||
list_widget = cast(SupersetListWidget, RowLevelSecurityListWidget)
|
||||
|
||||
list_title = _("Row level security filter")
|
||||
show_title = _("Show Row level security filter")
|
||||
add_title = _("Add Row level security filter")
|
||||
edit_title = _("Edit Row level security filter")
|
||||
|
||||
list_columns = ["tables", "roles", "clause", "creator", "modified"]
|
||||
order_columns = ["tables", "clause", "modified"]
|
||||
edit_columns = ["tables", "roles", "clause"]
|
||||
list_columns = [
|
||||
"filter_type",
|
||||
"tables",
|
||||
"roles",
|
||||
"group_key",
|
||||
"clause",
|
||||
"creator",
|
||||
"modified",
|
||||
]
|
||||
order_columns = ["filter_type", "group_key", "clause", "modified"]
|
||||
edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"]
|
||||
show_columns = edit_columns
|
||||
search_columns = ("tables", "roles", "clause")
|
||||
search_columns = ("filter_type", "tables", "roles", "group_key", "clause")
|
||||
add_columns = edit_columns
|
||||
base_order = ("changed_on", "desc")
|
||||
description_columns = {
|
||||
"filter_type": _(
|
||||
"Regular filters add where clauses to queries if a user belongs to a "
|
||||
"role referenced in the filter. Base filters apply filters to all queries "
|
||||
"except the roles defined in the filter, and can be used to define what "
|
||||
"users can see if no RLS filters within a filter group apply to them."
|
||||
),
|
||||
"tables": _("These are the tables this filter will be applied to."),
|
||||
"roles": _("These are the roles this filter will be applied to."),
|
||||
"roles": _(
|
||||
"For regular filters, these are the roles this filter will be "
|
||||
"applied to. For base filters, these are the roles that the "
|
||||
"filter DOES NOT apply to, e.g. Admin if admin should see all "
|
||||
"data."
|
||||
),
|
||||
"group_key": _(
|
||||
"Filters with the same group key will be ORed together within the group, "
|
||||
"while different filter groups will be ANDed together. Undefined group "
|
||||
"keys are treated as unique groups, i.e. are not grouped together. "
|
||||
"For example, if a table has three filters, of which two are for "
|
||||
"departments Finance and Marketing (group key = 'department'), and one "
|
||||
"refers to the region Europe (group key = 'region'), the filter clause "
|
||||
"would apply the filter (department = 'Finance' OR department = "
|
||||
"'Marketing') AND (region = 'Europe')."
|
||||
),
|
||||
"clause": _(
|
||||
"This is the condition that will be added to the WHERE clause. "
|
||||
"For example, to only return rows for a particular client, "
|
||||
"you might put in: client_id = 9"
|
||||
"you might define a regular filter with the clause `client_id = 9`. To "
|
||||
"display no rows unless a user belongs to a RLS filter role, a base "
|
||||
"filter can be created with the clause `1 = 0` (always false)."
|
||||
),
|
||||
}
|
||||
label_columns = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""add rls filter type and grouping key
|
||||
|
||||
Revision ID: e5ef6828ac4e
|
||||
Revises: ae19b4ee3692
|
||||
Create Date: 2020-09-15 18:22:40.130985
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e5ef6828ac4e"
|
||||
down_revision = "ae19b4ee3692"
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from superset.utils import core as utils
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("row_level_security_filters") as batch_op:
|
||||
batch_op.add_column(sa.Column("filter_type", sa.VARCHAR(255), nullable=True)),
|
||||
batch_op.add_column(sa.Column("group_key", sa.VARCHAR(255), nullable=True)),
|
||||
batch_op.create_index(
|
||||
op.f("ix_row_level_security_filters_filter_type"),
|
||||
["filter_type"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
bind = op.get_bind()
|
||||
metadata = sa.MetaData(bind=bind)
|
||||
filters = sa.Table("row_level_security_filters", metadata, autoload=True)
|
||||
statement = filters.update().values(
|
||||
filter_type=utils.RowLevelSecurityFilterType.REGULAR.value
|
||||
)
|
||||
bind.execute(statement)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("row_level_security_filters") as batch_op:
|
||||
batch_op.drop_index(op.f("ix_row_level_security_filters_filter_type"),)
|
||||
batch_op.drop_column("filter_type")
|
||||
batch_op.drop_column("group_key")
|
||||
|
|
@ -36,7 +36,7 @@ from flask_appbuilder.security.views import (
|
|||
ViewMenuModelView,
|
||||
)
|
||||
from flask_appbuilder.widgets import ListWidget
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import Query as SqlaQuery
|
||||
|
|
@ -46,7 +46,7 @@ from superset.connectors.connector_registry import ConnectorRegistry
|
|||
from superset.constants import RouteMethod
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.utils.core import DatasourceName
|
||||
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.common.query_context import QueryContext
|
||||
|
|
@ -62,7 +62,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class SupersetSecurityListWidget(ListWidget):
|
||||
"""
|
||||
Redeclaring to avoid circular imports
|
||||
Redeclaring to avoid circular imports
|
||||
"""
|
||||
|
||||
template = "superset/fab_overrides/list.html"
|
||||
|
|
@ -70,8 +70,8 @@ class SupersetSecurityListWidget(ListWidget):
|
|||
|
||||
class SupersetRoleListWidget(ListWidget):
|
||||
"""
|
||||
Role model view from FAB already uses a custom list widget override
|
||||
So we override the override
|
||||
Role model view from FAB already uses a custom list widget override
|
||||
So we override the override
|
||||
"""
|
||||
|
||||
template = "superset/fab_overrides/list_role.html"
|
||||
|
|
@ -1012,8 +1012,23 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
.filter(assoc_user_role.c.user_id == g.user.id)
|
||||
.subquery()
|
||||
)
|
||||
filter_roles = (
|
||||
regular_filter_roles = (
|
||||
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
|
||||
.join(RowLevelSecurityFilter)
|
||||
.filter(
|
||||
RowLevelSecurityFilter.filter_type
|
||||
== RowLevelSecurityFilterType.REGULAR
|
||||
)
|
||||
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
|
||||
.subquery()
|
||||
)
|
||||
base_filter_roles = (
|
||||
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
|
||||
.join(RowLevelSecurityFilter)
|
||||
.filter(
|
||||
RowLevelSecurityFilter.filter_type
|
||||
== RowLevelSecurityFilterType.BASE
|
||||
)
|
||||
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
|
||||
.subquery()
|
||||
)
|
||||
|
|
@ -1024,10 +1039,25 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
)
|
||||
query = (
|
||||
self.get_session.query(
|
||||
RowLevelSecurityFilter.id, RowLevelSecurityFilter.clause
|
||||
RowLevelSecurityFilter.id,
|
||||
RowLevelSecurityFilter.group_key,
|
||||
RowLevelSecurityFilter.clause,
|
||||
)
|
||||
.filter(RowLevelSecurityFilter.id.in_(filter_tables))
|
||||
.filter(RowLevelSecurityFilter.id.in_(filter_roles))
|
||||
.filter(
|
||||
or_(
|
||||
and_(
|
||||
RowLevelSecurityFilter.filter_type
|
||||
== RowLevelSecurityFilterType.REGULAR,
|
||||
RowLevelSecurityFilter.id.in_(regular_filter_roles),
|
||||
),
|
||||
and_(
|
||||
RowLevelSecurityFilter.filter_type
|
||||
== RowLevelSecurityFilterType.BASE,
|
||||
RowLevelSecurityFilter.id.notin_(base_filter_roles),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
return query.all()
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
{#
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
#}
|
||||
{% extends 'appbuilder/general/widgets/base_list.html' %}
|
||||
{% import 'appbuilder/general/lib.html' as lib %}
|
||||
|
||||
{% block begin_content scoped %}
|
||||
<div class="table-responsive">
|
||||
<table class="table table-hover">
|
||||
{% endblock %}
|
||||
|
||||
{% block begin_loop_header scoped %}
|
||||
<thead>
|
||||
<tr>
|
||||
{% if actions %}
|
||||
<th class="action_checkboxes">
|
||||
<input id="check_all" class="action_check_all" name="check_all" type="checkbox">
|
||||
</th>
|
||||
{% endif %}
|
||||
|
||||
{% if can_show or can_edit or can_delete %}
|
||||
<th class="col-md-1 col-lg-1 col-sm-1" ></th>
|
||||
{% endif %}
|
||||
|
||||
{% for item in include_columns %}
|
||||
{% if item in order_columns %}
|
||||
{% set res = item | get_link_order(modelview_name) %}
|
||||
{% if res == 2 %}
|
||||
<th><a href={{ item | link_order(modelview_name) }}>{{label_columns.get(item)}}
|
||||
<i class="fa fa-chevron-up pull-right"></i></a></th>
|
||||
{% elif res == 1 %}
|
||||
<th><a href={{ item | link_order(modelview_name) }}>{{label_columns.get(item)}}
|
||||
<i class="fa fa-chevron-down pull-right"></i></a></th>
|
||||
{% else %}
|
||||
<th><a href={{ item | link_order(modelview_name) }}>{{label_columns.get(item)}}
|
||||
<i class="fa fa-arrows-v pull-right"></i></a></th>
|
||||
{% endif %}
|
||||
{% else %}
|
||||
<th>{{label_columns.get(item)}}</th>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</tr>
|
||||
</thead>
|
||||
{% endblock %}
|
||||
|
||||
{% block begin_loop_values %}
|
||||
{% for item in value_columns %}
|
||||
{% set pk = pks[loop.index-1] %}
|
||||
<tr>
|
||||
{% if actions %}
|
||||
<td>
|
||||
<input id="{{pk}}" class="action_check" name="rowid" value="{{pk}}" type="checkbox">
|
||||
</td>
|
||||
{% endif %}
|
||||
{% if can_show or can_edit or can_delete %}
|
||||
<td><center>
|
||||
{{ lib.btn_crud(can_show, can_edit, can_delete, pk, modelview_name, filters) }}
|
||||
</center></td>
|
||||
{% endif %}
|
||||
{% for value in include_columns %}
|
||||
<td>
|
||||
{% if value == "roles" and item["filter_type"] == "Base" and not item[value] %}
|
||||
All
|
||||
{% elif value == "roles" and item["filter_type"] == 'Base' %}
|
||||
Not {{ item[value] }}
|
||||
{% elif value == "roles" and item["filter_type"] == 'Regular' and not item[value] %}
|
||||
None
|
||||
{% elif value == "group_key" and item[value] == None %}
|
||||
{% else %}
|
||||
{{ item[value] }}
|
||||
{% endif %}
|
||||
</td>
|
||||
{% endfor %}
|
||||
</tr>
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
|
||||
{% block end_content scoped %}
|
||||
</table>
|
||||
</div>
|
||||
{% endblock %}
|
||||
|
|
@ -1547,3 +1547,8 @@ class PostProcessingContributionOrientation(str, Enum):
|
|||
class AdhocMetricExpressionType(str, Enum):
|
||||
SIMPLE = "SIMPLE"
|
||||
SQL = "SQL"
|
||||
|
||||
|
||||
class RowLevelSecurityFilterType(str, Enum):
|
||||
REGULAR = "Regular"
|
||||
BASE = "Base"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
# under the License.
|
||||
# isort:skip_file
|
||||
import inspect
|
||||
import re
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
|
@ -1009,70 +1010,116 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
"""
|
||||
|
||||
rls_entry = None
|
||||
query_obj = dict(
|
||||
groupby=[],
|
||||
metrics=[],
|
||||
filter=[],
|
||||
is_timeseries=False,
|
||||
columns=["value"],
|
||||
granularity=None,
|
||||
from_dttm=None,
|
||||
to_dttm=None,
|
||||
extras={},
|
||||
)
|
||||
NAME_AB_ROLE = "NameAB"
|
||||
NAME_Q_ROLE = "NameQ"
|
||||
NAMES_A_REGEX = re.compile(r"name like 'A%'")
|
||||
NAMES_B_REGEX = re.compile(r"name like 'B%'")
|
||||
NAMES_Q_REGEX = re.compile(r"name like 'Q%'")
|
||||
BASE_FILTER_REGEX = re.compile(r"gender = 'boy'")
|
||||
|
||||
def setUp(self):
|
||||
session = db.session
|
||||
|
||||
# Create the RowLevelSecurityFilter
|
||||
self.rls_entry = RowLevelSecurityFilter()
|
||||
self.rls_entry.tables.extend(
|
||||
# Create roles
|
||||
security_manager.add_role(self.NAME_AB_ROLE)
|
||||
security_manager.add_role(self.NAME_Q_ROLE)
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
gamma_user.roles.append(security_manager.find_role(self.NAME_AB_ROLE))
|
||||
gamma_user.roles.append(security_manager.find_role(self.NAME_Q_ROLE))
|
||||
self.create_user_with_roles("NoRlsRoleUser", ["Gamma"])
|
||||
session.commit()
|
||||
|
||||
# Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
|
||||
self.rls_entry1 = RowLevelSecurityFilter()
|
||||
self.rls_entry1.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
|
||||
.all()
|
||||
)
|
||||
self.rls_entry.clause = "value > {{ cache_key_wrapper(1) }}"
|
||||
self.rls_entry.roles.append(
|
||||
security_manager.find_role("Gamma")
|
||||
) # db.session.query(Role).filter_by(name="Gamma").first())
|
||||
self.rls_entry.roles.append(security_manager.find_role("Alpha"))
|
||||
db.session.add(self.rls_entry)
|
||||
self.rls_entry1.filter_type = "Regular"
|
||||
self.rls_entry1.clause = "value > {{ cache_key_wrapper(1) }}"
|
||||
self.rls_entry1.group_key = None
|
||||
self.rls_entry1.roles.append(security_manager.find_role("Gamma"))
|
||||
self.rls_entry1.roles.append(security_manager.find_role("Alpha"))
|
||||
db.session.add(self.rls_entry1)
|
||||
|
||||
# Create regular RowLevelSecurityFilter (birth_names name starts with A or B)
|
||||
self.rls_entry2 = RowLevelSecurityFilter()
|
||||
self.rls_entry2.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||
.all()
|
||||
)
|
||||
self.rls_entry2.filter_type = "Regular"
|
||||
self.rls_entry2.clause = "name like 'A%' or name like 'B%'"
|
||||
self.rls_entry2.group_key = "name"
|
||||
self.rls_entry2.roles.append(security_manager.find_role("NameAB"))
|
||||
db.session.add(self.rls_entry2)
|
||||
|
||||
# Create Regular RowLevelSecurityFilter (birth_names name starts with Q)
|
||||
self.rls_entry3 = RowLevelSecurityFilter()
|
||||
self.rls_entry3.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||
.all()
|
||||
)
|
||||
self.rls_entry3.filter_type = "Regular"
|
||||
self.rls_entry3.clause = "name like 'Q%'"
|
||||
self.rls_entry3.group_key = "name"
|
||||
self.rls_entry3.roles.append(security_manager.find_role("NameQ"))
|
||||
db.session.add(self.rls_entry3)
|
||||
|
||||
# Create Base RowLevelSecurityFilter (birth_names boys)
|
||||
self.rls_entry4 = RowLevelSecurityFilter()
|
||||
self.rls_entry4.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||
.all()
|
||||
)
|
||||
self.rls_entry4.filter_type = "Base"
|
||||
self.rls_entry4.clause = "gender = 'boy'"
|
||||
self.rls_entry4.group_key = "gender"
|
||||
self.rls_entry4.roles.append(security_manager.find_role("Admin"))
|
||||
db.session.add(self.rls_entry4)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def tearDown(self):
|
||||
session = db.session
|
||||
session.delete(self.rls_entry)
|
||||
session.delete(self.rls_entry1)
|
||||
session.delete(self.rls_entry2)
|
||||
session.delete(self.rls_entry3)
|
||||
session.delete(self.rls_entry4)
|
||||
session.delete(security_manager.find_role("NameAB"))
|
||||
session.delete(security_manager.find_role("NameQ"))
|
||||
session.delete(self.get_user("NoRlsRoleUser"))
|
||||
session.commit()
|
||||
|
||||
# Do another test to make sure it doesn't alter another query
|
||||
def test_rls_filter_alters_query(self):
|
||||
g.user = self.get_user(
|
||||
username="alpha"
|
||||
) # self.login() doesn't actually set the user
|
||||
def test_rls_filter_alters_energy_query(self):
|
||||
g.user = self.get_user(username="alpha")
|
||||
tbl = self.get_table_by_name("energy_usage")
|
||||
query_obj = dict(
|
||||
groupby=[],
|
||||
metrics=[],
|
||||
filter=[],
|
||||
is_timeseries=False,
|
||||
columns=["value"],
|
||||
granularity=None,
|
||||
from_dttm=None,
|
||||
to_dttm=None,
|
||||
extras={},
|
||||
)
|
||||
sql = tbl.get_query_str(query_obj)
|
||||
assert tbl.get_extra_cache_keys(query_obj) == [1]
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
assert tbl.get_extra_cache_keys(self.query_obj) == [1]
|
||||
assert "value > 1" in sql
|
||||
|
||||
def test_rls_filter_doesnt_alter_query(self):
|
||||
def test_rls_filter_doesnt_alter_energy_query(self):
|
||||
g.user = self.get_user(
|
||||
username="admin"
|
||||
) # self.login() doesn't actually set the user
|
||||
tbl = self.get_table_by_name("energy_usage")
|
||||
query_obj = dict(
|
||||
groupby=[],
|
||||
metrics=[],
|
||||
filter=[],
|
||||
is_timeseries=False,
|
||||
columns=["value"],
|
||||
granularity=None,
|
||||
from_dttm=None,
|
||||
to_dttm=None,
|
||||
extras={},
|
||||
)
|
||||
sql = tbl.get_query_str(query_obj)
|
||||
assert tbl.get_extra_cache_keys(query_obj) == []
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
assert tbl.get_extra_cache_keys(self.query_obj) == []
|
||||
assert "value > 1" not in sql
|
||||
|
||||
def test_multiple_table_filter_alters_another_tables_query(self):
|
||||
|
|
@ -1080,17 +1127,41 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
username="alpha"
|
||||
) # self.login() doesn't actually set the user
|
||||
tbl = self.get_table_by_name("unicode_test")
|
||||
query_obj = dict(
|
||||
groupby=[],
|
||||
metrics=[],
|
||||
filter=[],
|
||||
is_timeseries=False,
|
||||
columns=["value"],
|
||||
granularity=None,
|
||||
from_dttm=None,
|
||||
to_dttm=None,
|
||||
extras={},
|
||||
)
|
||||
sql = tbl.get_query_str(query_obj)
|
||||
assert tbl.get_extra_cache_keys(query_obj) == [1]
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
assert tbl.get_extra_cache_keys(self.query_obj) == [1]
|
||||
assert "value > 1" in sql
|
||||
|
||||
def test_rls_filter_alters_gamma_birth_names_query(self):
|
||||
g.user = self.get_user(username="gamma")
|
||||
tbl = self.get_table_by_name("birth_names")
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
|
||||
# establish that the filters are grouped together correctly with
|
||||
# ANDs, ORs and parens in the correct place
|
||||
assert (
|
||||
"WHERE ((name like 'A%'\n or name like 'B%')\n OR (name like 'Q%'))\n AND (gender = 'boy');"
|
||||
in sql
|
||||
)
|
||||
|
||||
def test_rls_filter_alters_no_role_user_birth_names_query(self):
|
||||
g.user = self.get_user(username="NoRlsRoleUser")
|
||||
tbl = self.get_table_by_name("birth_names")
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
|
||||
# gamma's filters should not be present query
|
||||
assert not self.NAMES_A_REGEX.search(sql)
|
||||
assert not self.NAMES_B_REGEX.search(sql)
|
||||
assert not self.NAMES_Q_REGEX.search(sql)
|
||||
# base query should be present
|
||||
assert self.BASE_FILTER_REGEX.search(sql)
|
||||
|
||||
def test_rls_filter_doesnt_alter_admin_birth_names_query(self):
|
||||
g.user = self.get_user(username="admin")
|
||||
tbl = self.get_table_by_name("birth_names")
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
|
||||
# no filters are applied for admin user
|
||||
assert not self.NAMES_A_REGEX.search(sql)
|
||||
assert not self.NAMES_B_REGEX.search(sql)
|
||||
assert not self.NAMES_Q_REGEX.search(sql)
|
||||
assert not self.BASE_FILTER_REGEX.search(sql)
|
||||
|
|
|
|||
Loading…
Reference in New Issue