diff --git a/superset-frontend/src/features/databases/UploadDataModel/UploadDataModal.test.tsx b/superset-frontend/src/features/databases/UploadDataModel/UploadDataModal.test.tsx index 7adb325b6..4de9a89e6 100644 --- a/superset-frontend/src/features/databases/UploadDataModel/UploadDataModal.test.tsx +++ b/superset-frontend/src/features/databases/UploadDataModel/UploadDataModal.test.tsx @@ -301,7 +301,6 @@ test('CSV, renders the columns elements correctly', () => { const selectColumnsToRead = screen.getByRole('combobox', { name: /Choose columns to read/i, }); - const switchOverwriteDuplicates = screen.getByTestId('overwriteDuplicates'); const inputColumnDataTypes = screen.getByRole('textbox', { name: /Column data types/i, }); @@ -312,7 +311,6 @@ test('CSV, renders the columns elements correctly', () => { switchDataFrameIndex, inputColumnLabels, selectColumnsToRead, - switchOverwriteDuplicates, inputColumnDataTypes, ]; visibleComponents.forEach(component => { diff --git a/superset-frontend/src/features/databases/UploadDataModel/index.tsx b/superset-frontend/src/features/databases/UploadDataModel/index.tsx index 2fcb267ae..066910c8e 100644 --- a/superset-frontend/src/features/databases/UploadDataModel/index.tsx +++ b/superset-frontend/src/features/databases/UploadDataModel/index.tsx @@ -68,7 +68,6 @@ const CSVSpecificFields = [ 'skip_initial_space', 'skip_blank_lines', 'day_first', - 'overwrite_duplicates', 'column_data_types', ]; @@ -109,7 +108,6 @@ interface UploadInfo { dataframe_index: boolean; column_labels: string; columns_read: Array; - overwrite_duplicates: boolean; column_data_types: string; } @@ -132,7 +130,6 @@ const defaultUploadInfo: UploadInfo = { dataframe_index: false, column_labels: '', columns_read: [], - overwrite_duplicates: false, column_data_types: '', }; @@ -975,20 +972,6 @@ const UploadDataModal: FunctionComponent = ({ - {type === 'csv' && ( - - - - - - - - )} None: - self._model_id = model_id - self._model: Optional[Database] = None - self._table_name = table_name - self._schema = options.get("schema") - self._file = file - self._options = options - - def _read_csv(self) -> pd.DataFrame: - """ - Read CSV file into a DataFrame - - :return: pandas DataFrame - :throws DatabaseUploadFailed: if there is an error reading the CSV file - """ - try: - return pd.concat( - pd.read_csv( - chunksize=READ_CSV_CHUNK_SIZE, - encoding="utf-8", - filepath_or_buffer=self._file, - header=self._options.get("header_row", 0), - index_col=self._options.get("index_column"), - dayfirst=self._options.get("day_first", False), - iterator=True, - keep_default_na=not self._options.get("null_values"), - usecols=self._options.get("columns_read") - if self._options.get("columns_read") # None if an empty list - else None, - na_values=self._options.get("null_values") - if self._options.get("null_values") # None if an empty list - else None, - nrows=self._options.get("rows_to_read"), - parse_dates=self._options.get("column_dates"), - sep=self._options.get("delimiter", ","), - skip_blank_lines=self._options.get("skip_blank_lines", False), - skipinitialspace=self._options.get("skip_initial_space", False), - skiprows=self._options.get("skip_rows", 0), - dtype=self._options.get("column_data_types") - if self._options.get("column_data_types") - else None, - ) - ) - except ( - pd.errors.ParserError, - pd.errors.EmptyDataError, - UnicodeDecodeError, - ValueError, - ) as ex: - raise DatabaseUploadFailed( - message=_("Parsing error: %(error)s", error=str(ex)) - ) from ex - except Exception as ex: - raise DatabaseUploadFailed(_("Error reading CSV file")) from ex - - def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None: - """ - Upload DataFrame to database - - :param df: - :throws DatabaseUploadFailed: if there is an error uploading the DataFrame - """ - try: - csv_table = Table(table=self._table_name, schema=self._schema) - database.db_engine_spec.df_to_sql( - database, - csv_table, - df, - to_sql_kwargs={ - "chunksize": READ_CSV_CHUNK_SIZE, - "if_exists": self._options.get("already_exists", "fail"), - "index": self._options.get("index_column"), - "index_label": self._options.get("column_labels"), - }, - ) - except ValueError as ex: - raise DatabaseUploadFailed( - message=_( - "Table already exists. You can change your " - "'if table already exists' strategy to append or " - "replace or provide a different Table Name to use." - ) - ) from ex - except Exception as ex: - raise DatabaseUploadFailed(exception=ex) from ex - - def run(self) -> None: - self.validate() - if not self._model: - return - - df = self._read_csv() - self._dataframe_to_database(df, self._model) - - sqla_table = ( - db.session.query(SqlaTable) - .filter_by( - table_name=self._table_name, - schema=self._schema, - database_id=self._model_id, - ) - .one_or_none() - ) - if not sqla_table: - sqla_table = SqlaTable( - table_name=self._table_name, - database=self._model, - database_id=self._model_id, - owners=[get_user()], - schema=self._schema, - ) - db.session.add(sqla_table) - - sqla_table.fetch_metadata() - - try: - db.session.commit() - except SQLAlchemyError as ex: - db.session.rollback() - raise DatabaseUploadSaveMetadataFailed() from ex - - def validate(self) -> None: - self._model = DatabaseDAO.find_by_id(self._model_id) - if not self._model: - raise DatabaseNotFoundError() - if not schema_allows_file_upload(self._model, self._schema): - raise DatabaseSchemaUploadNotAllowed() diff --git a/superset/commands/database/exceptions.py b/superset/commands/database/exceptions.py index 77199d9a7..410eb9236 100644 --- a/superset/commands/database/exceptions.py +++ b/superset/commands/database/exceptions.py @@ -98,6 +98,11 @@ class DatabaseSchemaUploadNotAllowed(CommandException): message = _("Database schema is not allowed for csv uploads.") +class DatabaseUploadNotSupported(CommandException): + status = 422 + message = _("Database type does not support file uploads.") + + class DatabaseUploadFailed(CommandException): status = 422 message = _("Database upload file failed") diff --git a/superset/commands/database/uploaders/__init__.py b/superset/commands/database/uploaders/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/superset/commands/database/uploaders/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/commands/database/excel_import.py b/superset/commands/database/uploaders/base.py similarity index 65% rename from superset/commands/database/excel_import.py rename to superset/commands/database/uploaders/base.py index 6c2133df0..80e9b135a 100644 --- a/superset/commands/database/excel_import.py +++ b/superset/commands/database/uploaders/base.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import logging +from abc import abstractmethod from typing import Any, Optional, TypedDict import pandas as pd @@ -27,6 +29,7 @@ from superset.commands.database.exceptions import ( DatabaseNotFoundError, DatabaseSchemaUploadNotAllowed, DatabaseUploadFailed, + DatabaseUploadNotSupported, DatabaseUploadSaveMetadataFailed, ) from superset.connectors.sqla.models import SqlaTable @@ -38,78 +41,43 @@ from superset.views.database.validators import schema_allows_file_upload logger = logging.getLogger(__name__) -READ_EXCEL_CHUNK_SIZE = 1000 +READ_CHUNK_SIZE = 1000 -class ExcelImportOptions(TypedDict, total=False): - sheet_name: str - schema: str +class ReaderOptions(TypedDict, total=False): already_exists: str - column_dates: list[str] column_labels: str - columns_read: list[str] - dataframe_index: str - decimal_character: str - header_row: int index_column: str - null_values: list[str] - rows_to_read: int - skip_rows: int -class ExcelImportCommand(BaseCommand): - def __init__( - self, - model_id: int, - table_name: str, - file: Any, - options: ExcelImportOptions, - ) -> None: - self._model_id = model_id - self._model: Optional[Database] = None - self._table_name = table_name - self._schema = options.get("schema") - self._file = file +class BaseDataReader: + """ + Base class for reading data from a file and uploading it to a database + These child objects are used by the UploadCommand as a dependency injection + to read data from multiple file types (e.g. CSV, Excel, etc.) + """ + + def __init__(self, options: dict[str, Any]) -> None: self._options = options - def _read_excel(self) -> pd.DataFrame: - """ - Read Excel file into a DataFrame + @abstractmethod + def file_to_dataframe(self, file: Any) -> pd.DataFrame: + ... - :return: pandas DataFrame - :throws DatabaseUploadFailed: if there is an error reading the CSV file - """ + def read( + self, file: Any, database: Database, table_name: str, schema_name: Optional[str] + ) -> None: + self._dataframe_to_database( + self.file_to_dataframe(file), database, table_name, schema_name + ) - kwargs = { - "header": self._options.get("header_row", 0), - "index_col": self._options.get("index_column"), - "io": self._file, - "keep_default_na": not self._options.get("null_values"), - "na_values": self._options.get("null_values") - if self._options.get("null_values") # None if an empty list - else None, - "parse_dates": self._options.get("column_dates"), - "skiprows": self._options.get("skip_rows", 0), - "sheet_name": self._options.get("sheet_name", 0), - "nrows": self._options.get("rows_to_read"), - } - if self._options.get("columns_read"): - kwargs["usecols"] = self._options.get("columns_read") - try: - return pd.read_excel(**kwargs) - except ( - pd.errors.ParserError, - pd.errors.EmptyDataError, - UnicodeDecodeError, - ValueError, - ) as ex: - raise DatabaseUploadFailed( - message=_("Parsing error: %(error)s", error=str(ex)) - ) from ex - except Exception as ex: - raise DatabaseUploadFailed(_("Error reading Excel file")) from ex - - def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None: + def _dataframe_to_database( + self, + df: pd.DataFrame, + database: Database, + table_name: str, + schema_name: Optional[str], + ) -> None: """ Upload DataFrame to database @@ -117,13 +85,13 @@ class ExcelImportCommand(BaseCommand): :throws DatabaseUploadFailed: if there is an error uploading the DataFrame """ try: - data_table = Table(table=self._table_name, schema=self._schema) + data_table = Table(table=table_name, schema=schema_name) database.db_engine_spec.df_to_sql( database, data_table, df, to_sql_kwargs={ - "chunksize": READ_EXCEL_CHUNK_SIZE, + "chunksize": READ_CHUNK_SIZE, "if_exists": self._options.get("already_exists", "fail"), "index": self._options.get("index_column"), "index_label": self._options.get("column_labels"), @@ -140,13 +108,29 @@ class ExcelImportCommand(BaseCommand): except Exception as ex: raise DatabaseUploadFailed(exception=ex) from ex + +class UploadCommand(BaseCommand): + def __init__( # pylint: disable=too-many-arguments + self, + model_id: int, + table_name: str, + file: Any, + schema: Optional[str], + reader: BaseDataReader, + ) -> None: + self._model_id = model_id + self._model: Optional[Database] = None + self._table_name = table_name + self._schema = schema + self._file = file + self._reader = reader + def run(self) -> None: self.validate() if not self._model: return - df = self._read_excel() - self._dataframe_to_database(df, self._model) + self._reader.read(self._file, self._model, self._table_name, self._schema) sqla_table = ( db.session.query(SqlaTable) @@ -181,3 +165,5 @@ class ExcelImportCommand(BaseCommand): raise DatabaseNotFoundError() if not schema_allows_file_upload(self._model, self._schema): raise DatabaseSchemaUploadNotAllowed() + if not self._model.db_engine_spec.supports_file_upload: + raise DatabaseUploadNotSupported() diff --git a/superset/commands/database/uploaders/csv_reader.py b/superset/commands/database/uploaders/csv_reader.py new file mode 100644 index 000000000..06ff04f5c --- /dev/null +++ b/superset/commands/database/uploaders/csv_reader.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any + +import pandas as pd +from flask_babel import lazy_gettext as _ + +from superset.commands.database.exceptions import DatabaseUploadFailed +from superset.commands.database.uploaders.base import BaseDataReader, ReaderOptions + +logger = logging.getLogger(__name__) + +READ_CSV_CHUNK_SIZE = 1000 + + +class CSVReaderOptions(ReaderOptions, total=False): + delimiter: str + column_data_types: dict[str, str] + column_dates: list[str] + columns_read: list[str] + dataframe_index: str + day_first: bool + decimal_character: str + header_row: int + null_values: list[str] + rows_to_read: int + skip_blank_lines: bool + skip_initial_space: bool + skip_rows: int + + +class CSVReader(BaseDataReader): + def __init__( + self, + options: CSVReaderOptions, + ) -> None: + super().__init__( + options=dict(options), + ) + + def file_to_dataframe(self, file: Any) -> pd.DataFrame: + """ + Read CSV file into a DataFrame + + :return: pandas DataFrame + :throws DatabaseUploadFailed: if there is an error reading the CSV file + """ + try: + return pd.concat( + pd.read_csv( + chunksize=READ_CSV_CHUNK_SIZE, + encoding="utf-8", + filepath_or_buffer=file, + header=self._options.get("header_row", 0), + decimal=self._options.get("decimal_character", "."), + index_col=self._options.get("index_column"), + dayfirst=self._options.get("day_first", False), + iterator=True, + keep_default_na=not self._options.get("null_values"), + usecols=self._options.get("columns_read") + if self._options.get("columns_read") # None if an empty list + else None, + na_values=self._options.get("null_values") + if self._options.get("null_values") # None if an empty list + else None, + nrows=self._options.get("rows_to_read"), + parse_dates=self._options.get("column_dates"), + sep=self._options.get("delimiter", ","), + skip_blank_lines=self._options.get("skip_blank_lines", False), + skipinitialspace=self._options.get("skip_initial_space", False), + skiprows=self._options.get("skip_rows", 0), + dtype=self._options.get("column_data_types") + if self._options.get("column_data_types") + else None, + ) + ) + except ( + pd.errors.ParserError, + pd.errors.EmptyDataError, + UnicodeDecodeError, + ValueError, + ) as ex: + raise DatabaseUploadFailed( + message=_("Parsing error: %(error)s", error=str(ex)) + ) from ex + except Exception as ex: + raise DatabaseUploadFailed(_("Error reading CSV file")) from ex diff --git a/superset/commands/database/uploaders/excel_reader.py b/superset/commands/database/uploaders/excel_reader.py new file mode 100644 index 000000000..54e312fea --- /dev/null +++ b/superset/commands/database/uploaders/excel_reader.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any + +import pandas as pd +from flask_babel import lazy_gettext as _ + +from superset.commands.database.exceptions import DatabaseUploadFailed +from superset.commands.database.uploaders.base import BaseDataReader, ReaderOptions + +logger = logging.getLogger(__name__) + + +class ExcelReaderOptions(ReaderOptions, total=False): + sheet_name: str + column_dates: list[str] + columns_read: list[str] + dataframe_index: str + decimal_character: str + header_row: int + null_values: list[str] + rows_to_read: int + skip_rows: int + + +class ExcelReader(BaseDataReader): + def __init__( + self, + options: ExcelReaderOptions, + ) -> None: + super().__init__( + options=dict(options), + ) + + def file_to_dataframe(self, file: Any) -> pd.DataFrame: + """ + Read Excel file into a DataFrame + + :return: pandas DataFrame + :throws DatabaseUploadFailed: if there is an error reading the CSV file + """ + + kwargs = { + "header": self._options.get("header_row", 0), + "index_col": self._options.get("index_column"), + "io": file, + "keep_default_na": not self._options.get("null_values"), + "decimal": self._options.get("decimal_character", "."), + "na_values": self._options.get("null_values") + if self._options.get("null_values") # None if an empty list + else None, + "parse_dates": self._options.get("column_dates"), + "skiprows": self._options.get("skip_rows", 0), + "sheet_name": self._options.get("sheet_name", 0), + "nrows": self._options.get("rows_to_read"), + } + if self._options.get("columns_read"): + kwargs["usecols"] = self._options.get("columns_read") + try: + return pd.read_excel(**kwargs) + except ( + pd.errors.ParserError, + pd.errors.EmptyDataError, + UnicodeDecodeError, + ValueError, + ) as ex: + raise DatabaseUploadFailed( + message=_("Parsing error: %(error)s", error=str(ex)) + ) from ex + except Exception as ex: + raise DatabaseUploadFailed(_("Error reading Excel file")) from ex diff --git a/superset/databases/api.py b/superset/databases/api.py index b5552d1ab..0e8e5be39 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -34,9 +34,7 @@ from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError from superset import app, event_logger from superset.commands.database.create import CreateDatabaseCommand -from superset.commands.database.csv_import import CSVImportCommand from superset.commands.database.delete import DeleteDatabaseCommand -from superset.commands.database.excel_import import ExcelImportCommand from superset.commands.database.exceptions import ( DatabaseConnectionFailedError, DatabaseCreateFailedError, @@ -59,6 +57,9 @@ from superset.commands.database.ssh_tunnel.exceptions import ( from superset.commands.database.tables import TablesDatabaseCommand from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.commands.database.update import UpdateDatabaseCommand +from superset.commands.database.uploaders.base import UploadCommand +from superset.commands.database.uploaders.csv_reader import CSVReader +from superset.commands.database.uploaders.excel_reader import ExcelReader from superset.commands.database.validate import ValidateDatabaseParametersCommand from superset.commands.database.validate_sql import ValidateSQLCommand from superset.commands.importers.exceptions import ( @@ -1491,11 +1492,12 @@ class DatabaseRestApi(BaseSupersetModelRestApi): request_form = request.form.to_dict() request_form["file"] = request.files.get("file") parameters = CSVUploadPostSchema().load(request_form) - CSVImportCommand( + UploadCommand( pk, parameters["table_name"], parameters["file"], - parameters, + parameters.get("schema"), + CSVReader(parameters), ).run() except ValidationError as error: return self.response_400(message=error.messages) @@ -1550,11 +1552,12 @@ class DatabaseRestApi(BaseSupersetModelRestApi): request_form = request.form.to_dict() request_form["file"] = request.files.get("file") parameters = ExcelUploadPostSchema().load(request_form) - ExcelImportCommand( + UploadCommand( pk, parameters["table_name"], parameters["file"], - parameters, + parameters.get("schema"), + ExcelReader(parameters), ).run() except ValidationError as error: return self.response_400(message=error.messages) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 5843473bc..9a1fc9d6c 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -1116,12 +1116,6 @@ class CSVUploadPostSchema(BaseUploadPostSchema): "description": "DD/MM format dates, international and European format" } ) - overwrite_duplicates = fields.Boolean( - metadata={ - "description": "If duplicate columns are not overridden," - "they will be presented as 'X.1, X.2 ...X.x'." - } - ) skip_blank_lines = fields.Boolean( metadata={"description": "Skip blank lines in the CSV file."} ) diff --git a/tests/integration_tests/databases/commands/excel_upload_test.py b/tests/integration_tests/databases/commands/excel_upload_test.py deleted file mode 100644 index 51ac4e71f..000000000 --- a/tests/integration_tests/databases/commands/excel_upload_test.py +++ /dev/null @@ -1,257 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -from datetime import datetime - -import pytest - -from superset import db, security_manager -from superset.commands.database.excel_import import ExcelImportCommand -from superset.commands.database.exceptions import ( - DatabaseNotFoundError, - DatabaseSchemaUploadNotAllowed, - DatabaseUploadFailed, -) -from superset.models.core import Database -from superset.utils.core import override_user -from superset.utils.database import get_or_create_db -from tests.integration_tests.conftest import only_postgresql -from tests.integration_tests.test_app import app -from tests.unit_tests.fixtures.common import create_excel_file - -EXCEL_UPLOAD_DATABASE = "excel_explore_db" -EXCEL_UPLOAD_TABLE = "excel_upload" -EXCEL_UPLOAD_TABLE_W_SCHEMA = "excel_upload_w_schema" - - -EXCEL_FILE_1 = { - "Name": ["name1", "name2", "name3"], - "Age": [30, 29, 28], - "City": ["city1", "city2", "city3"], - "Birth": ["1-1-1980", "1-1-1981", "1-1-1982"], -} - -EXCEL_FILE_2 = { - "Name": ["name1", "name2", "name3"], - "Age": ["N/A", 29, 28], - "City": ["city1", "None", "city3"], - "Birth": ["1-1-1980", "1-1-1981", "1-1-1982"], -} - - -def _setup_excel_upload(allowed_schemas: list[str] | None = None): - upload_db = get_or_create_db( - EXCEL_UPLOAD_DATABASE, app.config["SQLALCHEMY_EXAMPLES_URI"] - ) - upload_db.allow_file_upload = True - extra = upload_db.get_extra() - allowed_schemas = allowed_schemas or [] - extra["schemas_allowed_for_file_upload"] = allowed_schemas - upload_db.extra = json.dumps(extra) - - db.session.commit() - - yield - - upload_db = get_upload_db() - with upload_db.get_sqla_engine_with_context() as engine: - engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}") - engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE_W_SCHEMA}") - db.session.delete(upload_db) - db.session.commit() - - -def get_upload_db(): - return ( - db.session.query(Database).filter_by(database_name=EXCEL_UPLOAD_DATABASE).one() - ) - - -@pytest.fixture(scope="function") -def setup_excel_upload_with_context(): - with app.app_context(): - yield from _setup_excel_upload() - - -@pytest.fixture(scope="function") -def setup_excel_upload_with_context_schema(): - with app.app_context(): - yield from _setup_excel_upload(["public"]) - - -@only_postgresql -@pytest.mark.parametrize( - "excel_data,options, table_data", - [ - ( - EXCEL_FILE_1, - {}, - [ - ("name1", 30, "city1", "1-1-1980"), - ("name2", 29, "city2", "1-1-1981"), - ("name3", 28, "city3", "1-1-1982"), - ], - ), - ( - EXCEL_FILE_1, - {"columns_read": ["Name", "Age"]}, - [("name1", 30), ("name2", 29), ("name3", 28)], - ), - ( - EXCEL_FILE_1, - {"columns_read": []}, - [ - ("name1", 30, "city1", "1-1-1980"), - ("name2", 29, "city2", "1-1-1981"), - ("name3", 28, "city3", "1-1-1982"), - ], - ), - ( - EXCEL_FILE_1, - {"rows_to_read": 1}, - [ - ("name1", 30, "city1", "1-1-1980"), - ], - ), - ( - EXCEL_FILE_1, - {"rows_to_read": 1, "columns_read": ["Name", "Age"]}, - [ - ("name1", 30), - ], - ), - ( - EXCEL_FILE_1, - {"skip_rows": 1}, - [("name2", 29, "city2", "1-1-1981"), ("name3", 28, "city3", "1-1-1982")], - ), - ( - EXCEL_FILE_1, - {"rows_to_read": 2}, - [ - ("name1", 30, "city1", "1-1-1980"), - ("name2", 29, "city2", "1-1-1981"), - ], - ), - ( - EXCEL_FILE_1, - {"column_dates": ["Birth"]}, - [ - ("name1", 30, "city1", datetime(1980, 1, 1, 0, 0)), - ("name2", 29, "city2", datetime(1981, 1, 1, 0, 0)), - ("name3", 28, "city3", datetime(1982, 1, 1, 0, 0)), - ], - ), - ( - EXCEL_FILE_2, - {"null_values": ["N/A", "None"]}, - [ - ("name1", None, "city1", "1-1-1980"), - ("name2", 29, None, "1-1-1981"), - ("name3", 28, "city3", "1-1-1982"), - ], - ), - ( - EXCEL_FILE_2, - { - "null_values": ["N/A", "None"], - "column_dates": ["Birth"], - "columns_read": ["Name", "Age", "Birth"], - }, - [ - ("name1", None, datetime(1980, 1, 1, 0, 0)), - ("name2", 29, datetime(1981, 1, 1, 0, 0)), - ("name3", 28, datetime(1982, 1, 1, 0, 0)), - ], - ), - ], -) -@pytest.mark.usefixtures("setup_excel_upload_with_context") -def test_excel_upload_options(excel_data, options, table_data): - admin_user = security_manager.find_user(username="admin") - upload_database = get_upload_db() - - with override_user(admin_user): - ExcelImportCommand( - upload_database.id, - EXCEL_UPLOAD_TABLE, - create_excel_file(excel_data), - options=options, - ).run() - with upload_database.get_sqla_engine_with_context() as engine: - data = engine.execute(f"SELECT * from {EXCEL_UPLOAD_TABLE}").fetchall() - assert data == table_data - - -@only_postgresql -@pytest.mark.usefixtures("setup_excel_upload_with_context") -def test_excel_upload_database_not_found(): - admin_user = security_manager.find_user(username="admin") - - with override_user(admin_user): - with pytest.raises(DatabaseNotFoundError): - ExcelImportCommand( - 1000, - EXCEL_UPLOAD_TABLE, - create_excel_file(EXCEL_FILE_1), - options={}, - ).run() - - -@only_postgresql -@pytest.mark.usefixtures("setup_excel_upload_with_context_schema") -def test_excel_upload_schema_not_allowed(): - admin_user = security_manager.find_user(username="admin") - upload_db_id = get_upload_db().id - with override_user(admin_user): - with pytest.raises(DatabaseSchemaUploadNotAllowed): - ExcelImportCommand( - upload_db_id, - EXCEL_UPLOAD_TABLE, - create_excel_file(EXCEL_FILE_1), - options={}, - ).run() - - with pytest.raises(DatabaseSchemaUploadNotAllowed): - ExcelImportCommand( - upload_db_id, - EXCEL_UPLOAD_TABLE, - create_excel_file(EXCEL_FILE_1), - options={"schema": "schema1"}, - ).run() - - ExcelImportCommand( - upload_db_id, - EXCEL_UPLOAD_TABLE, - create_excel_file(EXCEL_FILE_1), - options={"schema": "public"}, - ).run() - - -@only_postgresql -@pytest.mark.usefixtures("setup_excel_upload_with_context") -def test_excel_upload_broken_file(): - admin_user = security_manager.find_user(username="admin") - - with override_user(admin_user): - with pytest.raises(DatabaseUploadFailed): - ExcelImportCommand( - get_upload_db().id, - EXCEL_UPLOAD_TABLE, - create_excel_file([""]), - options={"column_dates": ["Birth"]}, - ).run() diff --git a/tests/integration_tests/databases/commands/csv_upload_test.py b/tests/integration_tests/databases/commands/upload_test.py similarity index 52% rename from tests/integration_tests/databases/commands/csv_upload_test.py rename to tests/integration_tests/databases/commands/upload_test.py index 18cc6f4a8..f08be099c 100644 --- a/tests/integration_tests/databases/commands/csv_upload_test.py +++ b/tests/integration_tests/databases/commands/upload_test.py @@ -18,17 +18,19 @@ from __future__ import annotations import json -from datetime import datetime import pytest from superset import db, security_manager -from superset.commands.database.csv_import import CSVImportCommand from superset.commands.database.exceptions import ( DatabaseNotFoundError, DatabaseSchemaUploadNotAllowed, DatabaseUploadFailed, + DatabaseUploadNotSupported, ) +from superset.commands.database.uploaders.base import UploadCommand +from superset.commands.database.uploaders.csv_reader import CSVReader +from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database from superset.utils.core import override_user from superset.utils.database import get_or_create_db @@ -48,27 +50,13 @@ CSV_FILE_1 = [ ["name3", "28", "city3", "1-1-1982"], ] -CSV_FILE_2 = [ - ["name1", "30", "city1", "1-1-1980"], - ["Name", "Age", "City", "Birth"], - ["name2", "29", "city2", "1-1-1981"], - ["name3", "28", "city3", "1-1-1982"], -] - -CSV_FILE_3 = [ +CSV_FILE_WITH_NULLS = [ ["Name", "Age", "City", "Birth"], ["name1", "N/A", "city1", "1-1-1980"], ["name2", "29", "None", "1-1-1981"], ["name3", "28", "city3", "1-1-1982"], ] -CSV_FILE_BROKEN = [ - ["Name", "Age", "City", "Birth"], - ["name1", "30", "city1", "1-1-1980"], - ["name2", "29"], - ["name3", "28", "city3", "1-1-1982"], -] - def _setup_csv_upload(allowed_schemas: list[str] | None = None): upload_db = get_or_create_db( @@ -108,122 +96,48 @@ def setup_csv_upload_with_context_schema(): yield from _setup_csv_upload(["public"]) -@only_postgresql -@pytest.mark.parametrize( - "csv_data,options, table_data", - [ - ( - CSV_FILE_1, - {}, - [ - ("name1", 30, "city1", "1-1-1980"), - ("name2", 29, "city2", "1-1-1981"), - ("name3", 28, "city3", "1-1-1982"), - ], - ), - ( - CSV_FILE_1, - {"columns_read": ["Name", "Age"]}, - [("name1", 30), ("name2", 29), ("name3", 28)], - ), - ( - CSV_FILE_1, - {"columns_read": []}, - [ - ("name1", 30, "city1", "1-1-1980"), - ("name2", 29, "city2", "1-1-1981"), - ("name3", 28, "city3", "1-1-1982"), - ], - ), - ( - CSV_FILE_1, - {"rows_to_read": 1}, - [ - ("name1", 30, "city1", "1-1-1980"), - ], - ), - ( - CSV_FILE_1, - {"rows_to_read": 1, "columns_read": ["Name", "Age"]}, - [ - ("name1", 30), - ], - ), - ( - CSV_FILE_1, - {"skip_rows": 1}, - [("name2", 29, "city2", "1-1-1981"), ("name3", 28, "city3", "1-1-1982")], - ), - ( - CSV_FILE_1, - {"rows_to_read": 2}, - [ - ("name1", 30, "city1", "1-1-1980"), - ("name2", 29, "city2", "1-1-1981"), - ], - ), - ( - CSV_FILE_1, - {"column_dates": ["Birth"]}, - [ - ("name1", 30, "city1", datetime(1980, 1, 1, 0, 0)), - ("name2", 29, "city2", datetime(1981, 1, 1, 0, 0)), - ("name3", 28, "city3", datetime(1982, 1, 1, 0, 0)), - ], - ), - ( - CSV_FILE_2, - {"header_row": 1}, - [("name2", 29, "city2", "1-1-1981"), ("name3", 28, "city3", "1-1-1982")], - ), - ( - CSV_FILE_3, - {"null_values": ["N/A", "None"]}, - [ - ("name1", None, "city1", "1-1-1980"), - ("name2", 29, None, "1-1-1981"), - ("name3", 28, "city3", "1-1-1982"), - ], - ), - ( - CSV_FILE_3, - { - "null_values": ["N/A", "None"], - "column_dates": ["Birth"], - "columns_read": ["Name", "Age", "Birth"], - }, - [ - ("name1", None, datetime(1980, 1, 1, 0, 0)), - ("name2", 29, datetime(1981, 1, 1, 0, 0)), - ("name3", 28, datetime(1982, 1, 1, 0, 0)), - ], - ), - ( - CSV_FILE_BROKEN, - {}, - [ - ("name1", 30, "city1", "1-1-1980"), - ("name2", 29, None, None), - ("name3", 28, "city3", "1-1-1982"), - ], - ), - ], -) @pytest.mark.usefixtures("setup_csv_upload_with_context") -def test_csv_upload_options(csv_data, options, table_data): +def test_csv_upload_with_nulls(): admin_user = security_manager.find_user(username="admin") upload_database = get_upload_db() with override_user(admin_user): - CSVImportCommand( + UploadCommand( upload_database.id, CSV_UPLOAD_TABLE, - create_csv_file(csv_data), - options=options, + create_csv_file(CSV_FILE_WITH_NULLS), + None, + CSVReader({"null_values": ["N/A", "None"]}), ).run() - with upload_database.get_sqla_engine_with_context() as engine: - data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() - assert data == table_data + with upload_database.get_sqla_engine_with_context() as engine: + data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() + assert data == [ + ("name1", None, "city1", "1-1-1980"), + ("name2", 29, None, "1-1-1981"), + ("name3", 28, "city3", "1-1-1982"), + ] + + +@pytest.mark.usefixtures("setup_csv_upload_with_context") +def test_csv_upload_dataset(): + admin_user = security_manager.find_user(username="admin") + upload_database = get_upload_db() + + with override_user(admin_user): + UploadCommand( + upload_database.id, + CSV_UPLOAD_TABLE, + create_csv_file(), + None, + CSVReader({}), + ).run() + dataset = ( + db.session.query(SqlaTable) + .filter_by(database_id=upload_database.id, table_name=CSV_UPLOAD_TABLE) + .one_or_none() + ) + assert dataset is not None + assert security_manager.find_user("admin") in dataset.owners @only_postgresql @@ -233,14 +147,33 @@ def test_csv_upload_database_not_found(): with override_user(admin_user): with pytest.raises(DatabaseNotFoundError): - CSVImportCommand( + UploadCommand( 1000, CSV_UPLOAD_TABLE, create_csv_file(CSV_FILE_1), - options={}, + None, + CSVReader({}), ).run() +@only_postgresql +@pytest.mark.usefixtures("setup_csv_upload_with_context") +def test_csv_upload_database_not_supported(): + admin_user = security_manager.find_user(username="admin") + upload_db: Database = get_upload_db() + upload_db.db_engine_spec.supports_file_upload = False + with override_user(admin_user): + with pytest.raises(DatabaseUploadNotSupported): + UploadCommand( + upload_db.id, + CSV_UPLOAD_TABLE, + create_csv_file(CSV_FILE_1), + None, + CSVReader({}), + ).run() + upload_db.db_engine_spec.supports_file_upload = True + + @only_postgresql @pytest.mark.usefixtures("setup_csv_upload_with_context_schema") def test_csv_upload_schema_not_allowed(): @@ -248,39 +181,25 @@ def test_csv_upload_schema_not_allowed(): upload_db_id = get_upload_db().id with override_user(admin_user): with pytest.raises(DatabaseSchemaUploadNotAllowed): - CSVImportCommand( + UploadCommand( upload_db_id, CSV_UPLOAD_TABLE, create_csv_file(CSV_FILE_1), - options={}, + None, + CSVReader({}), ).run() - with pytest.raises(DatabaseSchemaUploadNotAllowed): - CSVImportCommand( + UploadCommand( upload_db_id, CSV_UPLOAD_TABLE, create_csv_file(CSV_FILE_1), - options={"schema": "schema1"}, + "schema1", + CSVReader({}), ).run() - - CSVImportCommand( + UploadCommand( upload_db_id, - CSV_UPLOAD_TABLE, + CSV_UPLOAD_TABLE_W_SCHEMA, create_csv_file(CSV_FILE_1), - options={"schema": "public"}, + "public", + CSVReader({}), ).run() - - -@only_postgresql -@pytest.mark.usefixtures("setup_csv_upload_with_context") -def test_csv_upload_broken_file(): - admin_user = security_manager.find_user(username="admin") - - with override_user(admin_user): - with pytest.raises(DatabaseUploadFailed): - CSVImportCommand( - get_upload_db().id, - CSV_UPLOAD_TABLE, - create_csv_file([""]), - options={"column_dates": ["Birth"]}, - ).run() diff --git a/tests/unit_tests/commands/databases/__init__.py b/tests/unit_tests/commands/databases/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/tests/unit_tests/commands/databases/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/commands/databases/csv_reader_test.py b/tests/unit_tests/commands/databases/csv_reader_test.py new file mode 100644 index 000000000..00c861f72 --- /dev/null +++ b/tests/unit_tests/commands/databases/csv_reader_test.py @@ -0,0 +1,313 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import io +from datetime import datetime + +import numpy as np +import pytest + +from superset.commands.database.exceptions import DatabaseUploadFailed +from superset.commands.database.uploaders.csv_reader import CSVReader, CSVReaderOptions +from tests.unit_tests.fixtures.common import create_csv_file + +CSV_DATA = [ + ["Name", "Age", "City", "Birth"], + ["name1", "30", "city1", "1990-02-01"], + ["name2", "25", "city2", "1995-02-01"], + ["name3", "20", "city3", "2000-02-01"], +] + +CSV_DATA_CHANGED_HEADER = [ + ["name1", "30", "city1", "1990-02-01"], + ["Name", "Age", "City", "Birth"], + ["name2", "25", "city2", "1995-02-01"], + ["name3", "20", "city3", "2000-02-01"], +] + +CSV_DATA_WITH_NULLS = [ + ["Name", "Age", "City", "Birth"], + ["name1", "N/A", "city1", "1990-02-01"], + ["name2", "25", "None", "1995-02-01"], + ["name3", "20", "city3", "2000-02-01"], +] + +CSV_DATA_DAY_FIRST = [ + ["Name", "Age", "City", "Birth"], + ["name1", "30", "city1", "01-02-1990"], +] + +CSV_DATA_DECIMAL_CHAR = [ + ["Name", "Age", "City", "Birth"], + ["name1", "30,1", "city1", "1990-02-01"], +] + +CSV_DATA_SKIP_INITIAL_SPACE = [ + [" Name", "Age", "City", "Birth"], + [" name1", "30", "city1", "1990-02-01"], +] + + +@pytest.mark.parametrize( + "file, options, expected_cols, expected_values", + [ + ( + create_csv_file(CSV_DATA), + CSVReaderOptions(), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", "1990-02-01"], + ["name2", 25, "city2", "1995-02-01"], + ["name3", 20, "city3", "2000-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA, delimiter="|"), + CSVReaderOptions(delimiter="|"), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", "1990-02-01"], + ["name2", 25, "city2", "1995-02-01"], + ["name3", 20, "city3", "2000-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA), + CSVReaderOptions( + columns_read=["Name", "Age"], + ), + ["Name", "Age"], + [ + ["name1", 30], + ["name2", 25], + ["name3", 20], + ], + ), + ( + create_csv_file(CSV_DATA), + CSVReaderOptions( + columns_read=["Name", "Age"], + column_data_types={"Age": "float"}, + ), + ["Name", "Age"], + [ + ["name1", 30.0], + ["name2", 25.0], + ["name3", 20.0], + ], + ), + ( + create_csv_file(CSV_DATA), + CSVReaderOptions( + columns_read=[], + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", "1990-02-01"], + ["name2", 25, "city2", "1995-02-01"], + ["name3", 20, "city3", "2000-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA), + CSVReaderOptions( + columns_read=[], + column_data_types={"Age": "float"}, + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30.0, "city1", "1990-02-01"], + ["name2", 25.0, "city2", "1995-02-01"], + ["name3", 20.0, "city3", "2000-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA), + CSVReaderOptions( + rows_to_read=1, + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30.0, "city1", "1990-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA), + CSVReaderOptions( + rows_to_read=1, + columns_read=["Name", "Age"], + ), + ["Name", "Age"], + [ + ["name1", 30.0], + ], + ), + ( + create_csv_file(CSV_DATA), + CSVReaderOptions( + skip_rows=1, + ), + ["name1", "30", "city1", "1990-02-01"], + [ + ["name2", 25.0, "city2", "1995-02-01"], + ["name3", 20.0, "city3", "2000-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA), + CSVReaderOptions( + column_dates=["Birth"], + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", datetime(1990, 2, 1, 0, 0)], + ["name2", 25, "city2", datetime(1995, 2, 1, 0, 0)], + ["name3", 20, "city3", datetime(2000, 2, 1, 0, 0)], + ], + ), + ( + create_csv_file(CSV_DATA_CHANGED_HEADER), + CSVReaderOptions( + header_row=1, + ), + ["Name", "Age", "City", "Birth"], + [ + ["name2", 25, "city2", "1995-02-01"], + ["name3", 20, "city3", "2000-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA_WITH_NULLS), + CSVReaderOptions( + null_values=["N/A", "None"], + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", np.nan, "city1", "1990-02-01"], + ["name2", 25.0, np.nan, "1995-02-01"], + ["name3", 20.0, "city3", "2000-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA_DAY_FIRST), + CSVReaderOptions( + day_first=False, + column_dates=["Birth"], + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", datetime(1990, 1, 2, 0, 0)], + ], + ), + ( + create_csv_file(CSV_DATA_DAY_FIRST), + CSVReaderOptions( + day_first=True, + column_dates=["Birth"], + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", datetime(1990, 2, 1, 0, 0)], + ], + ), + ( + create_csv_file(CSV_DATA_DECIMAL_CHAR), + CSVReaderOptions( + decimal_character=",", + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30.1, "city1", "1990-02-01"], + ], + ), + ( + create_csv_file(CSV_DATA_SKIP_INITIAL_SPACE), + CSVReaderOptions( + skip_initial_space=True, + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", "1990-02-01"], + ], + ), + ], +) +def test_csv_reader_file_to_dataframe(file, options, expected_cols, expected_values): + csv_reader = CSVReader( + options=options, + ) + df = csv_reader.file_to_dataframe(file) + assert df.columns.tolist() == expected_cols + actual_values = df.values.tolist() + for i in range(len(expected_values)): + for j in range(len(expected_values[i])): + expected_val = expected_values[i][j] + actual_val = actual_values[i][j] + + # Check if both values are NaN + if isinstance(expected_val, float) and isinstance(actual_val, float): + assert np.isnan(expected_val) == np.isnan(actual_val) + else: + assert expected_val == actual_val + file.close() + + +def test_csv_reader_broken_file_no_columns(): + csv_reader = CSVReader( + options=CSVReaderOptions(), + ) + with pytest.raises(DatabaseUploadFailed) as ex: + csv_reader.file_to_dataframe(create_csv_file([""])) + assert str(ex.value) == "Parsing error: No columns to parse from file" + + +def test_csv_reader_wrong_columns_to_read(): + csv_reader = CSVReader( + options=CSVReaderOptions(columns_read=["xpto"]), + ) + with pytest.raises(DatabaseUploadFailed) as ex: + csv_reader.file_to_dataframe(create_csv_file(CSV_DATA)) + assert str(ex.value) == ( + "Parsing error: Usecols do not match columns, " + "columns expected but not found: ['xpto']" + ) + + +def test_csv_reader_invalid_file(): + csv_reader = CSVReader( + options=CSVReaderOptions(), + ) + with pytest.raises(DatabaseUploadFailed) as ex: + csv_reader.file_to_dataframe( + io.StringIO("c1,c2,c3\na,b,c\n1,2,3,4,5,6,7\n1,2,3") + ) + assert str(ex.value) == ( + "Parsing error: Error tokenizing data. C error:" + " Expected 3 fields in line 3, saw 7\n" + ) + + +def test_csv_reader_invalid_encoding(): + csv_reader = CSVReader( + options=CSVReaderOptions(), + ) + binary_data = b"col1,col2,col3\nv1,v2,\xba\nv3,v4,v5\n" + with pytest.raises(DatabaseUploadFailed) as ex: + csv_reader.file_to_dataframe(io.BytesIO(binary_data)) + assert str(ex.value) == ( + "Parsing error: 'utf-8' codec can't decode byte 0xba in" + " position 21: invalid start byte" + ) diff --git a/tests/unit_tests/commands/databases/excel_reader_test.py b/tests/unit_tests/commands/databases/excel_reader_test.py new file mode 100644 index 000000000..763b1b74c --- /dev/null +++ b/tests/unit_tests/commands/databases/excel_reader_test.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import io +from datetime import datetime +from typing import Any + +import numpy as np +import pytest + +from superset.commands.database.exceptions import DatabaseUploadFailed +from superset.commands.database.uploaders.excel_reader import ( + ExcelReader, + ExcelReaderOptions, +) +from tests.unit_tests.fixtures.common import create_excel_file + +EXCEL_DATA: dict[str, list[Any]] = { + "Name": ["name1", "name2", "name3"], + "Age": [30, 25, 20], + "City": ["city1", "city2", "city3"], + "Birth": ["1990-02-01", "1995-02-01", "2000-02-01"], +} + +EXCEL_WITH_NULLS: dict[str, list[Any]] = { + "Name": ["name1", "name2", "name3"], + "Age": ["N/A", 25, 20], + "City": ["city1", "None", "city3"], + "Birth": ["1990-02-01", "1995-02-01", "2000-02-01"], +} + +EXCEL_DATA_DECIMAL_CHAR = { + "Name": ["name1"], + "Age": ["30,1"], + "City": ["city1"], + "Birth": ["1990-02-01"], +} + + +@pytest.mark.parametrize( + "file, options, expected_cols, expected_values", + [ + ( + create_excel_file(EXCEL_DATA), + ExcelReaderOptions(), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", "1990-02-01"], + ["name2", 25, "city2", "1995-02-01"], + ["name3", 20, "city3", "2000-02-01"], + ], + ), + ( + create_excel_file(EXCEL_DATA), + ExcelReaderOptions( + columns_read=["Name", "Age"], + ), + ["Name", "Age"], + [ + ["name1", 30], + ["name2", 25], + ["name3", 20], + ], + ), + ( + create_excel_file(EXCEL_DATA), + ExcelReaderOptions( + columns_read=[], + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", "1990-02-01"], + ["name2", 25, "city2", "1995-02-01"], + ["name3", 20, "city3", "2000-02-01"], + ], + ), + ( + create_excel_file(EXCEL_DATA), + ExcelReaderOptions( + rows_to_read=1, + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30.0, "city1", "1990-02-01"], + ], + ), + ( + create_excel_file(EXCEL_DATA), + ExcelReaderOptions( + rows_to_read=1, + columns_read=["Name", "Age"], + ), + ["Name", "Age"], + [ + ["name1", 30.0], + ], + ), + ( + create_excel_file(EXCEL_DATA), + ExcelReaderOptions( + skip_rows=1, + ), + ["name1", 30, "city1", "1990-02-01"], + [ + ["name2", 25.0, "city2", "1995-02-01"], + ["name3", 20.0, "city3", "2000-02-01"], + ], + ), + ( + create_excel_file(EXCEL_DATA), + ExcelReaderOptions( + column_dates=["Birth"], + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30, "city1", datetime(1990, 2, 1, 0, 0)], + ["name2", 25, "city2", datetime(1995, 2, 1, 0, 0)], + ["name3", 20, "city3", datetime(2000, 2, 1, 0, 0)], + ], + ), + ( + create_excel_file(EXCEL_WITH_NULLS), + ExcelReaderOptions( + null_values=["N/A", "None"], + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", np.nan, "city1", "1990-02-01"], + ["name2", 25.0, np.nan, "1995-02-01"], + ["name3", 20.0, "city3", "2000-02-01"], + ], + ), + ( + create_excel_file(EXCEL_DATA_DECIMAL_CHAR), + ExcelReaderOptions( + decimal_character=",", + ), + ["Name", "Age", "City", "Birth"], + [ + ["name1", 30.1, "city1", "1990-02-01"], + ], + ), + ], +) +def test_excel_reader_file_to_dataframe(file, options, expected_cols, expected_values): + excel_reader = ExcelReader( + options=options, + ) + df = excel_reader.file_to_dataframe(file) + assert df.columns.tolist() == expected_cols + actual_values = df.values.tolist() + for i in range(len(expected_values)): + for j in range(len(expected_values[i])): + expected_val = expected_values[i][j] + actual_val = actual_values[i][j] + + # Check if both values are NaN + if isinstance(expected_val, float) and isinstance(actual_val, float): + assert np.isnan(expected_val) == np.isnan(actual_val) + else: + assert expected_val == actual_val + file.close() + + +def test_excel_reader_wrong_columns_to_read(): + excel_reader = ExcelReader( + options=ExcelReaderOptions(columns_read=["xpto"]), + ) + with pytest.raises(DatabaseUploadFailed) as ex: + excel_reader.file_to_dataframe(create_excel_file(EXCEL_DATA)) + assert str(ex.value) == ( + "Parsing error: Usecols do not match columns, " + "columns expected but not found: ['xpto'] (sheet: 0)" + ) + + +def test_excel_reader_wrong_date(): + excel_reader = ExcelReader( + options=ExcelReaderOptions(column_dates=["xpto"]), + ) + with pytest.raises(DatabaseUploadFailed) as ex: + excel_reader.file_to_dataframe(create_excel_file(EXCEL_DATA)) + assert str(ex.value) == ( + "Parsing error: Missing column provided to 'parse_dates':" " 'xpto' (sheet: 0)" + ) + + +def test_excel_reader_invalid_file(): + excel_reader = ExcelReader( + options=ExcelReaderOptions(), + ) + with pytest.raises(DatabaseUploadFailed) as ex: + excel_reader.file_to_dataframe(io.StringIO("c1")) + assert str(ex.value) == ( + "Parsing error: Excel file format cannot be determined, you must specify an engine manually." + ) diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 7f25f28f1..40bb7a019 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -33,8 +33,9 @@ from pytest_mock import MockFixture from sqlalchemy.orm.session import Session from superset import db -from superset.commands.database.csv_import import CSVImportCommand -from superset.commands.database.excel_import import ExcelImportCommand +from superset.commands.database.uploaders.base import UploadCommand +from superset.commands.database.uploaders.csv_reader import CSVReader +from superset.commands.database.uploaders.excel_reader import ExcelReader from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException @@ -829,7 +830,7 @@ def test_oauth2_error( @pytest.mark.parametrize( - "payload,cmd_called_with", + "payload,upload_called_with,reader_called_with", [ ( { @@ -841,6 +842,10 @@ def test_oauth2_error( 1, "table1", ANY, + None, + ANY, + ), + ( { "already_exists": "fail", "delimiter": ",", @@ -861,6 +866,10 @@ def test_oauth2_error( 1, "table2", ANY, + None, + ANY, + ), + ( { "already_exists": "replace", "column_dates": ["col1", "col2"], @@ -879,7 +888,6 @@ def test_oauth2_error( "columns_read": "col1,col2", "day_first": True, "rows_to_read": "1", - "overwrite_duplicates": True, "skip_blank_lines": True, "skip_initial_space": True, "skip_rows": "10", @@ -890,12 +898,15 @@ def test_oauth2_error( 1, "table2", ANY, + None, + ANY, + ), + ( { "already_exists": "replace", "columns_read": ["col1", "col2"], "null_values": ["None", "N/A", "''"], "day_first": True, - "overwrite_duplicates": True, "rows_to_read": 1, "skip_blank_lines": True, "skip_initial_space": True, @@ -911,7 +922,8 @@ def test_oauth2_error( ) def test_csv_upload( payload: dict[str, Any], - cmd_called_with: tuple[int, str, Any, dict[str, Any]], + upload_called_with: tuple[int, str, Any, dict[str, Any]], + reader_called_with: dict[str, Any], mocker: MockFixture, client: Any, full_api_access: None, @@ -919,9 +931,11 @@ def test_csv_upload( """ Test CSV Upload success. """ - init_mock = mocker.patch.object(CSVImportCommand, "__init__") + init_mock = mocker.patch.object(UploadCommand, "__init__") init_mock.return_value = None - _ = mocker.patch.object(CSVImportCommand, "run") + _ = mocker.patch.object(UploadCommand, "run") + reader_mock = mocker.patch.object(CSVReader, "__init__") + reader_mock.return_value = None response = client.post( f"/api/v1/database/1/csv_upload/", data=payload, @@ -929,7 +943,8 @@ def test_csv_upload( ) assert response.status_code == 200 assert response.json == {"message": "OK"} - init_mock.assert_called_with(*cmd_called_with) + init_mock.assert_called_with(*upload_called_with) + reader_mock.assert_called_with(*reader_called_with) @pytest.mark.parametrize( @@ -994,16 +1009,6 @@ def test_csv_upload( }, {"message": {"header_row": ["Not a valid integer."]}}, ), - ( - { - "file": (create_csv_file(), "out.csv"), - "table_name": "table1", - "delimiter": ",", - "already_exists": "fail", - "overwrite_duplicates": "test1", - }, - {"message": {"overwrite_duplicates": ["Not a valid boolean."]}}, - ), ( { "file": (create_csv_file(), "out.csv"), @@ -1066,7 +1071,7 @@ def test_csv_upload_validation( """ Test CSV Upload validation fails. """ - _ = mocker.patch.object(CSVImportCommand, "run") + _ = mocker.patch.object(UploadCommand, "run") response = client.post( f"/api/v1/database/1/csv_upload/", @@ -1085,7 +1090,7 @@ def test_csv_upload_file_size_validation( """ Test CSV Upload validation fails. """ - _ = mocker.patch.object(CSVImportCommand, "run") + _ = mocker.patch.object(UploadCommand, "run") current_app.config["CSV_UPLOAD_MAX_SIZE"] = 5 response = client.post( f"/api/v1/database/1/csv_upload/", @@ -1127,7 +1132,7 @@ def test_csv_upload_file_extension_invalid( """ Test CSV Upload validation fails. """ - _ = mocker.patch.object(CSVImportCommand, "run") + _ = mocker.patch.object(UploadCommand, "run") response = client.post( f"/api/v1/database/1/csv_upload/", data={ @@ -1163,7 +1168,7 @@ def test_csv_upload_file_extension_valid( """ Test CSV Upload validation fails. """ - _ = mocker.patch.object(CSVImportCommand, "run") + _ = mocker.patch.object(UploadCommand, "run") response = client.post( f"/api/v1/database/1/csv_upload/", data={ @@ -1177,7 +1182,7 @@ def test_csv_upload_file_extension_valid( @pytest.mark.parametrize( - "payload,cmd_called_with", + "payload,upload_called_with,reader_called_with", [ ( { @@ -1188,6 +1193,10 @@ def test_csv_upload_file_extension_valid( 1, "table1", ANY, + None, + ANY, + ), + ( { "already_exists": "fail", "file": ANY, @@ -1207,6 +1216,10 @@ def test_csv_upload_file_extension_valid( 1, "table2", ANY, + None, + ANY, + ), + ( { "already_exists": "replace", "column_dates": ["col1", "col2"], @@ -1231,6 +1244,10 @@ def test_csv_upload_file_extension_valid( 1, "table2", ANY, + None, + ANY, + ), + ( { "already_exists": "replace", "columns_read": ["col1", "col2"], @@ -1247,7 +1264,8 @@ def test_csv_upload_file_extension_valid( ) def test_excel_upload( payload: dict[str, Any], - cmd_called_with: tuple[int, str, Any, dict[str, Any]], + upload_called_with: tuple[int, str, Any, dict[str, Any]], + reader_called_with: dict[str, Any], mocker: MockFixture, client: Any, full_api_access: None, @@ -1255,9 +1273,11 @@ def test_excel_upload( """ Test Excel Upload success. """ - init_mock = mocker.patch.object(ExcelImportCommand, "__init__") + init_mock = mocker.patch.object(UploadCommand, "__init__") init_mock.return_value = None - _ = mocker.patch.object(ExcelImportCommand, "run") + _ = mocker.patch.object(UploadCommand, "run") + reader_mock = mocker.patch.object(ExcelReader, "__init__") + reader_mock.return_value = None response = client.post( f"/api/v1/database/1/excel_upload/", data=payload, @@ -1265,7 +1285,8 @@ def test_excel_upload( ) assert response.status_code == 200 assert response.json == {"message": "OK"} - init_mock.assert_called_with(*cmd_called_with) + init_mock.assert_called_with(*upload_called_with) + reader_mock.assert_called_with(*reader_called_with) @pytest.mark.parametrize( @@ -1347,7 +1368,7 @@ def test_excel_upload_validation( """ Test Excel Upload validation fails. """ - _ = mocker.patch.object(ExcelImportCommand, "run") + _ = mocker.patch.object(UploadCommand, "run") response = client.post( f"/api/v1/database/1/excel_upload/", @@ -1382,7 +1403,7 @@ def test_excel_upload_file_extension_invalid( """ Test Excel Upload file extension fails. """ - _ = mocker.patch.object(ExcelImportCommand, "run") + _ = mocker.patch.object(UploadCommand, "run") response = client.post( f"/api/v1/database/1/excel_upload/", data={ diff --git a/tests/unit_tests/fixtures/common.py b/tests/unit_tests/fixtures/common.py index 841566971..a360d38a6 100644 --- a/tests/unit_tests/fixtures/common.py +++ b/tests/unit_tests/fixtures/common.py @@ -31,7 +31,7 @@ def dttm() -> datetime: return datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f") -def create_csv_file(data: list[list[str]] | None = None) -> BytesIO: +def create_csv_file(data: list[list[str]] | None = None, delimiter=",") -> BytesIO: data = ( [ ["Name", "Age", "City"], @@ -42,7 +42,7 @@ def create_csv_file(data: list[list[str]] | None = None) -> BytesIO: ) output = StringIO() - writer = csv.writer(output) + writer = csv.writer(output, delimiter=delimiter) for row in data: writer.writerow(row) output.seek(0)