superset/tests/integration_tests/base_tests.py

620 lines
21 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
from datetime import datetime
from importlib.util import find_spec
from contextlib import contextmanager
from typing import Any, Union, Optional
from unittest.mock import Mock, patch, MagicMock
import pandas as pd
import prison
from flask import Response, g
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Session # noqa: F401
from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import dialect
from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.test_app import app, login
from superset.sql_parse import CtasMethod
from superset import db, security_manager
from superset.connectors.sqla.models import BaseDatasource, SqlaTable
from superset.models import core as models
from superset.models.slice import Slice
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.utils.core import get_example_default_schema, shortid
from superset.utils import json
from superset.utils.database import get_example_database
from superset.views.base_api import BaseSupersetModelRestApi
FAKE_DB_NAME = "fake_db_100"
DEFAULT_PASSWORD = "general" # noqa: S105
test_client = app.test_client()
def get_resp(
client: Any,
url: str,
data: Any = None,
follow_redirects: bool = True,
raise_on_error: bool = True,
json_: Optional[str] = None,
):
"""Shortcut to get the parsed results while following redirects"""
if data:
resp = client.post(url, data=data, follow_redirects=follow_redirects)
elif json_:
resp = client.post(url, json=json_, follow_redirects=follow_redirects)
else:
resp = client.get(url, follow_redirects=follow_redirects)
if raise_on_error and resp.status_code > 400:
raise Exception(f"http request failed with code {resp.status_code}")
return resp.data.decode("utf-8")
def post_assert_metric(
client: Any, uri: str, data: dict[str, Any], func_name: str
) -> Response:
"""
Simple client post with an extra assertion for statsd metrics
:param client: test client for superset api requests
:param uri: The URI to use for the HTTP POST
:param data: The JSON data payload to be posted
:param func_name: The function name that the HTTP POST triggers
for the statsd metric assertion
:return: HTTP Response
"""
with patch.object(
BaseSupersetModelRestApi, "incr_stats", return_value=None
) as mock_method:
rv = client.post(uri, json=data)
if 200 <= rv.status_code < 400:
mock_method.assert_called_once_with("success", func_name)
elif 400 <= rv.status_code < 500:
mock_method.assert_called_once_with("warning", func_name)
else:
mock_method.assert_called_once_with("error", func_name)
return rv
class SupersetTestCase(TestCase):
default_schema_backend_map = {
"sqlite": "main",
"mysql": "superset",
"postgresql": "public",
"presto": "default",
"hive": "default",
}
maxDiff = -1
def tearDown(self):
self.logout()
def create_app(self):
return app
@staticmethod
def get_nonexistent_numeric_id(model):
return (db.session.query(func.max(model.id)).scalar() or 0) + 1
@staticmethod
def get_birth_names_dataset() -> SqlaTable:
return SupersetTestCase.get_table(name="birth_names")
@staticmethod
def create_user_with_roles(
username: str, roles: list[str], should_create_roles: bool = False
):
user_to_create = security_manager.find_user(username)
if not user_to_create:
security_manager.add_user(
username,
username,
username,
f"{username}@superset.com",
security_manager.find_role("Gamma"), # it needs a role
password=DEFAULT_PASSWORD,
)
db.session.commit()
user_to_create = security_manager.find_user(username)
assert user_to_create
user_to_create.roles = []
for chosen_user_role in roles:
if should_create_roles:
# copy role from gamma but without data permissions
security_manager.copy_role("Gamma", chosen_user_role, merge=False)
user_to_create.roles.append(security_manager.find_role(chosen_user_role))
db.session.commit()
return user_to_create
@contextmanager
def temporary_user(
self,
clone_user=None,
username=None,
extra_roles=None,
extra_pvms=None,
login=False,
):
"""
Create a temporary user for testing and delete it after the test
with self.temporary_user(login=True, extra_roles=[Role(...)]) as user:
user.do_something()
# user is automatically logged out and deleted after the test
"""
username = username or f"temp_user_{shortid()}"
temp_user = ab_models.User(
username=username, email=f"{username}@temp.com", active=True
)
if clone_user:
temp_user.roles = clone_user.roles
temp_user.first_name = clone_user.first_name
temp_user.last_name = clone_user.last_name
temp_user.password = clone_user.password
else:
temp_user.first_name = temp_user.last_name = username
if clone_user:
temp_user.roles = clone_user.roles
if extra_roles:
temp_user.roles.extend(extra_roles)
pvms = []
temp_role = None
if extra_pvms:
temp_role = ab_models.Role(name=f"tmp_role_{shortid()}")
for pvm in extra_pvms:
if isinstance(pvm, (tuple, list)):
pvms.append(security_manager.find_permission_view_menu(*pvm))
else:
pvms.append(pvm)
temp_role.permissions = pvms
temp_user.roles.append(temp_role)
db.session.add(temp_role)
db.session.commit()
# Add the temp user to the session and commit to apply changes for the test
db.session.add(temp_user)
db.session.commit()
previous_g_user = g.user if hasattr(g, "user") else None
try:
if login:
self.login(username=temp_user.username)
else:
g.user = temp_user
yield temp_user
finally:
# Revert changes after the test
if temp_role:
db.session.delete(temp_role)
if login:
self.logout()
db.session.delete(temp_user)
db.session.commit()
g.user = previous_g_user
@staticmethod
def create_user(
username: str,
password: str,
role_name: str,
first_name: str = "admin",
last_name: str = "user",
email: str = "admin@fab.org",
) -> Union[ab_models.User, bool]:
role_admin = security_manager.find_role(role_name)
return security_manager.add_user(
username, first_name, last_name, email, role_admin, password
)
@staticmethod
def get_user(username: str) -> ab_models.User:
user = (
db.session.query(security_manager.user_model)
.filter_by(username=username)
.one_or_none()
)
return user
@staticmethod
def get_role(name: str) -> Optional[ab_models.User]:
user = (
db.session.query(security_manager.role_model)
.filter_by(name=name)
.one_or_none()
)
return user
@staticmethod
def get_table_by_id(table_id: int) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(id=table_id).one()
@staticmethod
def is_module_installed(module_name: str) -> bool:
try:
spec = find_spec(module_name)
return spec is not None
except (ModuleNotFoundError, ValueError, TypeError, ImportError):
return False
def get_or_create(self, cls, criteria, **kwargs):
obj = db.session.query(cls).filter_by(**criteria).first()
if not obj:
obj = cls(**criteria)
obj.__dict__.update(**kwargs)
db.session.add(obj)
db.session.commit()
return obj
def login(self, username, password=DEFAULT_PASSWORD):
return login(self.client, username, password)
def get_slice(self, slice_name: str) -> Slice:
return db.session.query(Slice).filter_by(slice_name=slice_name).one()
@staticmethod
def get_table(
name: str, database_id: Optional[int] = None, schema: Optional[str] = None
) -> SqlaTable:
schema = schema or get_example_default_schema()
return (
db.session.query(SqlaTable)
.filter_by(
database_id=database_id
or SupersetTestCase.get_database_by_name("examples").id,
schema=schema,
table_name=name,
)
.one()
)
@staticmethod
def get_database_by_id(db_id: int) -> Database:
return db.session.query(Database).filter_by(id=db_id).one()
@staticmethod
def get_database_by_name(database_name: str = "main") -> Database:
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")
@staticmethod
def get_datasource_mock() -> BaseDatasource:
datasource = MagicMock()
results = Mock()
results.query = Mock()
results.status = Mock()
results.error_message = None
results.df = pd.DataFrame()
datasource.type = "table"
datasource.query = Mock(return_value=results)
mock_dttm_col = Mock()
datasource.get_col = Mock(return_value=mock_dttm_col)
datasource.query = Mock(return_value=results)
datasource.database = Mock()
datasource.database.db_engine_spec = Mock()
datasource.database.perm = "mock_database_perm"
datasource.schema_perm = "mock_schema_perm"
datasource.perm = "mock_datasource_perm"
datasource.__class__ = SqlaTable
datasource.database.db_engine_spec.mutate_expression_label = lambda x: x
datasource.owners = MagicMock()
datasource.id = 99999
return datasource
def get_resp(
self, url, data=None, follow_redirects=True, raise_on_error=True, json_=None
):
return get_resp(self.client, url, data, follow_redirects, raise_on_error, json_)
def get_json_resp(
self, url, data=None, follow_redirects=True, raise_on_error=True, json_=None
):
"""Shortcut to get the parsed results while following redirects"""
resp = self.get_resp(url, data, follow_redirects, raise_on_error, json_)
return json.loads(resp)
def logout(self):
self.client.get("/logout/", follow_redirects=True)
def grant_public_access_to_table(self, table):
role_name = "Public"
self.grant_role_access_to_table(table, role_name)
def grant_role_access_to_table(self, table, role_name):
role = security_manager.find_role(role_name)
perms = db.session.query(ab_models.PermissionView).all()
for perm in perms:
if (
perm.permission.name == "datasource_access"
and perm.view_menu
and table.perm in perm.view_menu.name
):
security_manager.add_permission_role(role, perm)
def revoke_public_access_to_table(self, table):
role_name = "Public"
self.revoke_role_access_to_table(role_name, table)
def revoke_role_access_to_table(self, role_name, table):
public_role = security_manager.find_role(role_name)
perms = db.session.query(ab_models.PermissionView).all()
for perm in perms:
if (
perm.permission.name == "datasource_access"
and perm.view_menu
and table.perm in perm.view_menu.name
):
security_manager.del_permission_role(public_role, perm)
def run_sql(
self,
sql,
client_id=None,
username=None,
raise_on_error=False,
query_limit=None,
database_name="examples",
sql_editor_id=None,
select_as_cta=False,
tmp_table_name=None,
schema=None,
ctas_method=CtasMethod.TABLE,
template_params="{}",
):
if username:
self.logout()
self.login(username)
dbid = SupersetTestCase.get_database_by_name(database_name).id
json_payload = {
"database_id": dbid,
"sql": sql,
"client_id": client_id,
"queryLimit": query_limit,
"sql_editor_id": sql_editor_id,
"ctas_method": ctas_method,
"templateParams": template_params,
}
if tmp_table_name:
json_payload["tmp_table_name"] = tmp_table_name
if select_as_cta:
json_payload["select_as_cta"] = select_as_cta
if schema:
json_payload["schema"] = schema
resp = self.get_json_resp(
"/api/v1/sqllab/execute/", raise_on_error=False, json_=json_payload
)
if username:
self.logout()
if raise_on_error and "error" in resp:
raise Exception("run_sql failed")
return resp
def create_fake_db(self):
database_name = FAKE_DB_NAME
db_id = 100
extra = """{
"schemas_allowed_for_file_upload":
["this_schema_is_allowed", "this_schema_is_allowed_too"]
}"""
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
sqlalchemy_uri="sqlite:///:memory:",
id=db_id,
extra=extra,
)
def delete_fake_db(self):
database = (
db.session.query(Database)
.filter(Database.database_name == FAKE_DB_NAME)
.scalar()
)
if database:
db.session.delete(database)
db.session.commit()
def create_fake_db_for_macros(self):
database_name = "db_for_macros_testing"
db_id = 200
database = self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
sqlalchemy_uri="db_for_macros_testing://user@host:8080/hive",
id=db_id,
)
def mock_get_dialect() -> Dialect:
return dialect()
database.get_dialect = mock_get_dialect
return database
@staticmethod
def delete_fake_db_for_macros():
database = (
db.session.query(Database)
.filter(Database.database_name == "db_for_macros_testing")
.scalar()
)
if database:
db.session.delete(database)
db.session.commit()
def get_dash_by_slug(self, dash_slug):
return db.session.query(Dashboard).filter_by(slug=dash_slug).first()
def get_assert_metric(self, uri: str, func_name: str) -> Response:
"""
Simple client get with an extra assertion for statsd metrics
:param uri: The URI to use for the HTTP GET
:param func_name: The function name that the HTTP GET triggers
for the statsd metric assertion
:return: HTTP Response
"""
with patch.object(
BaseSupersetModelRestApi, "incr_stats", return_value=None
) as mock_method:
rv = self.client.get(uri)
if 200 <= rv.status_code < 400:
mock_method.assert_called_once_with("success", func_name)
elif 400 <= rv.status_code < 500:
mock_method.assert_called_once_with("warning", func_name)
else:
mock_method.assert_called_once_with("error", func_name)
return rv
def delete_assert_metric(self, uri: str, func_name: str) -> Response:
"""
Simple client delete with an extra assertion for statsd metrics
:param uri: The URI to use for the HTTP DELETE
:param func_name: The function name that the HTTP DELETE triggers
for the statsd metric assertion
:return: HTTP Response
"""
with patch.object(
BaseSupersetModelRestApi, "incr_stats", return_value=None
) as mock_method:
rv = self.client.delete(uri)
if 200 <= rv.status_code < 400:
mock_method.assert_called_once_with("success", func_name)
elif 400 <= rv.status_code < 500:
mock_method.assert_called_once_with("warning", func_name)
else:
mock_method.assert_called_once_with("error", func_name)
return rv
def post_assert_metric(
self, uri: str, data: dict[str, Any], func_name: str
) -> Response:
return post_assert_metric(self.client, uri, data, func_name)
def put_assert_metric(
self, uri: str, data: dict[str, Any], func_name: str
) -> Response:
"""
Simple client put with an extra assertion for statsd metrics
:param uri: The URI to use for the HTTP PUT
:param data: The JSON data payload to be posted
:param func_name: The function name that the HTTP PUT triggers
for the statsd metric assertion
:return: HTTP Response
"""
with patch.object(
BaseSupersetModelRestApi, "incr_stats", return_value=None
) as mock_method:
rv = self.client.put(uri, json=data)
if 200 <= rv.status_code < 400:
mock_method.assert_called_once_with("success", func_name)
elif 400 <= rv.status_code < 500:
mock_method.assert_called_once_with("warning", func_name)
else:
mock_method.assert_called_once_with("error", func_name)
return rv
@classmethod
def get_dttm(cls):
return datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
def insert_dashboard(
self,
dashboard_title: str,
slug: Optional[str],
owners: list[int],
roles: list[int] = [], # noqa: B006
created_by=None,
slices: Optional[list[Slice]] = None,
position_json: str = "",
css: str = "",
json_metadata: str = "",
published: bool = False,
certified_by: Optional[str] = None,
certification_details: Optional[str] = None,
) -> Dashboard:
obj_owners = list() # noqa: C408
obj_roles = list() # noqa: C408
slices = slices or []
for owner in owners:
user = db.session.query(security_manager.user_model).get(owner)
obj_owners.append(user)
for role in roles:
role_obj = db.session.query(security_manager.role_model).get(role)
obj_roles.append(role_obj)
dashboard = Dashboard(
dashboard_title=dashboard_title,
slug=slug,
owners=obj_owners,
roles=obj_roles,
position_json=position_json,
css=css,
json_metadata=json_metadata,
slices=slices,
published=published,
created_by=created_by,
certified_by=certified_by,
certification_details=certification_details,
)
db.session.add(dashboard)
db.session.commit()
return dashboard
def get_list(
self,
asset_type: str,
filter: dict[str, Any] = {}, # noqa: B006
username: str = ADMIN_USERNAME,
) -> Response:
"""
Get list of assets, by default using admin account. Can be filtered.
"""
self.login(username)
uri = f"api/v1/{asset_type}/?q={prison.dumps(filter)}"
response = self.get_assert_metric(uri, "get_list")
return response
@contextmanager
def db_insert_temp_object(obj: DeclarativeMeta):
"""Insert a temporary object in database; delete when done."""
try:
db.session.add(obj)
db.session.commit()
yield obj
finally:
db.session.delete(obj)
db.session.commit()