Initial commit

This commit is contained in:
Lihe Yang
2024-06-14 03:44:54 +08:00
committed by GitHub
parent a0c63fccc6
commit 2cbc36a8ce
73 changed files with 91693 additions and 2 deletions

View File

@@ -0,0 +1,41 @@
import os
import subprocess
import torch
import torch.distributed as dist
def setup_distributed(backend="nccl", port=None):
"""AdaHessian Optimizer
Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py
Originally licensed MIT, Copyright (c) 2020 Wei Li
"""
num_gpus = torch.cuda.device_count()
if "SLURM_JOB_ID" in os.environ:
rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NTASKS"])
node_list = os.environ["SLURM_NODELIST"]
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
# specify master port
if port is not None:
os.environ["MASTER_PORT"] = str(port)
elif "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "10685"
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = addr
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(rank % num_gpus)
os.environ["RANK"] = str(rank)
else:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
)
return rank, world_size

16
metric_depth/util/loss.py Normal file
View File

@@ -0,0 +1,16 @@
import torch
from torch import nn
class SiLogLoss(nn.Module):
def __init__(self, lambd=0.5):
super().__init__()
self.lambd = lambd
def forward(self, pred, target, valid_mask):
valid_mask = valid_mask.detach()
diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
self.lambd * torch.pow(diff_log.mean(), 2))
return loss

View File

@@ -0,0 +1,26 @@
import torch
def eval_depth(pred, target):
assert pred.shape == target.shape
thresh = torch.max((target / pred), (pred / target))
d1 = torch.sum(thresh < 1.25).float() / len(thresh)
d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh)
d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh)
diff = pred - target
diff_log = torch.log(pred) - torch.log(target)
abs_rel = torch.mean(torch.abs(diff) / target)
sq_rel = torch.mean(torch.pow(diff, 2) / target)
rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2)))
log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target)))
silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))
return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(), 'sq_rel': sq_rel.item(),
'rmse': rmse.item(), 'rmse_log': rmse_log.item(), 'log10':log10.item(), 'silog':silog.item()}

View File

@@ -0,0 +1,26 @@
import os
import re
import numpy as np
import logging
logs = set()
def init_log(name, level=logging.INFO):
if (name, level) in logs:
return
logs.add((name, level))
logger = logging.getLogger(name)
logger.setLevel(level)
ch = logging.StreamHandler()
ch.setLevel(level)
if "SLURM_PROCID" in os.environ:
rank = int(os.environ["SLURM_PROCID"])
logger.addFilter(lambda record: rank == 0)
else:
rank = 0
format_str = "[%(asctime)s][%(levelname)8s] %(message)s"
formatter = logging.Formatter(format_str)
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger