mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
loop 10 times in ut
This commit is contained in:
@@ -16,7 +16,7 @@ class TestCompression(unittest.TestCase):
|
||||
pass
|
||||
|
||||
def test_should_compress_uncompress(self):
|
||||
torch.manual_seed(0)
|
||||
#torch.manual_seed(0)
|
||||
a0 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda()
|
||||
a1 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda()
|
||||
|
||||
@@ -24,18 +24,20 @@ class TestCompression(unittest.TestCase):
|
||||
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
|
||||
for compression in [None, '4bit', '8bit']:
|
||||
b = compress_layer_state_dict(a_state_dict, compression)
|
||||
for iloop in range(10):
|
||||
for compression in [None, '4bit', '8bit']:
|
||||
b = compress_layer_state_dict(a_state_dict, compression)
|
||||
|
||||
print(f"for compression {compression}, compressed to: { {k:v.shape for k,v in b.items()} }")
|
||||
if iloop < 2:
|
||||
print(f"for compression {compression}, compressed to: { {k:v.shape for k,v in b.items()} }")
|
||||
|
||||
aa = uncompress_layer_state_dict(b)
|
||||
aa = uncompress_layer_state_dict(b)
|
||||
|
||||
for k in aa.keys():
|
||||
for k in aa.keys():
|
||||
|
||||
if compression is None:
|
||||
self.assertTrue(torch.equal(aa[k], a_state_dict[k]))
|
||||
else:
|
||||
RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])).detach().cpu().item()
|
||||
print(f"compression {compression} loss: {RMSE_loss}")
|
||||
self.assertLess(RMSE_loss, 2.5)
|
||||
if compression is None:
|
||||
self.assertTrue(torch.equal(aa[k], a_state_dict[k]))
|
||||
else:
|
||||
RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])).detach().cpu().item()
|
||||
print(f"compression {compression} loss: {RMSE_loss}")
|
||||
self.assertLess(RMSE_loss, 2.5)
|
||||
Reference in New Issue
Block a user