chore: Enforce Mypy for non-tests (#15757)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
d26254099e
commit
ab4e3b9bf9
|
|
@ -24,6 +24,7 @@ from dataclasses import dataclass
|
|||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import click
|
||||
from click.core import Context
|
||||
|
||||
try:
|
||||
from github import BadCredentialsException, Github, PullRequest, Repository
|
||||
|
|
@ -50,7 +51,7 @@ class GitLog:
|
|||
author_email: str = ""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
""" A log entry is considered equal if it has the same PR number """
|
||||
"""A log entry is considered equal if it has the same PR number"""
|
||||
if isinstance(other, self.__class__):
|
||||
return other.pr_number == self.pr_number
|
||||
return False
|
||||
|
|
@ -170,7 +171,7 @@ class GitChangeLog:
|
|||
|
||||
def _parse_change_log(
|
||||
self, changelog: Dict[str, str], pr_info: Dict[str, str], github_login: str,
|
||||
):
|
||||
) -> None:
|
||||
formatted_pr = (
|
||||
f"- [#{pr_info.get('id')}]"
|
||||
f"(https://github.com/{SUPERSET_REPO}/pull/{pr_info.get('id')}) "
|
||||
|
|
@ -324,8 +325,8 @@ def print_title(message: str) -> None:
|
|||
@click.pass_context
|
||||
@click.option("--previous_version", help="The previous release version", required=True)
|
||||
@click.option("--current_version", help="The current release version", required=True)
|
||||
def cli(ctx, previous_version: str, current_version: str) -> None:
|
||||
""" Welcome to change log generator """
|
||||
def cli(ctx: Context, previous_version: str, current_version: str) -> None:
|
||||
"""Welcome to change log generator"""
|
||||
previous_logs = GitLogs(previous_version)
|
||||
current_logs = GitLogs(current_version)
|
||||
previous_logs.fetch()
|
||||
|
|
@ -337,7 +338,7 @@ def cli(ctx, previous_version: str, current_version: str) -> None:
|
|||
@cli.command("compare")
|
||||
@click.pass_obj
|
||||
def compare(base_parameters: BaseParameters) -> None:
|
||||
""" Compares both versions (by PR) """
|
||||
"""Compares both versions (by PR)"""
|
||||
previous_logs = base_parameters.previous_logs
|
||||
current_logs = base_parameters.current_logs
|
||||
print_title(
|
||||
|
|
@ -369,7 +370,7 @@ def compare(base_parameters: BaseParameters) -> None:
|
|||
def change_log(
|
||||
base_parameters: BaseParameters, csv: str, access_token: str, risk: bool
|
||||
) -> None:
|
||||
""" Outputs a changelog (by PR) """
|
||||
"""Outputs a changelog (by PR)"""
|
||||
previous_logs = base_parameters.previous_logs
|
||||
current_logs = base_parameters.current_logs
|
||||
previous_diff_logs = previous_logs.diff(current_logs)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@
|
|||
#
|
||||
import smtplib
|
||||
import ssl
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from click.core import Context
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
|
|
@ -50,7 +52,7 @@ def send_email(
|
|||
sender_email: str,
|
||||
receiver_email: str,
|
||||
message: str,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Send a simple text email (SMTP)
|
||||
"""
|
||||
|
|
@ -61,7 +63,7 @@ def send_email(
|
|||
server.sendmail(sender_email, receiver_email, message)
|
||||
|
||||
|
||||
def render_template(template_file: str, **kwargs) -> str:
|
||||
def render_template(template_file: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Simple render template based on named parameters
|
||||
|
||||
|
|
@ -73,7 +75,9 @@ def render_template(template_file: str, **kwargs) -> str:
|
|||
return template.render(kwargs)
|
||||
|
||||
|
||||
def inter_send_email(username, password, sender_email, receiver_email, message):
|
||||
def inter_send_email(
|
||||
username: str, password: str, sender_email: str, receiver_email: str, message: str
|
||||
) -> None:
|
||||
print("--------------------------")
|
||||
print("SMTP Message")
|
||||
print("--------------------------")
|
||||
|
|
@ -102,16 +106,16 @@ def inter_send_email(username, password, sender_email, receiver_email, message):
|
|||
|
||||
class BaseParameters(object):
|
||||
def __init__(
|
||||
self, email=None, username=None, password=None, version=None, version_rc=None
|
||||
):
|
||||
self, email: str, username: str, password: str, version: str, version_rc: str,
|
||||
) -> None:
|
||||
self.email = email
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.version = version
|
||||
self.version_rc = version_rc
|
||||
self.template_arguments = dict()
|
||||
self.template_arguments: Dict[str, Any] = {}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Apache Credentials: {self.email}/{self.username}/{self.version}/{self.version_rc}"
|
||||
|
||||
|
||||
|
|
@ -133,8 +137,15 @@ class BaseParameters(object):
|
|||
)
|
||||
@click.option("--version", envvar="SUPERSET_VERSION")
|
||||
@click.option("--version_rc", envvar="SUPERSET_VERSION_RC")
|
||||
def cli(ctx, apache_email, apache_username, apache_password, version, version_rc):
|
||||
""" Welcome to releasing send email CLI interface! """
|
||||
def cli(
|
||||
ctx: Context,
|
||||
apache_email: str,
|
||||
apache_username: str,
|
||||
apache_password: str,
|
||||
version: str,
|
||||
version_rc: str,
|
||||
) -> None:
|
||||
"""Welcome to releasing send email CLI interface!"""
|
||||
base_parameters = BaseParameters(
|
||||
apache_email, apache_username, apache_password, version, version_rc
|
||||
)
|
||||
|
|
@ -155,7 +166,7 @@ def cli(ctx, apache_email, apache_username, apache_password, version, version_rc
|
|||
prompt="The receiver email (To:)",
|
||||
)
|
||||
@click.pass_obj
|
||||
def vote_pmc(base_parameters, receiver_email):
|
||||
def vote_pmc(base_parameters: BaseParameters, receiver_email: str) -> None:
|
||||
template_file = "email_templates/vote_pmc.j2"
|
||||
base_parameters.template_arguments["receiver_email"] = receiver_email
|
||||
message = render_template(template_file, **base_parameters.template_arguments)
|
||||
|
|
@ -202,13 +213,13 @@ def vote_pmc(base_parameters, receiver_email):
|
|||
)
|
||||
@click.pass_obj
|
||||
def result_pmc(
|
||||
base_parameters,
|
||||
receiver_email,
|
||||
vote_bindings,
|
||||
vote_nonbindings,
|
||||
vote_negatives,
|
||||
vote_thread,
|
||||
):
|
||||
base_parameters: BaseParameters,
|
||||
receiver_email: str,
|
||||
vote_bindings: str,
|
||||
vote_nonbindings: str,
|
||||
vote_negatives: str,
|
||||
vote_thread: str,
|
||||
) -> None:
|
||||
template_file = "email_templates/result_pmc.j2"
|
||||
base_parameters.template_arguments["receiver_email"] = receiver_email
|
||||
base_parameters.template_arguments["vote_bindings"] = string_comma_to_list(
|
||||
|
|
@ -239,7 +250,7 @@ def result_pmc(
|
|||
prompt="The receiver email (To:)",
|
||||
)
|
||||
@click.pass_obj
|
||||
def announce(base_parameters, receiver_email):
|
||||
def announce(base_parameters: BaseParameters, receiver_email: str) -> None:
|
||||
template_file = "email_templates/announce.j2"
|
||||
base_parameters.template_arguments["receiver_email"] = receiver_email
|
||||
message = render_template(template_file, **base_parameters.template_arguments)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
import logging
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from cachelib.file import FileSystemCache
|
||||
from celery.schedules import crontab
|
||||
|
|
@ -30,7 +31,7 @@ from celery.schedules import crontab
|
|||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_env_variable(var_name, default=None):
|
||||
def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
|
||||
"""Get the environment variable or raise exception."""
|
||||
try:
|
||||
return os.environ[var_name]
|
||||
|
|
@ -63,8 +64,8 @@ SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % (
|
|||
|
||||
REDIS_HOST = get_env_variable("REDIS_HOST")
|
||||
REDIS_PORT = get_env_variable("REDIS_PORT")
|
||||
REDIS_CELERY_DB = get_env_variable("REDIS_CELERY_DB", 0)
|
||||
REDIS_RESULTS_DB = get_env_variable("REDIS_RESULTS_DB", 1)
|
||||
REDIS_CELERY_DB = get_env_variable("REDIS_CELERY_DB", "0")
|
||||
REDIS_RESULTS_DB = get_env_variable("REDIS_RESULTS_DB", "1")
|
||||
|
||||
RESULTS_BACKEND = FileSystemCache("/app/superset_home/sqllab")
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ Example:
|
|||
./cancel_github_workflows.py 1024 --include-last
|
||||
"""
|
||||
import os
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
|
||||
|
||||
import click
|
||||
import requests
|
||||
|
|
@ -45,7 +45,9 @@ github_token = os.environ.get("GITHUB_TOKEN")
|
|||
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
|
||||
|
||||
|
||||
def request(method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs):
|
||||
def request(
|
||||
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
resp = requests.request(
|
||||
method,
|
||||
f"https://api.github.com/{endpoint.lstrip('/')}",
|
||||
|
|
@ -57,10 +59,14 @@ def request(method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kw
|
|||
return resp
|
||||
|
||||
|
||||
def list_runs(repo: str, params=None):
|
||||
def list_runs(
|
||||
repo: str, params: Optional[Dict[str, str]] = None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""List all github workflow runs.
|
||||
Returns:
|
||||
An iterator that will iterate through all pages of matching runs."""
|
||||
if params is None:
|
||||
params = {}
|
||||
page = 1
|
||||
total_count = 10000
|
||||
while page * 100 < total_count:
|
||||
|
|
@ -75,11 +81,11 @@ def list_runs(repo: str, params=None):
|
|||
page += 1
|
||||
|
||||
|
||||
def cancel_run(repo: str, run_id: Union[str, int]):
|
||||
def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]:
|
||||
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")
|
||||
|
||||
|
||||
def get_pull_request(repo: str, pull_number: Union[str, int]):
|
||||
def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]:
|
||||
return request("GET", f"/repos/{repo}/pulls/{pull_number}")
|
||||
|
||||
|
||||
|
|
@ -89,7 +95,7 @@ def get_runs(
|
|||
user: Optional[str] = None,
|
||||
statuses: Iterable[str] = ("queued", "in_progress"),
|
||||
events: Iterable[str] = ("pull_request", "push"),
|
||||
):
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get workflow runs associated with the given branch"""
|
||||
return [
|
||||
item
|
||||
|
|
@ -101,7 +107,7 @@ def get_runs(
|
|||
]
|
||||
|
||||
|
||||
def print_commit(commit, branch):
|
||||
def print_commit(commit: Dict[str, Any], branch: str) -> None:
|
||||
"""Print out commit message for verification"""
|
||||
indented_message = " \n".join(commit["message"].split("\n"))
|
||||
date_str = (
|
||||
|
|
@ -151,7 +157,7 @@ def cancel_github_workflows(
|
|||
event: List[str],
|
||||
include_last: bool,
|
||||
include_running: bool,
|
||||
):
|
||||
) -> None:
|
||||
"""Cancel running or queued GitHub workflows by branch or pull request ID"""
|
||||
if not github_token:
|
||||
raise ClickException("Please provide GITHUB_TOKEN as an env variable")
|
||||
|
|
@ -231,7 +237,7 @@ def cancel_github_workflows(
|
|||
try:
|
||||
print(f"[{entry['status']}] {entry['name']}", end="\r")
|
||||
cancel_run(repo, entry["id"])
|
||||
print(f"[Cancled] {entry['name']} ")
|
||||
print(f"[Canceled] {entry['name']} ")
|
||||
except ClickException as error:
|
||||
print(f"[Error: {error.message}] {entry['name']} ")
|
||||
print("")
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from collections import defaultdict
|
|||
from superset import security_manager
|
||||
|
||||
|
||||
def cleanup_permissions():
|
||||
def cleanup_permissions() -> None:
|
||||
# 1. Clean up duplicates.
|
||||
pvms = security_manager.get_session.query(
|
||||
security_manager.permissionview_model
|
||||
|
|
|
|||
16
setup.cfg
16
setup.cfg
|
|
@ -30,21 +30,23 @@ combine_as_imports = true
|
|||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
known_first_party = superset
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
|
||||
multi_line_output = 3
|
||||
order_by_type = false
|
||||
|
||||
[mypy]
|
||||
check_untyped_defs = true
|
||||
disallow_any_generics = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
no_implicit_optional = true
|
||||
warn_unused_ignores = true
|
||||
|
||||
[mypy-superset.*]
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = false
|
||||
|
||||
[mypy-superset.migrations.versions.*]
|
||||
ignore_errors = true
|
||||
|
||||
[mypy-tests.*]
|
||||
check_untyped_defs = false
|
||||
disallow_untyped_calls = false
|
||||
disallow_untyped_defs = false
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -32,7 +32,7 @@ with io.open("README.md", "r", encoding="utf-8") as f:
|
|||
long_description = f.read()
|
||||
|
||||
|
||||
def get_git_sha():
|
||||
def get_git_sha() -> str:
|
||||
try:
|
||||
s = subprocess.check_output(["git", "rev-parse", "HEAD"])
|
||||
return s.decode().strip()
|
||||
|
|
|
|||
|
|
@ -170,9 +170,7 @@ class QueryObject:
|
|||
# 2. { label: 'label_name' } - legacy format for a predefined metric
|
||||
# 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric
|
||||
self.metrics = metrics and [
|
||||
x
|
||||
if isinstance(x, str) or is_adhoc_metric(x)
|
||||
else x["label"] # type: ignore
|
||||
x if isinstance(x, str) or is_adhoc_metric(x) else x["label"]
|
||||
for x in metrics
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -734,7 +734,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
self.table_name, schema=self.schema, show_cols=False, latest_partition=False
|
||||
)
|
||||
|
||||
@property # type: ignore
|
||||
@property
|
||||
def health_check_message(self) -> Optional[str]:
|
||||
check = config["DATASET_HEALTH_CHECK"]
|
||||
return check(self) if check else None
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
|||
|
||||
# try to connect
|
||||
sqlalchemy_uri = engine_spec.build_sqlalchemy_uri(
|
||||
self._properties.get("parameters", None), # type: ignore
|
||||
self._properties.get("parameters"), # type: ignore
|
||||
encrypted_extra,
|
||||
)
|
||||
if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri():
|
||||
|
|
|
|||
|
|
@ -1444,7 +1444,7 @@ class BasicParametersMixin:
|
|||
errors: List[SupersetError] = []
|
||||
|
||||
required = {"host", "port", "username", "database"}
|
||||
present = {key for key in parameters if parameters.get(key, ())} # type: ignore
|
||||
present = {key for key in parameters if parameters.get(key, ())}
|
||||
missing = sorted(required - present)
|
||||
|
||||
if missing:
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
|
|||
return message
|
||||
|
||||
@classmethod
|
||||
def get_column_spec( # type: ignore
|
||||
def get_column_spec(
|
||||
cls,
|
||||
native_type: Optional[str],
|
||||
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
|
||||
|
|
|
|||
|
|
@ -276,7 +276,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
|||
return extra
|
||||
|
||||
@classmethod
|
||||
def get_column_spec( # type: ignore
|
||||
def get_column_spec(
|
||||
cls,
|
||||
native_type: Optional[str],
|
||||
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
|
||||
|
|
|
|||
|
|
@ -1212,7 +1212,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
|
|||
return super().is_readonly_query(parsed_query) or parsed_query.is_show()
|
||||
|
||||
@classmethod
|
||||
def get_column_spec( # type: ignore
|
||||
def get_column_spec(
|
||||
cls,
|
||||
native_type: Optional[str],
|
||||
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class SupersetAppInitializer:
|
|||
self.manifest: Dict[Any, Any] = {}
|
||||
|
||||
@deprecated(details="use self.superset_app instead of self.flask_app") # type: ignore # pylint: disable=line-too-long
|
||||
@property # type: ignore
|
||||
@property
|
||||
def flask_app(self) -> SupersetApp:
|
||||
return self.superset_app
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ class SupersetAppInitializer:
|
|||
|
||||
# Grab each call into the task and set up an app context
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
with superset_app.app_context(): # type: ignore
|
||||
with superset_app.app_context():
|
||||
return task_base.__call__(self, *args, **kwargs)
|
||||
|
||||
celery_app.Task = AppContextTask
|
||||
|
|
@ -573,7 +573,7 @@ class SupersetAppInitializer:
|
|||
self.configure_middlewares()
|
||||
self.configure_cache()
|
||||
|
||||
with self.superset_app.app_context(): # type: ignore
|
||||
with self.superset_app.app_context():
|
||||
self.init_app_in_ctx()
|
||||
|
||||
self.post_init()
|
||||
|
|
@ -689,7 +689,7 @@ class SupersetAppInitializer:
|
|||
def setup_db(self) -> None:
|
||||
db.init_app(self.superset_app)
|
||||
|
||||
with self.superset_app.app_context(): # type: ignore
|
||||
with self.superset_app.app_context():
|
||||
pessimistic_connection_handling(db.engine)
|
||||
|
||||
migrate.init_app(self.superset_app, db=db, directory=APP_DIR + "/migrations")
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ from superset.utils import core as utils
|
|||
from superset.utils.hashing import md5_sha_from_str
|
||||
from superset.utils.memoized import memoized
|
||||
from superset.utils.urls import get_url_path
|
||||
from superset.viz import BaseViz, viz_types # type: ignore
|
||||
from superset.viz import BaseViz, viz_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@ def _get_slice_data(
|
|||
|
||||
# Parse the csv file and generate HTML
|
||||
columns = rows.pop(0)
|
||||
with app.app_context(): # type: ignore
|
||||
with app.app_context():
|
||||
body = render_template(
|
||||
"superset/reports/slice_data.html",
|
||||
columns=columns,
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from superset.utils.hashing import md5_sha_from_dict
|
|||
if TYPE_CHECKING:
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
|
||||
config = app.config # type: ignore
|
||||
config = app.config
|
||||
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
|
@ -15,6 +15,9 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=unused-argument, invalid-name
|
||||
from flask.ctx import AppContext
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
|
||||
|
||||
|
|
@ -24,16 +27,30 @@ class ProgrammingError(Exception):
|
|||
"""
|
||||
|
||||
|
||||
def test_validate_parameters_simple(mocker, app_context):
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
def test_validate_parameters_simple(
|
||||
mocker: MockFixture, app_context: AppContext,
|
||||
) -> None:
|
||||
from superset.db_engine_specs.gsheets import (
|
||||
GSheetsEngineSpec,
|
||||
GSheetsParametersType,
|
||||
)
|
||||
|
||||
parameters = {}
|
||||
parameters: GSheetsParametersType = {
|
||||
"credentials_info": {},
|
||||
"query": {},
|
||||
"table_catalog": {},
|
||||
}
|
||||
errors = GSheetsEngineSpec.validate_parameters(parameters)
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_parameters_catalog(mocker, app_context):
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
def test_validate_parameters_catalog(
|
||||
mocker: MockFixture, app_context: AppContext,
|
||||
) -> None:
|
||||
from superset.db_engine_specs.gsheets import (
|
||||
GSheetsEngineSpec,
|
||||
GSheetsParametersType,
|
||||
)
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user.email = "admin@example.com"
|
||||
|
|
@ -47,7 +64,9 @@ def test_validate_parameters_catalog(mocker, app_context):
|
|||
ProgrammingError("Unsupported table: https://www.google.com/"),
|
||||
]
|
||||
|
||||
parameters = {
|
||||
parameters: GSheetsParametersType = {
|
||||
"credentials_info": {},
|
||||
"query": {},
|
||||
"table_catalog": {
|
||||
"private_sheet": "https://docs.google.com/spreadsheets/d/1/edit",
|
||||
"public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1",
|
||||
|
|
@ -114,12 +133,17 @@ def test_validate_parameters_catalog(mocker, app_context):
|
|||
),
|
||||
]
|
||||
create_engine.assert_called_with(
|
||||
"gsheets://", service_account_info=None, subject="admin@example.com",
|
||||
"gsheets://", service_account_info={}, subject="admin@example.com",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parameters_catalog_and_credentials(mocker, app_context):
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
def test_validate_parameters_catalog_and_credentials(
|
||||
mocker: MockFixture, app_context: AppContext,
|
||||
) -> None:
|
||||
from superset.db_engine_specs.gsheets import (
|
||||
GSheetsEngineSpec,
|
||||
GSheetsParametersType,
|
||||
)
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user.email = "admin@example.com"
|
||||
|
|
@ -133,13 +157,14 @@ def test_validate_parameters_catalog_and_credentials(mocker, app_context):
|
|||
ProgrammingError("Unsupported table: https://www.google.com/"),
|
||||
]
|
||||
|
||||
parameters = {
|
||||
parameters: GSheetsParametersType = {
|
||||
"credentials_info": {},
|
||||
"query": {},
|
||||
"table_catalog": {
|
||||
"private_sheet": "https://docs.google.com/spreadsheets/d/1/edit",
|
||||
"public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1",
|
||||
"not_a_sheet": "https://www.google.com/",
|
||||
},
|
||||
"credentials_info": "SECRET",
|
||||
}
|
||||
errors = GSheetsEngineSpec.validate_parameters(parameters)
|
||||
assert errors == [
|
||||
|
|
@ -173,5 +198,5 @@ def test_validate_parameters_catalog_and_credentials(mocker, app_context):
|
|||
),
|
||||
]
|
||||
create_engine.assert_called_with(
|
||||
"gsheets://", service_account_info="SECRET", subject="admin@example.com",
|
||||
"gsheets://", service_account_info={}, subject="admin@example.com",
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue