语义分割中常用的损失函数
语义分割中常用的损失函数
语义分割
语义分割是计算机视觉领域中的一项任务,旨在将图像中的每个像素分类为不同的语义类别。与对象检测任务不同,语义分割不仅需要识别图像中的物体,还需要对每个像素进行分类,从而实现对图像的细粒度理解和分析。
语义分割可以被看作是像素级别的图像分割,其目标是为图像中的每个像素分配一个特定的语义类别标签。每个像素都被视为图像的基本单位,因此语义分割可以提供更详细和准确的图像分析结果。
语义分割 vs 分类 :
在语义分割任务中,由于需要对每个像素进行分类,因此需要使用像素级别的损失函数。
语义分割任务中,图像中各个类别的像素数量通常不均衡,例如背景像素可能占据了大部分。
语义分割任务需要对图像中的每个像素进行分类,同时保持空间连续性。
损失函数
Dice Loss
Dice Loss 是一种常用于语义分割任务的损失函数,尤其在目标区域较小、类别不平衡(class imbalance)的情况下表现优异。它来源于 Dice 系数(Dice Coefficient) ,又称为 Sørensen-Dice 系数 ,是衡量两个样本集合之间重叠程度的一种指标。
Dice 系数衡量的是预测掩码与真实标签之间的相似性,公式如下:
其中:
:模型预测出的功能区域(如经过 sigmoid 后的概率值); :Ground Truth 掩码(二值化或软标签); :预测为正类且实际也为正类的部分(交集); :预测和真实中所有正类区域之和;
⚠️ 注意:Dice 系数范围是 [0, 1],越大越好。
Dice Loss 为了将其作为损失函数使用,我们通常取其补集:
有时也会加入一个平滑项 ϵ 防止除以零:
Dice Loss 的优势:
优势 | 描述 |
---|---|
对类别不平衡不敏感,更关注“有没有覆盖正确区域”,而不是“有多少点被正确分类” | 不像 BCE Loss 那样对负样本过多敏感 |
直接优化 IoU 的替代指标 | Dice 和 IoU 表现类似,但更易梯度下降 |
支持 soft mask 输入 | 可处理连续概率值,不需要先 threshold |
更关注整体区域匹配 | 而不是逐点分类 |
代码实现:
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
"""
初始化函数,支持加权和平均损失。
参数:
weight (Tensor): 各类别的权重(可选)
size_average (bool): 是否对 batch 中的样本取平均 loss
"""
super(DiceLoss, self).__init__()
# 该参数未在当前代码中使用,但保留接口以备后续扩展
self.weight = weight
# 控制是否对 batch 内 loss 取均值或求和
self.size_average = size_average
def forward(self, inputs, targets, smooth=1):
"""
前向传播函数,计算 Dice Loss。
参数:
inputs (Tensor): 模型输出的预测值(logits 或 raw output),形状为 [B, N]
targets (Tensor): 真实标签(ground truth mask),形状为 [B, N]
smooth (float): 平滑项,防止除零错误,默认为 1
返回:
dice_loss (Tensor): 计算得到的 Dice Loss
"""
# 如果你的模型最后没有 sigmoid 层,则需要在这里激活,否则应注释掉这行
inputs = F.sigmoid(inputs) # 将 logits 映射到 [0,1] 区间
# 将输入展平成一维张量,便于后续计算
# inputs: [B*N]
# targets: [B*N]
inputs = inputs.view(-1)
targets = targets.view(-1)
# 计算交集:预测与 GT 的重合部分
intersection = (inputs * targets).sum()
# 计算 Dice Coefficient,加入 smooth 防止除以零
dice_score = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
# 返回 Dice Loss,用 1 - Dice Coefficient
# 值越小表示匹配越好
return 1 - dice_score
BCE-Dice Loss
BCE-Dice Loss是将Dice Loss和标准的二元交叉熵(Binary Cross-Entropy, BCE)损失结合在一起的一种损失函数,通常用于分割模型中。它结合了两种 loss 的优点:
- BCE Loss :关注每个点的分类误差;
- Dice Loss :关注整体区域匹配度;
Binary Cross Entropy Loss(BCE Loss)
公式(逐点):
其中:
:真实标签(binary 或 soft mask); :模型输出的概率值;
特点:
- 对每个点单独计算分类误差;
- 强调预测与 GT 的一致性;
- 在类别平衡时效果好,但在前景远少于背景时容易偏向负样本;
Dice Loss
公式(简化版):
其中:
:预测概率; :真实标签; :平滑项,防止除以零;
特点:
- 不依赖绝对数量,而是关注预测和 GT 的交并比;
- 更适合前景极少的小区域识别;
- 能缓解类别不平衡问题;
为什么要把它们结合起来?
模型 | 缺陷 | 补充方式 |
---|---|---|
BCE Loss | 对前景响应弱,易受类别不平衡影响 | 加入 Dice Loss 增强区域匹配 |
Dice Loss | 对单个点的分类精度不够敏感 | 加入 BCE Loss 提高逐点判别能力 |
组合后的优势:
优势 | 描述 |
---|---|
✔️ 抗类别不平衡能力强 | Dice Loss 起主导作用 |
✔️ 对细节更敏感 | BCE Loss 提升边缘识别精度 |
✔️ 支持 soft mask 输入 | 可处理连续值掩码 |
✔️ 更稳定地收敛 | 两者互补,避免训练震荡 |
代码实现:
class DiceBCELoss(nn.Module):
def __init__(self, weight=None, size_average=True):
"""
初始化函数,构建一个组合损失函数 Dice + BCE。
参数:
weight (Tensor): 可选参数,用于类别加权;
size_average (bool): 是否对 batch 内样本取平均 loss(已弃用);
"""
super(DiceBCELoss, self).__init__()
# 这里暂时未使用 weight 和 size_average,保留接口以备扩展
def forward(self, inputs, targets, smooth=1):
"""
前向传播函数,计算预测输出与真实标签之间的 Dice Loss 与 BCE Loss 的加权和。
参数:
inputs (Tensor): 模型输出的 logits 或 raw 分数,形状为 [B, N]
targets (Tensor): 真实掩码(ground truth mask),形状为 [B, N]
smooth (float): 平滑项,防止除零错误,默认为 1
返回:
Dice_BCE (Tensor): Dice + BCE 组合损失值
"""
# 如果模型最后没有 sigmoid 层,这里需要激活
# 如果已经包含 sigmoid,则应注释掉这一行
inputs = F.sigmoid(inputs) # 将输入映射到概率空间 [0, 1]
# 将输入和目标展平成一维张量,便于后续计算
# inputs: [B*N]
# targets: [B*N]
inputs = inputs.view(-1)
targets = targets.view(-1)
# 计算交集:预测值和真实值都为 1 的区域
intersection = (inputs * targets).sum()
# 计算 Dice Loss:
# Dice Coefficient = (2 * intersection) / (inputs_sum + targets_sum)
# Dice Loss = 1 - Dice Coefficient
inputs_sum = inputs.sum()
targets_sum = targets.sum()
dice_score = (2. * intersection + smooth) / (inputs_sum + targets_sum + smooth)
dice_loss = 1 - dice_score
# 计算 Binary Cross Entropy Loss(BCE)
# 注意:F.binary_cross_entropy 默认要求 inputs 已经经过 sigmoid
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
# 组合损失:BCE + Dice Loss
Dice_BCE = BCE + dice_loss
return Dice_BCE
Jaccard/Intersection over Union (IoU) Loss
Jaccard Loss,也称为Intersection over Union (IoU) Loss,是一种常用的损失函数,用于语义分割任务中评估模型的分割结果与真实分割标签之间的相似性。它基于Jaccard指数(Jaccard Index),也称为 交并比(Intersection over Union, IoU)指标,用于度量两个集合之间的重叠程度。
- Jaccard Index(IoU)
其中:
:模型输出的概率值或二值化结果; :ground truth 掩码;- 分子是预测和 GT 的交集;
- 分母是两者的并集;
⚠️ IoU 值 ∈ [0, 1],越大越好。
- Jaccard Loss(IoU Loss)
为了将 IoU 转换为可优化的损失函数,我们取其补集:
这样,损失越小表示预测越接近真实标签。
为了避免除以零,通常加入平滑项
- Jaccard Loss 有以下几个优点:
特性 | 描述 |
---|---|
✔️ 对类别不平衡不敏感 | 不像 BCE Loss 那样偏向背景点 |
✔️ 关注整体区域匹配 | 强调预测与 GT 的空间一致性 |
✔️ 更适合评估边界模糊区域 | 如功能区域边缘不确定性较高 |
- 与其他 Loss 的对比
损失函数 | 是否支持 soft mask | 是否对类别不平衡敏感 | 是否直接优化 IoU | 输出范围 |
---|---|---|---|---|
BCE Loss | ❌ 否(需二值化) | ✅ 是 | ❌ 否 | [0, ∞) |
Focal Loss | ✅ 是(加权) | ✅ 是(缓解) | ❌ 否 | [0, ∞) |
Dice Loss | ✅ 是 | ✅ 是 | 近似于 IoU | [0, 1] |
Jaccard (IoU) Loss | ✅ 是 | ✅ 是 | ✅ 是 | [0, 1] |
虽然 Dice Loss 在实际训练中更稳定,但 Jaccard Loss 更贴近最终评估指标(IoU),适合在推理阶段作为验证标准。
代码实现:
class IoULoss(nn.Module):
def __init__(self, weight=None, size_average=True):
"""
初始化函数,构建一个基于 IoU(交并比)的损失函数。
参数:
weight (Tensor): 可选参数,用于类别加权(未使用)
size_average (bool): 是否对 batch 内样本取平均 loss(已弃用)
"""
super(IoULoss, self).__init__()
# weight 和 size_average 在此实现中未使用,保留接口以备后续扩展
def forward(self, inputs, targets, smooth=1):
"""
前向传播函数,计算预测输出与真实标签之间的 IoU Loss。
参数:
inputs (Tensor): 模型输出的原始 logit 或经过 sigmoid 的概率值;
形状为 [B, N]
targets (Tensor): ground truth 掩码,形状为 [B, N]
smooth (float): 平滑项,防止除零错误,默认为 1
返回:
iou_loss (Tensor): 计算得到的 IoU Loss
"""
# 如果模型最后没有 sigmoid 层,则在这里激活
# 如果已经包含 sigmoid,则应注释掉这一行
inputs = torch.sigmoid(inputs) # 将输入映射到 [0,1] 区间
# 将输入和目标展平成一维张量便于计算
# inputs: [B*N]
# targets: [B*N]
inputs = inputs.view(-1)
targets = targets.view(-1)
# 计算交集(Intersection),等价于 TP(True Positive)
intersection = (inputs * targets).sum()
# 计算并集:Union = input + target - intersection
total = (inputs + targets).sum()
union = total - intersection
# 计算 IoU Score,加入平滑项防止除以零
iou_score = (intersection + smooth) / (union + smooth)
# IoU Loss = 1 - IoU score,这样越接近 1,loss 越小
iou_loss = 1. - iou_score
return iou_loss
Focal Loss
Focal Loss 是一种针对类别不平衡(Class Imbalance)问题的损失函数改进方案,由何恺明团队在2017年论文《Focal Loss for Dense Object Detection》中提出,主要用于解决目标检测任务中前景-背景类别极端不平衡的问题(如1:1000)。其核心思想是通过调整难易样本的权重,使模型更关注难分类的样本。
Focal Loss 基于交叉熵损失进行扩展,将样本的权重进行动态调整。与交叉熵损失函数相比,Focal Loss引入了一个衰减因子
核心思想:
(1) 类别不平衡的问题
在分类任务中(尤其是目标检测),负样本(背景)往往远多于正样本(目标),导致:
模型被大量简单负样本主导,难以学习有效特征。
简单样本的梯度贡献淹没难样本的梯度。
(2) Focal Loss 的改进
降低易分类样本的权重:对模型已经分类正确的样本(高置信度)减少损失贡献。
聚焦难分类样本:对分类错误的样本(低置信度)保持高损失权重。
Focal Loss 基于标准交叉熵损失(Cross-Entropy Loss)改进而来。
(1) 标准交叉熵损失(CE Loss)
其中:
- p 是模型预测的概率(经过sigmoid/softmax)。
- y 是真实标签(0或1)。
(2) Focal Loss 定义
:类别平衡权重(通常 ),用于平衡正负样本数量差异。 :调节因子(通常 ),控制难易样本的权重衰减程度。
γ 参数用于抑制容易分类的样本,而 α 参数用于平衡正负类别的权重。两者解决的是不同维度的问题:
α:防止前景点(功能区域)被背景淹没,解决数据集中“类别数量不平衡”的问题(数据集级别);
γ:防止模型只关注简单样本,忽略难分类样本,解决模型训练时“简单样本主导梯度”的问题(样本级别);
综上,先通过 α 平衡类别数量,再通过 γ 抑制简单样本,两者协同提升模型性能。
关键参数的作用:
参数 | 作用 | 典型值 |
---|---|---|
控制难易样本权重: • • | 0.5 ~ 5 | |
平衡正负样本数量: • | 0.25 ~ 0.75 |
难样本vs易样:
易分类样本(如 p=0.9 ):
接近0,损失被大幅降低。难分类样本(如 p=0.1 ):
接近1,损失几乎不受影响。
假设两个正样本:
易样本:
(模型已自信分类)
标准 CE Loss:
Focal Loss(
): 损失权重降低 100 倍! 难样本:
(模型分类错误)
标准 CE Loss:
Focal Loss(
): 损失权重仅降低 20%。
应用场景:
目标检测(如RetinaNet): 解决前景(目标)与背景的极端不平衡问题。
医学图像分割: 病灶区域像素远少于正常组织。
任何类别不平衡的分类任务: 如欺诈检测、罕见疾病诊断等。
优缺点:
优点 | 缺点 |
---|---|
显著提升难样本的分类性能 | 需调参( |
抑制简单样本的梯度主导 | 对噪声标签敏感 |
兼容大多数分类模型 | 计算量略高于CE Loss |
Focal Loss 通过
动态调整样本权重,使模型聚焦难分类样本。参数选择:
:一般从2开始调优(值越大,简单样本抑制越强)。 :根据正负样本比例调整(如正样本少则增大 )。
适用场景:类别不平衡越严重,Focal Loss 效果越显著。
代码实现:
# 设置全局参数(可调)
ALPHA = 0.8 # 控制正样本(目标点)与负样本(非目标点)之间的损失权重;
# 若前景点稀疏(如 grasping area),建议设为较高值(如 0.25~0.75);
GAMMA = 2 # 聚焦参数,用于抑制易分类样本,放大难分类样本的影响;
class FocalLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
"""
初始化函数,构建一个基于 BCE 的改进版 Focal Loss。
参数:
weight (Tensor): 可选参数,用于类别加权(未使用);
size_average (bool): 是否对 batch 内样本取平均 loss(已弃用);
"""
super(FocalLoss, self).__init__()
# 当前实现未使用 weight 和 size_average,保留接口以备扩展
def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
"""
前向传播函数,计算预测输出与真实标签之间的 Focal Loss。
参数:
inputs (Tensor): 模型输出的原始 logit 或经过 sigmoid 的概率值;
形状为 [B, N](batch_size × 点数)
targets (Tensor): ground truth 掩码,形状为 [B, N]
alpha (float): 平衡因子,控制正类(功能区域)和负类(非功能区域)之间的损失权重;
前景点少 → alpha 高(如 0.75),防止被背景淹没;
gamma (float): 聚焦参数,抑制 easy examples,放大 hard examples;
smooth (float): 平滑项,防止除零错误,默认为 1
返回:
focal_loss (Tensor): 计算得到的 Focal Loss 值
"""
# 如果模型最后没有 sigmoid 层,则在这里激活
inputs = torch.sigmoid(inputs)
# 将输入展平便于后续计算
# inputs: [B*N], 表示每个点属于功能区域的概率;
# targets: [B*N], 表示每个点是否属于目标功能区域(soft/hard label);
inputs = inputs.view(-1)
targets = targets.view(-1)
# Step 1: 计算 Binary Cross Entropy Loss(BCE)
# 这里使用 'mean' reduction,表示对 batch 内取平均
ce_loss = F.binary_cross_entropy(inputs, targets, reduction='mean')
# Step 2: 计算 pt = exp(-ce_loss),即 e^{-ce_loss}
pt = torch.exp(-ce_loss) # shape: scalar
# Step 3: 按类别分配 alpha
alpha = torch.where(targets == 1, alpha, 1 - alpha)
# Step 4: 构建 Focal Weight:
# focal_weight = α * (1 - pt)^γ
# 目的是:让难分类样本获得更大的 loss 权重,从而引导模型学习更多语义信息
focal_weight = alpha * (1 - pt) ** gamma
# Step 5: 最终 Focal Loss = focal_weight × ce_loss
focal_loss = focal_weight * ce_loss
return focal_loss
关于计算 p_t(模型对真实类别的预测概率)代码解析:
pt = torch.exp(-ce_loss) # p_t = softmax(output)[target_class]
ce_loss = F.cross_entropy(...)
→ 这是交叉熵损失;-ce_loss
→ 负号;torch.exp(-ce_loss)
→ 求 exp(自然指数);
但实际上这行代码的意图是计算
所以:
pt = torch.exp(-ce_loss)
这个表达式其实是通过 CE loss 反推出来的
Tversky Loss
Tversky Loss的设计灵感来自Tversky指数(Tversky index),它是一种用于度量集合之间相似性的指标,同时也是 Dice Loss 的一种泛化形式,通过引入两个可调节参数来增强模型对假阳性(False Positives)和假阴性(False Negatives)的敏感度控制。
Tversky Loss 的核心是 Tversky 系数:
然后损失就是:
其中:
TP :真阳性(True Positive)= 预测为正类,且真实也为正类的样本数
FP :假阳性(False Positive)= 预测为正类,但真实是负类的样本数
FN :假阴性(False Negative)= 预测为负类,但真实是正类的样本数
α 和 β 是两个可调节的超参数
α 越大,FP 的影响就越大 → 模型更不喜欢“误报”
β 越大,FN 的影响就越大 → 模型更不喜欢“漏报”
如果你设置 α>β ,说明你更讨厌“误检”
如果你设置 β>α ,说明你更讨厌“漏检”
分母中的 TP+α⋅FP+β⋅FN 构成了一个“加权惩罚项”
例如:
α=0.3, β=0.7 → 更重视召回率(Recall)
α=0.7, β=0.3 → 更重视精确率(Precision)
当 alpha=beta=0.5 时,Tversky指数简化为Dice系数,该系数也等于F1得分。
当 alpha=beta=1 时,公式转化为Tanimoto系数,而当 alpha+beta=1 时,得到一组F-beta得分。
# 设置默认参数:当 alpha = beta = 0.5 时,等价于 Dice Loss
ALPHA = 0.5 # 控制假阳性(FP)的权重
BETA = 0.5 # 控制假阴性(FN)的权重
class TverskyLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
"""
初始化函数
参数:
weight: 可选,类别权重(用于处理类别不平衡)
size_average: 如果为 True,则返回所有样本损失的平均值
"""
super(TverskyLoss, self).__init__()
# 本类中不直接使用 weight 和 size_average,但保留它们作为接口兼容
self.weight = weight
self.size_average = size_average
def forward(self, inputs, targets, smooth=1, alpha=ALPHA, beta=BETA):
"""
前向传播计算损失值
参数:
inputs: 模型输出的预测结果(logits),形状如 (N, H, W) 或 (N, C, H, W)
targets: 真实标签(ground truth),形状与 inputs 相同
smooth: 平滑系数,防止除以零
alpha: FP 的惩罚权重
beta: FN 的惩罚权重
返回:
loss: 计算得到的 Tversky Loss
"""
# 如果模型最后一层没有 Sigmoid 激活函数,请取消下面这行注释
# 对输出应用 Sigmoid 函数,将 logits 转换为概率 [0,1]
inputs = F.sigmoid(inputs)
# 将输入和目标张量展平为一维,便于后续计算 TP、FP、FN
inputs = inputs.view(-1)
targets = targets.view(-1)
# 真阳性(True Positive):预测为正且实际也为正的像素数量
TP = (inputs * targets).sum()
# 假阳性(False Positive):预测为正但实际为负的像素数量
FP = ((1 - targets) * inputs).sum()
# 假阴性(False Negative):预测为负但实际为正的像素数量
FN = (targets * (1 - inputs)).sum()
# 计算 Tversky 系数(相似度指标)
# 分母中:TP + α·FP + β·FN
Tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth)
# 最终损失是 1 - Tversky,这样在训练中最小化损失就等于最大化重叠度
return 1 - Tversky
Lovasz Hinge Loss
Lovasz Hinge Loss的设计思想是,在计算IoU得分之前,根据预测误差对预测结果进行排序,然后累积计算每个误差对IoU得分的影响。然后,将该梯度向量与初始误差向量相乘,以最大程度地惩罚降低IoU得分的预测结果。
https://github.com/bermanmaxim/LovaszSoftmax
Combo Loss
Combo Loss 是一种结合了多个损失函数优点的混合损失函数,特别适用于图像分割任务。它将 Dice Loss 和 交叉熵损失(CrossEntropy Loss) 相结合,并引入一个可调节的权重参数,使得模型在训练过程中可以更灵活地平衡这两部分损失。
核心思想:
Combo Loss = α × CrossEntropy + (1 - α) × Dice Loss
或者更广义地:
Combo Loss = α × 分类误差(CE)+ β × 区域重叠误差(Dice)
其中 α + β = 1,α 控制分类误差的重要性,β 控制区域匹配误差的重要性。
数学定义:
假设我们有预测概率图
- 交叉熵损失(Binary Cross Entropy):
- Dice Loss:
- Combo Loss 定义为:
其中:
:控制两个损失之间的权重比例若
:仅使用交叉熵损失若
:仅使用 Dice Loss
为什么使用 Combo Loss:
优势 | 描述 |
---|---|
✔️ 兼顾像素级精度和区域重叠度 | CE 关注每个像素的分类准确性,Dice 关注整体区域匹配程度 |
✔️ 对类别不平衡问题鲁棒 | 在前景像素远少于背景像素时表现良好(如医学图像) |
✔️ 更稳定的训练过程 | 避免单一损失可能带来的训练不稳定性 |
✔️ 可调性强 | 通过调整 α 参数,适应不同任务需求 |
对比其他损失函数:
损失函数 | 是否关注像素分类? | 是否关注区域匹配? | 是否可调? | 是否适合类别不平衡? |
---|---|---|---|---|
CrossEntropy Loss | ✅ | ❌ | ❌ | ❌ |
Dice Loss | ❌ | ✅ | ❌ | ✅ |
Tversky Loss | ❌ | ✅ ✅ | ✅ | ✅ ✅ |
Combo Loss | ✅ ✅ | ✅ | ✅ | ✅ ✅ |
代码实现:
# 超参数设置说明:
ALPHA = 0.5 # 控制交叉熵中正负样本的权重
# 如果 ALPHA < 0.5:对假阳性(FP)惩罚更重(更关注精确率)
# 如果 ALPHA > 0.5:对假阴性(FN)惩罚更重(更关注召回率)
CE_RATIO = 0.5 # 控制交叉熵损失和 Dice 损失之间的权重分配
# CE_RATIO 越大,交叉熵在总损失中的占比越高
class ComboLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
"""
初始化函数
参数:
weight: 可选,类别权重(用于处理类别不平衡)
size_average: 如果为 True,则返回所有样本损失的平均值
"""
super(ComboLoss, self).__init__()
# 这里不直接使用 weight 和 size_average,但保留作为接口兼容
self.weight = weight
self.size_average = size_average
def forward(self, inputs, targets, smooth=1, alpha=ALPHA, beta=BETA, eps=1e-9):
"""
前向传播计算 Combo Loss
参数:
inputs: 模型输出的概率值(经过 Sigmoid),形状如 (N, H, W)
targets: 真实标签,形状与 inputs 相同,值为 0 或 1
smooth: 平滑系数,防止除以零
alpha: 控制 FP/FN 的惩罚比例(用于交叉熵部分)
eps: 防止 log(0) 出现的小常数
返回:
combo_loss: 计算得到的 Combo Loss
"""
# 将输入和目标张量展平为一维,便于后续计算
inputs = inputs.view(-1)
targets = targets.view(-1)
# 计算 Dice Loss 所需的交集
intersection = (inputs * targets).sum()
# Dice Score(区域匹配度)
dice_score = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
# 加入数值稳定性处理,防止 log(0) 出现 NaN
# torch.clamp(x, min=a, max=b) 是 PyTorch 中的一个函数,用于将张量 x 中的每个元素限制在 [a, b] 区间内:
# 这里把所有 inputs 中的值限制在区间 [eps, 1.0 - eps] 内,防止出现 0 或 1 的极端值。
inputs = torch.clamp(inputs, eps, 1.0 - eps)
# 加权交叉熵损失(Weighted Cross Entropy)
# 根据 ALPHA 参数调整正类和负类的权重
weighted_ce = - (ALPHA * targets * torch.log(inputs)) - ((1 - ALPHA) * (1 - targets) * torch.log(1 - inputs))
# 对损失求均值
weighted_ce = weighted_ce.mean()
# Combo Loss 是交叉熵和 Dice Loss 的加权组合
# 注意:这里使用的是负的 Dice Score(因为要最小化损失)
combo_loss = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice_score)
return combo_loss
上面代码实现中使用的是加权交叉熵损失:
如何选择?
任务需求:根据特定的分割任务的需求和特点,选择适合的损失函数。例如,对于类别不平衡的数据集,可以考虑使用Tversky Loss或Combo Loss等能够处理不平衡情况的损失函数。
实验评估:在实验中,使用不同的损失函数进行训练,并评估它们在验证集或测试集上的性能。比较它们在IoU、准确率、召回率等指标上的表现,选择性能最佳的损失函数。
超参数调整:一些损失函数具有额外的超参数,如Tversky Loss中的alpha和beta,可以通过调整这些超参数来进一步优化损失函数的性能。