diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bca0923e8..80202bebf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,9 +24,10 @@ repos: hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.790 + rev: v0.910 hooks: - id: mypy + additional_dependencies: [types-all] - repo: https://github.com/peterdemin/pip-compile-multi rev: v2.4.1 hooks: diff --git a/RELEASING/changelog.py b/RELEASING/changelog.py index d6a1842fd..e9ff2de04 100644 --- a/RELEASING/changelog.py +++ b/RELEASING/changelog.py @@ -384,12 +384,12 @@ def change_log( with open(csv, "w") as csv_file: log_items = list(logs) field_names = log_items[0].keys() - writer = lib_csv.DictWriter( + writer = lib_csv.DictWriter( # type: ignore csv_file, delimiter=",", quotechar='"', quoting=lib_csv.QUOTE_ALL, - fieldnames=field_names, + fieldnames=field_names, # type: ignore ) writer.writeheader() for log in logs: diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py index e4e49906d..27670b5d4 100644 --- a/scripts/benchmark_migration.py +++ b/scripts/benchmark_migration.py @@ -44,9 +44,13 @@ def import_migration_script(filepath: Path) -> ModuleType: Import migration script as if it were a module. """ spec = importlib.util.spec_from_file_location(filepath.stem, filepath) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) # type: ignore - return module + if spec: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) # type: ignore + return module + raise Exception( + "No module spec found in location: `{path}`".format(path=str(filepath)) + ) def extract_modified_tables(module: ModuleType) -> Set[str]: diff --git a/setup.py b/setup.py index a38c7d2a2..d082733fa 100644 --- a/setup.py +++ b/setup.py @@ -106,14 +106,14 @@ setup( "simplejson>=3.15.0", "slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions "sqlalchemy>=1.3.16, <1.4, !=1.3.21", - "sqlalchemy-utils>=0.36.6,<0.37", + "sqlalchemy-utils>=0.36.6, <0.37", "sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562 "tabulate==0.8.9", - "typing-extensions>=3.7.4.3,<4", # needed to support typing.Literal on py37 + "typing-extensions>=3.10, <4", # needed to support Literal (3.8) and TypeGuard (3.10) "wtforms-json", ], extras_require={ - "athena": ["pyathena>=1.10.8,<1.11"], + "athena": ["pyathena>=1.10.8, <1.11"], "bigquery": [ "pandas_gbq>=0.10.0", "pybigquery>=0.4.10", diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index fe75fbe4a..78c0e8efe 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -89,7 +89,7 @@ WEBDRIVER_BASEURL = config["WEBDRIVER_BASEURL"] WEBDRIVER_BASEURL_USER_FRIENDLY = config["WEBDRIVER_BASEURL_USER_FRIENDLY"] ReportContent = namedtuple( - "EmailContent", + "ReportContent", [ "body", # email body "data", # attachments diff --git a/superset/typing.py b/superset/typing.py index 4273444fe..d076402df 100644 --- a/superset/typing.py +++ b/superset/typing.py @@ -29,7 +29,7 @@ from typing import ( from flask import Flask from flask_caching import Cache -from typing_extensions import TypedDict +from typing_extensions import Literal, TypedDict from werkzeug.wrappers import Response if TYPE_CHECKING: @@ -57,7 +57,7 @@ class AdhocMetricColumn(TypedDict, total=False): class AdhocMetric(TypedDict, total=False): aggregate: str column: Optional[AdhocMetricColumn] - expressionType: str + expressionType: Literal["SIMPLE", "SQL"] label: Optional[str] sqlExpression: Optional[str] diff --git a/superset/utils/async_query_manager.py b/superset/utils/async_query_manager.py index 90428a11e..258024f7f 100644 --- a/superset/utils/async_query_manager.py +++ b/superset/utils/async_query_manager.py @@ -73,7 +73,7 @@ class AsyncQueryManager: def __init__(self) -> None: super().__init__() - self._redis: redis.Redis + self._redis: redis.Redis # type: ignore self._stream_prefix: str = "" self._stream_limit: Optional[int] self._stream_limit_firehose: Optional[int] @@ -100,7 +100,7 @@ class AsyncQueryManager: "Please provide a JWT secret at least 32 bytes long" ) - self._redis = redis.Redis( # type: ignore + self._redis = redis.Redis( **config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True ) self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] diff --git a/superset/utils/core.py b/superset/utils/core.py index 2e34fd4f8..062c3fa50 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -81,7 +81,7 @@ from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql.type_api import Variant from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine -from typing_extensions import TypedDict +from typing_extensions import TypedDict, TypeGuard import _thread # pylint: disable=C0411 from superset.constants import ( @@ -1275,7 +1275,7 @@ def backend() -> str: return get_example_database().backend -def is_adhoc_metric(metric: Metric) -> bool: +def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]: return isinstance(metric, dict) and "expressionType" in metric @@ -1288,7 +1288,6 @@ def get_metric_name(metric: Metric) -> str: :raises ValueError: if metric object is invalid """ if is_adhoc_metric(metric): - metric = cast(AdhocMetric, metric) label = metric.get("label") if label: return label @@ -1306,7 +1305,7 @@ def get_metric_name(metric: Metric) -> str: if column_name: return column_name raise ValueError(__("Invalid metric object")) - return cast(str, metric) + return metric # type: ignore def get_metric_names(metrics: Sequence[Metric]) -> List[str]: diff --git a/tests/unit_tests/core_tests.py b/tests/unit_tests/core_tests.py index bb3e50f51..51d1d0993 100644 --- a/tests/unit_tests/core_tests.py +++ b/tests/unit_tests/core_tests.py @@ -23,6 +23,7 @@ from superset.utils.core import ( GenericDataType, get_metric_name, get_metric_names, + is_adhoc_metric, ) STR_METRIC = "my_metric" @@ -91,3 +92,9 @@ def test_get_metric_names(): assert get_metric_names( [STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC] ) == ["my_metric", "my SUM", "my_sql"] + + +def test_is_adhoc_metric(): + assert is_adhoc_metric(STR_METRIC) is False + assert is_adhoc_metric(SIMPLE_SUM_ADHOC_METRIC) is True + assert is_adhoc_metric(SQL_ADHOC_METRIC) is True