mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-11-29 08:33:07 +00:00
Combine custom_data classes to one file
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
# flake8: noqa: F401
|
# flake8: noqa: F401
|
||||||
|
|
||||||
from freqtrade.persistence.custom_data_middleware import CustomDataWrapper
|
from freqtrade.persistence.custom_data import CustomDataWrapper
|
||||||
from freqtrade.persistence.key_value_store import KeyStoreKeys, KeyValueStore
|
from freqtrade.persistence.key_value_store import KeyStoreKeys, KeyValueStore
|
||||||
from freqtrade.persistence.models import init_db
|
from freqtrade.persistence.models import init_db
|
||||||
from freqtrade.persistence.pairlock_middleware import PairLocks
|
from freqtrade.persistence.pairlock_middleware import PairLocks
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import ClassVar, Optional, Sequence
|
from typing import Any, ClassVar, List, Optional, Sequence
|
||||||
|
|
||||||
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select
|
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
@@ -9,6 +11,9 @@ from freqtrade.persistence.base import ModelBase, SessionType
|
|||||||
from freqtrade.util import dt_now
|
from freqtrade.util import dt_now
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomData(ModelBase):
|
class CustomData(ModelBase):
|
||||||
"""
|
"""
|
||||||
CustomData database model
|
CustomData database model
|
||||||
@@ -60,3 +65,107 @@ class CustomData(ModelBase):
|
|||||||
filters.append(CustomData.cd_key.ilike(key))
|
filters.append(CustomData.cd_key.ilike(key))
|
||||||
|
|
||||||
return CustomData.session.scalars(select(CustomData).filter(*filters)).all()
|
return CustomData.session.scalars(select(CustomData).filter(*filters)).all()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDataWrapper:
|
||||||
|
"""
|
||||||
|
CustomData middleware class
|
||||||
|
Abstracts the database layer away so it becomes optional - which will be necessary to support
|
||||||
|
backtesting and hyperopt in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
use_db = True
|
||||||
|
custom_data: List[CustomData] = []
|
||||||
|
unserialized_types = ['bool', 'float', 'int', 'str']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset_custom_data() -> None:
|
||||||
|
"""
|
||||||
|
Resets all key-value pairs. Only active for backtesting mode.
|
||||||
|
"""
|
||||||
|
if not CustomDataWrapper.use_db:
|
||||||
|
CustomDataWrapper.custom_data = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_custom_data(key: Optional[str] = None,
|
||||||
|
trade_id: Optional[int] = None) -> CustomData:
|
||||||
|
if trade_id is None:
|
||||||
|
trade_id = 0
|
||||||
|
|
||||||
|
if CustomDataWrapper.use_db:
|
||||||
|
filtered_custom_data = []
|
||||||
|
for data_entry in CustomData.query_cd(trade_id=trade_id, key=key):
|
||||||
|
if data_entry.cd_type not in CustomDataWrapper.unserialized_types:
|
||||||
|
data_entry.cd_value = json.loads(data_entry.cd_value)
|
||||||
|
filtered_custom_data.append(data_entry)
|
||||||
|
return filtered_custom_data
|
||||||
|
else:
|
||||||
|
filtered_custom_data = [
|
||||||
|
data_entry for data_entry in CustomDataWrapper.custom_data
|
||||||
|
if (data_entry.ft_trade_id == trade_id)
|
||||||
|
]
|
||||||
|
if key is not None:
|
||||||
|
filtered_custom_data = [
|
||||||
|
data_entry for data_entry in filtered_custom_data
|
||||||
|
if (data_entry.cd_key.casefold() == key.casefold())
|
||||||
|
]
|
||||||
|
return filtered_custom_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None:
|
||||||
|
|
||||||
|
value_type = type(value).__name__
|
||||||
|
value_db = None
|
||||||
|
|
||||||
|
if value_type not in CustomDataWrapper.unserialized_types:
|
||||||
|
try:
|
||||||
|
value_db = json.dumps(value)
|
||||||
|
except TypeError as e:
|
||||||
|
logger.warning(f"could not serialize {key} value due to {e}")
|
||||||
|
else:
|
||||||
|
value_db = str(value)
|
||||||
|
|
||||||
|
if trade_id is None:
|
||||||
|
trade_id = 0
|
||||||
|
|
||||||
|
custom_data = CustomDataWrapper.get_custom_data(key=key, trade_id=trade_id)
|
||||||
|
if custom_data:
|
||||||
|
data_entry = custom_data[0]
|
||||||
|
data_entry.cd_value = value_db
|
||||||
|
data_entry.updated_at = dt_now()
|
||||||
|
else:
|
||||||
|
data_entry = CustomData(
|
||||||
|
ft_trade_id=trade_id,
|
||||||
|
cd_key=key,
|
||||||
|
cd_type=value_type,
|
||||||
|
cd_value=value_db,
|
||||||
|
created_at=dt_now()
|
||||||
|
)
|
||||||
|
|
||||||
|
if CustomDataWrapper.use_db and value_db is not None:
|
||||||
|
data_entry.cd_value = value_db
|
||||||
|
CustomData.session.add(data_entry)
|
||||||
|
CustomData.session.commit()
|
||||||
|
elif not CustomDataWrapper.use_db:
|
||||||
|
cd_index = -1
|
||||||
|
for index, data_entry in enumerate(CustomDataWrapper.custom_data):
|
||||||
|
if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key:
|
||||||
|
cd_index = index
|
||||||
|
break
|
||||||
|
|
||||||
|
if cd_index >= 0:
|
||||||
|
data_entry.cd_type = value_type
|
||||||
|
data_entry.cd_value = value_db
|
||||||
|
data_entry.updated_at = dt_now()
|
||||||
|
|
||||||
|
CustomDataWrapper.custom_data[cd_index] = data_entry
|
||||||
|
else:
|
||||||
|
CustomDataWrapper.custom_data.append(data_entry)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_custom_data() -> List[CustomData]:
|
||||||
|
|
||||||
|
if CustomDataWrapper.use_db:
|
||||||
|
return list(CustomData.query_cd())
|
||||||
|
else:
|
||||||
|
return CustomDataWrapper.custom_data
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from freqtrade.persistence.custom_data import CustomData
|
|
||||||
from freqtrade.util import dt_now
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDataWrapper:
|
|
||||||
"""
|
|
||||||
CustomData middleware class
|
|
||||||
Abstracts the database layer away so it becomes optional - which will be necessary to support
|
|
||||||
backtesting and hyperopt in the future.
|
|
||||||
"""
|
|
||||||
|
|
||||||
use_db = True
|
|
||||||
custom_data: List[CustomData] = []
|
|
||||||
unserialized_types = ['bool', 'float', 'int', 'str']
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reset_custom_data() -> None:
|
|
||||||
"""
|
|
||||||
Resets all key-value pairs. Only active for backtesting mode.
|
|
||||||
"""
|
|
||||||
if not CustomDataWrapper.use_db:
|
|
||||||
CustomDataWrapper.custom_data = []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_custom_data(key: Optional[str] = None,
|
|
||||||
trade_id: Optional[int] = None) -> CustomData:
|
|
||||||
if trade_id is None:
|
|
||||||
trade_id = 0
|
|
||||||
|
|
||||||
if CustomDataWrapper.use_db:
|
|
||||||
filtered_custom_data = []
|
|
||||||
for data_entry in CustomData.query_cd(trade_id=trade_id, key=key):
|
|
||||||
if data_entry.cd_type not in CustomDataWrapper.unserialized_types:
|
|
||||||
data_entry.cd_value = json.loads(data_entry.cd_value)
|
|
||||||
filtered_custom_data.append(data_entry)
|
|
||||||
return filtered_custom_data
|
|
||||||
else:
|
|
||||||
filtered_custom_data = [
|
|
||||||
data_entry for data_entry in CustomDataWrapper.custom_data
|
|
||||||
if (data_entry.ft_trade_id == trade_id)
|
|
||||||
]
|
|
||||||
if key is not None:
|
|
||||||
filtered_custom_data = [
|
|
||||||
data_entry for data_entry in filtered_custom_data
|
|
||||||
if (data_entry.cd_key.casefold() == key.casefold())
|
|
||||||
]
|
|
||||||
return filtered_custom_data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None:
|
|
||||||
|
|
||||||
value_type = type(value).__name__
|
|
||||||
value_db = None
|
|
||||||
|
|
||||||
if value_type not in CustomDataWrapper.unserialized_types:
|
|
||||||
try:
|
|
||||||
value_db = json.dumps(value)
|
|
||||||
except TypeError as e:
|
|
||||||
logger.warning(f"could not serialize {key} value due to {e}")
|
|
||||||
else:
|
|
||||||
value_db = str(value)
|
|
||||||
|
|
||||||
if trade_id is None:
|
|
||||||
trade_id = 0
|
|
||||||
|
|
||||||
custom_data = CustomDataWrapper.get_custom_data(key=key, trade_id=trade_id)
|
|
||||||
if custom_data:
|
|
||||||
data_entry = custom_data[0]
|
|
||||||
data_entry.cd_value = value_db
|
|
||||||
data_entry.updated_at = dt_now()
|
|
||||||
else:
|
|
||||||
data_entry = CustomData(
|
|
||||||
ft_trade_id=trade_id,
|
|
||||||
cd_key=key,
|
|
||||||
cd_type=value_type,
|
|
||||||
cd_value=value_db,
|
|
||||||
created_at=dt_now()
|
|
||||||
)
|
|
||||||
|
|
||||||
if CustomDataWrapper.use_db and value_db is not None:
|
|
||||||
data_entry.cd_value = value_db
|
|
||||||
CustomData.session.add(data_entry)
|
|
||||||
CustomData.session.commit()
|
|
||||||
elif not CustomDataWrapper.use_db:
|
|
||||||
cd_index = -1
|
|
||||||
for index, data_entry in enumerate(CustomDataWrapper.custom_data):
|
|
||||||
if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key:
|
|
||||||
cd_index = index
|
|
||||||
break
|
|
||||||
|
|
||||||
if cd_index >= 0:
|
|
||||||
data_entry.cd_type = value_type
|
|
||||||
data_entry.cd_value = value_db
|
|
||||||
data_entry.updated_at = dt_now()
|
|
||||||
|
|
||||||
CustomDataWrapper.custom_data[cd_index] = data_entry
|
|
||||||
else:
|
|
||||||
CustomDataWrapper.custom_data.append(data_entry)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_all_custom_data() -> List[CustomData]:
|
|
||||||
|
|
||||||
if CustomDataWrapper.use_db:
|
|
||||||
return list(CustomData.query_cd())
|
|
||||||
else:
|
|
||||||
return CustomDataWrapper.custom_data
|
|
||||||
@@ -23,8 +23,7 @@ from freqtrade.exchange import (ROUND_DOWN, ROUND_UP, amount_to_contract_precisi
|
|||||||
from freqtrade.leverage import interest
|
from freqtrade.leverage import interest
|
||||||
from freqtrade.misc import safe_value_fallback
|
from freqtrade.misc import safe_value_fallback
|
||||||
from freqtrade.persistence.base import ModelBase, SessionType
|
from freqtrade.persistence.base import ModelBase, SessionType
|
||||||
from freqtrade.persistence.custom_data import CustomData
|
from freqtrade.persistence.custom_data import CustomData, CustomDataWrapper
|
||||||
from freqtrade.persistence.custom_data_middleware import CustomDataWrapper
|
|
||||||
from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts
|
from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts
|
||||||
|
|
||||||
|
|
||||||
@@ -345,7 +344,7 @@ class LocalTrade:
|
|||||||
id: int = 0
|
id: int = 0
|
||||||
|
|
||||||
orders: List[Order] = []
|
orders: List[Order] = []
|
||||||
custom_data: List[CustomData] = []
|
custom_data: List[_CustomData] = []
|
||||||
|
|
||||||
exchange: str = ''
|
exchange: str = ''
|
||||||
pair: str = ''
|
pair: str = ''
|
||||||
@@ -1209,7 +1208,7 @@ class LocalTrade:
|
|||||||
def set_custom_data(self, key: str, value: Any) -> None:
|
def set_custom_data(self, key: str, value: Any) -> None:
|
||||||
CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id)
|
CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id)
|
||||||
|
|
||||||
def get_custom_data(self, key: Optional[str]) -> List[CustomData]:
|
def get_custom_data(self, key: Optional[str]) -> List[_CustomData]:
|
||||||
return CustomDataWrapper.get_custom_data(key=key, trade_id=self.id)
|
return CustomDataWrapper.get_custom_data(key=key, trade_id=self.id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1467,7 +1466,7 @@ class Trade(ModelBase, LocalTrade):
|
|||||||
orders: Mapped[List[Order]] = relationship(
|
orders: Mapped[List[Order]] = relationship(
|
||||||
"Order", order_by="Order.id", cascade="all, delete-orphan", lazy="selectin",
|
"Order", order_by="Order.id", cascade="all, delete-orphan", lazy="selectin",
|
||||||
innerjoin=True) # type: ignore
|
innerjoin=True) # type: ignore
|
||||||
custom_data: Mapped[List[CustomData]] = relationship(
|
custom_data: Mapped[List[_CustomData]] = relationship(
|
||||||
"CustomData", order_by="CustomData.id", cascade="all, delete-orphan",
|
"CustomData", order_by="CustomData.id", cascade="all, delete-orphan",
|
||||||
lazy="raise") # type: ignore
|
lazy="raise") # type: ignore
|
||||||
|
|
||||||
@@ -1574,9 +1573,9 @@ class Trade(ModelBase, LocalTrade):
|
|||||||
Order.session.delete(order)
|
Order.session.delete(order)
|
||||||
|
|
||||||
for entry in self.custom_data:
|
for entry in self.custom_data:
|
||||||
CustomData.session.delete(entry)
|
_CustomData.session.delete(entry)
|
||||||
|
|
||||||
CustomData.session.commit()
|
_CustomData.session.commit()
|
||||||
Trade.session.delete(self)
|
Trade.session.delete(self)
|
||||||
Trade.commit()
|
Trade.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
from freqtrade.persistence.custom_data_middleware import CustomDataWrapper
|
from freqtrade.persistence.custom_data import CustomDataWrapper
|
||||||
from freqtrade.persistence.pairlock_middleware import PairLocks
|
from freqtrade.persistence.pairlock_middleware import PairLocks
|
||||||
from freqtrade.persistence.trade_model import Trade
|
from freqtrade.persistence.trade_model import Trade
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user