fix(20428): Address-Presto/Trino-Poll-Issue-Refactor (#20434)

* fix(20428)-Address-Presto/Trino-Poll-Issue-Refacto
r

Update linter

* Update to only use BaseEngineSpec handle_cursor

* Fix CI

Co-authored-by: John Bodley <4567245+john-bodley@users.noreply.github.com>
This commit is contained in:
Simon Thelin 2022-06-20 00:28:59 +01:00 committed by GitHub
parent c2f01a676c
commit 8b7262fa90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 5 deletions

View File

@ -949,11 +949,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)
polled = cursor.poll()
while polled:
time.sleep(0.2)
polled = cursor.poll()
except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)

View File

@ -15,15 +15,18 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING
import simplejson as json
from flask import current_app
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.models.sql_lab import Query
from superset.utils import core as utils
if TYPE_CHECKING:
@ -77,6 +80,37 @@ class TrinoEngineSpec(PrestoEngineSpec):
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True
@classmethod
def get_table_names(
cls,
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
return BaseEngineSpec.get_table_names(
database=database,
inspector=inspector,
schema=schema,
)
@classmethod
def get_view_names(
cls,
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
return BaseEngineSpec.get_view_names(
database=database,
inspector=inspector,
schema=schema,
)
@classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
"""Updates progress information"""
BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session)
@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""