fix: adhoc metrics (#30202)
This commit is contained in:
parent
ef0ede7c13
commit
0db59b45b8
|
|
@ -1533,6 +1533,7 @@ class SqlaTable(
|
|||
expression = self._process_sql_expression(
|
||||
expression=metric["sqlExpression"],
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
|
@ -1566,6 +1567,7 @@ class SqlaTable(
|
|||
expression = self._process_sql_expression(
|
||||
expression=col["sqlExpression"],
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ from superset.exceptions import (
|
|||
ColumnNotFoundException,
|
||||
QueryClauseValidationException,
|
||||
QueryObjectValidationError,
|
||||
SupersetParseError,
|
||||
SupersetSecurityException,
|
||||
)
|
||||
from superset.extensions import feature_flag_manager
|
||||
|
|
@ -112,6 +113,7 @@ ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
|
|||
def validate_adhoc_subquery(
|
||||
sql: str,
|
||||
database_id: int,
|
||||
engine: str,
|
||||
default_schema: str,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -126,7 +128,12 @@ def validate_adhoc_subquery(
|
|||
"""
|
||||
statements = []
|
||||
for statement in sqlparse.parse(sql):
|
||||
if has_table_query(statement):
|
||||
try:
|
||||
has_table = has_table_query(str(statement), engine)
|
||||
except SupersetParseError:
|
||||
has_table = True
|
||||
|
||||
if has_table:
|
||||
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
|
|
@ -135,7 +142,9 @@ def validate_adhoc_subquery(
|
|||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
# TODO (betodealmeida): reimplement with sqlglot
|
||||
statement = insert_rls_in_predicate(statement, database_id, default_schema)
|
||||
|
||||
statements.append(statement)
|
||||
|
||||
return ";\n".join(str(statement) for statement in statements)
|
||||
|
|
@ -810,10 +819,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
# for datasources of type query
|
||||
return []
|
||||
|
||||
def _process_sql_expression(
|
||||
def _process_sql_expression( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
expression: Optional[str],
|
||||
database_id: int,
|
||||
engine: str,
|
||||
schema: str,
|
||||
template_processor: Optional[BaseTemplateProcessor],
|
||||
) -> Optional[str]:
|
||||
|
|
@ -823,6 +833,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
expression = validate_adhoc_subquery(
|
||||
expression,
|
||||
database_id,
|
||||
engine,
|
||||
schema,
|
||||
)
|
||||
try:
|
||||
|
|
@ -1108,6 +1119,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
expression = self._process_sql_expression(
|
||||
expression=metric["sqlExpression"],
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
|
@ -1551,6 +1563,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
col["sqlExpression"] = self._process_sql_expression(
|
||||
expression=col["sqlExpression"],
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
|
@ -1613,6 +1626,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
selected = validate_adhoc_subquery(
|
||||
selected,
|
||||
self.database_id,
|
||||
self.database.backend,
|
||||
self.schema,
|
||||
)
|
||||
outer = literal_column(f"({selected})")
|
||||
|
|
@ -1639,6 +1653,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
selected = validate_adhoc_subquery(
|
||||
_sql,
|
||||
self.database_id,
|
||||
self.database.backend,
|
||||
self.schema,
|
||||
)
|
||||
|
||||
|
|
@ -1915,6 +1930,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
where = self._process_sql_expression(
|
||||
expression=where,
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
|
@ -1933,6 +1949,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
having = self._process_sql_expression(
|
||||
expression=having,
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ class Query(
|
|||
expression = self._process_sql_expression(
|
||||
expression=col["sqlExpression"],
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ from superset.sql.parse import (
|
|||
extract_tables_from_statement,
|
||||
SQLGLOT_DIALECTS,
|
||||
SQLScript,
|
||||
SQLStatement,
|
||||
Table,
|
||||
)
|
||||
from superset.utils.backports import StrEnum
|
||||
|
|
@ -570,46 +571,31 @@ class InsertRLSState(StrEnum):
|
|||
FOUND_TABLE = "FOUND_TABLE"
|
||||
|
||||
|
||||
def has_table_query(token_list: TokenList) -> bool:
|
||||
def has_table_query(expression: str, engine: str) -> bool:
|
||||
"""
|
||||
Return if a statement has a query reading from a table.
|
||||
|
||||
>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
|
||||
>>> has_table_query("COUNT(*)", "postgresql")
|
||||
False
|
||||
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
|
||||
>>> has_table_query("SELECT * FROM table", "postgresql")
|
||||
True
|
||||
|
||||
Note that queries reading from constant values return false:
|
||||
|
||||
>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
|
||||
>>> has_table_query("SELECT * FROM (SELECT 1)", "postgresql")
|
||||
False
|
||||
|
||||
"""
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
# Ignore comments
|
||||
if isinstance(token, sqlparse.sql.Comment):
|
||||
continue
|
||||
# Remove trailing semicolon.
|
||||
expression = expression.strip().rstrip(";")
|
||||
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList) and has_table_query(token):
|
||||
return True
|
||||
# Wrap the expression in parentheses if it's not already.
|
||||
if not expression.startswith("("):
|
||||
expression = f"({expression})"
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
state = InsertRLSState.SEEN_SOURCE
|
||||
|
||||
# Found identifier/keyword after FROM/JOIN
|
||||
elif state == InsertRLSState.SEEN_SOURCE and (
|
||||
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
|
||||
):
|
||||
return True
|
||||
|
||||
# Found nothing, leaving source
|
||||
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
return False
|
||||
sql = f"SELECT {expression}"
|
||||
statement = SQLStatement(sql, engine)
|
||||
return any(statement.tables)
|
||||
|
||||
|
||||
def add_table_name(rls: TokenList, table: str) -> None:
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from superset.utils.database import ( # noqa: F401
|
|||
get_main_database,
|
||||
)
|
||||
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
|
||||
from tests.integration_tests.conftest import with_feature_flags
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices, # noqa: F401
|
||||
|
|
@ -585,6 +586,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
|
|||
assert "INCORRECT SQL" in rv.json.get("error")
|
||||
|
||||
|
||||
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
|
||||
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
|
||||
|
|
@ -649,6 +651,7 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
|
|||
assert rv.json["result"]["rowcount"] == 0
|
||||
|
||||
|
||||
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
|
||||
def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
|
||||
|
|
@ -669,6 +672,7 @@ def test_get_samples_with_time_filter(test_client, login_as_admin, physical_data
|
|||
assert rv.json["result"]["total_count"] == 2
|
||||
|
||||
|
||||
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
|
||||
def test_get_samples_with_multiple_filters(
|
||||
test_client, login_as_admin, physical_dataset
|
||||
):
|
||||
|
|
|
|||
|
|
@ -42,7 +42,11 @@ from superset.utils.core import (
|
|||
)
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.conftest import only_postgresql, only_sqlite
|
||||
from tests.integration_tests.conftest import (
|
||||
only_postgresql,
|
||||
only_sqlite,
|
||||
with_feature_flags,
|
||||
)
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices, # noqa: F401
|
||||
load_birth_names_data, # noqa: F401
|
||||
|
|
@ -858,6 +862,7 @@ def test_non_time_column_with_time_grain(app_context, physical_dataset):
|
|||
assert df["COL2 ALIAS"][0] == "a"
|
||||
|
||||
|
||||
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
|
||||
def test_special_chars_in_column_name(app_context, physical_dataset):
|
||||
qc = QueryContextFactory().create(
|
||||
datasource={
|
||||
|
|
|
|||
|
|
@ -1286,46 +1286,66 @@ def test_sqlparse_issue_652():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,expected",
|
||||
("engine", "sql", "expected"),
|
||||
[
|
||||
("SELECT * FROM table", True),
|
||||
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
|
||||
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
|
||||
("COUNT(*)", False),
|
||||
("SELECT a FROM (SELECT 1 AS a)", False),
|
||||
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
|
||||
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
|
||||
("SELECT * FROM other_table", True),
|
||||
("extract(HOUR from from_unixtime(hour_ts)", False),
|
||||
("(SELECT * FROM table)", True),
|
||||
("(SELECT COUNT(DISTINCT name) from birth_names)", True),
|
||||
("postgresql", "extract(HOUR from from_unixtime(hour_ts))", False),
|
||||
("postgresql", "SELECT * FROM table", True),
|
||||
("postgresql", "(SELECT * FROM table)", True),
|
||||
(
|
||||
"postgresql",
|
||||
"SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"postgresql",
|
||||
"(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)",
|
||||
True,
|
||||
),
|
||||
("postgresql", "COUNT(*)", False),
|
||||
("postgresql", "SELECT a FROM (SELECT 1 AS a)", False),
|
||||
("postgresql", "SELECT a FROM (SELECT 1 AS a) JOIN table", True),
|
||||
(
|
||||
"postgresql",
|
||||
"SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar",
|
||||
False,
|
||||
),
|
||||
("postgresql", "SELECT * FROM other_table", True),
|
||||
("postgresql", "(SELECT COUNT(DISTINCT name) from birth_names)", True),
|
||||
(
|
||||
"postgresql",
|
||||
"(SELECT table_name FROM information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"postgresql",
|
||||
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"postgresql",
|
||||
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"postgresql",
|
||||
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"postgresql",
|
||||
"((select users.id from (select 'majorie' as a) b, users where b.a = users.name and users.name in ('majorie') limit 1) like 'U%')",
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_has_table_query(sql: str, expected: bool) -> None:
|
||||
def test_has_table_query(engine: str, sql: str, expected: bool) -> None:
|
||||
"""
|
||||
Test if a given statement queries a table.
|
||||
|
||||
This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
|
||||
row-level security.
|
||||
"""
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
assert has_table_query(statement) == expected
|
||||
assert has_table_query(sql, engine) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
Loading…
Reference in New Issue