[mypy] Disallowing implicit optional (#9150)

This commit is contained in:
John Bodley 2020-02-16 22:34:16 -08:00 committed by GitHub
parent 114642d78c
commit a7e433a512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 39 additions and 27 deletions

View File

@ -51,3 +51,4 @@ order_by_type = false
[mypy] [mypy]
ignore_missing_imports = true ignore_missing_imports = true
no_implicit_optional = true

View File

@ -128,7 +128,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try_remove_schema_from_table_name = True # pylint: disable=invalid-name try_remove_schema_from_table_name = True # pylint: disable=invalid-name
@classmethod @classmethod
def get_allow_cost_estimate(cls, version: str = None) -> bool: def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
return False return False
@classmethod @classmethod
@ -686,7 +686,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def estimate_query_cost( def estimate_query_cost(
cls, database, schema: str, sql: str, source: str = None cls, database, schema: str, sql: str, source: Optional[str] = None
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
""" """
Estimate the cost of a multiple statement SQL query. Estimate the cost of a multiple statement SQL query.

View File

@ -358,7 +358,7 @@ class HiveEngineSpec(PrestoEngineSpec):
database, database,
table_name: str, table_name: str,
engine: Engine, engine: Engine,
schema: str = None, schema: Optional[str] = None,
limit: int = 100, limit: int = 100,
show_cols: bool = False, show_cols: bool = False,
indent: bool = True, indent: bool = True,

View File

@ -113,7 +113,7 @@ class PrestoEngineSpec(BaseEngineSpec):
} }
@classmethod @classmethod
def get_allow_cost_estimate(cls, version: str = None) -> bool: def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
return version is not None and StrictVersion(version) >= StrictVersion("0.319") return version is not None and StrictVersion(version) >= StrictVersion("0.319")
@classmethod @classmethod
@ -395,7 +395,7 @@ class PrestoEngineSpec(BaseEngineSpec):
database, database,
table_name: str, table_name: str,
engine: Engine, engine: Engine,
schema: str = None, schema: Optional[str] = None,
limit: int = 100, limit: int = 100,
show_cols: bool = False, show_cols: bool = False,
indent: bool = True, indent: bool = True,

View File

@ -429,7 +429,7 @@ class Database(
attribute_in_key="id", attribute_in_key="id",
) )
def get_all_table_names_in_database( def get_all_table_names_in_database(
self, cache: bool = False, cache_timeout: bool = None, force=False self, cache: bool = False, cache_timeout: Optional[bool] = None, force=False
) -> List[utils.DatasourceName]: ) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments.""" """Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch: if not self.allow_multi_schema_metadata_fetch:
@ -441,7 +441,10 @@ class Database(
attribute_in_key="id", # type: ignore attribute_in_key="id", # type: ignore
) )
def get_all_view_names_in_database( def get_all_view_names_in_database(
self, cache: bool = False, cache_timeout: bool = None, force: bool = False self,
cache: bool = False,
cache_timeout: Optional[bool] = None,
force: bool = False,
) -> List[utils.DatasourceName]: ) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments.""" """Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch: if not self.allow_multi_schema_metadata_fetch:
@ -514,7 +517,10 @@ class Database(
key=lambda *args, **kwargs: "db:{}:schema_list", attribute_in_key="id" key=lambda *args, **kwargs: "db:{}:schema_list", attribute_in_key="id"
) )
def get_all_schema_names( def get_all_schema_names(
self, cache: bool = False, cache_timeout: int = None, force: bool = False self,
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[str]: ) -> List[str]:
"""Parameters need to be passed as keyword arguments. """Parameters need to be passed as keyword arguments.

View File

@ -311,12 +311,12 @@ class SupersetSecurityManager(SecurityManager):
return conf.get("PERMISSION_INSTRUCTIONS_LINK") return conf.get("PERMISSION_INSTRUCTIONS_LINK")
def can_access_datasource( def can_access_datasource(
self, database: "Database", table_name: str, schema: str = None self, database: "Database", table_name: str, schema: Optional[str] = None
) -> bool: ) -> bool:
return self._datasource_access_by_name(database, table_name, schema=schema) return self._datasource_access_by_name(database, table_name, schema=schema)
def _datasource_access_by_name( def _datasource_access_by_name(
self, database: "Database", table_name: str, schema: str = None self, database: "Database", table_name: str, schema: Optional[str] = None
) -> bool: ) -> bool:
""" """
Return True if the user can access the SQL table, False otherwise. Return True if the user can access the SQL table, False otherwise.

View File

@ -55,7 +55,7 @@ class BaseSupersetSchema(Schema):
return super().load(data, many=many, partial=partial, **kwargs) return super().load(data, many=many, partial=partial, **kwargs)
@post_load @post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Model: def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
""" """
Creates a Model object from POST or PUT requests. PUT will use self.instance Creates a Model object from POST or PUT requests. PUT will use self.instance
previously fetched from the endpoint handler previously fetched from the endpoint handler
@ -81,7 +81,7 @@ class BaseOwnedSchema(BaseSupersetSchema):
owners_field_name = "owners" owners_field_name = "owners"
@post_load @post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Model: def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
discard = discard or [] discard = discard or []
discard.append(self.owners_field_name) discard.append(self.owners_field_name)
instance = super().make_object(data, discard) instance = super().make_object(data, discard)

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Dict, List from typing import Dict, List, Optional
from flask import current_app from flask import current_app
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -95,7 +95,7 @@ class ChartPostSchema(BaseOwnedSchema):
validate_update_datasource(data) validate_update_datasource(data)
@post_load @post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Slice: def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Slice:
instance = super().make_object(data, discard=["dashboards"]) instance = super().make_object(data, discard=["dashboards"])
populate_dashboards(instance, data.get("dashboards", [])) populate_dashboards(instance, data.get("dashboards", []))
return instance return instance
@ -119,7 +119,7 @@ class ChartPutSchema(BaseOwnedSchema):
validate_update_datasource(data) validate_update_datasource(data)
@post_load @post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Slice: def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Slice:
self.instance = super().make_object(data, ["dashboards"]) self.instance = super().make_object(data, ["dashboards"])
if "dashboards" in data: if "dashboards" in data:
populate_dashboards(self.instance, data["dashboards"]) populate_dashboards(self.instance, data["dashboards"])

View File

@ -166,7 +166,7 @@ def is_owner(obj, user):
def check_datasource_perms( def check_datasource_perms(
self, datasource_type: str = None, datasource_id: int = None self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
) -> None: ) -> None:
""" """
Check if user can access a cached response from explore_json. Check if user can access a cached response from explore_json.
@ -1973,7 +1973,9 @@ class Superset(BaseSupersetView):
@expose("/estimate_query_cost/<database_id>/", methods=["POST"]) @expose("/estimate_query_cost/<database_id>/", methods=["POST"])
@expose("/estimate_query_cost/<database_id>/<schema>/", methods=["POST"]) @expose("/estimate_query_cost/<database_id>/<schema>/", methods=["POST"])
@event_logger.log_this @event_logger.log_this
def estimate_query_cost(self, database_id: int, schema: str = None) -> Response: def estimate_query_cost(
self, database_id: int, schema: Optional[str] = None
) -> Response:
mydb = db.session.query(models.Database).get(database_id) mydb = db.session.query(models.Database).get(database_id)
sql = json.loads(request.form.get("sql", '""')) sql = json.loads(request.form.get("sql", '""'))

View File

@ -17,7 +17,7 @@
import json import json
import logging import logging
import re import re
from typing import Dict, List from typing import Dict, List, Optional
from flask import current_app, g, make_response from flask import current_app, g, make_response
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
@ -119,7 +119,7 @@ class DashboardPutSchema(BaseDashboardSchema):
published = fields.Boolean() published = fields.Boolean()
@post_load @post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Dashboard: def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Dashboard:
self.instance = super().make_object(data, []) self.instance = super().make_object(data, [])
for slc in self.instance.slices: for slc in self.instance.slices:
slc.owners = list(set(self.instance.owners) | set(slc.owners)) slc.owners = list(set(self.instance.owners) | set(slc.owners))

View File

@ -283,7 +283,9 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
@check_datasource_access @check_datasource_access
@safe @safe
@event_logger.log_this @event_logger.log_this
def select_star(self, database: Database, table_name: str, schema_name: str = None): def select_star(
self, database: Database, table_name: str, schema_name: Optional[str] = None
):
""" Table schema info """ Table schema info
--- ---
get: get:

View File

@ -16,6 +16,7 @@
# under the License. # under the License.
import functools import functools
import logging import logging
from typing import Optional
from flask import g from flask import g
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
@ -32,7 +33,7 @@ def check_datasource_access(f):
""" """
def wraps( def wraps(
self, pk: int, table_name: str, schema_name: str = None self, pk: int, table_name: str, schema_name: Optional[str] = None
): # pylint: disable=invalid-name ): # pylint: disable=invalid-name
schema_name_parsed = parse_js_uri_path_item(schema_name, eval_undefined=True) schema_name_parsed = parse_js_uri_path_item(schema_name, eval_undefined=True)
table_name_parsed = parse_js_uri_path_item(table_name) table_name_parsed = parse_js_uri_path_item(table_name)

View File

@ -1882,7 +1882,7 @@ class IFrameViz(BaseViz):
def query_obj(self): def query_obj(self):
return None return None
def get_df(self, query_obj: Dict[str, Any] = None) -> pd.DataFrame: def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
return pd.DataFrame() return pd.DataFrame()
def get_data(self, df: pd.DataFrame) -> VizData: def get_data(self, df: pd.DataFrame) -> VizData:

View File

@ -41,9 +41,9 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
owners: List[int], owners: List[int],
datasource_id: int, datasource_id: int,
datasource_type: str = "table", datasource_type: str = "table",
description: str = None, description: Optional[str] = None,
viz_type: str = None, viz_type: Optional[str] = None,
params: str = None, params: Optional[str] = None,
cache_timeout: Optional[int] = None, cache_timeout: Optional[int] = None,
) -> Slice: ) -> Slice:
obj_owners = list() obj_owners = list()

View File

@ -17,7 +17,7 @@
# isort:skip_file # isort:skip_file
"""Unit tests for Superset""" """Unit tests for Superset"""
import json import json
from typing import List from typing import List, Optional
import prison import prison
@ -42,7 +42,7 @@ class DashboardApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
dashboard_title: str, dashboard_title: str,
slug: str, slug: str,
owners: List[int], owners: List[int],
slices: List[Slice] = None, slices: Optional[List[Slice]] = None,
position_json: str = "", position_json: str = "",
css: str = "", css: str = "",
json_metadata: str = "", json_metadata: str = "",