feat(sqllab): TRINO_EXPAND_ROWS: expand columns from ROWs (#25809)

This commit is contained in:
Rob Moore 2023-11-20 17:59:10 +00:00 committed by GitHub
parent 411dba240b
commit 8d73ab9955
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 268 additions and 23 deletions

View File

@ -202,7 +202,7 @@ const ExtraOptions = ({
/>
</div>
</StyledInputContainer>
<StyledInputContainer>
<StyledInputContainer css={no_margin_bottom}>
<div className="input-container">
<IndeterminateCheckbox
id="disable_data_preview"
@ -220,6 +220,22 @@ const ExtraOptions = ({
/>
</div>
</StyledInputContainer>
<StyledInputContainer>
<div className="input-container">
<IndeterminateCheckbox
id="expand_rows"
indeterminate={false}
checked={!!extraJson?.schema_options?.expand_rows}
onChange={onExtraInputChange}
labelText={t('Enable row expansion in schemas')}
/>
<InfoTooltip
tooltip={t(
'For Trino, describe full schemas of nested ROW types, expanding them with dotted paths',
)}
/>
</div>
</StyledInputContainer>
</StyledExpandableForm>
</StyledInputContainer>
</Collapse.Panel>

View File

@ -674,7 +674,7 @@ describe('DatabaseModal', () => {
const exposeInSQLLabCheckbox = screen.getByRole('checkbox', {
name: /expose database in sql lab/i,
});
// This is both the checkbox and it's respective SVG
// This is both the checkbox and its respective SVG
// const exposeInSQLLabCheckboxSVG = checkboxOffSVGs[0].parentElement;
const exposeInSQLLabText = screen.getByText(
/expose database in sql lab/i,
@ -721,6 +721,13 @@ describe('DatabaseModal', () => {
/Disable SQL Lab data preview queries/i,
);
const enableRowExpansionCheckbox = screen.getByRole('checkbox', {
name: /enable row expansion in schemas/i,
});
const enableRowExpansionText = screen.getByText(
/enable row expansion in schemas/i,
);
// ---------- Assertions ----------
const visibleComponents = [
closeButton,
@ -737,6 +744,7 @@ describe('DatabaseModal', () => {
checkboxOffSVGs[2],
checkboxOffSVGs[3],
checkboxOffSVGs[4],
checkboxOffSVGs[5],
tooltipIcons[0],
tooltipIcons[1],
tooltipIcons[2],
@ -744,6 +752,7 @@ describe('DatabaseModal', () => {
tooltipIcons[4],
tooltipIcons[5],
tooltipIcons[6],
tooltipIcons[7],
exposeInSQLLabText,
allowCTASText,
allowCVASText,
@ -754,6 +763,7 @@ describe('DatabaseModal', () => {
enableQueryCostEstimationText,
allowDbExplorationText,
disableSQLLabDataPreviewQueriesText,
enableRowExpansionText,
];
// These components exist in the DOM but are not visible
const invisibleComponents = [
@ -764,6 +774,7 @@ describe('DatabaseModal', () => {
enableQueryCostEstimationCheckbox,
allowDbExplorationCheckbox,
disableSQLLabDataPreviewQueriesCheckbox,
enableRowExpansionCheckbox,
];
visibleComponents.forEach(component => {
expect(component).toBeVisible();
@ -771,8 +782,8 @@ describe('DatabaseModal', () => {
invisibleComponents.forEach(component => {
expect(component).not.toBeVisible();
});
expect(checkboxOffSVGs).toHaveLength(5);
expect(tooltipIcons).toHaveLength(7);
expect(checkboxOffSVGs).toHaveLength(6);
expect(tooltipIcons).toHaveLength(8);
});
test('renders the "Advanced" - PERFORMANCE tab correctly', async () => {

View File

@ -307,6 +307,18 @@ export function dbReducer(
}),
};
}
if (action.payload.name === 'expand_rows') {
return {
...trimmedState,
extra: JSON.stringify({
...extraJson,
schema_options: {
...extraJson?.schema_options,
[action.payload.name]: !!action.payload.value,
},
}),
};
}
return {
...trimmedState,
extra: JSON.stringify({

View File

@ -226,5 +226,8 @@ export interface ExtraJson {
table_cache_timeout?: number; // in Performance
}; // No field, holds schema and table timeout
schemas_allowed_for_file_upload?: string[]; // in Security
schema_options?: {
expand_rows?: boolean;
};
version?: string;
}

View File

@ -51,7 +51,7 @@ from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql import literal_column, quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
from sqlparse.tokens import CTE
@ -1322,8 +1322,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return comment
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
def get_columns( # pylint: disable=unused-argument
cls,
inspector: Inspector,
table_name: str,
schema: str | None,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
Get all columns from a given schema and table
@ -1331,6 +1335,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param inspector: SqlAlchemy Inspector instance
:param table_name: Table name
:param schema: Schema name. If omitted, uses default schema for database
:param options: Extra options to customise the display of columns in
some databases
:return: All columns in table
"""
return convert_inspector_columns(
@ -1382,7 +1388,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
return [column(c["column_name"]) for c in cols]
return [
literal_column(query_as)
if (query_as := c.get("query_as"))
else column(c["column_name"])
for c in cols
]
@classmethod
def select_star( # pylint: disable=too-many-arguments,too-many-locals

View File

@ -23,14 +23,12 @@ from datetime import datetime
from typing import Any, TYPE_CHECKING
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from superset import is_feature_enabled
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.exceptions import SupersetException
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
if TYPE_CHECKING:
@ -130,15 +128,6 @@ class DruidEngineSpec(BaseEngineSpec):
"""
return "MILLIS_TO_TIMESTAMP({col})"
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[ResultSetColumnType]:
"""
Update the Druid type map.
"""
return super().get_columns(inspector, table_name, schema)
@classmethod
def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel

View File

@ -410,9 +410,13 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
cls,
inspector: Inspector,
table_name: str,
schema: str | None,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
return BaseEngineSpec.get_columns(inspector, table_name, schema)
return BaseEngineSpec.get_columns(inspector, table_name, schema, options)
@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments

View File

@ -981,7 +981,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
cls,
inspector: Inspector,
table_name: str,
schema: str | None,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
Get columns from a Presto data source. This includes handling row and
@ -989,6 +993,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
:param inspector: object that performs database schema inspection
:param table_name: table name
:param schema: schema name
:param options: Extra configuration options, not used by this backend
:return: a list of results that contain column info
(i.e. column name and data type)
"""

View File

@ -24,8 +24,10 @@ from typing import Any, TYPE_CHECKING
import simplejson as json
from flask import current_app
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from trino.sqlalchemy import datatype
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.databases.utils import make_url_safe
@ -33,6 +35,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.db_engine_specs.presto import PrestoBaseEngineSpec
from superset.models.sql_lab import Query
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
if TYPE_CHECKING:
@ -331,3 +334,62 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
return {
requests_exceptions.ConnectionError: SupersetDBAPIConnectionError,
}
@classmethod
def _expand_columns(cls, col: ResultSetColumnType) -> list[ResultSetColumnType]:
"""
Expand the given column out to one or more columns by analysing their types,
descending into ROWS and expanding out their inner fields recursively.
We can only navigate named fields in ROWs in this way, so we can't expand out
MAP or ARRAY types, nor fields in ROWs which have no name (in fact the trino
library doesn't correctly parse unnamed fields in ROWs). We won't be able to
expand ROWs which are nested underneath any of those types, either.
Expanded columns are named foo.bar.baz and we provide a query_as property to
instruct the base engine spec how to correctly query them: instead of quoting
the whole string they have to be quoted like "foo"."bar"."baz" and we then
alias them to the full dotted string for ease of reference.
"""
cols = [col]
col_type = col.get("type")
if not isinstance(col_type, datatype.ROW):
return cols
for inner_name, inner_type in col_type.attr_types:
outer_name = col["name"]
name = ".".join([outer_name, inner_name])
query_name = ".".join([f'"{piece}"' for piece in name.split(".")])
column_spec = cls.get_column_spec(str(inner_type))
is_dttm = column_spec.is_dttm if column_spec else False
inner_col = ResultSetColumnType(
name=name,
column_name=name,
type=inner_type,
is_dttm=is_dttm,
query_as=f'{query_name} AS "{name}"',
)
cols.extend(cls._expand_columns(inner_col))
return cols
@classmethod
def get_columns(
cls,
inspector: Inspector,
table_name: str,
schema: str | None,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
If the "expand_rows" feature is enabled on the database via
"schema_options", expand the schema definition out to show all
subfields of nested ROWs as their appropriate dotted paths.
"""
base_cols = super().get_columns(inspector, table_name, schema, options)
if not (options or {}).get("expand_rows"):
return base_cols
return [col for base_col in base_cols for col in cls._expand_columns(base_col)]

View File

@ -237,6 +237,11 @@ class Database(
# this will prevent any 'trash value' strings from going through
return self.get_extra().get("disable_data_preview", False) is True
@property
def schema_options(self) -> dict[str, Any]:
"""Additional schema display config for engines with complex schemas"""
return self.get_extra().get("schema_options", {})
@property
def data(self) -> dict[str, Any]:
return {
@ -248,6 +253,7 @@ class Database(
"allows_cost_estimate": self.allows_cost_estimate,
"allows_virtual_table_explore": self.allows_virtual_table_explore,
"explore_database_id": self.explore_database_id,
"schema_options": self.schema_options,
"parameters": self.parameters,
"disable_data_preview": self.disable_data_preview,
"parameters_schema": self.parameters_schema,
@ -838,7 +844,9 @@ class Database(
self, table_name: str, schema: str | None = None
) -> list[ResultSetColumnType]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_columns(inspector, table_name, schema)
return self.db_engine_spec.get_columns(
inspector, table_name, schema, self.schema_options
)
def get_metrics(
self,

View File

@ -84,6 +84,8 @@ class ResultSetColumnType(TypedDict):
scale: NotRequired[Any]
max_length: NotRequired[Any]
query_as: NotRequired[Any]
CacheConfig = dict[str, Any]
DbapiDescriptionRow = tuple[

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import copy
import json
from datetime import datetime
from typing import Any, Optional
@ -24,9 +25,11 @@ import pandas as pd
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import types
from trino.sqlalchemy import datatype
import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
@ -35,6 +38,24 @@ from tests.unit_tests.db_engine_specs.utils import (
from tests.unit_tests.fixtures.common import dttm
def _assert_columns_equal(actual_cols, expected_cols) -> None:
"""
Assert equality of the given cols, bearing in mind sqlalchemy type
instances can't be compared for equality, so will have to be converted to
strings first.
"""
actual = copy.deepcopy(actual_cols)
expected = copy.deepcopy(expected_cols)
for col in actual:
col["type"] = str(col["type"])
for col in expected:
col["type"] = str(col["type"])
assert actual == expected
@pytest.mark.parametrize(
"extra,expected",
[
@ -395,3 +416,104 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
def test_get_columns(mocker: MockerFixture):
"""Test that ROW columns are not expanded without expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")
sqla_columns = [
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
]
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns
actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema")
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
),
ResultSetColumnType(
name="field2", column_name="field2", type=field2_type, is_dttm=False
),
ResultSetColumnType(
name="field3", column_name="field3", type=field3_type, is_dttm=False
),
]
_assert_columns_equal(actual, expected)
def test_get_columns_expand_rows(mocker: MockerFixture):
"""Test that ROW columns are correctly expanded with expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")
sqla_columns = [
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
]
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns
actual = TrinoEngineSpec.get_columns(
mock_inspector, "table", "schema", {"expand_rows": True}
)
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
),
ResultSetColumnType(
name="field1.a",
column_name="field1.a",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field1"."a" AS "field1.a"',
),
ResultSetColumnType(
name="field1.b",
column_name="field1.b",
type=types.DATE(),
is_dttm=True,
query_as='"field1"."b" AS "field1.b"',
),
ResultSetColumnType(
name="field2", column_name="field2", type=field2_type, is_dttm=False
),
ResultSetColumnType(
name="field2.r1",
column_name="field2.r1",
type=datatype.parse_sqltype("row(a varchar, b varchar)"),
is_dttm=False,
query_as='"field2"."r1" AS "field2.r1"',
),
ResultSetColumnType(
name="field2.r1.a",
column_name="field2.r1.a",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field2"."r1"."a" AS "field2.r1.a"',
),
ResultSetColumnType(
name="field2.r1.b",
column_name="field2.r1.b",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field2"."r1"."b" AS "field2.r1.b"',
),
ResultSetColumnType(
name="field3", column_name="field3", type=field3_type, is_dttm=False
),
]
_assert_columns_equal(actual, expected)