[bugfix] Correctly quote table and schema in select_star (#8181)

* Fix select_star table quoting bug

* Add unit test for fully qualified names in select_star

* Rename main_db to db

* Rename test function
This commit is contained in:
Ville Brofeldt 2019-09-05 22:44:34 +03:00 committed by GitHub
parent 8f071e8f7e
commit 4e1e54b538
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 4 deletions

View File

@ -478,7 +478,7 @@ class SqlaTable(Model, BaseDatasource):
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
self.name, show_cols=False, latest_partition=False
self.table_name, schema=self.schema, show_cols=False, latest_partition=False
)
def get_col(self, col_name):

View File

@ -104,9 +104,9 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertNotEquals(example_user, user_name)
def test_select_star(self):
main_db = get_example_database()
db = get_example_database()
table_name = "energy_usage"
sql = main_db.select_star(table_name, show_cols=False, latest_partition=False)
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
expected = textwrap.dedent(
f"""\
SELECT *
@ -115,7 +115,7 @@ class DatabaseModelTestCase(SupersetTestCase):
)
assert sql.startswith(expected)
sql = main_db.select_star(table_name, show_cols=True, latest_partition=False)
sql = db.select_star(table_name, show_cols=True, latest_partition=False)
expected = textwrap.dedent(
f"""\
SELECT source,
@ -126,6 +126,28 @@ class DatabaseModelTestCase(SupersetTestCase):
)
assert sql.startswith(expected)
def test_select_star_fully_qualified_names(self):
db = get_example_database()
schema = "schema.name"
table_name = "table/name"
sql = db.select_star(
table_name, schema=schema, show_cols=False, latest_partition=False
)
fully_qualified_names = {
"sqlite": '"schema.name"."table/name"',
"mysql": "`schema.name`.`table/name`",
"postgres": '"schema.name"."table/name"',
}
fully_qualified_name = fully_qualified_names.get(db.db_engine_spec.engine)
if fully_qualified_name:
expected = textwrap.dedent(
f"""\
SELECT *
FROM {fully_qualified_name}
LIMIT 100"""
)
assert sql.startswith(expected)
def test_single_statement(self):
main_db = get_main_database()