fix(dataset): use sqlglot for DML check (#31024)

This commit is contained in:
Beto Dealmeida 2024-11-22 07:21:05 -05:00 committed by GitHub
parent ccce9abf57
commit 832fed1db5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 7 deletions

View File

@ -38,7 +38,8 @@ from superset.exceptions import (
) )
from superset.models.core import Database from superset.models.core import Database
from superset.result_set import SupersetResultSet from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery, Table from superset.sql.parse import SQLScript
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -105,8 +106,8 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template( sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict dataset.sql, **dataset.template_params_dict
) )
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine) parsed_script = SQLScript(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query): if parsed_script.has_mutation():
raise SupersetSecurityException( raise SupersetSecurityException(
SupersetError( SupersetError(
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
@ -114,8 +115,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
level=ErrorLevel.ERROR, level=ErrorLevel.ERROR,
) )
) )
statements = parsed_query.get_statements() if len(parsed_script.statements) > 1:
if len(statements) > 1:
raise SupersetSecurityException( raise SupersetSecurityException(
SupersetError( SupersetError(
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
@ -127,7 +127,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
dataset.database, dataset.database,
dataset.catalog, dataset.catalog,
dataset.schema, dataset.schema,
statements[0], sql,
) )

View File

@ -1964,6 +1964,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
assert rv.status_code == 200 assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8")) data = json.loads(rv.data.decode("utf-8"))
data["result"].sort(key=lambda x: x["datasource_id"])
assert data["result"][0]["slice_name"] == "name0" assert data["result"][0]["slice_name"] == "name0"
assert data["result"][0]["datasource_id"] == 1 assert data["result"][0]["datasource_id"] == 1

View File

@ -15,9 +15,14 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from superset.connectors.sqla.utils import get_columns_description from superset.connectors.sqla.utils import (
get_columns_description,
get_virtual_table_metadata,
)
from superset.exceptions import SupersetSecurityException
# Returns column descriptions when given valid database, catalog, schema, and query # Returns column descriptions when given valid database, catalog, schema, and query
@ -89,3 +94,46 @@ def test_returns_column_descriptions(mocker: MockerFixture) -> None:
"is_dttm": False, "is_dttm": False,
}, },
] ]
def test_get_virtual_table_metadata(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function.
"""
mocker.patch(
"superset.connectors.sqla.utils.get_columns_description",
return_value=[{"name": "one", "type": "INTEGER"}],
)
dataset = mocker.MagicMock(
sql="with source as ( select 1 as one ) select * from source",
)
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql
assert get_virtual_table_metadata(dataset) == [{"name": "one", "type": "INTEGER"}]
def test_get_virtual_table_metadata_mutating(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function with mutating SQL.
"""
dataset = mocker.MagicMock(sql="DROP TABLE sample_data")
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql
with pytest.raises(SupersetSecurityException) as excinfo:
get_virtual_table_metadata(dataset)
assert str(excinfo.value) == "Only `SELECT` statements are allowed"
def test_get_virtual_table_metadata_multiple(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function with multiple statements.
"""
dataset = mocker.MagicMock(sql="SELECT 1; SELECT 2")
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql
with pytest.raises(SupersetSecurityException) as excinfo:
get_virtual_table_metadata(dataset)
assert str(excinfo.value) == "Only single queries supported"