diff --git a/freqtrade/rpc/api_server.py b/freqtrade/rpc/api_server.py index 711202b27..67bbfdc78 100644 --- a/freqtrade/rpc/api_server.py +++ b/freqtrade/rpc/api_server.py @@ -2,7 +2,7 @@ import logging import threading from datetime import date, datetime from ipaddress import IPv4Address -from typing import Dict +from typing import Dict, Callable, Any from arrow import Arrow from flask import Flask, jsonify, request @@ -34,41 +34,45 @@ class ArrowJSONEncoder(JSONEncoder): return JSONEncoder.default(self, obj) +# Type should really be Callable[[ApiServer, Any], Any], but that will create a circular dependency +def require_login(func: Callable[[Any, Any], Any]): + + def func_wrapper(obj, *args, **kwargs): + + auth = request.authorization + if auth and obj.check_auth(auth.username, auth.password): + return func(obj, *args, **kwargs) + else: + return jsonify({"error": "Unauthorized"}), 401 + + return func_wrapper + + +# Type should really be Callable[[ApiServer], Any], but that will create a circular dependency +def rpc_catch_errors(func: Callable[[Any], Any]): + + def func_wrapper(obj, *args, **kwargs): + + try: + return func(obj, *args, **kwargs) + except RPCException as e: + logger.exception("API Error calling %s: %s", func.__name__, e) + return obj.rest_error(f"Error querying {func.__name__}: {e}") + + return func_wrapper + + class ApiServer(RPC): """ This class runs api server and provides rpc.rpc functionality to it - This class starts a none blocking thread the api server runs within + This class starts a non-blocking thread the api server runs within """ - def rpc_catch_errors(func): - - def func_wrapper(self, *args, **kwargs): - - try: - return func(self, *args, **kwargs) - except RPCException as e: - logger.exception("API Error calling %s: %s", func.__name__, e) - return self.rest_error(f"Error querying {func.__name__}: {e}") - - return func_wrapper - def check_auth(self, username, password): return (username == self._config['api_server'].get('username') and password == self._config['api_server'].get('password')) - def require_login(func): - - def func_wrapper(self, *args, **kwargs): - - auth = request.authorization - if auth and self.check_auth(auth.username, auth.password): - return func(self, *args, **kwargs) - else: - return jsonify({"error": "Unauthorized"}), 401 - - return func_wrapper - def __init__(self, freqtrade) -> None: """ Init the api server, and init the super class RPC diff --git a/requirements-dev.txt b/requirements-dev.txt index 0578c03bd..f5cde59e8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ coveralls==1.8.2 flake8==3.7.8 flake8-type-annotations==0.1.0 flake8-tidy-imports==3.0.0 -mypy==0.730 +mypy==0.740 pytest==5.2.1 pytest-asyncio==0.10.0 pytest-cov==2.8.1