From 59cc6077614dd5f45c7983077edcdbc44e25008e Mon Sep 17 00:00:00 2001 From: Matthias Date: Sun, 14 Jan 2024 15:18:10 +0100 Subject: [PATCH] Don't force-patch torch if it ain't installed. --- tests/freqai/conftest.py | 7 ++++++- tests/freqai/test_freqai_interface.py | 7 +------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/freqai/conftest.py b/tests/freqai/conftest.py index 57ba3f64b..81d72d92a 100644 --- a/tests/freqai/conftest.py +++ b/tests/freqai/conftest.py @@ -1,4 +1,5 @@ import platform +import sys from copy import deepcopy from pathlib import Path from typing import Any, Dict @@ -15,6 +16,10 @@ from freqtrade.resolvers.freqaimodel_resolver import FreqaiModelResolver from tests.conftest import get_patched_exchange +def is_py12() -> bool: + return sys.version_info >= (3, 12) + + def is_mac() -> bool: machine = platform.system() return "Darwin" in machine @@ -31,7 +36,7 @@ def patch_torch_initlogs(mocker) -> None: module_name = 'torch' mocked_module = types.ModuleType(module_name) sys.modules[module_name] = mocked_module - else: + elif not is_py12(): mocker.patch("torch._logging._init_logs") diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index cc5a9b326..c020b098b 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -1,7 +1,6 @@ import logging import platform import shutil -import sys from pathlib import Path from unittest.mock import MagicMock @@ -16,14 +15,10 @@ from freqtrade.optimize.backtesting import Backtesting from freqtrade.persistence import Trade from freqtrade.plugins.pairlistmanager import PairListManager from tests.conftest import EXMS, create_mock_trades, get_patched_exchange, log_has_re -from tests.freqai.conftest import (get_patched_freqai_strategy, is_mac, make_rl_config, +from tests.freqai.conftest import (get_patched_freqai_strategy, is_mac, is_py12, make_rl_config, mock_pytorch_mlp_model_training_parameters) -def is_py12() -> bool: - return sys.version_info >= (3, 12) - - def is_arm() -> bool: machine = platform.machine() return "arm" in machine or "aarch64" in machine