chore: upgrade black (#19410)

This commit is contained in:
Ville Brofeldt 2022-03-29 20:03:09 +03:00 committed by GitHub
parent 816a2c3e1e
commit a619cb4ea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
204 changed files with 2125 additions and 608 deletions

View File

@ -41,7 +41,7 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
args: ["--markdown-linebreak-ext=md"] args: ["--markdown-linebreak-ext=md"]
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 19.10b0 rev: 22.3.0
hooks: hooks:
- id: black - id: black
language_version: python3 language_version: python3

View File

@ -167,7 +167,10 @@ class GitChangeLog:
return f"### {self._version} ({self._logs[0].time})" return f"### {self._version} ({self._logs[0].time})"
def _parse_change_log( def _parse_change_log(
self, changelog: Dict[str, str], pr_info: Dict[str, str], github_login: str, self,
changelog: Dict[str, str],
pr_info: Dict[str, str],
github_login: str,
) -> None: ) -> None:
formatted_pr = ( formatted_pr = (
f"- [#{pr_info.get('id')}]" f"- [#{pr_info.get('id')}]"
@ -355,7 +358,8 @@ def compare(base_parameters: BaseParameters) -> None:
@cli.command("changelog") @cli.command("changelog")
@click.option( @click.option(
"--csv", help="The csv filename to export the changelog to", "--csv",
help="The csv filename to export the changelog to",
) )
@click.option( @click.option(
"--access_token", "--access_token",

View File

@ -106,7 +106,12 @@ def inter_send_email(
class BaseParameters(object): class BaseParameters(object):
def __init__( def __init__(
self, email: str, username: str, password: str, version: str, version_rc: str, self,
email: str,
username: str,
password: str,
version: str,
version_rc: str,
) -> None: ) -> None:
self.email = email self.email = email
self.username = username self.username = username

View File

@ -60,7 +60,8 @@ def request(
def list_runs( def list_runs(
repo: str, params: Optional[Dict[str, str]] = None, repo: str,
params: Optional[Dict[str, str]] = None,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[Dict[str, Any]]:
"""List all github workflow runs. """List all github workflow runs.
Returns: Returns:
@ -193,7 +194,11 @@ def cancel_github_workflows(
if branch and ":" in branch: if branch and ":" in branch:
[user, branch] = branch.split(":", 2) [user, branch] = branch.split(":", 2)
runs = get_runs( runs = get_runs(
repo, branch=branch, user=user, statuses=statuses, events=events, repo,
branch=branch,
user=user,
statuses=statuses,
events=events,
) )
# sort old jobs to the front, so to cancel older jobs first # sort old jobs to the front, so to cancel older jobs first

View File

@ -73,7 +73,9 @@ class UpdateAnnotationCommand(BaseCommand):
# Validate short descr uniqueness on this layer # Validate short descr uniqueness on this layer
if not AnnotationDAO.validate_update_uniqueness( if not AnnotationDAO.validate_update_uniqueness(
layer_id, short_descr, annotation_id=self._model_id, layer_id,
short_descr,
annotation_id=self._model_id,
): ):
exceptions.append(AnnotationUniquenessValidationError()) exceptions.append(AnnotationUniquenessValidationError())
else: else:

View File

@ -64,13 +64,17 @@ class AnnotationPostSchema(Schema):
) )
long_descr = fields.String(description=annotation_long_descr, allow_none=True) long_descr = fields.String(description=annotation_long_descr, allow_none=True)
start_dttm = fields.DateTime( start_dttm = fields.DateTime(
description=annotation_start_dttm, required=True, allow_none=False, description=annotation_start_dttm,
required=True,
allow_none=False,
) )
end_dttm = fields.DateTime( end_dttm = fields.DateTime(
description=annotation_end_dttm, required=True, allow_none=False description=annotation_end_dttm, required=True, allow_none=False
) )
json_metadata = fields.String( json_metadata = fields.String(
description=annotation_json_metadata, validate=validate_json, allow_none=True, description=annotation_json_metadata,
validate=validate_json,
allow_none=True,
) )

View File

@ -110,9 +110,11 @@ class CacheRestApi(BaseSupersetModelRestApi):
) )
try: try:
delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member delete_stmt = (
CacheKey.__table__.delete().where( # pylint: disable=no-member
CacheKey.cache_key.in_(cache_keys) CacheKey.cache_key.in_(cache_keys)
) )
)
db.session.execute(delete_stmt) db.session.execute(delete_stmt)
db.session.commit() db.session.commit()
self.stats_logger.gauge("invalidated_cache", len(cache_keys)) self.stats_logger.gauge("invalidated_cache", len(cache_keys))

View File

@ -25,9 +25,15 @@ from superset.charts.schemas import (
class Datasource(Schema): class Datasource(Schema):
database_name = fields.String(description="Datasource name",) database_name = fields.String(
datasource_name = fields.String(description=datasource_name_description,) description="Datasource name",
schema = fields.String(description="Datasource schema",) )
datasource_name = fields.String(
description=datasource_name_description,
)
schema = fields.String(
description="Datasource schema",
)
datasource_type = fields.String( datasource_type = fields.String(
description=datasource_type_description, description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")), validate=validate.OneOf(choices=("druid", "table", "view")),
@ -37,7 +43,8 @@ class Datasource(Schema):
class CacheInvalidationRequestSchema(Schema): class CacheInvalidationRequestSchema(Schema):
datasource_uids = fields.List( datasource_uids = fields.List(
fields.String(), description=datasource_uid_description, fields.String(),
description=datasource_uid_description,
) )
datasources = fields.List( datasources = fields.List(
fields.Nested(Datasource), fields.Nested(Datasource),

View File

@ -279,7 +279,8 @@ class ChartCacheScreenshotResponseSchema(Schema):
class ChartDataColumnSchema(Schema): class ChartDataColumnSchema(Schema):
column_name = fields.String( column_name = fields.String(
description="The name of the target column", example="mycol", description="The name of the target column",
example="mycol",
) )
type = fields.String(description="Type of target column", example="BIGINT") type = fields.String(description="Type of target column", example="BIGINT")
@ -325,7 +326,8 @@ class ChartDataAdhocMetricSchema(Schema):
example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30", example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30",
) )
timeGrain = fields.String( timeGrain = fields.String(
description="Optional time grain for temporal filters", example="PT1M", description="Optional time grain for temporal filters",
example="PT1M",
) )
isExtra = fields.Boolean( isExtra = fields.Boolean(
description="Indicates if the filter has been added by a filter component as " description="Indicates if the filter has been added by a filter component as "
@ -370,7 +372,8 @@ class ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSch
groupby = ( groupby = (
fields.List( fields.List(
fields.String( fields.String(
allow_none=False, description="Columns by which to group by", allow_none=False,
description="Columns by which to group by",
), ),
minLength=1, minLength=1,
required=True, required=True,
@ -425,7 +428,9 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
example="percentile", example="percentile",
) )
window = fields.Integer( window = fields.Integer(
description="Size of the rolling window in days.", required=True, example=7, description="Size of the rolling window in days.",
required=True,
example=7,
) )
rolling_type_options = fields.Dict( rolling_type_options = fields.Dict(
desctiption="Optional options to pass to rolling method. Needed for " desctiption="Optional options to pass to rolling method. Needed for "
@ -592,7 +597,9 @@ class ChartDataBoxplotOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
""" """
groupby = fields.List( groupby = fields.List(
fields.String(description="Columns by which to group the query.",), fields.String(
description="Columns by which to group the query.",
),
allow_none=True, allow_none=True,
) )
@ -699,13 +706,16 @@ class ChartDataGeohashDecodeOptionsSchema(
""" """
geohash = fields.String( geohash = fields.String(
description="Name of source column containing geohash string", required=True, description="Name of source column containing geohash string",
required=True,
) )
latitude = fields.String( latitude = fields.String(
description="Name of target column for decoded latitude", required=True, description="Name of target column for decoded latitude",
required=True,
) )
longitude = fields.String( longitude = fields.String(
description="Name of target column for decoded longitude", required=True, description="Name of target column for decoded longitude",
required=True,
) )
@ -717,13 +727,16 @@ class ChartDataGeohashEncodeOptionsSchema(
""" """
latitude = fields.String( latitude = fields.String(
description="Name of source latitude column", required=True, description="Name of source latitude column",
required=True,
) )
longitude = fields.String( longitude = fields.String(
description="Name of source longitude column", required=True, description="Name of source longitude column",
required=True,
) )
geohash = fields.String( geohash = fields.String(
description="Name of target column for encoded geohash string", required=True, description="Name of target column for encoded geohash string",
required=True,
) )
@ -739,10 +752,12 @@ class ChartDataGeodeticParseOptionsSchema(
required=True, required=True,
) )
latitude = fields.String( latitude = fields.String(
description="Name of target column for decoded latitude", required=True, description="Name of target column for decoded latitude",
required=True,
) )
longitude = fields.String( longitude = fields.String(
description="Name of target column for decoded longitude", required=True, description="Name of target column for decoded longitude",
required=True,
) )
altitude = fields.String( altitude = fields.String(
description="Name of target column for decoded altitude. If omitted, " description="Name of target column for decoded altitude. If omitted, "
@ -789,7 +804,10 @@ class ChartDataPostProcessingOperationSchema(Schema):
"column": "age", "column": "age",
"options": {"q": 0.25}, "options": {"q": 0.25},
}, },
"age_mean": {"operator": "mean", "column": "age",}, "age_mean": {
"operator": "mean",
"column": "age",
},
}, },
}, },
) )
@ -816,7 +834,8 @@ class ChartDataFilterSchema(Schema):
example=["China", "France", "Japan"], example=["China", "France", "Japan"],
) )
grain = fields.String( grain = fields.String(
description="Optional time grain for temporal filters", example="PT1M", description="Optional time grain for temporal filters",
example="PT1M",
) )
isExtra = fields.Boolean( isExtra = fields.Boolean(
description="Indicates if the filter has been added by a filter component as " description="Indicates if the filter has been added by a filter component as "
@ -873,7 +892,10 @@ class AnnotationLayerSchema(Schema):
description="Type of annotation layer", description="Type of annotation layer",
validate=validate.OneOf(choices=[ann.value for ann in AnnotationType]), validate=validate.OneOf(choices=[ann.value for ann in AnnotationType]),
) )
color = fields.String(description="Layer color", allow_none=True,) color = fields.String(
description="Layer color",
allow_none=True,
)
descriptionColumns = fields.List( descriptionColumns = fields.List(
fields.String(), fields.String(),
description="Columns to use as the description. If none are provided, " description="Columns to use as the description. If none are provided, "
@ -911,7 +933,8 @@ class AnnotationLayerSchema(Schema):
) )
show = fields.Boolean(description="Should the layer be shown", required=True) show = fields.Boolean(description="Should the layer be shown", required=True)
showLabel = fields.Boolean( showLabel = fields.Boolean(
description="Should the label always be shown", allow_none=True, description="Should the label always be shown",
allow_none=True,
) )
showMarkers = fields.Boolean( showMarkers = fields.Boolean(
description="Should markers be shown. Only applies to line annotations.", description="Should markers be shown. Only applies to line annotations.",
@ -919,16 +942,34 @@ class AnnotationLayerSchema(Schema):
) )
sourceType = fields.String( sourceType = fields.String(
description="Type of source for annotation data", description="Type of source for annotation data",
validate=validate.OneOf(choices=("", "line", "NATIVE", "table",)), validate=validate.OneOf(
choices=(
"",
"line",
"NATIVE",
"table",
)
),
) )
style = fields.String( style = fields.String(
description="Line style. Only applies to time-series annotations", description="Line style. Only applies to time-series annotations",
validate=validate.OneOf(choices=("dashed", "dotted", "solid", "longDashed",)), validate=validate.OneOf(
choices=(
"dashed",
"dotted",
"solid",
"longDashed",
)
),
) )
timeColumn = fields.String( timeColumn = fields.String(
description="Column with event date or interval start date", allow_none=True, description="Column with event date or interval start date",
allow_none=True,
)
titleColumn = fields.String(
description="Column with title",
allow_none=True,
) )
titleColumn = fields.String(description="Column with title", allow_none=True,)
width = fields.Float( width = fields.Float(
description="Width of annotation line", description="Width of annotation line",
validate=[ validate=[
@ -948,7 +989,10 @@ class AnnotationLayerSchema(Schema):
class ChartDataDatasourceSchema(Schema): class ChartDataDatasourceSchema(Schema):
description = "Chart datasource" description = "Chart datasource"
id = fields.Integer(description="Datasource id", required=True,) id = fields.Integer(
description="Datasource id",
required=True,
)
type = fields.String( type = fields.String(
description="Datasource type", description="Datasource type",
validate=validate.OneOf(choices=("druid", "table")), validate=validate.OneOf(choices=("druid", "table")),
@ -1039,7 +1083,8 @@ class ChartDataQueryObjectSchema(Schema):
allow_none=True, allow_none=True,
) )
is_timeseries = fields.Boolean( is_timeseries = fields.Boolean(
description="Is the `query_object` a timeseries.", allow_none=True, description="Is the `query_object` a timeseries.",
allow_none=True,
) )
series_columns = fields.List( series_columns = fields.List(
fields.Raw(), fields.Raw(),
@ -1084,7 +1129,8 @@ class ChartDataQueryObjectSchema(Schema):
], ],
) )
order_desc = fields.Boolean( order_desc = fields.Boolean(
description="Reverse order. Default: `false`", allow_none=True, description="Reverse order. Default: `false`",
allow_none=True,
) )
extras = fields.Nested( extras = fields.Nested(
ChartDataExtrasSchema, ChartDataExtrasSchema,
@ -1151,7 +1197,10 @@ class ChartDataQueryObjectSchema(Schema):
description="Should the rowcount of the actual query be returned", description="Should the rowcount of the actual query be returned",
allow_none=True, allow_none=True,
) )
time_offsets = fields.List(fields.String(), allow_none=True,) time_offsets = fields.List(
fields.String(),
allow_none=True,
)
class ChartDataQueryContextSchema(Schema): class ChartDataQueryContextSchema(Schema):
@ -1190,7 +1239,9 @@ class AnnotationDataSchema(Schema):
required=True, required=True,
) )
records = fields.List( records = fields.List(
fields.Dict(keys=fields.String(),), fields.Dict(
keys=fields.String(),
),
description="records mapping the column name to it's value", description="records mapping the column name to it's value",
required=True, required=True,
) )
@ -1206,10 +1257,14 @@ class ChartDataResponseResult(Schema):
allow_none=True, allow_none=True,
) )
cache_key = fields.String( cache_key = fields.String(
description="Unique cache key for query object", required=True, allow_none=True, description="Unique cache key for query object",
required=True,
allow_none=True,
) )
cached_dttm = fields.String( cached_dttm = fields.String(
description="Cache timestamp", required=True, allow_none=True, description="Cache timestamp",
required=True,
allow_none=True,
) )
cache_timeout = fields.Integer( cache_timeout = fields.Integer(
description="Cache timeout in following order: custom timeout, datasource " description="Cache timeout in following order: custom timeout, datasource "
@ -1217,12 +1272,19 @@ class ChartDataResponseResult(Schema):
required=True, required=True,
allow_none=True, allow_none=True,
) )
error = fields.String(description="Error", allow_none=True,) error = fields.String(
description="Error",
allow_none=True,
)
is_cached = fields.Boolean( is_cached = fields.Boolean(
description="Is the result cached", required=True, allow_none=None, description="Is the result cached",
required=True,
allow_none=None,
) )
query = fields.String( query = fields.String(
description="The executed query statement", required=True, allow_none=False, description="The executed query statement",
required=True,
allow_none=False,
) )
status = fields.String( status = fields.String(
description="Status of the query", description="Status of the query",
@ -1240,10 +1302,12 @@ class ChartDataResponseResult(Schema):
allow_none=False, allow_none=False,
) )
stacktrace = fields.String( stacktrace = fields.String(
desciption="Stacktrace if there was an error", allow_none=True, desciption="Stacktrace if there was an error",
allow_none=True,
) )
rowcount = fields.Integer( rowcount = fields.Integer(
description="Amount of rows in result set", allow_none=False, description="Amount of rows in result set",
allow_none=False,
) )
data = fields.List(fields.Dict(), description="A list with results") data = fields.List(fields.Dict(), description="A list with results")
colnames = fields.List(fields.String(), description="A list of column names") colnames = fields.List(fields.String(), description="A list of column names")
@ -1273,13 +1337,24 @@ class ChartDataResponseSchema(Schema):
class ChartDataAsyncResponseSchema(Schema): class ChartDataAsyncResponseSchema(Schema):
channel_id = fields.String( channel_id = fields.String(
description="Unique session async channel ID", allow_none=False, description="Unique session async channel ID",
allow_none=False,
)
job_id = fields.String(
description="Unique async job ID",
allow_none=False,
)
user_id = fields.String(
description="Requesting user ID",
allow_none=True,
)
status = fields.String(
description="Status value for async job",
allow_none=False,
) )
job_id = fields.String(description="Unique async job ID", allow_none=False,)
user_id = fields.String(description="Requesting user ID", allow_none=True,)
status = fields.String(description="Status value for async job", allow_none=False,)
result_url = fields.String( result_url = fields.String(
description="Unique result URL for fetching async query data", allow_none=False, description="Unique result URL for fetching async query data",
allow_none=False,
) )

View File

@ -93,10 +93,16 @@ def load_examples_run(
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data") @click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
@click.option("--load-big-data", "-b", is_flag=True, help="Load additional big data") @click.option("--load-big-data", "-b", is_flag=True, help="Load additional big data")
@click.option( @click.option(
"--only-metadata", "-m", is_flag=True, help="Only load metadata, skip actual data", "--only-metadata",
"-m",
is_flag=True,
help="Only load metadata, skip actual data",
) )
@click.option( @click.option(
"--force", "-f", is_flag=True, help="Force load data even if table already exists", "--force",
"-f",
is_flag=True,
help="Force load data even if table already exists",
) )
def load_examples( def load_examples(
load_test_data: bool, load_test_data: bool,

View File

@ -36,10 +36,16 @@ logger = logging.getLogger(__name__)
@click.command() @click.command()
@click.argument("directory") @click.argument("directory")
@click.option( @click.option(
"--overwrite", "-o", is_flag=True, help="Overwriting existing metadata definitions", "--overwrite",
"-o",
is_flag=True,
help="Overwriting existing metadata definitions",
) )
@click.option( @click.option(
"--force", "-f", is_flag=True, help="Force load data even if table already exists", "--force",
"-f",
is_flag=True,
help="Force load data even if table already exists",
) )
def import_directory(directory: str, overwrite: bool, force: bool) -> None: def import_directory(directory: str, overwrite: bool, force: bool) -> None:
"""Imports configs from a given directory""" """Imports configs from a given directory"""
@ -47,7 +53,9 @@ def import_directory(directory: str, overwrite: bool, force: bool) -> None:
from superset.examples.utils import load_configs_from_directory from superset.examples.utils import load_configs_from_directory
load_configs_from_directory( load_configs_from_directory(
root=Path(directory), overwrite=overwrite, force_data=force, root=Path(directory),
overwrite=overwrite,
force_data=force,
) )
@ -56,7 +64,9 @@ if feature_flags.get("VERSIONED_EXPORT"):
@click.command() @click.command()
@with_appcontext @with_appcontext
@click.option( @click.option(
"--dashboard-file", "-f", help="Specify the the file to export to", "--dashboard-file",
"-f",
help="Specify the the file to export to",
) )
def export_dashboards(dashboard_file: Optional[str] = None) -> None: def export_dashboards(dashboard_file: Optional[str] = None) -> None:
"""Export dashboards to ZIP file""" """Export dashboards to ZIP file"""
@ -90,7 +100,9 @@ if feature_flags.get("VERSIONED_EXPORT"):
@click.command() @click.command()
@with_appcontext @with_appcontext
@click.option( @click.option(
"--datasource-file", "-f", help="Specify the the file to export to", "--datasource-file",
"-f",
help="Specify the the file to export to",
) )
def export_datasources(datasource_file: Optional[str] = None) -> None: def export_datasources(datasource_file: Optional[str] = None) -> None:
"""Export datasources to ZIP file""" """Export datasources to ZIP file"""
@ -122,7 +134,9 @@ if feature_flags.get("VERSIONED_EXPORT"):
@click.command() @click.command()
@with_appcontext @with_appcontext
@click.option( @click.option(
"--path", "-p", help="Path to a single ZIP file", "--path",
"-p",
help="Path to a single ZIP file",
) )
@click.option( @click.option(
"--username", "--username",
@ -160,7 +174,9 @@ if feature_flags.get("VERSIONED_EXPORT"):
@click.command() @click.command()
@with_appcontext @with_appcontext
@click.option( @click.option(
"--path", "-p", help="Path to a single ZIP file", "--path",
"-p",
help="Path to a single ZIP file",
) )
def import_datasources(path: str) -> None: def import_datasources(path: str) -> None:
"""Import datasources from ZIP file""" """Import datasources from ZIP file"""
@ -185,7 +201,6 @@ if feature_flags.get("VERSIONED_EXPORT"):
) )
sys.exit(1) sys.exit(1)
else: else:
@click.command() @click.command()

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
@click.group( @click.group(
cls=FlaskGroup, context_settings={"token_normalize_func": normalize_token}, cls=FlaskGroup,
context_settings={"token_normalize_func": normalize_token},
) )
@with_appcontext @with_appcontext
def superset() -> None: def superset() -> None:

View File

@ -44,7 +44,11 @@ logger = logging.getLogger(__name__)
help="Only process dashboards", help="Only process dashboards",
) )
@click.option( @click.option(
"--charts_only", "-c", is_flag=True, default=False, help="Only process charts", "--charts_only",
"-c",
is_flag=True,
default=False,
help="Only process charts",
) )
@click.option( @click.option(
"--force", "--force",

View File

@ -35,7 +35,10 @@ from superset.models.helpers import (
class Column( class Column(
Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model,
AuditMixinNullable,
ExtraJSONMixin,
ImportExportMixin,
): ):
""" """
A "column". A "column".

View File

@ -23,7 +23,7 @@ from superset.exceptions import SupersetException
class CommandException(SupersetException): class CommandException(SupersetException):
""" Common base class for Command exceptions. """ """Common base class for Command exceptions."""
def __repr__(self) -> str: def __repr__(self) -> str:
if self._exception: if self._exception:
@ -52,7 +52,7 @@ class ObjectNotFoundError(CommandException):
class CommandInvalidError(CommandException): class CommandInvalidError(CommandException):
""" Common base class for Command Invalid errors. """ """Common base class for Command Invalid errors."""
status = 422 status = 422

View File

@ -79,7 +79,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
def validate_metadata_type( def validate_metadata_type(
metadata: Optional[Dict[str, str]], type_: str, exceptions: List[ValidationError], metadata: Optional[Dict[str, str]],
type_: str,
exceptions: List[ValidationError],
) -> None: ) -> None:
"""Validate that the type declared in METADATA_FILE_NAME is correct""" """Validate that the type declared in METADATA_FILE_NAME is correct"""
if metadata and "type" in metadata: if metadata and "type" in metadata:

View File

@ -34,7 +34,9 @@ if TYPE_CHECKING:
def populate_owners( def populate_owners(
user: User, owner_ids: Optional[List[int]], default_to_user: bool, user: User,
owner_ids: Optional[List[int]],
default_to_user: bool,
) -> List[User]: ) -> List[User]:
""" """
Helper function for commands, will fetch all users from owners id's Helper function for commands, will fetch all users from owners id's

View File

@ -79,7 +79,9 @@ def _get_timegrains(
def _get_query( def _get_query(
query_context: "QueryContext", query_obj: "QueryObject", _: bool, query_context: "QueryContext",
query_obj: "QueryObject",
_: bool,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
datasource = _get_datasource(query_context, query_obj) datasource = _get_datasource(query_context, query_obj)
result = {"language": datasource.query_language} result = {"language": datasource.query_language}

View File

@ -69,7 +69,7 @@ class QueryContext:
result_format: ChartDataResultFormat, result_format: ChartDataResultFormat,
force: bool = False, force: bool = False,
custom_cache_timeout: Optional[int] = None, custom_cache_timeout: Optional[int] = None,
cache_values: Dict[str, Any] cache_values: Dict[str, Any],
) -> None: ) -> None:
self.datasource = datasource self.datasource = datasource
self.result_type = result_type self.result_type = result_type
@ -81,11 +81,16 @@ class QueryContext:
self.cache_values = cache_values self.cache_values = cache_values
self._processor = QueryContextProcessor(self) self._processor = QueryContextProcessor(self)
def get_data(self, df: pd.DataFrame,) -> Union[str, List[Dict[str, Any]]]: def get_data(
self,
df: pd.DataFrame,
) -> Union[str, List[Dict[str, Any]]]:
return self._processor.get_data(df) return self._processor.get_data(df)
def get_payload( def get_payload(
self, cache_query_context: Optional[bool] = False, force_cached: bool = False, self,
cache_query_context: Optional[bool] = False,
force_cached: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Returns the query results with both metadata and data""" """Returns the query results with both metadata and data"""
return self._processor.get_payload(cache_query_context, force_cached) return self._processor.get_payload(cache_query_context, force_cached)
@ -103,7 +108,9 @@ class QueryContext:
return self._processor.query_cache_key(query_obj, **kwargs) return self._processor.query_cache_key(query_obj, **kwargs)
def get_df_payload( def get_df_payload(
self, query_obj: QueryObject, force_cached: Optional[bool] = False, self,
query_obj: QueryObject,
force_cached: Optional[bool] = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return self._processor.get_df_payload(query_obj, force_cached) return self._processor.get_df_payload(query_obj, force_cached)
@ -111,7 +118,9 @@ class QueryContext:
return self._processor.get_query_result(query_object) return self._processor.get_query_result(query_object)
def processing_time_offsets( def processing_time_offsets(
self, df: pd.DataFrame, query_object: QueryObject, self,
df: pd.DataFrame,
query_object: QueryObject,
) -> CachedTimeOffset: ) -> CachedTimeOffset:
return self._processor.processing_time_offsets(df, query_object) return self._processor.processing_time_offsets(df, query_object)

View File

@ -50,7 +50,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
result_type: Optional[ChartDataResultType] = None, result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None, result_format: Optional[ChartDataResultFormat] = None,
force: bool = False, force: bool = False,
custom_cache_timeout: Optional[int] = None custom_cache_timeout: Optional[int] = None,
) -> QueryContext: ) -> QueryContext:
datasource_model_instance = None datasource_model_instance = None
if datasource: if datasource:

View File

@ -99,7 +99,10 @@ class QueryContextProcessor:
"""Handles caching around the df payload retrieval""" """Handles caching around the df payload retrieval"""
cache_key = self.query_cache_key(query_obj) cache_key = self.query_cache_key(query_obj)
cache = QueryCacheManager.get( cache = QueryCacheManager.get(
cache_key, CacheRegion.DATA, self._query_context.force, force_cached, cache_key,
CacheRegion.DATA,
self._query_context.force,
force_cached,
) )
if query_obj and cache_key and not cache.is_loaded: if query_obj and cache_key and not cache.is_loaded:
@ -235,7 +238,9 @@ class QueryContextProcessor:
return df return df
def processing_time_offsets( # pylint: disable=too-many-locals def processing_time_offsets( # pylint: disable=too-many-locals
self, df: pd.DataFrame, query_object: QueryObject, self,
df: pd.DataFrame,
query_object: QueryObject,
) -> CachedTimeOffset: ) -> CachedTimeOffset:
query_context = self._query_context query_context = self._query_context
# ensure query_object is immutable # ensure query_object is immutable
@ -250,7 +255,8 @@ class QueryContextProcessor:
for offset in time_offsets: for offset in time_offsets:
try: try:
query_object_clone.from_dttm = get_past_or_future( query_object_clone.from_dttm = get_past_or_future(
offset, outer_from_dttm, offset,
outer_from_dttm,
) )
query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm) query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm)
except ValueError as ex: except ValueError as ex:
@ -322,7 +328,9 @@ class QueryContextProcessor:
# df left join `offset_metrics_df` # df left join `offset_metrics_df`
offset_df = df_utils.left_join_df( offset_df = df_utils.left_join_df(
left_df=df, right_df=offset_metrics_df, join_keys=join_keys, left_df=df,
right_df=offset_metrics_df,
join_keys=join_keys,
) )
offset_slice = offset_df[metrics_mapping.values()] offset_slice = offset_df[metrics_mapping.values()]
@ -358,7 +366,9 @@ class QueryContextProcessor:
return df.to_dict(orient="records") return df.to_dict(orient="records")
def get_payload( def get_payload(
self, cache_query_context: Optional[bool] = False, force_cached: bool = False, self,
cache_query_context: Optional[bool] = False,
force_cached: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Returns the query results with both metadata and data""" """Returns the query results with both metadata and data"""

View File

@ -341,7 +341,11 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
def __repr__(self) -> str: def __repr__(self) -> str:
# we use `print` or `logging` output QueryObject # we use `print` or `logging` output QueryObject
return json.dumps(self.to_dict(), sort_keys=True, default=str,) return json.dumps(
self.to_dict(),
sort_keys=True,
default=str,
)
def cache_key(self, **extra: Any) -> str: def cache_key(self, **extra: Any) -> str:
""" """

View File

@ -80,7 +80,8 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
) )
def _process_extras( # pylint: disable=no-self-use def _process_extras( # pylint: disable=no-self-use
self, extras: Optional[Dict[str, Any]], self,
extras: Optional[Dict[str, Any]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
extras = extras or {} extras = extras or {}
return extras return extras

View File

@ -26,7 +26,9 @@ if TYPE_CHECKING:
def left_join_df( def left_join_df(
left_df: pd.DataFrame, right_df: pd.DataFrame, join_keys: List[str], left_df: pd.DataFrame,
right_df: pd.DataFrame,
join_keys: List[str],
) -> pd.DataFrame: ) -> pd.DataFrame:
df = left_df.set_index(join_keys).join(right_df.set_index(join_keys)) df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
df.reset_index(inplace=True) df.reset_index(inplace=True)

View File

@ -128,7 +128,6 @@ try:
self.name = name self.name = name
self.post_aggregator = post_aggregator self.post_aggregator = post_aggregator
except NameError: except NameError:
pass pass

View File

@ -62,7 +62,9 @@ class EnsureEnabledMixin:
class DruidColumnInlineView( # pylint: disable=too-many-ancestors class DruidColumnInlineView( # pylint: disable=too-many-ancestors
CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView, CompactCRUDMixin,
EnsureEnabledMixin,
SupersetModelView,
): ):
datamodel = SQLAInterface(models.DruidColumn) datamodel = SQLAInterface(models.DruidColumn)
include_route_methods = RouteMethod.RELATED_VIEW_SET include_route_methods = RouteMethod.RELATED_VIEW_SET
@ -151,7 +153,9 @@ class DruidColumnInlineView( # pylint: disable=too-many-ancestors
class DruidMetricInlineView( # pylint: disable=too-many-ancestors class DruidMetricInlineView( # pylint: disable=too-many-ancestors
CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView, CompactCRUDMixin,
EnsureEnabledMixin,
SupersetModelView,
): ):
datamodel = SQLAInterface(models.DruidMetric) datamodel = SQLAInterface(models.DruidMetric)
include_route_methods = RouteMethod.RELATED_VIEW_SET include_route_methods = RouteMethod.RELATED_VIEW_SET
@ -206,7 +210,10 @@ class DruidMetricInlineView( # pylint: disable=too-many-ancestors
class DruidClusterModelView( # pylint: disable=too-many-ancestors class DruidClusterModelView( # pylint: disable=too-many-ancestors
EnsureEnabledMixin, SupersetModelView, DeleteMixin, YamlExportMixin, EnsureEnabledMixin,
SupersetModelView,
DeleteMixin,
YamlExportMixin,
): ):
datamodel = SQLAInterface(models.DruidCluster) datamodel = SQLAInterface(models.DruidCluster)
include_route_methods = RouteMethod.CRUD_SET include_route_methods = RouteMethod.CRUD_SET
@ -270,7 +277,10 @@ class DruidClusterModelView( # pylint: disable=too-many-ancestors
class DruidDatasourceModelView( # pylint: disable=too-many-ancestors class DruidDatasourceModelView( # pylint: disable=too-many-ancestors
EnsureEnabledMixin, DatasourceModelView, DeleteMixin, YamlExportMixin, EnsureEnabledMixin,
DatasourceModelView,
DeleteMixin,
YamlExportMixin,
): ):
datamodel = SQLAInterface(models.DruidDatasource) datamodel = SQLAInterface(models.DruidDatasource)
include_route_methods = RouteMethod.CRUD_SET include_route_methods = RouteMethod.CRUD_SET

View File

@ -311,7 +311,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return self.table return self.table
def get_time_filter( def get_time_filter(
self, start_dttm: DateTime, end_dttm: DateTime, self,
start_dttm: DateTime,
end_dttm: DateTime,
) -> ColumnElement: ) -> ColumnElement:
col = self.get_sqla_col(label="__time") col = self.get_sqla_col(label="__time")
l = [] l = []
@ -687,7 +689,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if self.sql: if self.sql:
return get_virtual_table_metadata(dataset=self) return get_virtual_table_metadata(dataset=self)
return get_physical_table_metadata( return get_physical_table_metadata(
database=self.database, table_name=self.table_name, schema_name=self.schema, database=self.database,
table_name=self.table_name,
schema_name=self.schema,
) )
@property @property
@ -1013,7 +1017,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return all_filters return all_filters
except TemplateError as ex: except TemplateError as ex:
raise QueryObjectValidationError( raise QueryObjectValidationError(
_("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,) _(
"Error in jinja expression in RLS filters: %(msg)s",
msg=ex.message,
)
) from ex ) from ex
def text(self, clause: str) -> TextClause: def text(self, clause: str) -> TextClause:
@ -1233,7 +1240,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
): ):
time_filters.append( time_filters.append(
columns_by_name[self.main_dttm_col].get_time_filter( columns_by_name[self.main_dttm_col].get_time_filter(
from_dttm, to_dttm, from_dttm,
to_dttm,
) )
) )
time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
@ -1444,7 +1452,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if dttm_col and not db_engine_spec.time_groupby_inline: if dttm_col and not db_engine_spec.time_groupby_inline:
inner_time_filter = [ inner_time_filter = [
dttm_col.get_time_filter( dttm_col.get_time_filter(
inner_from_dttm or from_dttm, inner_to_dttm or to_dttm, inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
) )
] ]
subq = subq.where(and_(*(where_clause_and + inner_time_filter))) subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
@ -1473,7 +1482,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
orderby = [ orderby = [
( (
self._get_series_orderby( self._get_series_orderby(
series_limit_metric, metrics_by_name, columns_by_name, series_limit_metric,
metrics_by_name,
columns_by_name,
), ),
not order_desc, not order_desc,
) )
@ -1549,7 +1560,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return ob return ob
def _normalize_prequery_result_type( def _normalize_prequery_result_type(
self, row: pd.Series, dimension: str, columns_by_name: Dict[str, TableColumn], self,
row: pd.Series,
dimension: str,
columns_by_name: Dict[str, TableColumn],
) -> Union[str, int, float, bool, Text]: ) -> Union[str, int, float, bool, Text]:
""" """
Convert a prequery result type to its equivalent Python type. Convert a prequery result type to its equivalent Python type.
@ -1594,7 +1608,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
group = [] group = []
for dimension in dimensions: for dimension in dimensions:
value = self._normalize_prequery_result_type( value = self._normalize_prequery_result_type(
row, dimension, columns_by_name, row,
dimension,
columns_by_name,
) )
group.append(groupby_exprs[dimension] == value) group.append(groupby_exprs[dimension] == value)
@ -1933,7 +1949,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
@staticmethod @staticmethod
def after_insert( def after_insert(
mapper: Mapper, connection: Connection, target: "SqlaTable", mapper: Mapper,
connection: Connection,
target: "SqlaTable",
) -> None: ) -> None:
""" """
Shadow write the dataset to new models. Shadow write the dataset to new models.
@ -1962,7 +1980,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
@staticmethod @staticmethod
def after_delete( # pylint: disable=unused-argument def after_delete( # pylint: disable=unused-argument
mapper: Mapper, connection: Connection, target: "SqlaTable", mapper: Mapper,
connection: Connection,
target: "SqlaTable",
) -> None: ) -> None:
""" """
Shadow write the dataset to new models. Shadow write the dataset to new models.
@ -1985,7 +2005,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
@staticmethod @staticmethod
def after_update( # pylint: disable=too-many-branches, too-many-locals, too-many-statements def after_update( # pylint: disable=too-many-branches, too-many-locals, too-many-statements
mapper: Mapper, connection: Connection, target: "SqlaTable", mapper: Mapper,
connection: Connection,
target: "SqlaTable",
) -> None: ) -> None:
""" """
Shadow write the dataset to new models. Shadow write the dataset to new models.

View File

@ -36,7 +36,9 @@ if TYPE_CHECKING:
def get_physical_table_metadata( def get_physical_table_metadata(
database: Database, table_name: str, schema_name: Optional[str] = None, database: Database,
table_name: str,
schema_name: Optional[str] = None,
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
"""Use SQLAlchemy inspector to get table metadata""" """Use SQLAlchemy inspector to get table metadata"""
db_engine_spec = database.db_engine_spec db_engine_spec = database.db_engine_spec
@ -72,7 +74,11 @@ def get_physical_table_metadata(
# from different drivers that fall outside CompileError # from different drivers that fall outside CompileError
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
col.update( col.update(
{"type": "UNKNOWN", "generic_type": None, "is_dttm": None,} {
"type": "UNKNOWN",
"generic_type": None,
"is_dttm": None,
}
) )
return cols return cols

View File

@ -151,7 +151,8 @@ def import_dashboard(
old_dataset_id = target.get("datasetId") old_dataset_id = target.get("datasetId")
if dataset_id_mapping and old_dataset_id is not None: if dataset_id_mapping and old_dataset_id is not None:
target["datasetId"] = dataset_id_mapping.get( target["datasetId"] = dataset_id_mapping.get(
old_dataset_id, old_dataset_id, old_dataset_id,
old_dataset_id,
) )
dashboard.json_metadata = json.dumps(json_metadata) dashboard.json_metadata = json.dumps(json_metadata)

View File

@ -85,7 +85,8 @@ class BaseFilterSetCommand:
) )
except NotAuthorizedException as err: except NotAuthorizedException as err:
raise FilterSetForbiddenError( raise FilterSetForbiddenError(
str(self._filter_set_id), "user not authorized to access the filterset", str(self._filter_set_id),
"user not authorized to access the filterset",
) from err ) from err
except FilterSetForbiddenError as err: except FilterSetForbiddenError as err:
raise err raise err

View File

@ -46,7 +46,11 @@ class FilterSetSchema(Schema):
class FilterSetPostSchema(FilterSetSchema): class FilterSetPostSchema(FilterSetSchema):
json_metadata_schema: JsonMetadataSchema = JsonMetadataSchema() json_metadata_schema: JsonMetadataSchema = JsonMetadataSchema()
# pylint: disable=W0613 # pylint: disable=W0613
name = fields.String(required=True, allow_none=False, validate=Length(0, 500),) name = fields.String(
required=True,
allow_none=False,
validate=Length(0, 500),
)
description = fields.String( description = fields.String(
required=False, allow_none=True, validate=[Length(1, 1000)] required=False, allow_none=True, validate=[Length(1, 1000)]
) )

View File

@ -170,7 +170,9 @@ class FilterRelatedRoles(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query: Query, value: Optional[Any]) -> Query: def apply(self, query: Query, value: Optional[Any]) -> Query:
role_model = security_manager.role_model role_model = security_manager.role_model
if value: if value:
return query.filter(role_model.name.ilike(f"%{value}%"),) return query.filter(
role_model.name.ilike(f"%{value}%"),
)
return query return query
@ -184,7 +186,15 @@ class DashboardCertifiedFilter(BaseFilter): # pylint: disable=too-few-public-me
def apply(self, query: Query, value: Any) -> Query: def apply(self, query: Query, value: Any) -> Query:
if value is True: if value is True:
return query.filter(and_(Dashboard.certified_by.isnot(None),)) return query.filter(
and_(
Dashboard.certified_by.isnot(None),
)
)
if value is False: if value is False:
return query.filter(and_(Dashboard.certified_by.is_(None),)) return query.filter(
and_(
Dashboard.certified_by.is_(None),
)
)
return query return query

View File

@ -104,14 +104,19 @@ class DashboardPermalinkRestApi(BaseApi):
try: try:
state = self.add_model_schema.load(request.json) state = self.add_model_schema.load(request.json)
key = CreateDashboardPermalinkCommand( key = CreateDashboardPermalinkCommand(
actor=g.user, dashboard_id=pk, state=state, actor=g.user,
dashboard_id=pk,
state=state,
).run() ).run()
http_origin = request.headers.environ.get("HTTP_ORIGIN") http_origin = request.headers.environ.get("HTTP_ORIGIN")
url = f"{http_origin}/superset/dashboard/p/{key}/" url = f"{http_origin}/superset/dashboard/p/{key}/"
return self.response(201, key=key, url=url) return self.response(201, key=key, url=url)
except (ValidationError, DashboardPermalinkInvalidStateError) as ex: except (ValidationError, DashboardPermalinkInvalidStateError) as ex:
return self.response(400, message=str(ex)) return self.response(400, message=str(ex))
except (DashboardAccessDeniedError, KeyValueAccessDeniedError,) as ex: except (
DashboardAccessDeniedError,
KeyValueAccessDeniedError,
) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except DashboardNotFoundError as ex: except DashboardNotFoundError as ex:
return self.response(404, message=str(ex)) return self.response(404, message=str(ex))

View File

@ -31,7 +31,10 @@ logger = logging.getLogger(__name__)
class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
def __init__( def __init__(
self, actor: User, dashboard_id: str, state: DashboardPermalinkState, self,
actor: User,
dashboard_id: str,
state: DashboardPermalinkState,
): ):
self.actor = actor self.actor = actor
self.dashboard_id = dashboard_id self.dashboard_id = dashboard_id
@ -46,7 +49,9 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
"state": self.state, "state": self.state,
} }
key = CreateKeyValueCommand( key = CreateKeyValueCommand(
actor=self.actor, resource=self.resource, value=value, actor=self.actor,
resource=self.resource,
value=value,
).run() ).run()
return encode_permalink_key(key=key.id, salt=self.salt) return encode_permalink_key(key=key.id, salt=self.salt)
except SQLAlchemyError as ex: except SQLAlchemyError as ex:

View File

@ -19,7 +19,9 @@ from marshmallow import fields, Schema
class DashboardPermalinkPostSchema(Schema): class DashboardPermalinkPostSchema(Schema):
filterState = fields.Dict( filterState = fields.Dict(
required=False, allow_none=True, description="Native filter state", required=False,
allow_none=True,
description="Native filter state",
) )
urlParams = fields.List( urlParams = fields.List(
fields.Tuple( fields.Tuple(

View File

@ -243,7 +243,8 @@ class DashboardPostSchema(BaseDashboardSchema):
) )
css = fields.String() css = fields.String()
json_metadata = fields.String( json_metadata = fields.String(
description=json_metadata_description, validate=validate_json_metadata, description=json_metadata_description,
validate=validate_json_metadata,
) )
published = fields.Boolean(description=published_description) published = fields.Boolean(description=published_description)
certified_by = fields.String(description=certified_by_description, allow_none=True) certified_by = fields.String(description=certified_by_description, allow_none=True)

View File

@ -881,7 +881,10 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
database = DatabaseDAO.find_by_id(pk) database = DatabaseDAO.find_by_id(pk)
if not database: if not database:
return self.response_404() return self.response_404()
return self.response(200, function_names=database.function_names,) return self.response(
200,
function_names=database.function_names,
)
@expose("/available/", methods=["GET"]) @expose("/available/", methods=["GET"])
@protect() @protect()

View File

@ -47,7 +47,8 @@ class DatabaseExistsValidationError(ValidationError):
class DatabaseRequiredFieldValidationError(ValidationError): class DatabaseRequiredFieldValidationError(ValidationError):
def __init__(self, field_name: str) -> None: def __init__(self, field_name: str) -> None:
super().__init__( super().__init__(
[_("Field is required")], field_name=field_name, [_("Field is required")],
field_name=field_name,
) )
@ -100,7 +101,8 @@ class DatabaseUpdateFailedError(UpdateFailedError):
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
DatabaseCreateFailedError, DatabaseUpdateFailedError, DatabaseCreateFailedError,
DatabaseUpdateFailedError,
): ):
message = _("Connection failed, please check your connection settings") message = _("Connection failed, please check your connection settings")

View File

@ -57,7 +57,8 @@ class ValidateDatabaseParametersCommand(BaseCommand):
raise InvalidEngineError( raise InvalidEngineError(
SupersetError( SupersetError(
message=__( message=__(
'Engine "%(engine)s" is not a valid engine.', engine=engine, 'Engine "%(engine)s" is not a valid engine.',
engine=engine,
), ),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR, level=ErrorLevel.ERROR,
@ -101,7 +102,8 @@ class ValidateDatabaseParametersCommand(BaseCommand):
# try to connect # try to connect
sqlalchemy_uri = engine_spec.build_sqlalchemy_uri( # type: ignore sqlalchemy_uri = engine_spec.build_sqlalchemy_uri( # type: ignore
self._properties.get("parameters"), encrypted_extra, self._properties.get("parameters"),
encrypted_extra,
) )
if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri(): if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri():
sqlalchemy_uri = self._model.sqlalchemy_uri_decrypted sqlalchemy_uri = self._model.sqlalchemy_uri_decrypted

View File

@ -42,7 +42,8 @@ class DatabaseDAO(BaseDAO):
@staticmethod @staticmethod
def validate_update_uniqueness(database_id: int, database_name: str) -> bool: def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
database_query = db.session.query(Database).filter( database_query = db.session.query(Database).filter(
Database.database_name == database_name, Database.id != database_id, Database.database_name == database_name,
Database.id != database_id,
) )
return not db.session.query(database_query.exists()).scalar() return not db.session.query(database_query.exists()).scalar()

View File

@ -27,7 +27,8 @@ class DatabaseFilter(BaseFilter):
# TODO(bogdan): consider caching. # TODO(bogdan): consider caching.
def can_access_databases( # noqa pylint: disable=no-self-use def can_access_databases( # noqa pylint: disable=no-self-use
self, view_menu_name: str, self,
view_menu_name: str,
) -> Set[str]: ) -> Set[str]:
return { return {
security_manager.unpack_database_and_schema(vm).database security_manager.unpack_database_and_schema(vm).database

View File

@ -308,7 +308,12 @@ def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
engine_specs = get_engine_specs() engine_specs = get_engine_specs()
if engine not in engine_specs: if engine not in engine_specs:
raise ValidationError( raise ValidationError(
[_('Engine "%(engine)s" is not a valid engine.', engine=engine,)] [
_(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
)
]
) )
return engine_specs[engine] return engine_specs[engine]
@ -324,7 +329,9 @@ class DatabaseValidateParametersSchema(Schema):
description="DB-specific parameters for configuration", description="DB-specific parameters for configuration",
) )
database_name = fields.String( database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250), description=database_name_description,
allow_none=True,
validate=Length(1, 250),
) )
impersonate_user = fields.Boolean(description=impersonate_user_description) impersonate_user = fields.Boolean(description=impersonate_user_description)
extra = fields.String(description=extra_description, validate=extra_validator) extra = fields.String(description=extra_description, validate=extra_validator)
@ -351,7 +358,9 @@ class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
unknown = EXCLUDE unknown = EXCLUDE
database_name = fields.String( database_name = fields.String(
description=database_name_description, required=True, validate=Length(1, 250), description=database_name_description,
required=True,
validate=Length(1, 250),
) )
cache_timeout = fields.Integer( cache_timeout = fields.Integer(
description=cache_timeout_description, allow_none=True description=cache_timeout_description, allow_none=True
@ -395,7 +404,9 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
unknown = EXCLUDE unknown = EXCLUDE
database_name = fields.String( database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250), description=database_name_description,
allow_none=True,
validate=Length(1, 250),
) )
cache_timeout = fields.Integer( cache_timeout = fields.Integer(
description=cache_timeout_description, allow_none=True description=cache_timeout_description, allow_none=True
@ -436,7 +447,9 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
database_name = fields.String( database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250), description=database_name_description,
allow_none=True,
validate=Length(1, 250),
) )
impersonate_user = fields.Boolean(description=impersonate_user_description) impersonate_user = fields.Boolean(description=impersonate_user_description)
extra = fields.String(description=extra_description, validate=extra_validator) extra = fields.String(description=extra_description, validate=extra_validator)

View File

@ -285,7 +285,10 @@ class ImportDatasetsCommand(BaseCommand):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def __init__( def __init__(
self, contents: Dict[str, str], *args: Any, **kwargs: Any, self,
contents: Dict[str, str],
*args: Any,
**kwargs: Any,
): ):
self.contents = contents self.contents = contents
self._configs: Dict[str, Any] = {} self._configs: Dict[str, Any] = {}

View File

@ -65,7 +65,8 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
if self._model: if self._model:
try: try:
dataset = DatasetDAO.update( dataset = DatasetDAO.update(
model=self._model, properties=self._properties, model=self._model,
properties=self._properties,
) )
return dataset return dataset
except DAOUpdateFailedError as ex: except DAOUpdateFailedError as ex:

View File

@ -205,7 +205,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.BigInteger(), types.BigInteger(),
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
(re.compile(r"^long", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,), (
re.compile(r"^long", re.IGNORECASE),
types.Float(),
GenericDataType.NUMERIC,
),
( (
re.compile(r"^decimal", re.IGNORECASE), re.compile(r"^decimal", re.IGNORECASE),
types.Numeric(), types.Numeric(),
@ -216,13 +220,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.Numeric(), types.Numeric(),
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
(re.compile(r"^float", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,), (
re.compile(r"^float", re.IGNORECASE),
types.Float(),
GenericDataType.NUMERIC,
),
( (
re.compile(r"^double", re.IGNORECASE), re.compile(r"^double", re.IGNORECASE),
types.Float(), types.Float(),
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
(re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,), (
re.compile(r"^real", re.IGNORECASE),
types.REAL,
GenericDataType.NUMERIC,
),
( (
re.compile(r"^smallserial", re.IGNORECASE), re.compile(r"^smallserial", re.IGNORECASE),
types.SmallInteger(), types.SmallInteger(),
@ -258,7 +270,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.DateTime(), types.DateTime(),
GenericDataType.TEMPORAL, GenericDataType.TEMPORAL,
), ),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,), (
re.compile(r"^time", re.IGNORECASE),
types.Time(),
GenericDataType.TEMPORAL,
),
( (
re.compile(r"^interval", re.IGNORECASE), re.compile(r"^interval", re.IGNORECASE),
types.Interval(), types.Interval(),
@ -351,7 +367,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def get_allow_cost_estimate( # pylint: disable=unused-argument def get_allow_cost_estimate( # pylint: disable=unused-argument
cls, extra: Dict[str, Any], cls,
extra: Dict[str, Any],
) -> bool: ) -> bool:
return False return False
@ -618,7 +635,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def extra_table_metadata( # pylint: disable=unused-argument def extra_table_metadata( # pylint: disable=unused-argument
cls, database: "Database", table_name: str, schema_name: str, cls,
database: "Database",
table_name: str,
schema_name: str,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Returns engine-specific table metadata Returns engine-specific table metadata
@ -944,7 +964,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def get_table_names( # pylint: disable=unused-argument def get_table_names( # pylint: disable=unused-argument
cls, database: "Database", inspector: Inspector, schema: Optional[str], cls,
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]: ) -> List[str]:
""" """
Get all tables from schema Get all tables from schema
@ -961,7 +984,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def get_view_names( # pylint: disable=unused-argument def get_view_names( # pylint: disable=unused-argument
cls, database: "Database", inspector: Inspector, schema: Optional[str], cls,
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]: ) -> List[str]:
""" """
Get all views from schema Get all views from schema
@ -1193,7 +1219,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def update_impersonation_config( def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], cls,
connect_args: Dict[str, Any],
uri: str,
username: Optional[str],
) -> None: ) -> None:
""" """
Update a configuration dictionary Update a configuration dictionary
@ -1207,7 +1236,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def execute( # pylint: disable=unused-argument def execute( # pylint: disable=unused-argument
cls, cursor: Any, query: str, **kwargs: Any, cls,
cursor: Any,
query: str,
**kwargs: Any,
) -> None: ) -> None:
""" """
Execute a SQL query Execute a SQL query
@ -1333,7 +1365,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def get_function_names( # pylint: disable=unused-argument def get_function_names( # pylint: disable=unused-argument
cls, database: "Database", cls,
database: "Database",
) -> List[str]: ) -> List[str]:
""" """
Get a list of function names that are able to be called on the database. Get a list of function names that are able to be called on the database.
@ -1471,7 +1504,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def get_cancel_query_id( # pylint: disable=unused-argument def get_cancel_query_id( # pylint: disable=unused-argument
cls, cursor: Any, query: Query, cls,
cursor: Any,
query: Query,
) -> Optional[str]: ) -> Optional[str]:
""" """
Select identifiers from the database engine that uniquely identifies the Select identifiers from the database engine that uniquely identifies the
@ -1487,7 +1522,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def cancel_query( # pylint: disable=unused-argument def cancel_query( # pylint: disable=unused-argument
cls, cursor: Any, query: Query, cancel_query_id: str, cls,
cursor: Any,
query: Query,
cancel_query_id: str,
) -> bool: ) -> bool:
""" """
Cancel query in the underlying database. Cancel query in the underlying database.
@ -1515,7 +1553,7 @@ class BasicParametersSchema(Schema):
port = fields.Integer( port = fields.Integer(
required=True, required=True,
description=__("Database port"), description=__("Database port"),
validate=Range(min=0, max=2 ** 16, max_inclusive=False), validate=Range(min=0, max=2**16, max_inclusive=False),
) )
database = fields.String(required=True, description=__("Database name")) database = fields.String(required=True, description=__("Database name"))
query = fields.Dict( query = fields.Dict(
@ -1665,7 +1703,7 @@ class BasicParametersMixin:
extra={"invalid": ["port"]}, extra={"invalid": ["port"]},
), ),
) )
if not (isinstance(port, int) and 0 <= port < 2 ** 16): if not (isinstance(port, int) and 0 <= port < 2**16):
errors.append( errors.append(
SupersetError( SupersetError(
message=( message=(

View File

@ -72,7 +72,8 @@ ma_plugin = MarshmallowPlugin()
class BigQueryParametersSchema(Schema): class BigQueryParametersSchema(Schema):
credentials_info = EncryptedString( credentials_info = EncryptedString(
required=False, description="Contents of BigQuery JSON credentials.", required=False,
description="Contents of BigQuery JSON credentials.",
) )
query = fields.Dict(required=False) query = fields.Dict(required=False)

View File

@ -82,7 +82,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
@classmethod @classmethod
def modify_url_for_impersonation( def modify_url_for_impersonation(
cls, url: URL, impersonate_user: bool, username: Optional[str], cls,
url: URL,
impersonate_user: bool,
username: Optional[str],
) -> None: ) -> None:
if impersonate_user and username is not None: if impersonate_user and username is not None:
user = security_manager.find_user(username=username) user = security_manager.find_user(username=username)
@ -91,7 +94,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
@classmethod @classmethod
def extra_table_metadata( def extra_table_metadata(
cls, database: "Database", table_name: str, schema_name: str, cls,
database: "Database",
table_name: str,
schema_name: str,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
engine = cls.get_engine(database, schema=schema_name) engine = cls.get_engine(database, schema=schema_name)
with closing(engine.raw_connection()) as conn: with closing(engine.raw_connection()) as conn:
@ -150,7 +156,8 @@ class GSheetsEngineSpec(SqliteEngineSpec):
@classmethod @classmethod
def validate_parameters( def validate_parameters(
cls, parameters: GSheetsParametersType, cls,
parameters: GSheetsParametersType,
) -> List[SupersetError]: ) -> List[SupersetError]:
errors: List[SupersetError] = [] errors: List[SupersetError] = []
encrypted_credentials = parameters.get("service_account_info") or "{}" encrypted_credentials = parameters.get("service_account_info") or "{}"
@ -173,7 +180,9 @@ class GSheetsEngineSpec(SqliteEngineSpec):
subject = g.user.email if g.user else None subject = g.user.email if g.user else None
engine = create_engine( engine = create_engine(
"gsheets://", service_account_info=encrypted_credentials, subject=subject, "gsheets://",
service_account_info=encrypted_credentials,
subject=subject,
) )
conn = engine.connect() conn = engine.connect()
idx = 0 idx = 0

View File

@ -496,7 +496,10 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod @classmethod
def update_impersonation_config( def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], cls,
connect_args: Dict[str, Any],
uri: str,
username: Optional[str],
) -> None: ) -> None:
""" """
Update a configuration dictionary Update a configuration dictionary

View File

@ -74,24 +74,56 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
encryption_parameters = {"ssl": "1"} encryption_parameters = {"ssl": "1"}
column_type_mappings = ( column_type_mappings = (
(re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,), (
(re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,), re.compile(r"^int.*", re.IGNORECASE),
INTEGER(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^tinyint", re.IGNORECASE),
TINYINT(),
GenericDataType.NUMERIC,
),
( (
re.compile(r"^mediumint", re.IGNORECASE), re.compile(r"^mediumint", re.IGNORECASE),
MEDIUMINT(), MEDIUMINT(),
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
(re.compile(r"^decimal", re.IGNORECASE), DECIMAL(), GenericDataType.NUMERIC,), (
(re.compile(r"^float", re.IGNORECASE), FLOAT(), GenericDataType.NUMERIC,), re.compile(r"^decimal", re.IGNORECASE),
(re.compile(r"^double", re.IGNORECASE), DOUBLE(), GenericDataType.NUMERIC,), DECIMAL(),
(re.compile(r"^bit", re.IGNORECASE), BIT(), GenericDataType.NUMERIC,), GenericDataType.NUMERIC,
(re.compile(r"^tinytext", re.IGNORECASE), TINYTEXT(), GenericDataType.STRING,), ),
(
re.compile(r"^float", re.IGNORECASE),
FLOAT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^double", re.IGNORECASE),
DOUBLE(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bit", re.IGNORECASE),
BIT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^tinytext", re.IGNORECASE),
TINYTEXT(),
GenericDataType.STRING,
),
( (
re.compile(r"^mediumtext", re.IGNORECASE), re.compile(r"^mediumtext", re.IGNORECASE),
MEDIUMTEXT(), MEDIUMTEXT(),
GenericDataType.STRING, GenericDataType.STRING,
), ),
(re.compile(r"^longtext", re.IGNORECASE), LONGTEXT(), GenericDataType.STRING,), (
re.compile(r"^longtext", re.IGNORECASE),
LONGTEXT(),
GenericDataType.STRING,
),
) )
_time_grain_expressions = { _time_grain_expressions = {

View File

@ -188,8 +188,16 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
lambda match: ARRAY(int(match[2])) if match[2] else String(), lambda match: ARRAY(int(match[2])) if match[2] else String(),
GenericDataType.STRING, GenericDataType.STRING,
), ),
(re.compile(r"^json.*", re.IGNORECASE), JSON(), GenericDataType.STRING,), (
(re.compile(r"^enum.*", re.IGNORECASE), ENUM(), GenericDataType.STRING,), re.compile(r"^json.*", re.IGNORECASE),
JSON(),
GenericDataType.STRING,
),
(
re.compile(r"^enum.*", re.IGNORECASE),
ENUM(),
GenericDataType.STRING,
),
) )
@classmethod @classmethod

View File

@ -214,7 +214,10 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
@classmethod @classmethod
def update_impersonation_config( def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], cls,
connect_args: Dict[str, Any],
uri: str,
username: Optional[str],
) -> None: ) -> None:
""" """
Update a configuration dictionary Update a configuration dictionary
@ -487,7 +490,11 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
types.VARBINARY(), types.VARBINARY(),
GenericDataType.STRING, GenericDataType.STRING,
), ),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON(), GenericDataType.STRING,), (
re.compile(r"^json.*", re.IGNORECASE),
types.JSON(),
GenericDataType.STRING,
),
( (
re.compile(r"^date.*", re.IGNORECASE), re.compile(r"^date.*", re.IGNORECASE),
types.DATETIME(), types.DATETIME(),

View File

@ -94,7 +94,10 @@ class TrinoEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def update_impersonation_config( def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], cls,
connect_args: Dict[str, Any],
uri: str,
username: Optional[str],
) -> None: ) -> None:
""" """
Update a configuration dictionary Update a configuration dictionary

View File

@ -186,8 +186,16 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
default_query_context = { default_query_context = {
"result_format": "json", "result_format": "json",
"result_type": "full", "result_type": "full",
"datasource": {"id": tbl.id, "type": "table",}, "datasource": {
"queries": [{"columns": [], "metrics": [],},], "id": tbl.id,
"type": "table",
},
"queries": [
{
"columns": [],
"metrics": [],
},
],
} }
admin = get_admin_user() admin = get_admin_user()
@ -381,7 +389,12 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
), ),
query_context=get_slice_json( query_context=get_slice_json(
default_query_context, default_query_context,
queries=[{"columns": ["name", "state"], "metrics": [metric],}], queries=[
{
"columns": ["name", "state"],
"metrics": [metric],
}
],
), ),
), ),
] ]

View File

@ -43,7 +43,9 @@ from .helpers import (
def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-statements def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-statements
only_metadata: bool = False, force: bool = False, sample: bool = False, only_metadata: bool = False,
force: bool = False,
sample: bool = False,
) -> None: ) -> None:
"""Loads the world bank health dataset, slices and a dashboard""" """Loads the world bank health dataset, slices and a dashboard"""
tbl_name = "wb_health_population" tbl_name = "wb_health_population"

View File

@ -129,7 +129,10 @@ class SupersetGenericDBErrorException(SupersetErrorFromParamsException):
extra: Optional[Dict[str, Any]] = None, extra: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
SupersetErrorType.GENERIC_DB_ENGINE_ERROR, message, level, extra, SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
message,
level,
extra,
) )
@ -144,7 +147,10 @@ class SupersetTemplateParamsErrorException(SupersetErrorFromParamsException):
extra: Optional[Dict[str, Any]] = None, extra: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
error, message, level, extra, error,
message,
level,
extra,
) )

View File

@ -39,7 +39,8 @@ logger = logging.getLogger(__name__)
class UpdateFormDataCommand(BaseCommand, ABC): class UpdateFormDataCommand(BaseCommand, ABC):
def __init__( def __init__(
self, cmd_params: CommandParameters, self,
cmd_params: CommandParameters,
): ):
self._cmd_params = cmd_params self._cmd_params = cmd_params

View File

@ -162,7 +162,10 @@ class ExplorePermalinkRestApi(BaseApi):
return self.response(200, **value) return self.response(200, **value)
except ExplorePermalinkInvalidStateError as ex: except ExplorePermalinkInvalidStateError as ex:
return self.response(400, message=str(ex)) return self.response(400, message=str(ex))
except (ChartAccessDeniedError, DatasetAccessDeniedError,) as ex: except (
ChartAccessDeniedError,
DatasetAccessDeniedError,
) as ex:
return self.response(403, message=str(ex)) return self.response(403, message=str(ex))
except (ChartNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DatasetNotFoundError) as ex:
return self.response(404, message=str(ex)) return self.response(404, message=str(ex))

View File

@ -48,7 +48,9 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
"state": self.state, "state": self.state,
} }
command = CreateKeyValueCommand( command = CreateKeyValueCommand(
actor=self.actor, resource=self.resource, value=value, actor=self.actor,
resource=self.resource,
value=value,
) )
key = command.run() key = command.run()
return encode_permalink_key(key=key.id, salt=self.salt) return encode_permalink_key(key=key.id, salt=self.salt)

View File

@ -42,7 +42,8 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
try: try:
key = decode_permalink_id(self.key, salt=self.salt) key = decode_permalink_id(self.key, salt=self.salt)
value: Optional[ExplorePermalinkValue] = GetKeyValueCommand( value: Optional[ExplorePermalinkValue] = GetKeyValueCommand(
resource=self.resource, key=key, resource=self.resource,
key=key,
).run() ).run()
if value: if value:
chart_id: Optional[int] = value.get("chartId") chart_id: Optional[int] = value.get("chartId")

View File

@ -19,7 +19,9 @@ from marshmallow import fields, Schema
class ExplorePermalinkPostSchema(Schema): class ExplorePermalinkPostSchema(Schema):
formData = fields.Dict( formData = fields.Dict(
required=True, allow_none=False, description="Chart form data", required=True,
allow_none=False,
description="Chart form data",
) )
urlParams = fields.List( urlParams = fields.List(
fields.Tuple( fields.Tuple(

View File

@ -355,7 +355,10 @@ def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
return_value = json.loads(json.dumps(return_value)) return_value = json.loads(json.dumps(return_value))
except TypeError as ex: except TypeError as ex:
raise SupersetTemplateException( raise SupersetTemplateException(
_("Unsupported return value for method %(name)s", name=func.__name__,) _(
"Unsupported return value for method %(name)s",
name=func.__name__,
)
) from ex ) from ex
return return_value return return_value

View File

@ -214,7 +214,9 @@ def _delete_old_permissions(
def migrate_roles( def migrate_roles(
session: Session, pvm_key_map: PvmMigrationMapType, commit: bool = False, session: Session,
pvm_key_map: PvmMigrationMapType,
commit: bool = False,
) -> None: ) -> None:
""" """
Migrates all existing roles that have the permissions to be migrated Migrates all existing roles that have the permissions to be migrated

View File

@ -91,7 +91,8 @@ def downgrade():
for dashboard in session.query(Dashboard).all(): for dashboard in session.query(Dashboard).all():
logger.info( logger.info(
"[RemoveTypeToNativeFilter] Updating Dashobard<pk:%s>", dashboard.id, "[RemoveTypeToNativeFilter] Updating Dashobard<pk:%s>",
dashboard.id,
) )
if not dashboard.json_metadata: if not dashboard.json_metadata:
logger.info( logger.info(

View File

@ -39,24 +39,63 @@ from superset.migrations.shared.security_converge import (
Pvm, Pvm,
) )
NEW_PVMS = {"Dashboard": ("can_read", "can_write",)} NEW_PVMS = {
"Dashboard": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("DashboardModelView", "can_add"): (Pvm("Dashboard", "can_write"),), Pvm("DashboardModelView", "can_add"): (Pvm("Dashboard", "can_write"),),
Pvm("DashboardModelView", "can_delete"): (Pvm("Dashboard", "can_write"),), Pvm("DashboardModelView", "can_delete"): (Pvm("Dashboard", "can_write"),),
Pvm("DashboardModelView", "can_download_dashboards",): ( Pvm(
Pvm("Dashboard", "can_read"), "DashboardModelView",
), "can_download_dashboards",
Pvm("DashboardModelView", "can_edit",): (Pvm("Dashboard", "can_write"),), ): (Pvm("Dashboard", "can_read"),),
Pvm("DashboardModelView", "can_favorite_status",): (Pvm("Dashboard", "can_read"),), Pvm(
Pvm("DashboardModelView", "can_list",): (Pvm("Dashboard", "can_read"),), "DashboardModelView",
Pvm("DashboardModelView", "can_mulexport",): (Pvm("Dashboard", "can_read"),), "can_edit",
Pvm("DashboardModelView", "can_show",): (Pvm("Dashboard", "can_read"),), ): (Pvm("Dashboard", "can_write"),),
Pvm("DashboardModelView", "muldelete",): (Pvm("Dashboard", "can_write"),), Pvm(
Pvm("DashboardModelView", "mulexport",): (Pvm("Dashboard", "can_read"),), "DashboardModelView",
Pvm("DashboardModelViewAsync", "can_list",): (Pvm("Dashboard", "can_read"),), "can_favorite_status",
Pvm("DashboardModelViewAsync", "muldelete",): (Pvm("Dashboard", "can_write"),), ): (Pvm("Dashboard", "can_read"),),
Pvm("DashboardModelViewAsync", "mulexport",): (Pvm("Dashboard", "can_read"),), Pvm(
Pvm("Dashboard", "can_new",): (Pvm("Dashboard", "can_write"),), "DashboardModelView",
"can_list",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelView",
"can_mulexport",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelView",
"can_show",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelView",
"muldelete",
): (Pvm("Dashboard", "can_write"),),
Pvm(
"DashboardModelView",
"mulexport",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelViewAsync",
"can_list",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelViewAsync",
"muldelete",
): (Pvm("Dashboard", "can_write"),),
Pvm(
"DashboardModelViewAsync",
"mulexport",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"Dashboard",
"can_new",
): (Pvm("Dashboard", "can_write"),),
} }

View File

@ -43,9 +43,18 @@ def upgrade():
sa.Column("created_by_fk", sa.Integer(), nullable=True), sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True), sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.Column("alert_id", sa.Integer(), nullable=False), sa.Column("alert_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],), ["alert_id"],
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],), ["alerts.id"],
),
sa.ForeignKeyConstraint(
["changed_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["created_by_fk"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_table( op.create_table(
@ -58,10 +67,22 @@ def upgrade():
sa.Column("changed_by_fk", sa.Integer(), nullable=True), sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.Column("alert_id", sa.Integer(), nullable=False), sa.Column("alert_id", sa.Integer(), nullable=False),
sa.Column("database_id", sa.Integer(), nullable=False), sa.Column("database_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],), ["alert_id"],
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],), ["alerts.id"],
sa.ForeignKeyConstraint(["database_id"], ["dbs.id"],), ),
sa.ForeignKeyConstraint(
["changed_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["created_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["database_id"],
["dbs.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_table( op.create_table(
@ -72,8 +93,14 @@ def upgrade():
sa.Column("alert_id", sa.Integer(), nullable=True), sa.Column("alert_id", sa.Integer(), nullable=True),
sa.Column("value", sa.Float(), nullable=True), sa.Column("value", sa.Float(), nullable=True),
sa.Column("error_msg", sa.String(length=500), nullable=True), sa.Column("error_msg", sa.String(length=500), nullable=True),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(["observer_id"], ["sql_observers.id"],), ["alert_id"],
["alerts.id"],
),
sa.ForeignKeyConstraint(
["observer_id"],
["sql_observers.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(

View File

@ -49,8 +49,14 @@ def upgrade():
sa.Column("dashboard_id", sa.Integer(), nullable=True), sa.Column("dashboard_id", sa.Integer(), nullable=True),
sa.Column("last_eval_dttm", sa.DateTime(), nullable=True), sa.Column("last_eval_dttm", sa.DateTime(), nullable=True),
sa.Column("last_state", sa.String(length=10), nullable=True), sa.Column("last_state", sa.String(length=10), nullable=True),
sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"],), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(["slice_id"], ["slices.id"],), ["dashboard_id"],
["dashboards.id"],
),
sa.ForeignKeyConstraint(
["slice_id"],
["slices.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f("ix_alerts_active"), "alerts", ["active"], unique=False) op.create_index(op.f("ix_alerts_active"), "alerts", ["active"], unique=False)
@ -62,7 +68,10 @@ def upgrade():
sa.Column("dttm_end", sa.DateTime(), nullable=True), sa.Column("dttm_end", sa.DateTime(), nullable=True),
sa.Column("alert_id", sa.Integer(), nullable=True), sa.Column("alert_id", sa.Integer(), nullable=True),
sa.Column("state", sa.String(length=10), nullable=True), sa.Column("state", sa.String(length=10), nullable=True),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), sa.ForeignKeyConstraint(
["alert_id"],
["alerts.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_table( op.create_table(
@ -70,8 +79,14 @@ def upgrade():
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True), sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("alert_id", sa.Integer(), nullable=True), sa.Column("alert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"],), ["alert_id"],
["alerts.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )

View File

@ -34,7 +34,9 @@ def upgrade():
with op.batch_alter_table("report_schedule") as batch_op: with op.batch_alter_table("report_schedule") as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column( sa.Column(
"creation_method", sa.VARCHAR(255), server_default="alerts_reports", "creation_method",
sa.VARCHAR(255),
server_default="alerts_reports",
) )
) )
batch_op.create_index( batch_op.create_index(

View File

@ -39,13 +39,27 @@ from superset.migrations.shared.security_converge import (
Pvm, Pvm,
) )
NEW_PVMS = {"ReportSchedule": ("can_read", "can_write",)} NEW_PVMS = {
"ReportSchedule": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("ReportSchedule", "can_list"): (Pvm("ReportSchedule", "can_read"),), Pvm("ReportSchedule", "can_list"): (Pvm("ReportSchedule", "can_read"),),
Pvm("ReportSchedule", "can_show"): (Pvm("ReportSchedule", "can_read"),), Pvm("ReportSchedule", "can_show"): (Pvm("ReportSchedule", "can_read"),),
Pvm("ReportSchedule", "can_add",): (Pvm("ReportSchedule", "can_write"),), Pvm(
Pvm("ReportSchedule", "can_edit",): (Pvm("ReportSchedule", "can_write"),), "ReportSchedule",
Pvm("ReportSchedule", "can_delete",): (Pvm("ReportSchedule", "can_write"),), "can_add",
): (Pvm("ReportSchedule", "can_write"),),
Pvm(
"ReportSchedule",
"can_edit",
): (Pvm("ReportSchedule", "can_write"),),
Pvm(
"ReportSchedule",
"can_delete",
): (Pvm("ReportSchedule", "can_write"),),
} }

View File

@ -39,17 +39,43 @@ from superset.migrations.shared.security_converge import (
Pvm, Pvm,
) )
NEW_PVMS = {"Database": ("can_read", "can_write",)} NEW_PVMS = {
"Database": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("DatabaseView", "can_add"): (Pvm("Database", "can_write"),), Pvm("DatabaseView", "can_add"): (Pvm("Database", "can_write"),),
Pvm("DatabaseView", "can_delete"): (Pvm("Database", "can_write"),), Pvm("DatabaseView", "can_delete"): (Pvm("Database", "can_write"),),
Pvm("DatabaseView", "can_edit",): (Pvm("Database", "can_write"),), Pvm(
Pvm("DatabaseView", "can_list",): (Pvm("Database", "can_read"),), "DatabaseView",
Pvm("DatabaseView", "can_mulexport",): (Pvm("Database", "can_read"),), "can_edit",
Pvm("DatabaseView", "can_post",): (Pvm("Database", "can_write"),), ): (Pvm("Database", "can_write"),),
Pvm("DatabaseView", "can_show",): (Pvm("Database", "can_read"),), Pvm(
Pvm("DatabaseView", "muldelete",): (Pvm("Database", "can_write"),), "DatabaseView",
Pvm("DatabaseView", "yaml_export",): (Pvm("Database", "can_read"),), "can_list",
): (Pvm("Database", "can_read"),),
Pvm(
"DatabaseView",
"can_mulexport",
): (Pvm("Database", "can_read"),),
Pvm(
"DatabaseView",
"can_post",
): (Pvm("Database", "can_write"),),
Pvm(
"DatabaseView",
"can_show",
): (Pvm("Database", "can_read"),),
Pvm(
"DatabaseView",
"muldelete",
): (Pvm("Database", "can_write"),),
Pvm(
"DatabaseView",
"yaml_export",
): (Pvm("Database", "can_read"),),
} }

View File

@ -38,7 +38,12 @@ from superset.migrations.shared.security_converge import (
Pvm, Pvm,
) )
NEW_PVMS = {"Dataset": ("can_read", "can_write",)} NEW_PVMS = {
"Dataset": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("SqlMetricInlineView", "can_add"): (Pvm("Dataset", "can_write"),), Pvm("SqlMetricInlineView", "can_add"): (Pvm("Dataset", "can_write"),),
Pvm("SqlMetricInlineView", "can_delete"): (Pvm("Dataset", "can_write"),), Pvm("SqlMetricInlineView", "can_delete"): (Pvm("Dataset", "can_write"),),
@ -50,15 +55,33 @@ PVM_MAP = {
Pvm("TableColumnInlineView", "can_edit"): (Pvm("Dataset", "can_write"),), Pvm("TableColumnInlineView", "can_edit"): (Pvm("Dataset", "can_write"),),
Pvm("TableColumnInlineView", "can_list"): (Pvm("Dataset", "can_read"),), Pvm("TableColumnInlineView", "can_list"): (Pvm("Dataset", "can_read"),),
Pvm("TableColumnInlineView", "can_show"): (Pvm("Dataset", "can_read"),), Pvm("TableColumnInlineView", "can_show"): (Pvm("Dataset", "can_read"),),
Pvm("TableModelView", "can_add",): (Pvm("Dataset", "can_write"),), Pvm(
Pvm("TableModelView", "can_delete",): (Pvm("Dataset", "can_write"),), "TableModelView",
Pvm("TableModelView", "can_edit",): (Pvm("Dataset", "can_write"),), "can_add",
): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"can_delete",
): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"can_edit",
): (Pvm("Dataset", "can_write"),),
Pvm("TableModelView", "can_list"): (Pvm("Dataset", "can_read"),), Pvm("TableModelView", "can_list"): (Pvm("Dataset", "can_read"),),
Pvm("TableModelView", "can_mulexport"): (Pvm("Dataset", "can_read"),), Pvm("TableModelView", "can_mulexport"): (Pvm("Dataset", "can_read"),),
Pvm("TableModelView", "can_show"): (Pvm("Dataset", "can_read"),), Pvm("TableModelView", "can_show"): (Pvm("Dataset", "can_read"),),
Pvm("TableModelView", "muldelete",): (Pvm("Dataset", "can_write"),), Pvm(
Pvm("TableModelView", "refresh",): (Pvm("Dataset", "can_write"),), "TableModelView",
Pvm("TableModelView", "yaml_export",): (Pvm("Dataset", "can_read"),), "muldelete",
): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"refresh",
): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"yaml_export",
): (Pvm("Dataset", "can_read"),),
} }

View File

@ -114,8 +114,14 @@ def upgrade():
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("report_schedule_id", sa.Integer(), nullable=False), sa.Column("report_schedule_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["report_schedule_id"], ["report_schedule.id"],), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"],), ["report_schedule_id"],
["report_schedule.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )

View File

@ -37,10 +37,18 @@ from superset.migrations.shared.security_converge import (
revision = "4b84f97828aa" revision = "4b84f97828aa"
down_revision = "45731db65d9c" down_revision = "45731db65d9c"
NEW_PVMS = {"Log": ("can_read", "can_write",)} NEW_PVMS = {
"Log": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("LogModelView", "can_show"): (Pvm("Log", "can_read"),), Pvm("LogModelView", "can_show"): (Pvm("Log", "can_read"),),
Pvm("LogModelView", "can_add",): (Pvm("Log", "can_write"),), Pvm(
"LogModelView",
"can_add",
): (Pvm("Log", "can_write"),),
Pvm("LogModelView", "can_list"): (Pvm("Log", "can_read"),), Pvm("LogModelView", "can_list"): (Pvm("Log", "can_read"),),
} }

View File

@ -62,5 +62,8 @@ def downgrade():
) )
batch_op.create_foreign_key( batch_op.create_foreign_key(
"saved_query_id", "saved_query", ["saved_query_id"], ["id"], "saved_query_id",
"saved_query",
["saved_query_id"],
["id"],
) )

View File

@ -42,8 +42,14 @@ def upgrade():
sa.Column("bundle_url", sa.String(length=1000), nullable=False), sa.Column("bundle_url", sa.String(length=1000), nullable=False),
sa.Column("created_by_fk", sa.Integer(), nullable=True), sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True), sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],), ["changed_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["created_by_fk"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("key"), sa.UniqueConstraint("key"),
sa.UniqueConstraint("name"), sa.UniqueConstraint("name"),

View File

@ -39,16 +39,39 @@ from superset.migrations.shared.security_converge import (
Pvm, Pvm,
) )
NEW_PVMS = {"CssTemplate": ("can_read", "can_write",)} NEW_PVMS = {
"CssTemplate": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("CssTemplateModelView", "can_list"): (Pvm("CssTemplate", "can_read"),), Pvm("CssTemplateModelView", "can_list"): (Pvm("CssTemplate", "can_read"),),
Pvm("CssTemplateModelView", "can_show"): (Pvm("CssTemplate", "can_read"),), Pvm("CssTemplateModelView", "can_show"): (Pvm("CssTemplate", "can_read"),),
Pvm("CssTemplateModelView", "can_add",): (Pvm("CssTemplate", "can_write"),), Pvm(
Pvm("CssTemplateModelView", "can_edit",): (Pvm("CssTemplate", "can_write"),), "CssTemplateModelView",
Pvm("CssTemplateModelView", "can_delete",): (Pvm("CssTemplate", "can_write"),), "can_add",
Pvm("CssTemplateModelView", "muldelete",): (Pvm("CssTemplate", "can_write"),), ): (Pvm("CssTemplate", "can_write"),),
Pvm("CssTemplateAsyncModelView", "can_list",): (Pvm("CssTemplate", "can_read"),), Pvm(
Pvm("CssTemplateAsyncModelView", "muldelete",): (Pvm("CssTemplate", "can_write"),), "CssTemplateModelView",
"can_edit",
): (Pvm("CssTemplate", "can_write"),),
Pvm(
"CssTemplateModelView",
"can_delete",
): (Pvm("CssTemplate", "can_write"),),
Pvm(
"CssTemplateModelView",
"muldelete",
): (Pvm("CssTemplate", "can_write"),),
Pvm(
"CssTemplateAsyncModelView",
"can_list",
): (Pvm("CssTemplate", "can_read"),),
Pvm(
"CssTemplateAsyncModelView",
"muldelete",
): (Pvm("CssTemplate", "can_write"),),
} }

View File

@ -65,7 +65,10 @@ def upgrade():
with op.batch_alter_table("saved_query") as batch_op: with op.batch_alter_table("saved_query") as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column( sa.Column(
"uuid", UUIDType(binary=True), primary_key=False, default=uuid4, "uuid",
UUIDType(binary=True),
primary_key=False,
default=uuid4,
), ),
) )
except OperationalError: except OperationalError:

View File

@ -159,7 +159,10 @@ def upgrade():
for key_to_remove in keys_to_remove: for key_to_remove in keys_to_remove:
del position_dict[key_to_remove] del position_dict[key_to_remove]
dashboard.position_json = json.dumps( dashboard.position_json = json.dumps(
position_dict, indent=None, separators=(",", ":"), sort_keys=True, position_dict,
indent=None,
separators=(",", ":"),
sort_keys=True,
) )
session.merge(dashboard) session.merge(dashboard)

View File

@ -39,7 +39,12 @@ def upgrade():
with op.batch_alter_table("report_schedule") as batch_op: with op.batch_alter_table("report_schedule") as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column("extra", sa.Text(), nullable=True, default="{}",), sa.Column(
"extra",
sa.Text(),
nullable=True,
default="{}",
),
) )
bind.execute(report_schedule.update().values({"extra": "{}"})) bind.execute(report_schedule.update().values({"extra": "{}"}))
with op.batch_alter_table("report_schedule") as batch_op: with op.batch_alter_table("report_schedule") as batch_op:

View File

@ -118,7 +118,8 @@ def upgrade():
sa.Column("validator_config", sa.Text(), default="", nullable=True), sa.Column("validator_config", sa.Text(), default="", nullable=True),
) )
op.add_column( op.add_column(
"alerts", sa.Column("database_id", sa.Integer(), default=0, nullable=False), "alerts",
sa.Column("database_id", sa.Integer(), default=0, nullable=False),
) )
op.add_column("alerts", sa.Column("sql", sa.Text(), default="", nullable=False)) op.add_column("alerts", sa.Column("sql", sa.Text(), default="", nullable=False))
op.add_column( op.add_column(
@ -159,7 +160,10 @@ def upgrade():
sa.Column("alert_id", sa.Integer(), nullable=True), sa.Column("alert_id", sa.Integer(), nullable=True),
sa.Column("value", sa.Float(), nullable=True), sa.Column("value", sa.Float(), nullable=True),
sa.Column("error_msg", sa.String(length=500), nullable=True), sa.Column("error_msg", sa.String(length=500), nullable=True),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), sa.ForeignKeyConstraint(
["alert_id"],
["alerts.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
else: else:
@ -192,7 +196,11 @@ def downgrade():
sa.Column("created_on", sa.DateTime(), nullable=True), sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True), sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("validator_type", sa.String(length=100), nullable=False,), sa.Column(
"validator_type",
sa.String(length=100),
nullable=False,
),
sa.Column("config", sa.Text(), nullable=True), sa.Column("config", sa.Text(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), autoincrement=False, nullable=True), sa.Column("created_by_fk", sa.Integer(), autoincrement=False, nullable=True),
sa.Column("changed_by_fk", sa.Integer(), autoincrement=False, nullable=True), sa.Column("changed_by_fk", sa.Integer(), autoincrement=False, nullable=True),
@ -261,10 +269,22 @@ def downgrade():
sa.Column("created_by_fk", sa.Integer(), nullable=True), sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("created_on", sa.DateTime(), nullable=True), sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("slack_channel", sa.Text(), nullable=True), sa.Column("slack_channel", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"],), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(["slice_id"], ["slices.id"],), ["dashboard_id"],
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],), ["dashboards.id"],
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],), ),
sa.ForeignKeyConstraint(
["slice_id"],
["slices.id"],
),
sa.ForeignKeyConstraint(
["created_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["changed_by_fk"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
else: else:

View File

@ -42,6 +42,6 @@ def upgrade():
def downgrade(): def downgrade():
try: try:
# Trying since sqlite doesn't like constraints # Trying since sqlite doesn't like constraints
op.drop_constraint(u"_customer_location_uc", "tables", type_="unique") op.drop_constraint("_customer_location_uc", "tables", type_="unique")
except Exception: except Exception:
pass pass

View File

@ -171,7 +171,10 @@ def upgrade():
with op.batch_alter_table(table_name) as batch_op: with op.batch_alter_table(table_name) as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column( sa.Column(
"uuid", UUIDType(binary=True), primary_key=False, default=uuid4, "uuid",
UUIDType(binary=True),
primary_key=False,
default=uuid4,
), ),
) )

View File

@ -36,7 +36,8 @@ def upgrade():
kwargs: Dict[str, str] = {} kwargs: Dict[str, str] = {}
bind = op.get_bind() bind = op.get_bind()
op.add_column( op.add_column(
"dbs", sa.Column("server_cert", sa.LargeBinary(), nullable=True, **kwargs), "dbs",
sa.Column("server_cert", sa.LargeBinary(), nullable=True, **kwargs),
) )

View File

@ -379,16 +379,46 @@ def upgrade():
sa.Column("name", sa.TEXT(), nullable=False), sa.Column("name", sa.TEXT(), nullable=False),
sa.Column("type", sa.TEXT(), nullable=False), sa.Column("type", sa.TEXT(), nullable=False),
sa.Column("expression", sa.TEXT(), nullable=False), sa.Column("expression", sa.TEXT(), nullable=False),
sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=True,), sa.Column(
"is_physical",
sa.BOOLEAN(),
nullable=False,
default=True,
),
sa.Column("description", sa.TEXT(), nullable=True), sa.Column("description", sa.TEXT(), nullable=True),
sa.Column("warning_text", sa.TEXT(), nullable=True), sa.Column("warning_text", sa.TEXT(), nullable=True),
sa.Column("unit", sa.TEXT(), nullable=True), sa.Column("unit", sa.TEXT(), nullable=True),
sa.Column("is_temporal", sa.BOOLEAN(), nullable=False), sa.Column("is_temporal", sa.BOOLEAN(), nullable=False),
sa.Column("is_spatial", sa.BOOLEAN(), nullable=False, default=False,), sa.Column(
sa.Column("is_partition", sa.BOOLEAN(), nullable=False, default=False,), "is_spatial",
sa.Column("is_aggregation", sa.BOOLEAN(), nullable=False, default=False,), sa.BOOLEAN(),
sa.Column("is_additive", sa.BOOLEAN(), nullable=False, default=False,), nullable=False,
sa.Column("is_increase_desired", sa.BOOLEAN(), nullable=False, default=True,), default=False,
),
sa.Column(
"is_partition",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_aggregation",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_additive",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_increase_desired",
sa.BOOLEAN(),
nullable=False,
default=True,
),
sa.Column( sa.Column(
"is_managed_externally", "is_managed_externally",
sa.Boolean(), sa.Boolean(),
@ -459,7 +489,12 @@ def upgrade():
sa.Column("sqlatable_id", sa.INTEGER(), nullable=True), sa.Column("sqlatable_id", sa.INTEGER(), nullable=True),
sa.Column("name", sa.TEXT(), nullable=False), sa.Column("name", sa.TEXT(), nullable=False),
sa.Column("expression", sa.TEXT(), nullable=False), sa.Column("expression", sa.TEXT(), nullable=False),
sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=False,), sa.Column(
"is_physical",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column( sa.Column(
"is_managed_externally", "is_managed_externally",
sa.Boolean(), sa.Boolean(),

View File

@ -39,19 +39,51 @@ revision = "c25cb2c78727"
down_revision = "ccb74baaa89b" down_revision = "ccb74baaa89b"
NEW_PVMS = {"Annotation": ("can_read", "can_write",)} NEW_PVMS = {
"Annotation": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("AnnotationLayerModelView", "can_delete"): (Pvm("Annotation", "can_write"),), Pvm("AnnotationLayerModelView", "can_delete"): (Pvm("Annotation", "can_write"),),
Pvm("AnnotationLayerModelView", "can_list"): (Pvm("Annotation", "can_read"),), Pvm("AnnotationLayerModelView", "can_list"): (Pvm("Annotation", "can_read"),),
Pvm("AnnotationLayerModelView", "can_show",): (Pvm("Annotation", "can_read"),), Pvm(
Pvm("AnnotationLayerModelView", "can_add",): (Pvm("Annotation", "can_write"),), "AnnotationLayerModelView",
Pvm("AnnotationLayerModelView", "can_edit",): (Pvm("Annotation", "can_write"),), "can_show",
Pvm("AnnotationModelView", "can_annotation",): (Pvm("Annotation", "can_read"),), ): (Pvm("Annotation", "can_read"),),
Pvm("AnnotationModelView", "can_show",): (Pvm("Annotation", "can_read"),), Pvm(
Pvm("AnnotationModelView", "can_add",): (Pvm("Annotation", "can_write"),), "AnnotationLayerModelView",
Pvm("AnnotationModelView", "can_delete",): (Pvm("Annotation", "can_write"),), "can_add",
Pvm("AnnotationModelView", "can_edit",): (Pvm("Annotation", "can_write"),), ): (Pvm("Annotation", "can_write"),),
Pvm("AnnotationModelView", "can_list",): (Pvm("Annotation", "can_read"),), Pvm(
"AnnotationLayerModelView",
"can_edit",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationModelView",
"can_annotation",
): (Pvm("Annotation", "can_read"),),
Pvm(
"AnnotationModelView",
"can_show",
): (Pvm("Annotation", "can_read"),),
Pvm(
"AnnotationModelView",
"can_add",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationModelView",
"can_delete",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationModelView",
"can_edit",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationModelView",
"can_list",
): (Pvm("Annotation", "can_read"),),
} }

View File

@ -67,7 +67,10 @@ def upgrade():
with op.batch_alter_table(table_name) as batch_op: with op.batch_alter_table(table_name) as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column( sa.Column(
"uuid", UUIDType(binary=True), primary_key=False, default=uuid4, "uuid",
UUIDType(binary=True),
primary_key=False,
default=uuid4,
), ),
) )
add_uuids(model, table_name, session) add_uuids(model, table_name, session)

View File

@ -52,7 +52,10 @@ class AuditMixinNullable(AuditMixin):
@declared_attr @declared_attr
def created_by_fk(self) -> Column: def created_by_fk(self) -> Column:
return Column( return Column(
Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True, Integer,
ForeignKey("ab_user.id"),
default=self.get_user_id,
nullable=True,
) )
@declared_attr @declared_attr

View File

@ -80,7 +80,8 @@ def upgrade():
if isinstance(bind.dialect, MySQLDialect): if isinstance(bind.dialect, MySQLDialect):
op.drop_index( op.drop_index(
op.f("name"), table_name="report_schedule", op.f("name"),
table_name="report_schedule",
) )
if isinstance(bind.dialect, PGDialect): if isinstance(bind.dialect, PGDialect):

View File

@ -39,22 +39,63 @@ from superset.migrations.shared.security_converge import (
Pvm, Pvm,
) )
NEW_PVMS = {"Chart": ("can_read", "can_write",)} NEW_PVMS = {
"Chart": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("SliceModelView", "can_list"): (Pvm("Chart", "can_read"),), Pvm("SliceModelView", "can_list"): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "can_show"): (Pvm("Chart", "can_read"),), Pvm("SliceModelView", "can_show"): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "can_edit",): (Pvm("Chart", "can_write"),), Pvm(
Pvm("SliceModelView", "can_delete",): (Pvm("Chart", "can_write"),), "SliceModelView",
Pvm("SliceModelView", "can_add",): (Pvm("Chart", "can_write"),), "can_edit",
Pvm("SliceModelView", "can_download",): (Pvm("Chart", "can_read"),), ): (Pvm("Chart", "can_write"),),
Pvm("SliceModelView", "muldelete",): (Pvm("Chart", "can_write"),), Pvm(
Pvm("SliceModelView", "can_mulexport",): (Pvm("Chart", "can_read"),), "SliceModelView",
Pvm("SliceModelView", "can_favorite_status",): (Pvm("Chart", "can_read"),), "can_delete",
Pvm("SliceModelView", "can_cache_screenshot",): (Pvm("Chart", "can_read"),), ): (Pvm("Chart", "can_write"),),
Pvm("SliceModelView", "can_screenshot",): (Pvm("Chart", "can_read"),), Pvm(
Pvm("SliceModelView", "can_data_from_cache",): (Pvm("Chart", "can_read"),), "SliceModelView",
Pvm("SliceAsync", "can_list",): (Pvm("Chart", "can_read"),), "can_add",
Pvm("SliceAsync", "muldelete",): (Pvm("Chart", "can_write"),), ): (Pvm("Chart", "can_write"),),
Pvm(
"SliceModelView",
"can_download",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"muldelete",
): (Pvm("Chart", "can_write"),),
Pvm(
"SliceModelView",
"can_mulexport",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"can_favorite_status",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"can_cache_screenshot",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"can_screenshot",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"can_data_from_cache",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceAsync",
"can_list",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceAsync",
"muldelete",
): (Pvm("Chart", "can_write"),),
} }

View File

@ -33,7 +33,11 @@ from alembic import op
def upgrade(): def upgrade():
with op.batch_alter_table("query") as batch_op: with op.batch_alter_table("query") as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column("limiting_factor", sa.VARCHAR(255), server_default="UNKNOWN",) sa.Column(
"limiting_factor",
sa.VARCHAR(255),
server_default="UNKNOWN",
)
) )

View File

@ -39,20 +39,55 @@ from superset.migrations.shared.security_converge import (
Pvm, Pvm,
) )
NEW_PVMS = {"SavedQuery": ("can_read", "can_write",)} NEW_PVMS = {
"SavedQuery": (
"can_read",
"can_write",
)
}
PVM_MAP = { PVM_MAP = {
Pvm("SavedQueryView", "can_list"): (Pvm("SavedQuery", "can_read"),), Pvm("SavedQueryView", "can_list"): (Pvm("SavedQuery", "can_read"),),
Pvm("SavedQueryView", "can_show"): (Pvm("SavedQuery", "can_read"),), Pvm("SavedQueryView", "can_show"): (Pvm("SavedQuery", "can_read"),),
Pvm("SavedQueryView", "can_add",): (Pvm("SavedQuery", "can_write"),), Pvm(
Pvm("SavedQueryView", "can_edit",): (Pvm("SavedQuery", "can_write"),), "SavedQueryView",
Pvm("SavedQueryView", "can_delete",): (Pvm("SavedQuery", "can_write"),), "can_add",
Pvm("SavedQueryView", "muldelete",): (Pvm("SavedQuery", "can_write"),), ): (Pvm("SavedQuery", "can_write"),),
Pvm("SavedQueryView", "can_mulexport",): (Pvm("SavedQuery", "can_read"),), Pvm(
Pvm("SavedQueryViewApi", "can_show",): (Pvm("SavedQuery", "can_read"),), "SavedQueryView",
Pvm("SavedQueryViewApi", "can_edit",): (Pvm("SavedQuery", "can_write"),), "can_edit",
Pvm("SavedQueryViewApi", "can_list",): (Pvm("SavedQuery", "can_read"),), ): (Pvm("SavedQuery", "can_write"),),
Pvm("SavedQueryViewApi", "can_add",): (Pvm("SavedQuery", "can_write"),), Pvm(
Pvm("SavedQueryViewApi", "muldelete",): (Pvm("SavedQuery", "can_write"),), "SavedQueryView",
"can_delete",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryView",
"muldelete",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryView",
"can_mulexport",
): (Pvm("SavedQuery", "can_read"),),
Pvm(
"SavedQueryViewApi",
"can_show",
): (Pvm("SavedQuery", "can_read"),),
Pvm(
"SavedQueryViewApi",
"can_edit",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryViewApi",
"can_list",
): (Pvm("SavedQuery", "can_read"),),
Pvm(
"SavedQueryViewApi",
"can_add",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryViewApi",
"muldelete",
): (Pvm("SavedQuery", "can_write"),),
} }

View File

@ -53,6 +53,8 @@ def upgrade():
def downgrade(): def downgrade():
with op.batch_alter_table("row_level_security_filters") as batch_op: with op.batch_alter_table("row_level_security_filters") as batch_op:
batch_op.drop_index(op.f("ix_row_level_security_filters_filter_type"),) batch_op.drop_index(
op.f("ix_row_level_security_filters_filter_type"),
)
batch_op.drop_column("filter_type") batch_op.drop_column("filter_type")
batch_op.drop_column("group_key") batch_op.drop_column("group_key")

View File

@ -322,7 +322,9 @@ class Database(
self.sqlalchemy_uri = str(conn) # hides the password self.sqlalchemy_uri = str(conn) # hides the password
def get_effective_user( def get_effective_user(
self, object_url: URL, user_name: Optional[str] = None, self,
object_url: URL,
user_name: Optional[str] = None,
) -> Optional[str]: ) -> Optional[str]:
""" """
Get the effective user, especially during impersonation. Get the effective user, especially during impersonation.

View File

@ -344,7 +344,8 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
@debounce(0.1) @debounce(0.1)
def clear_cache_for_datasource(cls, datasource_id: int) -> None: def clear_cache_for_datasource(cls, datasource_id: int) -> None:
filter_query = select( filter_query = select(
[dashboard_slices.c.dashboard_id], distinct=True, [dashboard_slices.c.dashboard_id],
distinct=True,
).select_from( ).select_from(
join( join(
dashboard_slices, dashboard_slices,

View File

@ -18,7 +18,11 @@ from marshmallow import fields, Schema
from marshmallow.validate import Length from marshmallow.validate import Length
openapi_spec_methods_override = { openapi_spec_methods_override = {
"get": {"get": {"description": "Get a saved query",}}, "get": {
"get": {
"description": "Get a saved query",
}
},
"get_list": { "get_list": {
"get": { "get": {
"description": "Get a list of saved queries, use Rison or JSON " "description": "Get a list of saved queries, use Rison or JSON "

View File

@ -47,7 +47,7 @@ class BaseReportScheduleCommand(BaseCommand):
def validate_chart_dashboard( def validate_chart_dashboard(
self, exceptions: List[ValidationError], update: bool = False self, exceptions: List[ValidationError], update: bool = False
) -> None: ) -> None:
""" Validate chart or dashboard relation """ """Validate chart or dashboard relation"""
chart_id = self._properties.get("chart") chart_id = self._properties.get("chart")
dashboard_id = self._properties.get("dashboard") dashboard_id = self._properties.get("dashboard")
creation_method = self._properties.get("creation_method") creation_method = self._properties.get("creation_method")

View File

@ -95,7 +95,9 @@ class BaseReportState:
self._execution_id = execution_id self._execution_id = execution_id
def set_state_and_log( def set_state_and_log(
self, state: ReportState, error_message: Optional[str] = None, self,
state: ReportState,
error_message: Optional[str] = None,
) -> None: ) -> None:
""" """
Updates current ReportSchedule state and TS. If on final state writes the log Updates current ReportSchedule state and TS. If on final state writes the log
@ -104,7 +106,8 @@ class BaseReportState:
now_dttm = datetime.utcnow() now_dttm = datetime.utcnow()
self.set_state(state, now_dttm) self.set_state(state, now_dttm)
self.create_log( self.create_log(
state, error_message=error_message, state,
error_message=error_message,
) )
def set_state(self, state: ReportState, dttm: datetime) -> None: def set_state(self, state: ReportState, dttm: datetime) -> None:
@ -531,12 +534,14 @@ class ReportWorkingState(BaseReportState):
if self.is_on_working_timeout(): if self.is_on_working_timeout():
exception_timeout = ReportScheduleWorkingTimeoutError() exception_timeout = ReportScheduleWorkingTimeoutError()
self.set_state_and_log( self.set_state_and_log(
ReportState.ERROR, error_message=str(exception_timeout), ReportState.ERROR,
error_message=str(exception_timeout),
) )
raise exception_timeout raise exception_timeout
exception_working = ReportSchedulePreviousWorkingError() exception_working = ReportSchedulePreviousWorkingError()
self.set_state_and_log( self.set_state_and_log(
ReportState.WORKING, error_message=str(exception_working), ReportState.WORKING,
error_message=str(exception_working),
) )
raise exception_working raise exception_working

View File

@ -227,7 +227,8 @@ class ReportScheduleDAO(BaseDAO):
@staticmethod @staticmethod
def find_last_success_log( def find_last_success_log(
report_schedule: ReportSchedule, session: Optional[Session] = None, report_schedule: ReportSchedule,
session: Optional[Session] = None,
) -> Optional[ReportExecutionLog]: ) -> Optional[ReportExecutionLog]:
""" """
Finds last success execution log for a given report Finds last success execution log for a given report
@ -245,7 +246,8 @@ class ReportScheduleDAO(BaseDAO):
@staticmethod @staticmethod
def find_last_entered_working_log( def find_last_entered_working_log(
report_schedule: ReportSchedule, session: Optional[Session] = None, report_schedule: ReportSchedule,
session: Optional[Session] = None,
) -> Optional[ReportExecutionLog]: ) -> Optional[ReportExecutionLog]:
""" """
Finds last success execution log for a given report Finds last success execution log for a given report
@ -264,7 +266,8 @@ class ReportScheduleDAO(BaseDAO):
@staticmethod @staticmethod
def find_last_error_notification( def find_last_error_notification(
report_schedule: ReportSchedule, session: Optional[Session] = None, report_schedule: ReportSchedule,
session: Optional[Session] = None,
) -> Optional[ReportExecutionLog]: ) -> Optional[ReportExecutionLog]:
""" """
Finds last error email sent Finds last error email sent

View File

@ -203,7 +203,9 @@ class ReportSchedulePostSchema(Schema):
default=ReportDataFormat.VISUALIZATION, default=ReportDataFormat.VISUALIZATION,
validate=validate.OneOf(choices=tuple(key.value for key in ReportDataFormat)), validate=validate.OneOf(choices=tuple(key.value for key in ReportDataFormat)),
) )
extra = fields.Dict(default=None,) extra = fields.Dict(
default=None,
)
force_screenshot = fields.Boolean(default=False) force_screenshot = fields.Boolean(default=False)
@validates_schema @validates_schema

View File

@ -1307,7 +1307,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
@staticmethod @staticmethod
def _get_current_epoch_time() -> float: def _get_current_epoch_time() -> float:
""" This is used so the tests can mock time """ """This is used so the tests can mock time"""
return time.time() return time.time()
@staticmethod @staticmethod
@ -1376,7 +1376,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
def get_guest_user_from_token(self, token: GuestToken) -> GuestUser: def get_guest_user_from_token(self, token: GuestToken) -> GuestUser:
return self.guest_user_cls( return self.guest_user_cls(
token=token, roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], token=token,
roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])],
) )
def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]: def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]:

View File

@ -170,7 +170,10 @@ class ExecuteSqlCommand(BaseCommand):
except Exception as ex: except Exception as ex:
raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex
def _set_query_limit_if_required(self, rendered_query: str,) -> None: def _set_query_limit_if_required(
self,
rendered_query: str,
) -> None:
if self._is_required_to_set_limit(): if self._is_required_to_set_limit():
self._set_query_limit(rendered_query) self._set_query_limit(rendered_query)

Some files were not shown because too many files have changed in this diff Show More