Initial commit
This commit is contained in:
41
metric_depth/util/dist_helper.py
Normal file
41
metric_depth/util/dist_helper.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def setup_distributed(backend="nccl", port=None):
|
||||
"""AdaHessian Optimizer
|
||||
Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py
|
||||
Originally licensed MIT, Copyright (c) 2020 Wei Li
|
||||
"""
|
||||
num_gpus = torch.cuda.device_count()
|
||||
|
||||
if "SLURM_JOB_ID" in os.environ:
|
||||
rank = int(os.environ["SLURM_PROCID"])
|
||||
world_size = int(os.environ["SLURM_NTASKS"])
|
||||
node_list = os.environ["SLURM_NODELIST"]
|
||||
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
|
||||
# specify master port
|
||||
if port is not None:
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
elif "MASTER_PORT" not in os.environ:
|
||||
os.environ["MASTER_PORT"] = "10685"
|
||||
if "MASTER_ADDR" not in os.environ:
|
||||
os.environ["MASTER_ADDR"] = addr
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["LOCAL_RANK"] = str(rank % num_gpus)
|
||||
os.environ["RANK"] = str(rank)
|
||||
else:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
return rank, world_size
|
||||
16
metric_depth/util/loss.py
Normal file
16
metric_depth/util/loss.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SiLogLoss(nn.Module):
|
||||
def __init__(self, lambd=0.5):
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def forward(self, pred, target, valid_mask):
|
||||
valid_mask = valid_mask.detach()
|
||||
diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
|
||||
loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
|
||||
self.lambd * torch.pow(diff_log.mean(), 2))
|
||||
|
||||
return loss
|
||||
26
metric_depth/util/metric.py
Normal file
26
metric_depth/util/metric.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
|
||||
|
||||
def eval_depth(pred, target):
|
||||
assert pred.shape == target.shape
|
||||
|
||||
thresh = torch.max((target / pred), (pred / target))
|
||||
|
||||
d1 = torch.sum(thresh < 1.25).float() / len(thresh)
|
||||
d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh)
|
||||
d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh)
|
||||
|
||||
diff = pred - target
|
||||
diff_log = torch.log(pred) - torch.log(target)
|
||||
|
||||
abs_rel = torch.mean(torch.abs(diff) / target)
|
||||
sq_rel = torch.mean(torch.pow(diff, 2) / target)
|
||||
|
||||
rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
|
||||
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2)))
|
||||
|
||||
log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target)))
|
||||
silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))
|
||||
|
||||
return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(), 'sq_rel': sq_rel.item(),
|
||||
'rmse': rmse.item(), 'rmse_log': rmse_log.item(), 'log10':log10.item(), 'silog':silog.item()}
|
||||
26
metric_depth/util/utils.py
Normal file
26
metric_depth/util/utils.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
logs = set()
|
||||
|
||||
|
||||
def init_log(name, level=logging.INFO):
|
||||
if (name, level) in logs:
|
||||
return
|
||||
logs.add((name, level))
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(level)
|
||||
if "SLURM_PROCID" in os.environ:
|
||||
rank = int(os.environ["SLURM_PROCID"])
|
||||
logger.addFilter(lambda record: rank == 0)
|
||||
else:
|
||||
rank = 0
|
||||
format_str = "[%(asctime)s][%(levelname)8s] %(message)s"
|
||||
formatter = logging.Formatter(format_str)
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
return logger
|
||||
Reference in New Issue
Block a user