loop 10 times in ut

This commit is contained in:
Yu Li
2023-12-01 16:43:31 -06:00
parent 5cdf48b598
commit 6ad5ae82ff

View File

@@ -16,7 +16,7 @@ class TestCompression(unittest.TestCase):
pass
def test_should_compress_uncompress(self):
torch.manual_seed(0)
#torch.manual_seed(0)
a0 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda()
a1 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda()
@@ -24,18 +24,20 @@ class TestCompression(unittest.TestCase):
loss_fn = torch.nn.MSELoss()
for compression in [None, '4bit', '8bit']:
b = compress_layer_state_dict(a_state_dict, compression)
for iloop in range(10):
for compression in [None, '4bit', '8bit']:
b = compress_layer_state_dict(a_state_dict, compression)
print(f"for compression {compression}, compressed to: { {k:v.shape for k,v in b.items()} }")
if iloop < 2:
print(f"for compression {compression}, compressed to: { {k:v.shape for k,v in b.items()} }")
aa = uncompress_layer_state_dict(b)
aa = uncompress_layer_state_dict(b)
for k in aa.keys():
for k in aa.keys():
if compression is None:
self.assertTrue(torch.equal(aa[k], a_state_dict[k]))
else:
RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])).detach().cpu().item()
print(f"compression {compression} loss: {RMSE_loss}")
self.assertLess(RMSE_loss, 2.5)
if compression is None:
self.assertTrue(torch.equal(aa[k], a_state_dict[k]))
else:
RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])).detach().cpu().item()
print(f"compression {compression} loss: {RMSE_loss}")
self.assertLess(RMSE_loss, 2.5)