chore: upgrade SQLAlchemy to 1.4 (#19890)

* chore: upgrade SQLAlchemy

* Convert integration test to unit test

* Fix SQLite

* Update method names/docstrings

* Skip test

* Fix SQLite
This commit is contained in:
Beto Dealmeida 2022-07-18 15:21:38 -07:00 committed by GitHub
parent 90600d1883
commit e60083b45b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 656 additions and 255 deletions

View File

@ -122,6 +122,8 @@ geopy==2.2.0
# via apache-superset
graphlib-backport==1.0.3
# via apache-superset
greenlet==1.1.2
# via sqlalchemy
gunicorn==20.1.0
# via apache-superset
hashids==1.3.1
@ -259,7 +261,7 @@ six==1.16.0
# wtforms-json
slackclient==2.5.0
# via apache-superset
sqlalchemy==1.3.24
sqlalchemy==1.4.36
# via
# alembic
# apache-superset

View File

@ -12,8 +12,6 @@
# -r requirements/docker.in
gevent==21.8.0
# via -r requirements/docker.in
greenlet==1.1.1
# via gevent
psycopg2-binary==2.9.1
# via apache-superset
zope-event==4.5.0

View File

@ -109,7 +109,7 @@ setup(
"selenium>=3.141.0",
"simplejson>=3.15.0",
"slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions
"sqlalchemy>=1.3.16, <1.4, !=1.3.21",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.37.8, <0.38",
"sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562
"tabulate==0.8.9",

View File

@ -129,7 +129,6 @@ class ImportAssetsCommand(BaseCommand):
{"dashboard_id": dashboard_id, "slice_id": chart_id}
for (dashboard_id, chart_id) in dashboard_chart_ids
]
# pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656
session.execute(dashboard_slices.insert(), values)
def run(self) -> None:

View File

@ -181,5 +181,4 @@ class ImportExamplesCommand(ImportModelsCommand):
{"dashboard_id": dashboard_id, "slice_id": chart_id}
for (dashboard_id, chart_id) in dashboard_chart_ids
]
# pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656
session.execute(dashboard_slices.insert(), values)

View File

@ -139,5 +139,4 @@ class ImportDashboardsCommand(ImportModelsCommand):
{"dashboard_id": dashboard_id, "slice_id": chart_id}
for (dashboard_id, chart_id) in dashboard_chart_ids
]
# pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656
session.execute(dashboard_slices.insert(), values)

View File

@ -117,7 +117,9 @@ builtin_time_grains: Dict[Optional[str], str] = {
}
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method
class TimestampExpression(
ColumnClause
): # pylint: disable=abstract-method, too-many-ancestors
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
"""Sqlalchemy class that can be can be used to render native column elements
respeting engine-specific quoting rules as part of a string-based expression.
@ -933,9 +935,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
]
@classmethod
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None:
def adjust_database_uri( # pylint: disable=unused-argument
cls,
uri: URL,
selected_schema: Optional[str],
) -> URL:
"""
Mutate the database component of the SQLAlchemy URI.
Return a modified URL with a new database component.
The URI here represents the URI as entered when saving the database,
``selected_schema`` is the schema currently active presumably in
@ -949,9 +955,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
For those it's probably better to not alter the database
component of the URI with the schema name, it won't work.
Some database drivers like presto accept '{catalog}/{schema}' in
Some database drivers like Presto accept '{catalog}/{schema}' in
the database component of the URL, that can be handled here.
"""
return uri
@classmethod
def patch(cls) -> None:
@ -1206,17 +1213,20 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return costs
@classmethod
def modify_url_for_impersonation(
def get_url_for_impersonation(
cls, url: URL, impersonate_user: bool, username: Optional[str]
) -> None:
) -> URL:
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
Return a modified URL with the username set.
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
"""
if impersonate_user and username is not None:
url.username = username
url = url.set(username=username)
return url
@classmethod
def update_impersonation_config(

View File

@ -68,26 +68,31 @@ class DrillEngineSpec(BaseEngineSpec):
return None
@classmethod
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None:
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
if selected_schema:
uri.database = parse.quote(selected_schema, safe="")
uri = uri.set(database=parse.quote(selected_schema, safe=""))
return uri
@classmethod
def modify_url_for_impersonation(
def get_url_for_impersonation(
cls, url: URL, impersonate_user: bool, username: Optional[str]
) -> None:
) -> URL:
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
Return a modified URL with the username set.
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
"""
if impersonate_user and username is not None:
if url.drivername == "drill+odbc":
url.query["DelegationUID"] = username
url = url.update_query_dict({"DelegationUID": username})
elif url.drivername in ["drill+sadrill", "drill+jdbc"]:
url.query["impersonation_target"] = username
url = url.update_query_dict({"impersonation_target": username})
else:
raise SupersetDBAPIProgrammingError(
f"impersonation is not supported for {url.drivername}"
)
return url

View File

@ -81,16 +81,18 @@ class GSheetsEngineSpec(SqliteEngineSpec):
}
@classmethod
def modify_url_for_impersonation(
def get_url_for_impersonation(
cls,
url: URL,
impersonate_user: bool,
username: Optional[str],
) -> None:
) -> URL:
if impersonate_user and username is not None:
user = security_manager.find_user(username=username)
if user and user.email:
url.query["subject"] = user.email
url = url.update_query_dict({"subject": user.email})
return url
@classmethod
def extra_table_metadata(

View File

@ -269,9 +269,11 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
) -> URL:
if selected_schema:
uri.database = parse.quote(selected_schema, safe="")
uri = uri.set(database=parse.quote(selected_schema, safe=""))
return uri
@classmethod
def _extract_error_message(cls, ex: Exception) -> str:
@ -485,17 +487,19 @@ class HiveEngineSpec(PrestoEngineSpec):
)
@classmethod
def modify_url_for_impersonation(
def get_url_for_impersonation(
cls, url: URL, impersonate_user: bool, username: Optional[str]
) -> None:
) -> URL:
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
Return a modified URL with the username set.
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
"""
# Do nothing in the URL object since instead this should modify
# the configuraiton dictionary. See get_configuration_for_impersonation
return url
@classmethod
def update_impersonation_config(

View File

@ -193,9 +193,11 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
) -> URL:
if selected_schema:
uri.database = parse.quote(selected_schema, safe="")
uri = uri.set(database=parse.quote(selected_schema, safe=""))
return uri
@classmethod
def get_datatype(cls, type_code: Any) -> Optional[str]:

View File

@ -33,7 +33,7 @@ from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import Column, literal_column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.result import Row as ResultRow
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
@ -430,7 +430,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
@classmethod
def _show_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
) -> List[RowProxy]:
) -> List[ResultRow]:
"""
Show presto column names
:param inspector: object that performs database schema inspection
@ -729,7 +729,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
) -> URL:
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe="")
@ -737,7 +737,9 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
database = database.split("/")[0] + "/" + selected_schema
else:
database += "/" + selected_schema
uri.database = database
uri = uri.set(database=database)
return uri
@classmethod
def convert_dttm(

View File

@ -114,13 +114,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
) -> URL:
database = uri.database
if "/" in uri.database:
database = uri.database.split("/")[0]
if selected_schema:
selected_schema = parse.quote(selected_schema, safe="")
uri.database = database + "/" + selected_schema
uri = uri.set(database=f"{database}/{selected_schema}")
return uri
@classmethod
def epoch_to_dttm(cls) -> str:

View File

@ -65,16 +65,18 @@ class TrinoEngineSpec(PrestoEngineSpec):
connect_args["user"] = username
@classmethod
def modify_url_for_impersonation(
def get_url_for_impersonation(
cls, url: URL, impersonate_user: bool, username: Optional[str]
) -> None:
) -> URL:
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
Return a modified URL with the username set.
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
"""
# Do nothing and let update_impersonation_config take care of impersonation
return url
@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:

View File

@ -31,7 +31,7 @@ import logging
import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative.api import declarative_base
from sqlalchemy.ext.declarative import declarative_base
from superset import db

View File

@ -312,7 +312,7 @@ class Database(
def get_password_masked_url(cls, masked_url: URL) -> URL:
url_copy = deepcopy(masked_url)
if url_copy.password is not None:
url_copy.password = PASSWORD_MASK
url_copy = url_copy.set(password=PASSWORD_MASK)
return url_copy
def set_sqlalchemy_uri(self, uri: str) -> None:
@ -320,7 +320,7 @@ class Database(
if conn.password != PASSWORD_MASK and not custom_password_store:
# do not over-write the password with the password mask
self.password = conn.password
conn.password = PASSWORD_MASK if conn.password else None
conn = conn.set(password=PASSWORD_MASK if conn.password else None)
self.sqlalchemy_uri = str(conn) # hides the password
def get_effective_user(self, object_url: URL) -> Optional[str]:
@ -355,12 +355,12 @@ class Database(
) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted)
self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
# configuration parameter instead.
self.db_engine_spec.modify_url_for_impersonation(
sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
sqlalchemy_url, self.impersonate_user, effective_username
)
@ -736,9 +736,9 @@ class Database(
# (so users see 500 less often)
return "dialect://invalid_uri"
if custom_password_store:
conn.password = custom_password_store(conn)
conn = conn.set(password=custom_password_store(conn))
else:
conn.password = self.password
conn = conn.set(password=self.password)
return str(conn)
@property

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=abstract-method
# pylint: disable=abstract-method, no-init
from typing import Any, Dict, List, Optional, Type
from sqlalchemy.engine.interfaces import Dialect

View File

@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional
from flask import Flask
from sqlalchemy import text, TypeDecorator
from sqlalchemy.engine import Connection, Dialect, RowProxy
from sqlalchemy.engine import Connection, Dialect, Row
from sqlalchemy_utils import EncryptedType
logger = logging.getLogger(__name__)
@ -114,13 +114,13 @@ class SecretsMigrator:
@staticmethod
def _select_columns_from_table(
conn: Connection, column_names: List[str], table_name: str
) -> RowProxy:
) -> Row:
return conn.execute(f"SELECT id, {','.join(column_names)} FROM {table_name}")
def _re_encrypt_row(
self,
conn: Connection,
row: RowProxy,
row: Row,
table_name: str,
columns: Dict[str, EncryptedType],
) -> None:

View File

@ -206,11 +206,9 @@ def add_data(
metadata.create_all(engine)
if not append:
# pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656
engine.execute(table.delete())
data = generate_data(columns, num_rows)
# pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656
engine.execute(table.insert(), data)

View File

@ -28,7 +28,7 @@ from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import dialect

View File

@ -1,173 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import unittest
from typing import Any, Dict
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.test_app import app
from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.utils.database import get_or_create_db
FULL_DTTM_DEFAULTS_EXAMPLE = {
"main_dttm_col": "id",
"dttm_columns": {
"dttm": {
"python_date_format": "epoch_s",
"expression": "CAST(dttm as INTEGER)",
},
"id": {"python_date_format": "epoch_ms"},
"month": {
"python_date_format": "%Y-%m-%d",
"expression": "CASE WHEN length(month) = 7 THEN month || '-01' ELSE month END",
},
},
}
def apply_dttm_defaults(table: SqlaTable, dttm_defaults: Dict[str, Any]):
"""Applies dttm defaults to the table, mutates in place."""
for dbcol in table.columns:
# Set is_dttm is column is listed in dttm_columns.
if dbcol.column_name in dttm_defaults.get("dttm_columns", {}):
dbcol.is_dttm = True
# Skip non dttm columns.
if dbcol.column_name not in dttm_defaults.get("dttm_columns", {}):
continue
# Set table main_dttm_col.
if dbcol.column_name == dttm_defaults.get("main_dttm_col"):
table.main_dttm_col = dbcol.column_name
# Apply defaults if empty.
dttm_column_defaults = dttm_defaults.get("dttm_columns", {}).get(
dbcol.column_name, {}
)
dbcol.is_dttm = True
if (
not dbcol.python_date_format
and "python_date_format" in dttm_column_defaults
):
dbcol.python_date_format = dttm_column_defaults["python_date_format"]
if not dbcol.expression and "expression" in dttm_column_defaults:
dbcol.expression = dttm_column_defaults["expression"]
class TestConfig(SupersetTestCase):
def setUp(self) -> None:
self.login(username="admin")
self._test_db_id = get_or_create_db(
"column_test_db", app.config["SQLALCHEMY_DATABASE_URI"]
).id
self._old_sqla_table_mutator = app.config["SQLA_TABLE_MUTATOR"]
def createTable(self, dttm_defaults):
app.config["SQLA_TABLE_MUTATOR"] = lambda t: apply_dttm_defaults(
t, dttm_defaults
)
resp = self.client.post(
"/tablemodelview/add",
data=dict(database=self._test_db_id, table_name="logs"),
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
self._logs_table = (
db.session.query(SqlaTable).filter_by(table_name="logs").one()
)
def tearDown(self):
app.config["SQLA_TABLE_MUTATOR"] = self._old_sqla_table_mutator
if hasattr(self, "_logs_table"):
db.session.delete(self._logs_table)
db.session.delete(self._logs_table.database)
db.session.commit()
def test_main_dttm_col(self):
# Make sure that dttm column is set properly.
self.createTable({"main_dttm_col": "id", "dttm_columns": {"id": {}}})
self.assertEqual(self._logs_table.main_dttm_col, "id")
def test_main_dttm_col_nonexistent(self):
self.createTable({"main_dttm_col": "nonexistent"})
# Column doesn't exist, falls back to dttm.
self.assertEqual(self._logs_table.main_dttm_col, "dttm")
def test_main_dttm_col_nondttm(self):
self.createTable({"main_dttm_col": "duration_ms"})
# duration_ms is not dttm column, falls back to dttm.
self.assertEqual(self._logs_table.main_dttm_col, "dttm")
def test_python_date_format_by_column_name(self):
table_defaults = {
"dttm_columns": {
"id": {"python_date_format": "epoch_ms"},
"dttm": {"python_date_format": "epoch_s"},
"duration_ms": {"python_date_format": "invalid"},
}
}
self.createTable(table_defaults)
id_col = [c for c in self._logs_table.columns if c.column_name == "id"][0]
self.assertTrue(id_col.is_dttm)
self.assertEqual(id_col.python_date_format, "epoch_ms")
dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0]
self.assertTrue(dttm_col.is_dttm)
self.assertEqual(dttm_col.python_date_format, "epoch_s")
dms_col = [
c for c in self._logs_table.columns if c.column_name == "duration_ms"
][0]
self.assertTrue(dms_col.is_dttm)
self.assertEqual(dms_col.python_date_format, "invalid")
def test_expression_by_column_name(self):
table_defaults = {
"dttm_columns": {
"dttm": {"expression": "CAST(dttm as INTEGER)"},
"duration_ms": {"expression": "CAST(duration_ms as DOUBLE)"},
}
}
self.createTable(table_defaults)
dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0]
self.assertTrue(dttm_col.is_dttm)
self.assertEqual(dttm_col.expression, "CAST(dttm as INTEGER)")
dms_col = [
c for c in self._logs_table.columns if c.column_name == "duration_ms"
][0]
self.assertEqual(dms_col.expression, "CAST(duration_ms as DOUBLE)")
self.assertTrue(dms_col.is_dttm)
def test_full_setting(self):
self.createTable(FULL_DTTM_DEFAULTS_EXAMPLE)
self.assertEqual(self._logs_table.main_dttm_col, "id")
id_col = [c for c in self._logs_table.columns if c.column_name == "id"][0]
self.assertTrue(id_col.is_dttm)
self.assertEqual(id_col.python_date_format, "epoch_ms")
self.assertIsNone(id_col.expression)
dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0]
self.assertTrue(dttm_col.is_dttm)
self.assertEqual(dttm_col.python_date_format, "epoch_s")
self.assertEqual(dttm_col.expression, "CAST(dttm as INTEGER)")
if __name__ == "__main__":
unittest.main()

View File

@ -230,7 +230,7 @@ class TestCore(SupersetTestCase):
def test_get_superset_tables_substr(self):
example_db = superset.utils.database.get_example_database()
if example_db.backend in {"presto", "hive"}:
if example_db.backend in {"presto", "hive", "sqlite"}:
# TODO: change table to the real table that is in examples.
return
self.login(username="admin")

View File

@ -104,6 +104,10 @@ class TestDatasetApi(SupersetTestCase):
@pytest.fixture()
def create_virtual_datasets(self):
with self.create_app().app_context():
if backend() == "sqlite":
yield
return
datasets = []
admin = self.get_user("admin")
main_db = get_main_database()
@ -126,6 +130,10 @@ class TestDatasetApi(SupersetTestCase):
@pytest.fixture()
def create_datasets(self):
with self.create_app().app_context():
if backend() == "sqlite":
yield
return
datasets = []
admin = self.get_user("admin")
main_db = get_main_database()
@ -172,6 +180,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset list
"""
if backend() == "sqlite":
return
example_db = get_example_database()
self.login(username="admin")
arguments = {
@ -210,6 +221,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset list gamma
"""
if backend() == "sqlite":
return
self.login(username="gamma")
uri = "api/v1/dataset/"
rv = self.get_assert_metric(uri, "get_list")
@ -221,6 +235,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset list owned by gamma
"""
if backend() == "sqlite":
return
main_db = get_main_database()
owned_dataset = self.insert_dataset(
"ab_user", [self.get_user("gamma").id], main_db
@ -242,6 +259,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset related databases gamma
"""
if backend() == "sqlite":
return
self.login(username="gamma")
uri = "api/v1/dataset/related/database"
rv = self.client.get(uri)
@ -257,6 +277,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset item
"""
if backend() == "sqlite":
return
table = self.get_energy_usage_dataset()
main_db = get_main_database()
self.login(username="admin")
@ -297,6 +320,8 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset distinct schema
"""
if backend() == "sqlite":
return
def pg_test_query_parameter(query_parameter, expected_response):
uri = f"api/v1/dataset/distinct/schema?q={prison.dumps(query_parameter)}"
@ -367,6 +392,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset distinct not allowed
"""
if backend() == "sqlite":
return
self.login(username="admin")
uri = "api/v1/dataset/distinct/table_name"
rv = self.client.get(uri)
@ -376,6 +404,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset distinct with gamma
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="gamma")
@ -393,6 +424,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset info
"""
if backend() == "sqlite":
return
self.login(username="admin")
uri = "api/v1/dataset/_info"
rv = self.get_assert_metric(uri, "info")
@ -402,6 +436,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test info security
"""
if backend() == "sqlite":
return
self.login(username="admin")
params = {"keys": ["permissions"]}
uri = f"api/v1/dataset/_info?q={prison.dumps(params)}"
@ -414,6 +451,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create dataset item
"""
if backend() == "sqlite":
return
main_db = get_main_database()
self.login(username="admin")
table_data = {
@ -456,6 +496,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create dataset item gamma
"""
if backend() == "sqlite":
return
self.login(username="gamma")
main_db = get_main_database()
table_data = {
@ -471,6 +514,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create item owner
"""
if backend() == "sqlite":
return
main_db = get_main_database()
self.login(username="alpha")
admin = self.get_user("admin")
@ -496,6 +542,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create dataset item owner invalid
"""
if backend() == "sqlite":
return
admin = self.get_user("admin")
main_db = get_main_database()
self.login(username="admin")
@ -517,6 +566,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create dataset validate table uniqueness
"""
if backend() == "sqlite":
return
schema = get_example_default_schema()
energy_usage_ds = self.get_energy_usage_dataset()
self.login(username="admin")
@ -568,6 +620,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create dataset validate database exists
"""
if backend() == "sqlite":
return
self.login(username="admin")
dataset_data = {"database": 1000, "schema": "", "table_name": "birth_names"}
uri = "api/v1/dataset/"
@ -580,6 +635,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create dataset validate table exists
"""
if backend() == "sqlite":
return
example_db = get_example_database()
self.login(username="admin")
table_data = {
@ -600,6 +658,8 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create dataset validate view exists
"""
if backend() == "sqlite":
return
mock_get_columns.return_value = [
{
@ -644,6 +704,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test create dataset sqlalchemy error
"""
if backend() == "sqlite":
return
mock_dao_create.side_effect = DAOCreateFailedError()
self.login(username="admin")
main_db = get_main_database()
@ -662,6 +725,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset item
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="admin")
dataset_data = {"description": "changed_description"}
@ -678,6 +744,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset with override columns
"""
if backend() == "sqlite":
return
# Add default dataset
dataset = self.insert_default_dataset()
self.login(username="admin")
@ -714,6 +783,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset create column
"""
if backend() == "sqlite":
return
# create example dataset by Command
dataset = self.insert_default_dataset()
@ -809,6 +881,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset delete column
"""
if backend() == "sqlite":
return
# create example dataset by Command
dataset = self.insert_default_dataset()
@ -858,6 +933,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset columns
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="admin")
@ -894,6 +972,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset delete metric
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
metrics_query = (
db.session.query(SqlMetric)
@ -937,6 +1018,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset columns uniqueness
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="admin")
@ -957,6 +1041,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset metric uniqueness
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="admin")
@ -977,6 +1064,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset columns duplicate
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="admin")
@ -1002,6 +1092,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset metric duplicate
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="admin")
@ -1027,6 +1120,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset item gamma
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="gamma")
table_data = {"description": "changed_description"}
@ -1040,6 +1136,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset item not owned
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="alpha")
table_data = {"description": "changed_description"}
@ -1053,6 +1152,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset item owner invalid
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="admin")
table_data = {"description": "changed_description", "owners": [1000]}
@ -1066,6 +1168,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset uniqueness
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="admin")
ab_user = self.insert_dataset(
@ -1089,6 +1194,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test update dataset sqlalchemy error
"""
if backend() == "sqlite":
return
mock_dao_update.side_effect = DAOUpdateFailedError()
dataset = self.insert_default_dataset()
@ -1107,6 +1215,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset item
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
view_menu = security_manager.find_view_menu(dataset.get_perm())
assert view_menu is not None
@ -1124,6 +1235,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete item not owned
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="alpha")
uri = f"api/v1/dataset/{dataset.id}"
@ -1136,6 +1250,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete item not authorized
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="gamma")
uri = f"api/v1/dataset/{dataset.id}"
@ -1149,6 +1266,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset sqlalchemy error
"""
if backend() == "sqlite":
return
mock_dao_delete.side_effect = DAODeleteFailedError()
dataset = self.insert_default_dataset()
@ -1166,6 +1286,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset column
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
column_id = dataset.columns[0].id
self.login(username="admin")
@ -1179,6 +1302,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset column not found
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
non_id = self.get_nonexistent_numeric_id(TableColumn)
@ -1200,6 +1326,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset column not owned
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
column_id = dataset.columns[0].id
@ -1214,6 +1343,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset column
"""
if backend() == "sqlite":
return
mock_dao_delete.side_effect = DAODeleteFailedError()
dataset = self.get_fixture_datasets()[0]
column_id = dataset.columns[0].id
@ -1229,6 +1361,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset metric
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
test_metric = SqlMetric(
metric_name="metric1", expression="COUNT(*)", table=dataset
@ -1247,6 +1382,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset metric not found
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
non_id = self.get_nonexistent_numeric_id(SqlMetric)
@ -1268,6 +1406,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset metric not owned
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
metric_id = dataset.metrics[0].id
@ -1282,6 +1423,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test delete dataset metric
"""
if backend() == "sqlite":
return
mock_dao_delete.side_effect = DAODeleteFailedError()
dataset = self.get_fixture_datasets()[0]
column_id = dataset.metrics[0].id
@ -1297,6 +1441,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test bulk delete dataset items
"""
if backend() == "sqlite":
return
datasets = self.get_fixture_datasets()
dataset_ids = [dataset.id for dataset in datasets]
@ -1326,6 +1473,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test bulk delete item not owned
"""
if backend() == "sqlite":
return
datasets = self.get_fixture_datasets()
dataset_ids = [dataset.id for dataset in datasets]
@ -1339,6 +1489,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test bulk delete item not found
"""
if backend() == "sqlite":
return
datasets = self.get_fixture_datasets()
dataset_ids = [dataset.id for dataset in datasets]
dataset_ids.append(db.session.query(func.max(SqlaTable.id)).scalar())
@ -1353,6 +1506,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test bulk delete item not authorized
"""
if backend() == "sqlite":
return
datasets = self.get_fixture_datasets()
dataset_ids = [dataset.id for dataset in datasets]
@ -1366,6 +1522,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test bulk delete item incorrect request
"""
if backend() == "sqlite":
return
datasets = self.get_fixture_datasets()
dataset_ids = [dataset.id for dataset in datasets]
dataset_ids.append("Wrong")
@ -1379,6 +1538,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test item refresh
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
# delete a column
id_column = (
@ -1407,6 +1569,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test item refresh not found dataset
"""
if backend() == "sqlite":
return
max_id = db.session.query(func.max(SqlaTable.id)).scalar()
self.login(username="admin")
@ -1418,6 +1583,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test item refresh not owned dataset
"""
if backend() == "sqlite":
return
dataset = self.insert_default_dataset()
self.login(username="alpha")
uri = f"api/v1/dataset/{dataset.id}/refresh"
@ -1432,6 +1600,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test export dataset
"""
if backend() == "sqlite":
return
birth_names_dataset = self.get_birth_names_dataset()
# TODO: fix test for presto
# debug with dump: https://github.com/apache/superset/runs/1092546855
@ -1464,6 +1635,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test export dataset not found
"""
if backend() == "sqlite":
return
max_id = db.session.query(func.max(SqlaTable.id)).scalar()
# Just one does not exist and we get 404
argument = [max_id + 1, 1]
@ -1477,6 +1651,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test export dataset has gamma
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
argument = [dataset.id]
@ -1505,6 +1682,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test export dataset
"""
if backend() == "sqlite":
return
birth_names_dataset = self.get_birth_names_dataset()
# TODO: fix test for presto
# debug with dump: https://github.com/apache/superset/runs/1092546855
@ -1526,6 +1706,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test export dataset not found
"""
if backend() == "sqlite":
return
# Just one does not exist and we get 404
argument = [-1, 1]
uri = f"api/v1/dataset/export/?q={prison.dumps(argument)}"
@ -1539,6 +1722,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test export dataset has gamma
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
argument = [dataset.id]
@ -1556,6 +1742,9 @@ class TestDatasetApi(SupersetTestCase):
Dataset API: Test get chart and dashboard count related to a dataset
:return:
"""
if backend() == "sqlite":
return
self.login(username="admin")
table = self.get_birth_names_dataset()
uri = f"api/v1/dataset/{table.id}/related_objects"
@ -1569,6 +1758,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test related objects not found
"""
if backend() == "sqlite":
return
max_id = db.session.query(func.max(SqlaTable.id)).scalar()
# id does not exist and we get 404
invalid_id = max_id + 1
@ -1588,6 +1780,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test custom dataset_is_null_or_empty filter for sql
"""
if backend() == "sqlite":
return
arguments = {
"filters": [
{"col": "sql", "opr": "dataset_is_null_or_empty", "value": False}
@ -1621,6 +1816,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test import dataset
"""
if backend() == "sqlite":
return
self.login(username="admin")
uri = "api/v1/dataset/import/"
@ -1656,6 +1854,9 @@ class TestDatasetApi(SupersetTestCase):
db.session.commit()
def test_import_dataset_v0_export(self):
if backend() == "sqlite":
return
num_datasets = db.session.query(SqlaTable).count()
self.login(username="admin")
@ -1684,6 +1885,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test import existing dataset
"""
if backend() == "sqlite":
return
self.login(username="admin")
uri = "api/v1/dataset/import/"
@ -1753,6 +1957,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test import invalid dataset
"""
if backend() == "sqlite":
return
self.login(username="admin")
uri = "api/v1/dataset/import/"
@ -1803,6 +2010,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test import invalid dataset
"""
if backend() == "sqlite":
return
self.login(username="admin")
uri = "api/v1/dataset/import/"
@ -1848,6 +2058,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test custom dataset_is_certified filter
"""
if backend() == "sqlite":
return
table_w_certification = SqlaTable(
table_name="foo",
schema=None,
@ -1878,6 +2091,9 @@ class TestDatasetApi(SupersetTestCase):
"""
Dataset API: Test get dataset samples
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
self.login(username="admin")
@ -1919,6 +2135,9 @@ class TestDatasetApi(SupersetTestCase):
@pytest.mark.usefixtures("create_datasets")
def test_get_dataset_samples_with_failed_cc(self):
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
self.login(username="admin")
@ -1938,6 +2157,9 @@ class TestDatasetApi(SupersetTestCase):
assert "INCORRECT SQL" in rv_data.get("message")
def test_get_dataset_samples_on_virtual_dataset(self):
if backend() == "sqlite":
return
virtual_dataset = SqlaTable(
table_name="virtual_dataset",
sql=("SELECT 'foo' as foo, 'bar' as bar"),

View File

@ -19,7 +19,6 @@ from unittest import mock, skipUnless
import pandas as pd
from sqlalchemy import types
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import select
from superset.db_engine_specs.presto import PrestoEngineSpec
@ -83,12 +82,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
keymap = {
"Column": (None, None, 0),
"Type": (None, None, 1),
"Null": (None, None, 2),
}
row = RowProxy(mock.Mock(), column, [None, None, None, None], keymap)
row = mock.Mock()
row.Column, row.Type, row.Null = column
inspector.bind.execute = mock.Mock(return_value=[row])
results = PrestoEngineSpec.get_columns(inspector, "", "")
self.assertEqual(len(expected_results), len(results))

View File

@ -19,7 +19,7 @@ from typing import Any, Dict, Generator
import pytest
from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String, Table
from sqlalchemy.ext.declarative.api import declarative_base
from sqlalchemy.ext.declarative import declarative_base
from superset.columns.models import Column as Sl_Column
from superset.connectors.sqla.models import SqlaTable, TableColumn

View File

@ -199,7 +199,10 @@ class TestDatabaseModel(SupersetTestCase):
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://original_user@localhost"
assert (
str(call_args[0][0])
== "trino://original_user:original_user_password@localhost"
)
assert call_args[1]["connect_args"] == {"user": "gamma"}
@mock.patch("superset.models.core.create_engine")

View File

@ -379,9 +379,8 @@ class TestDatabaseModel(SupersetTestCase):
"extras": {},
}
# Table with Jinja callable.
table = SqlaTable(
table_name="test_table",
table_name="another_test_table",
sql="SELECT * from test_table;",
database=get_example_database(),
)

View File

@ -0,0 +1,330 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument, redefined-outer-name, invalid-name
from functools import partial
from typing import Any, Dict, TYPE_CHECKING
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
FULL_DTTM_DEFAULTS_EXAMPLE = {
"main_dttm_col": "id",
"dttm_columns": {
"dttm": {
"python_date_format": "epoch_s",
"expression": "CAST(dttm as INTEGER)",
},
"id": {"python_date_format": "epoch_ms"},
"month": {
"python_date_format": "%Y-%m-%d",
"expression": (
"CASE WHEN length(month) = 7 THEN month || '-01' ELSE month END"
),
},
},
}
def apply_dttm_defaults(table: "SqlaTable", dttm_defaults: Dict[str, Any]) -> None:
"""Applies dttm defaults to the table, mutates in place."""
for dbcol in table.columns:
# Set is_dttm is column is listed in dttm_columns.
if dbcol.column_name in dttm_defaults.get("dttm_columns", {}):
dbcol.is_dttm = True
# Skip non dttm columns.
if dbcol.column_name not in dttm_defaults.get("dttm_columns", {}):
continue
# Set table main_dttm_col.
if dbcol.column_name == dttm_defaults.get("main_dttm_col"):
table.main_dttm_col = dbcol.column_name
# Apply defaults if empty.
dttm_column_defaults = dttm_defaults.get("dttm_columns", {}).get(
dbcol.column_name, {}
)
dbcol.is_dttm = True
if (
not dbcol.python_date_format
and "python_date_format" in dttm_column_defaults
):
dbcol.python_date_format = dttm_column_defaults["python_date_format"]
if not dbcol.expression and "expression" in dttm_column_defaults:
dbcol.expression = dttm_column_defaults["expression"]
@pytest.fixture
def test_table(app_context: None, session: Session) -> "SqlaTable":
"""
Fixture that generates an in-memory table.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
columns = [
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
TableColumn(column_name="event_time", is_dttm=1, type="TIMESTAMP"),
TableColumn(column_name="id", type="INTEGER"),
TableColumn(column_name="dttm", type="INTEGER"),
TableColumn(column_name="duration_ms", type="INTEGER"),
]
return SqlaTable(
table_name="test_table",
columns=columns,
metrics=[],
main_dttm_col=None,
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
)
def test_main_dttm_col(mocker: MockerFixture, test_table: "SqlaTable") -> None:
"""
Test the ``SQLA_TABLE_MUTATOR`` config.
"""
dttm_defaults = {
"main_dttm_col": "event_time",
"dttm_columns": {"ds": {}, "event_time": {}},
}
mocker.patch(
"superset.connectors.sqla.models.config",
new={
"SQLA_TABLE_MUTATOR": partial(
apply_dttm_defaults,
dttm_defaults=dttm_defaults,
)
},
)
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"name": "id", "type": "INTEGER", "is_dttm": False},
],
)
assert test_table.main_dttm_col is None
test_table.fetch_metadata()
assert test_table.main_dttm_col == "event_time"
def test_main_dttm_col_nonexistent(
mocker: MockerFixture,
test_table: "SqlaTable",
) -> None:
"""
Test the ``SQLA_TABLE_MUTATOR`` config when main datetime column doesn't exist.
"""
dttm_defaults = {
"main_dttm_col": "nonexistent",
}
mocker.patch(
"superset.connectors.sqla.models.config",
new={
"SQLA_TABLE_MUTATOR": partial(
apply_dttm_defaults,
dttm_defaults=dttm_defaults,
)
},
)
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"name": "id", "type": "INTEGER", "is_dttm": False},
],
)
assert test_table.main_dttm_col is None
test_table.fetch_metadata()
# fall back to ds
assert test_table.main_dttm_col == "ds"
def test_main_dttm_col_nondttm(
mocker: MockerFixture,
test_table: "SqlaTable",
) -> None:
"""
Test the ``SQLA_TABLE_MUTATOR`` config when main datetime column has wrong type.
"""
dttm_defaults = {
"main_dttm_col": "id",
}
mocker.patch(
"superset.connectors.sqla.models.config",
new={
"SQLA_TABLE_MUTATOR": partial(
apply_dttm_defaults,
dttm_defaults=dttm_defaults,
)
},
)
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "ds", "type": "TIMESTAMP", "is_dttm": True},
{"name": "event_time", "type": "TIMESTAMP", "is_dttm": True},
{"name": "id", "type": "INTEGER", "is_dttm": False},
],
)
assert test_table.main_dttm_col is None
test_table.fetch_metadata()
# fall back to ds
assert test_table.main_dttm_col == "ds"
def test_python_date_format_by_column_name(
mocker: MockerFixture,
test_table: "SqlaTable",
) -> None:
"""
Test the ``SQLA_TABLE_MUTATOR`` setting for "python_date_format".
"""
table_defaults = {
"dttm_columns": {
"id": {"python_date_format": "epoch_ms"},
"dttm": {"python_date_format": "epoch_s"},
"duration_ms": {"python_date_format": "invalid"},
},
}
mocker.patch(
"superset.connectors.sqla.models.config",
new={
"SQLA_TABLE_MUTATOR": partial(
apply_dttm_defaults,
dttm_defaults=table_defaults,
)
},
)
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "id", "type": "INTEGER", "is_dttm": False},
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
],
)
test_table.fetch_metadata()
id_col = [c for c in test_table.columns if c.column_name == "id"][0]
assert id_col.is_dttm
assert id_col.python_date_format == "epoch_ms"
dttm_col = [c for c in test_table.columns if c.column_name == "dttm"][0]
assert dttm_col.is_dttm
assert dttm_col.python_date_format == "epoch_s"
duration_ms_col = [c for c in test_table.columns if c.column_name == "duration_ms"][
0
]
assert duration_ms_col.is_dttm
assert duration_ms_col.python_date_format == "invalid"
def test_expression_by_column_name(
mocker: MockerFixture,
test_table: "SqlaTable",
) -> None:
"""
Test the ``SQLA_TABLE_MUTATOR`` setting for expression.
"""
table_defaults = {
"dttm_columns": {
"dttm": {"expression": "CAST(dttm as INTEGER)"},
"duration_ms": {"expression": "CAST(duration_ms as DOUBLE)"},
},
}
mocker.patch(
"superset.connectors.sqla.models.config",
new={
"SQLA_TABLE_MUTATOR": partial(
apply_dttm_defaults,
dttm_defaults=table_defaults,
)
},
)
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
],
)
test_table.fetch_metadata()
dttm_col = [c for c in test_table.columns if c.column_name == "dttm"][0]
assert dttm_col.is_dttm
assert dttm_col.expression == "CAST(dttm as INTEGER)"
duration_ms_col = [c for c in test_table.columns if c.column_name == "duration_ms"][
0
]
assert duration_ms_col.is_dttm
assert duration_ms_col.expression == "CAST(duration_ms as DOUBLE)"
def test_full_setting(
mocker: MockerFixture,
test_table: "SqlaTable",
) -> None:
"""
Test the ``SQLA_TABLE_MUTATOR`` with full settings.
"""
mocker.patch(
"superset.connectors.sqla.models.config",
new={
"SQLA_TABLE_MUTATOR": partial(
apply_dttm_defaults,
dttm_defaults=FULL_DTTM_DEFAULTS_EXAMPLE,
)
},
)
mocker.patch(
"superset.connectors.sqla.models.get_physical_table_metadata",
return_value=[
{"name": "id", "type": "INTEGER", "is_dttm": False},
{"name": "dttm", "type": "INTEGER", "is_dttm": False},
{"name": "duration_ms", "type": "INTEGER", "is_dttm": False},
],
)
test_table.fetch_metadata()
id_col = [c for c in test_table.columns if c.column_name == "id"][0]
assert id_col.is_dttm
assert id_col.python_date_format == "epoch_ms"
assert id_col.expression == ""
dttm_col = [c for c in test_table.columns if c.column_name == "dttm"][0]
assert dttm_col.is_dttm
assert dttm_col.python_date_format == "epoch_s"
assert dttm_col.expression == "CAST(dttm as INTEGER)"

View File

@ -47,10 +47,12 @@ def get_session(mocker: MockFixture) -> Callable[[], Session]:
in_memory_session.remove = lambda: None
# patch session
mocker.patch(
get_session = mocker.patch(
"superset.security.SupersetSecurityManager.get_session",
return_value=in_memory_session,
)
get_session.return_value = in_memory_session
# FAB calls get_session.get_bind() to get a handler to the engine
get_session.get_bind.return_value = engine
mocker.patch("superset.db.session", in_memory_session)
return in_memory_session

View File

@ -124,11 +124,11 @@ def test_import_dataset(app_context: None, session: Session) -> None:
assert len(sqla_table.columns) == 1
assert sqla_table.columns[0].column_name == "profit"
assert sqla_table.columns[0].verbose_name is None
assert sqla_table.columns[0].is_dttm is None
assert sqla_table.columns[0].is_active is None
assert sqla_table.columns[0].is_dttm is False
assert sqla_table.columns[0].is_active is True
assert sqla_table.columns[0].type == "INTEGER"
assert sqla_table.columns[0].groupby is None
assert sqla_table.columns[0].filterable is None
assert sqla_table.columns[0].groupby is True
assert sqla_table.columns[0].filterable is True
assert sqla_table.columns[0].expression == "revenue-expenses"
assert sqla_table.columns[0].description is None
assert sqla_table.columns[0].python_date_format is None

View File

@ -259,7 +259,6 @@ def test_dataset_attributes(app_context: None, session: Session) -> None:
"main_dttm_col",
"metrics",
"offset",
"owners",
"params",
"perm",
"schema",

View File

@ -22,7 +22,7 @@ from pytest import raises
def test_odbc_impersonation(app_context: AppContext) -> None:
"""
Test ``modify_url_for_impersonation`` method when driver == odbc.
Test ``get_url_for_impersonation`` method when driver == odbc.
The method adds the parameter ``DelegationUID`` to the query string.
"""
@ -32,13 +32,13 @@ def test_odbc_impersonation(app_context: AppContext) -> None:
url = URL("drill+odbc")
username = "DoAsUser"
DrillEngineSpec.modify_url_for_impersonation(url, True, username)
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
assert url.query["DelegationUID"] == username
def test_jdbc_impersonation(app_context: AppContext) -> None:
"""
Test ``modify_url_for_impersonation`` method when driver == jdbc.
Test ``get_url_for_impersonation`` method when driver == jdbc.
The method adds the parameter ``impersonation_target`` to the query string.
"""
@ -48,13 +48,13 @@ def test_jdbc_impersonation(app_context: AppContext) -> None:
url = URL("drill+jdbc")
username = "DoAsUser"
DrillEngineSpec.modify_url_for_impersonation(url, True, username)
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
assert url.query["impersonation_target"] == username
def test_sadrill_impersonation(app_context: AppContext) -> None:
"""
Test ``modify_url_for_impersonation`` method when driver == sadrill.
Test ``get_url_for_impersonation`` method when driver == sadrill.
The method adds the parameter ``impersonation_target`` to the query string.
"""
@ -64,13 +64,13 @@ def test_sadrill_impersonation(app_context: AppContext) -> None:
url = URL("drill+sadrill")
username = "DoAsUser"
DrillEngineSpec.modify_url_for_impersonation(url, True, username)
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
assert url.query["impersonation_target"] == username
def test_invalid_impersonation(app_context: AppContext) -> None:
"""
Test ``modify_url_for_impersonation`` method when driver == foobar.
Test ``get_url_for_impersonation`` method when driver == foobar.
The method raises an exception because impersonation is not supported
for drill+foobar.
@ -84,4 +84,4 @@ def test_invalid_impersonation(app_context: AppContext) -> None:
username = "DoAsUser"
with raises(SupersetDBAPIProgrammingError):
DrillEngineSpec.modify_url_for_impersonation(url, True, username)
DrillEngineSpec.get_url_for_impersonation(url, True, username)