refactor: move index-handling into generator

This commit is contained in:
Matthias
2025-01-14 19:57:39 +01:00
parent 14acc6609f
commit 96cea99d4f

View File

@@ -1440,7 +1440,12 @@ class Backtesting:
current_time += self.timeframe_td
def time_pair_generator(
self, start_date: datetime, end_date: datetime, increment: timedelta, pairs: list[str]
self,
start_date: datetime,
end_date: datetime,
increment: timedelta,
pairs: list[str],
data: dict[str, list[tuple]],
):
"""
Backtest time and pair generator
@@ -1451,9 +1456,11 @@ class Backtesting:
self.progress.init_step(
BacktestState.BACKTEST, int((end_date - start_date) / self.timeframe_td)
)
for current_time in self.time_generator(start_date, end_date):
# Loop for each time point.
# Indexes per pair, so some pairs are allowed to have a missing start.
indexes: dict = defaultdict(int)
for current_time in self.time_generator(start_date, end_date):
# Loop for each main candle.
self.check_abort()
# Reset open trade count for this candle
# Critical to avoid exceeding max_open_trades in backtesting
@@ -1467,7 +1474,18 @@ class Backtesting:
new_pairlist = list(dict.fromkeys([t.pair for t in LocalTrade.bt_trades_open] + pairs))
for pair in new_pairlist:
yield current_time, pair
row_index = indexes[pair]
row = self.validate_row(data, pair, row_index, current_time)
if not row:
continue
row_index += 1
indexes[pair] = row_index
is_last_row = current_time == end_date
self.dataprovider._set_dataframe_max_index(self.required_startup + row_index)
self.dataprovider._set_dataframe_max_date(current_time)
yield current_time, pair, row, is_last_row
self.progress.increment()
@@ -1492,23 +1510,10 @@ class Backtesting:
# (looping lists is a lot faster than pandas DataFrames)
data: dict = self._get_ohlcv_as_lists(processed)
# Indexes per pair, so some pairs are allowed to have a missing start.
indexes: dict = defaultdict(int)
# Loop timerange and get candle for each pair at that point in time
for current_time, pair in self.time_pair_generator(
start_date, end_date, self.timeframe_td, list(data.keys())
for current_time, pair, row, is_last_row in self.time_pair_generator(
start_date, end_date, self.timeframe_td, list(data.keys()), data
):
row_index = indexes[pair]
row = self.validate_row(data, pair, row_index, current_time)
if not row:
continue
row_index += 1
indexes[pair] = row_index
is_last_row = current_time == end_date
self.dataprovider._set_dataframe_max_index(self.required_startup + row_index)
self.dataprovider._set_dataframe_max_date(current_time)
trade_dir: LongShort | None = self.check_for_trade_entry(row)
pair_has_open_trades = len(LocalTrade.bt_trades_open_pp[pair]) > 0