diff --git a/air_llm/tests/test_compression.py b/air_llm/tests/test_compression.py index 4957de8..52f77db 100644 --- a/air_llm/tests/test_compression.py +++ b/air_llm/tests/test_compression.py @@ -16,7 +16,7 @@ class TestCompression(unittest.TestCase): pass def test_should_compress_uncompress(self): - torch.manual_seed(0) + #torch.manual_seed(0) a0 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda() a1 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda() @@ -24,18 +24,20 @@ class TestCompression(unittest.TestCase): loss_fn = torch.nn.MSELoss() - for compression in [None, '4bit', '8bit']: - b = compress_layer_state_dict(a_state_dict, compression) + for iloop in range(10): + for compression in [None, '4bit', '8bit']: + b = compress_layer_state_dict(a_state_dict, compression) - print(f"for compression {compression}, compressed to: { {k:v.shape for k,v in b.items()} }") + if iloop < 2: + print(f"for compression {compression}, compressed to: { {k:v.shape for k,v in b.items()} }") - aa = uncompress_layer_state_dict(b) + aa = uncompress_layer_state_dict(b) - for k in aa.keys(): + for k in aa.keys(): - if compression is None: - self.assertTrue(torch.equal(aa[k], a_state_dict[k])) - else: - RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])).detach().cpu().item() - print(f"compression {compression} loss: {RMSE_loss}") - self.assertLess(RMSE_loss, 2.5) \ No newline at end of file + if compression is None: + self.assertTrue(torch.equal(aa[k], a_state_dict[k])) + else: + RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])).detach().cpu().item() + print(f"compression {compression} loss: {RMSE_loss}") + self.assertLess(RMSE_loss, 2.5) \ No newline at end of file