fix: Refactor SQL username logic (#19914)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2022-05-12 21:03:05 -07:00 committed by GitHub
parent fff9ad05d4
commit 449d08b25e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 388 additions and 340 deletions

View File

@ -31,6 +31,7 @@ from flask_appbuilder.api.manager import resolver
import superset.utils.database as database_utils
from superset.extensions import db
from superset.utils.core import override_user
from superset.utils.encrypt import SecretsMigrator
logger = logging.getLogger(__name__)
@ -54,23 +55,34 @@ def set_database_uri(database_name: str, uri: str, skip_create: bool) -> None:
@click.command()
@with_appcontext
def update_datasources_cache() -> None:
@click.option(
"--username",
"-u",
default=None,
help=(
"Specify which user should execute the underlying SQL queries. If undefined "
"defaults to the user registered with the database connection."
),
)
def update_datasources_cache(username: Optional[str]) -> None:
"""Refresh sqllab datasources cache"""
# pylint: disable=import-outside-toplevel
from superset import security_manager
from superset.models.core import Database
for database in db.session.query(Database).all():
if database.allow_multi_schema_metadata_fetch:
print("Fetching {} datasources ...".format(database.name))
try:
database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60
)
database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60
)
except Exception as ex: # pylint: disable=broad-except
print("{}".format(str(ex)))
with override_user(security_manager.find_user(username)):
for database in db.session.query(Database).all():
if database.allow_multi_schema_metadata_fetch:
print("Fetching {} datasources ...".format(database.name))
try:
database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60
)
database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60
)
except Exception as ex: # pylint: disable=broad-except
print("{}".format(str(ex)))
@click.command()

View File

@ -680,7 +680,7 @@ BACKUP_COUNT = 30
# database,
# query,
# schema=None,
# user=None,
# user=None, # TODO(john-bodley): Deprecate in 3.0.
# client=None,
# security_manager=None,
# log_params=None,
@ -1020,9 +1020,14 @@ DB_CONNECTION_MUTATOR = None
# The use case is can be around adding some sort of comment header
# with information such as the username and worker node information
#
# def SQL_QUERY_MUTATOR(sql, user_name=user_name, security_manager=security_manager, database=database):
# def SQL_QUERY_MUTATOR(
# sql,
# user_name=user_name, # TODO(john-bodley): Deprecate in 3.0.
# security_manager=security_manager,
# database=database,
# ):
# dttm = datetime.now().isoformat()
# return f"-- [SQL LAB] {username} {dttm}\n{sql}"
# return f"-- [SQL LAB] {user_name} {dttm}\n{sql}"
# For backward compatibility, you can unpack any of the above arguments in your
# function definition, but keep the **kwargs as the last argument to allow new args
# to be added later without any errors.

View File

@ -120,6 +120,7 @@ from superset.utils import core as utils
from superset.utils.core import (
GenericDataType,
get_column_name,
get_username,
is_adhoc_column,
MediumText,
QueryObjectFilterClause,
@ -917,10 +918,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
Typically adds comments to the query with context"""
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
username = utils.get_username()
sql = sql_query_mutator(
sql,
user_name=username,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
security_manager=security_manager,
database=self.database,
)

View File

@ -38,6 +38,7 @@ from superset.errors import ErrorLevel, SupersetErrorType
from superset.exceptions import SupersetSecurityException, SupersetTimeoutException
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.core import override_user
logger = logging.getLogger(__name__)
@ -74,42 +75,43 @@ class TestConnectionDatabaseCommand(BaseCommand):
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)
event_logger.log_with_context(
action="test_connection_attempt",
engine=database.db_engine_spec.__name__,
)
with closing(engine.raw_connection()) as conn:
try:
alive = func_timeout(
int(
app.config[
"TEST_DATABASE_CONNECTION_TIMEOUT"
].total_seconds()
),
engine.dialect.do_ping,
args=(conn,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(conn)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception: # pylint: disable=broad-except
alive = False
if not alive:
raise DBAPIError(None, None, None)
with override_user(self._actor):
engine = database.get_sqla_engine()
event_logger.log_with_context(
action="test_connection_attempt",
engine=database.db_engine_spec.__name__,
)
with closing(engine.raw_connection()) as conn:
try:
alive = func_timeout(
int(
app.config[
"TEST_DATABASE_CONNECTION_TIMEOUT"
].total_seconds()
),
engine.dialect.do_ping,
args=(conn,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(conn)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database "
"settings, and ensure that your database is accepting "
"connections, then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception: # pylint: disable=broad-except
alive = False
if not alive:
raise DBAPIError(None, None, None)
# Log succesful connection test with engine
event_logger.log_with_context(

View File

@ -35,6 +35,7 @@ from superset.db_engine_specs.base import BasicParametersMixin
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.core import override_user
BYPASS_VALIDATION_ENGINES = {"bigquery"}
@ -115,22 +116,23 @@ class ValidateDatabaseParametersCommand(BaseCommand):
)
database.set_sqlalchemy_uri(sqlalchemy_uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex
with override_user(self._actor):
engine = database.get_sqla_engine()
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex
if not alive:
raise DatabaseOfflineError(

View File

@ -40,7 +40,7 @@ import pandas as pd
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask import current_app, g
from flask import current_app
from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
from marshmallow.validate import Range
@ -64,7 +64,7 @@ from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.core import ColumnSpec, GenericDataType, get_username
from superset.utils.hashing import md5_sha_from_str
from superset.utils.network import is_hostname_valid, is_port_open
@ -392,10 +392,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
schema: Optional[str] = None,
source: Optional[str] = None,
) -> Engine:
user_name = utils.get_username()
return database.get_sqla_engine(
schema=schema, nullpool=True, user_name=user_name, source=source
)
return database.get_sqla_engine(schema=schema, source=source)
@classmethod
def get_timestamp_expr(
@ -1158,15 +1155,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
raise Exception("Database does not support cost estimation")
@classmethod
def process_statement(
cls, statement: str, database: "Database", user_name: str
) -> str:
def process_statement(cls, statement: str, database: "Database") -> str:
"""
Process a SQL statement by stripping and mutating it.
:param statement: A single SQL statement
:param database: Database instance
:param user_name: Effective username
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
@ -1175,7 +1169,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
if sql_query_mutator:
sql = sql_query_mutator(
sql,
user_name=user_name,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
security_manager=security_manager,
database=database,
)
@ -1198,7 +1192,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")
user_name = g.user.username if g.user and hasattr(g.user, "username") else None
parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()
@ -1207,9 +1200,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(
statement, database, user_name
)
processed_statement = cls.process_statement(statement, database)
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
return costs

View File

@ -61,6 +61,7 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.models.tags import FavStarUpdater
from superset.result_set import SupersetResultSet
from superset.utils import cache as cache_util, core as utils
from superset.utils.core import get_username
from superset.utils.memoized import memoized
config = app.config
@ -322,29 +323,21 @@ class Database(
conn.password = PASSWORD_MASK if conn.password else None
self.sqlalchemy_uri = str(conn) # hides the password
def get_effective_user(
self,
object_url: URL,
user_name: Optional[str] = None,
) -> Optional[str]:
def get_effective_user(self, object_url: URL) -> Optional[str]:
"""
Get the effective user, especially during impersonation.
:param object_url: SQL Alchemy URL object
:param user_name: Default username
:return: The effective username
"""
effective_username = None
if self.impersonate_user:
effective_username = object_url.username
if user_name:
effective_username = user_name
elif (
hasattr(g, "user")
and hasattr(g.user, "username")
and g.user.username is not None
):
effective_username = g.user.username
return effective_username
return ( # pylint: disable=used-before-assignment
username
if (username := get_username())
else object_url.username
if self.impersonate_user
else None
)
@memoized(
watch=(
@ -358,13 +351,12 @@ class Database(
self,
schema: Optional[str] = None,
nullpool: bool = True,
user_name: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted)
self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
effective_username = self.get_effective_user(sqlalchemy_url, user_name)
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
# configuration parameter instead.
@ -421,12 +413,9 @@ class Database(
sql: str,
schema: Optional[str] = None,
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
username: Optional[str] = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
engine = self.get_sqla_engine(schema=schema, user_name=username)
username = utils.get_username() or username
engine = self.get_sqla_engine(schema)
def needs_conversion(df_series: pd.Series) -> bool:
return (
@ -437,7 +426,14 @@ class Database(
def _log_query(sql: str) -> None:
if log_query:
log_query(engine.url, sql, schema, username, __name__, security_manager)
log_query(
engine.url,
sql,
schema,
get_username(),
__name__,
security_manager,
)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()

View File

@ -25,7 +25,7 @@ import pandas as pd
from celery.exceptions import SoftTimeLimitExceeded
from flask_babel import lazy_gettext as _
from superset import app, jinja_context
from superset import app, jinja_context, security_manager
from superset.commands.base import BaseCommand
from superset.models.reports import ReportSchedule, ReportScheduleValidatorType
from superset.reports.commands.exceptions import (
@ -36,6 +36,7 @@ from superset.reports.commands.exceptions import (
AlertQueryTimeout,
AlertValidatorConfigError,
)
from superset.utils.core import override_user
logger = logging.getLogger(__name__)
@ -145,18 +146,21 @@ class AlertCommand(BaseCommand):
limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql(
rendered_sql, ALERT_SQL_LIMIT
)
query_username = app.config["THUMBNAIL_SELENIUM_USER"]
start = default_timer()
df = self._report_schedule.database.get_df(
sql=limited_rendered_sql, username=query_username
)
stop = default_timer()
logger.info(
"Query for %s took %.2f ms",
self._report_schedule.name,
(stop - start) * 1000.0,
)
return df
with override_user(
security_manager.find_user(
username=app.config["THUMBNAIL_SELENIUM_USER"]
)
):
start = default_timer()
df = self._report_schedule.database.get_df(sql=limited_rendered_sql)
stop = default_timer()
logger.info(
"Query for %s took %.2f ms",
self._report_schedule.name,
(stop - start) * 1000.0,
)
return df
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while executing the alert query: %s", ex)
raise AlertQueryTimeout() from ex

View File

@ -25,7 +25,7 @@ from celery.exceptions import SoftTimeLimitExceeded
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.orm import Session
from superset import app
from superset import app, security_manager
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandException
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
@ -179,11 +179,10 @@ class BaseReportState:
**kwargs,
)
def _get_user(self) -> User:
user = (
self._session.query(User)
.filter(User.username == app.config["THUMBNAIL_SELENIUM_USER"])
.one_or_none()
@staticmethod
def _get_user() -> User:
user = security_manager.find_user(
username=app.config["THUMBNAIL_SELENIUM_USER"]
)
if not user:
raise ReportScheduleSelleniumUserNotFoundError()

View File

@ -50,7 +50,13 @@ from superset.result_set import SupersetResultSet
from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.celery import session_scope
from superset.utils.core import json_iso_dttm_ser, QuerySource, zlib_compress
from superset.utils.core import (
get_username,
json_iso_dttm_ser,
override_user,
QuerySource,
zlib_compress,
)
from superset.utils.dates import now_as_float
from superset.utils.decorators import stats_timing
@ -155,37 +161,35 @@ def get_sql_results( # pylint: disable=too-many-arguments
rendered_query: str,
return_results: bool = True,
store_results: bool = False,
user_name: Optional[str] = None,
username: Optional[str] = None,
start_time: Optional[float] = None,
expand_data: bool = False,
log_params: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""Executes the sql query returns the results."""
with session_scope(not ctask.request.called_directly) as session:
try:
return execute_sql_statements(
query_id,
rendered_query,
return_results,
store_results,
user_name,
session=session,
start_time=start_time,
expand_data=expand_data,
log_params=log_params,
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id, session)
return handle_query_error(ex, query, session)
with override_user(security_manager.find_user(username)):
try:
return execute_sql_statements(
query_id,
rendered_query,
return_results,
store_results,
session=session,
start_time=start_time,
expand_data=expand_data,
log_params=log_params,
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id, session)
return handle_query_error(ex, query, session)
def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements
sql_statement: str,
query: Query,
user_name: Optional[str],
session: Session,
cursor: Any,
log_params: Optional[Dict[str, Any]],
@ -204,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
parsed_query._parsed[0], # pylint: disable=protected-access
database.id,
query.schema,
username=user_name,
username=get_username(),
)
)
)
@ -246,7 +250,10 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
# Hook to allow environment-specific mutation (usually comments) to the SQL
sql = SQL_QUERY_MUTATOR(
sql, user_name=user_name, security_manager=security_manager, database=database
sql,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
security_manager=security_manager,
database=database,
)
try:
query.executed_sql = sql
@ -255,7 +262,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
query.database.sqlalchemy_uri,
query.executed_sql,
query.schema,
user_name,
get_username(),
__name__,
security_manager,
log_params,
@ -375,7 +382,6 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
rendered_query: str,
return_results: bool,
store_results: bool,
user_name: Optional[str],
session: Session,
start_time: Optional[float],
expand_data: bool,
@ -452,12 +458,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
)
)
engine = database.get_sqla_engine(
schema=query.schema,
nullpool=True,
user_name=user_name,
source=QuerySource.SQL_LAB,
)
engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
# Sharing a single connection and cursor across the
# execution of all statements (if many)
with closing(engine.raw_connection()) as conn:
@ -490,7 +491,6 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
result_set = execute_sql_statement(
statement,
query,
user_name,
session,
cursor,
log_params,
@ -597,7 +597,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
return None
def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
def cancel_query(query: Query) -> bool:
"""
Cancel a running query.
@ -605,7 +605,6 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
action is required.
:param query: Query to cancel
:param user_name: Default username
:return: True if query cancelled successfully, False otherwise
"""
@ -616,12 +615,7 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
if cancel_query_id is None:
return False
engine = query.database.get_sqla_engine(
schema=query.schema,
nullpool=True,
user_name=user_name,
source=QuerySource.SQL_LAB,
)
engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:

View File

@ -20,13 +20,11 @@ import time
from contextlib import closing
from typing import Any, Dict, List, Optional
from flask import g
from superset import app, security_manager
from superset.models.core import Database
from superset.sql_parse import ParsedQuery
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
from superset.utils.core import QuerySource
from superset.utils.core import get_username, QuerySource
MAX_ERROR_ROWS = 10
@ -45,7 +43,10 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate_statement(
cls, statement: str, database: Database, cursor: Any, user_name: str
cls,
statement: str,
database: Database,
cursor: Any,
) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
@ -57,7 +58,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
if sql_query_mutator:
sql = sql_query_mutator(
sql,
user_name=user_name,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
security_manager=security_manager,
database=database,
)
@ -157,26 +158,18 @@ class PrestoDBSQLValidator(BaseSQLValidator):
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
user_name = g.user.username if g.user and hasattr(g.user, "username") else None
parsed_query = ParsedQuery(sql)
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))
engine = database.get_sqla_engine(
schema=schema,
nullpool=True,
user_name=user_name,
source=QuerySource.SQL_LAB,
)
engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB)
# Sharing a single connection and cursor across the
# execution of all statements (if many)
annotations: List[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(
statement, database, cursor, user_name
)
annotation = cls.validate_statement(statement, database, cursor)
if annotation:
annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))

View File

@ -22,7 +22,6 @@ import logging
from abc import ABC
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
from flask import g
from flask_babel import gettext as __
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@ -34,6 +33,7 @@ from superset.exceptions import (
)
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.utils import core as utils
from superset.utils.core import get_username
from superset.utils.dates import now_as_float
if TYPE_CHECKING:
@ -139,9 +139,7 @@ class SynchronousSqlJsonExecutor(SqlJsonExecutorBase):
rendered_query,
return_results=True,
store_results=self._is_store_results(execution_context),
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
username=get_username(),
expand_data=execution_context.expand_data,
log_params=log_params,
)
@ -174,9 +172,7 @@ class ASynchronousSqlJsonExecutor(SqlJsonExecutorBase):
rendered_query,
return_results=False,
store_results=not execution_context.select_as_cta,
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
username=get_username(),
start_time=now_as_float(),
expand_data=execution_context.expand_data,
log_params=log_params,

View File

@ -32,6 +32,7 @@ import threading
import traceback
import uuid
import zlib
from contextlib import contextmanager
from datetime import date, datetime, time, timedelta
from distutils.util import strtobool
from email.mime.application import MIMEApplication
@ -1408,6 +1409,30 @@ def get_username() -> Optional[str]:
return None
@contextmanager
def override_user(user: Optional[User]) -> Iterator[Any]:
"""
Temporarily override the current user (if defined) per `flask.g`.
Sometimes, often in the context of async Celery tasks, it is useful to switch the
current user (which may be undefined) to different one, execute some SQLAlchemy
tasks and then revert back to the original one.
:param user: The override user
"""
# pylint: disable=assigning-non-slot
if hasattr(g, "user"):
current = g.user
g.user = user
yield
g.user = current
else:
g.user = user
yield
delattr(g, "user")
def parse_ssl_cert(certificate: str) -> _Certificate:
"""
Parses the contents of a certificate and returns a valid certificate object

View File

@ -1367,11 +1367,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = (
g.user.username if g.user and hasattr(g.user, "username") else None
)
engine = database.get_sqla_engine(user_name=username)
engine = database.get_sqla_engine()
with closing(engine.raw_connection()) as conn:
if engine.dialect.do_ping(conn):
@ -2298,7 +2294,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
return self.json_response("OK")
if not sql_lab.cancel_query(query, g.user.username if g.user else None):
if not sql_lab.cancel_query(query):
raise SupersetCancelQueryException("Could not cancel query")
query.status = QueryStatus.STOPPED

View File

@ -28,8 +28,11 @@ from __future__ import annotations
from typing import Callable, TYPE_CHECKING
from unittest.mock import MagicMock, Mock, PropertyMock
from flask import Flask
from flask.ctx import AppContext
from pytest import fixture
from superset.app import create_app
from tests.example_data.data_loading.pandas.pandas_data_loader import PandasDataLoader
from tests.example_data.data_loading.pandas.pands_data_loading_conf import (
PandasLoaderConfigurations,

View File

@ -21,6 +21,8 @@ import unittest
from unittest import mock
import pytest
from flask import g
from flask.ctx import AppContext
from sqlalchemy import inspect
from tests.integration_tests.fixtures.birth_names_dashboard import (
@ -41,6 +43,7 @@ from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.utils.core import get_username, override_user
from superset.utils.database import get_example_database
from .base_tests import SupersetTestCase
@ -86,7 +89,7 @@ DB_ACCESS_ROLE = "db_access_role"
SCHEMA_ACCESS_ROLE = "schema_access_role"
def create_access_request(session, ds_type, ds_name, role_name, user_name):
def create_access_request(session, ds_type, ds_name, role_name, username):
ds_class = ConnectorRegistry.sources[ds_type]
# TODO: generalize datasource names
if ds_type == "table":
@ -102,7 +105,7 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name):
access_request = DatasourceAccessRequest(
datasource_id=ds.id,
datasource_type=ds_type,
created_by_fk=security_manager.find_user(username=user_name).id,
created_by_fk=security_manager.find_user(username=username).id,
)
session.add(access_request)
session.commit()
@ -565,5 +568,46 @@ class TestRequestAccess(SupersetTestCase):
session.commit()
@pytest.mark.parametrize(
"username",
[
None,
"gamma",
],
)
def test_get_username(app_context: AppContext, username: str) -> None:
assert not hasattr(g, "user")
assert get_username() is None
g.user = security_manager.find_user(username)
assert get_username() == username
@pytest.mark.parametrize(
"username",
[
None,
"gamma",
],
)
def test_override_user(app_context: AppContext, username: str) -> None:
admin = security_manager.find_user(username="admin")
user = security_manager.find_user(username)
assert not hasattr(g, "user")
with override_user(user):
assert g.user == user
assert not hasattr(g, "user")
g.user = admin
with override_user(user):
assert g.user == user
assert g.user == admin
if __name__ == "__main__":
unittest.main()

View File

@ -329,7 +329,7 @@ class SupersetTestCase(TestCase):
self,
sql,
client_id=None,
user_name=None,
username=None,
raise_on_error=False,
query_limit=None,
database_name="examples",
@ -340,9 +340,9 @@ class SupersetTestCase(TestCase):
ctas_method=CtasMethod.TABLE,
template_params="{}",
):
if user_name:
if username:
self.logout()
self.login(username=(user_name or "admin"))
self.login(username=username)
dbid = SupersetTestCase.get_database_by_name(database_name).id
json_payload = {
"database_id": dbid,
@ -427,14 +427,14 @@ class SupersetTestCase(TestCase):
self,
sql,
client_id=None,
user_name=None,
username=None,
raise_on_error=False,
database_name="examples",
template_params=None,
):
if user_name:
if username:
self.logout()
self.login(username=(user_name if user_name else "admin"))
self.login(username=username)
dbid = SupersetTestCase.get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/validate_sql_json/",

View File

@ -1064,7 +1064,7 @@ class TestCore(SupersetTestCase):
LIMIT 10;
""",
client_id="client_id_1",
user_name="admin",
username="admin",
)
count_ds = []
count_name = []
@ -1454,7 +1454,7 @@ class TestCore(SupersetTestCase):
self.run_sql(
"SELECT name FROM birth_names",
"client_id_1",
user_name=username,
username=username,
raise_on_error=True,
sql_editor_id=str(tab_state_id),
)
@ -1462,7 +1462,7 @@ class TestCore(SupersetTestCase):
self.run_sql(
"SELECT name FROM birth_names",
"client_id_2",
user_name=username,
username=username,
raise_on_error=True,
)

View File

@ -20,8 +20,10 @@ import textwrap
import unittest
from unittest import mock
from superset import security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetException
from superset.utils.core import override_user
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
@ -112,21 +114,22 @@ class TestDatabaseModel(SupersetTestCase):
)
def test_database_impersonate_user(self):
uri = "mysql://root@localhost"
example_user = "giuseppe"
example_user = security_manager.find_user(username="gamma")
model = Database(database_name="test_database", sqlalchemy_uri=uri)
model.impersonate_user = True
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
self.assertEqual(example_user, user_name)
with override_user(example_user):
model.impersonate_user = True
username = make_url(model.get_sqla_engine().url).username
self.assertEqual(example_user.username, username)
model.impersonate_user = False
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
self.assertNotEqual(example_user, user_name)
model.impersonate_user = False
username = make_url(model.get_sqla_engine().url).username
self.assertNotEqual(example_user.username, username)
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_presto(self, mocked_create_engine):
uri = "presto://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
extra = """
{
"metadata_params": {},
@ -142,64 +145,66 @@ class TestDatabaseModel(SupersetTestCase):
}
"""
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://gamma@localhost"
assert str(call_args[0][0]) == "presto://logged_in_user@localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"principal_username": "gamma",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"principal_username": "logged_in_user",
}
model.impersonate_user = False
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = False
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://localhost"
assert str(call_args[0][0]) == "presto://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
uri = "trino://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
model = Database(database_name="test_database", sqlalchemy_uri=uri)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri="trino://localhost"
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://localhost"
assert call_args[1]["connect_args"] == {"user": "gamma"}
assert str(call_args[0][0]) == "trino://localhost"
model = Database(
database_name="test_database",
sqlalchemy_uri="trino://original_user:original_user_password@localhost",
)
assert call_args[1]["connect_args"] == {
"user": "logged_in_user",
}
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
uri = "trino://original_user:original_user_password@localhost"
model = Database(database_name="test_database", sqlalchemy_uri=uri)
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://original_user@localhost"
assert call_args[1]["connect_args"] == {"user": "logged_in_user"}
assert str(call_args[0][0]) == "trino://original_user@localhost"
assert call_args[1]["connect_args"] == {"user": "gamma"}
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_hive(self, mocked_create_engine):
uri = "hive://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
extra = """
{
"metadata_params": {},
@ -215,32 +220,34 @@ class TestDatabaseModel(SupersetTestCase):
}
"""
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
assert str(call_args[0][0]) == "hive://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"configuration": {"hive.server2.proxy.user": "gamma"},
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"configuration": {"hive.server2.proxy.user": "logged_in_user"},
}
model.impersonate_user = False
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = False
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
assert str(call_args[0][0]) == "hive://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_select_star(self):
@ -345,19 +352,6 @@ class TestDatabaseModel(SupersetTestCase):
df = main_db.get_df("USE superset; SELECT ';';", None)
self.assertEqual(df.iat[0, 0], ";")
@mock.patch("superset.models.core.Database.get_sqla_engine")
def test_username_param(self, mocked_get_sqla_engine):
main_db = get_example_database()
main_db.impersonate_user = True
test_username = "test_username_param"
if main_db.backend == "mysql":
main_db.get_df("USE superset; SELECT 1", username=test_username)
mocked_get_sqla_engine.assert_called_with(
schema=None,
user_name="test_username_param",
)
@mock.patch("superset.models.core.create_engine")
def test_get_sqla_engine(self, mocked_create_engine):
model = Database(

View File

@ -187,7 +187,7 @@ class TestPrestoValidator(SupersetTestCase):
"message": "your query isn't how I like it",
}
@patch("superset.sql_validators.presto_db.g")
@patch("superset.utils.core.g")
def test_validator_success(self, flask_g):
flask_g.user.username = "nobody"
sql = "SELECT 1 FROM default.notarealtable"
@ -197,7 +197,7 @@ class TestPrestoValidator(SupersetTestCase):
self.assertEqual([], errors)
@patch("superset.sql_validators.presto_db.g")
@patch("superset.utils.core.g")
def test_validator_db_error(self, flask_g):
flask_g.user.username = "nobody"
sql = "SELECT 1 FROM default.notarealtable"
@ -209,7 +209,7 @@ class TestPrestoValidator(SupersetTestCase):
with self.assertRaises(PrestoSQLValidationError):
self.validator.validate(sql, schema, self.database)
@patch("superset.sql_validators.presto_db.g")
@patch("superset.utils.core.g")
def test_validator_unexpected_error(self, flask_g):
flask_g.user.username = "nobody"
sql = "SELECT 1 FROM default.notarealtable"
@ -221,7 +221,7 @@ class TestPrestoValidator(SupersetTestCase):
with self.assertRaises(Exception):
self.validator.validate(sql, schema, self.database)
@patch("superset.sql_validators.presto_db.g")
@patch("superset.utils.core.g")
def test_validator_query_error(self, flask_g):
flask_g.user.username = "nobody"
sql = "SELECT 1 FROM default.notarealtable"

View File

@ -68,9 +68,9 @@ class TestSqlLab(SupersetTestCase):
def run_some_queries(self):
db.session.query(Query).delete()
db.session.commit()
self.run_sql(QUERY_1, client_id="client_id_1", user_name="admin")
self.run_sql(QUERY_2, client_id="client_id_3", user_name="admin")
self.run_sql(QUERY_3, client_id="client_id_2", user_name="gamma_sqllab")
self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
self.run_sql(QUERY_2, client_id="client_id_3", username="admin")
self.run_sql(QUERY_3, client_id="client_id_2", username="gamma_sqllab")
self.logout()
def tearDown(self):
@ -162,7 +162,7 @@ class TestSqlLab(SupersetTestCase):
db.session.commit()
with freeze_time(datetime.now().isoformat(timespec="seconds")):
self.run_sql(sql_statement, "1")
self.run_sql(sql_statement, "1", username="admin")
saved_query_ = (
db.session.query(SavedQuery)
.filter(
@ -248,7 +248,7 @@ class TestSqlLab(SupersetTestCase):
# Gamma user, with sqllab and db permission
self.create_user_with_roles("Gagarin", ["ExampleDBAccess", "Gamma", "sql_lab"])
data = self.run_sql(QUERY_1, "1", user_name="Gagarin")
data = self.run_sql(QUERY_1, "1", username="Gagarin")
db.session.query(Query).delete()
db.session.commit()
self.assertLess(0, len(data["data"]))
@ -278,14 +278,14 @@ class TestSqlLab(SupersetTestCase):
)
data = self.run_sql(
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", user_name="SchemaUser"
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser"
)
self.assertEqual(1, len(data["data"]))
data = self.run_sql(
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table",
"4",
user_name="SchemaUser",
username="SchemaUser",
schema=CTAS_SCHEMA_NAME,
)
self.assertEqual(1, len(data["data"]))
@ -295,7 +295,7 @@ class TestSqlLab(SupersetTestCase):
data = self.run_sql(
"SELECT * FROM test_table",
"5",
user_name="SchemaUser",
username="SchemaUser",
schema=CTAS_SCHEMA_NAME,
)
self.assertEqual(1, len(data["data"]))
@ -441,7 +441,7 @@ class TestSqlLab(SupersetTestCase):
self.run_sql(
"SELECT name as col, gender as col FROM birth_names LIMIT 10",
client_id="2e2df3",
user_name="admin",
username="admin",
raise_on_error=True,
)
@ -747,7 +747,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
@ -758,7 +757,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SET @value = 42",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
@ -767,7 +765,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SELECT @value AS foo",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
@ -804,7 +801,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
@ -858,7 +854,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
@ -869,7 +864,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SET @value = 42",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
@ -878,7 +872,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SELECT @value AS foo",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
@ -895,7 +888,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
@ -929,7 +921,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,

View File

@ -20,6 +20,8 @@ import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
from superset.utils.core import override_user
def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
"""
@ -46,7 +48,6 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
execute_sql_statement(
sql_statement,
query,
user_name=None,
session=session,
cursor=cursor,
log_params={},
@ -95,7 +96,6 @@ def test_execute_sql_statement_with_rls(
execute_sql_statement(
sql_statement,
query,
user_name=None,
session=session,
cursor=cursor,
log_params={},
@ -153,16 +153,24 @@ def test_sql_lab_insert_rls(
session.add(query)
session.commit()
# first without RLS
superset_result_set = execute_sql_statement(
sql_statement=query.sql,
query=query,
user_name="admin",
session=session,
cursor=cursor,
log_params=None,
apply_ctas=False,
admin = User(
first_name="Alice",
last_name="Doe",
email="adoe@example.org",
username="admin",
roles=[Role(name="Admin")],
)
# first without RLS
with override_user(admin):
superset_result_set = execute_sql_statement(
sql_statement=query.sql,
query=query,
session=session,
cursor=cursor,
log_params=None,
apply_ctas=False,
)
assert (
superset_result_set.to_pandas_df().to_markdown()
== """
@ -177,13 +185,6 @@ def test_sql_lab_insert_rls(
assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"
# now with RLS
admin = User(
first_name="Alice",
last_name="Doe",
email="adoe@example.org",
username="admin",
roles=[Role(name="Admin")],
)
rls = RowLevelSecurityFilter(
filter_type=RowLevelSecurityFilterType.REGULAR,
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
@ -196,15 +197,15 @@ def test_sql_lab_insert_rls(
mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
superset_result_set = execute_sql_statement(
sql_statement=query.sql,
query=query,
user_name="admin",
session=session,
cursor=cursor,
log_params=None,
apply_ctas=False,
)
with override_user(admin):
superset_result_set = execute_sql_statement(
sql_statement=query.sql,
query=query,
session=session,
cursor=cursor,
log_params=None,
apply_ctas=False,
)
assert (
superset_result_set.to_pandas_df().to_markdown()
== """