fix(hive): Use parquet rather than textfile when uploading CSV files to Hive (#14240)

* fix(hive): Use parquet rather than textfile when uploading CSV files

* [csv/excel]: Use stream rather than temporary file

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2021-04-24 18:17:30 +12:00 committed by GitHub
parent e392e2ed39
commit b0f8f6b6ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 247 additions and 387 deletions

View File

@ -618,50 +618,41 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.set_or_update_query_limit(limit)
@staticmethod
def csv_to_df(**kwargs: Any) -> pd.DataFrame:
"""Read csv into Pandas DataFrame
:param kwargs: params to be passed to DataFrame.read_csv
:return: Pandas DataFrame containing data from csv
"""
kwargs["encoding"] = "utf-8"
kwargs["iterator"] = True
chunks = pd.read_csv(**kwargs)
df = pd.concat(chunk for chunk in chunks)
return df
@classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None:
"""Upload data from a Pandas DataFrame to a database. For
regular engines this calls the DataFrame.to_sql() method. Can be
overridden for engines that don't work well with to_sql(), e.g.
BigQuery.
:param df: Dataframe with data to be uploaded
:param kwargs: kwargs to be passed to to_sql() method
"""
df.to_sql(**kwargs)
@classmethod
def create_table_from_csv( # pylint: disable=too-many-arguments
def df_to_sql(
cls,
filename: str,
table: Table,
database: "Database",
csv_to_df_kwargs: Dict[str, Any],
df_to_sql_kwargs: Dict[str, Any],
table: Table,
df: pd.DataFrame,
to_sql_kwargs: Dict[str, Any],
) -> None:
"""
Create table from contents of a csv. Note: this method does not create
metadata for the table.
Upload data from a Pandas DataFrame to a database.
For regular engines this calls the `pandas.DataFrame.to_sql` method. Can be
overridden for engines that don't work well with this method, e.g. Hive and
BigQuery.
Note this method does not create metadata for the table.
:param database: The database to upload the data to
:param table: The table to upload the data to
:param df: The dataframe with data to be uploaded
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""
df = cls.csv_to_df(filepath_or_buffer=filename, **csv_to_df_kwargs)
engine = cls.get_engine(database)
to_sql_kwargs["name"] = table.table
if table.schema:
# only add schema when it is preset and non empty
df_to_sql_kwargs["schema"] = table.schema
# Only add schema when it is preset and non empty.
to_sql_kwargs["schema"] = table.schema
if engine.dialect.supports_multivalues_insert:
df_to_sql_kwargs["method"] = "multi"
cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs)
to_sql_kwargs["method"] = "multi"
df.to_sql(con=engine, **to_sql_kwargs)
@classmethod
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
@ -674,28 +665,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
return None
@classmethod
def create_table_from_excel( # pylint: disable=too-many-arguments
cls,
filename: str,
table: Table,
database: "Database",
excel_to_df_kwargs: Dict[str, Any],
df_to_sql_kwargs: Dict[str, Any],
) -> None:
"""
Create table from contents of a excel. Note: this method does not create
metadata for the table.
"""
df = pd.read_excel(io=filename, **excel_to_df_kwargs)
engine = cls.get_engine(database)
if table.schema:
# only add schema when it is preset and non empty
df_to_sql_kwargs["schema"] = table.schema
if engine.dialect.supports_multivalues_insert:
df_to_sql_kwargs["method"] = "multi"
cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs)
@classmethod
def get_all_datasource_names(
cls, database: "Database", datasource_type: str

View File

@ -26,6 +26,7 @@ from sqlalchemy.sql.expression import ColumnClause
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.sql_parse import Table
from superset.utils import core as utils
if TYPE_CHECKING:
@ -228,16 +229,26 @@ class BigQueryEngineSpec(BaseEngineSpec):
return "TIMESTAMP_MILLIS({col})"
@classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None:
def df_to_sql(
cls,
database: "Database",
table: Table,
df: pd.DataFrame,
to_sql_kwargs: Dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to BigQuery. Calls
`DataFrame.to_gbq()` which requires `pandas_gbq` to be installed.
Upload data from a Pandas DataFrame to a database.
:param df: Dataframe with data to be uploaded
:param kwargs: kwargs to be passed to to_gbq() method. Requires that `schema`,
`name` and `con` are present in kwargs. `name` and `schema` are combined
and passed to `to_gbq()` as `destination_table`.
Calls `pandas_gbq.DataFrame.to_gbq` which requires `pandas_gbq` to be installed.
Note this method does not create metadata for the table.
:param database: The database to upload the data to
:param table: The table to upload the data to
:param df: The dataframe with data to be uploaded
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""
try:
import pandas_gbq
from google.oauth2 import service_account
@ -248,22 +259,25 @@ class BigQueryEngineSpec(BaseEngineSpec):
"to upload data to BigQuery"
)
if not ("name" in kwargs and "schema" in kwargs and "con" in kwargs):
raise Exception("name, schema and con need to be defined in kwargs")
if not table.schema:
raise Exception("The table schema must be defined")
gbq_kwargs = {}
gbq_kwargs["project_id"] = kwargs["con"].engine.url.host
gbq_kwargs["destination_table"] = f"{kwargs.pop('schema')}.{kwargs.pop('name')}"
engine = cls.get_engine(database)
to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host}
# Add credentials if they are set on the SQLAlchemy dialect.
creds = engine.dialect.credentials_info
# add credentials if they are set on the SQLAlchemy Dialect:
creds = kwargs["con"].dialect.credentials_info
if creds:
credentials = service_account.Credentials.from_service_account_info(creds)
gbq_kwargs["credentials"] = credentials
to_gbq_kwargs[
"credentials"
] = service_account.Credentials.from_service_account_info(creds)
# Only pass through supported kwargs
# Only pass through supported kwargs.
supported_kwarg_keys = {"if_exists"}
for key in supported_kwarg_keys:
if key in kwargs:
gbq_kwargs[key] = kwargs[key]
pandas_gbq.to_gbq(df, **gbq_kwargs)
if key in to_sql_kwargs:
to_gbq_kwargs[key] = to_sql_kwargs[key]
pandas_gbq.to_gbq(df, **to_gbq_kwargs)

View File

@ -17,12 +17,16 @@
import logging
import os
import re
import tempfile
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from urllib import parse
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from flask import g
from sqlalchemy import Column, text
from sqlalchemy.engine.base import Engine
@ -54,6 +58,15 @@ hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
"""
Upload the file to S3.
:param filename: The file to upload
:param upload_prefix: The S3 prefix
:param table: The table that will be created
:returns: The S3 location of the table
"""
# Optional dependency
import boto3 # pylint: disable=import-error
@ -156,89 +169,37 @@ class HiveEngineSpec(PrestoEngineSpec):
return []
@classmethod
def get_create_table_stmt( # pylint: disable=too-many-arguments
def df_to_sql(
cls,
table: Table,
schema_definition: str,
location: str,
delim: str,
header_line_count: Optional[int],
null_values: Optional[List[str]],
) -> text:
tblproperties = []
# available options:
# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
# TODO(bkyryliuk): figure out what to do with the skip rows field.
params: Dict[str, str] = {
"delim": delim,
"location": location,
}
if header_line_count is not None and header_line_count >= 0:
header_line_count += 1
tblproperties.append("'skip.header.line.count'=:header_line_count")
params["header_line_count"] = str(header_line_count)
if null_values:
# hive only supports 1 value for the null format
tblproperties.append("'serialization.null.format'=:null_value")
params["null_value"] = null_values[0]
if tblproperties:
tblproperties_stmt = f"tblproperties ({', '.join(tblproperties)})"
sql = f"""CREATE TABLE {str(table)} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
{tblproperties_stmt}"""
else:
sql = f"""CREATE TABLE {str(table)} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location"""
return sql, params
@classmethod
def create_table_from_csv( # pylint: disable=too-many-arguments, too-many-locals
cls,
filename: str,
table: Table,
database: "Database",
csv_to_df_kwargs: Dict[str, Any],
df_to_sql_kwargs: Dict[str, Any],
table: Table,
df: pd.DataFrame,
to_sql_kwargs: Dict[str, Any],
) -> None:
"""Uploads a csv file and creates a superset datasource in Hive."""
if_exists = df_to_sql_kwargs["if_exists"]
if if_exists == "append":
"""
Upload data from a Pandas DataFrame to a database.
The data is stored via the binary Parquet format which is both less problematic
and more performant than a text file. More specifically storing a table as a
CSV text file has severe limitations including the fact that the Hive CSV SerDe
does not support multiline fields.
Note this method does not create metadata for the table.
:param database: The database to upload the data to
:param: table The table to upload the data to
:param df: The dataframe with data to be uploaded
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""
engine = cls.get_engine(database)
if to_sql_kwargs["if_exists"] == "append":
raise SupersetException("Append operation not currently supported")
def convert_to_hive_type(col_type: str) -> str:
"""maps tableschema's types to hive types"""
tableschema_to_hive_types = {
"boolean": "BOOLEAN",
"integer": "BIGINT",
"number": "DOUBLE",
"string": "STRING",
}
return tableschema_to_hive_types.get(col_type, "STRING")
if to_sql_kwargs["if_exists"] == "fail":
upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
database, g.user, table.schema
)
# Optional dependency
from tableschema import ( # pylint: disable=import-error
Table as TableSchemaTable,
)
hive_table_schema = TableSchemaTable(filename).infer()
column_name_and_type = []
for column_info in hive_table_schema["fields"]:
column_name_and_type.append(
"`{}` {}".format(
column_info["name"], convert_to_hive_type(column_info["type"])
)
)
schema_definition = ", ".join(column_name_and_type)
# ensure table doesn't already exist
if if_exists == "fail":
# Ensure table doesn't already exist.
if table.schema:
table_exists = not database.get_df(
f"SHOW TABLES IN {table.schema} LIKE '{table.table}'"
@ -247,24 +208,47 @@ class HiveEngineSpec(PrestoEngineSpec):
table_exists = not database.get_df(
f"SHOW TABLES LIKE '{table.table}'"
).empty
if table_exists:
raise SupersetException("Table already exists")
engine = cls.get_engine(database)
if if_exists == "replace":
elif to_sql_kwargs["if_exists"] == "replace":
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
location = upload_to_s3(filename, upload_prefix, table)
sql, params = cls.get_create_table_stmt(
table,
schema_definition,
location,
csv_to_df_kwargs["sep"].encode().decode("unicode_escape"),
int(csv_to_df_kwargs.get("header", 0)),
csv_to_df_kwargs.get("na_values"),
def _get_hive_type(dtype: np.dtype) -> str:
hive_type_by_dtype = {
np.dtype("bool"): "BOOLEAN",
np.dtype("float64"): "DOUBLE",
np.dtype("int64"): "BIGINT",
np.dtype("object"): "STRING",
}
return hive_type_by_dtype.get(dtype, "STRING")
schema_definition = ", ".join(
f"`{name}` {_get_hive_type(dtype)}" for name, dtype in df.dtypes.items()
)
engine = cls.get_engine(database)
engine.execute(text(sql), **params)
with tempfile.NamedTemporaryFile(
dir=config["UPLOAD_FOLDER"], suffix=".parquet"
) as file:
pq.write_table(pa.Table.from_pandas(df), where=file.name)
engine.execute(
text(
f"""
CREATE TABLE {str(table)} ({schema_definition})
STORED AS PARQUET
LOCATION :location
"""
),
location=upload_to_s3(
filename=file.name,
upload_prefix=config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
database, g.user, table.schema
),
table=table,
),
)
@classmethod
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:

View File

@ -18,6 +18,7 @@ import os
import tempfile
from typing import TYPE_CHECKING
import pandas as pd
from flask import flash, g, redirect
from flask_appbuilder import expose, SimpleFormView
from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -149,55 +150,44 @@ class CsvToDatabaseView(SimpleFormView):
flash(message, "danger")
return redirect("/csvtodatabaseview/form")
uploaded_tmp_file_path = tempfile.NamedTemporaryFile(
dir=app.config["UPLOAD_FOLDER"],
suffix=os.path.splitext(form.csv_file.data.filename)[1].lower(),
delete=False,
).name
try:
utils.ensure_path_exists(config["UPLOAD_FOLDER"])
upload_stream_write(form.csv_file.data, uploaded_tmp_file_path)
con = form.data.get("con")
database = (
db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
df = pd.concat(
pd.read_csv(
chunksize=1000,
encoding="utf-8",
filepath_or_buffer=form.csv_file.data,
header=form.header.data if form.header.data else 0,
index_col=form.index_col.data,
infer_datetime_format=form.infer_datetime_format.data,
iterator=True,
keep_default_na=not form.null_values.data,
mangle_dupe_cols=form.mangle_dupe_cols.data,
na_values=form.null_values.data if form.null_values.data else None,
nrows=form.nrows.data,
parse_dates=form.parse_dates.data,
sep=form.sep.data,
skip_blank_lines=form.skip_blank_lines.data,
skipinitialspace=form.skipinitialspace.data,
skiprows=form.skiprows.data,
)
)
# More can be found here:
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
csv_to_df_kwargs = {
"sep": form.sep.data,
"header": form.header.data if form.header.data else 0,
"index_col": form.index_col.data,
"mangle_dupe_cols": form.mangle_dupe_cols.data,
"skipinitialspace": form.skipinitialspace.data,
"skiprows": form.skiprows.data,
"nrows": form.nrows.data,
"skip_blank_lines": form.skip_blank_lines.data,
"parse_dates": form.parse_dates.data,
"infer_datetime_format": form.infer_datetime_format.data,
"chunksize": 1000,
}
if form.null_values.data:
csv_to_df_kwargs["na_values"] = form.null_values.data
csv_to_df_kwargs["keep_default_na"] = False
database = (
db.session.query(models.Database)
.filter_by(id=form.data.get("con").data.get("id"))
.one()
)
# More can be found here:
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html
df_to_sql_kwargs = {
"name": csv_table.table,
"if_exists": form.if_exists.data,
"index": form.index.data,
"index_label": form.index_label.data,
"chunksize": 1000,
}
database.db_engine_spec.create_table_from_csv(
uploaded_tmp_file_path,
csv_table,
database.db_engine_spec.df_to_sql(
database,
csv_to_df_kwargs,
df_to_sql_kwargs,
csv_table,
df,
to_sql_kwargs={
"chunksize": 1000,
"if_exists": form.if_exists.data,
"index": form.index.data,
"index_label": form.index_label.data,
},
)
# Connect table to the database that should be used for exploration.
@ -236,10 +226,6 @@ class CsvToDatabaseView(SimpleFormView):
db.session.commit()
except Exception as ex: # pylint: disable=broad-except
db.session.rollback()
try:
os.remove(uploaded_tmp_file_path)
except OSError:
pass
message = _(
'Unable to upload CSV file "%(filename)s" to table '
'"%(table_name)s" in database "%(db_name)s". '
@ -254,7 +240,6 @@ class CsvToDatabaseView(SimpleFormView):
stats_logger.incr("failed_csv_upload")
return redirect("/csvtodatabaseview/form")
os.remove(uploaded_tmp_file_path)
# Go back to welcome page / splash screen
message = _(
'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in '
@ -316,40 +301,34 @@ class ExcelToDatabaseView(SimpleFormView):
utils.ensure_path_exists(config["UPLOAD_FOLDER"])
upload_stream_write(form.excel_file.data, uploaded_tmp_file_path)
con = form.data.get("con")
database = (
db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
df = pd.read_excel(
header=form.header.data if form.header.data else 0,
index_col=form.index_col.data,
io=form.excel_file.data,
keep_default_na=not form.null_values.data,
mangle_dupe_cols=form.mangle_dupe_cols.data,
na_values=form.null_values.data if form.null_values.data else None,
parse_dates=form.parse_dates.data,
skiprows=form.skiprows.data,
sheet_name=form.sheet_name.data if form.sheet_name.data else 0,
)
# some params are not supported by pandas.read_excel (e.g. chunksize).
# More can be found here:
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_excel.html
excel_to_df_kwargs = {
"header": form.header.data if form.header.data else 0,
"index_col": form.index_col.data,
"mangle_dupe_cols": form.mangle_dupe_cols.data,
"skiprows": form.skiprows.data,
"nrows": form.nrows.data,
"sheet_name": form.sheet_name.data if form.sheet_name.data else 0,
"parse_dates": form.parse_dates.data,
}
if form.null_values.data:
excel_to_df_kwargs["na_values"] = form.null_values.data
excel_to_df_kwargs["keep_default_na"] = False
database = (
db.session.query(models.Database)
.filter_by(id=form.data.get("con").data.get("id"))
.one()
)
df_to_sql_kwargs = {
"name": excel_table.table,
"if_exists": form.if_exists.data,
"index": form.index.data,
"index_label": form.index_label.data,
"chunksize": 1000,
}
database.db_engine_spec.create_table_from_excel(
uploaded_tmp_file_path,
excel_table,
database.db_engine_spec.df_to_sql(
database,
excel_to_df_kwargs,
df_to_sql_kwargs,
excel_table,
df,
to_sql_kwargs={
"chunksize": 1000,
"if_exists": form.if_exists.data,
"index": form.index.data,
"index_label": form.index_label.data,
},
)
# Connect table to the database that should be used for exploration.
@ -388,10 +367,6 @@ class ExcelToDatabaseView(SimpleFormView):
db.session.commit()
except Exception as ex: # pylint: disable=broad-except
db.session.rollback()
try:
os.remove(uploaded_tmp_file_path)
except OSError:
pass
message = _(
'Unable to upload Excel file "%(filename)s" to table '
'"%(table_name)s" in database "%(db_name)s". '
@ -406,7 +381,6 @@ class ExcelToDatabaseView(SimpleFormView):
stats_logger.incr("failed_excel_upload")
return redirect("/exceltodatabaseview/form")
os.remove(uploaded_tmp_file_path)
# Go back to welcome page / splash screen
message = _(
'Excel file "%(excel_filename)s" uploaded to table "%(table_name)s" in '

View File

@ -134,13 +134,14 @@ def upload_excel(
return get_resp(test_client, "/exceltodatabaseview/form", data=form_data)
def mock_upload_to_s3(f: str, p: str, t: Table) -> str:
""" HDFS is used instead of S3 for the unit tests.
def mock_upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
"""
HDFS is used instead of S3 for the unit tests.
:param f: filepath
:param p: unused parameter
:param t: table that will be created
:return: hdfs path to the directory with external table files
:param filename: The file to upload
:param upload_prefix: The S3 prefix
:param table: The table that will be created
:returns: The HDFS path to the directory with external table files
"""
# only needed for the hive tests
import docker
@ -148,11 +149,11 @@ def mock_upload_to_s3(f: str, p: str, t: Table) -> str:
client = docker.from_env()
container = client.containers.get("namenode")
# docker mounted volume that contains csv uploads
src = os.path.join("/tmp/superset_uploads", os.path.basename(f))
src = os.path.join("/tmp/superset_uploads", os.path.basename(filename))
# hdfs destination for the external tables
dest_dir = os.path.join("/tmp/external/superset_uploads/", str(t))
dest_dir = os.path.join("/tmp/external/superset_uploads/", str(table))
container.exec_run(f"hdfs dfs -mkdir -p {dest_dir}")
dest = os.path.join(dest_dir, os.path.basename(f))
dest = os.path.join(dest_dir, os.path.basename(filename))
container.exec_run(f"hdfs dfs -put {src} {dest}")
# hive external table expectes a directory for the location
return dest_dir
@ -279,23 +280,13 @@ def test_import_csv(setup_csv_upload, create_csv_files):
# make sure that john and empty string are replaced with None
engine = get_upload_db().get_sqla_engine()
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
if utils.backend() == "hive":
# Be aware that hive only uses first value from the null values list.
# It is hive database engine limitation.
# TODO(bkyryliuk): preprocess csv file for hive upload to match default engine capabilities.
assert data == [("john", 1, "x"), ("paul", 2, None)]
else:
assert data == [(None, 1, "x"), ("paul", 2, None)]
assert data == [(None, 1, "x"), ("paul", 2, None)]
# default null values
upload_csv(CSV_FILENAME2, CSV_UPLOAD_TABLE, extra={"if_exists": "replace"})
# make sure that john and empty string are replaced with None
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
if utils.backend() == "hive":
# By default hive does not convert values to null vs other databases.
assert data == [("john", 1, "x"), ("paul", 2, "")]
else:
assert data == [("john", 1, "x"), ("paul", 2, None)]
assert data == [("john", 1, "x"), ("paul", 2, None)]
@mock.patch("superset.db_engine_specs.hive.upload_to_s3", mock_upload_to_s3)

View File

@ -23,6 +23,7 @@ from sqlalchemy import column
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import Table
from tests.db_engine_specs.base_tests import TestDbEngineSpec
@ -166,21 +167,23 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
[{"name": "partition", "column_names": ["dttm"], "unique": False}],
)
def test_df_to_sql(self):
@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
def test_df_to_sql(self, mock_get_engine):
"""
DB Eng Specs (bigquery): Test DataFrame to SQL contract
"""
# test missing google.oauth2 dependency
sys.modules["pandas_gbq"] = mock.MagicMock()
df = DataFrame()
database = mock.MagicMock()
self.assertRaisesRegexp(
Exception,
"Could not import libraries",
BigQueryEngineSpec.df_to_sql,
df,
con="some_connection",
schema="schema",
name="name",
database=database,
table=Table(table="name", schema="schema"),
df=df,
to_sql_kwargs={},
)
invalid_kwargs = [
@ -191,15 +194,17 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
{"name": "some_name", "schema": "some_schema"},
{"con": "some_con", "schema": "some_schema"},
]
# Test check for missing required kwargs (name, schema, con)
# Test check for missing schema.
sys.modules["google.oauth2"] = mock.MagicMock()
for invalid_kwarg in invalid_kwargs:
self.assertRaisesRegexp(
Exception,
"name, schema and con need to be defined in kwargs",
"The table schema must be defined",
BigQueryEngineSpec.df_to_sql,
df,
**invalid_kwarg,
database=database,
table=Table(table="name"),
df=df,
to_sql_kwargs=invalid_kwarg,
)
import pandas_gbq
@ -209,12 +214,15 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
service_account.Credentials.from_service_account_info = mock.MagicMock(
return_value="account_info"
)
connection = mock.Mock()
connection.engine.url.host = "google-host"
connection.dialect.credentials_info = "secrets"
mock_get_engine.return_value.url.host = "google-host"
mock_get_engine.return_value.dialect.credentials_info = "secrets"
BigQueryEngineSpec.df_to_sql(
df, con=connection, schema="schema", name="name", if_exists="extra_key"
database=database,
table=Table(table="name", schema="schema"),
df=df,
to_sql_kwargs={"if_exists": "extra_key"},
)
pandas_gbq.to_gbq.assert_called_with(

View File

@ -163,11 +163,10 @@ def test_convert_dttm():
)
def test_create_table_from_csv_append() -> None:
def test_df_to_csv() -> None:
with pytest.raises(SupersetException):
HiveEngineSpec.create_table_from_csv(
"foo.csv", Table("foobar"), mock.MagicMock(), {}, {"if_exists": "append"}
HiveEngineSpec.df_to_sql(
mock.MagicMock(), Table("foobar"), pd.DataFrame(), {"if_exists": "append"},
)
@ -176,15 +175,13 @@ def test_create_table_from_csv_append() -> None:
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("tableschema.Table")
def test_create_table_from_csv_if_exists_fail(mock_table, mock_g):
mock_table.infer.return_value = {}
def test_df_to_sql_if_exists_fail(mock_g):
mock_g.user = True
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
with pytest.raises(SupersetException, match="Table already exists"):
HiveEngineSpec.create_table_from_csv(
"foo.csv", Table("foobar"), mock_database, {}, {"if_exists": "fail"}
HiveEngineSpec.df_to_sql(
mock_database, Table("foobar"), pd.DataFrame(), {"if_exists": "fail"}
)
@ -193,18 +190,15 @@ def test_create_table_from_csv_if_exists_fail(mock_table, mock_g):
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("tableschema.Table")
def test_create_table_from_csv_if_exists_fail_with_schema(mock_table, mock_g):
mock_table.infer.return_value = {}
def test_df_to_sql_if_exists_fail_with_schema(mock_g):
mock_g.user = True
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
with pytest.raises(SupersetException, match="Table already exists"):
HiveEngineSpec.create_table_from_csv(
"foo.csv",
Table(table="foobar", schema="schema"),
HiveEngineSpec.df_to_sql(
mock_database,
{},
Table(table="foobar", schema="schema"),
pd.DataFrame(),
{"if_exists": "fail"},
)
@ -214,11 +208,9 @@ def test_create_table_from_csv_if_exists_fail_with_schema(mock_table, mock_g):
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("tableschema.Table")
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
def test_create_table_from_csv_if_exists_replace(mock_upload_to_s3, mock_table, mock_g):
def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
mock_upload_to_s3.return_value = "mock-location"
mock_table.infer.return_value = {}
mock_g.user = True
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
@ -226,12 +218,11 @@ def test_create_table_from_csv_if_exists_replace(mock_upload_to_s3, mock_table,
mock_database.get_sqla_engine.return_value.execute = mock_execute
table_name = "foobar"
HiveEngineSpec.create_table_from_csv(
"foo.csv",
Table(table=table_name),
HiveEngineSpec.df_to_sql(
mock_database,
{"sep": "mock", "header": 1, "na_values": "mock"},
{"if_exists": "replace"},
Table(table=table_name),
pd.DataFrame(),
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
)
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}")
@ -242,13 +233,9 @@ def test_create_table_from_csv_if_exists_replace(mock_upload_to_s3, mock_table,
{**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("tableschema.Table")
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
def test_create_table_from_csv_if_exists_replace_with_schema(
mock_upload_to_s3, mock_table, mock_g
):
def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
mock_upload_to_s3.return_value = "mock-location"
mock_table.infer.return_value = {}
mock_g.user = True
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
@ -256,84 +243,17 @@ def test_create_table_from_csv_if_exists_replace_with_schema(
mock_database.get_sqla_engine.return_value.execute = mock_execute
table_name = "foobar"
schema = "schema"
HiveEngineSpec.create_table_from_csv(
"foo.csv",
Table(table=table_name, schema=schema),
HiveEngineSpec.df_to_sql(
mock_database,
{"sep": "mock", "header": 1, "na_values": "mock"},
{"if_exists": "replace"},
Table(table=table_name, schema=schema),
pd.DataFrame(),
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
)
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
def test_get_create_table_stmt() -> None:
table = Table("employee")
schema_def = """eid int, name String, salary String, destination String"""
location = "s3a://directory/table"
from unittest import TestCase
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", 0, [""]
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'=:header_line_count, 'serialization.null.format'=:null_value)""",
{
"delim": ",",
"location": "s3a://directory/table",
"header_line_count": "1",
"null_value": "",
},
)
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", 1, ["1", "2"]
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'=:header_line_count, 'serialization.null.format'=:null_value)""",
{
"delim": ",",
"location": "s3a://directory/table",
"header_line_count": "2",
"null_value": "1",
},
)
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", 100, ["NaN"]
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'=:header_line_count, 'serialization.null.format'=:null_value)""",
{
"delim": ",",
"location": "s3a://directory/table",
"header_line_count": "101",
"null_value": "NaN",
},
)
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", None, None
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location""",
{"delim": ",", "location": "s3a://directory/table"},
)
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", 100, []
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'=:header_line_count)""",
{"delim": ",", "location": "s3a://directory/table", "header_line_count": "101"},
)
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return HiveEngineSpec.is_readonly_query(ParsedQuery(sql))