| import torch |
| import torch.nn as nn |
|
|
|
|
| class WeibullNLLLoss(nn.Module): |
| """ |
| Negative log-likelihood loss for Weibull survival distribution. |
| Handles right-censored observations. |
| |
| weibull_params: (batch, 2) — log_scale, log_shape |
| halflife_days: (batch,) — observed time (1.0 for censored rows) |
| censored: (batch,) — bool, True = right-censored |
| """ |
| def forward(self, weibull_params, halflife_days, censored): |
| log_scale, log_shape = weibull_params[:, 0], weibull_params[:, 1] |
|
|
| |
| log_scale = torch.clamp(log_scale, -10, 10) |
| log_shape = torch.clamp(log_shape, -10, 10) |
|
|
| scale = torch.exp(log_scale) |
| shape = torch.exp(log_shape) |
|
|
| |
| t = torch.clamp(halflife_days, min=1e-6) |
|
|
| |
| log_pdf = (log_shape + (shape - 1) * torch.log(t) |
| - shape * log_scale - (t / scale) ** shape) |
| log_sf = -((t / scale) ** shape) |
|
|
| loss = torch.where(censored, -log_sf, -log_pdf) |
| return loss.mean() |
|
|