fix(Explore): Apply RLS at column values (#30490)
Co-authored-by: Beto Dealmeida <roberto@dealmeida.net>
This commit is contained in:
parent
0b34197815
commit
f314685a8e
|
|
@ -1309,7 +1309,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||||
)
|
)
|
||||||
return and_(*l)
|
return and_(*l)
|
||||||
|
|
||||||
def values_for_column(
|
def values_for_column( # pylint: disable=too-many-locals
|
||||||
self,
|
self,
|
||||||
column_name: str,
|
column_name: str,
|
||||||
limit: int = 10000,
|
limit: int = 10000,
|
||||||
|
|
@ -1345,6 +1345,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||||
if self.fetch_values_predicate:
|
if self.fetch_values_predicate:
|
||||||
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
|
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
|
||||||
|
|
||||||
|
rls_filters = self.get_sqla_row_level_filters(template_processor=tp)
|
||||||
|
qry = qry.where(and_(*rls_filters))
|
||||||
|
|
||||||
with self.database.get_sqla_engine() as engine:
|
with self.database.get_sqla_engine() as engine:
|
||||||
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||||
sql = self._apply_cte(sql, cte)
|
sql = self._apply_cte(sql, cte)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
from unittest.mock import ANY, patch
|
from unittest.mock import ANY, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy.sql.elements import TextClause
|
||||||
|
|
||||||
from superset import db, security_manager
|
from superset import db, security_manager
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
|
@ -176,3 +177,31 @@ class TestDatasourceApi(SupersetTestCase):
|
||||||
table.normalize_columns = False
|
table.normalize_columns = False
|
||||||
self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/") # noqa: F841
|
self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/") # noqa: F841
|
||||||
denormalize_name_mock.assert_called_with(ANY, "col2")
|
denormalize_name_mock.assert_called_with(ANY, "col2")
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("app_context", "virtual_dataset")
|
||||||
|
def test_get_column_values_with_rls(self):
|
||||||
|
self.login(ADMIN_USERNAME)
|
||||||
|
table = self.get_virtual_dataset()
|
||||||
|
with patch.object(
|
||||||
|
table, "get_sqla_row_level_filters", return_value=[TextClause("col2 = 'b'")]
|
||||||
|
):
|
||||||
|
rv = self.client.get(
|
||||||
|
f"api/v1/datasource/table/{table.id}/column/col2/values/"
|
||||||
|
)
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(response["result"], ["b"])
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("app_context", "virtual_dataset")
|
||||||
|
def test_get_column_values_with_rls_no_values(self):
|
||||||
|
self.login(ADMIN_USERNAME)
|
||||||
|
table = self.get_virtual_dataset()
|
||||||
|
with patch.object(
|
||||||
|
table, "get_sqla_row_level_filters", return_value=[TextClause("col2 = 'q'")]
|
||||||
|
):
|
||||||
|
rv = self.client.get(
|
||||||
|
f"api/v1/datasource/table/{table.id}/column/col2/values/"
|
||||||
|
)
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(response["result"], [])
|
||||||
|
|
|
||||||
|
|
@ -626,6 +626,32 @@ def test_values_for_column_on_text_column(text_column_table):
|
||||||
assert len(with_null) == 8
|
assert len(with_null) == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_values_for_column_on_text_column_with_rls(text_column_table):
|
||||||
|
with patch.object(
|
||||||
|
text_column_table,
|
||||||
|
"get_sqla_row_level_filters",
|
||||||
|
return_value=[
|
||||||
|
TextClause("foo = 'foo'"),
|
||||||
|
],
|
||||||
|
):
|
||||||
|
with_rls = text_column_table.values_for_column(column_name="foo", limit=10000)
|
||||||
|
assert with_rls == ["foo"]
|
||||||
|
assert len(with_rls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_values_for_column_on_text_column_with_rls_no_values(text_column_table):
|
||||||
|
with patch.object(
|
||||||
|
text_column_table,
|
||||||
|
"get_sqla_row_level_filters",
|
||||||
|
return_value=[
|
||||||
|
TextClause("foo = 'bar'"),
|
||||||
|
],
|
||||||
|
):
|
||||||
|
with_rls = text_column_table.values_for_column(column_name="foo", limit=10000)
|
||||||
|
assert with_rls == []
|
||||||
|
assert len(with_rls) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_filter_on_text_column(text_column_table):
|
def test_filter_on_text_column(text_column_table):
|
||||||
table = text_column_table
|
table = text_column_table
|
||||||
# null value should be replaced
|
# null value should be replaced
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
@ -85,6 +86,58 @@ def test_values_for_column(database: Database) -> None:
|
||||||
assert table.values_for_column("a") == [1, None]
|
assert table.values_for_column("a") == [1, None]
|
||||||
|
|
||||||
|
|
||||||
|
def test_values_for_column_with_rls(database: Database) -> None:
|
||||||
|
"""
|
||||||
|
Test the `values_for_column` method with RLS enabled.
|
||||||
|
"""
|
||||||
|
from sqlalchemy.sql.elements import TextClause
|
||||||
|
|
||||||
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||||
|
|
||||||
|
table = SqlaTable(
|
||||||
|
database=database,
|
||||||
|
schema=None,
|
||||||
|
table_name="t",
|
||||||
|
columns=[
|
||||||
|
TableColumn(column_name="a"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
with patch.object(
|
||||||
|
table,
|
||||||
|
"get_sqla_row_level_filters",
|
||||||
|
return_value=[
|
||||||
|
TextClause("a = 1"),
|
||||||
|
],
|
||||||
|
):
|
||||||
|
assert table.values_for_column("a") == [1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_values_for_column_with_rls_no_values(database: Database) -> None:
|
||||||
|
"""
|
||||||
|
Test the `values_for_column` method with RLS enabled and no values.
|
||||||
|
"""
|
||||||
|
from sqlalchemy.sql.elements import TextClause
|
||||||
|
|
||||||
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||||
|
|
||||||
|
table = SqlaTable(
|
||||||
|
database=database,
|
||||||
|
schema=None,
|
||||||
|
table_name="t",
|
||||||
|
columns=[
|
||||||
|
TableColumn(column_name="a"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
with patch.object(
|
||||||
|
table,
|
||||||
|
"get_sqla_row_level_filters",
|
||||||
|
return_value=[
|
||||||
|
TextClause("a = 2"),
|
||||||
|
],
|
||||||
|
):
|
||||||
|
assert table.values_for_column("a") == []
|
||||||
|
|
||||||
|
|
||||||
def test_values_for_column_calculated(
|
def test_values_for_column_calculated(
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
database: Database,
|
database: Database,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue