Initial commit
This commit is contained in:
26
metric_depth/util/metric.py
Normal file
26
metric_depth/util/metric.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
|
||||
|
||||
def eval_depth(pred, target):
|
||||
assert pred.shape == target.shape
|
||||
|
||||
thresh = torch.max((target / pred), (pred / target))
|
||||
|
||||
d1 = torch.sum(thresh < 1.25).float() / len(thresh)
|
||||
d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh)
|
||||
d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh)
|
||||
|
||||
diff = pred - target
|
||||
diff_log = torch.log(pred) - torch.log(target)
|
||||
|
||||
abs_rel = torch.mean(torch.abs(diff) / target)
|
||||
sq_rel = torch.mean(torch.pow(diff, 2) / target)
|
||||
|
||||
rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
|
||||
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2)))
|
||||
|
||||
log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target)))
|
||||
silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))
|
||||
|
||||
return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(), 'sq_rel': sq_rel.item(),
|
||||
'rmse': rmse.item(), 'rmse_log': rmse_log.item(), 'log10':log10.item(), 'silog':silog.item()}
|
||||
Reference in New Issue
Block a user