From 8b7262fa9040b6bc956dfa2c191953fe3b65bea6 Mon Sep 17 00:00:00 2001 From: Simon Thelin Date: Mon, 20 Jun 2022 00:28:59 +0100 Subject: [PATCH] fix(20428): Address-Presto/Trino-Poll-Issue-Refactor (#20434) * fix(20428)-Address-Presto/Trino-Poll-Issue-Refacto r Update linter * Update to only use BaseEngineSpec handle_cursor * Fix CI Co-authored-by: John Bodley <4567245+john-bodley@users.noreply.github.com> --- superset/db_engine_specs/presto.py | 4 ---- superset/db_engine_specs/trino.py | 36 +++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index d0621e288..cd6fa032b 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -949,11 +949,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho sql = f"SHOW CREATE VIEW {schema}.{table}" try: cls.execute(cursor, sql) - polled = cursor.poll() - while polled: - time.sleep(0.2) - polled = cursor.poll() except DatabaseError: # not a VIEW return None rows = cls.fetch_data(cursor, 1) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 46e3ed55d..acddb9710 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -15,15 +15,18 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING import simplejson as json from flask import current_app +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL +from sqlalchemy.orm import Session from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.presto import PrestoEngineSpec +from superset.models.sql_lab import Query from superset.utils import core as utils if TYPE_CHECKING: @@ -77,6 +80,37 @@ class TrinoEngineSpec(PrestoEngineSpec): def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: return True + @classmethod + def get_table_names( + cls, + database: "Database", + inspector: Inspector, + schema: Optional[str], + ) -> List[str]: + return BaseEngineSpec.get_table_names( + database=database, + inspector=inspector, + schema=schema, + ) + + @classmethod + def get_view_names( + cls, + database: "Database", + inspector: Inspector, + schema: Optional[str], + ) -> List[str]: + return BaseEngineSpec.get_view_names( + database=database, + inspector=inspector, + schema=schema, + ) + + @classmethod + def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: + """Updates progress information""" + BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session) + @staticmethod def get_extra_params(database: "Database") -> Dict[str, Any]: """