chore: refactor file upload commands (#28164)

This commit is contained in:
Daniel Vaz Gaspar 2024-04-23 08:42:19 +01:00 committed by GitHub
parent cfc440c56c
commit de82d90b9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 930 additions and 734 deletions

View File

@ -301,7 +301,6 @@ test('CSV, renders the columns elements correctly', () => {
const selectColumnsToRead = screen.getByRole('combobox', { const selectColumnsToRead = screen.getByRole('combobox', {
name: /Choose columns to read/i, name: /Choose columns to read/i,
}); });
const switchOverwriteDuplicates = screen.getByTestId('overwriteDuplicates');
const inputColumnDataTypes = screen.getByRole('textbox', { const inputColumnDataTypes = screen.getByRole('textbox', {
name: /Column data types/i, name: /Column data types/i,
}); });
@ -312,7 +311,6 @@ test('CSV, renders the columns elements correctly', () => {
switchDataFrameIndex, switchDataFrameIndex,
inputColumnLabels, inputColumnLabels,
selectColumnsToRead, selectColumnsToRead,
switchOverwriteDuplicates,
inputColumnDataTypes, inputColumnDataTypes,
]; ];
visibleComponents.forEach(component => { visibleComponents.forEach(component => {

View File

@ -68,7 +68,6 @@ const CSVSpecificFields = [
'skip_initial_space', 'skip_initial_space',
'skip_blank_lines', 'skip_blank_lines',
'day_first', 'day_first',
'overwrite_duplicates',
'column_data_types', 'column_data_types',
]; ];
@ -109,7 +108,6 @@ interface UploadInfo {
dataframe_index: boolean; dataframe_index: boolean;
column_labels: string; column_labels: string;
columns_read: Array<string>; columns_read: Array<string>;
overwrite_duplicates: boolean;
column_data_types: string; column_data_types: string;
} }
@ -132,7 +130,6 @@ const defaultUploadInfo: UploadInfo = {
dataframe_index: false, dataframe_index: false,
column_labels: '', column_labels: '',
columns_read: [], columns_read: [],
overwrite_duplicates: false,
column_data_types: '', column_data_types: '',
}; };
@ -975,20 +972,6 @@ const UploadDataModal: FunctionComponent<UploadDataModalProps> = ({
</StyledFormItem> </StyledFormItem>
</Col> </Col>
</Row> </Row>
{type === 'csv' && (
<Row>
<Col span={24}>
<StyledFormItem name="overwrite_duplicates">
<SwitchContainer
label={t(
'Overwrite Duplicate Columns. If duplicate columns are not overridden, they will be presented as "X.1, X.2 ...X.x"',
)}
dataTest="overwriteDuplicates"
/>
</StyledFormItem>
</Col>
</Row>
)}
</Collapse.Panel> </Collapse.Panel>
<Collapse.Panel <Collapse.Panel
header={ header={

View File

@ -1,198 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional, TypedDict
import pandas as pd
from flask_babel import lazy_gettext as _
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.database.exceptions import (
DatabaseNotFoundError,
DatabaseSchemaUploadNotAllowed,
DatabaseUploadFailed,
DatabaseUploadSaveMetadataFailed,
)
from superset.connectors.sqla.models import SqlaTable
from superset.daos.database import DatabaseDAO
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import get_user
from superset.views.database.validators import schema_allows_file_upload
logger = logging.getLogger(__name__)
READ_CSV_CHUNK_SIZE = 1000
class CSVImportOptions(TypedDict, total=False):
schema: str
delimiter: str
already_exists: str
column_data_types: dict[str, str]
column_dates: list[str]
column_labels: str
columns_read: list[str]
dataframe_index: str
day_first: bool
decimal_character: str
header_row: int
index_column: str
null_values: list[str]
overwrite_duplicates: bool
rows_to_read: int
skip_blank_lines: bool
skip_initial_space: bool
skip_rows: int
class CSVImportCommand(BaseCommand):
def __init__(
self,
model_id: int,
table_name: str,
file: Any,
options: CSVImportOptions,
) -> None:
self._model_id = model_id
self._model: Optional[Database] = None
self._table_name = table_name
self._schema = options.get("schema")
self._file = file
self._options = options
def _read_csv(self) -> pd.DataFrame:
"""
Read CSV file into a DataFrame
:return: pandas DataFrame
:throws DatabaseUploadFailed: if there is an error reading the CSV file
"""
try:
return pd.concat(
pd.read_csv(
chunksize=READ_CSV_CHUNK_SIZE,
encoding="utf-8",
filepath_or_buffer=self._file,
header=self._options.get("header_row", 0),
index_col=self._options.get("index_column"),
dayfirst=self._options.get("day_first", False),
iterator=True,
keep_default_na=not self._options.get("null_values"),
usecols=self._options.get("columns_read")
if self._options.get("columns_read") # None if an empty list
else None,
na_values=self._options.get("null_values")
if self._options.get("null_values") # None if an empty list
else None,
nrows=self._options.get("rows_to_read"),
parse_dates=self._options.get("column_dates"),
sep=self._options.get("delimiter", ","),
skip_blank_lines=self._options.get("skip_blank_lines", False),
skipinitialspace=self._options.get("skip_initial_space", False),
skiprows=self._options.get("skip_rows", 0),
dtype=self._options.get("column_data_types")
if self._options.get("column_data_types")
else None,
)
)
except (
pd.errors.ParserError,
pd.errors.EmptyDataError,
UnicodeDecodeError,
ValueError,
) as ex:
raise DatabaseUploadFailed(
message=_("Parsing error: %(error)s", error=str(ex))
) from ex
except Exception as ex:
raise DatabaseUploadFailed(_("Error reading CSV file")) from ex
def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None:
"""
Upload DataFrame to database
:param df:
:throws DatabaseUploadFailed: if there is an error uploading the DataFrame
"""
try:
csv_table = Table(table=self._table_name, schema=self._schema)
database.db_engine_spec.df_to_sql(
database,
csv_table,
df,
to_sql_kwargs={
"chunksize": READ_CSV_CHUNK_SIZE,
"if_exists": self._options.get("already_exists", "fail"),
"index": self._options.get("index_column"),
"index_label": self._options.get("column_labels"),
},
)
except ValueError as ex:
raise DatabaseUploadFailed(
message=_(
"Table already exists. You can change your "
"'if table already exists' strategy to append or "
"replace or provide a different Table Name to use."
)
) from ex
except Exception as ex:
raise DatabaseUploadFailed(exception=ex) from ex
def run(self) -> None:
self.validate()
if not self._model:
return
df = self._read_csv()
self._dataframe_to_database(df, self._model)
sqla_table = (
db.session.query(SqlaTable)
.filter_by(
table_name=self._table_name,
schema=self._schema,
database_id=self._model_id,
)
.one_or_none()
)
if not sqla_table:
sqla_table = SqlaTable(
table_name=self._table_name,
database=self._model,
database_id=self._model_id,
owners=[get_user()],
schema=self._schema,
)
db.session.add(sqla_table)
sqla_table.fetch_metadata()
try:
db.session.commit()
except SQLAlchemyError as ex:
db.session.rollback()
raise DatabaseUploadSaveMetadataFailed() from ex
def validate(self) -> None:
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
if not schema_allows_file_upload(self._model, self._schema):
raise DatabaseSchemaUploadNotAllowed()

View File

@ -98,6 +98,11 @@ class DatabaseSchemaUploadNotAllowed(CommandException):
message = _("Database schema is not allowed for csv uploads.") message = _("Database schema is not allowed for csv uploads.")
class DatabaseUploadNotSupported(CommandException):
status = 422
message = _("Database type does not support file uploads.")
class DatabaseUploadFailed(CommandException): class DatabaseUploadFailed(CommandException):
status = 422 status = 422
message = _("Database upload file failed") message = _("Database upload file failed")

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from abc import abstractmethod
from typing import Any, Optional, TypedDict from typing import Any, Optional, TypedDict
import pandas as pd import pandas as pd
@ -27,6 +29,7 @@ from superset.commands.database.exceptions import (
DatabaseNotFoundError, DatabaseNotFoundError,
DatabaseSchemaUploadNotAllowed, DatabaseSchemaUploadNotAllowed,
DatabaseUploadFailed, DatabaseUploadFailed,
DatabaseUploadNotSupported,
DatabaseUploadSaveMetadataFailed, DatabaseUploadSaveMetadataFailed,
) )
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
@ -38,78 +41,43 @@ from superset.views.database.validators import schema_allows_file_upload
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
READ_EXCEL_CHUNK_SIZE = 1000 READ_CHUNK_SIZE = 1000
class ExcelImportOptions(TypedDict, total=False): class ReaderOptions(TypedDict, total=False):
sheet_name: str
schema: str
already_exists: str already_exists: str
column_dates: list[str]
column_labels: str column_labels: str
columns_read: list[str]
dataframe_index: str
decimal_character: str
header_row: int
index_column: str index_column: str
null_values: list[str]
rows_to_read: int
skip_rows: int
class ExcelImportCommand(BaseCommand): class BaseDataReader:
def __init__( """
self, Base class for reading data from a file and uploading it to a database
model_id: int, These child objects are used by the UploadCommand as a dependency injection
table_name: str, to read data from multiple file types (e.g. CSV, Excel, etc.)
file: Any, """
options: ExcelImportOptions,
) -> None: def __init__(self, options: dict[str, Any]) -> None:
self._model_id = model_id
self._model: Optional[Database] = None
self._table_name = table_name
self._schema = options.get("schema")
self._file = file
self._options = options self._options = options
def _read_excel(self) -> pd.DataFrame: @abstractmethod
""" def file_to_dataframe(self, file: Any) -> pd.DataFrame:
Read Excel file into a DataFrame ...
:return: pandas DataFrame def read(
:throws DatabaseUploadFailed: if there is an error reading the CSV file self, file: Any, database: Database, table_name: str, schema_name: Optional[str]
""" ) -> None:
self._dataframe_to_database(
self.file_to_dataframe(file), database, table_name, schema_name
)
kwargs = { def _dataframe_to_database(
"header": self._options.get("header_row", 0), self,
"index_col": self._options.get("index_column"), df: pd.DataFrame,
"io": self._file, database: Database,
"keep_default_na": not self._options.get("null_values"), table_name: str,
"na_values": self._options.get("null_values") schema_name: Optional[str],
if self._options.get("null_values") # None if an empty list ) -> None:
else None,
"parse_dates": self._options.get("column_dates"),
"skiprows": self._options.get("skip_rows", 0),
"sheet_name": self._options.get("sheet_name", 0),
"nrows": self._options.get("rows_to_read"),
}
if self._options.get("columns_read"):
kwargs["usecols"] = self._options.get("columns_read")
try:
return pd.read_excel(**kwargs)
except (
pd.errors.ParserError,
pd.errors.EmptyDataError,
UnicodeDecodeError,
ValueError,
) as ex:
raise DatabaseUploadFailed(
message=_("Parsing error: %(error)s", error=str(ex))
) from ex
except Exception as ex:
raise DatabaseUploadFailed(_("Error reading Excel file")) from ex
def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None:
""" """
Upload DataFrame to database Upload DataFrame to database
@ -117,13 +85,13 @@ class ExcelImportCommand(BaseCommand):
:throws DatabaseUploadFailed: if there is an error uploading the DataFrame :throws DatabaseUploadFailed: if there is an error uploading the DataFrame
""" """
try: try:
data_table = Table(table=self._table_name, schema=self._schema) data_table = Table(table=table_name, schema=schema_name)
database.db_engine_spec.df_to_sql( database.db_engine_spec.df_to_sql(
database, database,
data_table, data_table,
df, df,
to_sql_kwargs={ to_sql_kwargs={
"chunksize": READ_EXCEL_CHUNK_SIZE, "chunksize": READ_CHUNK_SIZE,
"if_exists": self._options.get("already_exists", "fail"), "if_exists": self._options.get("already_exists", "fail"),
"index": self._options.get("index_column"), "index": self._options.get("index_column"),
"index_label": self._options.get("column_labels"), "index_label": self._options.get("column_labels"),
@ -140,13 +108,29 @@ class ExcelImportCommand(BaseCommand):
except Exception as ex: except Exception as ex:
raise DatabaseUploadFailed(exception=ex) from ex raise DatabaseUploadFailed(exception=ex) from ex
class UploadCommand(BaseCommand):
def __init__( # pylint: disable=too-many-arguments
self,
model_id: int,
table_name: str,
file: Any,
schema: Optional[str],
reader: BaseDataReader,
) -> None:
self._model_id = model_id
self._model: Optional[Database] = None
self._table_name = table_name
self._schema = schema
self._file = file
self._reader = reader
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
if not self._model: if not self._model:
return return
df = self._read_excel() self._reader.read(self._file, self._model, self._table_name, self._schema)
self._dataframe_to_database(df, self._model)
sqla_table = ( sqla_table = (
db.session.query(SqlaTable) db.session.query(SqlaTable)
@ -181,3 +165,5 @@ class ExcelImportCommand(BaseCommand):
raise DatabaseNotFoundError() raise DatabaseNotFoundError()
if not schema_allows_file_upload(self._model, self._schema): if not schema_allows_file_upload(self._model, self._schema):
raise DatabaseSchemaUploadNotAllowed() raise DatabaseSchemaUploadNotAllowed()
if not self._model.db_engine_spec.supports_file_upload:
raise DatabaseUploadNotSupported()

View File

@ -0,0 +1,102 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
import pandas as pd
from flask_babel import lazy_gettext as _
from superset.commands.database.exceptions import DatabaseUploadFailed
from superset.commands.database.uploaders.base import BaseDataReader, ReaderOptions
logger = logging.getLogger(__name__)
READ_CSV_CHUNK_SIZE = 1000
class CSVReaderOptions(ReaderOptions, total=False):
delimiter: str
column_data_types: dict[str, str]
column_dates: list[str]
columns_read: list[str]
dataframe_index: str
day_first: bool
decimal_character: str
header_row: int
null_values: list[str]
rows_to_read: int
skip_blank_lines: bool
skip_initial_space: bool
skip_rows: int
class CSVReader(BaseDataReader):
def __init__(
self,
options: CSVReaderOptions,
) -> None:
super().__init__(
options=dict(options),
)
def file_to_dataframe(self, file: Any) -> pd.DataFrame:
"""
Read CSV file into a DataFrame
:return: pandas DataFrame
:throws DatabaseUploadFailed: if there is an error reading the CSV file
"""
try:
return pd.concat(
pd.read_csv(
chunksize=READ_CSV_CHUNK_SIZE,
encoding="utf-8",
filepath_or_buffer=file,
header=self._options.get("header_row", 0),
decimal=self._options.get("decimal_character", "."),
index_col=self._options.get("index_column"),
dayfirst=self._options.get("day_first", False),
iterator=True,
keep_default_na=not self._options.get("null_values"),
usecols=self._options.get("columns_read")
if self._options.get("columns_read") # None if an empty list
else None,
na_values=self._options.get("null_values")
if self._options.get("null_values") # None if an empty list
else None,
nrows=self._options.get("rows_to_read"),
parse_dates=self._options.get("column_dates"),
sep=self._options.get("delimiter", ","),
skip_blank_lines=self._options.get("skip_blank_lines", False),
skipinitialspace=self._options.get("skip_initial_space", False),
skiprows=self._options.get("skip_rows", 0),
dtype=self._options.get("column_data_types")
if self._options.get("column_data_types")
else None,
)
)
except (
pd.errors.ParserError,
pd.errors.EmptyDataError,
UnicodeDecodeError,
ValueError,
) as ex:
raise DatabaseUploadFailed(
message=_("Parsing error: %(error)s", error=str(ex))
) from ex
except Exception as ex:
raise DatabaseUploadFailed(_("Error reading CSV file")) from ex

View File

@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
import pandas as pd
from flask_babel import lazy_gettext as _
from superset.commands.database.exceptions import DatabaseUploadFailed
from superset.commands.database.uploaders.base import BaseDataReader, ReaderOptions
logger = logging.getLogger(__name__)
class ExcelReaderOptions(ReaderOptions, total=False):
sheet_name: str
column_dates: list[str]
columns_read: list[str]
dataframe_index: str
decimal_character: str
header_row: int
null_values: list[str]
rows_to_read: int
skip_rows: int
class ExcelReader(BaseDataReader):
def __init__(
self,
options: ExcelReaderOptions,
) -> None:
super().__init__(
options=dict(options),
)
def file_to_dataframe(self, file: Any) -> pd.DataFrame:
"""
Read Excel file into a DataFrame
:return: pandas DataFrame
:throws DatabaseUploadFailed: if there is an error reading the CSV file
"""
kwargs = {
"header": self._options.get("header_row", 0),
"index_col": self._options.get("index_column"),
"io": file,
"keep_default_na": not self._options.get("null_values"),
"decimal": self._options.get("decimal_character", "."),
"na_values": self._options.get("null_values")
if self._options.get("null_values") # None if an empty list
else None,
"parse_dates": self._options.get("column_dates"),
"skiprows": self._options.get("skip_rows", 0),
"sheet_name": self._options.get("sheet_name", 0),
"nrows": self._options.get("rows_to_read"),
}
if self._options.get("columns_read"):
kwargs["usecols"] = self._options.get("columns_read")
try:
return pd.read_excel(**kwargs)
except (
pd.errors.ParserError,
pd.errors.EmptyDataError,
UnicodeDecodeError,
ValueError,
) as ex:
raise DatabaseUploadFailed(
message=_("Parsing error: %(error)s", error=str(ex))
) from ex
except Exception as ex:
raise DatabaseUploadFailed(_("Error reading Excel file")) from ex

View File

@ -34,9 +34,7 @@ from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
from superset import app, event_logger from superset import app, event_logger
from superset.commands.database.create import CreateDatabaseCommand from superset.commands.database.create import CreateDatabaseCommand
from superset.commands.database.csv_import import CSVImportCommand
from superset.commands.database.delete import DeleteDatabaseCommand from superset.commands.database.delete import DeleteDatabaseCommand
from superset.commands.database.excel_import import ExcelImportCommand
from superset.commands.database.exceptions import ( from superset.commands.database.exceptions import (
DatabaseConnectionFailedError, DatabaseConnectionFailedError,
DatabaseCreateFailedError, DatabaseCreateFailedError,
@ -59,6 +57,9 @@ from superset.commands.database.ssh_tunnel.exceptions import (
from superset.commands.database.tables import TablesDatabaseCommand from superset.commands.database.tables import TablesDatabaseCommand
from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.commands.database.update import UpdateDatabaseCommand from superset.commands.database.update import UpdateDatabaseCommand
from superset.commands.database.uploaders.base import UploadCommand
from superset.commands.database.uploaders.csv_reader import CSVReader
from superset.commands.database.uploaders.excel_reader import ExcelReader
from superset.commands.database.validate import ValidateDatabaseParametersCommand from superset.commands.database.validate import ValidateDatabaseParametersCommand
from superset.commands.database.validate_sql import ValidateSQLCommand from superset.commands.database.validate_sql import ValidateSQLCommand
from superset.commands.importers.exceptions import ( from superset.commands.importers.exceptions import (
@ -1491,11 +1492,12 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
request_form = request.form.to_dict() request_form = request.form.to_dict()
request_form["file"] = request.files.get("file") request_form["file"] = request.files.get("file")
parameters = CSVUploadPostSchema().load(request_form) parameters = CSVUploadPostSchema().load(request_form)
CSVImportCommand( UploadCommand(
pk, pk,
parameters["table_name"], parameters["table_name"],
parameters["file"], parameters["file"],
parameters, parameters.get("schema"),
CSVReader(parameters),
).run() ).run()
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
@ -1550,11 +1552,12 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
request_form = request.form.to_dict() request_form = request.form.to_dict()
request_form["file"] = request.files.get("file") request_form["file"] = request.files.get("file")
parameters = ExcelUploadPostSchema().load(request_form) parameters = ExcelUploadPostSchema().load(request_form)
ExcelImportCommand( UploadCommand(
pk, pk,
parameters["table_name"], parameters["table_name"],
parameters["file"], parameters["file"],
parameters, parameters.get("schema"),
ExcelReader(parameters),
).run() ).run()
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)

View File

@ -1116,12 +1116,6 @@ class CSVUploadPostSchema(BaseUploadPostSchema):
"description": "DD/MM format dates, international and European format" "description": "DD/MM format dates, international and European format"
} }
) )
overwrite_duplicates = fields.Boolean(
metadata={
"description": "If duplicate columns are not overridden,"
"they will be presented as 'X.1, X.2 ...X.x'."
}
)
skip_blank_lines = fields.Boolean( skip_blank_lines = fields.Boolean(
metadata={"description": "Skip blank lines in the CSV file."} metadata={"description": "Skip blank lines in the CSV file."}
) )

View File

@ -1,257 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from datetime import datetime
import pytest
from superset import db, security_manager
from superset.commands.database.excel_import import ExcelImportCommand
from superset.commands.database.exceptions import (
DatabaseNotFoundError,
DatabaseSchemaUploadNotAllowed,
DatabaseUploadFailed,
)
from superset.models.core import Database
from superset.utils.core import override_user
from superset.utils.database import get_or_create_db
from tests.integration_tests.conftest import only_postgresql
from tests.integration_tests.test_app import app
from tests.unit_tests.fixtures.common import create_excel_file
EXCEL_UPLOAD_DATABASE = "excel_explore_db"
EXCEL_UPLOAD_TABLE = "excel_upload"
EXCEL_UPLOAD_TABLE_W_SCHEMA = "excel_upload_w_schema"
EXCEL_FILE_1 = {
"Name": ["name1", "name2", "name3"],
"Age": [30, 29, 28],
"City": ["city1", "city2", "city3"],
"Birth": ["1-1-1980", "1-1-1981", "1-1-1982"],
}
EXCEL_FILE_2 = {
"Name": ["name1", "name2", "name3"],
"Age": ["N/A", 29, 28],
"City": ["city1", "None", "city3"],
"Birth": ["1-1-1980", "1-1-1981", "1-1-1982"],
}
def _setup_excel_upload(allowed_schemas: list[str] | None = None):
upload_db = get_or_create_db(
EXCEL_UPLOAD_DATABASE, app.config["SQLALCHEMY_EXAMPLES_URI"]
)
upload_db.allow_file_upload = True
extra = upload_db.get_extra()
allowed_schemas = allowed_schemas or []
extra["schemas_allowed_for_file_upload"] = allowed_schemas
upload_db.extra = json.dumps(extra)
db.session.commit()
yield
upload_db = get_upload_db()
with upload_db.get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE_W_SCHEMA}")
db.session.delete(upload_db)
db.session.commit()
def get_upload_db():
return (
db.session.query(Database).filter_by(database_name=EXCEL_UPLOAD_DATABASE).one()
)
@pytest.fixture(scope="function")
def setup_excel_upload_with_context():
with app.app_context():
yield from _setup_excel_upload()
@pytest.fixture(scope="function")
def setup_excel_upload_with_context_schema():
with app.app_context():
yield from _setup_excel_upload(["public"])
@only_postgresql
@pytest.mark.parametrize(
"excel_data,options, table_data",
[
(
EXCEL_FILE_1,
{},
[
("name1", 30, "city1", "1-1-1980"),
("name2", 29, "city2", "1-1-1981"),
("name3", 28, "city3", "1-1-1982"),
],
),
(
EXCEL_FILE_1,
{"columns_read": ["Name", "Age"]},
[("name1", 30), ("name2", 29), ("name3", 28)],
),
(
EXCEL_FILE_1,
{"columns_read": []},
[
("name1", 30, "city1", "1-1-1980"),
("name2", 29, "city2", "1-1-1981"),
("name3", 28, "city3", "1-1-1982"),
],
),
(
EXCEL_FILE_1,
{"rows_to_read": 1},
[
("name1", 30, "city1", "1-1-1980"),
],
),
(
EXCEL_FILE_1,
{"rows_to_read": 1, "columns_read": ["Name", "Age"]},
[
("name1", 30),
],
),
(
EXCEL_FILE_1,
{"skip_rows": 1},
[("name2", 29, "city2", "1-1-1981"), ("name3", 28, "city3", "1-1-1982")],
),
(
EXCEL_FILE_1,
{"rows_to_read": 2},
[
("name1", 30, "city1", "1-1-1980"),
("name2", 29, "city2", "1-1-1981"),
],
),
(
EXCEL_FILE_1,
{"column_dates": ["Birth"]},
[
("name1", 30, "city1", datetime(1980, 1, 1, 0, 0)),
("name2", 29, "city2", datetime(1981, 1, 1, 0, 0)),
("name3", 28, "city3", datetime(1982, 1, 1, 0, 0)),
],
),
(
EXCEL_FILE_2,
{"null_values": ["N/A", "None"]},
[
("name1", None, "city1", "1-1-1980"),
("name2", 29, None, "1-1-1981"),
("name3", 28, "city3", "1-1-1982"),
],
),
(
EXCEL_FILE_2,
{
"null_values": ["N/A", "None"],
"column_dates": ["Birth"],
"columns_read": ["Name", "Age", "Birth"],
},
[
("name1", None, datetime(1980, 1, 1, 0, 0)),
("name2", 29, datetime(1981, 1, 1, 0, 0)),
("name3", 28, datetime(1982, 1, 1, 0, 0)),
],
),
],
)
@pytest.mark.usefixtures("setup_excel_upload_with_context")
def test_excel_upload_options(excel_data, options, table_data):
admin_user = security_manager.find_user(username="admin")
upload_database = get_upload_db()
with override_user(admin_user):
ExcelImportCommand(
upload_database.id,
EXCEL_UPLOAD_TABLE,
create_excel_file(excel_data),
options=options,
).run()
with upload_database.get_sqla_engine_with_context() as engine:
data = engine.execute(f"SELECT * from {EXCEL_UPLOAD_TABLE}").fetchall()
assert data == table_data
@only_postgresql
@pytest.mark.usefixtures("setup_excel_upload_with_context")
def test_excel_upload_database_not_found():
admin_user = security_manager.find_user(username="admin")
with override_user(admin_user):
with pytest.raises(DatabaseNotFoundError):
ExcelImportCommand(
1000,
EXCEL_UPLOAD_TABLE,
create_excel_file(EXCEL_FILE_1),
options={},
).run()
@only_postgresql
@pytest.mark.usefixtures("setup_excel_upload_with_context_schema")
def test_excel_upload_schema_not_allowed():
admin_user = security_manager.find_user(username="admin")
upload_db_id = get_upload_db().id
with override_user(admin_user):
with pytest.raises(DatabaseSchemaUploadNotAllowed):
ExcelImportCommand(
upload_db_id,
EXCEL_UPLOAD_TABLE,
create_excel_file(EXCEL_FILE_1),
options={},
).run()
with pytest.raises(DatabaseSchemaUploadNotAllowed):
ExcelImportCommand(
upload_db_id,
EXCEL_UPLOAD_TABLE,
create_excel_file(EXCEL_FILE_1),
options={"schema": "schema1"},
).run()
ExcelImportCommand(
upload_db_id,
EXCEL_UPLOAD_TABLE,
create_excel_file(EXCEL_FILE_1),
options={"schema": "public"},
).run()
@only_postgresql
@pytest.mark.usefixtures("setup_excel_upload_with_context")
def test_excel_upload_broken_file():
admin_user = security_manager.find_user(username="admin")
with override_user(admin_user):
with pytest.raises(DatabaseUploadFailed):
ExcelImportCommand(
get_upload_db().id,
EXCEL_UPLOAD_TABLE,
create_excel_file([""]),
options={"column_dates": ["Birth"]},
).run()

View File

@ -18,17 +18,19 @@
from __future__ import annotations from __future__ import annotations
import json import json
from datetime import datetime
import pytest import pytest
from superset import db, security_manager from superset import db, security_manager
from superset.commands.database.csv_import import CSVImportCommand
from superset.commands.database.exceptions import ( from superset.commands.database.exceptions import (
DatabaseNotFoundError, DatabaseNotFoundError,
DatabaseSchemaUploadNotAllowed, DatabaseSchemaUploadNotAllowed,
DatabaseUploadFailed, DatabaseUploadFailed,
DatabaseUploadNotSupported,
) )
from superset.commands.database.uploaders.base import UploadCommand
from superset.commands.database.uploaders.csv_reader import CSVReader
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database from superset.models.core import Database
from superset.utils.core import override_user from superset.utils.core import override_user
from superset.utils.database import get_or_create_db from superset.utils.database import get_or_create_db
@ -48,27 +50,13 @@ CSV_FILE_1 = [
["name3", "28", "city3", "1-1-1982"], ["name3", "28", "city3", "1-1-1982"],
] ]
CSV_FILE_2 = [ CSV_FILE_WITH_NULLS = [
["name1", "30", "city1", "1-1-1980"],
["Name", "Age", "City", "Birth"],
["name2", "29", "city2", "1-1-1981"],
["name3", "28", "city3", "1-1-1982"],
]
CSV_FILE_3 = [
["Name", "Age", "City", "Birth"], ["Name", "Age", "City", "Birth"],
["name1", "N/A", "city1", "1-1-1980"], ["name1", "N/A", "city1", "1-1-1980"],
["name2", "29", "None", "1-1-1981"], ["name2", "29", "None", "1-1-1981"],
["name3", "28", "city3", "1-1-1982"], ["name3", "28", "city3", "1-1-1982"],
] ]
CSV_FILE_BROKEN = [
["Name", "Age", "City", "Birth"],
["name1", "30", "city1", "1-1-1980"],
["name2", "29"],
["name3", "28", "city3", "1-1-1982"],
]
def _setup_csv_upload(allowed_schemas: list[str] | None = None): def _setup_csv_upload(allowed_schemas: list[str] | None = None):
upload_db = get_or_create_db( upload_db = get_or_create_db(
@ -108,122 +96,48 @@ def setup_csv_upload_with_context_schema():
yield from _setup_csv_upload(["public"]) yield from _setup_csv_upload(["public"])
@only_postgresql
@pytest.mark.parametrize(
"csv_data,options, table_data",
[
(
CSV_FILE_1,
{},
[
("name1", 30, "city1", "1-1-1980"),
("name2", 29, "city2", "1-1-1981"),
("name3", 28, "city3", "1-1-1982"),
],
),
(
CSV_FILE_1,
{"columns_read": ["Name", "Age"]},
[("name1", 30), ("name2", 29), ("name3", 28)],
),
(
CSV_FILE_1,
{"columns_read": []},
[
("name1", 30, "city1", "1-1-1980"),
("name2", 29, "city2", "1-1-1981"),
("name3", 28, "city3", "1-1-1982"),
],
),
(
CSV_FILE_1,
{"rows_to_read": 1},
[
("name1", 30, "city1", "1-1-1980"),
],
),
(
CSV_FILE_1,
{"rows_to_read": 1, "columns_read": ["Name", "Age"]},
[
("name1", 30),
],
),
(
CSV_FILE_1,
{"skip_rows": 1},
[("name2", 29, "city2", "1-1-1981"), ("name3", 28, "city3", "1-1-1982")],
),
(
CSV_FILE_1,
{"rows_to_read": 2},
[
("name1", 30, "city1", "1-1-1980"),
("name2", 29, "city2", "1-1-1981"),
],
),
(
CSV_FILE_1,
{"column_dates": ["Birth"]},
[
("name1", 30, "city1", datetime(1980, 1, 1, 0, 0)),
("name2", 29, "city2", datetime(1981, 1, 1, 0, 0)),
("name3", 28, "city3", datetime(1982, 1, 1, 0, 0)),
],
),
(
CSV_FILE_2,
{"header_row": 1},
[("name2", 29, "city2", "1-1-1981"), ("name3", 28, "city3", "1-1-1982")],
),
(
CSV_FILE_3,
{"null_values": ["N/A", "None"]},
[
("name1", None, "city1", "1-1-1980"),
("name2", 29, None, "1-1-1981"),
("name3", 28, "city3", "1-1-1982"),
],
),
(
CSV_FILE_3,
{
"null_values": ["N/A", "None"],
"column_dates": ["Birth"],
"columns_read": ["Name", "Age", "Birth"],
},
[
("name1", None, datetime(1980, 1, 1, 0, 0)),
("name2", 29, datetime(1981, 1, 1, 0, 0)),
("name3", 28, datetime(1982, 1, 1, 0, 0)),
],
),
(
CSV_FILE_BROKEN,
{},
[
("name1", 30, "city1", "1-1-1980"),
("name2", 29, None, None),
("name3", 28, "city3", "1-1-1982"),
],
),
],
)
@pytest.mark.usefixtures("setup_csv_upload_with_context") @pytest.mark.usefixtures("setup_csv_upload_with_context")
def test_csv_upload_options(csv_data, options, table_data): def test_csv_upload_with_nulls():
admin_user = security_manager.find_user(username="admin") admin_user = security_manager.find_user(username="admin")
upload_database = get_upload_db() upload_database = get_upload_db()
with override_user(admin_user): with override_user(admin_user):
CSVImportCommand( UploadCommand(
upload_database.id, upload_database.id,
CSV_UPLOAD_TABLE, CSV_UPLOAD_TABLE,
create_csv_file(csv_data), create_csv_file(CSV_FILE_WITH_NULLS),
options=options, None,
CSVReader({"null_values": ["N/A", "None"]}),
).run() ).run()
with upload_database.get_sqla_engine_with_context() as engine: with upload_database.get_sqla_engine_with_context() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == table_data assert data == [
("name1", None, "city1", "1-1-1980"),
("name2", 29, None, "1-1-1981"),
("name3", 28, "city3", "1-1-1982"),
]
@pytest.mark.usefixtures("setup_csv_upload_with_context")
def test_csv_upload_dataset():
admin_user = security_manager.find_user(username="admin")
upload_database = get_upload_db()
with override_user(admin_user):
UploadCommand(
upload_database.id,
CSV_UPLOAD_TABLE,
create_csv_file(),
None,
CSVReader({}),
).run()
dataset = (
db.session.query(SqlaTable)
.filter_by(database_id=upload_database.id, table_name=CSV_UPLOAD_TABLE)
.one_or_none()
)
assert dataset is not None
assert security_manager.find_user("admin") in dataset.owners
@only_postgresql @only_postgresql
@ -233,14 +147,33 @@ def test_csv_upload_database_not_found():
with override_user(admin_user): with override_user(admin_user):
with pytest.raises(DatabaseNotFoundError): with pytest.raises(DatabaseNotFoundError):
CSVImportCommand( UploadCommand(
1000, 1000,
CSV_UPLOAD_TABLE, CSV_UPLOAD_TABLE,
create_csv_file(CSV_FILE_1), create_csv_file(CSV_FILE_1),
options={}, None,
CSVReader({}),
).run() ).run()
@only_postgresql
@pytest.mark.usefixtures("setup_csv_upload_with_context")
def test_csv_upload_database_not_supported():
admin_user = security_manager.find_user(username="admin")
upload_db: Database = get_upload_db()
upload_db.db_engine_spec.supports_file_upload = False
with override_user(admin_user):
with pytest.raises(DatabaseUploadNotSupported):
UploadCommand(
upload_db.id,
CSV_UPLOAD_TABLE,
create_csv_file(CSV_FILE_1),
None,
CSVReader({}),
).run()
upload_db.db_engine_spec.supports_file_upload = True
@only_postgresql @only_postgresql
@pytest.mark.usefixtures("setup_csv_upload_with_context_schema") @pytest.mark.usefixtures("setup_csv_upload_with_context_schema")
def test_csv_upload_schema_not_allowed(): def test_csv_upload_schema_not_allowed():
@ -248,39 +181,25 @@ def test_csv_upload_schema_not_allowed():
upload_db_id = get_upload_db().id upload_db_id = get_upload_db().id
with override_user(admin_user): with override_user(admin_user):
with pytest.raises(DatabaseSchemaUploadNotAllowed): with pytest.raises(DatabaseSchemaUploadNotAllowed):
CSVImportCommand( UploadCommand(
upload_db_id, upload_db_id,
CSV_UPLOAD_TABLE, CSV_UPLOAD_TABLE,
create_csv_file(CSV_FILE_1), create_csv_file(CSV_FILE_1),
options={}, None,
CSVReader({}),
).run() ).run()
with pytest.raises(DatabaseSchemaUploadNotAllowed): with pytest.raises(DatabaseSchemaUploadNotAllowed):
CSVImportCommand( UploadCommand(
upload_db_id, upload_db_id,
CSV_UPLOAD_TABLE, CSV_UPLOAD_TABLE,
create_csv_file(CSV_FILE_1), create_csv_file(CSV_FILE_1),
options={"schema": "schema1"}, "schema1",
CSVReader({}),
).run() ).run()
UploadCommand(
CSVImportCommand(
upload_db_id, upload_db_id,
CSV_UPLOAD_TABLE, CSV_UPLOAD_TABLE_W_SCHEMA,
create_csv_file(CSV_FILE_1), create_csv_file(CSV_FILE_1),
options={"schema": "public"}, "public",
CSVReader({}),
).run() ).run()
@only_postgresql
@pytest.mark.usefixtures("setup_csv_upload_with_context")
def test_csv_upload_broken_file():
admin_user = security_manager.find_user(username="admin")
with override_user(admin_user):
with pytest.raises(DatabaseUploadFailed):
CSVImportCommand(
get_upload_db().id,
CSV_UPLOAD_TABLE,
create_csv_file([""]),
options={"column_dates": ["Birth"]},
).run()

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,313 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
from datetime import datetime
import numpy as np
import pytest
from superset.commands.database.exceptions import DatabaseUploadFailed
from superset.commands.database.uploaders.csv_reader import CSVReader, CSVReaderOptions
from tests.unit_tests.fixtures.common import create_csv_file
CSV_DATA = [
["Name", "Age", "City", "Birth"],
["name1", "30", "city1", "1990-02-01"],
["name2", "25", "city2", "1995-02-01"],
["name3", "20", "city3", "2000-02-01"],
]
CSV_DATA_CHANGED_HEADER = [
["name1", "30", "city1", "1990-02-01"],
["Name", "Age", "City", "Birth"],
["name2", "25", "city2", "1995-02-01"],
["name3", "20", "city3", "2000-02-01"],
]
CSV_DATA_WITH_NULLS = [
["Name", "Age", "City", "Birth"],
["name1", "N/A", "city1", "1990-02-01"],
["name2", "25", "None", "1995-02-01"],
["name3", "20", "city3", "2000-02-01"],
]
CSV_DATA_DAY_FIRST = [
["Name", "Age", "City", "Birth"],
["name1", "30", "city1", "01-02-1990"],
]
CSV_DATA_DECIMAL_CHAR = [
["Name", "Age", "City", "Birth"],
["name1", "30,1", "city1", "1990-02-01"],
]
CSV_DATA_SKIP_INITIAL_SPACE = [
[" Name", "Age", "City", "Birth"],
[" name1", "30", "city1", "1990-02-01"],
]
@pytest.mark.parametrize(
"file, options, expected_cols, expected_values",
[
(
create_csv_file(CSV_DATA),
CSVReaderOptions(),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", "1990-02-01"],
["name2", 25, "city2", "1995-02-01"],
["name3", 20, "city3", "2000-02-01"],
],
),
(
create_csv_file(CSV_DATA, delimiter="|"),
CSVReaderOptions(delimiter="|"),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", "1990-02-01"],
["name2", 25, "city2", "1995-02-01"],
["name3", 20, "city3", "2000-02-01"],
],
),
(
create_csv_file(CSV_DATA),
CSVReaderOptions(
columns_read=["Name", "Age"],
),
["Name", "Age"],
[
["name1", 30],
["name2", 25],
["name3", 20],
],
),
(
create_csv_file(CSV_DATA),
CSVReaderOptions(
columns_read=["Name", "Age"],
column_data_types={"Age": "float"},
),
["Name", "Age"],
[
["name1", 30.0],
["name2", 25.0],
["name3", 20.0],
],
),
(
create_csv_file(CSV_DATA),
CSVReaderOptions(
columns_read=[],
),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", "1990-02-01"],
["name2", 25, "city2", "1995-02-01"],
["name3", 20, "city3", "2000-02-01"],
],
),
(
create_csv_file(CSV_DATA),
CSVReaderOptions(
columns_read=[],
column_data_types={"Age": "float"},
),
["Name", "Age", "City", "Birth"],
[
["name1", 30.0, "city1", "1990-02-01"],
["name2", 25.0, "city2", "1995-02-01"],
["name3", 20.0, "city3", "2000-02-01"],
],
),
(
create_csv_file(CSV_DATA),
CSVReaderOptions(
rows_to_read=1,
),
["Name", "Age", "City", "Birth"],
[
["name1", 30.0, "city1", "1990-02-01"],
],
),
(
create_csv_file(CSV_DATA),
CSVReaderOptions(
rows_to_read=1,
columns_read=["Name", "Age"],
),
["Name", "Age"],
[
["name1", 30.0],
],
),
(
create_csv_file(CSV_DATA),
CSVReaderOptions(
skip_rows=1,
),
["name1", "30", "city1", "1990-02-01"],
[
["name2", 25.0, "city2", "1995-02-01"],
["name3", 20.0, "city3", "2000-02-01"],
],
),
(
create_csv_file(CSV_DATA),
CSVReaderOptions(
column_dates=["Birth"],
),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", datetime(1990, 2, 1, 0, 0)],
["name2", 25, "city2", datetime(1995, 2, 1, 0, 0)],
["name3", 20, "city3", datetime(2000, 2, 1, 0, 0)],
],
),
(
create_csv_file(CSV_DATA_CHANGED_HEADER),
CSVReaderOptions(
header_row=1,
),
["Name", "Age", "City", "Birth"],
[
["name2", 25, "city2", "1995-02-01"],
["name3", 20, "city3", "2000-02-01"],
],
),
(
create_csv_file(CSV_DATA_WITH_NULLS),
CSVReaderOptions(
null_values=["N/A", "None"],
),
["Name", "Age", "City", "Birth"],
[
["name1", np.nan, "city1", "1990-02-01"],
["name2", 25.0, np.nan, "1995-02-01"],
["name3", 20.0, "city3", "2000-02-01"],
],
),
(
create_csv_file(CSV_DATA_DAY_FIRST),
CSVReaderOptions(
day_first=False,
column_dates=["Birth"],
),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", datetime(1990, 1, 2, 0, 0)],
],
),
(
create_csv_file(CSV_DATA_DAY_FIRST),
CSVReaderOptions(
day_first=True,
column_dates=["Birth"],
),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", datetime(1990, 2, 1, 0, 0)],
],
),
(
create_csv_file(CSV_DATA_DECIMAL_CHAR),
CSVReaderOptions(
decimal_character=",",
),
["Name", "Age", "City", "Birth"],
[
["name1", 30.1, "city1", "1990-02-01"],
],
),
(
create_csv_file(CSV_DATA_SKIP_INITIAL_SPACE),
CSVReaderOptions(
skip_initial_space=True,
),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", "1990-02-01"],
],
),
],
)
def test_csv_reader_file_to_dataframe(file, options, expected_cols, expected_values):
csv_reader = CSVReader(
options=options,
)
df = csv_reader.file_to_dataframe(file)
assert df.columns.tolist() == expected_cols
actual_values = df.values.tolist()
for i in range(len(expected_values)):
for j in range(len(expected_values[i])):
expected_val = expected_values[i][j]
actual_val = actual_values[i][j]
# Check if both values are NaN
if isinstance(expected_val, float) and isinstance(actual_val, float):
assert np.isnan(expected_val) == np.isnan(actual_val)
else:
assert expected_val == actual_val
file.close()
def test_csv_reader_broken_file_no_columns():
csv_reader = CSVReader(
options=CSVReaderOptions(),
)
with pytest.raises(DatabaseUploadFailed) as ex:
csv_reader.file_to_dataframe(create_csv_file([""]))
assert str(ex.value) == "Parsing error: No columns to parse from file"
def test_csv_reader_wrong_columns_to_read():
csv_reader = CSVReader(
options=CSVReaderOptions(columns_read=["xpto"]),
)
with pytest.raises(DatabaseUploadFailed) as ex:
csv_reader.file_to_dataframe(create_csv_file(CSV_DATA))
assert str(ex.value) == (
"Parsing error: Usecols do not match columns, "
"columns expected but not found: ['xpto']"
)
def test_csv_reader_invalid_file():
csv_reader = CSVReader(
options=CSVReaderOptions(),
)
with pytest.raises(DatabaseUploadFailed) as ex:
csv_reader.file_to_dataframe(
io.StringIO("c1,c2,c3\na,b,c\n1,2,3,4,5,6,7\n1,2,3")
)
assert str(ex.value) == (
"Parsing error: Error tokenizing data. C error:"
" Expected 3 fields in line 3, saw 7\n"
)
def test_csv_reader_invalid_encoding():
csv_reader = CSVReader(
options=CSVReaderOptions(),
)
binary_data = b"col1,col2,col3\nv1,v2,\xba\nv3,v4,v5\n"
with pytest.raises(DatabaseUploadFailed) as ex:
csv_reader.file_to_dataframe(io.BytesIO(binary_data))
assert str(ex.value) == (
"Parsing error: 'utf-8' codec can't decode byte 0xba in"
" position 21: invalid start byte"
)

View File

@ -0,0 +1,209 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
from datetime import datetime
from typing import Any
import numpy as np
import pytest
from superset.commands.database.exceptions import DatabaseUploadFailed
from superset.commands.database.uploaders.excel_reader import (
ExcelReader,
ExcelReaderOptions,
)
from tests.unit_tests.fixtures.common import create_excel_file
EXCEL_DATA: dict[str, list[Any]] = {
"Name": ["name1", "name2", "name3"],
"Age": [30, 25, 20],
"City": ["city1", "city2", "city3"],
"Birth": ["1990-02-01", "1995-02-01", "2000-02-01"],
}
EXCEL_WITH_NULLS: dict[str, list[Any]] = {
"Name": ["name1", "name2", "name3"],
"Age": ["N/A", 25, 20],
"City": ["city1", "None", "city3"],
"Birth": ["1990-02-01", "1995-02-01", "2000-02-01"],
}
EXCEL_DATA_DECIMAL_CHAR = {
"Name": ["name1"],
"Age": ["30,1"],
"City": ["city1"],
"Birth": ["1990-02-01"],
}
@pytest.mark.parametrize(
"file, options, expected_cols, expected_values",
[
(
create_excel_file(EXCEL_DATA),
ExcelReaderOptions(),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", "1990-02-01"],
["name2", 25, "city2", "1995-02-01"],
["name3", 20, "city3", "2000-02-01"],
],
),
(
create_excel_file(EXCEL_DATA),
ExcelReaderOptions(
columns_read=["Name", "Age"],
),
["Name", "Age"],
[
["name1", 30],
["name2", 25],
["name3", 20],
],
),
(
create_excel_file(EXCEL_DATA),
ExcelReaderOptions(
columns_read=[],
),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", "1990-02-01"],
["name2", 25, "city2", "1995-02-01"],
["name3", 20, "city3", "2000-02-01"],
],
),
(
create_excel_file(EXCEL_DATA),
ExcelReaderOptions(
rows_to_read=1,
),
["Name", "Age", "City", "Birth"],
[
["name1", 30.0, "city1", "1990-02-01"],
],
),
(
create_excel_file(EXCEL_DATA),
ExcelReaderOptions(
rows_to_read=1,
columns_read=["Name", "Age"],
),
["Name", "Age"],
[
["name1", 30.0],
],
),
(
create_excel_file(EXCEL_DATA),
ExcelReaderOptions(
skip_rows=1,
),
["name1", 30, "city1", "1990-02-01"],
[
["name2", 25.0, "city2", "1995-02-01"],
["name3", 20.0, "city3", "2000-02-01"],
],
),
(
create_excel_file(EXCEL_DATA),
ExcelReaderOptions(
column_dates=["Birth"],
),
["Name", "Age", "City", "Birth"],
[
["name1", 30, "city1", datetime(1990, 2, 1, 0, 0)],
["name2", 25, "city2", datetime(1995, 2, 1, 0, 0)],
["name3", 20, "city3", datetime(2000, 2, 1, 0, 0)],
],
),
(
create_excel_file(EXCEL_WITH_NULLS),
ExcelReaderOptions(
null_values=["N/A", "None"],
),
["Name", "Age", "City", "Birth"],
[
["name1", np.nan, "city1", "1990-02-01"],
["name2", 25.0, np.nan, "1995-02-01"],
["name3", 20.0, "city3", "2000-02-01"],
],
),
(
create_excel_file(EXCEL_DATA_DECIMAL_CHAR),
ExcelReaderOptions(
decimal_character=",",
),
["Name", "Age", "City", "Birth"],
[
["name1", 30.1, "city1", "1990-02-01"],
],
),
],
)
def test_excel_reader_file_to_dataframe(file, options, expected_cols, expected_values):
excel_reader = ExcelReader(
options=options,
)
df = excel_reader.file_to_dataframe(file)
assert df.columns.tolist() == expected_cols
actual_values = df.values.tolist()
for i in range(len(expected_values)):
for j in range(len(expected_values[i])):
expected_val = expected_values[i][j]
actual_val = actual_values[i][j]
# Check if both values are NaN
if isinstance(expected_val, float) and isinstance(actual_val, float):
assert np.isnan(expected_val) == np.isnan(actual_val)
else:
assert expected_val == actual_val
file.close()
def test_excel_reader_wrong_columns_to_read():
excel_reader = ExcelReader(
options=ExcelReaderOptions(columns_read=["xpto"]),
)
with pytest.raises(DatabaseUploadFailed) as ex:
excel_reader.file_to_dataframe(create_excel_file(EXCEL_DATA))
assert str(ex.value) == (
"Parsing error: Usecols do not match columns, "
"columns expected but not found: ['xpto'] (sheet: 0)"
)
def test_excel_reader_wrong_date():
excel_reader = ExcelReader(
options=ExcelReaderOptions(column_dates=["xpto"]),
)
with pytest.raises(DatabaseUploadFailed) as ex:
excel_reader.file_to_dataframe(create_excel_file(EXCEL_DATA))
assert str(ex.value) == (
"Parsing error: Missing column provided to 'parse_dates':" " 'xpto' (sheet: 0)"
)
def test_excel_reader_invalid_file():
excel_reader = ExcelReader(
options=ExcelReaderOptions(),
)
with pytest.raises(DatabaseUploadFailed) as ex:
excel_reader.file_to_dataframe(io.StringIO("c1"))
assert str(ex.value) == (
"Parsing error: Excel file format cannot be determined, you must specify an engine manually."
)

View File

@ -33,8 +33,9 @@ from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from superset import db from superset import db
from superset.commands.database.csv_import import CSVImportCommand from superset.commands.database.uploaders.base import UploadCommand
from superset.commands.database.excel_import import ExcelImportCommand from superset.commands.database.uploaders.csv_reader import CSVReader
from superset.commands.database.uploaders.excel_reader import ExcelReader
from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
@ -829,7 +830,7 @@ def test_oauth2_error(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload,cmd_called_with", "payload,upload_called_with,reader_called_with",
[ [
( (
{ {
@ -841,6 +842,10 @@ def test_oauth2_error(
1, 1,
"table1", "table1",
ANY, ANY,
None,
ANY,
),
(
{ {
"already_exists": "fail", "already_exists": "fail",
"delimiter": ",", "delimiter": ",",
@ -861,6 +866,10 @@ def test_oauth2_error(
1, 1,
"table2", "table2",
ANY, ANY,
None,
ANY,
),
(
{ {
"already_exists": "replace", "already_exists": "replace",
"column_dates": ["col1", "col2"], "column_dates": ["col1", "col2"],
@ -879,7 +888,6 @@ def test_oauth2_error(
"columns_read": "col1,col2", "columns_read": "col1,col2",
"day_first": True, "day_first": True,
"rows_to_read": "1", "rows_to_read": "1",
"overwrite_duplicates": True,
"skip_blank_lines": True, "skip_blank_lines": True,
"skip_initial_space": True, "skip_initial_space": True,
"skip_rows": "10", "skip_rows": "10",
@ -890,12 +898,15 @@ def test_oauth2_error(
1, 1,
"table2", "table2",
ANY, ANY,
None,
ANY,
),
(
{ {
"already_exists": "replace", "already_exists": "replace",
"columns_read": ["col1", "col2"], "columns_read": ["col1", "col2"],
"null_values": ["None", "N/A", "''"], "null_values": ["None", "N/A", "''"],
"day_first": True, "day_first": True,
"overwrite_duplicates": True,
"rows_to_read": 1, "rows_to_read": 1,
"skip_blank_lines": True, "skip_blank_lines": True,
"skip_initial_space": True, "skip_initial_space": True,
@ -911,7 +922,8 @@ def test_oauth2_error(
) )
def test_csv_upload( def test_csv_upload(
payload: dict[str, Any], payload: dict[str, Any],
cmd_called_with: tuple[int, str, Any, dict[str, Any]], upload_called_with: tuple[int, str, Any, dict[str, Any]],
reader_called_with: dict[str, Any],
mocker: MockFixture, mocker: MockFixture,
client: Any, client: Any,
full_api_access: None, full_api_access: None,
@ -919,9 +931,11 @@ def test_csv_upload(
""" """
Test CSV Upload success. Test CSV Upload success.
""" """
init_mock = mocker.patch.object(CSVImportCommand, "__init__") init_mock = mocker.patch.object(UploadCommand, "__init__")
init_mock.return_value = None init_mock.return_value = None
_ = mocker.patch.object(CSVImportCommand, "run") _ = mocker.patch.object(UploadCommand, "run")
reader_mock = mocker.patch.object(CSVReader, "__init__")
reader_mock.return_value = None
response = client.post( response = client.post(
f"/api/v1/database/1/csv_upload/", f"/api/v1/database/1/csv_upload/",
data=payload, data=payload,
@ -929,7 +943,8 @@ def test_csv_upload(
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json == {"message": "OK"} assert response.json == {"message": "OK"}
init_mock.assert_called_with(*cmd_called_with) init_mock.assert_called_with(*upload_called_with)
reader_mock.assert_called_with(*reader_called_with)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -994,16 +1009,6 @@ def test_csv_upload(
}, },
{"message": {"header_row": ["Not a valid integer."]}}, {"message": {"header_row": ["Not a valid integer."]}},
), ),
(
{
"file": (create_csv_file(), "out.csv"),
"table_name": "table1",
"delimiter": ",",
"already_exists": "fail",
"overwrite_duplicates": "test1",
},
{"message": {"overwrite_duplicates": ["Not a valid boolean."]}},
),
( (
{ {
"file": (create_csv_file(), "out.csv"), "file": (create_csv_file(), "out.csv"),
@ -1066,7 +1071,7 @@ def test_csv_upload_validation(
""" """
Test CSV Upload validation fails. Test CSV Upload validation fails.
""" """
_ = mocker.patch.object(CSVImportCommand, "run") _ = mocker.patch.object(UploadCommand, "run")
response = client.post( response = client.post(
f"/api/v1/database/1/csv_upload/", f"/api/v1/database/1/csv_upload/",
@ -1085,7 +1090,7 @@ def test_csv_upload_file_size_validation(
""" """
Test CSV Upload validation fails. Test CSV Upload validation fails.
""" """
_ = mocker.patch.object(CSVImportCommand, "run") _ = mocker.patch.object(UploadCommand, "run")
current_app.config["CSV_UPLOAD_MAX_SIZE"] = 5 current_app.config["CSV_UPLOAD_MAX_SIZE"] = 5
response = client.post( response = client.post(
f"/api/v1/database/1/csv_upload/", f"/api/v1/database/1/csv_upload/",
@ -1127,7 +1132,7 @@ def test_csv_upload_file_extension_invalid(
""" """
Test CSV Upload validation fails. Test CSV Upload validation fails.
""" """
_ = mocker.patch.object(CSVImportCommand, "run") _ = mocker.patch.object(UploadCommand, "run")
response = client.post( response = client.post(
f"/api/v1/database/1/csv_upload/", f"/api/v1/database/1/csv_upload/",
data={ data={
@ -1163,7 +1168,7 @@ def test_csv_upload_file_extension_valid(
""" """
Test CSV Upload validation fails. Test CSV Upload validation fails.
""" """
_ = mocker.patch.object(CSVImportCommand, "run") _ = mocker.patch.object(UploadCommand, "run")
response = client.post( response = client.post(
f"/api/v1/database/1/csv_upload/", f"/api/v1/database/1/csv_upload/",
data={ data={
@ -1177,7 +1182,7 @@ def test_csv_upload_file_extension_valid(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload,cmd_called_with", "payload,upload_called_with,reader_called_with",
[ [
( (
{ {
@ -1188,6 +1193,10 @@ def test_csv_upload_file_extension_valid(
1, 1,
"table1", "table1",
ANY, ANY,
None,
ANY,
),
(
{ {
"already_exists": "fail", "already_exists": "fail",
"file": ANY, "file": ANY,
@ -1207,6 +1216,10 @@ def test_csv_upload_file_extension_valid(
1, 1,
"table2", "table2",
ANY, ANY,
None,
ANY,
),
(
{ {
"already_exists": "replace", "already_exists": "replace",
"column_dates": ["col1", "col2"], "column_dates": ["col1", "col2"],
@ -1231,6 +1244,10 @@ def test_csv_upload_file_extension_valid(
1, 1,
"table2", "table2",
ANY, ANY,
None,
ANY,
),
(
{ {
"already_exists": "replace", "already_exists": "replace",
"columns_read": ["col1", "col2"], "columns_read": ["col1", "col2"],
@ -1247,7 +1264,8 @@ def test_csv_upload_file_extension_valid(
) )
def test_excel_upload( def test_excel_upload(
payload: dict[str, Any], payload: dict[str, Any],
cmd_called_with: tuple[int, str, Any, dict[str, Any]], upload_called_with: tuple[int, str, Any, dict[str, Any]],
reader_called_with: dict[str, Any],
mocker: MockFixture, mocker: MockFixture,
client: Any, client: Any,
full_api_access: None, full_api_access: None,
@ -1255,9 +1273,11 @@ def test_excel_upload(
""" """
Test Excel Upload success. Test Excel Upload success.
""" """
init_mock = mocker.patch.object(ExcelImportCommand, "__init__") init_mock = mocker.patch.object(UploadCommand, "__init__")
init_mock.return_value = None init_mock.return_value = None
_ = mocker.patch.object(ExcelImportCommand, "run") _ = mocker.patch.object(UploadCommand, "run")
reader_mock = mocker.patch.object(ExcelReader, "__init__")
reader_mock.return_value = None
response = client.post( response = client.post(
f"/api/v1/database/1/excel_upload/", f"/api/v1/database/1/excel_upload/",
data=payload, data=payload,
@ -1265,7 +1285,8 @@ def test_excel_upload(
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json == {"message": "OK"} assert response.json == {"message": "OK"}
init_mock.assert_called_with(*cmd_called_with) init_mock.assert_called_with(*upload_called_with)
reader_mock.assert_called_with(*reader_called_with)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -1347,7 +1368,7 @@ def test_excel_upload_validation(
""" """
Test Excel Upload validation fails. Test Excel Upload validation fails.
""" """
_ = mocker.patch.object(ExcelImportCommand, "run") _ = mocker.patch.object(UploadCommand, "run")
response = client.post( response = client.post(
f"/api/v1/database/1/excel_upload/", f"/api/v1/database/1/excel_upload/",
@ -1382,7 +1403,7 @@ def test_excel_upload_file_extension_invalid(
""" """
Test Excel Upload file extension fails. Test Excel Upload file extension fails.
""" """
_ = mocker.patch.object(ExcelImportCommand, "run") _ = mocker.patch.object(UploadCommand, "run")
response = client.post( response = client.post(
f"/api/v1/database/1/excel_upload/", f"/api/v1/database/1/excel_upload/",
data={ data={

View File

@ -31,7 +31,7 @@ def dttm() -> datetime:
return datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f") return datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
def create_csv_file(data: list[list[str]] | None = None) -> BytesIO: def create_csv_file(data: list[list[str]] | None = None, delimiter=",") -> BytesIO:
data = ( data = (
[ [
["Name", "Age", "City"], ["Name", "Age", "City"],
@ -42,7 +42,7 @@ def create_csv_file(data: list[list[str]] | None = None) -> BytesIO:
) )
output = StringIO() output = StringIO()
writer = csv.writer(output) writer = csv.writer(output, delimiter=delimiter)
for row in data: for row in data:
writer.writerow(row) writer.writerow(row)
output.seek(0) output.seek(0)