feat: safer insert RLS (#20323)
This commit is contained in:
parent
90e210892b
commit
2bd611916d
|
|
@ -68,7 +68,12 @@ from superset.exceptions import (
|
|||
)
|
||||
from superset.extensions import feature_flag_manager
|
||||
from superset.jinja_context import BaseTemplateProcessor
|
||||
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, sanitize_clause
|
||||
from superset.sql_parse import (
|
||||
has_table_query,
|
||||
insert_rls_in_predicate,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
)
|
||||
from superset.superset_typing import (
|
||||
AdhocMetric,
|
||||
Column as ColumnTyping,
|
||||
|
|
@ -128,7 +133,7 @@ def validate_adhoc_subquery(
|
|||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
statement = insert_rls(statement, database_id, default_schema)
|
||||
statement = insert_rls_in_predicate(statement, database_id, default_schema)
|
||||
statements.append(statement)
|
||||
|
||||
return ";\n".join(str(statement) for statement in statements)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,12 @@ from superset.extensions import celery_app
|
|||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery
|
||||
from superset.sql_parse import (
|
||||
CtasMethod,
|
||||
insert_rls_as_subquery,
|
||||
insert_rls_in_predicate,
|
||||
ParsedQuery,
|
||||
)
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
from superset.sqllab.utils import write_ipc_buffer
|
||||
from superset.utils.celery import session_scope
|
||||
|
|
@ -191,7 +196,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
|
|||
return handle_query_error(ex, query, session)
|
||||
|
||||
|
||||
def execute_sql_statement( # pylint: disable=too-many-arguments
|
||||
def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-locals
|
||||
sql_statement: str,
|
||||
query: Query,
|
||||
session: Session,
|
||||
|
|
@ -205,6 +210,16 @@ def execute_sql_statement( # pylint: disable=too-many-arguments
|
|||
|
||||
parsed_query = ParsedQuery(sql_statement)
|
||||
if is_feature_enabled("RLS_IN_SQLLAB"):
|
||||
# There are two ways to insert RLS: either replacing the table with a subquery
|
||||
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
||||
# safer, but not supported in all databases.
|
||||
insert_rls = (
|
||||
insert_rls_as_subquery
|
||||
if database.db_engine_spec.allows_subqueries
|
||||
and database.db_engine_spec.allows_alias_in_select
|
||||
else insert_rls_in_predicate
|
||||
)
|
||||
|
||||
# Insert any applicable RLS predicates
|
||||
parsed_query = ParsedQuery(
|
||||
str(
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from sqlparse.tokens import (
|
|||
Punctuation,
|
||||
String,
|
||||
Whitespace,
|
||||
Wildcard,
|
||||
)
|
||||
from sqlparse.utils import imt
|
||||
|
||||
|
|
@ -660,18 +661,29 @@ def get_rls_for_table(
|
|||
return None
|
||||
|
||||
rls = sqlparse.parse(predicate)[0]
|
||||
add_table_name(rls, str(dataset))
|
||||
add_table_name(rls, table.table)
|
||||
|
||||
return rls
|
||||
|
||||
|
||||
def insert_rls(
|
||||
def insert_rls_as_subquery(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
|
||||
The RLS predicate is applied as subquery replacing the original table:
|
||||
|
||||
before: SELECT * FROM some_table WHERE 1=1
|
||||
after: SELECT * FROM (
|
||||
SELECT * FROM some_table WHERE some_table.id=42
|
||||
) AS some_table
|
||||
WHERE 1=1
|
||||
|
||||
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
|
||||
databases.
|
||||
"""
|
||||
rls: Optional[TokenList] = None
|
||||
state = InsertRLSState.SCANNING
|
||||
|
|
@ -679,7 +691,98 @@ def insert_rls(
|
|||
# Recurse into child token list
|
||||
if isinstance(token, TokenList):
|
||||
i = token_list.tokens.index(token)
|
||||
token_list.tokens[i] = insert_rls(token, database_id, default_schema)
|
||||
token_list.tokens[i] = insert_rls_as_subquery(
|
||||
token,
|
||||
database_id,
|
||||
default_schema,
|
||||
)
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
state = InsertRLSState.SEEN_SOURCE
|
||||
|
||||
# Found identifier/keyword after FROM/JOIN, test for table
|
||||
elif state == InsertRLSState.SEEN_SOURCE and (
|
||||
isinstance(token, Identifier) or token.ttype == Keyword
|
||||
):
|
||||
rls = get_rls_for_table(token, database_id, default_schema)
|
||||
if rls:
|
||||
# replace table with subquery
|
||||
subquery_alias = (
|
||||
token.tokens[-1].value
|
||||
if isinstance(token, Identifier)
|
||||
else token.value
|
||||
)
|
||||
i = token_list.tokens.index(token)
|
||||
|
||||
# strip alias from table name
|
||||
if isinstance(token, Identifier) and token.has_alias():
|
||||
whitespace_index = token.token_next_by(t=Whitespace)[0]
|
||||
token.tokens = token.tokens[:whitespace_index]
|
||||
|
||||
token_list.tokens[i] = Identifier(
|
||||
[
|
||||
Parenthesis(
|
||||
[
|
||||
Token(Punctuation, "("),
|
||||
Token(DML, "SELECT"),
|
||||
Token(Whitespace, " "),
|
||||
Token(Wildcard, "*"),
|
||||
Token(Whitespace, " "),
|
||||
Token(Keyword, "FROM"),
|
||||
Token(Whitespace, " "),
|
||||
token,
|
||||
Token(Whitespace, " "),
|
||||
Where(
|
||||
[
|
||||
Token(Keyword, "WHERE"),
|
||||
Token(Whitespace, " "),
|
||||
rls,
|
||||
]
|
||||
),
|
||||
Token(Punctuation, ")"),
|
||||
]
|
||||
),
|
||||
Token(Whitespace, " "),
|
||||
Token(Keyword, "AS"),
|
||||
Token(Whitespace, " "),
|
||||
Identifier([Token(Name, subquery_alias)]),
|
||||
]
|
||||
)
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
# Found nothing, leaving source
|
||||
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
return token_list
|
||||
|
||||
|
||||
def insert_rls_in_predicate(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
|
||||
The RLS predicate is ``AND``ed to any existing predicates:
|
||||
|
||||
before: SELECT * FROM some_table WHERE 1=1
|
||||
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
|
||||
|
||||
"""
|
||||
rls: Optional[TokenList] = None
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList):
|
||||
i = token_list.tokens.index(token)
|
||||
token_list.tokens[i] = insert_rls_in_predicate(
|
||||
token,
|
||||
database_id,
|
||||
default_schema,
|
||||
)
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ def test_execute_sql_statement_with_rls(
|
|||
cursor = mocker.MagicMock()
|
||||
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
|
||||
mocker.patch(
|
||||
"superset.sql_lab.insert_rls",
|
||||
"superset.sql_lab.insert_rls_as_subquery",
|
||||
return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0],
|
||||
)
|
||||
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
|
||||
|
|
@ -112,12 +112,12 @@ def test_execute_sql_statement_with_rls(
|
|||
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
||||
|
||||
|
||||
def test_sql_lab_insert_rls(
|
||||
def test_sql_lab_insert_rls_as_subquery(
|
||||
mocker: MockerFixture,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Integration test for `insert_rls`.
|
||||
Integration test for `insert_rls_as_subquery`.
|
||||
"""
|
||||
from flask_appbuilder.security.sqla.models import Role, User
|
||||
|
||||
|
|
@ -213,4 +213,7 @@ def test_sql_lab_insert_rls(
|
|||
| 2 | 8 |
|
||||
| 3 | 9 |""".strip()
|
||||
)
|
||||
assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
|
||||
assert (
|
||||
query.executed_sql
|
||||
== "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
|
||||
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -31,7 +31,8 @@ from superset.sql_parse import (
|
|||
extract_table_references,
|
||||
get_rls_for_table,
|
||||
has_table_query,
|
||||
insert_rls,
|
||||
insert_rls_as_subquery,
|
||||
insert_rls_in_predicate,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
strip_comments_from_sql,
|
||||
|
|
@ -1318,6 +1319,184 @@ def test_has_table_query(sql: str, expected: bool) -> None:
|
|||
assert has_table_query(statement) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,table,rls,expected",
|
||||
[
|
||||
# Basic test
|
||||
(
|
||||
"SELECT * FROM some_table WHERE 1=1",
|
||||
"some_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM some_table WHERE some_table.id=42) "
|
||||
"AS some_table WHERE 1=1"
|
||||
),
|
||||
),
|
||||
# Here "table" is a reserved word; since sqlparse is too aggressive when
|
||||
# characterizing reserved words we need to support them even when not quoted.
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM (SELECT * FROM table WHERE table.id=42) AS table WHERE 1=1",
|
||||
),
|
||||
# RLS is only applied to queries reading from the associated table
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
"other_table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM other_table WHERE 1=1",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM other_table WHERE 1=1",
|
||||
),
|
||||
# JOINs are supported
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table JOIN "
|
||||
"(SELECT * FROM other_table WHERE other_table.id=42) AS other_table "
|
||||
"ON table.id = other_table.id"
|
||||
),
|
||||
),
|
||||
# Subqueries
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM other_table)",
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM ("
|
||||
"SELECT * FROM other_table WHERE other_table.id=42"
|
||||
") AS other_table)"
|
||||
),
|
||||
),
|
||||
# UNION
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
"table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM table WHERE table.id=42) AS table "
|
||||
"UNION ALL SELECT * FROM other_table"
|
||||
),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM ("
|
||||
"SELECT * FROM other_table WHERE other_table.id=42) AS other_table"
|
||||
),
|
||||
),
|
||||
# When comparing fully qualified table names (eg, schema.table) to simple names
|
||||
# (eg, table) we are also conservative, assuming the schema is the same, since
|
||||
# we don't have information on the default schema.
|
||||
(
|
||||
"SELECT * FROM schema.table_name",
|
||||
"table_name",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM schema.table_name "
|
||||
"WHERE table_name.id=42) AS table_name"
|
||||
),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM schema.table_name",
|
||||
"schema.table_name",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM schema.table_name "
|
||||
"WHERE schema.table_name.id=42) AS table_name"
|
||||
),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table_name",
|
||||
"schema.table_name",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM table_name WHERE "
|
||||
"schema.table_name.id=42) AS table_name"
|
||||
),
|
||||
),
|
||||
# Aliases
|
||||
(
|
||||
"SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col",
|
||||
"tbl_a",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT a.*, b.* FROM "
|
||||
"(SELECT * FROM tbl_a WHERE tbl_a.id=42) AS a "
|
||||
"INNER JOIN tbl_b AS b "
|
||||
"ON a.col = b.col"
|
||||
),
|
||||
),
|
||||
(
|
||||
"SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col",
|
||||
"tbl_a",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT a.*, b.* FROM "
|
||||
"(SELECT * FROM tbl_a WHERE tbl_a.id=42) AS a "
|
||||
"INNER JOIN tbl_b b ON a.col = b.col"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_insert_rls_as_subquery(
|
||||
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
|
||||
) -> None:
|
||||
"""
|
||||
Insert into a statement a given RLS condition associated with a table.
|
||||
"""
|
||||
condition = sqlparse.parse(rls)[0]
|
||||
add_table_name(condition, table)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_rls_for_table(
|
||||
candidate: Token,
|
||||
database_id: int,
|
||||
default_schema: str,
|
||||
) -> Optional[TokenList]:
|
||||
"""
|
||||
Return the RLS ``condition`` if ``candidate`` matches ``table``.
|
||||
"""
|
||||
if not isinstance(candidate, Identifier):
|
||||
candidate = Identifier([Token(Name, candidate.value)])
|
||||
|
||||
candidate_table = ParsedQuery.get_table(candidate)
|
||||
if not candidate_table:
|
||||
return None
|
||||
candidate_table_name = (
|
||||
f"{candidate_table.schema}.{candidate_table.table}"
|
||||
if candidate_table.schema
|
||||
else candidate_table.table
|
||||
)
|
||||
for left, right in zip(
|
||||
candidate_table_name.split(".")[::-1], table.split(".")[::-1]
|
||||
):
|
||||
if left != right:
|
||||
return None
|
||||
return condition
|
||||
|
||||
mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
|
||||
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
assert (
|
||||
str(
|
||||
insert_rls_as_subquery(
|
||||
token_list=statement, database_id=1, default_schema="my_schema"
|
||||
)
|
||||
).strip()
|
||||
== expected.strip()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,table,rls,expected",
|
||||
[
|
||||
|
|
@ -1492,7 +1671,7 @@ def test_has_table_query(sql: str, expected: bool) -> None:
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_insert_rls(
|
||||
def test_insert_rls_in_predicate(
|
||||
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -1521,7 +1700,11 @@ def test_insert_rls(
|
|||
statement = sqlparse.parse(sql)[0]
|
||||
assert (
|
||||
str(
|
||||
insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
|
||||
insert_rls_in_predicate(
|
||||
token_list=statement,
|
||||
database_id=1,
|
||||
default_schema="my_schema",
|
||||
)
|
||||
).strip()
|
||||
== expected.strip()
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue