fix(dataset): use sqlglot for DML check (#31024)

This commit is contained in:
Beto Dealmeida 2024-11-22 07:21:05 -05:00 committed by GitHub
parent ccce9abf57
commit 832fed1db5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 7 deletions

View File

@ -38,7 +38,8 @@ from superset.exceptions import (
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery, Table
from superset.sql.parse import SQLScript
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
@ -105,8 +106,8 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
parsed_script = SQLScript(sql, engine=db_engine_spec.engine)
if parsed_script.has_mutation():
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
@ -114,8 +115,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
level=ErrorLevel.ERROR,
)
)
statements = parsed_query.get_statements()
if len(statements) > 1:
if len(parsed_script.statements) > 1:
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
@ -127,7 +127,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
dataset.database,
dataset.catalog,
dataset.schema,
statements[0],
sql,
)

View File

@ -1964,6 +1964,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
data["result"].sort(key=lambda x: x["datasource_id"])
assert data["result"][0]["slice_name"] == "name0"
assert data["result"][0]["datasource_id"] == 1

View File

@ -15,9 +15,14 @@
# specific language governing permissions and limitations
# under the License.
import pytest
from pytest_mock import MockerFixture
from superset.connectors.sqla.utils import get_columns_description
from superset.connectors.sqla.utils import (
get_columns_description,
get_virtual_table_metadata,
)
from superset.exceptions import SupersetSecurityException
# Returns column descriptions when given valid database, catalog, schema, and query
@ -89,3 +94,46 @@ def test_returns_column_descriptions(mocker: MockerFixture) -> None:
"is_dttm": False,
},
]
def test_get_virtual_table_metadata(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function.
"""
mocker.patch(
"superset.connectors.sqla.utils.get_columns_description",
return_value=[{"name": "one", "type": "INTEGER"}],
)
dataset = mocker.MagicMock(
sql="with source as ( select 1 as one ) select * from source",
)
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql
assert get_virtual_table_metadata(dataset) == [{"name": "one", "type": "INTEGER"}]
def test_get_virtual_table_metadata_mutating(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function with mutating SQL.
"""
dataset = mocker.MagicMock(sql="DROP TABLE sample_data")
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql
with pytest.raises(SupersetSecurityException) as excinfo:
get_virtual_table_metadata(dataset)
assert str(excinfo.value) == "Only `SELECT` statements are allowed"
def test_get_virtual_table_metadata_multiple(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function with multiple statements.
"""
dataset = mocker.MagicMock(sql="SELECT 1; SELECT 2")
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql
with pytest.raises(SupersetSecurityException) as excinfo:
get_virtual_table_metadata(dataset)
assert str(excinfo.value) == "Only single queries supported"