chore: refactor get_last_sequence_ids

This commit is contained in:
Matthias
2026-02-15 18:24:02 +01:00
parent da62e614d7
commit 8ed1e6c394
2 changed files with 10 additions and 15 deletions

View File

@@ -30,24 +30,18 @@ def get_backup_name(tabs: list[str], backup_prefix: str):
return table_back_name
def get_last_sequence_ids(
engine, trade_back_name: str, order_back_name: str
) -> tuple[int | None, int | None]:
order_id: int | None = None
trade_id: int | None = None
def get_last_sequence_ids(engine, sequence_name: str, table_back_name: str) -> int | None:
last_id: int | None = None
if engine.name == "postgresql":
with engine.begin() as connection:
trade_id = connection.execute(text("select nextval('trades_id_seq')")).fetchone()[0]
order_id = connection.execute(text("select nextval('orders_id_seq')")).fetchone()[0]
last_id = connection.execute(text(f"select nextval('{sequence_name}')")).fetchone()[0]
with engine.begin() as connection:
connection.execute(
text(f"ALTER SEQUENCE orders_id_seq rename to {order_back_name}_id_seq_bak")
text(f"ALTER SEQUENCE {sequence_name} rename to {table_back_name}_id_seq_bak")
)
connection.execute(
text(f"ALTER SEQUENCE trades_id_seq rename to {trade_back_name}_id_seq_bak")
)
return order_id, trade_id
return last_id
def set_sequence_ids(
@@ -185,7 +179,8 @@ def migrate_trades_and_orders_table(
drop_index_on_table(engine, inspector, trade_back_name)
order_id, trade_id = get_last_sequence_ids(engine, trade_back_name, order_back_name)
order_id = get_last_sequence_ids(engine, "order_id_seq", order_back_name)
trade_id = get_last_sequence_ids(engine, "trades_id_seq", trade_back_name)
drop_orders_table(engine, order_back_name)

View File

@@ -360,14 +360,14 @@ def test_migrate_get_last_sequence_ids():
engine = MagicMock()
engine.begin = MagicMock()
engine.name = "postgresql"
get_last_sequence_ids(engine, "trades_bak", "orders_bak")
get_last_sequence_ids(engine, "trades_id_seq", "trades_bak")
assert engine.begin.call_count == 2
engine.reset_mock()
engine.begin.reset_mock()
engine.name = "somethingelse"
get_last_sequence_ids(engine, "trades_bak", "orders_bak")
get_last_sequence_ids(engine, "trades_id_seq", "trades_bak")
assert engine.begin.call_count == 0