feat: support nulls in the csv uploads (#10208)

* Support more table properties for the hive upload

Refactor

Add tests, and refactor them to be pytest friendly

Use lowercase table names

Ignore isort

* Use sql params

Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
Bogdan 2020-07-06 13:26:43 -07:00 committed by GitHub
parent 318e5347bc
commit 84f8a51458
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 325 additions and 160 deletions

View File

@ -15,7 +15,6 @@
# limitations under the License.
#
[pytest]
addopts = -ra -q
testpaths =
tests
python_files = *_test.py test_*.py *_tests.py

View File

@ -35,6 +35,7 @@ from celery.schedules import crontab
from dateutil import tz
from flask import Blueprint
from flask_appbuilder.security.manager import AUTH_DB
from pandas.io.parsers import STR_NA_VALUES
from superset.jinja_context import ( # pylint: disable=unused-import
BaseTemplateProcessor,
@ -622,6 +623,9 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
UPLOADED_CSV_HIVE_NAMESPACE
] if UPLOADED_CSV_HIVE_NAMESPACE else []
# Values that should be treated as nulls for the csv uploads.
CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
# A dictionary of items that gets merged into the Jinja context for
# SQL Lab. The existing context gets updated with this dictionary,
# meaning values for existing keys get overwritten by the content of this

View File

@ -106,6 +106,45 @@ class HiveEngineSpec(PrestoEngineSpec):
except pyhive.exc.ProgrammingError:
return []
@classmethod
def get_create_table_stmt( # pylint: disable=too-many-arguments
cls,
table: Table,
schema_definition: str,
location: str,
delim: str,
header_line_count: Optional[int],
null_values: Optional[List[str]],
) -> text:
tblproperties = []
# available options:
# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
# TODO(bkyryliuk): figure out what to do with the skip rows field.
params: Dict[str, str] = {
"delim": delim,
"location": location,
}
if header_line_count is not None and header_line_count >= 0:
header_line_count += 1
tblproperties.append("'skip.header.line.count'=':header_line_count'")
params["header_line_count"] = str(header_line_count)
if null_values:
# hive only supports 1 value for the null format
tblproperties.append("'serialization.null.format'=':null_value'")
params["null_value"] = null_values[0]
if tblproperties:
tblproperties_stmt = f"tblproperties ({', '.join(tblproperties)})"
sql = f"""CREATE TABLE {str(table)} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
{tblproperties_stmt}"""
else:
sql = f"""CREATE TABLE {str(table)} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location"""
return sql, params
@classmethod
def create_table_from_csv( # pylint: disable=too-many-arguments, too-many-locals
cls,
@ -182,18 +221,17 @@ class HiveEngineSpec(PrestoEngineSpec):
bucket_path,
os.path.join(upload_prefix, table.table, os.path.basename(filename)),
)
sql = text(
f"""CREATE TABLE {str(table)} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'='1')"""
sql, params = cls.get_create_table_stmt(
table,
schema_definition,
location,
csv_to_df_kwargs["sep"].encode().decode("unicode_escape"),
int(csv_to_df_kwargs.get("header", 0)),
csv_to_df_kwargs.get("na_values"),
)
engine = cls.get_engine(database)
engine.execute(
sql,
delim=csv_to_df_kwargs["sep"].encode().decode("unicode_escape"),
location=location,
)
engine.execute(text(sql), **params)
@classmethod
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:

View File

@ -15,12 +15,27 @@
# specific language governing permissions and limitations
# under the License.
"""Contains the logic to create cohesive forms on the explore view"""
import json
from typing import Any, List, Optional
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from wtforms import Field
class JsonListField(Field):
widget = BS3TextFieldWidget()
data: List[str] = []
def _value(self) -> str:
return json.dumps(self.data)
def process_formdata(self, valuelist: List[str]) -> None:
if valuelist and valuelist[0]:
self.data = json.loads(valuelist[0])
else:
self.data = []
class CommaSeparatedListField(Field):
widget = BS3TextFieldWidget()
data: List[str] = []

View File

@ -26,7 +26,11 @@ from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.validators import DataRequired, Length, NumberRange, Optional
from superset import app, db, security_manager
from superset.forms import CommaSeparatedListField, filter_not_empty_values
from superset.forms import (
CommaSeparatedListField,
filter_not_empty_values,
JsonListField,
)
from superset.models.core import Database
config = app.config
@ -210,6 +214,16 @@ class CsvToDatabaseForm(DynamicForm):
validators=[Optional()],
widget=BS3TextFieldWidget(),
)
null_values = JsonListField(
_("Null values"),
default=config["CSV_DEFAULT_NA_NAMES"],
description=_(
"Json list of the values that should be treated as null. "
'Examples: [""], ["None", "N/A"], ["nan", "null"]. '
"Warning: Hive database supports only single value. "
'Use [""] for empty string.'
),
)
class ExcelToDatabaseForm(DynamicForm):
@ -376,3 +390,13 @@ class ExcelToDatabaseForm(DynamicForm):
validators=[Optional()],
widget=BS3TextFieldWidget(),
)
null_values = JsonListField(
_("Null values"),
default=config["CSV_DEFAULT_NA_NAMES"],
description=_(
"Json list of the values that should be treated as null. "
'Examples: [""], ["None", "N/A"], ["nan", "null"]. '
"Warning: Hive database supports only single value. "
'Use [""] for empty string.'
),
)

View File

@ -149,6 +149,9 @@ class CsvToDatabaseView(SimpleFormView):
database = (
db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
)
# More can be found here:
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
csv_to_df_kwargs = {
"sep": form.sep.data,
"header": form.header.data if form.header.data else 0,
@ -162,6 +165,12 @@ class CsvToDatabaseView(SimpleFormView):
"infer_datetime_format": form.infer_datetime_format.data,
"chunksize": 1000,
}
if form.null_values.data:
csv_to_df_kwargs["na_values"] = form.null_values.data
csv_to_df_kwargs["keep_default_na"] = False
# More can be found here:
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html
df_to_sql_kwargs = {
"name": csv_table.table,
"if_exists": form.if_exists.data,

View File

@ -27,8 +27,8 @@ from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from sqlalchemy.orm import Session
from tests.test_app import app
from superset.sql_parse import CtasMethod
from tests.test_app import app # isort:skip
from superset import db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.druid.models import DruidCluster, DruidDatasource

View File

@ -916,12 +916,12 @@ class TestCore(SupersetTestCase):
def test_import_csv(self):
self.login(username="admin")
table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5))
table_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
f1 = "testCSV.csv"
self.create_sample_csvfile(f1, ["a,b", "john,1", "paul,2"])
f2 = "testCSV2.csv"
self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,y"])
self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,"])
self.enable_csv_upload(utils.get_example_database())
try:
@ -957,6 +957,23 @@ class TestCore(SupersetTestCase):
table = self.get_table_by_name(table_name)
# make sure the new column name is reflected in the table metadata
self.assertIn("d", table.column_names)
# null values are set
self.upload_csv(
f2,
table_name,
extra={"null_values": '["", "john"]', "if_exists": "replace"},
)
# make sure that john and empty string are replaced with None
data = db.session.execute(f"SELECT * from {table_name}").fetchall()
assert data == [(None, 1, "x"), ("paul", 2, None)]
# default null values
self.upload_csv(f2, table_name, extra={"if_exists": "replace"})
# make sure that john and empty string are replaced with None
data = db.session.execute(f"SELECT * from {table_name}").fetchall()
assert data == [("john", 1, "x"), ("paul", 2, None)]
finally:
os.remove(f1)
os.remove(f2)

View File

@ -14,13 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from datetime import datetime
from tests.test_app import app
from tests.base_tests import SupersetTestCase
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.models.core import Database
from tests.base_tests import SupersetTestCase
from tests.test_app import app # isort:skip
class TestDbEngineSpec(SupersetTestCase):

View File

@ -14,42 +14,48 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from datetime import datetime
from unittest import mock
import pytest
from tests.test_app import app
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.exceptions import SupersetException
from superset.sql_parse import Table
from tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestHiveDbEngineSpec(TestDbEngineSpec):
def test_0_progress(self):
def test_0_progress():
log = """
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
""".split(
"\n"
)
self.assertEqual(0, HiveEngineSpec.progress(log))
assert HiveEngineSpec.progress(log) == 0
def test_number_of_jobs_progress(self):
def test_number_of_jobs_progress():
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
""".split(
"\n"
)
self.assertEqual(0, HiveEngineSpec.progress(log))
assert HiveEngineSpec.progress(log) == 0
def test_job_1_launched_progress(self):
def test_job_1_launched_progress():
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
""".split(
"\n"
)
self.assertEqual(0, HiveEngineSpec.progress(log))
assert HiveEngineSpec.progress(log) == 0
def test_job_1_launched_stage_1(self):
def test_job_1_launched_stage_1():
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
@ -57,11 +63,10 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
""".split(
"\n"
)
self.assertEqual(0, HiveEngineSpec.progress(log))
assert HiveEngineSpec.progress(log) == 0
def test_job_1_launched_stage_1_map_40_progress(
self,
): # pylint: disable=invalid-name
def test_job_1_launched_stage_1_map_40_progress(): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
@ -70,11 +75,10 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
""".split(
"\n"
)
self.assertEqual(10, HiveEngineSpec.progress(log))
assert HiveEngineSpec.progress(log) == 10
def test_job_1_launched_stage_1_map_80_reduce_40_progress(
self,
): # pylint: disable=invalid-name
def test_job_1_launched_stage_1_map_80_reduce_40_progress(): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
@ -84,11 +88,10 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
""".split(
"\n"
)
self.assertEqual(30, HiveEngineSpec.progress(log))
assert HiveEngineSpec.progress(log) == 30
def test_job_1_launched_stage_2_stages_progress(
self,
): # pylint: disable=invalid-name
def test_job_1_launched_stage_2_stages_progress(): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
@ -100,11 +103,10 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
""".split(
"\n"
)
self.assertEqual(12, HiveEngineSpec.progress(log))
assert HiveEngineSpec.progress(log) == 12
def test_job_2_launched_stage_2_stages_progress(
self,
): # pylint: disable=invalid-name
def test_job_2_launched_stage_2_stages_progress(): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
@ -115,63 +117,120 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
""".split(
"\n"
)
self.assertEqual(60, HiveEngineSpec.progress(log))
assert HiveEngineSpec.progress(log) == 60
def test_hive_error_msg(self):
def test_hive_error_msg():
msg = (
'{...} errorMessage="Error while compiling statement: FAILED: '
"SemanticException [Error 10001]: Line 4"
":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, "
"sqlState='42S02', errorCode=10001)){...}"
)
self.assertEqual(
(
assert HiveEngineSpec.extract_error_message(Exception(msg)) == (
"hive error: Error while compiling statement: FAILED: "
"SemanticException [Error 10001]: Line 4:5 "
"Table not found 'fact_ridesfdslakj'"
),
HiveEngineSpec.extract_error_message(Exception(msg)),
)
e = Exception("Some string that doesn't match the regex")
self.assertEqual(f"hive error: {e}", HiveEngineSpec.extract_error_message(e))
assert HiveEngineSpec.extract_error_message(e) == f"hive error: {e}"
msg = (
"errorCode=10001, "
'errorMessage="Error while compiling statement"), operationHandle'
'=None)"'
)
self.assertEqual(
("hive error: Error while compiling statement"),
HiveEngineSpec.extract_error_message(Exception(msg)),
assert (
HiveEngineSpec.extract_error_message(Exception(msg))
== "hive error: Error while compiling statement"
)
def test_hive_get_view_names_return_empty_list(
self,
): # pylint: disable=invalid-name
self.assertEqual(
[], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
def test_hive_get_view_names_return_empty_list(): # pylint: disable=invalid-name
assert HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) == []
def test_convert_dttm():
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
assert (
HiveEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)"
)
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
HiveEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
def test_create_table_from_csv_append() -> None:
with pytest.raises(SupersetException):
HiveEngineSpec.create_table_from_csv(
"foo.csv", Table("foobar"), mock.MagicMock(), {}, {"if_exists": "append"}
)
self.assertEqual(
HiveEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)",
)
def test_create_table_from_csv_append(self) -> None:
self.assertRaises(
SupersetException,
HiveEngineSpec.create_table_from_csv,
"foo.csv",
Table("foobar"),
None,
{},
{"if_exists": "append"},
def test_get_create_table_stmt() -> None:
table = Table("employee")
schema_def = """eid int, name String, salary String, destination String"""
location = "s3a://directory/table"
from unittest import TestCase
TestCase.maxDiff = None
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", 0, [""]
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""",
{
"delim": ",",
"location": "s3a://directory/table",
"header_line_count": "1",
"null_value": "",
},
)
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", 1, ["1", "2"]
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""",
{
"delim": ",",
"location": "s3a://directory/table",
"header_line_count": "2",
"null_value": "1",
},
)
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", 100, ["NaN"]
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""",
{
"delim": ",",
"location": "s3a://directory/table",
"header_line_count": "101",
"null_value": "NaN",
},
)
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", None, None
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location""",
{"delim": ",", "location": "s3a://directory/table"},
)
assert HiveEngineSpec.get_create_table_stmt(
table, schema_def, location, ",", 100, []
) == (
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
STORED AS TEXTFILE LOCATION :location
tblproperties ('skip.header.line.count'=':header_line_count')""",
{"delim": ",", "location": "s3a://directory/table", "header_line_count": "101"},
)