fix(drill): no rows returned (#27073)

This commit is contained in:
Beto Dealmeida 2024-02-12 12:11:06 -05:00 committed by GitHub
parent 16e49cb2f7
commit 0950bb7b7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 8 deletions

View File

@ -154,7 +154,7 @@ setup(
], ],
"db2": ["ibm-db-sa>0.3.8, <=0.4.0"], "db2": ["ibm-db-sa>0.3.8, <=0.4.0"],
"dremio": ["sqlalchemy-dremio>=1.1.5, <1.3"], "dremio": ["sqlalchemy-dremio>=1.1.5, <1.3"],
"drill": ["sqlalchemy-drill==0.1.dev"], "drill": ["sqlalchemy-drill>=1.1.4, <2"],
"druid": ["pydruid>=0.6.5,<0.7"], "druid": ["pydruid>=0.6.5,<0.7"],
"duckdb": ["duckdb-engine>=0.9.5, <0.10"], "duckdb": ["duckdb-engine>=0.9.5, <0.10"],
"dynamodb": ["pydynamodb>=0.4.2"], "dynamodb": ["pydynamodb>=0.4.2"],

View File

@ -14,8 +14,11 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any
from urllib import parse from urllib import parse
from sqlalchemy import types from sqlalchemy import types
@ -60,8 +63,8 @@ class DrillEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def convert_dttm( def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
) -> Optional[str]: ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type) sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date): if isinstance(sqla_type, types.Date):
@ -76,8 +79,8 @@ class DrillEngineSpec(BaseEngineSpec):
cls, cls,
uri: URL, uri: URL,
connect_args: dict[str, Any], connect_args: dict[str, Any],
catalog: Optional[str] = None, catalog: str | None = None,
schema: Optional[str] = None, schema: str | None = None,
) -> tuple[URL, dict[str, Any]]: ) -> tuple[URL, dict[str, Any]]:
if schema: if schema:
uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe="")) uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
@ -89,7 +92,7 @@ class DrillEngineSpec(BaseEngineSpec):
cls, cls,
sqlalchemy_uri: URL, sqlalchemy_uri: URL,
connect_args: dict[str, Any], connect_args: dict[str, Any],
) -> Optional[str]: ) -> str | None:
""" """
Return the configured schema. Return the configured schema.
""" """
@ -97,7 +100,7 @@ class DrillEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def get_url_for_impersonation( def get_url_for_impersonation(
cls, url: URL, impersonate_user: bool, username: Optional[str] cls, url: URL, impersonate_user: bool, username: str | None
) -> URL: ) -> URL:
""" """
Return a modified URL with the username set. Return a modified URL with the username set.
@ -117,3 +120,23 @@ class DrillEngineSpec(BaseEngineSpec):
) )
return url return url
@classmethod
def fetch_data(
cls,
cursor: Any,
limit: int | None = None,
) -> list[tuple[Any, ...]]:
"""
Custom `fetch_data` for Drill.
When no rows are returned, Drill raises a `RuntimeError` with the message
"generator raised StopIteration". This method catches the exception and
returns an empty list instead.
"""
try:
return super().fetch_data(cursor, limit)
except RuntimeError as ex:
if str(ex) == "generator raised StopIteration":
return []
raise