diff --git a/setup.cfg b/setup.cfg index 8ca150e80..038d7b846 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,3 +51,4 @@ order_by_type = false [mypy] ignore_missing_imports = true +no_implicit_optional = true diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 97e86fb9f..b38f96db7 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -128,7 +128,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods try_remove_schema_from_table_name = True # pylint: disable=invalid-name @classmethod - def get_allow_cost_estimate(cls, version: str = None) -> bool: + def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool: return False @classmethod @@ -686,7 +686,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def estimate_query_cost( - cls, database, schema: str, sql: str, source: str = None + cls, database, schema: str, sql: str, source: Optional[str] = None ) -> List[Dict[str, str]]: """ Estimate the cost of a multiple statement SQL query. diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index ac06bc411..8785149e8 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -358,7 +358,7 @@ class HiveEngineSpec(PrestoEngineSpec): database, table_name: str, engine: Engine, - schema: str = None, + schema: Optional[str] = None, limit: int = 100, show_cols: bool = False, indent: bool = True, diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 3e4bda5d1..67d1e3879 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -113,7 +113,7 @@ class PrestoEngineSpec(BaseEngineSpec): } @classmethod - def get_allow_cost_estimate(cls, version: str = None) -> bool: + def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool: return version is not None and StrictVersion(version) >= StrictVersion("0.319") @classmethod @@ -395,7 +395,7 @@ class PrestoEngineSpec(BaseEngineSpec): database, table_name: str, engine: Engine, - schema: str = None, + schema: Optional[str] = None, limit: int = 100, show_cols: bool = False, indent: bool = True, diff --git a/superset/models/core.py b/superset/models/core.py index 3ad341bf0..af0d21ee5 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -429,7 +429,7 @@ class Database( attribute_in_key="id", ) def get_all_table_names_in_database( - self, cache: bool = False, cache_timeout: bool = None, force=False + self, cache: bool = False, cache_timeout: Optional[bool] = None, force=False ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: @@ -441,7 +441,10 @@ class Database( attribute_in_key="id", # type: ignore ) def get_all_view_names_in_database( - self, cache: bool = False, cache_timeout: bool = None, force: bool = False + self, + cache: bool = False, + cache_timeout: Optional[bool] = None, + force: bool = False, ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: @@ -514,7 +517,10 @@ class Database( key=lambda *args, **kwargs: "db:{}:schema_list", attribute_in_key="id" ) def get_all_schema_names( - self, cache: bool = False, cache_timeout: int = None, force: bool = False + self, + cache: bool = False, + cache_timeout: Optional[int] = None, + force: bool = False, ) -> List[str]: """Parameters need to be passed as keyword arguments. diff --git a/superset/security/manager.py b/superset/security/manager.py index 59e7b7538..f0afae1f6 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -311,12 +311,12 @@ class SupersetSecurityManager(SecurityManager): return conf.get("PERMISSION_INSTRUCTIONS_LINK") def can_access_datasource( - self, database: "Database", table_name: str, schema: str = None + self, database: "Database", table_name: str, schema: Optional[str] = None ) -> bool: return self._datasource_access_by_name(database, table_name, schema=schema) def _datasource_access_by_name( - self, database: "Database", table_name: str, schema: str = None + self, database: "Database", table_name: str, schema: Optional[str] = None ) -> bool: """ Return True if the user can access the SQL table, False otherwise. diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py index 704ede4e3..e4795c53c 100644 --- a/superset/views/base_schemas.py +++ b/superset/views/base_schemas.py @@ -55,7 +55,7 @@ class BaseSupersetSchema(Schema): return super().load(data, many=many, partial=partial, **kwargs) @post_load - def make_object(self, data: Dict, discard: List[str] = None) -> Model: + def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model: """ Creates a Model object from POST or PUT requests. PUT will use self.instance previously fetched from the endpoint handler @@ -81,7 +81,7 @@ class BaseOwnedSchema(BaseSupersetSchema): owners_field_name = "owners" @post_load - def make_object(self, data: Dict, discard: List[str] = None) -> Model: + def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model: discard = discard or [] discard.append(self.owners_field_name) instance = super().make_object(data, discard) diff --git a/superset/views/chart/api.py b/superset/views/chart/api.py index f3bbcbd2a..6389b9352 100644 --- a/superset/views/chart/api.py +++ b/superset/views/chart/api.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List +from typing import Dict, List, Optional from flask import current_app from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -95,7 +95,7 @@ class ChartPostSchema(BaseOwnedSchema): validate_update_datasource(data) @post_load - def make_object(self, data: Dict, discard: List[str] = None) -> Slice: + def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Slice: instance = super().make_object(data, discard=["dashboards"]) populate_dashboards(instance, data.get("dashboards", [])) return instance @@ -119,7 +119,7 @@ class ChartPutSchema(BaseOwnedSchema): validate_update_datasource(data) @post_load - def make_object(self, data: Dict, discard: List[str] = None) -> Slice: + def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Slice: self.instance = super().make_object(data, ["dashboards"]) if "dashboards" in data: populate_dashboards(self.instance, data["dashboards"]) diff --git a/superset/views/core.py b/superset/views/core.py index c127ce722..69347f1c5 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -166,7 +166,7 @@ def is_owner(obj, user): def check_datasource_perms( - self, datasource_type: str = None, datasource_id: int = None + self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None ) -> None: """ Check if user can access a cached response from explore_json. @@ -1973,7 +1973,9 @@ class Superset(BaseSupersetView): @expose("/estimate_query_cost//", methods=["POST"]) @expose("/estimate_query_cost///", methods=["POST"]) @event_logger.log_this - def estimate_query_cost(self, database_id: int, schema: str = None) -> Response: + def estimate_query_cost( + self, database_id: int, schema: Optional[str] = None + ) -> Response: mydb = db.session.query(models.Database).get(database_id) sql = json.loads(request.form.get("sql", '""')) diff --git a/superset/views/dashboard/api.py b/superset/views/dashboard/api.py index 80cac5569..d11def5ba 100644 --- a/superset/views/dashboard/api.py +++ b/superset/views/dashboard/api.py @@ -17,7 +17,7 @@ import json import logging import re -from typing import Dict, List +from typing import Dict, List, Optional from flask import current_app, g, make_response from flask_appbuilder.api import expose, protect, rison, safe @@ -119,7 +119,7 @@ class DashboardPutSchema(BaseDashboardSchema): published = fields.Boolean() @post_load - def make_object(self, data: Dict, discard: List[str] = None) -> Dashboard: + def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Dashboard: self.instance = super().make_object(data, []) for slc in self.instance.slices: slc.owners = list(set(self.instance.owners) | set(slc.owners)) diff --git a/superset/views/database/api.py b/superset/views/database/api.py index 9a846ceaf..2eb5ad88c 100644 --- a/superset/views/database/api.py +++ b/superset/views/database/api.py @@ -283,7 +283,9 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi): @check_datasource_access @safe @event_logger.log_this - def select_star(self, database: Database, table_name: str, schema_name: str = None): + def select_star( + self, database: Database, table_name: str, schema_name: Optional[str] = None + ): """ Table schema info --- get: diff --git a/superset/views/database/decorators.py b/superset/views/database/decorators.py index ea72a3017..789fbce03 100644 --- a/superset/views/database/decorators.py +++ b/superset/views/database/decorators.py @@ -16,6 +16,7 @@ # under the License. import functools import logging +from typing import Optional from flask import g from flask_babel import lazy_gettext as _ @@ -32,7 +33,7 @@ def check_datasource_access(f): """ def wraps( - self, pk: int, table_name: str, schema_name: str = None + self, pk: int, table_name: str, schema_name: Optional[str] = None ): # pylint: disable=invalid-name schema_name_parsed = parse_js_uri_path_item(schema_name, eval_undefined=True) table_name_parsed = parse_js_uri_path_item(table_name) diff --git a/superset/viz.py b/superset/viz.py index ec26eda4e..0b0a6ac0a 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -1882,7 +1882,7 @@ class IFrameViz(BaseViz): def query_obj(self): return None - def get_df(self, query_obj: Dict[str, Any] = None) -> pd.DataFrame: + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: return pd.DataFrame() def get_data(self, df: pd.DataFrame) -> VizData: diff --git a/tests/chart_api_tests.py b/tests/chart_api_tests.py index 8a67b030c..307d4add9 100644 --- a/tests/chart_api_tests.py +++ b/tests/chart_api_tests.py @@ -41,9 +41,9 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): owners: List[int], datasource_id: int, datasource_type: str = "table", - description: str = None, - viz_type: str = None, - params: str = None, + description: Optional[str] = None, + viz_type: Optional[str] = None, + params: Optional[str] = None, cache_timeout: Optional[int] = None, ) -> Slice: obj_owners = list() diff --git a/tests/dashboard_api_tests.py b/tests/dashboard_api_tests.py index af16e5fbd..56fe4c0e5 100644 --- a/tests/dashboard_api_tests.py +++ b/tests/dashboard_api_tests.py @@ -17,7 +17,7 @@ # isort:skip_file """Unit tests for Superset""" import json -from typing import List +from typing import List, Optional import prison @@ -42,7 +42,7 @@ class DashboardApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): dashboard_title: str, slug: str, owners: List[int], - slices: List[Slice] = None, + slices: Optional[List[Slice]] = None, position_json: str = "", css: str = "", json_metadata: str = "",