Split up tests/db_engine_test.py (#8449)
* Split up db_engine_specs_test.py into a number of targeted files * Remove db_engine_specs_test.py * isort
This commit is contained in:
parent
1d5718a1a8
commit
82b174701f
|
|
@ -0,0 +1,204 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
from superset import app
|
||||
from superset.db_engine_specs import engines
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, builtin_time_grains
|
||||
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
||||
from superset.utils.core import get_example_database
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class DbEngineSpecsTests(DbEngineSpecTestCase):
|
||||
def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
|
||||
q0 = "select * from table"
|
||||
q1 = "select * from mytable limit 10"
|
||||
q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
|
||||
q3 = "select * from (select * from my_subquery limit 10);"
|
||||
q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
|
||||
q5 = "select * from mytable limit 20, 10"
|
||||
q6 = "select * from mytable limit 10 offset 20"
|
||||
q7 = "select * from mytable limit"
|
||||
q8 = "select * from mytable limit 10.0"
|
||||
q9 = "select * from mytable limit x"
|
||||
q10 = "select * from mytable limit 20, x"
|
||||
q11 = "select * from mytable limit x offset 20"
|
||||
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
|
||||
|
||||
def test_wrapped_semi_tabs(self):
|
||||
self.sql_limit_regex(
|
||||
"SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
|
||||
)
|
||||
|
||||
def test_simple_limit_query(self):
|
||||
self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
|
||||
|
||||
def test_modify_limit_query(self):
|
||||
self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
|
||||
|
||||
def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
|
||||
self.sql_limit_regex(
|
||||
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
|
||||
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
|
||||
)
|
||||
|
||||
def test_limit_with_expr(self):
|
||||
self.sql_limit_regex(
|
||||
"""
|
||||
SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990""",
|
||||
"""SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 1000""",
|
||||
)
|
||||
|
||||
def test_limit_expr_and_semicolon(self):
|
||||
self.sql_limit_regex(
|
||||
"""
|
||||
SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990 ;""",
|
||||
"""SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 1000""",
|
||||
)
|
||||
|
||||
def test_get_datatype(self):
|
||||
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
|
||||
|
||||
def test_limit_with_implicit_offset(self):
|
||||
self.sql_limit_regex(
|
||||
"""
|
||||
SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990, 999999""",
|
||||
"""SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990, 1000""",
|
||||
)
|
||||
|
||||
def test_limit_with_explicit_offset(self):
|
||||
self.sql_limit_regex(
|
||||
"""
|
||||
SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990
|
||||
OFFSET 999999""",
|
||||
"""SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 1000
|
||||
OFFSET 999999""",
|
||||
)
|
||||
|
||||
def test_limit_with_non_token_limit(self):
|
||||
self.sql_limit_regex(
|
||||
"""SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
|
||||
)
|
||||
|
||||
def test_time_grain_blacklist(self):
|
||||
with app.app_context():
|
||||
app.config["TIME_GRAIN_BLACKLIST"] = ["PT1M"]
|
||||
time_grain_functions = SqliteEngineSpec.get_time_grain_functions()
|
||||
self.assertNotIn("PT1M", time_grain_functions)
|
||||
|
||||
def test_time_grain_addons(self):
|
||||
with app.app_context():
|
||||
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
|
||||
app.config["TIME_GRAIN_ADDON_FUNCTIONS"] = {
|
||||
"sqlite": {"PTXM": "ABC({col})"}
|
||||
}
|
||||
time_grains = SqliteEngineSpec.get_time_grains()
|
||||
time_grain_addon = time_grains[-1]
|
||||
self.assertEqual("PTXM", time_grain_addon.duration)
|
||||
self.assertEqual("x seconds", time_grain_addon.label)
|
||||
|
||||
def test_engine_time_grain_validity(self):
|
||||
time_grains = set(builtin_time_grains.keys())
|
||||
# loop over all subclasses of BaseEngineSpec
|
||||
for engine in engines.values():
|
||||
if engine is not BaseEngineSpec:
|
||||
# make sure time grain functions have been defined
|
||||
self.assertGreater(len(engine.get_time_grain_functions()), 0)
|
||||
# make sure all defined time grains are supported
|
||||
defined_grains = {grain.duration for grain in engine.get_time_grains()}
|
||||
intersection = time_grains.intersection(defined_grains)
|
||||
self.assertSetEqual(defined_grains, intersection, engine)
|
||||
|
||||
def test_get_table_names(self):
|
||||
inspector = mock.Mock()
|
||||
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
|
||||
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
|
||||
|
||||
""" Make sure base engine spec removes schema name from table name
|
||||
ie. when try_remove_schema_from_table_name == True. """
|
||||
base_result_expected = ["table", "table_2"]
|
||||
base_result = BaseEngineSpec.get_table_names(
|
||||
database=mock.ANY, schema="schema", inspector=inspector
|
||||
)
|
||||
self.assertListEqual(base_result_expected, base_result)
|
||||
|
||||
def test_column_datatype_to_string(self):
|
||||
example_db = get_example_database()
|
||||
sqla_table = example_db.get_table("energy_usage")
|
||||
dialect = example_db.get_dialect()
|
||||
col_names = [
|
||||
example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
|
||||
for c in sqla_table.columns
|
||||
]
|
||||
if example_db.backend == "postgresql":
|
||||
expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
|
||||
else:
|
||||
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
|
||||
self.assertEqual(col_names, expected)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.models.core import Database
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class DbEngineSpecTestCase(SupersetTestCase):
|
||||
def sql_limit_regex(
|
||||
self, sql, expected_sql, engine_spec_class=MySQLEngineSpec, limit=1000
|
||||
):
|
||||
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
|
||||
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
|
||||
self.assertEqual(expected_sql, limited)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from sqlalchemy import column
|
||||
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class BigQueryTestCase(DbEngineSpecTestCase):
|
||||
def test_bigquery_sqla_column_label(self):
|
||||
label = BigQueryEngineSpec.make_label_compatible(column("Col").name)
|
||||
label_expected = "Col"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BigQueryEngineSpec.make_label_compatible(column("SUM(x)").name)
|
||||
label_expected = "SUM_x__5f110"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BigQueryEngineSpec.make_label_compatible(column("SUM[x]").name)
|
||||
label_expected = "SUM_x__7ebe1"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BigQueryEngineSpec.make_label_compatible(column("12345_col").name)
|
||||
label_expected = "_12345_col_8d390"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class HiveTests(DbEngineSpecTestCase):
|
||||
def test_0_progress(self):
|
||||
log = """
|
||||
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
|
||||
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_number_of_jobs_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_stage_1(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_stage_1_map_40_progress(
|
||||
self
|
||||
): # pylint: disable=invalid-name
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(10, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_stage_1_map_80_reduce_40_progress(
|
||||
self
|
||||
): # pylint: disable=invalid-name
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(30, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_stage_2_stages_progress(
|
||||
self
|
||||
): # pylint: disable=invalid-name
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(12, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_2_launched_stage_2_stages_progress(
|
||||
self
|
||||
): # pylint: disable=invalid-name
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(60, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_hive_error_msg(self):
|
||||
msg = (
|
||||
'{...} errorMessage="Error while compiling statement: FAILED: '
|
||||
"SemanticException [Error 10001]: Line 4"
|
||||
":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, "
|
||||
"sqlState='42S02', errorCode=10001)){...}"
|
||||
)
|
||||
self.assertEqual(
|
||||
(
|
||||
"hive error: Error while compiling statement: FAILED: "
|
||||
"SemanticException [Error 10001]: Line 4:5 "
|
||||
"Table not found 'fact_ridesfdslakj'"
|
||||
),
|
||||
HiveEngineSpec.extract_error_message(Exception(msg)),
|
||||
)
|
||||
|
||||
e = Exception("Some string that doesn't match the regex")
|
||||
self.assertEqual(f"hive error: {e}", HiveEngineSpec.extract_error_message(e))
|
||||
|
||||
msg = (
|
||||
"errorCode=10001, "
|
||||
'errorMessage="Error while compiling statement"), operationHandle'
|
||||
'=None)"'
|
||||
)
|
||||
self.assertEqual(
|
||||
("hive error: Error while compiling statement"),
|
||||
HiveEngineSpec.extract_error_message(Exception(msg)),
|
||||
)
|
||||
|
||||
def test_hive_get_view_names_return_empty_list(
|
||||
self
|
||||
): # pylint: disable=invalid-name
|
||||
self.assertEqual(
|
||||
[], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
|
||||
)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from sqlalchemy import column, table
|
||||
from sqlalchemy.dialects import mssql
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy.types import String, UnicodeText
|
||||
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class MssqlEngineSpecTest(DbEngineSpecTestCase):
|
||||
def test_mssql_column_types(self):
|
||||
def assert_type(type_string, type_expected):
|
||||
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
|
||||
if type_expected is None:
|
||||
self.assertIsNone(type_assigned)
|
||||
else:
|
||||
self.assertIsInstance(type_assigned, type_expected)
|
||||
|
||||
assert_type("INT", None)
|
||||
assert_type("STRING", String)
|
||||
assert_type("CHAR(10)", String)
|
||||
assert_type("VARCHAR(10)", String)
|
||||
assert_type("TEXT", String)
|
||||
assert_type("NCHAR(10)", UnicodeText)
|
||||
assert_type("NVARCHAR(10)", UnicodeText)
|
||||
assert_type("NTEXT", UnicodeText)
|
||||
|
||||
def test_where_clause_n_prefix(self):
|
||||
dialect = mssql.dialect()
|
||||
spec = MssqlEngineSpec
|
||||
str_col = column("col", type_=spec.get_sqla_column_type("VARCHAR(10)"))
|
||||
unicode_col = column("unicode_col", type_=spec.get_sqla_column_type("NTEXT"))
|
||||
tbl = table("tbl")
|
||||
sel = (
|
||||
select([str_col, unicode_col])
|
||||
.select_from(tbl)
|
||||
.where(str_col == "abc")
|
||||
.where(unicode_col == "abc")
|
||||
)
|
||||
|
||||
query = str(
|
||||
sel.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
||||
)
|
||||
query_expected = (
|
||||
"SELECT col, unicode_col \n"
|
||||
"FROM tbl \n"
|
||||
"WHERE col = 'abc' AND unicode_col = N'abc'"
|
||||
)
|
||||
self.assertEqual(query, query_expected)
|
||||
|
||||
def test_time_exp_mixd_case_col_1y(self):
|
||||
col = column("MixedCase")
|
||||
expr = MssqlEngineSpec.get_timestamp_expr(col, None, "P1Y")
|
||||
result = str(expr.compile(None, dialect=mssql.dialect()))
|
||||
self.assertEqual(result, "DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)")
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import unittest
|
||||
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class MySQLEngineSpecsTestCase(DbEngineSpecTestCase):
|
||||
@unittest.skipUnless(
|
||||
DbEngineSpecTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
|
||||
)
|
||||
def test_get_datatype_mysql(self):
|
||||
"""Tests related to datatype mapping for MySQL"""
|
||||
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
|
||||
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy.dialects import oracle
|
||||
|
||||
from superset.db_engine_specs.oracle import OracleEngineSpec
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class OracleTestCase(DbEngineSpecTestCase):
|
||||
def test_oracle_sqla_column_name_length_exceeded(self):
|
||||
col = column("This_Is_32_Character_Column_Name")
|
||||
label = OracleEngineSpec.make_label_compatible(col.name)
|
||||
self.assertEqual(label.quote, True)
|
||||
label_expected = "3b26974078683be078219674eeb8f5"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
|
||||
col = column("decimal")
|
||||
expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
|
||||
result = str(expr.compile(dialect=oracle.dialect()))
|
||||
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from sqlalchemy import column
|
||||
|
||||
from superset.db_engine_specs.pinot import PinotEngineSpec
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class PinotTestCase(DbEngineSpecTestCase):
|
||||
""" Tests pertaining to our Pinot database support """
|
||||
|
||||
def test_pinot_time_expression_sec_one_1m_grain(self):
|
||||
col = column("tstamp")
|
||||
expr = PinotEngineSpec.get_timestamp_expr(col, "epoch_s", "P1M")
|
||||
result = str(expr.compile())
|
||||
self.assertEqual(
|
||||
result,
|
||||
'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")',
|
||||
) # noqa
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
from sqlalchemy import column, literal_column
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class PostgresTests(DbEngineSpecTestCase):
|
||||
def test_get_table_names(self):
|
||||
""" Make sure postgres doesn't try to remove schema name from table name
|
||||
ie. when try_remove_schema_from_table_name == False. """
|
||||
inspector = mock.Mock()
|
||||
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
|
||||
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
|
||||
|
||||
pg_result_expected = ["schema.table", "table_2", "table_3"]
|
||||
pg_result = PostgresEngineSpec.get_table_names(
|
||||
database=mock.ANY, schema="schema", inspector=inspector
|
||||
)
|
||||
self.assertListEqual(pg_result_expected, pg_result)
|
||||
|
||||
def test_time_exp_literal_no_grain(self):
|
||||
col = literal_column("COALESCE(a, b)")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "COALESCE(a, b)")
|
||||
|
||||
def test_time_exp_literal_1y_grain(self):
|
||||
col = literal_column("COALESCE(a, b)")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
|
||||
|
||||
def test_time_ex_lowr_col_no_grain(self):
|
||||
col = column("lower_case")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "lower_case")
|
||||
|
||||
def test_time_exp_lowr_col_sec_1y(self):
|
||||
col = column("lower_case")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y")
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(
|
||||
result,
|
||||
"DATE_TRUNC('year', "
|
||||
"(timestamp 'epoch' + lower_case * interval '1 second'))",
|
||||
)
|
||||
|
||||
def test_time_exp_mixd_case_col_1y(self):
|
||||
col = column("MixedCase")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
|
||||
|
|
@ -0,0 +1,343 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock, skipUnless
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
class PrestoTests(DbEngineSpecTestCase):
|
||||
@skipUnless(
|
||||
DbEngineSpecTestCase.is_module_installed("pyhive"), "pyhive not installed"
|
||||
)
|
||||
def test_get_datatype_presto(self):
|
||||
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
|
||||
|
||||
def test_presto_get_view_names_return_empty_list(
|
||||
self
|
||||
): # pylint: disable=invalid-name
|
||||
self.assertEqual(
|
||||
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
|
||||
)
|
||||
|
||||
def verify_presto_column(self, column, expected_results):
|
||||
inspector = mock.Mock()
|
||||
inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
|
||||
keymap = {
|
||||
"Column": (None, None, 0),
|
||||
"Type": (None, None, 1),
|
||||
"Null": (None, None, 2),
|
||||
}
|
||||
row = RowProxy(mock.Mock(), column, [None, None, None, None], keymap)
|
||||
inspector.bind.execute = mock.Mock(return_value=[row])
|
||||
results = PrestoEngineSpec.get_columns(inspector, "", "")
|
||||
self.assertEqual(len(expected_results), len(results))
|
||||
for expected_result, result in zip(expected_results, results):
|
||||
self.assertEqual(expected_result[0], result["name"])
|
||||
self.assertEqual(expected_result[1], str(result["type"]))
|
||||
|
||||
def test_presto_get_column(self):
|
||||
presto_column = ("column_name", "boolean", "")
|
||||
expected_results = [("column_name", "BOOLEAN")]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_simple_row_column(self):
|
||||
presto_column = ("column_name", "row(nested_obj double)", "")
|
||||
expected_results = [("column_name", "ROW"), ("column_name.nested_obj", "FLOAT")]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_simple_row_column_with_name_containing_whitespace(self):
|
||||
presto_column = ("column name", "row(nested_obj double)", "")
|
||||
expected_results = [("column name", "ROW"), ("column name.nested_obj", "FLOAT")]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_simple_row_column_with_tricky_nested_field_name(self):
|
||||
presto_column = ("column_name", 'row("Field Name(Tricky, Name)" double)', "")
|
||||
expected_results = [
|
||||
("column_name", "ROW"),
|
||||
('column_name."Field Name(Tricky, Name)"', "FLOAT"),
|
||||
]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_simple_array_column(self):
|
||||
presto_column = ("column_name", "array(double)", "")
|
||||
expected_results = [("column_name", "ARRAY")]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_row_within_array_within_row_column(self):
|
||||
presto_column = (
|
||||
"column_name",
|
||||
"row(nested_array array(row(nested_row double)), nested_obj double)",
|
||||
"",
|
||||
)
|
||||
expected_results = [
|
||||
("column_name", "ROW"),
|
||||
("column_name.nested_array", "ARRAY"),
|
||||
("column_name.nested_array.nested_row", "FLOAT"),
|
||||
("column_name.nested_obj", "FLOAT"),
|
||||
]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_array_within_row_within_array_column(self):
|
||||
presto_column = (
|
||||
"column_name",
|
||||
"array(row(nested_array array(double), nested_obj double))",
|
||||
"",
|
||||
)
|
||||
expected_results = [
|
||||
("column_name", "ARRAY"),
|
||||
("column_name.nested_array", "ARRAY"),
|
||||
("column_name.nested_obj", "FLOAT"),
|
||||
]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
def test_presto_get_fields(self):
|
||||
cols = [
|
||||
{"name": "column"},
|
||||
{"name": "column.nested_obj"},
|
||||
{"name": 'column."quoted.nested obj"'},
|
||||
]
|
||||
actual_results = PrestoEngineSpec._get_fields(cols)
|
||||
expected_results = [
|
||||
{"name": '"column"', "label": "column"},
|
||||
{"name": '"column"."nested_obj"', "label": "column.nested_obj"},
|
||||
{
|
||||
"name": '"column"."quoted.nested obj"',
|
||||
"label": 'column."quoted.nested obj"',
|
||||
},
|
||||
]
|
||||
for actual_result, expected_result in zip(actual_results, expected_results):
|
||||
self.assertEqual(actual_result.element.name, expected_result["name"])
|
||||
self.assertEqual(actual_result.name, expected_result["label"])
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_expand_data_with_simple_structural_columns(self):
|
||||
cols = [
|
||||
{"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"},
|
||||
{"name": "array_column", "type": "ARRAY(BIGINT)"},
|
||||
]
|
||||
data = [
|
||||
{"row_column": ["a"], "array_column": [1, 2, 3]},
|
||||
{"row_column": ["b"], "array_column": [4, 5, 6]},
|
||||
]
|
||||
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
|
||||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"},
|
||||
{"name": "row_column.nested_obj", "type": "VARCHAR"},
|
||||
{"name": "array_column", "type": "ARRAY(BIGINT)"},
|
||||
]
|
||||
|
||||
expected_data = [
|
||||
{"array_column": 1, "row_column": ["a"], "row_column.nested_obj": "a"},
|
||||
{"array_column": 2, "row_column": "", "row_column.nested_obj": ""},
|
||||
{"array_column": 3, "row_column": "", "row_column.nested_obj": ""},
|
||||
{"array_column": 4, "row_column": ["b"], "row_column.nested_obj": "b"},
|
||||
{"array_column": 5, "row_column": "", "row_column.nested_obj": ""},
|
||||
{"array_column": 6, "row_column": "", "row_column.nested_obj": ""},
|
||||
]
|
||||
|
||||
expected_expanded_cols = [{"name": "row_column.nested_obj", "type": "VARCHAR"}]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_expand_data_with_complex_row_columns(self):
|
||||
cols = [
|
||||
{
|
||||
"name": "row_column",
|
||||
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
|
||||
}
|
||||
]
|
||||
data = [{"row_column": ["a1", ["a2"]]}, {"row_column": ["b1", ["b2"]]}]
|
||||
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
|
||||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{
|
||||
"name": "row_column",
|
||||
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
|
||||
},
|
||||
{"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
|
||||
{"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
|
||||
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
|
||||
]
|
||||
expected_data = [
|
||||
{
|
||||
"row_column": ["a1", ["a2"]],
|
||||
"row_column.nested_obj1": "a1",
|
||||
"row_column.nested_row": ["a2"],
|
||||
"row_column.nested_row.nested_obj2": "a2",
|
||||
},
|
||||
{
|
||||
"row_column": ["b1", ["b2"]],
|
||||
"row_column.nested_obj1": "b1",
|
||||
"row_column.nested_row": ["b2"],
|
||||
"row_column.nested_row.nested_obj2": "b2",
|
||||
},
|
||||
]
|
||||
|
||||
expected_expanded_cols = [
|
||||
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
|
||||
{"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
|
||||
{"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_expand_data_with_complex_array_columns(self):
|
||||
cols = [
|
||||
{"name": "int_column", "type": "BIGINT"},
|
||||
{
|
||||
"name": "array_column",
|
||||
"type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
|
||||
},
|
||||
]
|
||||
data = [
|
||||
{"int_column": 1, "array_column": [[[["a"], ["b"]]], [[["c"], ["d"]]]]},
|
||||
{"int_column": 2, "array_column": [[[["e"], ["f"]]], [[["g"], ["h"]]]]},
|
||||
]
|
||||
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
|
||||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{"name": "int_column", "type": "BIGINT"},
|
||||
{
|
||||
"name": "array_column",
|
||||
"type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
|
||||
},
|
||||
{
|
||||
"name": "array_column.nested_array",
|
||||
"type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
|
||||
},
|
||||
{"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"},
|
||||
]
|
||||
expected_data = [
|
||||
{
|
||||
"array_column": [[["a"], ["b"]]],
|
||||
"array_column.nested_array": ["a"],
|
||||
"array_column.nested_array.nested_obj": "a",
|
||||
"int_column": 1,
|
||||
},
|
||||
{
|
||||
"array_column": "",
|
||||
"array_column.nested_array": ["b"],
|
||||
"array_column.nested_array.nested_obj": "b",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": [[["c"], ["d"]]],
|
||||
"array_column.nested_array": ["c"],
|
||||
"array_column.nested_array.nested_obj": "c",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": "",
|
||||
"array_column.nested_array": ["d"],
|
||||
"array_column.nested_array.nested_obj": "d",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": [[["e"], ["f"]]],
|
||||
"array_column.nested_array": ["e"],
|
||||
"array_column.nested_array.nested_obj": "e",
|
||||
"int_column": 2,
|
||||
},
|
||||
{
|
||||
"array_column": "",
|
||||
"array_column.nested_array": ["f"],
|
||||
"array_column.nested_array.nested_obj": "f",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": [[["g"], ["h"]]],
|
||||
"array_column.nested_array": ["g"],
|
||||
"array_column.nested_array.nested_obj": "g",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": "",
|
||||
"array_column.nested_array": ["h"],
|
||||
"array_column.nested_array.nested_obj": "h",
|
||||
"int_column": "",
|
||||
},
|
||||
]
|
||||
expected_expanded_cols = [
|
||||
{
|
||||
"name": "array_column.nested_array",
|
||||
"type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
|
||||
},
|
||||
{"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"},
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
def test_presto_extra_table_metadata(self):
|
||||
db = mock.Mock()
|
||||
db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
|
||||
db.get_extra = mock.Mock(return_value={})
|
||||
df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
|
||||
db.get_df = mock.Mock(return_value=df)
|
||||
PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
|
||||
result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
|
||||
self.assertEqual({"ds": "01-01-19", "hour": 1}, result["partitions"]["latest"])
|
||||
|
||||
def test_presto_where_latest_partition(self):
|
||||
db = mock.Mock()
|
||||
db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
|
||||
db.get_extra = mock.Mock(return_value={})
|
||||
df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
|
||||
db.get_df = mock.Mock(return_value=df)
|
||||
columns = [{"name": "ds"}, {"name": "hour"}]
|
||||
result = PrestoEngineSpec.where_latest_partition(
|
||||
"test_table", "test_schema", db, select(), columns
|
||||
)
|
||||
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
|
||||
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
|
||||
|
|
@ -1,810 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import column, literal_column, table
|
||||
from sqlalchemy.dialects import mssql, oracle, postgresql
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy.types import String, UnicodeText
|
||||
|
||||
from superset import app
|
||||
from superset.db_engine_specs import engines
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, builtin_time_grains
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.db_engine_specs.oracle import OracleEngineSpec
|
||||
from superset.db_engine_specs.pinot import PinotEngineSpec
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import get_example_database
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class DbEngineSpecsTestCase(SupersetTestCase):
|
||||
def test_0_progress(self):
|
||||
log = """
|
||||
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
|
||||
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_number_of_jobs_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_stage_1_0_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_stage_1_map_40_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(10, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_stage_1_map_80_reduce_40_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(30, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_1_launched_stage_2_stages_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(12, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_job_2_launched_stage_2_stages_progress(self):
|
||||
log = """
|
||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
|
||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
||||
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
self.assertEqual(60, HiveEngineSpec.progress(log))
|
||||
|
||||
def test_hive_error_msg(self):
|
||||
msg = (
|
||||
'{...} errorMessage="Error while compiling statement: FAILED: '
|
||||
"SemanticException [Error 10001]: Line 4"
|
||||
":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, "
|
||||
"sqlState='42S02', errorCode=10001)){...}"
|
||||
)
|
||||
self.assertEqual(
|
||||
(
|
||||
"hive error: Error while compiling statement: FAILED: "
|
||||
"SemanticException [Error 10001]: Line 4:5 "
|
||||
"Table not found 'fact_ridesfdslakj'"
|
||||
),
|
||||
HiveEngineSpec.extract_error_message(Exception(msg)),
|
||||
)
|
||||
|
||||
e = Exception("Some string that doesn't match the regex")
|
||||
self.assertEqual(f"hive error: {e}", HiveEngineSpec.extract_error_message(e))
|
||||
|
||||
msg = (
|
||||
"errorCode=10001, "
|
||||
'errorMessage="Error while compiling statement"), operationHandle'
|
||||
'=None)"'
|
||||
)
|
||||
self.assertEqual(
|
||||
("hive error: Error while compiling statement"),
|
||||
HiveEngineSpec.extract_error_message(Exception(msg)),
|
||||
)
|
||||
|
||||
def get_generic_database(self):
|
||||
return Database(database_name="test_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
def sql_limit_regex(
|
||||
self, sql, expected_sql, engine_spec_class=MySQLEngineSpec, limit=1000
|
||||
):
|
||||
main = self.get_generic_database()
|
||||
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
|
||||
self.assertEqual(expected_sql, limited)
|
||||
|
||||
def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec):
|
||||
q0 = "select * from table"
|
||||
q1 = "select * from mytable limit 10"
|
||||
q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
|
||||
q3 = "select * from (select * from my_subquery limit 10);"
|
||||
q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
|
||||
q5 = "select * from mytable limit 20, 10"
|
||||
q6 = "select * from mytable limit 10 offset 20"
|
||||
q7 = "select * from mytable limit"
|
||||
q8 = "select * from mytable limit 10.0"
|
||||
q9 = "select * from mytable limit x"
|
||||
q10 = "select * from mytable limit 20, x"
|
||||
q11 = "select * from mytable limit x offset 20"
|
||||
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
|
||||
|
||||
def test_wrapped_query(self):
|
||||
self.sql_limit_regex(
|
||||
"SELECT * FROM a",
|
||||
"SELECT * \nFROM (SELECT * FROM a) AS inner_qry\n LIMIT 1000 OFFSET 0",
|
||||
MssqlEngineSpec,
|
||||
)
|
||||
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
|
||||
)
|
||||
def test_wrapped_semi_tabs(self):
|
||||
self.sql_limit_regex(
|
||||
"SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
|
||||
)
|
||||
|
||||
def test_simple_limit_query(self):
|
||||
self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
|
||||
|
||||
def test_modify_limit_query(self):
|
||||
self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
|
||||
|
||||
def test_limit_query_with_limit_subquery(self):
|
||||
self.sql_limit_regex(
|
||||
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
|
||||
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
|
||||
)
|
||||
|
||||
def test_limit_with_expr(self):
|
||||
self.sql_limit_regex(
|
||||
"""
|
||||
SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990""",
|
||||
"""SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 1000""",
|
||||
)
|
||||
|
||||
def test_limit_expr_and_semicolon(self):
|
||||
self.sql_limit_regex(
|
||||
"""
|
||||
SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990 ;""",
|
||||
"""SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 1000""",
|
||||
)
|
||||
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
|
||||
)
|
||||
def test_get_datatype_mysql(self):
|
||||
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
|
||||
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
|
||||
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
|
||||
)
|
||||
def test_get_datatype_presto(self):
|
||||
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
|
||||
|
||||
def test_get_datatype(self):
|
||||
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
|
||||
|
||||
def test_limit_with_implicit_offset(self):
|
||||
self.sql_limit_regex(
|
||||
"""
|
||||
SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990, 999999""",
|
||||
"""SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990, 1000""",
|
||||
)
|
||||
|
||||
def test_limit_with_explicit_offset(self):
|
||||
self.sql_limit_regex(
|
||||
"""
|
||||
SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 99990
|
||||
OFFSET 999999""",
|
||||
"""SELECT
|
||||
'LIMIT 777' AS a
|
||||
, b
|
||||
FROM
|
||||
table
|
||||
LIMIT 1000
|
||||
OFFSET 999999""",
|
||||
)
|
||||
|
||||
def test_limit_with_non_token_limit(self):
|
||||
self.sql_limit_regex(
|
||||
"""SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
|
||||
)
|
||||
|
||||
def test_time_grain_blacklist(self):
|
||||
with app.app_context():
|
||||
app.config["TIME_GRAIN_BLACKLIST"] = ["PT1M"]
|
||||
time_grain_functions = SqliteEngineSpec.get_time_grain_functions()
|
||||
self.assertNotIn("PT1M", time_grain_functions)
|
||||
|
||||
def test_time_grain_addons(self):
|
||||
with app.app_context():
|
||||
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
|
||||
app.config["TIME_GRAIN_ADDON_FUNCTIONS"] = {
|
||||
"sqlite": {"PTXM": "ABC({col})"}
|
||||
}
|
||||
time_grains = SqliteEngineSpec.get_time_grains()
|
||||
time_grain_addon = time_grains[-1]
|
||||
self.assertEqual("PTXM", time_grain_addon.duration)
|
||||
self.assertEqual("x seconds", time_grain_addon.label)
|
||||
|
||||
def test_engine_time_grain_validity(self):
|
||||
time_grains = set(builtin_time_grains.keys())
|
||||
# loop over all subclasses of BaseEngineSpec
|
||||
for engine in engines.values():
|
||||
if engine is not BaseEngineSpec:
|
||||
# make sure time grain functions have been defined
|
||||
self.assertGreater(len(engine.get_time_grain_functions()), 0)
|
||||
# make sure all defined time grains are supported
|
||||
defined_grains = {grain.duration for grain in engine.get_time_grains()}
|
||||
intersection = time_grains.intersection(defined_grains)
|
||||
self.assertSetEqual(defined_grains, intersection, engine)
|
||||
|
||||
def test_presto_get_view_names_return_empty_list(self):
|
||||
self.assertEqual(
|
||||
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
|
||||
)
|
||||
|
||||
def verify_presto_column(self, column, expected_results):
|
||||
inspector = mock.Mock()
|
||||
inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
|
||||
keymap = {
|
||||
"Column": (None, None, 0),
|
||||
"Type": (None, None, 1),
|
||||
"Null": (None, None, 2),
|
||||
}
|
||||
row = RowProxy(mock.Mock(), column, [None, None, None, None], keymap)
|
||||
inspector.bind.execute = mock.Mock(return_value=[row])
|
||||
results = PrestoEngineSpec.get_columns(inspector, "", "")
|
||||
self.assertEqual(len(expected_results), len(results))
|
||||
for expected_result, result in zip(expected_results, results):
|
||||
self.assertEqual(expected_result[0], result["name"])
|
||||
self.assertEqual(expected_result[1], str(result["type"]))
|
||||
|
||||
def test_presto_get_column(self):
|
||||
presto_column = ("column_name", "boolean", "")
|
||||
expected_results = [("column_name", "BOOLEAN")]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_simple_row_column(self):
|
||||
presto_column = ("column_name", "row(nested_obj double)", "")
|
||||
expected_results = [("column_name", "ROW"), ("column_name.nested_obj", "FLOAT")]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_simple_row_column_with_name_containing_whitespace(self):
|
||||
presto_column = ("column name", "row(nested_obj double)", "")
|
||||
expected_results = [("column name", "ROW"), ("column name.nested_obj", "FLOAT")]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_simple_row_column_with_tricky_nested_field_name(self):
|
||||
presto_column = ("column_name", 'row("Field Name(Tricky, Name)" double)', "")
|
||||
expected_results = [
|
||||
("column_name", "ROW"),
|
||||
('column_name."Field Name(Tricky, Name)"', "FLOAT"),
|
||||
]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_simple_array_column(self):
|
||||
presto_column = ("column_name", "array(double)", "")
|
||||
expected_results = [("column_name", "ARRAY")]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_row_within_array_within_row_column(self):
|
||||
presto_column = (
|
||||
"column_name",
|
||||
"row(nested_array array(row(nested_row double)), nested_obj double)",
|
||||
"",
|
||||
)
|
||||
expected_results = [
|
||||
("column_name", "ROW"),
|
||||
("column_name.nested_array", "ARRAY"),
|
||||
("column_name.nested_array.nested_row", "FLOAT"),
|
||||
("column_name.nested_obj", "FLOAT"),
|
||||
]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_get_array_within_row_within_array_column(self):
|
||||
presto_column = (
|
||||
"column_name",
|
||||
"array(row(nested_array array(double), nested_obj double))",
|
||||
"",
|
||||
)
|
||||
expected_results = [
|
||||
("column_name", "ARRAY"),
|
||||
("column_name.nested_array", "ARRAY"),
|
||||
("column_name.nested_obj", "FLOAT"),
|
||||
]
|
||||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
def test_presto_get_fields(self):
|
||||
cols = [
|
||||
{"name": "column"},
|
||||
{"name": "column.nested_obj"},
|
||||
{"name": 'column."quoted.nested obj"'},
|
||||
]
|
||||
actual_results = PrestoEngineSpec._get_fields(cols)
|
||||
expected_results = [
|
||||
{"name": '"column"', "label": "column"},
|
||||
{"name": '"column"."nested_obj"', "label": "column.nested_obj"},
|
||||
{
|
||||
"name": '"column"."quoted.nested obj"',
|
||||
"label": 'column."quoted.nested obj"',
|
||||
},
|
||||
]
|
||||
for actual_result, expected_result in zip(actual_results, expected_results):
|
||||
self.assertEqual(actual_result.element.name, expected_result["name"])
|
||||
self.assertEqual(actual_result.name, expected_result["label"])
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_expand_data_with_simple_structural_columns(self):
|
||||
cols = [
|
||||
{"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"},
|
||||
{"name": "array_column", "type": "ARRAY(BIGINT)"},
|
||||
]
|
||||
data = [
|
||||
{"row_column": ["a"], "array_column": [1, 2, 3]},
|
||||
{"row_column": ["b"], "array_column": [4, 5, 6]},
|
||||
]
|
||||
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
|
||||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"},
|
||||
{"name": "row_column.nested_obj", "type": "VARCHAR"},
|
||||
{"name": "array_column", "type": "ARRAY(BIGINT)"},
|
||||
]
|
||||
|
||||
expected_data = [
|
||||
{"array_column": 1, "row_column": ["a"], "row_column.nested_obj": "a"},
|
||||
{"array_column": 2, "row_column": "", "row_column.nested_obj": ""},
|
||||
{"array_column": 3, "row_column": "", "row_column.nested_obj": ""},
|
||||
{"array_column": 4, "row_column": ["b"], "row_column.nested_obj": "b"},
|
||||
{"array_column": 5, "row_column": "", "row_column.nested_obj": ""},
|
||||
{"array_column": 6, "row_column": "", "row_column.nested_obj": ""},
|
||||
]
|
||||
|
||||
expected_expanded_cols = [{"name": "row_column.nested_obj", "type": "VARCHAR"}]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_expand_data_with_complex_row_columns(self):
|
||||
cols = [
|
||||
{
|
||||
"name": "row_column",
|
||||
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
|
||||
}
|
||||
]
|
||||
data = [{"row_column": ["a1", ["a2"]]}, {"row_column": ["b1", ["b2"]]}]
|
||||
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
|
||||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{
|
||||
"name": "row_column",
|
||||
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
|
||||
},
|
||||
{"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
|
||||
{"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
|
||||
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
|
||||
]
|
||||
expected_data = [
|
||||
{
|
||||
"row_column": ["a1", ["a2"]],
|
||||
"row_column.nested_obj1": "a1",
|
||||
"row_column.nested_row": ["a2"],
|
||||
"row_column.nested_row.nested_obj2": "a2",
|
||||
},
|
||||
{
|
||||
"row_column": ["b1", ["b2"]],
|
||||
"row_column.nested_obj1": "b1",
|
||||
"row_column.nested_row": ["b2"],
|
||||
"row_column.nested_row.nested_obj2": "b2",
|
||||
},
|
||||
]
|
||||
|
||||
expected_expanded_cols = [
|
||||
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
|
||||
{"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
|
||||
{"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
)
|
||||
def test_presto_expand_data_with_complex_array_columns(self):
|
||||
cols = [
|
||||
{"name": "int_column", "type": "BIGINT"},
|
||||
{
|
||||
"name": "array_column",
|
||||
"type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
|
||||
},
|
||||
]
|
||||
data = [
|
||||
{"int_column": 1, "array_column": [[[["a"], ["b"]]], [[["c"], ["d"]]]]},
|
||||
{"int_column": 2, "array_column": [[[["e"], ["f"]]], [[["g"], ["h"]]]]},
|
||||
]
|
||||
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
|
||||
cols, data
|
||||
)
|
||||
expected_cols = [
|
||||
{"name": "int_column", "type": "BIGINT"},
|
||||
{
|
||||
"name": "array_column",
|
||||
"type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))",
|
||||
},
|
||||
{
|
||||
"name": "array_column.nested_array",
|
||||
"type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
|
||||
},
|
||||
{"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"},
|
||||
]
|
||||
expected_data = [
|
||||
{
|
||||
"array_column": [[["a"], ["b"]]],
|
||||
"array_column.nested_array": ["a"],
|
||||
"array_column.nested_array.nested_obj": "a",
|
||||
"int_column": 1,
|
||||
},
|
||||
{
|
||||
"array_column": "",
|
||||
"array_column.nested_array": ["b"],
|
||||
"array_column.nested_array.nested_obj": "b",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": [[["c"], ["d"]]],
|
||||
"array_column.nested_array": ["c"],
|
||||
"array_column.nested_array.nested_obj": "c",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": "",
|
||||
"array_column.nested_array": ["d"],
|
||||
"array_column.nested_array.nested_obj": "d",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": [[["e"], ["f"]]],
|
||||
"array_column.nested_array": ["e"],
|
||||
"array_column.nested_array.nested_obj": "e",
|
||||
"int_column": 2,
|
||||
},
|
||||
{
|
||||
"array_column": "",
|
||||
"array_column.nested_array": ["f"],
|
||||
"array_column.nested_array.nested_obj": "f",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": [[["g"], ["h"]]],
|
||||
"array_column.nested_array": ["g"],
|
||||
"array_column.nested_array.nested_obj": "g",
|
||||
"int_column": "",
|
||||
},
|
||||
{
|
||||
"array_column": "",
|
||||
"array_column.nested_array": ["h"],
|
||||
"array_column.nested_array.nested_obj": "h",
|
||||
"int_column": "",
|
||||
},
|
||||
]
|
||||
expected_expanded_cols = [
|
||||
{
|
||||
"name": "array_column.nested_array",
|
||||
"type": "ARRAY(ROW(NESTED_OBJ VARCHAR))",
|
||||
},
|
||||
{"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"},
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
def test_presto_extra_table_metadata(self):
|
||||
db = mock.Mock()
|
||||
db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
|
||||
db.get_extra = mock.Mock(return_value={})
|
||||
df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
|
||||
db.get_df = mock.Mock(return_value=df)
|
||||
PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
|
||||
result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
|
||||
self.assertEqual({"ds": "01-01-19", "hour": 1}, result["partitions"]["latest"])
|
||||
|
||||
def test_presto_where_latest_partition(self):
|
||||
db = mock.Mock()
|
||||
db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
|
||||
db.get_extra = mock.Mock(return_value={})
|
||||
df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
|
||||
db.get_df = mock.Mock(return_value=df)
|
||||
columns = [{"name": "ds"}, {"name": "hour"}]
|
||||
result = PrestoEngineSpec.where_latest_partition(
|
||||
"test_table", "test_schema", db, select(), columns
|
||||
)
|
||||
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
|
||||
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
|
||||
|
||||
def test_hive_get_view_names_return_empty_list(self):
|
||||
self.assertEqual(
|
||||
[], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
|
||||
)
|
||||
|
||||
def test_bigquery_sqla_column_label(self):
|
||||
label = BigQueryEngineSpec.make_label_compatible(column("Col").name)
|
||||
label_expected = "Col"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BigQueryEngineSpec.make_label_compatible(column("SUM(x)").name)
|
||||
label_expected = "SUM_x__5f110"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BigQueryEngineSpec.make_label_compatible(column("SUM[x]").name)
|
||||
label_expected = "SUM_x__7ebe1"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
label = BigQueryEngineSpec.make_label_compatible(column("12345_col").name)
|
||||
label_expected = "_12345_col_8d390"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
def test_oracle_sqla_column_name_length_exceeded(self):
|
||||
col = column("This_Is_32_Character_Column_Name")
|
||||
label = OracleEngineSpec.make_label_compatible(col.name)
|
||||
self.assertEqual(label.quote, True)
|
||||
label_expected = "3b26974078683be078219674eeb8f5"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
def test_mssql_column_types(self):
|
||||
def assert_type(type_string, type_expected):
|
||||
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
|
||||
if type_expected is None:
|
||||
self.assertIsNone(type_assigned)
|
||||
else:
|
||||
self.assertIsInstance(type_assigned, type_expected)
|
||||
|
||||
assert_type("INT", None)
|
||||
assert_type("STRING", String)
|
||||
assert_type("CHAR(10)", String)
|
||||
assert_type("VARCHAR(10)", String)
|
||||
assert_type("TEXT", String)
|
||||
assert_type("NCHAR(10)", UnicodeText)
|
||||
assert_type("NVARCHAR(10)", UnicodeText)
|
||||
assert_type("NTEXT", UnicodeText)
|
||||
|
||||
def test_mssql_where_clause_n_prefix(self):
|
||||
dialect = mssql.dialect()
|
||||
spec = MssqlEngineSpec
|
||||
str_col = column("col", type_=spec.get_sqla_column_type("VARCHAR(10)"))
|
||||
unicode_col = column("unicode_col", type_=spec.get_sqla_column_type("NTEXT"))
|
||||
tbl = table("tbl")
|
||||
sel = (
|
||||
select([str_col, unicode_col])
|
||||
.select_from(tbl)
|
||||
.where(str_col == "abc")
|
||||
.where(unicode_col == "abc")
|
||||
)
|
||||
|
||||
query = str(
|
||||
sel.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
||||
)
|
||||
query_expected = (
|
||||
"SELECT col, unicode_col \n"
|
||||
"FROM tbl \n"
|
||||
"WHERE col = 'abc' AND unicode_col = N'abc'"
|
||||
)
|
||||
self.assertEqual(query, query_expected)
|
||||
|
||||
def test_get_table_names(self):
|
||||
inspector = mock.Mock()
|
||||
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
|
||||
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
|
||||
|
||||
""" Make sure base engine spec removes schema name from table name
|
||||
ie. when try_remove_schema_from_table_name == True. """
|
||||
base_result_expected = ["table", "table_2"]
|
||||
base_result = BaseEngineSpec.get_table_names(
|
||||
database=mock.ANY, schema="schema", inspector=inspector
|
||||
)
|
||||
self.assertListEqual(base_result_expected, base_result)
|
||||
|
||||
""" Make sure postgres doesn't try to remove schema name from table name
|
||||
ie. when try_remove_schema_from_table_name == False. """
|
||||
pg_result_expected = ["schema.table", "table_2", "table_3"]
|
||||
pg_result = PostgresEngineSpec.get_table_names(
|
||||
database=mock.ANY, schema="schema", inspector=inspector
|
||||
)
|
||||
self.assertListEqual(pg_result_expected, pg_result)
|
||||
|
||||
def test_pg_time_expression_literal_no_grain(self):
|
||||
col = literal_column("COALESCE(a, b)")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "COALESCE(a, b)")
|
||||
|
||||
def test_pg_time_expression_literal_1y_grain(self):
|
||||
col = literal_column("COALESCE(a, b)")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
|
||||
|
||||
def test_pg_time_expression_lower_column_no_grain(self):
|
||||
col = column("lower_case")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "lower_case")
|
||||
|
||||
def test_pg_time_expression_lower_case_column_sec_1y_grain(self):
|
||||
col = column("lower_case")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y")
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(
|
||||
result,
|
||||
"DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))",
|
||||
)
|
||||
|
||||
def test_pg_time_expression_mixed_case_column_1y_grain(self):
|
||||
col = column("MixedCase")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
|
||||
|
||||
def test_mssql_time_expression_mixed_case_column_1y_grain(self):
|
||||
col = column("MixedCase")
|
||||
expr = MssqlEngineSpec.get_timestamp_expr(col, None, "P1Y")
|
||||
result = str(expr.compile(dialect=mssql.dialect()))
|
||||
self.assertEqual(result, "DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)")
|
||||
|
||||
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
|
||||
col = column("decimal")
|
||||
expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
|
||||
result = str(expr.compile(dialect=oracle.dialect()))
|
||||
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
|
||||
|
||||
def test_pinot_time_expression_sec_1m_grain(self):
|
||||
col = column("tstamp")
|
||||
expr = PinotEngineSpec.get_timestamp_expr(col, "epoch_s", "P1M")
|
||||
result = str(expr.compile())
|
||||
self.assertEqual(
|
||||
result,
|
||||
'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")',
|
||||
)
|
||||
|
||||
def test_column_datatype_to_string(self):
|
||||
example_db = get_example_database()
|
||||
sqla_table = example_db.get_table("energy_usage")
|
||||
dialect = example_db.get_dialect()
|
||||
col_names = [
|
||||
example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
|
||||
for c in sqla_table.columns
|
||||
]
|
||||
if example_db.backend == "postgresql":
|
||||
expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
|
||||
else:
|
||||
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
|
||||
self.assertEqual(col_names, expected)
|
||||
Loading…
Reference in New Issue