-
Notifications
You must be signed in to change notification settings - Fork 3
/
jit.cpp
164 lines (127 loc) · 4.76 KB
/
jit.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#ifdef ENABLE_PYTROCH_JIT
#include "error.hpp"
#include "predictor.hpp"
#include <algorithm>
#include <iosfwd>
#include <iostream>
#include <memory>
#include <string>
#include <typeinfo>
#include <utility>
#include <vector>
extern Torch_IValue Torch_ConvertIValueToTorchIValue(torch::IValue value);
struct Torch_JITModule {
std::shared_ptr<torch::jit::script::Module> module;
};
struct Torch_JITModule_Method {
torch::jit::script::Method& run;
};
Torch_JITModuleContext Torch_CompileTorchScript(char* cstring_script, Torch_Error* error) {
HANDLE_TH_ERRORS(Torch_GlobalError);
std::string script(cstring_script);
auto mod = new Torch_JITModule();
mod->module = torch::jit::compile(script);
return (void*)mod;
END_HANDLE_TH_ERRORS(Torch_GlobalError, NULL)
}
Torch_JITModuleContext Torch_LoadJITModule(char* cstring_path, Torch_Error* error) {
HANDLE_TH_ERRORS(Torch_GlobalError);
std::string module_path(cstring_path);
auto mod = new Torch_JITModule();
mod->module = torch::jit::load(module_path);
return (void*)mod;
END_HANDLE_TH_ERRORS(Torch_GlobalError, NULL)
}
void Torch_ExportJITModule(Torch_JITModuleContext ctx, char* cstring_path, Torch_Error* error) {
HANDLE_TH_ERRORS(Torch_GlobalError);
std::string module_path(cstring_path);
auto mod = (Torch_JITModule*)ctx;
mod->module->save(module_path);
END_HANDLE_TH_ERRORS(Torch_GlobalError, )
}
Torch_JITModuleMethodContext Torch_JITModuleGetMethod(Torch_JITModuleContext ctx, char* cstring_method,
Torch_Error* error) {
HANDLE_TH_ERRORS(Torch_GlobalError);
std::string method_name(cstring_method);
auto mod = (Torch_JITModule*)ctx;
auto met = new Torch_JITModule_Method{mod->module->get_method(method_name)};
return (void*)met;
END_HANDLE_TH_ERRORS(Torch_GlobalError, NULL)
}
char** Torch_JITModuleGetMethodNames(Torch_JITModuleContext ctx, size_t* len) {
auto mod = (Torch_JITModule*)ctx;
auto size = mod->module->get_methods().size();
*len = size;
auto result = (char**)malloc(sizeof(char*) * size);
int i = 0;
for (auto& method : mod->module->get_methods()) {
auto key = method.value()->name();
auto ckey = new char[key.length() + 1];
strcpy(ckey, key.c_str());
*(result + i) = ckey;
i++;
}
return result;
}
Torch_IValue Torch_JITModuleMethodRun(Torch_JITModuleMethodContext ctx, Torch_IValue* inputs, size_t input_size,
Torch_Error* error) {
HANDLE_TH_ERRORS(Torch_GlobalError);
auto met = (Torch_JITModule_Method*)ctx;
std::vector<torch::IValue> inputs_vec;
for (int i = 0; i < input_size; i++) {
auto ival = *(inputs + i);
inputs_vec.push_back(Torch_ConvertTorchIValueToIValue(ival));
}
auto res = met->run(inputs_vec);
return Torch_ConvertIValueToTorchIValue(res);
END_HANDLE_TH_ERRORS(Torch_GlobalError, Torch_IValue{})
}
Torch_ModuleMethodArgument* Torch_JITModuleMethodArguments(Torch_JITModuleMethodContext ctx, size_t* res_size) {
auto met = (Torch_JITModule_Method*)ctx;
auto schema = met->run.getSchema();
auto arguments = schema.arguments();
auto result = (Torch_ModuleMethodArgument*)malloc(sizeof(Torch_ModuleMethodArgument) * arguments.size());
*res_size = arguments.size();
for (std::vector<torch::Argument>::size_type i = 0; i != arguments.size(); i++) {
auto name = arguments[i].name();
char* cstr_name = new char[name.length() + 1];
strcpy(cstr_name, name.c_str());
auto type = arguments[i].type()->str();
char* cstr_type = new char[type.length() + 1];
strcpy(cstr_type, type.c_str());
*(result + i) = Torch_ModuleMethodArgument{
.name = cstr_name,
.typ = cstr_type,
};
}
return result;
}
Torch_ModuleMethodArgument* Torch_JITModuleMethodReturns(Torch_JITModuleMethodContext ctx, size_t* res_size) {
auto met = (Torch_JITModule_Method*)ctx;
auto schema = met->run.getSchema();
auto arguments = schema.returns();
auto result = (Torch_ModuleMethodArgument*)malloc(sizeof(Torch_ModuleMethodArgument) * arguments.size());
*res_size = arguments.size();
for (std::vector<torch::Argument>::size_type i = 0; i != arguments.size(); i++) {
auto name = arguments[i].name();
char* cstr_name = new char[name.length() + 1];
strcpy(cstr_name, name.c_str());
auto type = arguments[i].type()->str();
char* cstr_type = new char[type.length() + 1];
strcpy(cstr_type, type.c_str());
*(result + i) = Torch_ModuleMethodArgument{
.name = cstr_name,
.typ = cstr_type,
};
}
return result;
}
void Torch_DeleteJITModuleMethod(Torch_JITModuleMethodContext ctx) {
auto med = (Torch_JITModule_Method*)ctx;
delete med;
}
void Torch_DeleteJITModule(Torch_JITModuleContext ctx) {
auto mod = (Torch_JITModule*)ctx;
delete mod;
}
#endif // ENABLE_PYTROCH_JIT