fix: CTE queries with non-SELECT statements (#25014)
This commit is contained in:
parent
6b660c86a4
commit
357986103b
|
|
@ -217,9 +217,53 @@ class ParsedQuery:
|
|||
def limit(self) -> Optional[int]:
|
||||
return self._limit
|
||||
|
||||
def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
if "with" not in parsed:
|
||||
return []
|
||||
return parsed["with"].get("cte_tables", [])
|
||||
|
||||
def _check_cte_is_select(self, oxide_parse: list[dict[str, Any]]) -> bool:
|
||||
"""
|
||||
Check if a oxide parsed CTE contains only SELECT statements
|
||||
|
||||
:param oxide_parse: parsed CTE
|
||||
:return: True if CTE is a SELECT statement
|
||||
"""
|
||||
for query in oxide_parse:
|
||||
parsed_query = query["Query"]
|
||||
cte_tables = self._get_cte_tables(parsed_query)
|
||||
for cte_table in cte_tables:
|
||||
is_select = all(
|
||||
key == "Select" for key in cte_table["query"]["body"].keys()
|
||||
)
|
||||
if not is_select:
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_select(self) -> bool:
|
||||
# make sure we strip comments; prevents a bug with comments in the CTE
|
||||
parsed = sqlparse.parse(self.strip_comments())
|
||||
|
||||
# Check if this is a CTE
|
||||
if parsed[0].is_group and parsed[0][0].ttype == Keyword.CTE:
|
||||
if sqloxide_parse is not None:
|
||||
try:
|
||||
if not self._check_cte_is_select(
|
||||
sqloxide_parse(self.strip_comments(), dialect="ansi")
|
||||
):
|
||||
return False
|
||||
except ValueError:
|
||||
# sqloxide was not able to parse the query, so let's continue with
|
||||
# sqlparse
|
||||
pass
|
||||
inner_cte = self.get_inner_cte_expression(parsed[0].tokens) or []
|
||||
# Check if the inner CTE is a not a SELECT
|
||||
if any(token.ttype == DDL for token in inner_cte) or any(
|
||||
token.ttype == DML and token.normalized != "SELECT"
|
||||
for token in inner_cte
|
||||
):
|
||||
return False
|
||||
|
||||
if parsed[0].get_type() == "SELECT":
|
||||
return True
|
||||
|
||||
|
|
@ -241,6 +285,17 @@ class ParsedQuery:
|
|||
token.ttype == DML and token.normalized == "SELECT" for token in parsed[0]
|
||||
)
|
||||
|
||||
def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
|
||||
for token in tokens:
|
||||
if self._is_identifier(token):
|
||||
for identifier_token in token.tokens:
|
||||
if (
|
||||
isinstance(identifier_token, Parenthesis)
|
||||
and identifier_token.is_group
|
||||
):
|
||||
return identifier_token.tokens
|
||||
return None
|
||||
|
||||
def is_valid_ctas(self) -> bool:
|
||||
parsed = sqlparse.parse(self.strip_comments())
|
||||
return parsed[-1].get_type() == "SELECT"
|
||||
|
|
|
|||
|
|
@ -1029,6 +1029,87 @@ FROM foo f"""
|
|||
assert sql.is_select()
|
||||
|
||||
|
||||
def test_cte_insert_is_not_select() -> None:
|
||||
"""
|
||||
Some CTEs with lowercase select are not correctly identified as SELECTS.
|
||||
"""
|
||||
sql = ParsedQuery(
|
||||
"""WITH foo AS(
|
||||
INSERT INTO foo (id) VALUES (1) RETURNING 1
|
||||
) select * FROM foo f"""
|
||||
)
|
||||
assert sql.is_select() is False
|
||||
|
||||
|
||||
def test_cte_delete_is_not_select() -> None:
|
||||
"""
|
||||
Some CTEs with lowercase select are not correctly identified as SELECTS.
|
||||
"""
|
||||
sql = ParsedQuery(
|
||||
"""WITH foo AS(
|
||||
DELETE FROM foo RETURNING *
|
||||
) select * FROM foo f"""
|
||||
)
|
||||
assert sql.is_select() is False
|
||||
|
||||
|
||||
def test_cte_is_not_select_lowercase() -> None:
|
||||
"""
|
||||
Some CTEs with lowercase select are not correctly identified as SELECTS.
|
||||
"""
|
||||
sql = ParsedQuery(
|
||||
"""WITH foo AS(
|
||||
insert into foo (id) values (1) RETURNING 1
|
||||
) select * FROM foo f"""
|
||||
)
|
||||
assert sql.is_select() is False
|
||||
|
||||
|
||||
def test_cte_with_multiple_selects() -> None:
|
||||
sql = ParsedQuery(
|
||||
"WITH a AS ( select * from foo1 ), b as (select * from foo2) SELECT * FROM a;"
|
||||
)
|
||||
assert sql.is_select()
|
||||
|
||||
|
||||
def test_cte_with_multiple_with_non_select() -> None:
|
||||
sql = ParsedQuery(
|
||||
"""WITH a AS (
|
||||
select * from foo1
|
||||
), b as (
|
||||
update foo2 set id=2
|
||||
) SELECT * FROM a"""
|
||||
)
|
||||
assert sql.is_select() is False
|
||||
sql = ParsedQuery(
|
||||
"""WITH a AS (
|
||||
update foo2 set name=2
|
||||
),
|
||||
b as (
|
||||
select * from foo1
|
||||
) SELECT * FROM a"""
|
||||
)
|
||||
assert sql.is_select() is False
|
||||
sql = ParsedQuery(
|
||||
"""WITH a AS (
|
||||
update foo2 set name=2
|
||||
),
|
||||
b as (
|
||||
update foo1 set name=2
|
||||
) SELECT * FROM a"""
|
||||
)
|
||||
assert sql.is_select() is False
|
||||
sql = ParsedQuery(
|
||||
"""WITH a AS (
|
||||
INSERT INTO foo (id) VALUES (1)
|
||||
),
|
||||
b as (
|
||||
select 1
|
||||
) SELECT * FROM a"""
|
||||
)
|
||||
assert sql.is_select() is False
|
||||
|
||||
|
||||
def test_unknown_select() -> None:
|
||||
"""
|
||||
Test that `is_select` works when sqlparse fails to identify the type.
|
||||
|
|
|
|||
Loading…
Reference in New Issue