Binary Cross-Entropy with Logits and Class Weights

The BCEWithLogitsLoss function in PyTorch is a numerically stable implementation of the binary cross-entropy (BCE) loss

It combines a Sigmoid activation and the BCE computation in a single operation, which avoids numerical instabilities for large positive or negative logits

It is used for binary classification, where the model predicts a single continuous value called a logit for each sample

That logit can be converted into a probability using the Sigmoid function

  σ(z) = 1 / (1 + exp(−z))

Given a target label y ∈ {0, 1} and a predicted logit z, the binary cross-entropy loss is defined as

  BCE(z, y) = − [ y · log(σ(z)) + (1 − y) · log(1 − σ(z)) ]

To handle class imbalance, PyTorch introduces a positive-class weight noted α (argument pos_weight)

  BCE[α](z, y) = − [ α · y · log(σ(z)) + (1 − y) · log(1 − σ(z)) ]

When the dataset contains far fewer positive than negative examples, you can set pos_weight = number_of_negatives / number_of_positives to make the model pay more attention to positive samples

For numerical stability, PyTorch computes this loss internally without explicitly calling the Sigmoid function, using the equivalent expression

  ℓ(z, y) = max(z, 0) − z · y + log(1 + exp(−|z|))

This formulation ensures correct gradients even for large |z|

Typical PyTorch usage

import torch
import torch.nn as nn

logits = torch.tensor([0.3, -1.2, 2.1])  # raw model outputs
targets = torch.tensor([1., 0., 1.])     # binary labels
pos_weight = torch.tensor([3.0])         # emphasize positive samples

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss = criterion(logits, targets)
print(loss.item())

Key points

  • Pass raw logits (no Sigmoid before the loss)

  • Use pos_weight to correct class imbalance

  • The loss is averaged by default 

  • Works for binary and multi-label tasks

Modifié le: vendredi 17 octobre 2025, 15:48