diff --git a/air_llm/tests/test_compression.py b/air_llm/tests/test_compression.py index 42788f2..4957de8 100644 --- a/air_llm/tests/test_compression.py +++ b/air_llm/tests/test_compression.py @@ -36,5 +36,6 @@ class TestCompression(unittest.TestCase): if compression is None: self.assertTrue(torch.equal(aa[k], a_state_dict[k])) else: - RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])) - self.assertLess(RMSE_loss.detach().cpu().item(), 2.5) \ No newline at end of file + RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])).detach().cpu().item() + print(f"compression {compression} loss: {RMSE_loss}") + self.assertLess(RMSE_loss, 2.5) \ No newline at end of file