-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen.rkt
344 lines (284 loc) · 10.9 KB
/
gen.rkt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
#lang rosette
(require
"util.rkt"
srfi/1
racket/match)
(provide (all-defined-out))
(define (gen stmt diff init-pos)
(define shfl-red (contrive-shuffle stmt diff)) ; redundant shuffle
(define shfl (optimize-shuffle shfl-red diff))
(print-shuffle-info shfl diff)
(define src (filter (lambda (x) (eq? (cadr x) 'src)) shfl))
(define regs
(map (lambda (x)
`(storageN ".reg" ,(~ x 2) ,(format "%__pw_~a" (~ x 0))))
(delete-duplicates src (lambda (a b) (eq? (car a) (car b))))))
;(define stmt-inited (put-init stmt init-pos src))
(define stmt-implanted (implant-shuffle stmt shfl))
(append
regs
'((storageN ".reg" ".u32" "%warp_id")
(instruction "mov" ".u32" "%warp_id" "%tid.x")
(instruction "rem" ".u32" "%warp_id" "%warp_id" "32"))
stmt-implanted))
(define (print-shuffle-info shfl diff)
(define ll (length diff))
(define dst (map cddr (filter (lambda (s) (eq? (cadr s) 'dst)) shfl)))
(define uni (filter (lambda (d) (= (length d) 3)) dst))
(define sum (apply + (map (compose abs (curryr ~ 2)) dst)))
(define dl (length dst))
(define ul (length uni))
(displayln
(format "LOAD:~a SHUFFLE:~a AVERAGE:~a POSSIBLE-BIDIR:~a"
ll
dl
(if (= dl 0) #f (~r (/ sum dl) #:precision '(= 2)))
(- dl ul))
(current-error-port)))
(define (init-val type)
(match type
[(regexp "^\\.[u]") "-1"]
[(or ".s8" ".b8" ".f8") "0xff"]
[(or ".s16" ".b16") "0xffff"]
[(or ".s32" ".b32" ".f16x2") "0xffffffff"]
[(or ".b64" ".b64") "0xffffffffffffffff"]
[".f32" "0fffffffff"]
[".f64" "0dffffffffffffffff"]
[".pred" "0b1"]
[else #f]))
;; Put initialization on the top of innermost loops
;; (no need to initialize loads outside loops)
(define (put-init s init-pos src)
(match s
[(list 'group _ ...)
`(group ,@(put-init (cdr s) init-pos src))]
[(list 'instruction _ ...)
(define h (eq-hash-code s))
(define v (filter-map
(lambda (x) (and (eq? h (cdr x))
(and-let1 info (assoc-ref src (car x))
(cons (car x) (~ info 1)) ; (hash . type)
)))
init-pos))
(if (null? v) s
`(group
,@(map (lambda (x)
`(instruction "mov" ,(cdr x) ,(format "%__pw_~a" (car x))
,(init-val (cdr x))))
v)
,s))]
[(list 'debugging _ ...)
(match (~ s 1)
[(or ".file" ".loc") s]
[".section"
(list 'debugging ".section" (~ s 2) (put-init (~ s 3) init-pos src))])]
[(list _ ...)
(map (curryr put-init init-pos src) s)]
[else s]))
;; Avoid redudant shuffles (shuffle over shuffle) & find bi-directional shuffles
(define (optimize-shuffle orig diff)
(define r (map car orig)) ; hashes of redundant shuffle
(append-map (lambda (o) (optimize-shuffle1 (~ o 0) (~ o 1) r diff)) orig))
(define (optimize-shuffle1 h i r diff) ; hash instruction redundant diff
(define x (assoc-ref diff h))
(define v (and x (filter cdr x)))
(define s (and v (sort v #:key cdr (lambda (a b) (< (abs a) (abs b))))))
(define unique? (compose null? cdr delete-duplicates))
(define not-in-r (negate (compose (curryr member r) caar)))
(define pred (conjoin not-in-r unique?))
;(define pred unique?) ; For bi-directional
(define same-hash (lambda (e) (filter (lambda (a) (= (car e) (car a))) x)))
(define f (and (pair? s) (filter (compose pred same-hash) s)))
(define c (and (pair? f) (car f)))
(define type (and c
(match (cdr i)
[(list opcode (and spec (regexp #rx"^[.]")) ... oprand ...)
(last spec)])))
(define b ; bi-directional
(and c (not (= (cdr c) 0))
(findf (lambda (e)
(define we (cdr e)) (define wc (cdr c))
(cond [(and (> we 0) (> wc 0)) #f]
[(and (< we 0) (< wc 0)) #f]
[(<= (+ (abs we) (abs wc)) 32) e]
[else #f]))
f)))
(define src
`(,@(if (not c) '()
;; hash 'src type dst-hash width
(list (list (car c) 'src type h (cdr c)) ))
,@(if (not b) '()
(list (list (car b) 'src type h (cdr b)) ))))
(define dst
`(,@(if (not c) '()
; hash 'dst type src-hash width [bi-directional-hash width]
(list `(,h dst ,type ,(car c) ,(cdr c)
,@(if (not b) '() (list (car b) (cdr b)))
)))))
(append src dst))
(define (contrive-shuffle stmt diff)
(append-map (curryr contrive1 diff) stmt))
(define (contrive1 s diff)
(match s
[(list 'group _ ...)
(contrive-shuffle (cdr s) diff)]
[(list 'instruction _ ...)
(contrive-insn s diff)]
[(list 'debugging _ ...)
(match (~ s 1)
[(or ".file" ".loc") '()]
[".section"
(contrive1 (~ s 3) diff)])]
[else '()]))
(define (contrive-insn i diff)
(define h (eq-hash-code i))
(define x (assoc-ref diff h))
(define v (and x (filter cdr x)))
(define s (and v (sort v #:key cdr (lambda (a b) (< (abs a) (abs b))) )))
;; if the shuffle size is consistent
(define unique? (compose null? cdr delete-duplicates))
(define same-hash (lambda (e) (filter (lambda (a) (= (car e) (car a))) x)))
(define c (and (pair? s) (findf (compose unique? same-hash) s)))
(define type (and c
(match (cdr i)
[(list opcode (and spec (regexp #rx"^[.]")) ... oprand ...)
(last spec)])))
(if c
;; (list (list (car m) 'src type h (cdr m)) ; hash 'src type dst-hash width
;; (list h 'dst type (car m) (cdr m))); hash 'dst type src-hash width
(list (list h i))
'()))
(define (implant-shuffle stmt shfl)
(map (curryr implant1 shfl) stmt))
(define (implant1 s shfl)
(match s
[(list 'group _ ...)
`(group ,@(implant-shuffle (cdr s) shfl))]
[(list 'instruction _ ...)
(implant-insn s shfl)]
[(list 'debugging _ ...)
(match (~ s 1)
[(or ".file" ".loc") s]
[".section"
(list 'debugging ".section" (~ s 2) (implant1 (~ s 3) shfl))])]
[else s]))
(define (implant-insn i shfl)
(define h (eq-hash-code i))
(define t (map cdr (filter (lambda (s) (= h (car s))) shfl)))
(define s (assoc-ref t 'src))
(define d (assoc-ref t 'dst))
(define r (and (or s d)
(match (cdr i)
[(list opcode (and spec (regexp #rx"^[.]")) ... oprand ...)
(car oprand)])))
(define si (and s `(instruction "mov" ,(~ s 0) ,(format "%__pw_~a" h) ,r)))
(define type (and d (~ d 0)))
(define temp (and d (format "%__pw_~a" (~ d 1) )))
(define width (and d (~ d 2)))
(define init (init-val type))
(define pred (format "%__pw_~a_p" h))
(define pred2 (format "%__pw_~a_p2" h))
(define pred3 (format "%__pw_~a_p3" h))
(define mask (format "%__pw_~a_m" h))
(define tmp0 (format "%__pw_~a_t0" h))
(define tmp1 (format "%__pw_~a_t1" h))
(define laneidx "%warp_id")
(define diN
(and d
(list
`(storageN ".reg" ".pred" ,pred)
`(storageN ".reg" ".pred" ,pred2)
`(storageN ".reg" ".pred" ,pred3)
`(storageN ".reg" ".u32" ,mask)
`(storageN ".reg" ,type ,tmp0)
`(storageN ".reg" ,type ,tmp1)
`(instruction "activemask" ".b32" ,mask)
`(instruction "setp" ".ne" ".s32" ,pred ,mask "-1")
(if (< width 0)
`(instruction "setp" ".lt" ".u32"
,pred2 ,laneidx
,(number->string (abs width)))
`(instruction "setp" ".gt" ".u32"
,pred2 ,laneidx
,(number->string (- 31 (abs width)))))
`(instruction "or" ".pred" ,pred3 ,pred ,pred2)
(if (getenv "PW_NOLOAD")
`(instruction "mov" ".b32" ,tmp0 ,tmp0)
`(instruction
"shfl" ".sync"
,(if (< width 0) ".up" ".down") ".b32"
,(format "~a|~a" r pred)
,temp ,(number->string (abs width))
,(if (< width 0) "0" "31") ,mask))
(format "@~a" pred3)
(if (or (getenv "PW_NOLOAD") (getenv "PW_NOCORNER"))
`(instruction "mov" ".b32" ,tmp0 ,tmp0)
i)
)))
;; ;; Bi-directional (No such case exists if redundant shuffle is eliminated)
;; (define b (and d (= (length d) 5)))
;; (define b-temp (and b (format "%__pw_~a" (~ d 3) )))
;; (define b-width (and b (~ d 4)))
;; ;; Too much overhead with the corner case
;; (define b-diN
;; (and b
;; (list
;; `(storageN ".reg" ".pred" ,pred)
;; `(storageN ".reg" ".pred" ,pred2)
;; `(storageN ".reg" ".pred" ,pred3)
;; `(storageN ".reg" ".u32" ,mask)
;; `(storageN ".reg" ,type ,tmp0)
;; `(storageN ".reg" ,type ,tmp1)
;; `(storageN ".reg" ".u32" ,laneidx)
;; `(instruction "mov" ".u32" ,laneidx "%tid.x")
;; `(instruction "rem" ".u32" ,laneidx ,laneidx "32")
;; `(instruction "activemask" ".b32" ,mask)
;; `(instruction "setp" ".ne" ".s32" ,pred ,mask "-1")
;; (if (< width 0)
;; `(instruction "setp" ".lt" ".u32"
;; ,pred2 ,laneidx
;; ,(number->string (abs width)))
;; `(instruction "setp" ".gt" ".u32"
;; ,pred2 ,laneidx
;; ,(number->string (- 31 (abs width)))))
;; `(instruction "or" ".pred" ,pred3 ,pred ,pred2)
;; `(instruction
;; "shfl" ".sync"
;; ,(if (< width 0) ".up" ".down") ".b32"
;; ,(format "~a|~a" r pred)
;; ,temp ,(number->string (abs width))
;; ,(if (< width 0) "0" "31") ,mask)
;; `(instruction
;; "shfl" ".sync"
;; ,(if (< b-width 0) ".up" ".down") ".b32"
;; ,(format "~a|~a" tmp0 pred)
;; ,b-temp ,(number->string (abs b-width))
;; ,(if (< b-width 0) "0" "31") ,mask)
;; (format "@~a" pred2)
;; `(instruction "mov" ,type ,r ,tmp0)
;; (format "@~a" pred) i
;; )))
(define di0 (and d (list `(instruction "mov" ,type ,r ,temp))))
(define di (and d (if (= width 0) di0 diN)))
;; (define di (cond [(not d) #f] [(= width 0) di0] [b b-diN] [else diN]))
;;
;; // If src found:
;; %__pw_HASH <- dst
;;
;; // If dst found:
;; dst|p = __shfl_(up|down)_sync(__activemask(), %__pw_HASH, Shuffle_Width)
;; if (!p || dst == ~0))
;; dst = LD
;;
(cond [(not (or s d)) i]
[(and s d)
`(group ,@di ,si)]
[s `(group ,i ,si)]
[d `(group ,@di)]))
(module+ test
(require rackunit rackunit/text-ui)
(run-tests
(test-suite "gen"
(test-case "gen-???"
(check-equal? (values #t) #t)))
))