diff --git a/superset/config.py b/superset/config.py index 1668f5722..e3d3ffbf8 100644 --- a/superset/config.py +++ b/superset/config.py @@ -586,11 +586,27 @@ CSV_TO_HIVE_UPLOAD_S3_BUCKET = None # The directory within the bucket specified above that will # contain all the external tables CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/" +# Function that creates upload directory dynamically based on the +# database used, user and schema provided. +CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC: Callable[ + ["Database", "models.User", str], Optional[str] +] = lambda database, user, schema: CSV_TO_HIVE_UPLOAD_DIRECTORY # The namespace within hive where the tables created from # uploading CSVs will be stored. UPLOADED_CSV_HIVE_NAMESPACE = None +# Function that computes the allowed schemas for the CSV uploads. +# Allowed schemas will be a union of schemas_allowed_for_csv_upload +# db configuration and a result of this function. + +# mypy doesn't catch that if case ensures list content being always str +ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[ + ["Database", "models.User"], List[str] +] = lambda database, user: [ + UPLOADED_CSV_HIVE_NAMESPACE # type: ignore +] if UPLOADED_CSV_HIVE_NAMESPACE else [] + # A dictionary of items that gets merged into the Jinja context for # SQL Lab. The existing context gets updated with this dictionary, # meaning values for existing keys get overwritten by the content of this diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 845b22c5b..a593f5900 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -18,7 +18,6 @@ import hashlib import json import logging -import os import re from contextlib import closing from datetime import datetime @@ -49,11 +48,11 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom from sqlalchemy.types import TypeEngine -from wtforms.form import Form from superset import app, sql_parse from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query +from superset.sql_parse import Table from superset.utils import core as utils if TYPE_CHECKING: @@ -454,55 +453,26 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods df.to_sql(**kwargs) @classmethod - def create_table_from_csv(cls, form: Form, database: "Database") -> None: + def create_table_from_csv( # pylint: disable=too-many-arguments + cls, + filename: str, + table: Table, + database: "Database", + csv_to_df_kwargs: Dict[str, Any], + df_to_sql_kwargs: Dict[str, Any], + ) -> None: """ Create table from contents of a csv. Note: this method does not create metadata for the table. - - :param form: Parameters defining how to process data - :param database: Database model object for the target database """ - - def _allowed_file(filename: str) -> bool: - # Only allow specific file extensions as specified in the config - extension = os.path.splitext(filename)[1].lower() - return ( - extension is not None and extension[1:] in config["ALLOWED_EXTENSIONS"] - ) - - filename = form.csv_file.data.filename - - if not _allowed_file(filename): - raise Exception("Invalid file type selected") - csv_to_df_kwargs = { - "filepath_or_buffer": filename, - "sep": form.sep.data, - "header": form.header.data if form.header.data else 0, - "index_col": form.index_col.data, - "mangle_dupe_cols": form.mangle_dupe_cols.data, - "skipinitialspace": form.skipinitialspace.data, - "skiprows": form.skiprows.data, - "nrows": form.nrows.data, - "skip_blank_lines": form.skip_blank_lines.data, - "parse_dates": form.parse_dates.data, - "infer_datetime_format": form.infer_datetime_format.data, - "chunksize": 10000, - } - df = cls.csv_to_df(**csv_to_df_kwargs) - + df = cls.csv_to_df(filepath_or_buffer=filename, **csv_to_df_kwargs,) engine = cls.get_engine(database) - - df_to_sql_kwargs = { - "df": df, - "name": form.name.data, - "con": engine, - "schema": form.schema.data, - "if_exists": form.if_exists.data, - "index": form.index.data, - "index_label": form.index_label.data, - "chunksize": 10000, - } - cls.df_to_sql(**df_to_sql_kwargs) + if table.schema: + # only add schema when it is preset and non empty + df_to_sql_kwargs["schema"] = table.schema + if engine.dialect.supports_multivalues_insert: + df_to_sql_kwargs["method"] = "multi" + cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs) @classmethod def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index cb810f10a..3fb09ef5f 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -23,18 +23,19 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from urllib import parse import pandas as pd +from flask import g from sqlalchemy import Column from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select -from wtforms.form import Form from superset import app, cache, conf from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.presto import PrestoEngineSpec from superset.models.sql_lab import Query +from superset.sql_parse import Table from superset.utils import core as utils if TYPE_CHECKING: @@ -105,8 +106,13 @@ class HiveEngineSpec(PrestoEngineSpec): return [] @classmethod - def create_table_from_csv( # pylint: disable=too-many-locals - cls, form: Form, database: "Database" + def create_table_from_csv( # pylint: disable=too-many-arguments, too-many-locals + cls, + filename: str, + table: Table, + database: "Database", + csv_to_df_kwargs: Dict[str, Any], + df_to_sql_kwargs: Dict[str, Any], ) -> None: """Uploads a csv file and creates a superset datasource in Hive.""" @@ -128,38 +134,16 @@ class HiveEngineSpec(PrestoEngineSpec): "No upload bucket specified. You can specify one in the config file." ) - table_name = form.name.data - schema_name = form.schema.data - - if config["UPLOADED_CSV_HIVE_NAMESPACE"]: - if "." in table_name or schema_name: - raise Exception( - "You can't specify a namespace. " - "All tables will be uploaded to the `{}` namespace".format( - config["HIVE_NAMESPACE"] - ) - ) - full_table_name = "{}.{}".format( - config["UPLOADED_CSV_HIVE_NAMESPACE"], table_name - ) - else: - if "." in table_name and schema_name: - raise Exception( - "You can't specify a namespace both in the name of the table " - "and in the schema field. Please remove one" - ) - - full_table_name = ( - "{}.{}".format(schema_name, table_name) if schema_name else table_name - ) - - filename = form.csv_file.data.filename - upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY"] + upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]( + database, g.user, table.schema + ) # Optional dependency - from tableschema import Table # pylint: disable=import-error + from tableschema import ( # pylint: disable=import-error + Table as TableSchemaTable, + ) - hive_table_schema = Table(filename).infer() + hive_table_schema = TableSchemaTable(filename).infer() column_name_and_type = [] for column_info in hive_table_schema["fields"]: column_name_and_type.append( @@ -173,13 +157,14 @@ class HiveEngineSpec(PrestoEngineSpec): import boto3 # pylint: disable=import-error s3 = boto3.client("s3") - location = os.path.join("s3a://", bucket_path, upload_prefix, table_name) + location = os.path.join("s3a://", bucket_path, upload_prefix, table.table) s3.upload_file( filename, bucket_path, - os.path.join(upload_prefix, table_name, os.path.basename(filename)), + os.path.join(upload_prefix, table.table, os.path.basename(filename)), ) - sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} ) + # TODO(bkyryliuk): support other delimiters + sql = f"""CREATE TABLE {str(table)} ( {schema_definition} ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE LOCATION '{location}' tblproperties ('skip.header.line.count'='1')""" diff --git a/superset/models/core.py b/superset/models/core.py index 94383c87d..abcb210d7 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -609,7 +609,13 @@ class Database( def get_schema_access_for_csv_upload( # pylint: disable=invalid-name self, ) -> List[str]: - return self.get_extra().get("schemas_allowed_for_csv_upload", []) + allowed_databases = self.get_extra().get("schemas_allowed_for_csv_upload", []) + if hasattr(g, "user"): + extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"]( + self, g.user + ) + allowed_databases += extra_allowed_databases + return sorted(set(allowed_databases)) @property def sqlalchemy_uri_decrypted(self) -> str: diff --git a/superset/views/database/views.py b/superset/views/database/views.py index 47bf1c3b8..208d35166 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -30,6 +30,7 @@ from superset import app, db from superset.connectors.sqla.models import SqlaTable from superset.constants import RouteMethod from superset.exceptions import CertificateException +from superset.sql_parse import Table from superset.utils import core as utils from superset.views.base import DeleteMixin, SupersetModelView, YamlExportMixin @@ -109,66 +110,116 @@ class CsvToDatabaseView(SimpleFormView): def form_post(self, form): database = form.con.data - schema_name = form.schema.data or "" + csv_table = Table(table=form.name.data, schema=form.schema.data) - if not schema_allows_csv_upload(database, schema_name): + if not schema_allows_csv_upload(database, csv_table.schema): message = _( 'Database "%(database_name)s" schema "%(schema_name)s" ' "is not allowed for csv uploads. Please contact your Superset Admin.", database_name=database.database_name, - schema_name=schema_name, + schema_name=csv_table.schema, ) flash(message, "danger") return redirect("/csvtodatabaseview/form") - csv_filename = form.csv_file.data.filename - extension = os.path.splitext(csv_filename)[1].lower() - path = tempfile.NamedTemporaryFile( - dir=app.config["UPLOAD_FOLDER"], suffix=extension, delete=False + if "." in csv_table.table and csv_table.schema: + message = _( + "You cannot specify a namespace both in the name of the table: " + '"%(csv_table.table)s" and in the schema field: ' + '"%(csv_table.schema)s". Please remove one', + table=csv_table.table, + schema=csv_table.schema, + ) + flash(message, "danger") + return redirect("/csvtodatabaseview/form") + + uploaded_tmp_file_path = tempfile.NamedTemporaryFile( + dir=app.config["UPLOAD_FOLDER"], + suffix=os.path.splitext(form.csv_file.data.filename)[1].lower(), + delete=False, ).name - form.csv_file.data.filename = path try: utils.ensure_path_exists(config["UPLOAD_FOLDER"]) - upload_stream_write(form.csv_file.data, path) - table_name = form.name.data + upload_stream_write(form.csv_file.data, uploaded_tmp_file_path) con = form.data.get("con") database = ( db.session.query(models.Database).filter_by(id=con.data.get("id")).one() ) - database.db_engine_spec.create_table_from_csv(form, database) - table = ( + csv_to_df_kwargs = { + "sep": form.sep.data, + "header": form.header.data if form.header.data else 0, + "index_col": form.index_col.data, + "mangle_dupe_cols": form.mangle_dupe_cols.data, + "skipinitialspace": form.skipinitialspace.data, + "skiprows": form.skiprows.data, + "nrows": form.nrows.data, + "skip_blank_lines": form.skip_blank_lines.data, + "parse_dates": form.parse_dates.data, + "infer_datetime_format": form.infer_datetime_format.data, + "chunksize": 1000, + } + df_to_sql_kwargs = { + "name": csv_table.table, + "if_exists": form.if_exists.data, + "index": form.index.data, + "index_label": form.index_label.data, + "chunksize": 1000, + } + database.db_engine_spec.create_table_from_csv( + uploaded_tmp_file_path, + csv_table, + database, + csv_to_df_kwargs, + df_to_sql_kwargs, + ) + + # Connect table to the database that should be used for exploration. + # E.g. if hive was used to upload a csv, presto will be a better option + # to explore the table. + expore_database = database + explore_database_id = database.get_extra().get("explore_database_id", None) + if explore_database_id: + expore_database = ( + db.session.query(models.Database) + .filter_by(id=explore_database_id) + .one_or_none() + or database + ) + + sqla_table = ( db.session.query(SqlaTable) .filter_by( - table_name=table_name, - schema=form.schema.data, - database_id=database.id, + table_name=csv_table.table, + schema=csv_table.schema, + database_id=expore_database.id, ) .one_or_none() ) - if table: - table.fetch_metadata() - if not table: - table = SqlaTable(table_name=table_name) - table.database = database - table.database_id = database.id - table.user_id = g.user.id - table.schema = form.schema.data - table.fetch_metadata() - db.session.add(table) + + if sqla_table: + sqla_table.fetch_metadata() + if not sqla_table: + sqla_table = SqlaTable(table_name=csv_table.table) + sqla_table.database = expore_database + sqla_table.database_id = database.id + sqla_table.user_id = g.user.id + sqla_table.schema = csv_table.schema + sqla_table.fetch_metadata() + db.session.add(sqla_table) db.session.commit() except Exception as ex: # pylint: disable=broad-except db.session.rollback() try: - os.remove(path) + os.remove(uploaded_tmp_file_path) except OSError: pass message = _( 'Unable to upload CSV file "%(filename)s" to table ' '"%(table_name)s" in database "%(db_name)s". ' "Error message: %(error_msg)s", - filename=csv_filename, + filename=form.csv_file.data.filename, table_name=form.name.data, db_name=database.database_name, error_msg=str(ex), @@ -178,14 +229,14 @@ class CsvToDatabaseView(SimpleFormView): stats_logger.incr("failed_csv_upload") return redirect("/csvtodatabaseview/form") - os.remove(path) + os.remove(uploaded_tmp_file_path) # Go back to welcome page / splash screen message = _( 'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in ' 'database "%(db_name)s"', - csv_filename=csv_filename, - table_name=form.name.data, - db_name=table.database.database_name, + csv_filename=form.csv_file.data.filename, + table_name=str(csv_table), + db_name=sqla_table.database.database_name, ) flash(message, "info") stats_logger.incr("successful_csv_upload") diff --git a/tests/core_tests.py b/tests/core_tests.py index 5c4f2b71a..12293c1af 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -24,6 +24,8 @@ import io import json import logging import os +from typing import Dict, List, Optional + import pytz import random import re @@ -44,6 +46,7 @@ from superset import ( is_feature_enabled, ) from superset.connectors.sqla.models import SqlaTable +from superset.datasets.dao import DatasetDAO from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec from superset.models import core as models @@ -769,102 +772,163 @@ class CoreTests(SupersetTestCase): self.get_json_resp(slc_url, {"form_data": json.dumps(slc.form_data)}) self.assertEqual(1, qry.count()) - def test_import_csv(self): - self.login(username="admin") - table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5)) + def create_sample_csvfile(self, filename: str, content: List[str]) -> None: + with open(filename, "w+") as test_file: + for l in content: + test_file.write(f"{l}\n") - filename_1 = "testCSV.csv" - test_file_1 = open(filename_1, "w+") - test_file_1.write("a,b\n") - test_file_1.write("john,1\n") - test_file_1.write("paul,2\n") - test_file_1.close() - - filename_2 = "testCSV2.csv" - test_file_2 = open(filename_2, "w+") - test_file_2.write("b,c,d\n") - test_file_2.write("john,1,x\n") - test_file_2.write("paul,2,y\n") - test_file_2.close() - - example_db = utils.get_example_database() - example_db.allow_csv_upload = True - db_id = example_db.id + def enable_csv_upload(self, database: models.Database) -> None: + """Enables csv upload in the given database.""" + database.allow_csv_upload = True db.session.commit() + add_datasource_page = self.get_resp("/databaseview/list/") + self.assertIn("Upload a CSV", add_datasource_page) + + form_get = self.get_resp("/csvtodatabaseview/form") + self.assertIn("CSV to Database configuration", form_get) + + def upload_csv( + self, filename: str, table_name: str, extra: Optional[Dict[str, str]] = None + ): form_data = { - "csv_file": open(filename_1, "rb"), + "csv_file": open(filename, "rb"), "sep": ",", "name": table_name, - "con": db_id, + "con": utils.get_example_database().id, "if_exists": "fail", "index_label": "test_label", "mangle_dupe_cols": False, } - url = "/databaseview/list/" - add_datasource_page = self.get_resp(url) - self.assertIn("Upload a CSV", add_datasource_page) + if extra: + form_data.update(extra) + return self.get_resp("/csvtodatabaseview/form", data=form_data) - url = "/csvtodatabaseview/form" - form_get = self.get_resp(url) - self.assertIn("CSV to Database configuration", form_get) + @mock.patch( + "superset.models.core.config", + {**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": lambda d, u: ["admin_database"]}, + ) + def test_import_csv_enforced_schema(self): + if utils.get_example_database().backend == "sqlite": + # sqlite doesn't support schema / database creation + return + self.login(username="admin") + table_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) + full_table_name = f"admin_database.{table_name}" + filename = "testCSV.csv" + self.create_sample_csvfile(filename, ["a,b", "john,1", "paul,2"]) + try: + self.enable_csv_upload(utils.get_example_database()) + + # no schema specified, fail upload + resp = self.upload_csv(filename, table_name) + self.assertIn( + 'Database "examples" schema "None" is not allowed for csv uploads', resp + ) + + # user specified schema matches the expected schema, append + success_msg = f'CSV file "{filename}" uploaded to table "{full_table_name}"' + resp = self.upload_csv( + filename, + table_name, + extra={"schema": "admin_database", "if_exists": "append"}, + ) + self.assertIn(success_msg, resp) + + resp = self.upload_csv( + filename, + table_name, + extra={"schema": "admin_database", "if_exists": "replace"}, + ) + self.assertIn(success_msg, resp) + + # user specified schema doesn't match, fail + resp = self.upload_csv(filename, table_name, extra={"schema": "gold"}) + self.assertIn( + 'Database "examples" schema "gold" is not allowed for csv uploads', + resp, + ) + finally: + os.remove(filename) + + def test_import_csv_explore_database(self): + if utils.get_example_database().backend == "sqlite": + # sqlite doesn't support schema / database creation + return + explore_db_id = utils.get_example_database().id + + upload_db = utils.get_or_create_db( + "csv_explore_db", app.config["SQLALCHEMY_DATABASE_URI"] + ) + upload_db_id = upload_db.id + extra = upload_db.get_extra() + extra["explore_database_id"] = explore_db_id + upload_db.extra = json.dumps(extra) + db.session.commit() + + self.login(username="admin") + self.enable_csv_upload(DatasetDAO.get_database_by_id(upload_db_id)) + table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5)) + + f = "testCSV.csv" + self.create_sample_csvfile(f, ["a,b", "john,1", "paul,2"]) + # initial upload with fail mode + resp = self.upload_csv(f, table_name) + self.assertIn(f'CSV file "{f}" uploaded to table "{table_name}"', resp) + table = self.get_table_by_name(table_name) + self.assertEqual(table.database_id, explore_db_id) + + # cleanup + db.session.delete(table) + db.session.delete(DatasetDAO.get_database_by_id(upload_db_id)) + db.session.commit() + os.remove(f) + + def test_import_csv(self): + self.login(username="admin") + table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5)) + + f1 = "testCSV.csv" + self.create_sample_csvfile(f1, ["a,b", "john,1", "paul,2"]) + f2 = "testCSV2.csv" + self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,y"]) + self.enable_csv_upload(utils.get_example_database()) try: + success_msg_f1 = f'CSV file "{f1}" uploaded to table "{table_name}"' + # initial upload with fail mode - resp = self.get_resp(url, data=form_data) - self.assertIn( - f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp - ) + resp = self.upload_csv(f1, table_name) + self.assertIn(success_msg_f1, resp) # upload again with fail mode; should fail - form_data["csv_file"] = open(filename_1, "rb") - resp = self.get_resp(url, data=form_data) - self.assertIn( - f'Unable to upload CSV file "{filename_1}" to table "{table_name}"', - resp, - ) + fail_msg = f'Unable to upload CSV file "{f1}" to table "{table_name}"' + resp = self.upload_csv(f1, table_name) + self.assertIn(fail_msg, resp) # upload again with append mode - form_data["csv_file"] = open(filename_1, "rb") - form_data["if_exists"] = "append" - resp = self.get_resp(url, data=form_data) - self.assertIn( - f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp - ) + resp = self.upload_csv(f1, table_name, extra={"if_exists": "append"}) + self.assertIn(success_msg_f1, resp) # upload again with replace mode - form_data["csv_file"] = open(filename_1, "rb") - form_data["if_exists"] = "replace" - resp = self.get_resp(url, data=form_data) - self.assertIn( - f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp - ) + resp = self.upload_csv(f1, table_name, extra={"if_exists": "replace"}) + self.assertIn(success_msg_f1, resp) # try to append to table from file with different schema - form_data["csv_file"] = open(filename_2, "rb") - form_data["if_exists"] = "append" - resp = self.get_resp(url, data=form_data) - self.assertIn( - f'Unable to upload CSV file "{filename_2}" to table "{table_name}"', - resp, - ) + resp = self.upload_csv(f2, table_name, extra={"if_exists": "append"}) + fail_msg_f2 = f'Unable to upload CSV file "{f2}" to table "{table_name}"' + self.assertIn(fail_msg_f2, resp) # replace table from file with different schema - form_data["csv_file"] = open(filename_2, "rb") - form_data["if_exists"] = "replace" - resp = self.get_resp(url, data=form_data) - self.assertIn( - f'CSV file "{filename_2}" uploaded to table "{table_name}"', resp - ) - table = ( - db.session.query(SqlaTable) - .filter_by(table_name=table_name, database_id=db_id) - .first() - ) + resp = self.upload_csv(f2, table_name, extra={"if_exists": "replace"}) + success_msg_f2 = f'CSV file "{f2}" uploaded to table "{table_name}"' + self.assertIn(success_msg_f2, resp) + + table = self.get_table_by_name(table_name) # make sure the new column name is reflected in the table metadata self.assertIn("d", table.column_names) finally: - os.remove(filename_1) - os.remove(filename_2) + os.remove(f1) + os.remove(f2) def test_dataframe_timezone(self): tz = pytz.FixedOffset(60)