From e1c1aa6317c4cd79b52abdb0f3cc158568e45748 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Dec 2023 16:30:18 -0600 Subject: [PATCH] fix ut --- air_llm/tests/test_compression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/air_llm/tests/test_compression.py b/air_llm/tests/test_compression.py index bb8479a..bf7e3d9 100644 --- a/air_llm/tests/test_compression.py +++ b/air_llm/tests/test_compression.py @@ -34,7 +34,7 @@ class TestCompression(unittest.TestCase): for k in aa.keys(): if compression is None: - self.assertAlmostEqual(aa[k], a[k]) + self.assertAlmostEqual(aa[k], a_state_dict[k]) else: - RMSE_loss = torch.sqrt(loss_fn(aa[k], a[k])) + RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])) self.assertLess(RMSE_loss.detach().numpy()[0], 0.5) \ No newline at end of file