1073 lines
28 KiB
Python
1073 lines
28 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
|
|
|
|
|
|
import pytest
|
|
from sqlglot import Dialects
|
|
|
|
from superset.exceptions import SupersetParseError
|
|
from superset.sql.parse import (
|
|
extract_tables_from_statement,
|
|
KustoKQLStatement,
|
|
split_kql,
|
|
SQLGLOT_DIALECTS,
|
|
SQLScript,
|
|
SQLStatement,
|
|
Table,
|
|
)
|
|
|
|
|
|
def test_table() -> None:
|
|
"""
|
|
Test the `Table` class and its string conversion.
|
|
|
|
Special characters in the table, schema, or catalog name should be escaped correctly.
|
|
"""
|
|
assert str(Table("tbname")) == "tbname"
|
|
assert str(Table("tbname", "schemaname")) == "schemaname.tbname"
|
|
assert (
|
|
str(Table("tbname", "schemaname", "catalogname"))
|
|
== "catalogname.schemaname.tbname"
|
|
)
|
|
assert (
|
|
str(Table("table.name", "schema/name", "catalog\nname"))
|
|
== "catalog%0Aname.schema%2Fname.table%2Ename"
|
|
)
|
|
|
|
|
|
def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]:
|
|
"""
|
|
Helper function to extract tables from SQL.
|
|
"""
|
|
dialect = SQLGLOT_DIALECTS.get(engine)
|
|
return {
|
|
table
|
|
for statement in SQLScript(sql, engine).statements
|
|
for table in extract_tables_from_statement(statement._parsed, dialect)
|
|
}
|
|
|
|
|
|
def test_extract_tables_from_sql() -> None:
|
|
"""
|
|
Test that referenced tables are parsed correctly from the SQL.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM tbname") == {Table("tbname")}
|
|
assert extract_tables_from_sql("SELECT * FROM tbname foo") == {Table("tbname")}
|
|
assert extract_tables_from_sql("SELECT * FROM tbname AS foo") == {Table("tbname")}
|
|
|
|
# underscore
|
|
assert extract_tables_from_sql("SELECT * FROM tb_name") == {Table("tb_name")}
|
|
|
|
# quotes
|
|
assert extract_tables_from_sql('SELECT * FROM "tbname"') == {Table("tbname")}
|
|
|
|
# unicode
|
|
assert extract_tables_from_sql('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == {
|
|
Table("tb_name")
|
|
}
|
|
|
|
# columns
|
|
assert extract_tables_from_sql("SELECT field1, field2 FROM tb_name") == {
|
|
Table("tb_name")
|
|
}
|
|
assert extract_tables_from_sql("SELECT t1.f1, t2.f2 FROM t1, t2") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
|
|
# named table
|
|
assert extract_tables_from_sql(
|
|
"SELECT a.date, a.field FROM left_table a LIMIT 10"
|
|
) == {Table("left_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
|
|
) == {Table("forbidden_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"select * from (select * from forbidden_table) forbidden_table"
|
|
) == {Table("forbidden_table")}
|
|
|
|
|
|
def test_extract_tables_subselect() -> None:
|
|
"""
|
|
Test that tables inside subselects are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT sub.*
|
|
FROM (
|
|
SELECT *
|
|
FROM s1.t1
|
|
WHERE day_of_week = 'Friday'
|
|
) sub, s2.t2
|
|
WHERE sub.resolution = 'NONE'
|
|
"""
|
|
) == {Table("t1", "s1"), Table("t2", "s2")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT sub.*
|
|
FROM (
|
|
SELECT *
|
|
FROM s1.t1
|
|
WHERE day_of_week = 'Friday'
|
|
) sub
|
|
WHERE sub.resolution = 'NONE'
|
|
"""
|
|
) == {Table("t1", "s1")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT * FROM t1
|
|
WHERE s11 > ANY (
|
|
SELECT COUNT(*) /* no hint */ FROM t2
|
|
WHERE NOT EXISTS (
|
|
SELECT * FROM t3
|
|
WHERE ROW(5*t2.s1,77)=(
|
|
SELECT 50,11*s1 FROM t4
|
|
)
|
|
)
|
|
)
|
|
"""
|
|
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
|
|
|
|
|
def test_extract_tables_select_in_expression() -> None:
|
|
"""
|
|
Test that parser works with `SELECT`s used as expressions.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql(
|
|
"SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1"
|
|
) == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
|
|
|
|
def test_extract_tables_parenthesis() -> None:
|
|
"""
|
|
Test that parenthesis are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_with_schema() -> None:
|
|
"""
|
|
Test that schemas are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM schemaname.tbname") == {
|
|
Table("tbname", "schemaname")
|
|
}
|
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname"') == {
|
|
Table("tbname", "schemaname")
|
|
}
|
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" foo') == {
|
|
Table("tbname", "schemaname")
|
|
}
|
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" AS foo') == {
|
|
Table("tbname", "schemaname")
|
|
}
|
|
|
|
|
|
def test_extract_tables_union() -> None:
|
|
"""
|
|
Test that `UNION` queries work as expected.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM t1 UNION SELECT * FROM t2") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql(
|
|
"SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
|
|
) == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
|
|
|
|
def test_extract_tables_select_from_values() -> None:
|
|
"""
|
|
Test that selecting from values returns no tables.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM VALUES (13, 42)") == set()
|
|
|
|
|
|
def test_extract_tables_select_array() -> None:
|
|
"""
|
|
Test that queries selecting arrays work as expected.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT ARRAY[1, 2, 3] AS my_array
|
|
FROM t1 LIMIT 10
|
|
"""
|
|
) == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_select_if() -> None:
|
|
"""
|
|
Test that queries with an `IF` work as expected.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
|
FROM t1 LIMIT 10
|
|
"""
|
|
) == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_with_catalog() -> None:
|
|
"""
|
|
Test that catalogs are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM catalogname.schemaname.tbname") == {
|
|
Table("tbname", "schemaname", "catalogname")
|
|
}
|
|
|
|
|
|
def test_extract_tables_illdefined() -> None:
|
|
"""
|
|
Test that ill-defined tables return an empty set.
|
|
"""
|
|
with pytest.raises(SupersetParseError) as excinfo:
|
|
extract_tables_from_sql("SELECT * FROM schemaname.")
|
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:25"
|
|
|
|
with pytest.raises(SupersetParseError) as excinfo:
|
|
extract_tables_from_sql("SELECT * FROM catalogname.schemaname.")
|
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:37"
|
|
|
|
with pytest.raises(SupersetParseError) as excinfo:
|
|
extract_tables_from_sql("SELECT * FROM catalogname..")
|
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:27"
|
|
|
|
with pytest.raises(SupersetParseError) as excinfo:
|
|
extract_tables_from_sql('SELECT * FROM "tbname')
|
|
assert str(excinfo.value) == "Unable to parse script"
|
|
|
|
# odd edge case that works
|
|
assert extract_tables_from_sql("SELECT * FROM catalogname..tbname") == {
|
|
Table(table="tbname", schema=None, catalog="catalogname")
|
|
}
|
|
|
|
|
|
def test_extract_tables_show_tables_from() -> None:
|
|
"""
|
|
Test `SHOW TABLES FROM`.
|
|
"""
|
|
assert (
|
|
extract_tables_from_sql("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
|
|
)
|
|
|
|
|
|
def test_format_show_tables() -> None:
|
|
"""
|
|
Test format when `ast.sql()` raises an exception.
|
|
|
|
In that case sqlparse should be used instead.
|
|
"""
|
|
assert (
|
|
SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
|
|
== "SHOW TABLES FROM s1 LIKE '%order%'"
|
|
)
|
|
|
|
|
|
def test_format_no_dialect() -> None:
|
|
"""
|
|
Test format with an engine that has no corresponding dialect.
|
|
"""
|
|
assert (
|
|
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format()
|
|
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
|
|
)
|
|
|
|
|
|
def test_split_no_dialect() -> None:
|
|
"""
|
|
Test the statement split when the engine has no corresponding dialect.
|
|
"""
|
|
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
|
|
statements = SQLScript(sql, "firebolt").statements
|
|
assert len(statements) == 3
|
|
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
|
|
assert statements[1]._sql == "SELECT * FROM t"
|
|
assert statements[2]._sql == "SELECT foo"
|
|
|
|
|
|
def test_extract_tables_show_columns_from() -> None:
|
|
"""
|
|
Test `SHOW COLUMNS FROM`.
|
|
"""
|
|
assert extract_tables_from_sql("SHOW COLUMNS FROM t1") == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_where_subquery() -> None:
|
|
"""
|
|
Test that tables in a `WHERE` subquery are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT name
|
|
FROM t1
|
|
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
|
"""
|
|
) == {Table("t1"), Table("t2")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT name
|
|
FROM t1
|
|
WHERE regionkey IN (SELECT regionkey FROM t2)
|
|
"""
|
|
) == {Table("t1"), Table("t2")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT name
|
|
FROM t1
|
|
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
|
"""
|
|
) == {Table("t1"), Table("t2")}
|
|
|
|
|
|
def test_extract_tables_describe() -> None:
|
|
"""
|
|
Test `DESCRIBE`.
|
|
"""
|
|
assert extract_tables_from_sql("DESCRIBE t1") == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_show_partitions() -> None:
|
|
"""
|
|
Test `SHOW PARTITIONS`.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SHOW PARTITIONS FROM orders
|
|
WHERE ds >= '2013-01-01' ORDER BY ds DESC
|
|
"""
|
|
) == {Table("orders")}
|
|
|
|
|
|
def test_extract_tables_join() -> None:
|
|
"""
|
|
Test joins.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
|
|
) == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.date = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
LEFT INNER JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.date = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
RIGHT OUTER JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.date = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
FULL OUTER JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.date = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
|
|
def test_extract_tables_semi_join() -> None:
|
|
"""
|
|
Test `LEFT SEMI JOIN`.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
LEFT SEMI JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.data = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
|
|
def test_extract_tables_combinations() -> None:
|
|
"""
|
|
Test a complex case with nested queries.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT * FROM t1
|
|
WHERE s11 > ANY (
|
|
SELECT * FROM t1 UNION ALL SELECT * FROM (
|
|
SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a
|
|
) tmp_join
|
|
WHERE NOT EXISTS (
|
|
SELECT * FROM t3
|
|
WHERE ROW(5*t3.s1,77)=(
|
|
SELECT 50,11*s1 FROM t4
|
|
)
|
|
)
|
|
)
|
|
"""
|
|
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT * FROM (
|
|
SELECT * FROM (
|
|
SELECT * FROM (
|
|
SELECT * FROM EmployeeS
|
|
) AS S1
|
|
) AS S2
|
|
) AS S3
|
|
"""
|
|
) == {Table("EmployeeS")}
|
|
|
|
|
|
def test_extract_tables_with() -> None:
|
|
"""
|
|
Test `WITH`.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
WITH
|
|
x AS (SELECT a FROM t1),
|
|
y AS (SELECT a AS b FROM t2),
|
|
z AS (SELECT b AS c FROM t3)
|
|
SELECT c FROM z
|
|
"""
|
|
) == {Table("t1"), Table("t2"), Table("t3")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
WITH
|
|
x AS (SELECT a FROM t1),
|
|
y AS (SELECT a AS b FROM x),
|
|
z AS (SELECT b AS c FROM y)
|
|
SELECT c FROM z
|
|
"""
|
|
) == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_reusing_aliases() -> None:
|
|
"""
|
|
Test that the parser follows aliases.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
with q1 as ( select key from q2 where key = '5'),
|
|
q2 as ( select key from src where key = '5')
|
|
select * from (select key from q1) a
|
|
"""
|
|
) == {Table("src")}
|
|
|
|
# weird query with circular dependency
|
|
assert (
|
|
extract_tables_from_sql(
|
|
"""
|
|
with src as ( select key from q2 where key = '5'),
|
|
q2 as ( select key from src where key = '5')
|
|
select * from (select key from src) a
|
|
"""
|
|
)
|
|
== set()
|
|
)
|
|
|
|
|
|
def test_extract_tables_multistatement() -> None:
|
|
"""
|
|
Test that the parser works with multiple statements.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2;") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql(
|
|
"ADD JAR file:///hive.jar; SELECT * FROM t1;",
|
|
engine="hive",
|
|
) == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_complex() -> None:
|
|
"""
|
|
Test a few complex queries.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT sum(m_examples) AS "sum__m_example"
|
|
FROM (
|
|
SELECT
|
|
COUNT(DISTINCT id_userid) AS m_examples,
|
|
some_more_info
|
|
FROM my_b_table b
|
|
JOIN my_t_table t ON b.ds=t.ds
|
|
JOIN my_l_table l ON b.uid=l.uid
|
|
WHERE
|
|
b.rid IN (
|
|
SELECT other_col
|
|
FROM inner_table
|
|
)
|
|
AND l.bla IN ('x', 'y')
|
|
GROUP BY 2
|
|
ORDER BY 2 ASC
|
|
) AS "meh"
|
|
ORDER BY "sum__m_example" DESC
|
|
LIMIT 10;
|
|
"""
|
|
) == {
|
|
Table("my_l_table"),
|
|
Table("my_b_table"),
|
|
Table("my_t_table"),
|
|
Table("inner_table"),
|
|
}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT *
|
|
FROM table_a AS a, table_b AS b, table_c as c
|
|
WHERE a.id = b.id and b.id = c.id
|
|
"""
|
|
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT somecol AS somecol
|
|
FROM (
|
|
WITH bla AS (
|
|
SELECT col_a
|
|
FROM a
|
|
WHERE
|
|
1=1
|
|
AND column_of_choice NOT IN (
|
|
SELECT interesting_col
|
|
FROM b
|
|
)
|
|
),
|
|
rb AS (
|
|
SELECT yet_another_column
|
|
FROM (
|
|
SELECT a
|
|
FROM c
|
|
GROUP BY the_other_col
|
|
) not_table
|
|
LEFT JOIN bla foo
|
|
ON foo.prop = not_table.bad_col0
|
|
WHERE 1=1
|
|
GROUP BY
|
|
not_table.bad_col1 ,
|
|
not_table.bad_col2 ,
|
|
ORDER BY not_table.bad_col_3 DESC ,
|
|
not_table.bad_col4 ,
|
|
not_table.bad_col5
|
|
)
|
|
SELECT random_col
|
|
FROM d
|
|
WHERE 1=1
|
|
UNION ALL SELECT even_more_cols
|
|
FROM e
|
|
WHERE 1=1
|
|
UNION ALL SELECT lets_go_deeper
|
|
FROM f
|
|
WHERE 1=1
|
|
WHERE 2=2
|
|
GROUP BY last_col
|
|
LIMIT 50000
|
|
)
|
|
"""
|
|
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
|
|
|
|
|
def test_extract_tables_mixed_from_clause() -> None:
|
|
"""
|
|
Test that the parser handles a `FROM` clause with table and subselect.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT *
|
|
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
|
WHERE a.id = b.id and b.id = c.id
|
|
"""
|
|
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
|
|
|
|
|
def test_extract_tables_nested_select() -> None:
|
|
"""
|
|
Test that the parser handles selects inside functions.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
|
from INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
|
""",
|
|
"mysql",
|
|
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
|
from INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
|
""",
|
|
"mysql",
|
|
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
|
|
|
|
|
def test_extract_tables_complex_cte_with_prefix() -> None:
|
|
"""
|
|
Test that the parser handles CTEs with prefixes.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
|
AS (
|
|
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
|
FROM SalesOrderHeader
|
|
WHERE SalesPersonID IS NOT NULL
|
|
)
|
|
SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
|
|
FROM CTE__test
|
|
GROUP BY SalesYear, SalesPersonID
|
|
ORDER BY SalesPersonID, SalesYear;
|
|
"""
|
|
) == {Table("SalesOrderHeader")}
|
|
|
|
|
|
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
|
|
"""
|
|
Test that aliases that are keywords are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
WITH
|
|
f AS (SELECT * FROM foo),
|
|
match AS (SELECT * FROM f)
|
|
SELECT * FROM match
|
|
"""
|
|
) == {Table("foo")}
|
|
|
|
|
|
def test_sqlscript() -> None:
|
|
"""
|
|
Test the `SQLScript` class.
|
|
"""
|
|
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
|
|
|
assert len(script.statements) == 2
|
|
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
|
assert script.statements[0].format() == "SELECT\n 1"
|
|
|
|
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
|
assert script.get_settings() == {"a": "2"}
|
|
|
|
query = SQLScript(
|
|
"""set querytrace;
|
|
Events | take 100""",
|
|
"kustokql",
|
|
)
|
|
assert query.get_settings() == {"querytrace": True}
|
|
|
|
|
|
def test_sqlstatement() -> None:
|
|
"""
|
|
Test the `SQLStatement` class.
|
|
"""
|
|
statement = SQLStatement(
|
|
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
|
|
"sqlite",
|
|
)
|
|
|
|
assert statement.tables == {
|
|
Table(table="table1", schema=None, catalog=None),
|
|
Table(table="table2", schema=None, catalog=None),
|
|
}
|
|
assert (
|
|
statement.format()
|
|
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
|
|
)
|
|
|
|
statement = SQLStatement("SET a=1", "sqlite")
|
|
assert statement.get_settings() == {"a": "1"}
|
|
|
|
|
|
def test_kustokqlstatement_split_script() -> None:
|
|
"""
|
|
Test the `KustoKQLStatement` split method.
|
|
"""
|
|
statements = KustoKQLStatement.split_script(
|
|
"""
|
|
let totalPagesPerDay = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp)
|
|
| summarize count() by Day;
|
|
let materializedScope = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp);
|
|
let cachedResult = materialize(materializedScope);
|
|
cachedResult
|
|
| project Page, Day1 = Day
|
|
| join kind = inner
|
|
(
|
|
cachedResult
|
|
| project Page, Day2 = Day
|
|
)
|
|
on Page
|
|
| where Day2 > Day1
|
|
| summarize count() by Day1, Day2
|
|
| join kind = inner
|
|
totalPagesPerDay
|
|
on $left.Day1 == $right.Day
|
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
""",
|
|
"kustokql",
|
|
)
|
|
assert len(statements) == 4
|
|
|
|
|
|
def test_kustokqlstatement_with_program() -> None:
|
|
"""
|
|
Test the `KustoKQLStatement` split method when the KQL has a program.
|
|
"""
|
|
statements = KustoKQLStatement.split_script(
|
|
"""
|
|
print program = ```
|
|
public class Program {
|
|
public static void Main() {
|
|
System.Console.WriteLine("Hello!");
|
|
}
|
|
}```
|
|
""",
|
|
"kustokql",
|
|
)
|
|
assert len(statements) == 1
|
|
|
|
|
|
def test_kustokqlstatement_with_set() -> None:
|
|
"""
|
|
Test the `KustoKQLStatement` split method when the KQL has a set command.
|
|
"""
|
|
statements = KustoKQLStatement.split_script(
|
|
"""
|
|
set querytrace;
|
|
Events | take 100
|
|
""",
|
|
"kustokql",
|
|
)
|
|
assert len(statements) == 2
|
|
assert statements[0].format() == "set querytrace"
|
|
assert statements[1].format() == "Events | take 100"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kql,statements",
|
|
[
|
|
('print banner=strcat("Hello", ", ", "World!")', 1),
|
|
(r"print 'O\'Malley\'s'", 1),
|
|
(r"print 'O\'Mal;ley\'s'", 1),
|
|
("print ```foo;\nbar;\nbaz;```\n", 1),
|
|
],
|
|
)
|
|
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
|
|
assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements
|
|
|
|
|
|
def test_split_kql() -> None:
|
|
"""
|
|
Test the `split_kql` function.
|
|
"""
|
|
kql = """
|
|
let totalPagesPerDay = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp)
|
|
| summarize count() by Day;
|
|
let materializedScope = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp);
|
|
let cachedResult = materialize(materializedScope);
|
|
cachedResult
|
|
| project Page, Day1 = Day
|
|
| join kind = inner
|
|
(
|
|
cachedResult
|
|
| project Page, Day2 = Day
|
|
)
|
|
on Page
|
|
| where Day2 > Day1
|
|
| summarize count() by Day1, Day2
|
|
| join kind = inner
|
|
totalPagesPerDay
|
|
on $left.Day1 == $right.Day
|
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
"""
|
|
assert split_kql(kql) == [
|
|
"""
|
|
let totalPagesPerDay = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp)
|
|
| summarize count() by Day""",
|
|
"""
|
|
let materializedScope = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp)""",
|
|
"""
|
|
let cachedResult = materialize(materializedScope)""",
|
|
"""
|
|
cachedResult
|
|
| project Page, Day1 = Day
|
|
| join kind = inner
|
|
(
|
|
cachedResult
|
|
| project Page, Day2 = Day
|
|
)
|
|
on Page
|
|
| where Day2 > Day1
|
|
| summarize count() by Day1, Day2
|
|
| join kind = inner
|
|
totalPagesPerDay
|
|
on $left.Day1 == $right.Day
|
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
""",
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("engine", "sql", "expected"),
|
|
[
|
|
("sqlite", "SELECT 1", False),
|
|
("sqlite", "INSERT INTO foo VALUES (1)", True),
|
|
("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True),
|
|
("sqlite", "DELETE FROM foo WHERE id = 1", True),
|
|
("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True),
|
|
("sqlite", "DROP TABLE foo", True),
|
|
("sqlite", "EXPLAIN SELECT * FROM foo", False),
|
|
("sqlite", "PRAGMA table_info(foo)", False),
|
|
("postgresql", "SELECT 1", False),
|
|
("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True),
|
|
("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True),
|
|
("postgresql", "DELETE FROM foo WHERE id = 1", True),
|
|
("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True),
|
|
("postgresql", "DROP TABLE foo", True),
|
|
("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False),
|
|
("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True),
|
|
("postgresql", "SHOW search_path", False),
|
|
("postgresql", "SET search_path TO public", False),
|
|
(
|
|
"postgres",
|
|
"""
|
|
with source as (
|
|
select 1 as one
|
|
)
|
|
select * from source
|
|
""",
|
|
False,
|
|
),
|
|
("trino", "SELECT 1", False),
|
|
("trino", "INSERT INTO foo VALUES (1, 'bar')", True),
|
|
("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True),
|
|
("trino", "DELETE FROM foo WHERE id = 1", True),
|
|
("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True),
|
|
("trino", "DROP TABLE foo", True),
|
|
("trino", "EXPLAIN SELECT * FROM foo", False),
|
|
("trino", "SHOW SCHEMAS", False),
|
|
("trino", "SET SESSION optimization_level = '3'", False),
|
|
("kustokql", "tbl | limit 100", False),
|
|
("kustokql", "let foo = 1; tbl | where bar == foo", False),
|
|
("kustokql", ".show tables", False),
|
|
("kustokql", "print 1", False),
|
|
("kustokql", "set querytrace; Events | take 100", False),
|
|
("kustokql", ".drop table foo", True),
|
|
("kustokql", ".set-or-append table foo <| bar", True),
|
|
("base", "SHOW LOCKS test EXTENDED", False),
|
|
("base", "SET hivevar:desc='Legislators'", False),
|
|
("base", "UPDATE t1 SET col1 = NULL", True),
|
|
("base", "EXPLAIN SELECT 1", False),
|
|
("base", "SELECT 1", False),
|
|
("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
|
("base", "SHOW CATALOGS", False),
|
|
("base", "SHOW TABLES", False),
|
|
("hive", "UPDATE t1 SET col1 = NULL", True),
|
|
("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
|
|
("hive", "SHOW LOCKS test EXTENDED", False),
|
|
("hive", "SET hivevar:desc='Legislators'", False),
|
|
("hive", "EXPLAIN SELECT 1", False),
|
|
("hive", "SELECT 1", False),
|
|
("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
|
("presto", "SET hivevar:desc='Legislators'", False),
|
|
("presto", "UPDATE t1 SET col1 = NULL", True),
|
|
("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
|
|
("presto", "SHOW LOCKS test EXTENDED", False),
|
|
("presto", "EXPLAIN SELECT 1", False),
|
|
("presto", "SELECT 1", False),
|
|
("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
|
],
|
|
)
|
|
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
|
|
"""
|
|
Test the `has_mutation` method.
|
|
"""
|
|
assert SQLScript(sql, engine).has_mutation() == expected
|
|
|
|
|
|
def test_get_settings() -> None:
|
|
"""
|
|
Test `get_settings` in some edge cases.
|
|
"""
|
|
sql = """
|
|
set
|
|
-- this is a tricky comment
|
|
search_path -- another one
|
|
= bar;
|
|
SELECT * FROM some_table;
|
|
"""
|
|
assert SQLScript(sql, "postgresql").get_settings() == {"search_path": "bar"}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"app",
|
|
[{"SQLGLOT_DIALECTS_EXTENSIONS": {"custom": Dialects.MYSQL}}],
|
|
indirect=True,
|
|
)
|
|
def test_custom_dialect(app: None) -> None:
|
|
"""
|
|
Test that custom dialects are loaded correctly.
|
|
"""
|
|
assert SQLGLOT_DIALECTS.get("custom") == Dialects.MYSQL
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"engine",
|
|
[
|
|
"ascend",
|
|
"awsathena",
|
|
"base",
|
|
"bigquery",
|
|
"clickhouse",
|
|
"clickhousedb",
|
|
"cockroachdb",
|
|
"couchbase",
|
|
"crate",
|
|
"databend",
|
|
"databricks",
|
|
"db2",
|
|
"denodo",
|
|
"dremio",
|
|
"drill",
|
|
"druid",
|
|
"duckdb",
|
|
"dynamodb",
|
|
"elasticsearch",
|
|
"exa",
|
|
"firebird",
|
|
"firebolt",
|
|
"gsheets",
|
|
"hana",
|
|
"hive",
|
|
"ibmi",
|
|
"impala",
|
|
"kustokql",
|
|
"kustosql",
|
|
"kylin",
|
|
"mariadb",
|
|
"motherduck",
|
|
"mssql",
|
|
"mysql",
|
|
"netezza",
|
|
"oceanbase",
|
|
"ocient",
|
|
"odelasticsearch",
|
|
"oracle",
|
|
"pinot",
|
|
"postgresql",
|
|
"presto",
|
|
"pydoris",
|
|
"redshift",
|
|
"risingwave",
|
|
"rockset",
|
|
"shillelagh",
|
|
"snowflake",
|
|
"solr",
|
|
"sqlite",
|
|
"starrocks",
|
|
"superset",
|
|
"teradatasql",
|
|
"trino",
|
|
"vertica",
|
|
],
|
|
)
|
|
def test_is_mutating(engine: str) -> None:
|
|
"""
|
|
Global tests for `is_mutating`, covering all supported engines.
|
|
"""
|
|
assert not SQLStatement(
|
|
"with source as ( select 1 as one ) select * from source",
|
|
engine=engine,
|
|
).is_mutating()
|