From 9f5f8e5d92f4d346acbedfbd2a1105f7194fea78 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Mon, 17 Feb 2020 23:08:11 -0800 Subject: [PATCH] [mypy] Enforcing typing for db_engine_specs (#9138) --- setup.cfg | 5 ++ superset/db_engine_specs/base.py | 74 +++++++++++++++----------- superset/db_engine_specs/bigquery.py | 17 +++--- superset/db_engine_specs/drill.py | 5 +- superset/db_engine_specs/druid.py | 9 +++- superset/db_engine_specs/exasol.py | 4 +- superset/db_engine_specs/hive.py | 58 ++++++++++++-------- superset/db_engine_specs/mssql.py | 6 +-- superset/db_engine_specs/mysql.py | 8 +-- superset/db_engine_specs/postgres.py | 4 +- superset/db_engine_specs/presto.py | 62 +++++++++++++-------- superset/db_engine_specs/snowflake.py | 7 ++- superset/db_engine_specs/sqlite.py | 2 +- superset/models/core.py | 2 +- superset/sql_parse.py | 2 +- superset/utils/core.py | 10 ++-- superset/utils/feature_flag_manager.py | 2 +- 17 files changed, 173 insertions(+), 104 deletions(-) diff --git a/setup.cfg b/setup.cfg index 038d7b846..46dde49f6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,3 +52,8 @@ order_by_type = false [mypy] ignore_missing_imports = true no_implicit_optional = true + +[mypy-superset.db_engine_specs.*] +check_untyped_defs = true +disallow_untyped_calls = true +disallow_untyped_defs = true diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index b38f96db7..ff40b2a19 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -30,16 +30,23 @@ from sqlalchemy import column, DateTime, select from sqlalchemy.engine.base import Engine from sqlalchemy.engine.interfaces import Compiled, Dialect from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles +from sqlalchemy.orm import Session from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom from sqlalchemy.types import TypeEngine +from wtforms.form import Form from superset import app, sql_parse +from superset.models.sql_lab import Query from superset.utils import core as utils if TYPE_CHECKING: # prevent circular imports + from superset.connectors.sqla.models import ( # pylint: disable=unused-import + TableColumn, + ) from superset.models.core import Database # pylint: disable=unused-import @@ -77,7 +84,7 @@ builtin_time_grains: Dict[Optional[str], str] = { class TimestampExpression( ColumnClause ): # pylint: disable=abstract-method,too-many-ancestors,too-few-public-methods - def __init__(self, expr: str, col: ColumnClause, **kwargs): + def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None: """Sqlalchemy class that can be can be used to render native column elements respeting engine-specific quoting rules as part of a string-based expression. @@ -89,7 +96,7 @@ class TimestampExpression( self.col = col @property - def _constructor(self): + def _constructor(self) -> ColumnClause: # Needed to ensure that the column label is rendered correctly when # proxied to the outer query. # See https://github.com/sqlalchemy/sqlalchemy/issues/4730 @@ -98,9 +105,9 @@ class TimestampExpression( @compiles(TimestampExpression) def compile_timegrain_expression( - element: TimestampExpression, compiler: Compiled, **kw + element: TimestampExpression, compiler: Compiled, **kwargs: Any ) -> str: - return element.name.replace("{col}", compiler.process(element.col, **kw)) + return element.name.replace("{col}", compiler.process(element.col, **kwargs)) class LimitMethod: # pylint: disable=too-few-public-methods @@ -132,7 +139,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return False @classmethod - def get_engine(cls, database, schema=None, source=None): + def get_engine( + cls, + database: "Database", + schema: Optional[str] = None, + source: Optional[str] = None, + ) -> Engine: user_name = utils.get_username() return database.get_sqla_engine( schema=schema, nullpool=True, user_name=user_name, source=source @@ -217,7 +229,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return select_exprs @classmethod - def fetch_data(cls, cursor, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: """ :param cursor: Cursor instance @@ -246,7 +258,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return columns, data, [] @classmethod - def alter_new_orm_column(cls, orm_col): + def alter_new_orm_column(cls, orm_col: "TableColumn") -> None: """Allow altering default column attributes when first detected/added For instance special column like `__time` for Druid can be @@ -290,7 +302,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def extra_table_metadata( - cls, database, table_name: str, schema_name: str + cls, database: "Database", table_name: str, schema_name: str ) -> Dict[str, Any]: """ Returns engine-specific table metadata @@ -304,7 +316,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return {} @classmethod - def apply_limit_to_sql(cls, sql: str, limit: int, database) -> str: + def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str: """ Alters the SQL statement to apply a LIMIT clause @@ -351,7 +363,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return parsed_query.get_query_with_new_limit(limit) @staticmethod - def csv_to_df(**kwargs) -> pd.DataFrame: + def csv_to_df(**kwargs: Any) -> pd.DataFrame: """ Read csv into Pandas DataFrame :param kwargs: params to be passed to DataFrame.read_csv :return: Pandas DataFrame containing data from csv @@ -363,7 +375,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return df @classmethod - def df_to_sql(cls, df: pd.DataFrame, **kwargs): # pylint: disable=invalid-name + def df_to_sql( # pylint: disable=invalid-name + cls, df: pd.DataFrame, **kwargs: Any + ) -> None: """ Upload data from a Pandas DataFrame to a database. For regular engines this calls the DataFrame.to_sql() method. Can be overridden for engines that don't work well with to_sql(), e.g. @@ -374,7 +388,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods df.to_sql(**kwargs) @classmethod - def create_table_from_csv(cls, form, database) -> None: + def create_table_from_csv(cls, form: Form, database: "Database") -> None: """ Create table from contents of a csv. Note: this method does not create metadata for the table. @@ -437,7 +451,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_all_datasource_names( - cls, database, datasource_type: str + cls, database: "Database", datasource_type: str ) -> List[utils.DatasourceName]: """Returns a list of all tables or views in database. @@ -472,7 +486,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return all_datasources @classmethod - def handle_cursor(cls, cursor, query, session): + def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: """Handle a live cursor between the execute and fetchall calls The flow works without this method doing anything, but it allows @@ -486,13 +500,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return f"{cls.engine} error: {cls._extract_error_message(e)}" @classmethod - def _extract_error_message(cls, e: Exception) -> str: + def _extract_error_message(cls, e: Exception) -> Optional[str]: """Extract error message for queries""" return utils.error_msg_from_exception(e) @classmethod - def adjust_database_uri(cls, uri, selected_schema: Optional[str]): - """Based on a URI and selected schema, return a new URI + def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None: + """ + Mutate the database component of the SQLAlchemy URI. The URI here represents the URI as entered when saving the database, ``selected_schema`` is the schema currently active presumably in @@ -509,11 +524,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods Some database drivers like presto accept '{catalog}/{schema}' in the database component of the URL, that can be handled here. """ - # TODO: All overrides mutate input uri; should be renamed or refactored - return uri + pass @classmethod - def patch(cls): + def patch(cls) -> None: """ TODO: Improve docstring and refactor implementation in Hive """ @@ -580,7 +594,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, table_name: str, schema: Optional[str], - database, + database: "Database", query: Select, columns: Optional[List] = None, ) -> Optional[Select]: @@ -599,13 +613,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return None @classmethod - def _get_fields(cls, cols): - return [column(c.get("name")) for c in cols] + def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]: + return [column(c["name"]) for c in cols] @classmethod def select_star( # pylint: disable=too-many-arguments,too-many-locals cls, - database, + database: "Database", table_name: str, engine: Engine, schema: Optional[str] = None, @@ -629,7 +643,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param cols: Columns to include in query :return: SQL query """ - fields = "*" + fields: Union[str, List[Any]] = "*" cols = cols or [] if (show_cols or latest_partition) and not cols: cols = database.get_columns(table_name, schema) @@ -659,7 +673,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def estimate_statement_cost( - cls, statement: str, database, cursor, user_name: str + cls, statement: str, database: "Database", cursor: Any, user_name: str ) -> Dict[str, Any]: """ Generate a SQL query that estimates the cost of a given statement. @@ -686,7 +700,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def estimate_query_cost( - cls, database, schema: str, sql: str, source: Optional[str] = None + cls, database: "Database", schema: str, sql: str, source: Optional[str] = None ) -> List[Dict[str, str]]: """ Estimate the cost of a multiple statement SQL query. @@ -718,8 +732,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def modify_url_for_impersonation( - cls, url, impersonate_user: bool, username: Optional[str] - ): + cls, url: URL, impersonate_user: bool, username: Optional[str] + ) -> None: """ Modify the SQL Alchemy URL object with the user to impersonate if applicable. :param url: SQLAlchemy URL object @@ -745,7 +759,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return {} @classmethod - def execute(cls, cursor, query: str, **kwargs): + def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None: """ Execute a SQL query diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 93b7802e2..9bf4b28f3 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -17,13 +17,17 @@ import hashlib import re from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import pandas as pd from sqlalchemy import literal_column +from sqlalchemy.sql.expression import ColumnClause from superset.db_engine_specs.base import BaseEngineSpec +if TYPE_CHECKING: + from superset.models.core import Database # pylint: disable=unused-import + class BigQueryEngineSpec(BaseEngineSpec): """Engine spec for Google's BigQuery @@ -69,7 +73,7 @@ class BigQueryEngineSpec(BaseEngineSpec): return None @classmethod - def fetch_data(cls, cursor, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: data = super(BigQueryEngineSpec, cls).fetch_data(cursor, limit) if data and type(data[0]).__name__ == "Row": data = [r.values() for r in data] # type: ignore @@ -112,7 +116,7 @@ class BigQueryEngineSpec(BaseEngineSpec): @classmethod def extra_table_metadata( - cls, database, table_name: str, schema_name: str + cls, database: "Database", table_name: str, schema_name: str ) -> Dict[str, Any]: indexes = database.get_indexes(table_name, schema_name) if not indexes: @@ -133,7 +137,7 @@ class BigQueryEngineSpec(BaseEngineSpec): } @classmethod - def _get_fields(cls, cols): + def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: """ BigQuery dialect requires us to not use backtick in the fieldname which are nested. @@ -143,8 +147,7 @@ class BigQueryEngineSpec(BaseEngineSpec): column names in the result. """ return [ - literal_column(c.get("name")).label(c.get("name").replace(".", "__")) - for c in cols + literal_column(c["name"]).label(c["name"].replace(".", "__")) for c in cols ] @classmethod @@ -156,7 +159,7 @@ class BigQueryEngineSpec(BaseEngineSpec): return "TIMESTAMP_MILLIS({col})" @classmethod - def df_to_sql(cls, df: pd.DataFrame, **kwargs): + def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None: """ Upload data from a Pandas DataFrame to BigQuery. Calls `DataFrame.to_gbq()` which requires `pandas_gbq` to be installed. diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index f8277eeb7..73b5912e2 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -18,6 +18,8 @@ from datetime import datetime from typing import Optional from urllib import parse +from sqlalchemy.engine.url import URL + from superset.db_engine_specs.base import BaseEngineSpec @@ -59,7 +61,6 @@ class DrillEngineSpec(BaseEngineSpec): return None @classmethod - def adjust_database_uri(cls, uri, selected_schema): + def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None: if selected_schema: uri.database = parse.quote(selected_schema, safe="") - return uri diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 3610a5891..35b3b4748 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -14,8 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import TYPE_CHECKING + from superset.db_engine_specs.base import BaseEngineSpec +if TYPE_CHECKING: + from superset.connectors.sqla.models import ( # pylint: disable=unused-import + TableColumn, + ) + class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method """Engine spec for Druid.io""" @@ -37,6 +44,6 @@ class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method } @classmethod - def alter_new_orm_column(cls, orm_col): + def alter_new_orm_column(cls, orm_col: "TableColumn") -> None: if orm_col.column_name == "__time": orm_col.is_dttm = True diff --git a/superset/db_engine_specs/exasol.py b/superset/db_engine_specs/exasol.py index ea4a40032..8c14581de 100644 --- a/superset/db_engine_specs/exasol.py +++ b/superset/db_engine_specs/exasol.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Tuple +from typing import Any, List, Tuple from superset.db_engine_specs.base import BaseEngineSpec @@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method } @classmethod - def fetch_data(cls, cursor, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 8785149e8..82d20d7ce 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -22,15 +22,19 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from urllib import parse +import pandas as pd from sqlalchemy import Column from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.engine.url import make_url +from sqlalchemy.engine.url import make_url, URL +from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select +from wtforms.form import Form from superset import app, cache, conf from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.presto import PrestoEngineSpec +from superset.models.sql_lab import Query from superset.utils import core as utils if TYPE_CHECKING: @@ -67,7 +71,7 @@ class HiveEngineSpec(PrestoEngineSpec): ) @classmethod - def patch(cls): + def patch(cls) -> None: from pyhive import hive # pylint: disable=no-name-in-module from superset.db_engines import hive as patched_hive from TCLIService import ( @@ -83,12 +87,12 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def get_all_datasource_names( - cls, database, datasource_type: str + cls, database: "Database", datasource_type: str ) -> List[utils.DatasourceName]: return BaseEngineSpec.get_all_datasource_names(database, datasource_type) @classmethod - def fetch_data(cls, cursor, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: import pyhive from TCLIService import ttypes @@ -102,11 +106,11 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def create_table_from_csv( # pylint: disable=too-many-locals - cls, form, database + cls, form: Form, database: "Database" ) -> None: """Uploads a csv file and creates a superset datasource in Hive.""" - def convert_to_hive_type(col_type): + def convert_to_hive_type(col_type: str) -> str: """maps tableschema's types to hive types""" tableschema_to_hive_types = { "boolean": "BOOLEAN", @@ -192,13 +196,14 @@ class HiveEngineSpec(PrestoEngineSpec): return None @classmethod - def adjust_database_uri(cls, uri, selected_schema=None): + def adjust_database_uri( + cls, uri: URL, selected_schema: Optional[str] = None + ) -> None: if selected_schema: uri.database = parse.quote(selected_schema, safe="") - return uri @classmethod - def _extract_error_message(cls, e): + def _extract_error_message(cls, e: Exception) -> str: msg = str(e) match = re.search(r'errorMessage="(.*?)(? int: total_jobs = 1 # assuming there's at least 1 job current_job = 1 - stages = {} + stages: Dict[int, float] = {} for line in log_lines: match = cls.jobs_stats_r.match(line) if match: @@ -237,15 +242,17 @@ class HiveEngineSpec(PrestoEngineSpec): return int(progress) @classmethod - def get_tracking_url(cls, log_lines): + def get_tracking_url(cls, log_lines: List[str]) -> Optional[str]: lkp = "Tracking URL = " for line in log_lines: if lkp in line: return line.split(lkp)[1] - return None + return None @classmethod - def handle_cursor(cls, cursor, query, session): # pylint: disable=too-many-locals + def handle_cursor( # pylint: disable=too-many-locals + cls, cursor: Any, query: Query, session: Session + ) -> None: """Updates progress information""" from pyhive import hive # pylint: disable=no-name-in-module @@ -310,7 +317,7 @@ class HiveEngineSpec(PrestoEngineSpec): cls, table_name: str, schema: Optional[str], - database, + database: "Database", query: Select, columns: Optional[List] = None, ) -> Optional[Select]: @@ -335,12 +342,14 @@ class HiveEngineSpec(PrestoEngineSpec): return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access @classmethod - def latest_sub_partition(cls, table_name, schema, database, **kwargs): + def latest_sub_partition( + cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any + ) -> str: # TODO(bogdan): implement` pass @classmethod - def _latest_partition_from_df(cls, df) -> Optional[List[str]]: + def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: """Hive partitions look like ds={partition name}""" if not df.empty: return [df.ix[:, 0].max().split("=")[1]] @@ -348,14 +357,19 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def _partition_query( # pylint: disable=too-many-arguments - cls, table_name, database, limit=0, order_by=None, filters=None - ): + cls, + table_name: str, + database: "Database", + limit: int = 0, + order_by: Optional[List[Tuple[str, bool]]] = None, + filters: Optional[Dict[Any, Any]] = None, + ) -> str: return f"SHOW PARTITIONS {table_name}" @classmethod def select_star( # pylint: disable=too-many-arguments cls, - database, + database: "Database", table_name: str, engine: Engine, schema: Optional[str] = None, @@ -381,8 +395,8 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def modify_url_for_impersonation( - cls, url, impersonate_user: bool, username: Optional[str] - ): + cls, url: URL, impersonate_user: bool, username: Optional[str] + ) -> None: """ Modify the SQL Alchemy URL object with the user to impersonate if applicable. :param url: SQLAlchemy URL object diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 94555b0a9..be91d1438 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -16,7 +16,7 @@ # under the License. import re from datetime import datetime -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.types import String, TypeEngine, UnicodeText @@ -46,7 +46,7 @@ class MssqlEngineSpec(BaseEngineSpec): } @classmethod - def epoch_to_dttm(cls): + def epoch_to_dttm(cls) -> str: return "dateadd(S, {col}, '1970-01-01')" @classmethod @@ -61,7 +61,7 @@ class MssqlEngineSpec(BaseEngineSpec): return None @classmethod - def fetch_data(cls, cursor, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 81dc40dca..023dd76aa 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Optional from urllib import parse from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.engine.url import URL from sqlalchemy.types import TypeEngine from superset.db_engine_specs.base import BaseEngineSpec @@ -59,10 +60,11 @@ class MySQLEngineSpec(BaseEngineSpec): return None @classmethod - def adjust_database_uri(cls, uri, selected_schema=None): + def adjust_database_uri( + cls, uri: URL, selected_schema: Optional[str] = None + ) -> None: if selected_schema: uri.database = parse.quote(selected_schema, safe="") - return uri @classmethod def get_datatype(cls, type_code: Any) -> Optional[str]: @@ -86,7 +88,7 @@ class MySQLEngineSpec(BaseEngineSpec): return "from_unixtime({col})" @classmethod - def _extract_error_message(cls, e): + def _extract_error_message(cls, e: Exception) -> str: """Extract error message for queries""" message = str(e) try: diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 68f442b61..388ae6aa7 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import List, Optional, Tuple, TYPE_CHECKING +from typing import Any, List, Optional, Tuple, TYPE_CHECKING from pytz import _FixedOffset # type: ignore from sqlalchemy.dialects.postgresql.base import PGInspector @@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec): } @classmethod - def fetch_data(cls, cursor, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: cursor.tzinfo_factory = FixedOffsetTimezone if not cursor.description: return [] diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 67d1e3879..18bd95858 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -25,16 +25,20 @@ from distutils.version import StrictVersion from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING from urllib import parse +import pandas as pd import simplejson as json from sqlalchemy import Column, literal_column from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.result import RowProxy +from sqlalchemy.engine.url import URL +from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select from superset import app, cache, is_feature_enabled, security_manager from superset.db_engine_specs.base import BaseEngineSpec from superset.exceptions import SupersetTemplateException +from superset.models.sql_lab import Query from superset.models.sql_types.presto_sql_types import type_map as presto_type_map from superset.sql_parse import ParsedQuery from superset.utils import core as utils @@ -392,7 +396,7 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def select_star( # pylint: disable=too-many-arguments cls, - database, + database: "Database", table_name: str, engine: Engine, schema: Optional[str] = None, @@ -428,7 +432,7 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def estimate_statement_cost( # pylint: disable=too-many-locals - cls, statement: str, database, cursor, user_name: str + cls, statement: str, database: "Database", cursor: Any, user_name: str ) -> Dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. @@ -510,7 +514,9 @@ class PrestoEngineSpec(BaseEngineSpec): return cost @classmethod - def adjust_database_uri(cls, uri, selected_schema=None): + def adjust_database_uri( + cls, uri: URL, selected_schema: Optional[str] = None + ) -> None: database = uri.database if selected_schema and database: selected_schema = parse.quote(selected_schema, safe="") @@ -519,7 +525,6 @@ class PrestoEngineSpec(BaseEngineSpec): else: database += "/" + selected_schema uri.database = database - return uri @classmethod def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: @@ -536,7 +541,7 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def get_all_datasource_names( - cls, database, datasource_type: str + cls, database: "Database", datasource_type: str ) -> List[utils.DatasourceName]: datasource_df = database.get_df( "SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S " @@ -656,7 +661,7 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def extra_table_metadata( - cls, database, table_name: str, schema_name: str + cls, database: "Database", table_name: str, schema_name: str ) -> Dict[str, Any]: metadata = {} @@ -670,10 +675,12 @@ class PrestoEngineSpec(BaseEngineSpec): col_names, latest_parts = cls.latest_partition( table_name, schema_name, database, show_first=True ) - latest_parts = latest_parts or tuple([None] * len(col_names)) + + if not latest_parts: + latest_parts = tuple([None] * len(col_names)) # type: ignore metadata["partitions"] = { "cols": cols, - "latest": dict(zip(col_names, latest_parts)), + "latest": dict(zip(col_names, latest_parts)), # type: ignore "partitionQuery": pql, } @@ -685,7 +692,9 @@ class PrestoEngineSpec(BaseEngineSpec): return metadata @classmethod - def get_create_view(cls, database, schema: str, table: str) -> Optional[str]: + def get_create_view( + cls, database: "Database", schema: str, table: str + ) -> Optional[str]: """ Return a CREATE VIEW statement, or `None` if not a view. @@ -712,7 +721,7 @@ class PrestoEngineSpec(BaseEngineSpec): return rows[0][0] @classmethod - def handle_cursor(cls, cursor, query, session): + def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: """Updates progress information""" query_id = query.id logger.info(f"Query {query_id}: Polling the cursor for progress") @@ -753,13 +762,13 @@ class PrestoEngineSpec(BaseEngineSpec): polled = cursor.poll() @classmethod - def _extract_error_message(cls, e): + def _extract_error_message(cls, e: Exception) -> Optional[str]: if ( hasattr(e, "orig") - and type(e.orig).__name__ == "DatabaseError" - and isinstance(e.orig[0], dict) + and type(e.orig).__name__ == "DatabaseError" # type: ignore + and isinstance(e.orig[0], dict) # type: ignore ): - error_dict = e.orig[0] + error_dict = e.orig[0] # type: ignore return "{} at {}: {}".format( error_dict.get("errorName"), error_dict.get("errorLocation"), @@ -772,8 +781,13 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def _partition_query( # pylint: disable=too-many-arguments,too-many-locals - cls, table_name, database, limit=0, order_by=None, filters=None - ): + cls, + table_name: str, + database: "Database", + limit: int = 0, + order_by: Optional[List[Tuple[str, bool]]] = None, + filters: Optional[Dict[Any, Any]] = None, + ) -> str: """Returns a partition query :param table_name: the name of the table to get partitions from @@ -827,7 +841,7 @@ class PrestoEngineSpec(BaseEngineSpec): cls, table_name: str, schema: Optional[str], - database, + database: "Database", query: Select, columns: Optional[List] = None, ) -> Optional[Select]: @@ -850,7 +864,7 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def _latest_partition_from_df( # pylint: disable=invalid-name - cls, df + cls, df: pd.DataFrame ) -> Optional[List[str]]: if not df.empty: return df.to_records(index=False)[0].item() @@ -858,8 +872,12 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def latest_partition( - cls, table_name: str, schema: Optional[str], database, show_first: bool = False - ): + cls, + table_name: str, + schema: Optional[str], + database: "Database", + show_first: bool = False, + ) -> Tuple[List[str], Optional[List[str]]]: """Returns col name and the latest (max) partition value for a table :param table_name: the name of the table @@ -897,7 +915,9 @@ class PrestoEngineSpec(BaseEngineSpec): return column_names, cls._latest_partition_from_df(df) @classmethod - def latest_sub_partition(cls, table_name, schema, database, **kwargs): + def latest_sub_partition( + cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any + ) -> Any: """Returns the latest (max) partition value for a table A filtering criteria should be passed for all fields that are diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index dc2e248fa..ada9fae8b 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -18,6 +18,8 @@ from datetime import datetime from typing import Optional from urllib import parse +from sqlalchemy.engine.url import URL + from superset.db_engine_specs.postgres import PostgresBaseEngineSpec @@ -47,14 +49,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): } @classmethod - def adjust_database_uri(cls, uri, selected_schema=None): + def adjust_database_uri( + cls, uri: URL, selected_schema: Optional[str] = None + ) -> None: database = uri.database if "/" in uri.database: database = uri.database.split("/")[0] if selected_schema: selected_schema = parse.quote(selected_schema, safe="") uri.database = database + "/" + selected_schema - return uri @classmethod def epoch_to_dttm(cls) -> str: diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py index 8444ede45..12d422add 100644 --- a/superset/db_engine_specs/sqlite.py +++ b/superset/db_engine_specs/sqlite.py @@ -49,7 +49,7 @@ class SqliteEngineSpec(BaseEngineSpec): @classmethod def get_all_datasource_names( - cls, database, datasource_type: str + cls, database: "Database", datasource_type: str ) -> List[utils.DatasourceName]: schemas = database.get_all_schema_names( cache=database.schema_cache_enabled, diff --git a/superset/models/core.py b/superset/models/core.py index af0d21ee5..36aa1c287 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -282,7 +282,7 @@ class Database( ) -> Engine: extra = self.get_extra() sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted) - sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) + self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) effective_username = self.get_effective_user(sqlalchemy_url, user_name) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 5e1612083..acd3c3226 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -51,7 +51,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]: class ParsedQuery: - def __init__(self, sql_statement): + def __init__(self, sql_statement: str): self.sql: str = sql_statement self._table_names: Set[str] = set() self._alias_names: Set[str] = set() diff --git a/superset/utils/core.py b/superset/utils/core.py index 41daebc08..436b06cce 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -414,7 +414,7 @@ def json_dumps_w_dates(payload): return json.dumps(payload, default=json_int_dttm_ser) -def error_msg_from_exception(e): +def error_msg_from_exception(e: Exception) -> str: """Translate exception into error message Database have different ways to handle exception. This function attempts @@ -430,10 +430,10 @@ def error_msg_from_exception(e): """ msg = "" if hasattr(e, "message"): - if isinstance(e.message, dict): - msg = e.message.get("message") - elif e.message: - msg = e.message + if isinstance(e.message, dict): # type: ignore + msg = e.message.get("message") # type: ignore + elif e.message: # type: ignore + msg = e.message # type: ignore return msg or str(e) diff --git a/superset/utils/feature_flag_manager.py b/superset/utils/feature_flag_manager.py index 7802f65c3..654607b02 100644 --- a/superset/utils/feature_flag_manager.py +++ b/superset/utils/feature_flag_manager.py @@ -34,6 +34,6 @@ class FeatureFlagManager: return self._feature_flags - def is_feature_enabled(self, feature): + def is_feature_enabled(self, feature) -> bool: """Utility function for checking whether a feature is turned on""" return self.get_feature_flags().get(feature)