Improve examples & related tests (#7773)

* [WiP] improve load_examples

related to #7472, longer term we will generate the examples by exporting
them into tarball as in #7472. In the meantime, we need this subset of
the features:

* allowing specifying an alternate database connection for examples
* allowing a --only-metadata flag to `load_examples` to load only
  dashboard and chart definitions, no actual data is loaded

* Improve logging

* Rename data->examples

* Load only if not exist

* By default do not load, add a force flag

* fix build

* set published to true
This commit is contained in:
Maxime Beauchemin 2019-07-16 21:36:56 -07:00 committed by GitHub
parent 86fdceb236
commit d65b039219
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 580 additions and 488 deletions

View File

@ -18,7 +18,7 @@ include NOTICE
include LICENSE.txt
graft licenses/
include README.md
recursive-include superset/data *
recursive-include superset/examples *
recursive-include superset/migrations *
recursive-include superset/static *
recursive-exclude superset/static/assets/docs *

View File

@ -26,7 +26,7 @@ from colorama import Fore, Style
from pathlib2 import Path
import yaml
from superset import app, appbuilder, data, db, security_manager
from superset import app, appbuilder, db, examples, security_manager
from superset.utils import core as utils, dashboard_import_export, dict_import_export
config = app.config
@ -46,6 +46,7 @@ def make_shell_context():
def init():
"""Inits the Superset application"""
utils.get_or_create_main_db()
utils.get_example_database()
appbuilder.add_permissions(update_perms=True)
security_manager.sync_role_definitions()
@ -67,66 +68,76 @@ def version(verbose):
print(Style.RESET_ALL)
def load_examples_run(load_test_data):
print("Loading examples into {}".format(db))
def load_examples_run(load_test_data, only_metadata=False, force=False):
if only_metadata:
print("Loading examples metadata")
else:
examples_db = utils.get_example_database()
print(f"Loading examples metadata and related data into {examples_db}")
data.load_css_templates()
examples.load_css_templates()
print("Loading energy related dataset")
data.load_energy()
examples.load_energy(only_metadata, force)
print("Loading [World Bank's Health Nutrition and Population Stats]")
data.load_world_bank_health_n_pop()
examples.load_world_bank_health_n_pop(only_metadata, force)
print("Loading [Birth names]")
data.load_birth_names()
examples.load_birth_names(only_metadata, force)
print("Loading [Unicode test data]")
data.load_unicode_test_data()
examples.load_unicode_test_data(only_metadata, force)
if not load_test_data:
print("Loading [Random time series data]")
data.load_random_time_series_data()
examples.load_random_time_series_data(only_metadata, force)
print("Loading [Random long/lat data]")
data.load_long_lat_data()
examples.load_long_lat_data(only_metadata, force)
print("Loading [Country Map data]")
data.load_country_map_data()
examples.load_country_map_data(only_metadata, force)
print("Loading [Multiformat time series]")
data.load_multiformat_time_series()
examples.load_multiformat_time_series(only_metadata, force)
print("Loading [Paris GeoJson]")
data.load_paris_iris_geojson()
examples.load_paris_iris_geojson(only_metadata, force)
print("Loading [San Francisco population polygons]")
data.load_sf_population_polygons()
examples.load_sf_population_polygons(only_metadata, force)
print("Loading [Flights data]")
data.load_flights()
examples.load_flights(only_metadata, force)
print("Loading [BART lines]")
data.load_bart_lines()
examples.load_bart_lines(only_metadata, force)
print("Loading [Multi Line]")
data.load_multi_line()
examples.load_multi_line(only_metadata)
print("Loading [Misc Charts] dashboard")
data.load_misc_dashboard()
examples.load_misc_dashboard()
print("Loading DECK.gl demo")
data.load_deck_dash()
examples.load_deck_dash()
print("Loading [Tabbed dashboard]")
data.load_tabbed_dashboard()
examples.load_tabbed_dashboard(only_metadata)
@app.cli.command()
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
def load_examples(load_test_data):
@click.option(
"--only-metadata", "-m", is_flag=True, help="Only load metadata, skip actual data"
)
@click.option(
"--force", "-f", is_flag=True, help="Force load data even if table already exists"
)
def load_examples(load_test_data, only_metadata=False, force=False):
"""Loads a set of Slices and Dashboards and a supporting dataset """
load_examples_run(load_test_data)
load_examples_run(load_test_data, only_metadata, force)
@app.cli.command()
@ -405,7 +416,7 @@ def load_test_users_run():
for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
utils.get_or_create_main_db()
db_perm = utils.get_main_database(security_manager.get_session).perm
db_perm = utils.get_main_database().perm
security_manager.add_permission_view_menu("database_access", db_perm)
db_pvm = security_manager.find_permission_view_menu(
view_menu_name=db_perm, permission_name="database_access"

View File

@ -617,6 +617,10 @@ TALISMAN_CONFIG = {
"force_https_permanent": False,
}
# URI to database storing the example data, points to
# SQLALCHEMY_DATABASE_URI by default if set to `None`
SQLALCHEMY_EXAMPLES_URI = None
try:
if CONFIG_PATH_ENV_VAR in os.environ:
# Explicitly import config module that is not in pythonpath; useful

View File

@ -55,18 +55,9 @@ class ConnectorRegistry(object):
cls, session, datasource_type, datasource_name, schema, database_name
):
datasource_class = ConnectorRegistry.sources[datasource_type]
datasources = session.query(datasource_class).all()
# Filter datasoures that don't have database.
db_ds = [
d
for d in datasources
if d.database
and d.database.name == database_name
and d.name == datasource_name
and schema == schema
]
return db_ds[0]
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
)
@classmethod
def query_datasources_by_permissions(cls, session, database, permissions):

View File

@ -732,6 +732,16 @@ class DruidDatasource(Model, BaseDatasource):
return 6 * 24 * 3600 * 1000 # 6 days
return 0
@classmethod
def get_datasource_by_name(cls, session, datasource_name, schema, database_name):
query = (
session.query(cls)
.join(DruidCluster)
.filter(cls.datasource_name == datasource_name)
.filter(DruidCluster.cluster_name == database_name)
)
return query.first()
# uses https://en.wikipedia.org/wiki/ISO_8601
# http://druid.io/docs/0.8.0/querying/granularities.html
# TODO: pass origin from the UI

View File

@ -374,6 +374,21 @@ class SqlaTable(Model, BaseDatasource):
def database_name(self):
return self.database.name
@classmethod
def get_datasource_by_name(cls, session, datasource_name, schema, database_name):
schema = schema or None
query = (
session.query(cls)
.join(Database)
.filter(cls.table_name == datasource_name)
.filter(Database.database_name == database_name)
)
# Handling schema being '' or None, which is easier to handle
# in python than in the SQLA query in a multi-dialect way
for tbl in query.all():
if schema == (tbl.schema or None):
return tbl
@property
def link(self):
name = escape(self.name)

View File

@ -21,37 +21,42 @@ import polyline
from sqlalchemy import String, Text
from superset import db
from superset.utils.core import get_or_create_main_db
from .helpers import TBL, get_example_data
from superset.utils.core import get_example_database
from .helpers import get_example_data, TBL
def load_bart_lines():
def load_bart_lines(only_metadata=False, force=False):
tbl_name = "bart_lines"
content = get_example_data("bart-lines.json.gz")
df = pd.read_json(content, encoding="latin-1")
df["path_json"] = df.path.map(json.dumps)
df["polyline"] = df.path.map(polyline.encode)
del df["path"]
database = get_example_database()
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
content = get_example_data("bart-lines.json.gz")
df = pd.read_json(content, encoding="latin-1")
df["path_json"] = df.path.map(json.dumps)
df["polyline"] = df.path.map(polyline.encode)
del df["path"]
df.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
"color": String(255),
"name": String(255),
"polyline": Text,
"path_json": Text,
},
index=False,
)
df.to_sql(
tbl_name,
db.engine,
if_exists="replace",
chunksize=500,
dtype={
"color": String(255),
"name": String(255),
"polyline": Text,
"path_json": Text,
},
index=False,
)
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "BART lines"
tbl.database = get_or_create_main_db()
tbl.database = database
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

View File

@ -23,7 +23,7 @@ from sqlalchemy.sql import column
from superset import db, security_manager
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.utils.core import get_or_create_main_db
from superset.utils.core import get_example_database
from .helpers import (
config,
Dash,
@ -36,33 +36,39 @@ from .helpers import (
)
def load_birth_names():
def load_birth_names(only_metadata=False, force=False):
"""Loading birth name dataset from a zip file in the repo"""
data = get_example_data("birth_names.json.gz")
pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf.to_sql(
"birth_names",
db.engine,
if_exists="replace",
chunksize=500,
dtype={
"ds": DateTime,
"gender": String(16),
"state": String(10),
"name": String(255),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
# pylint: disable=too-many-locals
tbl_name = "birth_names"
database = get_example_database()
table_exists = database.has_table_by_name(tbl_name)
print("Creating table [birth_names] reference")
obj = db.session.query(TBL).filter_by(table_name="birth_names").first()
if not only_metadata and (not table_exists or force):
pdf = pd.read_json(get_example_data("birth_names.json.gz"))
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
"ds": DateTime,
"gender": String(16),
"state": String(10),
"name": String(255),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name="birth_names")
print(f"Creating table [{tbl_name}] reference")
obj = TBL(table_name=tbl_name)
db.session.add(obj)
obj.main_dttm_col = "ds"
obj.database = get_or_create_main_db()
obj.database = database
obj.filter_select_enabled = True
if not any(col.column_name == "num_california" for col in obj.columns):
@ -79,7 +85,6 @@ def load_birth_names():
col = str(column("num").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
@ -384,10 +389,12 @@ def load_birth_names():
merge_slice(slc)
print("Creating a dashboard")
dash = db.session.query(Dash).filter_by(dashboard_title="Births").first()
dash = db.session.query(Dash).filter_by(slug="births").first()
if not dash:
dash = Dash()
db.session.add(dash)
dash.published = True
js = textwrap.dedent(
# pylint: disable=line-too-long
"""\
@ -649,5 +656,4 @@ def load_birth_names():
dash.dashboard_title = "Births"
dash.position_json = json.dumps(pos, indent=4)
dash.slug = "births"
db.session.merge(dash)
db.session.commit()

View File

@ -33,44 +33,50 @@ from .helpers import (
)
def load_country_map_data():
def load_country_map_data(only_metadata=False, force=False):
"""Loading data for map with country map"""
csv_bytes = get_example_data(
"birth_france_data_for_country_map.csv", is_gzip=False, make_bytes=True
)
data = pd.read_csv(csv_bytes, encoding="utf-8")
data["dttm"] = datetime.datetime.now().date()
data.to_sql( # pylint: disable=no-member
"birth_france_by_region",
db.engine,
if_exists="replace",
chunksize=500,
dtype={
"DEPT_ID": String(10),
"2003": BigInteger,
"2004": BigInteger,
"2005": BigInteger,
"2006": BigInteger,
"2007": BigInteger,
"2008": BigInteger,
"2009": BigInteger,
"2010": BigInteger,
"2011": BigInteger,
"2012": BigInteger,
"2013": BigInteger,
"2014": BigInteger,
"dttm": Date(),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
tbl_name = "birth_france_by_region"
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
csv_bytes = get_example_data(
"birth_france_data_for_country_map.csv", is_gzip=False, make_bytes=True
)
data = pd.read_csv(csv_bytes, encoding="utf-8")
data["dttm"] = datetime.datetime.now().date()
data.to_sql( # pylint: disable=no-member
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
"DEPT_ID": String(10),
"2003": BigInteger,
"2004": BigInteger,
"2005": BigInteger,
"2006": BigInteger,
"2007": BigInteger,
"2008": BigInteger,
"2009": BigInteger,
"2010": BigInteger,
"2011": BigInteger,
"2012": BigInteger,
"2013": BigInteger,
"2014": BigInteger,
"dttm": Date(),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
print("Creating table reference")
obj = db.session.query(TBL).filter_by(table_name="birth_france_by_region").first()
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name="birth_france_by_region")
obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "dttm"
obj.database = utils.get_or_create_main_db()
obj.database = database
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})"))

View File

@ -501,6 +501,7 @@ def load_deck_dash():
if not dash:
dash = Dash()
dash.published = True
js = POSITION_JSON
pos = json.loads(js)
update_slice_ids(pos, slices)

View File

@ -25,36 +25,33 @@ from sqlalchemy.sql import column
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.utils import core as utils
from .helpers import (
DATA_FOLDER,
get_example_data,
merge_slice,
misc_dash_slices,
Slice,
TBL,
)
from .helpers import get_example_data, merge_slice, misc_dash_slices, Slice, TBL
def load_energy():
def load_energy(only_metadata=False, force=False):
"""Loads an energy related dataset to use with sankey and graphs"""
tbl_name = "energy_usage"
data = get_example_data("energy.json.gz")
pdf = pd.read_json(data)
pdf.to_sql(
tbl_name,
db.engine,
if_exists="replace",
chunksize=500,
dtype={"source": String(255), "target": String(255), "value": Float()},
index=False,
)
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
data = get_example_data("energy.json.gz")
pdf = pd.read_json(data)
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={"source": String(255), "target": String(255), "value": Float()},
index=False,
)
print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Energy consumption"
tbl.database = utils.get_or_create_main_db()
tbl.database = database
if not any(col.metric_name == "sum__value" for col in tbl.metrics):
col = str(column("value").compile(db.engine))

View File

@ -22,38 +22,45 @@ from superset.utils import core as utils
from .helpers import get_example_data, TBL
def load_flights():
def load_flights(only_metadata=False, force=False):
"""Loading random time series data from a zip file in the repo"""
tbl_name = "flights"
data = get_example_data("flight_data.csv.gz", make_bytes=True)
pdf = pd.read_csv(data, encoding="latin-1")
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
# Loading airports info to join and get lat/long
airports_bytes = get_example_data("airports.csv.gz", make_bytes=True)
airports = pd.read_csv(airports_bytes, encoding="latin-1")
airports = airports.set_index("IATA_CODE")
if not only_metadata and (not table_exists or force):
data = get_example_data("flight_data.csv.gz", make_bytes=True)
pdf = pd.read_csv(data, encoding="latin-1")
pdf["ds"] = pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str)
pdf.ds = pd.to_datetime(pdf.ds)
del pdf["YEAR"]
del pdf["MONTH"]
del pdf["DAY"]
# Loading airports info to join and get lat/long
airports_bytes = get_example_data("airports.csv.gz", make_bytes=True)
airports = pd.read_csv(airports_bytes, encoding="latin-1")
airports = airports.set_index("IATA_CODE")
pdf["ds"] = (
pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str)
)
pdf.ds = pd.to_datetime(pdf.ds)
del pdf["YEAR"]
del pdf["MONTH"]
del pdf["DAY"]
pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime},
index=False,
)
pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
pdf.to_sql(
tbl_name,
db.engine,
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime},
index=False,
)
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Random set of flights in the US"
tbl.database = utils.get_or_create_main_db()
tbl.database = database
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

View File

@ -38,7 +38,7 @@ TBL = ConnectorRegistry.sources["table"]
config = app.config
DATA_FOLDER = os.path.join(config.get("BASE_DIR"), "data")
EXAMPLES_FOLDER = os.path.join(config.get("BASE_DIR"), "examples")
misc_dash_slices = set() # slices assembled in a 'Misc Chart' dashboard

View File

@ -33,52 +33,59 @@ from .helpers import (
)
def load_long_lat_data():
def load_long_lat_data(only_metadata=False, force=False):
"""Loading lat/long data from a csv file in the repo"""
data = get_example_data("san_francisco.csv.gz", make_bytes=True)
pdf = pd.read_csv(data, encoding="utf-8")
start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
pdf["datetime"] = [
start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
for i in range(len(pdf))
]
pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1)
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
pdf.to_sql( # pylint: disable=no-member
"long_lat",
db.engine,
if_exists="replace",
chunksize=500,
dtype={
"longitude": Float(),
"latitude": Float(),
"number": Float(),
"street": String(100),
"unit": String(10),
"city": String(50),
"district": String(50),
"region": String(50),
"postcode": Float(),
"id": String(100),
"datetime": DateTime(),
"occupancy": Float(),
"radius_miles": Float(),
"geohash": String(12),
"delimited": String(60),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
tbl_name = "long_lat"
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
data = get_example_data("san_francisco.csv.gz", make_bytes=True)
pdf = pd.read_csv(data, encoding="utf-8")
start = datetime.datetime.now().replace(
hour=0, minute=0, second=0, microsecond=0
)
pdf["datetime"] = [
start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
for i in range(len(pdf))
]
pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1)
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
pdf.to_sql( # pylint: disable=no-member
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
"longitude": Float(),
"latitude": Float(),
"number": Float(),
"street": String(100),
"unit": String(10),
"city": String(50),
"district": String(50),
"region": String(50),
"postcode": Float(),
"id": String(100),
"datetime": DateTime(),
"occupancy": Float(),
"radius_miles": Float(),
"geohash": String(12),
"delimited": String(60),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
print("Creating table reference")
obj = db.session.query(TBL).filter_by(table_name="long_lat").first()
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name="long_lat")
obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "datetime"
obj.database = utils.get_or_create_main_db()
obj.database = database
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()

View File

@ -22,9 +22,9 @@ from .helpers import merge_slice, misc_dash_slices, Slice
from .world_bank import load_world_bank_health_n_pop
def load_multi_line():
load_world_bank_health_n_pop()
load_birth_names()
def load_multi_line(only_metadata=False):
load_world_bank_health_n_pop(only_metadata)
load_birth_names(only_metadata)
ids = [
row.id
for row in db.session.query(Slice).filter(

View File

@ -19,7 +19,7 @@ import pandas as pd
from sqlalchemy import BigInteger, Date, DateTime, String
from superset import db
from superset.utils import core as utils
from superset.utils.core import get_example_database
from .helpers import (
config,
get_example_data,
@ -31,38 +31,44 @@ from .helpers import (
)
def load_multiformat_time_series():
def load_multiformat_time_series(only_metadata=False, force=False):
"""Loading time series data from a zip file in the repo"""
data = get_example_data("multiformat_time_series.json.gz")
pdf = pd.read_json(data)
tbl_name = "multiformat_time_series"
database = get_example_database()
table_exists = database.has_table_by_name(tbl_name)
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
pdf.to_sql(
"multiformat_time_series",
db.engine,
if_exists="replace",
chunksize=500,
dtype={
"ds": Date,
"ds2": DateTime,
"epoch_s": BigInteger,
"epoch_ms": BigInteger,
"string0": String(100),
"string1": String(100),
"string2": String(100),
"string3": String(100),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
print("Creating table [multiformat_time_series] reference")
obj = db.session.query(TBL).filter_by(table_name="multiformat_time_series").first()
if not only_metadata and (not table_exists or force):
data = get_example_data("multiformat_time_series.json.gz")
pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
"ds": Date,
"ds2": DateTime,
"epoch_s": BigInteger,
"epoch_ms": BigInteger,
"string0": String(100),
"string1": String(100),
"string2": String(100),
"string3": String(100),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
print(f"Creating table [{tbl_name}] reference")
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name="multiformat_time_series")
obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "ds"
obj.database = utils.get_or_create_main_db()
obj.database = database
dttm_and_expr_dict = {
"ds": [None, None],
"ds2": [None, None],

View File

@ -21,35 +21,39 @@ from sqlalchemy import String, Text
from superset import db
from superset.utils import core as utils
from .helpers import TBL, get_example_data
from .helpers import get_example_data, TBL
def load_paris_iris_geojson():
def load_paris_iris_geojson(only_metadata=False, force=False):
tbl_name = "paris_iris_mapping"
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
data = get_example_data("paris_iris.json.gz")
df = pd.read_json(data)
df["features"] = df.features.map(json.dumps)
if not only_metadata and (not table_exists or force):
data = get_example_data("paris_iris.json.gz")
df = pd.read_json(data)
df["features"] = df.features.map(json.dumps)
df.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
"color": String(255),
"name": String(255),
"features": Text,
"type": Text,
},
index=False,
)
df.to_sql(
tbl_name,
db.engine,
if_exists="replace",
chunksize=500,
dtype={
"color": String(255),
"name": String(255),
"features": Text,
"type": Text,
},
index=False,
)
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Map of Paris"
tbl.database = utils.get_or_create_main_db()
tbl.database = database
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

View File

@ -23,28 +23,33 @@ from superset.utils import core as utils
from .helpers import config, get_example_data, get_slice_json, merge_slice, Slice, TBL
def load_random_time_series_data():
def load_random_time_series_data(only_metadata=False, force=False):
"""Loading random time series data from a zip file in the repo"""
data = get_example_data("random_time_series.json.gz")
pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.to_sql(
"random_time_series",
db.engine,
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime},
index=False,
)
print("Done loading table!")
print("-" * 80)
tbl_name = "random_time_series"
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
print("Creating table [random_time_series] reference")
obj = db.session.query(TBL).filter_by(table_name="random_time_series").first()
if not only_metadata and (not table_exists or force):
data = get_example_data("random_time_series.json.gz")
pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime},
index=False,
)
print("Done loading table!")
print("-" * 80)
print(f"Creating table [{tbl_name}] reference")
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name="random_time_series")
obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "ds"
obj.database = utils.get_or_create_main_db()
obj.database = database
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()

View File

@ -21,35 +21,39 @@ from sqlalchemy import BigInteger, Text
from superset import db
from superset.utils import core as utils
from .helpers import TBL, get_example_data
from .helpers import get_example_data, TBL
def load_sf_population_polygons():
def load_sf_population_polygons(only_metadata=False, force=False):
tbl_name = "sf_population_polygons"
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
data = get_example_data("sf_population.json.gz")
df = pd.read_json(data)
df["contour"] = df.contour.map(json.dumps)
if not only_metadata and (not table_exists or force):
data = get_example_data("sf_population.json.gz")
df = pd.read_json(data)
df["contour"] = df.contour.map(json.dumps)
df.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
"zipcode": BigInteger,
"population": BigInteger,
"contour": Text,
"area": BigInteger,
},
index=False,
)
df.to_sql(
tbl_name,
db.engine,
if_exists="replace",
chunksize=500,
dtype={
"zipcode": BigInteger,
"population": BigInteger,
"contour": Text,
"area": BigInteger,
},
index=False,
)
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Population density of San Francisco"
tbl.database = utils.get_or_create_main_db()
tbl.database = database
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

View File

@ -17,30 +17,13 @@
"""Loads datasets, dashboards and slices in a new superset instance"""
# pylint: disable=C,R,W
import json
import os
import textwrap
import pandas as pd
from sqlalchemy import DateTime, String
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.utils import core as utils
from .helpers import (
config,
Dash,
DATA_FOLDER,
get_example_data,
get_slice_json,
merge_slice,
misc_dash_slices,
Slice,
TBL,
update_slice_ids,
)
from .helpers import Dash, Slice, update_slice_ids
def load_tabbed_dashboard():
def load_tabbed_dashboard(only_metadata=False):
"""Creating a tabbed dashboard"""
print("Creating a dashboard with nested tabs")

View File

@ -35,38 +35,43 @@ from .helpers import (
)
def load_unicode_test_data():
def load_unicode_test_data(only_metadata=False, force=False):
"""Loading unicode test dataset from a csv file in the repo"""
data = get_example_data(
"unicode_utf8_unixnl_test.csv", is_gzip=False, make_bytes=True
)
df = pd.read_csv(data, encoding="utf-8")
# generate date/numeric data
df["dttm"] = datetime.datetime.now().date()
df["value"] = [random.randint(1, 100) for _ in range(len(df))]
df.to_sql( # pylint: disable=no-member
"unicode_test",
db.engine,
if_exists="replace",
chunksize=500,
dtype={
"phrase": String(500),
"short_phrase": String(10),
"with_missing": String(100),
"dttm": Date(),
"value": Float(),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
tbl_name = "unicode_test"
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
data = get_example_data(
"unicode_utf8_unixnl_test.csv", is_gzip=False, make_bytes=True
)
df = pd.read_csv(data, encoding="utf-8")
# generate date/numeric data
df["dttm"] = datetime.datetime.now().date()
df["value"] = [random.randint(1, 100) for _ in range(len(df))]
df.to_sql( # pylint: disable=no-member
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
"phrase": String(500),
"short_phrase": String(10),
"with_missing": String(100),
"dttm": Date(),
"value": Float(),
},
index=False,
)
print("Done loading table!")
print("-" * 80)
print("Creating table [unicode_test] reference")
obj = db.session.query(TBL).filter_by(table_name="unicode_test").first()
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name="unicode_test")
obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "dttm"
obj.database = utils.get_or_create_main_db()
obj.database = database
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
@ -104,7 +109,7 @@ def load_unicode_test_data():
merge_slice(slc)
print("Creating a dashboard")
dash = db.session.query(Dash).filter_by(dashboard_title="Unicode Test").first()
dash = db.session.query(Dash).filter_by(slug="unicode-test").first()
if not dash:
dash = Dash()

View File

@ -30,7 +30,7 @@ from superset.utils import core as utils
from .helpers import (
config,
Dash,
DATA_FOLDER,
EXAMPLES_FOLDER,
get_example_data,
get_slice_json,
merge_slice,
@ -41,34 +41,38 @@ from .helpers import (
)
def load_world_bank_health_n_pop():
def load_world_bank_health_n_pop(only_metadata=False, force=False):
"""Loads the world bank health dataset, slices and a dashboard"""
tbl_name = "wb_health_population"
data = get_example_data("countries.json.gz")
pdf = pd.read_json(data)
pdf.columns = [col.replace(".", "_") for col in pdf.columns]
pdf.year = pd.to_datetime(pdf.year)
pdf.to_sql(
tbl_name,
db.engine,
if_exists="replace",
chunksize=50,
dtype={
"year": DateTime(),
"country_code": String(3),
"country_name": String(255),
"region": String(255),
},
index=False,
)
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
data = get_example_data("countries.json.gz")
pdf = pd.read_json(data)
pdf.columns = [col.replace(".", "_") for col in pdf.columns]
pdf.year = pd.to_datetime(pdf.year)
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=50,
dtype={
"year": DateTime(),
"country_code": String(3),
"country_name": String(255),
"region": String(255),
},
index=False,
)
print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = utils.readfile(os.path.join(DATA_FOLDER, "countries.md"))
tbl.description = utils.readfile(os.path.join(EXAMPLES_FOLDER, "countries.md"))
tbl.main_dttm_col = "year"
tbl.database = utils.get_or_create_main_db()
tbl.database = database
tbl.filter_select_enabled = True
metrics = [
@ -328,6 +332,7 @@ def load_world_bank_health_n_pop():
if not dash:
dash = Dash()
dash.published = True
js = textwrap.dedent(
"""\
{

View File

@ -666,12 +666,13 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin):
)
make_transient(copied_dashboard)
for slc in copied_dashboard.slices:
make_transient(slc)
datasource_ids.add((slc.datasource_id, slc.datasource_type))
# add extra params for the import
slc.alter_params(
remote_id=slc.id,
datasource_name=slc.datasource.name,
schema=slc.datasource.name,
schema=slc.datasource.schema,
database_name=slc.datasource.database.name,
)
copied_dashboard.alter_params(remote_id=dashboard_id)
@ -1169,6 +1170,10 @@ class Database(Model, AuditMixinNullable, ImportMixin):
engine = self.get_sqla_engine()
return engine.has_table(table.table_name, table.schema or None)
def has_table_by_name(self, table_name, schema=None):
engine = self.get_sqla_engine()
return engine.has_table(table_name, schema)
@utils.memoized
def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)

View File

@ -81,7 +81,7 @@ class TaggedObject(Model, AuditMixinNullable):
object_id = Column(Integer)
object_type = Column(Enum(ObjectTypes))
tag = relationship("Tag")
tag = relationship("Tag", backref="objects")
def get_tag(name, session, type_):

View File

@ -18,10 +18,8 @@
import json
import logging
import urllib.parse
from celery.utils.log import get_task_logger
from flask import url_for
import requests
from requests.exceptions import RequestException
from sqlalchemy import and_, func
@ -75,13 +73,13 @@ def get_form_data(chart_id, dashboard=None):
return form_data
def get_url(params):
def get_url(chart):
"""Return external URL for warming up a given chart/table cache."""
baseurl = "http://{SUPERSET_WEBSERVER_ADDRESS}:{SUPERSET_WEBSERVER_PORT}/".format(
**app.config
)
with app.test_request_context():
return urllib.parse.urljoin(baseurl, url_for("Superset.explore_json", **params))
baseurl = "{SUPERSET_WEBSERVER_ADDRESS}:{SUPERSET_WEBSERVER_PORT}".format(
**app.config
)
return f"{baseurl}{chart.url}"
class Strategy:
@ -136,7 +134,7 @@ class DummyStrategy(Strategy):
session = db.create_scoped_session()
charts = session.query(Slice).all()
return [get_url({"form_data": get_form_data(chart.id)}) for chart in charts]
return [get_url(chart) for chart in charts]
class TopNDashboardsStrategy(Strategy):
@ -180,7 +178,7 @@ class TopNDashboardsStrategy(Strategy):
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
urls.append(get_url({"form_data": get_form_data(chart.id, dashboard)}))
urls.append(get_url(chart))
return urls
@ -229,7 +227,7 @@ class DashboardTagsStrategy(Strategy):
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
urls.append(get_url({"form_data": get_form_data(chart.id, dashboard)}))
urls.append(get_url(chart))
# add charts that are tagged
tagged_objects = (
@ -245,7 +243,7 @@ class DashboardTagsStrategy(Strategy):
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
urls.append(get_url({"form_data": get_form_data(chart.id)}))
urls.append(get_url(chart))
return urls

View File

@ -942,25 +942,37 @@ def user_label(user: User) -> Optional[str]:
def get_or_create_main_db():
from superset import conf, db
get_main_database()
def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
from superset import db
from superset.models import core as models
logging.info("Creating database reference")
dbobj = get_main_database(db.session)
if not dbobj:
dbobj = models.Database(
database_name="main", allow_csv_upload=True, expose_in_sqllab=True
)
dbobj.set_sqlalchemy_uri(conf.get("SQLALCHEMY_DATABASE_URI"))
db.session.add(dbobj)
database = (
db.session.query(models.Database).filter_by(database_name=database_name).first()
)
if not database:
logging.info(f"Creating database reference for {database_name}")
database = models.Database(database_name=database_name, *args, **kwargs)
db.session.add(database)
database.set_sqlalchemy_uri(sqlalchemy_uri)
db.session.commit()
return dbobj
return database
def get_main_database(session):
from superset.models import core as models
def get_main_database():
from superset import conf
return session.query(models.Database).filter_by(database_name="main").first()
return get_or_create_db("main", conf.get("SQLALCHEMY_DATABASE_URI"))
def get_example_database():
from superset import conf
db_uri = conf.get("SQLALCHEMY_EXAMPLES_URI") or conf.get("SQLALCHEMY_DATABASE_URI")
return get_or_create_db("examples", db_uri)
def is_adhoc_metric(metric) -> bool:

View File

@ -1746,7 +1746,7 @@ class WorldMapViz(BaseViz):
return qry
def get_data(self, df):
from superset.data import countries
from superset.examples import countries
fd = self.form_data
cols = [fd.get("entity")]

View File

@ -31,7 +31,7 @@ ROLE_TABLES_PERM_DATA = {
"database": [
{
"datasource_type": "table",
"name": "main",
"name": "examples",
"schema": [{"name": "", "datasources": ["birth_names"]}],
}
],
@ -42,7 +42,7 @@ ROLE_ALL_PERM_DATA = {
"database": [
{
"datasource_type": "table",
"name": "main",
"name": "examples",
"schema": [{"name": "", "datasources": ["birth_names"]}],
},
{

View File

@ -168,9 +168,6 @@ class SupersetTestCase(unittest.TestCase):
):
security_manager.del_permission_role(public_role, perm)
def get_main_database(self):
return get_main_database(db.session)
def run_sql(
self,
sql,
@ -182,7 +179,7 @@ class SupersetTestCase(unittest.TestCase):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = self.get_main_database().id
dbid = get_main_database().id
resp = self.get_json_resp(
"/superset/sql_json/",
raise_on_error=False,
@ -202,7 +199,7 @@ class SupersetTestCase(unittest.TestCase):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = self.get_main_database().id
dbid = get_main_database().id
resp = self.get_json_resp(
"/superset/validate_sql_json/",
raise_on_error=False,
@ -223,3 +220,7 @@ class SupersetTestCase(unittest.TestCase):
def test_feature_flags(self):
self.assertEquals(is_feature_enabled("foo"), "bar")
self.assertEquals(is_feature_enabled("super"), "set")
def get_dash_by_slug(self, dash_slug):
sesh = db.session()
return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first()

View File

@ -128,14 +128,14 @@ class CeleryTestCase(SupersetTestCase):
return json.loads(resp.data)
def test_run_sync_query_dont_exist(self):
main_db = get_main_database(db.session)
main_db = get_main_database()
db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true")
self.assertTrue("error" in result1)
def test_run_sync_query_cta(self):
main_db = get_main_database(db.session)
main_db = get_main_database()
backend = main_db.backend
db_id = main_db.id
tmp_table_name = "tmp_async_22"
@ -158,7 +158,7 @@ class CeleryTestCase(SupersetTestCase):
self.assertGreater(len(results["data"]), 0)
def test_run_sync_query_cta_no_data(self):
main_db = get_main_database(db.session)
main_db = get_main_database()
db_id = main_db.id
sql_empty_result = "SELECT * FROM ab_user WHERE id=666"
result3 = self.run_sql(db_id, sql_empty_result, "3")
@ -179,7 +179,7 @@ class CeleryTestCase(SupersetTestCase):
return self.run_sql(db_id, sql)
def test_run_async_query(self):
main_db = get_main_database(db.session)
main_db = get_main_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_1", main_db)
@ -212,7 +212,7 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta_used)
def test_run_async_query_with_lower_limit(self):
main_db = get_main_database(db.session)
main_db = get_main_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_2", main_db)

View File

@ -39,7 +39,6 @@ from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.models import core as models
from superset.models.sql_lab import Query
from superset.utils import core as utils
from superset.utils.core import get_main_database
from superset.views.core import DatabaseView
from .base_tests import SupersetTestCase
from .fixtures.pyodbcRow import Row
@ -345,7 +344,7 @@ class CoreTests(SupersetTestCase):
def test_testconn(self, username="admin"):
self.login(username=username)
database = get_main_database(db.session)
database = utils.get_main_database()
# validate that the endpoint works with the password-masked sqlalchemy uri
data = json.dumps(
@ -376,7 +375,7 @@ class CoreTests(SupersetTestCase):
assert response.headers["Content-Type"] == "application/json"
def test_custom_password_store(self):
database = get_main_database(db.session)
database = utils.get_main_database()
conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted)
def custom_password_store(uri):
@ -394,13 +393,13 @@ class CoreTests(SupersetTestCase):
# validate that sending a password-masked uri does not over-write the decrypted
# uri
self.login(username=username)
database = get_main_database(db.session)
database = utils.get_main_database()
sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted
url = "databaseview/edit/{}".format(database.id)
data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns}
data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
self.client.post(url, data=data)
database = get_main_database(db.session)
database = utils.get_main_database()
self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted)
def test_warm_up_cache(self):
@ -483,27 +482,27 @@ class CoreTests(SupersetTestCase):
def test_extra_table_metadata(self):
self.login("admin")
dbid = get_main_database(db.session).id
dbid = utils.get_main_database().id
self.get_json_resp(
f"/superset/extra_table_metadata/{dbid}/" "ab_permission_view/panoramix/"
)
def test_process_template(self):
maindb = get_main_database(db.session)
maindb = utils.get_main_database()
sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(sql)
self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
def test_get_template_kwarg(self):
maindb = get_main_database(db.session)
maindb = utils.get_main_database()
s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb, foo="bar")
rendered = tp.process_template(s)
self.assertEqual("bar", rendered)
def test_template_kwarg(self):
maindb = get_main_database(db.session)
maindb = utils.get_main_database()
s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(s, foo="bar")
@ -516,7 +515,7 @@ class CoreTests(SupersetTestCase):
self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00")
def test_table_metadata(self):
maindb = get_main_database(db.session)
maindb = utils.get_main_database()
backend = maindb.backend
data = self.get_json_resp("/superset/table/{}/ab_user/null/".format(maindb.id))
self.assertEqual(data["name"], "ab_user")
@ -615,15 +614,16 @@ class CoreTests(SupersetTestCase):
test_file.write("john,1\n")
test_file.write("paul,2\n")
test_file.close()
main_db_uri = (
db.session.query(models.Database).filter_by(database_name="main").one()
)
example_db = utils.get_example_database()
example_db.allow_csv_upload = True
db_id = example_db.id
db.session.commit()
test_file = open(filename, "rb")
form_data = {
"csv_file": test_file,
"sep": ",",
"name": table_name,
"con": main_db_uri.id,
"con": db_id,
"if_exists": "append",
"index_label": "test_label",
"mangle_dupe_cols": False,
@ -638,8 +638,8 @@ class CoreTests(SupersetTestCase):
try:
# ensure uploaded successfully
form_post = self.get_resp(url, data=form_data)
assert 'CSV file "testCSV.csv" uploaded to table' in form_post
resp = self.get_resp(url, data=form_data)
assert 'CSV file "testCSV.csv" uploaded to table' in resp
finally:
os.remove(filename)
@ -769,7 +769,8 @@ class CoreTests(SupersetTestCase):
def test_select_star(self):
self.login(username="admin")
resp = self.get_resp("/superset/select_star/1/birth_names")
examples_db = utils.get_example_database()
resp = self.get_resp(f"/superset/select_star/{examples_db.id}/birth_names")
self.assertIn("gender", resp)

View File

@ -39,6 +39,7 @@ from superset.db_engine_specs.pinot import PinotEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.models.core import Database
from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
@ -925,14 +926,14 @@ class DbEngineSpecsTestCase(SupersetTestCase):
) # noqa
def test_column_datatype_to_string(self):
main_db = self.get_main_database()
sqla_table = main_db.get_table("energy_usage")
dialect = main_db.get_dialect()
example_db = get_example_database()
sqla_table = example_db.get_table("energy_usage")
dialect = example_db.get_dialect()
col_names = [
main_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
for c in sqla_table.columns
]
if main_db.backend == "postgresql":
if example_db.backend == "postgresql":
expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
else:
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]

View File

@ -63,7 +63,7 @@ class DictImportExportTests(SupersetTestCase):
params = {DBREF: id, "database_name": database_name}
dict_rep = {
"database_id": get_main_database(db.session).id,
"database_id": get_main_database().id,
"table_name": name,
"schema": schema,
"id": id,

View File

@ -63,7 +63,7 @@ class ImportExportTests(SupersetTestCase):
name,
ds_id=None,
id=None,
db_name="main",
db_name="examples",
table_name="wb_health_population",
):
params = {
@ -102,7 +102,7 @@ class ImportExportTests(SupersetTestCase):
)
def create_table(self, name, schema="", id=0, cols_names=[], metric_names=[]):
params = {"remote_id": id, "database_name": "main"}
params = {"remote_id": id, "database_name": "examples"}
table = SqlaTable(
id=id, schema=schema, table_name=name, params=json.dumps(params)
)
@ -135,10 +135,6 @@ class ImportExportTests(SupersetTestCase):
def get_dash(self, dash_id):
return db.session.query(models.Dashboard).filter_by(id=dash_id).first()
def get_dash_by_slug(self, dash_slug):
sesh = db.session()
return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first()
def get_datasource(self, datasource_id):
return db.session.query(DruidDatasource).filter_by(id=datasource_id).first()
@ -192,9 +188,21 @@ class ImportExportTests(SupersetTestCase):
self.assertEquals(expected_slc_name, actual_slc_name)
self.assertEquals(expected_slc.datasource_type, actual_slc.datasource_type)
self.assertEquals(expected_slc.viz_type, actual_slc.viz_type)
self.assertEquals(
json.loads(expected_slc.params), json.loads(actual_slc.params)
exp_params = json.loads(expected_slc.params)
actual_params = json.loads(actual_slc.params)
diff_params_keys = (
"schema",
"database_name",
"datasource_name",
"remote_id",
"import_time",
)
for k in diff_params_keys:
if k in actual_params:
actual_params.pop(k)
if k in exp_params:
exp_params.pop(k)
self.assertEquals(exp_params, actual_params)
def test_export_1_dashboard(self):
self.login("admin")
@ -233,11 +241,11 @@ class ImportExportTests(SupersetTestCase):
birth_dash.id, world_health_dash.id
)
resp = self.client.get(export_dash_url)
resp_data = json.loads(
resp.data.decode("utf-8"), object_hook=utils.decode_dashboards
)
exported_dashboards = sorted(
json.loads(resp.data.decode("utf-8"), object_hook=utils.decode_dashboards)[
"dashboards"
],
key=lambda d: d.dashboard_title,
resp_data.get("dashboards"), key=lambda d: d.dashboard_title
)
self.assertEquals(2, len(exported_dashboards))
@ -255,10 +263,7 @@ class ImportExportTests(SupersetTestCase):
)
exported_tables = sorted(
json.loads(resp.data.decode("utf-8"), object_hook=utils.decode_dashboards)[
"datasources"
],
key=lambda t: t.table_name,
resp_data.get("datasources"), key=lambda t: t.table_name
)
self.assertEquals(2, len(exported_tables))
self.assert_table_equals(
@ -297,7 +302,7 @@ class ImportExportTests(SupersetTestCase):
self.assertEquals(imported_slc_2.datasource.perm, imported_slc_2.perm)
def test_import_slices_for_non_existent_table(self):
with self.assertRaises(IndexError):
with self.assertRaises(AttributeError):
models.Slice.import_obj(
self.create_slice("Import Me 3", id=10004, table_name="non_existent"),
None,
@ -447,7 +452,7 @@ class ImportExportTests(SupersetTestCase):
imported = self.get_table(imported_id)
self.assert_table_equals(table, imported)
self.assertEquals(
{"remote_id": 10002, "import_time": 1990, "database_name": "main"},
{"remote_id": 10002, "import_time": 1990, "database_name": "examples"},
json.loads(imported.params),
)

View File

@ -14,23 +14,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset import data
from superset import examples
from superset.cli import load_test_users_run
from .base_tests import SupersetTestCase
class SupersetDataFrameTestCase(SupersetTestCase):
def test_load_css_templates(self):
data.load_css_templates()
examples.load_css_templates()
def test_load_energy(self):
data.load_energy()
examples.load_energy()
def test_load_world_bank_health_n_pop(self):
data.load_world_bank_health_n_pop()
examples.load_world_bank_health_n_pop()
def test_load_birth_names(self):
data.load_birth_names()
examples.load_birth_names()
def test_load_test_users_run(self):
load_test_users_run()
def test_load_unicode_test_data(self):
examples.load_unicode_test_data()

View File

@ -20,9 +20,9 @@ import unittest
import pandas
from sqlalchemy.engine.url import make_url
from superset import app, db
from superset import app
from superset.models.core import Database
from superset.utils.core import get_main_database, QueryStatus
from superset.utils.core import get_example_database, get_main_database, QueryStatus
from .base_tests import SupersetTestCase
@ -101,7 +101,7 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertNotEquals(example_user, user_name)
def test_select_star(self):
main_db = get_main_database(db.session)
main_db = get_example_database()
table_name = "energy_usage"
sql = main_db.select_star(table_name, show_cols=False, latest_partition=False)
expected = textwrap.dedent(
@ -124,7 +124,7 @@ class DatabaseModelTestCase(SupersetTestCase):
assert sql.startswith(expected)
def test_single_statement(self):
main_db = get_main_database(db.session)
main_db = get_main_database()
if main_db.backend == "mysql":
df = main_db.get_df("SELECT 1", None)
@ -134,7 +134,7 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertEquals(df.iat[0, 0], 1)
def test_multi_statement(self):
main_db = get_main_database(db.session)
main_db = get_main_database()
if main_db.backend == "mysql":
df = main_db.get_df("USE superset; SELECT 1", None)

View File

@ -83,7 +83,7 @@ class SqlLabTests(SupersetTestCase):
self.assertLess(0, len(data["data"]))
def test_sql_json_has_access(self):
main_db = get_main_database(db.session)
main_db = get_main_database()
security_manager.add_permission_view_menu("database_access", main_db.perm)
db.session.commit()
main_db_permission_view = (

View File

@ -23,14 +23,13 @@ from superset.models.core import Log
from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes
from superset.tasks.cache import (
DashboardTagsStrategy,
DummyStrategy,
get_form_data,
TopNDashboardsStrategy,
)
from .base_tests import SupersetTestCase
TEST_URL = "http://0.0.0.0:8081/superset/explore_json"
URL_PREFIX = "0.0.0.0:8081"
class CacheWarmUpTests(SupersetTestCase):
@ -141,61 +140,61 @@ class CacheWarmUpTests(SupersetTestCase):
}
self.assertEqual(result, expected)
def test_dummy_strategy(self):
strategy = DummyStrategy()
result = sorted(strategy.get_urls())
expected = [
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+1%7D",
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+17%7D",
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+18%7D",
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+19%7D",
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D",
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D",
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+8%7D",
]
self.assertEqual(result, expected)
def test_top_n_dashboards_strategy(self):
# create a top visited dashboard
db.session.query(Log).delete()
self.login(username="admin")
dash = self.get_dash_by_slug("births")
for _ in range(10):
self.client.get("/superset/dashboard/3/")
self.client.get(f"/superset/dashboard/{dash.id}/")
strategy = TopNDashboardsStrategy(1)
result = sorted(strategy.get_urls())
expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D"]
expected = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
self.assertEqual(result, expected)
def test_dashboard_tags(self):
strategy = DashboardTagsStrategy(["tag1"])
def reset_tag(self, tag):
"""Remove associated object from tag, used to reset tests"""
if tag.objects:
for o in tag.objects:
db.session.delete(o)
db.session.commit()
def test_dashboard_tags(self):
tag1 = get_tag("tag1", db.session, TagTypes.custom)
# delete first to make test idempotent
self.reset_tag(tag1)
strategy = DashboardTagsStrategy(["tag1"])
result = sorted(strategy.get_urls())
expected = []
self.assertEqual(result, expected)
# tag dashboard 3 with `tag1`
# tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagTypes.custom)
object_id = 3
dash = self.get_dash_by_slug("births")
tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
tagged_object = TaggedObject(
tag_id=tag1.id, object_id=object_id, object_type=ObjectTypes.dashboard
tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard
)
db.session.add(tagged_object)
db.session.commit()
result = sorted(strategy.get_urls())
expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D"]
self.assertEqual(result, expected)
self.assertEqual(sorted(strategy.get_urls()), tag1_urls)
strategy = DashboardTagsStrategy(["tag2"])
tag2 = get_tag("tag2", db.session, TagTypes.custom)
self.reset_tag(tag2)
result = sorted(strategy.get_urls())
expected = []
self.assertEqual(result, expected)
# tag chart 30 with `tag2`
tag2 = get_tag("tag2", db.session, TagTypes.custom)
object_id = 30
# tag first slice
dash = self.get_dash_by_slug("unicode-test")
slc = dash.slices[0]
tag2_urls = [f"{URL_PREFIX}{slc.url}"]
object_id = slc.id
tagged_object = TaggedObject(
tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart
)
@ -203,14 +202,10 @@ class CacheWarmUpTests(SupersetTestCase):
db.session.commit()
result = sorted(strategy.get_urls())
expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D"]
self.assertEqual(result, expected)
self.assertEqual(result, tag2_urls)
strategy = DashboardTagsStrategy(["tag1", "tag2"])
result = sorted(strategy.get_urls())
expected = [
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D",
f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D",
]
expected = sorted(tag1_urls + tag2_urls)
self.assertEqual(result, expected)

View File

@ -109,7 +109,6 @@ class BaseVizTestCase(SupersetTestCase):
datasource.get_col = Mock(return_value=mock_dttm_col)
mock_dttm_col.python_date_format = "epoch_ms"
result = test_viz.get_df(query_obj)
print(result)
import logging
logging.info(result)

View File

@ -46,7 +46,7 @@ setenv =
PYTHONPATH = {toxinidir}
SUPERSET_CONFIG = tests.superset_test_config
SUPERSET_HOME = {envtmpdir}
py36-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset
py36-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset?charset=utf8
py36-postgres: SUPERSET__SQLALCHEMY_DATABASE_URI = postgresql+psycopg2://postgresuser:pguserpassword@localhost/superset
py36-sqlite: SUPERSET__SQLALCHEMY_DATABASE_URI = sqlite:////{envtmpdir}/superset.db
whitelist_externals =