-
Notifications
You must be signed in to change notification settings - Fork 0
/
FBSDE_Helper.py
282 lines (222 loc) · 11.7 KB
/
FBSDE_Helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import torch
from torchani.utils import _get_derivatives_not_none as derivative
class FBSDE:
"""
Class implementing the Forward-Backward Stochastic Differential Equation (FBSDE) solver.
Parameters:
grad (bool): Whether to use gradient information in loss computation (default is True).
alpha (float): Weight for the terminal loss term (default is 1).
beta (float): Weight for the derivative terminal loss term (default is 1).
"""
def __init__(self, grad=True, alpha=1, beta=1):
self.grad = grad
self.alpha = alpha
self.beta = beta
def compute_YZ(self, X, t, model, graph=True):
"""
Compute the model output (Y) and its derivatives (Z) with respect to the input (X) at time (t).
Parameters:
X (torch.Tensor): Input tensor of shape (batch_size, p).
t (float): Time value.
model: The PyTorch model for computing Y and its derivatives.
graph (bool): Whether to create a computation graph (default is True).
Returns:
tuple: A tuple containing the model output (Y) and its derivatives (Z).
"""
X.requires_grad_()
Y = model(X, t)
Z = derivative(X, Y, retain_graph=True, create_graph=graph)
return Y, Z
def compute_stepwise_Ys(self, Y, Z, wiener_increments, B):
"""
Compute the stepwise values of Y.
Parameters:
Y (torch.Tensor): Tensor of shape (batch_size, time_steps, p) representing Y values.
Z (torch.Tensor): Tensor of shape (batch_size, time_steps, particle, x12) representing Z values.
wiener_increments (torch.Tensor): Tensor of shape (batch_size, time_steps, particle, x12) representing Wiener increments.
B (float): Scaling factor for the increments.
Returns:
torch.Tensor: Tensor of shape (batch_size, time_steps, p) representing stepwise Y values.
"""
increments = B * torch.einsum('ijkl, ijkl -> ijk', Z, wiener_increments).sum(-1)
return Y + increments.view(Y.size(0), Y.size(1), Y.size(2))
def OLD_Y_loss(self, Y, Y_incr):
"""
Compute the loss based on the difference between Y and the incremental Y values.
Parameters:
Y (torch.Tensor): Tensor of shape (batch_size, time_steps, p) representing Y values.
Y_incr (torch.Tensor): Tensor of shape (batch_size, time_steps, p) representing the incremental Y values.
Returns:
torch.Tensor: Tensor of shape (batch_size,) representing the loss.
"""
delta = Y[:, 1:, :] - Y_incr[:, :-1, :]
return (delta**2).sum(-1).mean(-1)
def Y_step(self, Y_t, Z_t, wiener_increment_t, B):
"""
Perform a single step Euler-Maruyama integration starting from the current state.
Parameters:
Y_t (torch.Tensor): Tensor of shape (batch_size, p) representing Y values at time t.
Z_t (torch.Tensor): Tensor of shape (batch_size, particle, x12) representing Z values at time t.
wiener_increment_t (torch.Tensor): Tensor of shape (batch_size, particle, x12) representing Wiener increments at time t.
B (float): Scaling factor for the increments.
Returns:
torch.Tensor: Tensor of shape (batch_size, p) representing the updated Y values.
"""
scalar = torch.einsum('ijk, ijk -> ij', Z_t, wiener_increment_t).sum(-1).view(-1, 1)
increment = B * scalar
return Y_t + increment
def Y_run(self, Y_0, Z, wiener_increments, B):
"""
Perform the Euler-Maruyama integration for Y values over time.
Parameters:
Y_0 (torch.Tensor): Tensor of shape (batch_size, p) representing the initial Y values.
Z (torch.Tensor): Tensor of shape (batch_size, time_steps, particle, x12) representing Z values.
wiener_increments (torch.Tensor): Tensor of shape (batch_size, time_steps, particle, x12) representing Wiener increments.
B (float): Scaling factor for the increments.
Returns:
torch.Tensor: Tensor of shape (batch_size, time_steps, p) representing the Y trajectories.
"""
steps = wiener_increments.size(1)
trajs = Y_0.view(Y_0.size(0), 1, Y_0.size(1))
Y_t = Y_0
wiener_increment_t = wiener_increments[:, 0, :, :]
Z_t = Z[:, 0, :, :]
for t in range(1, steps):
Y_t = self.Y_step(Y_t, Z_t, wiener_increment_t, B)
trajs = torch.cat((trajs, Y_t.view(Y_t.size(0), 1, Y_t.size(1))), dim=1)
wiener_increment_t = wiener_increments[:, t, :, :]
Z_t = Z[:, t, :, :]
return trajs
def Y_path_loss(self, Y, Y_star):
"""
Compute the loss based on the difference between Y and the target Y values.
Parameters:
Y (torch.Tensor): Tensor of shape (batch_size, time_steps, p) representing Y values.
Y_star (torch.Tensor): Tensor of shape (batch_size, time_steps, p) representing the target Y values.
Returns:
torch.Tensor: Tensor of shape (batch_size,) representing the loss.
"""
path_loss_Y = ((Y - Y_star)**2).mean(dim=-1)
return path_loss_Y.mean(dim=-1)
def Y_terminal_loss(self, Y_star_N, X_N, boundary_func):
"""
Compute the terminal loss based on the difference between Y_star_N and the boundary function.
Parameters:
Y_star_N (torch.Tensor): Tensor of shape (batch_size, p) representing Y_star_N values.
X_N (torch.Tensor): Tensor of shape (batch_size, p) representing the input X values at the terminal time step.
boundary_func: The boundary function.
Returns:
torch.Tensor: Tensor of shape (batch_size,) representing the terminal loss.
"""
terminal_loss_Y = ((Y_star_N - boundary_func(X_N))**2).sum(dim=1)
return terminal_loss_Y
def dY_terminal_loss(self, Z_N, X_N, boundary_func):
"""
Compute the derivative terminal loss based on the difference between Z_N and the derivative of the boundary function.
Parameters:
Z_N (torch.Tensor): Tensor of shape (batch_size, particle, x12) representing Z_N values.
X_N (torch.Tensor): Tensor of shape (batch_size, p) representing the input X values at the terminal time step.
boundary_func: The boundary function.
Returns:
torch.Tensor: Tensor of shape (batch_size,) representing the derivative terminal loss.
"""
if self.grad:
X_N.requires_grad_()
Z_N_star = derivative(X_N, boundary_func(X_N), retain_graph=True)
terminal_loss_Z = ((Z_N - Z_N_star).norm(dim=-1)**2).mean(dim=-1)
# If needed, mask values for specific conditions using the code commented below
# mask = torch.ones_like(terminal_loss_Z).to(X_N.device)
# mask[Z_N_star.norm(dim=-1).view(-1) >= 1] = 1e-3
# terminal_loss_Z = mask * terminal_loss_Z
else:
terminal_loss_Z = (Z_N**2).sum(dim=-1).sum(dim=-1)
return terminal_loss_Z
def FBSDE_Loss(self, Y, Y_star, Y_star_N, Z_N, X_N, boundary_func):
"""
Compute the FBSDE loss.
Parameters:
Y (torch.Tensor): Tensor of shape (batch_size, time_steps, p) representing Y values.
Y_star (torch.Tensor): Tensor of shape (batch_size, time_steps, p) representing the target Y values.
Y_star_N (torch.Tensor): Tensor of shape (batch_size, p) representing Y_star_N values.
Z_N (torch.Tensor): Tensor of shape (batch_size, particle, x12) representing Z_N values.
X_N (torch.Tensor): Tensor of shape (batch_size, p) representing the input X values at the terminal time step.
boundary_func: The boundary function.
Returns:
tuple: A tuple containing the total loss, the path loss, the terminal Y loss, and the derivative terminal loss.
"""
Y_path_loss = self.Y_path_loss(Y, Y_star)
Y_terminal_loss = self.Y_terminal_loss(Y_star_N, X_N, boundary_func)
dY_terminal_loss = self.dY_terminal_loss(Z_N, X_N, boundary_func)
Loss = Y_path_loss + self.alpha * Y_terminal_loss + self.beta * dY_terminal_loss
return Loss, Y_path_loss, Y_terminal_loss, dY_terminal_loss
# import torch
# from tqdm import tqdm
# from torchani.utils import _get_derivatives_not_none as derivative
# import copy
# from torch.autograd import Variable
# class FBSDE():
# def __init__(self,
# grad = True,
# alpha = 1,
# beta = 1
# ) -> None:
# self.grad = grad
# self.alpha = alpha
# self.beta = beta
# def compute_YZ(self, X, t, model, graph = True):
# X.requires_grad_()
# Y = model(X, t)
# Z = derivative(X, Y, retain_graph = True, create_graph = graph)
# return Y, Z
# def compute_stepwise_Ys(self, Y, Z, wiener_increments, B):
# # Y = batch, time, p; Z = batch, time, particle, x12
# increments = B * torch.einsum('ijkl, ijkl -> ijk', Z, wiener_increments).sum(-1)
# return Y + increments.view(Y.size(0) ,Y.size(1), Y.size(2))
# def OLD_Y_loss(self, Y, Y_incr):
# delta = Y[:,1:,:] - Y_incr[:,:-1,:]
# return (delta**2).sum(-1).mean(-1)
# def Y_step(self, Y_t, Z_t, wiener_increment_t, B):
# """
# Runs a single step Euler-Maruyama integration starting in the current state in the non-driven case
# """
# scalar = torch.einsum('ijk, ijk -> ij', Z_t, wiener_increment_t).sum(-1).view(-1,1)
# increment = B*scalar
# return Y_t + increment
# def Y_run(self, Y_0, Z, wiener_increments, B):
# steps = wiener_increments.size(1)
# trajs = Y_0.view(Y_0.size(0), 1, Y_0.size(1))
# Y_t = Y_0
# wiener_increment_t = wiener_increments[:, 0, :, :]
# Z_t = Z[:, 0, :, :]
# for t in (range(1,steps)):
# Y_t = self.Y_step(Y_t, Z_t, wiener_increment_t, B)
# trajs = torch.cat((trajs, Y_t.view(Y_t.size(0), 1, Y_t.size(1))), dim = 1)
# wiener_increment_t = wiener_increments[:, t, :, :]
# Z_t = Z[:, t, :, :]
# return trajs
# def Y_path_loss(self, Y, Y_star):
# path_loss_Y = ((Y - Y_star)**2).mean(dim = -1)
# return path_loss_Y.mean(dim = -1)
# def Y_terminal_loss(self, Y_star_N, X_N, boundary_func):
# #print(boundary_func(X_N))
# #print(Y_star_N.size(), boundary_func(X_N).size())
# terminal_loss_Y = ((Y_star_N - boundary_func(X_N))**2).sum(dim = 1)
# return terminal_loss_Y
# def dY_terminal_loss(self, Z_N, X_N, boundary_func):
# if self.grad == True:
# X_N.requires_grad_()
# Z_N_star = derivative(X_N, boundary_func(X_N), retain_graph = True)#, create_graph = True)
# terminal_loss_Z = ((Z_N - Z_N_star).norm(dim = -1)**2).mean(dim = -1)
# # mask = torch.ones_like(terminal_loss_Z).to(X_N.device)
# # mask[Z_N_star.norm(dim = -1).view(-1)>=1] = 1e-3
# # terminal_loss_Z = mask*terminal_loss_Z
# return terminal_loss_Z
# else:
# terminal_loss_Z = (Z_N**2).sum(dim = -1).sum(dim = -1)
# return terminal_loss_Z
# def FBSDE_Loss(self, Y, Y_star, Y_star_N, Z_N, X_N, boundary_func):
# Y_path_loss = self.Y_path_loss(Y, Y_star)
# Y_terminal_loss = self.Y_terminal_loss(Y_star_N, X_N, boundary_func)
# dY_terminal_loss = self.dY_terminal_loss(Z_N, X_N, boundary_func)
# Loss = Y_path_loss + self.alpha * Y_terminal_loss + self.beta * dY_terminal_loss
# return Loss, Y_path_loss, Y_terminal_loss, dY_terminal_loss