style(mypy): Enforcing typing for superset (#9943)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-06-03 15:26:12 -07:00 committed by GitHub
parent dcac860f3e
commit 244677cf5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 393 additions and 313 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*,superset.viz,superset.viz_sip38]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -17,6 +17,7 @@
import logging
import os
from typing import Any, Callable, Dict
import wtforms_json
from flask import Flask, redirect
@ -41,13 +42,14 @@ from superset.extensions import (
talisman,
)
from superset.security import SupersetSecurityManager
from superset.typing import FlaskResponse
from superset.utils.core import pessimistic_connection_handling
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
logger = logging.getLogger(__name__)
def create_app():
def create_app() -> Flask:
app = Flask(__name__)
try:
@ -68,7 +70,7 @@ def create_app():
class SupersetIndexView(IndexView):
@expose("/")
def index(self):
def index(self) -> FlaskResponse:
return redirect("/superset/welcome")
@ -109,8 +111,8 @@ class SupersetAppInitializer:
abstract = True
# Grab each call into the task and set up an app context
def __call__(self, *args, **kwargs):
with flask_app.app_context():
def __call__(self, *args: Any, **kwargs: Any) -> Any:
with flask_app.app_context(): # type: ignore
return task_base.__call__(self, *args, **kwargs)
celery_app.Task = AppContextTask
@ -454,51 +456,41 @@ class SupersetAppInitializer:
order to fully init the app
"""
self.pre_init()
self.setup_db()
self.configure_celery()
self.setup_event_logger()
self.setup_bundle_manifest()
self.register_blueprints()
self.configure_wtf()
self.configure_logging()
self.configure_middlewares()
self.configure_cache()
self.configure_jinja_context()
with self.flask_app.app_context():
with self.flask_app.app_context(): # type: ignore
self.init_app_in_ctx()
self.post_init()
def setup_event_logger(self):
def setup_event_logger(self) -> None:
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
self.flask_app.config.get("EVENT_LOGGER", DBEventLogger())
)
def configure_data_sources(self):
def configure_data_sources(self) -> None:
# Registering sources
module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"]
module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"])
ConnectorRegistry.register_sources(module_datasource_map)
def configure_cache(self):
def configure_cache(self) -> None:
cache_manager.init_app(self.flask_app)
results_backend_manager.init_app(self.flask_app)
def configure_feature_flags(self):
def configure_feature_flags(self) -> None:
feature_flag_manager.init_app(self.flask_app)
def configure_fab(self):
def configure_fab(self) -> None:
if self.config["SILENCE_FAB"]:
logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
@ -516,7 +508,7 @@ class SupersetAppInitializer:
appbuilder.update_perms = False
appbuilder.init_app(self.flask_app, db.session)
def configure_url_map_converters(self):
def configure_url_map_converters(self) -> None:
#
# Doing local imports here as model importing causes a reference to
# app.config to be invoked and we need the current_app to have been setup
@ -527,10 +519,10 @@ class SupersetAppInitializer:
self.flask_app.url_map.converters["regex"] = RegexConverter
self.flask_app.url_map.converters["object_type"] = ObjectTypeConverter
def configure_jinja_context(self):
def configure_jinja_context(self) -> None:
jinja_context_manager.init_app(self.flask_app)
def configure_middlewares(self):
def configure_middlewares(self) -> None:
if self.config["ENABLE_CORS"]:
from flask_cors import CORS
@ -539,24 +531,28 @@ class SupersetAppInitializer:
if self.config["ENABLE_PROXY_FIX"]:
from werkzeug.middleware.proxy_fix import ProxyFix
self.flask_app.wsgi_app = ProxyFix(
self.flask_app.wsgi_app = ProxyFix( # type: ignore
self.flask_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"]
)
if self.config["ENABLE_CHUNK_ENCODING"]:
class ChunkedEncodingFix: # pylint: disable=too-few-public-methods
def __init__(self, app):
def __init__(self, app: Flask) -> None:
self.app = app
def __call__(self, environ, start_response):
def __call__(
self, environ: Dict[str, Any], start_response: Callable
) -> Any:
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end.
if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == "chunked":
environ["wsgi.input_terminated"] = True
return self.app(environ, start_response)
self.flask_app.wsgi_app = ChunkedEncodingFix(self.flask_app.wsgi_app)
self.flask_app.wsgi_app = ChunkedEncodingFix( # type: ignore
self.flask_app.wsgi_app # type: ignore
)
if self.config["UPLOAD_FOLDER"]:
try:
@ -565,7 +561,9 @@ class SupersetAppInitializer:
pass
for middleware in self.config["ADDITIONAL_MIDDLEWARE"]:
self.flask_app.wsgi_app = middleware(self.flask_app.wsgi_app)
self.flask_app.wsgi_app = middleware( # type: ignore
self.flask_app.wsgi_app
)
# Flask-Compress
if self.config["ENABLE_FLASK_COMPRESS"]:
@ -574,27 +572,27 @@ class SupersetAppInitializer:
if self.config["TALISMAN_ENABLED"]:
talisman.init_app(self.flask_app, **self.config["TALISMAN_CONFIG"])
def configure_logging(self):
def configure_logging(self) -> None:
self.config["LOGGING_CONFIGURATOR"].configure_logging(
self.config, self.flask_app.debug
)
def setup_db(self):
def setup_db(self) -> None:
db.init_app(self.flask_app)
with self.flask_app.app_context():
with self.flask_app.app_context(): # type: ignore
pessimistic_connection_handling(db.engine)
migrate.init_app(self.flask_app, db=db, directory=APP_DIR + "/migrations")
def configure_wtf(self):
def configure_wtf(self) -> None:
if self.config["WTF_CSRF_ENABLED"]:
csrf = CSRFProtect(self.flask_app)
csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"]
for ex in csrf_exempt_list:
csrf.exempt(ex)
def register_blueprints(self):
def register_blueprints(self) -> None:
for bp in self.config["BLUEPRINTS"]:
try:
logger.info(f"Registering blueprint: '{bp.name}'")
@ -602,5 +600,5 @@ class SupersetAppInitializer:
except Exception: # pylint: disable=broad-except
logger.exception("blueprint registration failed")
def setup_bundle_manifest(self):
def setup_bundle_manifest(self) -> None:
manifest_processor.init_app(self.flask_app)

View File

@ -19,10 +19,11 @@ import logging
from datetime import datetime
from subprocess import Popen
from sys import stdout
from typing import Type, Union
from typing import Any, Dict, Type, Union
import click
import yaml
from celery.utils.abstract import CallableTask
from colorama import Fore, Style
from flask import g
from flask.cli import FlaskGroup, with_appcontext
@ -56,17 +57,17 @@ def normalize_token(token_name: str) -> str:
context_settings={"token_normalize_func": normalize_token},
)
@with_appcontext
def superset():
def superset() -> None:
"""This is a management script for the Superset application."""
@app.shell_context_processor
def make_shell_context(): # pylint: disable=unused-variable
def make_shell_context() -> Dict[str, Any]: # pylint: disable=unused-variable
return dict(app=app, db=db)
@superset.command()
@with_appcontext
def init():
def init() -> None:
"""Inits the Superset application"""
appbuilder.add_permissions(update_perms=True)
security_manager.sync_role_definitions()
@ -75,7 +76,7 @@ def init():
@superset.command()
@with_appcontext
@click.option("--verbose", "-v", is_flag=True, help="Show extra information")
def version(verbose):
def version(verbose: bool) -> None:
"""Prints the current version number"""
print(Fore.BLUE + "-=" * 15)
print(
@ -90,7 +91,9 @@ def version(verbose):
print(Style.RESET_ALL)
def load_examples_run(load_test_data, only_metadata=False, force=False):
def load_examples_run(
load_test_data: bool, only_metadata: bool = False, force: bool = False
) -> None:
if only_metadata:
print("Loading examples metadata")
else:
@ -160,7 +163,9 @@ def load_examples_run(load_test_data, only_metadata=False, force=False):
@click.option(
"--force", "-f", is_flag=True, help="Force load data even if table already exists"
)
def load_examples(load_test_data, only_metadata=False, force=False):
def load_examples(
load_test_data: bool, only_metadata: bool = False, force: bool = False
) -> None:
"""Loads a set of Slices and Dashboards and a supporting dataset """
load_examples_run(load_test_data, only_metadata, force)
@ -169,7 +174,7 @@ def load_examples(load_test_data, only_metadata=False, force=False):
@superset.command()
@click.option("--database_name", "-d", help="Database name to change")
@click.option("--uri", "-u", help="Database URI to change")
def set_database_uri(database_name, uri):
def set_database_uri(database_name: str, uri: str) -> None:
"""Updates a database connection URI """
utils.get_or_create_db(database_name, uri)
@ -189,7 +194,7 @@ def set_database_uri(database_name, uri):
default=False,
help="Specify using 'merge' property during operation. " "Default value is False.",
)
def refresh_druid(datasource, merge):
def refresh_druid(datasource: str, merge: bool) -> None:
"""Refresh druid datasources"""
session = db.session()
from superset.connectors.druid.models import DruidCluster
@ -226,7 +231,7 @@ def refresh_druid(datasource, merge):
default=None,
help="Specify the user name to assign dashboards to",
)
def import_dashboards(path, recursive, username):
def import_dashboards(path: str, recursive: bool, username: str) -> None:
"""Import dashboards from JSON"""
from superset.utils import dashboard_import_export
@ -258,7 +263,7 @@ def import_dashboards(path, recursive, username):
@click.option(
"--print_stdout", "-p", is_flag=True, default=False, help="Print JSON to stdout"
)
def export_dashboards(print_stdout, dashboard_file):
def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
"""Export dashboards to JSON"""
from superset.utils import dashboard_import_export
@ -295,7 +300,7 @@ def export_dashboards(print_stdout, dashboard_file):
default=False,
help="recursively search the path for yaml files",
)
def import_datasources(path, sync, recursive):
def import_datasources(path: str, sync: str, recursive: bool) -> None:
"""Import datasources from YAML"""
from superset.utils import dict_import_export
@ -345,8 +350,11 @@ def import_datasources(path, sync, recursive):
help="Include fields containing defaults",
)
def export_datasources(
print_stdout, datasource_file, back_references, include_defaults
):
print_stdout: bool,
datasource_file: str,
back_references: bool,
include_defaults: bool,
) -> None:
"""Export datasources to YAML"""
from superset.utils import dict_import_export
@ -373,7 +381,7 @@ def export_datasources(
default=False,
help="Include parent back references",
)
def export_datasource_schema(back_references):
def export_datasource_schema(back_references: bool) -> None:
"""Export datasource YAML schema to stdout"""
from superset.utils import dict_import_export
@ -383,7 +391,7 @@ def export_datasource_schema(back_references):
@superset.command()
@with_appcontext
def update_datasources_cache():
def update_datasources_cache() -> None:
"""Refresh sqllab datasources cache"""
from superset.models.core import Database
@ -406,7 +414,7 @@ def update_datasources_cache():
@click.option(
"--workers", "-w", type=int, help="Number of celery server workers to fire up"
)
def worker(workers):
def worker(workers: int) -> None:
"""Starts a Superset worker for async SQL query execution."""
logger.info(
"The 'superset worker' command is deprecated. Please use the 'celery "
@ -431,7 +439,7 @@ def worker(workers):
@click.option(
"-a", "--address", default="localhost", help="Address on which to run the service"
)
def flower(port, address):
def flower(port: int, address: str) -> None:
"""Runs a Celery Flower web server
Celery Flower is a UI to monitor the Celery operation on a given
@ -487,7 +495,7 @@ def compute_thumbnails(
charts_only: bool,
force: bool,
model_id: int,
):
) -> None:
"""Compute thumbnails"""
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@ -500,8 +508,8 @@ def compute_thumbnails(
friendly_type: str,
model_cls: Union[Type[Dashboard], Type[Slice]],
model_id: int,
compute_func,
):
compute_func: CallableTask,
) -> None:
query = db.session.query(model_cls)
if model_id:
query = query.filter(model_cls.id.in_(model_id))
@ -528,7 +536,7 @@ def compute_thumbnails(
@superset.command()
@with_appcontext
def load_test_users():
def load_test_users() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
@ -538,7 +546,7 @@ def load_test_users():
load_test_users_run()
def load_test_users_run():
def load_test_users_run() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
@ -583,7 +591,7 @@ def load_test_users_run():
@superset.command()
@with_appcontext
def sync_tags():
def sync_tags() -> None:
"""Rebuilds special tags (owner, type, favorited by)."""
# pylint: disable=no-member
metadata = Model.metadata

View File

@ -28,8 +28,9 @@ import os
import sys
from collections import OrderedDict
from datetime import date
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
from cachelib.base import BaseCache
from celery.schedules import crontab
from dateutil import tz
from flask_appbuilder.security.manager import AUTH_DB
@ -78,7 +79,7 @@ PACKAGE_JSON_FILE = os.path.join(BASE_DIR, "static", "assets", "package.json")
FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
def _try_json_readversion(filepath):
def _try_json_readversion(filepath: str) -> Optional[str]:
try:
with open(filepath, "r") as f:
return json.load(f).get("version")
@ -86,7 +87,9 @@ def _try_json_readversion(filepath):
return None
def _try_json_readsha(filepath, length): # pylint: disable=unused-argument
def _try_json_readsha( # pylint: disable=unused-argument
filepath: str, length: int
) -> Optional[str]:
try:
with open(filepath, "r") as f:
return json.load(f).get("GIT_SHA")[:length]
@ -453,6 +456,7 @@ BACKUP_COUNT = 30
# user=None,
# client=None,
# security_manager=None,
# log_params=None,
# ):
# pass
QUERY_LOGGER = None
@ -578,10 +582,9 @@ SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
Callable[["Database", "models.User", str, str], str]
] = None
# An instantiated derivative of cachelib.base.BaseCache
# if enabled, it can be used to store the results of long-running queries
# If enabled, it can be used to store the results of long-running queries
# in SQL Lab by using the "Run Async" button/feature
RESULTS_BACKEND = None
RESULTS_BACKEND: Optional[BaseCache] = None
# Use PyArrow and MessagePack for async query results serialization,
# rather than JSON. This feature requires additional testing from the
@ -604,7 +607,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC: Callable[
# The namespace within hive where the tables created from
# uploading CSVs will be stored.
UPLOADED_CSV_HIVE_NAMESPACE = None
UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None
# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_csv_upload
@ -614,7 +617,7 @@ UPLOADED_CSV_HIVE_NAMESPACE = None
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
["Database", "models.User"], List[str]
] = lambda database, user: [
UPLOADED_CSV_HIVE_NAMESPACE # type: ignore
UPLOADED_CSV_HIVE_NAMESPACE
] if UPLOADED_CSV_HIVE_NAMESPACE else []
# A dictionary of items that gets merged into the Jinja context for
@ -628,7 +631,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {}
# dictionary, which means the existing keys get overwritten by the content of this
# dictionary. The customized addons don't necessarily need to use jinjia templating
# language. This allows you to define custom logic to process macro template.
CUSTOM_TEMPLATE_PROCESSORS = {} # type: Dict[str, BaseTemplateProcessor]
CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {}
# Roles that are controlled by the API / Superset and should not be changes
# by humans.

View File

@ -32,7 +32,7 @@ class SupersetException(Exception):
super().__init__(self.message)
@property
def exception(self):
def exception(self) -> Optional[Exception]:
return self._exception

View File

@ -20,10 +20,12 @@ import random
import time
import uuid
from datetime import datetime, timedelta
from typing import Dict, TYPE_CHECKING # pylint: disable=unused-import
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
import celery
from cachelib.base import BaseCache
from dateutil.relativedelta import relativedelta
from flask import Flask
from flask_appbuilder import AppBuilder, SQLA
from flask_migrate import Migrate
from flask_talisman import Talisman
@ -32,7 +34,6 @@ from werkzeug.local import LocalProxy
from superset.utils.cache_manager import CacheManager
from superset.utils.feature_flag_manager import FeatureFlagManager
# Avoid circular import
if TYPE_CHECKING:
from superset.jinja_context import ( # pylint: disable=unused-import
BaseTemplateProcessor,
@ -49,18 +50,18 @@ class JinjaContextManager:
"timedelta": timedelta,
"uuid": uuid,
}
self._template_processors = {} # type: Dict[str, BaseTemplateProcessor]
self._template_processors: Dict[str, Type["BaseTemplateProcessor"]] = {}
def init_app(self, app):
def init_app(self, app: Flask) -> None:
self._base_context.update(app.config["JINJA_CONTEXT_ADDONS"])
self._template_processors.update(app.config["CUSTOM_TEMPLATE_PROCESSORS"])
@property
def base_context(self):
def base_context(self) -> Dict[str, Any]:
return self._base_context
@property
def template_processors(self):
def template_processors(self) -> Dict[str, Type["BaseTemplateProcessor"]]:
return self._template_processors
@ -69,35 +70,35 @@ class ResultsBackendManager:
self._results_backend = None
self._use_msgpack = False
def init_app(self, app):
self._results_backend = app.config.get("RESULTS_BACKEND")
self._use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK")
def init_app(self, app: Flask) -> None:
self._results_backend = app.config["RESULTS_BACKEND"]
self._use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"]
@property
def results_backend(self):
def results_backend(self) -> Optional[BaseCache]:
return self._results_backend
@property
def should_use_msgpack(self):
def should_use_msgpack(self) -> bool:
return self._use_msgpack
class UIManifestProcessor:
def __init__(self, app_dir: str) -> None:
self.app = None
self.manifest: dict = {}
self.app: Optional[Flask] = None
self.manifest: Dict[str, Dict[str, List[str]]] = {}
self.manifest_file = f"{app_dir}/static/assets/manifest.json"
def init_app(self, app):
def init_app(self, app: Flask) -> None:
self.app = app
# Preload the cache
self.parse_manifest_json()
@app.context_processor
def get_manifest(): # pylint: disable=unused-variable
def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable
loaded_chunks = set()
def get_files(bundle, asset_type="js"):
def get_files(bundle: str, asset_type: str = "js") -> List[str]:
files = self.get_manifest_files(bundle, asset_type)
filtered_files = [f for f in files if f not in loaded_chunks]
for f in filtered_files:
@ -109,18 +110,18 @@ class UIManifestProcessor:
css_manifest=lambda bundle: get_files(bundle, "css"),
)
def parse_manifest_json(self):
def parse_manifest_json(self) -> None:
try:
with open(self.manifest_file, "r") as f:
# the manifest includes non-entry files
# we only need entries in templates
# the manifest includes non-entry files we only need entries in
# templates
full_manifest = json.load(f)
self.manifest = full_manifest.get("entrypoints", {})
except Exception: # pylint: disable=broad-except
pass
def get_manifest_files(self, bundle, asset_type):
if self.app.debug:
def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]:
if self.app and self.app.debug:
self.parse_manifest_json()
return self.manifest.get(bundle, {}).get(asset_type, [])
@ -133,7 +134,7 @@ db = SQLA()
_event_logger: dict = {}
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager()
jinja_context_manager = JinjaContextManager() # type: JinjaContextManager
jinja_context_manager = JinjaContextManager()
manifest_processor = UIManifestProcessor(APP_DIR)
migrate = Migrate()
results_backend_manager = ResultsBackendManager()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Contains the logic to create cohesive forms on the explore view"""
from typing import List # pylint: disable=unused-import
from typing import Any, List, Optional
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from wtforms import Field
@ -25,24 +25,24 @@ class CommaSeparatedListField(Field):
widget = BS3TextFieldWidget()
data: List[str] = []
def _value(self):
def _value(self) -> str:
if self.data:
return u", ".join(self.data)
return ", ".join(self.data)
return u""
return ""
def process_formdata(self, valuelist):
def process_formdata(self, valuelist: List[str]) -> None:
if valuelist:
self.data = [x.strip() for x in valuelist[0].split(",")]
else:
self.data = []
def filter_not_empty_values(value):
def filter_not_empty_values(values: Optional[List[Any]]) -> Optional[List[Any]]:
"""Returns a list of non empty values or None"""
if not value:
if not values:
return None
data = [x for x in value if x]
data = [value for value in values if value]
if not data:
return None
return data

View File

@ -17,7 +17,7 @@
"""Defines the templating context for SQL Lab"""
import inspect
import re
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, cast, List, Optional, Tuple, TYPE_CHECKING
from flask import g, request
from jinja2.sandbox import SandboxedEnvironment
@ -207,7 +207,7 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods
def __init__(
self,
database: Optional["Database"] = None,
database: "Database",
query: Optional["Query"] = None,
table: Optional["SqlaTable"] = None,
extra_cache_keys: Optional[List[Any]] = None,
@ -266,7 +266,7 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
schema, table_name = table_name.split(".")
return table_name, schema
def first_latest_partition(self, table_name: str) -> str:
def first_latest_partition(self, table_name: str) -> Optional[str]:
"""
Gets the first value in the array of all latest partitions
@ -275,9 +275,10 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
:raises IndexError: If no partition exists
"""
return self.latest_partitions(table_name)[0]
latest_partitions = self.latest_partitions(table_name)
return latest_partitions[0] if latest_partitions else None
def latest_partitions(self, table_name: str) -> List[str]:
def latest_partitions(self, table_name: str) -> Optional[List[str]]:
"""
Gets the array of all latest partitions
@ -285,16 +286,21 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
:return: the latest partition array
"""
from superset.db_engine_specs.presto import PrestoEngineSpec
table_name, schema = self._schema_table(table_name, self.schema)
assert self.database
return self.database.db_engine_spec.latest_partition( # type: ignore
return cast(PrestoEngineSpec, self.database.db_engine_spec).latest_partition(
table_name, schema, self.database
)[1]
def latest_sub_partition(self, table_name, **kwargs):
def latest_sub_partition(self, table_name: str, **kwargs: Any) -> Any:
table_name, schema = self._schema_table(table_name, self.schema)
assert self.database
return self.database.db_engine_spec.latest_sub_partition(
from superset.db_engine_specs.presto import PrestoEngineSpec
return cast(
PrestoEngineSpec, self.database.db_engine_spec
).latest_sub_partition(
table_name=table_name, schema=schema, database=self.database, **kwargs
)

View File

@ -19,7 +19,7 @@ import uuid
from contextlib import closing
from datetime import datetime
from sys import getsizeof
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
import backoff
import msgpack
@ -27,9 +27,10 @@ import pyarrow as pa
import simplejson as json
import sqlalchemy
from celery.exceptions import SoftTimeLimitExceeded
from celery.task.base import Task
from contextlib2 import contextmanager
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import NullPool
from superset import (
@ -77,7 +78,9 @@ class SqlLabTimeoutException(SqlLabException):
pass
def handle_query_error(msg, query, session, payload=None):
def handle_query_error(
msg: str, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Local method handling error while processing the SQL"""
payload = payload or {}
troubleshooting_link = config["TROUBLESHOOTING_LINK"]
@ -91,14 +94,14 @@ def handle_query_error(msg, query, session, payload=None):
return payload
def get_query_backoff_handler(details):
def get_query_backoff_handler(details: Dict[Any, Any]) -> None:
query_id = details["kwargs"]["query_id"]
logger.error(f"Query with id `{query_id}` could not be retrieved")
stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1))
logger.error(f"Query {query_id}: Sleeping for a sec before retrying...")
def get_query_giveup_handler(_):
def get_query_giveup_handler(_: Any) -> None:
stats_logger.incr("error_failed_at_getting_orm_query")
@ -110,7 +113,7 @@ def get_query_giveup_handler(_):
on_giveup=get_query_giveup_handler,
max_tries=5,
)
def get_query(query_id, session):
def get_query(query_id: int, session: Session) -> Query:
"""attempts to get the query and retry if it cannot"""
try:
return session.query(Query).filter_by(id=query_id).one()
@ -119,7 +122,7 @@ def get_query(query_id, session):
@contextmanager
def session_scope(nullpool):
def session_scope(nullpool: bool) -> Iterator[Session]:
"""Provide a transactional scope around a series of operations."""
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
if "sqlite" in database_uri:
@ -154,16 +157,16 @@ def session_scope(nullpool):
soft_time_limit=SQLLAB_TIMEOUT,
)
def get_sql_results( # pylint: disable=too-many-arguments
ctask,
query_id,
rendered_query,
return_results=True,
store_results=False,
user_name=None,
start_time=None,
expand_data=False,
log_params=None,
):
ctask: Task,
query_id: int,
rendered_query: str,
return_results: bool = True,
store_results: bool = False,
user_name: Optional[str] = None,
start_time: Optional[float] = None,
expand_data: bool = False,
log_params: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""Executes the sql query returns the results."""
with session_scope(not ctask.request.called_directly) as session:
@ -188,7 +191,14 @@ def get_sql_results( # pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments
def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_params):
def execute_sql_statement(
sql_statement: str,
query: Query,
user_name: Optional[str],
session: Session,
cursor: Any,
log_params: Optional[Dict[str, Any]],
) -> SupersetResultSet:
"""Executes a single SQL statement"""
database = query.database
db_engine_spec = database.db_engine_spec
@ -275,7 +285,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_
def _serialize_payload(
payload: dict, use_msgpack: Optional[bool] = False
payload: Dict[Any, Any], use_msgpack: Optional[bool] = False
) -> Union[bytes, str]:
logger.debug(f"Serializing to msgpack: {use_msgpack}")
if use_msgpack:
@ -321,24 +331,24 @@ def _serialize_and_expand_data(
return (data, selected_columns, all_columns, expanded_columns)
def execute_sql_statements(
query_id,
rendered_query,
return_results=True,
store_results=False,
user_name=None,
session=None,
start_time=None,
expand_data=False,
log_params=None,
): # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
query_id: int,
rendered_query: str,
return_results: bool,
store_results: bool,
user_name: Optional[str],
session: Session,
start_time: Optional[float],
expand_data: bool,
log_params: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
"""Executes the sql query returns the results."""
if store_results and start_time:
# only asynchronous queries
stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)
query = get_query(query_id, session)
payload = dict(query_id=query_id)
payload: Dict[str, Any] = dict(query_id=query_id)
database = query.database
db_engine_spec = database.db_engine_spec
db_engine_spec.patch()
@ -406,7 +416,7 @@ def execute_sql_statements(
)
query.end_time = now_as_float()
use_arrow_data = store_results and results_backend_use_msgpack
use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
result_set, db_engine_spec, use_arrow_data, expand_data
)
@ -432,7 +442,7 @@ def execute_sql_statements(
"sqllab.query.results_backend_write_serialization", stats_logger
):
serialized_payload = _serialize_payload(
payload, results_backend_use_msgpack
payload, cast(bool, results_backend_use_msgpack)
)
cache_timeout = database.cache_timeout
if cache_timeout is None:

View File

@ -158,7 +158,7 @@ class ParsedQuery:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
def _process_tokenlist(self, token_list: TokenList):
def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
@ -204,7 +204,9 @@ class ParsedQuery:
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
return exec_sql
def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches
def _extract_from_token( # pylint: disable=too-many-branches
self, token: Token
) -> None:
"""
Populate self._tables from token

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from colorama import Fore, Style
@ -40,7 +41,7 @@ class BaseStatsLogger:
"""Decrement a counter"""
raise NotImplementedError()
def timing(self, key, value: float) -> None:
def timing(self, key: str, value: float) -> None:
raise NotImplementedError()
def gauge(self, key: str) -> None:
@ -49,18 +50,18 @@ class BaseStatsLogger:
class DummyStatsLogger(BaseStatsLogger):
def incr(self, key):
def incr(self, key: str) -> None:
logger.debug(Fore.CYAN + "[stats_logger] (incr) " + key + Style.RESET_ALL)
def decr(self, key):
def decr(self, key: str) -> None:
logger.debug((Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL))
def timing(self, key, value):
def timing(self, key: str, value: float) -> None:
logger.debug(
(Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL)
)
def gauge(self, key):
def gauge(self, key: str) -> None:
logger.debug(
(Fore.CYAN + "[stats_logger] (gauge) " + f"{key}" + Style.RESET_ALL)
)
@ -71,8 +72,12 @@ try:
class StatsdStatsLogger(BaseStatsLogger):
def __init__( # pylint: disable=super-init-not-called
self, host="localhost", port=8125, prefix="superset", statsd_client=None
):
self,
host: str = "localhost",
port: int = 8125,
prefix: str = "superset",
statsd_client: Optional[StatsClient] = None,
) -> None:
"""
Initializes from either params or a supplied, pre-constructed statsd client.
@ -84,16 +89,16 @@ try:
else:
self.client = StatsClient(host=host, port=port, prefix=prefix)
def incr(self, key):
def incr(self, key: str) -> None:
self.client.incr(key)
def decr(self, key):
def decr(self, key: str) -> None:
self.client.decr(key)
def timing(self, key, value):
def timing(self, key: str, value: float) -> None:
self.client.timing(key, value)
def gauge(self, key):
def gauge(self, key: str) -> None:
# pylint: disable=no-value-for-parameter
self.client.gauge(key)

View File

@ -33,6 +33,7 @@ Granularity = Union[str, Dict[str, Union[str, float]]]
Metric = Union[Dict[str, str], str]
QueryObjectDict = Dict[str, Any]
VizData = Optional[Union[List[Any], Dict[Any, Any]]]
VizPayload = Dict[str, Any]
# Flask response.
Base = Union[bytes, str]

File diff suppressed because it is too large Load Diff

View File

@ -20,6 +20,7 @@
These objects represent the backend of all the visualizations that
Superset can render.
"""
# mypy: ignore-errors
import copy
import hashlib
import inspect
@ -610,7 +611,7 @@ class TableViz(BaseViz):
raise QueryObjectValidationError(
_("Pick a granularity in the Time section or " "uncheck 'Include Time'")
)
return fd.get("include_time")
return bool(fd.get("include_time"))
def query_obj(self):
d = super().query_obj()

View File

@ -974,7 +974,7 @@ class BaseDeckGLVizTestCase(SupersetTestCase):
test_viz_deckgl = viz.DeckScatterViz(datasource, form_data)
test_viz_deckgl.point_radius_fixed = {}
result = test_viz_deckgl.get_metrics()
assert result is None
assert result == []
def test_get_js_columns(self):
form_data = load_fixture("deck_path_form_data.json")