diff --git a/freqtrade/persistence/migrations.py b/freqtrade/persistence/migrations.py index 1720b4d29..ca3ecf8f2 100644 --- a/freqtrade/persistence/migrations.py +++ b/freqtrade/persistence/migrations.py @@ -30,24 +30,18 @@ def get_backup_name(tabs: list[str], backup_prefix: str): return table_back_name -def get_last_sequence_ids( - engine, trade_back_name: str, order_back_name: str -) -> tuple[int | None, int | None]: - order_id: int | None = None - trade_id: int | None = None +def get_last_sequence_ids(engine, sequence_name: str, table_back_name: str) -> int | None: + last_id: int | None = None if engine.name == "postgresql": with engine.begin() as connection: - trade_id = connection.execute(text("select nextval('trades_id_seq')")).fetchone()[0] - order_id = connection.execute(text("select nextval('orders_id_seq')")).fetchone()[0] + last_id = connection.execute(text(f"select nextval('{sequence_name}')")).fetchone()[0] with engine.begin() as connection: connection.execute( - text(f"ALTER SEQUENCE orders_id_seq rename to {order_back_name}_id_seq_bak") + text(f"ALTER SEQUENCE {sequence_name} rename to {table_back_name}_id_seq_bak") ) - connection.execute( - text(f"ALTER SEQUENCE trades_id_seq rename to {trade_back_name}_id_seq_bak") - ) - return order_id, trade_id + + return last_id def set_sequence_ids( @@ -185,7 +179,8 @@ def migrate_trades_and_orders_table( drop_index_on_table(engine, inspector, trade_back_name) - order_id, trade_id = get_last_sequence_ids(engine, trade_back_name, order_back_name) + order_id = get_last_sequence_ids(engine, "order_id_seq", order_back_name) + trade_id = get_last_sequence_ids(engine, "trades_id_seq", trade_back_name) drop_orders_table(engine, order_back_name) diff --git a/tests/persistence/test_migrations.py b/tests/persistence/test_migrations.py index 1e99060c7..2118193fa 100644 --- a/tests/persistence/test_migrations.py +++ b/tests/persistence/test_migrations.py @@ -360,14 +360,14 @@ def test_migrate_get_last_sequence_ids(): engine = MagicMock() engine.begin = MagicMock() engine.name = "postgresql" - get_last_sequence_ids(engine, "trades_bak", "orders_bak") + get_last_sequence_ids(engine, "trades_id_seq", "trades_bak") assert engine.begin.call_count == 2 engine.reset_mock() engine.begin.reset_mock() engine.name = "somethingelse" - get_last_sequence_ids(engine, "trades_bak", "orders_bak") + get_last_sequence_ids(engine, "trades_id_seq", "trades_bak") assert engine.begin.call_count == 0