Skip to content

Commit

Permalink
extend tests. (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb authored Apr 11, 2024
1 parent 1ef957d commit e4bd42d
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions tests/mdn_test.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
from typing import Optional

import pytest
import torch
from torch import Tensor, eye, ones, zeros
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch import Tensor, eye

from pyknos.mdn.mdn import MultivariateGaussianMDN


def linear_gaussian(
theta: Tensor,
likelihood_shift: Tensor,
likelihood_cov: Tensor,
theta: Tensor, likelihood_shift: Tensor, likelihood_cov: Tensor
) -> Tensor:

chol_factor = torch.cholesky(likelihood_cov)

return likelihood_shift + theta + torch.mm(chol_factor, torch.randn_like(theta).T).T


@pytest.mark.parametrize(
"dim",
([1, 5, 10]),
)
@pytest.mark.parametrize("dim", ([1, 5, 10]))
@pytest.mark.parametrize("device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("hidden_features", (50, None))
def test_mdn_for_diff_dimension_data(
dim: int, device: str, hidden_features: int = 50, num_components: int = 10
dim: int, device: str, hidden_features: Optional[int], num_components: int = 10
) -> None:

if device == "cuda:0" and not torch.cuda.is_available():
pass
else:
Expand All @@ -38,14 +33,15 @@ def test_mdn_for_diff_dimension_data(
x_numel = theta[0].numel()
y_numel = context[0].numel()

net_features = hidden_features if hidden_features is not None else 50
distribution = MultivariateGaussianMDN(
features=x_numel,
context_features=y_numel,
hidden_features=hidden_features,
hidden_net=nn.Sequential(
nn.Linear(y_numel, hidden_features),
nn.Linear(y_numel, net_features),
nn.ReLU(),
nn.Linear(hidden_features, hidden_features),
nn.Linear(net_features, net_features),
nn.ReLU(),
),
num_components=num_components,
Expand Down

0 comments on commit e4bd42d

Please sign in to comment.