[mypy] Enforcing typing for superset.examples (#9469)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-04-06 09:11:49 -07:00 committed by GitHub
parent c0807c1af7
commit dcb7b8350e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 75 additions and 49 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*]
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -26,7 +26,7 @@ from superset.utils.core import get_example_database
from .helpers import get_example_data, TBL
def load_bart_lines(only_metadata=False, force=False):
def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "bart_lines"
database = get_example_database()
table_exists = database.has_table_by_name(tbl_name)

View File

@ -16,6 +16,7 @@
# under the License.
import json
import textwrap
from typing import Dict, Union
import pandas as pd
from sqlalchemy import DateTime, String
@ -23,6 +24,7 @@ from sqlalchemy.sql import column
from superset import db, security_manager
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_example_database
@ -38,7 +40,9 @@ from .helpers import (
)
def gen_filter(subject, comparator, operator="=="):
def gen_filter(
subject: str, comparator: str, operator: str = "=="
) -> Dict[str, Union[bool, str]]:
return {
"clause": "WHERE",
"comparator": comparator,
@ -49,7 +53,7 @@ def gen_filter(subject, comparator, operator="=="):
}
def load_data(tbl_name, database):
def load_data(tbl_name: str, database: Database) -> None:
pdf = pd.read_json(get_example_data("birth_names.json.gz"))
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf.to_sql(
@ -69,7 +73,7 @@ def load_data(tbl_name, database):
print("-" * 80)
def load_birth_names(only_metadata=False, force=False):
def load_birth_names(only_metadata: bool = False, force: bool = False) -> None:
"""Loading birth name dataset from a zip file in the repo"""
# pylint: disable=too-many-locals
tbl_name = "birth_names"

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains data related to countries and is used for geo mapping"""
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
countries: List[Dict[str, Any]] = [
{
@ -2498,13 +2498,8 @@ for lookup in lookups:
all_lookups[lookup][country[lookup].lower()] = country
def get(field, symbol):
def get(field: str, symbol: str) -> Optional[Dict[str, Any]]:
"""
Get country data based on a standard code and a symbol
>>> get('cioc', 'CUB')['name']
"Cuba"
>>> get('cca2', 'CA')['name']
"Canada"
"""
return all_lookups[field].get(symbol.lower())

View File

@ -34,7 +34,7 @@ from .helpers import (
)
def load_country_map_data(only_metadata=False, force=False):
def load_country_map_data(only_metadata: bool = False, force: bool = False) -> None:
"""Loading data for map with country map"""
tbl_name = "birth_france_by_region"
database = utils.get_example_database()

View File

@ -20,7 +20,7 @@ from superset import db
from superset.models.core import CssTemplate
def load_css_templates():
def load_css_templates() -> None:
"""Loads 2 css templates to demonstrate the feature"""
print("Creating default CSS templates")

View File

@ -167,7 +167,7 @@ POSITION_JSON = """\
}"""
def load_deck_dash():
def load_deck_dash() -> None:
print("Loading deck.gl dashboard")
slices = []
tbl = db.session.query(TBL).filter_by(table_name="long_lat").first()

View File

@ -29,7 +29,7 @@ from superset.utils import core as utils
from .helpers import get_example_data, merge_slice, misc_dash_slices, TBL
def load_energy(only_metadata=False, force=False):
def load_energy(only_metadata: bool = False, force: bool = False) -> None:
"""Loads an energy related dataset to use with sankey and graphs"""
tbl_name = "energy_usage"
database = utils.get_example_database()

View File

@ -23,7 +23,7 @@ from superset.utils import core as utils
from .helpers import get_example_data, TBL
def load_flights(only_metadata=False, force=False):
def load_flights(only_metadata: bool = False, force: bool = False) -> None:
"""Loading random time series data from a zip file in the repo"""
tbl_name = "flights"
database = utils.get_example_database()

View File

@ -19,7 +19,7 @@ import json
import os
import zlib
from io import BytesIO
from typing import Set
from typing import Any, Dict, List, Set
from urllib import request
from superset import app, db
@ -41,7 +41,7 @@ EXAMPLES_FOLDER = os.path.join(config["BASE_DIR"], "examples")
misc_dash_slices: Set[str] = set() # slices assembled in a 'Misc Chart' dashboard
def update_slice_ids(layout_dict, slices):
def update_slice_ids(layout_dict: Dict[Any, Any], slices: List[Slice]) -> None:
charts = [
component
for component in layout_dict.values()
@ -53,7 +53,7 @@ def update_slice_ids(layout_dict, slices):
chart_component["meta"]["chartId"] = int(slices[i].id)
def merge_slice(slc):
def merge_slice(slc: Slice) -> None:
o = db.session.query(Slice).filter_by(slice_name=slc.slice_name).first()
if o:
db.session.delete(o)
@ -61,13 +61,15 @@ def merge_slice(slc):
db.session.commit()
def get_slice_json(defaults, **kwargs):
def get_slice_json(defaults: Dict[Any, Any], **kwargs: Any) -> str:
d = defaults.copy()
d.update(kwargs)
return json.dumps(d, indent=4, sort_keys=True)
def get_example_data(filepath, is_gzip=True, make_bytes=False):
def get_example_data(
filepath: str, is_gzip: bool = True, make_bytes: bool = False
) -> BytesIO:
content = request.urlopen(f"{BASE_URL}{filepath}?raw=true").read()
if is_gzip:
content = zlib.decompress(content, zlib.MAX_WBITS | 16)

View File

@ -34,7 +34,7 @@ from .helpers import (
)
def load_long_lat_data(only_metadata=False, force=False):
def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None:
"""Loading lat/long data from a csv file in the repo"""
tbl_name = "long_lat"
database = utils.get_example_database()

View File

@ -26,7 +26,7 @@ from .helpers import misc_dash_slices, update_slice_ids
DASH_SLUG = "misc_charts"
def load_misc_dashboard():
def load_misc_dashboard() -> None:
"""Loading a dashboard featuring misc charts"""
print("Creating the dashboard")

View File

@ -24,7 +24,7 @@ from .helpers import merge_slice, misc_dash_slices
from .world_bank import load_world_bank_health_n_pop
def load_multi_line(only_metadata=False):
def load_multi_line(only_metadata: bool = False) -> None:
load_world_bank_health_n_pop(only_metadata)
load_birth_names(only_metadata)
ids = [

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, Optional, Tuple
import pandas as pd
from sqlalchemy import BigInteger, Date, DateTime, String
@ -32,7 +33,9 @@ from .helpers import (
)
def load_multiformat_time_series(only_metadata=False, force=False):
def load_multiformat_time_series(
only_metadata: bool = False, force: bool = False
) -> None:
"""Loading time series data from a zip file in the repo"""
tbl_name = "multiformat_time_series"
database = get_example_database()
@ -70,15 +73,15 @@ def load_multiformat_time_series(only_metadata=False, force=False):
obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "ds"
obj.database = database
dttm_and_expr_dict = {
"ds": [None, None],
"ds2": [None, None],
"epoch_s": ["epoch_s", None],
"epoch_ms": ["epoch_ms", None],
"string2": ["%Y%m%d-%H%M%S", None],
"string1": ["%Y-%m-%d^%H:%M:%S", None],
"string0": ["%Y-%m-%d %H:%M:%S.%f", None],
"string3": ["%Y/%m/%d%H:%M:%S.%f", None],
dttm_and_expr_dict: Dict[str, Tuple[Optional[str], None]] = {
"ds": (None, None),
"ds2": (None, None),
"epoch_s": ("epoch_s", None),
"epoch_ms": ("epoch_ms", None),
"string2": ("%Y%m%d-%H%M%S", None),
"string1": ("%Y-%m-%d^%H:%M:%S", None),
"string0": ("%Y-%m-%d %H:%M:%S.%f", None),
"string3": ("%Y/%m/%d%H:%M:%S.%f", None),
}
for col in obj.columns:
dttm_and_expr = dttm_and_expr_dict[col.column_name]

View File

@ -25,7 +25,7 @@ from superset.utils import core as utils
from .helpers import get_example_data, TBL
def load_paris_iris_geojson(only_metadata=False, force=False):
def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "paris_iris_mapping"
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)

View File

@ -25,7 +25,9 @@ from superset.utils import core as utils
from .helpers import config, get_example_data, get_slice_json, merge_slice, TBL
def load_random_time_series_data(only_metadata=False, force=False):
def load_random_time_series_data(
only_metadata: bool = False, force: bool = False
) -> None:
"""Loading random time series data from a zip file in the repo"""
tbl_name = "random_time_series"
database = utils.get_example_database()

View File

@ -25,7 +25,9 @@ from superset.utils import core as utils
from .helpers import get_example_data, TBL
def load_sf_population_polygons(only_metadata=False, force=False):
def load_sf_population_polygons(
only_metadata: bool = False, force: bool = False
) -> None:
tbl_name = "sf_population_polygons"
database = utils.get_example_database()
table_exists = database.has_table_by_name(tbl_name)

View File

@ -25,7 +25,7 @@ from superset.models.slice import Slice
from .helpers import update_slice_ids
def load_tabbed_dashboard(_=False):
def load_tabbed_dashboard(_: bool = False) -> None:
"""Creating a tabbed dashboard"""
print("Creating a dashboard with nested tabs")

View File

@ -36,7 +36,7 @@ from .helpers import (
)
def load_unicode_test_data(only_metadata=False, force=False):
def load_unicode_test_data(only_metadata: bool = False, force: bool = False) -> None:
"""Loading unicode test dataset from a csv file in the repo"""
tbl_name = "unicode_test"
database = utils.get_example_database()

View File

@ -41,9 +41,9 @@ from .helpers import (
)
def load_world_bank_health_n_pop(
only_metadata=False, force=False
): # pylint: disable=too-many-locals
def load_world_bank_health_n_pop( # pylint: disable=too-many-locals
only_metadata: bool = False, force: bool = False
) -> None:
"""Loads the world bank health dataset, slices and a dashboard"""
tbl_name = "wb_health_population"
database = utils.get_example_database()

View File

@ -37,7 +37,18 @@ from email.mime.text import MIMEText
from email.utils import formatdate
from enum import Enum
from time import struct_time
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union
from typing import (
Any,
Dict,
Iterator,
List,
NamedTuple,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from urllib.parse import unquote_plus
import bleach
@ -72,6 +83,9 @@ try:
except ImportError:
pass
if TYPE_CHECKING:
from superset.models.core import Database
logging.getLogger("MARKDOWN").setLevel(logging.INFO)
logger = logging.getLogger(__name__)
@ -944,7 +958,7 @@ def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
return database
def get_example_database():
def get_example_database() -> "Database":
from superset import conf
db_uri = conf.get("SQLALCHEMY_EXAMPLES_URI") or conf.get("SQLALCHEMY_DATABASE_URI")
@ -1057,11 +1071,15 @@ def get_since_until(
else:
rel, num, grain = time_range.split()
if rel == "Last":
since = relative_start - relativedelta(**{grain: int(num)}) # type: ignore
since = relative_start - relativedelta( # type: ignore
**{grain: int(num)} # type: ignore
)
until = relative_end
else: # rel == 'Next'
since = relative_start
until = relative_end + relativedelta(**{grain: int(num)}) # type: ignore
until = relative_end + relativedelta( # type: ignore
**{grain: int(num)} # type: ignore
)
else:
since = since or ""
if since:

View File

@ -1875,8 +1875,8 @@ class WorldMapViz(BaseViz):
for row in d:
country = None
if isinstance(row["country"], str):
country = countries.get(fd.get("country_fieldtype"), row["country"])
if "country_fieldtype" in fd:
country = countries.get(fd["country_fieldtype"], row["country"])
if country:
row["country"] = country["cca3"]
row["latitude"] = country["lat"]