From 659c8c237f7a7e30ad0929fed448c449a01fb2bf Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Mon, 14 Nov 2022 20:27:45 -0700 Subject: [PATCH 01/21] initial revision --- freqtrade/rpc/api_server/api_ws.py | 170 +++++--- freqtrade/rpc/api_server/deps.py | 4 +- freqtrade/rpc/api_server/webserver.py | 162 +++++--- freqtrade/rpc/api_server/ws/__init__.py | 3 +- freqtrade/rpc/api_server/ws/channel.py | 365 ++++++++++++------ freqtrade/rpc/api_server/ws/message_stream.py | 23 ++ freqtrade/rpc/api_server/ws/serializer.py | 8 +- 7 files changed, 494 insertions(+), 241 deletions(-) create mode 100644 freqtrade/rpc/api_server/ws/message_stream.py diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 785773b39..a9b88aadb 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,16 +1,17 @@ +import asyncio import logging from typing import Any, Dict -from fastapi import APIRouter, Depends, WebSocketDisconnect -from fastapi.websockets import WebSocket, WebSocketState +from fastapi import APIRouter, Depends +from fastapi.websockets import WebSocket, WebSocketDisconnect from pydantic import ValidationError -from websockets.exceptions import WebSocketException +from websockets.exceptions import ConnectionClosed from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import validate_ws_token -from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc +from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc from freqtrade.rpc.api_server.ws import WebSocketChannel -from freqtrade.rpc.api_server.ws.channel import ChannelManager +from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, WSRequestSchema, WSWhitelistMessage) from freqtrade.rpc.rpc import RPC @@ -22,23 +23,63 @@ logger = logging.getLogger(__name__) router = APIRouter() -async def is_websocket_alive(ws: WebSocket) -> bool: +# async def is_websocket_alive(ws: WebSocket) -> bool: +# """ +# Check if a FastAPI Websocket is still open +# """ +# if ( +# ws.application_state == WebSocketState.CONNECTED and +# ws.client_state == WebSocketState.CONNECTED +# ): +# return True +# return False + + +class WebSocketChannelClosed(Exception): """ - Check if a FastAPI Websocket is still open + General WebSocket exception to signal closing the channel """ - if ( - ws.application_state == WebSocketState.CONNECTED and - ws.client_state == WebSocketState.CONNECTED + pass + + +async def channel_reader(channel: WebSocketChannel, rpc: RPC): + """ + Iterate over the messages from the channel and process the request + """ + try: + async for message in channel: + await _process_consumer_request(message, channel, rpc) + except ( + RuntimeError, + WebSocketDisconnect, + ConnectionClosed ): - return True - return False + raise WebSocketChannelClosed + except asyncio.CancelledError: + return + + +async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): + """ + Iterate over messages in the message stream and send them + """ + try: + async for message in message_stream: + await channel.send(message) + except ( + RuntimeError, + WebSocketDisconnect, + ConnectionClosed + ): + raise WebSocketChannelClosed + except asyncio.CancelledError: + return async def _process_consumer_request( request: Dict[str, Any], channel: WebSocketChannel, - rpc: RPC, - channel_manager: ChannelManager + rpc: RPC ): """ Validate and handle a request from a websocket consumer @@ -75,7 +116,7 @@ async def _process_consumer_request( # Format response response = WSWhitelistMessage(data=whitelist) # Send it back - await channel_manager.send_direct(channel, response.dict(exclude_none=True)) + await channel.send(response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: limit = None @@ -86,53 +127,76 @@ async def _process_consumer_request( # For every pair in the generator, send a separate message for message in rpc._ws_request_analyzed_df(limit): + # Format response response = WSAnalyzedDFMessage(data=message) - await channel_manager.send_direct(channel, response.dict(exclude_none=True)) + await channel.send(response.dict(exclude_none=True)) @router.websocket("/message/ws") async def message_endpoint( - ws: WebSocket, + websocket: WebSocket, + token: str = Depends(validate_ws_token), rpc: RPC = Depends(get_rpc), - channel_manager=Depends(get_channel_manager), - token: str = Depends(validate_ws_token) + message_stream: MessageStream = Depends(get_message_stream) ): - """ - Message WebSocket endpoint, facilitates sending RPC messages - """ - try: - channel = await channel_manager.on_connect(ws) - if await is_websocket_alive(ws): + async with WebSocketChannel(websocket).connect() as channel: + try: + logger.info(f"Channel connected - {channel}") - logger.info(f"Consumer connected - {channel}") + channel_tasks = asyncio.gather( + channel_reader(channel, rpc), + channel_broadcaster(channel, message_stream) + ) + await channel_tasks - # Keep connection open until explicitly closed, and process requests - try: - while not channel.is_closed(): - request = await channel.recv() + finally: + logger.info(f"Channel disconnected - {channel}") + channel_tasks.cancel() - # Process the request here - await _process_consumer_request(request, channel, rpc, channel_manager) - except (WebSocketDisconnect, WebSocketException): - # Handle client disconnects - logger.info(f"Consumer disconnected - {channel}") - except RuntimeError: - # Handle cases like - - # RuntimeError('Cannot call "send" once a closed message has been sent') - pass - except Exception as e: - logger.info(f"Consumer connection failed - {channel}: {e}") - logger.debug(e, exc_info=e) +# @router.websocket("/message/ws") +# async def message_endpoint( +# ws: WebSocket, +# rpc: RPC = Depends(get_rpc), +# channel_manager=Depends(get_channel_manager), +# token: str = Depends(validate_ws_token) +# ): +# """ +# Message WebSocket endpoint, facilitates sending RPC messages +# """ +# try: +# channel = await channel_manager.on_connect(ws) +# if await is_websocket_alive(ws): - except RuntimeError: - # WebSocket was closed - # Do nothing - pass - except Exception as e: - logger.error(f"Failed to serve - {ws.client}") - # Log tracebacks to keep track of what errors are happening - logger.exception(e) - finally: - if channel: - await channel_manager.on_disconnect(ws) +# logger.info(f"Consumer connected - {channel}") + +# # Keep connection open until explicitly closed, and process requests +# try: +# while not channel.is_closed(): +# request = await channel.recv() + +# # Process the request here +# await _process_consumer_request(request, channel, rpc, channel_manager) + +# except (WebSocketDisconnect, WebSocketException): +# # Handle client disconnects +# logger.info(f"Consumer disconnected - {channel}") +# except RuntimeError: +# # Handle cases like - +# # RuntimeError('Cannot call "send" once a closed message has been sent') +# pass +# except Exception as e: +# logger.info(f"Consumer connection failed - {channel}: {e}") +# logger.debug(e, exc_info=e) + +# except RuntimeError: +# # WebSocket was closed +# # Do nothing +# pass +# except Exception as e: +# logger.error(f"Failed to serve - {ws.client}") +# # Log tracebacks to keep track of what errors are happening +# logger.exception(e) +# finally: +# if channel: +# await channel_manager.on_disconnect(ws) diff --git a/freqtrade/rpc/api_server/deps.py b/freqtrade/rpc/api_server/deps.py index abd3db036..aed97367b 100644 --- a/freqtrade/rpc/api_server/deps.py +++ b/freqtrade/rpc/api_server/deps.py @@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)): return ApiServer._exchange -def get_channel_manager(): - return ApiServer._ws_channel_manager +def get_message_stream(): + return ApiServer._message_stream def is_webserver_mode(config=Depends(get_config)): diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index e9a12e4df..7e2c3f39f 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -1,7 +1,6 @@ import asyncio import logging from ipaddress import IPv4Address -from threading import Thread from typing import Any, Dict import orjson @@ -15,7 +14,7 @@ from starlette.responses import JSONResponse from freqtrade.constants import Config from freqtrade.exceptions import OperationalException from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer -from freqtrade.rpc.api_server.ws import ChannelManager +from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler @@ -51,9 +50,10 @@ class ApiServer(RPCHandler): # Exchange - only available in webserver mode. _exchange = None # websocket message queue stuff - _ws_channel_manager = None - _ws_thread = None - _ws_loop = None + # _ws_channel_manager = None + # _ws_thread = None + # _ws_loop = None + _message_stream = None def __new__(cls, *args, **kwargs): """ @@ -71,14 +71,15 @@ class ApiServer(RPCHandler): return self._standalone: bool = standalone self._server = None + self._ws_queue = None - self._ws_background_task = None + self._ws_publisher_task = None ApiServer.__initialized = True api_config = self._config['api_server'] - ApiServer._ws_channel_manager = ChannelManager() + # ApiServer._ws_channel_manager = ChannelManager() self.app = FastAPI(title="Freqtrade API", docs_url='/docs' if api_config.get('enable_openapi', False) else None, @@ -107,18 +108,18 @@ class ApiServer(RPCHandler): logger.info("Stopping API Server") self._server.cleanup() - if self._ws_thread and self._ws_loop: - logger.info("Stopping API Server background tasks") + # if self._ws_thread and self._ws_loop: + # logger.info("Stopping API Server background tasks") - if self._ws_background_task: - # Cancel the queue task - self._ws_background_task.cancel() + # if self._ws_background_task: + # # Cancel the queue task + # self._ws_background_task.cancel() - self._ws_thread.join() + # self._ws_thread.join() - self._ws_thread = None - self._ws_loop = None - self._ws_background_task = None + # self._ws_thread = None + # self._ws_loop = None + # self._ws_background_task = None @classmethod def shutdown(cls): @@ -170,51 +171,102 @@ class ApiServer(RPCHandler): ) app.add_exception_handler(RPCException, self.handle_rpc_exception) + app.add_event_handler( + event_type="startup", + func=self._api_startup_event + ) + app.add_event_handler( + event_type="shutdown", + func=self._api_shutdown_event + ) - def start_message_queue(self): - if self._ws_thread: - return + async def _api_startup_event(self): + if not ApiServer._message_stream: + ApiServer._message_stream = MessageStream() - # Create a new loop, as it'll be just for the background thread - self._ws_loop = asyncio.new_event_loop() + if not self._ws_queue: + self._ws_queue = ThreadedQueue() - # Start the thread - self._ws_thread = Thread(target=self._ws_loop.run_forever) - self._ws_thread.start() + if not self._ws_publisher_task: + self._ws_publisher_task = asyncio.create_task( + self._publish_messages() + ) - # Finally, submit the coro to the thread - self._ws_background_task = asyncio.run_coroutine_threadsafe( - self._broadcast_queue_data(), loop=self._ws_loop) + async def _api_shutdown_event(self): + if ApiServer._message_stream: + ApiServer._message_stream = None - async def _broadcast_queue_data(self): - # Instantiate the queue in this coroutine so it's attached to our loop - self._ws_queue = ThreadedQueue() - async_queue = self._ws_queue.async_q - - try: - while True: - logger.debug("Getting queue messages...") - # Get data from queue - message: WSMessageSchemaType = await async_queue.get() - logger.debug(f"Found message of type: {message.get('type')}") - async_queue.task_done() - # Broadcast it - await self._ws_channel_manager.broadcast(message) - except asyncio.CancelledError: - pass - - # For testing, shouldn't happen when stable - except Exception as e: - logger.exception(f"Exception happened in background task: {e}") - - finally: - # Disconnect channels and stop the loop on cancel - await self._ws_channel_manager.disconnect_all() - self._ws_loop.stop() - # Avoid adding more items to the queue if they aren't - # going to get broadcasted. + if self._ws_queue: self._ws_queue = None + if self._ws_publisher_task: + self._ws_publisher_task.cancel() + + async def _publish_messages(self): + """ + Background task that reads messages from the queue and adds them + to the message stream + """ + try: + async_queue = self._ws_queue.async_q + message_stream = ApiServer._message_stream + + while message_stream: + message: WSMessageSchemaType = await async_queue.get() + message_stream.publish(message) + + # Make sure to throttle how fast we + # publish messages as some clients will be + # slower than others + await asyncio.sleep(0.01) + async_queue.task_done() + finally: + self._ws_queue = None + + # def start_message_queue(self): + # if self._ws_thread: + # return + + # # Create a new loop, as it'll be just for the background thread + # self._ws_loop = asyncio.new_event_loop() + + # # Start the thread + # self._ws_thread = Thread(target=self._ws_loop.run_forever) + # self._ws_thread.start() + + # # Finally, submit the coro to the thread + # self._ws_background_task = asyncio.run_coroutine_threadsafe( + # self._broadcast_queue_data(), loop=self._ws_loop) + + # async def _broadcast_queue_data(self): + # # Instantiate the queue in this coroutine so it's attached to our loop + # self._ws_queue = ThreadedQueue() + # async_queue = self._ws_queue.async_q + + # try: + # while True: + # logger.debug("Getting queue messages...") + # # Get data from queue + # message: WSMessageSchemaType = await async_queue.get() + # logger.debug(f"Found message of type: {message.get('type')}") + # async_queue.task_done() + # # Broadcast it + # await self._ws_channel_manager.broadcast(message) + # except asyncio.CancelledError: + # pass + + # # For testing, shouldn't happen when stable + # except Exception as e: + # logger.exception(f"Exception happened in background task: {e}") + + # finally: + # # Disconnect channels and stop the loop on cancel + # await self._ws_channel_manager.disconnect_all() + # self._ws_loop.stop() + # # Avoid adding more items to the queue if they aren't + # # going to get broadcasted. + # self._ws_queue = None + def start_api(self): """ Start API ... should be run in thread. @@ -253,7 +305,7 @@ class ApiServer(RPCHandler): if self._standalone: self._server.run() else: - self.start_message_queue() + # self.start_message_queue() self._server.run_in_thread() except Exception: logger.exception("Api server failed to start.") diff --git a/freqtrade/rpc/api_server/ws/__init__.py b/freqtrade/rpc/api_server/ws/__init__.py index 055b20a9d..0b94d3fee 100644 --- a/freqtrade/rpc/api_server/ws/__init__.py +++ b/freqtrade/rpc/api_server/ws/__init__.py @@ -3,4 +3,5 @@ from freqtrade.rpc.api_server.ws.types import WebSocketType from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer -from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel +from freqtrade.rpc.api_server.ws.channel import WebSocketChannel +from freqtrade.rpc.api_server.ws.message_stream import MessageStream diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 88b4db9ba..b98bd13c9 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,12 +1,9 @@ import asyncio import logging -import time -from threading import RLock +from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Type, Union from uuid import uuid4 -from fastapi import WebSocket as FastAPIWebSocket - from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, WebSocketSerializer) @@ -21,32 +18,21 @@ class WebSocketChannel: """ Object to help facilitate managing a websocket connection """ - def __init__( self, websocket: WebSocketType, channel_id: Optional[str] = None, - drain_timeout: int = 3, - throttle: float = 0.01, serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer ): - self.channel_id = channel_id if channel_id else uuid4().hex[:8] - - # The WebSocket object self._websocket = WebSocketProxy(websocket) - self.drain_timeout = drain_timeout - self.throttle = throttle - - self._subscriptions: List[str] = [] - # 32 is the size of the receiving queue in websockets package - self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) - self._relay_task = asyncio.create_task(self.relay()) - # Internal event to signify a closed websocket self._closed = asyncio.Event() + # Throttle how fast we send messages + self._throttle = 0.01 + # Wrap the WebSocket in the Serializing class self._wrapped_ws = serializer_cls(self._websocket) @@ -61,40 +47,16 @@ class WebSocketChannel: def remote_addr(self): return self._websocket.remote_addr - async def _send(self, data): + async def send(self, message: Union[WSMessageSchemaType, Dict[str, Any]]): """ - Send data on the wrapped websocket + Send a message on the wrapped websocket """ - await self._wrapped_ws.send(data) - - async def send(self, data) -> bool: - """ - Add the data to the queue to be sent. - :returns: True if data added to queue, False otherwise - """ - - # This block only runs if the queue is full, it will wait - # until self.drain_timeout for the relay to drain the outgoing queue - # We can't use asyncio.wait_for here because the queue may have been created with a - # different eventloop - start = time.time() - while self.queue.full(): - await asyncio.sleep(1) - if (time.time() - start) > self.drain_timeout: - return False - - # If for some reason the queue is still full, just return False - try: - self.queue.put_nowait(data) - except asyncio.QueueFull: - return False - - # If we got here everything is ok - return True + await asyncio.sleep(self._throttle) + await self._wrapped_ws.send(message) async def recv(self): """ - Receive data on the wrapped websocket + Receive a message on the wrapped websocket """ return await self._wrapped_ws.recv() @@ -104,18 +66,23 @@ class WebSocketChannel: """ return await self._websocket.ping() + async def accept(self): + """ + Accept the underlying websocket connection + """ + return await self._websocket.accept() + async def close(self): """ Close the WebSocketChannel """ try: - await self.raw_websocket.close() + await self._websocket.close() except Exception: pass self._closed.set() - self._relay_task.cancel() def is_closed(self) -> bool: """ @@ -139,99 +106,243 @@ class WebSocketChannel: """ return message_type in self._subscriptions - async def relay(self): + async def __aiter__(self): """ - Relay messages from the channel's queue and send them out. This is started - as a task. + Generator for received messages """ - while not self._closed.is_set(): - message = await self.queue.get() + while True: try: - await self._send(message) - self.queue.task_done() + yield await self.recv() + except Exception: + break - # Limit messages per sec. - # Could cause problems with queue size if too low, and - # problems with network traffik if too high. - # 0.01 = 100/s - await asyncio.sleep(self.throttle) - except RuntimeError: - # The connection was closed, just exit the task - return - - -class ChannelManager: - def __init__(self): - self.channels = dict() - self._lock = RLock() # Re-entrant Lock - - async def on_connect(self, websocket: WebSocketType): + @asynccontextmanager + async def connect(self): """ - Wrap websocket connection into Channel and add to list - - :param websocket: The WebSocket object to attach to the Channel + Context manager for safely opening and closing the websocket connection """ - if isinstance(websocket, FastAPIWebSocket): - try: - await websocket.accept() - except RuntimeError: - # The connection was closed before we could accept it - return + try: + await self.accept() + yield self + finally: + await self.close() - ws_channel = WebSocketChannel(websocket) - with self._lock: - self.channels[websocket] = ws_channel +# class WebSocketChannel: +# """ +# Object to help facilitate managing a websocket connection +# """ - return ws_channel +# def __init__( +# self, +# websocket: WebSocketType, +# channel_id: Optional[str] = None, +# drain_timeout: int = 3, +# throttle: float = 0.01, +# serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer +# ): - async def on_disconnect(self, websocket: WebSocketType): - """ - Call close on the channel if it's not, and remove from channel list +# self.channel_id = channel_id if channel_id else uuid4().hex[:8] - :param websocket: The WebSocket objet attached to the Channel - """ - with self._lock: - channel = self.channels.get(websocket) - if channel: - logger.info(f"Disconnecting channel {channel}") - if not channel.is_closed(): - await channel.close() +# # The WebSocket object +# self._websocket = WebSocketProxy(websocket) - del self.channels[websocket] +# self.drain_timeout = drain_timeout +# self.throttle = throttle - async def disconnect_all(self): - """ - Disconnect all Channels - """ - with self._lock: - for websocket in self.channels.copy().keys(): - await self.on_disconnect(websocket) +# self._subscriptions: List[str] = [] +# # 32 is the size of the receiving queue in websockets package +# self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) +# self._relay_task = asyncio.create_task(self.relay()) - async def broadcast(self, message: WSMessageSchemaType): - """ - Broadcast a message on all Channels +# # Internal event to signify a closed websocket +# self._closed = asyncio.Event() - :param message: The message to send - """ - with self._lock: - for channel in self.channels.copy().values(): - if channel.subscribed_to(message.get('type')): - await self.send_direct(channel, message) +# # Wrap the WebSocket in the Serializing class +# self._wrapped_ws = serializer_cls(self._websocket) - async def send_direct( - self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): - """ - Send a message directly through direct_channel only +# def __repr__(self): +# return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" - :param direct_channel: The WebSocketChannel object to send the message through - :param message: The message to send - """ - if not await channel.send(message): - await self.on_disconnect(channel.raw_websocket) +# @property +# def raw_websocket(self): +# return self._websocket.raw_websocket - def has_channels(self): - """ - Flag for more than 0 channels - """ - return len(self.channels) > 0 +# @property +# def remote_addr(self): +# return self._websocket.remote_addr + +# async def _send(self, data): +# """ +# Send data on the wrapped websocket +# """ +# await self._wrapped_ws.send(data) + +# async def send(self, data) -> bool: +# """ +# Add the data to the queue to be sent. +# :returns: True if data added to queue, False otherwise +# """ + +# # This block only runs if the queue is full, it will wait +# # until self.drain_timeout for the relay to drain the outgoing queue +# # We can't use asyncio.wait_for here because the queue may have been created with a +# # different eventloop +# start = time.time() +# while self.queue.full(): +# await asyncio.sleep(1) +# if (time.time() - start) > self.drain_timeout: +# return False + +# # If for some reason the queue is still full, just return False +# try: +# self.queue.put_nowait(data) +# except asyncio.QueueFull: +# return False + +# # If we got here everything is ok +# return True + +# async def recv(self): +# """ +# Receive data on the wrapped websocket +# """ +# return await self._wrapped_ws.recv() + +# async def ping(self): +# """ +# Ping the websocket +# """ +# return await self._websocket.ping() + +# async def close(self): +# """ +# Close the WebSocketChannel +# """ + +# try: +# await self.raw_websocket.close() +# except Exception: +# pass + +# self._closed.set() +# self._relay_task.cancel() + +# def is_closed(self) -> bool: +# """ +# Closed flag +# """ +# return self._closed.is_set() + +# def set_subscriptions(self, subscriptions: List[str] = []) -> None: +# """ +# Set which subscriptions this channel is subscribed to + +# :param subscriptions: List of subscriptions, List[str] +# """ +# self._subscriptions = subscriptions + +# def subscribed_to(self, message_type: str) -> bool: +# """ +# Check if this channel is subscribed to the message_type + +# :param message_type: The message type to check +# """ +# return message_type in self._subscriptions + +# async def relay(self): +# """ +# Relay messages from the channel's queue and send them out. This is started +# as a task. +# """ +# while not self._closed.is_set(): +# message = await self.queue.get() +# try: +# await self._send(message) +# self.queue.task_done() + +# # Limit messages per sec. +# # Could cause problems with queue size if too low, and +# # problems with network traffik if too high. +# # 0.01 = 100/s +# await asyncio.sleep(self.throttle) +# except RuntimeError: +# # The connection was closed, just exit the task +# return + + +# class ChannelManager: +# def __init__(self): +# self.channels = dict() +# self._lock = RLock() # Re-entrant Lock + +# async def on_connect(self, websocket: WebSocketType): +# """ +# Wrap websocket connection into Channel and add to list + +# :param websocket: The WebSocket object to attach to the Channel +# """ +# if isinstance(websocket, FastAPIWebSocket): +# try: +# await websocket.accept() +# except RuntimeError: +# # The connection was closed before we could accept it +# return + +# ws_channel = WebSocketChannel(websocket) + +# with self._lock: +# self.channels[websocket] = ws_channel + +# return ws_channel + +# async def on_disconnect(self, websocket: WebSocketType): +# """ +# Call close on the channel if it's not, and remove from channel list + +# :param websocket: The WebSocket objet attached to the Channel +# """ +# with self._lock: +# channel = self.channels.get(websocket) +# if channel: +# logger.info(f"Disconnecting channel {channel}") +# if not channel.is_closed(): +# await channel.close() + +# del self.channels[websocket] + +# async def disconnect_all(self): +# """ +# Disconnect all Channels +# """ +# with self._lock: +# for websocket in self.channels.copy().keys(): +# await self.on_disconnect(websocket) + +# async def broadcast(self, message: WSMessageSchemaType): +# """ +# Broadcast a message on all Channels + +# :param message: The message to send +# """ +# with self._lock: +# for channel in self.channels.copy().values(): +# if channel.subscribed_to(message.get('type')): +# await self.send_direct(channel, message) + +# async def send_direct( +# self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): +# """ +# Send a message directly through direct_channel only + +# :param direct_channel: The WebSocketChannel object to send the message through +# :param message: The message to send +# """ +# if not await channel.send(message): +# await self.on_disconnect(channel.raw_websocket) + +# def has_channels(self): +# """ +# Flag for more than 0 channels +# """ +# return len(self.channels) > 0 diff --git a/freqtrade/rpc/api_server/ws/message_stream.py b/freqtrade/rpc/api_server/ws/message_stream.py new file mode 100644 index 000000000..f77242719 --- /dev/null +++ b/freqtrade/rpc/api_server/ws/message_stream.py @@ -0,0 +1,23 @@ +import asyncio + + +class MessageStream: + """ + A message stream for consumers to subscribe to, + and for producers to publish to. + """ + def __init__(self): + self._loop = asyncio.get_running_loop() + self._waiter = self._loop.create_future() + + def publish(self, message): + waiter, self._waiter = self._waiter, self._loop.create_future() + waiter.set_result((message, self._waiter)) + + async def subscribe(self): + waiter = self._waiter + while True: + message, waiter = await waiter + yield message + + __aiter__ = subscribe diff --git a/freqtrade/rpc/api_server/ws/serializer.py b/freqtrade/rpc/api_server/ws/serializer.py index 6c402a100..85703136b 100644 --- a/freqtrade/rpc/api_server/ws/serializer.py +++ b/freqtrade/rpc/api_server/ws/serializer.py @@ -1,5 +1,6 @@ import logging from abc import ABC, abstractmethod +from typing import Any, Dict, Union import orjson import rapidjson @@ -7,6 +8,7 @@ from pandas import DataFrame from freqtrade.misc import dataframe_to_json, json_to_dataframe from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy +from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ class WebSocketSerializer(ABC): def _deserialize(self, data): raise NotImplementedError() - async def send(self, data: bytes): + async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]): await self._websocket.send(self._serialize(data)) async def recv(self) -> bytes: @@ -32,8 +34,8 @@ class WebSocketSerializer(ABC): return self._deserialize(data) - async def close(self, code: int = 1000): - await self._websocket.close(code) + # async def close(self, code: int = 1000): + # await self._websocket.close(code) class HybridJSONWebSocketSerializer(WebSocketSerializer): From d713af045fbd51df67825836d9fe3a17f1424622 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Mon, 14 Nov 2022 22:21:40 -0700 Subject: [PATCH 02/21] remove main queue completely --- freqtrade/rpc/api_server/api_ws.py | 3 +- freqtrade/rpc/api_server/webserver.py | 47 ++------------------------ freqtrade/rpc/api_server/ws/channel.py | 5 ++- 3 files changed, 6 insertions(+), 49 deletions(-) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index a9b88aadb..3f207eac3 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -148,7 +148,8 @@ async def message_endpoint( channel_broadcaster(channel, message_stream) ) await channel_tasks - + except WebSocketChannelClosed: + pass finally: logger.info(f"Channel disconnected - {channel}") channel_tasks.cancel() diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 7e2c3f39f..d0695e06d 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -1,4 +1,3 @@ -import asyncio import logging from ipaddress import IPv4Address from typing import Any, Dict @@ -7,15 +6,12 @@ import orjson import uvicorn from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware -# Look into alternatives -from janus import Queue as ThreadedQueue from starlette.responses import JSONResponse from freqtrade.constants import Config from freqtrade.exceptions import OperationalException from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer from freqtrade.rpc.api_server.ws.message_stream import MessageStream -from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler @@ -72,9 +68,6 @@ class ApiServer(RPCHandler): self._standalone: bool = standalone self._server = None - self._ws_queue = None - self._ws_publisher_task = None - ApiServer.__initialized = True api_config = self._config['api_server'] @@ -130,9 +123,8 @@ class ApiServer(RPCHandler): cls._rpc = None def send_msg(self, msg: Dict[str, Any]) -> None: - if self._ws_queue: - sync_q = self._ws_queue.sync_q - sync_q.put(msg) + if ApiServer._message_stream: + ApiServer._message_stream.publish(msg) def handle_rpc_exception(self, request, exc): logger.exception(f"API Error calling: {exc}") @@ -184,45 +176,10 @@ class ApiServer(RPCHandler): if not ApiServer._message_stream: ApiServer._message_stream = MessageStream() - if not self._ws_queue: - self._ws_queue = ThreadedQueue() - - if not self._ws_publisher_task: - self._ws_publisher_task = asyncio.create_task( - self._publish_messages() - ) - async def _api_shutdown_event(self): if ApiServer._message_stream: ApiServer._message_stream = None - if self._ws_queue: - self._ws_queue = None - - if self._ws_publisher_task: - self._ws_publisher_task.cancel() - - async def _publish_messages(self): - """ - Background task that reads messages from the queue and adds them - to the message stream - """ - try: - async_queue = self._ws_queue.async_q - message_stream = ApiServer._message_stream - - while message_stream: - message: WSMessageSchemaType = await async_queue.get() - message_stream.publish(message) - - # Make sure to throttle how fast we - # publish messages as some clients will be - # slower than others - await asyncio.sleep(0.01) - async_queue.task_done() - finally: - self._ws_queue = None - # def start_message_queue(self): # if self._ws_thread: # return diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index b98bd13c9..39c8db516 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -30,8 +30,8 @@ class WebSocketChannel: # Internal event to signify a closed websocket self._closed = asyncio.Event() - # Throttle how fast we send messages - self._throttle = 0.01 + # The subscribed message types + self._subscriptions: List[str] = [] # Wrap the WebSocket in the Serializing class self._wrapped_ws = serializer_cls(self._websocket) @@ -51,7 +51,6 @@ class WebSocketChannel: """ Send a message on the wrapped websocket """ - await asyncio.sleep(self._throttle) await self._wrapped_ws.send(message) async def recv(self): From 442467e8aed2ff639bfba04e7a2f6e175f774af1 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Mon, 14 Nov 2022 22:26:34 -0700 Subject: [PATCH 03/21] remove old comments and code --- freqtrade/rpc/api_server/api_ws.py | 60 ------ freqtrade/rpc/api_server/webserver.py | 75 ++------ freqtrade/rpc/api_server/ws/channel.py | 220 ---------------------- freqtrade/rpc/api_server/ws/serializer.py | 3 - 4 files changed, 12 insertions(+), 346 deletions(-) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 3f207eac3..01243b0cc 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -23,18 +23,6 @@ logger = logging.getLogger(__name__) router = APIRouter() -# async def is_websocket_alive(ws: WebSocket) -> bool: -# """ -# Check if a FastAPI Websocket is still open -# """ -# if ( -# ws.application_state == WebSocketState.CONNECTED and -# ws.client_state == WebSocketState.CONNECTED -# ): -# return True -# return False - - class WebSocketChannelClosed(Exception): """ General WebSocket exception to signal closing the channel @@ -153,51 +141,3 @@ async def message_endpoint( finally: logger.info(f"Channel disconnected - {channel}") channel_tasks.cancel() - - -# @router.websocket("/message/ws") -# async def message_endpoint( -# ws: WebSocket, -# rpc: RPC = Depends(get_rpc), -# channel_manager=Depends(get_channel_manager), -# token: str = Depends(validate_ws_token) -# ): -# """ -# Message WebSocket endpoint, facilitates sending RPC messages -# """ -# try: -# channel = await channel_manager.on_connect(ws) -# if await is_websocket_alive(ws): - -# logger.info(f"Consumer connected - {channel}") - -# # Keep connection open until explicitly closed, and process requests -# try: -# while not channel.is_closed(): -# request = await channel.recv() - -# # Process the request here -# await _process_consumer_request(request, channel, rpc, channel_manager) - -# except (WebSocketDisconnect, WebSocketException): -# # Handle client disconnects -# logger.info(f"Consumer disconnected - {channel}") -# except RuntimeError: -# # Handle cases like - -# # RuntimeError('Cannot call "send" once a closed message has been sent') -# pass -# except Exception as e: -# logger.info(f"Consumer connection failed - {channel}: {e}") -# logger.debug(e, exc_info=e) - -# except RuntimeError: -# # WebSocket was closed -# # Do nothing -# pass -# except Exception as e: -# logger.error(f"Failed to serve - {ws.client}") -# # Log tracebacks to keep track of what errors are happening -# logger.exception(e) -# finally: -# if channel: -# await channel_manager.on_disconnect(ws) diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index d0695e06d..f100a46ef 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -45,10 +45,7 @@ class ApiServer(RPCHandler): _config: Config = {} # Exchange - only available in webserver mode. _exchange = None - # websocket message queue stuff - # _ws_channel_manager = None - # _ws_thread = None - # _ws_loop = None + # websocket message stuff _message_stream = None def __new__(cls, *args, **kwargs): @@ -72,8 +69,6 @@ class ApiServer(RPCHandler): api_config = self._config['api_server'] - # ApiServer._ws_channel_manager = ChannelManager() - self.app = FastAPI(title="Freqtrade API", docs_url='/docs' if api_config.get('enable_openapi', False) else None, redoc_url=None, @@ -101,19 +96,6 @@ class ApiServer(RPCHandler): logger.info("Stopping API Server") self._server.cleanup() - # if self._ws_thread and self._ws_loop: - # logger.info("Stopping API Server background tasks") - - # if self._ws_background_task: - # # Cancel the queue task - # self._ws_background_task.cancel() - - # self._ws_thread.join() - - # self._ws_thread = None - # self._ws_loop = None - # self._ws_background_task = None - @classmethod def shutdown(cls): cls.__initialized = False @@ -123,6 +105,9 @@ class ApiServer(RPCHandler): cls._rpc = None def send_msg(self, msg: Dict[str, Any]) -> None: + """ + Publish the message to the message stream + """ if ApiServer._message_stream: ApiServer._message_stream.publish(msg) @@ -173,57 +158,21 @@ class ApiServer(RPCHandler): ) async def _api_startup_event(self): + """ + Creates the MessageStream class on startup + so it has access to the same event loop + as uvicorn + """ if not ApiServer._message_stream: ApiServer._message_stream = MessageStream() async def _api_shutdown_event(self): + """ + Removes the MessageStream class on shutdown + """ if ApiServer._message_stream: ApiServer._message_stream = None - # def start_message_queue(self): - # if self._ws_thread: - # return - - # # Create a new loop, as it'll be just for the background thread - # self._ws_loop = asyncio.new_event_loop() - - # # Start the thread - # self._ws_thread = Thread(target=self._ws_loop.run_forever) - # self._ws_thread.start() - - # # Finally, submit the coro to the thread - # self._ws_background_task = asyncio.run_coroutine_threadsafe( - # self._broadcast_queue_data(), loop=self._ws_loop) - - # async def _broadcast_queue_data(self): - # # Instantiate the queue in this coroutine so it's attached to our loop - # self._ws_queue = ThreadedQueue() - # async_queue = self._ws_queue.async_q - - # try: - # while True: - # logger.debug("Getting queue messages...") - # # Get data from queue - # message: WSMessageSchemaType = await async_queue.get() - # logger.debug(f"Found message of type: {message.get('type')}") - # async_queue.task_done() - # # Broadcast it - # await self._ws_channel_manager.broadcast(message) - # except asyncio.CancelledError: - # pass - - # # For testing, shouldn't happen when stable - # except Exception as e: - # logger.exception(f"Exception happened in background task: {e}") - - # finally: - # # Disconnect channels and stop the loop on cancel - # await self._ws_channel_manager.disconnect_all() - # self._ws_loop.stop() - # # Avoid adding more items to the queue if they aren't - # # going to get broadcasted. - # self._ws_queue = None - def start_api(self): """ Start API ... should be run in thread. diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 39c8db516..ee16a95c6 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -125,223 +125,3 @@ class WebSocketChannel: yield self finally: await self.close() - - -# class WebSocketChannel: -# """ -# Object to help facilitate managing a websocket connection -# """ - -# def __init__( -# self, -# websocket: WebSocketType, -# channel_id: Optional[str] = None, -# drain_timeout: int = 3, -# throttle: float = 0.01, -# serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer -# ): - -# self.channel_id = channel_id if channel_id else uuid4().hex[:8] - -# # The WebSocket object -# self._websocket = WebSocketProxy(websocket) - -# self.drain_timeout = drain_timeout -# self.throttle = throttle - -# self._subscriptions: List[str] = [] -# # 32 is the size of the receiving queue in websockets package -# self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) -# self._relay_task = asyncio.create_task(self.relay()) - -# # Internal event to signify a closed websocket -# self._closed = asyncio.Event() - -# # Wrap the WebSocket in the Serializing class -# self._wrapped_ws = serializer_cls(self._websocket) - -# def __repr__(self): -# return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" - -# @property -# def raw_websocket(self): -# return self._websocket.raw_websocket - -# @property -# def remote_addr(self): -# return self._websocket.remote_addr - -# async def _send(self, data): -# """ -# Send data on the wrapped websocket -# """ -# await self._wrapped_ws.send(data) - -# async def send(self, data) -> bool: -# """ -# Add the data to the queue to be sent. -# :returns: True if data added to queue, False otherwise -# """ - -# # This block only runs if the queue is full, it will wait -# # until self.drain_timeout for the relay to drain the outgoing queue -# # We can't use asyncio.wait_for here because the queue may have been created with a -# # different eventloop -# start = time.time() -# while self.queue.full(): -# await asyncio.sleep(1) -# if (time.time() - start) > self.drain_timeout: -# return False - -# # If for some reason the queue is still full, just return False -# try: -# self.queue.put_nowait(data) -# except asyncio.QueueFull: -# return False - -# # If we got here everything is ok -# return True - -# async def recv(self): -# """ -# Receive data on the wrapped websocket -# """ -# return await self._wrapped_ws.recv() - -# async def ping(self): -# """ -# Ping the websocket -# """ -# return await self._websocket.ping() - -# async def close(self): -# """ -# Close the WebSocketChannel -# """ - -# try: -# await self.raw_websocket.close() -# except Exception: -# pass - -# self._closed.set() -# self._relay_task.cancel() - -# def is_closed(self) -> bool: -# """ -# Closed flag -# """ -# return self._closed.is_set() - -# def set_subscriptions(self, subscriptions: List[str] = []) -> None: -# """ -# Set which subscriptions this channel is subscribed to - -# :param subscriptions: List of subscriptions, List[str] -# """ -# self._subscriptions = subscriptions - -# def subscribed_to(self, message_type: str) -> bool: -# """ -# Check if this channel is subscribed to the message_type - -# :param message_type: The message type to check -# """ -# return message_type in self._subscriptions - -# async def relay(self): -# """ -# Relay messages from the channel's queue and send them out. This is started -# as a task. -# """ -# while not self._closed.is_set(): -# message = await self.queue.get() -# try: -# await self._send(message) -# self.queue.task_done() - -# # Limit messages per sec. -# # Could cause problems with queue size if too low, and -# # problems with network traffik if too high. -# # 0.01 = 100/s -# await asyncio.sleep(self.throttle) -# except RuntimeError: -# # The connection was closed, just exit the task -# return - - -# class ChannelManager: -# def __init__(self): -# self.channels = dict() -# self._lock = RLock() # Re-entrant Lock - -# async def on_connect(self, websocket: WebSocketType): -# """ -# Wrap websocket connection into Channel and add to list - -# :param websocket: The WebSocket object to attach to the Channel -# """ -# if isinstance(websocket, FastAPIWebSocket): -# try: -# await websocket.accept() -# except RuntimeError: -# # The connection was closed before we could accept it -# return - -# ws_channel = WebSocketChannel(websocket) - -# with self._lock: -# self.channels[websocket] = ws_channel - -# return ws_channel - -# async def on_disconnect(self, websocket: WebSocketType): -# """ -# Call close on the channel if it's not, and remove from channel list - -# :param websocket: The WebSocket objet attached to the Channel -# """ -# with self._lock: -# channel = self.channels.get(websocket) -# if channel: -# logger.info(f"Disconnecting channel {channel}") -# if not channel.is_closed(): -# await channel.close() - -# del self.channels[websocket] - -# async def disconnect_all(self): -# """ -# Disconnect all Channels -# """ -# with self._lock: -# for websocket in self.channels.copy().keys(): -# await self.on_disconnect(websocket) - -# async def broadcast(self, message: WSMessageSchemaType): -# """ -# Broadcast a message on all Channels - -# :param message: The message to send -# """ -# with self._lock: -# for channel in self.channels.copy().values(): -# if channel.subscribed_to(message.get('type')): -# await self.send_direct(channel, message) - -# async def send_direct( -# self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): -# """ -# Send a message directly through direct_channel only - -# :param direct_channel: The WebSocketChannel object to send the message through -# :param message: The message to send -# """ -# if not await channel.send(message): -# await self.on_disconnect(channel.raw_websocket) - -# def has_channels(self): -# """ -# Flag for more than 0 channels -# """ -# return len(self.channels) > 0 diff --git a/freqtrade/rpc/api_server/ws/serializer.py b/freqtrade/rpc/api_server/ws/serializer.py index 85703136b..625a0990c 100644 --- a/freqtrade/rpc/api_server/ws/serializer.py +++ b/freqtrade/rpc/api_server/ws/serializer.py @@ -34,9 +34,6 @@ class WebSocketSerializer(ABC): return self._deserialize(data) - # async def close(self, code: int = 1000): - # await self._websocket.close(code) - class HybridJSONWebSocketSerializer(WebSocketSerializer): def _serialize(self, data) -> str: From 0cb6f71c026bd2f771a862c43c5b2c744a64264e Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Fri, 18 Nov 2022 13:32:27 -0700 Subject: [PATCH 04/21] better error handling, true async sending, more readable api --- freqtrade/rpc/api_server/api_ws.py | 66 +++----------- freqtrade/rpc/api_server/webserver.py | 1 + freqtrade/rpc/api_server/ws/channel.py | 89 +++++++++++++++---- freqtrade/rpc/api_server/ws/message_stream.py | 3 +- 4 files changed, 88 insertions(+), 71 deletions(-) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 01243b0cc..2454646ea 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,16 +1,14 @@ -import asyncio import logging from typing import Any, Dict from fastapi import APIRouter, Depends -from fastapi.websockets import WebSocket, WebSocketDisconnect +from fastapi.websockets import WebSocket from pydantic import ValidationError -from websockets.exceptions import ConnectionClosed from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc -from freqtrade.rpc.api_server.ws import WebSocketChannel +from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, WSRequestSchema, WSWhitelistMessage) @@ -23,45 +21,20 @@ logger = logging.getLogger(__name__) router = APIRouter() -class WebSocketChannelClosed(Exception): - """ - General WebSocket exception to signal closing the channel - """ - pass - - async def channel_reader(channel: WebSocketChannel, rpc: RPC): """ Iterate over the messages from the channel and process the request """ - try: - async for message in channel: - await _process_consumer_request(message, channel, rpc) - except ( - RuntimeError, - WebSocketDisconnect, - ConnectionClosed - ): - raise WebSocketChannelClosed - except asyncio.CancelledError: - return + async for message in channel: + await _process_consumer_request(message, channel, rpc) async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): """ Iterate over messages in the message stream and send them """ - try: - async for message in message_stream: - await channel.send(message) - except ( - RuntimeError, - WebSocketDisconnect, - ConnectionClosed - ): - raise WebSocketChannelClosed - except asyncio.CancelledError: - return + async for message in message_stream: + await channel.send(message) async def _process_consumer_request( @@ -103,15 +76,11 @@ async def _process_consumer_request( # Format response response = WSWhitelistMessage(data=whitelist) - # Send it back await channel.send(response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: - limit = None - - if data: - # Limit the amount of candles per dataframe to 'limit' or 1500 - limit = max(data.get('limit', 1500), 1500) + # Limit the amount of candles per dataframe to 'limit' or 1500 + limit = min(data.get('limit', 1500), 1500) if data else None # For every pair in the generator, send a separate message for message in rpc._ws_request_analyzed_df(limit): @@ -127,17 +96,8 @@ async def message_endpoint( rpc: RPC = Depends(get_rpc), message_stream: MessageStream = Depends(get_message_stream) ): - async with WebSocketChannel(websocket).connect() as channel: - try: - logger.info(f"Channel connected - {channel}") - - channel_tasks = asyncio.gather( - channel_reader(channel, rpc), - channel_broadcaster(channel, message_stream) - ) - await channel_tasks - except WebSocketChannelClosed: - pass - finally: - logger.info(f"Channel disconnected - {channel}") - channel_tasks.cancel() + async with create_channel(websocket) as channel: + await channel.run_channel_tasks( + channel_reader(channel, rpc), + channel_broadcaster(channel, message_stream) + ) diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index f100a46ef..4a9f089d1 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -94,6 +94,7 @@ class ApiServer(RPCHandler): del ApiServer._rpc if self._server and not self._standalone: logger.info("Stopping API Server") + # self._server.force_exit, self._server.should_exit = True, True self._server.cleanup() @classmethod diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 8e248d368..d4d4d6453 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -29,6 +29,7 @@ class WebSocketChannel: # Internal event to signify a closed websocket self._closed = asyncio.Event() + self._send_timeout_high_limit = 2 # The subscribed message types self._subscriptions: List[str] = [] @@ -36,6 +37,9 @@ class WebSocketChannel: # Wrap the WebSocket in the Serializing class self._wrapped_ws = serializer_cls(self._websocket) + # The async tasks created for the channel + self._channel_tasks: List[asyncio.Task] = [] + def __repr__(self): return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" @@ -51,7 +55,14 @@ class WebSocketChannel: """ Send a message on the wrapped websocket """ - await self._wrapped_ws.send(message) + + # Without this sleep, messages would send to one channel + # first then another after the first one finished. + # With the sleep call, it gives control to the event + # loop to schedule other channel send methods. + await asyncio.sleep(0) + + return await self._wrapped_ws.send(message) async def recv(self): """ @@ -77,7 +88,6 @@ class WebSocketChannel: """ self._closed.set() - self._relay_task.cancel() try: await self._websocket.close() @@ -106,23 +116,68 @@ class WebSocketChannel: """ return message_type in self._subscriptions + async def run_channel_tasks(self, *tasks, **kwargs): + """ + Create and await on the channel tasks unless an exception + was raised, then cancel them all. + + :params *tasks: All coros or tasks to be run concurrently + :param **kwargs: Any extra kwargs to pass to gather + """ + + # Wrap the coros into tasks if they aren't already + self._channel_tasks = [ + task if isinstance(task, asyncio.Task) else asyncio.create_task(task) + for task in tasks + ] + + try: + await asyncio.gather(*self._channel_tasks, **kwargs) + except Exception: + # If an exception occurred, cancel the rest of the tasks and bubble up + # the error that was caught here + await self.cancel_channel_tasks() + raise + + async def cancel_channel_tasks(self): + """ + Cancel and wait on all channel tasks + """ + for task in self._channel_tasks: + task.cancel() + + # Wait for tasks to finish cancelling + try: + await asyncio.wait(self._channel_tasks) + except asyncio.CancelledError: + pass + + self._channel_tasks = [] + async def __aiter__(self): """ Generator for received messages """ - while True: - try: - yield await self.recv() - except Exception: - break + # We can not catch any errors here as websocket.recv is + # the first to catch any disconnects and bubble it up + # so the connection is garbage collected right away + while not self.is_closed(): + yield await self.recv() - @asynccontextmanager - async def connect(self): - """ - Context manager for safely opening and closing the websocket connection - """ - try: - await self.accept() - yield self - finally: - await self.close() + +@asynccontextmanager +async def create_channel(websocket: WebSocketType, **kwargs): + """ + Context manager for safely opening and closing a WebSocketChannel + """ + channel = WebSocketChannel(websocket, **kwargs) + try: + await channel.accept() + logger.info(f"Connected to channel - {channel}") + + yield channel + except Exception: + pass + finally: + await channel.close() + logger.info(f"Disconnected from channel - {channel}") diff --git a/freqtrade/rpc/api_server/ws/message_stream.py b/freqtrade/rpc/api_server/ws/message_stream.py index f77242719..9592908ab 100644 --- a/freqtrade/rpc/api_server/ws/message_stream.py +++ b/freqtrade/rpc/api_server/ws/message_stream.py @@ -17,7 +17,8 @@ class MessageStream: async def subscribe(self): waiter = self._waiter while True: - message, waiter = await waiter + # Shield the future from being cancelled by a task waiting on it + message, waiter = await asyncio.shield(waiter) yield message __aiter__ = subscribe From c1a73a551225424591891c8bb15491de85a79a36 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Sat, 19 Nov 2022 13:21:26 -0700 Subject: [PATCH 05/21] move sleep call in send, minor cleanup --- freqtrade/rpc/api_server/ws/channel.py | 20 +++++++++----------- freqtrade/rpc/api_server/ws/serializer.py | 1 - 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index d4d4d6453..7a1191d62 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -55,14 +55,16 @@ class WebSocketChannel: """ Send a message on the wrapped websocket """ + await self._wrapped_ws.send(message) # Without this sleep, messages would send to one channel - # first then another after the first one finished. + # first then another after the first one finished and prevent + # any normal Rest API calls from processing at the same time. # With the sleep call, it gives control to the event - # loop to schedule other channel send methods. - await asyncio.sleep(0) - - return await self._wrapped_ws.send(message) + # loop to schedule other channel send methods, and helps + # throttle how fast we send. + # 0.01 = 100 messages/second max throughput + await asyncio.sleep(0.01) async def recv(self): """ @@ -132,12 +134,10 @@ class WebSocketChannel: ] try: - await asyncio.gather(*self._channel_tasks, **kwargs) + return await asyncio.gather(*self._channel_tasks, **kwargs) except Exception: - # If an exception occurred, cancel the rest of the tasks and bubble up - # the error that was caught here + # If an exception occurred, cancel the rest of the tasks await self.cancel_channel_tasks() - raise async def cancel_channel_tasks(self): """ @@ -176,8 +176,6 @@ async def create_channel(websocket: WebSocketType, **kwargs): logger.info(f"Connected to channel - {channel}") yield channel - except Exception: - pass finally: await channel.close() logger.info(f"Disconnected from channel - {channel}") diff --git a/freqtrade/rpc/api_server/ws/serializer.py b/freqtrade/rpc/api_server/ws/serializer.py index 625a0990c..9a894e1bf 100644 --- a/freqtrade/rpc/api_server/ws/serializer.py +++ b/freqtrade/rpc/api_server/ws/serializer.py @@ -31,7 +31,6 @@ class WebSocketSerializer(ABC): async def recv(self) -> bytes: data = await self._websocket.recv() - return self._deserialize(data) From 3714d7074b91b9f0219e9fbac9c3effed9b4aecd Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Sat, 19 Nov 2022 13:29:23 -0700 Subject: [PATCH 06/21] smaller throttle in channel send --- freqtrade/rpc/api_server/ws/channel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 7a1191d62..80b2ec220 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -63,8 +63,8 @@ class WebSocketChannel: # With the sleep call, it gives control to the event # loop to schedule other channel send methods, and helps # throttle how fast we send. - # 0.01 = 100 messages/second max throughput - await asyncio.sleep(0.01) + # 0.005 = 200 messages/second max throughput + await asyncio.sleep(0.005) async def recv(self): """ From 60a167bdefac8ba1cdf5224aee00dfdc26145020 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Sun, 20 Nov 2022 14:09:45 -0700 Subject: [PATCH 07/21] add dynamic send timeout --- freqtrade/rpc/api_server/api_ws.py | 2 +- freqtrade/rpc/api_server/ws/channel.py | 65 +++++++++++++++++++------- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 2454646ea..618490ec8 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -34,7 +34,7 @@ async def channel_broadcaster(channel: WebSocketChannel, message_stream: Message Iterate over messages in the message stream and send them """ async for message in message_stream: - await channel.send(message) + await channel.send(message, timeout=True) async def _process_consumer_request( diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 80b2ec220..5424d7440 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,7 +1,9 @@ import asyncio import logging +import time +from collections import deque from contextlib import asynccontextmanager -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Deque, Dict, List, Optional, Type, Union from uuid import uuid4 from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy @@ -29,7 +31,13 @@ class WebSocketChannel: # Internal event to signify a closed websocket self._closed = asyncio.Event() - self._send_timeout_high_limit = 2 + # The async tasks created for the channel + self._channel_tasks: List[asyncio.Task] = [] + + # Deque for average send times + self._send_times: Deque[float] = deque([], maxlen=10) + # High limit defaults to 3 to start + self._send_high_limit = 3 # The subscribed message types self._subscriptions: List[str] = [] @@ -37,9 +45,6 @@ class WebSocketChannel: # Wrap the WebSocket in the Serializing class self._wrapped_ws = serializer_cls(self._websocket) - # The async tasks created for the channel - self._channel_tasks: List[asyncio.Task] = [] - def __repr__(self): return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" @@ -51,20 +56,48 @@ class WebSocketChannel: def remote_addr(self): return self._websocket.remote_addr - async def send(self, message: Union[WSMessageSchemaType, Dict[str, Any]]): + def _calc_send_limit(self): """ - Send a message on the wrapped websocket + Calculate the send high limit for this channel """ - await self._wrapped_ws.send(message) - # Without this sleep, messages would send to one channel - # first then another after the first one finished and prevent - # any normal Rest API calls from processing at the same time. - # With the sleep call, it gives control to the event - # loop to schedule other channel send methods, and helps - # throttle how fast we send. - # 0.005 = 200 messages/second max throughput - await asyncio.sleep(0.005) + # Only update if we have enough data + if len(self._send_times) == self._send_times.maxlen: + # At least 1s or twice the average of send times + self._send_high_limit = max( + (sum(self._send_times) / len(self._send_times)) * 2, + 1 + ) + + async def send( + self, + message: Union[WSMessageSchemaType, Dict[str, Any]], + timeout: bool = False + ): + """ + Send a message on the wrapped websocket. If the sending + takes too long, it will raise a TimeoutError and + disconnect the connection. + + :param message: The message to send + :param timeout: Enforce send high limit, defaults to False + """ + try: + _ = time.time() + # If the send times out, it will raise + # a TimeoutError and bubble up to the + # message_endpoint to close the connection + await asyncio.wait_for( + self._wrapped_ws.send(message), + timeout=self._send_high_limit if timeout else None + ) + total_time = time.time() - _ + self._send_times.append(total_time) + + self._calc_send_limit() + except asyncio.TimeoutError: + logger.info(f"Connection for {self} is too far behind, disconnecting") + raise async def recv(self): """ From 48a1f2418ffb89c148e3417f65545ec7248a6faf Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Sun, 20 Nov 2022 16:18:24 -0700 Subject: [PATCH 08/21] update typing, remove unneeded try block, readd sleep --- freqtrade/rpc/api_server/ws/channel.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 5424d7440..4bd7b0e4b 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -3,7 +3,7 @@ import logging import time from collections import deque from contextlib import asynccontextmanager -from typing import Any, Deque, Dict, List, Optional, Type, Union +from typing import Any, AsyncGenerator, Deque, Dict, List, Optional, Type, Union from uuid import uuid4 from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy @@ -99,6 +99,15 @@ class WebSocketChannel: logger.info(f"Connection for {self} is too far behind, disconnecting") raise + # Without this sleep, messages would send to one channel + # first then another after the first one finished and prevent + # any normal Rest API calls from processing at the same time. + # With the sleep call, it gives control to the event + # loop to schedule other channel send methods, and helps + # throttle how fast we send. + # 0.01 = 100 messages/second max throughput + await asyncio.sleep(0.01) + async def recv(self): """ Receive a message on the wrapped websocket @@ -180,10 +189,7 @@ class WebSocketChannel: task.cancel() # Wait for tasks to finish cancelling - try: - await asyncio.wait(self._channel_tasks) - except asyncio.CancelledError: - pass + await asyncio.wait(self._channel_tasks) self._channel_tasks = [] @@ -199,7 +205,10 @@ class WebSocketChannel: @asynccontextmanager -async def create_channel(websocket: WebSocketType, **kwargs): +async def create_channel( + websocket: WebSocketType, + **kwargs +) -> AsyncGenerator[WebSocketChannel, None]: """ Context manager for safely opening and closing a WebSocketChannel """ From d2870d48ea8e7d19782f6a2c753ea622c16d36ae Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Sun, 20 Nov 2022 16:24:44 -0700 Subject: [PATCH 09/21] change typing to async iterator --- freqtrade/rpc/api_server/ws/channel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 4bd7b0e4b..8699de66c 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -3,7 +3,7 @@ import logging import time from collections import deque from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Deque, Dict, List, Optional, Type, Union +from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union from uuid import uuid4 from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy @@ -208,7 +208,7 @@ class WebSocketChannel: async def create_channel( websocket: WebSocketType, **kwargs -) -> AsyncGenerator[WebSocketChannel, None]: +) -> AsyncIterator[WebSocketChannel]: """ Context manager for safely opening and closing a WebSocketChannel """ From d9d7df70bfcbc2094ed51518438b238254d193f6 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Mon, 21 Nov 2022 12:21:40 -0700 Subject: [PATCH 10/21] fix tests, log unknown errors --- freqtrade/rpc/api_server/webserver.py | 1 - freqtrade/rpc/api_server/ws/channel.py | 14 ++++++++++- tests/rpc/test_rpc_apiserver.py | 34 ++++++++++++-------------- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 4a9f089d1..e4eb3895d 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -212,7 +212,6 @@ class ApiServer(RPCHandler): if self._standalone: self._server.run() else: - # self.start_message_queue() self._server.run_in_thread() except Exception: logger.exception("Api server failed to start.") diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 8699de66c..9dea21f3b 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -6,6 +6,9 @@ from contextlib import asynccontextmanager from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union from uuid import uuid4 +from fastapi import WebSocketDisconnect +from websockets.exceptions import ConnectionClosed + from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, WebSocketSerializer) @@ -189,7 +192,16 @@ class WebSocketChannel: task.cancel() # Wait for tasks to finish cancelling - await asyncio.wait(self._channel_tasks) + try: + await task + except ( + asyncio.CancelledError, + WebSocketDisconnect, + ConnectionClosed + ): + pass + except Exception as e: + logger.info(f"Encountered unknown exception: {e}", exc_info=e) self._channel_tasks = [] diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index 969728b6f..25d6a32e3 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -57,7 +57,10 @@ def botclient(default_conf, mocker): try: apiserver = ApiServer(default_conf) apiserver.add_rpc_handler(rpc) - yield ftbot, TestClient(apiserver.app) + # We need to use the TestClient as a context manager to + # handle lifespan events correctly + with TestClient(apiserver.app) as client: + yield ftbot, client # Cleanup ... ? finally: if apiserver: @@ -438,7 +441,6 @@ def test_api_cleanup(default_conf, mocker, caplog): apiserver.cleanup() assert apiserver._server.cleanup.call_count == 1 assert log_has("Stopping API Server", caplog) - assert log_has("Stopping API Server background tasks", caplog) ApiServer.shutdown() @@ -1714,12 +1716,14 @@ def test_api_ws_subscribe(botclient, mocker): with client.websocket_connect(ws_url) as ws: ws.send_json({'type': 'subscribe', 'data': ['whitelist']}) + time.sleep(1) # Check call count is now 1 as we sent a valid subscribe request assert sub_mock.call_count == 1 with client.websocket_connect(ws_url) as ws: ws.send_json({'type': 'subscribe', 'data': 'whitelist'}) + time.sleep(1) # Call count hasn't changed as the subscribe request was invalid assert sub_mock.call_count == 1 @@ -1773,24 +1777,18 @@ def test_api_ws_send_msg(default_conf, mocker, caplog): mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api') apiserver = ApiServer(default_conf) apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf))) - apiserver.start_message_queue() - # Give the queue thread time to start - time.sleep(0.2) - # Test message_queue coro receives the message - test_message = {"type": "status", "data": "test"} - apiserver.send_msg(test_message) - time.sleep(0.1) # Not sure how else to wait for the coro to receive the data - assert log_has("Found message of type: status", caplog) + # Start test client context manager to run lifespan events + with TestClient(apiserver.app): + # Test message is published on the Message Stream + test_message = {"type": "status", "data": "test"} + first_waiter = apiserver._message_stream._waiter + apiserver.send_msg(test_message) + assert first_waiter.result()[0] == test_message - # Test if exception logged when error occurs in sending - mocker.patch('freqtrade.rpc.api_server.ws.channel.ChannelManager.broadcast', - side_effect=Exception) - - apiserver.send_msg(test_message) - time.sleep(0.1) # Not sure how else to wait for the coro to receive the data - assert log_has_re(r"Exception happened in background task.*", caplog) + second_waiter = apiserver._message_stream._waiter + apiserver.send_msg(test_message) + assert first_waiter != second_waiter finally: - apiserver.cleanup() ApiServer.shutdown() From a5442772fc22138dc18fcd3c99c2727f1e9007dd Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Tue, 22 Nov 2022 09:42:09 -0700 Subject: [PATCH 11/21] ensure only broadcasting to subscribed topics --- freqtrade/rpc/api_server/api_ws.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 618490ec8..fe2968c05 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -34,7 +34,8 @@ async def channel_broadcaster(channel: WebSocketChannel, message_stream: Message Iterate over messages in the message stream and send them """ async for message in message_stream: - await channel.send(message, timeout=True) + if channel.subscribed_to(message.get('type')): + await channel.send(message, timeout=True) async def _process_consumer_request( From 48242ca02b0f819d0d0318e89ad2b1804017b076 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Tue, 22 Nov 2022 12:43:45 -0700 Subject: [PATCH 12/21] update catch block in cancel channel tasks --- freqtrade/rpc/api_server/ws/channel.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 9dea21f3b..ad183ce5b 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -99,7 +99,7 @@ class WebSocketChannel: self._calc_send_limit() except asyncio.TimeoutError: - logger.info(f"Connection for {self} is too far behind, disconnecting") + logger.info(f"Connection for {self} timed out, disconnecting") raise # Without this sleep, messages would send to one channel @@ -138,7 +138,7 @@ class WebSocketChannel: try: await self._websocket.close() - except Exception: + except RuntimeError: pass def is_closed(self) -> bool: @@ -196,8 +196,10 @@ class WebSocketChannel: await task except ( asyncio.CancelledError, + asyncio.TimeoutError, WebSocketDisconnect, - ConnectionClosed + ConnectionClosed, + RuntimeError ): pass except Exception as e: From 101dec461e40c2b8ed15a7075bb4b7dc9099c7b2 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Thu, 24 Nov 2022 11:35:50 -0700 Subject: [PATCH 13/21] close ws channel if can't accept --- freqtrade/rpc/api_server/ws/channel.py | 56 ++++++++++++++------------ 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index ad183ce5b..7343bc306 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -125,9 +125,14 @@ class WebSocketChannel: async def accept(self): """ - Accept the underlying websocket connection + Accept the underlying websocket connection, + if the connection has been closed before we can + accept, just close the channel. """ - return await self._websocket.accept() + try: + return await self._websocket.accept() + except RuntimeError: + await self.close() async def close(self): """ @@ -172,17 +177,18 @@ class WebSocketChannel: :param **kwargs: Any extra kwargs to pass to gather """ - # Wrap the coros into tasks if they aren't already - self._channel_tasks = [ - task if isinstance(task, asyncio.Task) else asyncio.create_task(task) - for task in tasks - ] + if not self.is_closed(): + # Wrap the coros into tasks if they aren't already + self._channel_tasks = [ + task if isinstance(task, asyncio.Task) else asyncio.create_task(task) + for task in tasks + ] - try: - return await asyncio.gather(*self._channel_tasks, **kwargs) - except Exception: - # If an exception occurred, cancel the rest of the tasks - await self.cancel_channel_tasks() + try: + return await asyncio.gather(*self._channel_tasks, **kwargs) + except Exception: + # If an exception occurred, cancel the rest of the tasks + await self.cancel_channel_tasks() async def cancel_channel_tasks(self): """ @@ -191,19 +197,19 @@ class WebSocketChannel: for task in self._channel_tasks: task.cancel() - # Wait for tasks to finish cancelling - try: - await task - except ( - asyncio.CancelledError, - asyncio.TimeoutError, - WebSocketDisconnect, - ConnectionClosed, - RuntimeError - ): - pass - except Exception as e: - logger.info(f"Encountered unknown exception: {e}", exc_info=e) + # Wait for tasks to finish cancelling + try: + await task + except ( + asyncio.CancelledError, + asyncio.TimeoutError, + WebSocketDisconnect, + ConnectionClosed, + RuntimeError + ): + pass + except Exception as e: + logger.info(f"Encountered unknown exception: {e}", exc_info=e) self._channel_tasks = [] From fc59b02255e3b91e8329b6bf02517102b05d0996 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Thu, 24 Nov 2022 13:41:10 -0700 Subject: [PATCH 14/21] prevent ws endpoint from running without valid token --- freqtrade/rpc/api_server/api_auth.py | 2 -- freqtrade/rpc/api_server/api_ws.py | 11 ++++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/freqtrade/rpc/api_server/api_auth.py b/freqtrade/rpc/api_server/api_auth.py index ee66fce2b..71f1145a9 100644 --- a/freqtrade/rpc/api_server/api_auth.py +++ b/freqtrade/rpc/api_server/api_auth.py @@ -81,8 +81,6 @@ async def validate_ws_token( except HTTPException: pass - # No checks passed, deny the connection - logger.debug("Denying websocket request.") # If it doesn't match, close the websocket connection await ws.close(code=status.WS_1008_POLICY_VIOLATION) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index fe2968c05..77950923d 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -97,8 +97,9 @@ async def message_endpoint( rpc: RPC = Depends(get_rpc), message_stream: MessageStream = Depends(get_message_stream) ): - async with create_channel(websocket) as channel: - await channel.run_channel_tasks( - channel_reader(channel, rpc), - channel_broadcaster(channel, message_stream) - ) + if token: + async with create_channel(websocket) as channel: + await channel.run_channel_tasks( + channel_reader(channel, rpc), + channel_broadcaster(channel, message_stream) + ) From afc00bc30a94abd64fee000535e66287fd91595f Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Fri, 25 Nov 2022 12:48:57 -0700 Subject: [PATCH 15/21] log warning if channel too far behind, add docstrings to message stream --- freqtrade/rpc/api_server/api_ws.py | 11 +++++++++- freqtrade/rpc/api_server/ws/channel.py | 12 ++++++----- freqtrade/rpc/api_server/ws/message_stream.py | 21 ++++++++++++------- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 77950923d..a80250c1b 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,4 +1,5 @@ import logging +import time from typing import Any, Dict from fastapi import APIRouter, Depends @@ -33,8 +34,16 @@ async def channel_broadcaster(channel: WebSocketChannel, message_stream: Message """ Iterate over messages in the message stream and send them """ - async for message in message_stream: + async for message, ts in message_stream: if channel.subscribed_to(message.get('type')): + # Log a warning if this channel is behind + # on the message stream by a lot + if (time.time() - ts) > 60: + logger.warning("Channel {channel} is behind MessageStream by 1 minute," + " this can cause a memory leak if you see this message" + " often, consider reducing pair list size or amount of" + " consumers.") + await channel.send(message, timeout=True) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 7343bc306..a5f3b6216 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -59,6 +59,10 @@ class WebSocketChannel: def remote_addr(self): return self._websocket.remote_addr + @property + def avg_send_time(self): + return sum(self._send_times) / len(self._send_times) + def _calc_send_limit(self): """ Calculate the send high limit for this channel @@ -66,11 +70,9 @@ class WebSocketChannel: # Only update if we have enough data if len(self._send_times) == self._send_times.maxlen: - # At least 1s or twice the average of send times - self._send_high_limit = max( - (sum(self._send_times) / len(self._send_times)) * 2, - 1 - ) + # At least 1s or twice the average of send times, with a + # maximum of 3 seconds per message + self._send_high_limit = min(max(self.avg_send_time * 2, 1), 3) async def send( self, diff --git a/freqtrade/rpc/api_server/ws/message_stream.py b/freqtrade/rpc/api_server/ws/message_stream.py index 9592908ab..a55a0da3c 100644 --- a/freqtrade/rpc/api_server/ws/message_stream.py +++ b/freqtrade/rpc/api_server/ws/message_stream.py @@ -1,4 +1,5 @@ import asyncio +import time class MessageStream: @@ -11,14 +12,20 @@ class MessageStream: self._waiter = self._loop.create_future() def publish(self, message): - waiter, self._waiter = self._waiter, self._loop.create_future() - waiter.set_result((message, self._waiter)) + """ + Publish a message to this MessageStream - async def subscribe(self): + :param message: The message to publish + """ + waiter, self._waiter = self._waiter, self._loop.create_future() + waiter.set_result((message, time.time(), self._waiter)) + + async def __aiter__(self): + """ + Iterate over the messages in the message stream + """ waiter = self._waiter while True: # Shield the future from being cancelled by a task waiting on it - message, waiter = await asyncio.shield(waiter) - yield message - - __aiter__ = subscribe + message, ts, waiter = await asyncio.shield(waiter) + yield message, ts From f268187e9b357127151ae45704538aed6c89f7f5 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Fri, 25 Nov 2022 12:56:33 -0700 Subject: [PATCH 16/21] offload initial df computation to thread --- freqtrade/misc.py | 43 ++++++++++++++++++++++++++++++ freqtrade/rpc/api_server/api_ws.py | 3 ++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/freqtrade/misc.py b/freqtrade/misc.py index 2d2c7513a..349735dcd 100644 --- a/freqtrade/misc.py +++ b/freqtrade/misc.py @@ -1,9 +1,11 @@ """ Various tool function for Freqtrade and scripts """ +import asyncio import gzip import logging import re +import threading from datetime import datetime from pathlib import Path from typing import Any, Dict, Iterator, List, Mapping, Union @@ -301,3 +303,44 @@ def remove_entry_exit_signals(dataframe: pd.DataFrame): dataframe[SignalTagType.EXIT_TAG.value] = None return dataframe + + +def sync_to_async_iter(iter): + """ + Wrap blocking iterator into an asynchronous by + offloading computation to thread and using + pubsub pattern for yielding results + + :param iter: A synchronous iterator + :returns: An asynchronous iterator + """ + + loop = asyncio.get_event_loop() + q = asyncio.Queue(1) + exception = None + _END = object() + + async def yield_queue_items(): + while True: + next_item = await q.get() + if next_item is _END: + break + yield next_item + if exception is not None: + # The iterator has raised, propagate the exception + raise exception + + def iter_to_queue(): + nonlocal exception + try: + for item in iter: + # This runs outside the event loop thread, so we + # must use thread-safe API to talk to the queue. + asyncio.run_coroutine_threadsafe(q.put(item), loop).result() + except Exception as e: + exception = e + finally: + asyncio.run_coroutine_threadsafe(q.put(_END), loop).result() + + threading.Thread(target=iter_to_queue).start() + return yield_queue_items() diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index a80250c1b..6ecc1ef2a 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -7,6 +7,7 @@ from fastapi.websockets import WebSocket from pydantic import ValidationError from freqtrade.enums import RPCMessageType, RPCRequestType +from freqtrade.misc import sync_to_async_iter from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel @@ -93,7 +94,7 @@ async def _process_consumer_request( limit = min(data.get('limit', 1500), 1500) if data else None # For every pair in the generator, send a separate message - for message in rpc._ws_request_analyzed_df(limit): + async for message in sync_to_async_iter(rpc._ws_request_analyzed_df(limit)): # Format response response = WSAnalyzedDFMessage(data=message) await channel.send(response.dict(exclude_none=True)) From 4aa4c6f49d27aa724ec8a120003c20215aa90195 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Fri, 25 Nov 2022 13:08:41 -0700 Subject: [PATCH 17/21] change sleep in channel send to 0 --- freqtrade/rpc/api_server/ws/channel.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index a5f3b6216..76e48d889 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -104,14 +104,9 @@ class WebSocketChannel: logger.info(f"Connection for {self} timed out, disconnecting") raise - # Without this sleep, messages would send to one channel - # first then another after the first one finished and prevent - # any normal Rest API calls from processing at the same time. - # With the sleep call, it gives control to the event - # loop to schedule other channel send methods, and helps - # throttle how fast we send. - # 0.01 = 100 messages/second max throughput - await asyncio.sleep(0.01) + # Explicitly give control back to event loop as + # websockets.send does not + await asyncio.sleep(0) async def recv(self): """ From bd95392eea3c4cdae7c5f97557a359599664ba34 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Fri, 25 Nov 2022 13:10:22 -0700 Subject: [PATCH 18/21] fix formatted string in warning message :) --- freqtrade/rpc/api_server/api_ws.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 6ecc1ef2a..9e7bb17a4 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -40,7 +40,7 @@ async def channel_broadcaster(channel: WebSocketChannel, message_stream: Message # Log a warning if this channel is behind # on the message stream by a lot if (time.time() - ts) > 60: - logger.warning("Channel {channel} is behind MessageStream by 1 minute," + logger.warning(f"Channel {channel} is behind MessageStream by 1 minute," " this can cause a memory leak if you see this message" " often, consider reducing pair list size or amount of" " consumers.") From 7b0a76fb7010eac44d7000626d9f167201b87f1a Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 25 Nov 2022 10:41:37 +0100 Subject: [PATCH 19/21] Improve typehint --- freqtrade/rpc/api_server/webserver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index e4eb3895d..92bded1c5 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -1,6 +1,6 @@ import logging from ipaddress import IPv4Address -from typing import Any, Dict +from typing import Any, Dict, Optional import orjson import uvicorn @@ -46,7 +46,7 @@ class ApiServer(RPCHandler): # Exchange - only available in webserver mode. _exchange = None # websocket message stuff - _message_stream = None + _message_stream: Optional[MessageStream] = None def __new__(cls, *args, **kwargs): """ From fcf13580f14aea8e889eaf1af82140eb17596d5c Mon Sep 17 00:00:00 2001 From: Matthias Date: Sat, 26 Nov 2022 13:33:54 +0100 Subject: [PATCH 20/21] Revert "offload initial df computation to thread" This reverts commit f268187e9b357127151ae45704538aed6c89f7f5. --- freqtrade/misc.py | 43 ------------------------------ freqtrade/rpc/api_server/api_ws.py | 3 +-- 2 files changed, 1 insertion(+), 45 deletions(-) diff --git a/freqtrade/misc.py b/freqtrade/misc.py index 349735dcd..2d2c7513a 100644 --- a/freqtrade/misc.py +++ b/freqtrade/misc.py @@ -1,11 +1,9 @@ """ Various tool function for Freqtrade and scripts """ -import asyncio import gzip import logging import re -import threading from datetime import datetime from pathlib import Path from typing import Any, Dict, Iterator, List, Mapping, Union @@ -303,44 +301,3 @@ def remove_entry_exit_signals(dataframe: pd.DataFrame): dataframe[SignalTagType.EXIT_TAG.value] = None return dataframe - - -def sync_to_async_iter(iter): - """ - Wrap blocking iterator into an asynchronous by - offloading computation to thread and using - pubsub pattern for yielding results - - :param iter: A synchronous iterator - :returns: An asynchronous iterator - """ - - loop = asyncio.get_event_loop() - q = asyncio.Queue(1) - exception = None - _END = object() - - async def yield_queue_items(): - while True: - next_item = await q.get() - if next_item is _END: - break - yield next_item - if exception is not None: - # The iterator has raised, propagate the exception - raise exception - - def iter_to_queue(): - nonlocal exception - try: - for item in iter: - # This runs outside the event loop thread, so we - # must use thread-safe API to talk to the queue. - asyncio.run_coroutine_threadsafe(q.put(item), loop).result() - except Exception as e: - exception = e - finally: - asyncio.run_coroutine_threadsafe(q.put(_END), loop).result() - - threading.Thread(target=iter_to_queue).start() - return yield_queue_items() diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 9e7bb17a4..e183cd7e7 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -7,7 +7,6 @@ from fastapi.websockets import WebSocket from pydantic import ValidationError from freqtrade.enums import RPCMessageType, RPCRequestType -from freqtrade.misc import sync_to_async_iter from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel @@ -94,7 +93,7 @@ async def _process_consumer_request( limit = min(data.get('limit', 1500), 1500) if data else None # For every pair in the generator, send a separate message - async for message in sync_to_async_iter(rpc._ws_request_analyzed_df(limit)): + for message in rpc._ws_request_analyzed_df(limit): # Format response response = WSAnalyzedDFMessage(data=message) await channel.send(response.dict(exclude_none=True)) From a26b3a9ca8031753f406df690abd638b09ca8d31 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Sat, 26 Nov 2022 09:40:22 -0700 Subject: [PATCH 21/21] change sleep call back to 0.01 --- freqtrade/rpc/api_server/ws/channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 76e48d889..c50aff8be 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -106,7 +106,7 @@ class WebSocketChannel: # Explicitly give control back to event loop as # websockets.send does not - await asyncio.sleep(0) + await asyncio.sleep(0.01) async def recv(self): """