chore: Enforce Mypy for non-tests (#15757)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2021-07-21 11:46:43 -07:00 committed by GitHub
parent d26254099e
commit ab4e3b9bf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 134 additions and 74 deletions

View File

@ -24,6 +24,7 @@ from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Union
import click
from click.core import Context
try:
from github import BadCredentialsException, Github, PullRequest, Repository
@ -50,7 +51,7 @@ class GitLog:
author_email: str = ""
def __eq__(self, other: object) -> bool:
""" A log entry is considered equal if it has the same PR number """
"""A log entry is considered equal if it has the same PR number"""
if isinstance(other, self.__class__):
return other.pr_number == self.pr_number
return False
@ -170,7 +171,7 @@ class GitChangeLog:
def _parse_change_log(
self, changelog: Dict[str, str], pr_info: Dict[str, str], github_login: str,
):
) -> None:
formatted_pr = (
f"- [#{pr_info.get('id')}]"
f"(https://github.com/{SUPERSET_REPO}/pull/{pr_info.get('id')}) "
@ -324,8 +325,8 @@ def print_title(message: str) -> None:
@click.pass_context
@click.option("--previous_version", help="The previous release version", required=True)
@click.option("--current_version", help="The current release version", required=True)
def cli(ctx, previous_version: str, current_version: str) -> None:
""" Welcome to change log generator """
def cli(ctx: Context, previous_version: str, current_version: str) -> None:
"""Welcome to change log generator"""
previous_logs = GitLogs(previous_version)
current_logs = GitLogs(current_version)
previous_logs.fetch()
@ -337,7 +338,7 @@ def cli(ctx, previous_version: str, current_version: str) -> None:
@cli.command("compare")
@click.pass_obj
def compare(base_parameters: BaseParameters) -> None:
""" Compares both versions (by PR) """
"""Compares both versions (by PR)"""
previous_logs = base_parameters.previous_logs
current_logs = base_parameters.current_logs
print_title(
@ -369,7 +370,7 @@ def compare(base_parameters: BaseParameters) -> None:
def change_log(
base_parameters: BaseParameters, csv: str, access_token: str, risk: bool
) -> None:
""" Outputs a changelog (by PR) """
"""Outputs a changelog (by PR)"""
previous_logs = base_parameters.previous_logs
current_logs = base_parameters.current_logs
previous_diff_logs = previous_logs.diff(current_logs)

View File

@ -17,7 +17,9 @@
#
import smtplib
import ssl
from typing import List
from typing import Any, Dict, List, Optional
from click.core import Context
try:
import jinja2
@ -50,7 +52,7 @@ def send_email(
sender_email: str,
receiver_email: str,
message: str,
):
) -> None:
"""
Send a simple text email (SMTP)
"""
@ -61,7 +63,7 @@ def send_email(
server.sendmail(sender_email, receiver_email, message)
def render_template(template_file: str, **kwargs) -> str:
def render_template(template_file: str, **kwargs: Any) -> str:
"""
Simple render template based on named parameters
@ -73,7 +75,9 @@ def render_template(template_file: str, **kwargs) -> str:
return template.render(kwargs)
def inter_send_email(username, password, sender_email, receiver_email, message):
def inter_send_email(
username: str, password: str, sender_email: str, receiver_email: str, message: str
) -> None:
print("--------------------------")
print("SMTP Message")
print("--------------------------")
@ -102,16 +106,16 @@ def inter_send_email(username, password, sender_email, receiver_email, message):
class BaseParameters(object):
def __init__(
self, email=None, username=None, password=None, version=None, version_rc=None
):
self, email: str, username: str, password: str, version: str, version_rc: str,
) -> None:
self.email = email
self.username = username
self.password = password
self.version = version
self.version_rc = version_rc
self.template_arguments = dict()
self.template_arguments: Dict[str, Any] = {}
def __repr__(self):
def __repr__(self) -> str:
return f"Apache Credentials: {self.email}/{self.username}/{self.version}/{self.version_rc}"
@ -133,8 +137,15 @@ class BaseParameters(object):
)
@click.option("--version", envvar="SUPERSET_VERSION")
@click.option("--version_rc", envvar="SUPERSET_VERSION_RC")
def cli(ctx, apache_email, apache_username, apache_password, version, version_rc):
""" Welcome to releasing send email CLI interface! """
def cli(
ctx: Context,
apache_email: str,
apache_username: str,
apache_password: str,
version: str,
version_rc: str,
) -> None:
"""Welcome to releasing send email CLI interface!"""
base_parameters = BaseParameters(
apache_email, apache_username, apache_password, version, version_rc
)
@ -155,7 +166,7 @@ def cli(ctx, apache_email, apache_username, apache_password, version, version_rc
prompt="The receiver email (To:)",
)
@click.pass_obj
def vote_pmc(base_parameters, receiver_email):
def vote_pmc(base_parameters: BaseParameters, receiver_email: str) -> None:
template_file = "email_templates/vote_pmc.j2"
base_parameters.template_arguments["receiver_email"] = receiver_email
message = render_template(template_file, **base_parameters.template_arguments)
@ -202,13 +213,13 @@ def vote_pmc(base_parameters, receiver_email):
)
@click.pass_obj
def result_pmc(
base_parameters,
receiver_email,
vote_bindings,
vote_nonbindings,
vote_negatives,
vote_thread,
):
base_parameters: BaseParameters,
receiver_email: str,
vote_bindings: str,
vote_nonbindings: str,
vote_negatives: str,
vote_thread: str,
) -> None:
template_file = "email_templates/result_pmc.j2"
base_parameters.template_arguments["receiver_email"] = receiver_email
base_parameters.template_arguments["vote_bindings"] = string_comma_to_list(
@ -239,7 +250,7 @@ def result_pmc(
prompt="The receiver email (To:)",
)
@click.pass_obj
def announce(base_parameters, receiver_email):
def announce(base_parameters: BaseParameters, receiver_email: str) -> None:
template_file = "email_templates/announce.j2"
base_parameters.template_arguments["receiver_email"] = receiver_email
message = render_template(template_file, **base_parameters.template_arguments)

View File

@ -23,6 +23,7 @@
import logging
import os
from datetime import timedelta
from typing import Optional
from cachelib.file import FileSystemCache
from celery.schedules import crontab
@ -30,7 +31,7 @@ from celery.schedules import crontab
logger = logging.getLogger()
def get_env_variable(var_name, default=None):
def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
"""Get the environment variable or raise exception."""
try:
return os.environ[var_name]
@ -63,8 +64,8 @@ SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % (
REDIS_HOST = get_env_variable("REDIS_HOST")
REDIS_PORT = get_env_variable("REDIS_PORT")
REDIS_CELERY_DB = get_env_variable("REDIS_CELERY_DB", 0)
REDIS_RESULTS_DB = get_env_variable("REDIS_RESULTS_DB", 1)
REDIS_CELERY_DB = get_env_variable("REDIS_CELERY_DB", "0")
REDIS_RESULTS_DB = get_env_variable("REDIS_RESULTS_DB", "1")
RESULTS_BACKEND = FileSystemCache("/app/superset_home/sqllab")

View File

@ -33,7 +33,7 @@ Example:
./cancel_github_workflows.py 1024 --include-last
"""
import os
from typing import Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
import click
import requests
@ -45,7 +45,9 @@ github_token = os.environ.get("GITHUB_TOKEN")
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
def request(method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs):
def request(
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
) -> Dict[str, Any]:
resp = requests.request(
method,
f"https://api.github.com/{endpoint.lstrip('/')}",
@ -57,10 +59,14 @@ def request(method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kw
return resp
def list_runs(repo: str, params=None):
def list_runs(
repo: str, params: Optional[Dict[str, str]] = None,
) -> Iterator[Dict[str, Any]]:
"""List all github workflow runs.
Returns:
An iterator that will iterate through all pages of matching runs."""
if params is None:
params = {}
page = 1
total_count = 10000
while page * 100 < total_count:
@ -75,11 +81,11 @@ def list_runs(repo: str, params=None):
page += 1
def cancel_run(repo: str, run_id: Union[str, int]):
def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]:
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")
def get_pull_request(repo: str, pull_number: Union[str, int]):
def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]:
return request("GET", f"/repos/{repo}/pulls/{pull_number}")
@ -89,7 +95,7 @@ def get_runs(
user: Optional[str] = None,
statuses: Iterable[str] = ("queued", "in_progress"),
events: Iterable[str] = ("pull_request", "push"),
):
) -> List[Dict[str, Any]]:
"""Get workflow runs associated with the given branch"""
return [
item
@ -101,7 +107,7 @@ def get_runs(
]
def print_commit(commit, branch):
def print_commit(commit: Dict[str, Any], branch: str) -> None:
"""Print out commit message for verification"""
indented_message = " \n".join(commit["message"].split("\n"))
date_str = (
@ -151,7 +157,7 @@ def cancel_github_workflows(
event: List[str],
include_last: bool,
include_running: bool,
):
) -> None:
"""Cancel running or queued GitHub workflows by branch or pull request ID"""
if not github_token:
raise ClickException("Please provide GITHUB_TOKEN as an env variable")
@ -231,7 +237,7 @@ def cancel_github_workflows(
try:
print(f"[{entry['status']}] {entry['name']}", end="\r")
cancel_run(repo, entry["id"])
print(f"[Cancled] {entry['name']} ")
print(f"[Canceled] {entry['name']} ")
except ClickException as error:
print(f"[Error: {error.message}] {entry['name']} ")
print("")

View File

@ -19,7 +19,7 @@ from collections import defaultdict
from superset import security_manager
def cleanup_permissions():
def cleanup_permissions() -> None:
# 1. Clean up duplicates.
pvms = security_manager.get_session.query(
security_manager.permissionview_model

View File

@ -30,21 +30,23 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false
[mypy]
check_untyped_defs = true
disallow_any_generics = true
disallow_untyped_calls = true
disallow_untyped_defs = true
ignore_missing_imports = true
no_implicit_optional = true
warn_unused_ignores = true
[mypy-superset.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
warn_unused_ignores = false
[mypy-superset.migrations.versions.*]
ignore_errors = true
[mypy-tests.*]
check_untyped_defs = false
disallow_untyped_calls = false
disallow_untyped_defs = false

View File

@ -32,7 +32,7 @@ with io.open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
def get_git_sha():
def get_git_sha() -> str:
try:
s = subprocess.check_output(["git", "rev-parse", "HEAD"])
return s.decode().strip()

View File

@ -170,9 +170,7 @@ class QueryObject:
# 2. { label: 'label_name' } - legacy format for a predefined metric
# 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric
self.metrics = metrics and [
x
if isinstance(x, str) or is_adhoc_metric(x)
else x["label"] # type: ignore
x if isinstance(x, str) or is_adhoc_metric(x) else x["label"]
for x in metrics
]

View File

@ -734,7 +734,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
self.table_name, schema=self.schema, show_cols=False, latest_partition=False
)
@property # type: ignore
@property
def health_check_message(self) -> Optional[str]:
check = config["DATASET_HEALTH_CHECK"]
return check(self) if check else None

View File

@ -97,7 +97,7 @@ class ValidateDatabaseParametersCommand(BaseCommand):
# try to connect
sqlalchemy_uri = engine_spec.build_sqlalchemy_uri(
self._properties.get("parameters", None), # type: ignore
self._properties.get("parameters"), # type: ignore
encrypted_extra,
)
if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri():

View File

@ -1444,7 +1444,7 @@ class BasicParametersMixin:
errors: List[SupersetError] = []
required = {"host", "port", "username", "database"}
present = {key for key in parameters if parameters.get(key, ())} # type: ignore
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)
if missing:

View File

@ -200,7 +200,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
return message
@classmethod
def get_column_spec( # type: ignore
def get_column_spec(
cls,
native_type: Optional[str],
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,

View File

@ -276,7 +276,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
return extra
@classmethod
def get_column_spec( # type: ignore
def get_column_spec(
cls,
native_type: Optional[str],
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,

View File

@ -1212,7 +1212,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
return super().is_readonly_query(parsed_query) or parsed_query.is_show()
@classmethod
def get_column_spec( # type: ignore
def get_column_spec(
cls,
native_type: Optional[str],
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,

View File

@ -66,7 +66,7 @@ class SupersetAppInitializer:
self.manifest: Dict[Any, Any] = {}
@deprecated(details="use self.superset_app instead of self.flask_app") # type: ignore # pylint: disable=line-too-long
@property # type: ignore
@property
def flask_app(self) -> SupersetApp:
return self.superset_app
@ -99,7 +99,7 @@ class SupersetAppInitializer:
# Grab each call into the task and set up an app context
def __call__(self, *args: Any, **kwargs: Any) -> Any:
with superset_app.app_context(): # type: ignore
with superset_app.app_context():
return task_base.__call__(self, *args, **kwargs)
celery_app.Task = AppContextTask
@ -573,7 +573,7 @@ class SupersetAppInitializer:
self.configure_middlewares()
self.configure_cache()
with self.superset_app.app_context(): # type: ignore
with self.superset_app.app_context():
self.init_app_in_ctx()
self.post_init()
@ -689,7 +689,7 @@ class SupersetAppInitializer:
def setup_db(self) -> None:
db.init_app(self.superset_app)
with self.superset_app.app_context(): # type: ignore
with self.superset_app.app_context():
pessimistic_connection_handling(db.engine)
migrate.init_app(self.superset_app, db=db, directory=APP_DIR + "/migrations")

View File

@ -37,7 +37,7 @@ from superset.utils import core as utils
from superset.utils.hashing import md5_sha_from_str
from superset.utils.memoized import memoized
from superset.utils.urls import get_url_path
from superset.viz import BaseViz, viz_types # type: ignore
from superset.viz import BaseViz, viz_types
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource

View File

@ -352,7 +352,7 @@ def _get_slice_data(
# Parse the csv file and generate HTML
columns = rows.pop(0)
with app.app_context(): # type: ignore
with app.app_context():
body = render_template(
"superset/reports/slice_data.html",
columns=columns,

View File

@ -35,7 +35,7 @@ from superset.utils.hashing import md5_sha_from_dict
if TYPE_CHECKING:
from superset.stats_logger import BaseStatsLogger
config = app.config # type: ignore
config = app.config
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, invalid-name
from flask.ctx import AppContext
from pytest_mock import MockFixture
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@ -24,16 +27,30 @@ class ProgrammingError(Exception):
"""
def test_validate_parameters_simple(mocker, app_context):
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
def test_validate_parameters_simple(
mocker: MockFixture, app_context: AppContext,
) -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,
GSheetsParametersType,
)
parameters = {}
parameters: GSheetsParametersType = {
"credentials_info": {},
"query": {},
"table_catalog": {},
}
errors = GSheetsEngineSpec.validate_parameters(parameters)
assert errors == []
def test_validate_parameters_catalog(mocker, app_context):
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
def test_validate_parameters_catalog(
mocker: MockFixture, app_context: AppContext,
) -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,
GSheetsParametersType,
)
g = mocker.patch("superset.db_engine_specs.gsheets.g")
g.user.email = "admin@example.com"
@ -47,7 +64,9 @@ def test_validate_parameters_catalog(mocker, app_context):
ProgrammingError("Unsupported table: https://www.google.com/"),
]
parameters = {
parameters: GSheetsParametersType = {
"credentials_info": {},
"query": {},
"table_catalog": {
"private_sheet": "https://docs.google.com/spreadsheets/d/1/edit",
"public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1",
@ -114,12 +133,17 @@ def test_validate_parameters_catalog(mocker, app_context):
),
]
create_engine.assert_called_with(
"gsheets://", service_account_info=None, subject="admin@example.com",
"gsheets://", service_account_info={}, subject="admin@example.com",
)
def test_validate_parameters_catalog_and_credentials(mocker, app_context):
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
def test_validate_parameters_catalog_and_credentials(
mocker: MockFixture, app_context: AppContext,
) -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,
GSheetsParametersType,
)
g = mocker.patch("superset.db_engine_specs.gsheets.g")
g.user.email = "admin@example.com"
@ -133,13 +157,14 @@ def test_validate_parameters_catalog_and_credentials(mocker, app_context):
ProgrammingError("Unsupported table: https://www.google.com/"),
]
parameters = {
parameters: GSheetsParametersType = {
"credentials_info": {},
"query": {},
"table_catalog": {
"private_sheet": "https://docs.google.com/spreadsheets/d/1/edit",
"public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1",
"not_a_sheet": "https://www.google.com/",
},
"credentials_info": "SECRET",
}
errors = GSheetsEngineSpec.validate_parameters(parameters)
assert errors == [
@ -173,5 +198,5 @@ def test_validate_parameters_catalog_and_credentials(mocker, app_context):
),
]
create_engine.assert_called_with(
"gsheets://", service_account_info="SECRET", subject="admin@example.com",
"gsheets://", service_account_info={}, subject="admin@example.com",
)