feat(sqllab): TRINO_EXPAND_ROWS: expand columns from ROWs (#25809)
This commit is contained in:
parent
411dba240b
commit
8d73ab9955
|
|
@ -202,7 +202,7 @@ const ExtraOptions = ({
|
|||
/>
|
||||
</div>
|
||||
</StyledInputContainer>
|
||||
<StyledInputContainer>
|
||||
<StyledInputContainer css={no_margin_bottom}>
|
||||
<div className="input-container">
|
||||
<IndeterminateCheckbox
|
||||
id="disable_data_preview"
|
||||
|
|
@ -220,6 +220,22 @@ const ExtraOptions = ({
|
|||
/>
|
||||
</div>
|
||||
</StyledInputContainer>
|
||||
<StyledInputContainer>
|
||||
<div className="input-container">
|
||||
<IndeterminateCheckbox
|
||||
id="expand_rows"
|
||||
indeterminate={false}
|
||||
checked={!!extraJson?.schema_options?.expand_rows}
|
||||
onChange={onExtraInputChange}
|
||||
labelText={t('Enable row expansion in schemas')}
|
||||
/>
|
||||
<InfoTooltip
|
||||
tooltip={t(
|
||||
'For Trino, describe full schemas of nested ROW types, expanding them with dotted paths',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</StyledInputContainer>
|
||||
</StyledExpandableForm>
|
||||
</StyledInputContainer>
|
||||
</Collapse.Panel>
|
||||
|
|
|
|||
|
|
@ -674,7 +674,7 @@ describe('DatabaseModal', () => {
|
|||
const exposeInSQLLabCheckbox = screen.getByRole('checkbox', {
|
||||
name: /expose database in sql lab/i,
|
||||
});
|
||||
// This is both the checkbox and it's respective SVG
|
||||
// This is both the checkbox and its respective SVG
|
||||
// const exposeInSQLLabCheckboxSVG = checkboxOffSVGs[0].parentElement;
|
||||
const exposeInSQLLabText = screen.getByText(
|
||||
/expose database in sql lab/i,
|
||||
|
|
@ -721,6 +721,13 @@ describe('DatabaseModal', () => {
|
|||
/Disable SQL Lab data preview queries/i,
|
||||
);
|
||||
|
||||
const enableRowExpansionCheckbox = screen.getByRole('checkbox', {
|
||||
name: /enable row expansion in schemas/i,
|
||||
});
|
||||
const enableRowExpansionText = screen.getByText(
|
||||
/enable row expansion in schemas/i,
|
||||
);
|
||||
|
||||
// ---------- Assertions ----------
|
||||
const visibleComponents = [
|
||||
closeButton,
|
||||
|
|
@ -737,6 +744,7 @@ describe('DatabaseModal', () => {
|
|||
checkboxOffSVGs[2],
|
||||
checkboxOffSVGs[3],
|
||||
checkboxOffSVGs[4],
|
||||
checkboxOffSVGs[5],
|
||||
tooltipIcons[0],
|
||||
tooltipIcons[1],
|
||||
tooltipIcons[2],
|
||||
|
|
@ -744,6 +752,7 @@ describe('DatabaseModal', () => {
|
|||
tooltipIcons[4],
|
||||
tooltipIcons[5],
|
||||
tooltipIcons[6],
|
||||
tooltipIcons[7],
|
||||
exposeInSQLLabText,
|
||||
allowCTASText,
|
||||
allowCVASText,
|
||||
|
|
@ -754,6 +763,7 @@ describe('DatabaseModal', () => {
|
|||
enableQueryCostEstimationText,
|
||||
allowDbExplorationText,
|
||||
disableSQLLabDataPreviewQueriesText,
|
||||
enableRowExpansionText,
|
||||
];
|
||||
// These components exist in the DOM but are not visible
|
||||
const invisibleComponents = [
|
||||
|
|
@ -764,6 +774,7 @@ describe('DatabaseModal', () => {
|
|||
enableQueryCostEstimationCheckbox,
|
||||
allowDbExplorationCheckbox,
|
||||
disableSQLLabDataPreviewQueriesCheckbox,
|
||||
enableRowExpansionCheckbox,
|
||||
];
|
||||
visibleComponents.forEach(component => {
|
||||
expect(component).toBeVisible();
|
||||
|
|
@ -771,8 +782,8 @@ describe('DatabaseModal', () => {
|
|||
invisibleComponents.forEach(component => {
|
||||
expect(component).not.toBeVisible();
|
||||
});
|
||||
expect(checkboxOffSVGs).toHaveLength(5);
|
||||
expect(tooltipIcons).toHaveLength(7);
|
||||
expect(checkboxOffSVGs).toHaveLength(6);
|
||||
expect(tooltipIcons).toHaveLength(8);
|
||||
});
|
||||
|
||||
test('renders the "Advanced" - PERFORMANCE tab correctly', async () => {
|
||||
|
|
|
|||
|
|
@ -307,6 +307,18 @@ export function dbReducer(
|
|||
}),
|
||||
};
|
||||
}
|
||||
if (action.payload.name === 'expand_rows') {
|
||||
return {
|
||||
...trimmedState,
|
||||
extra: JSON.stringify({
|
||||
...extraJson,
|
||||
schema_options: {
|
||||
...extraJson?.schema_options,
|
||||
[action.payload.name]: !!action.payload.value,
|
||||
},
|
||||
}),
|
||||
};
|
||||
}
|
||||
return {
|
||||
...trimmedState,
|
||||
extra: JSON.stringify({
|
||||
|
|
|
|||
|
|
@ -226,5 +226,8 @@ export interface ExtraJson {
|
|||
table_cache_timeout?: number; // in Performance
|
||||
}; // No field, holds schema and table timeout
|
||||
schemas_allowed_for_file_upload?: string[]; // in Security
|
||||
schema_options?: {
|
||||
expand_rows?: boolean;
|
||||
};
|
||||
version?: string;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ from sqlalchemy.engine.reflection import Inspector
|
|||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import quoted_name, text
|
||||
from sqlalchemy.sql import literal_column, quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
|
||||
from sqlalchemy.types import TypeEngine
|
||||
from sqlparse.tokens import CTE
|
||||
|
|
@ -1322,8 +1322,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return comment
|
||||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
def get_columns( # pylint: disable=unused-argument
|
||||
cls,
|
||||
inspector: Inspector,
|
||||
table_name: str,
|
||||
schema: str | None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
Get all columns from a given schema and table
|
||||
|
|
@ -1331,6 +1335,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param inspector: SqlAlchemy Inspector instance
|
||||
:param table_name: Table name
|
||||
:param schema: Schema name. If omitted, uses default schema for database
|
||||
:param options: Extra options to customise the display of columns in
|
||||
some databases
|
||||
:return: All columns in table
|
||||
"""
|
||||
return convert_inspector_columns(
|
||||
|
|
@ -1382,7 +1388,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
|
||||
return [column(c["column_name"]) for c in cols]
|
||||
return [
|
||||
literal_column(query_as)
|
||||
if (query_as := c.get("query_as"))
|
||||
else column(c["column_name"])
|
||||
for c in cols
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def select_star( # pylint: disable=too-many-arguments,too-many-locals
|
||||
|
|
|
|||
|
|
@ -23,14 +23,12 @@ from datetime import datetime
|
|||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils import core as utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -130,15 +128,6 @@ class DruidEngineSpec(BaseEngineSpec):
|
|||
"""
|
||||
return "MILLIS_TO_TIMESTAMP({col})"
|
||||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
Update the Druid type map.
|
||||
"""
|
||||
return super().get_columns(inspector, table_name, schema)
|
||||
|
||||
@classmethod
|
||||
def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
|
|
|||
|
|
@ -410,9 +410,13 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
cls,
|
||||
inspector: Inspector,
|
||||
table_name: str,
|
||||
schema: str | None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> list[ResultSetColumnType]:
|
||||
return BaseEngineSpec.get_columns(inspector, table_name, schema)
|
||||
return BaseEngineSpec.get_columns(inspector, table_name, schema, options)
|
||||
|
||||
@classmethod
|
||||
def where_latest_partition( # pylint: disable=too-many-arguments
|
||||
|
|
|
|||
|
|
@ -981,7 +981,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
cls,
|
||||
inspector: Inspector,
|
||||
table_name: str,
|
||||
schema: str | None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
Get columns from a Presto data source. This includes handling row and
|
||||
|
|
@ -989,6 +993,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
:param inspector: object that performs database schema inspection
|
||||
:param table_name: table name
|
||||
:param schema: schema name
|
||||
:param options: Extra configuration options, not used by this backend
|
||||
:return: a list of results that contain column info
|
||||
(i.e. column name and data type)
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -24,8 +24,10 @@ from typing import Any, TYPE_CHECKING
|
|||
|
||||
import simplejson as json
|
||||
from flask import current_app
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.orm import Session
|
||||
from trino.sqlalchemy import datatype
|
||||
|
||||
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
|
@ -33,6 +35,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
|
|||
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
|
||||
from superset.db_engine_specs.presto import PrestoBaseEngineSpec
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils import core as utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -331,3 +334,62 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
return {
|
||||
requests_exceptions.ConnectionError: SupersetDBAPIConnectionError,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _expand_columns(cls, col: ResultSetColumnType) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
Expand the given column out to one or more columns by analysing their types,
|
||||
descending into ROWS and expanding out their inner fields recursively.
|
||||
|
||||
We can only navigate named fields in ROWs in this way, so we can't expand out
|
||||
MAP or ARRAY types, nor fields in ROWs which have no name (in fact the trino
|
||||
library doesn't correctly parse unnamed fields in ROWs). We won't be able to
|
||||
expand ROWs which are nested underneath any of those types, either.
|
||||
|
||||
Expanded columns are named foo.bar.baz and we provide a query_as property to
|
||||
instruct the base engine spec how to correctly query them: instead of quoting
|
||||
the whole string they have to be quoted like "foo"."bar"."baz" and we then
|
||||
alias them to the full dotted string for ease of reference.
|
||||
"""
|
||||
cols = [col]
|
||||
col_type = col.get("type")
|
||||
|
||||
if not isinstance(col_type, datatype.ROW):
|
||||
return cols
|
||||
|
||||
for inner_name, inner_type in col_type.attr_types:
|
||||
outer_name = col["name"]
|
||||
name = ".".join([outer_name, inner_name])
|
||||
query_name = ".".join([f'"{piece}"' for piece in name.split(".")])
|
||||
column_spec = cls.get_column_spec(str(inner_type))
|
||||
is_dttm = column_spec.is_dttm if column_spec else False
|
||||
|
||||
inner_col = ResultSetColumnType(
|
||||
name=name,
|
||||
column_name=name,
|
||||
type=inner_type,
|
||||
is_dttm=is_dttm,
|
||||
query_as=f'{query_name} AS "{name}"',
|
||||
)
|
||||
cols.extend(cls._expand_columns(inner_col))
|
||||
|
||||
return cols
|
||||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls,
|
||||
inspector: Inspector,
|
||||
table_name: str,
|
||||
schema: str | None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
If the "expand_rows" feature is enabled on the database via
|
||||
"schema_options", expand the schema definition out to show all
|
||||
subfields of nested ROWs as their appropriate dotted paths.
|
||||
"""
|
||||
base_cols = super().get_columns(inspector, table_name, schema, options)
|
||||
if not (options or {}).get("expand_rows"):
|
||||
return base_cols
|
||||
|
||||
return [col for base_col in base_cols for col in cls._expand_columns(base_col)]
|
||||
|
|
|
|||
|
|
@ -237,6 +237,11 @@ class Database(
|
|||
# this will prevent any 'trash value' strings from going through
|
||||
return self.get_extra().get("disable_data_preview", False) is True
|
||||
|
||||
@property
|
||||
def schema_options(self) -> dict[str, Any]:
|
||||
"""Additional schema display config for engines with complex schemas"""
|
||||
return self.get_extra().get("schema_options", {})
|
||||
|
||||
@property
|
||||
def data(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
@ -248,6 +253,7 @@ class Database(
|
|||
"allows_cost_estimate": self.allows_cost_estimate,
|
||||
"allows_virtual_table_explore": self.allows_virtual_table_explore,
|
||||
"explore_database_id": self.explore_database_id,
|
||||
"schema_options": self.schema_options,
|
||||
"parameters": self.parameters,
|
||||
"disable_data_preview": self.disable_data_preview,
|
||||
"parameters_schema": self.parameters_schema,
|
||||
|
|
@ -838,7 +844,9 @@ class Database(
|
|||
self, table_name: str, schema: str | None = None
|
||||
) -> list[ResultSetColumnType]:
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
return self.db_engine_spec.get_columns(inspector, table_name, schema)
|
||||
return self.db_engine_spec.get_columns(
|
||||
inspector, table_name, schema, self.schema_options
|
||||
)
|
||||
|
||||
def get_metrics(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ class ResultSetColumnType(TypedDict):
|
|||
scale: NotRequired[Any]
|
||||
max_length: NotRequired[Any]
|
||||
|
||||
query_as: NotRequired[Any]
|
||||
|
||||
|
||||
CacheConfig = dict[str, Any]
|
||||
DbapiDescriptionRow = tuple[
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||
import copy
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
|
@ -24,9 +25,11 @@ import pandas as pd
|
|||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import types
|
||||
from trino.sqlalchemy import datatype
|
||||
|
||||
import superset.config
|
||||
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
|
||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.unit_tests.db_engine_specs.utils import (
|
||||
assert_column_spec,
|
||||
|
|
@ -35,6 +38,24 @@ from tests.unit_tests.db_engine_specs.utils import (
|
|||
from tests.unit_tests.fixtures.common import dttm
|
||||
|
||||
|
||||
def _assert_columns_equal(actual_cols, expected_cols) -> None:
|
||||
"""
|
||||
Assert equality of the given cols, bearing in mind sqlalchemy type
|
||||
instances can't be compared for equality, so will have to be converted to
|
||||
strings first.
|
||||
"""
|
||||
actual = copy.deepcopy(actual_cols)
|
||||
expected = copy.deepcopy(expected_cols)
|
||||
|
||||
for col in actual:
|
||||
col["type"] = str(col["type"])
|
||||
|
||||
for col in expected:
|
||||
col["type"] = str(col["type"])
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extra,expected",
|
||||
[
|
||||
|
|
@ -395,3 +416,104 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
|
|||
mock_query.set_extra_json_key.assert_called_once_with(
|
||||
key=QUERY_CANCEL_KEY, value=query_id
|
||||
)
|
||||
|
||||
|
||||
def test_get_columns(mocker: MockerFixture):
|
||||
"""Test that ROW columns are not expanded without expand_rows"""
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
|
||||
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
|
||||
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
|
||||
field3_type = datatype.parse_sqltype("int")
|
||||
|
||||
sqla_columns = [
|
||||
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
|
||||
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
|
||||
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
|
||||
]
|
||||
mock_inspector = mocker.MagicMock()
|
||||
mock_inspector.get_columns.return_value = sqla_columns
|
||||
|
||||
actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema")
|
||||
expected = [
|
||||
ResultSetColumnType(
|
||||
name="field1", column_name="field1", type=field1_type, is_dttm=False
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field2", column_name="field2", type=field2_type, is_dttm=False
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field3", column_name="field3", type=field3_type, is_dttm=False
|
||||
),
|
||||
]
|
||||
|
||||
_assert_columns_equal(actual, expected)
|
||||
|
||||
|
||||
def test_get_columns_expand_rows(mocker: MockerFixture):
|
||||
"""Test that ROW columns are correctly expanded with expand_rows"""
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
|
||||
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
|
||||
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
|
||||
field3_type = datatype.parse_sqltype("int")
|
||||
|
||||
sqla_columns = [
|
||||
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
|
||||
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
|
||||
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
|
||||
]
|
||||
mock_inspector = mocker.MagicMock()
|
||||
mock_inspector.get_columns.return_value = sqla_columns
|
||||
|
||||
actual = TrinoEngineSpec.get_columns(
|
||||
mock_inspector, "table", "schema", {"expand_rows": True}
|
||||
)
|
||||
expected = [
|
||||
ResultSetColumnType(
|
||||
name="field1", column_name="field1", type=field1_type, is_dttm=False
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field1.a",
|
||||
column_name="field1.a",
|
||||
type=types.VARCHAR(),
|
||||
is_dttm=False,
|
||||
query_as='"field1"."a" AS "field1.a"',
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field1.b",
|
||||
column_name="field1.b",
|
||||
type=types.DATE(),
|
||||
is_dttm=True,
|
||||
query_as='"field1"."b" AS "field1.b"',
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field2", column_name="field2", type=field2_type, is_dttm=False
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field2.r1",
|
||||
column_name="field2.r1",
|
||||
type=datatype.parse_sqltype("row(a varchar, b varchar)"),
|
||||
is_dttm=False,
|
||||
query_as='"field2"."r1" AS "field2.r1"',
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field2.r1.a",
|
||||
column_name="field2.r1.a",
|
||||
type=types.VARCHAR(),
|
||||
is_dttm=False,
|
||||
query_as='"field2"."r1"."a" AS "field2.r1.a"',
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field2.r1.b",
|
||||
column_name="field2.r1.b",
|
||||
type=types.VARCHAR(),
|
||||
is_dttm=False,
|
||||
query_as='"field2"."r1"."b" AS "field2.r1.b"',
|
||||
),
|
||||
ResultSetColumnType(
|
||||
name="field3", column_name="field3", type=field3_type, is_dttm=False
|
||||
),
|
||||
]
|
||||
|
||||
_assert_columns_equal(actual, expected)
|
||||
|
|
|
|||
Loading…
Reference in New Issue