Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Sep 4, 2023
1 parent 4c2b7ed commit 4b38156
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 57 deletions.
6 changes: 3 additions & 3 deletions proxtorch/operators/graphnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor:

def _smooth(self, x: torch.Tensor) -> torch.Tensor:
# The last channel is the for the l1 norm
grad = self.gradient(x)[:-1]
# norm of the gradient
norm = torch.norm(grad) ** 2
grad = self.gradient(x)[:-1]/(1-self.l1_ratio)
# sum of squares of the gradients
norm = torch.sum(grad ** 2)
return 0.5 * norm * self.alpha * (1 - self.l1_ratio)

def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion proxtorch/operators/tv_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ def __init__(self, alpha: float, max_iter: int = 200, tol: float = 1e-4) -> None
max_iter (int, optional): Maximum iterations for the iterative algorithm. Defaults to 50.
tol (float, optional): Tolerance level for early stopping. Defaults to 1e-2.
"""
super().__init__(alpha, max_iter, tol, l1_ratio=0.0)
super().__init__(alpha, l1_ratio=0.0, max_iter=max_iter, tol=tol)
2 changes: 1 addition & 1 deletion proxtorch/operators/tv_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ def __init__(self, alpha: float, max_iter: int = 200, tol: float = 1e-4) -> None
max_iter (int, optional): Maximum iterations for the iterative algorithm. Defaults to 50.
tol (float, optional): Tolerance level for early stopping. Defaults to 1e-2.
"""
super().__init__(alpha, max_iter, tol, l1_ratio=0.0)
super().__init__(alpha, l1_ratio=0.0, max_iter=max_iter, tol=tol)
2 changes: 1 addition & 1 deletion proxtorch/operators/tvl1_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ class TVL1_3D(ProxOperator):
def __init__(
self,
alpha: float,
l1_ratio=0.0,
max_iter: int = 200,
tol: float = 1e-4,
l1_ratio=0.0,
) -> None:
"""
Initialize the 3D Total Variation proximal operator.
Expand Down
103 changes: 52 additions & 51 deletions test/test_graphnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,58 @@
from proxtorch.operators import GraphNet3D, GraphNet2D


#
# def test_converges_to_sparse_smooth():
# torch.manual_seed(0)
#
# # generate a spatially sparse signal
# x_true = torch.zeros(10, 10)
# x_true[3:7, 3:7] = torch.ones(4, 4)
# x_true = x_true.flatten()
#
# # generate a random matrix
# A = torch.rand(100, 100)
#
# # generate measurements
# y = A @ x_true
#
# # define the proximal operator
# alpha = 10
# l1_ratio = 0.0 # 0.5
# prox = GraphNet2D(alpha, l1_ratio)
#
# # define the objective function
# def objective(x):
# return 0.5 * torch.norm(A @ x.reshape(-1) - y) ** 2
#
# # define the step size
# tau = 1 / torch.norm(A.t() @ A)
#
# # initialize the solution
# x = torch.nn.Parameter(torch.rand(10, 10, requires_grad=True))
#
# # optimizer
# optimizer = torch.optim.SGD([x], lr=tau)
#
# # optimization loop
# for i in range(1000):
# optimizer.zero_grad()
# obj = objective(x) + prox(x)
# obj.backward()
# optimizer.step()
# x.data = prox.prox(x.data, tau)
#
# # check that the result is smooth
# plt.imshow(x.data.detach().numpy())
# plt.show()
#
# # compare with x_true
# difference = torch.norm(x.data.flatten() - x_true)
# assert difference < 1e-3
#
# # check that the result is sparse
# assert torch.sum(x.data == 0) > 50
def test_converges_to_sparse_smooth():
import matplotlib.pyplot as plt
torch.manual_seed(0)

# generate a spatially sparse signal
x_true = torch.zeros(10, 10)
x_true[3:7, 3:7] = torch.ones(4, 4)
x_true = x_true.flatten()

# generate a random matrix
A = torch.rand(100, 100)

# generate measurements
y = A @ x_true

# define the proximal operator
alpha = 1000
l1_ratio = 0.0 # 0.5
prox = GraphNet2D(alpha, l1_ratio)

# define the objective function
def objective(x):
return 0.5 * torch.norm(A @ x.reshape(-1) - y) ** 2

# define the step size
tau = 0.1 / torch.norm(A.t() @ A)

# initialize the solution
x = torch.nn.Parameter(torch.rand(10, 10, requires_grad=True))

# optimizer
optimizer = torch.optim.SGD([x], lr=tau, nesterov=True, momentum=0.9)

# optimization loop
for i in range(20000):
optimizer.zero_grad()
p=prox(x)
obj = objective(x) + prox(x)
obj.backward()
optimizer.step()
x.data = prox.prox(x.data, tau)

# check that the result is smooth
plt.imshow(x.data.detach().numpy())
plt.show()

# compare with x_true
difference = torch.norm(x.data.flatten() - x_true)
assert difference < 1e-3

# check that the result is sparse
assert torch.sum(x.data == 0) > 50


def test_graph_net_3d_prox():
Expand Down

0 comments on commit 4b38156

Please sign in to comment.