refactor: move more stuff into generator

This commit is contained in:
Matthias
2025-01-14 19:38:05 +01:00
parent b525ba85c8
commit 8029729ab1

View File

@@ -1433,6 +1433,12 @@ class Backtesting:
detail_data.loc[:, "exit_tag"] = row[EXIT_TAG_IDX]
return detail_data
def time_generator(self, start_date: datetime, end_date: datetime):
current_time = start_date + self.timeframe_td
while current_time <= end_date:
yield current_time
current_time += self.timeframe_td
def time_pair_generator(
self, start_date: datetime, end_date: datetime, increment: timedelta, pairs: list[str]
):
@@ -1445,14 +1451,23 @@ class Backtesting:
self.progress.init_step(
BacktestState.BACKTEST, int((end_date - start_date) / self.timeframe_td)
)
while current_time <= end_date:
is_first = True
for current_time in self.time_generator(start_date, end_date):
# Loop for each time point.
self.check_abort()
# Reset open trade count for this candle
# Critical to avoid exceeding max_open_trades in backtesting
# when timeframe-detail is used and trades close within the opening candle.
LocalTrade.bt_open_open_trade_count_candle = LocalTrade.bt_open_open_trade_count
strategy_safe_wrapper(self.strategy.bot_loop_start, supress_error=True)(
current_time=current_time
)
# Pairs that have open trades should be processed first
new_pairlist = list(dict.fromkeys([t.pair for t in LocalTrade.bt_trades_open] + pairs))
for pair in new_pairlist:
yield current_time, pair, is_first
is_first = False
yield current_time, pair
self.progress.increment()
current_time += increment
@@ -1482,18 +1497,9 @@ class Backtesting:
indexes: dict = defaultdict(int)
# Loop timerange and get candle for each pair at that point in time
for current_time, pair, is_first_call in self.time_pair_generator(
for current_time, pair in self.time_pair_generator(
start_date, end_date, self.timeframe_td, list(data.keys())
):
if is_first_call:
self.check_abort()
# Reset open trade count for this candle
# Critical to avoid exceeding max_open_trades in backtesting
# when timeframe-detail is used and trades close within the opening candle.
LocalTrade.bt_open_open_trade_count_candle = LocalTrade.bt_open_open_trade_count
strategy_safe_wrapper(self.strategy.bot_loop_start, supress_error=True)(
current_time=current_time
)
row_index = indexes[pair]
row = self.validate_row(data, pair, row_index, current_time)
if not row: