From 63161b11c347d5a6d62f7ae7dc91fa3c30b5dc93 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Fri, 21 Oct 2016 16:55:37 -0700 Subject: [PATCH] [sqllab] proper, quoted, select * on the server side (#1404) * [sqllab] proper, quoted, select * on the server side * fixing tests --- caravel/assets/javascripts/SqlLab/common.js | 2 -- .../SqlLab/components/SqlEditorLeftBar.jsx | 7 ++-- .../SqlLab/components/TableElement.jsx | 34 ++---------------- caravel/models.py | 26 ++++++++++++-- caravel/views.py | 5 ++- tests/celery_tests.py | 36 +++++++++---------- 6 files changed, 49 insertions(+), 61 deletions(-) diff --git a/caravel/assets/javascripts/SqlLab/common.js b/caravel/assets/javascripts/SqlLab/common.js index 4005c591a..ed5726fa0 100644 --- a/caravel/assets/javascripts/SqlLab/common.js +++ b/caravel/assets/javascripts/SqlLab/common.js @@ -7,6 +7,4 @@ export const STATE_BSSTYLE_MAP = { success: 'success', }; -export const DATA_PREVIEW_ROW_COUNT = 100; - export const STATUS_OPTIONS = ['success', 'failed', 'running']; diff --git a/caravel/assets/javascripts/SqlLab/components/SqlEditorLeftBar.jsx b/caravel/assets/javascripts/SqlLab/components/SqlEditorLeftBar.jsx index 1b5ff686e..2bd4e61e2 100644 --- a/caravel/assets/javascripts/SqlLab/components/SqlEditorLeftBar.jsx +++ b/caravel/assets/javascripts/SqlLab/components/SqlEditorLeftBar.jsx @@ -82,15 +82,12 @@ class SqlEditorLeftBar extends React.Component { this.setState({ tableLoading: true }); $.get(url, (data) => { - this.props.actions.mergeTable({ + this.props.actions.mergeTable(Object.assign(data, { dbId: this.props.queryEditor.dbId, queryEditorId: this.props.queryEditor.id, - name: data.name, - indexes: data.indexes, schema: qe.schema, - columns: data.columns, expanded: true, - }); + })); this.setState({ tableLoading: false }); }) .fail(() => { diff --git a/caravel/assets/javascripts/SqlLab/components/TableElement.jsx b/caravel/assets/javascripts/SqlLab/components/TableElement.jsx index b3295e75b..0c47d06d5 100644 --- a/caravel/assets/javascripts/SqlLab/components/TableElement.jsx +++ b/caravel/assets/javascripts/SqlLab/components/TableElement.jsx @@ -6,7 +6,6 @@ import * as Actions from '../actions'; import { ButtonGroup, Well } from 'react-bootstrap'; import shortid from 'shortid'; -import { DATA_PREVIEW_ROW_COUNT } from '../common'; import CopyToClipboard from '../../components/CopyToClipboard'; import Link from './Link'; import ModalTrigger from '../../components/ModalTrigger'; @@ -23,33 +22,6 @@ const defaultProps = { }; class TableElement extends React.Component { - setSelectStar() { - this.props.actions.queryEditorSetSql(this.props.queryEditor, this.selectStar()); - } - - selectStar(useStar = false, limit = 0) { - let cols = ''; - this.props.table.columns.forEach((col, i) => { - cols += col.name; - if (i < this.props.table.columns.length - 1) { - cols += ', '; - } - }); - let tableName = this.props.table.name; - if (this.props.table.schema) { - tableName = this.props.table.schema + '.' + tableName; - } - let sql; - if (useStar) { - sql = `SELECT * FROM ${tableName}`; - } else { - sql = `SELECT ${cols}\nFROM ${tableName}`; - } - if (limit > 0) { - sql += `\nLIMIT ${limit}`; - } - return sql; - } popSelectStar() { const qe = { @@ -57,7 +29,7 @@ class TableElement extends React.Component { title: this.props.table.name, dbId: this.props.table.dbId, autorun: true, - sql: this.selectStar(), + sql: this.props.table.selectStar, }; this.props.actions.addQueryEditor(qe); } @@ -78,7 +50,7 @@ class TableElement extends React.Component { dataPreviewModal() { const query = { dbId: this.props.queryEditor.dbId, - sql: this.selectStar(true, DATA_PREVIEW_ROW_COUNT), + sql: this.props.table.selectStar, tableName: this.props.table.name, sqlEditorId: null, tab: '', @@ -208,7 +180,7 @@ class TableElement extends React.Component { copyNode={ } - text={this.selectStar()} + text={table.selectStar} shouldShowText={false} tooltipText="Copy SELECT statement to clipboard" /> diff --git a/caravel/models.py b/caravel/models.py index 1d53de828..7950c2a0d 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -689,6 +689,12 @@ class Database(Model, AuditMixinNullable): url.database = schema return create_engine(url, **params) + def get_reserved_words(self): + return self.get_sqla_engine().dialect.preparer.reserved_words + + def get_quoter(self): + return self.get_sqla_engine().dialect.identifier_preparer.quote + def get_df(self, sql, schema): eng = self.get_sqla_engine(schema=schema) cur = eng.execute(sql, schema=schema) @@ -701,12 +707,26 @@ class Database(Model, AuditMixinNullable): compiled = qry.compile(eng, compile_kwargs={"literal_binds": True}) return '{}'.format(compiled) - def select_star(self, table_name, schema=None, limit=1000): + def select_star( + self, table_name, schema=None, limit=100, show_cols=False, + indent=True): """Generates a ``select *`` statement in the proper dialect""" - qry = select('*').select_from(text(table_name)) + for i in range(10): + print(schema) + quote = self.get_quoter() + fields = '*' + table = self.get_table(table_name, schema=schema) + if show_cols: + fields = [quote(c.name) for c in table.columns] + if schema: + table_name = schema + '.' + table_name + qry = select(fields).select_from(text(table_name)) if limit: qry = qry.limit(limit) - return self.compile_sqla_query(qry) + sql = self.compile_sqla_query(qry) + if indent: + sql = sqlparse.format(sql, reindent=True) + return sql def wrap_sql_limit(self, sql, limit=1000): qry = ( diff --git a/caravel/views.py b/caravel/views.py index 5c58fbfe1..033a52d24 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -1968,6 +1968,8 @@ class Caravel(BaseCaravelView): tbl = { 'name': table_name, 'columns': cols, + 'selectStar': mydb.select_star( + table_name, schema=schema, show_cols=True, indent=True), 'indexes': indexes, } return Response(json.dumps(tbl), mimetype="application/json") @@ -1988,6 +1990,7 @@ class Caravel(BaseCaravelView): def select_star(self, database_id, table_name): mydb = db.session.query( models.Database).filter_by(id=database_id).first() + quote = mydb.get_quoter() t = mydb.get_table(table_name) # Prevent exposing column fields to users that cannot access DB. @@ -1996,7 +1999,7 @@ class Caravel(BaseCaravelView): return redirect("/tablemodelview/list/") fields = ", ".join( - [c.name for c in t.columns] or "*") + [quote(c.name) for c in t.columns] or "*") s = "SELECT\n{}\nFROM {}".format(fields, table_name) return self.render_template( "caravel/ajah.html", diff --git a/tests/celery_tests.py b/tests/celery_tests.py index b973ae556..bce23f947 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -125,13 +125,13 @@ class CeleryTestCase(CaravelTestCase): shell=True ) - def run_sql(self, dbid, sql, client_id, cta='false', tmp_table='tmp', + def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp', async='false'): self.login() resp = self.client.post( '/caravel/sql_json/', data=dict( - database_id=dbid, + database_id=db_id, sql=sql, async=async, select_as_cta=cta, @@ -144,12 +144,11 @@ class CeleryTestCase(CaravelTestCase): def test_add_limit_to_the_query(self): session = db.session - db_to_query = session.query(models.Database).filter_by( - id=1).first() - eng = db_to_query.get_sqla_engine() + main_db = self.get_main_database(db.session) + eng = main_db.get_sqla_engine() select_query = "SELECT * FROM outer_space;" - updated_select_query = db_to_query.wrap_sql_limit(select_query, 100) + updated_select_query = main_db.wrap_sql_limit(select_query, 100) # Different DB engines have their own spacing while compiling # the queries, that's why ' '.join(query.split()) is used. # In addition some of the engines do not include OFFSET 0. @@ -159,7 +158,7 @@ class CeleryTestCase(CaravelTestCase): ) select_query_no_semicolon = "SELECT * FROM outer_space" - updated_select_query_no_semicolon = db_to_query.wrap_sql_limit( + updated_select_query_no_semicolon = main_db.wrap_sql_limit( select_query_no_semicolon, 100) self.assertTrue( "SELECT * FROM (SELECT * FROM outer_space) AS inner_qry " @@ -170,7 +169,7 @@ class CeleryTestCase(CaravelTestCase): multi_line_query = ( "SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';" ) - updated_multi_line_query = db_to_query.wrap_sql_limit(multi_line_query, 100) + updated_multi_line_query = main_db.wrap_sql_limit(multi_line_query, 100) self.assertTrue( "SELECT * FROM (SELECT * FROM planets WHERE " "Luke_Father = 'Darth Vader';) AS inner_qry LIMIT 100" in @@ -178,21 +177,21 @@ class CeleryTestCase(CaravelTestCase): ) def test_run_sync_query(self): - main_db = db.session.query(models.Database).filter_by( - database_name="main").first() + main_db = self.get_main_database(db.session) eng = main_db.get_sqla_engine() + db_id = main_db.id # Case 1. # Table doesn't exist. sql_dont_exist = 'SELECT name FROM table_dont_exist' - result1 = self.run_sql(1, sql_dont_exist, "1", cta='true') + result1 = self.run_sql(db_id, sql_dont_exist, "1", cta='true') self.assertTrue('error' in result1) # Case 2. # Table and DB exists, CTA call to the backend. sql_where = "SELECT name FROM ab_permission WHERE name='can_sql'" result2 = self.run_sql( - 1, sql_where, "2", tmp_table='tmp_table_2', cta='true') + db_id, sql_where, "2", tmp_table='tmp_table_2', cta='true') self.assertEqual(QueryStatus.SUCCESS, result2['query']['state']) self.assertEqual([], result2['data']) self.assertEqual([], result2['columns']) @@ -207,7 +206,7 @@ class CeleryTestCase(CaravelTestCase): # Table and DB exists, CTA call to the backend, no data. sql_empty_result = 'SELECT * FROM ab_user WHERE id=666' result3 = self.run_sql( - 1, sql_empty_result, "3", tmp_table='tmp_table_3', cta='true',) + db_id, sql_empty_result, "3", tmp_table='tmp_table_3', cta='true',) self.assertEqual(QueryStatus.SUCCESS, result3['query']['state']) self.assertEqual([], result3['data']) self.assertEqual([], result3['columns']) @@ -216,8 +215,7 @@ class CeleryTestCase(CaravelTestCase): self.assertEqual(QueryStatus.SUCCESS, query3.status) def test_run_async_query(self): - main_db = db.session.query(models.Database).filter_by( - database_name="main").first() + main_db = self.get_main_database(db.session) eng = main_db.get_sqla_engine() # Schedule queries @@ -226,7 +224,8 @@ class CeleryTestCase(CaravelTestCase): # Table and DB exists, async CTA call to the backend. sql_where = "SELECT name FROM ab_role WHERE name='Admin'" result1 = self.run_sql( - 1, sql_where, "4", async='true', tmp_table='tmp_async_1', cta='true') + main_db.id, sql_where, "4", async='true', tmp_table='tmp_async_1', + cta='true') assert result1['query']['state'] in ( QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) @@ -238,7 +237,7 @@ class CeleryTestCase(CaravelTestCase): self.assertEqual(QueryStatus.SUCCESS, query1.status) self.assertEqual([{'name': 'Admin'}], df1.to_dict(orient='records')) self.assertEqual(QueryStatus.SUCCESS, query1.status) - self.assertTrue("SELECT * \nFROM tmp_async_1" in query1.select_sql) + self.assertTrue("FROM tmp_async_1" in query1.select_sql) self.assertTrue("LIMIT 666" in query1.select_sql) self.assertEqual( "CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role " @@ -252,8 +251,7 @@ class CeleryTestCase(CaravelTestCase): self.assertEqual(True, query1.select_as_cta_used) def test_get_columns_dict(self): - main_db = db.session.query(models.Database).filter_by( - database_name='main').first() + main_db = self.get_main_database(db.session) df = main_db.get_df("SELECT * FROM multiformat_time_series", None) cdf = dataframe.CaravelDataFrame(df) if main_db.sqlalchemy_uri.startswith('sqlite'):