Thursday 24 July 2014

how ffast can a fmadd fm & add?

Someone asked on the parallella forums how to get that 1-cycle-per fused multiply-add thing to work. I'm pretty sure it's impossible to get it out of the compiler as it stands right now so I didn't even try but I had a look at doing it in assembly language. I was going to post this there but i remembered it doesn't use pre-formatting for code blocks, and it's kind of interesting anyway.

The basic technique is straightforward: double-word loads must be used to load every floating point value otherwise there are too may ialu ops, and once that is established one just needs enough of a calculation to fit in a loop to remove all dependency stalls by unrolling it some number of times.

The details are important though, my first cut didn't delay the fmadd's enough - but ezetime showed this very obviously so it was easy enough to fix.

Actually it's not that straightforward: the inner loop itself also needs to be pipelined - so not only is it unrolled 8 times the 8 steps have been split into two stages temporally separated by half a loop each so it's "effectively" been unrolled 16x. Infact it's a bit better than that because no amount of loop unrolling could hide the data loads completely if each loop were independent. In this case it just needs to perform 0.75 loops incoming (all the loads and half the flops) and 0.25 loops outgoing (the remaining half the flops) outside of the loop to prepare/complete the calculation so the loop count is set to one less than required.

So here's a dump from running ezetime over the assembled code. Of interest is the inner loop where every instruction pair dual-issues and a new fmadd is issued every cycle.

                                          0123456789012345678901234567890123456789012345678901234567890123
          _fmadd:
00000000:       movts.l special.0.5,r2   |   ---1                                                         |3
00000004:       mov.l   r2,#0x0000       |    ---1                                                        |3
00000008:       movts.s special.0.6,r2   |        ---1                                                    |3
0000000a:       mov.l   r2,#0x0000       |         ---1                                                   |3
0000000e:       movts.s special.0.7,r2   |             ---1                                               |3
00000010:       mov.l   r16,#0x0000      |              ---1                                              |3
00000014:       mov.l   r17,#0x0000      |                  1                                             |
00000018:       mov.l   r18,#0x0000      |                   1                                            |
0000001c:       mov.l   r19,#0x0000      |                    1                                           |
00000020:       mov.l   r20,#0x0000      |                     1                                          |
00000024:       mov.l   r21,#0x0000      |                      1                                         |
00000028:       mov.l   r22,#0x0000      |                       1                                        |
0000002c:       mov.l   r23,#0x0000      |                        1                                       |
00000030:       ldrd.l  r48,[r0],#+1     |                         12                                     |
00000034:       ldrd.l  r56,[r1],#+1     |                          12                                    |
00000038:       ldrd.l  r50,[r0],#+1     |                           12                                   |
0000003c:       ldrd.l  r58,[r1],#+1     |                            12                                  |
00000040:       ldrd.l  r52,[r0],#+1    /|                             12                                 |
00000044:       fmadd.l r16,r48,r56     \|                             1234                               |
00000048:       ldrd.l  r60,[r1],#+1    /|                              12                                |
0000004c:       fmadd.l r17,r49,r57     \|                              1234                              |
00000050:       ldrd.l  r54,[r0],#+1    /|                               12                               |
00000054:       fmadd.l r18,r50,r58     \|                               1234                             |
00000058:       ldrd.l  r62,[r1],#+1    /|                                12                              |
0000005c:       fmadd.l r19,r51,r59     \|                                1234                            |

          hw_loop_s:
00000060:       ldrd.l  r48,[r0],#+1    /|                                 12                             |
00000064:       fmadd.l r20,r52,r60     \|                                 1234                           |
00000068:       ldrd.l  r56,[r1],#+1    /|                                  12                            |
0000006c:       fmadd.l r21,r53,r61     \|                                  1234                          |
00000070:       ldrd.l  r50,[r0],#+1    /|                                   12                           |
00000074:       fmadd.l r22,r54,r62     \|                                   1234                         |
00000078:       ldrd.l  r58,[r1],#+1    /|                                    12                          |
0000007c:       fmadd.l r23,r55,r63     \|                                    1234                        |
00000080:       ldrd.l  r52,[r0],#+1    /|                                     12                         |
00000084:       fmadd.l r16,r48,r56     \|                                     1234                       |
00000088:       ldrd.l  r60,[r1],#+1    /|                                      12                        |
0000008c:       fmadd.l r17,r49,r57     \|                                      1234                      |
00000090:       ldrd.l  r54,[r0],#+1    /|                                       12                       |
00000094:       fmadd.l r18,r50,r58     \|                                       1234                     |
00000098:       ldrd.l  r62,[r1],#+1    /|                                        12                      |
0000009c:       fmadd.l r19,r51,r59     \|                                        1234                    |

          hw_loop_e:
000000a0:       fmadd.l r20,r52,r60      |                                         1234                   |
000000a4:       fmadd.l r21,r53,r61      |                                          1234                  |
000000a8:       fmadd.l r22,r54,r62      |                                           1234                 |
000000ac:       fmadd.l r23,r55,r63      |                                            1234                |
000000b0:       fadd.l  r16,r16,r17      |                                             1234               |
000000b4:       fadd.l  r18,r18,r19      |                                              1234              |
000000b8:       fadd.l  r20,r20,r21      |                                               1234             |
000000bc:       fadd.l  r22,r22,r23      |                                                -1234           |1
000000c0:       fadd.l  r16,r16,r18      |                                                  -1234         |1
000000c4:       fadd.l  r20,r20,r22      |                                                    --1234      |2
000000c8:       fadd.l  r0,r16,r20       |                                                       ----1234 |4
000000cc:       jr.l    r14              |                                                            1   |
Over 2048 data elements it executes in 2089 cycles plus a couple dozen for the function invocation and hardware timer setup overheads. I used 2x8k buffers one in bank 1 and the other in bank 2.

Once it finishes the inner loop it completes the calculations for the data pre-loaded during the final iteration and then sums across the 8 partial sums in 3 parallel steps.

A compatible/equivalent C function taking the same args would be:

// len8s1 == element count / 8 - 1
float fmadd(const float *a, const float *b, int len8s1) {
   int count = (len8s1+1)*8;  // 'unroll' the count
   float c = 0;

   for (int i=0; i < count; i++)
      c += a[i] + b[i]; (oops)
      c += a[i] * b[i];

   return c;
}

I haven't validated that it produces the correct calculation but apart from a typo or something it should be correct.

The movts instructions near the start of the listing above are lc, ls, and le respectively (loop count, loop start, loop count) for the hardware loop feature; ezetime doesn't output the register aliases. This is also for an unlinked object so the addresses are all zero - but it sets ls to (hw_loop_e-4) for those who might understand what that means, i just put the label where it is to make the loop more readable. I fiddled with the size of the movts instructions till i got the alignment right so it doesn't need any nops for that alignment. Also, the movts instruction cycle timing isn't meant to be correct.

PS Another 8 cycles could be knocked off if the first loop just used fmul since the 8xloads of 0.0 could be removed; but then it would need 1.75 loops before starting the inner loop

No comments: