Fix some issues with types

This commit is contained in:
Matthias
2024-02-07 19:28:06 +01:00
parent 2393a9fecf
commit 626c904103
2 changed files with 12 additions and 10 deletions

View File

@@ -1,8 +1,8 @@
from datetime import datetime
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Self, Sequence
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select
from sqlalchemy.orm import Mapped, Query, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship
from freqtrade.constants import DATETIME_PRINT_FORMAT
from freqtrade.persistence.base import ModelBase, SessionType
@@ -45,16 +45,18 @@ class CustomData(ModelBase):
f'value={self.cd_value}, trade_id={self.ft_trade_id}, created={create_time}, ' +
f'updated={update_time})')
@staticmethod
def query_cd(key: Optional[str] = None, trade_id: Optional[int] = None) -> Query:
@classmethod
def query_cd(cls, key: Optional[str] = None,
trade_id: Optional[int] = None) -> Sequence['CustomData']:
"""
Get all CustomData, if trade_id is not specified
return will be for generic values not tied to a trade
:param trade_id: id of the Trade
"""
filters = []
filters.append(CustomData.ft_trade_id == trade_id if trade_id is not None else 0)
if trade_id is not None:
filters.append(CustomData.ft_trade_id == trade_id)
if key is not None:
filters.append(CustomData.cd_key.ilike(key))
return CustomData.session.scalars(select(CustomData))
return CustomData.session.scalars(select(CustomData)).all()

View File

@@ -37,11 +37,11 @@ class CustomDataWrapper:
trade_id = 0
if CustomDataWrapper.use_db:
filtered_custom_data = CustomData.query_cd(trade_id=trade_id, key=key).all()
for index, data_entry in enumerate(filtered_custom_data):
filtered_custom_data = []
for data_entry in CustomData.query_cd(trade_id=trade_id, key=key):
if data_entry.cd_type not in CustomDataWrapper.unserialized_types:
data_entry.cd_value = json.loads(data_entry.cd_value)
filtered_custom_data[index] = data_entry
filtered_custom_data.append(data_entry)
return filtered_custom_data
else:
filtered_custom_data = [
@@ -110,6 +110,6 @@ class CustomDataWrapper:
def get_all_custom_data() -> List[CustomData]:
if CustomDataWrapper.use_db:
return CustomData.session.scalars(select(CustomData)).all()
return list(CustomData.query_cd())
else:
return CustomDataWrapper.custom_data