[sqllab] proper, quoted, select * on the server side (#1404)
* [sqllab] proper, quoted, select * on the server side * fixing tests
This commit is contained in:
parent
4f886d65ec
commit
63161b11c3
|
|
@ -7,6 +7,4 @@ export const STATE_BSSTYLE_MAP = {
|
|||
success: 'success',
|
||||
};
|
||||
|
||||
export const DATA_PREVIEW_ROW_COUNT = 100;
|
||||
|
||||
export const STATUS_OPTIONS = ['success', 'failed', 'running'];
|
||||
|
|
|
|||
|
|
@ -82,15 +82,12 @@ class SqlEditorLeftBar extends React.Component {
|
|||
|
||||
this.setState({ tableLoading: true });
|
||||
$.get(url, (data) => {
|
||||
this.props.actions.mergeTable({
|
||||
this.props.actions.mergeTable(Object.assign(data, {
|
||||
dbId: this.props.queryEditor.dbId,
|
||||
queryEditorId: this.props.queryEditor.id,
|
||||
name: data.name,
|
||||
indexes: data.indexes,
|
||||
schema: qe.schema,
|
||||
columns: data.columns,
|
||||
expanded: true,
|
||||
});
|
||||
}));
|
||||
this.setState({ tableLoading: false });
|
||||
})
|
||||
.fail(() => {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import * as Actions from '../actions';
|
|||
import { ButtonGroup, Well } from 'react-bootstrap';
|
||||
import shortid from 'shortid';
|
||||
|
||||
import { DATA_PREVIEW_ROW_COUNT } from '../common';
|
||||
import CopyToClipboard from '../../components/CopyToClipboard';
|
||||
import Link from './Link';
|
||||
import ModalTrigger from '../../components/ModalTrigger';
|
||||
|
|
@ -23,33 +22,6 @@ const defaultProps = {
|
|||
};
|
||||
|
||||
class TableElement extends React.Component {
|
||||
setSelectStar() {
|
||||
this.props.actions.queryEditorSetSql(this.props.queryEditor, this.selectStar());
|
||||
}
|
||||
|
||||
selectStar(useStar = false, limit = 0) {
|
||||
let cols = '';
|
||||
this.props.table.columns.forEach((col, i) => {
|
||||
cols += col.name;
|
||||
if (i < this.props.table.columns.length - 1) {
|
||||
cols += ', ';
|
||||
}
|
||||
});
|
||||
let tableName = this.props.table.name;
|
||||
if (this.props.table.schema) {
|
||||
tableName = this.props.table.schema + '.' + tableName;
|
||||
}
|
||||
let sql;
|
||||
if (useStar) {
|
||||
sql = `SELECT * FROM ${tableName}`;
|
||||
} else {
|
||||
sql = `SELECT ${cols}\nFROM ${tableName}`;
|
||||
}
|
||||
if (limit > 0) {
|
||||
sql += `\nLIMIT ${limit}`;
|
||||
}
|
||||
return sql;
|
||||
}
|
||||
|
||||
popSelectStar() {
|
||||
const qe = {
|
||||
|
|
@ -57,7 +29,7 @@ class TableElement extends React.Component {
|
|||
title: this.props.table.name,
|
||||
dbId: this.props.table.dbId,
|
||||
autorun: true,
|
||||
sql: this.selectStar(),
|
||||
sql: this.props.table.selectStar,
|
||||
};
|
||||
this.props.actions.addQueryEditor(qe);
|
||||
}
|
||||
|
|
@ -78,7 +50,7 @@ class TableElement extends React.Component {
|
|||
dataPreviewModal() {
|
||||
const query = {
|
||||
dbId: this.props.queryEditor.dbId,
|
||||
sql: this.selectStar(true, DATA_PREVIEW_ROW_COUNT),
|
||||
sql: this.props.table.selectStar,
|
||||
tableName: this.props.table.name,
|
||||
sqlEditorId: null,
|
||||
tab: '',
|
||||
|
|
@ -208,7 +180,7 @@ class TableElement extends React.Component {
|
|||
copyNode={
|
||||
<a className="fa fa-clipboard pull-left m-l-2" />
|
||||
}
|
||||
text={this.selectStar()}
|
||||
text={table.selectStar}
|
||||
shouldShowText={false}
|
||||
tooltipText="Copy SELECT statement to clipboard"
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -689,6 +689,12 @@ class Database(Model, AuditMixinNullable):
|
|||
url.database = schema
|
||||
return create_engine(url, **params)
|
||||
|
||||
def get_reserved_words(self):
|
||||
return self.get_sqla_engine().dialect.preparer.reserved_words
|
||||
|
||||
def get_quoter(self):
|
||||
return self.get_sqla_engine().dialect.identifier_preparer.quote
|
||||
|
||||
def get_df(self, sql, schema):
|
||||
eng = self.get_sqla_engine(schema=schema)
|
||||
cur = eng.execute(sql, schema=schema)
|
||||
|
|
@ -701,12 +707,26 @@ class Database(Model, AuditMixinNullable):
|
|||
compiled = qry.compile(eng, compile_kwargs={"literal_binds": True})
|
||||
return '{}'.format(compiled)
|
||||
|
||||
def select_star(self, table_name, schema=None, limit=1000):
|
||||
def select_star(
|
||||
self, table_name, schema=None, limit=100, show_cols=False,
|
||||
indent=True):
|
||||
"""Generates a ``select *`` statement in the proper dialect"""
|
||||
qry = select('*').select_from(text(table_name))
|
||||
for i in range(10):
|
||||
print(schema)
|
||||
quote = self.get_quoter()
|
||||
fields = '*'
|
||||
table = self.get_table(table_name, schema=schema)
|
||||
if show_cols:
|
||||
fields = [quote(c.name) for c in table.columns]
|
||||
if schema:
|
||||
table_name = schema + '.' + table_name
|
||||
qry = select(fields).select_from(text(table_name))
|
||||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
return self.compile_sqla_query(qry)
|
||||
sql = self.compile_sqla_query(qry)
|
||||
if indent:
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
return sql
|
||||
|
||||
def wrap_sql_limit(self, sql, limit=1000):
|
||||
qry = (
|
||||
|
|
|
|||
|
|
@ -1968,6 +1968,8 @@ class Caravel(BaseCaravelView):
|
|||
tbl = {
|
||||
'name': table_name,
|
||||
'columns': cols,
|
||||
'selectStar': mydb.select_star(
|
||||
table_name, schema=schema, show_cols=True, indent=True),
|
||||
'indexes': indexes,
|
||||
}
|
||||
return Response(json.dumps(tbl), mimetype="application/json")
|
||||
|
|
@ -1988,6 +1990,7 @@ class Caravel(BaseCaravelView):
|
|||
def select_star(self, database_id, table_name):
|
||||
mydb = db.session.query(
|
||||
models.Database).filter_by(id=database_id).first()
|
||||
quote = mydb.get_quoter()
|
||||
t = mydb.get_table(table_name)
|
||||
|
||||
# Prevent exposing column fields to users that cannot access DB.
|
||||
|
|
@ -1996,7 +1999,7 @@ class Caravel(BaseCaravelView):
|
|||
return redirect("/tablemodelview/list/")
|
||||
|
||||
fields = ", ".join(
|
||||
[c.name for c in t.columns] or "*")
|
||||
[quote(c.name) for c in t.columns] or "*")
|
||||
s = "SELECT\n{}\nFROM {}".format(fields, table_name)
|
||||
return self.render_template(
|
||||
"caravel/ajah.html",
|
||||
|
|
|
|||
|
|
@ -125,13 +125,13 @@ class CeleryTestCase(CaravelTestCase):
|
|||
shell=True
|
||||
)
|
||||
|
||||
def run_sql(self, dbid, sql, client_id, cta='false', tmp_table='tmp',
|
||||
def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp',
|
||||
async='false'):
|
||||
self.login()
|
||||
resp = self.client.post(
|
||||
'/caravel/sql_json/',
|
||||
data=dict(
|
||||
database_id=dbid,
|
||||
database_id=db_id,
|
||||
sql=sql,
|
||||
async=async,
|
||||
select_as_cta=cta,
|
||||
|
|
@ -144,12 +144,11 @@ class CeleryTestCase(CaravelTestCase):
|
|||
|
||||
def test_add_limit_to_the_query(self):
|
||||
session = db.session
|
||||
db_to_query = session.query(models.Database).filter_by(
|
||||
id=1).first()
|
||||
eng = db_to_query.get_sqla_engine()
|
||||
main_db = self.get_main_database(db.session)
|
||||
eng = main_db.get_sqla_engine()
|
||||
|
||||
select_query = "SELECT * FROM outer_space;"
|
||||
updated_select_query = db_to_query.wrap_sql_limit(select_query, 100)
|
||||
updated_select_query = main_db.wrap_sql_limit(select_query, 100)
|
||||
# Different DB engines have their own spacing while compiling
|
||||
# the queries, that's why ' '.join(query.split()) is used.
|
||||
# In addition some of the engines do not include OFFSET 0.
|
||||
|
|
@ -159,7 +158,7 @@ class CeleryTestCase(CaravelTestCase):
|
|||
)
|
||||
|
||||
select_query_no_semicolon = "SELECT * FROM outer_space"
|
||||
updated_select_query_no_semicolon = db_to_query.wrap_sql_limit(
|
||||
updated_select_query_no_semicolon = main_db.wrap_sql_limit(
|
||||
select_query_no_semicolon, 100)
|
||||
self.assertTrue(
|
||||
"SELECT * FROM (SELECT * FROM outer_space) AS inner_qry "
|
||||
|
|
@ -170,7 +169,7 @@ class CeleryTestCase(CaravelTestCase):
|
|||
multi_line_query = (
|
||||
"SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';"
|
||||
)
|
||||
updated_multi_line_query = db_to_query.wrap_sql_limit(multi_line_query, 100)
|
||||
updated_multi_line_query = main_db.wrap_sql_limit(multi_line_query, 100)
|
||||
self.assertTrue(
|
||||
"SELECT * FROM (SELECT * FROM planets WHERE "
|
||||
"Luke_Father = 'Darth Vader';) AS inner_qry LIMIT 100" in
|
||||
|
|
@ -178,21 +177,21 @@ class CeleryTestCase(CaravelTestCase):
|
|||
)
|
||||
|
||||
def test_run_sync_query(self):
|
||||
main_db = db.session.query(models.Database).filter_by(
|
||||
database_name="main").first()
|
||||
main_db = self.get_main_database(db.session)
|
||||
eng = main_db.get_sqla_engine()
|
||||
|
||||
db_id = main_db.id
|
||||
# Case 1.
|
||||
# Table doesn't exist.
|
||||
sql_dont_exist = 'SELECT name FROM table_dont_exist'
|
||||
result1 = self.run_sql(1, sql_dont_exist, "1", cta='true')
|
||||
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta='true')
|
||||
self.assertTrue('error' in result1)
|
||||
|
||||
# Case 2.
|
||||
# Table and DB exists, CTA call to the backend.
|
||||
sql_where = "SELECT name FROM ab_permission WHERE name='can_sql'"
|
||||
result2 = self.run_sql(
|
||||
1, sql_where, "2", tmp_table='tmp_table_2', cta='true')
|
||||
db_id, sql_where, "2", tmp_table='tmp_table_2', cta='true')
|
||||
self.assertEqual(QueryStatus.SUCCESS, result2['query']['state'])
|
||||
self.assertEqual([], result2['data'])
|
||||
self.assertEqual([], result2['columns'])
|
||||
|
|
@ -207,7 +206,7 @@ class CeleryTestCase(CaravelTestCase):
|
|||
# Table and DB exists, CTA call to the backend, no data.
|
||||
sql_empty_result = 'SELECT * FROM ab_user WHERE id=666'
|
||||
result3 = self.run_sql(
|
||||
1, sql_empty_result, "3", tmp_table='tmp_table_3', cta='true',)
|
||||
db_id, sql_empty_result, "3", tmp_table='tmp_table_3', cta='true',)
|
||||
self.assertEqual(QueryStatus.SUCCESS, result3['query']['state'])
|
||||
self.assertEqual([], result3['data'])
|
||||
self.assertEqual([], result3['columns'])
|
||||
|
|
@ -216,8 +215,7 @@ class CeleryTestCase(CaravelTestCase):
|
|||
self.assertEqual(QueryStatus.SUCCESS, query3.status)
|
||||
|
||||
def test_run_async_query(self):
|
||||
main_db = db.session.query(models.Database).filter_by(
|
||||
database_name="main").first()
|
||||
main_db = self.get_main_database(db.session)
|
||||
eng = main_db.get_sqla_engine()
|
||||
|
||||
# Schedule queries
|
||||
|
|
@ -226,7 +224,8 @@ class CeleryTestCase(CaravelTestCase):
|
|||
# Table and DB exists, async CTA call to the backend.
|
||||
sql_where = "SELECT name FROM ab_role WHERE name='Admin'"
|
||||
result1 = self.run_sql(
|
||||
1, sql_where, "4", async='true', tmp_table='tmp_async_1', cta='true')
|
||||
main_db.id, sql_where, "4", async='true', tmp_table='tmp_async_1',
|
||||
cta='true')
|
||||
assert result1['query']['state'] in (
|
||||
QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS)
|
||||
|
||||
|
|
@ -238,7 +237,7 @@ class CeleryTestCase(CaravelTestCase):
|
|||
self.assertEqual(QueryStatus.SUCCESS, query1.status)
|
||||
self.assertEqual([{'name': 'Admin'}], df1.to_dict(orient='records'))
|
||||
self.assertEqual(QueryStatus.SUCCESS, query1.status)
|
||||
self.assertTrue("SELECT * \nFROM tmp_async_1" in query1.select_sql)
|
||||
self.assertTrue("FROM tmp_async_1" in query1.select_sql)
|
||||
self.assertTrue("LIMIT 666" in query1.select_sql)
|
||||
self.assertEqual(
|
||||
"CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role "
|
||||
|
|
@ -252,8 +251,7 @@ class CeleryTestCase(CaravelTestCase):
|
|||
self.assertEqual(True, query1.select_as_cta_used)
|
||||
|
||||
def test_get_columns_dict(self):
|
||||
main_db = db.session.query(models.Database).filter_by(
|
||||
database_name='main').first()
|
||||
main_db = self.get_main_database(db.session)
|
||||
df = main_db.get_df("SELECT * FROM multiformat_time_series", None)
|
||||
cdf = dataframe.CaravelDataFrame(df)
|
||||
if main_db.sqlalchemy_uri.startswith('sqlite'):
|
||||
|
|
|
|||
Loading…
Reference in New Issue