diff --git a/proxtorch/operators/graphnet.py b/proxtorch/operators/graphnet.py index dd52457..082500b 100644 --- a/proxtorch/operators/graphnet.py +++ b/proxtorch/operators/graphnet.py @@ -15,9 +15,9 @@ def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor: def _smooth(self, x: torch.Tensor) -> torch.Tensor: # The last channel is the for the l1 norm - grad = self.gradient(x)[:-1] - # norm of the gradient - norm = torch.norm(grad) ** 2 + grad = self.gradient(x)[:-1]/(1-self.l1_ratio) + # sum of squares of the gradients + norm = torch.sum(grad ** 2) return 0.5 * norm * self.alpha * (1 - self.l1_ratio) def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor: diff --git a/proxtorch/operators/tv_2d.py b/proxtorch/operators/tv_2d.py index 77dfef4..afeffab 100644 --- a/proxtorch/operators/tv_2d.py +++ b/proxtorch/operators/tv_2d.py @@ -11,4 +11,4 @@ def __init__(self, alpha: float, max_iter: int = 200, tol: float = 1e-4) -> None max_iter (int, optional): Maximum iterations for the iterative algorithm. Defaults to 50. tol (float, optional): Tolerance level for early stopping. Defaults to 1e-2. """ - super().__init__(alpha, max_iter, tol, l1_ratio=0.0) + super().__init__(alpha, l1_ratio=0.0, max_iter=max_iter, tol=tol) diff --git a/proxtorch/operators/tv_3d.py b/proxtorch/operators/tv_3d.py index 001e781..2f5ab8b 100644 --- a/proxtorch/operators/tv_3d.py +++ b/proxtorch/operators/tv_3d.py @@ -11,4 +11,4 @@ def __init__(self, alpha: float, max_iter: int = 200, tol: float = 1e-4) -> None max_iter (int, optional): Maximum iterations for the iterative algorithm. Defaults to 50. tol (float, optional): Tolerance level for early stopping. Defaults to 1e-2. """ - super().__init__(alpha, max_iter, tol, l1_ratio=0.0) + super().__init__(alpha, l1_ratio=0.0, max_iter=max_iter, tol=tol) diff --git a/proxtorch/operators/tvl1_3d.py b/proxtorch/operators/tvl1_3d.py index 9d3efe3..45ace7c 100644 --- a/proxtorch/operators/tvl1_3d.py +++ b/proxtorch/operators/tvl1_3d.py @@ -41,9 +41,9 @@ class TVL1_3D(ProxOperator): def __init__( self, alpha: float, + l1_ratio=0.0, max_iter: int = 200, tol: float = 1e-4, - l1_ratio=0.0, ) -> None: """ Initialize the 3D Total Variation proximal operator. diff --git a/test/test_graphnet.py b/test/test_graphnet.py index 740997c..eba82b9 100644 --- a/test/test_graphnet.py +++ b/test/test_graphnet.py @@ -3,57 +3,58 @@ from proxtorch.operators import GraphNet3D, GraphNet2D -# -# def test_converges_to_sparse_smooth(): -# torch.manual_seed(0) -# -# # generate a spatially sparse signal -# x_true = torch.zeros(10, 10) -# x_true[3:7, 3:7] = torch.ones(4, 4) -# x_true = x_true.flatten() -# -# # generate a random matrix -# A = torch.rand(100, 100) -# -# # generate measurements -# y = A @ x_true -# -# # define the proximal operator -# alpha = 10 -# l1_ratio = 0.0 # 0.5 -# prox = GraphNet2D(alpha, l1_ratio) -# -# # define the objective function -# def objective(x): -# return 0.5 * torch.norm(A @ x.reshape(-1) - y) ** 2 -# -# # define the step size -# tau = 1 / torch.norm(A.t() @ A) -# -# # initialize the solution -# x = torch.nn.Parameter(torch.rand(10, 10, requires_grad=True)) -# -# # optimizer -# optimizer = torch.optim.SGD([x], lr=tau) -# -# # optimization loop -# for i in range(1000): -# optimizer.zero_grad() -# obj = objective(x) + prox(x) -# obj.backward() -# optimizer.step() -# x.data = prox.prox(x.data, tau) -# -# # check that the result is smooth -# plt.imshow(x.data.detach().numpy()) -# plt.show() -# -# # compare with x_true -# difference = torch.norm(x.data.flatten() - x_true) -# assert difference < 1e-3 -# -# # check that the result is sparse -# assert torch.sum(x.data == 0) > 50 +def test_converges_to_sparse_smooth(): + import matplotlib.pyplot as plt + torch.manual_seed(0) + + # generate a spatially sparse signal + x_true = torch.zeros(10, 10) + x_true[3:7, 3:7] = torch.ones(4, 4) + x_true = x_true.flatten() + + # generate a random matrix + A = torch.rand(100, 100) + + # generate measurements + y = A @ x_true + + # define the proximal operator + alpha = 1000 + l1_ratio = 0.0 # 0.5 + prox = GraphNet2D(alpha, l1_ratio) + + # define the objective function + def objective(x): + return 0.5 * torch.norm(A @ x.reshape(-1) - y) ** 2 + + # define the step size + tau = 0.1 / torch.norm(A.t() @ A) + + # initialize the solution + x = torch.nn.Parameter(torch.rand(10, 10, requires_grad=True)) + + # optimizer + optimizer = torch.optim.SGD([x], lr=tau, nesterov=True, momentum=0.9) + + # optimization loop + for i in range(20000): + optimizer.zero_grad() + p=prox(x) + obj = objective(x) + prox(x) + obj.backward() + optimizer.step() + x.data = prox.prox(x.data, tau) + + # check that the result is smooth + plt.imshow(x.data.detach().numpy()) + plt.show() + + # compare with x_true + difference = torch.norm(x.data.flatten() - x_true) + assert difference < 1e-3 + + # check that the result is sparse + assert torch.sum(x.data == 0) > 50 def test_graph_net_3d_prox():