chore: enable lint PT009 'use regular assert over self.assert.*' (#30521)
This commit is contained in:
parent
1f013055d2
commit
a849c29288
|
|
@ -446,6 +446,7 @@ select = [
|
|||
"E7",
|
||||
"E9",
|
||||
"F",
|
||||
"PT009",
|
||||
"TRY201",
|
||||
]
|
||||
ignore = []
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class TestAsyncEventApi(SupersetTestCase):
|
|||
assert rv.status_code == 200
|
||||
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
|
||||
mock_xrange.assert_called_with(channel_id, "-", "+", 100)
|
||||
self.assertEqual(response, {"result": []})
|
||||
assert response == {"result": []}
|
||||
|
||||
def _test_events_last_id_logic(self, mock_cache):
|
||||
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
|
||||
|
|
@ -69,7 +69,7 @@ class TestAsyncEventApi(SupersetTestCase):
|
|||
assert rv.status_code == 200
|
||||
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
|
||||
mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100)
|
||||
self.assertEqual(response, {"result": []})
|
||||
assert response == {"result": []}
|
||||
|
||||
def _test_events_results_logic(self, mock_cache):
|
||||
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
|
||||
|
|
@ -115,7 +115,7 @@ class TestAsyncEventApi(SupersetTestCase):
|
|||
},
|
||||
]
|
||||
}
|
||||
self.assertEqual(response, expected)
|
||||
assert response == expected
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events_redis_cache_backend(self, mock_uuid4):
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class TestOpenApiSpec(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/_openapi"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
validate_spec(response)
|
||||
|
||||
|
|
@ -87,20 +87,20 @@ class TestBaseModelRestApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/model1api/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["list_columns"], ["id"])
|
||||
assert response["list_columns"] == ["id"]
|
||||
for result in response["result"]:
|
||||
self.assertEqual(list(result.keys()), ["id"])
|
||||
assert list(result.keys()) == ["id"]
|
||||
|
||||
# Check get response
|
||||
dashboard = db.session.query(Dashboard).first()
|
||||
uri = f"api/v1/model1api/{dashboard.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["show_columns"], ["id"])
|
||||
self.assertEqual(list(response["result"].keys()), ["id"])
|
||||
assert response["show_columns"] == ["id"]
|
||||
assert list(response["result"].keys()) == ["id"]
|
||||
|
||||
def test_default_missing_declaration_put_spec(self):
|
||||
"""
|
||||
|
|
@ -113,17 +113,18 @@ class TestBaseModelRestApi(SupersetTestCase):
|
|||
uri = "api/v1/_openapi"
|
||||
rv = self.client.get(uri)
|
||||
# dashboard model accepts all fields are null
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_mutation_spec = {
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"type": "object",
|
||||
}
|
||||
self.assertEqual(
|
||||
response["components"]["schemas"]["Model1Api.post"], expected_mutation_spec
|
||||
assert (
|
||||
response["components"]["schemas"]["Model1Api.post"]
|
||||
== expected_mutation_spec
|
||||
)
|
||||
self.assertEqual(
|
||||
response["components"]["schemas"]["Model1Api.put"], expected_mutation_spec
|
||||
assert (
|
||||
response["components"]["schemas"]["Model1Api.put"] == expected_mutation_spec
|
||||
)
|
||||
|
||||
def test_default_missing_declaration_post(self):
|
||||
|
|
@ -145,7 +146,7 @@ class TestBaseModelRestApi(SupersetTestCase):
|
|||
uri = "api/v1/model1api/"
|
||||
rv = self.client.post(uri, json=dashboard_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
expected_response = {
|
||||
"message": {
|
||||
"css": ["Unknown field."],
|
||||
|
|
@ -156,7 +157,7 @@ class TestBaseModelRestApi(SupersetTestCase):
|
|||
"slug": ["Unknown field."],
|
||||
}
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
def test_refuse_invalid_format_request(self):
|
||||
"""
|
||||
|
|
@ -169,7 +170,7 @@ class TestBaseModelRestApi(SupersetTestCase):
|
|||
rv = self.client.post(
|
||||
uri, data="a: value\nb: 1\n", content_type="application/yaml"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_default_missing_declaration_put(self):
|
||||
|
|
@ -185,14 +186,14 @@ class TestBaseModelRestApi(SupersetTestCase):
|
|||
uri = f"api/v1/model1api/{dashboard.id}"
|
||||
rv = self.client.put(uri, json=dashboard_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
expected_response = {
|
||||
"message": {
|
||||
"dashboard_title": ["Unknown field."],
|
||||
"slug": ["Unknown field."],
|
||||
}
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
class ApiOwnersTestCaseMixin:
|
||||
|
|
|
|||
|
|
@ -59,8 +59,8 @@ class TestCache(SupersetTestCase):
|
|||
)
|
||||
# restore DATA_CACHE_CONFIG
|
||||
app.config["DATA_CACHE_CONFIG"] = data_cache_config
|
||||
self.assertFalse(resp["is_cached"])
|
||||
self.assertFalse(resp_from_cache["is_cached"])
|
||||
assert not resp["is_cached"]
|
||||
assert not resp_from_cache["is_cached"]
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_slice_data_cache(self):
|
||||
|
|
@ -84,20 +84,20 @@ class TestCache(SupersetTestCase):
|
|||
resp_from_cache = self.get_json_resp(
|
||||
json_endpoint, {"form_data": json.dumps(slc.viz.form_data)}
|
||||
)
|
||||
self.assertFalse(resp["is_cached"])
|
||||
self.assertTrue(resp_from_cache["is_cached"])
|
||||
assert not resp["is_cached"]
|
||||
assert resp_from_cache["is_cached"]
|
||||
# should fallback to default cache timeout
|
||||
self.assertEqual(resp_from_cache["cache_timeout"], 10)
|
||||
self.assertEqual(resp_from_cache["status"], QueryStatus.SUCCESS)
|
||||
self.assertEqual(resp["data"], resp_from_cache["data"])
|
||||
self.assertEqual(resp["query"], resp_from_cache["query"])
|
||||
assert resp_from_cache["cache_timeout"] == 10
|
||||
assert resp_from_cache["status"] == QueryStatus.SUCCESS
|
||||
assert resp["data"] == resp_from_cache["data"]
|
||||
assert resp["query"] == resp_from_cache["query"]
|
||||
# should exists in `data_cache`
|
||||
self.assertEqual(
|
||||
cache_manager.data_cache.get(resp_from_cache["cache_key"])["query"],
|
||||
resp_from_cache["query"],
|
||||
assert (
|
||||
cache_manager.data_cache.get(resp_from_cache["cache_key"])["query"]
|
||||
== resp_from_cache["query"]
|
||||
)
|
||||
# should not exists in `cache`
|
||||
self.assertIsNone(cache_manager.cache.get(resp_from_cache["cache_key"]))
|
||||
assert cache_manager.cache.get(resp_from_cache["cache_key"]) is None
|
||||
|
||||
# reset cache config
|
||||
app.config["DATA_CACHE_CONFIG"] = data_cache_config
|
||||
|
|
|
|||
|
|
@ -336,9 +336,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
self.assertEqual(model, None)
|
||||
assert model is None
|
||||
|
||||
def test_delete_bulk_charts(self):
|
||||
"""
|
||||
|
|
@ -355,13 +355,13 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
argument = chart_ids
|
||||
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
|
||||
rv = self.delete_assert_metric(uri, "bulk_delete")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": f"Deleted {chart_count} charts"}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
for chart_id in chart_ids:
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
self.assertEqual(model, None)
|
||||
assert model is None
|
||||
|
||||
def test_delete_bulk_chart_bad_request(self):
|
||||
"""
|
||||
|
|
@ -372,7 +372,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
argument = chart_ids
|
||||
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
|
||||
rv = self.delete_assert_metric(uri, "bulk_delete")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
|
||||
def test_delete_not_found_chart(self):
|
||||
"""
|
||||
|
|
@ -382,7 +382,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
chart_id = 1000
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("create_chart_with_report")
|
||||
def test_delete_chart_with_report(self):
|
||||
|
|
@ -398,11 +398,11 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.client.delete(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
expected_response = {
|
||||
"message": "There are associated alerts or reports: report_with_chart"
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
def test_delete_bulk_charts_not_found(self):
|
||||
"""
|
||||
|
|
@ -413,7 +413,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/?q={prison.dumps(chart_ids)}"
|
||||
rv = self.delete_assert_metric(uri, "bulk_delete")
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("create_chart_with_report", "create_charts")
|
||||
def test_bulk_delete_chart_with_report(self):
|
||||
|
|
@ -434,11 +434,11 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
uri = f"api/v1/chart/?q={prison.dumps(chart_ids)}"
|
||||
rv = self.client.delete(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
expected_response = {
|
||||
"message": "There are associated alerts or reports: report_with_chart"
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
def test_delete_chart_admin_not_owned(self):
|
||||
"""
|
||||
|
|
@ -450,9 +450,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
self.assertEqual(model, None)
|
||||
assert model is None
|
||||
|
||||
def test_delete_bulk_chart_admin_not_owned(self):
|
||||
"""
|
||||
|
|
@ -471,13 +471,13 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
|
||||
rv = self.delete_assert_metric(uri, "bulk_delete")
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
expected_response = {"message": f"Deleted {chart_count} charts"}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
for chart_id in chart_ids:
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
self.assertEqual(model, None)
|
||||
assert model is None
|
||||
|
||||
def test_delete_chart_not_owned(self):
|
||||
"""
|
||||
|
|
@ -493,7 +493,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(username="alpha2", password="password")
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
db.session.delete(chart)
|
||||
db.session.delete(user_alpha1)
|
||||
db.session.delete(user_alpha2)
|
||||
|
|
@ -525,19 +525,19 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
arguments = [chart.id for chart in charts]
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.delete_assert_metric(uri, "bulk_delete")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": "Forbidden"}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
# # nothing is deleted in bulk with a list of owned and not owned charts
|
||||
arguments = [chart.id for chart in charts] + [owned_chart.id]
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.delete_assert_metric(uri, "bulk_delete")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": "Forbidden"}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
for chart in charts:
|
||||
db.session.delete(chart)
|
||||
|
|
@ -572,7 +572,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/chart/"
|
||||
rv = self.post_assert_metric(uri, chart_data, "post")
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
model = db.session.query(Slice).get(data.get("id"))
|
||||
db.session.delete(model)
|
||||
|
|
@ -590,7 +590,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/chart/"
|
||||
rv = self.post_assert_metric(uri, chart_data, "post")
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
model = db.session.query(Slice).get(data.get("id"))
|
||||
db.session.delete(model)
|
||||
|
|
@ -609,10 +609,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/chart/"
|
||||
rv = self.post_assert_metric(uri, chart_data, "post")
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": {"owners": ["Owners are invalid"]}}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
def test_create_chart_validate_params(self):
|
||||
"""
|
||||
|
|
@ -627,7 +627,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/chart/"
|
||||
rv = self.post_assert_metric(uri, chart_data, "post")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
|
||||
def test_create_chart_validate_datasource(self):
|
||||
"""
|
||||
|
|
@ -640,29 +640,24 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
"datasource_type": "unknown",
|
||||
}
|
||||
rv = self.post_assert_metric("/api/v1/chart/", chart_data, "post")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response,
|
||||
{
|
||||
assert response == {
|
||||
"message": {
|
||||
"datasource_type": [
|
||||
"Must be one of: table, dataset, query, saved_query, view."
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
chart_data = {
|
||||
"slice_name": "title1",
|
||||
"datasource_id": 0,
|
||||
"datasource_type": "table",
|
||||
}
|
||||
rv = self.post_assert_metric("/api/v1/chart/", chart_data, "post")
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
|
||||
)
|
||||
assert response == {"message": {"datasource_id": ["Datasource does not exist"]}}
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_create_chart_validate_user_is_dashboard_owner(self):
|
||||
|
|
@ -682,12 +677,11 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ALPHA_USERNAME)
|
||||
uri = "api/v1/chart/"
|
||||
rv = self.post_assert_metric(uri, chart_data, "post")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response,
|
||||
{"message": "Changing one or more of these dashboards is forbidden"},
|
||||
)
|
||||
assert response == {
|
||||
"message": "Changing one or more of these dashboards is forbidden"
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_update_chart(self):
|
||||
|
|
@ -720,23 +714,23 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
related_dashboard = db.session.query(Dashboard).filter_by(slug="births").first()
|
||||
self.assertEqual(model.created_by, admin)
|
||||
self.assertEqual(model.slice_name, "title1_changed")
|
||||
self.assertEqual(model.description, "description1")
|
||||
self.assertNotIn(admin, model.owners)
|
||||
self.assertIn(gamma, model.owners)
|
||||
self.assertEqual(model.viz_type, "viz_type1")
|
||||
self.assertEqual(model.params, """{"a": 1}""")
|
||||
self.assertEqual(model.cache_timeout, 1000)
|
||||
self.assertEqual(model.datasource_id, birth_names_table_id)
|
||||
self.assertEqual(model.datasource_type, "table")
|
||||
self.assertEqual(model.datasource_name, full_table_name)
|
||||
self.assertEqual(model.certified_by, "Mario Rossi")
|
||||
self.assertEqual(model.certification_details, "Edited certification")
|
||||
self.assertIn(model.id, [slice.id for slice in related_dashboard.slices])
|
||||
assert model.created_by == admin
|
||||
assert model.slice_name == "title1_changed"
|
||||
assert model.description == "description1"
|
||||
assert admin not in model.owners
|
||||
assert gamma in model.owners
|
||||
assert model.viz_type == "viz_type1"
|
||||
assert model.params == '{"a": 1}'
|
||||
assert model.cache_timeout == 1000
|
||||
assert model.datasource_id == birth_names_table_id
|
||||
assert model.datasource_type == "table"
|
||||
assert model.datasource_name == full_table_name
|
||||
assert model.certified_by == "Mario Rossi"
|
||||
assert model.certification_details == "Edited certification"
|
||||
assert model.id in [slice.id for slice in related_dashboard.slices]
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -755,16 +749,16 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
|
||||
response = self.get_assert_metric("api/v1/chart/", "get_list")
|
||||
res = json.loads(response.data.decode("utf-8"))["result"]
|
||||
|
||||
current_chart = [d for d in res if d["id"] == chart_id][0]
|
||||
self.assertEqual(current_chart["slice_name"], new_name)
|
||||
self.assertNotIn("username", current_chart["changed_by"].keys())
|
||||
self.assertNotIn("username", current_chart["owners"][0].keys())
|
||||
assert current_chart["slice_name"] == new_name
|
||||
assert "username" not in current_chart["changed_by"].keys()
|
||||
assert "username" not in current_chart["owners"][0].keys()
|
||||
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
|
@ -784,14 +778,14 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
|
||||
response = self.get_assert_metric(uri, "get")
|
||||
res = json.loads(response.data.decode("utf-8"))["result"]
|
||||
|
||||
self.assertEqual(res["slice_name"], new_name)
|
||||
self.assertNotIn("username", res["owners"][0].keys())
|
||||
assert res["slice_name"] == new_name
|
||||
assert "username" not in res["owners"][0].keys()
|
||||
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
|
@ -829,10 +823,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
self.assertNotIn(admin, model.owners)
|
||||
self.assertIn(gamma, model.owners)
|
||||
assert admin not in model.owners
|
||||
assert gamma in model.owners
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -848,8 +842,8 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(username="admin")
|
||||
uri = f"api/v1/chart/{self.chart.id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual([admin], self.chart.owners)
|
||||
assert rv.status_code == 200
|
||||
assert [admin] == self.chart.owners
|
||||
|
||||
@pytest.mark.usefixtures("add_dashboard_to_chart")
|
||||
def test_update_chart_clear_owner_list(self):
|
||||
|
|
@ -861,8 +855,8 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(username="admin")
|
||||
uri = f"api/v1/chart/{self.chart.id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual([], self.chart.owners)
|
||||
assert rv.status_code == 200
|
||||
assert [] == self.chart.owners
|
||||
|
||||
def test_update_chart_populate_owner(self):
|
||||
"""
|
||||
|
|
@ -873,15 +867,15 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
admin = self.get_user("admin")
|
||||
chart_id = self.insert_chart("title", [], 1).id
|
||||
model = db.session.query(Slice).get(chart_id)
|
||||
self.assertEqual(model.owners, [])
|
||||
assert model.owners == []
|
||||
chart_data = {"owners": [gamma.id]}
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model_updated = db.session.query(Slice).get(chart_id)
|
||||
self.assertNotIn(admin, model_updated.owners)
|
||||
self.assertIn(gamma, model_updated.owners)
|
||||
assert admin not in model_updated.owners
|
||||
assert gamma in model_updated.owners
|
||||
db.session.delete(model_updated)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -897,9 +891,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{self.chart.id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertIn(self.new_dashboard, self.chart.dashboards)
|
||||
self.assertNotIn(self.original_dashboard, self.chart.dashboards)
|
||||
assert rv.status_code == 200
|
||||
assert self.new_dashboard in self.chart.dashboards
|
||||
assert self.original_dashboard not in self.chart.dashboards
|
||||
|
||||
@pytest.mark.usefixtures("add_dashboard_to_chart")
|
||||
def test_not_update_chart_none_dashboards(self):
|
||||
|
|
@ -910,9 +904,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{self.chart.id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertIn(self.original_dashboard, self.chart.dashboards)
|
||||
self.assertEqual(len(self.chart.dashboards), 1)
|
||||
assert rv.status_code == 200
|
||||
assert self.original_dashboard in self.chart.dashboards
|
||||
assert len(self.chart.dashboards) == 1
|
||||
|
||||
def test_update_chart_not_owned(self):
|
||||
"""
|
||||
|
|
@ -930,7 +924,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
chart_data = {"slice_name": "title1_changed"}
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
db.session.delete(chart)
|
||||
db.session.delete(user_alpha1)
|
||||
db.session.delete(user_alpha2)
|
||||
|
|
@ -976,13 +970,13 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
uri = f"api/v1/chart/{chart.id}"
|
||||
|
||||
rv = self.put_assert_metric(uri, chart_data_with_invalid_dashboard, "put")
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": {"dashboards": ["Dashboards do not exist"]}}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
rv = self.put_assert_metric(uri, chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
db.session.delete(chart)
|
||||
db.session.delete(original_dashboard)
|
||||
|
|
@ -1001,26 +995,21 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
chart_data = {"datasource_id": 1, "datasource_type": "unknown"}
|
||||
rv = self.put_assert_metric(f"/api/v1/chart/{chart.id}", chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response,
|
||||
{
|
||||
assert response == {
|
||||
"message": {
|
||||
"datasource_type": [
|
||||
"Must be one of: table, dataset, query, saved_query, view."
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
chart_data = {"datasource_id": 0, "datasource_type": "table"}
|
||||
rv = self.put_assert_metric(f"/api/v1/chart/{chart.id}", chart_data, "put")
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
|
||||
)
|
||||
assert response == {"message": {"datasource_id": ["Datasource does not exist"]}}
|
||||
|
||||
db.session.delete(chart)
|
||||
db.session.commit()
|
||||
|
|
@ -1038,10 +1027,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/chart/" # noqa: F541
|
||||
rv = self.client.post(uri, json=chart_data)
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": {"owners": ["Owners are invalid"]}}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_get_chart(self):
|
||||
|
|
@ -1053,7 +1042,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.get_assert_metric(uri, "get")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
expected_result = {
|
||||
"cache_timeout": None,
|
||||
"certified_by": None,
|
||||
|
|
@ -1075,10 +1064,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
"is_managed_externally": False,
|
||||
}
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertIn("changed_on_delta_humanized", data["result"])
|
||||
self.assertIn("id", data["result"])
|
||||
self.assertIn("thumbnail_url", data["result"])
|
||||
self.assertIn("url", data["result"])
|
||||
assert "changed_on_delta_humanized" in data["result"]
|
||||
assert "id" in data["result"]
|
||||
assert "thumbnail_url" in data["result"]
|
||||
assert "url" in data["result"]
|
||||
for key, value in data["result"].items():
|
||||
# We can't assert timestamp values or id/urls
|
||||
if key not in (
|
||||
|
|
@ -1087,7 +1076,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
"thumbnail_url",
|
||||
"url",
|
||||
):
|
||||
self.assertEqual(value, expected_result[key])
|
||||
assert value == expected_result[key]
|
||||
db.session.delete(chart)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1099,7 +1088,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{chart_id}"
|
||||
rv = self.get_assert_metric(uri, "get")
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_get_chart_no_data_access(self):
|
||||
|
|
@ -1114,7 +1103,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
)
|
||||
uri = f"api/v1/chart/{chart_no_access.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"load_energy_table_with_slice",
|
||||
|
|
@ -1129,9 +1118,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/chart/" # noqa: F541
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 33)
|
||||
assert data["count"] == 33
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice", "add_dashboard_to_chart")
|
||||
def test_get_charts_dashboards(self):
|
||||
|
|
@ -1146,7 +1135,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
}
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["result"][0]["dashboards"] == [
|
||||
{
|
||||
|
|
@ -1172,7 +1161,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
}
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
result = data["result"]
|
||||
assert len(result) == 1
|
||||
|
|
@ -1206,26 +1195,20 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
# Filter by tag ID
|
||||
filter_params = get_filter_params("chart_tag_id", tag.id)
|
||||
response_by_id = self.get_list("chart", filter_params)
|
||||
self.assertEqual(response_by_id.status_code, 200)
|
||||
assert response_by_id.status_code == 200
|
||||
data_by_id = json.loads(response_by_id.data.decode("utf-8"))
|
||||
|
||||
# Filter by tag name
|
||||
filter_params = get_filter_params("chart_tags", tag.name)
|
||||
response_by_name = self.get_list("chart", filter_params)
|
||||
self.assertEqual(response_by_name.status_code, 200)
|
||||
assert response_by_name.status_code == 200
|
||||
data_by_name = json.loads(response_by_name.data.decode("utf-8"))
|
||||
|
||||
# Compare results
|
||||
self.assertEqual(
|
||||
data_by_id["count"],
|
||||
data_by_name["count"],
|
||||
len(expected_charts),
|
||||
)
|
||||
self.assertEqual(
|
||||
set(chart["id"] for chart in data_by_id["result"]),
|
||||
set(chart["id"] for chart in data_by_name["result"]),
|
||||
set(chart.id for chart in expected_charts),
|
||||
)
|
||||
assert data_by_id["count"] == data_by_name["count"], len(expected_charts)
|
||||
assert set(chart["id"] for chart in data_by_id["result"]) == set(
|
||||
chart["id"] for chart in data_by_name["result"]
|
||||
), set(chart.id for chart in expected_charts)
|
||||
|
||||
def test_get_charts_changed_on(self):
|
||||
"""
|
||||
|
|
@ -1243,7 +1226,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["result"][0]["changed_on_delta_humanized"] in (
|
||||
"now",
|
||||
|
|
@ -1266,9 +1249,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
arguments = {"filters": [{"col": "slice_name", "opr": "sw", "value": "G"}]}
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 5)
|
||||
assert data["count"] == 5
|
||||
|
||||
@pytest.fixture()
|
||||
def load_energy_charts(self):
|
||||
|
|
@ -1323,9 +1306,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 4)
|
||||
assert data["count"] == 4
|
||||
|
||||
expected_response = [
|
||||
{"description": "ZY_bar", "slice_name": "foo_a", "viz_type": None},
|
||||
|
|
@ -1334,11 +1317,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
{"description": "desc1", "slice_name": "zy_foo", "viz_type": None},
|
||||
]
|
||||
for index, item in enumerate(data["result"]):
|
||||
self.assertEqual(
|
||||
item["description"], expected_response[index]["description"]
|
||||
)
|
||||
self.assertEqual(item["slice_name"], expected_response[index]["slice_name"])
|
||||
self.assertEqual(item["viz_type"], expected_response[index]["viz_type"])
|
||||
assert item["description"] == expected_response[index]["description"]
|
||||
assert item["slice_name"] == expected_response[index]["slice_name"]
|
||||
assert item["viz_type"] == expected_response[index]["viz_type"]
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice", "load_energy_charts")
|
||||
def test_admin_gets_filtered_energy_slices(self):
|
||||
|
|
@ -1390,9 +1371,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], CHARTS_FIXTURE_COUNT)
|
||||
assert data["count"] == CHARTS_FIXTURE_COUNT
|
||||
|
||||
@pytest.mark.usefixtures("create_charts")
|
||||
def test_gets_not_certified_charts_filter(self):
|
||||
|
|
@ -1411,9 +1392,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 17)
|
||||
assert data["count"] == 17
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_charts")
|
||||
def test_user_gets_none_filtered_energy_slices(self):
|
||||
|
|
@ -1433,9 +1414,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(GAMMA_USERNAME)
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 0)
|
||||
assert data["count"] == 0
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_charts")
|
||||
def test_user_gets_all_charts(self):
|
||||
|
|
@ -1445,12 +1426,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
def count_charts():
|
||||
uri = "api/v1/chart/"
|
||||
rv = self.client.get(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = rv.get_json()
|
||||
return data["count"]
|
||||
|
||||
with self.temporary_user(gamma_user, login=True):
|
||||
self.assertEqual(count_charts(), 0)
|
||||
assert count_charts() == 0
|
||||
|
||||
perm = ("all_database_access", "all_database_access")
|
||||
with self.temporary_user(gamma_user, extra_pvms=[perm], login=True):
|
||||
|
|
@ -1462,7 +1443,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
# Back to normal
|
||||
with self.temporary_user(gamma_user, login=True):
|
||||
self.assertEqual(count_charts(), 0)
|
||||
assert count_charts() == 0
|
||||
|
||||
@pytest.mark.usefixtures("create_charts")
|
||||
def test_get_charts_favorite_filter(self):
|
||||
|
|
@ -1645,7 +1626,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
uri = f"api/v1/time_range/?q={prison.dumps(humanize_time_range)}"
|
||||
rv = self.client.get(uri)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
assert "since" in data["result"][0]
|
||||
assert "until" in data["result"][0]
|
||||
assert "timeRange" in data["result"][0]
|
||||
|
|
@ -1686,10 +1667,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
uri = f"api/v1/form_data/?slice_id={slice.id if slice else None}"
|
||||
rv = self.client.get(uri)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(rv.content_type, "application/json")
|
||||
assert rv.status_code == 200
|
||||
assert rv.content_type == "application/json"
|
||||
if slice:
|
||||
self.assertEqual(data["slice_id"], slice.id)
|
||||
assert data["slice_id"] == slice.id
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"load_unicode_dashboard_with_slice",
|
||||
|
|
@ -1706,16 +1687,16 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
arguments = {"page_size": 10, "page": 0}
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(len(data["result"]), 10)
|
||||
assert len(data["result"]) == 10
|
||||
|
||||
arguments = {"page_size": 10, "page": 3}
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(len(data["result"]), 3)
|
||||
assert len(data["result"]) == 3
|
||||
|
||||
def test_get_charts_no_data_access(self):
|
||||
"""
|
||||
|
|
@ -1724,9 +1705,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(GAMMA_USERNAME)
|
||||
uri = "api/v1/chart/"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 0)
|
||||
assert data["count"] == 0
|
||||
|
||||
def test_export_chart(self):
|
||||
"""
|
||||
|
|
@ -1940,9 +1921,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 8)
|
||||
assert data["count"] == 8
|
||||
|
||||
def test_gets_not_created_by_user_charts_filter(self):
|
||||
arguments = {
|
||||
|
|
@ -1954,9 +1935,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
|
||||
rv = self.get_assert_metric(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 8)
|
||||
assert data["count"] == 8
|
||||
|
||||
@pytest.mark.usefixtures("create_charts")
|
||||
def test_gets_owned_created_favorited_by_me_filter(self):
|
||||
|
|
@ -1978,7 +1959,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
"page_size": 25,
|
||||
}
|
||||
rv = self.client.get(f"api/v1/chart/?q={prison.dumps(arguments)}")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
assert data["result"][0]["slice_name"] == "name0"
|
||||
|
|
@ -1995,13 +1976,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
slc = self.get_slice(slice_name)
|
||||
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id})
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(
|
||||
data["result"],
|
||||
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
|
||||
)
|
||||
assert data["result"] == [
|
||||
{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
|
||||
]
|
||||
|
||||
dashboard = self.get_dash_by_slug("births")
|
||||
|
||||
|
|
@ -2009,12 +1989,11 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
"/api/v1/chart/warm_up_cache",
|
||||
json={"chart_id": slc.id, "dashboard_id": dashboard.id},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
data["result"],
|
||||
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
|
||||
)
|
||||
assert data["result"] == [
|
||||
{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
|
||||
]
|
||||
|
||||
rv = self.client.put(
|
||||
"/api/v1/chart/warm_up_cache",
|
||||
|
|
@ -2026,29 +2005,25 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
data["result"],
|
||||
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
|
||||
)
|
||||
assert data["result"] == [
|
||||
{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
|
||||
]
|
||||
|
||||
def test_warm_up_cache_chart_id_required(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"dashboard_id": 1})
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
data,
|
||||
{"message": {"chart_id": ["Missing data for required field."]}},
|
||||
)
|
||||
assert data == {"message": {"chart_id": ["Missing data for required field."]}}
|
||||
|
||||
def test_warm_up_cache_chart_not_found(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": 99999})
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data, {"message": "Chart not found"})
|
||||
assert data == {"message": "Chart not found"}
|
||||
|
||||
def test_warm_up_cache_payload_validation(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -2056,18 +2031,15 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
"/api/v1/chart/warm_up_cache",
|
||||
json={"chart_id": "id", "dashboard_id": "id", "extra_filters": 4},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
data,
|
||||
{
|
||||
assert data == {
|
||||
"message": {
|
||||
"chart_id": ["Not a valid integer."],
|
||||
"dashboard_id": ["Not a valid integer."],
|
||||
"extra_filters": ["Not a valid string."],
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_warm_up_cache_error(self) -> None:
|
||||
|
|
@ -2167,12 +2139,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.put_assert_metric(uri, update_payload, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart.id)
|
||||
|
||||
# Clean up system tags
|
||||
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
|
||||
self.assertEqual(tag_list, new_tags)
|
||||
assert tag_list == new_tags
|
||||
|
||||
@pytest.mark.usefixtures("create_chart_with_tag")
|
||||
def test_update_chart_remove_tags_can_write_on_tag(self):
|
||||
|
|
@ -2194,12 +2166,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.put_assert_metric(uri, update_payload, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart.id)
|
||||
|
||||
# Clean up system tags
|
||||
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
|
||||
self.assertEqual(tag_list, new_tags)
|
||||
assert tag_list == new_tags
|
||||
|
||||
@pytest.mark.usefixtures("create_chart_with_tag")
|
||||
def test_update_chart_add_tags_can_tag_on_chart(self):
|
||||
|
|
@ -2226,12 +2198,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.put_assert_metric(uri, update_payload, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart.id)
|
||||
|
||||
# Clean up system tags
|
||||
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
|
||||
self.assertEqual(tag_list, new_tags)
|
||||
assert tag_list == new_tags
|
||||
|
||||
security_manager.add_permission_role(alpha_role, write_tags_perm)
|
||||
|
||||
|
|
@ -2256,12 +2228,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.put_assert_metric(uri, update_payload, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Slice).get(chart.id)
|
||||
|
||||
# Clean up system tags
|
||||
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
|
||||
self.assertEqual(tag_list, [])
|
||||
assert tag_list == []
|
||||
|
||||
security_manager.add_permission_role(alpha_role, write_tags_perm)
|
||||
|
||||
|
|
@ -2291,10 +2263,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.put_assert_metric(uri, update_payload, "put")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
self.assertEqual(
|
||||
rv.json["message"],
|
||||
"You do not have permission to manage tags on charts",
|
||||
assert rv.status_code == 403
|
||||
assert (
|
||||
rv.json["message"] == "You do not have permission to manage tags on charts"
|
||||
)
|
||||
|
||||
security_manager.add_permission_role(alpha_role, write_tags_perm)
|
||||
|
|
@ -2322,10 +2293,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.put_assert_metric(uri, update_payload, "put")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
self.assertEqual(
|
||||
rv.json["message"],
|
||||
"You do not have permission to manage tags on charts",
|
||||
assert rv.status_code == 403
|
||||
assert (
|
||||
rv.json["message"] == "You do not have permission to manage tags on charts"
|
||||
)
|
||||
|
||||
security_manager.add_permission_role(alpha_role, write_tags_perm)
|
||||
|
|
@ -2353,7 +2323,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/chart/{chart.id}"
|
||||
rv = self.put_assert_metric(uri, update_payload, "put")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
security_manager.add_permission_role(alpha_role, write_tags_perm)
|
||||
security_manager.add_permission_role(alpha_role, tag_charts_perm)
|
||||
|
|
|
|||
|
|
@ -424,15 +424,19 @@ class TestChartWarmUpCacheCommand(SupersetTestCase):
|
|||
def test_warm_up_cache(self):
|
||||
slc = self.get_slice("Top 10 Girl Name Share")
|
||||
result = ChartWarmUpCacheCommand(slc.id, None, None).run()
|
||||
self.assertEqual(
|
||||
result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
|
||||
)
|
||||
assert result == {
|
||||
"chart_id": slc.id,
|
||||
"viz_error": None,
|
||||
"viz_status": "success",
|
||||
}
|
||||
|
||||
# can just pass in chart as well
|
||||
result = ChartWarmUpCacheCommand(slc, None, None).run()
|
||||
self.assertEqual(
|
||||
result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
|
||||
)
|
||||
assert result == {
|
||||
"chart_id": slc.id,
|
||||
"viz_error": None,
|
||||
"viz_status": "success",
|
||||
}
|
||||
|
||||
|
||||
class TestFavoriteChartCommand(SupersetTestCase):
|
||||
|
|
|
|||
|
|
@ -471,19 +471,16 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
"__time_origin": "now",
|
||||
}
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
data["result"][0]["applied_filters"],
|
||||
[
|
||||
assert data["result"][0]["applied_filters"] == [
|
||||
{"column": "gender"},
|
||||
{"column": "num"},
|
||||
{"column": "name"},
|
||||
{"column": "__time_range"},
|
||||
],
|
||||
)
|
||||
]
|
||||
expected_row_count = self.get_expected_row_count("client_id_2")
|
||||
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
|
||||
assert data["result"][0]["rowcount"] == expected_row_count
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_with_in_op_filter__data_is_returned(self):
|
||||
|
|
@ -533,7 +530,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
dttm_col.type,
|
||||
dttm,
|
||||
)
|
||||
self.assertIn(dttm_expression, result["query"])
|
||||
assert dttm_expression in result["query"]
|
||||
else:
|
||||
raise Exception("ds column not found")
|
||||
|
||||
|
|
@ -563,16 +560,16 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
}
|
||||
]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
row = result["data"][0]
|
||||
self.assertIn("__timestamp", row)
|
||||
self.assertIn("sum__num", row)
|
||||
self.assertIn("sum__num__yhat", row)
|
||||
self.assertIn("sum__num__yhat_upper", row)
|
||||
self.assertIn("sum__num__yhat_lower", row)
|
||||
self.assertEqual(result["rowcount"], 103)
|
||||
assert "__timestamp" in row
|
||||
assert "sum__num" in row
|
||||
assert "sum__num__yhat" in row
|
||||
assert "sum__num__yhat_upper" in row
|
||||
assert "sum__num__yhat_lower" in row
|
||||
assert result["rowcount"] == 103
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_invalid_post_processing(self):
|
||||
|
|
@ -730,11 +727,11 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
time.sleep(1)
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
time.sleep(1)
|
||||
self.assertEqual(rv.status_code, 202)
|
||||
assert rv.status_code == 202
|
||||
time.sleep(1)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
keys = list(data.keys())
|
||||
self.assertCountEqual(
|
||||
self.assertCountEqual( # noqa: PT009
|
||||
keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
|
||||
)
|
||||
|
||||
|
|
@ -764,10 +761,10 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
rv = self.post_assert_metric(
|
||||
CHART_DATA_URI, self.query_context_payload, "data"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
patched_run.assert_called_once_with(force_cached=True)
|
||||
self.assertEqual(data, {"result": [{"query": "select * from foo"}]})
|
||||
assert data == {"result": [{"query": "select * from foo"}]}
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
|
@ -779,7 +776,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
async_query_manager_factory.init_app(app)
|
||||
self.query_context_payload["result_type"] = "results"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
|
@ -793,7 +790,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo"
|
||||
)
|
||||
rv = test_client.post(CHART_DATA_URI, json=self.query_context_payload)
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
assert rv.status_code == 401
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_rowcount(self):
|
||||
|
|
@ -846,10 +843,8 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
|
||||
unique_names = {row["name"] for row in data}
|
||||
self.maxDiff = None
|
||||
self.assertEqual(len(unique_names), SERIES_LIMIT)
|
||||
self.assertEqual(
|
||||
{column for column in data[0].keys()}, {"state", "name", "sum__num"}
|
||||
)
|
||||
assert len(unique_names) == SERIES_LIMIT
|
||||
assert {column for column in data[0].keys()} == {"state", "name", "sum__num"}
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"create_annotation_layers", "load_birth_names_dashboard_with_slices"
|
||||
|
|
@ -888,10 +883,10 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
annotation_layers.append(event)
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# response should only contain interval and event data, not formula
|
||||
self.assertEqual(len(data["result"][0]["annotation_data"]), 2)
|
||||
assert len(data["result"][0]["annotation_data"]) == 2
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_with_virtual_table_with_colons_as_datasource(self):
|
||||
|
|
@ -1184,8 +1179,8 @@ class TestGetChartDataApi(BaseTestChartDataApi):
|
|||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
expected_row_count = self.get_expected_row_count("client_id_3")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
|
||||
assert rv.status_code == 200
|
||||
assert data["result"][0]["rowcount"] == expected_row_count
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
|
||||
|
|
@ -1202,8 +1197,8 @@ class TestGetChartDataApi(BaseTestChartDataApi):
|
|||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(data["message"], "Error loading data from cache")
|
||||
assert rv.status_code == 422
|
||||
assert data["message"] == "Error loading data from cache"
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
|
||||
|
|
@ -1231,7 +1226,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
|
|||
f"{CHART_DATA_URI}/test-cache-key",
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
assert rv.status_code == 401
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
def test_chart_data_cache_key_error(self):
|
||||
|
|
@ -1244,7 +1239,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
|
|||
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_with_adhoc_column(self):
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ class TestSchema(SupersetTestCase):
|
|||
payload["queries"][0]["row_offset"] = -1
|
||||
with self.assertRaises(ValidationError) as context:
|
||||
_ = ChartDataQueryContextSchema().load(payload)
|
||||
self.assertIn("row_limit", context.exception.messages["queries"][0])
|
||||
self.assertIn("row_offset", context.exception.messages["queries"][0])
|
||||
assert "row_limit" in context.exception.messages["queries"][0]
|
||||
assert "row_offset" in context.exception.messages["queries"][0]
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_null_timegrain(self):
|
||||
|
|
|
|||
|
|
@ -114,15 +114,15 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
def test_login(self):
|
||||
resp = self.get_resp("/login/", data=dict(username="admin", password="general"))
|
||||
self.assertNotIn("User confirmation needed", resp)
|
||||
assert "User confirmation needed" not in resp
|
||||
|
||||
resp = self.get_resp("/logout/", follow_redirects=True)
|
||||
self.assertIn("User confirmation needed", resp)
|
||||
assert "User confirmation needed" in resp
|
||||
|
||||
resp = self.get_resp(
|
||||
"/login/", data=dict(username="admin", password="wrongPassword")
|
||||
)
|
||||
self.assertIn("User confirmation needed", resp)
|
||||
assert "User confirmation needed" in resp
|
||||
|
||||
def test_dashboard_endpoint(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -146,20 +146,17 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
qobj["groupby"] = []
|
||||
cache_key_with_groupby = viz.cache_key(qobj)
|
||||
self.assertNotEqual(cache_key, cache_key_with_groupby)
|
||||
assert cache_key != cache_key_with_groupby
|
||||
|
||||
self.assertNotEqual(
|
||||
viz.cache_key(qobj), viz.cache_key(qobj, time_compare="12 weeks")
|
||||
)
|
||||
assert viz.cache_key(qobj) != viz.cache_key(qobj, time_compare="12 weeks")
|
||||
|
||||
self.assertNotEqual(
|
||||
viz.cache_key(qobj, time_compare="28 days"),
|
||||
viz.cache_key(qobj, time_compare="12 weeks"),
|
||||
assert viz.cache_key(qobj, time_compare="28 days") != viz.cache_key(
|
||||
qobj, time_compare="12 weeks"
|
||||
)
|
||||
|
||||
qobj["inner_from_dttm"] = datetime.datetime(1901, 1, 1)
|
||||
|
||||
self.assertEqual(cache_key_with_groupby, viz.cache_key(qobj))
|
||||
assert cache_key_with_groupby == viz.cache_key(qobj)
|
||||
|
||||
def test_admin_only_menu_views(self):
|
||||
def assert_admin_view_menus_in(role_name, assert_func):
|
||||
|
|
@ -205,9 +202,9 @@ class TestCore(SupersetTestCase):
|
|||
new_slice_id = resp.json["form_data"]["slice_id"]
|
||||
slc = db.session.query(Slice).filter_by(id=new_slice_id).one()
|
||||
|
||||
self.assertEqual(slc.slice_name, copy_name)
|
||||
assert slc.slice_name == copy_name
|
||||
form_data.pop("slice_id") # We don't save the slice id when saving as
|
||||
self.assertEqual(slc.viz.form_data, form_data)
|
||||
assert slc.viz.form_data == form_data
|
||||
|
||||
form_data = {
|
||||
"adhoc_filters": [],
|
||||
|
|
@ -224,8 +221,8 @@ class TestCore(SupersetTestCase):
|
|||
data={"form_data": json.dumps(form_data)},
|
||||
)
|
||||
slc = db.session.query(Slice).filter_by(id=new_slice_id).one()
|
||||
self.assertEqual(slc.slice_name, new_slice_name)
|
||||
self.assertEqual(slc.viz.form_data, form_data)
|
||||
assert slc.slice_name == new_slice_name
|
||||
assert slc.viz.form_data == form_data
|
||||
|
||||
# Cleanup
|
||||
slices = (
|
||||
|
|
@ -261,21 +258,21 @@ class TestCore(SupersetTestCase):
|
|||
logger.info(f"[{name}]/[{method}]: {url}")
|
||||
print(f"[{name}]/[{method}]: {url}")
|
||||
resp = self.client.get(url)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_add_slice(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
# assert that /chart/add responds with 200
|
||||
url = "/chart/add"
|
||||
resp = self.client.get(url)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_get_user_slices(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
userid = security_manager.find_user("admin").id
|
||||
url = f"/sliceasync/api/read?_flt_0_created_by={userid}"
|
||||
resp = self.client.get(url)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_slices_V2(self):
|
||||
|
|
@ -339,7 +336,7 @@ class TestCore(SupersetTestCase):
|
|||
data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
|
||||
self.client.post(url, data=data)
|
||||
database = superset.utils.database.get_example_database()
|
||||
self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted)
|
||||
assert sqlalchemy_uri_decrypted == database.sqlalchemy_uri_decrypted
|
||||
|
||||
# Need to clean up after ourselves
|
||||
database.impersonate_user = False
|
||||
|
|
@ -355,9 +352,9 @@ class TestCore(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
slc = self.get_slice("Top 10 Girl Name Share")
|
||||
data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
|
||||
self.assertEqual(
|
||||
data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
|
||||
)
|
||||
assert data == [
|
||||
{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}
|
||||
]
|
||||
|
||||
data = self.get_json_resp(
|
||||
"/superset/warm_up_cache?table_name=energy_usage&db_name=main"
|
||||
|
|
@ -415,29 +412,29 @@ class TestCore(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
|
||||
resp = self.client.get("/kv/10001/")
|
||||
self.assertEqual(404, resp.status_code)
|
||||
assert 404 == resp.status_code
|
||||
|
||||
value = json.dumps({"data": "this is a test"})
|
||||
resp = self.client.post("/kv/store/", data=dict(data=value))
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@with_feature_flags(KV_STORE=True)
|
||||
def test_kv_enabled(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
resp = self.client.get("/kv/10001/")
|
||||
self.assertEqual(404, resp.status_code)
|
||||
assert 404 == resp.status_code
|
||||
|
||||
value = json.dumps({"data": "this is a test"})
|
||||
resp = self.client.post("/kv/store/", data=dict(data=value))
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
assert resp.status_code == 200
|
||||
kv = db.session.query(models.KeyValue).first()
|
||||
kv_value = kv.value
|
||||
self.assertEqual(json.loads(value), json.loads(kv_value))
|
||||
assert json.loads(value) == json.loads(kv_value)
|
||||
|
||||
resp = self.client.get(f"/kv/{kv.id}/")
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8")))
|
||||
assert resp.status_code == 200
|
||||
assert json.loads(value) == json.loads(resp.data.decode("utf-8"))
|
||||
|
||||
def test_gamma(self):
|
||||
self.login(GAMMA_USERNAME)
|
||||
|
|
@ -451,7 +448,7 @@ class TestCore(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
sql = "SELECT '{{ 1+1 }}' as test"
|
||||
data = self.run_sql(sql, "fdaklj3ws")
|
||||
self.assertEqual(data["data"][0]["test"], "2")
|
||||
assert data["data"][0]["test"] == "2"
|
||||
|
||||
def test_fetch_datasource_metadata(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -466,7 +463,7 @@ class TestCore(SupersetTestCase):
|
|||
"id",
|
||||
]
|
||||
for k in keys:
|
||||
self.assertIn(k, resp.keys())
|
||||
assert k in resp.keys()
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_slice_id_is_always_logged_correctly_on_web_request(self):
|
||||
|
|
@ -475,7 +472,7 @@ class TestCore(SupersetTestCase):
|
|||
slc = db.session.query(Slice).filter_by(slice_name="Girls").one()
|
||||
qry = db.session.query(models.Log).filter_by(slice_id=slc.id)
|
||||
self.get_resp(slc.slice_url)
|
||||
self.assertEqual(1, qry.count())
|
||||
assert 1 == qry.count()
|
||||
|
||||
def create_sample_csvfile(self, filename: str, content: list[str]) -> None:
|
||||
with open(filename, "w+") as test_file:
|
||||
|
|
@ -490,7 +487,7 @@ class TestCore(SupersetTestCase):
|
|||
database.allow_file_upload = True
|
||||
db.session.commit()
|
||||
add_datasource_page = self.get_resp("/databaseview/list/")
|
||||
self.assertIn("Upload a CSV", add_datasource_page)
|
||||
assert "Upload a CSV" in add_datasource_page
|
||||
|
||||
def test_dataframe_timezone(self):
|
||||
tz = pytz.FixedOffset(60)
|
||||
|
|
@ -502,15 +499,15 @@ class TestCore(SupersetTestCase):
|
|||
df = results.to_pandas_df()
|
||||
data = dataframe.df_to_records(df)
|
||||
json_str = json.dumps(data, default=json.pessimistic_json_iso_dttm_ser)
|
||||
self.assertDictEqual(
|
||||
self.assertDictEqual( # noqa: PT009
|
||||
data[0], {"data": pd.Timestamp("2017-11-18 21:53:00.219225+0100", tz=tz)}
|
||||
)
|
||||
self.assertDictEqual(
|
||||
self.assertDictEqual( # noqa: PT009
|
||||
data[1], {"data": pd.Timestamp("2017-11-18 22:06:30+0100", tz=tz)}
|
||||
)
|
||||
self.assertEqual(
|
||||
json_str,
|
||||
'[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]',
|
||||
assert (
|
||||
json_str
|
||||
== '[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]'
|
||||
)
|
||||
|
||||
def test_mssql_engine_spec_pymssql(self):
|
||||
|
|
@ -524,11 +521,12 @@ class TestCore(SupersetTestCase):
|
|||
)
|
||||
df = results.to_pandas_df()
|
||||
data = dataframe.df_to_records(df)
|
||||
self.assertEqual(len(data), 2)
|
||||
self.assertEqual(
|
||||
data[0],
|
||||
{"col1": 1, "col2": 1, "col3": pd.Timestamp("2017-10-19 23:39:16.660000")},
|
||||
)
|
||||
assert len(data) == 2
|
||||
assert data[0] == {
|
||||
"col1": 1,
|
||||
"col2": 1,
|
||||
"col3": pd.Timestamp("2017-10-19 23:39:16.660000"),
|
||||
}
|
||||
|
||||
def test_comments_in_sqlatable_query(self):
|
||||
clean_query = "SELECT\n '/* val 1 */' AS c1,\n '-- val 2' AS c2\nFROM tbl"
|
||||
|
|
@ -554,9 +552,9 @@ class TestCore(SupersetTestCase):
|
|||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(
|
||||
data["errors"][0]["message"],
|
||||
"The dataset associated with this chart no longer exists",
|
||||
assert (
|
||||
data["errors"][0]["message"]
|
||||
== "The dataset associated with this chart no longer exists"
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
|
@ -579,8 +577,8 @@ class TestCore(SupersetTestCase):
|
|||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(data["rowcount"], 2)
|
||||
assert rv.status_code == 200
|
||||
assert data["rowcount"] == 2
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_explore_json_dist_bar_order(self):
|
||||
|
|
@ -741,7 +739,7 @@ class TestCore(SupersetTestCase):
|
|||
"/superset/explore_json/?results=true",
|
||||
data={"form_data": json.dumps(form_data)},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
|
|
@ -780,8 +778,8 @@ class TestCore(SupersetTestCase):
|
|||
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(data["rowcount"], 2)
|
||||
assert rv.status_code == 200
|
||||
assert data["rowcount"] == 2
|
||||
|
||||
@mock.patch(
|
||||
"superset.utils.cache_manager.CacheManager.cache",
|
||||
|
|
@ -814,7 +812,7 @@ class TestCore(SupersetTestCase):
|
|||
mock_cache.return_value = MockCache()
|
||||
|
||||
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_explore_json_data_invalid_cache_key(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -822,8 +820,8 @@ class TestCore(SupersetTestCase):
|
|||
rv = self.client.get(f"/superset/explore_json/data/{cache_key}")
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
self.assertEqual(data["error"], "Cached data not found")
|
||||
assert rv.status_code == 404
|
||||
assert data["error"] == "Cached data not found"
|
||||
|
||||
def test_results_default_deserialization(self):
|
||||
use_new_deserialization = False
|
||||
|
|
@ -863,14 +861,14 @@ class TestCore(SupersetTestCase):
|
|||
serialized_payload = sql_lab._serialize_payload(
|
||||
payload, use_new_deserialization
|
||||
)
|
||||
self.assertIsInstance(serialized_payload, str)
|
||||
assert isinstance(serialized_payload, str)
|
||||
|
||||
query_mock = mock.Mock()
|
||||
deserialized_payload = superset.views.utils._deserialize_results_payload(
|
||||
serialized_payload, query_mock, use_new_deserialization
|
||||
)
|
||||
|
||||
self.assertDictEqual(deserialized_payload, payload)
|
||||
self.assertDictEqual(deserialized_payload, payload) # noqa: PT009
|
||||
query_mock.assert_not_called()
|
||||
|
||||
def test_results_msgpack_deserialization(self):
|
||||
|
|
@ -911,7 +909,7 @@ class TestCore(SupersetTestCase):
|
|||
serialized_payload = sql_lab._serialize_payload(
|
||||
payload, use_new_deserialization
|
||||
)
|
||||
self.assertIsInstance(serialized_payload, bytes)
|
||||
assert isinstance(serialized_payload, bytes)
|
||||
|
||||
with mock.patch.object(
|
||||
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
|
||||
|
|
@ -925,7 +923,7 @@ class TestCore(SupersetTestCase):
|
|||
df = results.to_pandas_df()
|
||||
payload["data"] = dataframe.df_to_records(df)
|
||||
|
||||
self.assertDictEqual(deserialized_payload, payload)
|
||||
self.assertDictEqual(deserialized_payload, payload) # noqa: PT009
|
||||
expand_data.assert_called_once()
|
||||
|
||||
@mock.patch.dict(
|
||||
|
|
@ -960,7 +958,7 @@ class TestCore(SupersetTestCase):
|
|||
]
|
||||
for url in urls:
|
||||
data = self.get_resp(url)
|
||||
self.assertTrue(html_string in data)
|
||||
assert html_string in data
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -991,7 +989,7 @@ class TestCore(SupersetTestCase):
|
|||
tab_state_id = resp["id"]
|
||||
payload = self.get_json_resp(f"/tabstateview/{tab_state_id}")
|
||||
|
||||
self.assertEqual(payload["label"], "Untitled Query foo")
|
||||
assert payload["label"] == "Untitled Query foo"
|
||||
|
||||
def test_tabstate_update(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -1014,87 +1012,87 @@ class TestCore(SupersetTestCase):
|
|||
client_id = "asdfasdf"
|
||||
data = {"sql": json.dumps("select 1"), "latest_query_id": json.dumps(client_id)}
|
||||
response = self.client.put(f"/tabstateview/{tab_state_id}", data=data)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response.json["error"], "Bad request")
|
||||
assert response.status_code == 400
|
||||
assert response.json["error"] == "Bad request"
|
||||
# generate query
|
||||
db.session.add(Query(client_id=client_id, database_id=1))
|
||||
db.session.commit()
|
||||
# update tab state with a valid client_id
|
||||
response = self.client.put(f"/tabstateview/{tab_state_id}", data=data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
# nulls should be ok too
|
||||
data["latest_query_id"] = "null"
|
||||
response = self.client.put(f"/tabstateview/{tab_state_id}", data=data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_virtual_table_explore_visibility(self):
|
||||
# test that default visibility it set to True
|
||||
database = superset.utils.database.get_example_database()
|
||||
self.assertEqual(database.allows_virtual_table_explore, True)
|
||||
assert database.allows_virtual_table_explore is True
|
||||
|
||||
# test that visibility is disabled when extra is set to False
|
||||
extra = database.get_extra()
|
||||
extra["allows_virtual_table_explore"] = False
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.allows_virtual_table_explore, False)
|
||||
assert database.allows_virtual_table_explore is False
|
||||
|
||||
# test that visibility is enabled when extra is set to True
|
||||
extra = database.get_extra()
|
||||
extra["allows_virtual_table_explore"] = True
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.allows_virtual_table_explore, True)
|
||||
assert database.allows_virtual_table_explore is True
|
||||
|
||||
# test that visibility is not broken with bad values
|
||||
extra = database.get_extra()
|
||||
extra["allows_virtual_table_explore"] = "trash value"
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.allows_virtual_table_explore, True)
|
||||
assert database.allows_virtual_table_explore is True
|
||||
|
||||
def test_data_preview_visibility(self):
|
||||
# test that default visibility is allowed
|
||||
database = utils.get_example_database()
|
||||
self.assertEqual(database.disable_data_preview, False)
|
||||
assert database.disable_data_preview is False
|
||||
|
||||
# test that visibility is disabled when extra is set to true
|
||||
extra = database.get_extra()
|
||||
extra["disable_data_preview"] = True
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.disable_data_preview, True)
|
||||
assert database.disable_data_preview is True
|
||||
|
||||
# test that visibility is enabled when extra is set to false
|
||||
extra = database.get_extra()
|
||||
extra["disable_data_preview"] = False
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.disable_data_preview, False)
|
||||
assert database.disable_data_preview is False
|
||||
|
||||
# test that visibility is not broken with bad values
|
||||
extra = database.get_extra()
|
||||
extra["disable_data_preview"] = "trash value"
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.disable_data_preview, False)
|
||||
assert database.disable_data_preview is False
|
||||
|
||||
def test_disable_drill_to_detail(self):
|
||||
# test that disable_drill_to_detail is False by default
|
||||
database = utils.get_example_database()
|
||||
self.assertEqual(database.disable_drill_to_detail, False)
|
||||
assert database.disable_drill_to_detail is False
|
||||
|
||||
# test that disable_drill_to_detail can be set to True
|
||||
extra = database.get_extra()
|
||||
extra["disable_drill_to_detail"] = True
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.disable_drill_to_detail, True)
|
||||
assert database.disable_drill_to_detail is True
|
||||
|
||||
# test that disable_drill_to_detail can be set to False
|
||||
extra = database.get_extra()
|
||||
extra["disable_drill_to_detail"] = False
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.disable_drill_to_detail, False)
|
||||
assert database.disable_drill_to_detail is False
|
||||
|
||||
# test that disable_drill_to_detail is not broken with bad values
|
||||
extra = database.get_extra()
|
||||
extra["disable_drill_to_detail"] = "trash value"
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.disable_drill_to_detail, False)
|
||||
assert database.disable_drill_to_detail is False
|
||||
|
||||
def test_explore_database_id(self):
|
||||
database = superset.utils.database.get_example_database()
|
||||
|
|
@ -1102,13 +1100,13 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
# test that explore_database_id is the regular database
|
||||
# id if none is set in the extra
|
||||
self.assertEqual(database.explore_database_id, database.id)
|
||||
assert database.explore_database_id == database.id
|
||||
|
||||
# test that explore_database_id is correct if the extra is set
|
||||
extra = database.get_extra()
|
||||
extra["explore_database_id"] = explore_database.id
|
||||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.explore_database_id, explore_database.id)
|
||||
assert database.explore_database_id == explore_database.id
|
||||
|
||||
def test_get_column_names_from_metric(self):
|
||||
simple_metric = {
|
||||
|
|
@ -1146,7 +1144,7 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
self.login(ADMIN_USERNAME)
|
||||
data = self.get_resp(url)
|
||||
self.assertIn("Error message", data)
|
||||
assert "Error message" in data
|
||||
|
||||
# Assert we can handle a driver exception at the mutator level
|
||||
exception = SQLAlchemyError("Error message")
|
||||
|
|
@ -1156,7 +1154,7 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
self.login(ADMIN_USERNAME)
|
||||
data = self.get_resp(url)
|
||||
self.assertIn("Error message", data)
|
||||
assert "Error message" in data
|
||||
|
||||
@pytest.mark.skip(
|
||||
"TODO This test was wrong - 'Error message' was in the language pack"
|
||||
|
|
@ -1176,7 +1174,7 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
self.login(ADMIN_USERNAME)
|
||||
data = self.get_resp(url)
|
||||
self.assertIn("Error message", data)
|
||||
assert "Error message" in data
|
||||
|
||||
# Assert we can handle a driver exception at the mutator level
|
||||
exception = SQLAlchemyError("Error message")
|
||||
|
|
@ -1186,7 +1184,7 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
self.login(ADMIN_USERNAME)
|
||||
data = self.get_resp(url)
|
||||
self.assertIn("Error message", data)
|
||||
assert "Error message" in data
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
@mock.patch("superset.commands.explore.form_data.create.CreateFormDataCommand.run")
|
||||
|
|
@ -1200,9 +1198,7 @@ class TestCore(SupersetTestCase):
|
|||
rv = self.client.get(
|
||||
f"/superset/explore/?form_data={quote(json.dumps(form_data))}"
|
||||
)
|
||||
self.assertEqual(
|
||||
rv.headers["Location"], f"/explore/?form_data_key={random_key}"
|
||||
)
|
||||
assert rv.headers["Location"] == f"/explore/?form_data_key={random_key}"
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_has_table(self):
|
||||
|
|
@ -1223,7 +1219,7 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
expected_url = "/superset/dashboard/1?permalink_key=123&standalone=3"
|
||||
|
||||
self.assertEqual(resp.headers["Location"], expected_url)
|
||||
assert resp.headers["Location"] == expected_url
|
||||
assert resp.status_code == 302
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class TestDashboard(SupersetTestCase):
|
|||
url = "/dashboard/new/"
|
||||
response = self.client.get(url, follow_redirects=False)
|
||||
dash_count_after = db.session.query(func.count(Dashboard.id)).first()[0]
|
||||
self.assertEqual(dash_count_before + 1, dash_count_after)
|
||||
assert dash_count_before + 1 == dash_count_after
|
||||
group = re.match(
|
||||
r"\/superset\/dashboard\/([0-9]*)\/\?edit=true",
|
||||
response.headers["Location"],
|
||||
|
|
@ -145,25 +145,25 @@ class TestDashboard(SupersetTestCase):
|
|||
self.logout()
|
||||
|
||||
resp = self.get_resp("/api/v1/chart/")
|
||||
self.assertNotIn("birth_names", resp)
|
||||
assert "birth_names" not in resp
|
||||
|
||||
resp = self.get_resp("/api/v1/dashboard/")
|
||||
self.assertNotIn("/superset/dashboard/births/", resp)
|
||||
assert "/superset/dashboard/births/" not in resp
|
||||
|
||||
self.grant_public_access_to_table(table)
|
||||
|
||||
# Try access after adding appropriate permissions.
|
||||
self.assertIn("birth_names", self.get_resp("/api/v1/chart/"))
|
||||
assert "birth_names" in self.get_resp("/api/v1/chart/")
|
||||
|
||||
resp = self.get_resp("/api/v1/dashboard/")
|
||||
self.assertIn("/superset/dashboard/births/", resp)
|
||||
assert "/superset/dashboard/births/" in resp
|
||||
|
||||
# Confirm that public doesn't have access to other datasets.
|
||||
resp = self.get_resp("/api/v1/chart/")
|
||||
self.assertNotIn("wb_health_population", resp)
|
||||
assert "wb_health_population" not in resp
|
||||
|
||||
resp = self.get_resp("/api/v1/dashboard/")
|
||||
self.assertNotIn("/superset/dashboard/world_health/", resp)
|
||||
assert "/superset/dashboard/world_health/" not in resp
|
||||
|
||||
# Cleanup
|
||||
self.revoke_public_access_to_table(table)
|
||||
|
|
@ -224,8 +224,8 @@ class TestDashboard(SupersetTestCase):
|
|||
db.session.delete(hidden_dash)
|
||||
db.session.commit()
|
||||
|
||||
self.assertIn(f"/superset/dashboard/{my_dash_slug}/", resp)
|
||||
self.assertNotIn(f"/superset/dashboard/{not_my_dash_slug}/", resp)
|
||||
assert f"/superset/dashboard/{my_dash_slug}/" in resp
|
||||
assert f"/superset/dashboard/{not_my_dash_slug}/" not in resp
|
||||
|
||||
def test_user_can_not_view_unpublished_dash(self):
|
||||
admin_user = security_manager.find_user("admin")
|
||||
|
|
@ -247,7 +247,7 @@ class TestDashboard(SupersetTestCase):
|
|||
db.session.delete(dash)
|
||||
db.session.commit()
|
||||
|
||||
self.assertNotIn(f"/superset/dashboard/{slug}/", resp)
|
||||
assert f"/superset/dashboard/{slug}/" not in resp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -63,19 +63,19 @@ class DashboardTestCase(SupersetTestCase):
|
|||
|
||||
def assert_permission_was_created(self, dashboard):
|
||||
view_menu = security_manager.find_view_menu(dashboard.view_name)
|
||||
self.assertIsNotNone(view_menu)
|
||||
self.assertEqual(len(security_manager.find_permissions_view_menu(view_menu)), 1)
|
||||
assert view_menu is not None
|
||||
assert len(security_manager.find_permissions_view_menu(view_menu)) == 1
|
||||
|
||||
def assert_permission_kept_and_changed(self, updated_dashboard, excepted_view_id):
|
||||
view_menu_after_title_changed = security_manager.find_view_menu(
|
||||
updated_dashboard.view_name
|
||||
)
|
||||
self.assertIsNotNone(view_menu_after_title_changed)
|
||||
self.assertEqual(view_menu_after_title_changed.id, excepted_view_id)
|
||||
assert view_menu_after_title_changed is not None
|
||||
assert view_menu_after_title_changed.id == excepted_view_id
|
||||
|
||||
def assert_permissions_were_deleted(self, deleted_dashboard):
|
||||
view_menu = security_manager.find_view_menu(deleted_dashboard.view_name)
|
||||
self.assertIsNone(view_menu)
|
||||
assert view_menu is None
|
||||
|
||||
def clean_created_objects(self):
|
||||
with app.test_request_context():
|
||||
|
|
|
|||
|
|
@ -87,12 +87,12 @@ class TestDashboardDAO(SupersetTestCase):
|
|||
"duplicate_slices": False,
|
||||
}
|
||||
dash = DashboardDAO.copy_dashboard(original_dash, dash_data)
|
||||
self.assertNotEqual(dash.id, original_dash.id)
|
||||
self.assertEqual(len(dash.position), len(original_dash.position))
|
||||
self.assertEqual(dash.dashboard_title, "copied dash")
|
||||
self.assertEqual(dash.css, "<css>")
|
||||
self.assertEqual(dash.owners, [security_manager.find_user("admin")])
|
||||
self.assertCountEqual(dash.slices, original_dash.slices)
|
||||
assert dash.id != original_dash.id
|
||||
assert len(dash.position) == len(original_dash.position)
|
||||
assert dash.dashboard_title == "copied dash"
|
||||
assert dash.css == "<css>"
|
||||
assert dash.owners == [security_manager.find_user("admin")]
|
||||
self.assertCountEqual(dash.slices, original_dash.slices) # noqa: PT009
|
||||
|
||||
db.session.delete(dash)
|
||||
db.session.commit()
|
||||
|
|
@ -118,9 +118,7 @@ class TestDashboardDAO(SupersetTestCase):
|
|||
"duplicate_slices": False,
|
||||
}
|
||||
dash = DashboardDAO.copy_dashboard(original_dash, dash_data)
|
||||
self.assertEqual(
|
||||
dash.params_dict["native_filter_configuration"], [{"mock": "filter"}]
|
||||
)
|
||||
assert dash.params_dict["native_filter_configuration"] == [{"mock": "filter"}]
|
||||
|
||||
db.session.delete(dash)
|
||||
db.session.commit()
|
||||
|
|
@ -141,15 +139,15 @@ class TestDashboardDAO(SupersetTestCase):
|
|||
"duplicate_slices": True,
|
||||
}
|
||||
dash = DashboardDAO.copy_dashboard(original_dash, dash_data)
|
||||
self.assertNotEqual(dash.id, original_dash.id)
|
||||
self.assertEqual(len(dash.position), len(original_dash.position))
|
||||
self.assertEqual(dash.dashboard_title, "copied dash")
|
||||
self.assertEqual(dash.css, "<css>")
|
||||
self.assertEqual(dash.owners, [security_manager.find_user("admin")])
|
||||
self.assertEqual(len(dash.slices), len(original_dash.slices))
|
||||
assert dash.id != original_dash.id
|
||||
assert len(dash.position) == len(original_dash.position)
|
||||
assert dash.dashboard_title == "copied dash"
|
||||
assert dash.css == "<css>"
|
||||
assert dash.owners == [security_manager.find_user("admin")]
|
||||
assert len(dash.slices) == len(original_dash.slices)
|
||||
for original_slc in original_dash.slices:
|
||||
for slc in dash.slices:
|
||||
self.assertNotEqual(slc.id, original_slc.id)
|
||||
assert slc.id != original_slc.id
|
||||
|
||||
for slc in dash.slices:
|
||||
db.session.delete(slc)
|
||||
|
|
|
|||
|
|
@ -109,8 +109,8 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
|
|||
get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405
|
||||
|
||||
# assert
|
||||
self.assertIn(my_owned_dashboard.url, get_dashboards_response)
|
||||
self.assertNotIn(not_my_owned_dashboard.url, get_dashboards_response)
|
||||
assert my_owned_dashboard.url in get_dashboards_response
|
||||
assert not_my_owned_dashboard.url not in get_dashboards_response
|
||||
|
||||
def test_get_dashboards__owners_can_view_empty_dashboard(self):
|
||||
# arrange
|
||||
|
|
@ -123,7 +123,7 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
|
|||
get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405
|
||||
|
||||
# assert
|
||||
self.assertNotIn(dashboard_url, get_dashboards_response)
|
||||
assert dashboard_url not in get_dashboards_response
|
||||
|
||||
def test_get_dashboards__user_can_not_view_unpublished_dash(self):
|
||||
# arrange
|
||||
|
|
@ -139,9 +139,7 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
|
|||
get_dashboards_response_as_gamma = self.get_resp(DASHBOARDS_API_URL) # noqa: F405
|
||||
|
||||
# assert
|
||||
self.assertNotIn(
|
||||
admin_and_draft_dashboard.url, get_dashboards_response_as_gamma
|
||||
)
|
||||
assert admin_and_draft_dashboard.url not in get_dashboards_response_as_gamma
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice", "load_dashboard")
|
||||
def test_get_dashboards__users_can_view_permitted_dashboard(self):
|
||||
|
|
@ -172,8 +170,8 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
|
|||
get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405
|
||||
|
||||
# assert
|
||||
self.assertIn(second_dash.url, get_dashboards_response)
|
||||
self.assertIn(first_dash.url, get_dashboards_response)
|
||||
assert second_dash.url in get_dashboards_response
|
||||
assert first_dash.url in get_dashboards_response
|
||||
finally:
|
||||
self.revoke_public_access_to_table(accessed_table)
|
||||
|
||||
|
|
@ -193,5 +191,5 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
|
|||
rv = self.client.get(uri)
|
||||
self.assert200(rv)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(0, data["count"])
|
||||
assert 0 == data["count"]
|
||||
DashboardDAO.delete([dashboard])
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class TestDashboardRoleBasedSecurity(BaseTestDashboardSecurity):
|
|||
|
||||
request_payload = get_query_context("birth_names")
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
|
||||
# post
|
||||
revoke_access_to_dashboard(dashboard_to_access, new_role) # noqa: F405
|
||||
|
|
@ -480,12 +480,12 @@ class TestDashboardRoleBasedSecurity(BaseTestDashboardSecurity):
|
|||
|
||||
self.login(GAMMA_USERNAME)
|
||||
rv = self.client.post(uri, json=data)
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
self.logout()
|
||||
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.client.post(uri, json=data)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
target = (
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_columns = [
|
||||
"allow_ctas",
|
||||
|
|
@ -216,8 +216,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"uuid",
|
||||
]
|
||||
|
||||
self.assertGreater(response["count"], 0)
|
||||
self.assertEqual(list(response["result"][0].keys()), expected_columns)
|
||||
assert response["count"] > 0
|
||||
assert list(response["result"][0].keys()) == expected_columns
|
||||
|
||||
def test_get_items_filter(self):
|
||||
"""
|
||||
|
|
@ -241,8 +241,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(response["count"], len(dbs))
|
||||
assert rv.status_code == 200
|
||||
assert response["count"] == len(dbs)
|
||||
|
||||
# Cleanup
|
||||
db.session.delete(test_database)
|
||||
|
|
@ -255,9 +255,9 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(GAMMA_USERNAME)
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["count"], 0)
|
||||
assert response["count"] == 0
|
||||
|
||||
def test_create_database(self):
|
||||
"""
|
||||
|
|
@ -284,7 +284,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
assert model.configuration_method == ConfigurationMethod.SQLALCHEMY_FORM
|
||||
|
|
@ -326,14 +326,14 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(response.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX")
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX"
|
||||
assert model_ssh_tunnel.database_id == response.get("id")
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
|
|
@ -385,10 +385,10 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(
|
||||
response.get("message"),
|
||||
"A database port is required when connecting via SSH Tunnel.",
|
||||
assert rv.status_code == 400
|
||||
assert (
|
||||
response.get("message")
|
||||
== "A database port is required when connecting via SSH Tunnel."
|
||||
)
|
||||
|
||||
@mock.patch(
|
||||
|
|
@ -434,19 +434,19 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
assert model_ssh_tunnel.database_id == response_update.get("id")
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
|
|
@ -500,15 +500,15 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response_create = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
|
||||
uri = "api/v1/database/{}".format(response_create.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(
|
||||
response.get("message"),
|
||||
"A database port is required when connecting via SSH Tunnel.",
|
||||
assert rv.status_code == 400
|
||||
assert (
|
||||
response.get("message")
|
||||
== "A database port is required when connecting via SSH Tunnel."
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
|
|
@ -563,19 +563,19 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
assert model_ssh_tunnel.database_id == response_update.get("id")
|
||||
|
||||
database_data_with_ssh_tunnel_null = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
|
|
@ -585,7 +585,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
|
|
@ -651,30 +651,28 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
self.assertEqual(model_ssh_tunnel.username, "foo")
|
||||
assert model_ssh_tunnel.database_id == response.get("id")
|
||||
assert model_ssh_tunnel.username == "foo"
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
self.assertEqual(
|
||||
response_update.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX"
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.username, "Test")
|
||||
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
|
||||
self.assertEqual(model_ssh_tunnel.server_port, 8080)
|
||||
assert model_ssh_tunnel.database_id == response_update.get("id")
|
||||
assert response_update.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX"
|
||||
assert model_ssh_tunnel.username == "Test"
|
||||
assert model_ssh_tunnel.server_address == "123.132.123.1"
|
||||
assert model_ssh_tunnel.server_port == 8080
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
|
|
@ -715,13 +713,13 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
assert model_ssh_tunnel.database_id == response.get("id")
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
|
|
@ -769,7 +767,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
|
|
@ -777,7 +775,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
.one_or_none()
|
||||
)
|
||||
assert model_ssh_tunnel is None
|
||||
self.assertEqual(response, fail_message)
|
||||
assert response == fail_message
|
||||
|
||||
# Check that rollback was called
|
||||
mock_rollback.assert_called()
|
||||
|
|
@ -824,14 +822,14 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
|
||||
assert model_ssh_tunnel.database_id == response.get("id")
|
||||
assert response.get("result")["ssh_tunnel"] == response_ssh_tunnel
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
|
|
@ -866,8 +864,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, {"message": "SSH Tunneling is not enabled"})
|
||||
assert rv.status_code == 400
|
||||
assert response == {"message": "SSH Tunneling is not enabled"}
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
|
|
@ -897,7 +895,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{database.id}/table/{table_name}/null/"
|
||||
rv = self.client.get(uri)
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_create_database_invalid_configuration_method(self):
|
||||
"""
|
||||
|
|
@ -959,7 +957,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
assert rv.status_code == 201
|
||||
self.assertIn("sqlalchemy_form", response["result"]["configuration_method"])
|
||||
assert "sqlalchemy_form" in response["result"]["configuration_method"]
|
||||
|
||||
def test_create_database_server_cert_validate(self):
|
||||
"""
|
||||
|
|
@ -981,8 +979,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": {"server_cert": ["Invalid certificate"]}}
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, expected_response)
|
||||
assert rv.status_code == 400
|
||||
assert response == expected_response
|
||||
|
||||
def test_create_database_json_validate(self):
|
||||
"""
|
||||
|
|
@ -1016,8 +1014,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
],
|
||||
}
|
||||
}
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, expected_response)
|
||||
assert rv.status_code == 400
|
||||
assert response == expected_response
|
||||
|
||||
def test_create_database_extra_metadata_validate(self):
|
||||
"""
|
||||
|
|
@ -1052,8 +1050,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
}
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, expected_response)
|
||||
assert rv.status_code == 400
|
||||
assert response == expected_response
|
||||
|
||||
def test_create_database_unique_validate(self):
|
||||
"""
|
||||
|
|
@ -1078,8 +1076,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"database_name": "A database with the same name already exists."
|
||||
}
|
||||
}
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(response, expected_response)
|
||||
assert rv.status_code == 422
|
||||
assert response == expected_response
|
||||
|
||||
def test_create_database_uri_validate(self):
|
||||
"""
|
||||
|
|
@ -1095,11 +1093,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertIn(
|
||||
"Invalid connection string",
|
||||
response["message"]["sqlalchemy_uri"][0],
|
||||
)
|
||||
assert rv.status_code == 400
|
||||
assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0]
|
||||
|
||||
@mock.patch(
|
||||
"superset.views.core.app.config",
|
||||
|
|
@ -1127,8 +1122,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
}
|
||||
self.assertEqual(response_data, expected_response)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
assert response_data == expected_response
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_create_database_conn_fail(self):
|
||||
"""
|
||||
|
|
@ -1192,11 +1187,11 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
expected_response_postgres = {
|
||||
"errors": [dataclasses.asdict(superset_error_postgres)]
|
||||
}
|
||||
self.assertEqual(response.status_code, 500)
|
||||
assert response.status_code == 500
|
||||
if example_db.backend == "mysql":
|
||||
self.assertEqual(response_data, expected_response_mysql)
|
||||
assert response_data == expected_response_mysql
|
||||
else:
|
||||
self.assertEqual(response_data, expected_response_postgres)
|
||||
assert response_data == expected_response_postgres
|
||||
|
||||
def test_update_database(self):
|
||||
"""
|
||||
|
|
@ -1213,7 +1208,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
}
|
||||
uri = f"api/v1/database/{test_database.id}"
|
||||
rv = self.client.put(uri, json=database_data)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(test_database.id)
|
||||
db.session.delete(model)
|
||||
|
|
@ -1242,8 +1237,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
expected_response = {
|
||||
"message": "Connection failed, please check your connection settings"
|
||||
}
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(response, expected_response)
|
||||
assert rv.status_code == 422
|
||||
assert response == expected_response
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(test_database.id)
|
||||
db.session.delete(model)
|
||||
|
|
@ -1271,8 +1266,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"database_name": "A database with the same name already exists."
|
||||
}
|
||||
}
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(response, expected_response)
|
||||
assert rv.status_code == 422
|
||||
assert response == expected_response
|
||||
# Cleanup
|
||||
db.session.delete(test_database1)
|
||||
db.session.delete(test_database2)
|
||||
|
|
@ -1286,7 +1281,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
database_data = {"database_name": "test-database-updated"}
|
||||
uri = "api/v1/database/invalid"
|
||||
rv = self.client.put(uri, json=database_data)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_update_database_uri_validate(self):
|
||||
"""
|
||||
|
|
@ -1305,11 +1300,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{test_database.id}"
|
||||
rv = self.client.put(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertIn(
|
||||
"Invalid connection string",
|
||||
response["message"]["sqlalchemy_uri"][0],
|
||||
)
|
||||
assert rv.status_code == 400
|
||||
assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0]
|
||||
|
||||
db.session.delete(test_database)
|
||||
db.session.commit()
|
||||
|
|
@ -1369,9 +1361,9 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/database/{database_id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
model = db.session.query(Database).get(database_id)
|
||||
self.assertEqual(model, None)
|
||||
assert model is None
|
||||
|
||||
def test_delete_database_not_found(self):
|
||||
"""
|
||||
|
|
@ -1381,7 +1373,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/database/{max_id + 1}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("create_database_with_dataset")
|
||||
def test_delete_database_with_datasets(self):
|
||||
|
|
@ -1391,7 +1383,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/database/{self._database.id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
|
||||
@pytest.mark.usefixtures("create_database_with_report")
|
||||
def test_delete_database_with_report(self):
|
||||
|
|
@ -1407,11 +1399,11 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{database.id}"
|
||||
rv = self.client.delete(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
expected_response = {
|
||||
"message": "There are associated alerts or reports: report_with_database"
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_get_table_metadata(self):
|
||||
|
|
@ -1422,12 +1414,12 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["name"], "birth_names")
|
||||
self.assertIsNone(response["comment"])
|
||||
self.assertTrue(len(response["columns"]) > 5)
|
||||
self.assertTrue(response.get("selectStar").startswith("SELECT"))
|
||||
assert response["name"] == "birth_names"
|
||||
assert response["comment"] is None
|
||||
assert len(response["columns"]) > 5
|
||||
assert response.get("selectStar").startswith("SELECT")
|
||||
|
||||
def test_info_security_database(self):
|
||||
"""
|
||||
|
|
@ -1456,11 +1448,11 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
uri = "api/v1/database/some_database/table/some_table/some_schema/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_invalid_table_table_metadata(self):
|
||||
"""
|
||||
|
|
@ -1472,10 +1464,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.get(uri)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
if example_db.backend == "sqlite":
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(
|
||||
data,
|
||||
{
|
||||
assert rv.status_code == 200
|
||||
assert data == {
|
||||
"columns": [],
|
||||
"comment": None,
|
||||
"foreignKeys": [],
|
||||
|
|
@ -1483,14 +1473,13 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"name": "wrong_table",
|
||||
"primaryKey": {"constrained_columns": None, "name": None},
|
||||
"selectStar": "SELECT\n *\nFROM wrong_table\nLIMIT 100\nOFFSET 0",
|
||||
},
|
||||
)
|
||||
}
|
||||
elif example_db.backend == "mysql":
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(data, {"message": "`wrong_table`"})
|
||||
assert rv.status_code == 422
|
||||
assert data == {"message": "`wrong_table`"}
|
||||
else:
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(data, {"message": "wrong_table"})
|
||||
assert rv.status_code == 422
|
||||
assert data == {"message": "wrong_table"}
|
||||
|
||||
def test_get_table_metadata_no_db_permission(self):
|
||||
"""
|
||||
|
|
@ -1500,7 +1489,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_get_table_extra_metadata_deprecated(self):
|
||||
|
|
@ -1511,9 +1500,9 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/database/{example_db.id}/table_extra/birth_names/null/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response, {})
|
||||
assert response == {}
|
||||
|
||||
def test_get_invalid_database_table_extra_metadata_deprecated(self):
|
||||
"""
|
||||
|
|
@ -1523,11 +1512,11 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/database/{database_id}/table_extra/some_table/some_schema/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
uri = "api/v1/database/some_database/table_extra/some_table/some_schema/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_invalid_table_table_extra_metadata_deprecated(self):
|
||||
"""
|
||||
|
|
@ -1539,8 +1528,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.get(uri)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(data, {})
|
||||
assert rv.status_code == 200
|
||||
assert data == {}
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_get_select_star(self):
|
||||
|
|
@ -1551,7 +1540,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_get_select_star_not_allowed(self):
|
||||
"""
|
||||
|
|
@ -1561,7 +1550,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_select_star_not_found_database(self):
|
||||
"""
|
||||
|
|
@ -1571,7 +1560,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
max_id = db.session.query(func.max(Database.id)).scalar()
|
||||
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_select_star_not_found_table(self):
|
||||
"""
|
||||
|
|
@ -1585,7 +1574,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
|
||||
rv = self.client.get(uri)
|
||||
# TODO(bkyryliuk): investigate why presto returns 500
|
||||
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
|
||||
assert rv.status_code == (404 if example_db.backend != "presto" else 500)
|
||||
|
||||
def test_get_allow_file_upload_filter(self):
|
||||
"""
|
||||
|
|
@ -1952,13 +1941,13 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
|
||||
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(schemas, set(response["result"]))
|
||||
assert schemas == set(response["result"])
|
||||
|
||||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
|
||||
)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(schemas, set(response["result"]))
|
||||
assert schemas == set(response["result"])
|
||||
|
||||
def test_database_schemas_not_found(self):
|
||||
"""
|
||||
|
|
@ -1968,7 +1957,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/schemas/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_database_schemas_invalid_query(self):
|
||||
"""
|
||||
|
|
@ -1979,7 +1968,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
|
||||
def test_database_tables(self):
|
||||
"""
|
||||
|
|
@ -1993,17 +1982,17 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': schema_name})}"
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
if database.backend == "postgresql":
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
schemas = [
|
||||
s[0] for s in database.get_all_table_names_in_schema(None, schema_name)
|
||||
]
|
||||
self.assertEqual(response["count"], len(schemas))
|
||||
assert response["count"] == len(schemas)
|
||||
for option in response["result"]:
|
||||
self.assertEqual(option["extra"], None)
|
||||
self.assertEqual(option["type"], "table")
|
||||
self.assertTrue(option["value"] in schemas)
|
||||
assert option["extra"] is None
|
||||
assert option["type"] == "table"
|
||||
assert option["value"] in schemas
|
||||
|
||||
@patch("superset.utils.log.logger")
|
||||
def test_database_tables_not_found(self, logger_mock):
|
||||
|
|
@ -2014,7 +2003,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/tables/?q={prison.dumps({'schema_name': 'non_existent'})}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
logger_mock.warning.assert_called_once_with(
|
||||
"Database not found.", exc_info=True
|
||||
)
|
||||
|
|
@ -2028,7 +2017,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'force': 'nop'})}"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
|
||||
@patch("superset.utils.log.logger")
|
||||
@mock.patch("superset.security.manager.SupersetSecurityManager.can_access_database")
|
||||
|
|
@ -2046,7 +2035,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': 'main'})}"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
logger_mock.warning.assert_called_once_with("Test Error", exc_info=True)
|
||||
|
||||
def test_test_connection(self):
|
||||
|
|
@ -2074,8 +2063,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
}
|
||||
url = "api/v1/database/test_connection/"
|
||||
rv = self.post_assert_metric(url, data, "test_connection")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
|
||||
assert rv.status_code == 200
|
||||
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
|
||||
|
||||
# validate that the endpoint works with the decrypted sqlalchemy uri
|
||||
data = {
|
||||
|
|
@ -2086,8 +2075,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"server_cert": None,
|
||||
}
|
||||
rv = self.post_assert_metric(url, data, "test_connection")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
|
||||
assert rv.status_code == 200
|
||||
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
|
||||
|
||||
def test_test_connection_failed(self):
|
||||
"""
|
||||
|
|
@ -2103,8 +2092,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
}
|
||||
url = "api/v1/database/test_connection/"
|
||||
rv = self.post_assert_metric(url, data, "test_connection")
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
|
||||
assert rv.status_code == 422
|
||||
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {
|
||||
"errors": [
|
||||
|
|
@ -2123,7 +2112,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
}
|
||||
]
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
data = {
|
||||
"sqlalchemy_uri": "mssql+pymssql://url",
|
||||
|
|
@ -2132,8 +2121,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"server_cert": None,
|
||||
}
|
||||
rv = self.post_assert_metric(url, data, "test_connection")
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
|
||||
assert rv.status_code == 422
|
||||
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {
|
||||
"errors": [
|
||||
|
|
@ -2152,7 +2141,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
}
|
||||
]
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
def test_test_connection_unsafe_uri(self):
|
||||
"""
|
||||
|
|
@ -2169,7 +2158,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
}
|
||||
url = "api/v1/database/test_connection/"
|
||||
rv = self.post_assert_metric(url, data, "test_connection")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {
|
||||
"message": {
|
||||
|
|
@ -2178,7 +2167,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
assert response == expected_response
|
||||
|
||||
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
|
||||
|
||||
|
|
@ -2250,10 +2239,10 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
database = get_example_database()
|
||||
uri = f"api/v1/database/{database.id}/related_objects/"
|
||||
rv = self.get_assert_metric(uri, "related_objects")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["charts"]["count"], 33)
|
||||
self.assertEqual(response["dashboards"]["count"], 3)
|
||||
assert response["charts"]["count"] == 33
|
||||
assert response["dashboards"]["count"] == 3
|
||||
|
||||
def test_get_database_related_objects_not_found(self):
|
||||
"""
|
||||
|
|
@ -2265,13 +2254,13 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{invalid_id}/related_objects/"
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.get_assert_metric(uri, "related_objects")
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
self.logout()
|
||||
self.login(GAMMA_USERNAME)
|
||||
database = get_example_database()
|
||||
uri = f"api/v1/database/{database.id}/related_objects/"
|
||||
rv = self.get_assert_metric(uri, "related_objects")
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_export_database(self):
|
||||
"""
|
||||
|
|
@ -2679,7 +2668,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
.filter(SSHTunnel.database_id == database.id)
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.password, "TEST")
|
||||
assert model_ssh_tunnel.password == "TEST"
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -2797,8 +2786,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
.filter(SSHTunnel.database_id == database.id)
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
|
||||
self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
|
||||
assert model_ssh_tunnel.private_key == "TestPrivateKey"
|
||||
assert model_ssh_tunnel.private_key_password == "TEST"
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -3852,8 +3841,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{example_db.id}/validate_sql/"
|
||||
rv = self.client.post(uri, json=request_payload)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(response["result"], [])
|
||||
assert rv.status_code == 200
|
||||
assert response["result"] == []
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.config.SQL_VALIDATORS_BY_ENGINE",
|
||||
|
|
@ -3878,18 +3867,15 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{example_db.id}/validate_sql/"
|
||||
rv = self.client.post(uri, json=request_payload)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(
|
||||
response["result"],
|
||||
[
|
||||
assert rv.status_code == 200
|
||||
assert response["result"] == [
|
||||
{
|
||||
"end_column": None,
|
||||
"line_number": 1,
|
||||
"message": 'ERROR: syntax error at or near "table1"',
|
||||
"start_column": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.config.SQL_VALIDATORS_BY_ENGINE",
|
||||
|
|
@ -3910,7 +3896,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/"
|
||||
)
|
||||
rv = self.client.post(uri, json=request_payload)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.config.SQL_VALIDATORS_BY_ENGINE",
|
||||
|
|
@ -3932,8 +3918,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
)
|
||||
rv = self.client.post(uri, json=request_payload)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}})
|
||||
assert rv.status_code == 400
|
||||
assert response == {"message": {"sql": ["Field may not be null."]}}
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.config.SQL_VALIDATORS_BY_ENGINE",
|
||||
|
|
@ -3956,10 +3942,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{example_db.id}/validate_sql/"
|
||||
rv = self.client.post(uri, json=request_payload)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(
|
||||
response,
|
||||
{
|
||||
assert rv.status_code == 422
|
||||
assert response == {
|
||||
"errors": [
|
||||
{
|
||||
"message": f"no SQL validator is configured for "
|
||||
|
|
@ -3977,8 +3961,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@mock.patch("superset.commands.database.validate_sql.get_validator_by_name")
|
||||
@mock.patch.dict(
|
||||
|
|
@ -4013,8 +3996,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
# TODO(bkyryliuk): properly handle hive error
|
||||
if get_example_database().backend == "hive":
|
||||
return
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertIn("Kaboom!", response["errors"][0]["message"])
|
||||
assert rv.status_code == 422
|
||||
assert "Kaboom!" in response["errors"][0]["message"]
|
||||
|
||||
def test_get_databases_with_extra_filters(self):
|
||||
"""
|
||||
|
|
@ -4048,14 +4031,14 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri, json={**database_data, "database_name": "dyntest-create-database-1"}
|
||||
)
|
||||
first_response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(
|
||||
uri, json={**database_data, "database_name": "create-database-2"}
|
||||
)
|
||||
second_response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
|
||||
# The filter function
|
||||
def _base_filter(query):
|
||||
|
|
@ -4074,11 +4057,11 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.get(uri)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# All databases must be returned if no filter is present
|
||||
self.assertEqual(data["count"], len(dbs))
|
||||
assert data["count"] == len(dbs)
|
||||
database_names = [item["database_name"] for item in data["result"]]
|
||||
database_names.sort()
|
||||
# All Databases because we are an admin
|
||||
self.assertEqual(database_names, expected_names)
|
||||
assert database_names == expected_names
|
||||
assert rv.status_code == 200
|
||||
# Our filter function wasn't get called
|
||||
base_filter_mock.assert_not_called()
|
||||
|
|
@ -4092,10 +4075,10 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
rv = self.client.get(uri)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Only one database start with dyntest
|
||||
self.assertEqual(data["count"], 1)
|
||||
assert data["count"] == 1
|
||||
database_names = [item["database_name"] for item in data["result"]]
|
||||
# Only the database that starts with tests, even if we are an admin
|
||||
self.assertEqual(database_names, ["dyntest-create-database-1"])
|
||||
assert database_names == ["dyntest-create-database-1"]
|
||||
assert rv.status_code == 200
|
||||
# The filter function is called now that it's defined in our config
|
||||
base_filter_mock.assert_called()
|
||||
|
|
|
|||
|
|
@ -740,7 +740,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
|
|||
.filter(SSHTunnel.database_id == database.id)
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.password, "TEST")
|
||||
assert model_ssh_tunnel.password == "TEST"
|
||||
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
|
@ -787,8 +787,8 @@ class TestImportDatabasesCommand(SupersetTestCase):
|
|||
.filter(SSHTunnel.database_id == database.id)
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
|
||||
self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
|
||||
assert model_ssh_tunnel.private_key == "TestPrivateKey"
|
||||
assert model_ssh_tunnel.private_key_password == "TEST"
|
||||
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
def count_datasets():
|
||||
uri = "api/v1/chart/"
|
||||
rv = self.client.get(uri, "get_list")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = rv.get_json()
|
||||
return data["count"]
|
||||
|
||||
|
|
@ -1342,14 +1342,14 @@ class TestDatasetApi(SupersetTestCase):
|
|||
table_data = {"description": "changed_description"}
|
||||
uri = f"api/v1/dataset/{dataset.id}"
|
||||
rv = self.client.put(uri, json=table_data)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
response = self.get_assert_metric("api/v1/dataset/", "get_list")
|
||||
res = json.loads(response.data.decode("utf-8"))["result"]
|
||||
|
||||
current_dataset = [d for d in res if d["id"] == dataset.id][0]
|
||||
self.assertEqual(current_dataset["description"], "changed_description")
|
||||
self.assertNotIn("username", current_dataset["changed_by"].keys())
|
||||
assert current_dataset["description"] == "changed_description"
|
||||
assert "username" not in current_dataset["changed_by"].keys()
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
|
@ -1364,13 +1364,13 @@ class TestDatasetApi(SupersetTestCase):
|
|||
table_data = {"description": "changed_description"}
|
||||
uri = f"api/v1/dataset/{dataset.id}"
|
||||
rv = self.client.put(uri, json=table_data)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
response = self.get_assert_metric(uri, "get")
|
||||
res = json.loads(response.data.decode("utf-8"))["result"]
|
||||
|
||||
self.assertEqual(res["description"], "changed_description")
|
||||
self.assertNotIn("username", res["changed_by"].keys())
|
||||
assert res["description"] == "changed_description"
|
||||
assert "username" not in res["changed_by"].keys()
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
|
@ -2311,14 +2311,14 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"database_id": get_example_database().id,
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
dataset = (
|
||||
db.session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name == "virtual_dataset")
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(response["result"], {"table_id": dataset.id})
|
||||
assert response["result"] == {"table_id": dataset.id}
|
||||
|
||||
def test_get_or_create_dataset_database_not_found(self):
|
||||
"""
|
||||
|
|
@ -2329,9 +2329,9 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"api/v1/dataset/get_or_create/",
|
||||
json={"table_name": "virtual_dataset", "database_id": 999},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["message"], {"database": ["Database does not exist"]})
|
||||
assert response["message"] == {"database": ["Database does not exist"]}
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand.run")
|
||||
def test_get_or_create_dataset_create_fails(self, command_run_mock):
|
||||
|
|
@ -2347,9 +2347,9 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"database_id": get_example_database().id,
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["message"], "Dataset could not be created.")
|
||||
assert response["message"] == "Dataset could not be created."
|
||||
|
||||
def test_get_or_create_dataset_creates_table(self):
|
||||
"""
|
||||
|
|
@ -2370,7 +2370,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"template_params": '{"param": 1}',
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
table = (
|
||||
db.session.query(SqlaTable)
|
||||
|
|
@ -2410,12 +2410,9 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"db_name": get_example_database().database_name,
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
len(data["result"]),
|
||||
len(energy_charts),
|
||||
)
|
||||
assert len(data["result"]) == len(energy_charts)
|
||||
for chart_result in data["result"]:
|
||||
assert "chart_id" in chart_result
|
||||
assert "viz_error" in chart_result
|
||||
|
|
@ -2439,12 +2436,9 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"dashboard_id": dashboard.id,
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
len(data["result"]),
|
||||
len(birth_charts),
|
||||
)
|
||||
assert len(data["result"]) == len(birth_charts)
|
||||
for chart_result in data["result"]:
|
||||
assert "chart_id" in chart_result
|
||||
assert "viz_error" in chart_result
|
||||
|
|
@ -2462,12 +2456,9 @@ class TestDatasetApi(SupersetTestCase):
|
|||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
len(data["result"]),
|
||||
len(birth_charts),
|
||||
)
|
||||
assert len(data["result"]) == len(birth_charts)
|
||||
for chart_result in data["result"]:
|
||||
assert "chart_id" in chart_result
|
||||
assert "viz_error" in chart_result
|
||||
|
|
@ -2476,17 +2467,14 @@ class TestDatasetApi(SupersetTestCase):
|
|||
def test_warm_up_cache_db_and_table_name_required(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.client.put("/api/v1/dataset/warm_up_cache", json={"dashboard_id": 1})
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
data,
|
||||
{
|
||||
assert data == {
|
||||
"message": {
|
||||
"db_name": ["Missing data for required field."],
|
||||
"table_name": ["Missing data for required field."],
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
def test_warm_up_cache_table_not_found(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -2494,9 +2482,8 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"/api/v1/dataset/warm_up_cache",
|
||||
json={"table_name": "not_here", "db_name": "abc"},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
data,
|
||||
{"message": "The provided table was not found in the provided database"},
|
||||
)
|
||||
assert data == {
|
||||
"message": "The provided table was not found in the provided database"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -587,8 +587,8 @@ class TestCreateDatasetCommand(SupersetTestCase):
|
|||
.filter_by(table_name="test_create_dataset_command")
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(table, fetched_table)
|
||||
self.assertEqual([owner.username for owner in table.owners], ["admin"])
|
||||
assert table == fetched_table
|
||||
assert [owner.username for owner in table.owners] == ["admin"]
|
||||
|
||||
db.session.delete(table)
|
||||
with examples_db.get_sqla_engine() as engine:
|
||||
|
|
@ -626,7 +626,7 @@ class TestDatasetWarmUpCacheCommand(SupersetTestCase):
|
|||
results = DatasetWarmUpCacheCommand(
|
||||
get_example_database().database_name, "birth_names", None, None
|
||||
).run()
|
||||
self.assertEqual(len(results), len(birth_charts))
|
||||
assert len(results) == len(birth_charts)
|
||||
for chart_result in results:
|
||||
assert "chart_id" in chart_result
|
||||
assert "viz_error" in chart_result
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class TestDatasourceApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
table = self.get_virtual_dataset()
|
||||
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col1/values/")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
for val in range(10):
|
||||
assert val in response["result"]
|
||||
|
|
@ -51,7 +51,7 @@ class TestDatasourceApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
table = self.get_virtual_dataset()
|
||||
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
for val in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]:
|
||||
assert val in response["result"]
|
||||
|
|
@ -61,7 +61,7 @@ class TestDatasourceApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
table = self.get_virtual_dataset()
|
||||
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col3/values/")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
for val in [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]:
|
||||
assert val in response["result"]
|
||||
|
|
@ -71,16 +71,16 @@ class TestDatasourceApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
table = self.get_virtual_dataset()
|
||||
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col4/values/")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["result"], [None])
|
||||
assert response["result"] == [None]
|
||||
|
||||
@pytest.mark.usefixtures("app_context", "virtual_dataset")
|
||||
def test_get_column_values_integers_with_nulls(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
table = self.get_virtual_dataset()
|
||||
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col6/values/")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
for val in [1, None, 3, 4, 5, 6, 7, 8, 9, 10]:
|
||||
assert val in response["result"]
|
||||
|
|
@ -92,27 +92,27 @@ class TestDatasourceApi(SupersetTestCase):
|
|||
rv = self.client.get(
|
||||
f"api/v1/datasource/not_table/{table.id}/column/col1/values/"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["message"], "Invalid datasource type: not_table")
|
||||
assert response["message"] == "Invalid datasource type: not_table"
|
||||
|
||||
@patch("superset.datasource.api.DatasourceDAO.get_datasource")
|
||||
def test_get_column_values_datasource_type_not_supported(self, get_datasource_mock):
|
||||
get_datasource_mock.side_effect = DatasourceTypeNotSupportedError
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.client.get("api/v1/datasource/table/1/column/col1/values/")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response["message"], "DAO datasource query source type is not supported"
|
||||
assert (
|
||||
response["message"] == "DAO datasource query source type is not supported"
|
||||
)
|
||||
|
||||
def test_get_column_values_datasource_not_found(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.client.get("api/v1/datasource/table/999/column/col1/values/")
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["message"], "Datasource does not exist")
|
||||
assert response["message"] == "Datasource does not exist"
|
||||
|
||||
@pytest.mark.usefixtures("app_context", "virtual_dataset")
|
||||
def test_get_column_values_no_datasource_access(self):
|
||||
|
|
@ -126,12 +126,11 @@ class TestDatasourceApi(SupersetTestCase):
|
|||
self.login(GAMMA_USERNAME)
|
||||
table = self.get_virtual_dataset()
|
||||
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col1/values/")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response["message"],
|
||||
f"This endpoint requires the datasource {table.id}, "
|
||||
"database or `all_datasource_access` permission",
|
||||
assert (
|
||||
response["message"] == f"This endpoint requires the datasource {table.id}, "
|
||||
"database or `all_datasource_access` permission"
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("app_context", "virtual_dataset")
|
||||
|
|
@ -188,9 +187,9 @@ class TestDatasourceApi(SupersetTestCase):
|
|||
rv = self.client.get(
|
||||
f"api/v1/datasource/table/{table.id}/column/col2/values/"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["result"], ["b"])
|
||||
assert response["result"] == ["b"]
|
||||
|
||||
@pytest.mark.usefixtures("app_context", "virtual_dataset")
|
||||
def test_get_column_values_with_rls_no_values(self):
|
||||
|
|
@ -202,6 +201,6 @@ class TestDatasourceApi(SupersetTestCase):
|
|||
rv = self.client.get(
|
||||
f"api/v1/datasource/table/{table.id}/column/col2/values/"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["result"], [])
|
||||
assert response["result"] == []
|
||||
|
|
|
|||
|
|
@ -101,9 +101,15 @@ class TestDatasource(SupersetTestCase):
|
|||
url = f"/datasource/external_metadata/table/{tbl.id}/"
|
||||
resp = self.get_json_resp(url)
|
||||
col_names = {o.get("column_name") for o in resp}
|
||||
self.assertEqual(
|
||||
col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"}
|
||||
)
|
||||
assert col_names == {
|
||||
"num_boys",
|
||||
"num",
|
||||
"gender",
|
||||
"name",
|
||||
"ds",
|
||||
"state",
|
||||
"num_girls",
|
||||
}
|
||||
|
||||
def test_always_filter_main_dttm(self):
|
||||
database = get_example_database()
|
||||
|
|
@ -175,9 +181,15 @@ class TestDatasource(SupersetTestCase):
|
|||
url = f"/datasource/external_metadata_by_name/?q={params}"
|
||||
resp = self.get_json_resp(url)
|
||||
col_names = {o.get("column_name") for o in resp}
|
||||
self.assertEqual(
|
||||
col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"}
|
||||
)
|
||||
assert col_names == {
|
||||
"num_boys",
|
||||
"num",
|
||||
"gender",
|
||||
"name",
|
||||
"ds",
|
||||
"state",
|
||||
"num_girls",
|
||||
}
|
||||
|
||||
def test_external_metadata_by_name_for_virtual_table(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -235,7 +247,7 @@ class TestDatasource(SupersetTestCase):
|
|||
url = f"/datasource/external_metadata_by_name/?q={params}"
|
||||
resp = self.get_json_resp(url)
|
||||
col_names = {o.get("column_name") for o in resp}
|
||||
self.assertEqual(col_names, {"first", "second"})
|
||||
assert col_names == {"first", "second"}
|
||||
|
||||
# No databases found
|
||||
params = prison.dumps(
|
||||
|
|
@ -249,10 +261,10 @@ class TestDatasource(SupersetTestCase):
|
|||
)
|
||||
url = f"/datasource/external_metadata_by_name/?q={params}"
|
||||
resp = self.client.get(url)
|
||||
self.assertEqual(resp.status_code, DatasetNotFoundError.status)
|
||||
self.assertEqual(
|
||||
json.loads(resp.data.decode("utf-8")).get("error"),
|
||||
DatasetNotFoundError.message,
|
||||
assert resp.status_code == DatasetNotFoundError.status
|
||||
assert (
|
||||
json.loads(resp.data.decode("utf-8")).get("error")
|
||||
== DatasetNotFoundError.message
|
||||
)
|
||||
|
||||
# No table found
|
||||
|
|
@ -267,10 +279,10 @@ class TestDatasource(SupersetTestCase):
|
|||
)
|
||||
url = f"/datasource/external_metadata_by_name/?q={params}"
|
||||
resp = self.client.get(url)
|
||||
self.assertEqual(resp.status_code, DatasetNotFoundError.status)
|
||||
self.assertEqual(
|
||||
json.loads(resp.data.decode("utf-8")).get("error"),
|
||||
DatasetNotFoundError.message,
|
||||
assert resp.status_code == DatasetNotFoundError.status
|
||||
assert (
|
||||
json.loads(resp.data.decode("utf-8")).get("error")
|
||||
== DatasetNotFoundError.message
|
||||
)
|
||||
|
||||
# invalid query params
|
||||
|
|
@ -281,7 +293,7 @@ class TestDatasource(SupersetTestCase):
|
|||
)
|
||||
url = f"/datasource/external_metadata_by_name/?q={params}"
|
||||
resp = self.get_json_resp(url)
|
||||
self.assertIn("error", resp)
|
||||
assert "error" in resp
|
||||
|
||||
def test_external_metadata_for_virtual_table_template_params(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -308,7 +320,7 @@ class TestDatasource(SupersetTestCase):
|
|||
with db_insert_temp_object(table):
|
||||
url = f"/datasource/external_metadata/table/{table.id}/"
|
||||
resp = self.get_json_resp(url)
|
||||
self.assertEqual(resp["error"], "Only `SELECT` statements are allowed")
|
||||
assert resp["error"] == "Only `SELECT` statements are allowed"
|
||||
|
||||
def test_external_metadata_for_multistatement_virtual_table(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -322,7 +334,7 @@ class TestDatasource(SupersetTestCase):
|
|||
with db_insert_temp_object(table):
|
||||
url = f"/datasource/external_metadata/table/{table.id}/"
|
||||
resp = self.get_json_resp(url)
|
||||
self.assertEqual(resp["error"], "Only single queries supported")
|
||||
assert resp["error"] == "Only single queries supported"
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch("superset.connectors.sqla.models.SqlaTable.external_metadata")
|
||||
|
|
@ -350,7 +362,7 @@ class TestDatasource(SupersetTestCase):
|
|||
obj2 = l2_lookup.get(obj1.get(key))
|
||||
for k in obj1:
|
||||
if k not in "id" and obj1.get(k):
|
||||
self.assertEqual(obj1.get(k), obj2.get(k))
|
||||
assert obj1.get(k) == obj2.get(k)
|
||||
|
||||
def test_save(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -367,11 +379,11 @@ class TestDatasource(SupersetTestCase):
|
|||
elif k == "metrics":
|
||||
self.compare_lists(datasource_post[k], resp[k], "metric_name")
|
||||
elif k == "database":
|
||||
self.assertEqual(resp[k]["id"], datasource_post[k]["id"])
|
||||
assert resp[k]["id"] == datasource_post[k]["id"]
|
||||
elif k == "owners":
|
||||
self.assertEqual([o["id"] for o in resp[k]], datasource_post["owners"])
|
||||
assert [o["id"] for o in resp[k]] == datasource_post["owners"]
|
||||
else:
|
||||
self.assertEqual(resp[k], datasource_post[k])
|
||||
assert resp[k] == datasource_post[k]
|
||||
|
||||
def test_save_default_endpoint_validation_success(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -404,11 +416,11 @@ class TestDatasource(SupersetTestCase):
|
|||
new_db = self.create_fake_db()
|
||||
datasource_post["database"]["id"] = new_db.id
|
||||
resp = self.save_datasource_from_dict(datasource_post)
|
||||
self.assertEqual(resp["database"]["id"], new_db.id)
|
||||
assert resp["database"]["id"] == new_db.id
|
||||
|
||||
datasource_post["database"]["id"] = db_id
|
||||
resp = self.save_datasource_from_dict(datasource_post)
|
||||
self.assertEqual(resp["database"]["id"], db_id)
|
||||
assert resp["database"]["id"] == db_id
|
||||
|
||||
self.delete_fake_db()
|
||||
|
||||
|
|
@ -440,7 +452,7 @@ class TestDatasource(SupersetTestCase):
|
|||
)
|
||||
data = dict(data=json.dumps(datasource_post))
|
||||
resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False)
|
||||
self.assertIn("Duplicate column name(s): <new column>", resp["error"])
|
||||
assert "Duplicate column name(s): <new column>" in resp["error"]
|
||||
|
||||
def test_get_datasource(self):
|
||||
admin_user = self.get_user("admin")
|
||||
|
|
@ -454,11 +466,9 @@ class TestDatasource(SupersetTestCase):
|
|||
self.get_json_resp("/datasource/save/", data)
|
||||
url = f"/datasource/get/{tbl.type}/{tbl.id}/"
|
||||
resp = self.get_json_resp(url)
|
||||
self.assertEqual(resp.get("type"), "table")
|
||||
assert resp.get("type") == "table"
|
||||
col_names = {o.get("column_name") for o in resp["columns"]}
|
||||
self.assertEqual(
|
||||
col_names,
|
||||
{
|
||||
assert col_names == {
|
||||
"num_boys",
|
||||
"num",
|
||||
"gender",
|
||||
|
|
@ -467,8 +477,7 @@ class TestDatasource(SupersetTestCase):
|
|||
"state",
|
||||
"num_girls",
|
||||
"num_california",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
def test_get_datasource_with_health_check(self):
|
||||
def my_check(datasource):
|
||||
|
|
@ -491,7 +500,7 @@ class TestDatasource(SupersetTestCase):
|
|||
|
||||
self.login(ADMIN_USERNAME)
|
||||
resp = self.get_json_resp("/datasource/get/table/500000/", raise_on_error=False)
|
||||
self.assertEqual(resp.get("error"), "Datasource does not exist")
|
||||
assert resp.get("error") == "Datasource does not exist"
|
||||
|
||||
def test_get_datasource_invalid_datasource_failed(self):
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
|
|
@ -503,7 +512,7 @@ class TestDatasource(SupersetTestCase):
|
|||
|
||||
self.login(ADMIN_USERNAME)
|
||||
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
|
||||
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
|
||||
assert resp.get("error") == "'druid' is not a valid DatasourceType"
|
||||
|
||||
|
||||
def test_get_samples(test_client, login_as_admin, virtual_dataset):
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ class TestAscendDbEngineSpec(TestDbEngineSpec):
|
|||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
AscendEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
|
||||
assert (
|
||||
AscendEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
AscendEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
|
||||
assert (
|
||||
AscendEngineSpec.convert_dttm("TIMESTAMP", dttm)
|
||||
== "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -61,18 +61,18 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
q10 = "select * from mytable limit 20, x"
|
||||
q11 = "select * from mytable limit x offset 20"
|
||||
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
|
||||
assert engine_spec_class.get_limit_from_sql(q0) is None
|
||||
assert engine_spec_class.get_limit_from_sql(q1) == 10
|
||||
assert engine_spec_class.get_limit_from_sql(q2) == 20
|
||||
assert engine_spec_class.get_limit_from_sql(q3) is None
|
||||
assert engine_spec_class.get_limit_from_sql(q4) == 20
|
||||
assert engine_spec_class.get_limit_from_sql(q5) == 10
|
||||
assert engine_spec_class.get_limit_from_sql(q6) == 10
|
||||
assert engine_spec_class.get_limit_from_sql(q7) is None
|
||||
assert engine_spec_class.get_limit_from_sql(q8) is None
|
||||
assert engine_spec_class.get_limit_from_sql(q9) is None
|
||||
assert engine_spec_class.get_limit_from_sql(q10) is None
|
||||
assert engine_spec_class.get_limit_from_sql(q11) is None
|
||||
|
||||
def test_wrapped_semi_tabs(self):
|
||||
self.sql_limit_regex(
|
||||
|
|
@ -141,7 +141,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
)
|
||||
|
||||
def test_get_datatype(self):
|
||||
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
|
||||
assert "VARCHAR" == BaseEngineSpec.get_datatype("VARCHAR")
|
||||
|
||||
def test_limit_with_implicit_offset(self):
|
||||
self.sql_limit_regex(
|
||||
|
|
@ -198,17 +198,15 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
for engine in load_engine_specs():
|
||||
if engine is not BaseEngineSpec:
|
||||
# make sure time grain functions have been defined
|
||||
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
|
||||
assert len(engine.get_time_grain_expressions()) > 0
|
||||
# make sure all defined time grains are supported
|
||||
defined_grains = {grain.duration for grain in engine.get_time_grains()}
|
||||
intersection = time_grains.intersection(defined_grains)
|
||||
self.assertSetEqual(defined_grains, intersection, engine)
|
||||
self.assertSetEqual(defined_grains, intersection, engine) # noqa: PT009
|
||||
|
||||
def test_get_time_grain_expressions(self):
|
||||
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
||||
self.assertEqual(
|
||||
list(time_grains.keys()),
|
||||
[
|
||||
assert list(time_grains.keys()) == [
|
||||
None,
|
||||
"PT1S",
|
||||
"PT1M",
|
||||
|
|
@ -219,8 +217,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
"P3M",
|
||||
"P1Y",
|
||||
"1969-12-29T00:00:00Z/P1W",
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_get_table_names(self):
|
||||
inspector = mock.Mock()
|
||||
|
|
@ -255,11 +252,11 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
expected = ["STRING", "STRING", "FLOAT"]
|
||||
else:
|
||||
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
|
||||
self.assertEqual(col_names, expected)
|
||||
assert col_names == expected
|
||||
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm, db_extra=None))
|
||||
assert BaseEngineSpec.convert_dttm("", dttm, db_extra=None) is None
|
||||
|
||||
def test_pyodbc_rows_to_tuples(self):
|
||||
# Test for case when pyodbc.Row is returned (odbc driver)
|
||||
|
|
@ -272,7 +269,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
|
||||
]
|
||||
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
|
||||
self.assertListEqual(result, expected)
|
||||
self.assertListEqual(result, expected) # noqa: PT009
|
||||
|
||||
def test_pyodbc_rows_to_tuples_passthrough(self):
|
||||
# Test for case when tuples are returned
|
||||
|
|
@ -281,7 +278,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
|
||||
]
|
||||
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
|
||||
self.assertListEqual(result, data)
|
||||
self.assertListEqual(result, data) # noqa: PT009
|
||||
|
||||
@mock.patch("superset.models.core.Database.db_engine_spec", BaseEngineSpec)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
|
|
|||
|
|
@ -33,4 +33,4 @@ class TestDbEngineSpec(SupersetTestCase):
|
|||
):
|
||||
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
|
||||
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force)
|
||||
self.assertEqual(expected_sql, limited)
|
||||
assert expected_sql == limited
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
|||
}
|
||||
for original, expected in test_cases.items():
|
||||
actual = BigQueryEngineSpec.make_label_compatible(column(original).name)
|
||||
self.assertEqual(actual, expected)
|
||||
assert actual == expected
|
||||
|
||||
def test_timegrain_expressions(self):
|
||||
"""
|
||||
|
|
@ -63,7 +63,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
|||
actual = BigQueryEngineSpec.get_timestamp_expr(
|
||||
col=col, pdf=None, time_grain="PT1H"
|
||||
)
|
||||
self.assertEqual(str(actual), expected)
|
||||
assert str(actual) == expected
|
||||
|
||||
def test_custom_minute_timegrain_expressions(self):
|
||||
"""
|
||||
|
|
@ -104,12 +104,12 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
|||
data1 = [(1, "foo")]
|
||||
with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data1):
|
||||
result = BigQueryEngineSpec.fetch_data(None, 0)
|
||||
self.assertEqual(result, data1)
|
||||
assert result == data1
|
||||
|
||||
data2 = [Row(1), Row(2)]
|
||||
with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data2):
|
||||
result = BigQueryEngineSpec.fetch_data(None, 0)
|
||||
self.assertEqual(result, [1, 2])
|
||||
assert result == [1, 2]
|
||||
|
||||
def test_get_extra_table_metadata(self):
|
||||
"""
|
||||
|
|
@ -122,7 +122,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
|||
database,
|
||||
Table("some_table", "some_schema"),
|
||||
)
|
||||
self.assertEqual(result, {})
|
||||
assert result == {}
|
||||
|
||||
index_metadata = [
|
||||
{
|
||||
|
|
@ -143,7 +143,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
|||
database,
|
||||
Table("some_table", "some_schema"),
|
||||
)
|
||||
self.assertEqual(result, expected_result)
|
||||
assert result == expected_result
|
||||
|
||||
def test_get_indexes(self):
|
||||
database = mock.Mock()
|
||||
|
|
|
|||
|
|
@ -40,4 +40,4 @@ class TestElasticsearchDbEngineSpec(TestDbEngineSpec):
|
|||
actual = ElasticSearchEngineSpec.get_timestamp_expr(
|
||||
col=col, pdf=None, time_grain=time_grain
|
||||
)
|
||||
self.assertEqual(str(actual), expected_time_grain_expression)
|
||||
assert str(actual) == expected_time_grain_expression
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
|
|||
)
|
||||
def test_get_datatype_mysql(self):
|
||||
"""Tests related to datatype mapping for MySQL"""
|
||||
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
|
||||
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
|
||||
assert "TINY" == MySQLEngineSpec.get_datatype(1)
|
||||
assert "VARCHAR" == MySQLEngineSpec.get_datatype(15)
|
||||
|
||||
def test_column_datatype_to_string(self):
|
||||
test_cases = (
|
||||
|
|
@ -49,7 +49,7 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
|
|||
actual = MySQLEngineSpec.column_datatype_to_string(
|
||||
original, mysql.dialect()
|
||||
)
|
||||
self.assertEqual(actual, expected)
|
||||
assert actual == expected
|
||||
|
||||
def test_extract_error_message(self):
|
||||
from MySQLdb._exceptions import OperationalError
|
||||
|
|
|
|||
|
|
@ -32,20 +32,14 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
|
|||
+ "DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', "
|
||||
+ "'1:SECONDS:EPOCH', '1:SECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_pinot_time_expression_simple_date_format_1d_grain(self):
|
||||
col = column("tstamp")
|
||||
expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1D")
|
||||
result = str(expr.compile())
|
||||
expected = "CAST(DATE_TRUNC('day', CAST(tstamp AS TIMESTAMP)) AS TIMESTAMP)"
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_pinot_time_expression_simple_date_format_10m_grain(self):
|
||||
col = column("tstamp")
|
||||
|
|
@ -55,20 +49,14 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
|
|||
"CAST(ROUND(DATE_TRUNC('minute', CAST(tstamp AS "
|
||||
+ "TIMESTAMP)), 600000) AS TIMESTAMP)"
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_pinot_time_expression_simple_date_format_1w_grain(self):
|
||||
col = column("tstamp")
|
||||
expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1W")
|
||||
result = str(expr.compile())
|
||||
expected = "CAST(DATE_TRUNC('week', CAST(tstamp AS TIMESTAMP)) AS TIMESTAMP)"
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_pinot_time_expression_sec_one_1m_grain(self):
|
||||
col = column("tstamp")
|
||||
|
|
@ -79,10 +67,7 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
|
|||
+ "DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', "
|
||||
+ "'1:SECONDS:EPOCH', '1:SECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_pinot_time_expression_millisec_one_1m_grain(self):
|
||||
col = column("tstamp")
|
||||
|
|
@ -93,10 +78,7 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
|
|||
+ "DATETIMECONVERT(tstamp, '1:MILLISECONDS:EPOCH', "
|
||||
+ "'1:MILLISECONDS:EPOCH', '1:MILLISECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_invalid_get_time_expression_arguments(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
col = literal_column("COALESCE(a, b)")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "COALESCE(a, b)")
|
||||
assert result == "COALESCE(a, b)"
|
||||
|
||||
def test_time_exp_literal_1y_grain(self):
|
||||
"""
|
||||
|
|
@ -66,7 +66,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
col = literal_column("COALESCE(a, b)")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
|
||||
assert result == "DATE_TRUNC('year', COALESCE(a, b))"
|
||||
|
||||
def test_time_ex_lowr_col_no_grain(self):
|
||||
"""
|
||||
|
|
@ -75,7 +75,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
col = column("lower_case")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "lower_case")
|
||||
assert result == "lower_case"
|
||||
|
||||
def test_time_exp_lowr_col_sec_1y(self):
|
||||
"""
|
||||
|
|
@ -84,10 +84,9 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
col = column("lower_case")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y")
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(
|
||||
result,
|
||||
"DATE_TRUNC('year', "
|
||||
"(timestamp 'epoch' + lower_case * interval '1 second'))",
|
||||
assert (
|
||||
result == "DATE_TRUNC('year', "
|
||||
"(timestamp 'epoch' + lower_case * interval '1 second'))"
|
||||
)
|
||||
|
||||
def test_time_exp_mixed_case_col_1y(self):
|
||||
|
|
@ -97,7 +96,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
col = column("MixedCase")
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
|
||||
assert result == "DATE_TRUNC('year', \"MixedCase\")"
|
||||
|
||||
def test_empty_dbapi_cursor_description(self):
|
||||
"""
|
||||
|
|
@ -107,7 +106,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
# empty description mean no columns, this mocks the following SQL: "SELECT"
|
||||
cursor.description = []
|
||||
results = PostgresEngineSpec.fetch_data(cursor, 1000)
|
||||
self.assertEqual(results, [])
|
||||
assert results == []
|
||||
|
||||
def test_engine_alias_name(self):
|
||||
"""
|
||||
|
|
@ -158,13 +157,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
)
|
||||
sql = "SELECT * FROM birth_names"
|
||||
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
|
||||
self.assertEqual(
|
||||
results,
|
||||
{
|
||||
"Start-up cost": 0.00,
|
||||
"Total cost": 1537.91,
|
||||
},
|
||||
)
|
||||
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
|
||||
|
||||
def test_estimate_statement_invalid_syntax(self):
|
||||
"""
|
||||
|
|
@ -199,19 +192,10 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
},
|
||||
]
|
||||
result = PostgresEngineSpec.query_cost_formatter(raw_cost)
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
{
|
||||
"Start-up cost": "0.0",
|
||||
"Total cost": "1537.91",
|
||||
},
|
||||
{
|
||||
"Start-up cost": "10.0",
|
||||
"Total cost": "1537.0",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert result == [
|
||||
{"Start-up cost": "0.0", "Total cost": "1537.91"},
|
||||
{"Start-up cost": "10.0", "Total cost": "1537.0"},
|
||||
]
|
||||
|
||||
def test_extract_errors(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
|||
class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
@skipUnless(TestDbEngineSpec.is_module_installed("pyhive"), "pyhive not installed")
|
||||
def test_get_datatype_presto(self):
|
||||
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
|
||||
assert "STRING" == PrestoEngineSpec.get_datatype("string")
|
||||
|
||||
def test_get_view_names_with_schema(self):
|
||||
database = mock.MagicMock()
|
||||
|
|
@ -86,10 +86,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
row.Column, row.Type, row.Null = column
|
||||
inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row])
|
||||
results = PrestoEngineSpec.get_columns(inspector, Table("", ""))
|
||||
self.assertEqual(len(expected_results), len(results))
|
||||
assert len(expected_results) == len(results)
|
||||
for expected_result, result in zip(expected_results, results):
|
||||
self.assertEqual(expected_result[0], result["column_name"])
|
||||
self.assertEqual(expected_result[1], str(result["type"]))
|
||||
assert expected_result[0] == result["column_name"]
|
||||
assert expected_result[1] == str(result["type"])
|
||||
|
||||
def test_presto_get_column(self):
|
||||
presto_column = ("column_name", "boolean", "")
|
||||
|
|
@ -192,8 +192,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
},
|
||||
]
|
||||
for actual_result, expected_result in zip(actual_results, expected_results):
|
||||
self.assertEqual(actual_result.element.name, expected_result["column_name"])
|
||||
self.assertEqual(actual_result.name, expected_result["label"])
|
||||
assert actual_result.element.name == expected_result["column_name"]
|
||||
assert actual_result.name == expected_result["label"]
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -260,9 +260,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
"is_dttm": False,
|
||||
}
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
assert actual_cols == expected_cols
|
||||
assert actual_data == expected_data
|
||||
assert actual_expanded_cols == expected_expanded_cols
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -343,9 +343,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
assert actual_cols == expected_cols
|
||||
assert actual_data == expected_data
|
||||
assert actual_expanded_cols == expected_expanded_cols
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -427,9 +427,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
assert actual_cols == expected_cols
|
||||
assert actual_data == expected_data
|
||||
assert actual_expanded_cols == expected_expanded_cols
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -548,9 +548,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
"is_dttm": False,
|
||||
},
|
||||
]
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
assert actual_cols == expected_cols
|
||||
assert actual_data == expected_data
|
||||
assert actual_expanded_cols == expected_expanded_cols
|
||||
|
||||
def test_presto_get_extra_table_metadata(self):
|
||||
database = mock.Mock()
|
||||
|
|
@ -582,7 +582,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
columns,
|
||||
)
|
||||
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
|
||||
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
|
||||
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
|
||||
|
||||
def test_query_cost_formatter(self):
|
||||
raw_cost = [
|
||||
|
|
@ -645,7 +645,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
"Network cost": "354 G",
|
||||
}
|
||||
]
|
||||
self.assertEqual(formatted_cost, expected)
|
||||
assert formatted_cost == expected
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -752,9 +752,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||
},
|
||||
]
|
||||
|
||||
self.assertEqual(actual_cols, expected_cols)
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
assert actual_cols == expected_cols
|
||||
assert actual_data == expected_data
|
||||
assert actual_expanded_cols == expected_expanded_cols
|
||||
|
||||
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
|
||||
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
|
||||
|
|
|
|||
|
|
@ -91,36 +91,32 @@ class TestDictImportExport(SupersetTestCase):
|
|||
def yaml_compare(self, obj_1, obj_2):
|
||||
obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False)
|
||||
obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False)
|
||||
self.assertEqual(obj_1_str, obj_2_str)
|
||||
assert obj_1_str == obj_2_str
|
||||
|
||||
def assert_table_equals(self, expected_ds, actual_ds):
|
||||
self.assertEqual(expected_ds.table_name, actual_ds.table_name)
|
||||
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
|
||||
self.assertEqual(expected_ds.schema, actual_ds.schema)
|
||||
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
|
||||
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
|
||||
self.assertEqual(
|
||||
{c.column_name for c in expected_ds.columns},
|
||||
{c.column_name for c in actual_ds.columns},
|
||||
)
|
||||
self.assertEqual(
|
||||
{m.metric_name for m in expected_ds.metrics},
|
||||
{m.metric_name for m in actual_ds.metrics},
|
||||
)
|
||||
assert expected_ds.table_name == actual_ds.table_name
|
||||
assert expected_ds.main_dttm_col == actual_ds.main_dttm_col
|
||||
assert expected_ds.schema == actual_ds.schema
|
||||
assert len(expected_ds.metrics) == len(actual_ds.metrics)
|
||||
assert len(expected_ds.columns) == len(actual_ds.columns)
|
||||
assert {c.column_name for c in expected_ds.columns} == {
|
||||
c.column_name for c in actual_ds.columns
|
||||
}
|
||||
assert {m.metric_name for m in expected_ds.metrics} == {
|
||||
m.metric_name for m in actual_ds.metrics
|
||||
}
|
||||
|
||||
def assert_datasource_equals(self, expected_ds, actual_ds):
|
||||
self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name)
|
||||
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
|
||||
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
|
||||
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
|
||||
self.assertEqual(
|
||||
{c.column_name for c in expected_ds.columns},
|
||||
{c.column_name for c in actual_ds.columns},
|
||||
)
|
||||
self.assertEqual(
|
||||
{m.metric_name for m in expected_ds.metrics},
|
||||
{m.metric_name for m in actual_ds.metrics},
|
||||
)
|
||||
assert expected_ds.datasource_name == actual_ds.datasource_name
|
||||
assert expected_ds.main_dttm_col == actual_ds.main_dttm_col
|
||||
assert len(expected_ds.metrics) == len(actual_ds.metrics)
|
||||
assert len(expected_ds.columns) == len(actual_ds.columns)
|
||||
assert {c.column_name for c in expected_ds.columns} == {
|
||||
c.column_name for c in actual_ds.columns
|
||||
}
|
||||
assert {m.metric_name for m in expected_ds.metrics} == {
|
||||
m.metric_name for m in actual_ds.metrics
|
||||
}
|
||||
|
||||
def test_import_table_no_metadata(self):
|
||||
table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
|
||||
|
|
@ -143,8 +139,8 @@ class TestDictImportExport(SupersetTestCase):
|
|||
db.session.commit()
|
||||
imported = self.get_table_by_id(imported_table.id)
|
||||
self.assert_table_equals(table, imported)
|
||||
self.assertEqual(
|
||||
{DBREF: ID_PREFIX + 2, "database_name": "main"}, json.loads(imported.params)
|
||||
assert {DBREF: ID_PREFIX + 2, "database_name": "main"} == json.loads(
|
||||
imported.params
|
||||
)
|
||||
self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
|
||||
|
||||
|
|
@ -178,7 +174,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
imported_over = self.get_table_by_id(imported_over_table.id)
|
||||
self.assertEqual(imported_table.id, imported_over.id)
|
||||
assert imported_table.id == imported_over.id
|
||||
expected_table, _ = self.create_table(
|
||||
"table_override",
|
||||
id=ID_PREFIX + 3,
|
||||
|
|
@ -209,7 +205,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
imported_over = self.get_table_by_id(imported_over_table.id)
|
||||
self.assertEqual(imported_table.id, imported_over.id)
|
||||
assert imported_table.id == imported_over.id
|
||||
expected_table, _ = self.create_table(
|
||||
"table_override",
|
||||
id=ID_PREFIX + 3,
|
||||
|
|
@ -239,7 +235,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
)
|
||||
imported_copy_table = SqlaTable.import_from_dict(dict_copy_table)
|
||||
db.session.commit()
|
||||
self.assertEqual(imported_table.id, imported_copy_table.id)
|
||||
assert imported_table.id == imported_copy_table.id
|
||||
self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
|
||||
self.yaml_compare(
|
||||
imported_copy_table.export_to_dict(), imported_table.export_to_dict()
|
||||
|
|
@ -259,12 +255,12 @@ class TestDictImportExport(SupersetTestCase):
|
|||
"/databaseview/action_post", {"action": "yaml_export", "rowid": 1}
|
||||
)
|
||||
ui_export = yaml.safe_load(resp)
|
||||
self.assertEqual(
|
||||
ui_export["databases"][0]["database_name"],
|
||||
cli_export["databases"][0]["database_name"],
|
||||
assert (
|
||||
ui_export["databases"][0]["database_name"]
|
||||
== cli_export["databases"][0]["database_name"]
|
||||
)
|
||||
self.assertEqual(
|
||||
ui_export["databases"][0]["tables"], cli_export["databases"][0]["tables"]
|
||||
assert (
|
||||
ui_export["databases"][0]["tables"] == cli_export["databases"][0]["tables"]
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class TestDynamicPlugins(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "/dynamic-plugins/api"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@with_feature_flags(DYNAMIC_PLUGINS=True)
|
||||
def test_dynamic_plugins_enabled(self):
|
||||
|
|
@ -38,4 +38,4 @@ class TestDynamicPlugins(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "/dynamic-plugins/api"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class TestEmailSmtp(SupersetTestCase):
|
|||
app.config["SMTP_HOST"], app.config["SMTP_PORT"], context=mock.ANY
|
||||
)
|
||||
called_context = mock_smtp_ssl.call_args.kwargs["context"]
|
||||
self.assertEqual(called_context.verify_mode, ssl.CERT_REQUIRED)
|
||||
assert called_context.verify_mode == ssl.CERT_REQUIRED
|
||||
|
||||
@mock.patch("smtplib.SMTP")
|
||||
def test_send_mime_tls_server_auth(self, mock_smtp):
|
||||
|
|
@ -233,7 +233,7 @@ class TestEmailSmtp(SupersetTestCase):
|
|||
utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False)
|
||||
mock_smtp.return_value.starttls.assert_called_with(context=mock.ANY)
|
||||
called_context = mock_smtp.return_value.starttls.call_args.kwargs["context"]
|
||||
self.assertEqual(called_context.verify_mode, ssl.CERT_REQUIRED)
|
||||
assert called_context.verify_mode == ssl.CERT_REQUIRED
|
||||
|
||||
@mock.patch("smtplib.SMTP_SSL")
|
||||
@mock.patch("smtplib.SMTP")
|
||||
|
|
|
|||
|
|
@ -36,13 +36,13 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
|
|||
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
||||
db.session.flush()
|
||||
assert dash.embedded
|
||||
self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"])
|
||||
assert dash.embedded[0].allowed_domains == ["test.example.com"]
|
||||
original_uuid = dash.embedded[0].uuid
|
||||
self.assertIsNotNone(original_uuid)
|
||||
assert original_uuid is not None
|
||||
EmbeddedDashboardDAO.upsert(dash, [])
|
||||
db.session.flush()
|
||||
self.assertEqual(dash.embedded[0].allowed_domains, [])
|
||||
self.assertEqual(dash.embedded[0].uuid, original_uuid)
|
||||
assert dash.embedded[0].allowed_domains == []
|
||||
assert dash.embedded[0].uuid == original_uuid
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_get_by_uuid(self):
|
||||
|
|
@ -51,4 +51,4 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
|
|||
db.session.flush()
|
||||
uuid = str(dash.embedded[0].uuid)
|
||||
embedded = EmbeddedDashboardDAO.find_by_id(uuid)
|
||||
self.assertIsNotNone(embedded)
|
||||
assert embedded is not None
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class TestEventLogger(unittest.TestCase):
|
|||
# unmodified object
|
||||
obj = DBEventLogger()
|
||||
res = get_event_logger_from_cfg_value(obj)
|
||||
self.assertIs(obj, res)
|
||||
assert obj is res
|
||||
|
||||
def test_config_class_deprecation(self):
|
||||
# test that assignment of a class object to EVENT_LOGGER is correctly
|
||||
|
|
@ -51,7 +51,7 @@ class TestEventLogger(unittest.TestCase):
|
|||
res = get_event_logger_from_cfg_value(DBEventLogger)
|
||||
|
||||
# class is instantiated and returned
|
||||
self.assertIsInstance(res, DBEventLogger)
|
||||
assert isinstance(res, DBEventLogger)
|
||||
|
||||
def test_raises_typeerror_if_not_abc(self):
|
||||
# test that assignment of non AbstractEventLogger derived type raises
|
||||
|
|
@ -71,19 +71,16 @@ class TestEventLogger(unittest.TestCase):
|
|||
with app.test_request_context("/superset/dashboard/1/?myparam=foo"):
|
||||
result = test_func()
|
||||
payload = mock_log.call_args[1]
|
||||
self.assertEqual(result, 1)
|
||||
self.assertEqual(
|
||||
payload["records"],
|
||||
[
|
||||
assert result == 1
|
||||
assert payload["records"] == [
|
||||
{
|
||||
"myparam": "foo",
|
||||
"path": "/superset/dashboard/1/",
|
||||
"url_rule": "/superset/dashboard/<dashboard_id_or_slug>/",
|
||||
"object_ref": test_func.__qualname__,
|
||||
}
|
||||
],
|
||||
)
|
||||
self.assertGreaterEqual(payload["duration_ms"], 50)
|
||||
]
|
||||
assert payload["duration_ms"] >= 50
|
||||
|
||||
@patch.object(DBEventLogger, "log")
|
||||
def test_log_this_with_extra_payload(self, mock_log):
|
||||
|
|
@ -98,19 +95,16 @@ class TestEventLogger(unittest.TestCase):
|
|||
with app.test_request_context():
|
||||
result = test_func(1, karg1=2) # pylint: disable=no-value-for-parameter
|
||||
payload = mock_log.call_args[1]
|
||||
self.assertEqual(result, 2)
|
||||
self.assertEqual(
|
||||
payload["records"],
|
||||
[
|
||||
assert result == 2
|
||||
assert payload["records"] == [
|
||||
{
|
||||
"foo": "bar",
|
||||
"path": "/",
|
||||
"karg1": 2,
|
||||
"object_ref": test_func.__qualname__,
|
||||
}
|
||||
],
|
||||
)
|
||||
self.assertGreaterEqual(payload["duration_ms"], 100)
|
||||
]
|
||||
assert payload["duration_ms"] >= 100
|
||||
|
||||
@patch("superset.utils.core.g", spec={})
|
||||
@freeze_time("Jan 14th, 2020", auto_tick_seconds=15)
|
||||
|
|
@ -141,9 +135,7 @@ class TestEventLogger(unittest.TestCase):
|
|||
with logger(action="foo", engine="bar"):
|
||||
pass
|
||||
|
||||
self.assertEquals(
|
||||
logger.records,
|
||||
[
|
||||
assert logger.records == [
|
||||
{
|
||||
"records": [{"path": "/", "engine": "bar"}],
|
||||
"database_id": None,
|
||||
|
|
@ -152,8 +144,7 @@ class TestEventLogger(unittest.TestCase):
|
|||
"curated_payload": {},
|
||||
"curated_form_data": {},
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
@patch("superset.utils.core.g", spec={})
|
||||
def test_context_manager_log_with_context(self, mock_g):
|
||||
|
|
@ -188,9 +179,7 @@ class TestEventLogger(unittest.TestCase):
|
|||
payload_override={"engine": "sqlite"},
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
logger.records,
|
||||
[
|
||||
assert logger.records == [
|
||||
{
|
||||
"records": [
|
||||
{
|
||||
|
|
@ -205,8 +194,7 @@ class TestEventLogger(unittest.TestCase):
|
|||
"curated_payload": {},
|
||||
"curated_form_data": {},
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
@patch("superset.utils.core.g", spec={})
|
||||
def test_log_with_context_user_null(self, mock_g):
|
||||
|
|
|
|||
|
|
@ -24,13 +24,13 @@ class TestForm(SupersetTestCase):
|
|||
def test_comma_separated_list_field(self):
|
||||
field = CommaSeparatedListField().bind(Form(), "foo")
|
||||
field.process_formdata([""])
|
||||
self.assertEqual(field.data, [""])
|
||||
assert field.data == [""]
|
||||
|
||||
field.process_formdata(["a,comma,separated,list"])
|
||||
self.assertEqual(field.data, ["a", "comma", "separated", "list"])
|
||||
assert field.data == ["a", "comma", "separated", "list"]
|
||||
|
||||
def test_filter_not_empty_values(self):
|
||||
self.assertEqual(filter_not_empty_values(None), None)
|
||||
self.assertEqual(filter_not_empty_values([]), None)
|
||||
self.assertEqual(filter_not_empty_values([""]), None)
|
||||
self.assertEqual(filter_not_empty_values(["hi"]), ["hi"])
|
||||
assert filter_not_empty_values(None) is None
|
||||
assert filter_not_empty_values([]) is None
|
||||
assert filter_not_empty_values([""]) is None
|
||||
assert filter_not_empty_values(["hi"]) == ["hi"]
|
||||
|
|
|
|||
|
|
@ -148,52 +148,48 @@ class TestImportExport(SupersetTestCase):
|
|||
self, expected_dash, actual_dash, check_position=True, check_slugs=True
|
||||
):
|
||||
if check_slugs:
|
||||
self.assertEqual(expected_dash.slug, actual_dash.slug)
|
||||
self.assertEqual(expected_dash.dashboard_title, actual_dash.dashboard_title)
|
||||
self.assertEqual(len(expected_dash.slices), len(actual_dash.slices))
|
||||
assert expected_dash.slug == actual_dash.slug
|
||||
assert expected_dash.dashboard_title == actual_dash.dashboard_title
|
||||
assert len(expected_dash.slices) == len(actual_dash.slices)
|
||||
expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "")
|
||||
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
|
||||
for e_slc, a_slc in zip(expected_slices, actual_slices):
|
||||
self.assert_slice_equals(e_slc, a_slc)
|
||||
if check_position:
|
||||
self.assertEqual(expected_dash.position_json, actual_dash.position_json)
|
||||
assert expected_dash.position_json == actual_dash.position_json
|
||||
|
||||
def assert_table_equals(self, expected_ds, actual_ds):
|
||||
self.assertEqual(expected_ds.table_name, actual_ds.table_name)
|
||||
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
|
||||
self.assertEqual(expected_ds.schema, actual_ds.schema)
|
||||
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
|
||||
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
|
||||
self.assertEqual(
|
||||
{c.column_name for c in expected_ds.columns},
|
||||
{c.column_name for c in actual_ds.columns},
|
||||
)
|
||||
self.assertEqual(
|
||||
{m.metric_name for m in expected_ds.metrics},
|
||||
{m.metric_name for m in actual_ds.metrics},
|
||||
)
|
||||
assert expected_ds.table_name == actual_ds.table_name
|
||||
assert expected_ds.main_dttm_col == actual_ds.main_dttm_col
|
||||
assert expected_ds.schema == actual_ds.schema
|
||||
assert len(expected_ds.metrics) == len(actual_ds.metrics)
|
||||
assert len(expected_ds.columns) == len(actual_ds.columns)
|
||||
assert {c.column_name for c in expected_ds.columns} == {
|
||||
c.column_name for c in actual_ds.columns
|
||||
}
|
||||
assert {m.metric_name for m in expected_ds.metrics} == {
|
||||
m.metric_name for m in actual_ds.metrics
|
||||
}
|
||||
|
||||
def assert_datasource_equals(self, expected_ds, actual_ds):
|
||||
self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name)
|
||||
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
|
||||
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
|
||||
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
|
||||
self.assertEqual(
|
||||
{c.column_name for c in expected_ds.columns},
|
||||
{c.column_name for c in actual_ds.columns},
|
||||
)
|
||||
self.assertEqual(
|
||||
{m.metric_name for m in expected_ds.metrics},
|
||||
{m.metric_name for m in actual_ds.metrics},
|
||||
)
|
||||
assert expected_ds.datasource_name == actual_ds.datasource_name
|
||||
assert expected_ds.main_dttm_col == actual_ds.main_dttm_col
|
||||
assert len(expected_ds.metrics) == len(actual_ds.metrics)
|
||||
assert len(expected_ds.columns) == len(actual_ds.columns)
|
||||
assert {c.column_name for c in expected_ds.columns} == {
|
||||
c.column_name for c in actual_ds.columns
|
||||
}
|
||||
assert {m.metric_name for m in expected_ds.metrics} == {
|
||||
m.metric_name for m in actual_ds.metrics
|
||||
}
|
||||
|
||||
def assert_slice_equals(self, expected_slc, actual_slc):
|
||||
# to avoid bad slice data (no slice_name)
|
||||
expected_slc_name = expected_slc.slice_name or ""
|
||||
actual_slc_name = actual_slc.slice_name or ""
|
||||
self.assertEqual(expected_slc_name, actual_slc_name)
|
||||
self.assertEqual(expected_slc.datasource_type, actual_slc.datasource_type)
|
||||
self.assertEqual(expected_slc.viz_type, actual_slc.viz_type)
|
||||
assert expected_slc_name == actual_slc_name
|
||||
assert expected_slc.datasource_type == actual_slc.datasource_type
|
||||
assert expected_slc.viz_type == actual_slc.viz_type
|
||||
exp_params = json.loads(expected_slc.params)
|
||||
actual_params = json.loads(actual_slc.params)
|
||||
diff_params_keys = (
|
||||
|
|
@ -208,7 +204,7 @@ class TestImportExport(SupersetTestCase):
|
|||
actual_params.pop(k)
|
||||
if k in exp_params:
|
||||
exp_params.pop(k)
|
||||
self.assertEqual(exp_params, actual_params)
|
||||
assert exp_params == actual_params
|
||||
|
||||
def assert_only_exported_slc_fields(self, expected_dash, actual_dash):
|
||||
"""only exported json has this params
|
||||
|
|
@ -218,9 +214,9 @@ class TestImportExport(SupersetTestCase):
|
|||
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
|
||||
for e_slc, a_slc in zip(expected_slices, actual_slices):
|
||||
params = a_slc.params_dict
|
||||
self.assertEqual(e_slc.datasource.name, params["datasource_name"])
|
||||
self.assertEqual(e_slc.datasource.schema, params["schema"])
|
||||
self.assertEqual(e_slc.datasource.database.name, params["database_name"])
|
||||
assert e_slc.datasource.name == params["datasource_name"]
|
||||
assert e_slc.datasource.schema == params["schema"]
|
||||
assert e_slc.datasource.database.name == params["database_name"]
|
||||
|
||||
@unittest.skip("Schema needs to be updated")
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
|
@ -237,17 +233,17 @@ class TestImportExport(SupersetTestCase):
|
|||
birth_dash = self.get_dash_by_slug("births")
|
||||
self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0])
|
||||
self.assert_dash_equals(birth_dash, exported_dashboards[0])
|
||||
self.assertEqual(
|
||||
id_,
|
||||
json.loads(
|
||||
assert (
|
||||
id_
|
||||
== json.loads(
|
||||
exported_dashboards[0].json_metadata, object_hook=decode_dashboards
|
||||
)["remote_id"],
|
||||
)["remote_id"]
|
||||
)
|
||||
|
||||
exported_tables = json.loads(
|
||||
resp.data.decode("utf-8"), object_hook=decode_dashboards
|
||||
)["datasources"]
|
||||
self.assertEqual(1, len(exported_tables))
|
||||
assert 1 == len(exported_tables)
|
||||
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
|
||||
|
||||
@unittest.skip("Schema needs to be updated")
|
||||
|
|
@ -269,27 +265,28 @@ class TestImportExport(SupersetTestCase):
|
|||
exported_dashboards = sorted(
|
||||
resp_data.get("dashboards"), key=lambda d: d.dashboard_title
|
||||
)
|
||||
self.assertEqual(2, len(exported_dashboards))
|
||||
assert 2 == len(exported_dashboards)
|
||||
|
||||
birth_dash = self.get_dash_by_slug("births")
|
||||
self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0])
|
||||
self.assert_dash_equals(birth_dash, exported_dashboards[0])
|
||||
self.assertEqual(
|
||||
birth_dash.id, json.loads(exported_dashboards[0].json_metadata)["remote_id"]
|
||||
assert (
|
||||
birth_dash.id
|
||||
== json.loads(exported_dashboards[0].json_metadata)["remote_id"]
|
||||
)
|
||||
|
||||
world_health_dash = self.get_dash_by_slug("world_health")
|
||||
self.assert_only_exported_slc_fields(world_health_dash, exported_dashboards[1])
|
||||
self.assert_dash_equals(world_health_dash, exported_dashboards[1])
|
||||
self.assertEqual(
|
||||
world_health_dash.id,
|
||||
json.loads(exported_dashboards[1].json_metadata)["remote_id"],
|
||||
assert (
|
||||
world_health_dash.id
|
||||
== json.loads(exported_dashboards[1].json_metadata)["remote_id"]
|
||||
)
|
||||
|
||||
exported_tables = sorted(
|
||||
resp_data.get("datasources"), key=lambda t: t.table_name
|
||||
)
|
||||
self.assertEqual(2, len(exported_tables))
|
||||
assert 2 == len(exported_tables)
|
||||
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
|
||||
self.assert_table_equals(
|
||||
self.get_table(name="wb_health_population"), exported_tables[1]
|
||||
|
|
@ -302,11 +299,11 @@ class TestImportExport(SupersetTestCase):
|
|||
)
|
||||
slc_id = import_chart(expected_slice, None, import_time=1989)
|
||||
slc = self.get_slice(slc_id)
|
||||
self.assertEqual(slc.datasource.perm, slc.perm)
|
||||
assert slc.datasource.perm == slc.perm
|
||||
self.assert_slice_equals(expected_slice, slc)
|
||||
|
||||
table_id = self.get_table(name="wb_health_population").id
|
||||
self.assertEqual(table_id, self.get_slice(slc_id).datasource_id)
|
||||
assert table_id == self.get_slice(slc_id).datasource_id
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_import_2_slices_for_same_table(self):
|
||||
|
|
@ -323,13 +320,13 @@ class TestImportExport(SupersetTestCase):
|
|||
|
||||
imported_slc_1 = self.get_slice(slc_id_1)
|
||||
imported_slc_2 = self.get_slice(slc_id_2)
|
||||
self.assertEqual(table_id, imported_slc_1.datasource_id)
|
||||
assert table_id == imported_slc_1.datasource_id
|
||||
self.assert_slice_equals(slc_1, imported_slc_1)
|
||||
self.assertEqual(imported_slc_1.datasource.perm, imported_slc_1.perm)
|
||||
assert imported_slc_1.datasource.perm == imported_slc_1.perm
|
||||
|
||||
self.assertEqual(table_id, imported_slc_2.datasource_id)
|
||||
assert table_id == imported_slc_2.datasource_id
|
||||
self.assert_slice_equals(slc_2, imported_slc_2)
|
||||
self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm)
|
||||
assert imported_slc_2.datasource.perm == imported_slc_2.perm
|
||||
|
||||
def test_import_slices_override(self):
|
||||
schema = get_example_default_schema()
|
||||
|
|
@ -339,7 +336,7 @@ class TestImportExport(SupersetTestCase):
|
|||
imported_slc_1 = self.get_slice(slc_1_id)
|
||||
slc_2 = self.create_slice("Import Me New", id=10005, schema=schema)
|
||||
slc_2_id = import_chart(slc_2, imported_slc_1, import_time=1990)
|
||||
self.assertEqual(slc_1_id, slc_2_id)
|
||||
assert slc_1_id == slc_2_id
|
||||
imported_slc_2 = self.get_slice(slc_2_id)
|
||||
self.assert_slice_equals(slc, imported_slc_2)
|
||||
|
||||
|
|
@ -379,21 +376,18 @@ class TestImportExport(SupersetTestCase):
|
|||
self.assert_dash_equals(
|
||||
expected_dash, imported_dash, check_position=False, check_slugs=False
|
||||
)
|
||||
self.assertEqual(
|
||||
{
|
||||
assert {
|
||||
"remote_id": 10002,
|
||||
"import_time": 1990,
|
||||
"native_filter_configuration": [],
|
||||
},
|
||||
json.loads(imported_dash.json_metadata),
|
||||
)
|
||||
} == json.loads(imported_dash.json_metadata)
|
||||
|
||||
expected_position = dash_with_1_slice.position
|
||||
# new slice id (auto-incremental) assigned on insert
|
||||
# id from json is used only for updating position with new id
|
||||
meta = expected_position["DASHBOARD_CHART_TYPE-10006"]["meta"]
|
||||
meta["chartId"] = imported_dash.slices[0].id
|
||||
self.assertEqual(expected_position, imported_dash.position)
|
||||
assert expected_position == imported_dash.position
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_import_dashboard_2_slices(self):
|
||||
|
|
@ -444,9 +438,7 @@ class TestImportExport(SupersetTestCase):
|
|||
},
|
||||
"native_filter_configuration": [],
|
||||
}
|
||||
self.assertEqual(
|
||||
expected_json_metadata, json.loads(imported_dash.json_metadata)
|
||||
)
|
||||
assert expected_json_metadata == json.loads(imported_dash.json_metadata)
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_import_override_dashboard_2_slices(self):
|
||||
|
|
@ -478,7 +470,7 @@ class TestImportExport(SupersetTestCase):
|
|||
imported_dash_id_2 = import_dashboard(dash_to_import_override, import_time=1992)
|
||||
|
||||
# override doesn't change the id
|
||||
self.assertEqual(imported_dash_id_1, imported_dash_id_2)
|
||||
assert imported_dash_id_1 == imported_dash_id_2
|
||||
expected_dash = self.create_dashboard(
|
||||
"override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004
|
||||
)
|
||||
|
|
@ -487,20 +479,17 @@ class TestImportExport(SupersetTestCase):
|
|||
self.assert_dash_equals(
|
||||
expected_dash, imported_dash, check_position=False, check_slugs=False
|
||||
)
|
||||
self.assertEqual(
|
||||
{
|
||||
assert {
|
||||
"remote_id": 10004,
|
||||
"import_time": 1992,
|
||||
"native_filter_configuration": [],
|
||||
},
|
||||
json.loads(imported_dash.json_metadata),
|
||||
)
|
||||
} == json.loads(imported_dash.json_metadata)
|
||||
|
||||
def test_import_new_dashboard_slice_reset_ownership(self):
|
||||
admin_user = security_manager.find_user(username="admin")
|
||||
self.assertTrue(admin_user)
|
||||
assert admin_user
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
self.assertTrue(gamma_user)
|
||||
assert gamma_user
|
||||
g.user = gamma_user
|
||||
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10200)
|
||||
|
|
@ -511,35 +500,35 @@ class TestImportExport(SupersetTestCase):
|
|||
|
||||
imported_dash_id = import_dashboard(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
assert imported_dash.created_by == gamma_user
|
||||
assert imported_dash.changed_by == gamma_user
|
||||
assert imported_dash.owners == [gamma_user]
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
assert imported_slc.created_by == gamma_user
|
||||
assert imported_slc.changed_by == gamma_user
|
||||
assert imported_slc.owners == [gamma_user]
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_import_override_dashboard_slice_reset_ownership(self):
|
||||
admin_user = security_manager.find_user(username="admin")
|
||||
self.assertTrue(admin_user)
|
||||
assert admin_user
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
self.assertTrue(gamma_user)
|
||||
assert gamma_user
|
||||
g.user = gamma_user
|
||||
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
|
||||
|
||||
imported_dash_id = import_dashboard(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
assert imported_dash.created_by == gamma_user
|
||||
assert imported_dash.changed_by == gamma_user
|
||||
assert imported_dash.owners == [gamma_user]
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
assert imported_slc.created_by == gamma_user
|
||||
assert imported_slc.changed_by == gamma_user
|
||||
assert imported_slc.owners == [gamma_user]
|
||||
|
||||
# re-import with another user shouldn't change the permissions
|
||||
g.user = admin_user
|
||||
|
|
@ -547,14 +536,14 @@ class TestImportExport(SupersetTestCase):
|
|||
|
||||
imported_dash_id = import_dashboard(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
assert imported_dash.created_by == gamma_user
|
||||
assert imported_dash.changed_by == gamma_user
|
||||
assert imported_dash.owners == [gamma_user]
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
assert imported_slc.created_by == gamma_user
|
||||
assert imported_slc.changed_by == gamma_user
|
||||
assert imported_slc.owners == [gamma_user]
|
||||
|
||||
def _create_dashboard_for_import(self, id_=10100):
|
||||
slc = self.create_slice(
|
||||
|
|
@ -600,10 +589,11 @@ class TestImportExport(SupersetTestCase):
|
|||
imported_id = import_dataset(table, db_id, import_time=1990)
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
self.assertEqual(
|
||||
{"remote_id": 10002, "import_time": 1990, "database_name": "examples"},
|
||||
json.loads(imported.params),
|
||||
)
|
||||
assert {
|
||||
"remote_id": 10002,
|
||||
"import_time": 1990,
|
||||
"database_name": "examples",
|
||||
} == json.loads(imported.params)
|
||||
|
||||
def test_import_table_2_col_2_met(self):
|
||||
schema = get_example_default_schema()
|
||||
|
|
@ -642,7 +632,7 @@ class TestImportExport(SupersetTestCase):
|
|||
imported_over_id = import_dataset(table_over, db_id, import_time=1992)
|
||||
|
||||
imported_over = self.get_table_by_id(imported_over_id)
|
||||
self.assertEqual(imported_id, imported_over.id)
|
||||
assert imported_id == imported_over.id
|
||||
expected_table = self.create_table(
|
||||
"table_override",
|
||||
id=10003,
|
||||
|
|
@ -673,7 +663,7 @@ class TestImportExport(SupersetTestCase):
|
|||
)
|
||||
imported_id_copy = import_dataset(copy_table, db_id, import_time=1994)
|
||||
|
||||
self.assertEqual(imported_id, imported_id_copy)
|
||||
assert imported_id == imported_id_copy
|
||||
self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class TestLogApi(SupersetTestCase):
|
|||
arguments = {"filters": [{"col": "action", "opr": "sw", "value": "some_"}]}
|
||||
uri = f"api/v1/log/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_list(self):
|
||||
"""
|
||||
|
|
@ -94,11 +94,11 @@ class TestLogApi(SupersetTestCase):
|
|||
arguments = {"filters": [{"col": "action", "opr": "sw", "value": "some_"}]}
|
||||
uri = f"api/v1/log/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(list(response["result"][0].keys()), EXPECTED_COLUMNS)
|
||||
self.assertEqual(response["result"][0]["action"], "some_action")
|
||||
self.assertEqual(response["result"][0]["user"], {"username": "admin"})
|
||||
assert list(response["result"][0].keys()) == EXPECTED_COLUMNS
|
||||
assert response["result"][0]["action"] == "some_action"
|
||||
assert response["result"][0]["user"] == {"username": "admin"}
|
||||
db.session.delete(log)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -111,10 +111,10 @@ class TestLogApi(SupersetTestCase):
|
|||
self.login(GAMMA_USERNAME)
|
||||
uri = "api/v1/log/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
self.login(ALPHA_USERNAME)
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
db.session.delete(log)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -127,12 +127,12 @@ class TestLogApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/log/{log.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(list(response["result"].keys()), EXPECTED_COLUMNS)
|
||||
self.assertEqual(response["result"]["action"], "some_action")
|
||||
self.assertEqual(response["result"]["user"], {"username": "admin"})
|
||||
assert list(response["result"].keys()) == EXPECTED_COLUMNS
|
||||
assert response["result"]["action"] == "some_action"
|
||||
assert response["result"]["user"] == {"username": "admin"}
|
||||
db.session.delete(log)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ class TestLogApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/log/{log.id}"
|
||||
rv = self.client.delete(uri)
|
||||
self.assertEqual(rv.status_code, 405)
|
||||
assert rv.status_code == 405
|
||||
db.session.delete(log)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -160,7 +160,7 @@ class TestLogApi(SupersetTestCase):
|
|||
log_data = {"action": "some_action"}
|
||||
uri = f"api/v1/log/{log.id}"
|
||||
rv = self.client.put(uri, json=log_data)
|
||||
self.assertEqual(rv.status_code, 405)
|
||||
assert rv.status_code == 405
|
||||
db.session.delete(log)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ class TestLogApi(SupersetTestCase):
|
|||
|
||||
uri = f"api/v1/log/recent_activity/" # noqa: F541
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
db.session.delete(log1)
|
||||
|
|
@ -184,9 +184,7 @@ class TestLogApi(SupersetTestCase):
|
|||
db.session.delete(dash)
|
||||
db.session.commit()
|
||||
|
||||
self.assertEqual(
|
||||
response,
|
||||
{
|
||||
assert response == {
|
||||
"result": [
|
||||
{
|
||||
"action": "dashboard",
|
||||
|
|
@ -197,8 +195,7 @@ class TestLogApi(SupersetTestCase):
|
|||
"time_delta_humanized": ANY,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
def test_get_recent_activity_actions_filter(self):
|
||||
"""
|
||||
|
|
@ -219,9 +216,9 @@ class TestLogApi(SupersetTestCase):
|
|||
db.session.delete(dash)
|
||||
db.session.commit()
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(len(response["result"]), 1)
|
||||
assert len(response["result"]) == 1
|
||||
|
||||
def test_get_recent_activity_distinct_false(self):
|
||||
"""
|
||||
|
|
@ -243,9 +240,9 @@ class TestLogApi(SupersetTestCase):
|
|||
db.session.delete(log2)
|
||||
db.session.delete(dash)
|
||||
db.session.commit()
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(len(response["result"]), 2)
|
||||
assert len(response["result"]) == 2
|
||||
|
||||
def test_get_recent_activity_pagination(self):
|
||||
"""
|
||||
|
|
@ -269,11 +266,9 @@ class TestLogApi(SupersetTestCase):
|
|||
uri = f"api/v1/log/recent_activity/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response,
|
||||
{
|
||||
assert response == {
|
||||
"result": [
|
||||
{
|
||||
"action": "dashboard",
|
||||
|
|
@ -292,8 +287,7 @@ class TestLogApi(SupersetTestCase):
|
|||
"time_delta_humanized": ANY,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
arguments = {"page": 1, "page_size": 2}
|
||||
uri = f"api/v1/log/recent_activity/?q={prison.dumps(arguments)}"
|
||||
|
|
@ -307,11 +301,9 @@ class TestLogApi(SupersetTestCase):
|
|||
db.session.delete(dash3)
|
||||
db.session.commit()
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
response,
|
||||
{
|
||||
assert response == {
|
||||
"result": [
|
||||
{
|
||||
"action": "dashboard",
|
||||
|
|
@ -322,5 +314,4 @@ class TestLogApi(SupersetTestCase):
|
|||
"time_delta_humanized": ANY,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,4 +52,4 @@ class TestLoggingConfigurator(unittest.TestCase):
|
|||
cfg.configure_logging(MagicMock(), True)
|
||||
|
||||
logging.info("test", extra={"testattr": "foo"})
|
||||
self.assertTrue(handler.received)
|
||||
assert handler.received
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
import re
|
||||
from superset.utils.core import DatasourceType
|
||||
from superset.utils import json
|
||||
import unittest
|
||||
|
|
@ -59,22 +60,22 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("hive/default", db)
|
||||
assert "hive/default" == db
|
||||
|
||||
with model.get_sqla_engine(schema="core_db") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("hive/core_db", db)
|
||||
assert "hive/core_db" == db
|
||||
|
||||
sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("hive", db)
|
||||
assert "hive" == db
|
||||
|
||||
with model.get_sqla_engine(schema="core_db") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("hive/core_db", db)
|
||||
assert "hive/core_db" == db
|
||||
|
||||
def test_database_schema_postgres(self):
|
||||
sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
|
||||
|
|
@ -82,11 +83,11 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("prod", db)
|
||||
assert "prod" == db
|
||||
|
||||
with model.get_sqla_engine(schema="foo") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("prod", db)
|
||||
assert "prod" == db
|
||||
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("thrift"), "thrift not installed"
|
||||
|
|
@ -100,11 +101,11 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("default", db)
|
||||
assert "default" == db
|
||||
|
||||
with model.get_sqla_engine(schema="core_db") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("core_db", db)
|
||||
assert "core_db" == db
|
||||
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
|
||||
|
|
@ -115,11 +116,11 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
with model.get_sqla_engine() as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("superset", db)
|
||||
assert "superset" == db
|
||||
|
||||
with model.get_sqla_engine(schema="staging") as engine:
|
||||
db = make_url(engine.url).database
|
||||
self.assertEqual("staging", db)
|
||||
assert "staging" == db
|
||||
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
|
||||
|
|
@ -133,12 +134,12 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
model.impersonate_user = True
|
||||
with model.get_sqla_engine() as engine:
|
||||
username = make_url(engine.url).username
|
||||
self.assertEqual(example_user.username, username)
|
||||
assert example_user.username == username
|
||||
|
||||
model.impersonate_user = False
|
||||
with model.get_sqla_engine() as engine:
|
||||
username = make_url(engine.url).username
|
||||
self.assertNotEqual(example_user.username, username)
|
||||
assert example_user.username != username
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_impersonate_user_presto(self, mocked_create_engine):
|
||||
|
|
@ -344,20 +345,20 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
if main_db.backend == "mysql":
|
||||
df = main_db.get_df("SELECT 1", None, None)
|
||||
self.assertEqual(df.iat[0, 0], 1)
|
||||
assert df.iat[0, 0] == 1
|
||||
|
||||
df = main_db.get_df("SELECT 1;", None, None)
|
||||
self.assertEqual(df.iat[0, 0], 1)
|
||||
assert df.iat[0, 0] == 1
|
||||
|
||||
def test_multi_statement(self):
|
||||
main_db = get_example_database()
|
||||
|
||||
if main_db.backend == "mysql":
|
||||
df = main_db.get_df("USE superset; SELECT 1", None, None)
|
||||
self.assertEqual(df.iat[0, 0], 1)
|
||||
assert df.iat[0, 0] == 1
|
||||
|
||||
df = main_db.get_df("USE superset; SELECT ';';", None, None)
|
||||
self.assertEqual(df.iat[0, 0], ";")
|
||||
assert df.iat[0, 0] == ";"
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_get_sqla_engine(self, mocked_create_engine):
|
||||
|
|
@ -404,20 +405,20 @@ class TestSqlaTableModel(SupersetTestCase):
|
|||
sqla_literal = ds_col.get_timestamp_expression(None)
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
self.assertEqual(compiled, "from_unixtime(ds)")
|
||||
assert compiled == "from_unixtime(ds)"
|
||||
|
||||
ds_col.python_date_format = "epoch_s"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
self.assertEqual(compiled, "DATE(from_unixtime(ds))")
|
||||
assert compiled == "DATE(from_unixtime(ds))"
|
||||
|
||||
prev_ds_expr = ds_col.expression
|
||||
ds_col.expression = "DATE_ADD(ds, 1)"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
self.assertEqual(compiled, "DATE(from_unixtime(DATE_ADD(ds, 1)))")
|
||||
assert compiled == "DATE(from_unixtime(DATE_ADD(ds, 1)))"
|
||||
ds_col.expression = prev_ds_expr
|
||||
|
||||
def query_with_expr_helper(self, is_timeseries, inner_join=True):
|
||||
|
|
@ -448,16 +449,16 @@ class TestSqlaTableModel(SupersetTestCase):
|
|||
series_limit=15 if inner_join and is_timeseries else None,
|
||||
)
|
||||
qr = tbl.query(query_obj)
|
||||
self.assertEqual(qr.status, QueryStatus.SUCCESS)
|
||||
assert qr.status == QueryStatus.SUCCESS
|
||||
sql = qr.query
|
||||
self.assertIn(arbitrary_gby, sql)
|
||||
self.assertIn("name", sql)
|
||||
assert arbitrary_gby in sql
|
||||
assert "name" in sql
|
||||
if inner_join and is_timeseries:
|
||||
self.assertIn("JOIN", sql.upper())
|
||||
assert "JOIN" in sql.upper()
|
||||
else:
|
||||
self.assertNotIn("JOIN", sql.upper())
|
||||
assert "JOIN" not in sql.upper()
|
||||
spec.allows_joins = old_inner_join
|
||||
self.assertFalse(qr.df.empty)
|
||||
assert not qr.df.empty
|
||||
return qr.df
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
|
@ -475,7 +476,7 @@ class TestSqlaTableModel(SupersetTestCase):
|
|||
name_list1 = canonicalize_df(df1).name.values.tolist()
|
||||
df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False)
|
||||
name_list2 = canonicalize_df(df1).name.values.tolist()
|
||||
self.assertFalse(df2.empty)
|
||||
assert not df2.empty
|
||||
|
||||
assert name_list2 == name_list1
|
||||
|
||||
|
|
@ -498,14 +499,14 @@ class TestSqlaTableModel(SupersetTestCase):
|
|||
extras={},
|
||||
)
|
||||
sql = tbl.get_query_str(query_obj)
|
||||
self.assertNotIn("-- COMMENT", sql)
|
||||
assert "-- COMMENT" not in sql
|
||||
|
||||
def mutator(*args, **kwargs):
|
||||
return "-- COMMENT\n" + args[0]
|
||||
|
||||
app.config["SQL_QUERY_MUTATOR"] = mutator
|
||||
sql = tbl.get_query_str(query_obj)
|
||||
self.assertIn("-- COMMENT", sql)
|
||||
assert "-- COMMENT" in sql
|
||||
|
||||
app.config["SQL_QUERY_MUTATOR"] = None
|
||||
|
||||
|
|
@ -524,15 +525,15 @@ class TestSqlaTableModel(SupersetTestCase):
|
|||
extras={},
|
||||
)
|
||||
sql = tbl.get_query_str(query_obj)
|
||||
self.assertNotIn("-- COMMENT", sql)
|
||||
assert "-- COMMENT" not in sql
|
||||
|
||||
def mutator(sql, database=None, **kwargs):
|
||||
return "-- COMMENT\n--" + "\n" + str(database) + "\n" + sql
|
||||
|
||||
app.config["SQL_QUERY_MUTATOR"] = mutator
|
||||
mutated_sql = tbl.get_query_str(query_obj)
|
||||
self.assertIn("-- COMMENT", mutated_sql)
|
||||
self.assertIn(tbl.database.name, mutated_sql)
|
||||
assert "-- COMMENT" in mutated_sql
|
||||
assert tbl.database.name in mutated_sql
|
||||
|
||||
app.config["SQL_QUERY_MUTATOR"] = None
|
||||
|
||||
|
|
@ -554,7 +555,7 @@ class TestSqlaTableModel(SupersetTestCase):
|
|||
with self.assertRaises(Exception) as context:
|
||||
tbl.get_query_str(query_obj)
|
||||
|
||||
self.assertTrue("Metric 'invalid' does not exist", context.exception)
|
||||
assert "Metric 'invalid' does not exist", context.exception
|
||||
|
||||
def test_query_label_without_group_by(self):
|
||||
tbl = self.get_table(name="birth_names")
|
||||
|
|
@ -577,7 +578,7 @@ class TestSqlaTableModel(SupersetTestCase):
|
|||
)
|
||||
|
||||
sql = tbl.get_query_str(query_obj)
|
||||
self.assertRegex(sql, r'name AS ["`]?Given Name["`]?')
|
||||
assert re.search('name AS ["`]?Given Name["`]?', sql) # noqa: F821
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_data_for_slices_with_no_query_context(self):
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class TestQueryApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/query/{query.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
expected_result = {
|
||||
"database": {"id": example_db.id},
|
||||
|
|
@ -163,7 +163,7 @@ class TestQueryApi(SupersetTestCase):
|
|||
"tracking_url": None,
|
||||
}
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertIn("changed_on", data["result"])
|
||||
assert "changed_on" in data["result"]
|
||||
for key, value in data["result"].items():
|
||||
# We can't assert timestamp
|
||||
if key not in (
|
||||
|
|
@ -173,7 +173,7 @@ class TestQueryApi(SupersetTestCase):
|
|||
"start_time",
|
||||
"id",
|
||||
):
|
||||
self.assertEqual(value, expected_result[key])
|
||||
assert value == expected_result[key]
|
||||
# rollback changes
|
||||
db.session.delete(query)
|
||||
db.session.commit()
|
||||
|
|
@ -189,7 +189,7 @@ class TestQueryApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/query/{max_id + 1}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
db.session.delete(query)
|
||||
db.session.commit()
|
||||
|
|
@ -222,30 +222,30 @@ class TestQueryApi(SupersetTestCase):
|
|||
self.login(username="gamma_1", password="password")
|
||||
uri = f"api/v1/query/{query_gamma2.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
uri = f"api/v1/query/{query_gamma1.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
# Gamma2 user, only sees their own queries
|
||||
self.logout()
|
||||
self.login(username="gamma_2", password="password")
|
||||
uri = f"api/v1/query/{query_gamma1.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
uri = f"api/v1/query/{query_gamma2.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
# Admin's have the "all query access" permission
|
||||
self.logout()
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/query/{query_gamma1.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
uri = f"api/v1/query/{query_gamma2.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
# rollback changes
|
||||
db.session.delete(query_gamma1)
|
||||
|
|
@ -262,7 +262,7 @@ class TestQueryApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/query/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] == QUERIES_FIXTURE_COUNT
|
||||
# check expected columns
|
||||
|
|
@ -433,11 +433,11 @@ class TestQueryApi(SupersetTestCase):
|
|||
timestamp = datetime.timestamp(now - timedelta(days=2)) * 1000
|
||||
uri = f"api/v1/query/updated_since?q={prison.dumps({'last_updated_ms': timestamp})}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
expected_result = updated_query.to_dict()
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(len(data["result"]), 1)
|
||||
assert len(data["result"]) == 1
|
||||
for key, value in data["result"][0].items():
|
||||
# We can't assert timestamp
|
||||
if key not in (
|
||||
|
|
@ -447,7 +447,7 @@ class TestQueryApi(SupersetTestCase):
|
|||
"start_time",
|
||||
"id",
|
||||
):
|
||||
self.assertEqual(value, expected_result[key])
|
||||
assert value == expected_result[key]
|
||||
# rollback changes
|
||||
db.session.delete(old_query)
|
||||
db.session.delete(updated_query)
|
||||
|
|
|
|||
|
|
@ -467,26 +467,22 @@ class TestSavedQueryApi(SupersetTestCase):
|
|||
# Filter by tag ID
|
||||
filter_params = get_filter_params("saved_query_tag_id", tag.id)
|
||||
response_by_id = self.get_list("saved_query", filter_params)
|
||||
self.assertEqual(response_by_id.status_code, 200)
|
||||
assert response_by_id.status_code == 200
|
||||
data_by_id = json.loads(response_by_id.data.decode("utf-8"))
|
||||
|
||||
# Filter by tag name
|
||||
filter_params = get_filter_params("saved_query_tags", tag.name)
|
||||
response_by_name = self.get_list("saved_query", filter_params)
|
||||
self.assertEqual(response_by_name.status_code, 200)
|
||||
assert response_by_name.status_code == 200
|
||||
data_by_name = json.loads(response_by_name.data.decode("utf-8"))
|
||||
|
||||
# Compare results
|
||||
self.assertEqual(
|
||||
data_by_id["count"],
|
||||
data_by_name["count"],
|
||||
len(expected_saved_queries),
|
||||
)
|
||||
self.assertEqual(
|
||||
set(query["id"] for query in data_by_id["result"]),
|
||||
set(query["id"] for query in data_by_name["result"]),
|
||||
set(query.id for query in expected_saved_queries),
|
||||
assert data_by_id["count"] == data_by_name["count"], len(
|
||||
expected_saved_queries
|
||||
)
|
||||
assert set(query["id"] for query in data_by_id["result"]) == set(
|
||||
query["id"] for query in data_by_name["result"]
|
||||
), set(query.id for query in expected_saved_queries)
|
||||
|
||||
@pytest.mark.usefixtures("create_saved_queries")
|
||||
def test_get_saved_query_favorite_filter(self):
|
||||
|
|
|
|||
|
|
@ -70,15 +70,15 @@ class TestQueryContext(SupersetTestCase):
|
|||
|
||||
payload = get_query_context("birth_names", add_postprocessing_operations=True)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
self.assertEqual(len(query_context.queries), len(payload["queries"]))
|
||||
assert len(query_context.queries) == len(payload["queries"])
|
||||
|
||||
for query_idx, query in enumerate(query_context.queries):
|
||||
payload_query = payload["queries"][query_idx]
|
||||
|
||||
# check basic properties
|
||||
self.assertEqual(query.extras, payload_query["extras"])
|
||||
self.assertEqual(query.filter, payload_query["filters"])
|
||||
self.assertEqual(query.columns, payload_query["columns"])
|
||||
assert query.extras == payload_query["extras"]
|
||||
assert query.filter == payload_query["filters"]
|
||||
assert query.columns == payload_query["columns"]
|
||||
|
||||
# metrics are mutated during creation
|
||||
for metric_idx, metric in enumerate(query.metrics):
|
||||
|
|
@ -88,16 +88,16 @@ class TestQueryContext(SupersetTestCase):
|
|||
if "expressionType" in payload_metric
|
||||
else payload_metric["label"]
|
||||
)
|
||||
self.assertEqual(metric, payload_metric)
|
||||
assert metric == payload_metric
|
||||
|
||||
self.assertEqual(query.orderby, payload_query["orderby"])
|
||||
self.assertEqual(query.time_range, payload_query["time_range"])
|
||||
assert query.orderby == payload_query["orderby"]
|
||||
assert query.time_range == payload_query["time_range"]
|
||||
|
||||
# check post processing operation properties
|
||||
for post_proc_idx, post_proc in enumerate(query.post_processing):
|
||||
payload_post_proc = payload_query["post_processing"][post_proc_idx]
|
||||
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
|
||||
self.assertEqual(post_proc["options"], payload_post_proc["options"])
|
||||
assert post_proc["operation"] == payload_post_proc["operation"]
|
||||
assert post_proc["options"] == payload_post_proc["options"]
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_cache(self):
|
||||
|
|
@ -128,12 +128,12 @@ class TestQueryContext(SupersetTestCase):
|
|||
rehydrated_qo = rehydrated_qc.queries[0]
|
||||
rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)
|
||||
|
||||
self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
|
||||
self.assertEqual(len(rehydrated_qc.queries), 1)
|
||||
self.assertEqual(query_cache_key, rehydrated_query_cache_key)
|
||||
self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
|
||||
self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
|
||||
self.assertFalse(rehydrated_qc.force)
|
||||
assert rehydrated_qc.datasource == query_context.datasource
|
||||
assert len(rehydrated_qc.queries) == 1
|
||||
assert query_cache_key == rehydrated_query_cache_key
|
||||
assert rehydrated_qc.result_type == query_context.result_type
|
||||
assert rehydrated_qc.result_format == query_context.result_format
|
||||
assert not rehydrated_qc.force
|
||||
|
||||
def test_query_cache_key_changes_when_datasource_is_updated(self):
|
||||
payload = get_query_context("birth_names")
|
||||
|
|
@ -164,7 +164,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
cache_key_new = query_context.query_cache_key(query_object)
|
||||
|
||||
# the new cache_key should be different due to updated datasource
|
||||
self.assertNotEqual(cache_key_original, cache_key_new)
|
||||
assert cache_key_original != cache_key_new
|
||||
|
||||
def test_query_cache_key_changes_when_metric_is_updated(self):
|
||||
payload = get_query_context("birth_names")
|
||||
|
|
@ -198,7 +198,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
# the new cache_key should be different due to updated datasource
|
||||
self.assertNotEqual(cache_key_original, cache_key_new)
|
||||
assert cache_key_original != cache_key_new
|
||||
|
||||
def test_query_cache_key_does_not_change_for_non_existent_or_null(self):
|
||||
payload = get_query_context("birth_names", add_postprocessing_operations=True)
|
||||
|
|
@ -228,14 +228,14 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key = query_context.query_cache_key(query_object)
|
||||
self.assertEqual(cache_key_original, cache_key)
|
||||
assert cache_key_original == cache_key
|
||||
|
||||
# ensure query without post processing operation is different
|
||||
payload["queries"][0].pop("post_processing")
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key = query_context.query_cache_key(query_object)
|
||||
self.assertNotEqual(cache_key_original, cache_key)
|
||||
assert cache_key_original != cache_key
|
||||
|
||||
def test_query_cache_key_changes_when_time_offsets_is_updated(self):
|
||||
payload = get_query_context("birth_names", add_time_offsets=True)
|
||||
|
|
@ -248,7 +248,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key = query_context.query_cache_key(query_object)
|
||||
self.assertNotEqual(cache_key_original, cache_key)
|
||||
assert cache_key_original != cache_key
|
||||
|
||||
def test_handle_metrics_field(self):
|
||||
"""
|
||||
|
|
@ -265,7 +265,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload["queries"][0]["metrics"] = ["sum__num", {"label": "abc"}, adhoc_metric]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.metrics, ["sum__num", "abc", adhoc_metric])
|
||||
assert query_object.metrics == ["sum__num", "abc", adhoc_metric]
|
||||
|
||||
def test_convert_deprecated_fields(self):
|
||||
"""
|
||||
|
|
@ -280,12 +280,12 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload["queries"][0]["granularity_sqla"] = "timecol"
|
||||
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
self.assertEqual(len(query_context.queries), 1)
|
||||
assert len(query_context.queries) == 1
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.granularity, "timecol")
|
||||
self.assertEqual(query_object.columns, columns)
|
||||
self.assertEqual(query_object.series_limit, 99)
|
||||
self.assertEqual(query_object.series_limit_metric, "sum__num")
|
||||
assert query_object.granularity == "timecol"
|
||||
assert query_object.columns == columns
|
||||
assert query_object.series_limit == 99
|
||||
assert query_object.series_limit_metric == "sum__num"
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_csv_response_format(self):
|
||||
|
|
@ -297,10 +297,10 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload["queries"][0]["row_limit"] = 10
|
||||
query_context: QueryContext = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
assert len(responses) == 1
|
||||
data = responses["queries"][0]["data"]
|
||||
self.assertIn("name,sum__num\n", data)
|
||||
self.assertEqual(len(data.split("\n")), 12)
|
||||
assert "name,sum__num\n" in data
|
||||
assert len(data.split("\n")) == 12
|
||||
|
||||
def test_sql_injection_via_groupby(self):
|
||||
"""
|
||||
|
|
@ -352,11 +352,11 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload["queries"][0]["row_limit"] = 5
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
assert len(responses) == 1
|
||||
data = responses["queries"][0]["data"]
|
||||
self.assertIsInstance(data, list)
|
||||
self.assertEqual(len(data), 5)
|
||||
self.assertNotIn("sum__num", data[0])
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 5
|
||||
assert "sum__num" not in data[0]
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_response_type(self):
|
||||
|
|
@ -489,7 +489,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
new_cache_key = responses["queries"][0]["cache_key"]
|
||||
self.assertEqual(orig_cache_key, new_cache_key)
|
||||
assert orig_cache_key == new_cache_key
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_time_offsets_in_query_object(self):
|
||||
|
|
@ -505,21 +505,18 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload["queries"][0]["time_range"] = "1990 : 1991"
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(
|
||||
responses["queries"][0]["colnames"],
|
||||
[
|
||||
assert responses["queries"][0]["colnames"] == [
|
||||
"__timestamp",
|
||||
"name",
|
||||
"sum__num",
|
||||
"sum__num__1 year ago",
|
||||
"sum__num__1 year later",
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
sqls = [
|
||||
sql for sql in responses["queries"][0]["query"].split(";") if sql.strip()
|
||||
]
|
||||
self.assertEqual(len(sqls), 3)
|
||||
assert len(sqls) == 3
|
||||
# 1 year ago
|
||||
assert re.search(r"1989-01-01.+1990-01-01", sqls[1], re.S)
|
||||
assert re.search(r"1990-01-01.+1991-01-01", sqls[1], re.S)
|
||||
|
|
@ -560,9 +557,9 @@ class TestQueryContext(SupersetTestCase):
|
|||
cache_keys = rv["cache_keys"]
|
||||
cache_keys__1_year_ago = cache_keys[0]
|
||||
cache_keys__1_year_later = cache_keys[1]
|
||||
self.assertIsNotNone(cache_keys__1_year_ago)
|
||||
self.assertIsNotNone(cache_keys__1_year_later)
|
||||
self.assertNotEqual(cache_keys__1_year_ago, cache_keys__1_year_later)
|
||||
assert cache_keys__1_year_ago is not None
|
||||
assert cache_keys__1_year_later is not None
|
||||
assert cache_keys__1_year_ago != cache_keys__1_year_later
|
||||
|
||||
# swap offsets
|
||||
payload["queries"][0]["time_offsets"] = ["1 year later", "1 year ago"]
|
||||
|
|
@ -570,8 +567,8 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_object = query_context.queries[0]
|
||||
rv = query_context.processing_time_offsets(df.copy(), query_object)
|
||||
cache_keys = rv["cache_keys"]
|
||||
self.assertEqual(cache_keys__1_year_ago, cache_keys[1])
|
||||
self.assertEqual(cache_keys__1_year_later, cache_keys[0])
|
||||
assert cache_keys__1_year_ago == cache_keys[1]
|
||||
assert cache_keys__1_year_later == cache_keys[0]
|
||||
|
||||
# remove all offsets
|
||||
payload["queries"][0]["time_offsets"] = []
|
||||
|
|
@ -582,9 +579,9 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_object,
|
||||
)
|
||||
|
||||
self.assertEqual(rv["df"].shape, df.shape)
|
||||
self.assertEqual(rv["queries"], [])
|
||||
self.assertEqual(rv["cache_keys"], [])
|
||||
assert rv["df"].shape == df.shape
|
||||
assert rv["queries"] == []
|
||||
assert rv["cache_keys"] == []
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_time_offsets_sql(self):
|
||||
|
|
@ -732,7 +729,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
row_limit_pattern_with_config_value = r"LIMIT " + re.escape(
|
||||
str(row_limit_value)
|
||||
)
|
||||
self.assertEqual(len(sqls), 2)
|
||||
assert len(sqls) == 2
|
||||
# 1 year ago
|
||||
assert re.search(r"1989-01-01.+1990-01-01", sqls[0], re.S)
|
||||
assert not re.search(r"LIMIT 100", sqls[0], re.S)
|
||||
|
|
|
|||
|
|
@ -1673,7 +1673,7 @@ class TestReportSchedulesApi(SupersetTestCase):
|
|||
}
|
||||
uri = f"api/v1/report/{report_schedule.id}"
|
||||
rv = self.put_assert_metric(uri, report_schedule_data, "put")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
|
||||
@pytest.mark.usefixtures("create_report_schedules")
|
||||
def test_update_report_preserve_ownership(self):
|
||||
|
|
@ -1819,7 +1819,7 @@ class TestReportSchedulesApi(SupersetTestCase):
|
|||
self.login(username="alpha2", password="password")
|
||||
uri = f"api/v1/report/{report_schedule.id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
|
||||
@pytest.mark.usefixtures("create_report_schedules")
|
||||
def test_bulk_delete_report_schedule(self):
|
||||
|
|
@ -1876,7 +1876,7 @@ class TestReportSchedulesApi(SupersetTestCase):
|
|||
self.login(username="alpha2", password="password")
|
||||
uri = f"api/v1/report/?q={prison.dumps(report_schedules_ids)}"
|
||||
rv = self.delete_assert_metric(uri, "bulk_delete")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
assert rv.status_code == 403
|
||||
|
||||
@pytest.mark.usefixtures("create_report_schedules")
|
||||
def test_get_list_report_schedule_logs(self):
|
||||
|
|
|
|||
|
|
@ -28,27 +28,34 @@ from .base_tests import SupersetTestCase
|
|||
|
||||
class TestSupersetResultSet(SupersetTestCase):
|
||||
def test_dedup(self):
|
||||
self.assertEqual(dedup(["foo", "bar"]), ["foo", "bar"])
|
||||
self.assertEqual(
|
||||
dedup(["foo", "bar", "foo", "bar", "Foo"]),
|
||||
["foo", "bar", "foo__1", "bar__1", "Foo"],
|
||||
)
|
||||
self.assertEqual(
|
||||
dedup(["foo", "bar", "bar", "bar", "Bar"]),
|
||||
["foo", "bar", "bar__1", "bar__2", "Bar"],
|
||||
)
|
||||
self.assertEqual(
|
||||
dedup(["foo", "bar", "bar", "bar", "Bar"], case_sensitive=False),
|
||||
["foo", "bar", "bar__1", "bar__2", "Bar__3"],
|
||||
)
|
||||
assert dedup(["foo", "bar"]) == ["foo", "bar"]
|
||||
assert dedup(["foo", "bar", "foo", "bar", "Foo"]) == [
|
||||
"foo",
|
||||
"bar",
|
||||
"foo__1",
|
||||
"bar__1",
|
||||
"Foo",
|
||||
]
|
||||
assert dedup(["foo", "bar", "bar", "bar", "Bar"]) == [
|
||||
"foo",
|
||||
"bar",
|
||||
"bar__1",
|
||||
"bar__2",
|
||||
"Bar",
|
||||
]
|
||||
assert dedup(["foo", "bar", "bar", "bar", "Bar"], case_sensitive=False) == [
|
||||
"foo",
|
||||
"bar",
|
||||
"bar__1",
|
||||
"bar__2",
|
||||
"Bar__3",
|
||||
]
|
||||
|
||||
def test_get_columns_basic(self):
|
||||
data = [("a1", "b1", "c1"), ("a2", "b2", "c2")]
|
||||
cursor_descr = (("a", "string"), ("b", "string"), ("c", "string"))
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(
|
||||
results.columns,
|
||||
[
|
||||
assert results.columns == [
|
||||
{
|
||||
"is_dttm": False,
|
||||
"type": "STRING",
|
||||
|
|
@ -70,16 +77,13 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
"column_name": "c",
|
||||
"name": "c",
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_get_columns_with_int(self):
|
||||
data = [("a1", 1), ("a2", 2)]
|
||||
cursor_descr = (("a", "string"), ("b", "int"))
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(
|
||||
results.columns,
|
||||
[
|
||||
assert results.columns == [
|
||||
{
|
||||
"is_dttm": False,
|
||||
"type": "STRING",
|
||||
|
|
@ -94,8 +98,7 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
"column_name": "b",
|
||||
"name": "b",
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_get_columns_type_inference(self):
|
||||
data = [
|
||||
|
|
@ -104,9 +107,7 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
]
|
||||
cursor_descr = (("a", None), ("b", None), ("c", None), ("d", None), ("e", None))
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(
|
||||
results.columns,
|
||||
[
|
||||
assert results.columns == [
|
||||
{
|
||||
"is_dttm": False,
|
||||
"type": "FLOAT",
|
||||
|
|
@ -142,34 +143,33 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
"column_name": "e",
|
||||
"name": "e",
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_is_date(self):
|
||||
data = [("a", 1), ("a", 2)]
|
||||
cursor_descr = (("a", "string"), ("a", "string"))
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.is_temporal("DATE"), True)
|
||||
self.assertEqual(results.is_temporal("DATETIME"), True)
|
||||
self.assertEqual(results.is_temporal("TIME"), True)
|
||||
self.assertEqual(results.is_temporal("TIMESTAMP"), True)
|
||||
self.assertEqual(results.is_temporal("STRING"), False)
|
||||
self.assertEqual(results.is_temporal(""), False)
|
||||
self.assertEqual(results.is_temporal(None), False)
|
||||
assert results.is_temporal("DATE") is True
|
||||
assert results.is_temporal("DATETIME") is True
|
||||
assert results.is_temporal("TIME") is True
|
||||
assert results.is_temporal("TIMESTAMP") is True
|
||||
assert results.is_temporal("STRING") is False
|
||||
assert results.is_temporal("") is False
|
||||
assert results.is_temporal(None) is False
|
||||
|
||||
def test_dedup_with_data(self):
|
||||
data = [("a", 1), ("a", 2)]
|
||||
cursor_descr = (("a", "string"), ("a", "string"))
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
column_names = [col["column_name"] for col in results.columns]
|
||||
self.assertListEqual(column_names, ["a", "a__1"])
|
||||
self.assertListEqual(column_names, ["a", "a__1"]) # noqa: PT009
|
||||
|
||||
def test_int64_with_missing_data(self):
|
||||
data = [(None,), (1239162456494753670,), (None,), (None,), (None,), (None,)]
|
||||
cursor_descr = [("user_id", "bigint", None, None, None, None, True)]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.columns[0]["type"], "BIGINT")
|
||||
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.NUMERIC)
|
||||
assert results.columns[0]["type"] == "BIGINT"
|
||||
assert results.columns[0]["type_generic"] == GenericDataType.NUMERIC
|
||||
|
||||
def test_data_as_list_of_lists(self):
|
||||
data = [[1, "a"], [2, "b"]]
|
||||
|
|
@ -179,29 +179,26 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
df = results.to_pandas_df()
|
||||
self.assertEqual(
|
||||
df_to_records(df),
|
||||
[{"user_id": 1, "username": "a"}, {"user_id": 2, "username": "b"}],
|
||||
)
|
||||
assert df_to_records(df) == [
|
||||
{"user_id": 1, "username": "a"},
|
||||
{"user_id": 2, "username": "b"},
|
||||
]
|
||||
|
||||
def test_nullable_bool(self):
|
||||
data = [(None,), (True,), (None,), (None,), (None,), (None,)]
|
||||
cursor_descr = [("is_test", "bool", None, None, None, None, True)]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.columns[0]["type"], "BOOL")
|
||||
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.BOOLEAN)
|
||||
assert results.columns[0]["type"] == "BOOL"
|
||||
assert results.columns[0]["type_generic"] == GenericDataType.BOOLEAN
|
||||
df = results.to_pandas_df()
|
||||
self.assertEqual(
|
||||
df_to_records(df),
|
||||
[
|
||||
assert df_to_records(df) == [
|
||||
{"is_test": None},
|
||||
{"is_test": True},
|
||||
{"is_test": None},
|
||||
{"is_test": None},
|
||||
{"is_test": None},
|
||||
{"is_test": None},
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_nested_types(self):
|
||||
data = [
|
||||
|
|
@ -220,18 +217,16 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
]
|
||||
cursor_descr = [("id",), ("dict_arr",), ("num_arr",), ("map_col",)]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.columns[0]["type"], "INT")
|
||||
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.NUMERIC)
|
||||
self.assertEqual(results.columns[1]["type"], "STRING")
|
||||
self.assertEqual(results.columns[1]["type_generic"], GenericDataType.STRING)
|
||||
self.assertEqual(results.columns[2]["type"], "STRING")
|
||||
self.assertEqual(results.columns[2]["type_generic"], GenericDataType.STRING)
|
||||
self.assertEqual(results.columns[3]["type"], "STRING")
|
||||
self.assertEqual(results.columns[3]["type_generic"], GenericDataType.STRING)
|
||||
assert results.columns[0]["type"] == "INT"
|
||||
assert results.columns[0]["type_generic"] == GenericDataType.NUMERIC
|
||||
assert results.columns[1]["type"] == "STRING"
|
||||
assert results.columns[1]["type_generic"] == GenericDataType.STRING
|
||||
assert results.columns[2]["type"] == "STRING"
|
||||
assert results.columns[2]["type_generic"] == GenericDataType.STRING
|
||||
assert results.columns[3]["type"] == "STRING"
|
||||
assert results.columns[3]["type_generic"] == GenericDataType.STRING
|
||||
df = results.to_pandas_df()
|
||||
self.assertEqual(
|
||||
df_to_records(df),
|
||||
[
|
||||
assert df_to_records(df) == [
|
||||
{
|
||||
"id": 4,
|
||||
"dict_arr": '[{"table_name": "unicode_test", "database_id": 1}]',
|
||||
|
|
@ -244,8 +239,7 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
"num_arr": "[4, 5, 6]",
|
||||
"map_col": "{'chart_name': 'plot'}",
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_single_column_multidim_nested_types(self):
|
||||
data = [
|
||||
|
|
@ -270,35 +264,30 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
]
|
||||
cursor_descr = [("metadata",)]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.columns[0]["type"], "STRING")
|
||||
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING)
|
||||
assert results.columns[0]["type"] == "STRING"
|
||||
assert results.columns[0]["type_generic"] == GenericDataType.STRING
|
||||
df = results.to_pandas_df()
|
||||
self.assertEqual(
|
||||
df_to_records(df),
|
||||
[
|
||||
assert df_to_records(df) == [
|
||||
{
|
||||
"metadata": '["test", [["foo", 123456, [[["test"], 3432546, 7657658766], [["fake"], 656756765, 324324324324]]]], ["test2", 43, 765765765], null, null]'
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_nested_list_types(self):
|
||||
data = [([{"TestKey": [123456, "foo"]}],)]
|
||||
cursor_descr = [("metadata",)]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.columns[0]["type"], "STRING")
|
||||
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING)
|
||||
assert results.columns[0]["type"] == "STRING"
|
||||
assert results.columns[0]["type_generic"] == GenericDataType.STRING
|
||||
df = results.to_pandas_df()
|
||||
self.assertEqual(
|
||||
df_to_records(df), [{"metadata": '[{"TestKey": [123456, "foo"]}]'}]
|
||||
)
|
||||
assert df_to_records(df) == [{"metadata": '[{"TestKey": [123456, "foo"]}]'}]
|
||||
|
||||
def test_empty_datetime(self):
|
||||
data = [(None,)]
|
||||
cursor_descr = [("ds", "timestamp", None, None, None, None, True)]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.columns[0]["type"], "TIMESTAMP")
|
||||
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.TEMPORAL)
|
||||
assert results.columns[0]["type"] == "TIMESTAMP"
|
||||
assert results.columns[0]["type_generic"] == GenericDataType.TEMPORAL
|
||||
|
||||
def test_no_type_coercion(self):
|
||||
data = [("a", 1), ("b", 2)]
|
||||
|
|
@ -307,10 +296,10 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
("two", "int", None, None, None, None, True),
|
||||
]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.columns[0]["type"], "VARCHAR")
|
||||
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING)
|
||||
self.assertEqual(results.columns[1]["type"], "INT")
|
||||
self.assertEqual(results.columns[1]["type_generic"], GenericDataType.NUMERIC)
|
||||
assert results.columns[0]["type"] == "VARCHAR"
|
||||
assert results.columns[0]["type_generic"] == GenericDataType.STRING
|
||||
assert results.columns[1]["type"] == "INT"
|
||||
assert results.columns[1]["type_generic"] == GenericDataType.NUMERIC
|
||||
|
||||
def test_empty_data(self):
|
||||
data = []
|
||||
|
|
@ -319,4 +308,4 @@ class TestSupersetResultSet(SupersetTestCase):
|
|||
("emptytwo", "int", None, None, None, None, True),
|
||||
]
|
||||
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
|
||||
self.assertEqual(results.columns, [])
|
||||
assert results.columns == []
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class TestSecurityCsrfApi(SupersetTestCase):
|
|||
response = self.client.get(uri)
|
||||
self.assert200(response)
|
||||
data = json.loads(response.data.decode("utf-8"))
|
||||
self.assertEqual(generate_csrf(), data["result"])
|
||||
assert generate_csrf() == data["result"]
|
||||
|
||||
def test_get_csrf_token(self):
|
||||
"""
|
||||
|
|
@ -120,8 +120,8 @@ class TestSecurityGuestTokenApi(SupersetTestCase):
|
|||
audience=get_url_host(),
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
self.assertEqual(user, decoded_token["user"])
|
||||
self.assertEqual(resource, decoded_token["resources"][0])
|
||||
assert user == decoded_token["user"]
|
||||
assert resource == decoded_token["resources"][0]
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_post_guest_token_bad_resources(self):
|
||||
|
|
|
|||
|
|
@ -55,15 +55,15 @@ class TestGuestUserSecurity(SupersetTestCase):
|
|||
|
||||
def test_is_guest_user__regular_user(self):
|
||||
is_guest = security_manager.is_guest_user(security_manager.find_user("admin"))
|
||||
self.assertFalse(is_guest)
|
||||
assert not is_guest
|
||||
|
||||
def test_is_guest_user__anonymous(self):
|
||||
is_guest = security_manager.is_guest_user(security_manager.get_anonymous_user())
|
||||
self.assertFalse(is_guest)
|
||||
assert not is_guest
|
||||
|
||||
def test_is_guest_user__guest_user(self):
|
||||
is_guest = security_manager.is_guest_user(self.authorized_guest())
|
||||
self.assertTrue(is_guest)
|
||||
assert is_guest
|
||||
|
||||
@patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -71,34 +71,34 @@ class TestGuestUserSecurity(SupersetTestCase):
|
|||
)
|
||||
def test_is_guest_user__flag_off(self):
|
||||
is_guest = security_manager.is_guest_user(self.authorized_guest())
|
||||
self.assertFalse(is_guest)
|
||||
assert not is_guest
|
||||
|
||||
def test_get_guest_user__regular_user(self):
|
||||
g.user = security_manager.find_user("admin")
|
||||
guest_user = security_manager.get_current_guest_user_if_guest()
|
||||
self.assertIsNone(guest_user)
|
||||
assert guest_user is None
|
||||
|
||||
def test_get_guest_user__anonymous_user(self):
|
||||
g.user = security_manager.get_anonymous_user()
|
||||
guest_user = security_manager.get_current_guest_user_if_guest()
|
||||
self.assertIsNone(guest_user)
|
||||
assert guest_user is None
|
||||
|
||||
def test_get_guest_user__guest_user(self):
|
||||
g.user = self.authorized_guest()
|
||||
guest_user = security_manager.get_current_guest_user_if_guest()
|
||||
self.assertEqual(guest_user, g.user)
|
||||
assert guest_user == g.user
|
||||
|
||||
def test_get_guest_user_roles_explicit(self):
|
||||
guest = self.authorized_guest()
|
||||
roles = security_manager.get_user_roles(guest)
|
||||
self.assertEqual(guest.roles, roles)
|
||||
assert guest.roles == roles
|
||||
|
||||
def test_get_guest_user_roles_implicit(self):
|
||||
guest = self.authorized_guest()
|
||||
g.user = guest
|
||||
|
||||
roles = security_manager.get_user_roles()
|
||||
self.assertEqual(guest.roles, roles)
|
||||
assert guest.roles == roles
|
||||
|
||||
|
||||
@patch.dict(
|
||||
|
|
@ -142,17 +142,17 @@ class TestGuestUserDashboardAccess(SupersetTestCase):
|
|||
def test_has_guest_access__regular_user(self):
|
||||
g.user = security_manager.find_user("admin")
|
||||
has_guest_access = security_manager.has_guest_access(self.dash)
|
||||
self.assertFalse(has_guest_access)
|
||||
assert not has_guest_access
|
||||
|
||||
def test_has_guest_access__anonymous_user(self):
|
||||
g.user = security_manager.get_anonymous_user()
|
||||
has_guest_access = security_manager.has_guest_access(self.dash)
|
||||
self.assertFalse(has_guest_access)
|
||||
assert not has_guest_access
|
||||
|
||||
def test_has_guest_access__authorized_guest_user(self):
|
||||
g.user = self.authorized_guest
|
||||
has_guest_access = security_manager.has_guest_access(self.dash)
|
||||
self.assertTrue(has_guest_access)
|
||||
assert has_guest_access
|
||||
|
||||
def test_has_guest_access__authorized_guest_user__non_zero_resource_index(self):
|
||||
# set up a user who has authorized access, plus another resource
|
||||
|
|
@ -163,7 +163,7 @@ class TestGuestUserDashboardAccess(SupersetTestCase):
|
|||
g.user = guest
|
||||
|
||||
has_guest_access = security_manager.has_guest_access(self.dash)
|
||||
self.assertTrue(has_guest_access)
|
||||
assert has_guest_access
|
||||
|
||||
def test_has_guest_access__unauthorized_guest_user__different_resource_id(self):
|
||||
g.user = security_manager.get_guest_user_from_token(
|
||||
|
|
@ -173,14 +173,14 @@ class TestGuestUserDashboardAccess(SupersetTestCase):
|
|||
}
|
||||
)
|
||||
has_guest_access = security_manager.has_guest_access(self.dash)
|
||||
self.assertFalse(has_guest_access)
|
||||
assert not has_guest_access
|
||||
|
||||
def test_has_guest_access__unauthorized_guest_user__different_resource_type(self):
|
||||
g.user = security_manager.get_guest_user_from_token(
|
||||
{"user": {}, "resources": [{"type": "dirt", "id": self.embedded.uuid}]}
|
||||
)
|
||||
has_guest_access = security_manager.has_guest_access(self.dash)
|
||||
self.assertFalse(has_guest_access)
|
||||
assert not has_guest_access
|
||||
|
||||
def test_raise_for_dashboard_access_as_guest(self):
|
||||
g.user = self.authorized_guest
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
"clause": "client_id=1",
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
rls1 = (
|
||||
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
|
||||
).one_or_none()
|
||||
|
|
@ -214,7 +214,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
"clause": "client_id=1",
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
|
||||
@pytest.mark.usefixtures("create_dataset")
|
||||
def test_model_view_rls_add_tables_required(self):
|
||||
|
|
@ -231,7 +231,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
"clause": "client_id=1",
|
||||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["message"] == {"tables": ["Shorter than minimum length 1."]}
|
||||
|
||||
|
|
@ -326,8 +326,8 @@ class TestRowLevelSecurityCreateAPI(SupersetTestCase):
|
|||
}
|
||||
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
|
||||
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(status_code, 422)
|
||||
self.assertEqual(data["message"], "[l'Some roles do not exist']")
|
||||
assert status_code == 422
|
||||
assert data["message"] == "[l'Some roles do not exist']"
|
||||
|
||||
def test_invalid_table_failure(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -340,8 +340,8 @@ class TestRowLevelSecurityCreateAPI(SupersetTestCase):
|
|||
}
|
||||
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
|
||||
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(status_code, 422)
|
||||
self.assertEqual(data["message"], "[l'Datasource does not exist']")
|
||||
assert status_code == 422
|
||||
assert data["message"] == "[l'Datasource does not exist']"
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_post_success(self):
|
||||
|
|
@ -357,7 +357,7 @@ class TestRowLevelSecurityCreateAPI(SupersetTestCase):
|
|||
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
|
||||
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(status_code, 201)
|
||||
assert status_code == 201
|
||||
|
||||
rls = (
|
||||
db.session.query(RowLevelSecurityFilter)
|
||||
|
|
@ -366,11 +366,11 @@ class TestRowLevelSecurityCreateAPI(SupersetTestCase):
|
|||
)
|
||||
|
||||
assert rls
|
||||
self.assertEqual(rls.name, "rls 1")
|
||||
self.assertEqual(rls.clause, "1=1")
|
||||
self.assertEqual(rls.filter_type, "Base")
|
||||
self.assertEqual(rls.tables[0].id, table.id)
|
||||
self.assertEqual(rls.roles[0].id, 1)
|
||||
assert rls.name == "rls 1"
|
||||
assert rls.clause == "1=1"
|
||||
assert rls.filter_type == "Base"
|
||||
assert rls.tables[0].id == table.id
|
||||
assert rls.roles[0].id == 1
|
||||
|
||||
db.session.delete(rls)
|
||||
db.session.commit()
|
||||
|
|
@ -388,8 +388,8 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
|
|||
}
|
||||
rv = self.client.put("/api/v1/rowlevelsecurity/99999999", json=payload)
|
||||
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(status_code, 404)
|
||||
self.assertEqual(data["message"], "Not found")
|
||||
assert status_code == 404
|
||||
assert data["message"] == "Not found"
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_invalid_role_failure(self):
|
||||
|
|
@ -410,8 +410,8 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
|
|||
}
|
||||
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
|
||||
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(status_code, 422)
|
||||
self.assertEqual(data["message"], "[l'Some roles do not exist']")
|
||||
assert status_code == 422
|
||||
assert data["message"] == "[l'Some roles do not exist']"
|
||||
|
||||
db.session.delete(rls)
|
||||
db.session.commit()
|
||||
|
|
@ -439,8 +439,8 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
|
|||
}
|
||||
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
|
||||
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(status_code, 422)
|
||||
self.assertEqual(data["message"], "[l'Datasource does not exist']")
|
||||
assert status_code == 422
|
||||
assert data["message"] == "[l'Datasource does not exist']"
|
||||
|
||||
db.session.delete(rls)
|
||||
db.session.commit()
|
||||
|
|
@ -472,7 +472,7 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
|
|||
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
|
||||
status_code, _data = rv.status_code, json.loads(rv.data.decode("utf-8")) # noqa: F841
|
||||
|
||||
self.assertEqual(status_code, 201)
|
||||
assert status_code == 201
|
||||
|
||||
rls = (
|
||||
db.session.query(RowLevelSecurityFilter)
|
||||
|
|
@ -480,11 +480,11 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
|
|||
.one_or_none()
|
||||
)
|
||||
|
||||
self.assertEqual(rls.name, "rls put success")
|
||||
self.assertEqual(rls.clause, "2=2")
|
||||
self.assertEqual(rls.filter_type, "Base")
|
||||
self.assertEqual(rls.tables[0].id, tables[1].id)
|
||||
self.assertEqual(rls.roles[0].id, roles[1].id)
|
||||
assert rls.name == "rls put success"
|
||||
assert rls.clause == "2=2"
|
||||
assert rls.filter_type == "Base"
|
||||
assert rls.tables[0].id == tables[1].id
|
||||
assert rls.roles[0].id == roles[1].id
|
||||
|
||||
db.session.delete(rls)
|
||||
db.session.commit()
|
||||
|
|
@ -498,8 +498,8 @@ class TestRowLevelSecurityDeleteAPI(SupersetTestCase):
|
|||
rv = self.client.delete(f"/api/v1/rowlevelsecurity/?q={ids_to_delete}")
|
||||
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(status_code, 404)
|
||||
self.assertEqual(data["message"], "Not found")
|
||||
assert status_code == 404
|
||||
assert data["message"] == "Not found"
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
|
|
@ -530,8 +530,8 @@ class TestRowLevelSecurityDeleteAPI(SupersetTestCase):
|
|||
rv = self.client.delete(f"/api/v1/rowlevelsecurity/?q={ids_to_delete}")
|
||||
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(status_code, 200)
|
||||
self.assertEqual(data["message"], "Deleted 2 rules")
|
||||
assert status_code == 200
|
||||
assert data["message"] == "Deleted 2 rules"
|
||||
|
||||
|
||||
class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
|
||||
|
|
@ -543,7 +543,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
|
|||
params = prison.dumps({"page": 0, "page_size": 100})
|
||||
|
||||
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
result = data["result"]
|
||||
|
||||
|
|
@ -561,7 +561,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
|
|||
params = prison.dumps({"page": 0, "page_size": 100})
|
||||
|
||||
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/roles?q={params}")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
result = data["result"]
|
||||
|
||||
|
|
@ -584,7 +584,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
|
|||
params = prison.dumps({"page": 0, "page_size": 10})
|
||||
|
||||
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
result = data["result"]
|
||||
received_tables = {table["text"].split(".")[-1] for table in result}
|
||||
|
|
@ -664,7 +664,7 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
|
|||
tbl = self.get_table(name="birth_names")
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
|
||||
self.assertRegex(sql, RLS_ALICE_REGEX)
|
||||
assert re.search(RLS_ALICE_REGEX, sql)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_rls_filter_does_not_alter_unrelated_query(self):
|
||||
|
|
@ -679,7 +679,7 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
|
|||
tbl = self.get_table(name="birth_names")
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
|
||||
self.assertNotRegex(sql, RLS_ALICE_REGEX)
|
||||
assert not re.search(RLS_ALICE_REGEX, sql)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_multiple_rls_filters_are_unionized(self):
|
||||
|
|
@ -695,8 +695,8 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
|
|||
tbl = self.get_table(name="birth_names")
|
||||
sql = tbl.get_query_str(self.query_obj)
|
||||
|
||||
self.assertRegex(sql, RLS_ALICE_REGEX)
|
||||
self.assertRegex(sql, RLS_GENDER_REGEX)
|
||||
assert re.search(RLS_ALICE_REGEX, sql)
|
||||
assert re.search(RLS_GENDER_REGEX, sql)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
|
|
@ -709,8 +709,8 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
|
|||
births_sql = births.get_query_str(self.query_obj)
|
||||
energy_sql = energy.get_query_str(self.query_obj)
|
||||
|
||||
self.assertRegex(births_sql, RLS_ALICE_REGEX)
|
||||
self.assertRegex(energy_sql, RLS_ALICE_REGEX)
|
||||
assert re.search(RLS_ALICE_REGEX, births_sql)
|
||||
assert re.search(RLS_ALICE_REGEX, energy_sql)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_dataset_id_can_be_string(self):
|
||||
|
|
@ -721,4 +721,4 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
|
|||
)
|
||||
sql = dataset.get_query_str(self.query_obj)
|
||||
|
||||
self.assertRegex(sql, RLS_ALICE_REGEX)
|
||||
assert re.search(RLS_ALICE_REGEX, sql)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -67,7 +67,7 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
result = data.get("result")
|
||||
assert result["active_tab"] is None # noqa: E711
|
||||
assert result["tab_state_ids"] == []
|
||||
self.assertEqual(len(result["databases"]), 0)
|
||||
assert len(result["databases"]) == 0
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -126,7 +126,7 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
# associated with any tabs
|
||||
resp = self.get_json_resp("/api/v1/sqllab/")
|
||||
result = resp["result"]
|
||||
self.assertEqual(result["active_tab"]["id"], tab_state_id)
|
||||
assert result["active_tab"]["id"] == tab_state_id
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
|
|
@ -220,8 +220,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
}
|
||||
}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, failed_resp)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
|
||||
assert rv.status_code == 400
|
||||
|
||||
data = {"sql": "SELECT 1"}
|
||||
rv = self.client.post(
|
||||
|
|
@ -230,8 +230,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
)
|
||||
failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, failed_resp)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
|
||||
assert rv.status_code == 400
|
||||
|
||||
data = {"database_id": 1}
|
||||
rv = self.client.post(
|
||||
|
|
@ -240,8 +240,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
)
|
||||
failed_resp = {"message": {"sql": ["Missing data for required field."]}}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, failed_resp)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
|
||||
assert rv.status_code == 400
|
||||
|
||||
def test_estimate_valid_request(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -270,8 +270,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
|
||||
success_resp = {"result": formatter_response}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, success_resp)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_format_sql_request(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -283,8 +283,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
)
|
||||
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, success_resp)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
|
||||
assert rv.status_code == 200
|
||||
|
||||
@mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
|
||||
def test_execute_required_params(self):
|
||||
|
|
@ -303,8 +303,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
}
|
||||
}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, failed_resp)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
|
||||
assert rv.status_code == 400
|
||||
|
||||
data = {"sql": "SELECT 1", "client_id": client_id}
|
||||
rv = self.client.post(
|
||||
|
|
@ -313,8 +313,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
)
|
||||
failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, failed_resp)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
|
||||
assert rv.status_code == 400
|
||||
|
||||
data = {"database_id": 1, "client_id": client_id}
|
||||
rv = self.client.post(
|
||||
|
|
@ -323,8 +323,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
)
|
||||
failed_resp = {"message": {"sql": ["Missing data for required field."]}}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, failed_resp)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
|
||||
assert rv.status_code == 400
|
||||
|
||||
@mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
|
||||
def test_execute_valid_request(self) -> None:
|
||||
|
|
@ -342,8 +342,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
json=data,
|
||||
)
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(resp_data.get("status"), "success")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert resp_data.get("status") == "success"
|
||||
assert rv.status_code == 200
|
||||
|
||||
@mock.patch(
|
||||
"tests.integration_tests.superset_test_custom_template_processors.datetime"
|
||||
|
|
@ -366,7 +366,7 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
"/api/v1/sqllab/execute/", raise_on_error=False, json_=json_payload
|
||||
)
|
||||
assert sql_lab_mock.called
|
||||
self.assertEqual(sql_lab_mock.call_args[0][1], "SELECT '1970-01-01' as test")
|
||||
assert sql_lab_mock.call_args[0][1] == "SELECT '1970-01-01' as test"
|
||||
|
||||
self.delete_fake_db_for_macros()
|
||||
|
||||
|
|
@ -419,8 +419,8 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
self.get_resp(f"/api/v1/sqllab/results/?q={prison.dumps(arguments)}")
|
||||
)
|
||||
|
||||
self.assertEqual(result_key, expected_key)
|
||||
self.assertEqual(result_limited, expected_limited)
|
||||
assert result_key == expected_key
|
||||
assert result_limited == expected_limited
|
||||
|
||||
app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack
|
||||
|
||||
|
|
@ -454,6 +454,6 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
data = csv.reader(io.StringIO(resp))
|
||||
expected_data = csv.reader(io.StringIO("foo\n1\n2"))
|
||||
|
||||
self.assertEqual(list(expected_data), list(data))
|
||||
assert list(expected_data) == list(data)
|
||||
db.session.delete(query_obj)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class TestPrestoValidator(SupersetTestCase):
|
|||
|
||||
errors = self.validator.validate(sql, None, schema, self.database)
|
||||
|
||||
self.assertEqual([], errors)
|
||||
assert [] == errors
|
||||
|
||||
@patch("superset.utils.core.g")
|
||||
def test_validator_db_error(self, flask_g):
|
||||
|
|
@ -95,7 +95,7 @@ class TestPrestoValidator(SupersetTestCase):
|
|||
|
||||
errors = self.validator.validate(sql, None, schema, self.database)
|
||||
|
||||
self.assertEqual(1, len(errors))
|
||||
assert 1 == len(errors)
|
||||
|
||||
|
||||
class TestPostgreSQLValidator(SupersetTestCase):
|
||||
|
|
|
|||
|
|
@ -83,12 +83,12 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
database = Database(database_name="druid_db", sqlalchemy_uri="druid://db")
|
||||
tbl = SqlaTable(table_name="druid_tbl", database=database)
|
||||
col = TableColumn(column_name="__time", type="INTEGER", table=tbl)
|
||||
self.assertEqual(col.is_dttm, None)
|
||||
assert col.is_dttm is None
|
||||
DruidEngineSpec.alter_new_orm_column(col)
|
||||
self.assertEqual(col.is_dttm, True)
|
||||
assert col.is_dttm is True
|
||||
|
||||
col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl)
|
||||
self.assertEqual(col.is_temporal, False)
|
||||
assert col.is_temporal is False
|
||||
|
||||
def test_temporal_varchar(self):
|
||||
"""Ensure a column with is_dttm set to true evaluates to is_temporal == True"""
|
||||
|
|
@ -125,13 +125,13 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database())
|
||||
for str_type, db_col_type in test_cases.items():
|
||||
col = TableColumn(column_name="foo", type=str_type, table=tbl)
|
||||
self.assertEqual(col.is_temporal, db_col_type == GenericDataType.TEMPORAL)
|
||||
self.assertEqual(col.is_numeric, db_col_type == GenericDataType.NUMERIC)
|
||||
self.assertEqual(col.is_string, db_col_type == GenericDataType.STRING)
|
||||
assert col.is_temporal == (db_col_type == GenericDataType.TEMPORAL)
|
||||
assert col.is_numeric == (db_col_type == GenericDataType.NUMERIC)
|
||||
assert col.is_string == (db_col_type == GenericDataType.STRING)
|
||||
|
||||
for str_type, db_col_type in test_cases.items():
|
||||
col = TableColumn(column_name="foo", type=str_type, table=tbl, is_dttm=True)
|
||||
self.assertTrue(col.is_temporal)
|
||||
assert col.is_temporal
|
||||
|
||||
@patch("superset.jinja_context.get_user_id", return_value=1)
|
||||
@patch("superset.jinja_context.get_username", return_value="abc")
|
||||
|
|
@ -161,7 +161,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
query_obj = dict(**base_query_obj, extras={})
|
||||
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
|
||||
assert table1.has_extra_cache_key_calls(query_obj)
|
||||
assert set(extra_cache_keys) == {1, "abc", "abc@test.com"}
|
||||
|
||||
# Table with Jinja callable disabled.
|
||||
|
|
@ -177,8 +177,8 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
)
|
||||
query_obj = dict(**base_query_obj, extras={})
|
||||
extra_cache_keys = table2.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table2.has_extra_cache_key_calls(query_obj))
|
||||
self.assertListEqual(extra_cache_keys, [])
|
||||
assert table2.has_extra_cache_key_calls(query_obj)
|
||||
self.assertListEqual(extra_cache_keys, []) # noqa: PT009
|
||||
|
||||
# Table with no Jinja callable.
|
||||
query = "SELECT 'abc' as user"
|
||||
|
|
@ -190,15 +190,15 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"})
|
||||
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
|
||||
self.assertFalse(table3.has_extra_cache_key_calls(query_obj))
|
||||
self.assertListEqual(extra_cache_keys, [])
|
||||
assert not table3.has_extra_cache_key_calls(query_obj)
|
||||
self.assertListEqual(extra_cache_keys, []) # noqa: PT009
|
||||
|
||||
# With Jinja callable in SQL expression.
|
||||
query_obj = dict(
|
||||
**base_query_obj, extras={"where": "(user != '{{ current_username() }}')"}
|
||||
)
|
||||
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table3.has_extra_cache_key_calls(query_obj))
|
||||
assert table3.has_extra_cache_key_calls(query_obj)
|
||||
assert extra_cache_keys == ["abc"]
|
||||
|
||||
@patch("superset.jinja_context.get_username", return_value="abc")
|
||||
|
|
@ -393,11 +393,9 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
sqla_query = table.get_sqla_query(**query_obj)
|
||||
sql = table.database.compile_sqla_query(sqla_query.sqla_query)
|
||||
if isinstance(filter_.expected, list):
|
||||
self.assertTrue(
|
||||
any([candidate in sql for candidate in filter_.expected])
|
||||
)
|
||||
assert any([candidate in sql for candidate in filter_.expected])
|
||||
else:
|
||||
self.assertIn(filter_.expected, sql)
|
||||
assert filter_.expected in sql
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_boolean_type_where_operators(self):
|
||||
|
|
@ -434,7 +432,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
# https://github.com/sqlalchemy/sqlalchemy/blob/master/lib/sqlalchemy/dialects/mysql/base.py
|
||||
if not dialect.supports_native_boolean and dialect.name != "mysql":
|
||||
operand = "(1, 0)"
|
||||
self.assertIn(f"IN {operand}", sql)
|
||||
assert f"IN {operand}" in sql
|
||||
|
||||
def test_incorrect_jinja_syntax_raises_correct_exception(self):
|
||||
query_obj = {
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
|
||||
data = self.run_sql("SELECT * FROM birth_names LIMIT 10", "1")
|
||||
self.assertLess(0, len(data["data"]))
|
||||
assert 0 < len(data["data"])
|
||||
|
||||
data = self.run_sql("SELECT * FROM nonexistent_table", "2")
|
||||
if backend() == "presto":
|
||||
|
|
@ -220,8 +220,8 @@ class TestSqlLab(SupersetTestCase):
|
|||
names_count = engine.execute(
|
||||
f"SELECT COUNT(*) FROM birth_names" # noqa: F541
|
||||
).first()
|
||||
self.assertEqual(
|
||||
names_count[0], len(data)
|
||||
assert names_count[0] == len(
|
||||
data
|
||||
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
|
||||
|
||||
# cleanup
|
||||
|
|
@ -238,14 +238,14 @@ class TestSqlLab(SupersetTestCase):
|
|||
SELECT * FROM birth_names LIMIT 2;
|
||||
"""
|
||||
data = self.run_sql(multi_sql, "2234")
|
||||
self.assertLess(0, len(data["data"]))
|
||||
assert 0 < len(data["data"])
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_explain(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
data = self.run_sql("EXPLAIN SELECT * FROM birth_names", "1")
|
||||
self.assertLess(0, len(data["data"]))
|
||||
assert 0 < len(data["data"])
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_sql_json_has_access(self):
|
||||
|
|
@ -261,21 +261,21 @@ class TestSqlLab(SupersetTestCase):
|
|||
data = self.run_sql(QUERY_1, "1", username="Gagarin")
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
self.assertLess(0, len(data["data"]))
|
||||
assert 0 < len(data["data"])
|
||||
|
||||
def test_sqllab_has_access(self):
|
||||
for username in (ADMIN_USERNAME, GAMMA_SQLLAB_USERNAME):
|
||||
self.login(username)
|
||||
for endpoint in ("/sqllab/", "/sqllab/history/"):
|
||||
resp = self.client.get(endpoint)
|
||||
self.assertEqual(200, resp.status_code)
|
||||
assert 200 == resp.status_code
|
||||
|
||||
def test_sqllab_no_access(self):
|
||||
self.login(GAMMA_USERNAME)
|
||||
for endpoint in ("/sqllab/", "/sqllab/history/"):
|
||||
resp = self.client.get(endpoint)
|
||||
# Redirects to the main page
|
||||
self.assertEqual(302, resp.status_code)
|
||||
assert 302 == resp.status_code
|
||||
|
||||
def test_sql_json_schema_access(self):
|
||||
examples_db = get_example_database()
|
||||
|
|
@ -311,7 +311,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
data = self.run_sql(
|
||||
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser"
|
||||
)
|
||||
self.assertEqual(1, len(data["data"]))
|
||||
assert 1 == len(data["data"])
|
||||
|
||||
data = self.run_sql(
|
||||
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table",
|
||||
|
|
@ -319,7 +319,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
username="SchemaUser",
|
||||
schema=CTAS_SCHEMA_NAME,
|
||||
)
|
||||
self.assertEqual(1, len(data["data"]))
|
||||
assert 1 == len(data["data"])
|
||||
|
||||
# postgres needs a schema as a part of the table name.
|
||||
if db_backend == "mysql":
|
||||
|
|
@ -329,7 +329,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
username="SchemaUser",
|
||||
schema=CTAS_SCHEMA_NAME,
|
||||
)
|
||||
self.assertEqual(1, len(data["data"]))
|
||||
assert 1 == len(data["data"])
|
||||
|
||||
db.session.query(Query).delete()
|
||||
with get_example_database().get_sqla_engine() as engine:
|
||||
|
|
@ -349,77 +349,75 @@ class TestSqlLab(SupersetTestCase):
|
|||
data = [["a", 4, 4.0]]
|
||||
results = SupersetResultSet(data, cols, BaseEngineSpec)
|
||||
|
||||
self.assertEqual(len(data), results.size)
|
||||
self.assertEqual(len(cols), len(results.columns))
|
||||
assert len(data) == results.size
|
||||
assert len(cols) == len(results.columns)
|
||||
|
||||
def test_pa_conversion_tuple(self):
|
||||
cols = ["string_col", "int_col", "list_col", "float_col"]
|
||||
data = [("Text", 111, [123], 1.0)]
|
||||
results = SupersetResultSet(data, cols, BaseEngineSpec)
|
||||
|
||||
self.assertEqual(len(data), results.size)
|
||||
self.assertEqual(len(cols), len(results.columns))
|
||||
assert len(data) == results.size
|
||||
assert len(cols) == len(results.columns)
|
||||
|
||||
def test_pa_conversion_dict(self):
|
||||
cols = ["string_col", "dict_col", "int_col"]
|
||||
data = [["a", {"c1": 1, "c2": 2, "c3": 3}, 4]]
|
||||
results = SupersetResultSet(data, cols, BaseEngineSpec)
|
||||
|
||||
self.assertEqual(len(data), results.size)
|
||||
self.assertEqual(len(cols), len(results.columns))
|
||||
assert len(data) == results.size
|
||||
assert len(cols) == len(results.columns)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_sql_limit(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
test_limit = 1
|
||||
data = self.run_sql("SELECT * FROM birth_names", client_id="sql_limit_1")
|
||||
self.assertGreater(len(data["data"]), test_limit)
|
||||
assert len(data["data"]) > test_limit
|
||||
data = self.run_sql(
|
||||
"SELECT * FROM birth_names", client_id="sql_limit_2", query_limit=test_limit
|
||||
)
|
||||
self.assertEqual(len(data["data"]), test_limit)
|
||||
assert len(data["data"]) == test_limit
|
||||
|
||||
data = self.run_sql(
|
||||
f"SELECT * FROM birth_names LIMIT {test_limit}",
|
||||
client_id="sql_limit_3",
|
||||
query_limit=test_limit + 1,
|
||||
)
|
||||
self.assertEqual(len(data["data"]), test_limit)
|
||||
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.QUERY)
|
||||
assert len(data["data"]) == test_limit
|
||||
assert data["query"]["limitingFactor"] == LimitingFactor.QUERY
|
||||
|
||||
data = self.run_sql(
|
||||
f"SELECT * FROM birth_names LIMIT {test_limit + 1}",
|
||||
client_id="sql_limit_4",
|
||||
query_limit=test_limit,
|
||||
)
|
||||
self.assertEqual(len(data["data"]), test_limit)
|
||||
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.DROPDOWN)
|
||||
assert len(data["data"]) == test_limit
|
||||
assert data["query"]["limitingFactor"] == LimitingFactor.DROPDOWN
|
||||
|
||||
data = self.run_sql(
|
||||
f"SELECT * FROM birth_names LIMIT {test_limit}",
|
||||
client_id="sql_limit_5",
|
||||
query_limit=test_limit,
|
||||
)
|
||||
self.assertEqual(len(data["data"]), test_limit)
|
||||
self.assertEqual(
|
||||
data["query"]["limitingFactor"], LimitingFactor.QUERY_AND_DROPDOWN
|
||||
)
|
||||
assert len(data["data"]) == test_limit
|
||||
assert data["query"]["limitingFactor"] == LimitingFactor.QUERY_AND_DROPDOWN
|
||||
|
||||
data = self.run_sql(
|
||||
"SELECT * FROM birth_names",
|
||||
client_id="sql_limit_6",
|
||||
query_limit=10000,
|
||||
)
|
||||
self.assertEqual(len(data["data"]), 1200)
|
||||
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
|
||||
assert len(data["data"]) == 1200
|
||||
assert data["query"]["limitingFactor"] == LimitingFactor.NOT_LIMITED
|
||||
|
||||
data = self.run_sql(
|
||||
"SELECT * FROM birth_names",
|
||||
client_id="sql_limit_7",
|
||||
query_limit=1200,
|
||||
)
|
||||
self.assertEqual(len(data["data"]), 1200)
|
||||
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
|
||||
assert len(data["data"]) == 1200
|
||||
assert data["query"]["limitingFactor"] == LimitingFactor.NOT_LIMITED
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data")
|
||||
def test_query_api_filter(self) -> None:
|
||||
|
|
@ -434,7 +432,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
data = self.get_json_resp(url)
|
||||
admin = security_manager.find_user("admin")
|
||||
gamma_sqllab = security_manager.find_user("gamma_sqllab")
|
||||
self.assertEqual(3, len(data["result"]))
|
||||
assert 3 == len(data["result"])
|
||||
user_queries = [
|
||||
result.get("user").get("first_name") for result in data["result"]
|
||||
]
|
||||
|
|
@ -461,7 +459,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
self.login(GAMMA_SQLLAB_USERNAME)
|
||||
url = "/api/v1/query/"
|
||||
data = self.get_json_resp(url)
|
||||
self.assertEqual(3, len(data["result"]))
|
||||
assert 3 == len(data["result"])
|
||||
|
||||
# Remove all_query_access from gamma sqllab
|
||||
all_queries_view = security_manager.find_permission_view_menu(
|
||||
|
|
@ -521,10 +519,9 @@ class TestSqlLab(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
url = f"/api/v1/query/?q={prison.dumps(arguments)}"
|
||||
self.assertEqual(
|
||||
{"SELECT 1", "SELECT 2"},
|
||||
{r.get("sql") for r in self.get_json_resp(url)["result"]},
|
||||
)
|
||||
assert {"SELECT 1", "SELECT 2"} == {
|
||||
r.get("sql") for r in self.get_json_resp(url)["result"]
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data")
|
||||
def test_query_admin_can_access_all_queries(self) -> None:
|
||||
|
|
@ -537,7 +534,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
|
||||
url = "/api/v1/query/"
|
||||
data = self.get_json_resp(url)
|
||||
self.assertEqual(3, len(data["result"]))
|
||||
assert 3 == len(data["result"])
|
||||
|
||||
def test_api_database(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
@ -555,10 +552,9 @@ class TestSqlLab(SupersetTestCase):
|
|||
}
|
||||
url = f"api/v1/database/?q={prison.dumps(arguments)}"
|
||||
|
||||
self.assertEqual(
|
||||
{"examples", "fake_db_100", "main"},
|
||||
{r.get("database_name") for r in self.get_json_resp(url)["result"]},
|
||||
)
|
||||
assert {"examples", "fake_db_100", "main"} == {
|
||||
r.get("database_name") for r in self.get_json_resp(url)["result"]
|
||||
}
|
||||
self.delete_fake_db()
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class TestCacheWarmUp(SupersetTestCase):
|
|||
expected = [
|
||||
{"chart_id": chart.id, "dashboard_id": dash.id} for chart in dash.slices
|
||||
]
|
||||
self.assertCountEqual(result, expected)
|
||||
self.assertCountEqual(result, expected) # noqa: PT009
|
||||
|
||||
def reset_tag(self, tag):
|
||||
"""Remove associated object from tag, used to reset tests"""
|
||||
|
|
@ -106,7 +106,7 @@ class TestCacheWarmUp(SupersetTestCase):
|
|||
strategy = DashboardTagsStrategy(["tag1"])
|
||||
result = strategy.get_payloads()
|
||||
expected = []
|
||||
self.assertEqual(result, expected)
|
||||
assert result == expected
|
||||
|
||||
# tag dashboard 'births' with `tag1`
|
||||
tag1 = get_tag("tag1", db.session, TagType.custom)
|
||||
|
|
@ -118,7 +118,7 @@ class TestCacheWarmUp(SupersetTestCase):
|
|||
db.session.add(tagged_object)
|
||||
db.session.commit()
|
||||
|
||||
self.assertCountEqual(strategy.get_payloads(), tag1_urls)
|
||||
self.assertCountEqual(strategy.get_payloads(), tag1_urls) # noqa: PT009
|
||||
|
||||
strategy = DashboardTagsStrategy(["tag2"])
|
||||
tag2 = get_tag("tag2", db.session, TagType.custom)
|
||||
|
|
@ -126,7 +126,7 @@ class TestCacheWarmUp(SupersetTestCase):
|
|||
|
||||
result = strategy.get_payloads()
|
||||
expected = []
|
||||
self.assertEqual(result, expected)
|
||||
assert result == expected
|
||||
|
||||
# tag first slice
|
||||
dash = self.get_dash_by_slug("unicode-test")
|
||||
|
|
@ -140,10 +140,10 @@ class TestCacheWarmUp(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
result = strategy.get_payloads()
|
||||
self.assertCountEqual(result, tag2_urls)
|
||||
self.assertCountEqual(result, tag2_urls) # noqa: PT009
|
||||
|
||||
strategy = DashboardTagsStrategy(["tag1", "tag2"])
|
||||
|
||||
result = strategy.get_payloads()
|
||||
expected = tag1_urls + tag2_urls
|
||||
self.assertCountEqual(result, expected)
|
||||
self.assertCountEqual(result, expected) # noqa: PT009
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class TestTagging(SupersetTestCase):
|
|||
self.clear_tagged_object_table()
|
||||
|
||||
# Test to make sure nothing is in the tagged_object table
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
# Create a dataset and add it to the db
|
||||
test_dataset = SqlaTable(
|
||||
|
|
@ -71,16 +71,16 @@ class TestTagging(SupersetTestCase):
|
|||
|
||||
# Test to make sure that a dataset tag was added to the tagged_object table
|
||||
tags = self.query_tagged_object_table()
|
||||
self.assertEqual(1, len(tags))
|
||||
self.assertEqual("ObjectType.dataset", str(tags[0].object_type))
|
||||
self.assertEqual(test_dataset.id, tags[0].object_id)
|
||||
assert 1 == len(tags)
|
||||
assert "ObjectType.dataset" == str(tags[0].object_type)
|
||||
assert test_dataset.id == tags[0].object_id
|
||||
|
||||
# Cleanup the db
|
||||
db.session.delete(test_dataset)
|
||||
db.session.commit()
|
||||
|
||||
# Test to make sure the tag is deleted when the associated object is deleted
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
@pytest.mark.usefixtures("with_tagging_system_feature")
|
||||
def test_chart_tagging(self):
|
||||
|
|
@ -94,7 +94,7 @@ class TestTagging(SupersetTestCase):
|
|||
self.clear_tagged_object_table()
|
||||
|
||||
# Test to make sure nothing is in the tagged_object table
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
# Create a chart and add it to the db
|
||||
test_chart = Slice(
|
||||
|
|
@ -109,16 +109,16 @@ class TestTagging(SupersetTestCase):
|
|||
|
||||
# Test to make sure that a chart tag was added to the tagged_object table
|
||||
tags = self.query_tagged_object_table()
|
||||
self.assertEqual(1, len(tags))
|
||||
self.assertEqual("ObjectType.chart", str(tags[0].object_type))
|
||||
self.assertEqual(test_chart.id, tags[0].object_id)
|
||||
assert 1 == len(tags)
|
||||
assert "ObjectType.chart" == str(tags[0].object_type)
|
||||
assert test_chart.id == tags[0].object_id
|
||||
|
||||
# Cleanup the db
|
||||
db.session.delete(test_chart)
|
||||
db.session.commit()
|
||||
|
||||
# Test to make sure the tag is deleted when the associated object is deleted
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
@pytest.mark.usefixtures("with_tagging_system_feature")
|
||||
def test_dashboard_tagging(self):
|
||||
|
|
@ -132,7 +132,7 @@ class TestTagging(SupersetTestCase):
|
|||
self.clear_tagged_object_table()
|
||||
|
||||
# Test to make sure nothing is in the tagged_object table
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
# Create a dashboard and add it to the db
|
||||
test_dashboard = Dashboard()
|
||||
|
|
@ -145,16 +145,16 @@ class TestTagging(SupersetTestCase):
|
|||
|
||||
# Test to make sure that a dashboard tag was added to the tagged_object table
|
||||
tags = self.query_tagged_object_table()
|
||||
self.assertEqual(1, len(tags))
|
||||
self.assertEqual("ObjectType.dashboard", str(tags[0].object_type))
|
||||
self.assertEqual(test_dashboard.id, tags[0].object_id)
|
||||
assert 1 == len(tags)
|
||||
assert "ObjectType.dashboard" == str(tags[0].object_type)
|
||||
assert test_dashboard.id == tags[0].object_id
|
||||
|
||||
# Cleanup the db
|
||||
db.session.delete(test_dashboard)
|
||||
db.session.commit()
|
||||
|
||||
# Test to make sure the tag is deleted when the associated object is deleted
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
@pytest.mark.usefixtures("with_tagging_system_feature")
|
||||
def test_saved_query_tagging(self):
|
||||
|
|
@ -168,7 +168,7 @@ class TestTagging(SupersetTestCase):
|
|||
self.clear_tagged_object_table()
|
||||
|
||||
# Test to make sure nothing is in the tagged_object table
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
# Create a saved query and add it to the db
|
||||
test_saved_query = SavedQuery(id=1, label="test saved query")
|
||||
|
|
@ -178,24 +178,24 @@ class TestTagging(SupersetTestCase):
|
|||
# Test to make sure that a saved query tag was added to the tagged_object table
|
||||
tags = self.query_tagged_object_table()
|
||||
|
||||
self.assertEqual(2, len(tags))
|
||||
assert 2 == len(tags)
|
||||
|
||||
self.assertEqual("ObjectType.query", str(tags[0].object_type))
|
||||
self.assertEqual("owner:None", str(tags[0].tag.name))
|
||||
self.assertEqual("TagType.owner", str(tags[0].tag.type))
|
||||
self.assertEqual(test_saved_query.id, tags[0].object_id)
|
||||
assert "ObjectType.query" == str(tags[0].object_type)
|
||||
assert "owner:None" == str(tags[0].tag.name)
|
||||
assert "TagType.owner" == str(tags[0].tag.type)
|
||||
assert test_saved_query.id == tags[0].object_id
|
||||
|
||||
self.assertEqual("ObjectType.query", str(tags[1].object_type))
|
||||
self.assertEqual("type:query", str(tags[1].tag.name))
|
||||
self.assertEqual("TagType.type", str(tags[1].tag.type))
|
||||
self.assertEqual(test_saved_query.id, tags[1].object_id)
|
||||
assert "ObjectType.query" == str(tags[1].object_type)
|
||||
assert "type:query" == str(tags[1].tag.name)
|
||||
assert "TagType.type" == str(tags[1].tag.type)
|
||||
assert test_saved_query.id == tags[1].object_id
|
||||
|
||||
# Cleanup the db
|
||||
db.session.delete(test_saved_query)
|
||||
db.session.commit()
|
||||
|
||||
# Test to make sure the tag is deleted when the associated object is deleted
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
@pytest.mark.usefixtures("with_tagging_system_feature")
|
||||
def test_favorite_tagging(self):
|
||||
|
|
@ -209,7 +209,7 @@ class TestTagging(SupersetTestCase):
|
|||
self.clear_tagged_object_table()
|
||||
|
||||
# Test to make sure nothing is in the tagged_object table
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
# Create a favorited object and add it to the db
|
||||
test_saved_query = FavStar(user_id=1, class_name="slice", obj_id=1)
|
||||
|
|
@ -218,16 +218,16 @@ class TestTagging(SupersetTestCase):
|
|||
|
||||
# Test to make sure that a favorited object tag was added to the tagged_object table
|
||||
tags = self.query_tagged_object_table()
|
||||
self.assertEqual(1, len(tags))
|
||||
self.assertEqual("ObjectType.chart", str(tags[0].object_type))
|
||||
self.assertEqual(test_saved_query.obj_id, tags[0].object_id)
|
||||
assert 1 == len(tags)
|
||||
assert "ObjectType.chart" == str(tags[0].object_type)
|
||||
assert test_saved_query.obj_id == tags[0].object_id
|
||||
|
||||
# Cleanup the db
|
||||
db.session.delete(test_saved_query)
|
||||
db.session.commit()
|
||||
|
||||
# Test to make sure the tag is deleted when the associated object is deleted
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
@with_feature_flags(TAGGING_SYSTEM=False)
|
||||
def test_tagging_system(self):
|
||||
|
|
@ -240,7 +240,7 @@ class TestTagging(SupersetTestCase):
|
|||
self.clear_tagged_object_table()
|
||||
|
||||
# Test to make sure nothing is in the tagged_object table
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
||||
# Create a dataset and add it to the db
|
||||
test_dataset = SqlaTable(
|
||||
|
|
@ -282,7 +282,7 @@ class TestTagging(SupersetTestCase):
|
|||
|
||||
# Test to make sure that no tags were added to the tagged_object table
|
||||
tags = self.query_tagged_object_table()
|
||||
self.assertEqual(0, len(tags))
|
||||
assert 0 == len(tags)
|
||||
|
||||
# Cleanup the db
|
||||
db.session.delete(test_dataset)
|
||||
|
|
@ -293,4 +293,4 @@ class TestTagging(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
# Test to make sure all the tags are deleted when the associated objects are deleted
|
||||
self.assertEqual([], self.query_tagged_object_table())
|
||||
assert [] == self.query_tagged_object_table()
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class TestTagApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/tag/{tag.id}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
expected_result = {
|
||||
"changed_by": None,
|
||||
"changed_on_delta_humanized": "now",
|
||||
|
|
@ -146,7 +146,7 @@ class TestTagApi(SupersetTestCase):
|
|||
}
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
for key, value in expected_result.items():
|
||||
self.assertEqual(value, data["result"][key])
|
||||
assert value == data["result"][key]
|
||||
# rollback changes
|
||||
db.session.delete(tag)
|
||||
db.session.commit()
|
||||
|
|
@ -160,7 +160,7 @@ class TestTagApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/tag/{max_id + 1}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
# cleanup
|
||||
db.session.delete(tag)
|
||||
db.session.commit()
|
||||
|
|
@ -173,7 +173,7 @@ class TestTagApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = "api/v1/tag/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] == TAGS_FIXTURE_COUNT
|
||||
# check expected columns
|
||||
|
|
@ -211,7 +211,7 @@ class TestTagApi(SupersetTestCase):
|
|||
}
|
||||
uri = f"api/v1/tag/?{parse.urlencode({'q': prison.dumps(query)})}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] == 2
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ class TestTagApi(SupersetTestCase):
|
|||
query["filters"][0]["value"] = False
|
||||
uri = f"api/v1/tag/?{parse.urlencode({'q': prison.dumps(query)})}"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] == 3
|
||||
|
||||
|
|
@ -249,10 +249,10 @@ class TestTagApi(SupersetTestCase):
|
|||
data = {"properties": {"tags": example_tag_names}}
|
||||
rv = self.client.post(uri, json=data, follow_redirects=True)
|
||||
# successful request
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
# check that tags were created in database
|
||||
tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names))
|
||||
self.assertEqual(tags.count(), 2)
|
||||
assert tags.count() == 2
|
||||
# check that tagged objects were created
|
||||
tag_ids = [tags[0].id, tags[1].id]
|
||||
tagged_objects = db.session.query(TaggedObject).filter(
|
||||
|
|
@ -308,7 +308,7 @@ class TestTagApi(SupersetTestCase):
|
|||
uri = f"api/v1/tag/{dashboard_type.value}/{dashboard_id}/{tags.first().name}"
|
||||
rv = self.client.delete(uri, follow_redirects=True)
|
||||
# successful request
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
# ensure that tagged object no longer exists
|
||||
tagged_object = (
|
||||
db.session.query(TaggedObject)
|
||||
|
|
@ -358,14 +358,14 @@ class TestTagApi(SupersetTestCase):
|
|||
TaggedObject.object_id == dashboard_id,
|
||||
TaggedObject.object_type == dashboard_type.name,
|
||||
)
|
||||
self.assertEqual(tagged_objects.count(), 2)
|
||||
assert tagged_objects.count() == 2
|
||||
uri = f'api/v1/tag/get_objects/?tags={",".join(tag_names)}'
|
||||
rv = self.client.get(uri)
|
||||
# successful request
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
fetched_objects = rv.json["result"]
|
||||
self.assertEqual(len(fetched_objects), 1)
|
||||
self.assertEqual(fetched_objects[0]["id"], dashboard_id)
|
||||
assert len(fetched_objects) == 1
|
||||
assert fetched_objects[0]["id"] == dashboard_id
|
||||
# clean up tagged object
|
||||
tagged_objects.delete()
|
||||
|
||||
|
|
@ -394,12 +394,12 @@ class TestTagApi(SupersetTestCase):
|
|||
TaggedObject.object_id == dashboard_id,
|
||||
TaggedObject.object_type == dashboard_type.name,
|
||||
)
|
||||
self.assertEqual(tagged_objects.count(), 2)
|
||||
self.assertEqual(tagged_objects.first().object_id, dashboard_id)
|
||||
assert tagged_objects.count() == 2
|
||||
assert tagged_objects.first().object_id == dashboard_id
|
||||
uri = "api/v1/tag/get_objects/"
|
||||
rv = self.client.get(uri)
|
||||
# successful request
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
fetched_objects = rv.json["result"]
|
||||
# check that the dashboard object was fetched
|
||||
assert dashboard_id in [obj["id"] for obj in fetched_objects]
|
||||
|
|
@ -413,25 +413,25 @@ class TestTagApi(SupersetTestCase):
|
|||
# check that tags exist in the database
|
||||
example_tag_names = ["example_tag_1", "example_tag_2", "example_tag_3"]
|
||||
tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names))
|
||||
self.assertEqual(tags.count(), 3)
|
||||
assert tags.count() == 3
|
||||
# delete the first tag
|
||||
uri = f"api/v1/tag/?q={prison.dumps(example_tag_names[:1])}"
|
||||
rv = self.client.delete(uri, follow_redirects=True)
|
||||
# successful request
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
# check that tag does not exist in the database
|
||||
tag = db.session.query(Tag).filter(Tag.name == example_tag_names[0]).first()
|
||||
assert tag is None
|
||||
tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names))
|
||||
self.assertEqual(tags.count(), 2)
|
||||
assert tags.count() == 2
|
||||
# delete multiple tags
|
||||
uri = f"api/v1/tag/?q={prison.dumps(example_tag_names[1:])}"
|
||||
rv = self.client.delete(uri, follow_redirects=True)
|
||||
# successful request
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
# check that tags are all gone
|
||||
tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names))
|
||||
self.assertEqual(tags.count(), 0)
|
||||
assert tags.count() == 0
|
||||
|
||||
@pytest.mark.usefixtures("create_tags")
|
||||
def test_delete_favorite_tag(self):
|
||||
|
|
@ -442,7 +442,7 @@ class TestTagApi(SupersetTestCase):
|
|||
tag = db.session.query(Tag).first()
|
||||
rv = self.client.post(uri, follow_redirects=True)
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
from sqlalchemy import and_ # noqa: F811
|
||||
from superset.tags.models import user_favorite_tag_table # noqa: F811
|
||||
from flask import g # noqa: F401, F811
|
||||
|
|
@ -463,7 +463,7 @@ class TestTagApi(SupersetTestCase):
|
|||
uri = f"api/v1/tag/{tag.id}/favorites/"
|
||||
rv = self.client.delete(uri, follow_redirects=True)
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
association_row = (
|
||||
db.session.query(user_favorite_tag_table)
|
||||
.filter(
|
||||
|
|
@ -483,7 +483,7 @@ class TestTagApi(SupersetTestCase):
|
|||
uri = "api/v1/tag/123/favorites/" # noqa: F541
|
||||
rv = self.client.post(uri, follow_redirects=True)
|
||||
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("create_tags")
|
||||
def test_delete_favorite_tag_not_found(self):
|
||||
|
|
@ -491,7 +491,7 @@ class TestTagApi(SupersetTestCase):
|
|||
uri = "api/v1/tag/123/favorites/" # noqa: F541
|
||||
rv = self.client.delete(uri, follow_redirects=True)
|
||||
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("create_tags")
|
||||
@patch("superset.daos.tag.g")
|
||||
|
|
@ -501,7 +501,7 @@ class TestTagApi(SupersetTestCase):
|
|||
uri = "api/v1/tag/123/favorites/" # noqa: F541
|
||||
rv = self.client.post(uri, follow_redirects=True)
|
||||
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
|
||||
@pytest.mark.usefixtures("create_tags")
|
||||
@patch("superset.daos.tag.g")
|
||||
|
|
@ -511,7 +511,7 @@ class TestTagApi(SupersetTestCase):
|
|||
uri = "api/v1/tag/123/favorites/" # noqa: F541
|
||||
rv = self.client.delete(uri, follow_redirects=True)
|
||||
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
assert rv.status_code == 422
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_post_tag(self):
|
||||
|
|
@ -527,7 +527,7 @@ class TestTagApi(SupersetTestCase):
|
|||
json={"name": "my_tag", "objects_to_tag": [["dashboard", dashboard.id]]},
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
assert rv.status_code == 201
|
||||
self.get_user(username="admin").get_id() # noqa: F841
|
||||
tag = (
|
||||
db.session.query(Tag)
|
||||
|
|
@ -550,7 +550,7 @@ class TestTagApi(SupersetTestCase):
|
|||
json={"name": "", "objects_to_tag": [["dashboard", dashboard.id]]},
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
@pytest.mark.usefixtures("create_tags")
|
||||
|
|
@ -563,7 +563,7 @@ class TestTagApi(SupersetTestCase):
|
|||
uri, json={"name": "new_name", "description": "new description"}
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
tag = (
|
||||
db.session.query(Tag)
|
||||
|
|
@ -581,7 +581,7 @@ class TestTagApi(SupersetTestCase):
|
|||
uri = f"api/v1/tag/{tag_to_update.id}"
|
||||
rv = self.client.put(uri, json={"foo": "bar"})
|
||||
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
assert rv.status_code == 400
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_post_bulk_tag(self):
|
||||
|
|
@ -617,7 +617,7 @@ class TestTagApi(SupersetTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
|
||||
result = TagDAO.get_tagged_objects_for_tags(tags, ["dashboard"])
|
||||
assert len(result) == 1
|
||||
|
|
@ -686,7 +686,7 @@ class TestTagApi(SupersetTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
assert rv.status_code == 200
|
||||
result = rv.json["result"]
|
||||
assert len(result["objects_tagged"]) == 2
|
||||
assert len(result["objects_skipped"]) == 1
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class TestThumbnailsSeleniumLive(LiveServerTestCase):
|
|||
"admin",
|
||||
thumbnail_url,
|
||||
)
|
||||
self.assertEqual(response.getcode(), 202)
|
||||
assert response.getcode() == 202
|
||||
|
||||
|
||||
class TestWebDriverScreenshotErrorDetector(SupersetTestCase):
|
||||
|
|
@ -217,7 +217,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=False)
|
||||
|
|
@ -228,7 +228,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -255,7 +255,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
assert mock_adjust_string.call_args[0][2] == "admin"
|
||||
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 202)
|
||||
assert rv.status_code == 202
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -283,7 +283,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
assert mock_adjust_string.call_args[0][2] == username
|
||||
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 202)
|
||||
assert rv.status_code == 202
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -295,7 +295,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/dashboard/{max_id + 1}/thumbnail/1234/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
|
||||
|
|
@ -306,7 +306,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -333,7 +333,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
assert mock_adjust_string.call_args[0][2] == "admin"
|
||||
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 202)
|
||||
assert rv.status_code == 202
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -361,7 +361,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
assert mock_adjust_string.call_args[0][2] == username
|
||||
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 202)
|
||||
assert rv.status_code == 202
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -373,7 +373,7 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
uri = f"api/v1/chart/{max_id + 1}/thumbnail/1234/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
assert rv.status_code == 404
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -387,8 +387,8 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
|
||||
rv = self.client.get(f"api/v1/chart/{id_}/thumbnail/1234/")
|
||||
self.assertEqual(rv.status_code, 302)
|
||||
self.assertEqual(rv.headers["Location"], thumbnail_url)
|
||||
assert rv.status_code == 302
|
||||
assert rv.headers["Location"] == thumbnail_url
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -402,8 +402,8 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(rv.data, self.mock_image)
|
||||
assert rv.status_code == 200
|
||||
assert rv.data == self.mock_image
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -417,8 +417,8 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
|
||||
rv = self.client.get(thumbnail_url)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(rv.data, self.mock_image)
|
||||
assert rv.status_code == 200
|
||||
assert rv.data == self.mock_image
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(THUMBNAILS=True)
|
||||
|
|
@ -432,5 +432,5 @@ class TestThumbnails(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
id_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
|
||||
rv = self.client.get(f"api/v1/dashboard/{id_}/thumbnail/1234/")
|
||||
self.assertEqual(rv.status_code, 302)
|
||||
self.assertEqual(rv.headers["Location"], thumbnail_url)
|
||||
assert rv.status_code == 302
|
||||
assert rv.headers["Location"] == thumbnail_url
|
||||
|
|
|
|||
|
|
@ -35,36 +35,36 @@ class TestCurrentUserApi(SupersetTestCase):
|
|||
|
||||
rv = self.client.get(meUri)
|
||||
|
||||
self.assertEqual(200, rv.status_code)
|
||||
assert 200 == rv.status_code
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual("admin", response["result"]["username"])
|
||||
self.assertEqual(True, response["result"]["is_active"])
|
||||
self.assertEqual(False, response["result"]["is_anonymous"])
|
||||
assert "admin" == response["result"]["username"]
|
||||
assert True is response["result"]["is_active"]
|
||||
assert False is response["result"]["is_anonymous"]
|
||||
|
||||
def test_get_me_with_roles(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
rv = self.client.get(meUri + "roles/")
|
||||
self.assertEqual(200, rv.status_code)
|
||||
assert 200 == rv.status_code
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
roles = list(response["result"]["roles"].keys())
|
||||
self.assertEqual("Admin", roles.pop())
|
||||
assert "Admin" == roles.pop()
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_get_my_roles_anonymous(self, mock_g):
|
||||
mock_g.user = security_manager.get_anonymous_user
|
||||
rv = self.client.get(meUri + "roles/")
|
||||
self.assertEqual(401, rv.status_code)
|
||||
assert 401 == rv.status_code
|
||||
|
||||
def test_get_me_unauthorized(self):
|
||||
rv = self.client.get(meUri)
|
||||
self.assertEqual(401, rv.status_code)
|
||||
assert 401 == rv.status_code
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_get_me_anonymous(self, mock_g):
|
||||
mock_g.user = security_manager.get_anonymous_user
|
||||
rv = self.client.get(meUri)
|
||||
self.assertEqual(401, rv.status_code)
|
||||
assert 401 == rv.status_code
|
||||
|
||||
|
||||
class TestUserApi(SupersetTestCase):
|
||||
|
|
|
|||
|
|
@ -53,8 +53,8 @@ class EncryptedFieldTest(SupersetTestCase):
|
|||
|
||||
def test_create_field(self):
|
||||
field = encrypted_field_factory.create(String(1024))
|
||||
self.assertTrue(isinstance(field, EncryptedType))
|
||||
self.assertEqual(self.app.config["SECRET_KEY"], field.key)
|
||||
assert isinstance(field, EncryptedType)
|
||||
assert self.app.config["SECRET_KEY"] == field.key
|
||||
|
||||
def test_custom_adapter(self):
|
||||
self.app.config["SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER"] = (
|
||||
|
|
@ -62,10 +62,10 @@ class EncryptedFieldTest(SupersetTestCase):
|
|||
)
|
||||
encrypted_field_factory.init_app(self.app)
|
||||
field = encrypted_field_factory.create(String(1024))
|
||||
self.assertTrue(isinstance(field, StringEncryptedType))
|
||||
self.assertFalse(isinstance(field, EncryptedType))
|
||||
self.assertTrue(getattr(field, "__created_by_enc_field_adapter__"))
|
||||
self.assertEqual(self.app.config["SECRET_KEY"], field.key)
|
||||
assert isinstance(field, StringEncryptedType)
|
||||
assert not isinstance(field, EncryptedType)
|
||||
assert getattr(field, "__created_by_enc_field_adapter__")
|
||||
assert self.app.config["SECRET_KEY"] == field.key
|
||||
|
||||
def test_ensure_encrypted_field_factory_is_used(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class MachineAuthProviderTests(SupersetTestCase):
|
|||
def test_get_auth_cookies(self):
|
||||
user = self.get_user("admin")
|
||||
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user)
|
||||
self.assertIsNotNone(auth_cookies["session"])
|
||||
assert auth_cookies["session"] is not None
|
||||
|
||||
@patch("superset.utils.machine_auth.MachineAuthProvider.get_auth_cookies")
|
||||
def test_auth_driver_user(self, get_auth_cookies):
|
||||
|
|
|
|||
|
|
@ -132,14 +132,14 @@ class TestUtils(SupersetTestCase):
|
|||
json_str = '{"test": 1}'
|
||||
blob = zlib_compress(json_str)
|
||||
got_str = zlib_decompress(blob)
|
||||
self.assertEqual(json_str, got_str)
|
||||
assert json_str == got_str
|
||||
|
||||
def test_merge_extra_filters(self):
|
||||
# does nothing if no extra filters
|
||||
form_data = {"A": 1, "B": 2, "c": "test"}
|
||||
expected = {**form_data, "adhoc_filters": [], "applied_time_extras": {}}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
# empty extra_filters
|
||||
form_data = {"A": 1, "B": 2, "c": "test", "extra_filters": []}
|
||||
expected = {
|
||||
|
|
@ -150,7 +150,7 @@ class TestUtils(SupersetTestCase):
|
|||
"applied_time_extras": {},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
# copy over extra filters into empty filters
|
||||
form_data = {
|
||||
"extra_filters": [
|
||||
|
|
@ -182,7 +182,7 @@ class TestUtils(SupersetTestCase):
|
|||
"applied_time_extras": {},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
# adds extra filters to existing filters
|
||||
form_data = {
|
||||
"extra_filters": [
|
||||
|
|
@ -230,7 +230,7 @@ class TestUtils(SupersetTestCase):
|
|||
"applied_time_extras": {},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
# adds extra filters to existing filters and sets time options
|
||||
form_data = {
|
||||
"extra_filters": [
|
||||
|
|
@ -262,7 +262,7 @@ class TestUtils(SupersetTestCase):
|
|||
},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_merge_extra_filters_ignores_empty_filters(self):
|
||||
form_data = {
|
||||
|
|
@ -273,7 +273,7 @@ class TestUtils(SupersetTestCase):
|
|||
}
|
||||
expected = {"adhoc_filters": [], "applied_time_extras": {}}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_merge_extra_filters_ignores_nones(self):
|
||||
form_data = {
|
||||
|
|
@ -301,7 +301,7 @@ class TestUtils(SupersetTestCase):
|
|||
"applied_time_extras": {},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_merge_extra_filters_ignores_equal_filters(self):
|
||||
form_data = {
|
||||
|
|
@ -361,7 +361,7 @@ class TestUtils(SupersetTestCase):
|
|||
"applied_time_extras": {},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_merge_extra_filters_merges_different_val_types(self):
|
||||
form_data = {
|
||||
|
|
@ -415,7 +415,7 @@ class TestUtils(SupersetTestCase):
|
|||
"applied_time_extras": {},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
form_data = {
|
||||
"extra_filters": [
|
||||
{"col": "a", "op": "in", "val": "someval"},
|
||||
|
|
@ -467,7 +467,7 @@ class TestUtils(SupersetTestCase):
|
|||
"applied_time_extras": {},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_merge_extra_filters_adds_unequal_lists(self):
|
||||
form_data = {
|
||||
|
|
@ -530,27 +530,24 @@ class TestUtils(SupersetTestCase):
|
|||
"applied_time_extras": {},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_merge_extra_filters_when_applied_time_extras_predefined(self):
|
||||
form_data = {"applied_time_extras": {"__time_range": "Last week"}}
|
||||
merge_extra_filters(form_data)
|
||||
|
||||
self.assertEqual(
|
||||
form_data,
|
||||
{
|
||||
assert form_data == {
|
||||
"applied_time_extras": {"__time_range": "Last week"},
|
||||
"adhoc_filters": [],
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
def test_merge_request_params_when_url_params_undefined(self):
|
||||
form_data = {"since": "2000", "until": "now"}
|
||||
url_params = {"form_data": form_data, "dashboard_ids": "(1,2,3,4,5)"}
|
||||
merge_request_params(form_data, url_params)
|
||||
self.assertIn("url_params", form_data.keys())
|
||||
self.assertIn("dashboard_ids", form_data["url_params"])
|
||||
self.assertNotIn("form_data", form_data.keys())
|
||||
assert "url_params" in form_data.keys()
|
||||
assert "dashboard_ids" in form_data["url_params"]
|
||||
assert "form_data" not in form_data.keys()
|
||||
|
||||
def test_merge_request_params_when_url_params_predefined(self):
|
||||
form_data = {
|
||||
|
|
@ -560,30 +557,26 @@ class TestUtils(SupersetTestCase):
|
|||
}
|
||||
url_params = {"form_data": form_data, "dashboard_ids": "(1,2,3,4,5)"}
|
||||
merge_request_params(form_data, url_params)
|
||||
self.assertIn("url_params", form_data.keys())
|
||||
self.assertIn("abc", form_data["url_params"])
|
||||
self.assertEqual(
|
||||
url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"]
|
||||
)
|
||||
assert "url_params" in form_data.keys()
|
||||
assert "abc" in form_data["url_params"]
|
||||
assert url_params["dashboard_ids"] == form_data["url_params"]["dashboard_ids"]
|
||||
|
||||
def test_format_timedelta(self):
|
||||
self.assertEqual(json.format_timedelta(timedelta(0)), "0:00:00")
|
||||
self.assertEqual(json.format_timedelta(timedelta(days=1)), "1 day, 0:00:00")
|
||||
self.assertEqual(json.format_timedelta(timedelta(minutes=-6)), "-0:06:00")
|
||||
self.assertEqual(
|
||||
json.format_timedelta(timedelta(0) - timedelta(days=1, hours=5, minutes=6)),
|
||||
"-1 day, 5:06:00",
|
||||
assert json.format_timedelta(timedelta(0)) == "0:00:00"
|
||||
assert json.format_timedelta(timedelta(days=1)) == "1 day, 0:00:00"
|
||||
assert json.format_timedelta(timedelta(minutes=-6)) == "-0:06:00"
|
||||
assert (
|
||||
json.format_timedelta(timedelta(0) - timedelta(days=1, hours=5, minutes=6))
|
||||
== "-1 day, 5:06:00"
|
||||
)
|
||||
self.assertEqual(
|
||||
json.format_timedelta(
|
||||
timedelta(0) - timedelta(days=16, hours=4, minutes=3)
|
||||
),
|
||||
"-16 days, 4:03:00",
|
||||
assert (
|
||||
json.format_timedelta(timedelta(0) - timedelta(days=16, hours=4, minutes=3))
|
||||
== "-16 days, 4:03:00"
|
||||
)
|
||||
|
||||
def test_validate_json(self):
|
||||
valid = '{"a": 5, "b": [1, 5, ["g", "h"]]}'
|
||||
self.assertIsNone(json.validate_json(valid))
|
||||
assert json.validate_json(valid) is None
|
||||
invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
|
||||
with self.assertRaises(json.JSONDecodeError):
|
||||
json.validate_json(invalid)
|
||||
|
|
@ -601,7 +594,7 @@ class TestUtils(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
convert_legacy_filters_into_adhoc(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_convert_legacy_filters_into_adhoc_filters(self):
|
||||
form_data = {"filters": [{"col": "a", "op": "in", "val": "someval"}]}
|
||||
|
|
@ -618,7 +611,7 @@ class TestUtils(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
convert_legacy_filters_into_adhoc(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_convert_legacy_filters_into_adhoc_present_and_empty(self):
|
||||
form_data = {"adhoc_filters": [], "where": "a = 1"}
|
||||
|
|
@ -633,7 +626,7 @@ class TestUtils(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
convert_legacy_filters_into_adhoc(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_convert_legacy_filters_into_adhoc_having(self):
|
||||
form_data = {"having": "COUNT(1) = 1"}
|
||||
|
|
@ -648,7 +641,7 @@ class TestUtils(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
convert_legacy_filters_into_adhoc(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_convert_legacy_filters_into_adhoc_present_and_nonempty(self):
|
||||
form_data = {
|
||||
|
|
@ -664,23 +657,23 @@ class TestUtils(SupersetTestCase):
|
|||
]
|
||||
}
|
||||
convert_legacy_filters_into_adhoc(form_data)
|
||||
self.assertEqual(form_data, expected)
|
||||
assert form_data == expected
|
||||
|
||||
def test_parse_js_uri_path_items_eval_undefined(self):
|
||||
self.assertIsNone(parse_js_uri_path_item("undefined", eval_undefined=True))
|
||||
self.assertIsNone(parse_js_uri_path_item("null", eval_undefined=True))
|
||||
self.assertEqual("undefined", parse_js_uri_path_item("undefined"))
|
||||
self.assertEqual("null", parse_js_uri_path_item("null"))
|
||||
assert parse_js_uri_path_item("undefined", eval_undefined=True) is None
|
||||
assert parse_js_uri_path_item("null", eval_undefined=True) is None
|
||||
assert "undefined" == parse_js_uri_path_item("undefined")
|
||||
assert "null" == parse_js_uri_path_item("null")
|
||||
|
||||
def test_parse_js_uri_path_items_unquote(self):
|
||||
self.assertEqual("slashed/name", parse_js_uri_path_item("slashed%2fname"))
|
||||
self.assertEqual(
|
||||
"slashed%2fname", parse_js_uri_path_item("slashed%2fname", unquote=False)
|
||||
assert "slashed/name" == parse_js_uri_path_item("slashed%2fname")
|
||||
assert "slashed%2fname" == parse_js_uri_path_item(
|
||||
"slashed%2fname", unquote=False
|
||||
)
|
||||
|
||||
def test_parse_js_uri_path_items_item_optional(self):
|
||||
self.assertIsNone(parse_js_uri_path_item(None))
|
||||
self.assertIsNotNone(parse_js_uri_path_item("item"))
|
||||
assert parse_js_uri_path_item(None) is None
|
||||
assert parse_js_uri_path_item("item") is not None
|
||||
|
||||
def test_get_stacktrace(self):
|
||||
app.config["SHOW_STACKTRACE"] = True
|
||||
|
|
@ -688,7 +681,7 @@ class TestUtils(SupersetTestCase):
|
|||
raise Exception("NONONO!")
|
||||
except Exception:
|
||||
stacktrace = get_stacktrace()
|
||||
self.assertIn("NONONO", stacktrace)
|
||||
assert "NONONO" in stacktrace
|
||||
|
||||
app.config["SHOW_STACKTRACE"] = False
|
||||
try:
|
||||
|
|
@ -698,31 +691,31 @@ class TestUtils(SupersetTestCase):
|
|||
assert stacktrace is None
|
||||
|
||||
def test_split(self):
|
||||
self.assertEqual(list(split("a b")), ["a", "b"])
|
||||
self.assertEqual(list(split("a,b", delimiter=",")), ["a", "b"])
|
||||
self.assertEqual(list(split("a,(b,a)", delimiter=",")), ["a", "(b,a)"])
|
||||
self.assertEqual(
|
||||
list(split('a,(b,a),"foo , bar"', delimiter=",")),
|
||||
["a", "(b,a)", '"foo , bar"'],
|
||||
)
|
||||
self.assertEqual(
|
||||
list(split("a,'b,c'", delimiter=",", quote="'")), ["a", "'b,c'"]
|
||||
)
|
||||
self.assertEqual(list(split('a "b c"')), ["a", '"b c"'])
|
||||
self.assertEqual(list(split(r'a "b \" c"')), ["a", r'"b \" c"'])
|
||||
assert list(split("a b")) == ["a", "b"]
|
||||
assert list(split("a,b", delimiter=",")) == ["a", "b"]
|
||||
assert list(split("a,(b,a)", delimiter=",")) == ["a", "(b,a)"]
|
||||
assert list(split('a,(b,a),"foo , bar"', delimiter=",")) == [
|
||||
"a",
|
||||
"(b,a)",
|
||||
'"foo , bar"',
|
||||
]
|
||||
assert list(split("a,'b,c'", delimiter=",", quote="'")) == ["a", "'b,c'"]
|
||||
assert list(split('a "b c"')) == ["a", '"b c"']
|
||||
assert list(split('a "b \\" c"')) == ["a", '"b \\" c"']
|
||||
|
||||
def test_get_or_create_db(self):
|
||||
get_or_create_db("test_db", "sqlite:///superset.db")
|
||||
database = db.session.query(Database).filter_by(database_name="test_db").one()
|
||||
self.assertIsNotNone(database)
|
||||
self.assertEqual(database.sqlalchemy_uri, "sqlite:///superset.db")
|
||||
self.assertIsNotNone(
|
||||
assert database is not None
|
||||
assert database.sqlalchemy_uri == "sqlite:///superset.db"
|
||||
assert (
|
||||
security_manager.find_permission_view_menu("database_access", database.perm)
|
||||
is not None
|
||||
)
|
||||
# Test change URI
|
||||
get_or_create_db("test_db", "sqlite:///changed.db")
|
||||
database = db.session.query(Database).filter_by(database_name="test_db").one()
|
||||
self.assertEqual(database.sqlalchemy_uri, "sqlite:///changed.db")
|
||||
assert database.sqlalchemy_uri == "sqlite:///changed.db"
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -738,22 +731,16 @@ class TestUtils(SupersetTestCase):
|
|||
assert database.sqlalchemy_uri == "sqlite:///superset.db"
|
||||
|
||||
def test_as_list(self):
|
||||
self.assertListEqual(as_list(123), [123])
|
||||
self.assertListEqual(as_list([123]), [123])
|
||||
self.assertListEqual(as_list("foo"), ["foo"])
|
||||
self.assertListEqual(as_list(123), [123]) # noqa: PT009
|
||||
self.assertListEqual(as_list([123]), [123]) # noqa: PT009
|
||||
self.assertListEqual(as_list("foo"), ["foo"]) # noqa: PT009
|
||||
|
||||
def test_merge_extra_filters_with_no_extras(self):
|
||||
form_data = {
|
||||
"time_range": "Last 10 days",
|
||||
}
|
||||
merge_extra_form_data(form_data)
|
||||
self.assertEqual(
|
||||
form_data,
|
||||
{
|
||||
"time_range": "Last 10 days",
|
||||
"adhoc_filters": [],
|
||||
},
|
||||
)
|
||||
assert form_data == {"time_range": "Last 10 days", "adhoc_filters": []}
|
||||
|
||||
def test_merge_extra_filters_with_unset_legacy_time_range(self):
|
||||
"""
|
||||
|
|
@ -767,14 +754,11 @@ class TestUtils(SupersetTestCase):
|
|||
"extra_form_data": {"time_range": "Last year"},
|
||||
}
|
||||
merge_extra_filters(form_data)
|
||||
self.assertEqual(
|
||||
form_data,
|
||||
{
|
||||
assert form_data == {
|
||||
"time_range": "Last year",
|
||||
"applied_time_extras": {},
|
||||
"adhoc_filters": [],
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
def test_merge_extra_filters_with_extras(self):
|
||||
form_data = {
|
||||
|
|
@ -817,41 +801,45 @@ class TestUtils(SupersetTestCase):
|
|||
|
||||
def test_ssl_certificate_parse(self):
|
||||
parsed_certificate = parse_ssl_cert(ssl_certificate)
|
||||
self.assertEqual(parsed_certificate.serial_number, 12355228710836649848)
|
||||
assert parsed_certificate.serial_number == 12355228710836649848
|
||||
|
||||
def test_ssl_certificate_file_creation(self):
|
||||
path = create_ssl_cert_file(ssl_certificate)
|
||||
expected_filename = md5_sha_from_str(ssl_certificate)
|
||||
self.assertIn(expected_filename, path)
|
||||
self.assertTrue(os.path.exists(path))
|
||||
assert expected_filename in path
|
||||
assert os.path.exists(path)
|
||||
|
||||
def test_get_email_address_list(self):
|
||||
self.assertEqual(get_email_address_list("a@a"), ["a@a"])
|
||||
self.assertEqual(get_email_address_list(" a@a "), ["a@a"])
|
||||
self.assertEqual(get_email_address_list("a@a\n"), ["a@a"])
|
||||
self.assertEqual(get_email_address_list(",a@a;"), ["a@a"])
|
||||
self.assertEqual(
|
||||
get_email_address_list(",a@a; b@b c@c a-c@c; d@d, f@f"),
|
||||
["a@a", "b@b", "c@c", "a-c@c", "d@d", "f@f"],
|
||||
)
|
||||
assert get_email_address_list("a@a") == ["a@a"]
|
||||
assert get_email_address_list(" a@a ") == ["a@a"]
|
||||
assert get_email_address_list("a@a\n") == ["a@a"]
|
||||
assert get_email_address_list(",a@a;") == ["a@a"]
|
||||
assert get_email_address_list(",a@a; b@b c@c a-c@c; d@d, f@f") == [
|
||||
"a@a",
|
||||
"b@b",
|
||||
"c@c",
|
||||
"a-c@c",
|
||||
"d@d",
|
||||
"f@f",
|
||||
]
|
||||
|
||||
def test_get_form_data_default(self) -> None:
|
||||
form_data, slc = get_form_data()
|
||||
self.assertEqual(slc, None)
|
||||
assert slc is None
|
||||
|
||||
def test_get_form_data_request_args(self) -> None:
|
||||
with app.test_request_context(
|
||||
query_string={"form_data": json.dumps({"foo": "bar"})}
|
||||
):
|
||||
form_data, slc = get_form_data()
|
||||
self.assertEqual(form_data, {"foo": "bar"})
|
||||
self.assertEqual(slc, None)
|
||||
assert form_data == {"foo": "bar"}
|
||||
assert slc is None
|
||||
|
||||
def test_get_form_data_request_form(self) -> None:
|
||||
with app.test_request_context(data={"form_data": json.dumps({"foo": "bar"})}):
|
||||
form_data, slc = get_form_data()
|
||||
self.assertEqual(form_data, {"foo": "bar"})
|
||||
self.assertEqual(slc, None)
|
||||
assert form_data == {"foo": "bar"}
|
||||
assert slc is None
|
||||
|
||||
def test_get_form_data_request_form_with_queries(self) -> None:
|
||||
# the CSV export uses for requests, even when sending requests to
|
||||
|
|
@ -862,8 +850,8 @@ class TestUtils(SupersetTestCase):
|
|||
}
|
||||
):
|
||||
form_data, slc = get_form_data()
|
||||
self.assertEqual(form_data, {"url_params": {"foo": "bar"}})
|
||||
self.assertEqual(slc, None)
|
||||
assert form_data == {"url_params": {"foo": "bar"}}
|
||||
assert slc is None
|
||||
|
||||
def test_get_form_data_request_args_and_form(self) -> None:
|
||||
with app.test_request_context(
|
||||
|
|
@ -871,16 +859,16 @@ class TestUtils(SupersetTestCase):
|
|||
query_string={"form_data": json.dumps({"baz": "bar"})},
|
||||
):
|
||||
form_data, slc = get_form_data()
|
||||
self.assertEqual(form_data, {"baz": "bar", "foo": "bar"})
|
||||
self.assertEqual(slc, None)
|
||||
assert form_data == {"baz": "bar", "foo": "bar"}
|
||||
assert slc is None
|
||||
|
||||
def test_get_form_data_globals(self) -> None:
|
||||
with app.test_request_context():
|
||||
g.form_data = {"foo": "bar"}
|
||||
form_data, slc = get_form_data()
|
||||
delattr(g, "form_data")
|
||||
self.assertEqual(form_data, {"foo": "bar"})
|
||||
self.assertEqual(slc, None)
|
||||
assert form_data == {"foo": "bar"}
|
||||
assert slc is None
|
||||
|
||||
def test_get_form_data_corrupted_json(self) -> None:
|
||||
with app.test_request_context(
|
||||
|
|
@ -888,8 +876,8 @@ class TestUtils(SupersetTestCase):
|
|||
query_string={"form_data": '{"baz": "bar"'},
|
||||
):
|
||||
form_data, slc = get_form_data()
|
||||
self.assertEqual(form_data, {})
|
||||
self.assertEqual(slc, None)
|
||||
assert form_data == {}
|
||||
assert slc is None
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_log_this(self) -> None:
|
||||
|
|
@ -912,29 +900,29 @@ class TestUtils(SupersetTestCase):
|
|||
.first()
|
||||
)
|
||||
|
||||
self.assertEqual(record.dashboard_id, dashboard_id)
|
||||
self.assertEqual(json.loads(record.json)["dashboard_id"], str(dashboard_id))
|
||||
self.assertEqual(json.loads(record.json)["form_data"]["slice_id"], slc.id)
|
||||
assert record.dashboard_id == dashboard_id
|
||||
assert json.loads(record.json)["dashboard_id"] == str(dashboard_id)
|
||||
assert json.loads(record.json)["form_data"]["slice_id"] == slc.id
|
||||
|
||||
self.assertEqual(
|
||||
json.loads(record.json)["form_data"]["viz_type"],
|
||||
slc.viz.form_data["viz_type"],
|
||||
assert (
|
||||
json.loads(record.json)["form_data"]["viz_type"]
|
||||
== slc.viz.form_data["viz_type"]
|
||||
)
|
||||
|
||||
def test_schema_validate_json(self):
|
||||
valid = '{"a": 5, "b": [1, 5, ["g", "h"]]}'
|
||||
self.assertIsNone(schema.validate_json(valid))
|
||||
assert schema.validate_json(valid) is None
|
||||
invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
|
||||
self.assertRaises(marshmallow.ValidationError, schema.validate_json, invalid)
|
||||
|
||||
def test_schema_one_of_case_insensitive(self):
|
||||
validator = schema.OneOfCaseInsensitive(choices=[1, 2, 3, "FoO", "BAR", "baz"])
|
||||
self.assertEqual(1, validator(1))
|
||||
self.assertEqual(2, validator(2))
|
||||
self.assertEqual("FoO", validator("FoO"))
|
||||
self.assertEqual("FOO", validator("FOO"))
|
||||
self.assertEqual("bar", validator("bar"))
|
||||
self.assertEqual("BaZ", validator("BaZ"))
|
||||
assert 1 == validator(1)
|
||||
assert 2 == validator(2)
|
||||
assert "FoO" == validator("FoO")
|
||||
assert "FOO" == validator("FOO")
|
||||
assert "bar" == validator("bar")
|
||||
assert "BaZ" == validator("BaZ")
|
||||
self.assertRaises(marshmallow.ValidationError, validator, "qwerty")
|
||||
self.assertRaises(marshmallow.ValidationError, validator, 4)
|
||||
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@ class TestBaseViz(SupersetTestCase):
|
|||
"SUM(SP_URB_TOTL)",
|
||||
"count",
|
||||
]
|
||||
self.assertEqual(test_viz.metric_labels, expect_metric_labels)
|
||||
self.assertEqual(test_viz.all_metrics, expect_metric_labels)
|
||||
assert test_viz.metric_labels == expect_metric_labels
|
||||
assert test_viz.all_metrics == expect_metric_labels
|
||||
|
||||
def test_get_df_returns_empty_df(self):
|
||||
form_data = {"dummy": 123}
|
||||
|
|
@ -91,8 +91,8 @@ class TestBaseViz(SupersetTestCase):
|
|||
datasource = self.get_datasource_mock()
|
||||
test_viz = viz.BaseViz(datasource, form_data)
|
||||
result = test_viz.get_df(query_obj)
|
||||
self.assertEqual(type(result), pd.DataFrame)
|
||||
self.assertTrue(result.empty)
|
||||
assert type(result) == pd.DataFrame
|
||||
assert result.empty
|
||||
|
||||
def test_get_df_handles_dttm_col(self):
|
||||
form_data = {"dummy": 123}
|
||||
|
|
@ -148,31 +148,31 @@ class TestBaseViz(SupersetTestCase):
|
|||
datasource = self.get_datasource_mock()
|
||||
datasource.cache_timeout = 0
|
||||
test_viz = viz.BaseViz(datasource, form_data={})
|
||||
self.assertEqual(0, test_viz.cache_timeout)
|
||||
assert 0 == test_viz.cache_timeout
|
||||
|
||||
datasource.cache_timeout = 156
|
||||
test_viz = viz.BaseViz(datasource, form_data={})
|
||||
self.assertEqual(156, test_viz.cache_timeout)
|
||||
assert 156 == test_viz.cache_timeout
|
||||
|
||||
datasource.cache_timeout = None
|
||||
datasource.database.cache_timeout = 0
|
||||
self.assertEqual(0, test_viz.cache_timeout)
|
||||
assert 0 == test_viz.cache_timeout
|
||||
|
||||
datasource.database.cache_timeout = 1666
|
||||
self.assertEqual(1666, test_viz.cache_timeout)
|
||||
assert 1666 == test_viz.cache_timeout
|
||||
|
||||
datasource.database.cache_timeout = None
|
||||
test_viz = viz.BaseViz(datasource, form_data={})
|
||||
self.assertEqual(
|
||||
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"],
|
||||
test_viz.cache_timeout,
|
||||
assert (
|
||||
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
|
||||
== test_viz.cache_timeout
|
||||
)
|
||||
|
||||
data_cache_timeout = app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
|
||||
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None
|
||||
datasource.database.cache_timeout = None
|
||||
test_viz = viz.BaseViz(datasource, form_data={})
|
||||
self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout)
|
||||
assert app.config["CACHE_DEFAULT_TIMEOUT"] == test_viz.cache_timeout
|
||||
# restore DATA_CACHE_CONFIG timeout
|
||||
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout
|
||||
|
||||
|
|
@ -195,14 +195,14 @@ class TestDistBarViz(SupersetTestCase):
|
|||
)
|
||||
test_viz = viz.DistributionBarViz(datasource, form_data)
|
||||
data = test_viz.get_data(df)[0]
|
||||
self.assertEqual("votes", data["key"])
|
||||
assert "votes" == data["key"]
|
||||
expected_values = [
|
||||
{"x": "pepperoni", "y": 5},
|
||||
{"x": "cheese", "y": 3},
|
||||
{"x": NULL_STRING, "y": 2},
|
||||
{"x": "anchovies", "y": 1},
|
||||
]
|
||||
self.assertEqual(expected_values, data["values"])
|
||||
assert expected_values == data["values"]
|
||||
|
||||
def test_groupby_nans(self):
|
||||
form_data = {
|
||||
|
|
@ -216,7 +216,7 @@ class TestDistBarViz(SupersetTestCase):
|
|||
df = pd.DataFrame({"beds": [0, 1, nan, 2], "count": [30, 42, 3, 29]})
|
||||
test_viz = viz.DistributionBarViz(datasource, form_data)
|
||||
data = test_viz.get_data(df)[0]
|
||||
self.assertEqual("count", data["key"])
|
||||
assert "count" == data["key"]
|
||||
expected_values = [
|
||||
{"x": "1.0", "y": 42},
|
||||
{"x": "0.0", "y": 30},
|
||||
|
|
@ -224,7 +224,7 @@ class TestDistBarViz(SupersetTestCase):
|
|||
{"x": NULL_STRING, "y": 3},
|
||||
]
|
||||
|
||||
self.assertEqual(expected_values, data["values"])
|
||||
assert expected_values == data["values"]
|
||||
|
||||
def test_column_nulls(self):
|
||||
form_data = {
|
||||
|
|
@ -254,7 +254,7 @@ class TestDistBarViz(SupersetTestCase):
|
|||
"values": [{"x": "pepperoni", "y": 5}, {"x": "cheese", "y": 3}],
|
||||
},
|
||||
]
|
||||
self.assertEqual(expected, data)
|
||||
assert expected == data
|
||||
|
||||
def test_column_metrics_in_order(self):
|
||||
form_data = {
|
||||
|
|
@ -292,7 +292,7 @@ class TestDistBarViz(SupersetTestCase):
|
|||
},
|
||||
]
|
||||
|
||||
self.assertEqual(expected, data)
|
||||
assert expected == data
|
||||
|
||||
def test_column_metrics_in_order_with_breakdowns(self):
|
||||
form_data = {
|
||||
|
|
@ -342,7 +342,7 @@ class TestDistBarViz(SupersetTestCase):
|
|||
},
|
||||
]
|
||||
|
||||
self.assertEqual(expected, data)
|
||||
assert expected == data
|
||||
|
||||
|
||||
class TestPairedTTest(SupersetTestCase):
|
||||
|
|
@ -445,7 +445,7 @@ class TestPairedTTest(SupersetTestCase):
|
|||
},
|
||||
],
|
||||
}
|
||||
self.assertEqual(data, expected)
|
||||
assert data == expected
|
||||
|
||||
def test_get_data_empty_null_keys(self):
|
||||
form_data = {"groupby": [], "metrics": [""]}
|
||||
|
|
@ -472,7 +472,7 @@ class TestPairedTTest(SupersetTestCase):
|
|||
}
|
||||
],
|
||||
}
|
||||
self.assertEqual(data, expected)
|
||||
assert data == expected
|
||||
|
||||
form_data = {"groupby": [], "metrics": [None]}
|
||||
with self.assertRaises(ValueError):
|
||||
|
|
@ -487,10 +487,10 @@ class TestPartitionViz(SupersetTestCase):
|
|||
test_viz = viz.PartitionViz(datasource, form_data)
|
||||
super_query_obj.return_value = {}
|
||||
query_obj = test_viz.query_obj()
|
||||
self.assertFalse(query_obj["is_timeseries"])
|
||||
assert not query_obj["is_timeseries"]
|
||||
test_viz.form_data["time_series_option"] = "agg_sum"
|
||||
query_obj = test_viz.query_obj()
|
||||
self.assertTrue(query_obj["is_timeseries"])
|
||||
assert query_obj["is_timeseries"]
|
||||
|
||||
def test_levels_for_computes_levels(self):
|
||||
raw = {}
|
||||
|
|
@ -506,37 +506,37 @@ class TestPartitionViz(SupersetTestCase):
|
|||
time_op = "agg_sum"
|
||||
test_viz = viz.PartitionViz(Mock(), {})
|
||||
levels = test_viz.levels_for(time_op, groups, df)
|
||||
self.assertEqual(4, len(levels))
|
||||
assert 4 == len(levels)
|
||||
expected = {DTTM_ALIAS: 1800, "metric1": 45, "metric2": 450, "metric3": 4500}
|
||||
self.assertEqual(expected, levels[0].to_dict())
|
||||
assert expected == levels[0].to_dict()
|
||||
expected = {
|
||||
DTTM_ALIAS: {"a1": 600, "b1": 600, "c1": 600},
|
||||
"metric1": {"a1": 6, "b1": 15, "c1": 24},
|
||||
"metric2": {"a1": 60, "b1": 150, "c1": 240},
|
||||
"metric3": {"a1": 600, "b1": 1500, "c1": 2400},
|
||||
}
|
||||
self.assertEqual(expected, levels[1].to_dict())
|
||||
self.assertEqual(["groupA", "groupB"], levels[2].index.names)
|
||||
self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names)
|
||||
assert expected == levels[1].to_dict()
|
||||
assert ["groupA", "groupB"] == levels[2].index.names
|
||||
assert ["groupA", "groupB", "groupC"] == levels[3].index.names
|
||||
time_op = "agg_mean"
|
||||
levels = test_viz.levels_for(time_op, groups, df)
|
||||
self.assertEqual(4, len(levels))
|
||||
assert 4 == len(levels)
|
||||
expected = {
|
||||
DTTM_ALIAS: 200.0,
|
||||
"metric1": 5.0,
|
||||
"metric2": 50.0,
|
||||
"metric3": 500.0,
|
||||
}
|
||||
self.assertEqual(expected, levels[0].to_dict())
|
||||
assert expected == levels[0].to_dict()
|
||||
expected = {
|
||||
DTTM_ALIAS: {"a1": 200, "c1": 200, "b1": 200},
|
||||
"metric1": {"a1": 2, "b1": 5, "c1": 8},
|
||||
"metric2": {"a1": 20, "b1": 50, "c1": 80},
|
||||
"metric3": {"a1": 200, "b1": 500, "c1": 800},
|
||||
}
|
||||
self.assertEqual(expected, levels[1].to_dict())
|
||||
self.assertEqual(["groupA", "groupB"], levels[2].index.names)
|
||||
self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names)
|
||||
assert expected == levels[1].to_dict()
|
||||
assert ["groupA", "groupB"] == levels[2].index.names
|
||||
assert ["groupA", "groupB", "groupC"] == levels[3].index.names
|
||||
|
||||
def test_levels_for_diff_computes_difference(self):
|
||||
raw = {}
|
||||
|
|
@ -553,15 +553,15 @@ class TestPartitionViz(SupersetTestCase):
|
|||
time_op = "point_diff"
|
||||
levels = test_viz.levels_for_diff(time_op, groups, df)
|
||||
expected = {"metric1": 6, "metric2": 60, "metric3": 600}
|
||||
self.assertEqual(expected, levels[0].to_dict())
|
||||
assert expected == levels[0].to_dict()
|
||||
expected = {
|
||||
"metric1": {"a1": 2, "b1": 2, "c1": 2},
|
||||
"metric2": {"a1": 20, "b1": 20, "c1": 20},
|
||||
"metric3": {"a1": 200, "b1": 200, "c1": 200},
|
||||
}
|
||||
self.assertEqual(expected, levels[1].to_dict())
|
||||
self.assertEqual(4, len(levels))
|
||||
self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names)
|
||||
assert expected == levels[1].to_dict()
|
||||
assert 4 == len(levels)
|
||||
assert ["groupA", "groupB", "groupC"] == levels[3].index.names
|
||||
|
||||
def test_levels_for_time_calls_process_data_and_drops_cols(self):
|
||||
raw = {}
|
||||
|
|
@ -581,16 +581,16 @@ class TestPartitionViz(SupersetTestCase):
|
|||
|
||||
test_viz.process_data = Mock(side_effect=return_args)
|
||||
levels = test_viz.levels_for_time(groups, df)
|
||||
self.assertEqual(4, len(levels))
|
||||
assert 4 == len(levels)
|
||||
cols = [DTTM_ALIAS, "metric1", "metric2", "metric3"]
|
||||
self.assertEqual(sorted(cols), sorted(levels[0].columns.tolist()))
|
||||
assert sorted(cols) == sorted(levels[0].columns.tolist())
|
||||
cols += ["groupA"]
|
||||
self.assertEqual(sorted(cols), sorted(levels[1].columns.tolist()))
|
||||
assert sorted(cols) == sorted(levels[1].columns.tolist())
|
||||
cols += ["groupB"]
|
||||
self.assertEqual(sorted(cols), sorted(levels[2].columns.tolist()))
|
||||
assert sorted(cols) == sorted(levels[2].columns.tolist())
|
||||
cols += ["groupC"]
|
||||
self.assertEqual(sorted(cols), sorted(levels[3].columns.tolist()))
|
||||
self.assertEqual(4, len(test_viz.process_data.mock_calls))
|
||||
assert sorted(cols) == sorted(levels[3].columns.tolist())
|
||||
assert 4 == len(test_viz.process_data.mock_calls)
|
||||
|
||||
def test_nest_values_returns_hierarchy(self):
|
||||
raw = {}
|
||||
|
|
@ -605,12 +605,12 @@ class TestPartitionViz(SupersetTestCase):
|
|||
groups = ["groupA", "groupB", "groupC"]
|
||||
levels = test_viz.levels_for("agg_sum", groups, df)
|
||||
nest = test_viz.nest_values(levels)
|
||||
self.assertEqual(3, len(nest))
|
||||
assert 3 == len(nest)
|
||||
for i in range(0, 3):
|
||||
self.assertEqual("metric" + str(i + 1), nest[i]["name"])
|
||||
self.assertEqual(3, len(nest[0]["children"]))
|
||||
self.assertEqual(1, len(nest[0]["children"][0]["children"]))
|
||||
self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"]))
|
||||
assert "metric" + str(i + 1) == nest[i]["name"]
|
||||
assert 3 == len(nest[0]["children"])
|
||||
assert 1 == len(nest[0]["children"][0]["children"])
|
||||
assert 1 == len(nest[0]["children"][0]["children"][0]["children"])
|
||||
|
||||
def test_nest_procs_returns_hierarchy(self):
|
||||
raw = {}
|
||||
|
|
@ -633,15 +633,15 @@ class TestPartitionViz(SupersetTestCase):
|
|||
)
|
||||
procs[i] = pivot
|
||||
nest = test_viz.nest_procs(procs)
|
||||
self.assertEqual(3, len(nest))
|
||||
assert 3 == len(nest)
|
||||
for i in range(0, 3):
|
||||
self.assertEqual("metric" + str(i + 1), nest[i]["name"])
|
||||
self.assertEqual(None, nest[i].get("val"))
|
||||
self.assertEqual(3, len(nest[0]["children"]))
|
||||
self.assertEqual(3, len(nest[0]["children"][0]["children"]))
|
||||
self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"]))
|
||||
self.assertEqual(
|
||||
1, len(nest[0]["children"][0]["children"][0]["children"][0]["children"])
|
||||
assert "metric" + str(i + 1) == nest[i]["name"]
|
||||
assert None is nest[i].get("val")
|
||||
assert 3 == len(nest[0]["children"])
|
||||
assert 3 == len(nest[0]["children"][0]["children"])
|
||||
assert 1 == len(nest[0]["children"][0]["children"][0]["children"])
|
||||
assert 1 == len(
|
||||
nest[0]["children"][0]["children"][0]["children"][0]["children"]
|
||||
)
|
||||
|
||||
def test_get_data_calls_correct_method(self):
|
||||
|
|
@ -662,33 +662,33 @@ class TestPartitionViz(SupersetTestCase):
|
|||
test_viz.form_data["groupby"] = ["groups"]
|
||||
test_viz.form_data["time_series_option"] = "not_time"
|
||||
test_viz.get_data(df)
|
||||
self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[0][1][0])
|
||||
assert "agg_sum" == test_viz.levels_for.mock_calls[0][1][0]
|
||||
test_viz.form_data["time_series_option"] = "agg_sum"
|
||||
test_viz.get_data(df)
|
||||
self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[1][1][0])
|
||||
assert "agg_sum" == test_viz.levels_for.mock_calls[1][1][0]
|
||||
test_viz.form_data["time_series_option"] = "agg_mean"
|
||||
test_viz.get_data(df)
|
||||
self.assertEqual("agg_mean", test_viz.levels_for.mock_calls[2][1][0])
|
||||
assert "agg_mean" == test_viz.levels_for.mock_calls[2][1][0]
|
||||
test_viz.form_data["time_series_option"] = "point_diff"
|
||||
test_viz.levels_for_diff = Mock(return_value=1)
|
||||
test_viz.get_data(df)
|
||||
self.assertEqual("point_diff", test_viz.levels_for_diff.mock_calls[0][1][0])
|
||||
assert "point_diff" == test_viz.levels_for_diff.mock_calls[0][1][0]
|
||||
test_viz.form_data["time_series_option"] = "point_percent"
|
||||
test_viz.get_data(df)
|
||||
self.assertEqual("point_percent", test_viz.levels_for_diff.mock_calls[1][1][0])
|
||||
assert "point_percent" == test_viz.levels_for_diff.mock_calls[1][1][0]
|
||||
test_viz.form_data["time_series_option"] = "point_factor"
|
||||
test_viz.get_data(df)
|
||||
self.assertEqual("point_factor", test_viz.levels_for_diff.mock_calls[2][1][0])
|
||||
assert "point_factor" == test_viz.levels_for_diff.mock_calls[2][1][0]
|
||||
test_viz.levels_for_time = Mock(return_value=1)
|
||||
test_viz.nest_procs = Mock(return_value=1)
|
||||
test_viz.form_data["time_series_option"] = "adv_anal"
|
||||
test_viz.get_data(df)
|
||||
self.assertEqual(1, len(test_viz.levels_for_time.mock_calls))
|
||||
self.assertEqual(1, len(test_viz.nest_procs.mock_calls))
|
||||
assert 1 == len(test_viz.levels_for_time.mock_calls)
|
||||
assert 1 == len(test_viz.nest_procs.mock_calls)
|
||||
test_viz.form_data["time_series_option"] = "time_series"
|
||||
test_viz.get_data(df)
|
||||
self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[3][1][0])
|
||||
self.assertEqual(7, len(test_viz.nest_values.mock_calls))
|
||||
assert "agg_sum" == test_viz.levels_for.mock_calls[3][1][0]
|
||||
assert 7 == len(test_viz.nest_values.mock_calls)
|
||||
|
||||
|
||||
class TestRoseVis(SupersetTestCase):
|
||||
|
|
@ -724,7 +724,7 @@ class TestRoseVis(SupersetTestCase):
|
|||
{"time": t3, "value": 9, "key": ("c1",), "name": ("c1",)},
|
||||
],
|
||||
}
|
||||
self.assertEqual(expected, res)
|
||||
assert expected == res
|
||||
|
||||
|
||||
class TestTimeSeriesTableViz(SupersetTestCase):
|
||||
|
|
@ -741,13 +741,13 @@ class TestTimeSeriesTableViz(SupersetTestCase):
|
|||
test_viz = viz.TimeTableViz(datasource, form_data)
|
||||
data = test_viz.get_data(df)
|
||||
# Check method correctly transforms data
|
||||
self.assertEqual({"count", "sum__A"}, set(data["columns"]))
|
||||
assert {"count", "sum__A"} == set(data["columns"])
|
||||
time_format = "%Y-%m-%d %H:%M:%S"
|
||||
expected = {
|
||||
t1.strftime(time_format): {"sum__A": 15, "count": 6},
|
||||
t2.strftime(time_format): {"sum__A": 20, "count": 7},
|
||||
}
|
||||
self.assertEqual(expected, data["records"])
|
||||
assert expected == data["records"]
|
||||
|
||||
def test_get_data_group_by(self):
|
||||
form_data = {"metrics": ["sum__A"], "groupby": ["groupby1"]}
|
||||
|
|
@ -762,13 +762,13 @@ class TestTimeSeriesTableViz(SupersetTestCase):
|
|||
test_viz = viz.TimeTableViz(datasource, form_data)
|
||||
data = test_viz.get_data(df)
|
||||
# Check method correctly transforms data
|
||||
self.assertEqual({"a1", "a2", "a3"}, set(data["columns"]))
|
||||
assert {"a1", "a2", "a3"} == set(data["columns"])
|
||||
time_format = "%Y-%m-%d %H:%M:%S"
|
||||
expected = {
|
||||
t1.strftime(time_format): {"a1": 15, "a2": 20, "a3": 25},
|
||||
t2.strftime(time_format): {"a1": 30, "a2": 35, "a3": 40},
|
||||
}
|
||||
self.assertEqual(expected, data["records"])
|
||||
assert expected == data["records"]
|
||||
|
||||
@patch("superset.viz.BaseViz.query_obj")
|
||||
def test_query_obj_throws_metrics_and_groupby(self, super_query_obj):
|
||||
|
|
@ -788,7 +788,7 @@ class TestTimeSeriesTableViz(SupersetTestCase):
|
|||
self.get_datasource_mock(), {"metrics": ["sum__A", "count"], "groupby": []}
|
||||
)
|
||||
query_obj = test_viz.query_obj()
|
||||
self.assertEqual(query_obj["orderby"], [("sum__A", False)])
|
||||
assert query_obj["orderby"] == [("sum__A", False)]
|
||||
|
||||
|
||||
class TestBaseDeckGLViz(SupersetTestCase):
|
||||
|
|
@ -838,7 +838,7 @@ class TestBaseDeckGLViz(SupersetTestCase):
|
|||
with self.assertRaises(NotImplementedError) as context:
|
||||
test_viz_deckgl.get_properties(mock_d)
|
||||
|
||||
self.assertTrue("" in str(context.exception))
|
||||
assert "" in str(context.exception)
|
||||
|
||||
def test_process_spatial_query_obj(self):
|
||||
form_data = load_fixture("deck_path_form_data.json")
|
||||
|
|
@ -850,7 +850,7 @@ class TestBaseDeckGLViz(SupersetTestCase):
|
|||
with self.assertRaises(ValueError) as context:
|
||||
test_viz_deckgl.process_spatial_query_obj(mock_key, mock_gb)
|
||||
|
||||
self.assertTrue("Bad spatial key" in str(context.exception))
|
||||
assert "Bad spatial key" in str(context.exception)
|
||||
|
||||
test_form_data = {
|
||||
"latlong_key": {"type": "latlong", "lonCol": "lon", "latCol": "lat"},
|
||||
|
|
@ -886,14 +886,14 @@ class TestBaseDeckGLViz(SupersetTestCase):
|
|||
viz_instance = viz.BaseDeckGLViz(datasource, form_data)
|
||||
|
||||
coord = viz_instance.parse_coordinates("1.23, 3.21")
|
||||
self.assertEqual(coord, (1.23, 3.21))
|
||||
assert coord == (1.23, 3.21)
|
||||
|
||||
coord = viz_instance.parse_coordinates("1.23 3.21")
|
||||
self.assertEqual(coord, (1.23, 3.21))
|
||||
assert coord == (1.23, 3.21)
|
||||
|
||||
self.assertEqual(viz_instance.parse_coordinates(None), None)
|
||||
assert viz_instance.parse_coordinates(None) is None
|
||||
|
||||
self.assertEqual(viz_instance.parse_coordinates(""), None)
|
||||
assert viz_instance.parse_coordinates("") is None
|
||||
|
||||
def test_parse_coordinates_raises(self):
|
||||
form_data = load_fixture("deck_path_form_data.json")
|
||||
|
|
@ -1001,7 +1001,7 @@ class TestTimeSeriesViz(SupersetTestCase):
|
|||
"key": ("Real Madrid C.F.\U0001f1fa\U0001f1f8\U0001f1ec\U0001f1e7",),
|
||||
},
|
||||
]
|
||||
self.assertEqual(expected, viz_data)
|
||||
assert expected == viz_data
|
||||
|
||||
def test_process_data_resample(self):
|
||||
datasource = self.get_datasource_mock()
|
||||
|
|
@ -1015,15 +1015,10 @@ class TestTimeSeriesViz(SupersetTestCase):
|
|||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
viz.NVD3TimeSeriesViz(
|
||||
assert viz.NVD3TimeSeriesViz(
|
||||
datasource,
|
||||
{"metrics": ["y"], "resample_method": "sum", "resample_rule": "1D"},
|
||||
)
|
||||
.process_data(df)["y"]
|
||||
.tolist(),
|
||||
[1.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0],
|
||||
)
|
||||
).process_data(df)["y"].tolist() == [1.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0]
|
||||
|
||||
np.testing.assert_equal(
|
||||
viz.NVD3TimeSeriesViz(
|
||||
|
|
@ -1043,8 +1038,7 @@ class TestTimeSeriesViz(SupersetTestCase):
|
|||
),
|
||||
data={"y": [1.0, 2.0, 3.0, 4.0]},
|
||||
)
|
||||
self.assertEqual(
|
||||
viz.NVD3TimeSeriesViz(
|
||||
assert viz.NVD3TimeSeriesViz(
|
||||
datasource,
|
||||
{
|
||||
"metrics": ["y"],
|
||||
|
|
@ -1052,13 +1046,8 @@ class TestTimeSeriesViz(SupersetTestCase):
|
|||
"rolling_periods": 0,
|
||||
"min_periods": 0,
|
||||
},
|
||||
)
|
||||
.apply_rolling(df)["y"]
|
||||
.tolist(),
|
||||
[1.0, 3.0, 6.0, 10.0],
|
||||
)
|
||||
self.assertEqual(
|
||||
viz.NVD3TimeSeriesViz(
|
||||
).apply_rolling(df)["y"].tolist() == [1.0, 3.0, 6.0, 10.0]
|
||||
assert viz.NVD3TimeSeriesViz(
|
||||
datasource,
|
||||
{
|
||||
"metrics": ["y"],
|
||||
|
|
@ -1066,13 +1055,8 @@ class TestTimeSeriesViz(SupersetTestCase):
|
|||
"rolling_periods": 2,
|
||||
"min_periods": 0,
|
||||
},
|
||||
)
|
||||
.apply_rolling(df)["y"]
|
||||
.tolist(),
|
||||
[1.0, 3.0, 5.0, 7.0],
|
||||
)
|
||||
self.assertEqual(
|
||||
viz.NVD3TimeSeriesViz(
|
||||
).apply_rolling(df)["y"].tolist() == [1.0, 3.0, 5.0, 7.0]
|
||||
assert viz.NVD3TimeSeriesViz(
|
||||
datasource,
|
||||
{
|
||||
"metrics": ["y"],
|
||||
|
|
@ -1080,11 +1064,7 @@ class TestTimeSeriesViz(SupersetTestCase):
|
|||
"rolling_periods": 10,
|
||||
"min_periods": 0,
|
||||
},
|
||||
)
|
||||
.apply_rolling(df)["y"]
|
||||
.tolist(),
|
||||
[1.0, 1.5, 2.0, 2.5],
|
||||
)
|
||||
).apply_rolling(df)["y"].tolist() == [1.0, 1.5, 2.0, 2.5]
|
||||
|
||||
def test_apply_rolling_without_data(self):
|
||||
datasource = self.get_datasource_mock()
|
||||
|
|
|
|||
Loading…
Reference in New Issue