mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
fix ut
This commit is contained in:
@@ -36,5 +36,6 @@ class TestCompression(unittest.TestCase):
|
||||
if compression is None:
|
||||
self.assertTrue(torch.equal(aa[k], a_state_dict[k]))
|
||||
else:
|
||||
RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k]))
|
||||
self.assertLess(RMSE_loss.detach().cpu().item(), 2.5)
|
||||
RMSE_loss = torch.sqrt(loss_fn(aa[k], a_state_dict[k])).detach().cpu().item()
|
||||
print(f"compression {compression} loss: {RMSE_loss}")
|
||||
self.assertLess(RMSE_loss, 2.5)
|
||||
Reference in New Issue
Block a user