-
Notifications
You must be signed in to change notification settings - Fork 0
/
parse_generic_orbit.py
200 lines (172 loc) · 6.93 KB
/
parse_generic_orbit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from sympy.parsing import mathematica
from sympy import cse, ccode, symbols
from sympy.utilities.iterables import numbered_symbols
from sympy.codegen.rewriting import create_expand_pow_optimization
from collections import deque
order = 1
with open(f"generic/Psi{order}acc.txt") as f:
psi_str = f.read()
with open(f"generic/dtPsi{order}acc.txt") as f:
dtpsi_str = f.read()
with open(f"generic/dxPsi{order}acc.txt") as f:
dxpsi_str = f.read()
with open(f"generic/dyPsi{order}acc.txt") as f:
dypsi_str = f.read()
with open(f"generic/dzPsi{order}acc.txt") as f:
dzpsi_str = f.read()
xp, yp, zp, xpdot, xpddot, ypdot, ypddot, zpdot, Dx, Dy, Dz, rp, rpdot, fx, fy, ft, fxdot, fydot, ftdot, Duft, Dufx, Dufy, Duftdot, Dufxdot, Dufydot, M, z, l = symbols("xp yp zp xpdot xpddot ypdot ypddot zpdot Dx Dy Dz rp rpdot fx fy ft fxdot fydot ftdot Duft Dufx Dufy Duftdot Dufxdot Dufydot M z l")
psi = mathematica.parse_mathematica(psi_str)
dtpsi = mathematica.parse_mathematica(dtpsi_str)
dxpsi = mathematica.parse_mathematica(dxpsi_str)
dypsi = mathematica.parse_mathematica(dypsi_str)
dzpsi = mathematica.parse_mathematica(dzpsi_str)
subexpr, res = cse([psi, dtpsi, dxpsi, dypsi, dzpsi], optimizations='basic')
cpp_file = """
// Distributed under the MIT License.
// See LICENSE.txt for details.
#include "Evolution/Systems/CurvedScalarWave/Worldtube/PunctureField.hpp"
#include "DataStructures/DataBox/Prefixes.hpp"
#include "DataStructures/DataVector.hpp"
#include "DataStructures/DynamicBuffer.hpp"
#include "DataStructures/Tags/TempTensor.hpp"
#include "DataStructures/Tensor/EagerMath/Magnitude.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "DataStructures/Variables.hpp"
#include "Evolution/Systems/CurvedScalarWave/Tags.hpp"
#include "NumericalAlgorithms/LinearOperators/PartialDerivatives.hpp"
#include "Utilities/Gsl.hpp"
namespace CurvedScalarWave::Worldtube {{
void puncture_field_acc_0(
gsl::not_null<Variables<tmpl::list<
CurvedScalarWave::Tags::Psi, ::Tags::dt<CurvedScalarWave::Tags::Psi>,
::Tags::deriv<CurvedScalarWave::Tags::Psi, tmpl::size_t<3>,
Frame::Inertial>>>*>
result,
const tnsr::I<DataVector, 3, Frame::Inertial>& centered_coords,
const tnsr::I<double, 3>& particle_position,
const tnsr::I<double, 3>& particle_velocity,
const tnsr::I<double, 3>& particle_acceleration, const double BH_mass) {{
const size_t grid_size = get<0>(centered_coords).size();
result->initialize(grid_size);
const double xp = particle_position[0];
const double yp = particle_position[1];
const double xpdot = particle_velocity[0];
const double ypdot = particle_velocity[1];
const double xpddot = particle_acceleration[0];
const double ypddot = particle_acceleration[1];
const double rp = get(magnitude(particle_position));
const double rpdot = (xp * xpdot + yp * ypdot) / rp;
const auto& Dx = get<0>(centered_coords);
const auto& Dy = get<1>(centered_coords);
const auto& z = get<2>(centered_coords);
const double M = BH_mass;
DynamicBuffer<DataVector> temps({number_of_temps}, grid_size);
{aux_vars}
get(get<CurvedScalarWave::Tags::Psi>(*result)) = {sol_psi};
get(get<::Tags::dt<CurvedScalarWave::Tags::Psi>>(*result)) = {sol_dtpsi};
get<0>(get<::Tags::deriv<CurvedScalarWave::Tags::Psi, tmpl::size_t<3>, Frame::Inertial>>(
*result)) = {sol_dxpsi};
get<1>(get<::Tags::deriv<CurvedScalarWave::Tags::Psi, tmpl::size_t<3>, Frame::Inertial>>(
*result)) = {sol_dypsi};
get<2>(get<::Tags::deriv<CurvedScalarWave::Tags::Psi, tmpl::size_t<3>, Frame::Inertial>>(
*result)) = {sol_dzpsi};
}}
}}
"""
symbol_gen = numbered_symbols()
temp_symbols = {}
for i in range(10000):
temp_symbols[next(symbol_gen)] = i
temp_dict = {}
root_symbols = [Dx, Dy, Dz, xp, yp, xpdot, xpddot, ypdot, ypddot, rp, rpdot, ft, fx, fy, fxdot, fydot, ftdot, Duft, Dufx, Dufy, Duftdot, Dufxdot, Dufydot, z, M]
double_symbols = numbered_symbols("d_")
dv_symbols = numbered_symbols("dv_")
dvs = []
doubles = []
exprs = []
for i, (name, expr) in enumerate(subexpr):
old_expr = expr
for sym in expr.free_symbols:
if sym not in root_symbols:
expr = expr.subs(sym, exprs[temp_symbols[sym]])
new_symbol = None
if (Dx in expr.free_symbols or Dy in expr.free_symbols
or z in expr.free_symbols):
new_symbol = next(dv_symbols)
dvs.append((new_symbol, old_expr))
else:
new_symbol = next(double_symbols)
doubles.append((new_symbol, old_expr))
temp_dict[name] = new_symbol
exprs.append(expr)
print(f"{len(doubles)} double vars")
print(f"{len(dvs)} datavector vars")
expand_opt = create_expand_pow_optimization(10)
for i, (name, expr) in enumerate(doubles):
for sym in expr.free_symbols:
if sym in temp_dict:
expr = expr.subs(sym, temp_dict[sym])
doubles[i] = (name, expand_opt(expr))
for i, (name, expr) in enumerate(dvs):
for sym in expr.free_symbols:
if sym in temp_dict:
expr = expr.subs(sym, temp_dict[sym])
dvs[i] = (name, expand_opt(expr))
for i, expr in enumerate(res):
for sym in expr.free_symbols:
if sym in temp_dict:
expr = expr.subs(sym, temp_dict[sym])
res[i] = expand_opt(expr)
def check_last_usage(sym, exprs, res):
for _, expr in exprs:
if sym in expr.free_symbols:
return False
for expr in res:
if sym in expr.free_symbols:
return False
return True
free_storage = deque()
occ_storage = []
total_storage = 0
dv_symbols = [name for name, _ in dvs]
for i, (name, expr) in enumerate(dvs):
for sym in expr.free_symbols:
if sym in dv_symbols:
if check_last_usage(sym, dvs[i + 1:], res):
free_storage.append(sym)
if len(free_storage) > 0:
occ_storage.append(free_storage.pop())
else:
total_storage += 1
occ_storage.append(name)
print(f"{total_storage} allocations")
free_storage = deque(range(total_storage))
occ_storage = []
current_storage = {}
aux_dvs = ""
for i, (name, expr) in enumerate(dvs):
for sym in expr.free_symbols:
if sym in dv_symbols:
if check_last_usage(sym, dvs[i + 1:], res):
free_storage.appendleft(current_storage[sym])
index = free_storage.popleft()
current_storage[name] = index
code = f"DataVector& {name} = temps.at({index});\n{name} = {ccode(expr)};\n"
aux_dvs += code
aux_doubles = "\n".join(f"const double {name} = {ccode(expr)};"
for name, expr in doubles)
aux_vars = aux_doubles + "\n" + aux_dvs
cse_dict = {
"aux_vars": aux_vars,
"sol_psi": ccode(res[0]),
"sol_dtpsi": ccode(res[1]),
"sol_dxpsi": ccode(res[2]),
"sol_dypsi": ccode(res[3]),
"sol_dzpsi": ccode(res[4]),
"number_of_temps": total_storage
}
full_file_to_write = cpp_file.format(**cse_dict)
import re
full_file_to_write = re.sub(r"\s(\d)\s", r"\1\.0", full_file_to_write)
with open("PunctureField.cpp", "w") as f:
f.write(full_file_to_write)