This commit is contained in:
Yu Li
2023-12-01 16:30:18 -06:00
parent c2d13c1063
commit e1c1aa6317

View File

@@ -34,7 +34,7 @@ class TestCompression(unittest.TestCase):
for k in aa.keys():
if compression is None:
self.assertAlmostEqual(aa[k], a[k])
self.assertAlmostEqual(aa[k], a_state_dict[k])
else:
RMSE_loss = torch.sqrt(loss_fn(aa[k], a[k]))
RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k]))
self.assertLess(RMSE_loss.detach().numpy()[0], 0.5)