From 53798c79041a5b5961a87ad1da0af5032d750fa8 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Thu, 13 Jun 2024 08:55:09 -0700 Subject: [PATCH] feat(trino): Add functionality to upload data (#29164) --- pyproject.toml | 3 +- superset/db_engine_specs/hive.py | 6 + superset/db_engine_specs/trino.py | 89 +++++++++++++- .../db_engine_specs/trino_tests.py | 115 ++++++++++++++++++ 4 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 tests/integration_tests/db_engine_specs/trino_tests.py diff --git a/pyproject.toml b/pyproject.toml index 84736580c..65471b0cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,7 @@ gevent = ["gevent>=23.9.1"] gsheets = ["shillelagh[gsheetsapi]>=1.2.18, <2"] hana = ["hdbcli==2.4.162", "sqlalchemy_hana==0.4.0"] hive = [ + "boto3", "pyhive[hive]>=0.6.5;python_version<'3.11'", "pyhive[hive_pure_sasl]>=0.7.0", "tableschema", @@ -154,7 +155,7 @@ pinot = ["pinotdb>=0.3.3, <0.4"] playwright = ["playwright>=1.37.0, <2"] postgres = ["psycopg2-binary==2.9.6"] presto = ["pyhive[presto]>=0.6.5"] -trino = ["trino>=0.328.0"] +trino = ["boto3", "trino>=0.328.0"] prophet = ["prophet>=1.1.5, <2"] redshift = ["sqlalchemy-redshift>=0.8.1, <0.9"] rockset = ["rockset-sqlalchemy>=0.0.1, <1"] diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 80892b598..519618aaa 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -79,6 +79,12 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str: ) s3 = boto3.client("s3") + + # The location is merely an S3 prefix and thus we first need to ensure that there is + # one and only one key associated with the table. + bucket = s3.Bucket(bucket_path) + bucket.objects.filter(Prefix=os.path.join(upload_prefix, table.table)).delete() + location = os.path.join("s3a://", bucket_path, upload_prefix, table.table) s3.upload_file( filename, diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 350337c6b..ce0e03be7 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -20,9 +20,14 @@ import contextlib import logging import threading import time +from tempfile import NamedTemporaryFile from typing import Any, TYPE_CHECKING -from flask import current_app, Flask +import numpy as np +import pandas as pd +import pyarrow as pa +from flask import current_app, Flask, g +from sqlalchemy import text from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchTableError @@ -37,7 +42,9 @@ from superset.db_engine_specs.exceptions import ( SupersetDBAPIOperationalError, SupersetDBAPIProgrammingError, ) +from superset.db_engine_specs.hive import upload_to_s3 from superset.db_engine_specs.presto import PrestoBaseEngineSpec +from superset.exceptions import SupersetException from superset.models.sql_lab import Query from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType @@ -452,3 +459,83 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): return super().get_indexes(database, inspector, table) except NoSuchTableError: return [] + + @classmethod + def df_to_sql( + cls, + database: Database, + table: Table, + df: pd.DataFrame, + to_sql_kwargs: dict[str, Any], + ) -> None: + """ + Upload data from a Pandas DataFrame to a database. + + The data is stored via the binary Parquet format which is both less problematic + and more performant than a text file. + + Note this method does not create metadata for the table. + + :param database: The database to upload the data to + :param table: The table to upload the data to + :param df: The Pandas Dataframe with data to be uploaded + :param to_sql_kwargs: The `pandas.DataFrame.to_sql` keyword arguments + :see: superset.db_engine_specs.HiveEngineSpec.df_to_sql + """ + + # pylint: disable=import-outside-toplevel + + if to_sql_kwargs["if_exists"] == "append": + raise SupersetException("Append operation not currently supported") + + if to_sql_kwargs["if_exists"] == "fail": + if database.has_table_by_name(table.table, table.schema): + raise SupersetException("Table already exists") + elif to_sql_kwargs["if_exists"] == "replace": + with cls.get_engine(database) as engine: + engine.execute(f"DROP TABLE IF EXISTS {str(table)}") + + def _get_trino_type(dtype: np.dtype[Any]) -> str: + return { + np.dtype("bool"): "BOOLEAN", + np.dtype("float64"): "DOUBLE", + np.dtype("int64"): "BIGINT", + np.dtype("object"): "VARCHAR", + }.get(dtype, "VARCHAR") + + with NamedTemporaryFile( + dir=current_app.config["UPLOAD_FOLDER"], + suffix=".parquet", + ) as file: + pa.parquet.write_table(pa.Table.from_pandas(df), where=file.name) + + with cls.get_engine(database) as engine: + engine.execute( + # pylint: disable=consider-using-f-string + text( + """ + CREATE TABLE {table} ({schema}) + WITH ( + format = 'PARQUET', + external_location = '{location}' + ) + """.format( + location=upload_to_s3( + filename=file.name, + upload_prefix=current_app.config[ + "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC" + ]( + database, + g.user, + table.schema, + ), + table=table, + ), + schema=", ".join( + f'"{name}" {_get_trino_type(dtype)}' + for name, dtype in df.dtypes.items() + ), + table=str(table), + ), + ), + ) diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py new file mode 100644 index 000000000..d03999713 --- /dev/null +++ b/tests/integration_tests/db_engine_specs/trino_tests.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest import mock + +import pandas as pd +import pytest + +from superset.db_engine_specs.trino import TrinoEngineSpec +from superset.exceptions import SupersetException +from superset.sql_parse import Table +from tests.integration_tests.test_app import app + + +def test_df_to_csv() -> None: + with pytest.raises(SupersetException): + TrinoEngineSpec.df_to_sql( + mock.MagicMock(), + Table("foobar"), + pd.DataFrame(), + {"if_exists": "append"}, + ) + + +@mock.patch("superset.db_engine_specs.trino.g", spec={}) +def test_df_to_sql_if_exists_fail(mock_g): + mock_g.user = True + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + with pytest.raises(SupersetException, match="Table already exists"): + TrinoEngineSpec.df_to_sql( + mock_database, Table("foobar"), pd.DataFrame(), {"if_exists": "fail"} + ) + + +@mock.patch("superset.db_engine_specs.trino.g", spec={}) +def test_df_to_sql_if_exists_fail_with_schema(mock_g): + mock_g.user = True + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + with pytest.raises(SupersetException, match="Table already exists"): + TrinoEngineSpec.df_to_sql( + mock_database, + Table(table="foobar", schema="schema"), + pd.DataFrame(), + {"if_exists": "fail"}, + ) + + +@mock.patch("superset.db_engine_specs.trino.g", spec={}) +@mock.patch("superset.db_engine_specs.trino.upload_to_s3") +def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g): + config = app.config.copy() + app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: "" # noqa: F722 + mock_upload_to_s3.return_value = "mock-location" + mock_g.user = True + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + mock_execute = mock.MagicMock(return_value=True) + mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = ( + mock_execute + ) + table_name = "foobar" + + with app.app_context(): + TrinoEngineSpec.df_to_sql( + mock_database, + Table(table=table_name), + pd.DataFrame(), + {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"}, + ) + + mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}") + app.config = config + + +@mock.patch("superset.db_engine_specs.trino.g", spec={}) +@mock.patch("superset.db_engine_specs.trino.upload_to_s3") +def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g): + config = app.config.copy() + app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: "" # noqa: F722 + mock_upload_to_s3.return_value = "mock-location" + mock_g.user = True + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + mock_execute = mock.MagicMock(return_value=True) + mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = ( + mock_execute + ) + table_name = "foobar" + schema = "schema" + + with app.app_context(): + TrinoEngineSpec.df_to_sql( + mock_database, + Table(table=table_name, schema=schema), + pd.DataFrame(), + {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"}, + ) + + mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}") + app.config = config