Binary Cross-Entropy with Logits and Class Weights
Binary Cross-Entropy with Logits and Class Weights
The BCEWithLogitsLoss function in PyTorch is a numerically stable implementation of the binary cross-entropy (BCE) loss
It combines a Sigmoid activation and the BCE computation in a single operation, which avoids numerical instabilities for large positive or negative logits
It is used for binary classification, where the model predicts a single continuous value called a logit for each sample
That logit can be converted into a probability using the Sigmoid function
σ(z) = 1 / (1 + exp(−z))
Given a target label y ∈ {0, 1} and a predicted logit z, the binary cross-entropy loss is defined as
BCE(z, y) = − [ y · log(σ(z)) + (1 − y) · log(1 − σ(z)) ]
To handle class imbalance, PyTorch introduces a positive-class weight noted α (argument pos_weight)
BCE[α](z, y) = − [ α · y · log(σ(z)) + (1 − y) · log(1 − σ(z)) ]
When the dataset contains far fewer positive than negative examples, you can set pos_weight = number_of_negatives / number_of_positives to make the model pay more attention to positive samples
For numerical stability, PyTorch computes this loss internally without explicitly calling the Sigmoid function, using the equivalent expression
ℓ(z, y) = max(z, 0) − z · y + log(1 + exp(−|z|))
This formulation ensures correct gradients even for large |z|
Typical PyTorch usage
import torch
import torch.nn as nn
logits = torch.tensor([0.3, -1.2, 2.1]) # raw model outputs
targets = torch.tensor([1., 0., 1.]) # binary labels
pos_weight = torch.tensor([3.0]) # emphasize positive samples
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss = criterion(logits, targets)
print(loss.item())
Key points
-
Pass raw logits (no Sigmoid before the loss)
-
Use pos_weight to correct class imbalance
-
The loss is averaged by default
-
Works for binary and multi-label tasks