diff --git a/freqtrade/rpc/api_server/api_auth.py b/freqtrade/rpc/api_server/api_auth.py index ee66fce2b..71f1145a9 100644 --- a/freqtrade/rpc/api_server/api_auth.py +++ b/freqtrade/rpc/api_server/api_auth.py @@ -81,8 +81,6 @@ async def validate_ws_token( except HTTPException: pass - # No checks passed, deny the connection - logger.debug("Denying websocket request.") # If it doesn't match, close the websocket connection await ws.close(code=status.WS_1008_POLICY_VIOLATION) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 785773b39..e183cd7e7 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,16 +1,16 @@ import logging +import time from typing import Any, Dict -from fastapi import APIRouter, Depends, WebSocketDisconnect -from fastapi.websockets import WebSocket, WebSocketState +from fastapi import APIRouter, Depends +from fastapi.websockets import WebSocket from pydantic import ValidationError -from websockets.exceptions import WebSocketException from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import validate_ws_token -from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc -from freqtrade.rpc.api_server.ws import WebSocketChannel -from freqtrade.rpc.api_server.ws.channel import ChannelManager +from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc +from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel +from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, WSRequestSchema, WSWhitelistMessage) from freqtrade.rpc.rpc import RPC @@ -22,23 +22,35 @@ logger = logging.getLogger(__name__) router = APIRouter() -async def is_websocket_alive(ws: WebSocket) -> bool: +async def channel_reader(channel: WebSocketChannel, rpc: RPC): """ - Check if a FastAPI Websocket is still open + Iterate over the messages from the channel and process the request """ - if ( - ws.application_state == WebSocketState.CONNECTED and - ws.client_state == WebSocketState.CONNECTED - ): - return True - return False + async for message in channel: + await _process_consumer_request(message, channel, rpc) + + +async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): + """ + Iterate over messages in the message stream and send them + """ + async for message, ts in message_stream: + if channel.subscribed_to(message.get('type')): + # Log a warning if this channel is behind + # on the message stream by a lot + if (time.time() - ts) > 60: + logger.warning(f"Channel {channel} is behind MessageStream by 1 minute," + " this can cause a memory leak if you see this message" + " often, consider reducing pair list size or amount of" + " consumers.") + + await channel.send(message, timeout=True) async def _process_consumer_request( request: Dict[str, Any], channel: WebSocketChannel, - rpc: RPC, - channel_manager: ChannelManager + rpc: RPC ): """ Validate and handle a request from a websocket consumer @@ -74,65 +86,29 @@ async def _process_consumer_request( # Format response response = WSWhitelistMessage(data=whitelist) - # Send it back - await channel_manager.send_direct(channel, response.dict(exclude_none=True)) + await channel.send(response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: - limit = None - - if data: - # Limit the amount of candles per dataframe to 'limit' or 1500 - limit = max(data.get('limit', 1500), 1500) + # Limit the amount of candles per dataframe to 'limit' or 1500 + limit = min(data.get('limit', 1500), 1500) if data else None # For every pair in the generator, send a separate message for message in rpc._ws_request_analyzed_df(limit): + # Format response response = WSAnalyzedDFMessage(data=message) - await channel_manager.send_direct(channel, response.dict(exclude_none=True)) + await channel.send(response.dict(exclude_none=True)) @router.websocket("/message/ws") async def message_endpoint( - ws: WebSocket, + websocket: WebSocket, + token: str = Depends(validate_ws_token), rpc: RPC = Depends(get_rpc), - channel_manager=Depends(get_channel_manager), - token: str = Depends(validate_ws_token) + message_stream: MessageStream = Depends(get_message_stream) ): - """ - Message WebSocket endpoint, facilitates sending RPC messages - """ - try: - channel = await channel_manager.on_connect(ws) - if await is_websocket_alive(ws): - - logger.info(f"Consumer connected - {channel}") - - # Keep connection open until explicitly closed, and process requests - try: - while not channel.is_closed(): - request = await channel.recv() - - # Process the request here - await _process_consumer_request(request, channel, rpc, channel_manager) - - except (WebSocketDisconnect, WebSocketException): - # Handle client disconnects - logger.info(f"Consumer disconnected - {channel}") - except RuntimeError: - # Handle cases like - - # RuntimeError('Cannot call "send" once a closed message has been sent') - pass - except Exception as e: - logger.info(f"Consumer connection failed - {channel}: {e}") - logger.debug(e, exc_info=e) - - except RuntimeError: - # WebSocket was closed - # Do nothing - pass - except Exception as e: - logger.error(f"Failed to serve - {ws.client}") - # Log tracebacks to keep track of what errors are happening - logger.exception(e) - finally: - if channel: - await channel_manager.on_disconnect(ws) + if token: + async with create_channel(websocket) as channel: + await channel.run_channel_tasks( + channel_reader(channel, rpc), + channel_broadcaster(channel, message_stream) + ) diff --git a/freqtrade/rpc/api_server/deps.py b/freqtrade/rpc/api_server/deps.py index abd3db036..aed97367b 100644 --- a/freqtrade/rpc/api_server/deps.py +++ b/freqtrade/rpc/api_server/deps.py @@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)): return ApiServer._exchange -def get_channel_manager(): - return ApiServer._ws_channel_manager +def get_message_stream(): + return ApiServer._message_stream def is_webserver_mode(config=Depends(get_config)): diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index ec4907e67..92bded1c5 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -1,22 +1,17 @@ -import asyncio import logging from ipaddress import IPv4Address -from threading import Thread from typing import Any, Dict, Optional import orjson import uvicorn from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware -# Look into alternatives -from janus import Queue as ThreadedQueue from starlette.responses import JSONResponse from freqtrade.constants import Config from freqtrade.exceptions import OperationalException from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer -from freqtrade.rpc.api_server.ws import ChannelManager -from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType +from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler @@ -50,10 +45,8 @@ class ApiServer(RPCHandler): _config: Config = {} # Exchange - only available in webserver mode. _exchange = None - # websocket message queue stuff - _ws_channel_manager: ChannelManager - _ws_thread = None - _ws_loop: Optional[asyncio.AbstractEventLoop] = None + # websocket message stuff + _message_stream: Optional[MessageStream] = None def __new__(cls, *args, **kwargs): """ @@ -71,15 +64,11 @@ class ApiServer(RPCHandler): return self._standalone: bool = standalone self._server = None - self._ws_queue: Optional[ThreadedQueue] = None - self._ws_background_task = None ApiServer.__initialized = True api_config = self._config['api_server'] - ApiServer._ws_channel_manager = ChannelManager() - self.app = FastAPI(title="Freqtrade API", docs_url='/docs' if api_config.get('enable_openapi', False) else None, redoc_url=None, @@ -105,21 +94,9 @@ class ApiServer(RPCHandler): del ApiServer._rpc if self._server and not self._standalone: logger.info("Stopping API Server") + # self._server.force_exit, self._server.should_exit = True, True self._server.cleanup() - if self._ws_thread and self._ws_loop: - logger.info("Stopping API Server background tasks") - - if self._ws_background_task: - # Cancel the queue task - self._ws_background_task.cancel() - - self._ws_thread.join() - - self._ws_thread = None - self._ws_loop = None - self._ws_background_task = None - @classmethod def shutdown(cls): cls.__initialized = False @@ -129,9 +106,11 @@ class ApiServer(RPCHandler): cls._rpc = None def send_msg(self, msg: Dict[str, Any]) -> None: - if self._ws_queue: - sync_q = self._ws_queue.sync_q - sync_q.put(msg) + """ + Publish the message to the message stream + """ + if ApiServer._message_stream: + ApiServer._message_stream.publish(msg) def handle_rpc_exception(self, request, exc): logger.exception(f"API Error calling: {exc}") @@ -170,54 +149,30 @@ class ApiServer(RPCHandler): ) app.add_exception_handler(RPCException, self.handle_rpc_exception) + app.add_event_handler( + event_type="startup", + func=self._api_startup_event + ) + app.add_event_handler( + event_type="shutdown", + func=self._api_shutdown_event + ) - def start_message_queue(self): - if self._ws_thread: - return + async def _api_startup_event(self): + """ + Creates the MessageStream class on startup + so it has access to the same event loop + as uvicorn + """ + if not ApiServer._message_stream: + ApiServer._message_stream = MessageStream() - # Create a new loop, as it'll be just for the background thread - self._ws_loop = asyncio.new_event_loop() - - # Start the thread - self._ws_thread = Thread(target=self._ws_loop.run_forever) - self._ws_thread.start() - - # Finally, submit the coro to the thread - self._ws_background_task = asyncio.run_coroutine_threadsafe( - self._broadcast_queue_data(), loop=self._ws_loop) - - async def _broadcast_queue_data(self) -> None: - # Instantiate the queue in this coroutine so it's attached to our loop - self._ws_queue = ThreadedQueue() - async_queue = self._ws_queue.async_q - - try: - while True: - logger.debug("Getting queue messages...") - if (qsize := async_queue.qsize()) > 20: - # If the queue becomes too big for too long, this may indicate a problem. - logger.warning(f"Queue size now {qsize}") - # Get data from queue - message: WSMessageSchemaType = await async_queue.get() - logger.debug(f"Found message of type: {message.get('type')}") - async_queue.task_done() - # Broadcast it - await self._ws_channel_manager.broadcast(message) - except asyncio.CancelledError: - pass - - # For testing, shouldn't happen when stable - except Exception as e: - logger.exception(f"Exception happened in background task: {e}") - - finally: - # Disconnect channels and stop the loop on cancel - await self._ws_channel_manager.disconnect_all() - if self._ws_loop: - self._ws_loop.stop() - # Avoid adding more items to the queue if they aren't - # going to get broadcasted. - self._ws_queue = None + async def _api_shutdown_event(self): + """ + Removes the MessageStream class on shutdown + """ + if ApiServer._message_stream: + ApiServer._message_stream = None def start_api(self): """ @@ -257,7 +212,6 @@ class ApiServer(RPCHandler): if self._standalone: self._server.run() else: - self.start_message_queue() self._server.run_in_thread() except Exception: logger.exception("Api server failed to start.") diff --git a/freqtrade/rpc/api_server/ws/__init__.py b/freqtrade/rpc/api_server/ws/__init__.py index 055b20a9d..0b94d3fee 100644 --- a/freqtrade/rpc/api_server/ws/__init__.py +++ b/freqtrade/rpc/api_server/ws/__init__.py @@ -3,4 +3,5 @@ from freqtrade.rpc.api_server.ws.types import WebSocketType from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer -from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel +from freqtrade.rpc.api_server.ws.channel import WebSocketChannel +from freqtrade.rpc.api_server.ws.message_stream import MessageStream diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 4eef738d4..c50aff8be 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,11 +1,13 @@ import asyncio import logging import time -from threading import RLock -from typing import Any, Dict, List, Optional, Type, Union +from collections import deque +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union from uuid import uuid4 -from fastapi import WebSocket as FastAPIWebSocket +from fastapi import WebSocketDisconnect +from websockets.exceptions import ConnectionClosed from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, @@ -21,31 +23,27 @@ class WebSocketChannel: """ Object to help facilitate managing a websocket connection """ - def __init__( self, websocket: WebSocketType, channel_id: Optional[str] = None, - drain_timeout: int = 3, - throttle: float = 0.01, serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer ): - self.channel_id = channel_id if channel_id else uuid4().hex[:8] - - # The WebSocket object self._websocket = WebSocketProxy(websocket) - self.drain_timeout = drain_timeout - self.throttle = throttle - - self._subscriptions: List[str] = [] - # 32 is the size of the receiving queue in websockets package - self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) - self._relay_task = asyncio.create_task(self.relay()) - # Internal event to signify a closed websocket self._closed = asyncio.Event() + # The async tasks created for the channel + self._channel_tasks: List[asyncio.Task] = [] + + # Deque for average send times + self._send_times: Deque[float] = deque([], maxlen=10) + # High limit defaults to 3 to start + self._send_high_limit = 3 + + # The subscribed message types + self._subscriptions: List[str] = [] # Wrap the WebSocket in the Serializing class self._wrapped_ws = serializer_cls(self._websocket) @@ -61,43 +59,58 @@ class WebSocketChannel: def remote_addr(self): return self._websocket.remote_addr - async def _send(self, data): - """ - Send data on the wrapped websocket - """ - await self._wrapped_ws.send(data) + @property + def avg_send_time(self): + return sum(self._send_times) / len(self._send_times) - async def send(self, data) -> bool: + def _calc_send_limit(self): """ - Add the data to the queue to be sent. - :returns: True if data added to queue, False otherwise + Calculate the send high limit for this channel """ - # This block only runs if the queue is full, it will wait - # until self.drain_timeout for the relay to drain the outgoing queue - # We can't use asyncio.wait_for here because the queue may have been created with a - # different eventloop - if not self.is_closed(): - start = time.time() - while self.queue.full(): - await asyncio.sleep(1) - if (time.time() - start) > self.drain_timeout: - return False + # Only update if we have enough data + if len(self._send_times) == self._send_times.maxlen: + # At least 1s or twice the average of send times, with a + # maximum of 3 seconds per message + self._send_high_limit = min(max(self.avg_send_time * 2, 1), 3) - # If for some reason the queue is still full, just return False - try: - self.queue.put_nowait(data) - except asyncio.QueueFull: - return False + async def send( + self, + message: Union[WSMessageSchemaType, Dict[str, Any]], + timeout: bool = False + ): + """ + Send a message on the wrapped websocket. If the sending + takes too long, it will raise a TimeoutError and + disconnect the connection. - # If we got here everything is ok - return True - else: - return False + :param message: The message to send + :param timeout: Enforce send high limit, defaults to False + """ + try: + _ = time.time() + # If the send times out, it will raise + # a TimeoutError and bubble up to the + # message_endpoint to close the connection + await asyncio.wait_for( + self._wrapped_ws.send(message), + timeout=self._send_high_limit if timeout else None + ) + total_time = time.time() - _ + self._send_times.append(total_time) + + self._calc_send_limit() + except asyncio.TimeoutError: + logger.info(f"Connection for {self} timed out, disconnecting") + raise + + # Explicitly give control back to event loop as + # websockets.send does not + await asyncio.sleep(0.01) async def recv(self): """ - Receive data on the wrapped websocket + Receive a message on the wrapped websocket """ return await self._wrapped_ws.recv() @@ -107,17 +120,27 @@ class WebSocketChannel: """ return await self._websocket.ping() + async def accept(self): + """ + Accept the underlying websocket connection, + if the connection has been closed before we can + accept, just close the channel. + """ + try: + return await self._websocket.accept() + except RuntimeError: + await self.close() + async def close(self): """ Close the WebSocketChannel """ self._closed.set() - self._relay_task.cancel() try: - await self.raw_websocket.close() - except Exception: + await self._websocket.close() + except RuntimeError: pass def is_closed(self) -> bool: @@ -142,99 +165,76 @@ class WebSocketChannel: """ return message_type in self._subscriptions - async def relay(self): + async def run_channel_tasks(self, *tasks, **kwargs): """ - Relay messages from the channel's queue and send them out. This is started - as a task. + Create and await on the channel tasks unless an exception + was raised, then cancel them all. + + :params *tasks: All coros or tasks to be run concurrently + :param **kwargs: Any extra kwargs to pass to gather """ - while not self._closed.is_set(): - message = await self.queue.get() + + if not self.is_closed(): + # Wrap the coros into tasks if they aren't already + self._channel_tasks = [ + task if isinstance(task, asyncio.Task) else asyncio.create_task(task) + for task in tasks + ] + try: - await self._send(message) - self.queue.task_done() + return await asyncio.gather(*self._channel_tasks, **kwargs) + except Exception: + # If an exception occurred, cancel the rest of the tasks + await self.cancel_channel_tasks() - # Limit messages per sec. - # Could cause problems with queue size if too low, and - # problems with network traffik if too high. - # 0.01 = 100/s - await asyncio.sleep(self.throttle) - except RuntimeError: - # The connection was closed, just exit the task - return - - -class ChannelManager: - def __init__(self): - self.channels = dict() - self._lock = RLock() # Re-entrant Lock - - async def on_connect(self, websocket: WebSocketType): + async def cancel_channel_tasks(self): """ - Wrap websocket connection into Channel and add to list - - :param websocket: The WebSocket object to attach to the Channel + Cancel and wait on all channel tasks """ - if isinstance(websocket, FastAPIWebSocket): + for task in self._channel_tasks: + task.cancel() + + # Wait for tasks to finish cancelling try: - await websocket.accept() - except RuntimeError: - # The connection was closed before we could accept it - return + await task + except ( + asyncio.CancelledError, + asyncio.TimeoutError, + WebSocketDisconnect, + ConnectionClosed, + RuntimeError + ): + pass + except Exception as e: + logger.info(f"Encountered unknown exception: {e}", exc_info=e) - ws_channel = WebSocketChannel(websocket) + self._channel_tasks = [] - with self._lock: - self.channels[websocket] = ws_channel - - return ws_channel - - async def on_disconnect(self, websocket: WebSocketType): + async def __aiter__(self): """ - Call close on the channel if it's not, and remove from channel list + Generator for received messages + """ + # We can not catch any errors here as websocket.recv is + # the first to catch any disconnects and bubble it up + # so the connection is garbage collected right away + while not self.is_closed(): + yield await self.recv() - :param websocket: The WebSocket objet attached to the Channel - """ - with self._lock: - channel = self.channels.get(websocket) - if channel: - logger.info(f"Disconnecting channel {channel}") - if not channel.is_closed(): - await channel.close() - del self.channels[websocket] +@asynccontextmanager +async def create_channel( + websocket: WebSocketType, + **kwargs +) -> AsyncIterator[WebSocketChannel]: + """ + Context manager for safely opening and closing a WebSocketChannel + """ + channel = WebSocketChannel(websocket, **kwargs) + try: + await channel.accept() + logger.info(f"Connected to channel - {channel}") - async def disconnect_all(self): - """ - Disconnect all Channels - """ - with self._lock: - for websocket in self.channels.copy().keys(): - await self.on_disconnect(websocket) - - async def broadcast(self, message: WSMessageSchemaType): - """ - Broadcast a message on all Channels - - :param message: The message to send - """ - with self._lock: - for channel in self.channels.copy().values(): - if channel.subscribed_to(message.get('type')): - await self.send_direct(channel, message) - - async def send_direct( - self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): - """ - Send a message directly through direct_channel only - - :param direct_channel: The WebSocketChannel object to send the message through - :param message: The message to send - """ - if not await channel.send(message): - await self.on_disconnect(channel.raw_websocket) - - def has_channels(self): - """ - Flag for more than 0 channels - """ - return len(self.channels) > 0 + yield channel + finally: + await channel.close() + logger.info(f"Disconnected from channel - {channel}") diff --git a/freqtrade/rpc/api_server/ws/message_stream.py b/freqtrade/rpc/api_server/ws/message_stream.py new file mode 100644 index 000000000..a55a0da3c --- /dev/null +++ b/freqtrade/rpc/api_server/ws/message_stream.py @@ -0,0 +1,31 @@ +import asyncio +import time + + +class MessageStream: + """ + A message stream for consumers to subscribe to, + and for producers to publish to. + """ + def __init__(self): + self._loop = asyncio.get_running_loop() + self._waiter = self._loop.create_future() + + def publish(self, message): + """ + Publish a message to this MessageStream + + :param message: The message to publish + """ + waiter, self._waiter = self._waiter, self._loop.create_future() + waiter.set_result((message, time.time(), self._waiter)) + + async def __aiter__(self): + """ + Iterate over the messages in the message stream + """ + waiter = self._waiter + while True: + # Shield the future from being cancelled by a task waiting on it + message, ts, waiter = await asyncio.shield(waiter) + yield message, ts diff --git a/freqtrade/rpc/api_server/ws/serializer.py b/freqtrade/rpc/api_server/ws/serializer.py index 6c402a100..9a894e1bf 100644 --- a/freqtrade/rpc/api_server/ws/serializer.py +++ b/freqtrade/rpc/api_server/ws/serializer.py @@ -1,5 +1,6 @@ import logging from abc import ABC, abstractmethod +from typing import Any, Dict, Union import orjson import rapidjson @@ -7,6 +8,7 @@ from pandas import DataFrame from freqtrade.misc import dataframe_to_json, json_to_dataframe from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy +from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType logger = logging.getLogger(__name__) @@ -24,17 +26,13 @@ class WebSocketSerializer(ABC): def _deserialize(self, data): raise NotImplementedError() - async def send(self, data: bytes): + async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]): await self._websocket.send(self._serialize(data)) async def recv(self) -> bytes: data = await self._websocket.recv() - return self._deserialize(data) - async def close(self, code: int = 1000): - await self._websocket.close(code) - class HybridJSONWebSocketSerializer(WebSocketSerializer): def _serialize(self, data) -> str: diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index 969728b6f..25d6a32e3 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -57,7 +57,10 @@ def botclient(default_conf, mocker): try: apiserver = ApiServer(default_conf) apiserver.add_rpc_handler(rpc) - yield ftbot, TestClient(apiserver.app) + # We need to use the TestClient as a context manager to + # handle lifespan events correctly + with TestClient(apiserver.app) as client: + yield ftbot, client # Cleanup ... ? finally: if apiserver: @@ -438,7 +441,6 @@ def test_api_cleanup(default_conf, mocker, caplog): apiserver.cleanup() assert apiserver._server.cleanup.call_count == 1 assert log_has("Stopping API Server", caplog) - assert log_has("Stopping API Server background tasks", caplog) ApiServer.shutdown() @@ -1714,12 +1716,14 @@ def test_api_ws_subscribe(botclient, mocker): with client.websocket_connect(ws_url) as ws: ws.send_json({'type': 'subscribe', 'data': ['whitelist']}) + time.sleep(1) # Check call count is now 1 as we sent a valid subscribe request assert sub_mock.call_count == 1 with client.websocket_connect(ws_url) as ws: ws.send_json({'type': 'subscribe', 'data': 'whitelist'}) + time.sleep(1) # Call count hasn't changed as the subscribe request was invalid assert sub_mock.call_count == 1 @@ -1773,24 +1777,18 @@ def test_api_ws_send_msg(default_conf, mocker, caplog): mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api') apiserver = ApiServer(default_conf) apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf))) - apiserver.start_message_queue() - # Give the queue thread time to start - time.sleep(0.2) - # Test message_queue coro receives the message - test_message = {"type": "status", "data": "test"} - apiserver.send_msg(test_message) - time.sleep(0.1) # Not sure how else to wait for the coro to receive the data - assert log_has("Found message of type: status", caplog) + # Start test client context manager to run lifespan events + with TestClient(apiserver.app): + # Test message is published on the Message Stream + test_message = {"type": "status", "data": "test"} + first_waiter = apiserver._message_stream._waiter + apiserver.send_msg(test_message) + assert first_waiter.result()[0] == test_message - # Test if exception logged when error occurs in sending - mocker.patch('freqtrade.rpc.api_server.ws.channel.ChannelManager.broadcast', - side_effect=Exception) - - apiserver.send_msg(test_message) - time.sleep(0.1) # Not sure how else to wait for the coro to receive the data - assert log_has_re(r"Exception happened in background task.*", caplog) + second_waiter = apiserver._message_stream._waiter + apiserver.send_msg(test_message) + assert first_waiter != second_waiter finally: - apiserver.cleanup() ApiServer.shutdown()