-
Notifications
You must be signed in to change notification settings - Fork 4
/
setup.py
152 lines (123 loc) · 5 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from os import environ, makedirs
from pathlib import Path
from shutil import rmtree
from subprocess import check_output
from sys import executable
from packaging.version import parse as parse_version
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext as build_ext_base
from setuptools_scm import Version, dump_version, get_version
class CMakeExtension(Extension):
def __init__(self, name):
super().__init__(name, sources=[])
# We needs to name commands like this since introspection is used to print
# usage message.
class build_ext(build_ext_base):
user_options = build_ext_base.user_options + [
('cuda', None, "build with CUDA support"),
('cuda-arch=', None,
"list of CUDA architectures to generate PTX code"),
('cmake-prefix-path=', None,
"semicolon-separated list of directories specifying search prefixes"),
('cmake-generator=', None, "supply CMake generator"),
('cmake-options=', None, "supply auxiliary CMake arguments"),
]
boolean_options = build_ext_base.boolean_options + ['cuda']
def initialize_options(self):
super().initialize_options()
self.cmake_generator = None
self.cmake_options = None
self.cmake_prefix_path = None
self.cuda = False
self.cuda_arch = 'common'
def run(self):
cmake_extensions = []
rest_extensions = []
for ext in self.extensions:
if isinstance(ext, CMakeExtension):
cmake_extensions.append(ext)
else:
rest_extensions.append(ext)
self.extensions = rest_extensions
super().run()
for ext in cmake_extensions:
self.build_extension(ext)
def build_extension(self, ext):
build_dir = self.build_temp
install_dir = self.build_lib
source_dir = '.'
if self.inplace:
environ['CMAKE_INSTALL_MODE'] = 'SYMLINK_OR_COPY'
install_dir = Path().cwd().absolute()
build_type = 'RelWithDebugInfo'
if self.debug:
build_type = 'Debug'
# Do nothing on dry run.
if self.dry_run:
return
# If we are forced to rebuild then remove build directory run as usual.
if self.force:
rmtree(build_dir)
makedirs(build_dir, exist_ok=True)
# Obtain CMake prefix path to PyTorch scripts.
if not self.cmake_prefix_path:
self.cmake_prefix_path = get_torch_cmake_prefix_path()
# Generate project build system.
cmd = (f'cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -S {source_dir} '
f'-B {build_dir} -DCMAKE_BUILD_TYPE={build_type}').split()
if self.cmake_generator:
cmd.extend(('-G', self.cmake_generator))
if self.cmake_prefix_path:
cmd.append(f'-DCMAKE_PREFIX_PATH={self.cmake_prefix_path}')
if self.cuda:
cmd.append('-DUSE_CUDA=ON')
if self.cuda and self.cuda_arch:
cmd.append(f'-DTORCH_CUDA_ARCH_LIST={self.cuda_arch.capitalize()}')
if self.debug:
cmd.append('-DCMAKE_BUILD_TYPE=Debug')
if self.cmake_options:
cmd.extend(self.cmake_options.split())
self.spawn(cmd)
# Build project.
cmd = f'cmake --build {self.build_temp}'.split()
if self.parallel:
cmd.extend(('-j', str(self.parallel)))
self.spawn(cmd)
# Install project to build directory.
cmd = (f'cmake --install {self.build_temp} '
f'--prefix {install_dir}').split()
self.spawn(cmd)
def get_torch_attr(script):
# We do not want import torch directly in order to save 200+Mb of memory
# during building extension.
command = [executable, '-c', script]
output = check_output(command, encoding='utf-8', timeout=60)
return output.strip()
def get_torch_cmake_prefix_path():
script = 'import torch.utils; print(torch.utils.cmake_prefix_path)'
return get_torch_attr(script)
def get_torch_version():
script = 'import torch as T; print(T.version.__version__)'
return get_torch_attr(script)
# Get FewBit version and Torch version.
try:
fewbit_version = parse_version(get_version())
except LookupError:
fewbit_version = Version('0.0.0')
torch_version = parse_version(get_torch_version())
# FewBit version is <fewbit-public>[+<torch-local>.pt<torch-base>] version.
version = fewbit_version.base_version
if torch_version.local:
version += f'+{torch_version.local}.pt{torch_version.base_version}'
else:
version += f'+pt{torch_version.base_version}'
# Write FewBit version to file.
dump_version('.', version, 'fewbit/version.py')
# We fix Torch version in order to maintain compatibility between Torch and its
# extension as well as CUDA ABI.
install_requires = ['numpy', f'torch=={torch_version.base_version}']
setup(name='fewbit',
version=version,
install_requires=install_requires,
ext_modules=[CMakeExtension('fewbit.fewbit')],
cmdclass={'build_ext': build_ext})