feat: support nulls in the csv uploads (#10208)
* Support more table properties for the hive upload Refactor Add tests, and refactor them to be pytest friendly Use lowercase table names Ignore isort * Use sql params Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
parent
318e5347bc
commit
84f8a51458
|
|
@ -15,7 +15,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
[pytest]
|
[pytest]
|
||||||
addopts = -ra -q
|
|
||||||
testpaths =
|
testpaths =
|
||||||
tests
|
tests
|
||||||
python_files = *_test.py test_*.py *_tests.py
|
python_files = *_test.py test_*.py *_tests.py
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ from celery.schedules import crontab
|
||||||
from dateutil import tz
|
from dateutil import tz
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask_appbuilder.security.manager import AUTH_DB
|
from flask_appbuilder.security.manager import AUTH_DB
|
||||||
|
from pandas.io.parsers import STR_NA_VALUES
|
||||||
|
|
||||||
from superset.jinja_context import ( # pylint: disable=unused-import
|
from superset.jinja_context import ( # pylint: disable=unused-import
|
||||||
BaseTemplateProcessor,
|
BaseTemplateProcessor,
|
||||||
|
|
@ -622,6 +623,9 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
|
||||||
UPLOADED_CSV_HIVE_NAMESPACE
|
UPLOADED_CSV_HIVE_NAMESPACE
|
||||||
] if UPLOADED_CSV_HIVE_NAMESPACE else []
|
] if UPLOADED_CSV_HIVE_NAMESPACE else []
|
||||||
|
|
||||||
|
# Values that should be treated as nulls for the csv uploads.
|
||||||
|
CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
|
||||||
|
|
||||||
# A dictionary of items that gets merged into the Jinja context for
|
# A dictionary of items that gets merged into the Jinja context for
|
||||||
# SQL Lab. The existing context gets updated with this dictionary,
|
# SQL Lab. The existing context gets updated with this dictionary,
|
||||||
# meaning values for existing keys get overwritten by the content of this
|
# meaning values for existing keys get overwritten by the content of this
|
||||||
|
|
|
||||||
|
|
@ -106,6 +106,45 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||||
except pyhive.exc.ProgrammingError:
|
except pyhive.exc.ProgrammingError:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_create_table_stmt( # pylint: disable=too-many-arguments
|
||||||
|
cls,
|
||||||
|
table: Table,
|
||||||
|
schema_definition: str,
|
||||||
|
location: str,
|
||||||
|
delim: str,
|
||||||
|
header_line_count: Optional[int],
|
||||||
|
null_values: Optional[List[str]],
|
||||||
|
) -> text:
|
||||||
|
tblproperties = []
|
||||||
|
# available options:
|
||||||
|
# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
|
||||||
|
# TODO(bkyryliuk): figure out what to do with the skip rows field.
|
||||||
|
params: Dict[str, str] = {
|
||||||
|
"delim": delim,
|
||||||
|
"location": location,
|
||||||
|
}
|
||||||
|
if header_line_count is not None and header_line_count >= 0:
|
||||||
|
header_line_count += 1
|
||||||
|
tblproperties.append("'skip.header.line.count'=':header_line_count'")
|
||||||
|
params["header_line_count"] = str(header_line_count)
|
||||||
|
if null_values:
|
||||||
|
# hive only supports 1 value for the null format
|
||||||
|
tblproperties.append("'serialization.null.format'=':null_value'")
|
||||||
|
params["null_value"] = null_values[0]
|
||||||
|
|
||||||
|
if tblproperties:
|
||||||
|
tblproperties_stmt = f"tblproperties ({', '.join(tblproperties)})"
|
||||||
|
sql = f"""CREATE TABLE {str(table)} ( {schema_definition} )
|
||||||
|
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
|
||||||
|
STORED AS TEXTFILE LOCATION :location
|
||||||
|
{tblproperties_stmt}"""
|
||||||
|
else:
|
||||||
|
sql = f"""CREATE TABLE {str(table)} ( {schema_definition} )
|
||||||
|
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
|
||||||
|
STORED AS TEXTFILE LOCATION :location"""
|
||||||
|
return sql, params
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_table_from_csv( # pylint: disable=too-many-arguments, too-many-locals
|
def create_table_from_csv( # pylint: disable=too-many-arguments, too-many-locals
|
||||||
cls,
|
cls,
|
||||||
|
|
@ -182,18 +221,17 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||||
bucket_path,
|
bucket_path,
|
||||||
os.path.join(upload_prefix, table.table, os.path.basename(filename)),
|
os.path.join(upload_prefix, table.table, os.path.basename(filename)),
|
||||||
)
|
)
|
||||||
sql = text(
|
|
||||||
f"""CREATE TABLE {str(table)} ( {schema_definition} )
|
sql, params = cls.get_create_table_stmt(
|
||||||
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
|
table,
|
||||||
STORED AS TEXTFILE LOCATION :location
|
schema_definition,
|
||||||
tblproperties ('skip.header.line.count'='1')"""
|
location,
|
||||||
|
csv_to_df_kwargs["sep"].encode().decode("unicode_escape"),
|
||||||
|
int(csv_to_df_kwargs.get("header", 0)),
|
||||||
|
csv_to_df_kwargs.get("na_values"),
|
||||||
)
|
)
|
||||||
engine = cls.get_engine(database)
|
engine = cls.get_engine(database)
|
||||||
engine.execute(
|
engine.execute(text(sql), **params)
|
||||||
sql,
|
|
||||||
delim=csv_to_df_kwargs["sep"].encode().decode("unicode_escape"),
|
|
||||||
location=location,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
|
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,27 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
"""Contains the logic to create cohesive forms on the explore view"""
|
"""Contains the logic to create cohesive forms on the explore view"""
|
||||||
|
import json
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
|
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
|
||||||
from wtforms import Field
|
from wtforms import Field
|
||||||
|
|
||||||
|
|
||||||
|
class JsonListField(Field):
|
||||||
|
widget = BS3TextFieldWidget()
|
||||||
|
data: List[str] = []
|
||||||
|
|
||||||
|
def _value(self) -> str:
|
||||||
|
return json.dumps(self.data)
|
||||||
|
|
||||||
|
def process_formdata(self, valuelist: List[str]) -> None:
|
||||||
|
if valuelist and valuelist[0]:
|
||||||
|
self.data = json.loads(valuelist[0])
|
||||||
|
else:
|
||||||
|
self.data = []
|
||||||
|
|
||||||
|
|
||||||
class CommaSeparatedListField(Field):
|
class CommaSeparatedListField(Field):
|
||||||
widget = BS3TextFieldWidget()
|
widget = BS3TextFieldWidget()
|
||||||
data: List[str] = []
|
data: List[str] = []
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,11 @@ from wtforms.ext.sqlalchemy.fields import QuerySelectField
|
||||||
from wtforms.validators import DataRequired, Length, NumberRange, Optional
|
from wtforms.validators import DataRequired, Length, NumberRange, Optional
|
||||||
|
|
||||||
from superset import app, db, security_manager
|
from superset import app, db, security_manager
|
||||||
from superset.forms import CommaSeparatedListField, filter_not_empty_values
|
from superset.forms import (
|
||||||
|
CommaSeparatedListField,
|
||||||
|
filter_not_empty_values,
|
||||||
|
JsonListField,
|
||||||
|
)
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
||||||
config = app.config
|
config = app.config
|
||||||
|
|
@ -210,6 +214,16 @@ class CsvToDatabaseForm(DynamicForm):
|
||||||
validators=[Optional()],
|
validators=[Optional()],
|
||||||
widget=BS3TextFieldWidget(),
|
widget=BS3TextFieldWidget(),
|
||||||
)
|
)
|
||||||
|
null_values = JsonListField(
|
||||||
|
_("Null values"),
|
||||||
|
default=config["CSV_DEFAULT_NA_NAMES"],
|
||||||
|
description=_(
|
||||||
|
"Json list of the values that should be treated as null. "
|
||||||
|
'Examples: [""], ["None", "N/A"], ["nan", "null"]. '
|
||||||
|
"Warning: Hive database supports only single value. "
|
||||||
|
'Use [""] for empty string.'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExcelToDatabaseForm(DynamicForm):
|
class ExcelToDatabaseForm(DynamicForm):
|
||||||
|
|
@ -376,3 +390,13 @@ class ExcelToDatabaseForm(DynamicForm):
|
||||||
validators=[Optional()],
|
validators=[Optional()],
|
||||||
widget=BS3TextFieldWidget(),
|
widget=BS3TextFieldWidget(),
|
||||||
)
|
)
|
||||||
|
null_values = JsonListField(
|
||||||
|
_("Null values"),
|
||||||
|
default=config["CSV_DEFAULT_NA_NAMES"],
|
||||||
|
description=_(
|
||||||
|
"Json list of the values that should be treated as null. "
|
||||||
|
'Examples: [""], ["None", "N/A"], ["nan", "null"]. '
|
||||||
|
"Warning: Hive database supports only single value. "
|
||||||
|
'Use [""] for empty string.'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -149,6 +149,9 @@ class CsvToDatabaseView(SimpleFormView):
|
||||||
database = (
|
database = (
|
||||||
db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
|
db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# More can be found here:
|
||||||
|
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
|
||||||
csv_to_df_kwargs = {
|
csv_to_df_kwargs = {
|
||||||
"sep": form.sep.data,
|
"sep": form.sep.data,
|
||||||
"header": form.header.data if form.header.data else 0,
|
"header": form.header.data if form.header.data else 0,
|
||||||
|
|
@ -162,6 +165,12 @@ class CsvToDatabaseView(SimpleFormView):
|
||||||
"infer_datetime_format": form.infer_datetime_format.data,
|
"infer_datetime_format": form.infer_datetime_format.data,
|
||||||
"chunksize": 1000,
|
"chunksize": 1000,
|
||||||
}
|
}
|
||||||
|
if form.null_values.data:
|
||||||
|
csv_to_df_kwargs["na_values"] = form.null_values.data
|
||||||
|
csv_to_df_kwargs["keep_default_na"] = False
|
||||||
|
|
||||||
|
# More can be found here:
|
||||||
|
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html
|
||||||
df_to_sql_kwargs = {
|
df_to_sql_kwargs = {
|
||||||
"name": csv_table.table,
|
"name": csv_table.table,
|
||||||
"if_exists": form.if_exists.data,
|
"if_exists": form.if_exists.data,
|
||||||
|
|
|
||||||
|
|
@ -27,8 +27,8 @@ from flask_appbuilder.security.sqla import models as ab_models
|
||||||
from flask_testing import TestCase
|
from flask_testing import TestCase
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from tests.test_app import app
|
||||||
from superset.sql_parse import CtasMethod
|
from superset.sql_parse import CtasMethod
|
||||||
from tests.test_app import app # isort:skip
|
|
||||||
from superset import db, security_manager
|
from superset import db, security_manager
|
||||||
from superset.connectors.base.models import BaseDatasource
|
from superset.connectors.base.models import BaseDatasource
|
||||||
from superset.connectors.druid.models import DruidCluster, DruidDatasource
|
from superset.connectors.druid.models import DruidCluster, DruidDatasource
|
||||||
|
|
|
||||||
|
|
@ -916,12 +916,12 @@ class TestCore(SupersetTestCase):
|
||||||
|
|
||||||
def test_import_csv(self):
|
def test_import_csv(self):
|
||||||
self.login(username="admin")
|
self.login(username="admin")
|
||||||
table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5))
|
table_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
|
||||||
|
|
||||||
f1 = "testCSV.csv"
|
f1 = "testCSV.csv"
|
||||||
self.create_sample_csvfile(f1, ["a,b", "john,1", "paul,2"])
|
self.create_sample_csvfile(f1, ["a,b", "john,1", "paul,2"])
|
||||||
f2 = "testCSV2.csv"
|
f2 = "testCSV2.csv"
|
||||||
self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,y"])
|
self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,"])
|
||||||
self.enable_csv_upload(utils.get_example_database())
|
self.enable_csv_upload(utils.get_example_database())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -957,6 +957,23 @@ class TestCore(SupersetTestCase):
|
||||||
table = self.get_table_by_name(table_name)
|
table = self.get_table_by_name(table_name)
|
||||||
# make sure the new column name is reflected in the table metadata
|
# make sure the new column name is reflected in the table metadata
|
||||||
self.assertIn("d", table.column_names)
|
self.assertIn("d", table.column_names)
|
||||||
|
|
||||||
|
# null values are set
|
||||||
|
self.upload_csv(
|
||||||
|
f2,
|
||||||
|
table_name,
|
||||||
|
extra={"null_values": '["", "john"]', "if_exists": "replace"},
|
||||||
|
)
|
||||||
|
# make sure that john and empty string are replaced with None
|
||||||
|
data = db.session.execute(f"SELECT * from {table_name}").fetchall()
|
||||||
|
assert data == [(None, 1, "x"), ("paul", 2, None)]
|
||||||
|
|
||||||
|
# default null values
|
||||||
|
self.upload_csv(f2, table_name, extra={"if_exists": "replace"})
|
||||||
|
# make sure that john and empty string are replaced with None
|
||||||
|
data = db.session.execute(f"SELECT * from {table_name}").fetchall()
|
||||||
|
assert data == [("john", 1, "x"), ("paul", 2, None)]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
os.remove(f1)
|
os.remove(f1)
|
||||||
os.remove(f2)
|
os.remove(f2)
|
||||||
|
|
|
||||||
|
|
@ -14,13 +14,13 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
# isort:skip_file
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from tests.test_app import app
|
||||||
|
from tests.base_tests import SupersetTestCase
|
||||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from tests.base_tests import SupersetTestCase
|
|
||||||
|
|
||||||
from tests.test_app import app # isort:skip
|
|
||||||
|
|
||||||
|
|
||||||
class TestDbEngineSpec(SupersetTestCase):
|
class TestDbEngineSpec(SupersetTestCase):
|
||||||
|
|
|
||||||
|
|
@ -14,42 +14,48 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
# isort:skip_file
|
||||||
|
from datetime import datetime
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.test_app import app
|
||||||
from superset.db_engine_specs.hive import HiveEngineSpec
|
from superset.db_engine_specs.hive import HiveEngineSpec
|
||||||
from superset.exceptions import SupersetException
|
from superset.exceptions import SupersetException
|
||||||
from superset.sql_parse import Table
|
from superset.sql_parse import Table
|
||||||
from tests.db_engine_specs.base_tests import TestDbEngineSpec
|
|
||||||
|
|
||||||
|
|
||||||
class TestHiveDbEngineSpec(TestDbEngineSpec):
|
def test_0_progress():
|
||||||
def test_0_progress(self):
|
|
||||||
log = """
|
log = """
|
||||||
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
|
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
|
||||||
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
|
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
assert HiveEngineSpec.progress(log) == 0
|
||||||
|
|
||||||
def test_number_of_jobs_progress(self):
|
|
||||||
|
def test_number_of_jobs_progress():
|
||||||
log = """
|
log = """
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
assert HiveEngineSpec.progress(log) == 0
|
||||||
|
|
||||||
def test_job_1_launched_progress(self):
|
|
||||||
|
def test_job_1_launched_progress():
|
||||||
log = """
|
log = """
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
assert HiveEngineSpec.progress(log) == 0
|
||||||
|
|
||||||
def test_job_1_launched_stage_1(self):
|
|
||||||
|
def test_job_1_launched_stage_1():
|
||||||
log = """
|
log = """
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||||
|
|
@ -57,11 +63,10 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(0, HiveEngineSpec.progress(log))
|
assert HiveEngineSpec.progress(log) == 0
|
||||||
|
|
||||||
def test_job_1_launched_stage_1_map_40_progress(
|
|
||||||
self,
|
def test_job_1_launched_stage_1_map_40_progress(): # pylint: disable=invalid-name
|
||||||
): # pylint: disable=invalid-name
|
|
||||||
log = """
|
log = """
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||||
|
|
@ -70,11 +75,10 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(10, HiveEngineSpec.progress(log))
|
assert HiveEngineSpec.progress(log) == 10
|
||||||
|
|
||||||
def test_job_1_launched_stage_1_map_80_reduce_40_progress(
|
|
||||||
self,
|
def test_job_1_launched_stage_1_map_80_reduce_40_progress(): # pylint: disable=invalid-name
|
||||||
): # pylint: disable=invalid-name
|
|
||||||
log = """
|
log = """
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||||
|
|
@ -84,11 +88,10 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(30, HiveEngineSpec.progress(log))
|
assert HiveEngineSpec.progress(log) == 30
|
||||||
|
|
||||||
def test_job_1_launched_stage_2_stages_progress(
|
|
||||||
self,
|
def test_job_1_launched_stage_2_stages_progress(): # pylint: disable=invalid-name
|
||||||
): # pylint: disable=invalid-name
|
|
||||||
log = """
|
log = """
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||||
|
|
@ -100,11 +103,10 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(12, HiveEngineSpec.progress(log))
|
assert HiveEngineSpec.progress(log) == 12
|
||||||
|
|
||||||
def test_job_2_launched_stage_2_stages_progress(
|
|
||||||
self,
|
def test_job_2_launched_stage_2_stages_progress(): # pylint: disable=invalid-name
|
||||||
): # pylint: disable=invalid-name
|
|
||||||
log = """
|
log = """
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
||||||
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
||||||
|
|
@ -115,63 +117,120 @@ class TestHiveDbEngineSpec(TestDbEngineSpec):
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(60, HiveEngineSpec.progress(log))
|
assert HiveEngineSpec.progress(log) == 60
|
||||||
|
|
||||||
def test_hive_error_msg(self):
|
|
||||||
|
def test_hive_error_msg():
|
||||||
msg = (
|
msg = (
|
||||||
'{...} errorMessage="Error while compiling statement: FAILED: '
|
'{...} errorMessage="Error while compiling statement: FAILED: '
|
||||||
"SemanticException [Error 10001]: Line 4"
|
"SemanticException [Error 10001]: Line 4"
|
||||||
":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, "
|
":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, "
|
||||||
"sqlState='42S02', errorCode=10001)){...}"
|
"sqlState='42S02', errorCode=10001)){...}"
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
assert HiveEngineSpec.extract_error_message(Exception(msg)) == (
|
||||||
(
|
|
||||||
"hive error: Error while compiling statement: FAILED: "
|
"hive error: Error while compiling statement: FAILED: "
|
||||||
"SemanticException [Error 10001]: Line 4:5 "
|
"SemanticException [Error 10001]: Line 4:5 "
|
||||||
"Table not found 'fact_ridesfdslakj'"
|
"Table not found 'fact_ridesfdslakj'"
|
||||||
),
|
|
||||||
HiveEngineSpec.extract_error_message(Exception(msg)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
e = Exception("Some string that doesn't match the regex")
|
e = Exception("Some string that doesn't match the regex")
|
||||||
self.assertEqual(f"hive error: {e}", HiveEngineSpec.extract_error_message(e))
|
assert HiveEngineSpec.extract_error_message(e) == f"hive error: {e}"
|
||||||
|
|
||||||
msg = (
|
msg = (
|
||||||
"errorCode=10001, "
|
"errorCode=10001, "
|
||||||
'errorMessage="Error while compiling statement"), operationHandle'
|
'errorMessage="Error while compiling statement"), operationHandle'
|
||||||
'=None)"'
|
'=None)"'
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
assert (
|
||||||
("hive error: Error while compiling statement"),
|
HiveEngineSpec.extract_error_message(Exception(msg))
|
||||||
HiveEngineSpec.extract_error_message(Exception(msg)),
|
== "hive error: Error while compiling statement"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hive_get_view_names_return_empty_list(
|
|
||||||
self,
|
def test_hive_get_view_names_return_empty_list(): # pylint: disable=invalid-name
|
||||||
): # pylint: disable=invalid-name
|
assert HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) == []
|
||||||
self.assertEqual(
|
|
||||||
[], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
|
|
||||||
|
def test_convert_dttm():
|
||||||
|
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
|
||||||
|
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
|
||||||
|
assert (
|
||||||
|
HiveEngineSpec.convert_dttm("TIMESTAMP", dttm)
|
||||||
|
== "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_convert_dttm(self):
|
|
||||||
dttm = self.get_dttm()
|
|
||||||
|
|
||||||
self.assertEqual(
|
def test_create_table_from_csv_append() -> None:
|
||||||
HiveEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
|
|
||||||
|
with pytest.raises(SupersetException):
|
||||||
|
HiveEngineSpec.create_table_from_csv(
|
||||||
|
"foo.csv", Table("foobar"), mock.MagicMock(), {}, {"if_exists": "append"}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
HiveEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
|
||||||
"CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_create_table_from_csv_append(self) -> None:
|
def test_get_create_table_stmt() -> None:
|
||||||
self.assertRaises(
|
table = Table("employee")
|
||||||
SupersetException,
|
schema_def = """eid int, name String, salary String, destination String"""
|
||||||
HiveEngineSpec.create_table_from_csv,
|
location = "s3a://directory/table"
|
||||||
"foo.csv",
|
from unittest import TestCase
|
||||||
Table("foobar"),
|
|
||||||
None,
|
TestCase.maxDiff = None
|
||||||
{},
|
assert HiveEngineSpec.get_create_table_stmt(
|
||||||
{"if_exists": "append"},
|
table, schema_def, location, ",", 0, [""]
|
||||||
|
) == (
|
||||||
|
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
|
||||||
|
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
|
||||||
|
STORED AS TEXTFILE LOCATION :location
|
||||||
|
tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""",
|
||||||
|
{
|
||||||
|
"delim": ",",
|
||||||
|
"location": "s3a://directory/table",
|
||||||
|
"header_line_count": "1",
|
||||||
|
"null_value": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert HiveEngineSpec.get_create_table_stmt(
|
||||||
|
table, schema_def, location, ",", 1, ["1", "2"]
|
||||||
|
) == (
|
||||||
|
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
|
||||||
|
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
|
||||||
|
STORED AS TEXTFILE LOCATION :location
|
||||||
|
tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""",
|
||||||
|
{
|
||||||
|
"delim": ",",
|
||||||
|
"location": "s3a://directory/table",
|
||||||
|
"header_line_count": "2",
|
||||||
|
"null_value": "1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert HiveEngineSpec.get_create_table_stmt(
|
||||||
|
table, schema_def, location, ",", 100, ["NaN"]
|
||||||
|
) == (
|
||||||
|
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
|
||||||
|
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
|
||||||
|
STORED AS TEXTFILE LOCATION :location
|
||||||
|
tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""",
|
||||||
|
{
|
||||||
|
"delim": ",",
|
||||||
|
"location": "s3a://directory/table",
|
||||||
|
"header_line_count": "101",
|
||||||
|
"null_value": "NaN",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert HiveEngineSpec.get_create_table_stmt(
|
||||||
|
table, schema_def, location, ",", None, None
|
||||||
|
) == (
|
||||||
|
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
|
||||||
|
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
|
||||||
|
STORED AS TEXTFILE LOCATION :location""",
|
||||||
|
{"delim": ",", "location": "s3a://directory/table"},
|
||||||
|
)
|
||||||
|
assert HiveEngineSpec.get_create_table_stmt(
|
||||||
|
table, schema_def, location, ",", 100, []
|
||||||
|
) == (
|
||||||
|
"""CREATE TABLE employee ( eid int, name String, salary String, destination String )
|
||||||
|
ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim
|
||||||
|
STORED AS TEXTFILE LOCATION :location
|
||||||
|
tblproperties ('skip.header.line.count'=':header_line_count')""",
|
||||||
|
{"delim": ",", "location": "s3a://directory/table", "header_line_count": "101"},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue