-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLoss.py
More file actions
81 lines (60 loc) · 2.55 KB
/
Loss.py
File metadata and controls
81 lines (60 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import torch.nn.functional as F
class LossIn(nn.Module):
def __init__(self):
super(LossIn, self).__init__()
def forward(self, image_y, image_ir, generate_img):
x_in_max = torch.max(image_y, image_ir)
loss_in = F.l1_loss(x_in_max, generate_img)
return loss_in
class LossGrad(nn.Module):
def __init__(self):
super(LossGrad, self).__init__()
self.sobelconv = Sobelxy()
def forward(self, image_y, image_ir, generate_img):
y_grad = self.sobelconv(image_y)
ir_grad = self.sobelconv(image_ir)
generate_img_grad = self.sobelconv(generate_img)
x_grad_joint = torch.max(y_grad, ir_grad)
loss_grad = F.l1_loss(x_grad_joint, generate_img_grad)
return loss_grad
class Sobelxy(nn.Module):
def __init__(self):
super(Sobelxy, self).__init__()
kernelx = [[-1, 0, 1],
[-2,0 , 2],
[-1, 0, 1]]
kernely = [[1, 2, 1],
[0,0 , 0],
[-1, -2, -1]]
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
def forward(self,x):
sobelx=F.conv2d(x, self.weightx, padding=1)
sobely=F.conv2d(x, self.weighty, padding=1)
return torch.abs(sobelx)+torch.abs(sobely)
def cc(img1, img2):
eps = torch.finfo(torch.float32).eps
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
N, C, _, _ = img1.shape
img1 = img1.reshape(N, C, -1)
img2 = img2.reshape(N, C, -1)
img1 = img1 - img1.mean(dim=-1, keepdim=True)
img2 = img2 - img2.mean(dim=-1, keepdim=True)
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
cc = torch.clamp(cc, -1., 1.)
return cc.mean()
def total_variation_loss(image):
# 计算图像的总变差损失
batch_size, channels, height, width = image.size()
# 计算水平方向上的差异
horizontal_diff = image[:, :, :, 1:] - image[:, :, :, :-1]
# 计算垂直方向上的差异
vertical_diff = image[:, :, 1:, :] - image[:, :, :-1, :]
# 计算总变差损失
loss = torch.sum(torch.abs(horizontal_diff)) + torch.sum(torch.abs(vertical_diff))
return loss