[mypy] Enforcing typing for some modules (#9416)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-04-04 12:45:14 -07:00 committed by GitHub
parent 1cdfb829d7
commit 5e55e09e3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 39 additions and 27 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset.charts.*,superset.db_engine_specs.*]
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
from typing import Optional
from flask_appbuilder.models.sqla import Model
class BaseCommand(ABC):
@ -23,7 +26,7 @@ class BaseCommand(ABC):
"""
@abstractmethod
def run(self):
def run(self) -> Optional[Model]:
"""
Run executes the command. Can raise command exceptions
:raises: CommandException

View File

@ -25,10 +25,10 @@ from superset.exceptions import SupersetException
class CommandException(SupersetException):
""" Common base class for Command exceptions. """
def __repr__(self):
def __repr__(self) -> str:
if self._exception:
return self._exception
return self
return repr(self._exception)
return repr(self)
class CommandInvalidError(CommandException):
@ -36,14 +36,14 @@ class CommandInvalidError(CommandException):
status = 422
def __init__(self, message="") -> None:
def __init__(self, message: str = "") -> None:
self._invalid_exceptions: List[ValidationError] = []
super().__init__(self.message)
def add(self, exception: ValidationError):
def add(self, exception: ValidationError) -> None:
self._invalid_exceptions.append(exception)
def add_list(self, exceptions: List[ValidationError]):
def add_list(self, exceptions: List[ValidationError]) -> None:
self._invalid_exceptions.extend(exceptions)
def normalized_messages(self) -> Dict[Any, Any]:
@ -76,12 +76,12 @@ class ForbiddenError(CommandException):
class OwnersNotFoundValidationError(ValidationError):
status = 422
def __init__(self):
def __init__(self) -> None:
super().__init__(_("Owners are invalid"), field_names=["owners"])
class DatasourceNotFoundValidationError(ValidationError):
status = 404
def __init__(self):
def __init__(self) -> None:
super().__init__(_("Datasource does not exist"), field_names=["datasource_id"])

View File

@ -157,7 +157,7 @@ class QueryContext:
return self.datasource.database.cache_timeout
return config["CACHE_DEFAULT_TIMEOUT"]
def cache_key(self, query_obj: QueryObject, **kwargs) -> Optional[str]:
def cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
cache_key = (
query_obj.cache_key(
@ -173,7 +173,7 @@ class QueryContext:
return cache_key
def get_df_payload( # pylint: disable=too-many-locals,too-many-statements
self, query_obj: QueryObject, **kwargs
self, query_obj: QueryObject, **kwargs: Any
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
cache_key = self.cache_key(query_obj, **kwargs)

View File

@ -122,7 +122,7 @@ class QueryObject:
}
return query_object_dict
def cache_key(self, **extra) -> str:
def cache_key(self, **extra: Any) -> str:
"""
The cache key is made out of the key/values from to_dict(), plus any
other key/values in `extra`

View File

@ -14,14 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from sqlalchemy import Metadata
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import and_, func, functions, join, literal, select
from superset.models.tags import ObjectTypes, TagTypes
def add_types(engine, metadata):
def add_types(engine: Engine, metadata: Metadata) -> None:
"""
Tag every object according to its type:
@ -163,7 +164,7 @@ def add_types(engine, metadata):
engine.execute(query)
def add_owners(engine, metadata):
def add_owners(engine: Engine, metadata: Metadata) -> None:
"""
Tag every object according to its owner:
@ -319,7 +320,7 @@ def add_owners(engine, metadata):
engine.execute(query)
def add_favorites(engine, metadata):
def add_favorites(engine: Engine, metadata: Metadata) -> None:
"""
Tag every object that was favorited:

View File

@ -112,7 +112,7 @@ class BaseDAO:
return model
@classmethod
def delete(cls, model: Model, commit=True):
def delete(cls, model: Model, commit: bool = True) -> Model:
"""
Generic delete a model
:raises: DAOCreateFailedError

View File

@ -14,12 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from pyhive.hive import Cursor # pylint: disable=unused-import
from TCLIService.ttypes import TFetchOrientation # pylint: disable=unused-import
# pylint: disable=protected-access
# TODO: contribute back to pyhive.
def fetch_logs(
self, max_rows=1024, orientation=None
): # pylint: disable=unused-argument
self: "Cursor",
max_rows: int = 1024, # pylint: disable=unused-argument
orientation: Optional["TFetchOrientation"] = None,
) -> str: # pylint: disable=unused-argument
"""Mocked. Retrieve the logs produced by the execution of the query.
Can be called multiple times to fetch the logs produced after
the previous call.

View File

@ -24,26 +24,26 @@ logger = logging.getLogger(__name__)
class BaseStatsLogger:
"""Base class for logging realtime events"""
def __init__(self, prefix="superset"):
def __init__(self, prefix: str = "superset") -> None:
self.prefix = prefix
def key(self, key):
def key(self, key: str) -> str:
if self.prefix:
return self.prefix + key
return key
def incr(self, key):
def incr(self, key: str) -> None:
"""Increment a counter"""
raise NotImplementedError()
def decr(self, key):
def decr(self, key: str) -> None:
"""Decrement a counter"""
raise NotImplementedError()
def timing(self, key, value):
def timing(self, key, value: float) -> None:
raise NotImplementedError()
def gauge(self, key):
def gauge(self, key: str) -> None:
"""Setup a gauge"""
raise NotImplementedError()

View File

@ -1224,9 +1224,10 @@ class DatasourceName(NamedTuple):
schema: str
def get_stacktrace():
def get_stacktrace() -> Optional[str]:
if current_app.config["SHOW_STACKTRACE"]:
return traceback.format_exc()
return None
def split(