Wondering if weights are handled correctly #61
Replies: 13 comments
-
Thanks for your interest of the project. Good question. I was following Tianqi Chen's answer (xgboost author) on the question of how to use the weights. As you point out, one could also use the weights in the calculation of the loss directly. However, weighting means increasing the contribution of an example to the loss function. That means the contribution of the gradient/hessian of that example will also be larger. That's why xgboost multiplies the gradient and the hessian by the weights, not the target values/loss function. |
Beta Was this translation helpful? Give feedback.
-
I think in xgboost, loss, gradient, and hessian are calculated per instance first, and weights are applied per instance too. As you can see here, and here. In xgboostLSS, I think you compute total loss across the whole dataset first without weights, and then get the gradient and hessian from that total loss. I feel that's not equivalent? The whole reason I'm asking this is because I tried to get average loss per sample from the total loss output by XGBoostLSS, but I couldn't get similar values I got from using a XGB accelerated failure time model. |
Beta Was this translation helpful? Give feedback.
-
Whether you incorporate the weights directly into the loss or post-multiply the gradients and hessians, the effect on the boosting procedure is the same. To be more precise: Hence both give the same gradients and hessians. We can double check with the following example # Imports
import torch
from torch.distributions import Normal
from torch.autograd import grad as autograd
import numpy as np
np.set_printoptions(suppress=True)
# Functions
def get_derivs(nll: torch.tensor, predt: torch.tensor) -> np.ndarray:
""" Calculates gradients and hessians.
Args:
nll: torch.tensor, calculated NLL
predt: torch.tensor, list of predicted paramters
Returns:
grad, hess
"""
# Gradient and Hessian
grad = autograd(nll, inputs=predt, create_graph=True)
hess = [autograd(grad[i].nansum(), inputs=predt[i], retain_graph=True)[0] for i in range(len(grad))]
return grad, hess
# Data
torch.manual_seed(123)
y = torch.randn(10).reshape(-1,1)
weights = torch.abs(torch.randn(10).reshape(-1,1))
weights /= weights.sum()
loc = torch.randn(10).reshape(-1,1)
loc.requires_grad = True
scale = torch.randn(10).reshape(-1,1)
scale.requires_grad = True
scale_exp = torch.exp(scale)
# Gradients and Hessians with weights
params = [loc, scale]
normal_dist = Normal(loc=loc, scale=scale_exp)
nll = torch.nansum(normal_dist.log_prob(y))
grad, hess = get_derivs(nll, params)
grad_weights = torch.round((torch.cat(grad,dim=1).detach())*weights, decimals=4)
hess_weights = torch.round((torch.cat(hess,dim=1).detach())*weights, decimals=4)
# Weighted Loss
params = [loc, scale]
normal_dist = Normal(loc=loc, scale=scale_exp)
nll_weight = torch.nansum(normal_dist.log_prob(y)*weights)
grad, hess = get_derivs(nll_weight, params)
grad_loss_weight = torch.round(torch.cat(grad,dim=1).detach(), decimals=4)
hess_loss_weight = torch.round(torch.cat(hess,dim=1).detach(), decimals=4)
# Check if the two are the same
print(torch.equal(grad_weights, grad_loss_weight)) # True
print(torch.equal(hess_weights, hess_loss_weight)) # True |
Beta Was this translation helpful? Give feedback.
-
That example is very helpful. Thanks a lot! I guess to get the weighted average loss, I have to modify the function to use the first method. But it won't change learning. |
Beta Was this translation helpful? Give feedback.
-
You can modify the metric-function that is used for cv/early stopping etc. With that you can use a weighted loss that doesn't affect training. |
Beta Was this translation helpful? Give feedback.
-
Thanks. I tried changing both the get_params_loss and compute_gradients_and_hessians functions to use weighting when computing loss but not use weighting when computing gradients and hessians. This does give me the typical validation loss I see with the XGB AFT framework. In any case, the current approach you have definitely trains correctly with weights. Thanks for helping me understand the method! |
Beta Was this translation helpful? Give feedback.
-
Reopening this, since I just realized that not using weights when calculating the validation loss creates a problem. Early stopping relies on accurate loss calculated with weights, otherwise it might stop prematurely or stop too late. I'd suggest that we still go with the other approach that calculates loss using weights first, and not using weights later to scale gradient or hessian |
Beta Was this translation helpful? Give feedback.
-
Thanks for your comment. Can you please verify your statement with a reproducible example, using both an unweighted and weighted loss for early stopping. Thanks. |
Beta Was this translation helpful? Give feedback.
-
Hi, here's an example: from xgboostlss.model import *
from xgboostlss.distributions.Gaussian import *
from xgboostlss.datasets.data_loader import load_simulated_gaussian_data
train, test = load_simulated_gaussian_data()
X_train, y_train = train.filter(regex="x"), train["y"].values
X_test, y_test = test.filter(regex="x"), test["y"].values
# create a weighted y, where the last 1000 samples are multiplied by 2, but weighted by 1/2.
# This should produce the same result as the unweighted case.
y_test_weighted = y_test.copy()
y_test_weighted[2000:] *= 2
weights = np.ones_like(y_test)
weights[2000:] = 1/2
dtrain = xgb.DMatrix(X_train, label=y_train)
deval = xgb.DMatrix(X_test, label=y_test)
deval_weighted = xgb.DMatrix(X_test, label=y_test_weighted, weight=weights)
m = XGBoostLSS(Gaussian())
opt_params = {
"eta": 0.10015347345470738,
"max_depth": 8,
"gamma": 24.75078796889987,
"subsample": 0.6161756203438147,
"colsample_bytree": 0.851057889242629,
"min_child_weight": 147.09687376037445,
"booster": "gbtree",
}
m.train(opt_params, dtrain, evals=[(deval, "val")], verbose_eval=5, num_boost_round=100, early_stopping_rounds=30)
m.train(opt_params, dtrain, evals=[(deval_weighted, "val")], verbose_eval=5, num_boost_round=100, early_stopping_rounds=30) The first training call would go on till 100 iterations. The second one will only run for 30 rounds and get stopped. |
Beta Was this translation helpful? Give feedback.
-
Sorry I realize that the example wasn’t right. I need to repeat the last thousand elements twice not multiplying them by 2. I will update it tomorrow. |
Beta Was this translation helpful? Give feedback.
-
OK. So here's the updated example. from xgboostlss.model import *
from xgboostlss.distributions.Gaussian import *
from xgboostlss.datasets.data_loader import load_simulated_gaussian_data
train, test = load_simulated_gaussian_data()
X_train, y_train = train.filter(regex="x"), train["y"].values
X_test, y_test = test.filter(regex="x"), test["y"].values
# Add some more errors to the test set
y_test[2000:] = y_test[2000:] * 1.2
# Repeat those errors 2 times
X_test_weighted = pd.concat([X_test, X_test.iloc[2000:, :]], axis=0)
y_test_weighted = np.concatenate([y_test, y_test[2000:]])
weights = np.ones_like(y_test_weighted)
# Assign half the weight to the repeated errors
weights[2000:] = 1/2
dtrain = xgb.DMatrix(X_train, label=y_train)
deval = xgb.DMatrix(X_test, label=y_test)
deval_weighted = xgb.DMatrix(X_test_weighted, label=y_test_weighted, weight=weights)
dtrain = xgb.DMatrix(X_train, label=y_train)
deval = xgb.DMatrix(X_test, label=y_test)
deval_weighted = xgb.DMatrix(X_test_weighted, label=y_test_weighted, weight=weights)
m = XGBoostLSS(Gaussian())
opt_params = {
"eta": 0.10015347345470738,
"max_depth": 8,
"gamma": 24.75078796889987,
"subsample": 0.6161756203438147,
"colsample_bytree": 0.851057889242629,
"min_child_weight": 147.09687376037445,
"booster": "gbtree",
}
m.train(opt_params, dtrain, evals=[(deval, "val")], verbose_eval=5, num_boost_round=200, early_stopping_rounds=30)
m.train(opt_params, dtrain, evals=[(deval_weighted, "val")], verbose_eval=5, num_boost_round=200, early_stopping_rounds=30) The first training output without weights stops at round 34:
The second training output with weights stops at round 32:
You can see that the one with down weighted repeated errors stops earlier than the one without those repeated errors. This is basically what I see in my data, but the effect is even stronger with my data. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the code snippet. I am on vacation until end of October, so please expect some delay in my reply. |
Beta Was this translation helpful? Give feedback.
-
@yunfeng-eiq I am not fully sure I understand your code example. Can you please adapt the objective_fn and the metric_fn for your example. You can easily overwrite the functions once you've specified the |
Beta Was this translation helpful? Give feedback.
-
Hi,
I'm using xgboostlss with a dataset that has sample weights. I wonder if the library is handling weights correctly. I see function here that computes total loss over the entire dataset, but it did so without accounting for weights. I feel like this should compute a weighted sum?
Beta Was this translation helpful? Give feedback.
All reactions