Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/bayesian ridge #247

Merged
merged 90 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
ae425f1
mean pinball loss
JoaquinIglesiasTurina Feb 27, 2024
ad92978
add links to origin of cases and formulas
JoaquinIglesiasTurina Feb 27, 2024
8cbcabb
Update lib/scholar/metrics/regression.ex
JoaquinIglesiasTurina Feb 27, 2024
67b716d
Update lib/scholar/metrics/regression.ex
JoaquinIglesiasTurina Feb 27, 2024
864a0dd
fix tests to use opts in optional arguments
JoaquinIglesiasTurina Feb 28, 2024
3be9884
add default alpha 0.5, same as sklearn
JoaquinIglesiasTurina Feb 28, 2024
8572534
added sample_weights option to mean_pinball_loss
JoaquinIglesiasTurina Feb 28, 2024
30952da
added multioutput support
JoaquinIglesiasTurina Feb 28, 2024
0ca6725
fixed multioutput behavior to be on par with sklearn
JoaquinIglesiasTurina Feb 28, 2024
cb2a2fa
added comments for better multioutput understanding
JoaquinIglesiasTurina Feb 28, 2024
0ab4741
add option to allow for 2 dimensional sample weights
JoaquinIglesiasTurina Mar 1, 2024
b7f4a9d
add nimble options and rename sample_weight to
JoaquinIglesiasTurina Mar 1, 2024
ae8b3b0
fix tests
JoaquinIglesiasTurina Mar 1, 2024
afe88ee
fixed sample_weights: as tensor behaviour
JoaquinIglesiasTurina Mar 2, 2024
421a8ef
fixed call to NimbleOptions to be consistent
JoaquinIglesiasTurina Mar 2, 2024
9be10e4
fixed multi_weights option and docs
JoaquinIglesiasTurina Mar 2, 2024
02b285d
Update lib/scholar/metrics/regression.ex
JoaquinIglesiasTurina Mar 3, 2024
373eda0
use assert_all_close on multi output pinball loss tests
JoaquinIglesiasTurina Mar 3, 2024
3c508a4
run formatter
JoaquinIglesiasTurina Mar 3, 2024
2cf1c98
working on bayesian ridge
JoaquinIglesiasTurina Mar 21, 2024
895dd97
bayesian ridge algorithm works for simplest case
JoaquinIglesiasTurina Mar 23, 2024
1989c3f
run formatter
JoaquinIglesiasTurina Mar 23, 2024
f5ba5b6
refactor recursion to use while loop
JoaquinIglesiasTurina Mar 25, 2024
115c447
cleanup code and add convergence message
JoaquinIglesiasTurina Mar 25, 2024
1a3070c
add test case
JoaquinIglesiasTurina Mar 26, 2024
c0ebe68
simple case works
JoaquinIglesiasTurina Mar 27, 2024
5533ab5
reshape for multi feat
JoaquinIglesiasTurina Mar 27, 2024
df1be29
expanded test passing
JoaquinIglesiasTurina Mar 27, 2024
afa417d
better test name
JoaquinIglesiasTurina Mar 27, 2024
a761eab
use new_axis where I should
JoaquinIglesiasTurina Mar 27, 2024
9cdeecc
iterations and eps as options. remove solve
JoaquinIglesiasTurina Mar 27, 2024
d5f2510
Options documented and renamed
JoaquinIglesiasTurina Mar 27, 2024
0f054c6
refactor options
JoaquinIglesiasTurina Mar 28, 2024
7cc4683
fix alpha default option
JoaquinIglesiasTurina Mar 28, 2024
7d8d470
model contains all fitted parameters
JoaquinIglesiasTurina Mar 28, 2024
e1e62d6
add next test
JoaquinIglesiasTurina Mar 28, 2024
a575c3b
add intrecept option
JoaquinIglesiasTurina Mar 28, 2024
6d5d2d4
add test weights options
JoaquinIglesiasTurina Mar 28, 2024
fc4e124
better naming + next steps
JoaquinIglesiasTurina Mar 28, 2024
c96a6f0
add eps to initial alpha to avoid division by 0
JoaquinIglesiasTurina Mar 28, 2024
bd75143
comment on previous commit
JoaquinIglesiasTurina Mar 28, 2024
37de967
clean gamma calculation
JoaquinIglesiasTurina Mar 28, 2024
469a9b9
add another todo
JoaquinIglesiasTurina Mar 28, 2024
4d2d782
run formatter
JoaquinIglesiasTurina Mar 28, 2024
3cf69c2
remove unused variable
JoaquinIglesiasTurina Mar 29, 2024
988243c
fix regularization parameter types
JoaquinIglesiasTurina Mar 29, 2024
889de37
working on test to compute score
JoaquinIglesiasTurina Mar 29, 2024
8c969c6
better naming + data
JoaquinIglesiasTurina Mar 29, 2024
514b167
add diabetes data
JoaquinIglesiasTurina Mar 30, 2024
e4cd5b0
add test using diabetes data
JoaquinIglesiasTurina Mar 30, 2024
f7b9a57
fix score computation
JoaquinIglesiasTurina Mar 30, 2024
60004c7
formatter
JoaquinIglesiasTurina Mar 30, 2024
0248749
add mini test to check linear regression timeout
JoaquinIglesiasTurina Apr 3, 2024
9f114f4
fixed multiple scores
JoaquinIglesiasTurina Apr 3, 2024
92e434e
added sigma
JoaquinIglesiasTurina Apr 3, 2024
7a53d4d
add required underscores
JoaquinIglesiasTurina Apr 3, 2024
cd3822c
fix predict function. scores are list and cannot be defn
JoaquinIglesiasTurina Apr 3, 2024
fdd0324
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 5, 2024
287eb6a
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 5, 2024
3694a0d
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 5, 2024
199e1f2
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 5, 2024
c409e83
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 5, 2024
d9628a6
Update test/scholar/linear/bayesian_ridge_regression_test.exs
JoaquinIglesiasTurina Apr 5, 2024
a4c6728
Update test/scholar/linear/bayesian_ridge_regression_test.exs
JoaquinIglesiasTurina Apr 5, 2024
f14f3fa
Update test/scholar/linear/bayesian_ridge_regression_test.exs
JoaquinIglesiasTurina Apr 5, 2024
8e0fb45
Update test/scholar/linear/bayesian_ridge_regression_test.exs
JoaquinIglesiasTurina Apr 5, 2024
9ba1bc4
Update test/scholar/linear/bayesian_ridge_regression_test.exs
JoaquinIglesiasTurina Apr 5, 2024
44d10e6
Update test/scholar/linear/bayesian_ridge_regression_test.exs
JoaquinIglesiasTurina Apr 5, 2024
5d59f87
remove show test
JoaquinIglesiasTurina Apr 5, 2024
3ecde56
remove debug inspec
JoaquinIglesiasTurina Apr 5, 2024
0d43838
linear regression dependency is no longer required
JoaquinIglesiasTurina Apr 5, 2024
0006e71
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 5, 2024
de7b972
dot product without transpose
JoaquinIglesiasTurina Apr 5, 2024
7762408
minor fixes, and cleanup
JoaquinIglesiasTurina Apr 5, 2024
5ae61c4
fix test types
JoaquinIglesiasTurina Apr 5, 2024
7ce464b
fixed optional scores and n_features > n_samples
JoaquinIglesiasTurina Apr 5, 2024
0fe5c1a
reduce diabetes data sample
JoaquinIglesiasTurina Apr 5, 2024
e4ed46a
test if jit compilable
JoaquinIglesiasTurina Apr 5, 2024
95d673f
formatter run
JoaquinIglesiasTurina Apr 5, 2024
5823c97
remove multi_weights duplication
JoaquinIglesiasTurina Apr 14, 2024
58f36ce
wrote docs
JoaquinIglesiasTurina Apr 14, 2024
98c51cc
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 21, 2024
52b2084
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 21, 2024
4570fc5
Update lib/scholar/linear/bayesian_ridge_regression.ex
JoaquinIglesiasTurina Apr 21, 2024
06eea91
fix references
JoaquinIglesiasTurina Apr 21, 2024
ccc2358
undo scale suggestion
JoaquinIglesiasTurina Apr 21, 2024
8253e0b
remove jit compilable test case
JoaquinIglesiasTurina Apr 21, 2024
4670db6
formatter
JoaquinIglesiasTurina Apr 21, 2024
f0e87a7
refactored rescale fuction
JoaquinIglesiasTurina Apr 21, 2024
db92d39
formatter
JoaquinIglesiasTurina Apr 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 320 additions & 0 deletions lib/scholar/linear/bayesian_ridge_regression.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
defmodule Scholar.Linear.BayesianRidgeRegression do
require Nx
import Nx.Defn
import Scholar.Shared

@derive {Nx.Container,
containers: [:coefficients, :intercept, :alpha, :lambda, :sigma, :rmse, :iterations, :scores]}
defstruct [:coefficients, :intercept, :alpha, :lambda, :sigma, :rmse, :iterations, :scores]

opts = [
iterations: [
type: :pos_integer,
default: 300,
doc: """
Maximum number of iterations before stopping the fitting algorithm.
The number of iterations may be lower is parameters converge.
"""
],
sample_weights: [
type:
{:or,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would allow sample_weights to be only type: {:custom, Scholar.Options, :weights, []}.

[
{:custom, Scholar.Options, :non_negative_number, []},
{:list, {:custom, Scholar.Options, :non_negative_number, []}},
{:custom, Scholar.Options, :weights, []}
]},
doc: """
The weights for each observation. If not provided,
all observations are assigned equal weight.
"""
],
fit_intercept?: [
type: :boolean,
default: true,
doc: """
If set to `true`, a model will fit the intercept. Otherwise,
the intercept is set to `0.0`. The intercept is an independent term
in a linear model. Specifically, it is the expected mean value
of targets for a zero-vector on input.
"""
],
alpha_init: [
type: {:custom, Scholar.Options, :non_negative_number, []},
doc: ~S"""
The initial value for alpha. This parameter influences the precision of the noise.
`:alpha` must be a non-negative float i.e. in [0, inf).
Defaults to 1/Var(y).
"""
],
lambda_init: [
type: {:custom, Scholar.Options, :non_negative_number, []},
default: 1.0,
doc: ~S"""
The initial value for lambda. This parameter influences the precision of the weights.
`:lambda` must be a non-negative float i.e. in [0, inf).
Defaults to 1.
"""
],
alpha_1: [
type: {:custom, Scholar.Options, :non_negative_number, []},
default: 1.0e-6,
doc: ~S"""
Hyper-parameter : shape parameter for the Gamma distribution prior
over the alpha parameter.
"""
],
alpha_2: [
type: {:custom, Scholar.Options, :non_negative_number, []},
default: 1.0e-6,
doc: ~S"""
Hyper-parameter : inverse scale (rate) parameter for the Gamma distribution prior
over the alpha parameter.
"""
],
lambda_1: [
type: {:custom, Scholar.Options, :non_negative_number, []},
default: 1.0e-6,
doc: ~S"""
Hyper-parameter : shape parameter for the Gamma distribution prior
over the lambda parameter.
"""
],
lambda_2: [
type: {:custom, Scholar.Options, :non_negative_number, []},
default: 1.0e-6,
doc: ~S"""
Hyper-parameter : inverse scale (rate) parameter for the Gamma distribution prior
over the lambda parameter.
"""
],
eps: [
type: :float,
default: 1.0e-8,
doc:
"The convergence tolerance. When `Nx.sum(Nx.abs(coef - coef_new)) < :eps`, the algorithm is considered to have converged."
]
]

@opts_schema NimbleOptions.new!(opts)
deftransform fit(x, y, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)

opts =
[
sample_weights_flag: opts[:sample_weights] != nil
] ++
opts

{sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would set default sample_weights to Nx.broadcast(Nx.as_type(1.0, x_type), {num_samples}).

x_type = to_float_type(x)

sample_weights =
Copy link
Member

@krstopro krstopro Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to check if sample_weights is a tensor here. Doing Nx.tensor on a tensor will simply return it. E.g.

tensor = Nx.tensor([1, 2, 3, 4, 5, 6])
Nx.tensor(tensor)

will give

#Nx.Tensor<
  s64[6]
  [1, 2, 3, 4, 5, 6]
>

if Nx.is_tensor(sample_weights),
do: Nx.as_type(sample_weights, x_type),
else: Nx.tensor(sample_weights, type: x_type)

# handle vector types
# handle default alpha value, add eps to avoid division by 0
eps = Nx.Constants.smallest_positive_normal(x_type)
default_alpha = Nx.divide(1, Nx.add(Nx.variance(x), eps))
JoaquinIglesiasTurina marked this conversation as resolved.
Show resolved Hide resolved
alpha = Keyword.get(opts, :alpha_init, default_alpha)
alpha = Nx.tensor(alpha, type: x_type)
opts = Keyword.put(opts, :alpha_init, alpha)
Copy link
Contributor

@josevalim josevalim Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensors should never be passed as options. You should always pass all tensors as arguments to the defn function. Your job in this function is to:

  1. If any option should become a tensor, you remove it from the options and pass argument
  2. All other options (which are not tensors) can be passed as a group of options to the defn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that model hyper parameters, which are plain scalars, are ok to be passed as opts to defn.
Is that understanding correct?

Please note that those hyperparameters need to be pased to the while loop. So, they get converted to tensors, there. Right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends. Everything that is passed as an option to a defn results in a different compilation. So, if you can convert something to a tensor and pass it as a tensor, that's ideally better. But that's not always possible. Sometimes the value is used as a shape, and different shapes always lead to different compilations.

Please note that those hyperparameters need to be pased to the while loop. So, they get converted to tensors, there. Right?

Correct.


{lambda, opts} = Keyword.pop!(opts, :lambda_init)
lambda = Nx.tensor(lambda, type: x_type)
opts = Keyword.put(opts, :lambda_init, lambda)
zeros_list = for k <- 0..opts[:iterations], do: 0

Check warning on line 128 in lib/scholar/linear/bayesian_ridge_regression.ex

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

variable "k" is unused (if the variable is not meant to be used, prefix it with an underscore)
JoaquinIglesiasTurina marked this conversation as resolved.
Show resolved Hide resolved
scores = Nx.tensor(zeros_list, type: x_type)
JoaquinIglesiasTurina marked this conversation as resolved.
Show resolved Hide resolved

{coefficients, intercept, alpha, lambda, rmse, iterations, has_converged, scores, sigma} =
fit_n(x, y, sample_weights, scores, opts)
iterations = Nx.to_number(iterations)
scores = scores
|> Nx.to_list()
|> Enum.take(iterations)

if Nx.to_number(has_converged) == 1 do
IO.puts("Convergence after #{Nx.to_number(iterations)} iterations")
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You really can't call this here. You have to assume you never can really read the tensor values, even inside deftransform. You can call: Nx.Defn.jit(&BayesianRidgeRegression.fit/3).(x, y, opts) and you will see these operations will fail. :) You can even add this as a test to check the operation is fully jittable. This means you can't do Nx.to_list() |> Enum.take(iterations) either. Whoever calls fithas to do that instead.


%__MODULE__{
coefficients: coefficients,
intercept: intercept,
alpha: Nx.to_number(alpha),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the same as Jose mentioned, you cannot use Nx.to_number here

lambda: Nx.to_number(lambda),
sigma: sigma,
rmse: Nx.to_number(rmse),
iterations: iterations,
scores: scores
}
end

defnp fit_n(x, y, sample_weights, scores, opts) do
x = to_float(x)
y = to_float(y)

{x_offset, y_offset} =
if opts[:fit_intercept?] do
preprocess_data(x, y, sample_weights, opts)
else
x_offset_shape = Nx.axis_size(x, 1)
y_reshaped = if Nx.rank(y) > 1, do: y, else: Nx.reshape(y, {:auto, 1})
y_offset_shape = Nx.axis_size(y_reshaped, 1)

{Nx.broadcast(Nx.tensor(0.0, type: Nx.type(x)), {x_offset_shape}),
Nx.broadcast(Nx.tensor(0.0, type: Nx.type(y)), {y_offset_shape})}
end

{x, y} = {x - x_offset, y - y_offset}

{x, y} =
if opts[:sample_weights_flag] do
rescale(x, y, sample_weights)
else
{x, y}
end

alpha = opts[:alpha_init]
lambda = opts[:lambda_init]

alpha_1 = opts[:alpha_1]
alpha_2 = opts[:alpha_2]
lambda_1 = opts[:lambda_1]
lambda_2 = opts[:lambda_2]

iterations = opts[:iterations]

xt_y = Nx.dot(Nx.transpose(x), y)
JoaquinIglesiasTurina marked this conversation as resolved.
Show resolved Hide resolved
{u, s, vh} = Nx.LinAlg.svd(x, full_matrices?: false)
eigenvals = Nx.pow(s, 2)
JoaquinIglesiasTurina marked this conversation as resolved.
Show resolved Hide resolved
{n_samples, n_features} = Nx.shape(x)
{coef, rmse} = update_coef(x, y, n_samples, n_features, xt_y, u, vh, eigenvals, alpha, lambda)

{{coef, alpha, lambda, rmse, iter, has_converged, scores}, _} =
while {{coef, rmse, alpha, lambda, iter = 0, has_converged = Nx.u8(0), scores = scores},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
while {{coef, rmse, alpha, lambda, iter = 0, has_converged = Nx.u8(0), scores = scores},
while {{coef, rmse, alpha, lambda, iter = Nx.s64(0), has_converged = Nx.u8(0), scores = scores},

JoaquinIglesiasTurina marked this conversation as resolved.
Show resolved Hide resolved
{x, y, xt_y, u, s, vh, eigenvals, alpha_1, alpha_2, lambda_1, lambda_2, iterations}},
iter <= iterations and not has_converged do
new_score =
log_marginal_likelihood(
coef,
rmse,
n_samples,
n_features,
eigenvals,
alpha,
lambda,
alpha_1,
alpha_2,
lambda_1,
lambda_2
)
scores = Nx.put_slice(scores, [iter], Nx.new_axis(new_score, -1))

gamma = Nx.sum(alpha * eigenvals / (lambda + alpha * eigenvals))
lambda = (gamma + 2 * lambda_1) / (Nx.sum(coef ** 2) + 2 * lambda_2)
alpha = (n_samples - gamma + 2 * alpha_1) / (rmse + 2 * alpha_2)

{coef_new, rmse} =
update_coef(x, y, n_samples, n_features, xt_y, u, vh, eigenvals, alpha, lambda)

has_converged = Nx.sum(Nx.abs(coef - coef_new)) < 1.0e-8

{{coef_new, alpha, lambda, rmse, iter + 1, has_converged, scores},
{x, y, xt_y, u, s, vh, eigenvals, alpha_1, alpha_2, lambda_1, lambda_2, iterations}}
end

intercept = set_intercept(coef, x_offset, y_offset, opts[:fit_intercept?])
scaled_sigma = Nx.dot(Nx.transpose(vh), vh / Nx.new_axis(eigenvals + lambda / alpha, -1))
JoaquinIglesiasTurina marked this conversation as resolved.
Show resolved Hide resolved
sigma = scaled_sigma / alpha
{coef, intercept, alpha, lambda, rmse, iter, has_converged, scores, sigma}
end

defnp update_coef(
x,
y,
n_samples,

Check warning on line 237 in lib/scholar/linear/bayesian_ridge_regression.ex

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

variable "n_samples" is unused (if the variable is not meant to be used, prefix it with an underscore)
n_features,

Check warning on line 238 in lib/scholar/linear/bayesian_ridge_regression.ex

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

variable "n_features" is unused (if the variable is not meant to be used, prefix it with an underscore)
xt_y,
u,

Check warning on line 240 in lib/scholar/linear/bayesian_ridge_regression.ex

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

variable "u" is unused (if the variable is not meant to be used, prefix it with an underscore)
vh,
eigenvals,
alpha,
lambda
) do
scaled_eigens = eigenvals + lambda / alpha
regularization = vh / Nx.new_axis(scaled_eigens, -1)
reg_transpose = Nx.dot(regularization, xt_y)
coef = Nx.dot(Nx.transpose(vh), reg_transpose)
JoaquinIglesiasTurina marked this conversation as resolved.
Show resolved Hide resolved

error = y - Nx.dot(x, coef)
squared_error = error ** 2
rmse = Nx.sum(squared_error)

{coef, rmse}
end

defnp log_marginal_likelihood(
coef,
rmse,
n_samples,
n_features,
eigenvals,
alpha,
lambda,
alpha_1,
alpha_2,
lambda_1,
lambda_2
) do
logdet_sigma = -1 * Nx.sum(Nx.log(lambda + alpha * eigenvals))
score_lambda = lambda_1 * Nx.log(lambda) - lambda_2 * lambda
score_alpha = alpha_1 * Nx.log(alpha) - alpha_2 * alpha

score_parameters =
n_features * Nx.log(lambda) + n_samples * Nx.log(alpha) - alpha * rmse -
lambda * Nx.sum(coef ** 2)

score =
0.5 * (score_parameters + logdet_sigma - n_samples * Nx.log(2 * Nx.Constants.pi()))

score_alpha + score_lambda + score
end

deftransform predict(%__MODULE__{coefficients: coeff, intercept: intercept} = _model, x) do
predict_n(coeff, intercept, x)
end

defnp predict_n(coeff, intercept, x), do: Nx.dot(x, [-1], coeff, [-1]) + intercept

# Implements sample weighting by rescaling inputs and
# targets by sqrt(sample_weight).
defnp rescale(x, y, sample_weights) do
case Nx.shape(sample_weights) do
{} = scalar ->
scalar = Nx.sqrt(scalar)
{scalar * x, scalar * y}

_ ->
scale = sample_weights |> Nx.sqrt() |> Nx.make_diagonal()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified by doing

Suggested change
scale = sample_weights |> Nx.sqrt() |> Nx.make_diagonal()
scale = sample_weights |> Nx.sqrt() |> Nx.new_axis(1)

and using * instead of Nx.dot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not believe this change is possible.
Please note that a is a {n_samples} sized vector, b is an {n_samples, n_features} matrix and sample_weights is an {n_samples} sized vector.

sample_weights |> Nx.sqrt() |> Nx.new_axis(1) yields an {n_samples, 1} sized matrix, that cannot be multiplied with a.

As far as I can tell, sample_weights |> Nx.sqrt() |> Nx.make_diagonal() yields the only matrix that is dottable with both a and b.

An alternative would be keeping 2 scale tensors, one for a and one for b. I personally do not like this option as it would unbalance the function. You would have a different tensor for each piece of data, and the operations would look different than the other branch of the case statement.

Please, let me know if you think of a better way.

Copy link
Member

@krstopro krstopro Apr 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct, but we are able to use * instead of Nx.dot. This should work:

defnp rescale(x, y, sample_weights) do
  factor = Nx.sqrt(sample_weights)
  x_scaled = case Nx.shape(factor) do
    {} -> factor * x
    _ -> Nx.new_axis(factor, 1) * x
  end
  y_scaled = factor * y
  {x_scaled, y_scaled}
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works and I find it's a pretty clean solution. Thank you for your comments.

{Nx.dot(scale, x), Nx.dot(scale, y)}
end
end

defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
if fit_intercept? do
y_offset - Nx.dot(x_offset, coeff)
else
Nx.tensor(0.0, type: Nx.type(coeff))
end
end

defnp preprocess_data(x, y, sample_weights, opts) do
if opts[:sample_weights_flag],
do:
{Nx.weighted_mean(x, sample_weights, axes: [0]),
Nx.weighted_mean(y, sample_weights, axes: [0])},
else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])}
end
end
10 changes: 10 additions & 0 deletions lib/scholar/options.ex
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@
end
end

def multi_weights(weights) do

Check warning on line 96 in lib/scholar/options.ex

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

this clause for multi_weights/1 cannot match because a previous clause at line 86 always matches
if is_nil(weights) or
(Nx.is_tensor(weights) and Nx.rank(weights) > 1) do
{:ok, weights}
else
{:error,
"expected weights to be a tensor with rank greater than 1, got: #{inspect(weights)}"}
end
end

def key(key) do
if Nx.is_tensor(key) and Nx.type(key) == {:u, 32} and Nx.shape(key) == {2} do
{:ok, key}
Expand Down
Loading
Loading