chore: upgrade black (#19410)

This commit is contained in:
Ville Brofeldt 2022-03-29 20:03:09 +03:00 committed by GitHub
parent 816a2c3e1e
commit a619cb4ea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
204 changed files with 2125 additions and 608 deletions

View File

@ -41,7 +41,7 @@ repos:
- id: trailing-whitespace
args: ["--markdown-linebreak-ext=md"]
- repo: https://github.com/psf/black
rev: 19.10b0
rev: 22.3.0
hooks:
- id: black
language_version: python3

View File

@ -167,7 +167,10 @@ class GitChangeLog:
return f"### {self._version} ({self._logs[0].time})"
def _parse_change_log(
self, changelog: Dict[str, str], pr_info: Dict[str, str], github_login: str,
self,
changelog: Dict[str, str],
pr_info: Dict[str, str],
github_login: str,
) -> None:
formatted_pr = (
f"- [#{pr_info.get('id')}]"
@ -355,7 +358,8 @@ def compare(base_parameters: BaseParameters) -> None:
@cli.command("changelog")
@click.option(
"--csv", help="The csv filename to export the changelog to",
"--csv",
help="The csv filename to export the changelog to",
)
@click.option(
"--access_token",

View File

@ -106,7 +106,12 @@ def inter_send_email(
class BaseParameters(object):
def __init__(
self, email: str, username: str, password: str, version: str, version_rc: str,
self,
email: str,
username: str,
password: str,
version: str,
version_rc: str,
) -> None:
self.email = email
self.username = username

View File

@ -60,7 +60,8 @@ def request(
def list_runs(
repo: str, params: Optional[Dict[str, str]] = None,
repo: str,
params: Optional[Dict[str, str]] = None,
) -> Iterator[Dict[str, Any]]:
"""List all github workflow runs.
Returns:
@ -193,7 +194,11 @@ def cancel_github_workflows(
if branch and ":" in branch:
[user, branch] = branch.split(":", 2)
runs = get_runs(
repo, branch=branch, user=user, statuses=statuses, events=events,
repo,
branch=branch,
user=user,
statuses=statuses,
events=events,
)
# sort old jobs to the front, so to cancel older jobs first

View File

@ -73,7 +73,9 @@ class UpdateAnnotationCommand(BaseCommand):
# Validate short descr uniqueness on this layer
if not AnnotationDAO.validate_update_uniqueness(
layer_id, short_descr, annotation_id=self._model_id,
layer_id,
short_descr,
annotation_id=self._model_id,
):
exceptions.append(AnnotationUniquenessValidationError())
else:

View File

@ -64,13 +64,17 @@ class AnnotationPostSchema(Schema):
)
long_descr = fields.String(description=annotation_long_descr, allow_none=True)
start_dttm = fields.DateTime(
description=annotation_start_dttm, required=True, allow_none=False,
description=annotation_start_dttm,
required=True,
allow_none=False,
)
end_dttm = fields.DateTime(
description=annotation_end_dttm, required=True, allow_none=False
)
json_metadata = fields.String(
description=annotation_json_metadata, validate=validate_json, allow_none=True,
description=annotation_json_metadata,
validate=validate_json,
allow_none=True,
)

View File

@ -110,9 +110,11 @@ class CacheRestApi(BaseSupersetModelRestApi):
)
try:
delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member
delete_stmt = (
CacheKey.__table__.delete().where( # pylint: disable=no-member
CacheKey.cache_key.in_(cache_keys)
)
)
db.session.execute(delete_stmt)
db.session.commit()
self.stats_logger.gauge("invalidated_cache", len(cache_keys))

View File

@ -25,9 +25,15 @@ from superset.charts.schemas import (
class Datasource(Schema):
database_name = fields.String(description="Datasource name",)
datasource_name = fields.String(description=datasource_name_description,)
schema = fields.String(description="Datasource schema",)
database_name = fields.String(
description="Datasource name",
)
datasource_name = fields.String(
description=datasource_name_description,
)
schema = fields.String(
description="Datasource schema",
)
datasource_type = fields.String(
description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")),
@ -37,7 +43,8 @@ class Datasource(Schema):
class CacheInvalidationRequestSchema(Schema):
datasource_uids = fields.List(
fields.String(), description=datasource_uid_description,
fields.String(),
description=datasource_uid_description,
)
datasources = fields.List(
fields.Nested(Datasource),

View File

@ -279,7 +279,8 @@ class ChartCacheScreenshotResponseSchema(Schema):
class ChartDataColumnSchema(Schema):
column_name = fields.String(
description="The name of the target column", example="mycol",
description="The name of the target column",
example="mycol",
)
type = fields.String(description="Type of target column", example="BIGINT")
@ -325,7 +326,8 @@ class ChartDataAdhocMetricSchema(Schema):
example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30",
)
timeGrain = fields.String(
description="Optional time grain for temporal filters", example="PT1M",
description="Optional time grain for temporal filters",
example="PT1M",
)
isExtra = fields.Boolean(
description="Indicates if the filter has been added by a filter component as "
@ -370,7 +372,8 @@ class ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSch
groupby = (
fields.List(
fields.String(
allow_none=False, description="Columns by which to group by",
allow_none=False,
description="Columns by which to group by",
),
minLength=1,
required=True,
@ -425,7 +428,9 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
example="percentile",
)
window = fields.Integer(
description="Size of the rolling window in days.", required=True, example=7,
description="Size of the rolling window in days.",
required=True,
example=7,
)
rolling_type_options = fields.Dict(
desctiption="Optional options to pass to rolling method. Needed for "
@ -592,7 +597,9 @@ class ChartDataBoxplotOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
"""
groupby = fields.List(
fields.String(description="Columns by which to group the query.",),
fields.String(
description="Columns by which to group the query.",
),
allow_none=True,
)
@ -699,13 +706,16 @@ class ChartDataGeohashDecodeOptionsSchema(
"""
geohash = fields.String(
description="Name of source column containing geohash string", required=True,
description="Name of source column containing geohash string",
required=True,
)
latitude = fields.String(
description="Name of target column for decoded latitude", required=True,
description="Name of target column for decoded latitude",
required=True,
)
longitude = fields.String(
description="Name of target column for decoded longitude", required=True,
description="Name of target column for decoded longitude",
required=True,
)
@ -717,13 +727,16 @@ class ChartDataGeohashEncodeOptionsSchema(
"""
latitude = fields.String(
description="Name of source latitude column", required=True,
description="Name of source latitude column",
required=True,
)
longitude = fields.String(
description="Name of source longitude column", required=True,
description="Name of source longitude column",
required=True,
)
geohash = fields.String(
description="Name of target column for encoded geohash string", required=True,
description="Name of target column for encoded geohash string",
required=True,
)
@ -739,10 +752,12 @@ class ChartDataGeodeticParseOptionsSchema(
required=True,
)
latitude = fields.String(
description="Name of target column for decoded latitude", required=True,
description="Name of target column for decoded latitude",
required=True,
)
longitude = fields.String(
description="Name of target column for decoded longitude", required=True,
description="Name of target column for decoded longitude",
required=True,
)
altitude = fields.String(
description="Name of target column for decoded altitude. If omitted, "
@ -789,7 +804,10 @@ class ChartDataPostProcessingOperationSchema(Schema):
"column": "age",
"options": {"q": 0.25},
},
"age_mean": {"operator": "mean", "column": "age",},
"age_mean": {
"operator": "mean",
"column": "age",
},
},
},
)
@ -816,7 +834,8 @@ class ChartDataFilterSchema(Schema):
example=["China", "France", "Japan"],
)
grain = fields.String(
description="Optional time grain for temporal filters", example="PT1M",
description="Optional time grain for temporal filters",
example="PT1M",
)
isExtra = fields.Boolean(
description="Indicates if the filter has been added by a filter component as "
@ -873,7 +892,10 @@ class AnnotationLayerSchema(Schema):
description="Type of annotation layer",
validate=validate.OneOf(choices=[ann.value for ann in AnnotationType]),
)
color = fields.String(description="Layer color", allow_none=True,)
color = fields.String(
description="Layer color",
allow_none=True,
)
descriptionColumns = fields.List(
fields.String(),
description="Columns to use as the description. If none are provided, "
@ -911,7 +933,8 @@ class AnnotationLayerSchema(Schema):
)
show = fields.Boolean(description="Should the layer be shown", required=True)
showLabel = fields.Boolean(
description="Should the label always be shown", allow_none=True,
description="Should the label always be shown",
allow_none=True,
)
showMarkers = fields.Boolean(
description="Should markers be shown. Only applies to line annotations.",
@ -919,16 +942,34 @@ class AnnotationLayerSchema(Schema):
)
sourceType = fields.String(
description="Type of source for annotation data",
validate=validate.OneOf(choices=("", "line", "NATIVE", "table",)),
validate=validate.OneOf(
choices=(
"",
"line",
"NATIVE",
"table",
)
),
)
style = fields.String(
description="Line style. Only applies to time-series annotations",
validate=validate.OneOf(choices=("dashed", "dotted", "solid", "longDashed",)),
validate=validate.OneOf(
choices=(
"dashed",
"dotted",
"solid",
"longDashed",
)
),
)
timeColumn = fields.String(
description="Column with event date or interval start date", allow_none=True,
description="Column with event date or interval start date",
allow_none=True,
)
titleColumn = fields.String(
description="Column with title",
allow_none=True,
)
titleColumn = fields.String(description="Column with title", allow_none=True,)
width = fields.Float(
description="Width of annotation line",
validate=[
@ -948,7 +989,10 @@ class AnnotationLayerSchema(Schema):
class ChartDataDatasourceSchema(Schema):
description = "Chart datasource"
id = fields.Integer(description="Datasource id", required=True,)
id = fields.Integer(
description="Datasource id",
required=True,
)
type = fields.String(
description="Datasource type",
validate=validate.OneOf(choices=("druid", "table")),
@ -1039,7 +1083,8 @@ class ChartDataQueryObjectSchema(Schema):
allow_none=True,
)
is_timeseries = fields.Boolean(
description="Is the `query_object` a timeseries.", allow_none=True,
description="Is the `query_object` a timeseries.",
allow_none=True,
)
series_columns = fields.List(
fields.Raw(),
@ -1084,7 +1129,8 @@ class ChartDataQueryObjectSchema(Schema):
],
)
order_desc = fields.Boolean(
description="Reverse order. Default: `false`", allow_none=True,
description="Reverse order. Default: `false`",
allow_none=True,
)
extras = fields.Nested(
ChartDataExtrasSchema,
@ -1151,7 +1197,10 @@ class ChartDataQueryObjectSchema(Schema):
description="Should the rowcount of the actual query be returned",
allow_none=True,
)
time_offsets = fields.List(fields.String(), allow_none=True,)
time_offsets = fields.List(
fields.String(),
allow_none=True,
)
class ChartDataQueryContextSchema(Schema):
@ -1190,7 +1239,9 @@ class AnnotationDataSchema(Schema):
required=True,
)
records = fields.List(
fields.Dict(keys=fields.String(),),
fields.Dict(
keys=fields.String(),
),
description="records mapping the column name to it's value",
required=True,
)
@ -1206,10 +1257,14 @@ class ChartDataResponseResult(Schema):
allow_none=True,
)
cache_key = fields.String(
description="Unique cache key for query object", required=True, allow_none=True,
description="Unique cache key for query object",
required=True,
allow_none=True,
)
cached_dttm = fields.String(
description="Cache timestamp", required=True, allow_none=True,
description="Cache timestamp",
required=True,
allow_none=True,
)
cache_timeout = fields.Integer(
description="Cache timeout in following order: custom timeout, datasource "
@ -1217,12 +1272,19 @@ class ChartDataResponseResult(Schema):
required=True,
allow_none=True,
)
error = fields.String(description="Error", allow_none=True,)
error = fields.String(
description="Error",
allow_none=True,
)
is_cached = fields.Boolean(
description="Is the result cached", required=True, allow_none=None,
description="Is the result cached",
required=True,
allow_none=None,
)
query = fields.String(
description="The executed query statement", required=True, allow_none=False,
description="The executed query statement",
required=True,
allow_none=False,
)
status = fields.String(
description="Status of the query",
@ -1240,10 +1302,12 @@ class ChartDataResponseResult(Schema):
allow_none=False,
)
stacktrace = fields.String(
desciption="Stacktrace if there was an error", allow_none=True,
desciption="Stacktrace if there was an error",
allow_none=True,
)
rowcount = fields.Integer(
description="Amount of rows in result set", allow_none=False,
description="Amount of rows in result set",
allow_none=False,
)
data = fields.List(fields.Dict(), description="A list with results")
colnames = fields.List(fields.String(), description="A list of column names")
@ -1273,13 +1337,24 @@ class ChartDataResponseSchema(Schema):
class ChartDataAsyncResponseSchema(Schema):
channel_id = fields.String(
description="Unique session async channel ID", allow_none=False,
description="Unique session async channel ID",
allow_none=False,
)
job_id = fields.String(
description="Unique async job ID",
allow_none=False,
)
user_id = fields.String(
description="Requesting user ID",
allow_none=True,
)
status = fields.String(
description="Status value for async job",
allow_none=False,
)
job_id = fields.String(description="Unique async job ID", allow_none=False,)
user_id = fields.String(description="Requesting user ID", allow_none=True,)
status = fields.String(description="Status value for async job", allow_none=False,)
result_url = fields.String(
description="Unique result URL for fetching async query data", allow_none=False,
description="Unique result URL for fetching async query data",
allow_none=False,
)

View File

@ -93,10 +93,16 @@ def load_examples_run(
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
@click.option("--load-big-data", "-b", is_flag=True, help="Load additional big data")
@click.option(
"--only-metadata", "-m", is_flag=True, help="Only load metadata, skip actual data",
"--only-metadata",
"-m",
is_flag=True,
help="Only load metadata, skip actual data",
)
@click.option(
"--force", "-f", is_flag=True, help="Force load data even if table already exists",
"--force",
"-f",
is_flag=True,
help="Force load data even if table already exists",
)
def load_examples(
load_test_data: bool,

View File

@ -36,10 +36,16 @@ logger = logging.getLogger(__name__)
@click.command()
@click.argument("directory")
@click.option(
"--overwrite", "-o", is_flag=True, help="Overwriting existing metadata definitions",
"--overwrite",
"-o",
is_flag=True,
help="Overwriting existing metadata definitions",
)
@click.option(
"--force", "-f", is_flag=True, help="Force load data even if table already exists",
"--force",
"-f",
is_flag=True,
help="Force load data even if table already exists",
)
def import_directory(directory: str, overwrite: bool, force: bool) -> None:
"""Imports configs from a given directory"""
@ -47,7 +53,9 @@ def import_directory(directory: str, overwrite: bool, force: bool) -> None:
from superset.examples.utils import load_configs_from_directory
load_configs_from_directory(
root=Path(directory), overwrite=overwrite, force_data=force,
root=Path(directory),
overwrite=overwrite,
force_data=force,
)
@ -56,7 +64,9 @@ if feature_flags.get("VERSIONED_EXPORT"):
@click.command()
@with_appcontext
@click.option(
"--dashboard-file", "-f", help="Specify the the file to export to",
"--dashboard-file",
"-f",
help="Specify the the file to export to",
)
def export_dashboards(dashboard_file: Optional[str] = None) -> None:
"""Export dashboards to ZIP file"""
@ -90,7 +100,9 @@ if feature_flags.get("VERSIONED_EXPORT"):
@click.command()
@with_appcontext
@click.option(
"--datasource-file", "-f", help="Specify the the file to export to",
"--datasource-file",
"-f",
help="Specify the the file to export to",
)
def export_datasources(datasource_file: Optional[str] = None) -> None:
"""Export datasources to ZIP file"""
@ -122,7 +134,9 @@ if feature_flags.get("VERSIONED_EXPORT"):
@click.command()
@with_appcontext
@click.option(
"--path", "-p", help="Path to a single ZIP file",
"--path",
"-p",
help="Path to a single ZIP file",
)
@click.option(
"--username",
@ -160,7 +174,9 @@ if feature_flags.get("VERSIONED_EXPORT"):
@click.command()
@with_appcontext
@click.option(
"--path", "-p", help="Path to a single ZIP file",
"--path",
"-p",
help="Path to a single ZIP file",
)
def import_datasources(path: str) -> None:
"""Import datasources from ZIP file"""
@ -185,7 +201,6 @@ if feature_flags.get("VERSIONED_EXPORT"):
)
sys.exit(1)
else:
@click.command()

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
@click.group(
cls=FlaskGroup, context_settings={"token_normalize_func": normalize_token},
cls=FlaskGroup,
context_settings={"token_normalize_func": normalize_token},
)
@with_appcontext
def superset() -> None:

View File

@ -44,7 +44,11 @@ logger = logging.getLogger(__name__)
help="Only process dashboards",
)
@click.option(
"--charts_only", "-c", is_flag=True, default=False, help="Only process charts",
"--charts_only",
"-c",
is_flag=True,
default=False,
help="Only process charts",
)
@click.option(
"--force",

View File

@ -35,7 +35,10 @@ from superset.models.helpers import (
class Column(
Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin,
Model,
AuditMixinNullable,
ExtraJSONMixin,
ImportExportMixin,
):
"""
A "column".

View File

@ -79,7 +79,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
def validate_metadata_type(
metadata: Optional[Dict[str, str]], type_: str, exceptions: List[ValidationError],
metadata: Optional[Dict[str, str]],
type_: str,
exceptions: List[ValidationError],
) -> None:
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
if metadata and "type" in metadata:

View File

@ -34,7 +34,9 @@ if TYPE_CHECKING:
def populate_owners(
user: User, owner_ids: Optional[List[int]], default_to_user: bool,
user: User,
owner_ids: Optional[List[int]],
default_to_user: bool,
) -> List[User]:
"""
Helper function for commands, will fetch all users from owners id's

View File

@ -79,7 +79,9 @@ def _get_timegrains(
def _get_query(
query_context: "QueryContext", query_obj: "QueryObject", _: bool,
query_context: "QueryContext",
query_obj: "QueryObject",
_: bool,
) -> Dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
result = {"language": datasource.query_language}

View File

@ -69,7 +69,7 @@ class QueryContext:
result_format: ChartDataResultFormat,
force: bool = False,
custom_cache_timeout: Optional[int] = None,
cache_values: Dict[str, Any]
cache_values: Dict[str, Any],
) -> None:
self.datasource = datasource
self.result_type = result_type
@ -81,11 +81,16 @@ class QueryContext:
self.cache_values = cache_values
self._processor = QueryContextProcessor(self)
def get_data(self, df: pd.DataFrame,) -> Union[str, List[Dict[str, Any]]]:
def get_data(
self,
df: pd.DataFrame,
) -> Union[str, List[Dict[str, Any]]]:
return self._processor.get_data(df)
def get_payload(
self, cache_query_context: Optional[bool] = False, force_cached: bool = False,
self,
cache_query_context: Optional[bool] = False,
force_cached: bool = False,
) -> Dict[str, Any]:
"""Returns the query results with both metadata and data"""
return self._processor.get_payload(cache_query_context, force_cached)
@ -103,7 +108,9 @@ class QueryContext:
return self._processor.query_cache_key(query_obj, **kwargs)
def get_df_payload(
self, query_obj: QueryObject, force_cached: Optional[bool] = False,
self,
query_obj: QueryObject,
force_cached: Optional[bool] = False,
) -> Dict[str, Any]:
return self._processor.get_df_payload(query_obj, force_cached)
@ -111,7 +118,9 @@ class QueryContext:
return self._processor.get_query_result(query_object)
def processing_time_offsets(
self, df: pd.DataFrame, query_object: QueryObject,
self,
df: pd.DataFrame,
query_object: QueryObject,
) -> CachedTimeOffset:
return self._processor.processing_time_offsets(df, query_object)

View File

@ -50,7 +50,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
force: bool = False,
custom_cache_timeout: Optional[int] = None
custom_cache_timeout: Optional[int] = None,
) -> QueryContext:
datasource_model_instance = None
if datasource:

View File

@ -99,7 +99,10 @@ class QueryContextProcessor:
"""Handles caching around the df payload retrieval"""
cache_key = self.query_cache_key(query_obj)
cache = QueryCacheManager.get(
cache_key, CacheRegion.DATA, self._query_context.force, force_cached,
cache_key,
CacheRegion.DATA,
self._query_context.force,
force_cached,
)
if query_obj and cache_key and not cache.is_loaded:
@ -235,7 +238,9 @@ class QueryContextProcessor:
return df
def processing_time_offsets( # pylint: disable=too-many-locals
self, df: pd.DataFrame, query_object: QueryObject,
self,
df: pd.DataFrame,
query_object: QueryObject,
) -> CachedTimeOffset:
query_context = self._query_context
# ensure query_object is immutable
@ -250,7 +255,8 @@ class QueryContextProcessor:
for offset in time_offsets:
try:
query_object_clone.from_dttm = get_past_or_future(
offset, outer_from_dttm,
offset,
outer_from_dttm,
)
query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm)
except ValueError as ex:
@ -322,7 +328,9 @@ class QueryContextProcessor:
# df left join `offset_metrics_df`
offset_df = df_utils.left_join_df(
left_df=df, right_df=offset_metrics_df, join_keys=join_keys,
left_df=df,
right_df=offset_metrics_df,
join_keys=join_keys,
)
offset_slice = offset_df[metrics_mapping.values()]
@ -358,7 +366,9 @@ class QueryContextProcessor:
return df.to_dict(orient="records")
def get_payload(
self, cache_query_context: Optional[bool] = False, force_cached: bool = False,
self,
cache_query_context: Optional[bool] = False,
force_cached: bool = False,
) -> Dict[str, Any]:
"""Returns the query results with both metadata and data"""

View File

@ -341,7 +341,11 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
def __repr__(self) -> str:
# we use `print` or `logging` output QueryObject
return json.dumps(self.to_dict(), sort_keys=True, default=str,)
return json.dumps(
self.to_dict(),
sort_keys=True,
default=str,
)
def cache_key(self, **extra: Any) -> str:
"""

View File

@ -80,7 +80,8 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
)
def _process_extras( # pylint: disable=no-self-use
self, extras: Optional[Dict[str, Any]],
self,
extras: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
extras = extras or {}
return extras

View File

@ -26,7 +26,9 @@ if TYPE_CHECKING:
def left_join_df(
left_df: pd.DataFrame, right_df: pd.DataFrame, join_keys: List[str],
left_df: pd.DataFrame,
right_df: pd.DataFrame,
join_keys: List[str],
) -> pd.DataFrame:
df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
df.reset_index(inplace=True)

View File

@ -128,7 +128,6 @@ try:
self.name = name
self.post_aggregator = post_aggregator
except NameError:
pass

View File

@ -62,7 +62,9 @@ class EnsureEnabledMixin:
class DruidColumnInlineView( # pylint: disable=too-many-ancestors
CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView,
CompactCRUDMixin,
EnsureEnabledMixin,
SupersetModelView,
):
datamodel = SQLAInterface(models.DruidColumn)
include_route_methods = RouteMethod.RELATED_VIEW_SET
@ -151,7 +153,9 @@ class DruidColumnInlineView( # pylint: disable=too-many-ancestors
class DruidMetricInlineView( # pylint: disable=too-many-ancestors
CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView,
CompactCRUDMixin,
EnsureEnabledMixin,
SupersetModelView,
):
datamodel = SQLAInterface(models.DruidMetric)
include_route_methods = RouteMethod.RELATED_VIEW_SET
@ -206,7 +210,10 @@ class DruidMetricInlineView( # pylint: disable=too-many-ancestors
class DruidClusterModelView( # pylint: disable=too-many-ancestors
EnsureEnabledMixin, SupersetModelView, DeleteMixin, YamlExportMixin,
EnsureEnabledMixin,
SupersetModelView,
DeleteMixin,
YamlExportMixin,
):
datamodel = SQLAInterface(models.DruidCluster)
include_route_methods = RouteMethod.CRUD_SET
@ -270,7 +277,10 @@ class DruidClusterModelView( # pylint: disable=too-many-ancestors
class DruidDatasourceModelView( # pylint: disable=too-many-ancestors
EnsureEnabledMixin, DatasourceModelView, DeleteMixin, YamlExportMixin,
EnsureEnabledMixin,
DatasourceModelView,
DeleteMixin,
YamlExportMixin,
):
datamodel = SQLAInterface(models.DruidDatasource)
include_route_methods = RouteMethod.CRUD_SET

View File

@ -311,7 +311,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return self.table
def get_time_filter(
self, start_dttm: DateTime, end_dttm: DateTime,
self,
start_dttm: DateTime,
end_dttm: DateTime,
) -> ColumnElement:
col = self.get_sqla_col(label="__time")
l = []
@ -687,7 +689,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if self.sql:
return get_virtual_table_metadata(dataset=self)
return get_physical_table_metadata(
database=self.database, table_name=self.table_name, schema_name=self.schema,
database=self.database,
table_name=self.table_name,
schema_name=self.schema,
)
@property
@ -1013,7 +1017,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return all_filters
except TemplateError as ex:
raise QueryObjectValidationError(
_("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,)
_(
"Error in jinja expression in RLS filters: %(msg)s",
msg=ex.message,
)
) from ex
def text(self, clause: str) -> TextClause:
@ -1233,7 +1240,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
):
time_filters.append(
columns_by_name[self.main_dttm_col].get_time_filter(
from_dttm, to_dttm,
from_dttm,
to_dttm,
)
)
time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
@ -1444,7 +1452,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if dttm_col and not db_engine_spec.time_groupby_inline:
inner_time_filter = [
dttm_col.get_time_filter(
inner_from_dttm or from_dttm, inner_to_dttm or to_dttm,
inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
)
]
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
@ -1473,7 +1482,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
orderby = [
(
self._get_series_orderby(
series_limit_metric, metrics_by_name, columns_by_name,
series_limit_metric,
metrics_by_name,
columns_by_name,
),
not order_desc,
)
@ -1549,7 +1560,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return ob
def _normalize_prequery_result_type(
self, row: pd.Series, dimension: str, columns_by_name: Dict[str, TableColumn],
self,
row: pd.Series,
dimension: str,
columns_by_name: Dict[str, TableColumn],
) -> Union[str, int, float, bool, Text]:
"""
Convert a prequery result type to its equivalent Python type.
@ -1594,7 +1608,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
group = []
for dimension in dimensions:
value = self._normalize_prequery_result_type(
row, dimension, columns_by_name,
row,
dimension,
columns_by_name,
)
group.append(groupby_exprs[dimension] == value)
@ -1933,7 +1949,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
@staticmethod
def after_insert(
mapper: Mapper, connection: Connection, target: "SqlaTable",
mapper: Mapper,
connection: Connection,
target: "SqlaTable",
) -> None:
"""
Shadow write the dataset to new models.
@ -1962,7 +1980,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
@staticmethod
def after_delete( # pylint: disable=unused-argument
mapper: Mapper, connection: Connection, target: "SqlaTable",
mapper: Mapper,
connection: Connection,
target: "SqlaTable",
) -> None:
"""
Shadow write the dataset to new models.
@ -1985,7 +2005,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
@staticmethod
def after_update( # pylint: disable=too-many-branches, too-many-locals, too-many-statements
mapper: Mapper, connection: Connection, target: "SqlaTable",
mapper: Mapper,
connection: Connection,
target: "SqlaTable",
) -> None:
"""
Shadow write the dataset to new models.

View File

@ -36,7 +36,9 @@ if TYPE_CHECKING:
def get_physical_table_metadata(
database: Database, table_name: str, schema_name: Optional[str] = None,
database: Database,
table_name: str,
schema_name: Optional[str] = None,
) -> List[Dict[str, str]]:
"""Use SQLAlchemy inspector to get table metadata"""
db_engine_spec = database.db_engine_spec
@ -72,7 +74,11 @@ def get_physical_table_metadata(
# from different drivers that fall outside CompileError
except Exception: # pylint: disable=broad-except
col.update(
{"type": "UNKNOWN", "generic_type": None, "is_dttm": None,}
{
"type": "UNKNOWN",
"generic_type": None,
"is_dttm": None,
}
)
return cols

View File

@ -151,7 +151,8 @@ def import_dashboard(
old_dataset_id = target.get("datasetId")
if dataset_id_mapping and old_dataset_id is not None:
target["datasetId"] = dataset_id_mapping.get(
old_dataset_id, old_dataset_id,
old_dataset_id,
old_dataset_id,
)
dashboard.json_metadata = json.dumps(json_metadata)

View File

@ -85,7 +85,8 @@ class BaseFilterSetCommand:
)
except NotAuthorizedException as err:
raise FilterSetForbiddenError(
str(self._filter_set_id), "user not authorized to access the filterset",
str(self._filter_set_id),
"user not authorized to access the filterset",
) from err
except FilterSetForbiddenError as err:
raise err

View File

@ -46,7 +46,11 @@ class FilterSetSchema(Schema):
class FilterSetPostSchema(FilterSetSchema):
json_metadata_schema: JsonMetadataSchema = JsonMetadataSchema()
# pylint: disable=W0613
name = fields.String(required=True, allow_none=False, validate=Length(0, 500),)
name = fields.String(
required=True,
allow_none=False,
validate=Length(0, 500),
)
description = fields.String(
required=False, allow_none=True, validate=[Length(1, 1000)]
)

View File

@ -170,7 +170,9 @@ class FilterRelatedRoles(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query: Query, value: Optional[Any]) -> Query:
role_model = security_manager.role_model
if value:
return query.filter(role_model.name.ilike(f"%{value}%"),)
return query.filter(
role_model.name.ilike(f"%{value}%"),
)
return query
@ -184,7 +186,15 @@ class DashboardCertifiedFilter(BaseFilter): # pylint: disable=too-few-public-me
def apply(self, query: Query, value: Any) -> Query:
if value is True:
return query.filter(and_(Dashboard.certified_by.isnot(None),))
return query.filter(
and_(
Dashboard.certified_by.isnot(None),
)
)
if value is False:
return query.filter(and_(Dashboard.certified_by.is_(None),))
return query.filter(
and_(
Dashboard.certified_by.is_(None),
)
)
return query

View File

@ -104,14 +104,19 @@ class DashboardPermalinkRestApi(BaseApi):
try:
state = self.add_model_schema.load(request.json)
key = CreateDashboardPermalinkCommand(
actor=g.user, dashboard_id=pk, state=state,
actor=g.user,
dashboard_id=pk,
state=state,
).run()
http_origin = request.headers.environ.get("HTTP_ORIGIN")
url = f"{http_origin}/superset/dashboard/p/{key}/"
return self.response(201, key=key, url=url)
except (ValidationError, DashboardPermalinkInvalidStateError) as ex:
return self.response(400, message=str(ex))
except (DashboardAccessDeniedError, KeyValueAccessDeniedError,) as ex:
except (
DashboardAccessDeniedError,
KeyValueAccessDeniedError,
) as ex:
return self.response(403, message=str(ex))
except DashboardNotFoundError as ex:
return self.response(404, message=str(ex))

View File

@ -31,7 +31,10 @@ logger = logging.getLogger(__name__)
class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
def __init__(
self, actor: User, dashboard_id: str, state: DashboardPermalinkState,
self,
actor: User,
dashboard_id: str,
state: DashboardPermalinkState,
):
self.actor = actor
self.dashboard_id = dashboard_id
@ -46,7 +49,9 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
"state": self.state,
}
key = CreateKeyValueCommand(
actor=self.actor, resource=self.resource, value=value,
actor=self.actor,
resource=self.resource,
value=value,
).run()
return encode_permalink_key(key=key.id, salt=self.salt)
except SQLAlchemyError as ex:

View File

@ -19,7 +19,9 @@ from marshmallow import fields, Schema
class DashboardPermalinkPostSchema(Schema):
filterState = fields.Dict(
required=False, allow_none=True, description="Native filter state",
required=False,
allow_none=True,
description="Native filter state",
)
urlParams = fields.List(
fields.Tuple(

View File

@ -243,7 +243,8 @@ class DashboardPostSchema(BaseDashboardSchema):
)
css = fields.String()
json_metadata = fields.String(
description=json_metadata_description, validate=validate_json_metadata,
description=json_metadata_description,
validate=validate_json_metadata,
)
published = fields.Boolean(description=published_description)
certified_by = fields.String(description=certified_by_description, allow_none=True)

View File

@ -881,7 +881,10 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
database = DatabaseDAO.find_by_id(pk)
if not database:
return self.response_404()
return self.response(200, function_names=database.function_names,)
return self.response(
200,
function_names=database.function_names,
)
@expose("/available/", methods=["GET"])
@protect()

View File

@ -47,7 +47,8 @@ class DatabaseExistsValidationError(ValidationError):
class DatabaseRequiredFieldValidationError(ValidationError):
def __init__(self, field_name: str) -> None:
super().__init__(
[_("Field is required")], field_name=field_name,
[_("Field is required")],
field_name=field_name,
)
@ -100,7 +101,8 @@ class DatabaseUpdateFailedError(UpdateFailedError):
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
DatabaseCreateFailedError, DatabaseUpdateFailedError,
DatabaseCreateFailedError,
DatabaseUpdateFailedError,
):
message = _("Connection failed, please check your connection settings")

View File

@ -57,7 +57,8 @@ class ValidateDatabaseParametersCommand(BaseCommand):
raise InvalidEngineError(
SupersetError(
message=__(
'Engine "%(engine)s" is not a valid engine.', engine=engine,
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
@ -101,7 +102,8 @@ class ValidateDatabaseParametersCommand(BaseCommand):
# try to connect
sqlalchemy_uri = engine_spec.build_sqlalchemy_uri( # type: ignore
self._properties.get("parameters"), encrypted_extra,
self._properties.get("parameters"),
encrypted_extra,
)
if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri():
sqlalchemy_uri = self._model.sqlalchemy_uri_decrypted

View File

@ -42,7 +42,8 @@ class DatabaseDAO(BaseDAO):
@staticmethod
def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
database_query = db.session.query(Database).filter(
Database.database_name == database_name, Database.id != database_id,
Database.database_name == database_name,
Database.id != database_id,
)
return not db.session.query(database_query.exists()).scalar()

View File

@ -27,7 +27,8 @@ class DatabaseFilter(BaseFilter):
# TODO(bogdan): consider caching.
def can_access_databases( # noqa pylint: disable=no-self-use
self, view_menu_name: str,
self,
view_menu_name: str,
) -> Set[str]:
return {
security_manager.unpack_database_and_schema(vm).database

View File

@ -308,7 +308,12 @@ def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
engine_specs = get_engine_specs()
if engine not in engine_specs:
raise ValidationError(
[_('Engine "%(engine)s" is not a valid engine.', engine=engine,)]
[
_(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
)
]
)
return engine_specs[engine]
@ -324,7 +329,9 @@ class DatabaseValidateParametersSchema(Schema):
description="DB-specific parameters for configuration",
)
database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250),
description=database_name_description,
allow_none=True,
validate=Length(1, 250),
)
impersonate_user = fields.Boolean(description=impersonate_user_description)
extra = fields.String(description=extra_description, validate=extra_validator)
@ -351,7 +358,9 @@ class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
unknown = EXCLUDE
database_name = fields.String(
description=database_name_description, required=True, validate=Length(1, 250),
description=database_name_description,
required=True,
validate=Length(1, 250),
)
cache_timeout = fields.Integer(
description=cache_timeout_description, allow_none=True
@ -395,7 +404,9 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
unknown = EXCLUDE
database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250),
description=database_name_description,
allow_none=True,
validate=Length(1, 250),
)
cache_timeout = fields.Integer(
description=cache_timeout_description, allow_none=True
@ -436,7 +447,9 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250),
description=database_name_description,
allow_none=True,
validate=Length(1, 250),
)
impersonate_user = fields.Boolean(description=impersonate_user_description)
extra = fields.String(description=extra_description, validate=extra_validator)

View File

@ -285,7 +285,10 @@ class ImportDatasetsCommand(BaseCommand):
# pylint: disable=unused-argument
def __init__(
self, contents: Dict[str, str], *args: Any, **kwargs: Any,
self,
contents: Dict[str, str],
*args: Any,
**kwargs: Any,
):
self.contents = contents
self._configs: Dict[str, Any] = {}

View File

@ -65,7 +65,8 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
if self._model:
try:
dataset = DatasetDAO.update(
model=self._model, properties=self._properties,
model=self._model,
properties=self._properties,
)
return dataset
except DAOUpdateFailedError as ex:

View File

@ -205,7 +205,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.BigInteger(),
GenericDataType.NUMERIC,
),
(re.compile(r"^long", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,),
(
re.compile(r"^long", re.IGNORECASE),
types.Float(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal", re.IGNORECASE),
types.Numeric(),
@ -216,13 +220,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.Numeric(),
GenericDataType.NUMERIC,
),
(re.compile(r"^float", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,),
(
re.compile(r"^float", re.IGNORECASE),
types.Float(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^double", re.IGNORECASE),
types.Float(),
GenericDataType.NUMERIC,
),
(re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,),
(
re.compile(r"^real", re.IGNORECASE),
types.REAL,
GenericDataType.NUMERIC,
),
(
re.compile(r"^smallserial", re.IGNORECASE),
types.SmallInteger(),
@ -258,7 +270,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.DateTime(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
(
re.compile(r"^time", re.IGNORECASE),
types.Time(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^interval", re.IGNORECASE),
types.Interval(),
@ -351,7 +367,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_allow_cost_estimate( # pylint: disable=unused-argument
cls, extra: Dict[str, Any],
cls,
extra: Dict[str, Any],
) -> bool:
return False
@ -618,7 +635,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def extra_table_metadata( # pylint: disable=unused-argument
cls, database: "Database", table_name: str, schema_name: str,
cls,
database: "Database",
table_name: str,
schema_name: str,
) -> Dict[str, Any]:
"""
Returns engine-specific table metadata
@ -944,7 +964,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_table_names( # pylint: disable=unused-argument
cls, database: "Database", inspector: Inspector, schema: Optional[str],
cls,
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
"""
Get all tables from schema
@ -961,7 +984,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_view_names( # pylint: disable=unused-argument
cls, database: "Database", inspector: Inspector, schema: Optional[str],
cls,
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
"""
Get all views from schema
@ -1193,7 +1219,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
cls,
connect_args: Dict[str, Any],
uri: str,
username: Optional[str],
) -> None:
"""
Update a configuration dictionary
@ -1207,7 +1236,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def execute( # pylint: disable=unused-argument
cls, cursor: Any, query: str, **kwargs: Any,
cls,
cursor: Any,
query: str,
**kwargs: Any,
) -> None:
"""
Execute a SQL query
@ -1333,7 +1365,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_function_names( # pylint: disable=unused-argument
cls, database: "Database",
cls,
database: "Database",
) -> List[str]:
"""
Get a list of function names that are able to be called on the database.
@ -1471,7 +1504,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_cancel_query_id( # pylint: disable=unused-argument
cls, cursor: Any, query: Query,
cls,
cursor: Any,
query: Query,
) -> Optional[str]:
"""
Select identifiers from the database engine that uniquely identifies the
@ -1487,7 +1522,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def cancel_query( # pylint: disable=unused-argument
cls, cursor: Any, query: Query, cancel_query_id: str,
cls,
cursor: Any,
query: Query,
cancel_query_id: str,
) -> bool:
"""
Cancel query in the underlying database.

View File

@ -72,7 +72,8 @@ ma_plugin = MarshmallowPlugin()
class BigQueryParametersSchema(Schema):
credentials_info = EncryptedString(
required=False, description="Contents of BigQuery JSON credentials.",
required=False,
description="Contents of BigQuery JSON credentials.",
)
query = fields.Dict(required=False)

View File

@ -82,7 +82,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
@classmethod
def modify_url_for_impersonation(
cls, url: URL, impersonate_user: bool, username: Optional[str],
cls,
url: URL,
impersonate_user: bool,
username: Optional[str],
) -> None:
if impersonate_user and username is not None:
user = security_manager.find_user(username=username)
@ -91,7 +94,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
@classmethod
def extra_table_metadata(
cls, database: "Database", table_name: str, schema_name: str,
cls,
database: "Database",
table_name: str,
schema_name: str,
) -> Dict[str, Any]:
engine = cls.get_engine(database, schema=schema_name)
with closing(engine.raw_connection()) as conn:
@ -150,7 +156,8 @@ class GSheetsEngineSpec(SqliteEngineSpec):
@classmethod
def validate_parameters(
cls, parameters: GSheetsParametersType,
cls,
parameters: GSheetsParametersType,
) -> List[SupersetError]:
errors: List[SupersetError] = []
encrypted_credentials = parameters.get("service_account_info") or "{}"
@ -173,7 +180,9 @@ class GSheetsEngineSpec(SqliteEngineSpec):
subject = g.user.email if g.user else None
engine = create_engine(
"gsheets://", service_account_info=encrypted_credentials, subject=subject,
"gsheets://",
service_account_info=encrypted_credentials,
subject=subject,
)
conn = engine.connect()
idx = 0

View File

@ -496,7 +496,10 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
cls,
connect_args: Dict[str, Any],
uri: str,
username: Optional[str],
) -> None:
"""
Update a configuration dictionary

View File

@ -74,24 +74,56 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
encryption_parameters = {"ssl": "1"}
column_type_mappings = (
(re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,),
(re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,),
(
re.compile(r"^int.*", re.IGNORECASE),
INTEGER(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^tinyint", re.IGNORECASE),
TINYINT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^mediumint", re.IGNORECASE),
MEDIUMINT(),
GenericDataType.NUMERIC,
),
(re.compile(r"^decimal", re.IGNORECASE), DECIMAL(), GenericDataType.NUMERIC,),
(re.compile(r"^float", re.IGNORECASE), FLOAT(), GenericDataType.NUMERIC,),
(re.compile(r"^double", re.IGNORECASE), DOUBLE(), GenericDataType.NUMERIC,),
(re.compile(r"^bit", re.IGNORECASE), BIT(), GenericDataType.NUMERIC,),
(re.compile(r"^tinytext", re.IGNORECASE), TINYTEXT(), GenericDataType.STRING,),
(
re.compile(r"^decimal", re.IGNORECASE),
DECIMAL(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^float", re.IGNORECASE),
FLOAT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^double", re.IGNORECASE),
DOUBLE(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bit", re.IGNORECASE),
BIT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^tinytext", re.IGNORECASE),
TINYTEXT(),
GenericDataType.STRING,
),
(
re.compile(r"^mediumtext", re.IGNORECASE),
MEDIUMTEXT(),
GenericDataType.STRING,
),
(re.compile(r"^longtext", re.IGNORECASE), LONGTEXT(), GenericDataType.STRING,),
(
re.compile(r"^longtext", re.IGNORECASE),
LONGTEXT(),
GenericDataType.STRING,
),
)
_time_grain_expressions = {

View File

@ -188,8 +188,16 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
lambda match: ARRAY(int(match[2])) if match[2] else String(),
GenericDataType.STRING,
),
(re.compile(r"^json.*", re.IGNORECASE), JSON(), GenericDataType.STRING,),
(re.compile(r"^enum.*", re.IGNORECASE), ENUM(), GenericDataType.STRING,),
(
re.compile(r"^json.*", re.IGNORECASE),
JSON(),
GenericDataType.STRING,
),
(
re.compile(r"^enum.*", re.IGNORECASE),
ENUM(),
GenericDataType.STRING,
),
)
@classmethod

View File

@ -214,7 +214,10 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
@classmethod
def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
cls,
connect_args: Dict[str, Any],
uri: str,
username: Optional[str],
) -> None:
"""
Update a configuration dictionary
@ -487,7 +490,11 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
types.VARBINARY(),
GenericDataType.STRING,
),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON(), GenericDataType.STRING,),
(
re.compile(r"^json.*", re.IGNORECASE),
types.JSON(),
GenericDataType.STRING,
),
(
re.compile(r"^date.*", re.IGNORECASE),
types.DATETIME(),

View File

@ -94,7 +94,10 @@ class TrinoEngineSpec(BaseEngineSpec):
@classmethod
def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
cls,
connect_args: Dict[str, Any],
uri: str,
username: Optional[str],
) -> None:
"""
Update a configuration dictionary

View File

@ -186,8 +186,16 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
default_query_context = {
"result_format": "json",
"result_type": "full",
"datasource": {"id": tbl.id, "type": "table",},
"queries": [{"columns": [], "metrics": [],},],
"datasource": {
"id": tbl.id,
"type": "table",
},
"queries": [
{
"columns": [],
"metrics": [],
},
],
}
admin = get_admin_user()
@ -381,7 +389,12 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
),
query_context=get_slice_json(
default_query_context,
queries=[{"columns": ["name", "state"], "metrics": [metric],}],
queries=[
{
"columns": ["name", "state"],
"metrics": [metric],
}
],
),
),
]

View File

@ -43,7 +43,9 @@ from .helpers import (
def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-statements
only_metadata: bool = False, force: bool = False, sample: bool = False,
only_metadata: bool = False,
force: bool = False,
sample: bool = False,
) -> None:
"""Loads the world bank health dataset, slices and a dashboard"""
tbl_name = "wb_health_population"

View File

@ -129,7 +129,10 @@ class SupersetGenericDBErrorException(SupersetErrorFromParamsException):
extra: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
SupersetErrorType.GENERIC_DB_ENGINE_ERROR, message, level, extra,
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
message,
level,
extra,
)
@ -144,7 +147,10 @@ class SupersetTemplateParamsErrorException(SupersetErrorFromParamsException):
extra: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
error, message, level, extra,
error,
message,
level,
extra,
)

View File

@ -39,7 +39,8 @@ logger = logging.getLogger(__name__)
class UpdateFormDataCommand(BaseCommand, ABC):
def __init__(
self, cmd_params: CommandParameters,
self,
cmd_params: CommandParameters,
):
self._cmd_params = cmd_params

View File

@ -162,7 +162,10 @@ class ExplorePermalinkRestApi(BaseApi):
return self.response(200, **value)
except ExplorePermalinkInvalidStateError as ex:
return self.response(400, message=str(ex))
except (ChartAccessDeniedError, DatasetAccessDeniedError,) as ex:
except (
ChartAccessDeniedError,
DatasetAccessDeniedError,
) as ex:
return self.response(403, message=str(ex))
except (ChartNotFoundError, DatasetNotFoundError) as ex:
return self.response(404, message=str(ex))

View File

@ -48,7 +48,9 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
"state": self.state,
}
command = CreateKeyValueCommand(
actor=self.actor, resource=self.resource, value=value,
actor=self.actor,
resource=self.resource,
value=value,
)
key = command.run()
return encode_permalink_key(key=key.id, salt=self.salt)

View File

@ -42,7 +42,8 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
try:
key = decode_permalink_id(self.key, salt=self.salt)
value: Optional[ExplorePermalinkValue] = GetKeyValueCommand(
resource=self.resource, key=key,
resource=self.resource,
key=key,
).run()
if value:
chart_id: Optional[int] = value.get("chartId")

View File

@ -19,7 +19,9 @@ from marshmallow import fields, Schema
class ExplorePermalinkPostSchema(Schema):
formData = fields.Dict(
required=True, allow_none=False, description="Chart form data",
required=True,
allow_none=False,
description="Chart form data",
)
urlParams = fields.List(
fields.Tuple(

View File

@ -355,7 +355,10 @@ def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
return_value = json.loads(json.dumps(return_value))
except TypeError as ex:
raise SupersetTemplateException(
_("Unsupported return value for method %(name)s", name=func.__name__,)
_(
"Unsupported return value for method %(name)s",
name=func.__name__,
)
) from ex
return return_value

View File

@ -214,7 +214,9 @@ def _delete_old_permissions(
def migrate_roles(
session: Session, pvm_key_map: PvmMigrationMapType, commit: bool = False,
session: Session,
pvm_key_map: PvmMigrationMapType,
commit: bool = False,
) -> None:
"""
Migrates all existing roles that have the permissions to be migrated

View File

@ -91,7 +91,8 @@ def downgrade():
for dashboard in session.query(Dashboard).all():
logger.info(
"[RemoveTypeToNativeFilter] Updating Dashobard<pk:%s>", dashboard.id,
"[RemoveTypeToNativeFilter] Updating Dashobard<pk:%s>",
dashboard.id,
)
if not dashboard.json_metadata:
logger.info(

View File

@ -39,24 +39,63 @@ from superset.migrations.shared.security_converge import (
Pvm,
)
NEW_PVMS = {"Dashboard": ("can_read", "can_write",)}
NEW_PVMS = {
"Dashboard": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("DashboardModelView", "can_add"): (Pvm("Dashboard", "can_write"),),
Pvm("DashboardModelView", "can_delete"): (Pvm("Dashboard", "can_write"),),
Pvm("DashboardModelView", "can_download_dashboards",): (
Pvm("Dashboard", "can_read"),
),
Pvm("DashboardModelView", "can_edit",): (Pvm("Dashboard", "can_write"),),
Pvm("DashboardModelView", "can_favorite_status",): (Pvm("Dashboard", "can_read"),),
Pvm("DashboardModelView", "can_list",): (Pvm("Dashboard", "can_read"),),
Pvm("DashboardModelView", "can_mulexport",): (Pvm("Dashboard", "can_read"),),
Pvm("DashboardModelView", "can_show",): (Pvm("Dashboard", "can_read"),),
Pvm("DashboardModelView", "muldelete",): (Pvm("Dashboard", "can_write"),),
Pvm("DashboardModelView", "mulexport",): (Pvm("Dashboard", "can_read"),),
Pvm("DashboardModelViewAsync", "can_list",): (Pvm("Dashboard", "can_read"),),
Pvm("DashboardModelViewAsync", "muldelete",): (Pvm("Dashboard", "can_write"),),
Pvm("DashboardModelViewAsync", "mulexport",): (Pvm("Dashboard", "can_read"),),
Pvm("Dashboard", "can_new",): (Pvm("Dashboard", "can_write"),),
Pvm(
"DashboardModelView",
"can_download_dashboards",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelView",
"can_edit",
): (Pvm("Dashboard", "can_write"),),
Pvm(
"DashboardModelView",
"can_favorite_status",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelView",
"can_list",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelView",
"can_mulexport",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelView",
"can_show",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelView",
"muldelete",
): (Pvm("Dashboard", "can_write"),),
Pvm(
"DashboardModelView",
"mulexport",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelViewAsync",
"can_list",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"DashboardModelViewAsync",
"muldelete",
): (Pvm("Dashboard", "can_write"),),
Pvm(
"DashboardModelViewAsync",
"mulexport",
): (Pvm("Dashboard", "can_read"),),
Pvm(
"Dashboard",
"can_new",
): (Pvm("Dashboard", "can_write"),),
}

View File

@ -43,9 +43,18 @@ def upgrade():
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.Column("alert_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],),
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],),
sa.ForeignKeyConstraint(
["alert_id"],
["alerts.id"],
),
sa.ForeignKeyConstraint(
["changed_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["created_by_fk"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
@ -58,10 +67,22 @@ def upgrade():
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.Column("alert_id", sa.Integer(), nullable=False),
sa.Column("database_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],),
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],),
sa.ForeignKeyConstraint(["database_id"], ["dbs.id"],),
sa.ForeignKeyConstraint(
["alert_id"],
["alerts.id"],
),
sa.ForeignKeyConstraint(
["changed_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["created_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["database_id"],
["dbs.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
@ -72,8 +93,14 @@ def upgrade():
sa.Column("alert_id", sa.Integer(), nullable=True),
sa.Column("value", sa.Float(), nullable=True),
sa.Column("error_msg", sa.String(length=500), nullable=True),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],),
sa.ForeignKeyConstraint(["observer_id"], ["sql_observers.id"],),
sa.ForeignKeyConstraint(
["alert_id"],
["alerts.id"],
),
sa.ForeignKeyConstraint(
["observer_id"],
["sql_observers.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(

View File

@ -49,8 +49,14 @@ def upgrade():
sa.Column("dashboard_id", sa.Integer(), nullable=True),
sa.Column("last_eval_dttm", sa.DateTime(), nullable=True),
sa.Column("last_state", sa.String(length=10), nullable=True),
sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"],),
sa.ForeignKeyConstraint(["slice_id"], ["slices.id"],),
sa.ForeignKeyConstraint(
["dashboard_id"],
["dashboards.id"],
),
sa.ForeignKeyConstraint(
["slice_id"],
["slices.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_alerts_active"), "alerts", ["active"], unique=False)
@ -62,7 +68,10 @@ def upgrade():
sa.Column("dttm_end", sa.DateTime(), nullable=True),
sa.Column("alert_id", sa.Integer(), nullable=True),
sa.Column("state", sa.String(length=10), nullable=True),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],),
sa.ForeignKeyConstraint(
["alert_id"],
["alerts.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
@ -70,8 +79,14 @@ def upgrade():
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("alert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],),
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"],),
sa.ForeignKeyConstraint(
["alert_id"],
["alerts.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@ -34,7 +34,9 @@ def upgrade():
with op.batch_alter_table("report_schedule") as batch_op:
batch_op.add_column(
sa.Column(
"creation_method", sa.VARCHAR(255), server_default="alerts_reports",
"creation_method",
sa.VARCHAR(255),
server_default="alerts_reports",
)
)
batch_op.create_index(

View File

@ -39,13 +39,27 @@ from superset.migrations.shared.security_converge import (
Pvm,
)
NEW_PVMS = {"ReportSchedule": ("can_read", "can_write",)}
NEW_PVMS = {
"ReportSchedule": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("ReportSchedule", "can_list"): (Pvm("ReportSchedule", "can_read"),),
Pvm("ReportSchedule", "can_show"): (Pvm("ReportSchedule", "can_read"),),
Pvm("ReportSchedule", "can_add",): (Pvm("ReportSchedule", "can_write"),),
Pvm("ReportSchedule", "can_edit",): (Pvm("ReportSchedule", "can_write"),),
Pvm("ReportSchedule", "can_delete",): (Pvm("ReportSchedule", "can_write"),),
Pvm(
"ReportSchedule",
"can_add",
): (Pvm("ReportSchedule", "can_write"),),
Pvm(
"ReportSchedule",
"can_edit",
): (Pvm("ReportSchedule", "can_write"),),
Pvm(
"ReportSchedule",
"can_delete",
): (Pvm("ReportSchedule", "can_write"),),
}

View File

@ -39,17 +39,43 @@ from superset.migrations.shared.security_converge import (
Pvm,
)
NEW_PVMS = {"Database": ("can_read", "can_write",)}
NEW_PVMS = {
"Database": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("DatabaseView", "can_add"): (Pvm("Database", "can_write"),),
Pvm("DatabaseView", "can_delete"): (Pvm("Database", "can_write"),),
Pvm("DatabaseView", "can_edit",): (Pvm("Database", "can_write"),),
Pvm("DatabaseView", "can_list",): (Pvm("Database", "can_read"),),
Pvm("DatabaseView", "can_mulexport",): (Pvm("Database", "can_read"),),
Pvm("DatabaseView", "can_post",): (Pvm("Database", "can_write"),),
Pvm("DatabaseView", "can_show",): (Pvm("Database", "can_read"),),
Pvm("DatabaseView", "muldelete",): (Pvm("Database", "can_write"),),
Pvm("DatabaseView", "yaml_export",): (Pvm("Database", "can_read"),),
Pvm(
"DatabaseView",
"can_edit",
): (Pvm("Database", "can_write"),),
Pvm(
"DatabaseView",
"can_list",
): (Pvm("Database", "can_read"),),
Pvm(
"DatabaseView",
"can_mulexport",
): (Pvm("Database", "can_read"),),
Pvm(
"DatabaseView",
"can_post",
): (Pvm("Database", "can_write"),),
Pvm(
"DatabaseView",
"can_show",
): (Pvm("Database", "can_read"),),
Pvm(
"DatabaseView",
"muldelete",
): (Pvm("Database", "can_write"),),
Pvm(
"DatabaseView",
"yaml_export",
): (Pvm("Database", "can_read"),),
}

View File

@ -38,7 +38,12 @@ from superset.migrations.shared.security_converge import (
Pvm,
)
NEW_PVMS = {"Dataset": ("can_read", "can_write",)}
NEW_PVMS = {
"Dataset": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("SqlMetricInlineView", "can_add"): (Pvm("Dataset", "can_write"),),
Pvm("SqlMetricInlineView", "can_delete"): (Pvm("Dataset", "can_write"),),
@ -50,15 +55,33 @@ PVM_MAP = {
Pvm("TableColumnInlineView", "can_edit"): (Pvm("Dataset", "can_write"),),
Pvm("TableColumnInlineView", "can_list"): (Pvm("Dataset", "can_read"),),
Pvm("TableColumnInlineView", "can_show"): (Pvm("Dataset", "can_read"),),
Pvm("TableModelView", "can_add",): (Pvm("Dataset", "can_write"),),
Pvm("TableModelView", "can_delete",): (Pvm("Dataset", "can_write"),),
Pvm("TableModelView", "can_edit",): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"can_add",
): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"can_delete",
): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"can_edit",
): (Pvm("Dataset", "can_write"),),
Pvm("TableModelView", "can_list"): (Pvm("Dataset", "can_read"),),
Pvm("TableModelView", "can_mulexport"): (Pvm("Dataset", "can_read"),),
Pvm("TableModelView", "can_show"): (Pvm("Dataset", "can_read"),),
Pvm("TableModelView", "muldelete",): (Pvm("Dataset", "can_write"),),
Pvm("TableModelView", "refresh",): (Pvm("Dataset", "can_write"),),
Pvm("TableModelView", "yaml_export",): (Pvm("Dataset", "can_read"),),
Pvm(
"TableModelView",
"muldelete",
): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"refresh",
): (Pvm("Dataset", "can_write"),),
Pvm(
"TableModelView",
"yaml_export",
): (Pvm("Dataset", "can_read"),),
}

View File

@ -114,8 +114,14 @@ def upgrade():
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("report_schedule_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["report_schedule_id"], ["report_schedule.id"],),
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"],),
sa.ForeignKeyConstraint(
["report_schedule_id"],
["report_schedule.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@ -37,10 +37,18 @@ from superset.migrations.shared.security_converge import (
revision = "4b84f97828aa"
down_revision = "45731db65d9c"
NEW_PVMS = {"Log": ("can_read", "can_write",)}
NEW_PVMS = {
"Log": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("LogModelView", "can_show"): (Pvm("Log", "can_read"),),
Pvm("LogModelView", "can_add",): (Pvm("Log", "can_write"),),
Pvm(
"LogModelView",
"can_add",
): (Pvm("Log", "can_write"),),
Pvm("LogModelView", "can_list"): (Pvm("Log", "can_read"),),
}

View File

@ -62,5 +62,8 @@ def downgrade():
)
batch_op.create_foreign_key(
"saved_query_id", "saved_query", ["saved_query_id"], ["id"],
"saved_query_id",
"saved_query",
["saved_query_id"],
["id"],
)

View File

@ -42,8 +42,14 @@ def upgrade():
sa.Column("bundle_url", sa.String(length=1000), nullable=False),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],),
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],),
sa.ForeignKeyConstraint(
["changed_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["created_by_fk"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("key"),
sa.UniqueConstraint("name"),

View File

@ -39,16 +39,39 @@ from superset.migrations.shared.security_converge import (
Pvm,
)
NEW_PVMS = {"CssTemplate": ("can_read", "can_write",)}
NEW_PVMS = {
"CssTemplate": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("CssTemplateModelView", "can_list"): (Pvm("CssTemplate", "can_read"),),
Pvm("CssTemplateModelView", "can_show"): (Pvm("CssTemplate", "can_read"),),
Pvm("CssTemplateModelView", "can_add",): (Pvm("CssTemplate", "can_write"),),
Pvm("CssTemplateModelView", "can_edit",): (Pvm("CssTemplate", "can_write"),),
Pvm("CssTemplateModelView", "can_delete",): (Pvm("CssTemplate", "can_write"),),
Pvm("CssTemplateModelView", "muldelete",): (Pvm("CssTemplate", "can_write"),),
Pvm("CssTemplateAsyncModelView", "can_list",): (Pvm("CssTemplate", "can_read"),),
Pvm("CssTemplateAsyncModelView", "muldelete",): (Pvm("CssTemplate", "can_write"),),
Pvm(
"CssTemplateModelView",
"can_add",
): (Pvm("CssTemplate", "can_write"),),
Pvm(
"CssTemplateModelView",
"can_edit",
): (Pvm("CssTemplate", "can_write"),),
Pvm(
"CssTemplateModelView",
"can_delete",
): (Pvm("CssTemplate", "can_write"),),
Pvm(
"CssTemplateModelView",
"muldelete",
): (Pvm("CssTemplate", "can_write"),),
Pvm(
"CssTemplateAsyncModelView",
"can_list",
): (Pvm("CssTemplate", "can_read"),),
Pvm(
"CssTemplateAsyncModelView",
"muldelete",
): (Pvm("CssTemplate", "can_write"),),
}

View File

@ -65,7 +65,10 @@ def upgrade():
with op.batch_alter_table("saved_query") as batch_op:
batch_op.add_column(
sa.Column(
"uuid", UUIDType(binary=True), primary_key=False, default=uuid4,
"uuid",
UUIDType(binary=True),
primary_key=False,
default=uuid4,
),
)
except OperationalError:

View File

@ -159,7 +159,10 @@ def upgrade():
for key_to_remove in keys_to_remove:
del position_dict[key_to_remove]
dashboard.position_json = json.dumps(
position_dict, indent=None, separators=(",", ":"), sort_keys=True,
position_dict,
indent=None,
separators=(",", ":"),
sort_keys=True,
)
session.merge(dashboard)

View File

@ -39,7 +39,12 @@ def upgrade():
with op.batch_alter_table("report_schedule") as batch_op:
batch_op.add_column(
sa.Column("extra", sa.Text(), nullable=True, default="{}",),
sa.Column(
"extra",
sa.Text(),
nullable=True,
default="{}",
),
)
bind.execute(report_schedule.update().values({"extra": "{}"}))
with op.batch_alter_table("report_schedule") as batch_op:

View File

@ -118,7 +118,8 @@ def upgrade():
sa.Column("validator_config", sa.Text(), default="", nullable=True),
)
op.add_column(
"alerts", sa.Column("database_id", sa.Integer(), default=0, nullable=False),
"alerts",
sa.Column("database_id", sa.Integer(), default=0, nullable=False),
)
op.add_column("alerts", sa.Column("sql", sa.Text(), default="", nullable=False))
op.add_column(
@ -159,7 +160,10 @@ def upgrade():
sa.Column("alert_id", sa.Integer(), nullable=True),
sa.Column("value", sa.Float(), nullable=True),
sa.Column("error_msg", sa.String(length=500), nullable=True),
sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],),
sa.ForeignKeyConstraint(
["alert_id"],
["alerts.id"],
),
sa.PrimaryKeyConstraint("id"),
)
else:
@ -192,7 +196,11 @@ def downgrade():
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("validator_type", sa.String(length=100), nullable=False,),
sa.Column(
"validator_type",
sa.String(length=100),
nullable=False,
),
sa.Column("config", sa.Text(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), autoincrement=False, nullable=True),
sa.Column("changed_by_fk", sa.Integer(), autoincrement=False, nullable=True),
@ -261,10 +269,22 @@ def downgrade():
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("slack_channel", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"],),
sa.ForeignKeyConstraint(["slice_id"], ["slices.id"],),
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],),
sa.ForeignKeyConstraint(
["dashboard_id"],
["dashboards.id"],
),
sa.ForeignKeyConstraint(
["slice_id"],
["slices.id"],
),
sa.ForeignKeyConstraint(
["created_by_fk"],
["ab_user.id"],
),
sa.ForeignKeyConstraint(
["changed_by_fk"],
["ab_user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
else:

View File

@ -42,6 +42,6 @@ def upgrade():
def downgrade():
try:
# Trying since sqlite doesn't like constraints
op.drop_constraint(u"_customer_location_uc", "tables", type_="unique")
op.drop_constraint("_customer_location_uc", "tables", type_="unique")
except Exception:
pass

View File

@ -171,7 +171,10 @@ def upgrade():
with op.batch_alter_table(table_name) as batch_op:
batch_op.add_column(
sa.Column(
"uuid", UUIDType(binary=True), primary_key=False, default=uuid4,
"uuid",
UUIDType(binary=True),
primary_key=False,
default=uuid4,
),
)

View File

@ -36,7 +36,8 @@ def upgrade():
kwargs: Dict[str, str] = {}
bind = op.get_bind()
op.add_column(
"dbs", sa.Column("server_cert", sa.LargeBinary(), nullable=True, **kwargs),
"dbs",
sa.Column("server_cert", sa.LargeBinary(), nullable=True, **kwargs),
)

View File

@ -379,16 +379,46 @@ def upgrade():
sa.Column("name", sa.TEXT(), nullable=False),
sa.Column("type", sa.TEXT(), nullable=False),
sa.Column("expression", sa.TEXT(), nullable=False),
sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=True,),
sa.Column(
"is_physical",
sa.BOOLEAN(),
nullable=False,
default=True,
),
sa.Column("description", sa.TEXT(), nullable=True),
sa.Column("warning_text", sa.TEXT(), nullable=True),
sa.Column("unit", sa.TEXT(), nullable=True),
sa.Column("is_temporal", sa.BOOLEAN(), nullable=False),
sa.Column("is_spatial", sa.BOOLEAN(), nullable=False, default=False,),
sa.Column("is_partition", sa.BOOLEAN(), nullable=False, default=False,),
sa.Column("is_aggregation", sa.BOOLEAN(), nullable=False, default=False,),
sa.Column("is_additive", sa.BOOLEAN(), nullable=False, default=False,),
sa.Column("is_increase_desired", sa.BOOLEAN(), nullable=False, default=True,),
sa.Column(
"is_spatial",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_partition",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_aggregation",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_additive",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_increase_desired",
sa.BOOLEAN(),
nullable=False,
default=True,
),
sa.Column(
"is_managed_externally",
sa.Boolean(),
@ -459,7 +489,12 @@ def upgrade():
sa.Column("sqlatable_id", sa.INTEGER(), nullable=True),
sa.Column("name", sa.TEXT(), nullable=False),
sa.Column("expression", sa.TEXT(), nullable=False),
sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=False,),
sa.Column(
"is_physical",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_managed_externally",
sa.Boolean(),

View File

@ -39,19 +39,51 @@ revision = "c25cb2c78727"
down_revision = "ccb74baaa89b"
NEW_PVMS = {"Annotation": ("can_read", "can_write",)}
NEW_PVMS = {
"Annotation": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("AnnotationLayerModelView", "can_delete"): (Pvm("Annotation", "can_write"),),
Pvm("AnnotationLayerModelView", "can_list"): (Pvm("Annotation", "can_read"),),
Pvm("AnnotationLayerModelView", "can_show",): (Pvm("Annotation", "can_read"),),
Pvm("AnnotationLayerModelView", "can_add",): (Pvm("Annotation", "can_write"),),
Pvm("AnnotationLayerModelView", "can_edit",): (Pvm("Annotation", "can_write"),),
Pvm("AnnotationModelView", "can_annotation",): (Pvm("Annotation", "can_read"),),
Pvm("AnnotationModelView", "can_show",): (Pvm("Annotation", "can_read"),),
Pvm("AnnotationModelView", "can_add",): (Pvm("Annotation", "can_write"),),
Pvm("AnnotationModelView", "can_delete",): (Pvm("Annotation", "can_write"),),
Pvm("AnnotationModelView", "can_edit",): (Pvm("Annotation", "can_write"),),
Pvm("AnnotationModelView", "can_list",): (Pvm("Annotation", "can_read"),),
Pvm(
"AnnotationLayerModelView",
"can_show",
): (Pvm("Annotation", "can_read"),),
Pvm(
"AnnotationLayerModelView",
"can_add",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationLayerModelView",
"can_edit",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationModelView",
"can_annotation",
): (Pvm("Annotation", "can_read"),),
Pvm(
"AnnotationModelView",
"can_show",
): (Pvm("Annotation", "can_read"),),
Pvm(
"AnnotationModelView",
"can_add",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationModelView",
"can_delete",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationModelView",
"can_edit",
): (Pvm("Annotation", "can_write"),),
Pvm(
"AnnotationModelView",
"can_list",
): (Pvm("Annotation", "can_read"),),
}

View File

@ -67,7 +67,10 @@ def upgrade():
with op.batch_alter_table(table_name) as batch_op:
batch_op.add_column(
sa.Column(
"uuid", UUIDType(binary=True), primary_key=False, default=uuid4,
"uuid",
UUIDType(binary=True),
primary_key=False,
default=uuid4,
),
)
add_uuids(model, table_name, session)

View File

@ -52,7 +52,10 @@ class AuditMixinNullable(AuditMixin):
@declared_attr
def created_by_fk(self) -> Column:
return Column(
Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True,
Integer,
ForeignKey("ab_user.id"),
default=self.get_user_id,
nullable=True,
)
@declared_attr

View File

@ -80,7 +80,8 @@ def upgrade():
if isinstance(bind.dialect, MySQLDialect):
op.drop_index(
op.f("name"), table_name="report_schedule",
op.f("name"),
table_name="report_schedule",
)
if isinstance(bind.dialect, PGDialect):

View File

@ -39,22 +39,63 @@ from superset.migrations.shared.security_converge import (
Pvm,
)
NEW_PVMS = {"Chart": ("can_read", "can_write",)}
NEW_PVMS = {
"Chart": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("SliceModelView", "can_list"): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "can_show"): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "can_edit",): (Pvm("Chart", "can_write"),),
Pvm("SliceModelView", "can_delete",): (Pvm("Chart", "can_write"),),
Pvm("SliceModelView", "can_add",): (Pvm("Chart", "can_write"),),
Pvm("SliceModelView", "can_download",): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "muldelete",): (Pvm("Chart", "can_write"),),
Pvm("SliceModelView", "can_mulexport",): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "can_favorite_status",): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "can_cache_screenshot",): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "can_screenshot",): (Pvm("Chart", "can_read"),),
Pvm("SliceModelView", "can_data_from_cache",): (Pvm("Chart", "can_read"),),
Pvm("SliceAsync", "can_list",): (Pvm("Chart", "can_read"),),
Pvm("SliceAsync", "muldelete",): (Pvm("Chart", "can_write"),),
Pvm(
"SliceModelView",
"can_edit",
): (Pvm("Chart", "can_write"),),
Pvm(
"SliceModelView",
"can_delete",
): (Pvm("Chart", "can_write"),),
Pvm(
"SliceModelView",
"can_add",
): (Pvm("Chart", "can_write"),),
Pvm(
"SliceModelView",
"can_download",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"muldelete",
): (Pvm("Chart", "can_write"),),
Pvm(
"SliceModelView",
"can_mulexport",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"can_favorite_status",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"can_cache_screenshot",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"can_screenshot",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceModelView",
"can_data_from_cache",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceAsync",
"can_list",
): (Pvm("Chart", "can_read"),),
Pvm(
"SliceAsync",
"muldelete",
): (Pvm("Chart", "can_write"),),
}

View File

@ -33,7 +33,11 @@ from alembic import op
def upgrade():
with op.batch_alter_table("query") as batch_op:
batch_op.add_column(
sa.Column("limiting_factor", sa.VARCHAR(255), server_default="UNKNOWN",)
sa.Column(
"limiting_factor",
sa.VARCHAR(255),
server_default="UNKNOWN",
)
)

View File

@ -39,20 +39,55 @@ from superset.migrations.shared.security_converge import (
Pvm,
)
NEW_PVMS = {"SavedQuery": ("can_read", "can_write",)}
NEW_PVMS = {
"SavedQuery": (
"can_read",
"can_write",
)
}
PVM_MAP = {
Pvm("SavedQueryView", "can_list"): (Pvm("SavedQuery", "can_read"),),
Pvm("SavedQueryView", "can_show"): (Pvm("SavedQuery", "can_read"),),
Pvm("SavedQueryView", "can_add",): (Pvm("SavedQuery", "can_write"),),
Pvm("SavedQueryView", "can_edit",): (Pvm("SavedQuery", "can_write"),),
Pvm("SavedQueryView", "can_delete",): (Pvm("SavedQuery", "can_write"),),
Pvm("SavedQueryView", "muldelete",): (Pvm("SavedQuery", "can_write"),),
Pvm("SavedQueryView", "can_mulexport",): (Pvm("SavedQuery", "can_read"),),
Pvm("SavedQueryViewApi", "can_show",): (Pvm("SavedQuery", "can_read"),),
Pvm("SavedQueryViewApi", "can_edit",): (Pvm("SavedQuery", "can_write"),),
Pvm("SavedQueryViewApi", "can_list",): (Pvm("SavedQuery", "can_read"),),
Pvm("SavedQueryViewApi", "can_add",): (Pvm("SavedQuery", "can_write"),),
Pvm("SavedQueryViewApi", "muldelete",): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryView",
"can_add",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryView",
"can_edit",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryView",
"can_delete",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryView",
"muldelete",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryView",
"can_mulexport",
): (Pvm("SavedQuery", "can_read"),),
Pvm(
"SavedQueryViewApi",
"can_show",
): (Pvm("SavedQuery", "can_read"),),
Pvm(
"SavedQueryViewApi",
"can_edit",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryViewApi",
"can_list",
): (Pvm("SavedQuery", "can_read"),),
Pvm(
"SavedQueryViewApi",
"can_add",
): (Pvm("SavedQuery", "can_write"),),
Pvm(
"SavedQueryViewApi",
"muldelete",
): (Pvm("SavedQuery", "can_write"),),
}

View File

@ -53,6 +53,8 @@ def upgrade():
def downgrade():
with op.batch_alter_table("row_level_security_filters") as batch_op:
batch_op.drop_index(op.f("ix_row_level_security_filters_filter_type"),)
batch_op.drop_index(
op.f("ix_row_level_security_filters_filter_type"),
)
batch_op.drop_column("filter_type")
batch_op.drop_column("group_key")

View File

@ -322,7 +322,9 @@ class Database(
self.sqlalchemy_uri = str(conn) # hides the password
def get_effective_user(
self, object_url: URL, user_name: Optional[str] = None,
self,
object_url: URL,
user_name: Optional[str] = None,
) -> Optional[str]:
"""
Get the effective user, especially during impersonation.

View File

@ -344,7 +344,8 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
@debounce(0.1)
def clear_cache_for_datasource(cls, datasource_id: int) -> None:
filter_query = select(
[dashboard_slices.c.dashboard_id], distinct=True,
[dashboard_slices.c.dashboard_id],
distinct=True,
).select_from(
join(
dashboard_slices,

View File

@ -18,7 +18,11 @@ from marshmallow import fields, Schema
from marshmallow.validate import Length
openapi_spec_methods_override = {
"get": {"get": {"description": "Get a saved query",}},
"get": {
"get": {
"description": "Get a saved query",
}
},
"get_list": {
"get": {
"description": "Get a list of saved queries, use Rison or JSON "

View File

@ -95,7 +95,9 @@ class BaseReportState:
self._execution_id = execution_id
def set_state_and_log(
self, state: ReportState, error_message: Optional[str] = None,
self,
state: ReportState,
error_message: Optional[str] = None,
) -> None:
"""
Updates current ReportSchedule state and TS. If on final state writes the log
@ -104,7 +106,8 @@ class BaseReportState:
now_dttm = datetime.utcnow()
self.set_state(state, now_dttm)
self.create_log(
state, error_message=error_message,
state,
error_message=error_message,
)
def set_state(self, state: ReportState, dttm: datetime) -> None:
@ -531,12 +534,14 @@ class ReportWorkingState(BaseReportState):
if self.is_on_working_timeout():
exception_timeout = ReportScheduleWorkingTimeoutError()
self.set_state_and_log(
ReportState.ERROR, error_message=str(exception_timeout),
ReportState.ERROR,
error_message=str(exception_timeout),
)
raise exception_timeout
exception_working = ReportSchedulePreviousWorkingError()
self.set_state_and_log(
ReportState.WORKING, error_message=str(exception_working),
ReportState.WORKING,
error_message=str(exception_working),
)
raise exception_working

View File

@ -227,7 +227,8 @@ class ReportScheduleDAO(BaseDAO):
@staticmethod
def find_last_success_log(
report_schedule: ReportSchedule, session: Optional[Session] = None,
report_schedule: ReportSchedule,
session: Optional[Session] = None,
) -> Optional[ReportExecutionLog]:
"""
Finds last success execution log for a given report
@ -245,7 +246,8 @@ class ReportScheduleDAO(BaseDAO):
@staticmethod
def find_last_entered_working_log(
report_schedule: ReportSchedule, session: Optional[Session] = None,
report_schedule: ReportSchedule,
session: Optional[Session] = None,
) -> Optional[ReportExecutionLog]:
"""
Finds last success execution log for a given report
@ -264,7 +266,8 @@ class ReportScheduleDAO(BaseDAO):
@staticmethod
def find_last_error_notification(
report_schedule: ReportSchedule, session: Optional[Session] = None,
report_schedule: ReportSchedule,
session: Optional[Session] = None,
) -> Optional[ReportExecutionLog]:
"""
Finds last error email sent

View File

@ -203,7 +203,9 @@ class ReportSchedulePostSchema(Schema):
default=ReportDataFormat.VISUALIZATION,
validate=validate.OneOf(choices=tuple(key.value for key in ReportDataFormat)),
)
extra = fields.Dict(default=None,)
extra = fields.Dict(
default=None,
)
force_screenshot = fields.Boolean(default=False)
@validates_schema

View File

@ -1376,7 +1376,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
def get_guest_user_from_token(self, token: GuestToken) -> GuestUser:
return self.guest_user_cls(
token=token, roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])],
token=token,
roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])],
)
def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]:

View File

@ -170,7 +170,10 @@ class ExecuteSqlCommand(BaseCommand):
except Exception as ex:
raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex
def _set_query_limit_if_required(self, rendered_query: str,) -> None:
def _set_query_limit_if_required(
self,
rendered_query: str,
) -> None:
if self._is_required_to_set_limit():
self._set_query_limit(rendered_query)

View File

@ -100,7 +100,12 @@ class SqlQueryRenderImpl(SqlQueryRender):
extra={
"undefined_parameters": list(undefined_parameters),
"template_parameters": execution_context.template_params,
"issue_codes": [{"code": 1006, "message": MSG_OF_1006,}],
"issue_codes": [
{
"code": 1006,
"message": MSG_OF_1006,
}
],
},
)

View File

@ -107,6 +107,5 @@ try:
def gauge(self, key: str, value: float) -> None:
self.client.gauge(key, value)
except Exception: # pylint: disable=broad-except
pass

Some files were not shown because too many files have changed in this diff Show More