139 lines
4.8 KiB
Python
139 lines
4.8 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, cast
|
|
|
|
from sqlalchemy.orm import lazyload, load_only
|
|
|
|
from superset.commands.base import BaseCommand
|
|
from superset.commands.database.exceptions import (
|
|
DatabaseNotFoundError,
|
|
DatabaseTablesUnexpectedError,
|
|
)
|
|
from superset.connectors.sqla.models import SqlaTable
|
|
from superset.daos.database import DatabaseDAO
|
|
from superset.exceptions import SupersetException
|
|
from superset.extensions import db, security_manager
|
|
from superset.models.core import Database
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TablesDatabaseCommand(BaseCommand):
|
|
_model: Database
|
|
|
|
def __init__(
|
|
self,
|
|
db_id: int,
|
|
catalog_name: str | None,
|
|
schema_name: str,
|
|
force: bool,
|
|
):
|
|
self._db_id = db_id
|
|
self._catalog_name = catalog_name
|
|
self._schema_name = schema_name
|
|
self._force = force
|
|
|
|
def run(self) -> dict[str, Any]:
|
|
self.validate()
|
|
try:
|
|
tables = security_manager.get_datasources_accessible_by_user(
|
|
database=self._model,
|
|
catalog=self._catalog_name,
|
|
schema=self._schema_name,
|
|
datasource_names=sorted(
|
|
self._model.get_all_table_names_in_schema(
|
|
catalog=self._catalog_name,
|
|
schema=self._schema_name,
|
|
force=self._force,
|
|
cache=self._model.table_cache_enabled,
|
|
cache_timeout=self._model.table_cache_timeout,
|
|
)
|
|
),
|
|
)
|
|
|
|
views = security_manager.get_datasources_accessible_by_user(
|
|
database=self._model,
|
|
catalog=self._catalog_name,
|
|
schema=self._schema_name,
|
|
datasource_names=sorted(
|
|
self._model.get_all_view_names_in_schema(
|
|
catalog=self._catalog_name,
|
|
schema=self._schema_name,
|
|
force=self._force,
|
|
cache=self._model.table_cache_enabled,
|
|
cache_timeout=self._model.table_cache_timeout,
|
|
)
|
|
),
|
|
)
|
|
|
|
extra_dict_by_name = {
|
|
table.name: table.extra_dict
|
|
for table in (
|
|
db.session.query(SqlaTable)
|
|
.filter(
|
|
SqlaTable.database_id == self._model.id,
|
|
SqlaTable.catalog == self._catalog_name,
|
|
SqlaTable.schema == self._schema_name,
|
|
)
|
|
.options(
|
|
load_only(
|
|
SqlaTable.catalog,
|
|
SqlaTable.schema,
|
|
SqlaTable.table_name,
|
|
SqlaTable.extra,
|
|
),
|
|
lazyload(SqlaTable.columns),
|
|
lazyload(SqlaTable.metrics),
|
|
)
|
|
).all()
|
|
}
|
|
|
|
options = sorted(
|
|
[
|
|
{
|
|
"value": table.table,
|
|
"type": "table",
|
|
"extra": extra_dict_by_name.get(table.table, None),
|
|
}
|
|
for table in tables
|
|
]
|
|
+ [
|
|
{
|
|
"value": view.table,
|
|
"type": "view",
|
|
}
|
|
for view in views
|
|
],
|
|
key=lambda item: item["value"],
|
|
)
|
|
|
|
payload = {"count": len(tables) + len(views), "result": options}
|
|
return payload
|
|
except SupersetException:
|
|
raise
|
|
except Exception as ex:
|
|
raise DatabaseTablesUnexpectedError(ex) from ex
|
|
|
|
def validate(self) -> None:
|
|
self._model = cast(Database, DatabaseDAO.find_by_id(self._db_id))
|
|
if not self._model:
|
|
raise DatabaseNotFoundError()
|