diff --git a/caravel/models.py b/caravel/models.py
index 1a4468009..71824b7fd 100644
--- a/caravel/models.py
+++ b/caravel/models.py
@@ -504,6 +504,11 @@ class Database(Model, AuditMixinNullable):
def sql_link(self):
return 'SQL'.format(self.sql_url)
+ @property
+ def perm(self):
+ return (
+ "[{obj.database_name}].(id:{obj.id})").format(obj=self)
+
class SqlaTable(Model, Queryable, AuditMixinNullable):
diff --git a/caravel/utils.py b/caravel/utils.py
index 6a2507332..ef6b35609 100644
--- a/caravel/utils.py
+++ b/caravel/utils.py
@@ -200,7 +200,7 @@ def init(caravel):
perms = db.session.query(ab_models.PermissionView).all()
for perm in perms:
- if perm.permission.name == 'datasource_access':
+ if perm.permission.name in ('datasource_access', 'database_access'):
continue
if perm.view_menu and perm.view_menu.name not in (
'UserDBModelView', 'RoleModelView', 'ResetPasswordView',
@@ -226,6 +226,7 @@ def init(caravel):
'can_edit',
'can_save',
'datasource_access',
+ 'database_access',
'muldelete',
)):
sm.add_permission_role(gamma, perm)
@@ -239,6 +240,9 @@ def init(caravel):
for table_perm in table_perms:
merge_perm(sm, 'datasource_access', table_perm)
+ db_perms = [db.perm for db in session.query(models.Database).all()]
+ for db_perm in db_perms:
+ merge_perm(sm, 'database_access', db_perm)
init_metrics_perm(caravel)
diff --git a/caravel/views.py b/caravel/views.py
index b850b1431..8076cb15b 100755
--- a/caravel/views.py
+++ b/caravel/views.py
@@ -407,6 +407,7 @@ class DatabaseView(CaravelModelView, DeleteMixin): # noqa
db.password = conn.password
conn.password = "X" * 10 if conn.password else None
db.sqlalchemy_uri = str(conn) # hides the password
+ utils.merge_perm(sm, 'database_access', db.perm)
def pre_update(self, db):
self.pre_add(db)
@@ -1176,15 +1177,17 @@ class Caravel(BaseCaravelView):
@expose("/sql//")
@log_this
def sql(self, database_id):
- if (
- not self.can_access(
- 'all_datasource_access', 'all_datasource_access')):
- flash(
- "This view requires the `all_datasource_access` "
- "permission", "danger")
- return redirect("/tablemodelview/list/")
mydb = db.session.query(
models.Database).filter_by(id=database_id).first()
+
+ if not (self.can_access(
+ 'all_datasource_access', 'all_datasource_access') or
+ self.can_access('database_access', mydb.perm)):
+ flash(
+ "This view requires the specific database or "
+ "`all_datasource_access` permission", "danger"
+ )
+ return redirect("/tablemodelview/list/")
engine = mydb.get_sqla_engine()
tables = engine.table_names()
@@ -1221,6 +1224,18 @@ class Caravel(BaseCaravelView):
mydb = db.session.query(
models.Database).filter_by(id=database_id).first()
t = mydb.get_table(table_name)
+
+ # Prevent exposing column fields to users that cannot access DB.
+ if not (self.can_access(
+ 'all_datasource_access', 'all_datasource_access') or
+ self.can_access('database_access', mydb.perm) or
+ self.can_access('datasource_access', t.perm)):
+ flash(
+ "This view requires the specific database, table or "
+ "`all_datasource_access` permission", "danger"
+ )
+ return redirect("/tablemodelview/list/")
+
fields = ", ".join(
[c.name for c in t.columns] or "*")
s = "SELECT\n{}\nFROM {}".format(fields, table_name)
@@ -1242,11 +1257,13 @@ class Caravel(BaseCaravelView):
database_id = data.get('database_id')
mydb = session.query(models.Database).filter_by(id=database_id).first()
- if (
- not self.can_access(
- 'all_datasource_access', 'all_datasource_access')):
+ if not (self.can_access(
+ 'all_datasource_access', 'all_datasource_access') or
+ self.can_access('database_access', mydb.perm)):
raise utils.CaravelSecurityException(_(
- "SQL Lab requires the `all_datasource_access` permission"))
+ "SQL Lab requires the `all_datasource_access` or "
+ "specific db permission"))
+
content = ""
if mydb:
eng = mydb.get_sqla_engine()
@@ -1254,10 +1271,12 @@ class Caravel(BaseCaravelView):
sql = sql.strip().strip(';')
qry = (
select('*')
- .select_from(TextAsFrom(text(sql), ['*']).alias('inner_qry'))
+ .select_from(TextAsFrom(text(sql), ['*'])
+ .alias('inner_qry'))
.limit(limit)
)
- sql = str(qry.compile(eng, compile_kwargs={"literal_binds": True}))
+ sql = '{}'.format(qry.compile(
+ eng, compile_kwargs={"literal_binds": True}))
try:
df = pd.read_sql_query(sql=sql, con=eng)
content = df.to_html(
@@ -1289,11 +1308,12 @@ class Caravel(BaseCaravelView):
database_id = request.form.get('database_id')
mydb = session.query(models.Database).filter_by(id=database_id).first()
- if (
- not self.can_access(
- 'all_datasource_access', 'all_datasource_access')):
+ if not (self.can_access(
+ 'all_datasource_access', 'all_datasource_access') or
+ self.can_access('database_access', mydb.perm)):
raise utils.CaravelSecurityException(_(
- "This view requires the `all_datasource_access` permission"))
+ "SQL Lab requires the `all_datasource_access` or "
+ "specific DB permission"))
error_msg = ""
if not mydb:
@@ -1304,10 +1324,12 @@ class Caravel(BaseCaravelView):
sql = sql.strip().strip(';')
qry = (
select('*')
- .select_from(TextAsFrom(text(sql), ['*']).alias('inner_qry'))
+ .select_from(TextAsFrom(text(sql), ['*'])
+ .alias('inner_qry'))
.limit(limit)
)
- sql = str(qry.compile(eng, compile_kwargs={"literal_binds": True}))
+ sql = '{}'.format(qry.compile(
+ eng, compile_kwargs={"literal_binds": True}))
try:
df = pd.read_sql_query(sql=sql, con=eng)
df = df.fillna(0) # TODO make sure NULL
@@ -1328,7 +1350,8 @@ class Caravel(BaseCaravelView):
'columns': [c for c in df.columns],
'data': df.to_dict(orient='records'),
}
- return json.dumps(data, default=utils.json_int_dttm_ser, allow_nan=False)
+ return json.dumps(
+ data, default=utils.json_int_dttm_ser, allow_nan=False)
@has_access
@expose("/refresh_datasources/")
@@ -1342,7 +1365,7 @@ class Caravel(BaseCaravelView):
except Exception as e:
flash(
"Error while processing cluster '{}'\n{}".format(
- cluster_name, str(e)),
+ cluster_name, utils.error_msg_from_exception(e)),
"danger")
logging.exception(e)
return redirect('/druidclustermodelview/list/')
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 40e206c23..ee73086f2 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -16,7 +16,7 @@ from flask import escape
from flask_appbuilder.security.sqla import models as ab_models
import caravel
-from caravel import app, db, models, utils, appbuilder
+from caravel import app, db, models, utils, appbuilder, sm
from caravel.models import DruidCluster, DruidDatasource
os.environ['CARAVEL_CONFIG'] = 'tests.caravel_test_config'
@@ -247,8 +247,8 @@ class CoreTests(CaravelTestCase):
resp = self.client.get('/dashboardmodelview/list/')
assert "List Dashboard" in resp.data.decode('utf-8')
- def run_sql(self, sql):
- self.login(username='admin')
+ def run_sql(self, sql, user_name):
+ self.login(username=user_name)
dbid = (
db.session.query(models.Database)
.filter_by(database_name="main")
@@ -258,13 +258,47 @@ class CoreTests(CaravelTestCase):
'/caravel/sql_json/',
data=dict(database_id=dbid, sql=sql),
)
+ self.logout()
return json.loads(resp.data.decode('utf-8'))
- def test_sql_json(self):
- data = self.run_sql("SELECT * FROM ab_user")
+ def test_sql_json_no_access(self):
+ self.assertRaises(
+ utils.CaravelSecurityException,
+ self.run_sql, "SELECT * FROM ab_user", 'gamma')
+
+ def test_sql_json_has_access(self):
+ main_db = (
+ db.session.query(models.Database).filter_by(database_name="main")
+ .first()
+ )
+ utils.merge_perm(sm, 'database_access', main_db.perm)
+ db.session.commit()
+ main_db_permission_view = (
+ db.session.query(ab_models.PermissionView)
+ .join(ab_models.ViewMenu)
+ .filter(ab_models.ViewMenu.name == '[main].(id:1)')
+ .first()
+ )
+ astronaut = sm.add_role("Astronaut")
+ sm.add_permission_role(astronaut, main_db_permission_view)
+ # Astronaut role is Gamme + main db permissions
+ for gamma_perm in sm.find_role('Gamma').permissions:
+ sm.add_permission_role(astronaut, gamma_perm)
+
+ gagarin = appbuilder.sm.find_user('gagarin')
+ if not gagarin:
+ appbuilder.sm.add_user(
+ 'gagarin', 'Iurii', 'Gagarin', 'gagarin@cosmos.ussr',
+ appbuilder.sm.find_role('Astronaut'),
+ password='general')
+ data = self.run_sql('SELECT * FROM ab_user', 'gagarin')
assert len(data['data']) > 0
- data = self.run_sql("SELECT * FROM unexistant_table")
+ def test_sql_json(self):
+ data = self.run_sql("SELECT * FROM ab_user", 'admin')
+ assert len(data['data']) > 0
+
+ data = self.run_sql("SELECT * FROM unexistant_table", 'admin')
assert len(data['error']) > 0
def test_public_user_dashboard_access(self):
@@ -301,7 +335,6 @@ class CoreTests(CaravelTestCase):
data = resp.data.decode('utf-8')
assert "/caravel/dashboard/world_health/" not in data
-
def test_only_owners_can_save(self):
dash = (
db.session