fix: Refactor SQL username logic (#19914)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
fff9ad05d4
commit
449d08b25e
|
|
@ -31,6 +31,7 @@ from flask_appbuilder.api.manager import resolver
|
|||
|
||||
import superset.utils.database as database_utils
|
||||
from superset.extensions import db
|
||||
from superset.utils.core import override_user
|
||||
from superset.utils.encrypt import SecretsMigrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -54,23 +55,34 @@ def set_database_uri(database_name: str, uri: str, skip_create: bool) -> None:
|
|||
|
||||
@click.command()
|
||||
@with_appcontext
|
||||
def update_datasources_cache() -> None:
|
||||
@click.option(
|
||||
"--username",
|
||||
"-u",
|
||||
default=None,
|
||||
help=(
|
||||
"Specify which user should execute the underlying SQL queries. If undefined "
|
||||
"defaults to the user registered with the database connection."
|
||||
),
|
||||
)
|
||||
def update_datasources_cache(username: Optional[str]) -> None:
|
||||
"""Refresh sqllab datasources cache"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import security_manager
|
||||
from superset.models.core import Database
|
||||
|
||||
for database in db.session.query(Database).all():
|
||||
if database.allow_multi_schema_metadata_fetch:
|
||||
print("Fetching {} datasources ...".format(database.name))
|
||||
try:
|
||||
database.get_all_table_names_in_database(
|
||||
force=True, cache=True, cache_timeout=24 * 60 * 60
|
||||
)
|
||||
database.get_all_view_names_in_database(
|
||||
force=True, cache=True, cache_timeout=24 * 60 * 60
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
print("{}".format(str(ex)))
|
||||
with override_user(security_manager.find_user(username)):
|
||||
for database in db.session.query(Database).all():
|
||||
if database.allow_multi_schema_metadata_fetch:
|
||||
print("Fetching {} datasources ...".format(database.name))
|
||||
try:
|
||||
database.get_all_table_names_in_database(
|
||||
force=True, cache=True, cache_timeout=24 * 60 * 60
|
||||
)
|
||||
database.get_all_view_names_in_database(
|
||||
force=True, cache=True, cache_timeout=24 * 60 * 60
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
print("{}".format(str(ex)))
|
||||
|
||||
|
||||
@click.command()
|
||||
|
|
|
|||
|
|
@ -680,7 +680,7 @@ BACKUP_COUNT = 30
|
|||
# database,
|
||||
# query,
|
||||
# schema=None,
|
||||
# user=None,
|
||||
# user=None, # TODO(john-bodley): Deprecate in 3.0.
|
||||
# client=None,
|
||||
# security_manager=None,
|
||||
# log_params=None,
|
||||
|
|
@ -1020,9 +1020,14 @@ DB_CONNECTION_MUTATOR = None
|
|||
# The use case is can be around adding some sort of comment header
|
||||
# with information such as the username and worker node information
|
||||
#
|
||||
# def SQL_QUERY_MUTATOR(sql, user_name=user_name, security_manager=security_manager, database=database):
|
||||
# def SQL_QUERY_MUTATOR(
|
||||
# sql,
|
||||
# user_name=user_name, # TODO(john-bodley): Deprecate in 3.0.
|
||||
# security_manager=security_manager,
|
||||
# database=database,
|
||||
# ):
|
||||
# dttm = datetime.now().isoformat()
|
||||
# return f"-- [SQL LAB] {username} {dttm}\n{sql}"
|
||||
# return f"-- [SQL LAB] {user_name} {dttm}\n{sql}"
|
||||
# For backward compatibility, you can unpack any of the above arguments in your
|
||||
# function definition, but keep the **kwargs as the last argument to allow new args
|
||||
# to be added later without any errors.
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ from superset.utils import core as utils
|
|||
from superset.utils.core import (
|
||||
GenericDataType,
|
||||
get_column_name,
|
||||
get_username,
|
||||
is_adhoc_column,
|
||||
MediumText,
|
||||
QueryObjectFilterClause,
|
||||
|
|
@ -917,10 +918,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
Typically adds comments to the query with context"""
|
||||
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
|
||||
if sql_query_mutator:
|
||||
username = utils.get_username()
|
||||
sql = sql_query_mutator(
|
||||
sql,
|
||||
user_name=username,
|
||||
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
|
||||
security_manager=security_manager,
|
||||
database=self.database,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from superset.errors import ErrorLevel, SupersetErrorType
|
|||
from superset.exceptions import SupersetSecurityException, SupersetTimeoutException
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import override_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -74,42 +75,43 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
|
||||
database.set_sqlalchemy_uri(uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
username = self._actor.username if self._actor is not None else None
|
||||
engine = database.get_sqla_engine(user_name=username)
|
||||
event_logger.log_with_context(
|
||||
action="test_connection_attempt",
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
try:
|
||||
alive = func_timeout(
|
||||
int(
|
||||
app.config[
|
||||
"TEST_DATABASE_CONNECTION_TIMEOUT"
|
||||
].total_seconds()
|
||||
),
|
||||
engine.dialect.do_ping,
|
||||
args=(conn,),
|
||||
)
|
||||
except (sqlite3.ProgrammingError, RuntimeError):
|
||||
# SQLite can't run on a separate thread, so ``func_timeout`` fails
|
||||
# RuntimeError catches the equivalent error from duckdb.
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
except FunctionTimedOut as ex:
|
||||
raise SupersetTimeoutException(
|
||||
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
|
||||
message=(
|
||||
"Please check your connection details and database settings, "
|
||||
"and ensure that your database is accepting connections, "
|
||||
"then try connecting again."
|
||||
),
|
||||
level=ErrorLevel.ERROR,
|
||||
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
|
||||
) from ex
|
||||
except Exception: # pylint: disable=broad-except
|
||||
alive = False
|
||||
if not alive:
|
||||
raise DBAPIError(None, None, None)
|
||||
|
||||
with override_user(self._actor):
|
||||
engine = database.get_sqla_engine()
|
||||
event_logger.log_with_context(
|
||||
action="test_connection_attempt",
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
try:
|
||||
alive = func_timeout(
|
||||
int(
|
||||
app.config[
|
||||
"TEST_DATABASE_CONNECTION_TIMEOUT"
|
||||
].total_seconds()
|
||||
),
|
||||
engine.dialect.do_ping,
|
||||
args=(conn,),
|
||||
)
|
||||
except (sqlite3.ProgrammingError, RuntimeError):
|
||||
# SQLite can't run on a separate thread, so ``func_timeout`` fails
|
||||
# RuntimeError catches the equivalent error from duckdb.
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
except FunctionTimedOut as ex:
|
||||
raise SupersetTimeoutException(
|
||||
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
|
||||
message=(
|
||||
"Please check your connection details and database "
|
||||
"settings, and ensure that your database is accepting "
|
||||
"connections, then try connecting again."
|
||||
),
|
||||
level=ErrorLevel.ERROR,
|
||||
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
|
||||
) from ex
|
||||
except Exception: # pylint: disable=broad-except
|
||||
alive = False
|
||||
if not alive:
|
||||
raise DBAPIError(None, None, None)
|
||||
|
||||
# Log succesful connection test with engine
|
||||
event_logger.log_with_context(
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from superset.db_engine_specs.base import BasicParametersMixin
|
|||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import override_user
|
||||
|
||||
BYPASS_VALIDATION_ENGINES = {"bigquery"}
|
||||
|
||||
|
|
@ -115,22 +116,23 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
|||
)
|
||||
database.set_sqlalchemy_uri(sqlalchemy_uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
username = self._actor.username if self._actor is not None else None
|
||||
engine = database.get_sqla_engine(user_name=username)
|
||||
try:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
except Exception as ex:
|
||||
url = make_url_safe(sqlalchemy_uri)
|
||||
context = {
|
||||
"hostname": url.host,
|
||||
"password": url.password,
|
||||
"port": url.port,
|
||||
"username": url.username,
|
||||
"database": url.database,
|
||||
}
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
raise DatabaseTestConnectionFailedError(errors) from ex
|
||||
|
||||
with override_user(self._actor):
|
||||
engine = database.get_sqla_engine()
|
||||
try:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
except Exception as ex:
|
||||
url = make_url_safe(sqlalchemy_uri)
|
||||
context = {
|
||||
"hostname": url.host,
|
||||
"password": url.password,
|
||||
"port": url.port,
|
||||
"username": url.username,
|
||||
"database": url.database,
|
||||
}
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
raise DatabaseTestConnectionFailedError(errors) from ex
|
||||
|
||||
if not alive:
|
||||
raise DatabaseOfflineError(
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ import pandas as pd
|
|||
import sqlparse
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from flask import current_app, g
|
||||
from flask import current_app
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.validate import Range
|
||||
|
|
@ -64,7 +64,7 @@ from superset.models.sql_lab import Query
|
|||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import ColumnSpec, GenericDataType
|
||||
from superset.utils.core import ColumnSpec, GenericDataType, get_username
|
||||
from superset.utils.hashing import md5_sha_from_str
|
||||
from superset.utils.network import is_hostname_valid, is_port_open
|
||||
|
||||
|
|
@ -392,10 +392,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
schema: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
) -> Engine:
|
||||
user_name = utils.get_username()
|
||||
return database.get_sqla_engine(
|
||||
schema=schema, nullpool=True, user_name=user_name, source=source
|
||||
)
|
||||
return database.get_sqla_engine(schema=schema, source=source)
|
||||
|
||||
@classmethod
|
||||
def get_timestamp_expr(
|
||||
|
|
@ -1158,15 +1155,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
raise Exception("Database does not support cost estimation")
|
||||
|
||||
@classmethod
|
||||
def process_statement(
|
||||
cls, statement: str, database: "Database", user_name: str
|
||||
) -> str:
|
||||
def process_statement(cls, statement: str, database: "Database") -> str:
|
||||
"""
|
||||
Process a SQL statement by stripping and mutating it.
|
||||
|
||||
:param statement: A single SQL statement
|
||||
:param database: Database instance
|
||||
:param user_name: Effective username
|
||||
:return: Dictionary with different costs
|
||||
"""
|
||||
parsed_query = ParsedQuery(statement)
|
||||
|
|
@ -1175,7 +1169,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
if sql_query_mutator:
|
||||
sql = sql_query_mutator(
|
||||
sql,
|
||||
user_name=user_name,
|
||||
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
|
||||
security_manager=security_manager,
|
||||
database=database,
|
||||
)
|
||||
|
|
@ -1198,7 +1192,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
if not cls.get_allow_cost_estimate(extra):
|
||||
raise Exception("Database does not support cost estimation")
|
||||
|
||||
user_name = g.user.username if g.user and hasattr(g.user, "username") else None
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
statements = parsed_query.get_statements()
|
||||
|
||||
|
|
@ -1207,9 +1200,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in statements:
|
||||
processed_statement = cls.process_statement(
|
||||
statement, database, user_name
|
||||
)
|
||||
processed_statement = cls.process_statement(statement, database)
|
||||
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
|
||||
return costs
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
|||
from superset.models.tags import FavStarUpdater
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.utils import cache as cache_util, core as utils
|
||||
from superset.utils.core import get_username
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
config = app.config
|
||||
|
|
@ -322,29 +323,21 @@ class Database(
|
|||
conn.password = PASSWORD_MASK if conn.password else None
|
||||
self.sqlalchemy_uri = str(conn) # hides the password
|
||||
|
||||
def get_effective_user(
|
||||
self,
|
||||
object_url: URL,
|
||||
user_name: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
def get_effective_user(self, object_url: URL) -> Optional[str]:
|
||||
"""
|
||||
Get the effective user, especially during impersonation.
|
||||
|
||||
:param object_url: SQL Alchemy URL object
|
||||
:param user_name: Default username
|
||||
:return: The effective username
|
||||
"""
|
||||
effective_username = None
|
||||
if self.impersonate_user:
|
||||
effective_username = object_url.username
|
||||
if user_name:
|
||||
effective_username = user_name
|
||||
elif (
|
||||
hasattr(g, "user")
|
||||
and hasattr(g.user, "username")
|
||||
and g.user.username is not None
|
||||
):
|
||||
effective_username = g.user.username
|
||||
return effective_username
|
||||
|
||||
return ( # pylint: disable=used-before-assignment
|
||||
username
|
||||
if (username := get_username())
|
||||
else object_url.username
|
||||
if self.impersonate_user
|
||||
else None
|
||||
)
|
||||
|
||||
@memoized(
|
||||
watch=(
|
||||
|
|
@ -358,13 +351,12 @@ class Database(
|
|||
self,
|
||||
schema: Optional[str] = None,
|
||||
nullpool: bool = True,
|
||||
user_name: Optional[str] = None,
|
||||
source: Optional[utils.QuerySource] = None,
|
||||
) -> Engine:
|
||||
extra = self.get_extra()
|
||||
sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted)
|
||||
self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
|
||||
effective_username = self.get_effective_user(sqlalchemy_url, user_name)
|
||||
effective_username = self.get_effective_user(sqlalchemy_url)
|
||||
# If using MySQL or Presto for example, will set url.username
|
||||
# If using Hive, will not do anything yet since that relies on a
|
||||
# configuration parameter instead.
|
||||
|
|
@ -421,12 +413,9 @@ class Database(
|
|||
sql: str,
|
||||
schema: Optional[str] = None,
|
||||
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
sqls = self.db_engine_spec.parse_sql(sql)
|
||||
|
||||
engine = self.get_sqla_engine(schema=schema, user_name=username)
|
||||
username = utils.get_username() or username
|
||||
engine = self.get_sqla_engine(schema)
|
||||
|
||||
def needs_conversion(df_series: pd.Series) -> bool:
|
||||
return (
|
||||
|
|
@ -437,7 +426,14 @@ class Database(
|
|||
|
||||
def _log_query(sql: str) -> None:
|
||||
if log_query:
|
||||
log_query(engine.url, sql, schema, username, __name__, security_manager)
|
||||
log_query(
|
||||
engine.url,
|
||||
sql,
|
||||
schema,
|
||||
get_username(),
|
||||
__name__,
|
||||
security_manager,
|
||||
)
|
||||
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import pandas as pd
|
|||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset import app, jinja_context
|
||||
from superset import app, jinja_context, security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.models.reports import ReportSchedule, ReportScheduleValidatorType
|
||||
from superset.reports.commands.exceptions import (
|
||||
|
|
@ -36,6 +36,7 @@ from superset.reports.commands.exceptions import (
|
|||
AlertQueryTimeout,
|
||||
AlertValidatorConfigError,
|
||||
)
|
||||
from superset.utils.core import override_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -145,18 +146,21 @@ class AlertCommand(BaseCommand):
|
|||
limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql(
|
||||
rendered_sql, ALERT_SQL_LIMIT
|
||||
)
|
||||
query_username = app.config["THUMBNAIL_SELENIUM_USER"]
|
||||
start = default_timer()
|
||||
df = self._report_schedule.database.get_df(
|
||||
sql=limited_rendered_sql, username=query_username
|
||||
)
|
||||
stop = default_timer()
|
||||
logger.info(
|
||||
"Query for %s took %.2f ms",
|
||||
self._report_schedule.name,
|
||||
(stop - start) * 1000.0,
|
||||
)
|
||||
return df
|
||||
|
||||
with override_user(
|
||||
security_manager.find_user(
|
||||
username=app.config["THUMBNAIL_SELENIUM_USER"]
|
||||
)
|
||||
):
|
||||
start = default_timer()
|
||||
df = self._report_schedule.database.get_df(sql=limited_rendered_sql)
|
||||
stop = default_timer()
|
||||
logger.info(
|
||||
"Query for %s took %.2f ms",
|
||||
self._report_schedule.name,
|
||||
(stop - start) * 1000.0,
|
||||
)
|
||||
return df
|
||||
except SoftTimeLimitExceeded as ex:
|
||||
logger.warning("A timeout occurred while executing the alert query: %s", ex)
|
||||
raise AlertQueryTimeout() from ex
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from celery.exceptions import SoftTimeLimitExceeded
|
|||
from flask_appbuilder.security.sqla.models import User
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import app
|
||||
from superset import app, security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
|
|
@ -179,11 +179,10 @@ class BaseReportState:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_user(self) -> User:
|
||||
user = (
|
||||
self._session.query(User)
|
||||
.filter(User.username == app.config["THUMBNAIL_SELENIUM_USER"])
|
||||
.one_or_none()
|
||||
@staticmethod
|
||||
def _get_user() -> User:
|
||||
user = security_manager.find_user(
|
||||
username=app.config["THUMBNAIL_SELENIUM_USER"]
|
||||
)
|
||||
if not user:
|
||||
raise ReportScheduleSelleniumUserNotFoundError()
|
||||
|
|
|
|||
|
|
@ -50,7 +50,13 @@ from superset.result_set import SupersetResultSet
|
|||
from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.core import json_iso_dttm_ser, QuerySource, zlib_compress
|
||||
from superset.utils.core import (
|
||||
get_username,
|
||||
json_iso_dttm_ser,
|
||||
override_user,
|
||||
QuerySource,
|
||||
zlib_compress,
|
||||
)
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import stats_timing
|
||||
|
||||
|
|
@ -155,37 +161,35 @@ def get_sql_results( # pylint: disable=too-many-arguments
|
|||
rendered_query: str,
|
||||
return_results: bool = True,
|
||||
store_results: bool = False,
|
||||
user_name: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
expand_data: bool = False,
|
||||
log_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Executes the sql query returns the results."""
|
||||
with session_scope(not ctask.request.called_directly) as session:
|
||||
|
||||
try:
|
||||
return execute_sql_statements(
|
||||
query_id,
|
||||
rendered_query,
|
||||
return_results,
|
||||
store_results,
|
||||
user_name,
|
||||
session=session,
|
||||
start_time=start_time,
|
||||
expand_data=expand_data,
|
||||
log_params=log_params,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.debug("Query %d: %s", query_id, ex)
|
||||
stats_logger.incr("error_sqllab_unhandled")
|
||||
query = get_query(query_id, session)
|
||||
return handle_query_error(ex, query, session)
|
||||
with override_user(security_manager.find_user(username)):
|
||||
try:
|
||||
return execute_sql_statements(
|
||||
query_id,
|
||||
rendered_query,
|
||||
return_results,
|
||||
store_results,
|
||||
session=session,
|
||||
start_time=start_time,
|
||||
expand_data=expand_data,
|
||||
log_params=log_params,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.debug("Query %d: %s", query_id, ex)
|
||||
stats_logger.incr("error_sqllab_unhandled")
|
||||
query = get_query(query_id, session)
|
||||
return handle_query_error(ex, query, session)
|
||||
|
||||
|
||||
def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements
|
||||
sql_statement: str,
|
||||
query: Query,
|
||||
user_name: Optional[str],
|
||||
session: Session,
|
||||
cursor: Any,
|
||||
log_params: Optional[Dict[str, Any]],
|
||||
|
|
@ -204,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
|
|||
parsed_query._parsed[0], # pylint: disable=protected-access
|
||||
database.id,
|
||||
query.schema,
|
||||
username=user_name,
|
||||
username=get_username(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -246,7 +250,10 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
|
|||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
sql = SQL_QUERY_MUTATOR(
|
||||
sql, user_name=user_name, security_manager=security_manager, database=database
|
||||
sql,
|
||||
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
|
||||
security_manager=security_manager,
|
||||
database=database,
|
||||
)
|
||||
try:
|
||||
query.executed_sql = sql
|
||||
|
|
@ -255,7 +262,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
|
|||
query.database.sqlalchemy_uri,
|
||||
query.executed_sql,
|
||||
query.schema,
|
||||
user_name,
|
||||
get_username(),
|
||||
__name__,
|
||||
security_manager,
|
||||
log_params,
|
||||
|
|
@ -375,7 +382,6 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
rendered_query: str,
|
||||
return_results: bool,
|
||||
store_results: bool,
|
||||
user_name: Optional[str],
|
||||
session: Session,
|
||||
start_time: Optional[float],
|
||||
expand_data: bool,
|
||||
|
|
@ -452,12 +458,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
)
|
||||
)
|
||||
|
||||
engine = database.get_sqla_engine(
|
||||
schema=query.schema,
|
||||
nullpool=True,
|
||||
user_name=user_name,
|
||||
source=QuerySource.SQL_LAB,
|
||||
)
|
||||
engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
|
|
@ -490,7 +491,6 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
result_set = execute_sql_statement(
|
||||
statement,
|
||||
query,
|
||||
user_name,
|
||||
session,
|
||||
cursor,
|
||||
log_params,
|
||||
|
|
@ -597,7 +597,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
|
|||
return None
|
||||
|
||||
|
||||
def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
|
||||
def cancel_query(query: Query) -> bool:
|
||||
"""
|
||||
Cancel a running query.
|
||||
|
||||
|
|
@ -605,7 +605,6 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
|
|||
action is required.
|
||||
|
||||
:param query: Query to cancel
|
||||
:param user_name: Default username
|
||||
:return: True if query cancelled successfully, False otherwise
|
||||
"""
|
||||
|
||||
|
|
@ -616,12 +615,7 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
|
|||
if cancel_query_id is None:
|
||||
return False
|
||||
|
||||
engine = query.database.get_sqla_engine(
|
||||
schema=query.schema,
|
||||
nullpool=True,
|
||||
user_name=user_name,
|
||||
source=QuerySource.SQL_LAB,
|
||||
)
|
||||
engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
|
||||
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
with closing(conn.cursor()) as cursor:
|
||||
|
|
|
|||
|
|
@ -20,13 +20,11 @@ import time
|
|||
from contextlib import closing
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask import g
|
||||
|
||||
from superset import app, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
|
||||
from superset.utils.core import QuerySource
|
||||
from superset.utils.core import get_username, QuerySource
|
||||
|
||||
MAX_ERROR_ROWS = 10
|
||||
|
||||
|
|
@ -45,7 +43,10 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||
|
||||
@classmethod
|
||||
def validate_statement(
|
||||
cls, statement: str, database: Database, cursor: Any, user_name: str
|
||||
cls,
|
||||
statement: str,
|
||||
database: Database,
|
||||
cursor: Any,
|
||||
) -> Optional[SQLValidationAnnotation]:
|
||||
# pylint: disable=too-many-locals
|
||||
db_engine_spec = database.db_engine_spec
|
||||
|
|
@ -57,7 +58,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||
if sql_query_mutator:
|
||||
sql = sql_query_mutator(
|
||||
sql,
|
||||
user_name=user_name,
|
||||
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
|
||||
security_manager=security_manager,
|
||||
database=database,
|
||||
)
|
||||
|
|
@ -157,26 +158,18 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
|
||||
VALIDATE) SELECT 1 FROM default.mytable.
|
||||
"""
|
||||
user_name = g.user.username if g.user and hasattr(g.user, "username") else None
|
||||
parsed_query = ParsedQuery(sql)
|
||||
statements = parsed_query.get_statements()
|
||||
|
||||
logger.info("Validating %i statement(s)", len(statements))
|
||||
engine = database.get_sqla_engine(
|
||||
schema=schema,
|
||||
nullpool=True,
|
||||
user_name=user_name,
|
||||
source=QuerySource.SQL_LAB,
|
||||
)
|
||||
engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB)
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
annotations: List[SQLValidationAnnotation] = []
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in parsed_query.get_statements():
|
||||
annotation = cls.validate_statement(
|
||||
statement, database, cursor, user_name
|
||||
)
|
||||
annotation = cls.validate_statement(statement, database, cursor)
|
||||
if annotation:
|
||||
annotations.append(annotation)
|
||||
logger.debug("Validation found %i error(s)", len(annotations))
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import logging
|
|||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from flask import g
|
||||
from flask_babel import gettext as __
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
|
|
@ -34,6 +33,7 @@ from superset.exceptions import (
|
|||
)
|
||||
from superset.sqllab.command_status import SqlJsonExecutionStatus
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import get_username
|
||||
from superset.utils.dates import now_as_float
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -139,9 +139,7 @@ class SynchronousSqlJsonExecutor(SqlJsonExecutorBase):
|
|||
rendered_query,
|
||||
return_results=True,
|
||||
store_results=self._is_store_results(execution_context),
|
||||
user_name=g.user.username
|
||||
if g.user and hasattr(g.user, "username")
|
||||
else None,
|
||||
username=get_username(),
|
||||
expand_data=execution_context.expand_data,
|
||||
log_params=log_params,
|
||||
)
|
||||
|
|
@ -174,9 +172,7 @@ class ASynchronousSqlJsonExecutor(SqlJsonExecutorBase):
|
|||
rendered_query,
|
||||
return_results=False,
|
||||
store_results=not execution_context.select_as_cta,
|
||||
user_name=g.user.username
|
||||
if g.user and hasattr(g.user, "username")
|
||||
else None,
|
||||
username=get_username(),
|
||||
start_time=now_as_float(),
|
||||
expand_data=execution_context.expand_data,
|
||||
log_params=log_params,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ import threading
|
|||
import traceback
|
||||
import uuid
|
||||
import zlib
|
||||
from contextlib import contextmanager
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from distutils.util import strtobool
|
||||
from email.mime.application import MIMEApplication
|
||||
|
|
@ -1408,6 +1409,30 @@ def get_username() -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_user(user: Optional[User]) -> Iterator[Any]:
|
||||
"""
|
||||
Temporarily override the current user (if defined) per `flask.g`.
|
||||
|
||||
Sometimes, often in the context of async Celery tasks, it is useful to switch the
|
||||
current user (which may be undefined) to different one, execute some SQLAlchemy
|
||||
tasks and then revert back to the original one.
|
||||
|
||||
:param user: The override user
|
||||
"""
|
||||
|
||||
# pylint: disable=assigning-non-slot
|
||||
if hasattr(g, "user"):
|
||||
current = g.user
|
||||
g.user = user
|
||||
yield
|
||||
g.user = current
|
||||
else:
|
||||
g.user = user
|
||||
yield
|
||||
delattr(g, "user")
|
||||
|
||||
|
||||
def parse_ssl_cert(certificate: str) -> _Certificate:
|
||||
"""
|
||||
Parses the contents of a certificate and returns a valid certificate object
|
||||
|
|
|
|||
|
|
@ -1367,11 +1367,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
database.set_sqlalchemy_uri(uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
|
||||
username = (
|
||||
g.user.username if g.user and hasattr(g.user, "username") else None
|
||||
)
|
||||
engine = database.get_sqla_engine(user_name=username)
|
||||
engine = database.get_sqla_engine()
|
||||
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
if engine.dialect.do_ping(conn):
|
||||
|
|
@ -2298,7 +2294,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
return self.json_response("OK")
|
||||
|
||||
if not sql_lab.cancel_query(query, g.user.username if g.user else None):
|
||||
if not sql_lab.cancel_query(query):
|
||||
raise SupersetCancelQueryException("Could not cancel query")
|
||||
|
||||
query.status = QueryStatus.STOPPED
|
||||
|
|
|
|||
|
|
@ -28,8 +28,11 @@ from __future__ import annotations
|
|||
from typing import Callable, TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, Mock, PropertyMock
|
||||
|
||||
from flask import Flask
|
||||
from flask.ctx import AppContext
|
||||
from pytest import fixture
|
||||
|
||||
from superset.app import create_app
|
||||
from tests.example_data.data_loading.pandas.pandas_data_loader import PandasDataLoader
|
||||
from tests.example_data.data_loading.pandas.pands_data_loading_conf import (
|
||||
PandasLoaderConfigurations,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ import unittest
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask import g
|
||||
from flask.ctx import AppContext
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
|
|
@ -41,6 +43,7 @@ from superset.connectors.connector_registry import ConnectorRegistry
|
|||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models import core as models
|
||||
from superset.models.datasource_access_request import DatasourceAccessRequest
|
||||
from superset.utils.core import get_username, override_user
|
||||
from superset.utils.database import get_example_database
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
|
@ -86,7 +89,7 @@ DB_ACCESS_ROLE = "db_access_role"
|
|||
SCHEMA_ACCESS_ROLE = "schema_access_role"
|
||||
|
||||
|
||||
def create_access_request(session, ds_type, ds_name, role_name, user_name):
|
||||
def create_access_request(session, ds_type, ds_name, role_name, username):
|
||||
ds_class = ConnectorRegistry.sources[ds_type]
|
||||
# TODO: generalize datasource names
|
||||
if ds_type == "table":
|
||||
|
|
@ -102,7 +105,7 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name):
|
|||
access_request = DatasourceAccessRequest(
|
||||
datasource_id=ds.id,
|
||||
datasource_type=ds_type,
|
||||
created_by_fk=security_manager.find_user(username=user_name).id,
|
||||
created_by_fk=security_manager.find_user(username=username).id,
|
||||
)
|
||||
session.add(access_request)
|
||||
session.commit()
|
||||
|
|
@ -565,5 +568,46 @@ class TestRequestAccess(SupersetTestCase):
|
|||
session.commit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"username",
|
||||
[
|
||||
None,
|
||||
"gamma",
|
||||
],
|
||||
)
|
||||
def test_get_username(app_context: AppContext, username: str) -> None:
|
||||
assert not hasattr(g, "user")
|
||||
assert get_username() is None
|
||||
|
||||
g.user = security_manager.find_user(username)
|
||||
assert get_username() == username
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"username",
|
||||
[
|
||||
None,
|
||||
"gamma",
|
||||
],
|
||||
)
|
||||
def test_override_user(app_context: AppContext, username: str) -> None:
|
||||
admin = security_manager.find_user(username="admin")
|
||||
user = security_manager.find_user(username)
|
||||
|
||||
assert not hasattr(g, "user")
|
||||
|
||||
with override_user(user):
|
||||
assert g.user == user
|
||||
|
||||
assert not hasattr(g, "user")
|
||||
|
||||
g.user = admin
|
||||
|
||||
with override_user(user):
|
||||
assert g.user == user
|
||||
|
||||
assert g.user == admin
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -329,7 +329,7 @@ class SupersetTestCase(TestCase):
|
|||
self,
|
||||
sql,
|
||||
client_id=None,
|
||||
user_name=None,
|
||||
username=None,
|
||||
raise_on_error=False,
|
||||
query_limit=None,
|
||||
database_name="examples",
|
||||
|
|
@ -340,9 +340,9 @@ class SupersetTestCase(TestCase):
|
|||
ctas_method=CtasMethod.TABLE,
|
||||
template_params="{}",
|
||||
):
|
||||
if user_name:
|
||||
if username:
|
||||
self.logout()
|
||||
self.login(username=(user_name or "admin"))
|
||||
self.login(username=username)
|
||||
dbid = SupersetTestCase.get_database_by_name(database_name).id
|
||||
json_payload = {
|
||||
"database_id": dbid,
|
||||
|
|
@ -427,14 +427,14 @@ class SupersetTestCase(TestCase):
|
|||
self,
|
||||
sql,
|
||||
client_id=None,
|
||||
user_name=None,
|
||||
username=None,
|
||||
raise_on_error=False,
|
||||
database_name="examples",
|
||||
template_params=None,
|
||||
):
|
||||
if user_name:
|
||||
if username:
|
||||
self.logout()
|
||||
self.login(username=(user_name if user_name else "admin"))
|
||||
self.login(username=username)
|
||||
dbid = SupersetTestCase.get_database_by_name(database_name).id
|
||||
resp = self.get_json_resp(
|
||||
"/superset/validate_sql_json/",
|
||||
|
|
|
|||
|
|
@ -1064,7 +1064,7 @@ class TestCore(SupersetTestCase):
|
|||
LIMIT 10;
|
||||
""",
|
||||
client_id="client_id_1",
|
||||
user_name="admin",
|
||||
username="admin",
|
||||
)
|
||||
count_ds = []
|
||||
count_name = []
|
||||
|
|
@ -1454,7 +1454,7 @@ class TestCore(SupersetTestCase):
|
|||
self.run_sql(
|
||||
"SELECT name FROM birth_names",
|
||||
"client_id_1",
|
||||
user_name=username,
|
||||
username=username,
|
||||
raise_on_error=True,
|
||||
sql_editor_id=str(tab_state_id),
|
||||
)
|
||||
|
|
@ -1462,7 +1462,7 @@ class TestCore(SupersetTestCase):
|
|||
self.run_sql(
|
||||
"SELECT name FROM birth_names",
|
||||
"client_id_2",
|
||||
user_name=username,
|
||||
username=username,
|
||||
raise_on_error=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,8 +20,10 @@ import textwrap
|
|||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from superset import security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.utils.core import override_user
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
load_birth_names_data,
|
||||
|
|
@ -112,21 +114,22 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
)
|
||||
def test_database_impersonate_user(self):
|
||||
uri = "mysql://root@localhost"
|
||||
example_user = "giuseppe"
|
||||
example_user = security_manager.find_user(username="gamma")
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=uri)
|
||||
|
||||
model.impersonate_user = True
|
||||
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
|
||||
self.assertEqual(example_user, user_name)
|
||||
with override_user(example_user):
|
||||
model.impersonate_user = True
|
||||
username = make_url(model.get_sqla_engine().url).username
|
||||
self.assertEqual(example_user.username, username)
|
||||
|
||||
model.impersonate_user = False
|
||||
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
|
||||
self.assertNotEqual(example_user, user_name)
|
||||
model.impersonate_user = False
|
||||
username = make_url(model.get_sqla_engine().url).username
|
||||
self.assertNotEqual(example_user.username, username)
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_impersonate_user_presto(self, mocked_create_engine):
|
||||
uri = "presto://localhost"
|
||||
principal_user = "logged_in_user"
|
||||
principal_user = security_manager.find_user(username="gamma")
|
||||
extra = """
|
||||
{
|
||||
"metadata_params": {},
|
||||
|
|
@ -142,64 +145,66 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
}
|
||||
"""
|
||||
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
|
||||
with override_user(principal_user):
|
||||
model = Database(
|
||||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
assert str(call_args[0][0]) == "presto://gamma@localhost"
|
||||
|
||||
assert str(call_args[0][0]) == "presto://logged_in_user@localhost"
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
"principal_username": "gamma",
|
||||
}
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
"principal_username": "logged_in_user",
|
||||
}
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
assert str(call_args[0][0]) == "presto://localhost"
|
||||
|
||||
assert str(call_args[0][0]) == "presto://localhost"
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
}
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
}
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_impersonate_user_trino(self, mocked_create_engine):
|
||||
uri = "trino://localhost"
|
||||
principal_user = "logged_in_user"
|
||||
principal_user = security_manager.find_user(username="gamma")
|
||||
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=uri)
|
||||
with override_user(principal_user):
|
||||
model = Database(
|
||||
database_name="test_database", sqlalchemy_uri="trino://localhost"
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
assert str(call_args[0][0]) == "trino://localhost"
|
||||
assert call_args[1]["connect_args"] == {"user": "gamma"}
|
||||
|
||||
assert str(call_args[0][0]) == "trino://localhost"
|
||||
model = Database(
|
||||
database_name="test_database",
|
||||
sqlalchemy_uri="trino://original_user:original_user_password@localhost",
|
||||
)
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"user": "logged_in_user",
|
||||
}
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
uri = "trino://original_user:original_user_password@localhost"
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=uri)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "trino://original_user@localhost"
|
||||
|
||||
assert call_args[1]["connect_args"] == {"user": "logged_in_user"}
|
||||
assert str(call_args[0][0]) == "trino://original_user@localhost"
|
||||
assert call_args[1]["connect_args"] == {"user": "gamma"}
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_impersonate_user_hive(self, mocked_create_engine):
|
||||
uri = "hive://localhost"
|
||||
principal_user = "logged_in_user"
|
||||
principal_user = security_manager.find_user(username="gamma")
|
||||
extra = """
|
||||
{
|
||||
"metadata_params": {},
|
||||
|
|
@ -215,32 +220,34 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
}
|
||||
"""
|
||||
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
|
||||
with override_user(principal_user):
|
||||
model = Database(
|
||||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
"configuration": {"hive.server2.proxy.user": "gamma"},
|
||||
}
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
"configuration": {"hive.server2.proxy.user": "logged_in_user"},
|
||||
}
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
}
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_select_star(self):
|
||||
|
|
@ -345,19 +352,6 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
df = main_db.get_df("USE superset; SELECT ';';", None)
|
||||
self.assertEqual(df.iat[0, 0], ";")
|
||||
|
||||
@mock.patch("superset.models.core.Database.get_sqla_engine")
|
||||
def test_username_param(self, mocked_get_sqla_engine):
|
||||
main_db = get_example_database()
|
||||
main_db.impersonate_user = True
|
||||
test_username = "test_username_param"
|
||||
|
||||
if main_db.backend == "mysql":
|
||||
main_db.get_df("USE superset; SELECT 1", username=test_username)
|
||||
mocked_get_sqla_engine.assert_called_with(
|
||||
schema=None,
|
||||
user_name="test_username_param",
|
||||
)
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_get_sqla_engine(self, mocked_create_engine):
|
||||
model = Database(
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ class TestPrestoValidator(SupersetTestCase):
|
|||
"message": "your query isn't how I like it",
|
||||
}
|
||||
|
||||
@patch("superset.sql_validators.presto_db.g")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_validator_success(self, flask_g):
|
||||
flask_g.user.username = "nobody"
|
||||
sql = "SELECT 1 FROM default.notarealtable"
|
||||
|
|
@ -197,7 +197,7 @@ class TestPrestoValidator(SupersetTestCase):
|
|||
|
||||
self.assertEqual([], errors)
|
||||
|
||||
@patch("superset.sql_validators.presto_db.g")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_validator_db_error(self, flask_g):
|
||||
flask_g.user.username = "nobody"
|
||||
sql = "SELECT 1 FROM default.notarealtable"
|
||||
|
|
@ -209,7 +209,7 @@ class TestPrestoValidator(SupersetTestCase):
|
|||
with self.assertRaises(PrestoSQLValidationError):
|
||||
self.validator.validate(sql, schema, self.database)
|
||||
|
||||
@patch("superset.sql_validators.presto_db.g")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_validator_unexpected_error(self, flask_g):
|
||||
flask_g.user.username = "nobody"
|
||||
sql = "SELECT 1 FROM default.notarealtable"
|
||||
|
|
@ -221,7 +221,7 @@ class TestPrestoValidator(SupersetTestCase):
|
|||
with self.assertRaises(Exception):
|
||||
self.validator.validate(sql, schema, self.database)
|
||||
|
||||
@patch("superset.sql_validators.presto_db.g")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_validator_query_error(self, flask_g):
|
||||
flask_g.user.username = "nobody"
|
||||
sql = "SELECT 1 FROM default.notarealtable"
|
||||
|
|
|
|||
|
|
@ -68,9 +68,9 @@ class TestSqlLab(SupersetTestCase):
|
|||
def run_some_queries(self):
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
self.run_sql(QUERY_1, client_id="client_id_1", user_name="admin")
|
||||
self.run_sql(QUERY_2, client_id="client_id_3", user_name="admin")
|
||||
self.run_sql(QUERY_3, client_id="client_id_2", user_name="gamma_sqllab")
|
||||
self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
|
||||
self.run_sql(QUERY_2, client_id="client_id_3", username="admin")
|
||||
self.run_sql(QUERY_3, client_id="client_id_2", username="gamma_sqllab")
|
||||
self.logout()
|
||||
|
||||
def tearDown(self):
|
||||
|
|
@ -162,7 +162,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
with freeze_time(datetime.now().isoformat(timespec="seconds")):
|
||||
self.run_sql(sql_statement, "1")
|
||||
self.run_sql(sql_statement, "1", username="admin")
|
||||
saved_query_ = (
|
||||
db.session.query(SavedQuery)
|
||||
.filter(
|
||||
|
|
@ -248,7 +248,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
# Gamma user, with sqllab and db permission
|
||||
self.create_user_with_roles("Gagarin", ["ExampleDBAccess", "Gamma", "sql_lab"])
|
||||
|
||||
data = self.run_sql(QUERY_1, "1", user_name="Gagarin")
|
||||
data = self.run_sql(QUERY_1, "1", username="Gagarin")
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
self.assertLess(0, len(data["data"]))
|
||||
|
|
@ -278,14 +278,14 @@ class TestSqlLab(SupersetTestCase):
|
|||
)
|
||||
|
||||
data = self.run_sql(
|
||||
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", user_name="SchemaUser"
|
||||
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser"
|
||||
)
|
||||
self.assertEqual(1, len(data["data"]))
|
||||
|
||||
data = self.run_sql(
|
||||
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table",
|
||||
"4",
|
||||
user_name="SchemaUser",
|
||||
username="SchemaUser",
|
||||
schema=CTAS_SCHEMA_NAME,
|
||||
)
|
||||
self.assertEqual(1, len(data["data"]))
|
||||
|
|
@ -295,7 +295,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
data = self.run_sql(
|
||||
"SELECT * FROM test_table",
|
||||
"5",
|
||||
user_name="SchemaUser",
|
||||
username="SchemaUser",
|
||||
schema=CTAS_SCHEMA_NAME,
|
||||
)
|
||||
self.assertEqual(1, len(data["data"]))
|
||||
|
|
@ -441,7 +441,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
self.run_sql(
|
||||
"SELECT name as col, gender as col FROM birth_names LIMIT 10",
|
||||
client_id="2e2df3",
|
||||
user_name="admin",
|
||||
username="admin",
|
||||
raise_on_error=True,
|
||||
)
|
||||
|
||||
|
|
@ -747,7 +747,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
|
|
@ -758,7 +757,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock.call(
|
||||
"SET @value = 42",
|
||||
mock_query,
|
||||
"admin",
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
|
|
@ -767,7 +765,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock.call(
|
||||
"SELECT @value AS foo",
|
||||
mock_query,
|
||||
"admin",
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
|
|
@ -804,7 +801,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
|
|
@ -858,7 +854,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
|
|
@ -869,7 +864,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock.call(
|
||||
"SET @value = 42",
|
||||
mock_query,
|
||||
"admin",
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
|
|
@ -878,7 +872,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
mock.call(
|
||||
"SELECT @value AS foo",
|
||||
mock_query,
|
||||
"admin",
|
||||
mock_session,
|
||||
mock_cursor,
|
||||
None,
|
||||
|
|
@ -895,7 +888,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
|
|
@ -929,7 +921,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
rendered_query=sql,
|
||||
return_results=True,
|
||||
store_results=False,
|
||||
user_name="admin",
|
||||
session=mock_session,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ import sqlparse
|
|||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.utils.core import override_user
|
||||
|
||||
|
||||
def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
|
||||
"""
|
||||
|
|
@ -46,7 +48,6 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
|
|||
execute_sql_statement(
|
||||
sql_statement,
|
||||
query,
|
||||
user_name=None,
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params={},
|
||||
|
|
@ -95,7 +96,6 @@ def test_execute_sql_statement_with_rls(
|
|||
execute_sql_statement(
|
||||
sql_statement,
|
||||
query,
|
||||
user_name=None,
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params={},
|
||||
|
|
@ -153,16 +153,24 @@ def test_sql_lab_insert_rls(
|
|||
session.add(query)
|
||||
session.commit()
|
||||
|
||||
# first without RLS
|
||||
superset_result_set = execute_sql_statement(
|
||||
sql_statement=query.sql,
|
||||
query=query,
|
||||
user_name="admin",
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params=None,
|
||||
apply_ctas=False,
|
||||
admin = User(
|
||||
first_name="Alice",
|
||||
last_name="Doe",
|
||||
email="adoe@example.org",
|
||||
username="admin",
|
||||
roles=[Role(name="Admin")],
|
||||
)
|
||||
|
||||
# first without RLS
|
||||
with override_user(admin):
|
||||
superset_result_set = execute_sql_statement(
|
||||
sql_statement=query.sql,
|
||||
query=query,
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params=None,
|
||||
apply_ctas=False,
|
||||
)
|
||||
assert (
|
||||
superset_result_set.to_pandas_df().to_markdown()
|
||||
== """
|
||||
|
|
@ -177,13 +185,6 @@ def test_sql_lab_insert_rls(
|
|||
assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"
|
||||
|
||||
# now with RLS
|
||||
admin = User(
|
||||
first_name="Alice",
|
||||
last_name="Doe",
|
||||
email="adoe@example.org",
|
||||
username="admin",
|
||||
roles=[Role(name="Admin")],
|
||||
)
|
||||
rls = RowLevelSecurityFilter(
|
||||
filter_type=RowLevelSecurityFilterType.REGULAR,
|
||||
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
|
||||
|
|
@ -196,15 +197,15 @@ def test_sql_lab_insert_rls(
|
|||
mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
|
||||
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
|
||||
|
||||
superset_result_set = execute_sql_statement(
|
||||
sql_statement=query.sql,
|
||||
query=query,
|
||||
user_name="admin",
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params=None,
|
||||
apply_ctas=False,
|
||||
)
|
||||
with override_user(admin):
|
||||
superset_result_set = execute_sql_statement(
|
||||
sql_statement=query.sql,
|
||||
query=query,
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
log_params=None,
|
||||
apply_ctas=False,
|
||||
)
|
||||
assert (
|
||||
superset_result_set.to_pandas_df().to_markdown()
|
||||
== """
|
||||
|
|
|
|||
Loading…
Reference in New Issue