chore: enable lint PT009 'use regular assert over self.assert.*' (#30521)

This commit is contained in:
Maxime Beauchemin 2024-10-07 13:17:27 -07:00 committed by GitHub
parent 1f013055d2
commit a849c29288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
62 changed files with 2218 additions and 2422 deletions

View File

@ -446,6 +446,7 @@ select = [
"E7",
"E9",
"F",
"PT009",
"TRY201",
]
ignore = []

View File

@ -59,7 +59,7 @@ class TestAsyncEventApi(SupersetTestCase):
assert rv.status_code == 200
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
mock_xrange.assert_called_with(channel_id, "-", "+", 100)
self.assertEqual(response, {"result": []})
assert response == {"result": []}
def _test_events_last_id_logic(self, mock_cache):
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
@ -69,7 +69,7 @@ class TestAsyncEventApi(SupersetTestCase):
assert rv.status_code == 200
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100)
self.assertEqual(response, {"result": []})
assert response == {"result": []}
def _test_events_results_logic(self, mock_cache):
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
@ -115,7 +115,7 @@ class TestAsyncEventApi(SupersetTestCase):
},
]
}
self.assertEqual(response, expected)
assert response == expected
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events_redis_cache_backend(self, mock_uuid4):

View File

@ -69,7 +69,7 @@ class TestOpenApiSpec(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/_openapi"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
validate_spec(response)
@ -87,20 +87,20 @@ class TestBaseModelRestApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/model1api/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["list_columns"], ["id"])
assert response["list_columns"] == ["id"]
for result in response["result"]:
self.assertEqual(list(result.keys()), ["id"])
assert list(result.keys()) == ["id"]
# Check get response
dashboard = db.session.query(Dashboard).first()
uri = f"api/v1/model1api/{dashboard.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["show_columns"], ["id"])
self.assertEqual(list(response["result"].keys()), ["id"])
assert response["show_columns"] == ["id"]
assert list(response["result"].keys()) == ["id"]
def test_default_missing_declaration_put_spec(self):
"""
@ -113,17 +113,18 @@ class TestBaseModelRestApi(SupersetTestCase):
uri = "api/v1/_openapi"
rv = self.client.get(uri)
# dashboard model accepts all fields are null
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
expected_mutation_spec = {
"properties": {"id": {"type": "integer"}},
"type": "object",
}
self.assertEqual(
response["components"]["schemas"]["Model1Api.post"], expected_mutation_spec
assert (
response["components"]["schemas"]["Model1Api.post"]
== expected_mutation_spec
)
self.assertEqual(
response["components"]["schemas"]["Model1Api.put"], expected_mutation_spec
assert (
response["components"]["schemas"]["Model1Api.put"] == expected_mutation_spec
)
def test_default_missing_declaration_post(self):
@ -145,7 +146,7 @@ class TestBaseModelRestApi(SupersetTestCase):
uri = "api/v1/model1api/"
rv = self.client.post(uri, json=dashboard_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
expected_response = {
"message": {
"css": ["Unknown field."],
@ -156,7 +157,7 @@ class TestBaseModelRestApi(SupersetTestCase):
"slug": ["Unknown field."],
}
}
self.assertEqual(response, expected_response)
assert response == expected_response
def test_refuse_invalid_format_request(self):
"""
@ -169,7 +170,7 @@ class TestBaseModelRestApi(SupersetTestCase):
rv = self.client.post(
uri, data="a: value\nb: 1\n", content_type="application/yaml"
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_default_missing_declaration_put(self):
@ -185,14 +186,14 @@ class TestBaseModelRestApi(SupersetTestCase):
uri = f"api/v1/model1api/{dashboard.id}"
rv = self.client.put(uri, json=dashboard_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
expected_response = {
"message": {
"dashboard_title": ["Unknown field."],
"slug": ["Unknown field."],
}
}
self.assertEqual(response, expected_response)
assert response == expected_response
class ApiOwnersTestCaseMixin:

View File

@ -59,8 +59,8 @@ class TestCache(SupersetTestCase):
)
# restore DATA_CACHE_CONFIG
app.config["DATA_CACHE_CONFIG"] = data_cache_config
self.assertFalse(resp["is_cached"])
self.assertFalse(resp_from_cache["is_cached"])
assert not resp["is_cached"]
assert not resp_from_cache["is_cached"]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_slice_data_cache(self):
@ -84,20 +84,20 @@ class TestCache(SupersetTestCase):
resp_from_cache = self.get_json_resp(
json_endpoint, {"form_data": json.dumps(slc.viz.form_data)}
)
self.assertFalse(resp["is_cached"])
self.assertTrue(resp_from_cache["is_cached"])
assert not resp["is_cached"]
assert resp_from_cache["is_cached"]
# should fallback to default cache timeout
self.assertEqual(resp_from_cache["cache_timeout"], 10)
self.assertEqual(resp_from_cache["status"], QueryStatus.SUCCESS)
self.assertEqual(resp["data"], resp_from_cache["data"])
self.assertEqual(resp["query"], resp_from_cache["query"])
assert resp_from_cache["cache_timeout"] == 10
assert resp_from_cache["status"] == QueryStatus.SUCCESS
assert resp["data"] == resp_from_cache["data"]
assert resp["query"] == resp_from_cache["query"]
# should exists in `data_cache`
self.assertEqual(
cache_manager.data_cache.get(resp_from_cache["cache_key"])["query"],
resp_from_cache["query"],
assert (
cache_manager.data_cache.get(resp_from_cache["cache_key"])["query"]
== resp_from_cache["query"]
)
# should not exists in `cache`
self.assertIsNone(cache_manager.cache.get(resp_from_cache["cache_key"]))
assert cache_manager.cache.get(resp_from_cache["cache_key"]) is None
# reset cache config
app.config["DATA_CACHE_CONFIG"] = data_cache_config

View File

@ -336,9 +336,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{chart_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
assert model is None
def test_delete_bulk_charts(self):
"""
@ -355,13 +355,13 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
argument = chart_ids
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": f"Deleted {chart_count} charts"}
self.assertEqual(response, expected_response)
assert response == expected_response
for chart_id in chart_ids:
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
assert model is None
def test_delete_bulk_chart_bad_request(self):
"""
@ -372,7 +372,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
argument = chart_ids
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
def test_delete_not_found_chart(self):
"""
@ -382,7 +382,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
chart_id = 1000
uri = f"api/v1/chart/{chart_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("create_chart_with_report")
def test_delete_chart_with_report(self):
@ -398,11 +398,11 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.client.delete(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
expected_response = {
"message": "There are associated alerts or reports: report_with_chart"
}
self.assertEqual(response, expected_response)
assert response == expected_response
def test_delete_bulk_charts_not_found(self):
"""
@ -413,7 +413,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/?q={prison.dumps(chart_ids)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("create_chart_with_report", "create_charts")
def test_bulk_delete_chart_with_report(self):
@ -434,11 +434,11 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/?q={prison.dumps(chart_ids)}"
rv = self.client.delete(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
expected_response = {
"message": "There are associated alerts or reports: report_with_chart"
}
self.assertEqual(response, expected_response)
assert response == expected_response
def test_delete_chart_admin_not_owned(self):
"""
@ -450,9 +450,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{chart_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
assert model is None
def test_delete_bulk_chart_admin_not_owned(self):
"""
@ -471,13 +471,13 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
expected_response = {"message": f"Deleted {chart_count} charts"}
self.assertEqual(response, expected_response)
assert response == expected_response
for chart_id in chart_ids:
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
assert model is None
def test_delete_chart_not_owned(self):
"""
@ -493,7 +493,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(username="alpha2", password="password")
uri = f"api/v1/chart/{chart.id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
db.session.delete(chart)
db.session.delete(user_alpha1)
db.session.delete(user_alpha2)
@ -525,19 +525,19 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
arguments = [chart.id for chart in charts]
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Forbidden"}
self.assertEqual(response, expected_response)
assert response == expected_response
# # nothing is deleted in bulk with a list of owned and not owned charts
arguments = [chart.id for chart in charts] + [owned_chart.id]
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Forbidden"}
self.assertEqual(response, expected_response)
assert response == expected_response
for chart in charts:
db.session.delete(chart)
@ -572,7 +572,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
data = json.loads(rv.data.decode("utf-8"))
model = db.session.query(Slice).get(data.get("id"))
db.session.delete(model)
@ -590,7 +590,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
data = json.loads(rv.data.decode("utf-8"))
model = db.session.query(Slice).get(data.get("id"))
db.session.delete(model)
@ -609,10 +609,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"owners": ["Owners are invalid"]}}
self.assertEqual(response, expected_response)
assert response == expected_response
def test_create_chart_validate_params(self):
"""
@ -627,7 +627,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
def test_create_chart_validate_datasource(self):
"""
@ -640,29 +640,24 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
"datasource_type": "unknown",
}
rv = self.post_assert_metric("/api/v1/chart/", chart_data, "post")
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{
"message": {
"datasource_type": [
"Must be one of: table, dataset, query, saved_query, view."
]
}
},
)
assert response == {
"message": {
"datasource_type": [
"Must be one of: table, dataset, query, saved_query, view."
]
}
}
chart_data = {
"slice_name": "title1",
"datasource_id": 0,
"datasource_type": "table",
}
rv = self.post_assert_metric("/api/v1/chart/", chart_data, "post")
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
)
assert response == {"message": {"datasource_id": ["Datasource does not exist"]}}
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_create_chart_validate_user_is_dashboard_owner(self):
@ -682,12 +677,11 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ALPHA_USERNAME)
uri = "api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{"message": "Changing one or more of these dashboards is forbidden"},
)
assert response == {
"message": "Changing one or more of these dashboards is forbidden"
}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_update_chart(self):
@ -720,23 +714,23 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{chart_id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart_id)
related_dashboard = db.session.query(Dashboard).filter_by(slug="births").first()
self.assertEqual(model.created_by, admin)
self.assertEqual(model.slice_name, "title1_changed")
self.assertEqual(model.description, "description1")
self.assertNotIn(admin, model.owners)
self.assertIn(gamma, model.owners)
self.assertEqual(model.viz_type, "viz_type1")
self.assertEqual(model.params, """{"a": 1}""")
self.assertEqual(model.cache_timeout, 1000)
self.assertEqual(model.datasource_id, birth_names_table_id)
self.assertEqual(model.datasource_type, "table")
self.assertEqual(model.datasource_name, full_table_name)
self.assertEqual(model.certified_by, "Mario Rossi")
self.assertEqual(model.certification_details, "Edited certification")
self.assertIn(model.id, [slice.id for slice in related_dashboard.slices])
assert model.created_by == admin
assert model.slice_name == "title1_changed"
assert model.description == "description1"
assert admin not in model.owners
assert gamma in model.owners
assert model.viz_type == "viz_type1"
assert model.params == '{"a": 1}'
assert model.cache_timeout == 1000
assert model.datasource_id == birth_names_table_id
assert model.datasource_type == "table"
assert model.datasource_name == full_table_name
assert model.certified_by == "Mario Rossi"
assert model.certification_details == "Edited certification"
assert model.id in [slice.id for slice in related_dashboard.slices]
db.session.delete(model)
db.session.commit()
@ -755,16 +749,16 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{chart_id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart_id)
response = self.get_assert_metric("api/v1/chart/", "get_list")
res = json.loads(response.data.decode("utf-8"))["result"]
current_chart = [d for d in res if d["id"] == chart_id][0]
self.assertEqual(current_chart["slice_name"], new_name)
self.assertNotIn("username", current_chart["changed_by"].keys())
self.assertNotIn("username", current_chart["owners"][0].keys())
assert current_chart["slice_name"] == new_name
assert "username" not in current_chart["changed_by"].keys()
assert "username" not in current_chart["owners"][0].keys()
db.session.delete(model)
db.session.commit()
@ -784,14 +778,14 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{chart_id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart_id)
response = self.get_assert_metric(uri, "get")
res = json.loads(response.data.decode("utf-8"))["result"]
self.assertEqual(res["slice_name"], new_name)
self.assertNotIn("username", res["owners"][0].keys())
assert res["slice_name"] == new_name
assert "username" not in res["owners"][0].keys()
db.session.delete(model)
db.session.commit()
@ -829,10 +823,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{chart_id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart_id)
self.assertNotIn(admin, model.owners)
self.assertIn(gamma, model.owners)
assert admin not in model.owners
assert gamma in model.owners
db.session.delete(model)
db.session.commit()
@ -848,8 +842,8 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(username="admin")
uri = f"api/v1/chart/{self.chart.id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
self.assertEqual([admin], self.chart.owners)
assert rv.status_code == 200
assert [admin] == self.chart.owners
@pytest.mark.usefixtures("add_dashboard_to_chart")
def test_update_chart_clear_owner_list(self):
@ -861,8 +855,8 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(username="admin")
uri = f"api/v1/chart/{self.chart.id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
self.assertEqual([], self.chart.owners)
assert rv.status_code == 200
assert [] == self.chart.owners
def test_update_chart_populate_owner(self):
"""
@ -873,15 +867,15 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
admin = self.get_user("admin")
chart_id = self.insert_chart("title", [], 1).id
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model.owners, [])
assert model.owners == []
chart_data = {"owners": [gamma.id]}
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_updated = db.session.query(Slice).get(chart_id)
self.assertNotIn(admin, model_updated.owners)
self.assertIn(gamma, model_updated.owners)
assert admin not in model_updated.owners
assert gamma in model_updated.owners
db.session.delete(model_updated)
db.session.commit()
@ -897,9 +891,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{self.chart.id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
self.assertIn(self.new_dashboard, self.chart.dashboards)
self.assertNotIn(self.original_dashboard, self.chart.dashboards)
assert rv.status_code == 200
assert self.new_dashboard in self.chart.dashboards
assert self.original_dashboard not in self.chart.dashboards
@pytest.mark.usefixtures("add_dashboard_to_chart")
def test_not_update_chart_none_dashboards(self):
@ -910,9 +904,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{self.chart.id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
self.assertIn(self.original_dashboard, self.chart.dashboards)
self.assertEqual(len(self.chart.dashboards), 1)
assert rv.status_code == 200
assert self.original_dashboard in self.chart.dashboards
assert len(self.chart.dashboards) == 1
def test_update_chart_not_owned(self):
"""
@ -930,7 +924,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
chart_data = {"slice_name": "title1_changed"}
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
db.session.delete(chart)
db.session.delete(user_alpha1)
db.session.delete(user_alpha2)
@ -976,13 +970,13 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, chart_data_with_invalid_dashboard, "put")
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"dashboards": ["Dashboards do not exist"]}}
self.assertEqual(response, expected_response)
assert response == expected_response
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
db.session.delete(chart)
db.session.delete(original_dashboard)
@ -1001,26 +995,21 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
chart_data = {"datasource_id": 1, "datasource_type": "unknown"}
rv = self.put_assert_metric(f"/api/v1/chart/{chart.id}", chart_data, "put")
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{
"message": {
"datasource_type": [
"Must be one of: table, dataset, query, saved_query, view."
]
}
},
)
assert response == {
"message": {
"datasource_type": [
"Must be one of: table, dataset, query, saved_query, view."
]
}
}
chart_data = {"datasource_id": 0, "datasource_type": "table"}
rv = self.put_assert_metric(f"/api/v1/chart/{chart.id}", chart_data, "put")
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
)
assert response == {"message": {"datasource_id": ["Datasource does not exist"]}}
db.session.delete(chart)
db.session.commit()
@ -1038,10 +1027,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/chart/" # noqa: F541
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"owners": ["Owners are invalid"]}}
self.assertEqual(response, expected_response)
assert response == expected_response
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_get_chart(self):
@ -1053,7 +1042,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{chart.id}"
rv = self.get_assert_metric(uri, "get")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
expected_result = {
"cache_timeout": None,
"certified_by": None,
@ -1075,10 +1064,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
"is_managed_externally": False,
}
data = json.loads(rv.data.decode("utf-8"))
self.assertIn("changed_on_delta_humanized", data["result"])
self.assertIn("id", data["result"])
self.assertIn("thumbnail_url", data["result"])
self.assertIn("url", data["result"])
assert "changed_on_delta_humanized" in data["result"]
assert "id" in data["result"]
assert "thumbnail_url" in data["result"]
assert "url" in data["result"]
for key, value in data["result"].items():
# We can't assert timestamp values or id/urls
if key not in (
@ -1087,7 +1076,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
"thumbnail_url",
"url",
):
self.assertEqual(value, expected_result[key])
assert value == expected_result[key]
db.session.delete(chart)
db.session.commit()
@ -1099,7 +1088,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{chart_id}"
rv = self.get_assert_metric(uri, "get")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_chart_no_data_access(self):
@ -1114,7 +1103,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
)
uri = f"api/v1/chart/{chart_no_access.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures(
"load_energy_table_with_slice",
@ -1129,9 +1118,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/chart/" # noqa: F541
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 33)
assert data["count"] == 33
@pytest.mark.usefixtures("load_energy_table_with_slice", "add_dashboard_to_chart")
def test_get_charts_dashboards(self):
@ -1146,7 +1135,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["result"][0]["dashboards"] == [
{
@ -1172,7 +1161,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
assert len(result) == 1
@ -1206,26 +1195,20 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
# Filter by tag ID
filter_params = get_filter_params("chart_tag_id", tag.id)
response_by_id = self.get_list("chart", filter_params)
self.assertEqual(response_by_id.status_code, 200)
assert response_by_id.status_code == 200
data_by_id = json.loads(response_by_id.data.decode("utf-8"))
# Filter by tag name
filter_params = get_filter_params("chart_tags", tag.name)
response_by_name = self.get_list("chart", filter_params)
self.assertEqual(response_by_name.status_code, 200)
assert response_by_name.status_code == 200
data_by_name = json.loads(response_by_name.data.decode("utf-8"))
# Compare results
self.assertEqual(
data_by_id["count"],
data_by_name["count"],
len(expected_charts),
)
self.assertEqual(
set(chart["id"] for chart in data_by_id["result"]),
set(chart["id"] for chart in data_by_name["result"]),
set(chart.id for chart in expected_charts),
)
assert data_by_id["count"] == data_by_name["count"], len(expected_charts)
assert set(chart["id"] for chart in data_by_id["result"]) == set(
chart["id"] for chart in data_by_name["result"]
), set(chart.id for chart in expected_charts)
def test_get_charts_changed_on(self):
"""
@ -1243,7 +1226,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["result"][0]["changed_on_delta_humanized"] in (
"now",
@ -1266,9 +1249,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
arguments = {"filters": [{"col": "slice_name", "opr": "sw", "value": "G"}]}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 5)
assert data["count"] == 5
@pytest.fixture()
def load_energy_charts(self):
@ -1323,9 +1306,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 4)
assert data["count"] == 4
expected_response = [
{"description": "ZY_bar", "slice_name": "foo_a", "viz_type": None},
@ -1334,11 +1317,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
{"description": "desc1", "slice_name": "zy_foo", "viz_type": None},
]
for index, item in enumerate(data["result"]):
self.assertEqual(
item["description"], expected_response[index]["description"]
)
self.assertEqual(item["slice_name"], expected_response[index]["slice_name"])
self.assertEqual(item["viz_type"], expected_response[index]["viz_type"])
assert item["description"] == expected_response[index]["description"]
assert item["slice_name"] == expected_response[index]["slice_name"]
assert item["viz_type"] == expected_response[index]["viz_type"]
@pytest.mark.usefixtures("load_energy_table_with_slice", "load_energy_charts")
def test_admin_gets_filtered_energy_slices(self):
@ -1390,9 +1371,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], CHARTS_FIXTURE_COUNT)
assert data["count"] == CHARTS_FIXTURE_COUNT
@pytest.mark.usefixtures("create_charts")
def test_gets_not_certified_charts_filter(self):
@ -1411,9 +1392,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 17)
assert data["count"] == 17
@pytest.mark.usefixtures("load_energy_charts")
def test_user_gets_none_filtered_energy_slices(self):
@ -1433,9 +1414,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(GAMMA_USERNAME)
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 0)
assert data["count"] == 0
@pytest.mark.usefixtures("load_energy_charts")
def test_user_gets_all_charts(self):
@ -1445,12 +1426,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
def count_charts():
uri = "api/v1/chart/"
rv = self.client.get(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = rv.get_json()
return data["count"]
with self.temporary_user(gamma_user, login=True):
self.assertEqual(count_charts(), 0)
assert count_charts() == 0
perm = ("all_database_access", "all_database_access")
with self.temporary_user(gamma_user, extra_pvms=[perm], login=True):
@ -1462,7 +1443,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
# Back to normal
with self.temporary_user(gamma_user, login=True):
self.assertEqual(count_charts(), 0)
assert count_charts() == 0
@pytest.mark.usefixtures("create_charts")
def test_get_charts_favorite_filter(self):
@ -1645,7 +1626,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/time_range/?q={prison.dumps(humanize_time_range)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
assert "since" in data["result"][0]
assert "until" in data["result"][0]
assert "timeRange" in data["result"][0]
@ -1686,10 +1667,10 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/form_data/?slice_id={slice.id if slice else None}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.content_type, "application/json")
assert rv.status_code == 200
assert rv.content_type == "application/json"
if slice:
self.assertEqual(data["slice_id"], slice.id)
assert data["slice_id"] == slice.id
@pytest.mark.usefixtures(
"load_unicode_dashboard_with_slice",
@ -1706,16 +1687,16 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
arguments = {"page_size": 10, "page": 0}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 10)
assert len(data["result"]) == 10
arguments = {"page_size": 10, "page": 3}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 3)
assert len(data["result"]) == 3
def test_get_charts_no_data_access(self):
"""
@ -1724,9 +1705,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(GAMMA_USERNAME)
uri = "api/v1/chart/"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 0)
assert data["count"] == 0
def test_export_chart(self):
"""
@ -1940,9 +1921,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 8)
assert data["count"] == 8
def test_gets_not_created_by_user_charts_filter(self):
arguments = {
@ -1954,9 +1935,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 8)
assert data["count"] == 8
@pytest.mark.usefixtures("create_charts")
def test_gets_owned_created_favorited_by_me_filter(self):
@ -1978,7 +1959,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
"page_size": 25,
}
rv = self.client.get(f"api/v1/chart/?q={prison.dumps(arguments)}")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["result"][0]["slice_name"] == "name0"
@ -1995,13 +1976,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
self.login(ADMIN_USERNAME)
slc = self.get_slice(slice_name)
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id})
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"],
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
)
assert data["result"] == [
{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
]
dashboard = self.get_dash_by_slug("births")
@ -2009,12 +1989,11 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
"/api/v1/chart/warm_up_cache",
json={"chart_id": slc.id, "dashboard_id": dashboard.id},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"],
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
)
assert data["result"] == [
{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
]
rv = self.client.put(
"/api/v1/chart/warm_up_cache",
@ -2026,29 +2005,25 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
),
},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"],
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
)
assert data["result"] == [
{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
]
def test_warm_up_cache_chart_id_required(self):
self.login(ADMIN_USERNAME)
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"dashboard_id": 1})
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data,
{"message": {"chart_id": ["Missing data for required field."]}},
)
assert data == {"message": {"chart_id": ["Missing data for required field."]}}
def test_warm_up_cache_chart_not_found(self):
self.login(ADMIN_USERNAME)
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": 99999})
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data, {"message": "Chart not found"})
assert data == {"message": "Chart not found"}
def test_warm_up_cache_payload_validation(self):
self.login(ADMIN_USERNAME)
@ -2056,18 +2031,15 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
"/api/v1/chart/warm_up_cache",
json={"chart_id": "id", "dashboard_id": "id", "extra_filters": 4},
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data,
{
"message": {
"chart_id": ["Not a valid integer."],
"dashboard_id": ["Not a valid integer."],
"extra_filters": ["Not a valid string."],
}
},
)
assert data == {
"message": {
"chart_id": ["Not a valid integer."],
"dashboard_id": ["Not a valid integer."],
"extra_filters": ["Not a valid string."],
}
}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache_error(self) -> None:
@ -2167,12 +2139,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, update_payload, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart.id)
# Clean up system tags
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
self.assertEqual(tag_list, new_tags)
assert tag_list == new_tags
@pytest.mark.usefixtures("create_chart_with_tag")
def test_update_chart_remove_tags_can_write_on_tag(self):
@ -2194,12 +2166,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, update_payload, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart.id)
# Clean up system tags
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
self.assertEqual(tag_list, new_tags)
assert tag_list == new_tags
@pytest.mark.usefixtures("create_chart_with_tag")
def test_update_chart_add_tags_can_tag_on_chart(self):
@ -2226,12 +2198,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, update_payload, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart.id)
# Clean up system tags
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
self.assertEqual(tag_list, new_tags)
assert tag_list == new_tags
security_manager.add_permission_role(alpha_role, write_tags_perm)
@ -2256,12 +2228,12 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, update_payload, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Slice).get(chart.id)
# Clean up system tags
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
self.assertEqual(tag_list, [])
assert tag_list == []
security_manager.add_permission_role(alpha_role, write_tags_perm)
@ -2291,10 +2263,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, update_payload, "put")
self.assertEqual(rv.status_code, 403)
self.assertEqual(
rv.json["message"],
"You do not have permission to manage tags on charts",
assert rv.status_code == 403
assert (
rv.json["message"] == "You do not have permission to manage tags on charts"
)
security_manager.add_permission_role(alpha_role, write_tags_perm)
@ -2322,10 +2293,9 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, update_payload, "put")
self.assertEqual(rv.status_code, 403)
self.assertEqual(
rv.json["message"],
"You do not have permission to manage tags on charts",
assert rv.status_code == 403
assert (
rv.json["message"] == "You do not have permission to manage tags on charts"
)
security_manager.add_permission_role(alpha_role, write_tags_perm)
@ -2353,7 +2323,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, update_payload, "put")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
security_manager.add_permission_role(alpha_role, write_tags_perm)
security_manager.add_permission_role(alpha_role, tag_charts_perm)

View File

@ -424,15 +424,19 @@ class TestChartWarmUpCacheCommand(SupersetTestCase):
def test_warm_up_cache(self):
slc = self.get_slice("Top 10 Girl Name Share")
result = ChartWarmUpCacheCommand(slc.id, None, None).run()
self.assertEqual(
result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
)
assert result == {
"chart_id": slc.id,
"viz_error": None,
"viz_status": "success",
}
# can just pass in chart as well
result = ChartWarmUpCacheCommand(slc, None, None).run()
self.assertEqual(
result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
)
assert result == {
"chart_id": slc.id,
"viz_error": None,
"viz_status": "success",
}
class TestFavoriteChartCommand(SupersetTestCase):

View File

@ -471,19 +471,16 @@ class TestPostChartDataApi(BaseTestChartDataApi):
"__time_origin": "now",
}
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"][0]["applied_filters"],
[
{"column": "gender"},
{"column": "num"},
{"column": "name"},
{"column": "__time_range"},
],
)
assert data["result"][0]["applied_filters"] == [
{"column": "gender"},
{"column": "num"},
{"column": "name"},
{"column": "__time_range"},
]
expected_row_count = self.get_expected_row_count("client_id_2")
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
assert data["result"][0]["rowcount"] == expected_row_count
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_in_op_filter__data_is_returned(self):
@ -533,7 +530,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
dttm_col.type,
dttm,
)
self.assertIn(dttm_expression, result["query"])
assert dttm_expression in result["query"]
else:
raise Exception("ds column not found")
@ -563,16 +560,16 @@ class TestPostChartDataApi(BaseTestChartDataApi):
}
]
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
row = result["data"][0]
self.assertIn("__timestamp", row)
self.assertIn("sum__num", row)
self.assertIn("sum__num__yhat", row)
self.assertIn("sum__num__yhat_upper", row)
self.assertIn("sum__num__yhat_lower", row)
self.assertEqual(result["rowcount"], 103)
assert "__timestamp" in row
assert "sum__num" in row
assert "sum__num__yhat" in row
assert "sum__num__yhat_upper" in row
assert "sum__num__yhat_lower" in row
assert result["rowcount"] == 103
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_invalid_post_processing(self):
@ -730,11 +727,11 @@ class TestPostChartDataApi(BaseTestChartDataApi):
time.sleep(1)
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
time.sleep(1)
self.assertEqual(rv.status_code, 202)
assert rv.status_code == 202
time.sleep(1)
data = json.loads(rv.data.decode("utf-8"))
keys = list(data.keys())
self.assertCountEqual(
self.assertCountEqual( # noqa: PT009
keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
)
@ -764,10 +761,10 @@ class TestPostChartDataApi(BaseTestChartDataApi):
rv = self.post_assert_metric(
CHART_DATA_URI, self.query_context_payload, "data"
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
patched_run.assert_called_once_with(force_cached=True)
self.assertEqual(data, {"result": [{"query": "select * from foo"}]})
assert data == {"result": [{"query": "select * from foo"}]}
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -779,7 +776,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
async_query_manager_factory.init_app(app)
self.query_context_payload["result_type"] = "results"
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -793,7 +790,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo"
)
rv = test_client.post(CHART_DATA_URI, json=self.query_context_payload)
self.assertEqual(rv.status_code, 401)
assert rv.status_code == 401
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_rowcount(self):
@ -846,10 +843,8 @@ class TestPostChartDataApi(BaseTestChartDataApi):
unique_names = {row["name"] for row in data}
self.maxDiff = None
self.assertEqual(len(unique_names), SERIES_LIMIT)
self.assertEqual(
{column for column in data[0].keys()}, {"state", "name", "sum__num"}
)
assert len(unique_names) == SERIES_LIMIT
assert {column for column in data[0].keys()} == {"state", "name", "sum__num"}
@pytest.mark.usefixtures(
"create_annotation_layers", "load_birth_names_dashboard_with_slices"
@ -888,10 +883,10 @@ class TestPostChartDataApi(BaseTestChartDataApi):
annotation_layers.append(event)
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# response should only contain interval and event data, not formula
self.assertEqual(len(data["result"][0]["annotation_data"]), 2)
assert len(data["result"][0]["annotation_data"]) == 2
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_virtual_table_with_colons_as_datasource(self):
@ -1184,8 +1179,8 @@ class TestGetChartDataApi(BaseTestChartDataApi):
data = json.loads(rv.data.decode("utf-8"))
expected_row_count = self.get_expected_row_count("client_id_3")
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
assert rv.status_code == 200
assert data["result"][0]["rowcount"] == expected_row_count
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
@ -1202,8 +1197,8 @@ class TestGetChartDataApi(BaseTestChartDataApi):
)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
self.assertEqual(data["message"], "Error loading data from cache")
assert rv.status_code == 422
assert data["message"] == "Error loading data from cache"
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
@ -1231,7 +1226,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
f"{CHART_DATA_URI}/test-cache-key",
)
self.assertEqual(rv.status_code, 401)
assert rv.status_code == 401
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
def test_chart_data_cache_key_error(self):
@ -1244,7 +1239,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_with_adhoc_column(self):

View File

@ -46,8 +46,8 @@ class TestSchema(SupersetTestCase):
payload["queries"][0]["row_offset"] = -1
with self.assertRaises(ValidationError) as context:
_ = ChartDataQueryContextSchema().load(payload)
self.assertIn("row_limit", context.exception.messages["queries"][0])
self.assertIn("row_offset", context.exception.messages["queries"][0])
assert "row_limit" in context.exception.messages["queries"][0]
assert "row_offset" in context.exception.messages["queries"][0]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_timegrain(self):

View File

@ -114,15 +114,15 @@ class TestCore(SupersetTestCase):
def test_login(self):
resp = self.get_resp("/login/", data=dict(username="admin", password="general"))
self.assertNotIn("User confirmation needed", resp)
assert "User confirmation needed" not in resp
resp = self.get_resp("/logout/", follow_redirects=True)
self.assertIn("User confirmation needed", resp)
assert "User confirmation needed" in resp
resp = self.get_resp(
"/login/", data=dict(username="admin", password="wrongPassword")
)
self.assertIn("User confirmation needed", resp)
assert "User confirmation needed" in resp
def test_dashboard_endpoint(self):
self.login(ADMIN_USERNAME)
@ -146,20 +146,17 @@ class TestCore(SupersetTestCase):
qobj["groupby"] = []
cache_key_with_groupby = viz.cache_key(qobj)
self.assertNotEqual(cache_key, cache_key_with_groupby)
assert cache_key != cache_key_with_groupby
self.assertNotEqual(
viz.cache_key(qobj), viz.cache_key(qobj, time_compare="12 weeks")
)
assert viz.cache_key(qobj) != viz.cache_key(qobj, time_compare="12 weeks")
self.assertNotEqual(
viz.cache_key(qobj, time_compare="28 days"),
viz.cache_key(qobj, time_compare="12 weeks"),
assert viz.cache_key(qobj, time_compare="28 days") != viz.cache_key(
qobj, time_compare="12 weeks"
)
qobj["inner_from_dttm"] = datetime.datetime(1901, 1, 1)
self.assertEqual(cache_key_with_groupby, viz.cache_key(qobj))
assert cache_key_with_groupby == viz.cache_key(qobj)
def test_admin_only_menu_views(self):
def assert_admin_view_menus_in(role_name, assert_func):
@ -205,9 +202,9 @@ class TestCore(SupersetTestCase):
new_slice_id = resp.json["form_data"]["slice_id"]
slc = db.session.query(Slice).filter_by(id=new_slice_id).one()
self.assertEqual(slc.slice_name, copy_name)
assert slc.slice_name == copy_name
form_data.pop("slice_id") # We don't save the slice id when saving as
self.assertEqual(slc.viz.form_data, form_data)
assert slc.viz.form_data == form_data
form_data = {
"adhoc_filters": [],
@ -224,8 +221,8 @@ class TestCore(SupersetTestCase):
data={"form_data": json.dumps(form_data)},
)
slc = db.session.query(Slice).filter_by(id=new_slice_id).one()
self.assertEqual(slc.slice_name, new_slice_name)
self.assertEqual(slc.viz.form_data, form_data)
assert slc.slice_name == new_slice_name
assert slc.viz.form_data == form_data
# Cleanup
slices = (
@ -261,21 +258,21 @@ class TestCore(SupersetTestCase):
logger.info(f"[{name}]/[{method}]: {url}")
print(f"[{name}]/[{method}]: {url}")
resp = self.client.get(url)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
def test_add_slice(self):
self.login(ADMIN_USERNAME)
# assert that /chart/add responds with 200
url = "/chart/add"
resp = self.client.get(url)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
def test_get_user_slices(self):
self.login(ADMIN_USERNAME)
userid = security_manager.find_user("admin").id
url = f"/sliceasync/api/read?_flt_0_created_by={userid}"
resp = self.client.get(url)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_slices_V2(self):
@ -339,7 +336,7 @@ class TestCore(SupersetTestCase):
data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
self.client.post(url, data=data)
database = superset.utils.database.get_example_database()
self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted)
assert sqlalchemy_uri_decrypted == database.sqlalchemy_uri_decrypted
# Need to clean up after ourselves
database.impersonate_user = False
@ -355,9 +352,9 @@ class TestCore(SupersetTestCase):
self.login(ADMIN_USERNAME)
slc = self.get_slice("Top 10 Girl Name Share")
data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
self.assertEqual(
data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
)
assert data == [
{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}
]
data = self.get_json_resp(
"/superset/warm_up_cache?table_name=energy_usage&db_name=main"
@ -415,29 +412,29 @@ class TestCore(SupersetTestCase):
self.login(ADMIN_USERNAME)
resp = self.client.get("/kv/10001/")
self.assertEqual(404, resp.status_code)
assert 404 == resp.status_code
value = json.dumps({"data": "this is a test"})
resp = self.client.post("/kv/store/", data=dict(data=value))
self.assertEqual(resp.status_code, 404)
assert resp.status_code == 404
@with_feature_flags(KV_STORE=True)
def test_kv_enabled(self):
self.login(ADMIN_USERNAME)
resp = self.client.get("/kv/10001/")
self.assertEqual(404, resp.status_code)
assert 404 == resp.status_code
value = json.dumps({"data": "this is a test"})
resp = self.client.post("/kv/store/", data=dict(data=value))
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
kv = db.session.query(models.KeyValue).first()
kv_value = kv.value
self.assertEqual(json.loads(value), json.loads(kv_value))
assert json.loads(value) == json.loads(kv_value)
resp = self.client.get(f"/kv/{kv.id}/")
self.assertEqual(resp.status_code, 200)
self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8")))
assert resp.status_code == 200
assert json.loads(value) == json.loads(resp.data.decode("utf-8"))
def test_gamma(self):
self.login(GAMMA_USERNAME)
@ -451,7 +448,7 @@ class TestCore(SupersetTestCase):
self.login(ADMIN_USERNAME)
sql = "SELECT '{{ 1+1 }}' as test"
data = self.run_sql(sql, "fdaklj3ws")
self.assertEqual(data["data"][0]["test"], "2")
assert data["data"][0]["test"] == "2"
def test_fetch_datasource_metadata(self):
self.login(ADMIN_USERNAME)
@ -466,7 +463,7 @@ class TestCore(SupersetTestCase):
"id",
]
for k in keys:
self.assertIn(k, resp.keys())
assert k in resp.keys()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_slice_id_is_always_logged_correctly_on_web_request(self):
@ -475,7 +472,7 @@ class TestCore(SupersetTestCase):
slc = db.session.query(Slice).filter_by(slice_name="Girls").one()
qry = db.session.query(models.Log).filter_by(slice_id=slc.id)
self.get_resp(slc.slice_url)
self.assertEqual(1, qry.count())
assert 1 == qry.count()
def create_sample_csvfile(self, filename: str, content: list[str]) -> None:
with open(filename, "w+") as test_file:
@ -490,7 +487,7 @@ class TestCore(SupersetTestCase):
database.allow_file_upload = True
db.session.commit()
add_datasource_page = self.get_resp("/databaseview/list/")
self.assertIn("Upload a CSV", add_datasource_page)
assert "Upload a CSV" in add_datasource_page
def test_dataframe_timezone(self):
tz = pytz.FixedOffset(60)
@ -502,15 +499,15 @@ class TestCore(SupersetTestCase):
df = results.to_pandas_df()
data = dataframe.df_to_records(df)
json_str = json.dumps(data, default=json.pessimistic_json_iso_dttm_ser)
self.assertDictEqual(
self.assertDictEqual( # noqa: PT009
data[0], {"data": pd.Timestamp("2017-11-18 21:53:00.219225+0100", tz=tz)}
)
self.assertDictEqual(
self.assertDictEqual( # noqa: PT009
data[1], {"data": pd.Timestamp("2017-11-18 22:06:30+0100", tz=tz)}
)
self.assertEqual(
json_str,
'[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]',
assert (
json_str
== '[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]'
)
def test_mssql_engine_spec_pymssql(self):
@ -524,11 +521,12 @@ class TestCore(SupersetTestCase):
)
df = results.to_pandas_df()
data = dataframe.df_to_records(df)
self.assertEqual(len(data), 2)
self.assertEqual(
data[0],
{"col1": 1, "col2": 1, "col3": pd.Timestamp("2017-10-19 23:39:16.660000")},
)
assert len(data) == 2
assert data[0] == {
"col1": 1,
"col2": 1,
"col3": pd.Timestamp("2017-10-19 23:39:16.660000"),
}
def test_comments_in_sqlatable_query(self):
clean_query = "SELECT\n '/* val 1 */' AS c1,\n '-- val 2' AS c2\nFROM tbl"
@ -554,9 +552,9 @@ class TestCore(SupersetTestCase):
)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["errors"][0]["message"],
"The dataset associated with this chart no longer exists",
assert (
data["errors"][0]["message"]
== "The dataset associated with this chart no longer exists"
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -579,8 +577,8 @@ class TestCore(SupersetTestCase):
)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["rowcount"], 2)
assert rv.status_code == 200
assert data["rowcount"] == 2
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_explore_json_dist_bar_order(self):
@ -741,7 +739,7 @@ class TestCore(SupersetTestCase):
"/superset/explore_json/?results=true",
data={"form_data": json.dumps(form_data)},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
@ -780,8 +778,8 @@ class TestCore(SupersetTestCase):
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["rowcount"], 2)
assert rv.status_code == 200
assert data["rowcount"] == 2
@mock.patch(
"superset.utils.cache_manager.CacheManager.cache",
@ -814,7 +812,7 @@ class TestCore(SupersetTestCase):
mock_cache.return_value = MockCache()
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
def test_explore_json_data_invalid_cache_key(self):
self.login(ADMIN_USERNAME)
@ -822,8 +820,8 @@ class TestCore(SupersetTestCase):
rv = self.client.get(f"/superset/explore_json/data/{cache_key}")
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 404)
self.assertEqual(data["error"], "Cached data not found")
assert rv.status_code == 404
assert data["error"] == "Cached data not found"
def test_results_default_deserialization(self):
use_new_deserialization = False
@ -863,14 +861,14 @@ class TestCore(SupersetTestCase):
serialized_payload = sql_lab._serialize_payload(
payload, use_new_deserialization
)
self.assertIsInstance(serialized_payload, str)
assert isinstance(serialized_payload, str)
query_mock = mock.Mock()
deserialized_payload = superset.views.utils._deserialize_results_payload(
serialized_payload, query_mock, use_new_deserialization
)
self.assertDictEqual(deserialized_payload, payload)
self.assertDictEqual(deserialized_payload, payload) # noqa: PT009
query_mock.assert_not_called()
def test_results_msgpack_deserialization(self):
@ -911,7 +909,7 @@ class TestCore(SupersetTestCase):
serialized_payload = sql_lab._serialize_payload(
payload, use_new_deserialization
)
self.assertIsInstance(serialized_payload, bytes)
assert isinstance(serialized_payload, bytes)
with mock.patch.object(
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
@ -925,7 +923,7 @@ class TestCore(SupersetTestCase):
df = results.to_pandas_df()
payload["data"] = dataframe.df_to_records(df)
self.assertDictEqual(deserialized_payload, payload)
self.assertDictEqual(deserialized_payload, payload) # noqa: PT009
expand_data.assert_called_once()
@mock.patch.dict(
@ -960,7 +958,7 @@ class TestCore(SupersetTestCase):
]
for url in urls:
data = self.get_resp(url)
self.assertTrue(html_string in data)
assert html_string in data
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -991,7 +989,7 @@ class TestCore(SupersetTestCase):
tab_state_id = resp["id"]
payload = self.get_json_resp(f"/tabstateview/{tab_state_id}")
self.assertEqual(payload["label"], "Untitled Query foo")
assert payload["label"] == "Untitled Query foo"
def test_tabstate_update(self):
self.login(ADMIN_USERNAME)
@ -1014,87 +1012,87 @@ class TestCore(SupersetTestCase):
client_id = "asdfasdf"
data = {"sql": json.dumps("select 1"), "latest_query_id": json.dumps(client_id)}
response = self.client.put(f"/tabstateview/{tab_state_id}", data=data)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json["error"], "Bad request")
assert response.status_code == 400
assert response.json["error"] == "Bad request"
# generate query
db.session.add(Query(client_id=client_id, database_id=1))
db.session.commit()
# update tab state with a valid client_id
response = self.client.put(f"/tabstateview/{tab_state_id}", data=data)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
# nulls should be ok too
data["latest_query_id"] = "null"
response = self.client.put(f"/tabstateview/{tab_state_id}", data=data)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
def test_virtual_table_explore_visibility(self):
# test that default visibility it set to True
database = superset.utils.database.get_example_database()
self.assertEqual(database.allows_virtual_table_explore, True)
assert database.allows_virtual_table_explore is True
# test that visibility is disabled when extra is set to False
extra = database.get_extra()
extra["allows_virtual_table_explore"] = False
database.extra = json.dumps(extra)
self.assertEqual(database.allows_virtual_table_explore, False)
assert database.allows_virtual_table_explore is False
# test that visibility is enabled when extra is set to True
extra = database.get_extra()
extra["allows_virtual_table_explore"] = True
database.extra = json.dumps(extra)
self.assertEqual(database.allows_virtual_table_explore, True)
assert database.allows_virtual_table_explore is True
# test that visibility is not broken with bad values
extra = database.get_extra()
extra["allows_virtual_table_explore"] = "trash value"
database.extra = json.dumps(extra)
self.assertEqual(database.allows_virtual_table_explore, True)
assert database.allows_virtual_table_explore is True
def test_data_preview_visibility(self):
# test that default visibility is allowed
database = utils.get_example_database()
self.assertEqual(database.disable_data_preview, False)
assert database.disable_data_preview is False
# test that visibility is disabled when extra is set to true
extra = database.get_extra()
extra["disable_data_preview"] = True
database.extra = json.dumps(extra)
self.assertEqual(database.disable_data_preview, True)
assert database.disable_data_preview is True
# test that visibility is enabled when extra is set to false
extra = database.get_extra()
extra["disable_data_preview"] = False
database.extra = json.dumps(extra)
self.assertEqual(database.disable_data_preview, False)
assert database.disable_data_preview is False
# test that visibility is not broken with bad values
extra = database.get_extra()
extra["disable_data_preview"] = "trash value"
database.extra = json.dumps(extra)
self.assertEqual(database.disable_data_preview, False)
assert database.disable_data_preview is False
def test_disable_drill_to_detail(self):
# test that disable_drill_to_detail is False by default
database = utils.get_example_database()
self.assertEqual(database.disable_drill_to_detail, False)
assert database.disable_drill_to_detail is False
# test that disable_drill_to_detail can be set to True
extra = database.get_extra()
extra["disable_drill_to_detail"] = True
database.extra = json.dumps(extra)
self.assertEqual(database.disable_drill_to_detail, True)
assert database.disable_drill_to_detail is True
# test that disable_drill_to_detail can be set to False
extra = database.get_extra()
extra["disable_drill_to_detail"] = False
database.extra = json.dumps(extra)
self.assertEqual(database.disable_drill_to_detail, False)
assert database.disable_drill_to_detail is False
# test that disable_drill_to_detail is not broken with bad values
extra = database.get_extra()
extra["disable_drill_to_detail"] = "trash value"
database.extra = json.dumps(extra)
self.assertEqual(database.disable_drill_to_detail, False)
assert database.disable_drill_to_detail is False
def test_explore_database_id(self):
database = superset.utils.database.get_example_database()
@ -1102,13 +1100,13 @@ class TestCore(SupersetTestCase):
# test that explore_database_id is the regular database
# id if none is set in the extra
self.assertEqual(database.explore_database_id, database.id)
assert database.explore_database_id == database.id
# test that explore_database_id is correct if the extra is set
extra = database.get_extra()
extra["explore_database_id"] = explore_database.id
database.extra = json.dumps(extra)
self.assertEqual(database.explore_database_id, explore_database.id)
assert database.explore_database_id == explore_database.id
def test_get_column_names_from_metric(self):
simple_metric = {
@ -1146,7 +1144,7 @@ class TestCore(SupersetTestCase):
self.login(ADMIN_USERNAME)
data = self.get_resp(url)
self.assertIn("Error message", data)
assert "Error message" in data
# Assert we can handle a driver exception at the mutator level
exception = SQLAlchemyError("Error message")
@ -1156,7 +1154,7 @@ class TestCore(SupersetTestCase):
self.login(ADMIN_USERNAME)
data = self.get_resp(url)
self.assertIn("Error message", data)
assert "Error message" in data
@pytest.mark.skip(
"TODO This test was wrong - 'Error message' was in the language pack"
@ -1176,7 +1174,7 @@ class TestCore(SupersetTestCase):
self.login(ADMIN_USERNAME)
data = self.get_resp(url)
self.assertIn("Error message", data)
assert "Error message" in data
# Assert we can handle a driver exception at the mutator level
exception = SQLAlchemyError("Error message")
@ -1186,7 +1184,7 @@ class TestCore(SupersetTestCase):
self.login(ADMIN_USERNAME)
data = self.get_resp(url)
self.assertIn("Error message", data)
assert "Error message" in data
@pytest.mark.usefixtures("load_energy_table_with_slice")
@mock.patch("superset.commands.explore.form_data.create.CreateFormDataCommand.run")
@ -1200,9 +1198,7 @@ class TestCore(SupersetTestCase):
rv = self.client.get(
f"/superset/explore/?form_data={quote(json.dumps(form_data))}"
)
self.assertEqual(
rv.headers["Location"], f"/explore/?form_data_key={random_key}"
)
assert rv.headers["Location"] == f"/explore/?form_data_key={random_key}"
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_has_table(self):
@ -1223,7 +1219,7 @@ class TestCore(SupersetTestCase):
expected_url = "/superset/dashboard/1?permalink_key=123&standalone=3"
self.assertEqual(resp.headers["Location"], expected_url)
assert resp.headers["Location"] == expected_url
assert resp.status_code == 302

View File

@ -117,7 +117,7 @@ class TestDashboard(SupersetTestCase):
url = "/dashboard/new/"
response = self.client.get(url, follow_redirects=False)
dash_count_after = db.session.query(func.count(Dashboard.id)).first()[0]
self.assertEqual(dash_count_before + 1, dash_count_after)
assert dash_count_before + 1 == dash_count_after
group = re.match(
r"\/superset\/dashboard\/([0-9]*)\/\?edit=true",
response.headers["Location"],
@ -145,25 +145,25 @@ class TestDashboard(SupersetTestCase):
self.logout()
resp = self.get_resp("/api/v1/chart/")
self.assertNotIn("birth_names", resp)
assert "birth_names" not in resp
resp = self.get_resp("/api/v1/dashboard/")
self.assertNotIn("/superset/dashboard/births/", resp)
assert "/superset/dashboard/births/" not in resp
self.grant_public_access_to_table(table)
# Try access after adding appropriate permissions.
self.assertIn("birth_names", self.get_resp("/api/v1/chart/"))
assert "birth_names" in self.get_resp("/api/v1/chart/")
resp = self.get_resp("/api/v1/dashboard/")
self.assertIn("/superset/dashboard/births/", resp)
assert "/superset/dashboard/births/" in resp
# Confirm that public doesn't have access to other datasets.
resp = self.get_resp("/api/v1/chart/")
self.assertNotIn("wb_health_population", resp)
assert "wb_health_population" not in resp
resp = self.get_resp("/api/v1/dashboard/")
self.assertNotIn("/superset/dashboard/world_health/", resp)
assert "/superset/dashboard/world_health/" not in resp
# Cleanup
self.revoke_public_access_to_table(table)
@ -224,8 +224,8 @@ class TestDashboard(SupersetTestCase):
db.session.delete(hidden_dash)
db.session.commit()
self.assertIn(f"/superset/dashboard/{my_dash_slug}/", resp)
self.assertNotIn(f"/superset/dashboard/{not_my_dash_slug}/", resp)
assert f"/superset/dashboard/{my_dash_slug}/" in resp
assert f"/superset/dashboard/{not_my_dash_slug}/" not in resp
def test_user_can_not_view_unpublished_dash(self):
admin_user = security_manager.find_user("admin")
@ -247,7 +247,7 @@ class TestDashboard(SupersetTestCase):
db.session.delete(dash)
db.session.commit()
self.assertNotIn(f"/superset/dashboard/{slug}/", resp)
assert f"/superset/dashboard/{slug}/" not in resp
if __name__ == "__main__":

File diff suppressed because it is too large Load Diff

View File

@ -63,19 +63,19 @@ class DashboardTestCase(SupersetTestCase):
def assert_permission_was_created(self, dashboard):
view_menu = security_manager.find_view_menu(dashboard.view_name)
self.assertIsNotNone(view_menu)
self.assertEqual(len(security_manager.find_permissions_view_menu(view_menu)), 1)
assert view_menu is not None
assert len(security_manager.find_permissions_view_menu(view_menu)) == 1
def assert_permission_kept_and_changed(self, updated_dashboard, excepted_view_id):
view_menu_after_title_changed = security_manager.find_view_menu(
updated_dashboard.view_name
)
self.assertIsNotNone(view_menu_after_title_changed)
self.assertEqual(view_menu_after_title_changed.id, excepted_view_id)
assert view_menu_after_title_changed is not None
assert view_menu_after_title_changed.id == excepted_view_id
def assert_permissions_were_deleted(self, deleted_dashboard):
view_menu = security_manager.find_view_menu(deleted_dashboard.view_name)
self.assertIsNone(view_menu)
assert view_menu is None
def clean_created_objects(self):
with app.test_request_context():

View File

@ -87,12 +87,12 @@ class TestDashboardDAO(SupersetTestCase):
"duplicate_slices": False,
}
dash = DashboardDAO.copy_dashboard(original_dash, dash_data)
self.assertNotEqual(dash.id, original_dash.id)
self.assertEqual(len(dash.position), len(original_dash.position))
self.assertEqual(dash.dashboard_title, "copied dash")
self.assertEqual(dash.css, "<css>")
self.assertEqual(dash.owners, [security_manager.find_user("admin")])
self.assertCountEqual(dash.slices, original_dash.slices)
assert dash.id != original_dash.id
assert len(dash.position) == len(original_dash.position)
assert dash.dashboard_title == "copied dash"
assert dash.css == "<css>"
assert dash.owners == [security_manager.find_user("admin")]
self.assertCountEqual(dash.slices, original_dash.slices) # noqa: PT009
db.session.delete(dash)
db.session.commit()
@ -118,9 +118,7 @@ class TestDashboardDAO(SupersetTestCase):
"duplicate_slices": False,
}
dash = DashboardDAO.copy_dashboard(original_dash, dash_data)
self.assertEqual(
dash.params_dict["native_filter_configuration"], [{"mock": "filter"}]
)
assert dash.params_dict["native_filter_configuration"] == [{"mock": "filter"}]
db.session.delete(dash)
db.session.commit()
@ -141,15 +139,15 @@ class TestDashboardDAO(SupersetTestCase):
"duplicate_slices": True,
}
dash = DashboardDAO.copy_dashboard(original_dash, dash_data)
self.assertNotEqual(dash.id, original_dash.id)
self.assertEqual(len(dash.position), len(original_dash.position))
self.assertEqual(dash.dashboard_title, "copied dash")
self.assertEqual(dash.css, "<css>")
self.assertEqual(dash.owners, [security_manager.find_user("admin")])
self.assertEqual(len(dash.slices), len(original_dash.slices))
assert dash.id != original_dash.id
assert len(dash.position) == len(original_dash.position)
assert dash.dashboard_title == "copied dash"
assert dash.css == "<css>"
assert dash.owners == [security_manager.find_user("admin")]
assert len(dash.slices) == len(original_dash.slices)
for original_slc in original_dash.slices:
for slc in dash.slices:
self.assertNotEqual(slc.id, original_slc.id)
assert slc.id != original_slc.id
for slc in dash.slices:
db.session.delete(slc)

View File

@ -109,8 +109,8 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405
# assert
self.assertIn(my_owned_dashboard.url, get_dashboards_response)
self.assertNotIn(not_my_owned_dashboard.url, get_dashboards_response)
assert my_owned_dashboard.url in get_dashboards_response
assert not_my_owned_dashboard.url not in get_dashboards_response
def test_get_dashboards__owners_can_view_empty_dashboard(self):
# arrange
@ -123,7 +123,7 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405
# assert
self.assertNotIn(dashboard_url, get_dashboards_response)
assert dashboard_url not in get_dashboards_response
def test_get_dashboards__user_can_not_view_unpublished_dash(self):
# arrange
@ -139,9 +139,7 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
get_dashboards_response_as_gamma = self.get_resp(DASHBOARDS_API_URL) # noqa: F405
# assert
self.assertNotIn(
admin_and_draft_dashboard.url, get_dashboards_response_as_gamma
)
assert admin_and_draft_dashboard.url not in get_dashboards_response_as_gamma
@pytest.mark.usefixtures("load_energy_table_with_slice", "load_dashboard")
def test_get_dashboards__users_can_view_permitted_dashboard(self):
@ -172,8 +170,8 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405
# assert
self.assertIn(second_dash.url, get_dashboards_response)
self.assertIn(first_dash.url, get_dashboards_response)
assert second_dash.url in get_dashboards_response
assert first_dash.url in get_dashboards_response
finally:
self.revoke_public_access_to_table(accessed_table)
@ -193,5 +191,5 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
rv = self.client.get(uri)
self.assert200(rv)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(0, data["count"])
assert 0 == data["count"]
DashboardDAO.delete([dashboard])

View File

@ -207,7 +207,7 @@ class TestDashboardRoleBasedSecurity(BaseTestDashboardSecurity):
request_payload = get_query_context("birth_names")
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
# post
revoke_access_to_dashboard(dashboard_to_access, new_role) # noqa: F405
@ -480,12 +480,12 @@ class TestDashboardRoleBasedSecurity(BaseTestDashboardSecurity):
self.login(GAMMA_USERNAME)
rv = self.client.post(uri, json=data)
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
self.logout()
self.login(ADMIN_USERNAME)
rv = self.client.post(uri, json=data)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
target = (

View File

@ -187,7 +187,7 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
expected_columns = [
"allow_ctas",
@ -216,8 +216,8 @@ class TestDatabaseApi(SupersetTestCase):
"uuid",
]
self.assertGreater(response["count"], 0)
self.assertEqual(list(response["result"][0].keys()), expected_columns)
assert response["count"] > 0
assert list(response["result"][0].keys()) == expected_columns
def test_get_items_filter(self):
"""
@ -241,8 +241,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["count"], len(dbs))
assert rv.status_code == 200
assert response["count"] == len(dbs)
# Cleanup
db.session.delete(test_database)
@ -255,9 +255,9 @@ class TestDatabaseApi(SupersetTestCase):
self.login(GAMMA_USERNAME)
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["count"], 0)
assert response["count"] == 0
def test_create_database(self):
"""
@ -284,7 +284,7 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
# Cleanup
model = db.session.query(Database).get(response.get("id"))
assert model.configuration_method == ConfigurationMethod.SQLALCHEMY_FORM
@ -326,14 +326,14 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(response.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX")
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX"
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@ -385,10 +385,10 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
assert rv.status_code == 400
assert (
response.get("message")
== "A database port is required when connecting via SSH Tunnel."
)
@mock.patch(
@ -434,19 +434,19 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
assert model_ssh_tunnel.database_id == response_update.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@ -500,15 +500,15 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response_create = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response_create.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
assert rv.status_code == 400
assert (
response.get("message")
== "A database port is required when connecting via SSH Tunnel."
)
# Cleanup
@ -563,19 +563,19 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
assert model_ssh_tunnel.database_id == response_update.get("id")
database_data_with_ssh_tunnel_null = {
"database_name": "test-db-with-ssh-tunnel",
@ -585,7 +585,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
@ -651,30 +651,28 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(model_ssh_tunnel.username, "foo")
assert model_ssh_tunnel.database_id == response.get("id")
assert model_ssh_tunnel.username == "foo"
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
self.assertEqual(
response_update.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX"
)
self.assertEqual(model_ssh_tunnel.username, "Test")
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
self.assertEqual(model_ssh_tunnel.server_port, 8080)
assert model_ssh_tunnel.database_id == response_update.get("id")
assert response_update.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX"
assert model_ssh_tunnel.username == "Test"
assert model_ssh_tunnel.server_address == "123.132.123.1"
assert model_ssh_tunnel.server_port == 8080
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@ -715,13 +713,13 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@ -769,7 +767,7 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
model_ssh_tunnel = (
db.session.query(SSHTunnel)
@ -777,7 +775,7 @@ class TestDatabaseApi(SupersetTestCase):
.one_or_none()
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
assert response == fail_message
# Check that rollback was called
mock_rollback.assert_called()
@ -824,14 +822,14 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
assert model_ssh_tunnel.database_id == response.get("id")
assert response.get("result")["ssh_tunnel"] == response_ssh_tunnel
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@ -866,8 +864,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, {"message": "SSH Tunneling is not enabled"})
assert rv.status_code == 400
assert response == {"message": "SSH Tunneling is not enabled"}
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
@ -897,7 +895,7 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{database.id}/table/{table_name}/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
def test_create_database_invalid_configuration_method(self):
"""
@ -959,7 +957,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
self.assertIn("sqlalchemy_form", response["result"]["configuration_method"])
assert "sqlalchemy_form" in response["result"]["configuration_method"]
def test_create_database_server_cert_validate(self):
"""
@ -981,8 +979,8 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"server_cert": ["Invalid certificate"]}}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
assert rv.status_code == 400
assert response == expected_response
def test_create_database_json_validate(self):
"""
@ -1016,8 +1014,8 @@ class TestDatabaseApi(SupersetTestCase):
],
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
assert rv.status_code == 400
assert response == expected_response
def test_create_database_extra_metadata_validate(self):
"""
@ -1052,8 +1050,8 @@ class TestDatabaseApi(SupersetTestCase):
]
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
assert rv.status_code == 400
assert response == expected_response
def test_create_database_unique_validate(self):
"""
@ -1078,8 +1076,8 @@ class TestDatabaseApi(SupersetTestCase):
"database_name": "A database with the same name already exists."
}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
assert rv.status_code == 422
assert response == expected_response
def test_create_database_uri_validate(self):
"""
@ -1095,11 +1093,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertIn(
"Invalid connection string",
response["message"]["sqlalchemy_uri"][0],
)
assert rv.status_code == 400
assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0]
@mock.patch(
"superset.views.core.app.config",
@ -1127,8 +1122,8 @@ class TestDatabaseApi(SupersetTestCase):
]
}
}
self.assertEqual(response_data, expected_response)
self.assertEqual(response.status_code, 400)
assert response_data == expected_response
assert response.status_code == 400
def test_create_database_conn_fail(self):
"""
@ -1192,11 +1187,11 @@ class TestDatabaseApi(SupersetTestCase):
expected_response_postgres = {
"errors": [dataclasses.asdict(superset_error_postgres)]
}
self.assertEqual(response.status_code, 500)
assert response.status_code == 500
if example_db.backend == "mysql":
self.assertEqual(response_data, expected_response_mysql)
assert response_data == expected_response_mysql
else:
self.assertEqual(response_data, expected_response_postgres)
assert response_data == expected_response_postgres
def test_update_database(self):
"""
@ -1213,7 +1208,7 @@ class TestDatabaseApi(SupersetTestCase):
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
@ -1242,8 +1237,8 @@ class TestDatabaseApi(SupersetTestCase):
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
assert rv.status_code == 422
assert response == expected_response
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
@ -1271,8 +1266,8 @@ class TestDatabaseApi(SupersetTestCase):
"database_name": "A database with the same name already exists."
}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
assert rv.status_code == 422
assert response == expected_response
# Cleanup
db.session.delete(test_database1)
db.session.delete(test_database2)
@ -1286,7 +1281,7 @@ class TestDatabaseApi(SupersetTestCase):
database_data = {"database_name": "test-database-updated"}
uri = "api/v1/database/invalid"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_update_database_uri_validate(self):
"""
@ -1305,11 +1300,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertIn(
"Invalid connection string",
response["message"]["sqlalchemy_uri"][0],
)
assert rv.status_code == 400
assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0]
db.session.delete(test_database)
db.session.commit()
@ -1369,9 +1361,9 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Database).get(database_id)
self.assertEqual(model, None)
assert model is None
def test_delete_database_not_found(self):
"""
@ -1381,7 +1373,7 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{max_id + 1}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("create_database_with_dataset")
def test_delete_database_with_datasets(self):
@ -1391,7 +1383,7 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{self._database.id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
@pytest.mark.usefixtures("create_database_with_report")
def test_delete_database_with_report(self):
@ -1407,11 +1399,11 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{database.id}"
rv = self.client.delete(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
expected_response = {
"message": "There are associated alerts or reports: report_with_database"
}
self.assertEqual(response, expected_response)
assert response == expected_response
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_table_metadata(self):
@ -1422,12 +1414,12 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["name"], "birth_names")
self.assertIsNone(response["comment"])
self.assertTrue(len(response["columns"]) > 5)
self.assertTrue(response.get("selectStar").startswith("SELECT"))
assert response["name"] == "birth_names"
assert response["comment"] is None
assert len(response["columns"]) > 5
assert response.get("selectStar").startswith("SELECT")
def test_info_security_database(self):
"""
@ -1456,11 +1448,11 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
uri = "api/v1/database/some_database/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_invalid_table_table_metadata(self):
"""
@ -1472,25 +1464,22 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
if example_db.backend == "sqlite":
self.assertEqual(rv.status_code, 200)
self.assertEqual(
data,
{
"columns": [],
"comment": None,
"foreignKeys": [],
"indexes": [],
"name": "wrong_table",
"primaryKey": {"constrained_columns": None, "name": None},
"selectStar": "SELECT\n *\nFROM wrong_table\nLIMIT 100\nOFFSET 0",
},
)
assert rv.status_code == 200
assert data == {
"columns": [],
"comment": None,
"foreignKeys": [],
"indexes": [],
"name": "wrong_table",
"primaryKey": {"constrained_columns": None, "name": None},
"selectStar": "SELECT\n *\nFROM wrong_table\nLIMIT 100\nOFFSET 0",
}
elif example_db.backend == "mysql":
self.assertEqual(rv.status_code, 422)
self.assertEqual(data, {"message": "`wrong_table`"})
assert rv.status_code == 422
assert data == {"message": "`wrong_table`"}
else:
self.assertEqual(rv.status_code, 422)
self.assertEqual(data, {"message": "wrong_table"})
assert rv.status_code == 422
assert data == {"message": "wrong_table"}
def test_get_table_metadata_no_db_permission(self):
"""
@ -1500,7 +1489,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_table_extra_metadata_deprecated(self):
@ -1511,9 +1500,9 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/table_extra/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response, {})
assert response == {}
def test_get_invalid_database_table_extra_metadata_deprecated(self):
"""
@ -1523,11 +1512,11 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}/table_extra/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
uri = "api/v1/database/some_database/table_extra/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_invalid_table_table_extra_metadata_deprecated(self):
"""
@ -1539,8 +1528,8 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data, {})
assert rv.status_code == 200
assert data == {}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_select_star(self):
@ -1551,7 +1540,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
def test_get_select_star_not_allowed(self):
"""
@ -1561,7 +1550,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_select_star_not_found_database(self):
"""
@ -1571,7 +1560,7 @@ class TestDatabaseApi(SupersetTestCase):
max_id = db.session.query(func.max(Database.id)).scalar()
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_select_star_not_found_table(self):
"""
@ -1585,7 +1574,7 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
rv = self.client.get(uri)
# TODO(bkyryliuk): investigate why presto returns 500
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
assert rv.status_code == (404 if example_db.backend != "presto" else 500)
def test_get_allow_file_upload_filter(self):
"""
@ -1952,13 +1941,13 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, set(response["result"]))
assert schemas == set(response["result"])
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, set(response["result"]))
assert schemas == set(response["result"])
def test_database_schemas_not_found(self):
"""
@ -1968,7 +1957,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/schemas/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_database_schemas_invalid_query(self):
"""
@ -1979,7 +1968,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
def test_database_tables(self):
"""
@ -1993,17 +1982,17 @@ class TestDatabaseApi(SupersetTestCase):
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': schema_name})}"
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
if database.backend == "postgresql":
response = json.loads(rv.data.decode("utf-8"))
schemas = [
s[0] for s in database.get_all_table_names_in_schema(None, schema_name)
]
self.assertEqual(response["count"], len(schemas))
assert response["count"] == len(schemas)
for option in response["result"]:
self.assertEqual(option["extra"], None)
self.assertEqual(option["type"], "table")
self.assertTrue(option["value"] in schemas)
assert option["extra"] is None
assert option["type"] == "table"
assert option["value"] in schemas
@patch("superset.utils.log.logger")
def test_database_tables_not_found(self, logger_mock):
@ -2014,7 +2003,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/tables/?q={prison.dumps({'schema_name': 'non_existent'})}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
logger_mock.warning.assert_called_once_with(
"Database not found.", exc_info=True
)
@ -2028,7 +2017,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
@patch("superset.utils.log.logger")
@mock.patch("superset.security.manager.SupersetSecurityManager.can_access_database")
@ -2046,7 +2035,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': 'main'})}"
)
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
logger_mock.warning.assert_called_once_with("Test Error", exc_info=True)
def test_test_connection(self):
@ -2074,8 +2063,8 @@ class TestDatabaseApi(SupersetTestCase):
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
assert rv.status_code == 200
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
# validate that the endpoint works with the decrypted sqlalchemy uri
data = {
@ -2086,8 +2075,8 @@ class TestDatabaseApi(SupersetTestCase):
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
assert rv.status_code == 200
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
def test_test_connection_failed(self):
"""
@ -2103,8 +2092,8 @@ class TestDatabaseApi(SupersetTestCase):
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
assert rv.status_code == 422
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"errors": [
@ -2123,7 +2112,7 @@ class TestDatabaseApi(SupersetTestCase):
}
]
}
self.assertEqual(response, expected_response)
assert response == expected_response
data = {
"sqlalchemy_uri": "mssql+pymssql://url",
@ -2132,8 +2121,8 @@ class TestDatabaseApi(SupersetTestCase):
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
assert rv.status_code == 422
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"errors": [
@ -2152,7 +2141,7 @@ class TestDatabaseApi(SupersetTestCase):
}
]
}
self.assertEqual(response, expected_response)
assert response == expected_response
def test_test_connection_unsafe_uri(self):
"""
@ -2169,7 +2158,7 @@ class TestDatabaseApi(SupersetTestCase):
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
@ -2178,7 +2167,7 @@ class TestDatabaseApi(SupersetTestCase):
]
}
}
self.assertEqual(response, expected_response)
assert response == expected_response
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
@ -2250,10 +2239,10 @@ class TestDatabaseApi(SupersetTestCase):
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["charts"]["count"], 33)
self.assertEqual(response["dashboards"]["count"], 3)
assert response["charts"]["count"] == 33
assert response["dashboards"]["count"] == 3
def test_get_database_related_objects_not_found(self):
"""
@ -2265,13 +2254,13 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{invalid_id}/related_objects/"
self.login(ADMIN_USERNAME)
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
self.logout()
self.login(GAMMA_USERNAME)
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_export_database(self):
"""
@ -2679,7 +2668,7 @@ class TestDatabaseApi(SupersetTestCase):
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.password, "TEST")
assert model_ssh_tunnel.password == "TEST"
db.session.delete(database)
db.session.commit()
@ -2797,8 +2786,8 @@ class TestDatabaseApi(SupersetTestCase):
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
assert model_ssh_tunnel.private_key == "TestPrivateKey"
assert model_ssh_tunnel.private_key_password == "TEST"
db.session.delete(database)
db.session.commit()
@ -3852,8 +3841,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["result"], [])
assert rv.status_code == 200
assert response["result"] == []
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
@ -3878,18 +3867,15 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(
response["result"],
[
{
"end_column": None,
"line_number": 1,
"message": 'ERROR: syntax error at or near "table1"',
"start_column": None,
}
],
)
assert rv.status_code == 200
assert response["result"] == [
{
"end_column": None,
"line_number": 1,
"message": 'ERROR: syntax error at or near "table1"',
"start_column": None,
}
]
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
@ -3910,7 +3896,7 @@ class TestDatabaseApi(SupersetTestCase):
f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/"
)
rv = self.client.post(uri, json=request_payload)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
@ -3932,8 +3918,8 @@ class TestDatabaseApi(SupersetTestCase):
)
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}})
assert rv.status_code == 400
assert response == {"message": {"sql": ["Field may not be null."]}}
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
@ -3956,29 +3942,26 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
self.assertEqual(
response,
{
"errors": [
{
"message": f"no SQL validator is configured for "
f"{example_db.backend}",
"error_type": "GENERIC_DB_ENGINE_ERROR",
"level": "error",
"extra": {
"issue_codes": [
{
"code": 1002,
"message": "Issue 1002 - The database returned an "
"unexpected error.",
}
]
},
}
]
},
)
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": f"no SQL validator is configured for "
f"{example_db.backend}",
"error_type": "GENERIC_DB_ENGINE_ERROR",
"level": "error",
"extra": {
"issue_codes": [
{
"code": 1002,
"message": "Issue 1002 - The database returned an "
"unexpected error.",
}
]
},
}
]
}
@mock.patch("superset.commands.database.validate_sql.get_validator_by_name")
@mock.patch.dict(
@ -4013,8 +3996,8 @@ class TestDatabaseApi(SupersetTestCase):
# TODO(bkyryliuk): properly handle hive error
if get_example_database().backend == "hive":
return
self.assertEqual(rv.status_code, 422)
self.assertIn("Kaboom!", response["errors"][0]["message"])
assert rv.status_code == 422
assert "Kaboom!" in response["errors"][0]["message"]
def test_get_databases_with_extra_filters(self):
"""
@ -4048,14 +4031,14 @@ class TestDatabaseApi(SupersetTestCase):
uri, json={**database_data, "database_name": "dyntest-create-database-1"}
)
first_response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
uri = "api/v1/database/"
rv = self.client.post(
uri, json={**database_data, "database_name": "create-database-2"}
)
second_response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
# The filter function
def _base_filter(query):
@ -4074,11 +4057,11 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
# All databases must be returned if no filter is present
self.assertEqual(data["count"], len(dbs))
assert data["count"] == len(dbs)
database_names = [item["database_name"] for item in data["result"]]
database_names.sort()
# All Databases because we are an admin
self.assertEqual(database_names, expected_names)
assert database_names == expected_names
assert rv.status_code == 200
# Our filter function wasn't get called
base_filter_mock.assert_not_called()
@ -4092,10 +4075,10 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
# Only one database start with dyntest
self.assertEqual(data["count"], 1)
assert data["count"] == 1
database_names = [item["database_name"] for item in data["result"]]
# Only the database that starts with tests, even if we are an admin
self.assertEqual(database_names, ["dyntest-create-database-1"])
assert database_names == ["dyntest-create-database-1"]
assert rv.status_code == 200
# The filter function is called now that it's defined in our config
base_filter_mock.assert_called()

View File

@ -740,7 +740,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.password, "TEST")
assert model_ssh_tunnel.password == "TEST"
db.session.delete(database)
db.session.commit()
@ -787,8 +787,8 @@ class TestImportDatabasesCommand(SupersetTestCase):
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
assert model_ssh_tunnel.private_key == "TestPrivateKey"
assert model_ssh_tunnel.private_key_password == "TEST"
db.session.delete(database)
db.session.commit()

View File

@ -192,7 +192,7 @@ class TestDatasetApi(SupersetTestCase):
def count_datasets():
uri = "api/v1/chart/"
rv = self.client.get(uri, "get_list")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = rv.get_json()
return data["count"]
@ -1342,14 +1342,14 @@ class TestDatasetApi(SupersetTestCase):
table_data = {"description": "changed_description"}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=table_data)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = self.get_assert_metric("api/v1/dataset/", "get_list")
res = json.loads(response.data.decode("utf-8"))["result"]
current_dataset = [d for d in res if d["id"] == dataset.id][0]
self.assertEqual(current_dataset["description"], "changed_description")
self.assertNotIn("username", current_dataset["changed_by"].keys())
assert current_dataset["description"] == "changed_description"
assert "username" not in current_dataset["changed_by"].keys()
db.session.delete(dataset)
db.session.commit()
@ -1364,13 +1364,13 @@ class TestDatasetApi(SupersetTestCase):
table_data = {"description": "changed_description"}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=table_data)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = self.get_assert_metric(uri, "get")
res = json.loads(response.data.decode("utf-8"))["result"]
self.assertEqual(res["description"], "changed_description")
self.assertNotIn("username", res["changed_by"].keys())
assert res["description"] == "changed_description"
assert "username" not in res["changed_by"].keys()
db.session.delete(dataset)
db.session.commit()
@ -2311,14 +2311,14 @@ class TestDatasetApi(SupersetTestCase):
"database_id": get_example_database().id,
},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
dataset = (
db.session.query(SqlaTable)
.filter(SqlaTable.table_name == "virtual_dataset")
.one()
)
self.assertEqual(response["result"], {"table_id": dataset.id})
assert response["result"] == {"table_id": dataset.id}
def test_get_or_create_dataset_database_not_found(self):
"""
@ -2329,9 +2329,9 @@ class TestDatasetApi(SupersetTestCase):
"api/v1/dataset/get_or_create/",
json={"table_name": "virtual_dataset", "database_id": 999},
)
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["message"], {"database": ["Database does not exist"]})
assert response["message"] == {"database": ["Database does not exist"]}
@patch("superset.commands.dataset.create.CreateDatasetCommand.run")
def test_get_or_create_dataset_create_fails(self, command_run_mock):
@ -2347,9 +2347,9 @@ class TestDatasetApi(SupersetTestCase):
"database_id": get_example_database().id,
},
)
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["message"], "Dataset could not be created.")
assert response["message"] == "Dataset could not be created."
def test_get_or_create_dataset_creates_table(self):
"""
@ -2370,7 +2370,7 @@ class TestDatasetApi(SupersetTestCase):
"template_params": '{"param": 1}',
},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
table = (
db.session.query(SqlaTable)
@ -2410,12 +2410,9 @@ class TestDatasetApi(SupersetTestCase):
"db_name": get_example_database().database_name,
},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
len(data["result"]),
len(energy_charts),
)
assert len(data["result"]) == len(energy_charts)
for chart_result in data["result"]:
assert "chart_id" in chart_result
assert "viz_error" in chart_result
@ -2439,12 +2436,9 @@ class TestDatasetApi(SupersetTestCase):
"dashboard_id": dashboard.id,
},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
len(data["result"]),
len(birth_charts),
)
assert len(data["result"]) == len(birth_charts)
for chart_result in data["result"]:
assert "chart_id" in chart_result
assert "viz_error" in chart_result
@ -2462,12 +2456,9 @@ class TestDatasetApi(SupersetTestCase):
),
},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
len(data["result"]),
len(birth_charts),
)
assert len(data["result"]) == len(birth_charts)
for chart_result in data["result"]:
assert "chart_id" in chart_result
assert "viz_error" in chart_result
@ -2476,17 +2467,14 @@ class TestDatasetApi(SupersetTestCase):
def test_warm_up_cache_db_and_table_name_required(self):
self.login(ADMIN_USERNAME)
rv = self.client.put("/api/v1/dataset/warm_up_cache", json={"dashboard_id": 1})
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data,
{
"message": {
"db_name": ["Missing data for required field."],
"table_name": ["Missing data for required field."],
}
},
)
assert data == {
"message": {
"db_name": ["Missing data for required field."],
"table_name": ["Missing data for required field."],
}
}
def test_warm_up_cache_table_not_found(self):
self.login(ADMIN_USERNAME)
@ -2494,9 +2482,8 @@ class TestDatasetApi(SupersetTestCase):
"/api/v1/dataset/warm_up_cache",
json={"table_name": "not_here", "db_name": "abc"},
)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data,
{"message": "The provided table was not found in the provided database"},
)
assert data == {
"message": "The provided table was not found in the provided database"
}

View File

@ -587,8 +587,8 @@ class TestCreateDatasetCommand(SupersetTestCase):
.filter_by(table_name="test_create_dataset_command")
.one()
)
self.assertEqual(table, fetched_table)
self.assertEqual([owner.username for owner in table.owners], ["admin"])
assert table == fetched_table
assert [owner.username for owner in table.owners] == ["admin"]
db.session.delete(table)
with examples_db.get_sqla_engine() as engine:
@ -626,7 +626,7 @@ class TestDatasetWarmUpCacheCommand(SupersetTestCase):
results = DatasetWarmUpCacheCommand(
get_example_database().database_name, "birth_names", None, None
).run()
self.assertEqual(len(results), len(birth_charts))
assert len(results) == len(birth_charts)
for chart_result in results:
assert "chart_id" in chart_result
assert "viz_error" in chart_result

View File

@ -41,7 +41,7 @@ class TestDatasourceApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
table = self.get_virtual_dataset()
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col1/values/")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
for val in range(10):
assert val in response["result"]
@ -51,7 +51,7 @@ class TestDatasourceApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
table = self.get_virtual_dataset()
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
for val in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]:
assert val in response["result"]
@ -61,7 +61,7 @@ class TestDatasourceApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
table = self.get_virtual_dataset()
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col3/values/")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
for val in [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]:
assert val in response["result"]
@ -71,16 +71,16 @@ class TestDatasourceApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
table = self.get_virtual_dataset()
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col4/values/")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["result"], [None])
assert response["result"] == [None]
@pytest.mark.usefixtures("app_context", "virtual_dataset")
def test_get_column_values_integers_with_nulls(self):
self.login(ADMIN_USERNAME)
table = self.get_virtual_dataset()
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col6/values/")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
for val in [1, None, 3, 4, 5, 6, 7, 8, 9, 10]:
assert val in response["result"]
@ -92,27 +92,27 @@ class TestDatasourceApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/datasource/not_table/{table.id}/column/col1/values/"
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["message"], "Invalid datasource type: not_table")
assert response["message"] == "Invalid datasource type: not_table"
@patch("superset.datasource.api.DatasourceDAO.get_datasource")
def test_get_column_values_datasource_type_not_supported(self, get_datasource_mock):
get_datasource_mock.side_effect = DatasourceTypeNotSupportedError
self.login(ADMIN_USERNAME)
rv = self.client.get("api/v1/datasource/table/1/column/col1/values/")
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response["message"], "DAO datasource query source type is not supported"
assert (
response["message"] == "DAO datasource query source type is not supported"
)
def test_get_column_values_datasource_not_found(self):
self.login(ADMIN_USERNAME)
rv = self.client.get("api/v1/datasource/table/999/column/col1/values/")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["message"], "Datasource does not exist")
assert response["message"] == "Datasource does not exist"
@pytest.mark.usefixtures("app_context", "virtual_dataset")
def test_get_column_values_no_datasource_access(self):
@ -126,12 +126,11 @@ class TestDatasourceApi(SupersetTestCase):
self.login(GAMMA_USERNAME)
table = self.get_virtual_dataset()
rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col1/values/")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response["message"],
f"This endpoint requires the datasource {table.id}, "
"database or `all_datasource_access` permission",
assert (
response["message"] == f"This endpoint requires the datasource {table.id}, "
"database or `all_datasource_access` permission"
)
@pytest.mark.usefixtures("app_context", "virtual_dataset")
@ -188,9 +187,9 @@ class TestDatasourceApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/datasource/table/{table.id}/column/col2/values/"
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["result"], ["b"])
assert response["result"] == ["b"]
@pytest.mark.usefixtures("app_context", "virtual_dataset")
def test_get_column_values_with_rls_no_values(self):
@ -202,6 +201,6 @@ class TestDatasourceApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/datasource/table/{table.id}/column/col2/values/"
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["result"], [])
assert response["result"] == []

View File

@ -101,9 +101,15 @@ class TestDatasource(SupersetTestCase):
url = f"/datasource/external_metadata/table/{tbl.id}/"
resp = self.get_json_resp(url)
col_names = {o.get("column_name") for o in resp}
self.assertEqual(
col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"}
)
assert col_names == {
"num_boys",
"num",
"gender",
"name",
"ds",
"state",
"num_girls",
}
def test_always_filter_main_dttm(self):
database = get_example_database()
@ -175,9 +181,15 @@ class TestDatasource(SupersetTestCase):
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
col_names = {o.get("column_name") for o in resp}
self.assertEqual(
col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"}
)
assert col_names == {
"num_boys",
"num",
"gender",
"name",
"ds",
"state",
"num_girls",
}
def test_external_metadata_by_name_for_virtual_table(self):
self.login(ADMIN_USERNAME)
@ -235,7 +247,7 @@ class TestDatasource(SupersetTestCase):
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
col_names = {o.get("column_name") for o in resp}
self.assertEqual(col_names, {"first", "second"})
assert col_names == {"first", "second"}
# No databases found
params = prison.dumps(
@ -249,10 +261,10 @@ class TestDatasource(SupersetTestCase):
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.client.get(url)
self.assertEqual(resp.status_code, DatasetNotFoundError.status)
self.assertEqual(
json.loads(resp.data.decode("utf-8")).get("error"),
DatasetNotFoundError.message,
assert resp.status_code == DatasetNotFoundError.status
assert (
json.loads(resp.data.decode("utf-8")).get("error")
== DatasetNotFoundError.message
)
# No table found
@ -267,10 +279,10 @@ class TestDatasource(SupersetTestCase):
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.client.get(url)
self.assertEqual(resp.status_code, DatasetNotFoundError.status)
self.assertEqual(
json.loads(resp.data.decode("utf-8")).get("error"),
DatasetNotFoundError.message,
assert resp.status_code == DatasetNotFoundError.status
assert (
json.loads(resp.data.decode("utf-8")).get("error")
== DatasetNotFoundError.message
)
# invalid query params
@ -281,7 +293,7 @@ class TestDatasource(SupersetTestCase):
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
self.assertIn("error", resp)
assert "error" in resp
def test_external_metadata_for_virtual_table_template_params(self):
self.login(ADMIN_USERNAME)
@ -308,7 +320,7 @@ class TestDatasource(SupersetTestCase):
with db_insert_temp_object(table):
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
self.assertEqual(resp["error"], "Only `SELECT` statements are allowed")
assert resp["error"] == "Only `SELECT` statements are allowed"
def test_external_metadata_for_multistatement_virtual_table(self):
self.login(ADMIN_USERNAME)
@ -322,7 +334,7 @@ class TestDatasource(SupersetTestCase):
with db_insert_temp_object(table):
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
self.assertEqual(resp["error"], "Only single queries supported")
assert resp["error"] == "Only single queries supported"
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch("superset.connectors.sqla.models.SqlaTable.external_metadata")
@ -350,7 +362,7 @@ class TestDatasource(SupersetTestCase):
obj2 = l2_lookup.get(obj1.get(key))
for k in obj1:
if k not in "id" and obj1.get(k):
self.assertEqual(obj1.get(k), obj2.get(k))
assert obj1.get(k) == obj2.get(k)
def test_save(self):
self.login(ADMIN_USERNAME)
@ -367,11 +379,11 @@ class TestDatasource(SupersetTestCase):
elif k == "metrics":
self.compare_lists(datasource_post[k], resp[k], "metric_name")
elif k == "database":
self.assertEqual(resp[k]["id"], datasource_post[k]["id"])
assert resp[k]["id"] == datasource_post[k]["id"]
elif k == "owners":
self.assertEqual([o["id"] for o in resp[k]], datasource_post["owners"])
assert [o["id"] for o in resp[k]] == datasource_post["owners"]
else:
self.assertEqual(resp[k], datasource_post[k])
assert resp[k] == datasource_post[k]
def test_save_default_endpoint_validation_success(self):
self.login(ADMIN_USERNAME)
@ -404,11 +416,11 @@ class TestDatasource(SupersetTestCase):
new_db = self.create_fake_db()
datasource_post["database"]["id"] = new_db.id
resp = self.save_datasource_from_dict(datasource_post)
self.assertEqual(resp["database"]["id"], new_db.id)
assert resp["database"]["id"] == new_db.id
datasource_post["database"]["id"] = db_id
resp = self.save_datasource_from_dict(datasource_post)
self.assertEqual(resp["database"]["id"], db_id)
assert resp["database"]["id"] == db_id
self.delete_fake_db()
@ -440,7 +452,7 @@ class TestDatasource(SupersetTestCase):
)
data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False)
self.assertIn("Duplicate column name(s): <new column>", resp["error"])
assert "Duplicate column name(s): <new column>" in resp["error"]
def test_get_datasource(self):
admin_user = self.get_user("admin")
@ -454,21 +466,18 @@ class TestDatasource(SupersetTestCase):
self.get_json_resp("/datasource/save/", data)
url = f"/datasource/get/{tbl.type}/{tbl.id}/"
resp = self.get_json_resp(url)
self.assertEqual(resp.get("type"), "table")
assert resp.get("type") == "table"
col_names = {o.get("column_name") for o in resp["columns"]}
self.assertEqual(
col_names,
{
"num_boys",
"num",
"gender",
"name",
"ds",
"state",
"num_girls",
"num_california",
},
)
assert col_names == {
"num_boys",
"num",
"gender",
"name",
"ds",
"state",
"num_girls",
"num_california",
}
def test_get_datasource_with_health_check(self):
def my_check(datasource):
@ -491,7 +500,7 @@ class TestDatasource(SupersetTestCase):
self.login(ADMIN_USERNAME)
resp = self.get_json_resp("/datasource/get/table/500000/", raise_on_error=False)
self.assertEqual(resp.get("error"), "Datasource does not exist")
assert resp.get("error") == "Datasource does not exist"
def test_get_datasource_invalid_datasource_failed(self):
from superset.daos.datasource import DatasourceDAO
@ -503,7 +512,7 @@ class TestDatasource(SupersetTestCase):
self.login(ADMIN_USERNAME)
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
assert resp.get("error") == "'druid' is not a valid DatasourceType"
def test_get_samples(test_client, login_as_admin, virtual_dataset):

View File

@ -22,11 +22,11 @@ class TestAscendDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
AscendEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
assert (
AscendEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
AscendEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
assert (
AscendEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"
)

View File

@ -61,18 +61,18 @@ class TestDbEngineSpecs(TestDbEngineSpec):
q10 = "select * from mytable limit 20, x"
q11 = "select * from mytable limit x offset 20"
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
assert engine_spec_class.get_limit_from_sql(q0) is None
assert engine_spec_class.get_limit_from_sql(q1) == 10
assert engine_spec_class.get_limit_from_sql(q2) == 20
assert engine_spec_class.get_limit_from_sql(q3) is None
assert engine_spec_class.get_limit_from_sql(q4) == 20
assert engine_spec_class.get_limit_from_sql(q5) == 10
assert engine_spec_class.get_limit_from_sql(q6) == 10
assert engine_spec_class.get_limit_from_sql(q7) is None
assert engine_spec_class.get_limit_from_sql(q8) is None
assert engine_spec_class.get_limit_from_sql(q9) is None
assert engine_spec_class.get_limit_from_sql(q10) is None
assert engine_spec_class.get_limit_from_sql(q11) is None
def test_wrapped_semi_tabs(self):
self.sql_limit_regex(
@ -141,7 +141,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
)
def test_get_datatype(self):
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
assert "VARCHAR" == BaseEngineSpec.get_datatype("VARCHAR")
def test_limit_with_implicit_offset(self):
self.sql_limit_regex(
@ -198,29 +198,26 @@ class TestDbEngineSpecs(TestDbEngineSpec):
for engine in load_engine_specs():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
assert len(engine.get_time_grain_expressions()) > 0
# make sure all defined time grains are supported
defined_grains = {grain.duration for grain in engine.get_time_grains()}
intersection = time_grains.intersection(defined_grains)
self.assertSetEqual(defined_grains, intersection, engine)
self.assertSetEqual(defined_grains, intersection, engine) # noqa: PT009
def test_get_time_grain_expressions(self):
time_grains = MySQLEngineSpec.get_time_grain_expressions()
self.assertEqual(
list(time_grains.keys()),
[
None,
"PT1S",
"PT1M",
"PT1H",
"P1D",
"P1W",
"P1M",
"P3M",
"P1Y",
"1969-12-29T00:00:00Z/P1W",
],
)
assert list(time_grains.keys()) == [
None,
"PT1S",
"PT1M",
"PT1H",
"P1D",
"P1W",
"P1M",
"P3M",
"P1Y",
"1969-12-29T00:00:00Z/P1W",
]
def test_get_table_names(self):
inspector = mock.Mock()
@ -255,11 +252,11 @@ class TestDbEngineSpecs(TestDbEngineSpec):
expected = ["STRING", "STRING", "FLOAT"]
else:
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
self.assertEqual(col_names, expected)
assert col_names == expected
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm, db_extra=None))
assert BaseEngineSpec.convert_dttm("", dttm, db_extra=None) is None
def test_pyodbc_rows_to_tuples(self):
# Test for case when pyodbc.Row is returned (odbc driver)
@ -272,7 +269,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, expected)
self.assertListEqual(result, expected) # noqa: PT009
def test_pyodbc_rows_to_tuples_passthrough(self):
# Test for case when tuples are returned
@ -281,7 +278,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, data)
self.assertListEqual(result, data) # noqa: PT009
@mock.patch("superset.models.core.Database.db_engine_spec", BaseEngineSpec)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")

View File

@ -33,4 +33,4 @@ class TestDbEngineSpec(SupersetTestCase):
):
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force)
self.assertEqual(expected_sql, limited)
assert expected_sql == limited

View File

@ -45,7 +45,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
}
for original, expected in test_cases.items():
actual = BigQueryEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)
assert actual == expected
def test_timegrain_expressions(self):
"""
@ -63,7 +63,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
actual = BigQueryEngineSpec.get_timestamp_expr(
col=col, pdf=None, time_grain="PT1H"
)
self.assertEqual(str(actual), expected)
assert str(actual) == expected
def test_custom_minute_timegrain_expressions(self):
"""
@ -104,12 +104,12 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
data1 = [(1, "foo")]
with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data1):
result = BigQueryEngineSpec.fetch_data(None, 0)
self.assertEqual(result, data1)
assert result == data1
data2 = [Row(1), Row(2)]
with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data2):
result = BigQueryEngineSpec.fetch_data(None, 0)
self.assertEqual(result, [1, 2])
assert result == [1, 2]
def test_get_extra_table_metadata(self):
"""
@ -122,7 +122,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
database,
Table("some_table", "some_schema"),
)
self.assertEqual(result, {})
assert result == {}
index_metadata = [
{
@ -143,7 +143,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
database,
Table("some_table", "some_schema"),
)
self.assertEqual(result, expected_result)
assert result == expected_result
def test_get_indexes(self):
database = mock.Mock()

View File

@ -40,4 +40,4 @@ class TestElasticsearchDbEngineSpec(TestDbEngineSpec):
actual = ElasticSearchEngineSpec.get_timestamp_expr(
col=col, pdf=None, time_grain=time_grain
)
self.assertEqual(str(actual), expected_time_grain_expression)
assert str(actual) == expected_time_grain_expression

View File

@ -30,8 +30,8 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
)
def test_get_datatype_mysql(self):
"""Tests related to datatype mapping for MySQL"""
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
assert "TINY" == MySQLEngineSpec.get_datatype(1)
assert "VARCHAR" == MySQLEngineSpec.get_datatype(15)
def test_column_datatype_to_string(self):
test_cases = (
@ -49,7 +49,7 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
actual = MySQLEngineSpec.column_datatype_to_string(
original, mysql.dialect()
)
self.assertEqual(actual, expected)
assert actual == expected
def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError

View File

@ -32,20 +32,14 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
+ "DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', "
+ "'1:SECONDS:EPOCH', '1:SECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
)
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_simple_date_format_1d_grain(self):
col = column("tstamp")
expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1D")
result = str(expr.compile())
expected = "CAST(DATE_TRUNC('day', CAST(tstamp AS TIMESTAMP)) AS TIMESTAMP)"
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_simple_date_format_10m_grain(self):
col = column("tstamp")
@ -55,20 +49,14 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
"CAST(ROUND(DATE_TRUNC('minute', CAST(tstamp AS "
+ "TIMESTAMP)), 600000) AS TIMESTAMP)"
)
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_simple_date_format_1w_grain(self):
col = column("tstamp")
expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1W")
result = str(expr.compile())
expected = "CAST(DATE_TRUNC('week', CAST(tstamp AS TIMESTAMP)) AS TIMESTAMP)"
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_sec_one_1m_grain(self):
col = column("tstamp")
@ -79,10 +67,7 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
+ "DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', "
+ "'1:SECONDS:EPOCH', '1:SECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
)
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_millisec_one_1m_grain(self):
col = column("tstamp")
@ -93,10 +78,7 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
+ "DATETIMECONVERT(tstamp, '1:MILLISECONDS:EPOCH', "
+ "'1:MILLISECONDS:EPOCH', '1:MILLISECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
)
self.assertEqual(
result,
expected,
)
assert result == expected
def test_invalid_get_time_expression_arguments(self):
with self.assertRaises(NotImplementedError):

View File

@ -57,7 +57,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = literal_column("COALESCE(a, b)")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "COALESCE(a, b)")
assert result == "COALESCE(a, b)"
def test_time_exp_literal_1y_grain(self):
"""
@ -66,7 +66,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = literal_column("COALESCE(a, b)")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
assert result == "DATE_TRUNC('year', COALESCE(a, b))"
def test_time_ex_lowr_col_no_grain(self):
"""
@ -75,7 +75,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = column("lower_case")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "lower_case")
assert result == "lower_case"
def test_time_exp_lowr_col_sec_1y(self):
"""
@ -84,10 +84,9 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = column("lower_case")
expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(
result,
"DATE_TRUNC('year', "
"(timestamp 'epoch' + lower_case * interval '1 second'))",
assert (
result == "DATE_TRUNC('year', "
"(timestamp 'epoch' + lower_case * interval '1 second'))"
)
def test_time_exp_mixed_case_col_1y(self):
@ -97,7 +96,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = column("MixedCase")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
assert result == "DATE_TRUNC('year', \"MixedCase\")"
def test_empty_dbapi_cursor_description(self):
"""
@ -107,7 +106,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
# empty description mean no columns, this mocks the following SQL: "SELECT"
cursor.description = []
results = PostgresEngineSpec.fetch_data(cursor, 1000)
self.assertEqual(results, [])
assert results == []
def test_engine_alias_name(self):
"""
@ -158,13 +157,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
)
sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
self.assertEqual(
results,
{
"Start-up cost": 0.00,
"Total cost": 1537.91,
},
)
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
def test_estimate_statement_invalid_syntax(self):
"""
@ -199,19 +192,10 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
},
]
result = PostgresEngineSpec.query_cost_formatter(raw_cost)
self.assertEqual(
result,
[
{
"Start-up cost": "0.0",
"Total cost": "1537.91",
},
{
"Start-up cost": "10.0",
"Total cost": "1537.0",
},
],
)
assert result == [
{"Start-up cost": "0.0", "Total cost": "1537.91"},
{"Start-up cost": "10.0", "Total cost": "1537.0"},
]
def test_extract_errors(self):
"""

View File

@ -33,7 +33,7 @@ from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestPrestoDbEngineSpec(TestDbEngineSpec):
@skipUnless(TestDbEngineSpec.is_module_installed("pyhive"), "pyhive not installed")
def test_get_datatype_presto(self):
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
assert "STRING" == PrestoEngineSpec.get_datatype("string")
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
@ -86,10 +86,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
row.Column, row.Type, row.Null = column
inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row])
results = PrestoEngineSpec.get_columns(inspector, Table("", ""))
self.assertEqual(len(expected_results), len(results))
assert len(expected_results) == len(results)
for expected_result, result in zip(expected_results, results):
self.assertEqual(expected_result[0], result["column_name"])
self.assertEqual(expected_result[1], str(result["type"]))
assert expected_result[0] == result["column_name"]
assert expected_result[1] == str(result["type"])
def test_presto_get_column(self):
presto_column = ("column_name", "boolean", "")
@ -192,8 +192,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
},
]
for actual_result, expected_result in zip(actual_results, expected_results):
self.assertEqual(actual_result.element.name, expected_result["column_name"])
self.assertEqual(actual_result.name, expected_result["label"])
assert actual_result.element.name == expected_result["column_name"]
assert actual_result.name == expected_result["label"]
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -260,9 +260,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"is_dttm": False,
}
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -343,9 +343,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"is_dttm": False,
},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -427,9 +427,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"is_dttm": False,
},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -548,9 +548,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"is_dttm": False,
},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
def test_presto_get_extra_table_metadata(self):
database = mock.Mock()
@ -582,7 +582,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
columns,
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
def test_query_cost_formatter(self):
raw_cost = [
@ -645,7 +645,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"Network cost": "354 G",
}
]
self.assertEqual(formatted_cost, expected)
assert formatted_cost == expected
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -752,9 +752,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")

View File

@ -91,36 +91,32 @@ class TestDictImportExport(SupersetTestCase):
def yaml_compare(self, obj_1, obj_2):
obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False)
obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False)
self.assertEqual(obj_1_str, obj_2_str)
assert obj_1_str == obj_2_str
def assert_table_equals(self, expected_ds, actual_ds):
self.assertEqual(expected_ds.table_name, actual_ds.table_name)
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
self.assertEqual(expected_ds.schema, actual_ds.schema)
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
{c.column_name for c in expected_ds.columns},
{c.column_name for c in actual_ds.columns},
)
self.assertEqual(
{m.metric_name for m in expected_ds.metrics},
{m.metric_name for m in actual_ds.metrics},
)
assert expected_ds.table_name == actual_ds.table_name
assert expected_ds.main_dttm_col == actual_ds.main_dttm_col
assert expected_ds.schema == actual_ds.schema
assert len(expected_ds.metrics) == len(actual_ds.metrics)
assert len(expected_ds.columns) == len(actual_ds.columns)
assert {c.column_name for c in expected_ds.columns} == {
c.column_name for c in actual_ds.columns
}
assert {m.metric_name for m in expected_ds.metrics} == {
m.metric_name for m in actual_ds.metrics
}
def assert_datasource_equals(self, expected_ds, actual_ds):
self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name)
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
{c.column_name for c in expected_ds.columns},
{c.column_name for c in actual_ds.columns},
)
self.assertEqual(
{m.metric_name for m in expected_ds.metrics},
{m.metric_name for m in actual_ds.metrics},
)
assert expected_ds.datasource_name == actual_ds.datasource_name
assert expected_ds.main_dttm_col == actual_ds.main_dttm_col
assert len(expected_ds.metrics) == len(actual_ds.metrics)
assert len(expected_ds.columns) == len(actual_ds.columns)
assert {c.column_name for c in expected_ds.columns} == {
c.column_name for c in actual_ds.columns
}
assert {m.metric_name for m in expected_ds.metrics} == {
m.metric_name for m in actual_ds.metrics
}
def test_import_table_no_metadata(self):
table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
@ -143,8 +139,8 @@ class TestDictImportExport(SupersetTestCase):
db.session.commit()
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
self.assertEqual(
{DBREF: ID_PREFIX + 2, "database_name": "main"}, json.loads(imported.params)
assert {DBREF: ID_PREFIX + 2, "database_name": "main"} == json.loads(
imported.params
)
self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
@ -178,7 +174,7 @@ class TestDictImportExport(SupersetTestCase):
db.session.commit()
imported_over = self.get_table_by_id(imported_over_table.id)
self.assertEqual(imported_table.id, imported_over.id)
assert imported_table.id == imported_over.id
expected_table, _ = self.create_table(
"table_override",
id=ID_PREFIX + 3,
@ -209,7 +205,7 @@ class TestDictImportExport(SupersetTestCase):
db.session.commit()
imported_over = self.get_table_by_id(imported_over_table.id)
self.assertEqual(imported_table.id, imported_over.id)
assert imported_table.id == imported_over.id
expected_table, _ = self.create_table(
"table_override",
id=ID_PREFIX + 3,
@ -239,7 +235,7 @@ class TestDictImportExport(SupersetTestCase):
)
imported_copy_table = SqlaTable.import_from_dict(dict_copy_table)
db.session.commit()
self.assertEqual(imported_table.id, imported_copy_table.id)
assert imported_table.id == imported_copy_table.id
self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
self.yaml_compare(
imported_copy_table.export_to_dict(), imported_table.export_to_dict()
@ -259,12 +255,12 @@ class TestDictImportExport(SupersetTestCase):
"/databaseview/action_post", {"action": "yaml_export", "rowid": 1}
)
ui_export = yaml.safe_load(resp)
self.assertEqual(
ui_export["databases"][0]["database_name"],
cli_export["databases"][0]["database_name"],
assert (
ui_export["databases"][0]["database_name"]
== cli_export["databases"][0]["database_name"]
)
self.assertEqual(
ui_export["databases"][0]["tables"], cli_export["databases"][0]["tables"]
assert (
ui_export["databases"][0]["tables"] == cli_export["databases"][0]["tables"]
)

View File

@ -28,7 +28,7 @@ class TestDynamicPlugins(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "/dynamic-plugins/api"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@with_feature_flags(DYNAMIC_PLUGINS=True)
def test_dynamic_plugins_enabled(self):
@ -38,4 +38,4 @@ class TestDynamicPlugins(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "/dynamic-plugins/api"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200

View File

@ -222,7 +222,7 @@ class TestEmailSmtp(SupersetTestCase):
app.config["SMTP_HOST"], app.config["SMTP_PORT"], context=mock.ANY
)
called_context = mock_smtp_ssl.call_args.kwargs["context"]
self.assertEqual(called_context.verify_mode, ssl.CERT_REQUIRED)
assert called_context.verify_mode == ssl.CERT_REQUIRED
@mock.patch("smtplib.SMTP")
def test_send_mime_tls_server_auth(self, mock_smtp):
@ -233,7 +233,7 @@ class TestEmailSmtp(SupersetTestCase):
utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False)
mock_smtp.return_value.starttls.assert_called_with(context=mock.ANY)
called_context = mock_smtp.return_value.starttls.call_args.kwargs["context"]
self.assertEqual(called_context.verify_mode, ssl.CERT_REQUIRED)
assert called_context.verify_mode == ssl.CERT_REQUIRED
@mock.patch("smtplib.SMTP_SSL")
@mock.patch("smtplib.SMTP")

View File

@ -36,13 +36,13 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
db.session.flush()
assert dash.embedded
self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"])
assert dash.embedded[0].allowed_domains == ["test.example.com"]
original_uuid = dash.embedded[0].uuid
self.assertIsNotNone(original_uuid)
assert original_uuid is not None
EmbeddedDashboardDAO.upsert(dash, [])
db.session.flush()
self.assertEqual(dash.embedded[0].allowed_domains, [])
self.assertEqual(dash.embedded[0].uuid, original_uuid)
assert dash.embedded[0].allowed_domains == []
assert dash.embedded[0].uuid == original_uuid
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_get_by_uuid(self):
@ -51,4 +51,4 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
db.session.flush()
uuid = str(dash.embedded[0].uuid)
embedded = EmbeddedDashboardDAO.find_by_id(uuid)
self.assertIsNotNone(embedded)
assert embedded is not None

View File

@ -39,7 +39,7 @@ class TestEventLogger(unittest.TestCase):
# unmodified object
obj = DBEventLogger()
res = get_event_logger_from_cfg_value(obj)
self.assertIs(obj, res)
assert obj is res
def test_config_class_deprecation(self):
# test that assignment of a class object to EVENT_LOGGER is correctly
@ -51,7 +51,7 @@ class TestEventLogger(unittest.TestCase):
res = get_event_logger_from_cfg_value(DBEventLogger)
# class is instantiated and returned
self.assertIsInstance(res, DBEventLogger)
assert isinstance(res, DBEventLogger)
def test_raises_typeerror_if_not_abc(self):
# test that assignment of non AbstractEventLogger derived type raises
@ -71,19 +71,16 @@ class TestEventLogger(unittest.TestCase):
with app.test_request_context("/superset/dashboard/1/?myparam=foo"):
result = test_func()
payload = mock_log.call_args[1]
self.assertEqual(result, 1)
self.assertEqual(
payload["records"],
[
{
"myparam": "foo",
"path": "/superset/dashboard/1/",
"url_rule": "/superset/dashboard/<dashboard_id_or_slug>/",
"object_ref": test_func.__qualname__,
}
],
)
self.assertGreaterEqual(payload["duration_ms"], 50)
assert result == 1
assert payload["records"] == [
{
"myparam": "foo",
"path": "/superset/dashboard/1/",
"url_rule": "/superset/dashboard/<dashboard_id_or_slug>/",
"object_ref": test_func.__qualname__,
}
]
assert payload["duration_ms"] >= 50
@patch.object(DBEventLogger, "log")
def test_log_this_with_extra_payload(self, mock_log):
@ -98,19 +95,16 @@ class TestEventLogger(unittest.TestCase):
with app.test_request_context():
result = test_func(1, karg1=2) # pylint: disable=no-value-for-parameter
payload = mock_log.call_args[1]
self.assertEqual(result, 2)
self.assertEqual(
payload["records"],
[
{
"foo": "bar",
"path": "/",
"karg1": 2,
"object_ref": test_func.__qualname__,
}
],
)
self.assertGreaterEqual(payload["duration_ms"], 100)
assert result == 2
assert payload["records"] == [
{
"foo": "bar",
"path": "/",
"karg1": 2,
"object_ref": test_func.__qualname__,
}
]
assert payload["duration_ms"] >= 100
@patch("superset.utils.core.g", spec={})
@freeze_time("Jan 14th, 2020", auto_tick_seconds=15)
@ -141,19 +135,16 @@ class TestEventLogger(unittest.TestCase):
with logger(action="foo", engine="bar"):
pass
self.assertEquals(
logger.records,
[
{
"records": [{"path": "/", "engine": "bar"}],
"database_id": None,
"user_id": 2,
"duration": 15000,
"curated_payload": {},
"curated_form_data": {},
}
],
)
assert logger.records == [
{
"records": [{"path": "/", "engine": "bar"}],
"database_id": None,
"user_id": 2,
"duration": 15000,
"curated_payload": {},
"curated_form_data": {},
}
]
@patch("superset.utils.core.g", spec={})
def test_context_manager_log_with_context(self, mock_g):
@ -188,25 +179,22 @@ class TestEventLogger(unittest.TestCase):
payload_override={"engine": "sqlite"},
)
self.assertEquals(
logger.records,
[
{
"records": [
{
"path": "/",
"object_ref": {"baz": "food"},
"payload_override": {"engine": "sqlite"},
}
],
"database_id": None,
"user_id": 2,
"duration": 5558756000,
"curated_payload": {},
"curated_form_data": {},
}
],
)
assert logger.records == [
{
"records": [
{
"path": "/",
"object_ref": {"baz": "food"},
"payload_override": {"engine": "sqlite"},
}
],
"database_id": None,
"user_id": 2,
"duration": 5558756000,
"curated_payload": {},
"curated_form_data": {},
}
]
@patch("superset.utils.core.g", spec={})
def test_log_with_context_user_null(self, mock_g):

View File

@ -24,13 +24,13 @@ class TestForm(SupersetTestCase):
def test_comma_separated_list_field(self):
field = CommaSeparatedListField().bind(Form(), "foo")
field.process_formdata([""])
self.assertEqual(field.data, [""])
assert field.data == [""]
field.process_formdata(["a,comma,separated,list"])
self.assertEqual(field.data, ["a", "comma", "separated", "list"])
assert field.data == ["a", "comma", "separated", "list"]
def test_filter_not_empty_values(self):
self.assertEqual(filter_not_empty_values(None), None)
self.assertEqual(filter_not_empty_values([]), None)
self.assertEqual(filter_not_empty_values([""]), None)
self.assertEqual(filter_not_empty_values(["hi"]), ["hi"])
assert filter_not_empty_values(None) is None
assert filter_not_empty_values([]) is None
assert filter_not_empty_values([""]) is None
assert filter_not_empty_values(["hi"]) == ["hi"]

View File

@ -148,52 +148,48 @@ class TestImportExport(SupersetTestCase):
self, expected_dash, actual_dash, check_position=True, check_slugs=True
):
if check_slugs:
self.assertEqual(expected_dash.slug, actual_dash.slug)
self.assertEqual(expected_dash.dashboard_title, actual_dash.dashboard_title)
self.assertEqual(len(expected_dash.slices), len(actual_dash.slices))
assert expected_dash.slug == actual_dash.slug
assert expected_dash.dashboard_title == actual_dash.dashboard_title
assert len(expected_dash.slices) == len(actual_dash.slices)
expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "")
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
for e_slc, a_slc in zip(expected_slices, actual_slices):
self.assert_slice_equals(e_slc, a_slc)
if check_position:
self.assertEqual(expected_dash.position_json, actual_dash.position_json)
assert expected_dash.position_json == actual_dash.position_json
def assert_table_equals(self, expected_ds, actual_ds):
self.assertEqual(expected_ds.table_name, actual_ds.table_name)
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
self.assertEqual(expected_ds.schema, actual_ds.schema)
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
{c.column_name for c in expected_ds.columns},
{c.column_name for c in actual_ds.columns},
)
self.assertEqual(
{m.metric_name for m in expected_ds.metrics},
{m.metric_name for m in actual_ds.metrics},
)
assert expected_ds.table_name == actual_ds.table_name
assert expected_ds.main_dttm_col == actual_ds.main_dttm_col
assert expected_ds.schema == actual_ds.schema
assert len(expected_ds.metrics) == len(actual_ds.metrics)
assert len(expected_ds.columns) == len(actual_ds.columns)
assert {c.column_name for c in expected_ds.columns} == {
c.column_name for c in actual_ds.columns
}
assert {m.metric_name for m in expected_ds.metrics} == {
m.metric_name for m in actual_ds.metrics
}
def assert_datasource_equals(self, expected_ds, actual_ds):
self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name)
self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
{c.column_name for c in expected_ds.columns},
{c.column_name for c in actual_ds.columns},
)
self.assertEqual(
{m.metric_name for m in expected_ds.metrics},
{m.metric_name for m in actual_ds.metrics},
)
assert expected_ds.datasource_name == actual_ds.datasource_name
assert expected_ds.main_dttm_col == actual_ds.main_dttm_col
assert len(expected_ds.metrics) == len(actual_ds.metrics)
assert len(expected_ds.columns) == len(actual_ds.columns)
assert {c.column_name for c in expected_ds.columns} == {
c.column_name for c in actual_ds.columns
}
assert {m.metric_name for m in expected_ds.metrics} == {
m.metric_name for m in actual_ds.metrics
}
def assert_slice_equals(self, expected_slc, actual_slc):
# to avoid bad slice data (no slice_name)
expected_slc_name = expected_slc.slice_name or ""
actual_slc_name = actual_slc.slice_name or ""
self.assertEqual(expected_slc_name, actual_slc_name)
self.assertEqual(expected_slc.datasource_type, actual_slc.datasource_type)
self.assertEqual(expected_slc.viz_type, actual_slc.viz_type)
assert expected_slc_name == actual_slc_name
assert expected_slc.datasource_type == actual_slc.datasource_type
assert expected_slc.viz_type == actual_slc.viz_type
exp_params = json.loads(expected_slc.params)
actual_params = json.loads(actual_slc.params)
diff_params_keys = (
@ -208,7 +204,7 @@ class TestImportExport(SupersetTestCase):
actual_params.pop(k)
if k in exp_params:
exp_params.pop(k)
self.assertEqual(exp_params, actual_params)
assert exp_params == actual_params
def assert_only_exported_slc_fields(self, expected_dash, actual_dash):
"""only exported json has this params
@ -218,9 +214,9 @@ class TestImportExport(SupersetTestCase):
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
for e_slc, a_slc in zip(expected_slices, actual_slices):
params = a_slc.params_dict
self.assertEqual(e_slc.datasource.name, params["datasource_name"])
self.assertEqual(e_slc.datasource.schema, params["schema"])
self.assertEqual(e_slc.datasource.database.name, params["database_name"])
assert e_slc.datasource.name == params["datasource_name"]
assert e_slc.datasource.schema == params["schema"]
assert e_slc.datasource.database.name == params["database_name"]
@unittest.skip("Schema needs to be updated")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -237,17 +233,17 @@ class TestImportExport(SupersetTestCase):
birth_dash = self.get_dash_by_slug("births")
self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0])
self.assert_dash_equals(birth_dash, exported_dashboards[0])
self.assertEqual(
id_,
json.loads(
assert (
id_
== json.loads(
exported_dashboards[0].json_metadata, object_hook=decode_dashboards
)["remote_id"],
)["remote_id"]
)
exported_tables = json.loads(
resp.data.decode("utf-8"), object_hook=decode_dashboards
)["datasources"]
self.assertEqual(1, len(exported_tables))
assert 1 == len(exported_tables)
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
@unittest.skip("Schema needs to be updated")
@ -269,27 +265,28 @@ class TestImportExport(SupersetTestCase):
exported_dashboards = sorted(
resp_data.get("dashboards"), key=lambda d: d.dashboard_title
)
self.assertEqual(2, len(exported_dashboards))
assert 2 == len(exported_dashboards)
birth_dash = self.get_dash_by_slug("births")
self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0])
self.assert_dash_equals(birth_dash, exported_dashboards[0])
self.assertEqual(
birth_dash.id, json.loads(exported_dashboards[0].json_metadata)["remote_id"]
assert (
birth_dash.id
== json.loads(exported_dashboards[0].json_metadata)["remote_id"]
)
world_health_dash = self.get_dash_by_slug("world_health")
self.assert_only_exported_slc_fields(world_health_dash, exported_dashboards[1])
self.assert_dash_equals(world_health_dash, exported_dashboards[1])
self.assertEqual(
world_health_dash.id,
json.loads(exported_dashboards[1].json_metadata)["remote_id"],
assert (
world_health_dash.id
== json.loads(exported_dashboards[1].json_metadata)["remote_id"]
)
exported_tables = sorted(
resp_data.get("datasources"), key=lambda t: t.table_name
)
self.assertEqual(2, len(exported_tables))
assert 2 == len(exported_tables)
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
self.assert_table_equals(
self.get_table(name="wb_health_population"), exported_tables[1]
@ -302,11 +299,11 @@ class TestImportExport(SupersetTestCase):
)
slc_id = import_chart(expected_slice, None, import_time=1989)
slc = self.get_slice(slc_id)
self.assertEqual(slc.datasource.perm, slc.perm)
assert slc.datasource.perm == slc.perm
self.assert_slice_equals(expected_slice, slc)
table_id = self.get_table(name="wb_health_population").id
self.assertEqual(table_id, self.get_slice(slc_id).datasource_id)
assert table_id == self.get_slice(slc_id).datasource_id
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_import_2_slices_for_same_table(self):
@ -323,13 +320,13 @@ class TestImportExport(SupersetTestCase):
imported_slc_1 = self.get_slice(slc_id_1)
imported_slc_2 = self.get_slice(slc_id_2)
self.assertEqual(table_id, imported_slc_1.datasource_id)
assert table_id == imported_slc_1.datasource_id
self.assert_slice_equals(slc_1, imported_slc_1)
self.assertEqual(imported_slc_1.datasource.perm, imported_slc_1.perm)
assert imported_slc_1.datasource.perm == imported_slc_1.perm
self.assertEqual(table_id, imported_slc_2.datasource_id)
assert table_id == imported_slc_2.datasource_id
self.assert_slice_equals(slc_2, imported_slc_2)
self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm)
assert imported_slc_2.datasource.perm == imported_slc_2.perm
def test_import_slices_override(self):
schema = get_example_default_schema()
@ -339,7 +336,7 @@ class TestImportExport(SupersetTestCase):
imported_slc_1 = self.get_slice(slc_1_id)
slc_2 = self.create_slice("Import Me New", id=10005, schema=schema)
slc_2_id = import_chart(slc_2, imported_slc_1, import_time=1990)
self.assertEqual(slc_1_id, slc_2_id)
assert slc_1_id == slc_2_id
imported_slc_2 = self.get_slice(slc_2_id)
self.assert_slice_equals(slc, imported_slc_2)
@ -379,21 +376,18 @@ class TestImportExport(SupersetTestCase):
self.assert_dash_equals(
expected_dash, imported_dash, check_position=False, check_slugs=False
)
self.assertEqual(
{
"remote_id": 10002,
"import_time": 1990,
"native_filter_configuration": [],
},
json.loads(imported_dash.json_metadata),
)
assert {
"remote_id": 10002,
"import_time": 1990,
"native_filter_configuration": [],
} == json.loads(imported_dash.json_metadata)
expected_position = dash_with_1_slice.position
# new slice id (auto-incremental) assigned on insert
# id from json is used only for updating position with new id
meta = expected_position["DASHBOARD_CHART_TYPE-10006"]["meta"]
meta["chartId"] = imported_dash.slices[0].id
self.assertEqual(expected_position, imported_dash.position)
assert expected_position == imported_dash.position
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_import_dashboard_2_slices(self):
@ -444,9 +438,7 @@ class TestImportExport(SupersetTestCase):
},
"native_filter_configuration": [],
}
self.assertEqual(
expected_json_metadata, json.loads(imported_dash.json_metadata)
)
assert expected_json_metadata == json.loads(imported_dash.json_metadata)
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_import_override_dashboard_2_slices(self):
@ -478,7 +470,7 @@ class TestImportExport(SupersetTestCase):
imported_dash_id_2 = import_dashboard(dash_to_import_override, import_time=1992)
# override doesn't change the id
self.assertEqual(imported_dash_id_1, imported_dash_id_2)
assert imported_dash_id_1 == imported_dash_id_2
expected_dash = self.create_dashboard(
"override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004
)
@ -487,20 +479,17 @@ class TestImportExport(SupersetTestCase):
self.assert_dash_equals(
expected_dash, imported_dash, check_position=False, check_slugs=False
)
self.assertEqual(
{
"remote_id": 10004,
"import_time": 1992,
"native_filter_configuration": [],
},
json.loads(imported_dash.json_metadata),
)
assert {
"remote_id": 10004,
"import_time": 1992,
"native_filter_configuration": [],
} == json.loads(imported_dash.json_metadata)
def test_import_new_dashboard_slice_reset_ownership(self):
admin_user = security_manager.find_user(username="admin")
self.assertTrue(admin_user)
assert admin_user
gamma_user = security_manager.find_user(username="gamma")
self.assertTrue(gamma_user)
assert gamma_user
g.user = gamma_user
dash_with_1_slice = self._create_dashboard_for_import(id_=10200)
@ -511,35 +500,35 @@ class TestImportExport(SupersetTestCase):
imported_dash_id = import_dashboard(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
assert imported_dash.created_by == gamma_user
assert imported_dash.changed_by == gamma_user
assert imported_dash.owners == [gamma_user]
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
assert imported_slc.created_by == gamma_user
assert imported_slc.changed_by == gamma_user
assert imported_slc.owners == [gamma_user]
@pytest.mark.skip
def test_import_override_dashboard_slice_reset_ownership(self):
admin_user = security_manager.find_user(username="admin")
self.assertTrue(admin_user)
assert admin_user
gamma_user = security_manager.find_user(username="gamma")
self.assertTrue(gamma_user)
assert gamma_user
g.user = gamma_user
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
imported_dash_id = import_dashboard(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
assert imported_dash.created_by == gamma_user
assert imported_dash.changed_by == gamma_user
assert imported_dash.owners == [gamma_user]
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
assert imported_slc.created_by == gamma_user
assert imported_slc.changed_by == gamma_user
assert imported_slc.owners == [gamma_user]
# re-import with another user shouldn't change the permissions
g.user = admin_user
@ -547,14 +536,14 @@ class TestImportExport(SupersetTestCase):
imported_dash_id = import_dashboard(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
assert imported_dash.created_by == gamma_user
assert imported_dash.changed_by == gamma_user
assert imported_dash.owners == [gamma_user]
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
assert imported_slc.created_by == gamma_user
assert imported_slc.changed_by == gamma_user
assert imported_slc.owners == [gamma_user]
def _create_dashboard_for_import(self, id_=10100):
slc = self.create_slice(
@ -600,10 +589,11 @@ class TestImportExport(SupersetTestCase):
imported_id = import_dataset(table, db_id, import_time=1990)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
self.assertEqual(
{"remote_id": 10002, "import_time": 1990, "database_name": "examples"},
json.loads(imported.params),
)
assert {
"remote_id": 10002,
"import_time": 1990,
"database_name": "examples",
} == json.loads(imported.params)
def test_import_table_2_col_2_met(self):
schema = get_example_default_schema()
@ -642,7 +632,7 @@ class TestImportExport(SupersetTestCase):
imported_over_id = import_dataset(table_over, db_id, import_time=1992)
imported_over = self.get_table_by_id(imported_over_id)
self.assertEqual(imported_id, imported_over.id)
assert imported_id == imported_over.id
expected_table = self.create_table(
"table_override",
id=10003,
@ -673,7 +663,7 @@ class TestImportExport(SupersetTestCase):
)
imported_id_copy = import_dataset(copy_table, db_id, import_time=1994)
self.assertEqual(imported_id, imported_id_copy)
assert imported_id == imported_id_copy
self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))

View File

@ -82,7 +82,7 @@ class TestLogApi(SupersetTestCase):
arguments = {"filters": [{"col": "action", "opr": "sw", "value": "some_"}]}
uri = f"api/v1/log/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_list(self):
"""
@ -94,11 +94,11 @@ class TestLogApi(SupersetTestCase):
arguments = {"filters": [{"col": "action", "opr": "sw", "value": "some_"}]}
uri = f"api/v1/log/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(list(response["result"][0].keys()), EXPECTED_COLUMNS)
self.assertEqual(response["result"][0]["action"], "some_action")
self.assertEqual(response["result"][0]["user"], {"username": "admin"})
assert list(response["result"][0].keys()) == EXPECTED_COLUMNS
assert response["result"][0]["action"] == "some_action"
assert response["result"][0]["user"] == {"username": "admin"}
db.session.delete(log)
db.session.commit()
@ -111,10 +111,10 @@ class TestLogApi(SupersetTestCase):
self.login(GAMMA_USERNAME)
uri = "api/v1/log/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
self.login(ALPHA_USERNAME)
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
db.session.delete(log)
db.session.commit()
@ -127,12 +127,12 @@ class TestLogApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/log/{log.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(list(response["result"].keys()), EXPECTED_COLUMNS)
self.assertEqual(response["result"]["action"], "some_action")
self.assertEqual(response["result"]["user"], {"username": "admin"})
assert list(response["result"].keys()) == EXPECTED_COLUMNS
assert response["result"]["action"] == "some_action"
assert response["result"]["user"] == {"username": "admin"}
db.session.delete(log)
db.session.commit()
@ -145,7 +145,7 @@ class TestLogApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/log/{log.id}"
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 405)
assert rv.status_code == 405
db.session.delete(log)
db.session.commit()
@ -160,7 +160,7 @@ class TestLogApi(SupersetTestCase):
log_data = {"action": "some_action"}
uri = f"api/v1/log/{log.id}"
rv = self.client.put(uri, json=log_data)
self.assertEqual(rv.status_code, 405)
assert rv.status_code == 405
db.session.delete(log)
db.session.commit()
@ -176,7 +176,7 @@ class TestLogApi(SupersetTestCase):
uri = f"api/v1/log/recent_activity/" # noqa: F541
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
db.session.delete(log1)
@ -184,21 +184,18 @@ class TestLogApi(SupersetTestCase):
db.session.delete(dash)
db.session.commit()
self.assertEqual(
response,
{
"result": [
{
"action": "dashboard",
"item_type": "dashboard",
"item_url": "/superset/dashboard/dash_slug/",
"item_title": "dash_title",
"time": ANY,
"time_delta_humanized": ANY,
}
]
},
)
assert response == {
"result": [
{
"action": "dashboard",
"item_type": "dashboard",
"item_url": "/superset/dashboard/dash_slug/",
"item_title": "dash_title",
"time": ANY,
"time_delta_humanized": ANY,
}
]
}
def test_get_recent_activity_actions_filter(self):
"""
@ -219,9 +216,9 @@ class TestLogApi(SupersetTestCase):
db.session.delete(dash)
db.session.commit()
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(response["result"]), 1)
assert len(response["result"]) == 1
def test_get_recent_activity_distinct_false(self):
"""
@ -243,9 +240,9 @@ class TestLogApi(SupersetTestCase):
db.session.delete(log2)
db.session.delete(dash)
db.session.commit()
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(response["result"]), 2)
assert len(response["result"]) == 2
def test_get_recent_activity_pagination(self):
"""
@ -269,31 +266,28 @@ class TestLogApi(SupersetTestCase):
uri = f"api/v1/log/recent_activity/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{
"result": [
{
"action": "dashboard",
"item_type": "dashboard",
"item_url": "/superset/dashboard/dash3_slug/",
"item_title": "dash3_title",
"time": ANY,
"time_delta_humanized": ANY,
},
{
"action": "dashboard",
"item_type": "dashboard",
"item_url": "/superset/dashboard/dash2_slug/",
"item_title": "dash2_title",
"time": ANY,
"time_delta_humanized": ANY,
},
]
},
)
assert response == {
"result": [
{
"action": "dashboard",
"item_type": "dashboard",
"item_url": "/superset/dashboard/dash3_slug/",
"item_title": "dash3_title",
"time": ANY,
"time_delta_humanized": ANY,
},
{
"action": "dashboard",
"item_type": "dashboard",
"item_url": "/superset/dashboard/dash2_slug/",
"item_title": "dash2_title",
"time": ANY,
"time_delta_humanized": ANY,
},
]
}
arguments = {"page": 1, "page_size": 2}
uri = f"api/v1/log/recent_activity/?q={prison.dumps(arguments)}"
@ -307,20 +301,17 @@ class TestLogApi(SupersetTestCase):
db.session.delete(dash3)
db.session.commit()
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{
"result": [
{
"action": "dashboard",
"item_type": "dashboard",
"item_url": "/superset/dashboard/dash_slug/",
"item_title": "dash_title",
"time": ANY,
"time_delta_humanized": ANY,
}
]
},
)
assert response == {
"result": [
{
"action": "dashboard",
"item_type": "dashboard",
"item_url": "/superset/dashboard/dash_slug/",
"item_title": "dash_title",
"time": ANY,
"time_delta_humanized": ANY,
}
]
}

View File

@ -52,4 +52,4 @@ class TestLoggingConfigurator(unittest.TestCase):
cfg.configure_logging(MagicMock(), True)
logging.info("test", extra={"testattr": "foo"})
self.assertTrue(handler.received)
assert handler.received

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import re
from superset.utils.core import DatasourceType
from superset.utils import json
import unittest
@ -59,22 +60,22 @@ class TestDatabaseModel(SupersetTestCase):
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("hive/default", db)
assert "hive/default" == db
with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("hive/core_db", db)
assert "hive/core_db" == db
sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("hive", db)
assert "hive" == db
with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("hive/core_db", db)
assert "hive/core_db" == db
def test_database_schema_postgres(self):
sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
@ -82,11 +83,11 @@ class TestDatabaseModel(SupersetTestCase):
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("prod", db)
assert "prod" == db
with model.get_sqla_engine(schema="foo") as engine:
db = make_url(engine.url).database
self.assertEqual("prod", db)
assert "prod" == db
@unittest.skipUnless(
SupersetTestCase.is_module_installed("thrift"), "thrift not installed"
@ -100,11 +101,11 @@ class TestDatabaseModel(SupersetTestCase):
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("default", db)
assert "default" == db
with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("core_db", db)
assert "core_db" == db
@unittest.skipUnless(
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
@ -115,11 +116,11 @@ class TestDatabaseModel(SupersetTestCase):
with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("superset", db)
assert "superset" == db
with model.get_sqla_engine(schema="staging") as engine:
db = make_url(engine.url).database
self.assertEqual("staging", db)
assert "staging" == db
@unittest.skipUnless(
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
@ -133,12 +134,12 @@ class TestDatabaseModel(SupersetTestCase):
model.impersonate_user = True
with model.get_sqla_engine() as engine:
username = make_url(engine.url).username
self.assertEqual(example_user.username, username)
assert example_user.username == username
model.impersonate_user = False
with model.get_sqla_engine() as engine:
username = make_url(engine.url).username
self.assertNotEqual(example_user.username, username)
assert example_user.username != username
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_presto(self, mocked_create_engine):
@ -344,20 +345,20 @@ class TestDatabaseModel(SupersetTestCase):
if main_db.backend == "mysql":
df = main_db.get_df("SELECT 1", None, None)
self.assertEqual(df.iat[0, 0], 1)
assert df.iat[0, 0] == 1
df = main_db.get_df("SELECT 1;", None, None)
self.assertEqual(df.iat[0, 0], 1)
assert df.iat[0, 0] == 1
def test_multi_statement(self):
main_db = get_example_database()
if main_db.backend == "mysql":
df = main_db.get_df("USE superset; SELECT 1", None, None)
self.assertEqual(df.iat[0, 0], 1)
assert df.iat[0, 0] == 1
df = main_db.get_df("USE superset; SELECT ';';", None, None)
self.assertEqual(df.iat[0, 0], ";")
assert df.iat[0, 0] == ";"
@mock.patch("superset.models.core.create_engine")
def test_get_sqla_engine(self, mocked_create_engine):
@ -404,20 +405,20 @@ class TestSqlaTableModel(SupersetTestCase):
sqla_literal = ds_col.get_timestamp_expression(None)
compiled = f"{sqla_literal.compile()}"
if tbl.database.backend == "mysql":
self.assertEqual(compiled, "from_unixtime(ds)")
assert compiled == "from_unixtime(ds)"
ds_col.python_date_format = "epoch_s"
sqla_literal = ds_col.get_timestamp_expression("P1D")
compiled = f"{sqla_literal.compile()}"
if tbl.database.backend == "mysql":
self.assertEqual(compiled, "DATE(from_unixtime(ds))")
assert compiled == "DATE(from_unixtime(ds))"
prev_ds_expr = ds_col.expression
ds_col.expression = "DATE_ADD(ds, 1)"
sqla_literal = ds_col.get_timestamp_expression("P1D")
compiled = f"{sqla_literal.compile()}"
if tbl.database.backend == "mysql":
self.assertEqual(compiled, "DATE(from_unixtime(DATE_ADD(ds, 1)))")
assert compiled == "DATE(from_unixtime(DATE_ADD(ds, 1)))"
ds_col.expression = prev_ds_expr
def query_with_expr_helper(self, is_timeseries, inner_join=True):
@ -448,16 +449,16 @@ class TestSqlaTableModel(SupersetTestCase):
series_limit=15 if inner_join and is_timeseries else None,
)
qr = tbl.query(query_obj)
self.assertEqual(qr.status, QueryStatus.SUCCESS)
assert qr.status == QueryStatus.SUCCESS
sql = qr.query
self.assertIn(arbitrary_gby, sql)
self.assertIn("name", sql)
assert arbitrary_gby in sql
assert "name" in sql
if inner_join and is_timeseries:
self.assertIn("JOIN", sql.upper())
assert "JOIN" in sql.upper()
else:
self.assertNotIn("JOIN", sql.upper())
assert "JOIN" not in sql.upper()
spec.allows_joins = old_inner_join
self.assertFalse(qr.df.empty)
assert not qr.df.empty
return qr.df
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -475,7 +476,7 @@ class TestSqlaTableModel(SupersetTestCase):
name_list1 = canonicalize_df(df1).name.values.tolist()
df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False)
name_list2 = canonicalize_df(df1).name.values.tolist()
self.assertFalse(df2.empty)
assert not df2.empty
assert name_list2 == name_list1
@ -498,14 +499,14 @@ class TestSqlaTableModel(SupersetTestCase):
extras={},
)
sql = tbl.get_query_str(query_obj)
self.assertNotIn("-- COMMENT", sql)
assert "-- COMMENT" not in sql
def mutator(*args, **kwargs):
return "-- COMMENT\n" + args[0]
app.config["SQL_QUERY_MUTATOR"] = mutator
sql = tbl.get_query_str(query_obj)
self.assertIn("-- COMMENT", sql)
assert "-- COMMENT" in sql
app.config["SQL_QUERY_MUTATOR"] = None
@ -524,15 +525,15 @@ class TestSqlaTableModel(SupersetTestCase):
extras={},
)
sql = tbl.get_query_str(query_obj)
self.assertNotIn("-- COMMENT", sql)
assert "-- COMMENT" not in sql
def mutator(sql, database=None, **kwargs):
return "-- COMMENT\n--" + "\n" + str(database) + "\n" + sql
app.config["SQL_QUERY_MUTATOR"] = mutator
mutated_sql = tbl.get_query_str(query_obj)
self.assertIn("-- COMMENT", mutated_sql)
self.assertIn(tbl.database.name, mutated_sql)
assert "-- COMMENT" in mutated_sql
assert tbl.database.name in mutated_sql
app.config["SQL_QUERY_MUTATOR"] = None
@ -554,7 +555,7 @@ class TestSqlaTableModel(SupersetTestCase):
with self.assertRaises(Exception) as context:
tbl.get_query_str(query_obj)
self.assertTrue("Metric 'invalid' does not exist", context.exception)
assert "Metric 'invalid' does not exist", context.exception
def test_query_label_without_group_by(self):
tbl = self.get_table(name="birth_names")
@ -577,7 +578,7 @@ class TestSqlaTableModel(SupersetTestCase):
)
sql = tbl.get_query_str(query_obj)
self.assertRegex(sql, r'name AS ["`]?Given Name["`]?')
assert re.search('name AS ["`]?Given Name["`]?', sql) # noqa: F821
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices_with_no_query_context(self):

View File

@ -138,7 +138,7 @@ class TestQueryApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/query/{query.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
expected_result = {
"database": {"id": example_db.id},
@ -163,7 +163,7 @@ class TestQueryApi(SupersetTestCase):
"tracking_url": None,
}
data = json.loads(rv.data.decode("utf-8"))
self.assertIn("changed_on", data["result"])
assert "changed_on" in data["result"]
for key, value in data["result"].items():
# We can't assert timestamp
if key not in (
@ -173,7 +173,7 @@ class TestQueryApi(SupersetTestCase):
"start_time",
"id",
):
self.assertEqual(value, expected_result[key])
assert value == expected_result[key]
# rollback changes
db.session.delete(query)
db.session.commit()
@ -189,7 +189,7 @@ class TestQueryApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/query/{max_id + 1}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
db.session.delete(query)
db.session.commit()
@ -222,30 +222,30 @@ class TestQueryApi(SupersetTestCase):
self.login(username="gamma_1", password="password")
uri = f"api/v1/query/{query_gamma2.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
uri = f"api/v1/query/{query_gamma1.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
# Gamma2 user, only sees their own queries
self.logout()
self.login(username="gamma_2", password="password")
uri = f"api/v1/query/{query_gamma1.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
uri = f"api/v1/query/{query_gamma2.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
# Admin's have the "all query access" permission
self.logout()
self.login(ADMIN_USERNAME)
uri = f"api/v1/query/{query_gamma1.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
uri = f"api/v1/query/{query_gamma2.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
# rollback changes
db.session.delete(query_gamma1)
@ -262,7 +262,7 @@ class TestQueryApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/query/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == QUERIES_FIXTURE_COUNT
# check expected columns
@ -433,11 +433,11 @@ class TestQueryApi(SupersetTestCase):
timestamp = datetime.timestamp(now - timedelta(days=2)) * 1000
uri = f"api/v1/query/updated_since?q={prison.dumps({'last_updated_ms': timestamp})}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
expected_result = updated_query.to_dict()
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 1)
assert len(data["result"]) == 1
for key, value in data["result"][0].items():
# We can't assert timestamp
if key not in (
@ -447,7 +447,7 @@ class TestQueryApi(SupersetTestCase):
"start_time",
"id",
):
self.assertEqual(value, expected_result[key])
assert value == expected_result[key]
# rollback changes
db.session.delete(old_query)
db.session.delete(updated_query)

View File

@ -467,26 +467,22 @@ class TestSavedQueryApi(SupersetTestCase):
# Filter by tag ID
filter_params = get_filter_params("saved_query_tag_id", tag.id)
response_by_id = self.get_list("saved_query", filter_params)
self.assertEqual(response_by_id.status_code, 200)
assert response_by_id.status_code == 200
data_by_id = json.loads(response_by_id.data.decode("utf-8"))
# Filter by tag name
filter_params = get_filter_params("saved_query_tags", tag.name)
response_by_name = self.get_list("saved_query", filter_params)
self.assertEqual(response_by_name.status_code, 200)
assert response_by_name.status_code == 200
data_by_name = json.loads(response_by_name.data.decode("utf-8"))
# Compare results
self.assertEqual(
data_by_id["count"],
data_by_name["count"],
len(expected_saved_queries),
)
self.assertEqual(
set(query["id"] for query in data_by_id["result"]),
set(query["id"] for query in data_by_name["result"]),
set(query.id for query in expected_saved_queries),
assert data_by_id["count"] == data_by_name["count"], len(
expected_saved_queries
)
assert set(query["id"] for query in data_by_id["result"]) == set(
query["id"] for query in data_by_name["result"]
), set(query.id for query in expected_saved_queries)
@pytest.mark.usefixtures("create_saved_queries")
def test_get_saved_query_favorite_filter(self):

View File

@ -70,15 +70,15 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context("birth_names", add_postprocessing_operations=True)
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), len(payload["queries"]))
assert len(query_context.queries) == len(payload["queries"])
for query_idx, query in enumerate(query_context.queries):
payload_query = payload["queries"][query_idx]
# check basic properties
self.assertEqual(query.extras, payload_query["extras"])
self.assertEqual(query.filter, payload_query["filters"])
self.assertEqual(query.columns, payload_query["columns"])
assert query.extras == payload_query["extras"]
assert query.filter == payload_query["filters"]
assert query.columns == payload_query["columns"]
# metrics are mutated during creation
for metric_idx, metric in enumerate(query.metrics):
@ -88,16 +88,16 @@ class TestQueryContext(SupersetTestCase):
if "expressionType" in payload_metric
else payload_metric["label"]
)
self.assertEqual(metric, payload_metric)
assert metric == payload_metric
self.assertEqual(query.orderby, payload_query["orderby"])
self.assertEqual(query.time_range, payload_query["time_range"])
assert query.orderby == payload_query["orderby"]
assert query.time_range == payload_query["time_range"]
# check post processing operation properties
for post_proc_idx, post_proc in enumerate(query.post_processing):
payload_post_proc = payload_query["post_processing"][post_proc_idx]
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"])
assert post_proc["operation"] == payload_post_proc["operation"]
assert post_proc["options"] == payload_post_proc["options"]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self):
@ -128,12 +128,12 @@ class TestQueryContext(SupersetTestCase):
rehydrated_qo = rehydrated_qc.queries[0]
rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)
self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
self.assertEqual(len(rehydrated_qc.queries), 1)
self.assertEqual(query_cache_key, rehydrated_query_cache_key)
self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
self.assertFalse(rehydrated_qc.force)
assert rehydrated_qc.datasource == query_context.datasource
assert len(rehydrated_qc.queries) == 1
assert query_cache_key == rehydrated_query_cache_key
assert rehydrated_qc.result_type == query_context.result_type
assert rehydrated_qc.result_format == query_context.result_format
assert not rehydrated_qc.force
def test_query_cache_key_changes_when_datasource_is_updated(self):
payload = get_query_context("birth_names")
@ -164,7 +164,7 @@ class TestQueryContext(SupersetTestCase):
cache_key_new = query_context.query_cache_key(query_object)
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
assert cache_key_original != cache_key_new
def test_query_cache_key_changes_when_metric_is_updated(self):
payload = get_query_context("birth_names")
@ -198,7 +198,7 @@ class TestQueryContext(SupersetTestCase):
db.session.commit()
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
assert cache_key_original != cache_key_new
def test_query_cache_key_does_not_change_for_non_existent_or_null(self):
payload = get_query_context("birth_names", add_postprocessing_operations=True)
@ -228,14 +228,14 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key = query_context.query_cache_key(query_object)
self.assertEqual(cache_key_original, cache_key)
assert cache_key_original == cache_key
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key = query_context.query_cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key)
assert cache_key_original != cache_key
def test_query_cache_key_changes_when_time_offsets_is_updated(self):
payload = get_query_context("birth_names", add_time_offsets=True)
@ -248,7 +248,7 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key = query_context.query_cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key)
assert cache_key_original != cache_key
def test_handle_metrics_field(self):
"""
@ -265,7 +265,7 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["metrics"] = ["sum__num", {"label": "abc"}, adhoc_metric]
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
self.assertEqual(query_object.metrics, ["sum__num", "abc", adhoc_metric])
assert query_object.metrics == ["sum__num", "abc", adhoc_metric]
def test_convert_deprecated_fields(self):
"""
@ -280,12 +280,12 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["granularity_sqla"] = "timecol"
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), 1)
assert len(query_context.queries) == 1
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
self.assertEqual(query_object.columns, columns)
self.assertEqual(query_object.series_limit, 99)
self.assertEqual(query_object.series_limit_metric, "sum__num")
assert query_object.granularity == "timecol"
assert query_object.columns == columns
assert query_object.series_limit == 99
assert query_object.series_limit_metric == "sum__num"
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_csv_response_format(self):
@ -297,10 +297,10 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["row_limit"] = 10
query_context: QueryContext = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
assert len(responses) == 1
data = responses["queries"][0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)
assert "name,sum__num\n" in data
assert len(data.split("\n")) == 12
def test_sql_injection_via_groupby(self):
"""
@ -352,11 +352,11 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["row_limit"] = 5
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
assert len(responses) == 1
data = responses["queries"][0]["data"]
self.assertIsInstance(data, list)
self.assertEqual(len(data), 5)
self.assertNotIn("sum__num", data[0])
assert isinstance(data, list)
assert len(data) == 5
assert "sum__num" not in data[0]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_response_type(self):
@ -489,7 +489,7 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
new_cache_key = responses["queries"][0]["cache_key"]
self.assertEqual(orig_cache_key, new_cache_key)
assert orig_cache_key == new_cache_key
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_time_offsets_in_query_object(self):
@ -505,21 +505,18 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["time_range"] = "1990 : 1991"
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(
responses["queries"][0]["colnames"],
[
"__timestamp",
"name",
"sum__num",
"sum__num__1 year ago",
"sum__num__1 year later",
],
)
assert responses["queries"][0]["colnames"] == [
"__timestamp",
"name",
"sum__num",
"sum__num__1 year ago",
"sum__num__1 year later",
]
sqls = [
sql for sql in responses["queries"][0]["query"].split(";") if sql.strip()
]
self.assertEqual(len(sqls), 3)
assert len(sqls) == 3
# 1 year ago
assert re.search(r"1989-01-01.+1990-01-01", sqls[1], re.S)
assert re.search(r"1990-01-01.+1991-01-01", sqls[1], re.S)
@ -560,9 +557,9 @@ class TestQueryContext(SupersetTestCase):
cache_keys = rv["cache_keys"]
cache_keys__1_year_ago = cache_keys[0]
cache_keys__1_year_later = cache_keys[1]
self.assertIsNotNone(cache_keys__1_year_ago)
self.assertIsNotNone(cache_keys__1_year_later)
self.assertNotEqual(cache_keys__1_year_ago, cache_keys__1_year_later)
assert cache_keys__1_year_ago is not None
assert cache_keys__1_year_later is not None
assert cache_keys__1_year_ago != cache_keys__1_year_later
# swap offsets
payload["queries"][0]["time_offsets"] = ["1 year later", "1 year ago"]
@ -570,8 +567,8 @@ class TestQueryContext(SupersetTestCase):
query_object = query_context.queries[0]
rv = query_context.processing_time_offsets(df.copy(), query_object)
cache_keys = rv["cache_keys"]
self.assertEqual(cache_keys__1_year_ago, cache_keys[1])
self.assertEqual(cache_keys__1_year_later, cache_keys[0])
assert cache_keys__1_year_ago == cache_keys[1]
assert cache_keys__1_year_later == cache_keys[0]
# remove all offsets
payload["queries"][0]["time_offsets"] = []
@ -582,9 +579,9 @@ class TestQueryContext(SupersetTestCase):
query_object,
)
self.assertEqual(rv["df"].shape, df.shape)
self.assertEqual(rv["queries"], [])
self.assertEqual(rv["cache_keys"], [])
assert rv["df"].shape == df.shape
assert rv["queries"] == []
assert rv["cache_keys"] == []
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_time_offsets_sql(self):
@ -732,7 +729,7 @@ class TestQueryContext(SupersetTestCase):
row_limit_pattern_with_config_value = r"LIMIT " + re.escape(
str(row_limit_value)
)
self.assertEqual(len(sqls), 2)
assert len(sqls) == 2
# 1 year ago
assert re.search(r"1989-01-01.+1990-01-01", sqls[0], re.S)
assert not re.search(r"LIMIT 100", sqls[0], re.S)

View File

@ -1673,7 +1673,7 @@ class TestReportSchedulesApi(SupersetTestCase):
}
uri = f"api/v1/report/{report_schedule.id}"
rv = self.put_assert_metric(uri, report_schedule_data, "put")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
@pytest.mark.usefixtures("create_report_schedules")
def test_update_report_preserve_ownership(self):
@ -1819,7 +1819,7 @@ class TestReportSchedulesApi(SupersetTestCase):
self.login(username="alpha2", password="password")
uri = f"api/v1/report/{report_schedule.id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
@pytest.mark.usefixtures("create_report_schedules")
def test_bulk_delete_report_schedule(self):
@ -1876,7 +1876,7 @@ class TestReportSchedulesApi(SupersetTestCase):
self.login(username="alpha2", password="password")
uri = f"api/v1/report/?q={prison.dumps(report_schedules_ids)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 403)
assert rv.status_code == 403
@pytest.mark.usefixtures("create_report_schedules")
def test_get_list_report_schedule_logs(self):

View File

@ -28,74 +28,77 @@ from .base_tests import SupersetTestCase
class TestSupersetResultSet(SupersetTestCase):
def test_dedup(self):
self.assertEqual(dedup(["foo", "bar"]), ["foo", "bar"])
self.assertEqual(
dedup(["foo", "bar", "foo", "bar", "Foo"]),
["foo", "bar", "foo__1", "bar__1", "Foo"],
)
self.assertEqual(
dedup(["foo", "bar", "bar", "bar", "Bar"]),
["foo", "bar", "bar__1", "bar__2", "Bar"],
)
self.assertEqual(
dedup(["foo", "bar", "bar", "bar", "Bar"], case_sensitive=False),
["foo", "bar", "bar__1", "bar__2", "Bar__3"],
)
assert dedup(["foo", "bar"]) == ["foo", "bar"]
assert dedup(["foo", "bar", "foo", "bar", "Foo"]) == [
"foo",
"bar",
"foo__1",
"bar__1",
"Foo",
]
assert dedup(["foo", "bar", "bar", "bar", "Bar"]) == [
"foo",
"bar",
"bar__1",
"bar__2",
"Bar",
]
assert dedup(["foo", "bar", "bar", "bar", "Bar"], case_sensitive=False) == [
"foo",
"bar",
"bar__1",
"bar__2",
"Bar__3",
]
def test_get_columns_basic(self):
data = [("a1", "b1", "c1"), ("a2", "b2", "c2")]
cursor_descr = (("a", "string"), ("b", "string"), ("c", "string"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
results.columns,
[
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "a",
"name": "a",
},
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "b",
"name": "b",
},
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "c",
"name": "c",
},
],
)
assert results.columns == [
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "a",
"name": "a",
},
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "b",
"name": "b",
},
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "c",
"name": "c",
},
]
def test_get_columns_with_int(self):
data = [("a1", 1), ("a2", 2)]
cursor_descr = (("a", "string"), ("b", "int"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
results.columns,
[
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "a",
"name": "a",
},
{
"is_dttm": False,
"type": "INT",
"type_generic": GenericDataType.NUMERIC,
"column_name": "b",
"name": "b",
},
],
)
assert results.columns == [
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "a",
"name": "a",
},
{
"is_dttm": False,
"type": "INT",
"type_generic": GenericDataType.NUMERIC,
"column_name": "b",
"name": "b",
},
]
def test_get_columns_type_inference(self):
data = [
@ -104,72 +107,69 @@ class TestSupersetResultSet(SupersetTestCase):
]
cursor_descr = (("a", None), ("b", None), ("c", None), ("d", None), ("e", None))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
results.columns,
[
{
"is_dttm": False,
"type": "FLOAT",
"type_generic": GenericDataType.NUMERIC,
"column_name": "a",
"name": "a",
},
{
"is_dttm": False,
"type": "INT",
"type_generic": GenericDataType.NUMERIC,
"column_name": "b",
"name": "b",
},
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "c",
"name": "c",
},
{
"is_dttm": True,
"type": "DATETIME",
"type_generic": GenericDataType.TEMPORAL,
"column_name": "d",
"name": "d",
},
{
"is_dttm": False,
"type": "BOOL",
"type_generic": GenericDataType.BOOLEAN,
"column_name": "e",
"name": "e",
},
],
)
assert results.columns == [
{
"is_dttm": False,
"type": "FLOAT",
"type_generic": GenericDataType.NUMERIC,
"column_name": "a",
"name": "a",
},
{
"is_dttm": False,
"type": "INT",
"type_generic": GenericDataType.NUMERIC,
"column_name": "b",
"name": "b",
},
{
"is_dttm": False,
"type": "STRING",
"type_generic": GenericDataType.STRING,
"column_name": "c",
"name": "c",
},
{
"is_dttm": True,
"type": "DATETIME",
"type_generic": GenericDataType.TEMPORAL,
"column_name": "d",
"name": "d",
},
{
"is_dttm": False,
"type": "BOOL",
"type_generic": GenericDataType.BOOLEAN,
"column_name": "e",
"name": "e",
},
]
def test_is_date(self):
data = [("a", 1), ("a", 2)]
cursor_descr = (("a", "string"), ("a", "string"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.is_temporal("DATE"), True)
self.assertEqual(results.is_temporal("DATETIME"), True)
self.assertEqual(results.is_temporal("TIME"), True)
self.assertEqual(results.is_temporal("TIMESTAMP"), True)
self.assertEqual(results.is_temporal("STRING"), False)
self.assertEqual(results.is_temporal(""), False)
self.assertEqual(results.is_temporal(None), False)
assert results.is_temporal("DATE") is True
assert results.is_temporal("DATETIME") is True
assert results.is_temporal("TIME") is True
assert results.is_temporal("TIMESTAMP") is True
assert results.is_temporal("STRING") is False
assert results.is_temporal("") is False
assert results.is_temporal(None) is False
def test_dedup_with_data(self):
data = [("a", 1), ("a", 2)]
cursor_descr = (("a", "string"), ("a", "string"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
column_names = [col["column_name"] for col in results.columns]
self.assertListEqual(column_names, ["a", "a__1"])
self.assertListEqual(column_names, ["a", "a__1"]) # noqa: PT009
def test_int64_with_missing_data(self):
data = [(None,), (1239162456494753670,), (None,), (None,), (None,), (None,)]
cursor_descr = [("user_id", "bigint", None, None, None, None, True)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "BIGINT")
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.NUMERIC)
assert results.columns[0]["type"] == "BIGINT"
assert results.columns[0]["type_generic"] == GenericDataType.NUMERIC
def test_data_as_list_of_lists(self):
data = [[1, "a"], [2, "b"]]
@ -179,29 +179,26 @@ class TestSupersetResultSet(SupersetTestCase):
]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
df = results.to_pandas_df()
self.assertEqual(
df_to_records(df),
[{"user_id": 1, "username": "a"}, {"user_id": 2, "username": "b"}],
)
assert df_to_records(df) == [
{"user_id": 1, "username": "a"},
{"user_id": 2, "username": "b"},
]
def test_nullable_bool(self):
data = [(None,), (True,), (None,), (None,), (None,), (None,)]
cursor_descr = [("is_test", "bool", None, None, None, None, True)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "BOOL")
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.BOOLEAN)
assert results.columns[0]["type"] == "BOOL"
assert results.columns[0]["type_generic"] == GenericDataType.BOOLEAN
df = results.to_pandas_df()
self.assertEqual(
df_to_records(df),
[
{"is_test": None},
{"is_test": True},
{"is_test": None},
{"is_test": None},
{"is_test": None},
{"is_test": None},
],
)
assert df_to_records(df) == [
{"is_test": None},
{"is_test": True},
{"is_test": None},
{"is_test": None},
{"is_test": None},
{"is_test": None},
]
def test_nested_types(self):
data = [
@ -220,32 +217,29 @@ class TestSupersetResultSet(SupersetTestCase):
]
cursor_descr = [("id",), ("dict_arr",), ("num_arr",), ("map_col",)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "INT")
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.NUMERIC)
self.assertEqual(results.columns[1]["type"], "STRING")
self.assertEqual(results.columns[1]["type_generic"], GenericDataType.STRING)
self.assertEqual(results.columns[2]["type"], "STRING")
self.assertEqual(results.columns[2]["type_generic"], GenericDataType.STRING)
self.assertEqual(results.columns[3]["type"], "STRING")
self.assertEqual(results.columns[3]["type_generic"], GenericDataType.STRING)
assert results.columns[0]["type"] == "INT"
assert results.columns[0]["type_generic"] == GenericDataType.NUMERIC
assert results.columns[1]["type"] == "STRING"
assert results.columns[1]["type_generic"] == GenericDataType.STRING
assert results.columns[2]["type"] == "STRING"
assert results.columns[2]["type_generic"] == GenericDataType.STRING
assert results.columns[3]["type"] == "STRING"
assert results.columns[3]["type_generic"] == GenericDataType.STRING
df = results.to_pandas_df()
self.assertEqual(
df_to_records(df),
[
{
"id": 4,
"dict_arr": '[{"table_name": "unicode_test", "database_id": 1}]',
"num_arr": "[1, 2, 3]",
"map_col": "{'chart_name': 'scatter'}",
},
{
"id": 3,
"dict_arr": '[{"table_name": "birth_names", "database_id": 1}]',
"num_arr": "[4, 5, 6]",
"map_col": "{'chart_name': 'plot'}",
},
],
)
assert df_to_records(df) == [
{
"id": 4,
"dict_arr": '[{"table_name": "unicode_test", "database_id": 1}]',
"num_arr": "[1, 2, 3]",
"map_col": "{'chart_name': 'scatter'}",
},
{
"id": 3,
"dict_arr": '[{"table_name": "birth_names", "database_id": 1}]',
"num_arr": "[4, 5, 6]",
"map_col": "{'chart_name': 'plot'}",
},
]
def test_single_column_multidim_nested_types(self):
data = [
@ -270,35 +264,30 @@ class TestSupersetResultSet(SupersetTestCase):
]
cursor_descr = [("metadata",)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "STRING")
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING)
assert results.columns[0]["type"] == "STRING"
assert results.columns[0]["type_generic"] == GenericDataType.STRING
df = results.to_pandas_df()
self.assertEqual(
df_to_records(df),
[
{
"metadata": '["test", [["foo", 123456, [[["test"], 3432546, 7657658766], [["fake"], 656756765, 324324324324]]]], ["test2", 43, 765765765], null, null]'
}
],
)
assert df_to_records(df) == [
{
"metadata": '["test", [["foo", 123456, [[["test"], 3432546, 7657658766], [["fake"], 656756765, 324324324324]]]], ["test2", 43, 765765765], null, null]'
}
]
def test_nested_list_types(self):
data = [([{"TestKey": [123456, "foo"]}],)]
cursor_descr = [("metadata",)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "STRING")
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING)
assert results.columns[0]["type"] == "STRING"
assert results.columns[0]["type_generic"] == GenericDataType.STRING
df = results.to_pandas_df()
self.assertEqual(
df_to_records(df), [{"metadata": '[{"TestKey": [123456, "foo"]}]'}]
)
assert df_to_records(df) == [{"metadata": '[{"TestKey": [123456, "foo"]}]'}]
def test_empty_datetime(self):
data = [(None,)]
cursor_descr = [("ds", "timestamp", None, None, None, None, True)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "TIMESTAMP")
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.TEMPORAL)
assert results.columns[0]["type"] == "TIMESTAMP"
assert results.columns[0]["type_generic"] == GenericDataType.TEMPORAL
def test_no_type_coercion(self):
data = [("a", 1), ("b", 2)]
@ -307,10 +296,10 @@ class TestSupersetResultSet(SupersetTestCase):
("two", "int", None, None, None, None, True),
]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "VARCHAR")
self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING)
self.assertEqual(results.columns[1]["type"], "INT")
self.assertEqual(results.columns[1]["type_generic"], GenericDataType.NUMERIC)
assert results.columns[0]["type"] == "VARCHAR"
assert results.columns[0]["type_generic"] == GenericDataType.STRING
assert results.columns[1]["type"] == "INT"
assert results.columns[1]["type_generic"] == GenericDataType.NUMERIC
def test_empty_data(self):
data = []
@ -319,4 +308,4 @@ class TestSupersetResultSet(SupersetTestCase):
("emptytwo", "int", None, None, None, None, True),
]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns, [])
assert results.columns == []

View File

@ -43,7 +43,7 @@ class TestSecurityCsrfApi(SupersetTestCase):
response = self.client.get(uri)
self.assert200(response)
data = json.loads(response.data.decode("utf-8"))
self.assertEqual(generate_csrf(), data["result"])
assert generate_csrf() == data["result"]
def test_get_csrf_token(self):
"""
@ -120,8 +120,8 @@ class TestSecurityGuestTokenApi(SupersetTestCase):
audience=get_url_host(),
algorithms=["HS256"],
)
self.assertEqual(user, decoded_token["user"])
self.assertEqual(resource, decoded_token["resources"][0])
assert user == decoded_token["user"]
assert resource == decoded_token["resources"][0]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_post_guest_token_bad_resources(self):

View File

@ -55,15 +55,15 @@ class TestGuestUserSecurity(SupersetTestCase):
def test_is_guest_user__regular_user(self):
is_guest = security_manager.is_guest_user(security_manager.find_user("admin"))
self.assertFalse(is_guest)
assert not is_guest
def test_is_guest_user__anonymous(self):
is_guest = security_manager.is_guest_user(security_manager.get_anonymous_user())
self.assertFalse(is_guest)
assert not is_guest
def test_is_guest_user__guest_user(self):
is_guest = security_manager.is_guest_user(self.authorized_guest())
self.assertTrue(is_guest)
assert is_guest
@patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -71,34 +71,34 @@ class TestGuestUserSecurity(SupersetTestCase):
)
def test_is_guest_user__flag_off(self):
is_guest = security_manager.is_guest_user(self.authorized_guest())
self.assertFalse(is_guest)
assert not is_guest
def test_get_guest_user__regular_user(self):
g.user = security_manager.find_user("admin")
guest_user = security_manager.get_current_guest_user_if_guest()
self.assertIsNone(guest_user)
assert guest_user is None
def test_get_guest_user__anonymous_user(self):
g.user = security_manager.get_anonymous_user()
guest_user = security_manager.get_current_guest_user_if_guest()
self.assertIsNone(guest_user)
assert guest_user is None
def test_get_guest_user__guest_user(self):
g.user = self.authorized_guest()
guest_user = security_manager.get_current_guest_user_if_guest()
self.assertEqual(guest_user, g.user)
assert guest_user == g.user
def test_get_guest_user_roles_explicit(self):
guest = self.authorized_guest()
roles = security_manager.get_user_roles(guest)
self.assertEqual(guest.roles, roles)
assert guest.roles == roles
def test_get_guest_user_roles_implicit(self):
guest = self.authorized_guest()
g.user = guest
roles = security_manager.get_user_roles()
self.assertEqual(guest.roles, roles)
assert guest.roles == roles
@patch.dict(
@ -142,17 +142,17 @@ class TestGuestUserDashboardAccess(SupersetTestCase):
def test_has_guest_access__regular_user(self):
g.user = security_manager.find_user("admin")
has_guest_access = security_manager.has_guest_access(self.dash)
self.assertFalse(has_guest_access)
assert not has_guest_access
def test_has_guest_access__anonymous_user(self):
g.user = security_manager.get_anonymous_user()
has_guest_access = security_manager.has_guest_access(self.dash)
self.assertFalse(has_guest_access)
assert not has_guest_access
def test_has_guest_access__authorized_guest_user(self):
g.user = self.authorized_guest
has_guest_access = security_manager.has_guest_access(self.dash)
self.assertTrue(has_guest_access)
assert has_guest_access
def test_has_guest_access__authorized_guest_user__non_zero_resource_index(self):
# set up a user who has authorized access, plus another resource
@ -163,7 +163,7 @@ class TestGuestUserDashboardAccess(SupersetTestCase):
g.user = guest
has_guest_access = security_manager.has_guest_access(self.dash)
self.assertTrue(has_guest_access)
assert has_guest_access
def test_has_guest_access__unauthorized_guest_user__different_resource_id(self):
g.user = security_manager.get_guest_user_from_token(
@ -173,14 +173,14 @@ class TestGuestUserDashboardAccess(SupersetTestCase):
}
)
has_guest_access = security_manager.has_guest_access(self.dash)
self.assertFalse(has_guest_access)
assert not has_guest_access
def test_has_guest_access__unauthorized_guest_user__different_resource_type(self):
g.user = security_manager.get_guest_user_from_token(
{"user": {}, "resources": [{"type": "dirt", "id": self.embedded.uuid}]}
)
has_guest_access = security_manager.has_guest_access(self.dash)
self.assertFalse(has_guest_access)
assert not has_guest_access
def test_raise_for_dashboard_access_as_guest(self):
g.user = self.authorized_guest

View File

@ -188,7 +188,7 @@ class TestRowLevelSecurity(SupersetTestCase):
"clause": "client_id=1",
},
)
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
rls1 = (
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
).one_or_none()
@ -214,7 +214,7 @@ class TestRowLevelSecurity(SupersetTestCase):
"clause": "client_id=1",
},
)
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_tables_required(self):
@ -231,7 +231,7 @@ class TestRowLevelSecurity(SupersetTestCase):
"clause": "client_id=1",
},
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
data = json.loads(rv.data.decode("utf-8"))
assert data["message"] == {"tables": ["Shorter than minimum length 1."]}
@ -326,8 +326,8 @@ class TestRowLevelSecurityCreateAPI(SupersetTestCase):
}
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 422)
self.assertEqual(data["message"], "[l'Some roles do not exist']")
assert status_code == 422
assert data["message"] == "[l'Some roles do not exist']"
def test_invalid_table_failure(self):
self.login(ADMIN_USERNAME)
@ -340,8 +340,8 @@ class TestRowLevelSecurityCreateAPI(SupersetTestCase):
}
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 422)
self.assertEqual(data["message"], "[l'Datasource does not exist']")
assert status_code == 422
assert data["message"] == "[l'Datasource does not exist']"
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_post_success(self):
@ -357,7 +357,7 @@ class TestRowLevelSecurityCreateAPI(SupersetTestCase):
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 201)
assert status_code == 201
rls = (
db.session.query(RowLevelSecurityFilter)
@ -366,11 +366,11 @@ class TestRowLevelSecurityCreateAPI(SupersetTestCase):
)
assert rls
self.assertEqual(rls.name, "rls 1")
self.assertEqual(rls.clause, "1=1")
self.assertEqual(rls.filter_type, "Base")
self.assertEqual(rls.tables[0].id, table.id)
self.assertEqual(rls.roles[0].id, 1)
assert rls.name == "rls 1"
assert rls.clause == "1=1"
assert rls.filter_type == "Base"
assert rls.tables[0].id == table.id
assert rls.roles[0].id == 1
db.session.delete(rls)
db.session.commit()
@ -388,8 +388,8 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
}
rv = self.client.put("/api/v1/rowlevelsecurity/99999999", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 404)
self.assertEqual(data["message"], "Not found")
assert status_code == 404
assert data["message"] == "Not found"
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_invalid_role_failure(self):
@ -410,8 +410,8 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
}
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 422)
self.assertEqual(data["message"], "[l'Some roles do not exist']")
assert status_code == 422
assert data["message"] == "[l'Some roles do not exist']"
db.session.delete(rls)
db.session.commit()
@ -439,8 +439,8 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
}
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 422)
self.assertEqual(data["message"], "[l'Datasource does not exist']")
assert status_code == 422
assert data["message"] == "[l'Datasource does not exist']"
db.session.delete(rls)
db.session.commit()
@ -472,7 +472,7 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
status_code, _data = rv.status_code, json.loads(rv.data.decode("utf-8")) # noqa: F841
self.assertEqual(status_code, 201)
assert status_code == 201
rls = (
db.session.query(RowLevelSecurityFilter)
@ -480,11 +480,11 @@ class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
.one_or_none()
)
self.assertEqual(rls.name, "rls put success")
self.assertEqual(rls.clause, "2=2")
self.assertEqual(rls.filter_type, "Base")
self.assertEqual(rls.tables[0].id, tables[1].id)
self.assertEqual(rls.roles[0].id, roles[1].id)
assert rls.name == "rls put success"
assert rls.clause == "2=2"
assert rls.filter_type == "Base"
assert rls.tables[0].id == tables[1].id
assert rls.roles[0].id == roles[1].id
db.session.delete(rls)
db.session.commit()
@ -498,8 +498,8 @@ class TestRowLevelSecurityDeleteAPI(SupersetTestCase):
rv = self.client.delete(f"/api/v1/rowlevelsecurity/?q={ids_to_delete}")
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 404)
self.assertEqual(data["message"], "Not found")
assert status_code == 404
assert data["message"] == "Not found"
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures("load_energy_table_with_slice")
@ -530,8 +530,8 @@ class TestRowLevelSecurityDeleteAPI(SupersetTestCase):
rv = self.client.delete(f"/api/v1/rowlevelsecurity/?q={ids_to_delete}")
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 200)
self.assertEqual(data["message"], "Deleted 2 rules")
assert status_code == 200
assert data["message"] == "Deleted 2 rules"
class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
@ -543,7 +543,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
params = prison.dumps({"page": 0, "page_size": 100})
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
@ -561,7 +561,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
params = prison.dumps({"page": 0, "page_size": 100})
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/roles?q={params}")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
@ -584,7 +584,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
params = prison.dumps({"page": 0, "page_size": 10})
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
received_tables = {table["text"].split(".")[-1] for table in result}
@ -664,7 +664,7 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
self.assertRegex(sql, RLS_ALICE_REGEX)
assert re.search(RLS_ALICE_REGEX, sql)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_does_not_alter_unrelated_query(self):
@ -679,7 +679,7 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
self.assertNotRegex(sql, RLS_ALICE_REGEX)
assert not re.search(RLS_ALICE_REGEX, sql)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_multiple_rls_filters_are_unionized(self):
@ -695,8 +695,8 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
self.assertRegex(sql, RLS_ALICE_REGEX)
self.assertRegex(sql, RLS_GENDER_REGEX)
assert re.search(RLS_ALICE_REGEX, sql)
assert re.search(RLS_GENDER_REGEX, sql)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures("load_energy_table_with_slice")
@ -709,8 +709,8 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
births_sql = births.get_query_str(self.query_obj)
energy_sql = energy.get_query_str(self.query_obj)
self.assertRegex(births_sql, RLS_ALICE_REGEX)
self.assertRegex(energy_sql, RLS_ALICE_REGEX)
assert re.search(RLS_ALICE_REGEX, births_sql)
assert re.search(RLS_ALICE_REGEX, energy_sql)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_dataset_id_can_be_string(self):
@ -721,4 +721,4 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
)
sql = dataset.get_query_str(self.query_obj)
self.assertRegex(sql, RLS_ALICE_REGEX)
assert re.search(RLS_ALICE_REGEX, sql)

File diff suppressed because it is too large Load Diff

View File

@ -67,7 +67,7 @@ class TestSqlLabApi(SupersetTestCase):
result = data.get("result")
assert result["active_tab"] is None # noqa: E711
assert result["tab_state_ids"] == []
self.assertEqual(len(result["databases"]), 0)
assert len(result["databases"]) == 0
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -126,7 +126,7 @@ class TestSqlLabApi(SupersetTestCase):
# associated with any tabs
resp = self.get_json_resp("/api/v1/sqllab/")
result = resp["result"]
self.assertEqual(result["active_tab"]["id"], tab_state_id)
assert result["active_tab"]["id"] == tab_state_id
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@ -220,8 +220,8 @@ class TestSqlLabApi(SupersetTestCase):
}
}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
assert rv.status_code == 400
data = {"sql": "SELECT 1"}
rv = self.client.post(
@ -230,8 +230,8 @@ class TestSqlLabApi(SupersetTestCase):
)
failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
assert rv.status_code == 400
data = {"database_id": 1}
rv = self.client.post(
@ -240,8 +240,8 @@ class TestSqlLabApi(SupersetTestCase):
)
failed_resp = {"message": {"sql": ["Missing data for required field."]}}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
assert rv.status_code == 400
def test_estimate_valid_request(self):
self.login(ADMIN_USERNAME)
@ -270,8 +270,8 @@ class TestSqlLabApi(SupersetTestCase):
success_resp = {"result": formatter_response}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp)
self.assertEqual(rv.status_code, 200)
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200
def test_format_sql_request(self):
self.login(ADMIN_USERNAME)
@ -283,8 +283,8 @@ class TestSqlLabApi(SupersetTestCase):
)
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp)
self.assertEqual(rv.status_code, 200)
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200
@mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
def test_execute_required_params(self):
@ -303,8 +303,8 @@ class TestSqlLabApi(SupersetTestCase):
}
}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
assert rv.status_code == 400
data = {"sql": "SELECT 1", "client_id": client_id}
rv = self.client.post(
@ -313,8 +313,8 @@ class TestSqlLabApi(SupersetTestCase):
)
failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
assert rv.status_code == 400
data = {"database_id": 1, "client_id": client_id}
rv = self.client.post(
@ -323,8 +323,8 @@ class TestSqlLabApi(SupersetTestCase):
)
failed_resp = {"message": {"sql": ["Missing data for required field."]}}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)
self.assertDictEqual(resp_data, failed_resp) # noqa: PT009
assert rv.status_code == 400
@mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
def test_execute_valid_request(self) -> None:
@ -342,8 +342,8 @@ class TestSqlLabApi(SupersetTestCase):
json=data,
)
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(resp_data.get("status"), "success")
self.assertEqual(rv.status_code, 200)
assert resp_data.get("status") == "success"
assert rv.status_code == 200
@mock.patch(
"tests.integration_tests.superset_test_custom_template_processors.datetime"
@ -366,7 +366,7 @@ class TestSqlLabApi(SupersetTestCase):
"/api/v1/sqllab/execute/", raise_on_error=False, json_=json_payload
)
assert sql_lab_mock.called
self.assertEqual(sql_lab_mock.call_args[0][1], "SELECT '1970-01-01' as test")
assert sql_lab_mock.call_args[0][1] == "SELECT '1970-01-01' as test"
self.delete_fake_db_for_macros()
@ -419,8 +419,8 @@ class TestSqlLabApi(SupersetTestCase):
self.get_resp(f"/api/v1/sqllab/results/?q={prison.dumps(arguments)}")
)
self.assertEqual(result_key, expected_key)
self.assertEqual(result_limited, expected_limited)
assert result_key == expected_key
assert result_limited == expected_limited
app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack
@ -454,6 +454,6 @@ class TestSqlLabApi(SupersetTestCase):
data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(io.StringIO("foo\n1\n2"))
self.assertEqual(list(expected_data), list(data))
assert list(expected_data) == list(data)
db.session.delete(query_obj)
db.session.commit()

View File

@ -58,7 +58,7 @@ class TestPrestoValidator(SupersetTestCase):
errors = self.validator.validate(sql, None, schema, self.database)
self.assertEqual([], errors)
assert [] == errors
@patch("superset.utils.core.g")
def test_validator_db_error(self, flask_g):
@ -95,7 +95,7 @@ class TestPrestoValidator(SupersetTestCase):
errors = self.validator.validate(sql, None, schema, self.database)
self.assertEqual(1, len(errors))
assert 1 == len(errors)
class TestPostgreSQLValidator(SupersetTestCase):

View File

@ -83,12 +83,12 @@ class TestDatabaseModel(SupersetTestCase):
database = Database(database_name="druid_db", sqlalchemy_uri="druid://db")
tbl = SqlaTable(table_name="druid_tbl", database=database)
col = TableColumn(column_name="__time", type="INTEGER", table=tbl)
self.assertEqual(col.is_dttm, None)
assert col.is_dttm is None
DruidEngineSpec.alter_new_orm_column(col)
self.assertEqual(col.is_dttm, True)
assert col.is_dttm is True
col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl)
self.assertEqual(col.is_temporal, False)
assert col.is_temporal is False
def test_temporal_varchar(self):
"""Ensure a column with is_dttm set to true evaluates to is_temporal == True"""
@ -125,13 +125,13 @@ class TestDatabaseModel(SupersetTestCase):
tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database())
for str_type, db_col_type in test_cases.items():
col = TableColumn(column_name="foo", type=str_type, table=tbl)
self.assertEqual(col.is_temporal, db_col_type == GenericDataType.TEMPORAL)
self.assertEqual(col.is_numeric, db_col_type == GenericDataType.NUMERIC)
self.assertEqual(col.is_string, db_col_type == GenericDataType.STRING)
assert col.is_temporal == (db_col_type == GenericDataType.TEMPORAL)
assert col.is_numeric == (db_col_type == GenericDataType.NUMERIC)
assert col.is_string == (db_col_type == GenericDataType.STRING)
for str_type, db_col_type in test_cases.items():
col = TableColumn(column_name="foo", type=str_type, table=tbl, is_dttm=True)
self.assertTrue(col.is_temporal)
assert col.is_temporal
@patch("superset.jinja_context.get_user_id", return_value=1)
@patch("superset.jinja_context.get_username", return_value="abc")
@ -161,7 +161,7 @@ class TestDatabaseModel(SupersetTestCase):
query_obj = dict(**base_query_obj, extras={})
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
assert table1.has_extra_cache_key_calls(query_obj)
assert set(extra_cache_keys) == {1, "abc", "abc@test.com"}
# Table with Jinja callable disabled.
@ -177,8 +177,8 @@ class TestDatabaseModel(SupersetTestCase):
)
query_obj = dict(**base_query_obj, extras={})
extra_cache_keys = table2.get_extra_cache_keys(query_obj)
self.assertTrue(table2.has_extra_cache_key_calls(query_obj))
self.assertListEqual(extra_cache_keys, [])
assert table2.has_extra_cache_key_calls(query_obj)
self.assertListEqual(extra_cache_keys, []) # noqa: PT009
# Table with no Jinja callable.
query = "SELECT 'abc' as user"
@ -190,15 +190,15 @@ class TestDatabaseModel(SupersetTestCase):
query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"})
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
self.assertFalse(table3.has_extra_cache_key_calls(query_obj))
self.assertListEqual(extra_cache_keys, [])
assert not table3.has_extra_cache_key_calls(query_obj)
self.assertListEqual(extra_cache_keys, []) # noqa: PT009
# With Jinja callable in SQL expression.
query_obj = dict(
**base_query_obj, extras={"where": "(user != '{{ current_username() }}')"}
)
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
self.assertTrue(table3.has_extra_cache_key_calls(query_obj))
assert table3.has_extra_cache_key_calls(query_obj)
assert extra_cache_keys == ["abc"]
@patch("superset.jinja_context.get_username", return_value="abc")
@ -393,11 +393,9 @@ class TestDatabaseModel(SupersetTestCase):
sqla_query = table.get_sqla_query(**query_obj)
sql = table.database.compile_sqla_query(sqla_query.sqla_query)
if isinstance(filter_.expected, list):
self.assertTrue(
any([candidate in sql for candidate in filter_.expected])
)
assert any([candidate in sql for candidate in filter_.expected])
else:
self.assertIn(filter_.expected, sql)
assert filter_.expected in sql
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_boolean_type_where_operators(self):
@ -434,7 +432,7 @@ class TestDatabaseModel(SupersetTestCase):
# https://github.com/sqlalchemy/sqlalchemy/blob/master/lib/sqlalchemy/dialects/mysql/base.py
if not dialect.supports_native_boolean and dialect.name != "mysql":
operand = "(1, 0)"
self.assertIn(f"IN {operand}", sql)
assert f"IN {operand}" in sql
def test_incorrect_jinja_syntax_raises_correct_exception(self):
query_obj = {

View File

@ -91,7 +91,7 @@ class TestSqlLab(SupersetTestCase):
self.login(ADMIN_USERNAME)
data = self.run_sql("SELECT * FROM birth_names LIMIT 10", "1")
self.assertLess(0, len(data["data"]))
assert 0 < len(data["data"])
data = self.run_sql("SELECT * FROM nonexistent_table", "2")
if backend() == "presto":
@ -220,8 +220,8 @@ class TestSqlLab(SupersetTestCase):
names_count = engine.execute(
f"SELECT COUNT(*) FROM birth_names" # noqa: F541
).first()
self.assertEqual(
names_count[0], len(data)
assert names_count[0] == len(
data
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
# cleanup
@ -238,14 +238,14 @@ class TestSqlLab(SupersetTestCase):
SELECT * FROM birth_names LIMIT 2;
"""
data = self.run_sql(multi_sql, "2234")
self.assertLess(0, len(data["data"]))
assert 0 < len(data["data"])
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_explain(self):
self.login(ADMIN_USERNAME)
data = self.run_sql("EXPLAIN SELECT * FROM birth_names", "1")
self.assertLess(0, len(data["data"]))
assert 0 < len(data["data"])
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_sql_json_has_access(self):
@ -261,21 +261,21 @@ class TestSqlLab(SupersetTestCase):
data = self.run_sql(QUERY_1, "1", username="Gagarin")
db.session.query(Query).delete()
db.session.commit()
self.assertLess(0, len(data["data"]))
assert 0 < len(data["data"])
def test_sqllab_has_access(self):
for username in (ADMIN_USERNAME, GAMMA_SQLLAB_USERNAME):
self.login(username)
for endpoint in ("/sqllab/", "/sqllab/history/"):
resp = self.client.get(endpoint)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
def test_sqllab_no_access(self):
self.login(GAMMA_USERNAME)
for endpoint in ("/sqllab/", "/sqllab/history/"):
resp = self.client.get(endpoint)
# Redirects to the main page
self.assertEqual(302, resp.status_code)
assert 302 == resp.status_code
def test_sql_json_schema_access(self):
examples_db = get_example_database()
@ -311,7 +311,7 @@ class TestSqlLab(SupersetTestCase):
data = self.run_sql(
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser"
)
self.assertEqual(1, len(data["data"]))
assert 1 == len(data["data"])
data = self.run_sql(
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table",
@ -319,7 +319,7 @@ class TestSqlLab(SupersetTestCase):
username="SchemaUser",
schema=CTAS_SCHEMA_NAME,
)
self.assertEqual(1, len(data["data"]))
assert 1 == len(data["data"])
# postgres needs a schema as a part of the table name.
if db_backend == "mysql":
@ -329,7 +329,7 @@ class TestSqlLab(SupersetTestCase):
username="SchemaUser",
schema=CTAS_SCHEMA_NAME,
)
self.assertEqual(1, len(data["data"]))
assert 1 == len(data["data"])
db.session.query(Query).delete()
with get_example_database().get_sqla_engine() as engine:
@ -349,77 +349,75 @@ class TestSqlLab(SupersetTestCase):
data = [["a", 4, 4.0]]
results = SupersetResultSet(data, cols, BaseEngineSpec)
self.assertEqual(len(data), results.size)
self.assertEqual(len(cols), len(results.columns))
assert len(data) == results.size
assert len(cols) == len(results.columns)
def test_pa_conversion_tuple(self):
cols = ["string_col", "int_col", "list_col", "float_col"]
data = [("Text", 111, [123], 1.0)]
results = SupersetResultSet(data, cols, BaseEngineSpec)
self.assertEqual(len(data), results.size)
self.assertEqual(len(cols), len(results.columns))
assert len(data) == results.size
assert len(cols) == len(results.columns)
def test_pa_conversion_dict(self):
cols = ["string_col", "dict_col", "int_col"]
data = [["a", {"c1": 1, "c2": 2, "c3": 3}, 4]]
results = SupersetResultSet(data, cols, BaseEngineSpec)
self.assertEqual(len(data), results.size)
self.assertEqual(len(cols), len(results.columns))
assert len(data) == results.size
assert len(cols) == len(results.columns)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_sql_limit(self):
self.login(ADMIN_USERNAME)
test_limit = 1
data = self.run_sql("SELECT * FROM birth_names", client_id="sql_limit_1")
self.assertGreater(len(data["data"]), test_limit)
assert len(data["data"]) > test_limit
data = self.run_sql(
"SELECT * FROM birth_names", client_id="sql_limit_2", query_limit=test_limit
)
self.assertEqual(len(data["data"]), test_limit)
assert len(data["data"]) == test_limit
data = self.run_sql(
f"SELECT * FROM birth_names LIMIT {test_limit}",
client_id="sql_limit_3",
query_limit=test_limit + 1,
)
self.assertEqual(len(data["data"]), test_limit)
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.QUERY)
assert len(data["data"]) == test_limit
assert data["query"]["limitingFactor"] == LimitingFactor.QUERY
data = self.run_sql(
f"SELECT * FROM birth_names LIMIT {test_limit + 1}",
client_id="sql_limit_4",
query_limit=test_limit,
)
self.assertEqual(len(data["data"]), test_limit)
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.DROPDOWN)
assert len(data["data"]) == test_limit
assert data["query"]["limitingFactor"] == LimitingFactor.DROPDOWN
data = self.run_sql(
f"SELECT * FROM birth_names LIMIT {test_limit}",
client_id="sql_limit_5",
query_limit=test_limit,
)
self.assertEqual(len(data["data"]), test_limit)
self.assertEqual(
data["query"]["limitingFactor"], LimitingFactor.QUERY_AND_DROPDOWN
)
assert len(data["data"]) == test_limit
assert data["query"]["limitingFactor"] == LimitingFactor.QUERY_AND_DROPDOWN
data = self.run_sql(
"SELECT * FROM birth_names",
client_id="sql_limit_6",
query_limit=10000,
)
self.assertEqual(len(data["data"]), 1200)
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
assert len(data["data"]) == 1200
assert data["query"]["limitingFactor"] == LimitingFactor.NOT_LIMITED
data = self.run_sql(
"SELECT * FROM birth_names",
client_id="sql_limit_7",
query_limit=1200,
)
self.assertEqual(len(data["data"]), 1200)
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
assert len(data["data"]) == 1200
assert data["query"]["limitingFactor"] == LimitingFactor.NOT_LIMITED
@pytest.mark.usefixtures("load_birth_names_data")
def test_query_api_filter(self) -> None:
@ -434,7 +432,7 @@ class TestSqlLab(SupersetTestCase):
data = self.get_json_resp(url)
admin = security_manager.find_user("admin")
gamma_sqllab = security_manager.find_user("gamma_sqllab")
self.assertEqual(3, len(data["result"]))
assert 3 == len(data["result"])
user_queries = [
result.get("user").get("first_name") for result in data["result"]
]
@ -461,7 +459,7 @@ class TestSqlLab(SupersetTestCase):
self.login(GAMMA_SQLLAB_USERNAME)
url = "/api/v1/query/"
data = self.get_json_resp(url)
self.assertEqual(3, len(data["result"]))
assert 3 == len(data["result"])
# Remove all_query_access from gamma sqllab
all_queries_view = security_manager.find_permission_view_menu(
@ -521,10 +519,9 @@ class TestSqlLab(SupersetTestCase):
]
}
url = f"/api/v1/query/?q={prison.dumps(arguments)}"
self.assertEqual(
{"SELECT 1", "SELECT 2"},
{r.get("sql") for r in self.get_json_resp(url)["result"]},
)
assert {"SELECT 1", "SELECT 2"} == {
r.get("sql") for r in self.get_json_resp(url)["result"]
}
@pytest.mark.usefixtures("load_birth_names_data")
def test_query_admin_can_access_all_queries(self) -> None:
@ -537,7 +534,7 @@ class TestSqlLab(SupersetTestCase):
url = "/api/v1/query/"
data = self.get_json_resp(url)
self.assertEqual(3, len(data["result"]))
assert 3 == len(data["result"])
def test_api_database(self):
self.login(ADMIN_USERNAME)
@ -555,10 +552,9 @@ class TestSqlLab(SupersetTestCase):
}
url = f"api/v1/database/?q={prison.dumps(arguments)}"
self.assertEqual(
{"examples", "fake_db_100", "main"},
{r.get("database_name") for r in self.get_json_resp(url)["result"]},
)
assert {"examples", "fake_db_100", "main"} == {
r.get("database_name") for r in self.get_json_resp(url)["result"]
}
self.delete_fake_db()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")

View File

@ -86,7 +86,7 @@ class TestCacheWarmUp(SupersetTestCase):
expected = [
{"chart_id": chart.id, "dashboard_id": dash.id} for chart in dash.slices
]
self.assertCountEqual(result, expected)
self.assertCountEqual(result, expected) # noqa: PT009
def reset_tag(self, tag):
"""Remove associated object from tag, used to reset tests"""
@ -106,7 +106,7 @@ class TestCacheWarmUp(SupersetTestCase):
strategy = DashboardTagsStrategy(["tag1"])
result = strategy.get_payloads()
expected = []
self.assertEqual(result, expected)
assert result == expected
# tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagType.custom)
@ -118,7 +118,7 @@ class TestCacheWarmUp(SupersetTestCase):
db.session.add(tagged_object)
db.session.commit()
self.assertCountEqual(strategy.get_payloads(), tag1_urls)
self.assertCountEqual(strategy.get_payloads(), tag1_urls) # noqa: PT009
strategy = DashboardTagsStrategy(["tag2"])
tag2 = get_tag("tag2", db.session, TagType.custom)
@ -126,7 +126,7 @@ class TestCacheWarmUp(SupersetTestCase):
result = strategy.get_payloads()
expected = []
self.assertEqual(result, expected)
assert result == expected
# tag first slice
dash = self.get_dash_by_slug("unicode-test")
@ -140,10 +140,10 @@ class TestCacheWarmUp(SupersetTestCase):
db.session.commit()
result = strategy.get_payloads()
self.assertCountEqual(result, tag2_urls)
self.assertCountEqual(result, tag2_urls) # noqa: PT009
strategy = DashboardTagsStrategy(["tag1", "tag2"])
result = strategy.get_payloads()
expected = tag1_urls + tag2_urls
self.assertCountEqual(result, expected)
self.assertCountEqual(result, expected) # noqa: PT009

View File

@ -55,7 +55,7 @@ class TestTagging(SupersetTestCase):
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
# Create a dataset and add it to the db
test_dataset = SqlaTable(
@ -71,16 +71,16 @@ class TestTagging(SupersetTestCase):
# Test to make sure that a dataset tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(1, len(tags))
self.assertEqual("ObjectType.dataset", str(tags[0].object_type))
self.assertEqual(test_dataset.id, tags[0].object_id)
assert 1 == len(tags)
assert "ObjectType.dataset" == str(tags[0].object_type)
assert test_dataset.id == tags[0].object_id
# Cleanup the db
db.session.delete(test_dataset)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_chart_tagging(self):
@ -94,7 +94,7 @@ class TestTagging(SupersetTestCase):
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
# Create a chart and add it to the db
test_chart = Slice(
@ -109,16 +109,16 @@ class TestTagging(SupersetTestCase):
# Test to make sure that a chart tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(1, len(tags))
self.assertEqual("ObjectType.chart", str(tags[0].object_type))
self.assertEqual(test_chart.id, tags[0].object_id)
assert 1 == len(tags)
assert "ObjectType.chart" == str(tags[0].object_type)
assert test_chart.id == tags[0].object_id
# Cleanup the db
db.session.delete(test_chart)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_dashboard_tagging(self):
@ -132,7 +132,7 @@ class TestTagging(SupersetTestCase):
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
# Create a dashboard and add it to the db
test_dashboard = Dashboard()
@ -145,16 +145,16 @@ class TestTagging(SupersetTestCase):
# Test to make sure that a dashboard tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(1, len(tags))
self.assertEqual("ObjectType.dashboard", str(tags[0].object_type))
self.assertEqual(test_dashboard.id, tags[0].object_id)
assert 1 == len(tags)
assert "ObjectType.dashboard" == str(tags[0].object_type)
assert test_dashboard.id == tags[0].object_id
# Cleanup the db
db.session.delete(test_dashboard)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_saved_query_tagging(self):
@ -168,7 +168,7 @@ class TestTagging(SupersetTestCase):
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
# Create a saved query and add it to the db
test_saved_query = SavedQuery(id=1, label="test saved query")
@ -178,24 +178,24 @@ class TestTagging(SupersetTestCase):
# Test to make sure that a saved query tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(2, len(tags))
assert 2 == len(tags)
self.assertEqual("ObjectType.query", str(tags[0].object_type))
self.assertEqual("owner:None", str(tags[0].tag.name))
self.assertEqual("TagType.owner", str(tags[0].tag.type))
self.assertEqual(test_saved_query.id, tags[0].object_id)
assert "ObjectType.query" == str(tags[0].object_type)
assert "owner:None" == str(tags[0].tag.name)
assert "TagType.owner" == str(tags[0].tag.type)
assert test_saved_query.id == tags[0].object_id
self.assertEqual("ObjectType.query", str(tags[1].object_type))
self.assertEqual("type:query", str(tags[1].tag.name))
self.assertEqual("TagType.type", str(tags[1].tag.type))
self.assertEqual(test_saved_query.id, tags[1].object_id)
assert "ObjectType.query" == str(tags[1].object_type)
assert "type:query" == str(tags[1].tag.name)
assert "TagType.type" == str(tags[1].tag.type)
assert test_saved_query.id == tags[1].object_id
# Cleanup the db
db.session.delete(test_saved_query)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
@pytest.mark.usefixtures("with_tagging_system_feature")
def test_favorite_tagging(self):
@ -209,7 +209,7 @@ class TestTagging(SupersetTestCase):
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
# Create a favorited object and add it to the db
test_saved_query = FavStar(user_id=1, class_name="slice", obj_id=1)
@ -218,16 +218,16 @@ class TestTagging(SupersetTestCase):
# Test to make sure that a favorited object tag was added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(1, len(tags))
self.assertEqual("ObjectType.chart", str(tags[0].object_type))
self.assertEqual(test_saved_query.obj_id, tags[0].object_id)
assert 1 == len(tags)
assert "ObjectType.chart" == str(tags[0].object_type)
assert test_saved_query.obj_id == tags[0].object_id
# Cleanup the db
db.session.delete(test_saved_query)
db.session.commit()
# Test to make sure the tag is deleted when the associated object is deleted
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
@with_feature_flags(TAGGING_SYSTEM=False)
def test_tagging_system(self):
@ -240,7 +240,7 @@ class TestTagging(SupersetTestCase):
self.clear_tagged_object_table()
# Test to make sure nothing is in the tagged_object table
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()
# Create a dataset and add it to the db
test_dataset = SqlaTable(
@ -282,7 +282,7 @@ class TestTagging(SupersetTestCase):
# Test to make sure that no tags were added to the tagged_object table
tags = self.query_tagged_object_table()
self.assertEqual(0, len(tags))
assert 0 == len(tags)
# Cleanup the db
db.session.delete(test_dataset)
@ -293,4 +293,4 @@ class TestTagging(SupersetTestCase):
db.session.commit()
# Test to make sure all the tags are deleted when the associated objects are deleted
self.assertEqual([], self.query_tagged_object_table())
assert [] == self.query_tagged_object_table()

View File

@ -135,7 +135,7 @@ class TestTagApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/tag/{tag.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
expected_result = {
"changed_by": None,
"changed_on_delta_humanized": "now",
@ -146,7 +146,7 @@ class TestTagApi(SupersetTestCase):
}
data = json.loads(rv.data.decode("utf-8"))
for key, value in expected_result.items():
self.assertEqual(value, data["result"][key])
assert value == data["result"][key]
# rollback changes
db.session.delete(tag)
db.session.commit()
@ -160,7 +160,7 @@ class TestTagApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/tag/{max_id + 1}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
# cleanup
db.session.delete(tag)
db.session.commit()
@ -173,7 +173,7 @@ class TestTagApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/tag/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == TAGS_FIXTURE_COUNT
# check expected columns
@ -211,7 +211,7 @@ class TestTagApi(SupersetTestCase):
}
uri = f"api/v1/tag/?{parse.urlencode({'q': prison.dumps(query)})}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 2
@ -219,7 +219,7 @@ class TestTagApi(SupersetTestCase):
query["filters"][0]["value"] = False
uri = f"api/v1/tag/?{parse.urlencode({'q': prison.dumps(query)})}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 3
@ -249,10 +249,10 @@ class TestTagApi(SupersetTestCase):
data = {"properties": {"tags": example_tag_names}}
rv = self.client.post(uri, json=data, follow_redirects=True)
# successful request
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
# check that tags were created in database
tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names))
self.assertEqual(tags.count(), 2)
assert tags.count() == 2
# check that tagged objects were created
tag_ids = [tags[0].id, tags[1].id]
tagged_objects = db.session.query(TaggedObject).filter(
@ -308,7 +308,7 @@ class TestTagApi(SupersetTestCase):
uri = f"api/v1/tag/{dashboard_type.value}/{dashboard_id}/{tags.first().name}"
rv = self.client.delete(uri, follow_redirects=True)
# successful request
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
# ensure that tagged object no longer exists
tagged_object = (
db.session.query(TaggedObject)
@ -358,14 +358,14 @@ class TestTagApi(SupersetTestCase):
TaggedObject.object_id == dashboard_id,
TaggedObject.object_type == dashboard_type.name,
)
self.assertEqual(tagged_objects.count(), 2)
assert tagged_objects.count() == 2
uri = f'api/v1/tag/get_objects/?tags={",".join(tag_names)}'
rv = self.client.get(uri)
# successful request
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
fetched_objects = rv.json["result"]
self.assertEqual(len(fetched_objects), 1)
self.assertEqual(fetched_objects[0]["id"], dashboard_id)
assert len(fetched_objects) == 1
assert fetched_objects[0]["id"] == dashboard_id
# clean up tagged object
tagged_objects.delete()
@ -394,12 +394,12 @@ class TestTagApi(SupersetTestCase):
TaggedObject.object_id == dashboard_id,
TaggedObject.object_type == dashboard_type.name,
)
self.assertEqual(tagged_objects.count(), 2)
self.assertEqual(tagged_objects.first().object_id, dashboard_id)
assert tagged_objects.count() == 2
assert tagged_objects.first().object_id == dashboard_id
uri = "api/v1/tag/get_objects/"
rv = self.client.get(uri)
# successful request
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
fetched_objects = rv.json["result"]
# check that the dashboard object was fetched
assert dashboard_id in [obj["id"] for obj in fetched_objects]
@ -413,25 +413,25 @@ class TestTagApi(SupersetTestCase):
# check that tags exist in the database
example_tag_names = ["example_tag_1", "example_tag_2", "example_tag_3"]
tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names))
self.assertEqual(tags.count(), 3)
assert tags.count() == 3
# delete the first tag
uri = f"api/v1/tag/?q={prison.dumps(example_tag_names[:1])}"
rv = self.client.delete(uri, follow_redirects=True)
# successful request
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
# check that tag does not exist in the database
tag = db.session.query(Tag).filter(Tag.name == example_tag_names[0]).first()
assert tag is None
tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names))
self.assertEqual(tags.count(), 2)
assert tags.count() == 2
# delete multiple tags
uri = f"api/v1/tag/?q={prison.dumps(example_tag_names[1:])}"
rv = self.client.delete(uri, follow_redirects=True)
# successful request
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
# check that tags are all gone
tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names))
self.assertEqual(tags.count(), 0)
assert tags.count() == 0
@pytest.mark.usefixtures("create_tags")
def test_delete_favorite_tag(self):
@ -442,7 +442,7 @@ class TestTagApi(SupersetTestCase):
tag = db.session.query(Tag).first()
rv = self.client.post(uri, follow_redirects=True)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
from sqlalchemy import and_ # noqa: F811
from superset.tags.models import user_favorite_tag_table # noqa: F811
from flask import g # noqa: F401, F811
@ -463,7 +463,7 @@ class TestTagApi(SupersetTestCase):
uri = f"api/v1/tag/{tag.id}/favorites/"
rv = self.client.delete(uri, follow_redirects=True)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
association_row = (
db.session.query(user_favorite_tag_table)
.filter(
@ -483,7 +483,7 @@ class TestTagApi(SupersetTestCase):
uri = "api/v1/tag/123/favorites/" # noqa: F541
rv = self.client.post(uri, follow_redirects=True)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("create_tags")
def test_delete_favorite_tag_not_found(self):
@ -491,7 +491,7 @@ class TestTagApi(SupersetTestCase):
uri = "api/v1/tag/123/favorites/" # noqa: F541
rv = self.client.delete(uri, follow_redirects=True)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("create_tags")
@patch("superset.daos.tag.g")
@ -501,7 +501,7 @@ class TestTagApi(SupersetTestCase):
uri = "api/v1/tag/123/favorites/" # noqa: F541
rv = self.client.post(uri, follow_redirects=True)
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
@pytest.mark.usefixtures("create_tags")
@patch("superset.daos.tag.g")
@ -511,7 +511,7 @@ class TestTagApi(SupersetTestCase):
uri = "api/v1/tag/123/favorites/" # noqa: F541
rv = self.client.delete(uri, follow_redirects=True)
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_post_tag(self):
@ -527,7 +527,7 @@ class TestTagApi(SupersetTestCase):
json={"name": "my_tag", "objects_to_tag": [["dashboard", dashboard.id]]},
)
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
self.get_user(username="admin").get_id() # noqa: F841
tag = (
db.session.query(Tag)
@ -550,7 +550,7 @@ class TestTagApi(SupersetTestCase):
json={"name": "", "objects_to_tag": [["dashboard", dashboard.id]]},
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
@pytest.mark.usefixtures("create_tags")
@ -563,7 +563,7 @@ class TestTagApi(SupersetTestCase):
uri, json={"name": "new_name", "description": "new description"}
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
tag = (
db.session.query(Tag)
@ -581,7 +581,7 @@ class TestTagApi(SupersetTestCase):
uri = f"api/v1/tag/{tag_to_update.id}"
rv = self.client.put(uri, json={"foo": "bar"})
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_post_bulk_tag(self):
@ -617,7 +617,7 @@ class TestTagApi(SupersetTestCase):
},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
result = TagDAO.get_tagged_objects_for_tags(tags, ["dashboard"])
assert len(result) == 1
@ -686,7 +686,7 @@ class TestTagApi(SupersetTestCase):
},
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
result = rv.json["result"]
assert len(result["objects_tagged"]) == 2
assert len(result["objects_skipped"]) == 1

View File

@ -73,7 +73,7 @@ class TestThumbnailsSeleniumLive(LiveServerTestCase):
"admin",
thumbnail_url,
)
self.assertEqual(response.getcode(), 202)
assert response.getcode() == 202
class TestWebDriverScreenshotErrorDetector(SupersetTestCase):
@ -217,7 +217,7 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=False)
@ -228,7 +228,7 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -255,7 +255,7 @@ class TestThumbnails(SupersetTestCase):
assert mock_adjust_string.call_args[0][2] == "admin"
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 202)
assert rv.status_code == 202
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -283,7 +283,7 @@ class TestThumbnails(SupersetTestCase):
assert mock_adjust_string.call_args[0][2] == username
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 202)
assert rv.status_code == 202
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -295,7 +295,7 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/dashboard/{max_id + 1}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
@ -306,7 +306,7 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -333,7 +333,7 @@ class TestThumbnails(SupersetTestCase):
assert mock_adjust_string.call_args[0][2] == "admin"
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 202)
assert rv.status_code == 202
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -361,7 +361,7 @@ class TestThumbnails(SupersetTestCase):
assert mock_adjust_string.call_args[0][2] == username
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 202)
assert rv.status_code == 202
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -373,7 +373,7 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/{max_id + 1}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -387,8 +387,8 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
rv = self.client.get(f"api/v1/chart/{id_}/thumbnail/1234/")
self.assertEqual(rv.status_code, 302)
self.assertEqual(rv.headers["Location"], thumbnail_url)
assert rv.status_code == 302
assert rv.headers["Location"] == thumbnail_url
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -402,8 +402,8 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.data, self.mock_image)
assert rv.status_code == 200
assert rv.data == self.mock_image
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -417,8 +417,8 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
rv = self.client.get(thumbnail_url)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.data, self.mock_image)
assert rv.status_code == 200
assert rv.data == self.mock_image
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -432,5 +432,5 @@ class TestThumbnails(SupersetTestCase):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
rv = self.client.get(f"api/v1/dashboard/{id_}/thumbnail/1234/")
self.assertEqual(rv.status_code, 302)
self.assertEqual(rv.headers["Location"], thumbnail_url)
assert rv.status_code == 302
assert rv.headers["Location"] == thumbnail_url

View File

@ -35,36 +35,36 @@ class TestCurrentUserApi(SupersetTestCase):
rv = self.client.get(meUri)
self.assertEqual(200, rv.status_code)
assert 200 == rv.status_code
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual("admin", response["result"]["username"])
self.assertEqual(True, response["result"]["is_active"])
self.assertEqual(False, response["result"]["is_anonymous"])
assert "admin" == response["result"]["username"]
assert True is response["result"]["is_active"]
assert False is response["result"]["is_anonymous"]
def test_get_me_with_roles(self):
self.login(ADMIN_USERNAME)
rv = self.client.get(meUri + "roles/")
self.assertEqual(200, rv.status_code)
assert 200 == rv.status_code
response = json.loads(rv.data.decode("utf-8"))
roles = list(response["result"]["roles"].keys())
self.assertEqual("Admin", roles.pop())
assert "Admin" == roles.pop()
@patch("superset.security.manager.g")
def test_get_my_roles_anonymous(self, mock_g):
mock_g.user = security_manager.get_anonymous_user
rv = self.client.get(meUri + "roles/")
self.assertEqual(401, rv.status_code)
assert 401 == rv.status_code
def test_get_me_unauthorized(self):
rv = self.client.get(meUri)
self.assertEqual(401, rv.status_code)
assert 401 == rv.status_code
@patch("superset.security.manager.g")
def test_get_me_anonymous(self, mock_g):
mock_g.user = security_manager.get_anonymous_user
rv = self.client.get(meUri)
self.assertEqual(401, rv.status_code)
assert 401 == rv.status_code
class TestUserApi(SupersetTestCase):

View File

@ -53,8 +53,8 @@ class EncryptedFieldTest(SupersetTestCase):
def test_create_field(self):
field = encrypted_field_factory.create(String(1024))
self.assertTrue(isinstance(field, EncryptedType))
self.assertEqual(self.app.config["SECRET_KEY"], field.key)
assert isinstance(field, EncryptedType)
assert self.app.config["SECRET_KEY"] == field.key
def test_custom_adapter(self):
self.app.config["SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER"] = (
@ -62,10 +62,10 @@ class EncryptedFieldTest(SupersetTestCase):
)
encrypted_field_factory.init_app(self.app)
field = encrypted_field_factory.create(String(1024))
self.assertTrue(isinstance(field, StringEncryptedType))
self.assertFalse(isinstance(field, EncryptedType))
self.assertTrue(getattr(field, "__created_by_enc_field_adapter__"))
self.assertEqual(self.app.config["SECRET_KEY"], field.key)
assert isinstance(field, StringEncryptedType)
assert not isinstance(field, EncryptedType)
assert getattr(field, "__created_by_enc_field_adapter__")
assert self.app.config["SECRET_KEY"] == field.key
def test_ensure_encrypted_field_factory_is_used(self):
"""

View File

@ -25,7 +25,7 @@ class MachineAuthProviderTests(SupersetTestCase):
def test_get_auth_cookies(self):
user = self.get_user("admin")
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user)
self.assertIsNotNone(auth_cookies["session"])
assert auth_cookies["session"] is not None
@patch("superset.utils.machine_auth.MachineAuthProvider.get_auth_cookies")
def test_auth_driver_user(self, get_auth_cookies):

View File

@ -132,14 +132,14 @@ class TestUtils(SupersetTestCase):
json_str = '{"test": 1}'
blob = zlib_compress(json_str)
got_str = zlib_decompress(blob)
self.assertEqual(json_str, got_str)
assert json_str == got_str
def test_merge_extra_filters(self):
# does nothing if no extra filters
form_data = {"A": 1, "B": 2, "c": "test"}
expected = {**form_data, "adhoc_filters": [], "applied_time_extras": {}}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
# empty extra_filters
form_data = {"A": 1, "B": 2, "c": "test", "extra_filters": []}
expected = {
@ -150,7 +150,7 @@ class TestUtils(SupersetTestCase):
"applied_time_extras": {},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
# copy over extra filters into empty filters
form_data = {
"extra_filters": [
@ -182,7 +182,7 @@ class TestUtils(SupersetTestCase):
"applied_time_extras": {},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
# adds extra filters to existing filters
form_data = {
"extra_filters": [
@ -230,7 +230,7 @@ class TestUtils(SupersetTestCase):
"applied_time_extras": {},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
# adds extra filters to existing filters and sets time options
form_data = {
"extra_filters": [
@ -262,7 +262,7 @@ class TestUtils(SupersetTestCase):
},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_merge_extra_filters_ignores_empty_filters(self):
form_data = {
@ -273,7 +273,7 @@ class TestUtils(SupersetTestCase):
}
expected = {"adhoc_filters": [], "applied_time_extras": {}}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_merge_extra_filters_ignores_nones(self):
form_data = {
@ -301,7 +301,7 @@ class TestUtils(SupersetTestCase):
"applied_time_extras": {},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_merge_extra_filters_ignores_equal_filters(self):
form_data = {
@ -361,7 +361,7 @@ class TestUtils(SupersetTestCase):
"applied_time_extras": {},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_merge_extra_filters_merges_different_val_types(self):
form_data = {
@ -415,7 +415,7 @@ class TestUtils(SupersetTestCase):
"applied_time_extras": {},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
form_data = {
"extra_filters": [
{"col": "a", "op": "in", "val": "someval"},
@ -467,7 +467,7 @@ class TestUtils(SupersetTestCase):
"applied_time_extras": {},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_merge_extra_filters_adds_unequal_lists(self):
form_data = {
@ -530,27 +530,24 @@ class TestUtils(SupersetTestCase):
"applied_time_extras": {},
}
merge_extra_filters(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_merge_extra_filters_when_applied_time_extras_predefined(self):
form_data = {"applied_time_extras": {"__time_range": "Last week"}}
merge_extra_filters(form_data)
self.assertEqual(
form_data,
{
"applied_time_extras": {"__time_range": "Last week"},
"adhoc_filters": [],
},
)
assert form_data == {
"applied_time_extras": {"__time_range": "Last week"},
"adhoc_filters": [],
}
def test_merge_request_params_when_url_params_undefined(self):
form_data = {"since": "2000", "until": "now"}
url_params = {"form_data": form_data, "dashboard_ids": "(1,2,3,4,5)"}
merge_request_params(form_data, url_params)
self.assertIn("url_params", form_data.keys())
self.assertIn("dashboard_ids", form_data["url_params"])
self.assertNotIn("form_data", form_data.keys())
assert "url_params" in form_data.keys()
assert "dashboard_ids" in form_data["url_params"]
assert "form_data" not in form_data.keys()
def test_merge_request_params_when_url_params_predefined(self):
form_data = {
@ -560,30 +557,26 @@ class TestUtils(SupersetTestCase):
}
url_params = {"form_data": form_data, "dashboard_ids": "(1,2,3,4,5)"}
merge_request_params(form_data, url_params)
self.assertIn("url_params", form_data.keys())
self.assertIn("abc", form_data["url_params"])
self.assertEqual(
url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"]
)
assert "url_params" in form_data.keys()
assert "abc" in form_data["url_params"]
assert url_params["dashboard_ids"] == form_data["url_params"]["dashboard_ids"]
def test_format_timedelta(self):
self.assertEqual(json.format_timedelta(timedelta(0)), "0:00:00")
self.assertEqual(json.format_timedelta(timedelta(days=1)), "1 day, 0:00:00")
self.assertEqual(json.format_timedelta(timedelta(minutes=-6)), "-0:06:00")
self.assertEqual(
json.format_timedelta(timedelta(0) - timedelta(days=1, hours=5, minutes=6)),
"-1 day, 5:06:00",
assert json.format_timedelta(timedelta(0)) == "0:00:00"
assert json.format_timedelta(timedelta(days=1)) == "1 day, 0:00:00"
assert json.format_timedelta(timedelta(minutes=-6)) == "-0:06:00"
assert (
json.format_timedelta(timedelta(0) - timedelta(days=1, hours=5, minutes=6))
== "-1 day, 5:06:00"
)
self.assertEqual(
json.format_timedelta(
timedelta(0) - timedelta(days=16, hours=4, minutes=3)
),
"-16 days, 4:03:00",
assert (
json.format_timedelta(timedelta(0) - timedelta(days=16, hours=4, minutes=3))
== "-16 days, 4:03:00"
)
def test_validate_json(self):
valid = '{"a": 5, "b": [1, 5, ["g", "h"]]}'
self.assertIsNone(json.validate_json(valid))
assert json.validate_json(valid) is None
invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
with self.assertRaises(json.JSONDecodeError):
json.validate_json(invalid)
@ -601,7 +594,7 @@ class TestUtils(SupersetTestCase):
]
}
convert_legacy_filters_into_adhoc(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_convert_legacy_filters_into_adhoc_filters(self):
form_data = {"filters": [{"col": "a", "op": "in", "val": "someval"}]}
@ -618,7 +611,7 @@ class TestUtils(SupersetTestCase):
]
}
convert_legacy_filters_into_adhoc(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_convert_legacy_filters_into_adhoc_present_and_empty(self):
form_data = {"adhoc_filters": [], "where": "a = 1"}
@ -633,7 +626,7 @@ class TestUtils(SupersetTestCase):
]
}
convert_legacy_filters_into_adhoc(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_convert_legacy_filters_into_adhoc_having(self):
form_data = {"having": "COUNT(1) = 1"}
@ -648,7 +641,7 @@ class TestUtils(SupersetTestCase):
]
}
convert_legacy_filters_into_adhoc(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_convert_legacy_filters_into_adhoc_present_and_nonempty(self):
form_data = {
@ -664,23 +657,23 @@ class TestUtils(SupersetTestCase):
]
}
convert_legacy_filters_into_adhoc(form_data)
self.assertEqual(form_data, expected)
assert form_data == expected
def test_parse_js_uri_path_items_eval_undefined(self):
self.assertIsNone(parse_js_uri_path_item("undefined", eval_undefined=True))
self.assertIsNone(parse_js_uri_path_item("null", eval_undefined=True))
self.assertEqual("undefined", parse_js_uri_path_item("undefined"))
self.assertEqual("null", parse_js_uri_path_item("null"))
assert parse_js_uri_path_item("undefined", eval_undefined=True) is None
assert parse_js_uri_path_item("null", eval_undefined=True) is None
assert "undefined" == parse_js_uri_path_item("undefined")
assert "null" == parse_js_uri_path_item("null")
def test_parse_js_uri_path_items_unquote(self):
self.assertEqual("slashed/name", parse_js_uri_path_item("slashed%2fname"))
self.assertEqual(
"slashed%2fname", parse_js_uri_path_item("slashed%2fname", unquote=False)
assert "slashed/name" == parse_js_uri_path_item("slashed%2fname")
assert "slashed%2fname" == parse_js_uri_path_item(
"slashed%2fname", unquote=False
)
def test_parse_js_uri_path_items_item_optional(self):
self.assertIsNone(parse_js_uri_path_item(None))
self.assertIsNotNone(parse_js_uri_path_item("item"))
assert parse_js_uri_path_item(None) is None
assert parse_js_uri_path_item("item") is not None
def test_get_stacktrace(self):
app.config["SHOW_STACKTRACE"] = True
@ -688,7 +681,7 @@ class TestUtils(SupersetTestCase):
raise Exception("NONONO!")
except Exception:
stacktrace = get_stacktrace()
self.assertIn("NONONO", stacktrace)
assert "NONONO" in stacktrace
app.config["SHOW_STACKTRACE"] = False
try:
@ -698,31 +691,31 @@ class TestUtils(SupersetTestCase):
assert stacktrace is None
def test_split(self):
self.assertEqual(list(split("a b")), ["a", "b"])
self.assertEqual(list(split("a,b", delimiter=",")), ["a", "b"])
self.assertEqual(list(split("a,(b,a)", delimiter=",")), ["a", "(b,a)"])
self.assertEqual(
list(split('a,(b,a),"foo , bar"', delimiter=",")),
["a", "(b,a)", '"foo , bar"'],
)
self.assertEqual(
list(split("a,'b,c'", delimiter=",", quote="'")), ["a", "'b,c'"]
)
self.assertEqual(list(split('a "b c"')), ["a", '"b c"'])
self.assertEqual(list(split(r'a "b \" c"')), ["a", r'"b \" c"'])
assert list(split("a b")) == ["a", "b"]
assert list(split("a,b", delimiter=",")) == ["a", "b"]
assert list(split("a,(b,a)", delimiter=",")) == ["a", "(b,a)"]
assert list(split('a,(b,a),"foo , bar"', delimiter=",")) == [
"a",
"(b,a)",
'"foo , bar"',
]
assert list(split("a,'b,c'", delimiter=",", quote="'")) == ["a", "'b,c'"]
assert list(split('a "b c"')) == ["a", '"b c"']
assert list(split('a "b \\" c"')) == ["a", '"b \\" c"']
def test_get_or_create_db(self):
get_or_create_db("test_db", "sqlite:///superset.db")
database = db.session.query(Database).filter_by(database_name="test_db").one()
self.assertIsNotNone(database)
self.assertEqual(database.sqlalchemy_uri, "sqlite:///superset.db")
self.assertIsNotNone(
assert database is not None
assert database.sqlalchemy_uri == "sqlite:///superset.db"
assert (
security_manager.find_permission_view_menu("database_access", database.perm)
is not None
)
# Test change URI
get_or_create_db("test_db", "sqlite:///changed.db")
database = db.session.query(Database).filter_by(database_name="test_db").one()
self.assertEqual(database.sqlalchemy_uri, "sqlite:///changed.db")
assert database.sqlalchemy_uri == "sqlite:///changed.db"
db.session.delete(database)
db.session.commit()
@ -738,22 +731,16 @@ class TestUtils(SupersetTestCase):
assert database.sqlalchemy_uri == "sqlite:///superset.db"
def test_as_list(self):
self.assertListEqual(as_list(123), [123])
self.assertListEqual(as_list([123]), [123])
self.assertListEqual(as_list("foo"), ["foo"])
self.assertListEqual(as_list(123), [123]) # noqa: PT009
self.assertListEqual(as_list([123]), [123]) # noqa: PT009
self.assertListEqual(as_list("foo"), ["foo"]) # noqa: PT009
def test_merge_extra_filters_with_no_extras(self):
form_data = {
"time_range": "Last 10 days",
}
merge_extra_form_data(form_data)
self.assertEqual(
form_data,
{
"time_range": "Last 10 days",
"adhoc_filters": [],
},
)
assert form_data == {"time_range": "Last 10 days", "adhoc_filters": []}
def test_merge_extra_filters_with_unset_legacy_time_range(self):
"""
@ -767,14 +754,11 @@ class TestUtils(SupersetTestCase):
"extra_form_data": {"time_range": "Last year"},
}
merge_extra_filters(form_data)
self.assertEqual(
form_data,
{
"time_range": "Last year",
"applied_time_extras": {},
"adhoc_filters": [],
},
)
assert form_data == {
"time_range": "Last year",
"applied_time_extras": {},
"adhoc_filters": [],
}
def test_merge_extra_filters_with_extras(self):
form_data = {
@ -817,41 +801,45 @@ class TestUtils(SupersetTestCase):
def test_ssl_certificate_parse(self):
parsed_certificate = parse_ssl_cert(ssl_certificate)
self.assertEqual(parsed_certificate.serial_number, 12355228710836649848)
assert parsed_certificate.serial_number == 12355228710836649848
def test_ssl_certificate_file_creation(self):
path = create_ssl_cert_file(ssl_certificate)
expected_filename = md5_sha_from_str(ssl_certificate)
self.assertIn(expected_filename, path)
self.assertTrue(os.path.exists(path))
assert expected_filename in path
assert os.path.exists(path)
def test_get_email_address_list(self):
self.assertEqual(get_email_address_list("a@a"), ["a@a"])
self.assertEqual(get_email_address_list(" a@a "), ["a@a"])
self.assertEqual(get_email_address_list("a@a\n"), ["a@a"])
self.assertEqual(get_email_address_list(",a@a;"), ["a@a"])
self.assertEqual(
get_email_address_list(",a@a; b@b c@c a-c@c; d@d, f@f"),
["a@a", "b@b", "c@c", "a-c@c", "d@d", "f@f"],
)
assert get_email_address_list("a@a") == ["a@a"]
assert get_email_address_list(" a@a ") == ["a@a"]
assert get_email_address_list("a@a\n") == ["a@a"]
assert get_email_address_list(",a@a;") == ["a@a"]
assert get_email_address_list(",a@a; b@b c@c a-c@c; d@d, f@f") == [
"a@a",
"b@b",
"c@c",
"a-c@c",
"d@d",
"f@f",
]
def test_get_form_data_default(self) -> None:
form_data, slc = get_form_data()
self.assertEqual(slc, None)
assert slc is None
def test_get_form_data_request_args(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"foo": "bar"})}
):
form_data, slc = get_form_data()
self.assertEqual(form_data, {"foo": "bar"})
self.assertEqual(slc, None)
assert form_data == {"foo": "bar"}
assert slc is None
def test_get_form_data_request_form(self) -> None:
with app.test_request_context(data={"form_data": json.dumps({"foo": "bar"})}):
form_data, slc = get_form_data()
self.assertEqual(form_data, {"foo": "bar"})
self.assertEqual(slc, None)
assert form_data == {"foo": "bar"}
assert slc is None
def test_get_form_data_request_form_with_queries(self) -> None:
# the CSV export uses for requests, even when sending requests to
@ -862,8 +850,8 @@ class TestUtils(SupersetTestCase):
}
):
form_data, slc = get_form_data()
self.assertEqual(form_data, {"url_params": {"foo": "bar"}})
self.assertEqual(slc, None)
assert form_data == {"url_params": {"foo": "bar"}}
assert slc is None
def test_get_form_data_request_args_and_form(self) -> None:
with app.test_request_context(
@ -871,16 +859,16 @@ class TestUtils(SupersetTestCase):
query_string={"form_data": json.dumps({"baz": "bar"})},
):
form_data, slc = get_form_data()
self.assertEqual(form_data, {"baz": "bar", "foo": "bar"})
self.assertEqual(slc, None)
assert form_data == {"baz": "bar", "foo": "bar"}
assert slc is None
def test_get_form_data_globals(self) -> None:
with app.test_request_context():
g.form_data = {"foo": "bar"}
form_data, slc = get_form_data()
delattr(g, "form_data")
self.assertEqual(form_data, {"foo": "bar"})
self.assertEqual(slc, None)
assert form_data == {"foo": "bar"}
assert slc is None
def test_get_form_data_corrupted_json(self) -> None:
with app.test_request_context(
@ -888,8 +876,8 @@ class TestUtils(SupersetTestCase):
query_string={"form_data": '{"baz": "bar"'},
):
form_data, slc = get_form_data()
self.assertEqual(form_data, {})
self.assertEqual(slc, None)
assert form_data == {}
assert slc is None
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_log_this(self) -> None:
@ -912,29 +900,29 @@ class TestUtils(SupersetTestCase):
.first()
)
self.assertEqual(record.dashboard_id, dashboard_id)
self.assertEqual(json.loads(record.json)["dashboard_id"], str(dashboard_id))
self.assertEqual(json.loads(record.json)["form_data"]["slice_id"], slc.id)
assert record.dashboard_id == dashboard_id
assert json.loads(record.json)["dashboard_id"] == str(dashboard_id)
assert json.loads(record.json)["form_data"]["slice_id"] == slc.id
self.assertEqual(
json.loads(record.json)["form_data"]["viz_type"],
slc.viz.form_data["viz_type"],
assert (
json.loads(record.json)["form_data"]["viz_type"]
== slc.viz.form_data["viz_type"]
)
def test_schema_validate_json(self):
valid = '{"a": 5, "b": [1, 5, ["g", "h"]]}'
self.assertIsNone(schema.validate_json(valid))
assert schema.validate_json(valid) is None
invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
self.assertRaises(marshmallow.ValidationError, schema.validate_json, invalid)
def test_schema_one_of_case_insensitive(self):
validator = schema.OneOfCaseInsensitive(choices=[1, 2, 3, "FoO", "BAR", "baz"])
self.assertEqual(1, validator(1))
self.assertEqual(2, validator(2))
self.assertEqual("FoO", validator("FoO"))
self.assertEqual("FOO", validator("FOO"))
self.assertEqual("bar", validator("bar"))
self.assertEqual("BaZ", validator("BaZ"))
assert 1 == validator(1)
assert 2 == validator(2)
assert "FoO" == validator("FoO")
assert "FOO" == validator("FOO")
assert "bar" == validator("bar")
assert "BaZ" == validator("BaZ")
self.assertRaises(marshmallow.ValidationError, validator, "qwerty")
self.assertRaises(marshmallow.ValidationError, validator, 4)

View File

@ -82,8 +82,8 @@ class TestBaseViz(SupersetTestCase):
"SUM(SP_URB_TOTL)",
"count",
]
self.assertEqual(test_viz.metric_labels, expect_metric_labels)
self.assertEqual(test_viz.all_metrics, expect_metric_labels)
assert test_viz.metric_labels == expect_metric_labels
assert test_viz.all_metrics == expect_metric_labels
def test_get_df_returns_empty_df(self):
form_data = {"dummy": 123}
@ -91,8 +91,8 @@ class TestBaseViz(SupersetTestCase):
datasource = self.get_datasource_mock()
test_viz = viz.BaseViz(datasource, form_data)
result = test_viz.get_df(query_obj)
self.assertEqual(type(result), pd.DataFrame)
self.assertTrue(result.empty)
assert type(result) == pd.DataFrame
assert result.empty
def test_get_df_handles_dttm_col(self):
form_data = {"dummy": 123}
@ -148,31 +148,31 @@ class TestBaseViz(SupersetTestCase):
datasource = self.get_datasource_mock()
datasource.cache_timeout = 0
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(0, test_viz.cache_timeout)
assert 0 == test_viz.cache_timeout
datasource.cache_timeout = 156
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(156, test_viz.cache_timeout)
assert 156 == test_viz.cache_timeout
datasource.cache_timeout = None
datasource.database.cache_timeout = 0
self.assertEqual(0, test_viz.cache_timeout)
assert 0 == test_viz.cache_timeout
datasource.database.cache_timeout = 1666
self.assertEqual(1666, test_viz.cache_timeout)
assert 1666 == test_viz.cache_timeout
datasource.database.cache_timeout = None
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"],
test_viz.cache_timeout,
assert (
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
== test_viz.cache_timeout
)
data_cache_timeout = app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None
datasource.database.cache_timeout = None
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout)
assert app.config["CACHE_DEFAULT_TIMEOUT"] == test_viz.cache_timeout
# restore DATA_CACHE_CONFIG timeout
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout
@ -195,14 +195,14 @@ class TestDistBarViz(SupersetTestCase):
)
test_viz = viz.DistributionBarViz(datasource, form_data)
data = test_viz.get_data(df)[0]
self.assertEqual("votes", data["key"])
assert "votes" == data["key"]
expected_values = [
{"x": "pepperoni", "y": 5},
{"x": "cheese", "y": 3},
{"x": NULL_STRING, "y": 2},
{"x": "anchovies", "y": 1},
]
self.assertEqual(expected_values, data["values"])
assert expected_values == data["values"]
def test_groupby_nans(self):
form_data = {
@ -216,7 +216,7 @@ class TestDistBarViz(SupersetTestCase):
df = pd.DataFrame({"beds": [0, 1, nan, 2], "count": [30, 42, 3, 29]})
test_viz = viz.DistributionBarViz(datasource, form_data)
data = test_viz.get_data(df)[0]
self.assertEqual("count", data["key"])
assert "count" == data["key"]
expected_values = [
{"x": "1.0", "y": 42},
{"x": "0.0", "y": 30},
@ -224,7 +224,7 @@ class TestDistBarViz(SupersetTestCase):
{"x": NULL_STRING, "y": 3},
]
self.assertEqual(expected_values, data["values"])
assert expected_values == data["values"]
def test_column_nulls(self):
form_data = {
@ -254,7 +254,7 @@ class TestDistBarViz(SupersetTestCase):
"values": [{"x": "pepperoni", "y": 5}, {"x": "cheese", "y": 3}],
},
]
self.assertEqual(expected, data)
assert expected == data
def test_column_metrics_in_order(self):
form_data = {
@ -292,7 +292,7 @@ class TestDistBarViz(SupersetTestCase):
},
]
self.assertEqual(expected, data)
assert expected == data
def test_column_metrics_in_order_with_breakdowns(self):
form_data = {
@ -342,7 +342,7 @@ class TestDistBarViz(SupersetTestCase):
},
]
self.assertEqual(expected, data)
assert expected == data
class TestPairedTTest(SupersetTestCase):
@ -445,7 +445,7 @@ class TestPairedTTest(SupersetTestCase):
},
],
}
self.assertEqual(data, expected)
assert data == expected
def test_get_data_empty_null_keys(self):
form_data = {"groupby": [], "metrics": [""]}
@ -472,7 +472,7 @@ class TestPairedTTest(SupersetTestCase):
}
],
}
self.assertEqual(data, expected)
assert data == expected
form_data = {"groupby": [], "metrics": [None]}
with self.assertRaises(ValueError):
@ -487,10 +487,10 @@ class TestPartitionViz(SupersetTestCase):
test_viz = viz.PartitionViz(datasource, form_data)
super_query_obj.return_value = {}
query_obj = test_viz.query_obj()
self.assertFalse(query_obj["is_timeseries"])
assert not query_obj["is_timeseries"]
test_viz.form_data["time_series_option"] = "agg_sum"
query_obj = test_viz.query_obj()
self.assertTrue(query_obj["is_timeseries"])
assert query_obj["is_timeseries"]
def test_levels_for_computes_levels(self):
raw = {}
@ -506,37 +506,37 @@ class TestPartitionViz(SupersetTestCase):
time_op = "agg_sum"
test_viz = viz.PartitionViz(Mock(), {})
levels = test_viz.levels_for(time_op, groups, df)
self.assertEqual(4, len(levels))
assert 4 == len(levels)
expected = {DTTM_ALIAS: 1800, "metric1": 45, "metric2": 450, "metric3": 4500}
self.assertEqual(expected, levels[0].to_dict())
assert expected == levels[0].to_dict()
expected = {
DTTM_ALIAS: {"a1": 600, "b1": 600, "c1": 600},
"metric1": {"a1": 6, "b1": 15, "c1": 24},
"metric2": {"a1": 60, "b1": 150, "c1": 240},
"metric3": {"a1": 600, "b1": 1500, "c1": 2400},
}
self.assertEqual(expected, levels[1].to_dict())
self.assertEqual(["groupA", "groupB"], levels[2].index.names)
self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names)
assert expected == levels[1].to_dict()
assert ["groupA", "groupB"] == levels[2].index.names
assert ["groupA", "groupB", "groupC"] == levels[3].index.names
time_op = "agg_mean"
levels = test_viz.levels_for(time_op, groups, df)
self.assertEqual(4, len(levels))
assert 4 == len(levels)
expected = {
DTTM_ALIAS: 200.0,
"metric1": 5.0,
"metric2": 50.0,
"metric3": 500.0,
}
self.assertEqual(expected, levels[0].to_dict())
assert expected == levels[0].to_dict()
expected = {
DTTM_ALIAS: {"a1": 200, "c1": 200, "b1": 200},
"metric1": {"a1": 2, "b1": 5, "c1": 8},
"metric2": {"a1": 20, "b1": 50, "c1": 80},
"metric3": {"a1": 200, "b1": 500, "c1": 800},
}
self.assertEqual(expected, levels[1].to_dict())
self.assertEqual(["groupA", "groupB"], levels[2].index.names)
self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names)
assert expected == levels[1].to_dict()
assert ["groupA", "groupB"] == levels[2].index.names
assert ["groupA", "groupB", "groupC"] == levels[3].index.names
def test_levels_for_diff_computes_difference(self):
raw = {}
@ -553,15 +553,15 @@ class TestPartitionViz(SupersetTestCase):
time_op = "point_diff"
levels = test_viz.levels_for_diff(time_op, groups, df)
expected = {"metric1": 6, "metric2": 60, "metric3": 600}
self.assertEqual(expected, levels[0].to_dict())
assert expected == levels[0].to_dict()
expected = {
"metric1": {"a1": 2, "b1": 2, "c1": 2},
"metric2": {"a1": 20, "b1": 20, "c1": 20},
"metric3": {"a1": 200, "b1": 200, "c1": 200},
}
self.assertEqual(expected, levels[1].to_dict())
self.assertEqual(4, len(levels))
self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names)
assert expected == levels[1].to_dict()
assert 4 == len(levels)
assert ["groupA", "groupB", "groupC"] == levels[3].index.names
def test_levels_for_time_calls_process_data_and_drops_cols(self):
raw = {}
@ -581,16 +581,16 @@ class TestPartitionViz(SupersetTestCase):
test_viz.process_data = Mock(side_effect=return_args)
levels = test_viz.levels_for_time(groups, df)
self.assertEqual(4, len(levels))
assert 4 == len(levels)
cols = [DTTM_ALIAS, "metric1", "metric2", "metric3"]
self.assertEqual(sorted(cols), sorted(levels[0].columns.tolist()))
assert sorted(cols) == sorted(levels[0].columns.tolist())
cols += ["groupA"]
self.assertEqual(sorted(cols), sorted(levels[1].columns.tolist()))
assert sorted(cols) == sorted(levels[1].columns.tolist())
cols += ["groupB"]
self.assertEqual(sorted(cols), sorted(levels[2].columns.tolist()))
assert sorted(cols) == sorted(levels[2].columns.tolist())
cols += ["groupC"]
self.assertEqual(sorted(cols), sorted(levels[3].columns.tolist()))
self.assertEqual(4, len(test_viz.process_data.mock_calls))
assert sorted(cols) == sorted(levels[3].columns.tolist())
assert 4 == len(test_viz.process_data.mock_calls)
def test_nest_values_returns_hierarchy(self):
raw = {}
@ -605,12 +605,12 @@ class TestPartitionViz(SupersetTestCase):
groups = ["groupA", "groupB", "groupC"]
levels = test_viz.levels_for("agg_sum", groups, df)
nest = test_viz.nest_values(levels)
self.assertEqual(3, len(nest))
assert 3 == len(nest)
for i in range(0, 3):
self.assertEqual("metric" + str(i + 1), nest[i]["name"])
self.assertEqual(3, len(nest[0]["children"]))
self.assertEqual(1, len(nest[0]["children"][0]["children"]))
self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"]))
assert "metric" + str(i + 1) == nest[i]["name"]
assert 3 == len(nest[0]["children"])
assert 1 == len(nest[0]["children"][0]["children"])
assert 1 == len(nest[0]["children"][0]["children"][0]["children"])
def test_nest_procs_returns_hierarchy(self):
raw = {}
@ -633,15 +633,15 @@ class TestPartitionViz(SupersetTestCase):
)
procs[i] = pivot
nest = test_viz.nest_procs(procs)
self.assertEqual(3, len(nest))
assert 3 == len(nest)
for i in range(0, 3):
self.assertEqual("metric" + str(i + 1), nest[i]["name"])
self.assertEqual(None, nest[i].get("val"))
self.assertEqual(3, len(nest[0]["children"]))
self.assertEqual(3, len(nest[0]["children"][0]["children"]))
self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"]))
self.assertEqual(
1, len(nest[0]["children"][0]["children"][0]["children"][0]["children"])
assert "metric" + str(i + 1) == nest[i]["name"]
assert None is nest[i].get("val")
assert 3 == len(nest[0]["children"])
assert 3 == len(nest[0]["children"][0]["children"])
assert 1 == len(nest[0]["children"][0]["children"][0]["children"])
assert 1 == len(
nest[0]["children"][0]["children"][0]["children"][0]["children"]
)
def test_get_data_calls_correct_method(self):
@ -662,33 +662,33 @@ class TestPartitionViz(SupersetTestCase):
test_viz.form_data["groupby"] = ["groups"]
test_viz.form_data["time_series_option"] = "not_time"
test_viz.get_data(df)
self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[0][1][0])
assert "agg_sum" == test_viz.levels_for.mock_calls[0][1][0]
test_viz.form_data["time_series_option"] = "agg_sum"
test_viz.get_data(df)
self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[1][1][0])
assert "agg_sum" == test_viz.levels_for.mock_calls[1][1][0]
test_viz.form_data["time_series_option"] = "agg_mean"
test_viz.get_data(df)
self.assertEqual("agg_mean", test_viz.levels_for.mock_calls[2][1][0])
assert "agg_mean" == test_viz.levels_for.mock_calls[2][1][0]
test_viz.form_data["time_series_option"] = "point_diff"
test_viz.levels_for_diff = Mock(return_value=1)
test_viz.get_data(df)
self.assertEqual("point_diff", test_viz.levels_for_diff.mock_calls[0][1][0])
assert "point_diff" == test_viz.levels_for_diff.mock_calls[0][1][0]
test_viz.form_data["time_series_option"] = "point_percent"
test_viz.get_data(df)
self.assertEqual("point_percent", test_viz.levels_for_diff.mock_calls[1][1][0])
assert "point_percent" == test_viz.levels_for_diff.mock_calls[1][1][0]
test_viz.form_data["time_series_option"] = "point_factor"
test_viz.get_data(df)
self.assertEqual("point_factor", test_viz.levels_for_diff.mock_calls[2][1][0])
assert "point_factor" == test_viz.levels_for_diff.mock_calls[2][1][0]
test_viz.levels_for_time = Mock(return_value=1)
test_viz.nest_procs = Mock(return_value=1)
test_viz.form_data["time_series_option"] = "adv_anal"
test_viz.get_data(df)
self.assertEqual(1, len(test_viz.levels_for_time.mock_calls))
self.assertEqual(1, len(test_viz.nest_procs.mock_calls))
assert 1 == len(test_viz.levels_for_time.mock_calls)
assert 1 == len(test_viz.nest_procs.mock_calls)
test_viz.form_data["time_series_option"] = "time_series"
test_viz.get_data(df)
self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[3][1][0])
self.assertEqual(7, len(test_viz.nest_values.mock_calls))
assert "agg_sum" == test_viz.levels_for.mock_calls[3][1][0]
assert 7 == len(test_viz.nest_values.mock_calls)
class TestRoseVis(SupersetTestCase):
@ -724,7 +724,7 @@ class TestRoseVis(SupersetTestCase):
{"time": t3, "value": 9, "key": ("c1",), "name": ("c1",)},
],
}
self.assertEqual(expected, res)
assert expected == res
class TestTimeSeriesTableViz(SupersetTestCase):
@ -741,13 +741,13 @@ class TestTimeSeriesTableViz(SupersetTestCase):
test_viz = viz.TimeTableViz(datasource, form_data)
data = test_viz.get_data(df)
# Check method correctly transforms data
self.assertEqual({"count", "sum__A"}, set(data["columns"]))
assert {"count", "sum__A"} == set(data["columns"])
time_format = "%Y-%m-%d %H:%M:%S"
expected = {
t1.strftime(time_format): {"sum__A": 15, "count": 6},
t2.strftime(time_format): {"sum__A": 20, "count": 7},
}
self.assertEqual(expected, data["records"])
assert expected == data["records"]
def test_get_data_group_by(self):
form_data = {"metrics": ["sum__A"], "groupby": ["groupby1"]}
@ -762,13 +762,13 @@ class TestTimeSeriesTableViz(SupersetTestCase):
test_viz = viz.TimeTableViz(datasource, form_data)
data = test_viz.get_data(df)
# Check method correctly transforms data
self.assertEqual({"a1", "a2", "a3"}, set(data["columns"]))
assert {"a1", "a2", "a3"} == set(data["columns"])
time_format = "%Y-%m-%d %H:%M:%S"
expected = {
t1.strftime(time_format): {"a1": 15, "a2": 20, "a3": 25},
t2.strftime(time_format): {"a1": 30, "a2": 35, "a3": 40},
}
self.assertEqual(expected, data["records"])
assert expected == data["records"]
@patch("superset.viz.BaseViz.query_obj")
def test_query_obj_throws_metrics_and_groupby(self, super_query_obj):
@ -788,7 +788,7 @@ class TestTimeSeriesTableViz(SupersetTestCase):
self.get_datasource_mock(), {"metrics": ["sum__A", "count"], "groupby": []}
)
query_obj = test_viz.query_obj()
self.assertEqual(query_obj["orderby"], [("sum__A", False)])
assert query_obj["orderby"] == [("sum__A", False)]
class TestBaseDeckGLViz(SupersetTestCase):
@ -838,7 +838,7 @@ class TestBaseDeckGLViz(SupersetTestCase):
with self.assertRaises(NotImplementedError) as context:
test_viz_deckgl.get_properties(mock_d)
self.assertTrue("" in str(context.exception))
assert "" in str(context.exception)
def test_process_spatial_query_obj(self):
form_data = load_fixture("deck_path_form_data.json")
@ -850,7 +850,7 @@ class TestBaseDeckGLViz(SupersetTestCase):
with self.assertRaises(ValueError) as context:
test_viz_deckgl.process_spatial_query_obj(mock_key, mock_gb)
self.assertTrue("Bad spatial key" in str(context.exception))
assert "Bad spatial key" in str(context.exception)
test_form_data = {
"latlong_key": {"type": "latlong", "lonCol": "lon", "latCol": "lat"},
@ -886,14 +886,14 @@ class TestBaseDeckGLViz(SupersetTestCase):
viz_instance = viz.BaseDeckGLViz(datasource, form_data)
coord = viz_instance.parse_coordinates("1.23, 3.21")
self.assertEqual(coord, (1.23, 3.21))
assert coord == (1.23, 3.21)
coord = viz_instance.parse_coordinates("1.23 3.21")
self.assertEqual(coord, (1.23, 3.21))
assert coord == (1.23, 3.21)
self.assertEqual(viz_instance.parse_coordinates(None), None)
assert viz_instance.parse_coordinates(None) is None
self.assertEqual(viz_instance.parse_coordinates(""), None)
assert viz_instance.parse_coordinates("") is None
def test_parse_coordinates_raises(self):
form_data = load_fixture("deck_path_form_data.json")
@ -1001,7 +1001,7 @@ class TestTimeSeriesViz(SupersetTestCase):
"key": ("Real Madrid C.F.\U0001f1fa\U0001f1f8\U0001f1ec\U0001f1e7",),
},
]
self.assertEqual(expected, viz_data)
assert expected == viz_data
def test_process_data_resample(self):
datasource = self.get_datasource_mock()
@ -1015,15 +1015,10 @@ class TestTimeSeriesViz(SupersetTestCase):
}
)
self.assertEqual(
viz.NVD3TimeSeriesViz(
datasource,
{"metrics": ["y"], "resample_method": "sum", "resample_rule": "1D"},
)
.process_data(df)["y"]
.tolist(),
[1.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0],
)
assert viz.NVD3TimeSeriesViz(
datasource,
{"metrics": ["y"], "resample_method": "sum", "resample_rule": "1D"},
).process_data(df)["y"].tolist() == [1.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0]
np.testing.assert_equal(
viz.NVD3TimeSeriesViz(
@ -1043,48 +1038,33 @@ class TestTimeSeriesViz(SupersetTestCase):
),
data={"y": [1.0, 2.0, 3.0, 4.0]},
)
self.assertEqual(
viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "cumsum",
"rolling_periods": 0,
"min_periods": 0,
},
)
.apply_rolling(df)["y"]
.tolist(),
[1.0, 3.0, 6.0, 10.0],
)
self.assertEqual(
viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "sum",
"rolling_periods": 2,
"min_periods": 0,
},
)
.apply_rolling(df)["y"]
.tolist(),
[1.0, 3.0, 5.0, 7.0],
)
self.assertEqual(
viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "mean",
"rolling_periods": 10,
"min_periods": 0,
},
)
.apply_rolling(df)["y"]
.tolist(),
[1.0, 1.5, 2.0, 2.5],
)
assert viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "cumsum",
"rolling_periods": 0,
"min_periods": 0,
},
).apply_rolling(df)["y"].tolist() == [1.0, 3.0, 6.0, 10.0]
assert viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "sum",
"rolling_periods": 2,
"min_periods": 0,
},
).apply_rolling(df)["y"].tolist() == [1.0, 3.0, 5.0, 7.0]
assert viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "mean",
"rolling_periods": 10,
"min_periods": 0,
},
).apply_rolling(df)["y"].tolist() == [1.0, 1.5, 2.0, 2.5]
def test_apply_rolling_without_data(self):
datasource = self.get_datasource_mock()