From 19f96d60e3bc899818af6833ba941e4df8df8ff7 Mon Sep 17 00:00:00 2001 From: xzmeng Date: Thu, 14 Nov 2024 08:09:59 +0800 Subject: [PATCH] refactor: streamline error handling by raising instead of returning --- freqtrade/exchange/binance_public_data.py | 48 ++++++++-------- tests/exchange/test_binance_public_data.py | 66 +++++++++++++++------- 2 files changed, 68 insertions(+), 46 deletions(-) diff --git a/freqtrade/exchange/binance_public_data.py b/freqtrade/exchange/binance_public_data.py index a8231d2b6..c27bb74a6 100644 --- a/freqtrade/exchange/binance_public_data.py +++ b/freqtrade/exchange/binance_public_data.py @@ -135,11 +135,12 @@ async def _download_archive_ohlcv( for date in dates ] for task in tasks: - result = await task current_day += 1 - if isinstance(result, Http404): + try: + df = await task + except Http404 as e: if stop_on_404: - logger.debug(f"Failed to download {result.url} due to 404.") + logger.debug(f"Failed to download {e.url} due to 404.") # A 404 error on the first day indicates missing data # on https://data.binance.vision, we provide the warning and the advice. @@ -147,7 +148,7 @@ async def _download_archive_ohlcv( if current_day == 1: logger.warning( f"Fast download is unavailable due to missing data: " - f"{result.url}. Falling back to the slower REST API, " + f"{e.url}. Falling back to the slower REST API, " "which may take more time." ) if pair in ["BTC/USDT:USDT", "ETH/USDT:USDT", "BCH/USDT:USDT"]: @@ -158,31 +159,31 @@ async def _download_archive_ohlcv( ) else: logger.warning( - f"Binance fast download for {pair} stopped at {result.date} due to " - f"missing data: {result.url}, falling back to rest API for the " + f"Binance fast download for {pair} stopped at {e.date} due to " + f"missing data: {e.url}, falling back to rest API for the " "remaining data, this can take more time." ) - await cancel_uncompleted_tasks(tasks) + await cancel_and_await_tasks(tasks[tasks.index(task) + 1 :]) return concat(dfs) else: dfs.append(None) - elif isinstance(result, BaseException): - logger.warning(f"An exception raised: : {result}") + except BaseException as e: + logger.warning(f"An exception raised: : {e}") # Directly return the existing data, do not allow the gap within the data - await cancel_uncompleted_tasks(tasks) + await cancel_and_await_tasks(tasks[tasks.index(task) + 1 :]) return concat(dfs) else: - dfs.append(result) + dfs.append(df) return concat(dfs) -async def cancel_uncompleted_tasks(tasks): +async def cancel_and_await_tasks(unawaited_tasks): + """Cancel and await the tasks""" logger.debug("Try to cancel uncompleted download tasks.") - uncompleted_tasks = [task for task in tasks if not task.done()] - for task in uncompleted_tasks: + for task in unawaited_tasks: task.cancel() - await asyncio.gather(*uncompleted_tasks) - logger.debug("All uncompleted download tasks were successfully cancelled.") + await asyncio.gather(*unawaited_tasks, return_exceptions=True) + logger.debug("All download tasks were awaited.") def date_range(start: datetime.date, end: datetime.date): @@ -221,7 +222,7 @@ async def get_daily_ohlcv( session: aiohttp.ClientSession, retry_count: int = 3, retry_delay: float = 0.0, -) -> DataFrame | Exception: +) -> DataFrame: """ Get daily OHLCV from https://data.binance.vision See https://github.com/binance/binance-public-data @@ -233,7 +234,7 @@ async def get_daily_ohlcv( :session: an aiohttp.ClientSession instance :retry_count: times to retry before returning the exceptions :retry_delay: the time to wait before every retry - :return: This function won't raise any exceptions, it will catch and return them + :return: A dataframe containing columns date,open,high,low,close,volume """ url = zip_url(asset_type_url_segment, symbol, timeframe, date) @@ -276,13 +277,8 @@ async def get_daily_ohlcv( raise Http404(f"404: {url}", date, url) else: raise BadHttpStatus(f"{resp.status} - {resp.reason}") - except asyncio.CancelledError as e: - return e except Exception as e: - if isinstance(e, Http404): - return e - else: - if retry >= retry_count: - logger.debug(f"Failed to get data from {url}: {e}") - return e retry += 1 + if isinstance(e, Http404) or retry > retry_count: + logger.debug(f"Failed to get data from {url}: {e}") + raise diff --git a/tests/exchange/test_binance_public_data.py b/tests/exchange/test_binance_public_data.py index 6b23cb2d6..f6e0c2e6a 100644 --- a/tests/exchange/test_binance_public_data.py +++ b/tests/exchange/test_binance_public_data.py @@ -104,10 +104,11 @@ def make_response_from_url(start_date, end_date): @pytest.mark.parametrize( - "candle_type,since,until,first_date,last_date,stop_on_404", + "candle_type,pair,since,until,first_date,last_date,stop_on_404", [ ( CandleType.SPOT, + "BTC/USDT", dt_utc(2020, 1, 1), dt_utc(2020, 1, 2), dt_utc(2020, 1, 1), @@ -116,6 +117,7 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.SPOT, + "BTC/USDT", dt_utc(2020, 1, 1), dt_utc(2020, 1, 1, 23, 59, 59), dt_utc(2020, 1, 1), @@ -124,6 +126,7 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.SPOT, + "BTC/USDT", dt_utc(2020, 1, 1), dt_utc(2020, 1, 5), dt_utc(2020, 1, 1), @@ -132,6 +135,7 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.SPOT, + "BTC/USDT", dt_utc(2019, 12, 25), dt_utc(2020, 1, 5), dt_utc(2020, 1, 1), @@ -140,6 +144,7 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.SPOT, + "BTC/USDT", dt_utc(2019, 1, 1), dt_utc(2019, 1, 5), None, @@ -148,6 +153,7 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.SPOT, + "BTC/USDT", dt_utc(2021, 1, 1), dt_utc(2021, 1, 5), None, @@ -156,6 +162,7 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.SPOT, + "BTC/USDT", dt_utc(2020, 1, 2), None, dt_utc(2020, 1, 2), @@ -164,14 +171,7 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.SPOT, - dt_utc(2019, 12, 25), - dt_utc(2020, 1, 5), - None, - None, - True, - ), - ( - CandleType.SPOT, + "BTC/USDT", dt_utc(2020, 1, 5), dt_utc(2020, 1, 1), None, @@ -180,6 +180,7 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.FUTURES, + "BTC/USDT:USDT", dt_utc(2020, 1, 1), dt_utc(2020, 1, 1, 23, 59, 59), dt_utc(2020, 1, 1), @@ -188,24 +189,49 @@ def make_response_from_url(start_date, end_date): ), ( CandleType.INDEX, + "N/A", dt_utc(2020, 1, 1), dt_utc(2020, 1, 1, 23, 59, 59), None, None, False, ), + # stop_on_404 = True + ( + CandleType.SPOT, + "BTC/USDT", + dt_utc(2019, 12, 25), + dt_utc(2020, 1, 5), + None, + None, + True, + ), + ( + CandleType.SPOT, + "BTC/USDT", + dt_utc(2020, 1, 1), + dt_utc(2020, 1, 5), + dt_utc(2020, 1, 1), + dt_utc(2020, 1, 3, 23), + True, + ), + ( + CandleType.FUTURES, + "BTC/USDT:USDT", + dt_utc(2019, 12, 25), + dt_utc(2020, 1, 5), + None, + None, + True, + ), ], ) async def test_download_archive_ohlcv( - mocker, candle_type, since, until, first_date, last_date, stop_on_404 + mocker, candle_type, pair, since, until, first_date, last_date, stop_on_404 ): history_start = dt_utc(2020, 1, 1).date() history_end = dt_utc(2020, 1, 3).date() timeframe = "1h" - if candle_type == CandleType.SPOT: - pair = "BTC/USDT" - else: - pair = "BTC/USDT:USDT" since_ms = dt_ts(since) until_ms = dt_ts(until) @@ -283,23 +309,23 @@ async def test_get_daily_ohlcv(mocker, testdatadir): "freqtrade.exchange.binance_public_data.aiohttp.ClientSession.get", return_value=MockResponse(b"", 404), ) - df = await get_daily_ohlcv("spot", symbol, timeframe, date, session, retry_delay=0) + with pytest.raises(Http404): + df = await get_daily_ohlcv("spot", symbol, timeframe, date, session, retry_delay=0) assert get.call_count == 1 - assert isinstance(df, Http404) get = mocker.patch( "freqtrade.exchange.binance_public_data.aiohttp.ClientSession.get", return_value=MockResponse(b"", 500), ) mocker.patch("asyncio.sleep") - df = await get_daily_ohlcv("spot", symbol, timeframe, date, session) + with pytest.raises(BadHttpStatus): + df = await get_daily_ohlcv("spot", symbol, timeframe, date, session) assert get.call_count == 4 # 1 + 3 default retries - assert isinstance(df, BadHttpStatus) get = mocker.patch( "freqtrade.exchange.binance_public_data.aiohttp.ClientSession.get", return_value=MockResponse(b"nop", 200), ) - df = await get_daily_ohlcv("spot", symbol, timeframe, date, session) + with pytest.raises(zipfile.BadZipFile): + df = await get_daily_ohlcv("spot", symbol, timeframe, date, session) assert get.call_count == 4 # 1 + 3 default retries - assert isinstance(df, zipfile.BadZipFile)