chore: upgrade mypy and add type guards (#16227)

This commit is contained in:
Ville Brofeldt 2021-08-14 06:31:45 +03:00 committed by GitHub
parent 9b2dffeb1d
commit d46dc9aa45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 29 additions and 18 deletions

View File

@ -24,9 +24,10 @@ repos:
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.790
rev: v0.910
hooks:
- id: mypy
additional_dependencies: [types-all]
- repo: https://github.com/peterdemin/pip-compile-multi
rev: v2.4.1
hooks:

View File

@ -384,12 +384,12 @@ def change_log(
with open(csv, "w") as csv_file:
log_items = list(logs)
field_names = log_items[0].keys()
writer = lib_csv.DictWriter(
writer = lib_csv.DictWriter( # type: ignore
csv_file,
delimiter=",",
quotechar='"',
quoting=lib_csv.QUOTE_ALL,
fieldnames=field_names,
fieldnames=field_names, # type: ignore
)
writer.writeheader()
for log in logs:

View File

@ -44,9 +44,13 @@ def import_migration_script(filepath: Path) -> ModuleType:
Import migration script as if it were a module.
"""
spec = importlib.util.spec_from_file_location(filepath.stem, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
if spec:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
raise Exception(
"No module spec found in location: `{path}`".format(path=str(filepath))
)
def extract_modified_tables(module: ModuleType) -> Set[str]:

View File

@ -106,14 +106,14 @@ setup(
"simplejson>=3.15.0",
"slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions
"sqlalchemy>=1.3.16, <1.4, !=1.3.21",
"sqlalchemy-utils>=0.36.6,<0.37",
"sqlalchemy-utils>=0.36.6, <0.37",
"sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562
"tabulate==0.8.9",
"typing-extensions>=3.7.4.3,<4", # needed to support typing.Literal on py37
"typing-extensions>=3.10, <4", # needed to support Literal (3.8) and TypeGuard (3.10)
"wtforms-json",
],
extras_require={
"athena": ["pyathena>=1.10.8,<1.11"],
"athena": ["pyathena>=1.10.8, <1.11"],
"bigquery": [
"pandas_gbq>=0.10.0",
"pybigquery>=0.4.10",

View File

@ -89,7 +89,7 @@ WEBDRIVER_BASEURL = config["WEBDRIVER_BASEURL"]
WEBDRIVER_BASEURL_USER_FRIENDLY = config["WEBDRIVER_BASEURL_USER_FRIENDLY"]
ReportContent = namedtuple(
"EmailContent",
"ReportContent",
[
"body", # email body
"data", # attachments

View File

@ -29,7 +29,7 @@ from typing import (
from flask import Flask
from flask_caching import Cache
from typing_extensions import TypedDict
from typing_extensions import Literal, TypedDict
from werkzeug.wrappers import Response
if TYPE_CHECKING:
@ -57,7 +57,7 @@ class AdhocMetricColumn(TypedDict, total=False):
class AdhocMetric(TypedDict, total=False):
aggregate: str
column: Optional[AdhocMetricColumn]
expressionType: str
expressionType: Literal["SIMPLE", "SQL"]
label: Optional[str]
sqlExpression: Optional[str]

View File

@ -73,7 +73,7 @@ class AsyncQueryManager:
def __init__(self) -> None:
super().__init__()
self._redis: redis.Redis
self._redis: redis.Redis # type: ignore
self._stream_prefix: str = ""
self._stream_limit: Optional[int]
self._stream_limit_firehose: Optional[int]
@ -100,7 +100,7 @@ class AsyncQueryManager:
"Please provide a JWT secret at least 32 bytes long"
)
self._redis = redis.Redis( # type: ignore
self._redis = redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]

View File

@ -81,7 +81,7 @@ from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
from typing_extensions import TypedDict
from typing_extensions import TypedDict, TypeGuard
import _thread # pylint: disable=C0411
from superset.constants import (
@ -1275,7 +1275,7 @@ def backend() -> str:
return get_example_database().backend
def is_adhoc_metric(metric: Metric) -> bool:
def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]:
return isinstance(metric, dict) and "expressionType" in metric
@ -1288,7 +1288,6 @@ def get_metric_name(metric: Metric) -> str:
:raises ValueError: if metric object is invalid
"""
if is_adhoc_metric(metric):
metric = cast(AdhocMetric, metric)
label = metric.get("label")
if label:
return label
@ -1306,7 +1305,7 @@ def get_metric_name(metric: Metric) -> str:
if column_name:
return column_name
raise ValueError(__("Invalid metric object"))
return cast(str, metric)
return metric # type: ignore
def get_metric_names(metrics: Sequence[Metric]) -> List[str]:

View File

@ -23,6 +23,7 @@ from superset.utils.core import (
GenericDataType,
get_metric_name,
get_metric_names,
is_adhoc_metric,
)
STR_METRIC = "my_metric"
@ -91,3 +92,9 @@ def test_get_metric_names():
assert get_metric_names(
[STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC]
) == ["my_metric", "my SUM", "my_sql"]
def test_is_adhoc_metric():
assert is_adhoc_metric(STR_METRIC) is False
assert is_adhoc_metric(SIMPLE_SUM_ADHOC_METRIC) is True
assert is_adhoc_metric(SQL_ADHOC_METRIC) is True