diff --git a/superset-frontend/src/datasource/DatasourceEditor.jsx b/superset-frontend/src/datasource/DatasourceEditor.jsx index 70abc4693..56bc58a9d 100644 --- a/superset-frontend/src/datasource/DatasourceEditor.jsx +++ b/superset-frontend/src/datasource/DatasourceEditor.jsx @@ -392,14 +392,9 @@ class DatasourceEditor extends React.PureComponent { syncMetadata() { const { datasource } = this.state; - // Handle carefully when the schema is empty - const endpoint = - `/datasource/external_metadata/${ - datasource.type || datasource.datasource_type - }/${datasource.id}/` + - `?db_id=${datasource.database.id}` + - `&schema=${datasource.schema || ''}` + - `&table_name=${datasource.datasource_name || datasource.table_name}`; + const endpoint = `/datasource/external_metadata/${ + datasource.type || datasource.datasource_type + }/${datasource.id}/`; this.setState({ metadataLoading: true }); SupersetClient.get({ endpoint }) @@ -930,12 +925,6 @@ class DatasourceEditor extends React.PureComponent { buttonStyle="primary" onClick={this.syncMetadata} className="sync-from-source" - disabled={!!datasource.sql} - tooltip={ - datasource.sql - ? t('This option is not yet available for views') - : null - } > {t('Sync columns from source')} diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index b93f5cdda..c7ab7986a 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -17,6 +17,7 @@ import json import logging from collections import defaultdict, OrderedDict +from contextlib import closing from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union @@ -44,18 +45,24 @@ from sqlalchemy import ( Table, Text, ) -from sqlalchemy.exc import CompileError, SQLAlchemyError +from sqlalchemy.exc import CompileError from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table, text from sqlalchemy.sql.expression import Label, Select, TextAsFrom +from sqlalchemy.types import TypeEngine from superset import app, db, is_feature_enabled, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.constants import NULL_STRING from superset.db_engine_specs.base import TimestampExpression -from superset.exceptions import DatabaseNotFound, QueryObjectValidationError +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import ( + DatabaseNotFound, + QueryObjectValidationError, + SupersetSecurityException, +) from superset.jinja_context import ( BaseTemplateProcessor, ExtraCache, @@ -64,6 +71,7 @@ from superset.jinja_context import ( from superset.models.annotations import Annotation from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, QueryResult +from superset.result_set import SupersetResultSet from superset.sql_parse import ParsedQuery from superset.typing import Metric, QueryObjectDict from superset.utils import core as utils, import_datasource @@ -643,12 +651,52 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at return self.database.sql_url + "?table_name=" + str(self.table_name) def external_metadata(self) -> List[Dict[str, str]]: - cols = self.database.get_columns(self.table_name, schema=self.schema) - for col in cols: - try: - col["type"] = str(col["type"]) - except CompileError: - col["type"] = "UNKNOWN" + db_engine_spec = self.database.db_engine_spec + if self.sql: + engine = self.database.get_sqla_engine(schema=self.schema) + sql = self.get_template_processor().process_template(self.sql) + parsed_query = ParsedQuery(sql) + if not parsed_query.is_readonly(): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + message=_("Only `SELECT` statements are allowed"), + level=ErrorLevel.ERROR, + ) + ) + statements = parsed_query.get_statements() + if len(statements) > 1: + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + message=_("Only single queries supported"), + level=ErrorLevel.ERROR, + ) + ) + # TODO(villebro): refactor to use same code that's used by + # sql_lab.py:execute_sql_statements + with closing(engine.raw_connection()) as conn: + with closing(conn.cursor()) as cursor: + query = self.database.apply_limit_to_sql(statements[0]) + db_engine_spec.execute(cursor, query) + result = db_engine_spec.fetch_data(cursor, limit=1) + result_set = SupersetResultSet( + result, cursor.description, db_engine_spec + ) + cols = result_set.columns + else: + db_dialect = self.database.get_dialect() + cols = self.database.get_columns( + self.table_name, schema=self.schema or None + ) + for col in cols: + try: + if isinstance(col["type"], TypeEngine): + col["type"] = db_engine_spec.column_datatype_to_string( + col["type"], db_dialect + ) + except CompileError: + col["type"] = "UNKNOWN" return cols @property @@ -1310,21 +1358,10 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at :param commit: should the changes be committed or not. :return: Tuple with lists of added, removed and modified column names. """ - try: - new_table = self.get_sqla_table_object() - except SQLAlchemyError: - raise QueryObjectValidationError( - _( - "Table %(table)s doesn't seem to exist in the specified database, " - "couldn't fetch column information", - table=self.table_name, - ) - ) - + new_columns = self.external_metadata() metrics = [] any_date_col = None db_engine_spec = self.database.db_engine_spec - db_dialect = self.database.get_dialect() old_columns = db.session.query(TableColumn).filter(TableColumn.table == self) old_columns_by_name = {col.column_name: col for col in old_columns} @@ -1332,39 +1369,31 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at removed=[ col for col in old_columns_by_name - if col not in {col.name for col in new_table.columns} + if col not in {col["name"] for col in new_columns} ] ) # clear old columns before adding modified columns back self.columns = [] - for col in new_table.columns: - try: - datatype = db_engine_spec.column_datatype_to_string( - col.type, db_dialect - ) - except Exception as ex: # pylint: disable=broad-except - datatype = "UNKNOWN" - logger.error("Unrecognized data type in %s.%s", new_table, col.name) - logger.exception(ex) - old_column = old_columns_by_name.get(col.name, None) + for col in new_columns: + old_column = old_columns_by_name.get(col["name"], None) if not old_column: - results.added.append(col.name) + results.added.append(col["name"]) new_column = TableColumn( - column_name=col.name, type=datatype, table=self + column_name=col["name"], type=col["type"], table=self ) new_column.is_dttm = new_column.is_temporal db_engine_spec.alter_new_orm_column(new_column) else: new_column = old_column - if new_column.type != datatype: - results.modified.append(col.name) - new_column.type = datatype + if new_column.type != col["type"]: + results.modified.append(col["name"]) + new_column.type = col["type"] new_column.groupby = True new_column.filterable = True self.columns.append(new_column) if not any_date_col and new_column.is_temporal: - any_date_col = col.name + any_date_col = col["name"] metrics.append( SqlMetric( metric_name="count", diff --git a/superset/views/datasource.py b/superset/views/datasource.py index 2ce11027c..92cf5c154 100644 --- a/superset/views/datasource.py +++ b/superset/views/datasource.py @@ -24,7 +24,7 @@ from sqlalchemy.orm.exc import NoResultFound from superset import db from superset.connectors.connector_registry import ConnectorRegistry -from superset.models.core import Database +from superset.exceptions import SupersetException from superset.typing import FlaskResponse from .base import api, BaseSupersetView, handle_api_exception, json_error_response @@ -100,21 +100,11 @@ class Datasource(BaseSupersetView): self, datasource_type: str, datasource_id: int ) -> FlaskResponse: """Gets column info from the source system""" - if datasource_type == "druid": + try: datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session ) - elif datasource_type == "table": - database = ( - db.session.query(Database).filter_by(id=request.args.get("db_id")).one() - ) - table_class = ConnectorRegistry.sources["table"] - datasource = table_class( - database=database, - table_name=request.args.get("table_name"), - schema=request.args.get("schema") or None, - ) - else: - raise Exception(f"Unsupported datasource_type: {datasource_type}") - external_metadata = datasource.external_metadata() - return self.json_response(external_metadata) + external_metadata = datasource.external_metadata() + return self.json_response(external_metadata) + except SupersetException as ex: + return json_error_response(str(ex), status=400) diff --git a/tests/datasource_tests.py b/tests/datasource_tests.py index 5fd81c0e1..31a4c9633 100644 --- a/tests/datasource_tests.py +++ b/tests/datasource_tests.py @@ -18,27 +18,82 @@ import json from copy import deepcopy +from superset import db +from superset.connectors.sqla.models import SqlaTable +from superset.utils.core import get_example_database + from .base_tests import SupersetTestCase from .fixtures.datasource import datasource_post class TestDatasource(SupersetTestCase): - def test_external_metadata(self): + def test_external_metadata_for_physical_table(self): self.login(username="admin") tbl = self.get_table_by_name("birth_names") - schema = tbl.schema or "" - url = ( - f"/datasource/external_metadata/table/{tbl.id}/?" - f"db_id={tbl.database.id}&" - f"table_name={tbl.table_name}&" - f"schema={schema}&" - ) + url = f"/datasource/external_metadata/table/{tbl.id}/" resp = self.get_json_resp(url) col_names = {o.get("name") for o in resp} self.assertEqual( col_names, {"sum_boys", "num", "gender", "name", "ds", "state", "sum_girls"} ) + def test_external_metadata_for_virtual_table(self): + self.login(username="admin") + session = db.session + table = SqlaTable( + table_name="dummy_sql_table", + database=get_example_database(), + sql="select 123 as intcol, 'abc' as strcol", + ) + session.add(table) + session.commit() + + table = self.get_table_by_name("dummy_sql_table") + url = f"/datasource/external_metadata/table/{table.id}/" + resp = self.get_json_resp(url) + assert {o.get("name") for o in resp} == {"intcol", "strcol"} + session.delete(table) + session.commit() + + def test_external_metadata_for_malicious_virtual_table(self): + self.login(username="admin") + session = db.session + table = SqlaTable( + table_name="malicious_sql_table", + database=get_example_database(), + sql="delete table birth_names", + ) + session.add(table) + session.commit() + + table = self.get_table_by_name("malicious_sql_table") + url = f"/datasource/external_metadata/table/{table.id}/" + resp = self.get_json_resp(url) + assert "error" in resp + + session.delete(table) + session.commit() + + def test_external_metadata_for_mutistatement_virtual_table(self): + self.login(username="admin") + session = db.session + table = SqlaTable( + table_name="multistatement_sql_table", + database=get_example_database(), + sql="select 123 as intcol, 'abc' as strcol;" + "select 123 as intcol, 'abc' as strcol", + ) + session.add(table) + session.commit() + + table = self.get_table_by_name("multistatement_sql_table") + url = f"/datasource/external_metadata/table/{table.id}/" + resp = self.get_json_resp(url) + assert "error" in resp + + session.delete(table) + session.commit() + def compare_lists(self, l1, l2, key): l2_lookup = {o.get(key): o for o in l2} for obj1 in l1: diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index ed6ae7847..0ca8dbdb9 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -113,7 +113,7 @@ class TestImportExport(SupersetTestCase): json_metadata=json.dumps(json_metadata), ) - def create_table(self, name, schema="", id=0, cols_names=[], metric_names=[]): + def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]): params = {"remote_id": id, "database_name": "examples"} table = SqlaTable( id=id, schema=schema, table_name=name, params=json.dumps(params)