diff --git a/pytest.ini b/pytest.ini index 418efb256..fdb50114d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -15,7 +15,6 @@ # limitations under the License. # [pytest] -addopts = -ra -q testpaths = tests python_files = *_test.py test_*.py *_tests.py diff --git a/superset/config.py b/superset/config.py index d4da7c474..3b2637351 100644 --- a/superset/config.py +++ b/superset/config.py @@ -35,6 +35,7 @@ from celery.schedules import crontab from dateutil import tz from flask import Blueprint from flask_appbuilder.security.manager import AUTH_DB +from pandas.io.parsers import STR_NA_VALUES from superset.jinja_context import ( # pylint: disable=unused-import BaseTemplateProcessor, @@ -622,6 +623,9 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[ UPLOADED_CSV_HIVE_NAMESPACE ] if UPLOADED_CSV_HIVE_NAMESPACE else [] +# Values that should be treated as nulls for the csv uploads. +CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES) + # A dictionary of items that gets merged into the Jinja context for # SQL Lab. The existing context gets updated with this dictionary, # meaning values for existing keys get overwritten by the content of this diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index c8a07a9b3..478a65ec9 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -106,6 +106,45 @@ class HiveEngineSpec(PrestoEngineSpec): except pyhive.exc.ProgrammingError: return [] + @classmethod + def get_create_table_stmt( # pylint: disable=too-many-arguments + cls, + table: Table, + schema_definition: str, + location: str, + delim: str, + header_line_count: Optional[int], + null_values: Optional[List[str]], + ) -> text: + tblproperties = [] + # available options: + # https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + # TODO(bkyryliuk): figure out what to do with the skip rows field. + params: Dict[str, str] = { + "delim": delim, + "location": location, + } + if header_line_count is not None and header_line_count >= 0: + header_line_count += 1 + tblproperties.append("'skip.header.line.count'=':header_line_count'") + params["header_line_count"] = str(header_line_count) + if null_values: + # hive only supports 1 value for the null format + tblproperties.append("'serialization.null.format'=':null_value'") + params["null_value"] = null_values[0] + + if tblproperties: + tblproperties_stmt = f"tblproperties ({', '.join(tblproperties)})" + sql = f"""CREATE TABLE {str(table)} ( {schema_definition} ) + ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim + STORED AS TEXTFILE LOCATION :location + {tblproperties_stmt}""" + else: + sql = f"""CREATE TABLE {str(table)} ( {schema_definition} ) + ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim + STORED AS TEXTFILE LOCATION :location""" + return sql, params + @classmethod def create_table_from_csv( # pylint: disable=too-many-arguments, too-many-locals cls, @@ -182,18 +221,17 @@ class HiveEngineSpec(PrestoEngineSpec): bucket_path, os.path.join(upload_prefix, table.table, os.path.basename(filename)), ) - sql = text( - f"""CREATE TABLE {str(table)} ( {schema_definition} ) - ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim - STORED AS TEXTFILE LOCATION :location - tblproperties ('skip.header.line.count'='1')""" + + sql, params = cls.get_create_table_stmt( + table, + schema_definition, + location, + csv_to_df_kwargs["sep"].encode().decode("unicode_escape"), + int(csv_to_df_kwargs.get("header", 0)), + csv_to_df_kwargs.get("na_values"), ) engine = cls.get_engine(database) - engine.execute( - sql, - delim=csv_to_df_kwargs["sep"].encode().decode("unicode_escape"), - location=location, - ) + engine.execute(text(sql), **params) @classmethod def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: diff --git a/superset/forms.py b/superset/forms.py index 4ba3ca2fb..c9b29dfcd 100644 --- a/superset/forms.py +++ b/superset/forms.py @@ -15,12 +15,27 @@ # specific language governing permissions and limitations # under the License. """Contains the logic to create cohesive forms on the explore view""" +import json from typing import Any, List, Optional from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from wtforms import Field +class JsonListField(Field): + widget = BS3TextFieldWidget() + data: List[str] = [] + + def _value(self) -> str: + return json.dumps(self.data) + + def process_formdata(self, valuelist: List[str]) -> None: + if valuelist and valuelist[0]: + self.data = json.loads(valuelist[0]) + else: + self.data = [] + + class CommaSeparatedListField(Field): widget = BS3TextFieldWidget() data: List[str] = [] diff --git a/superset/views/database/forms.py b/superset/views/database/forms.py index 13059a4b0..20003ad2a 100644 --- a/superset/views/database/forms.py +++ b/superset/views/database/forms.py @@ -26,7 +26,11 @@ from wtforms.ext.sqlalchemy.fields import QuerySelectField from wtforms.validators import DataRequired, Length, NumberRange, Optional from superset import app, db, security_manager -from superset.forms import CommaSeparatedListField, filter_not_empty_values +from superset.forms import ( + CommaSeparatedListField, + filter_not_empty_values, + JsonListField, +) from superset.models.core import Database config = app.config @@ -210,6 +214,16 @@ class CsvToDatabaseForm(DynamicForm): validators=[Optional()], widget=BS3TextFieldWidget(), ) + null_values = JsonListField( + _("Null values"), + default=config["CSV_DEFAULT_NA_NAMES"], + description=_( + "Json list of the values that should be treated as null. " + 'Examples: [""], ["None", "N/A"], ["nan", "null"]. ' + "Warning: Hive database supports only single value. " + 'Use [""] for empty string.' + ), + ) class ExcelToDatabaseForm(DynamicForm): @@ -376,3 +390,13 @@ class ExcelToDatabaseForm(DynamicForm): validators=[Optional()], widget=BS3TextFieldWidget(), ) + null_values = JsonListField( + _("Null values"), + default=config["CSV_DEFAULT_NA_NAMES"], + description=_( + "Json list of the values that should be treated as null. " + 'Examples: [""], ["None", "N/A"], ["nan", "null"]. ' + "Warning: Hive database supports only single value. " + 'Use [""] for empty string.' + ), + ) diff --git a/superset/views/database/views.py b/superset/views/database/views.py index 47b8686f0..70b590746 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -149,6 +149,9 @@ class CsvToDatabaseView(SimpleFormView): database = ( db.session.query(models.Database).filter_by(id=con.data.get("id")).one() ) + + # More can be found here: + # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html csv_to_df_kwargs = { "sep": form.sep.data, "header": form.header.data if form.header.data else 0, @@ -162,6 +165,12 @@ class CsvToDatabaseView(SimpleFormView): "infer_datetime_format": form.infer_datetime_format.data, "chunksize": 1000, } + if form.null_values.data: + csv_to_df_kwargs["na_values"] = form.null_values.data + csv_to_df_kwargs["keep_default_na"] = False + + # More can be found here: + # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html df_to_sql_kwargs = { "name": csv_table.table, "if_exists": form.if_exists.data, diff --git a/tests/base_tests.py b/tests/base_tests.py index 871480bfa..8b6db2cd3 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -27,8 +27,8 @@ from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase from sqlalchemy.orm import Session +from tests.test_app import app from superset.sql_parse import CtasMethod -from tests.test_app import app # isort:skip from superset import db, security_manager from superset.connectors.base.models import BaseDatasource from superset.connectors.druid.models import DruidCluster, DruidDatasource diff --git a/tests/core_tests.py b/tests/core_tests.py index 1efea886e..2e9ab4b36 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -916,12 +916,12 @@ class TestCore(SupersetTestCase): def test_import_csv(self): self.login(username="admin") - table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5)) + table_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) f1 = "testCSV.csv" self.create_sample_csvfile(f1, ["a,b", "john,1", "paul,2"]) f2 = "testCSV2.csv" - self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,y"]) + self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,"]) self.enable_csv_upload(utils.get_example_database()) try: @@ -957,6 +957,23 @@ class TestCore(SupersetTestCase): table = self.get_table_by_name(table_name) # make sure the new column name is reflected in the table metadata self.assertIn("d", table.column_names) + + # null values are set + self.upload_csv( + f2, + table_name, + extra={"null_values": '["", "john"]', "if_exists": "replace"}, + ) + # make sure that john and empty string are replaced with None + data = db.session.execute(f"SELECT * from {table_name}").fetchall() + assert data == [(None, 1, "x"), ("paul", 2, None)] + + # default null values + self.upload_csv(f2, table_name, extra={"if_exists": "replace"}) + # make sure that john and empty string are replaced with None + data = db.session.execute(f"SELECT * from {table_name}").fetchall() + assert data == [("john", 1, "x"), ("paul", 2, None)] + finally: os.remove(f1) os.remove(f2) diff --git a/tests/db_engine_specs/base_tests.py b/tests/db_engine_specs/base_tests.py index 9df91f4b1..b95b97fd1 100644 --- a/tests/db_engine_specs/base_tests.py +++ b/tests/db_engine_specs/base_tests.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file from datetime import datetime +from tests.test_app import app +from tests.base_tests import SupersetTestCase from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.models.core import Database -from tests.base_tests import SupersetTestCase - -from tests.test_app import app # isort:skip class TestDbEngineSpec(SupersetTestCase): diff --git a/tests/db_engine_specs/hive_tests.py b/tests/db_engine_specs/hive_tests.py index 83b23e994..b272bbb61 100644 --- a/tests/db_engine_specs/hive_tests.py +++ b/tests/db_engine_specs/hive_tests.py @@ -14,164 +14,223 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file +from datetime import datetime from unittest import mock +import pytest + +from tests.test_app import app from superset.db_engine_specs.hive import HiveEngineSpec from superset.exceptions import SupersetException from superset.sql_parse import Table -from tests.db_engine_specs.base_tests import TestDbEngineSpec -class TestHiveDbEngineSpec(TestDbEngineSpec): - def test_0_progress(self): - log = """ - 17/02/07 18:26:27 INFO log.PerfLogger: - 17/02/07 18:26:27 INFO log.PerfLogger: - """.split( - "\n" - ) - self.assertEqual(0, HiveEngineSpec.progress(log)) +def test_0_progress(): + log = """ + 17/02/07 18:26:27 INFO log.PerfLogger: + 17/02/07 18:26:27 INFO log.PerfLogger: + """.split( + "\n" + ) + assert HiveEngineSpec.progress(log) == 0 - def test_number_of_jobs_progress(self): - log = """ - 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 - """.split( - "\n" - ) - self.assertEqual(0, HiveEngineSpec.progress(log)) - def test_job_1_launched_progress(self): - log = """ - 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 - 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 - """.split( - "\n" - ) - self.assertEqual(0, HiveEngineSpec.progress(log)) +def test_number_of_jobs_progress(): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + """.split( + "\n" + ) + assert HiveEngineSpec.progress(log) == 0 - def test_job_1_launched_stage_1(self): - log = """ - 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 - 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% - """.split( - "\n" - ) - self.assertEqual(0, HiveEngineSpec.progress(log)) - def test_job_1_launched_stage_1_map_40_progress( - self, - ): # pylint: disable=invalid-name - log = """ - 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 - 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% - """.split( - "\n" - ) - self.assertEqual(10, HiveEngineSpec.progress(log)) +def test_job_1_launched_progress(): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + """.split( + "\n" + ) + assert HiveEngineSpec.progress(log) == 0 - def test_job_1_launched_stage_1_map_80_reduce_40_progress( - self, - ): # pylint: disable=invalid-name - log = """ - 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 - 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40% - """.split( - "\n" - ) - self.assertEqual(30, HiveEngineSpec.progress(log)) - def test_job_1_launched_stage_2_stages_progress( - self, - ): # pylint: disable=invalid-name - log = """ - 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 - 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40% - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0% - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0% - """.split( - "\n" - ) - self.assertEqual(12, HiveEngineSpec.progress(log)) +def test_job_1_launched_stage_1(): + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + """.split( + "\n" + ) + assert HiveEngineSpec.progress(log) == 0 - def test_job_2_launched_stage_2_stages_progress( - self, - ): # pylint: disable=invalid-name - log = """ - 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 - 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0% - 17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2 - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% - 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% - """.split( - "\n" - ) - self.assertEqual(60, HiveEngineSpec.progress(log)) - def test_hive_error_msg(self): - msg = ( - '{...} errorMessage="Error while compiling statement: FAILED: ' - "SemanticException [Error 10001]: Line 4" - ":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, " - "sqlState='42S02', errorCode=10001)){...}" - ) - self.assertEqual( - ( - "hive error: Error while compiling statement: FAILED: " - "SemanticException [Error 10001]: Line 4:5 " - "Table not found 'fact_ridesfdslakj'" - ), - HiveEngineSpec.extract_error_message(Exception(msg)), +def test_job_1_launched_stage_1_map_40_progress(): # pylint: disable=invalid-name + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% + """.split( + "\n" + ) + assert HiveEngineSpec.progress(log) == 10 + + +def test_job_1_launched_stage_1_map_80_reduce_40_progress(): # pylint: disable=invalid-name + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40% + """.split( + "\n" + ) + assert HiveEngineSpec.progress(log) == 30 + + +def test_job_1_launched_stage_2_stages_progress(): # pylint: disable=invalid-name + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0% + """.split( + "\n" + ) + assert HiveEngineSpec.progress(log) == 12 + + +def test_job_2_launched_stage_2_stages_progress(): # pylint: disable=invalid-name + log = """ + 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0% + 17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2 + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% + 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% + """.split( + "\n" + ) + assert HiveEngineSpec.progress(log) == 60 + + +def test_hive_error_msg(): + msg = ( + '{...} errorMessage="Error while compiling statement: FAILED: ' + "SemanticException [Error 10001]: Line 4" + ":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, " + "sqlState='42S02', errorCode=10001)){...}" + ) + assert HiveEngineSpec.extract_error_message(Exception(msg)) == ( + "hive error: Error while compiling statement: FAILED: " + "SemanticException [Error 10001]: Line 4:5 " + "Table not found 'fact_ridesfdslakj'" + ) + + e = Exception("Some string that doesn't match the regex") + assert HiveEngineSpec.extract_error_message(e) == f"hive error: {e}" + + msg = ( + "errorCode=10001, " + 'errorMessage="Error while compiling statement"), operationHandle' + '=None)"' + ) + assert ( + HiveEngineSpec.extract_error_message(Exception(msg)) + == "hive error: Error while compiling statement" + ) + + +def test_hive_get_view_names_return_empty_list(): # pylint: disable=invalid-name + assert HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) == [] + + +def test_convert_dttm(): + dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f") + assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)" + assert ( + HiveEngineSpec.convert_dttm("TIMESTAMP", dttm) + == "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)" + ) + + +def test_create_table_from_csv_append() -> None: + + with pytest.raises(SupersetException): + HiveEngineSpec.create_table_from_csv( + "foo.csv", Table("foobar"), mock.MagicMock(), {}, {"if_exists": "append"} ) - e = Exception("Some string that doesn't match the regex") - self.assertEqual(f"hive error: {e}", HiveEngineSpec.extract_error_message(e)) - msg = ( - "errorCode=10001, " - 'errorMessage="Error while compiling statement"), operationHandle' - '=None)"' - ) - self.assertEqual( - ("hive error: Error while compiling statement"), - HiveEngineSpec.extract_error_message(Exception(msg)), - ) +def test_get_create_table_stmt() -> None: + table = Table("employee") + schema_def = """eid int, name String, salary String, destination String""" + location = "s3a://directory/table" + from unittest import TestCase - def test_hive_get_view_names_return_empty_list( - self, - ): # pylint: disable=invalid-name - self.assertEqual( - [], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) - ) - - def test_convert_dttm(self): - dttm = self.get_dttm() - - self.assertEqual( - HiveEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)" - ) - - self.assertEqual( - HiveEngineSpec.convert_dttm("TIMESTAMP", dttm), - "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)", - ) - - def test_create_table_from_csv_append(self) -> None: - self.assertRaises( - SupersetException, - HiveEngineSpec.create_table_from_csv, - "foo.csv", - Table("foobar"), - None, - {}, - {"if_exists": "append"}, - ) + TestCase.maxDiff = None + assert HiveEngineSpec.get_create_table_stmt( + table, schema_def, location, ",", 0, [""] + ) == ( + """CREATE TABLE employee ( eid int, name String, salary String, destination String ) + ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim + STORED AS TEXTFILE LOCATION :location + tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""", + { + "delim": ",", + "location": "s3a://directory/table", + "header_line_count": "1", + "null_value": "", + }, + ) + assert HiveEngineSpec.get_create_table_stmt( + table, schema_def, location, ",", 1, ["1", "2"] + ) == ( + """CREATE TABLE employee ( eid int, name String, salary String, destination String ) + ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim + STORED AS TEXTFILE LOCATION :location + tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""", + { + "delim": ",", + "location": "s3a://directory/table", + "header_line_count": "2", + "null_value": "1", + }, + ) + assert HiveEngineSpec.get_create_table_stmt( + table, schema_def, location, ",", 100, ["NaN"] + ) == ( + """CREATE TABLE employee ( eid int, name String, salary String, destination String ) + ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim + STORED AS TEXTFILE LOCATION :location + tblproperties ('skip.header.line.count'=':header_line_count', 'serialization.null.format'=':null_value')""", + { + "delim": ",", + "location": "s3a://directory/table", + "header_line_count": "101", + "null_value": "NaN", + }, + ) + assert HiveEngineSpec.get_create_table_stmt( + table, schema_def, location, ",", None, None + ) == ( + """CREATE TABLE employee ( eid int, name String, salary String, destination String ) + ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim + STORED AS TEXTFILE LOCATION :location""", + {"delim": ",", "location": "s3a://directory/table"}, + ) + assert HiveEngineSpec.get_create_table_stmt( + table, schema_def, location, ",", 100, [] + ) == ( + """CREATE TABLE employee ( eid int, name String, salary String, destination String ) + ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim + STORED AS TEXTFILE LOCATION :location + tblproperties ('skip.header.line.count'=':header_line_count')""", + {"delim": ",", "location": "s3a://directory/table", "header_line_count": "101"}, + )