fix: adhoc metrics (#30202)

This commit is contained in:
Beto Dealmeida 2024-10-10 16:46:17 -04:00 committed by GitHub
parent ef0ede7c13
commit 0db59b45b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 80 additions and 45 deletions

View File

@ -1533,6 +1533,7 @@ class SqlaTable(
expression = self._process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
@ -1566,6 +1567,7 @@ class SqlaTable(
expression = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)

View File

@ -63,6 +63,7 @@ from superset.exceptions import (
ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetParseError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
@ -112,6 +113,7 @@ ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
def validate_adhoc_subquery(
sql: str,
database_id: int,
engine: str,
default_schema: str,
) -> str:
"""
@ -126,7 +128,12 @@ def validate_adhoc_subquery(
"""
statements = []
for statement in sqlparse.parse(sql):
if has_table_query(statement):
try:
has_table = has_table_query(str(statement), engine)
except SupersetParseError:
has_table = True
if has_table:
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
raise SupersetSecurityException(
SupersetError(
@ -135,7 +142,9 @@ def validate_adhoc_subquery(
level=ErrorLevel.ERROR,
)
)
# TODO (betodealmeida): reimplement with sqlglot
statement = insert_rls_in_predicate(statement, database_id, default_schema)
statements.append(statement)
return ";\n".join(str(statement) for statement in statements)
@ -810,10 +819,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
# for datasources of type query
return []
def _process_sql_expression(
def _process_sql_expression( # pylint: disable=too-many-arguments
self,
expression: Optional[str],
database_id: int,
engine: str,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
@ -823,6 +833,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
expression = validate_adhoc_subquery(
expression,
database_id,
engine,
schema,
)
try:
@ -1108,6 +1119,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
expression = self._process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
@ -1551,6 +1563,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
col["sqlExpression"] = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
@ -1613,6 +1626,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.database.backend,
self.schema,
)
outer = literal_column(f"({selected})")
@ -1639,6 +1653,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
selected = validate_adhoc_subquery(
_sql,
self.database_id,
self.database.backend,
self.schema,
)
@ -1915,6 +1930,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
where = self._process_sql_expression(
expression=where,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
@ -1933,6 +1949,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
having = self._process_sql_expression(
expression=having,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)

View File

@ -374,6 +374,7 @@ class Query(
expression = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)

View File

@ -64,6 +64,7 @@ from superset.sql.parse import (
extract_tables_from_statement,
SQLGLOT_DIALECTS,
SQLScript,
SQLStatement,
Table,
)
from superset.utils.backports import StrEnum
@ -570,46 +571,31 @@ class InsertRLSState(StrEnum):
FOUND_TABLE = "FOUND_TABLE"
def has_table_query(token_list: TokenList) -> bool:
def has_table_query(expression: str, engine: str) -> bool:
"""
Return if a statement has a query reading from a table.
>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
>>> has_table_query("COUNT(*)", "postgresql")
False
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
>>> has_table_query("SELECT * FROM table", "postgresql")
True
Note that queries reading from constant values return false:
>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
>>> has_table_query("SELECT * FROM (SELECT 1)", "postgresql")
False
"""
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Ignore comments
if isinstance(token, sqlparse.sql.Comment):
continue
# Remove trailing semicolon.
expression = expression.strip().rstrip(";")
# Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True
# Wrap the expression in parentheses if it's not already.
if not expression.startswith("("):
expression = f"({expression})"
# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE
# Found identifier/keyword after FROM/JOIN
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
):
return True
# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING
return False
sql = f"SELECT {expression}"
statement = SQLStatement(sql, engine)
return any(statement.tables)
def add_table_name(rls: TokenList, table: str) -> None:

View File

@ -42,6 +42,7 @@ from superset.utils.database import ( # noqa: F401
get_main_database,
)
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
@ -585,6 +586,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
assert "INCORRECT SQL" in rv.json.get("error")
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
@ -649,6 +651,7 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
assert rv.json["result"]["rowcount"] == 0
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
@ -669,6 +672,7 @@ def test_get_samples_with_time_filter(test_client, login_as_admin, physical_data
assert rv.json["result"]["total_count"] == 2
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_with_multiple_filters(
test_client, login_as_admin, physical_dataset
):

View File

@ -42,7 +42,11 @@ from superset.utils.core import (
)
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.conftest import only_postgresql, only_sqlite
from tests.integration_tests.conftest import (
only_postgresql,
only_sqlite,
with_feature_flags,
)
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
load_birth_names_data, # noqa: F401
@ -858,6 +862,7 @@ def test_non_time_column_with_time_grain(app_context, physical_dataset):
assert df["COL2 ALIAS"][0] == "a"
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_special_chars_in_column_name(app_context, physical_dataset):
qc = QueryContextFactory().create(
datasource={

View File

@ -1286,46 +1286,66 @@ def test_sqlparse_issue_652():
@pytest.mark.parametrize(
"sql,expected",
("engine", "sql", "expected"),
[
("SELECT * FROM table", True),
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
("COUNT(*)", False),
("SELECT a FROM (SELECT 1 AS a)", False),
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
("extract(HOUR from from_unixtime(hour_ts)", False),
("(SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) from birth_names)", True),
("postgresql", "extract(HOUR from from_unixtime(hour_ts))", False),
("postgresql", "SELECT * FROM table", True),
("postgresql", "(SELECT * FROM table)", True),
(
"postgresql",
"SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)",
True,
),
(
"postgresql",
"(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)",
True,
),
("postgresql", "COUNT(*)", False),
("postgresql", "SELECT a FROM (SELECT 1 AS a)", False),
("postgresql", "SELECT a FROM (SELECT 1 AS a) JOIN table", True),
(
"postgresql",
"SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar",
False,
),
("postgresql", "SELECT * FROM other_table", True),
("postgresql", "(SELECT COUNT(DISTINCT name) from birth_names)", True),
(
"postgresql",
"(SELECT table_name FROM information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"postgresql",
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"postgresql",
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
True,
),
(
"postgresql",
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
True,
),
(
"postgresql",
"((select users.id from (select 'majorie' as a) b, users where b.a = users.name and users.name in ('majorie') limit 1) like 'U%')",
True,
),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
def test_has_table_query(engine: str, sql: str, expected: bool) -> None:
"""
Test if a given statement queries a table.
This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
row-level security.
"""
statement = sqlparse.parse(sql)[0]
assert has_table_query(statement) == expected
assert has_table_query(sql, engine) == expected
@pytest.mark.parametrize(