diff --git a/freqtrade/persistence/custom_data.py b/freqtrade/persistence/custom_data.py index beae8c478..42b267e95 100644 --- a/freqtrade/persistence/custom_data.py +++ b/freqtrade/persistence/custom_data.py @@ -1,8 +1,8 @@ from datetime import datetime -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Self, Sequence from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select -from sqlalchemy.orm import Mapped, Query, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from freqtrade.constants import DATETIME_PRINT_FORMAT from freqtrade.persistence.base import ModelBase, SessionType @@ -45,16 +45,18 @@ class CustomData(ModelBase): f'value={self.cd_value}, trade_id={self.ft_trade_id}, created={create_time}, ' + f'updated={update_time})') - @staticmethod - def query_cd(key: Optional[str] = None, trade_id: Optional[int] = None) -> Query: + @classmethod + def query_cd(cls, key: Optional[str] = None, + trade_id: Optional[int] = None) -> Sequence['CustomData']: """ Get all CustomData, if trade_id is not specified return will be for generic values not tied to a trade :param trade_id: id of the Trade """ filters = [] - filters.append(CustomData.ft_trade_id == trade_id if trade_id is not None else 0) + if trade_id is not None: + filters.append(CustomData.ft_trade_id == trade_id) if key is not None: filters.append(CustomData.cd_key.ilike(key)) - return CustomData.session.scalars(select(CustomData)) + return CustomData.session.scalars(select(CustomData)).all() diff --git a/freqtrade/persistence/custom_data_middleware.py b/freqtrade/persistence/custom_data_middleware.py index cf7b83abc..2f99d9c75 100644 --- a/freqtrade/persistence/custom_data_middleware.py +++ b/freqtrade/persistence/custom_data_middleware.py @@ -37,11 +37,11 @@ class CustomDataWrapper: trade_id = 0 if CustomDataWrapper.use_db: - filtered_custom_data = CustomData.query_cd(trade_id=trade_id, key=key).all() - for index, data_entry in enumerate(filtered_custom_data): + filtered_custom_data = [] + for data_entry in CustomData.query_cd(trade_id=trade_id, key=key): if data_entry.cd_type not in CustomDataWrapper.unserialized_types: data_entry.cd_value = json.loads(data_entry.cd_value) - filtered_custom_data[index] = data_entry + filtered_custom_data.append(data_entry) return filtered_custom_data else: filtered_custom_data = [ @@ -110,6 +110,6 @@ class CustomDataWrapper: def get_all_custom_data() -> List[CustomData]: if CustomDataWrapper.use_db: - return CustomData.session.scalars(select(CustomData)).all() + return list(CustomData.query_cd()) else: return CustomDataWrapper.custom_data