Thursday 23 April 2015

spliterators are just iterators

I've been mucking about on-and-off with a 2D pixel library (i.e images) and as one part of that I was trying to utilise the stream functions from java 8 and/or at least look into some functional callbacks.

I can't just use the jre stuff because they don't have byte or float streams and they don't handle 2D data; I wanted to be able to support sub-regions of a rectangular array accessed by offset with different strides.

It took me a little while to wrap my head around spliterators because I was looking at them as if their primary job was to split the work up into smaller chunks to be worked across threads; but this is completely the wrong way to frame them. They are instead just iterators that can be broken into non-overlapping parts.

Once I realised that I had no trouble coming up with some 2D array streams of various types and trying to work out how to get better performance.

Given the code complexity i'm fairly impressed with the performance of the stream code in general but it doesn't always seem to interact with the jvm very well. Some code I have gets slower once hotspot has completed it's optimisations; which is odd. And sometimes it's really not too fast at all.

So I have a test-case which creates a 2D byte array of 1024x1024 elements and sums up every element as an integer 1000 times in tight loop.

The basic object i'm using for these tests is:

class BytePixels {
    public final byte[] pixels;
    public final int width, height, stride, offset;

    public int offset(int x, int y)  { return x+y*stride+offset; }
    public int getb(int x, int y)    { return pixels[offset(x,y)] & 0xff; }
    public int getb(int i)           { return pixels[i] & 0xff; }
}

The machine is a quad-core cpu and 4 threads are used whenever concurrency.

I'll go through some of the cases I tried and show the inner loops. For the most part i'm assuming offset is not zero since that case is a little simpler but its difference to runtime is minimal. I'm applying obvious micro-optimisations like copying loop bounds or array pointers to local variables.

The first examples are just using basic Java & jre facilities.

376ms Nested loops with getb(x,y)

This loops over the dimensions of the image and calls getb(x,y) to retrieve each element.

for (int y = 0; y < height; y++)
    for (int x = 0; x < width; x++)
        sum += bp.getb(x, y);

This sets a base-line for the most obvious implementation.

387ms Nested loops with pixels[o++]

This loops over the dimensions but removes the internal array calculation by calculating the offset at the start of each row and then accessing the pixel values directly.

for (int y = 0; y < height; y++)
    for (int x = 0, o = bp.offset(0, y); x < width; x++)
        sum += pixels[o++] & 0xff;

This is supposedly optimised ... but executes marginally slower or about the same. Ahah what? I can't explain this one.

328ms Nested loops with pixels[o++] take 2

This is the same as the previous but changes the inner loop bounds checking so that the incrementing value is the same one used as the bounds check.

for (int y = 0; y < height; y++)
    for (int o = bp.offset(0, y), e = o + width; o < < e; o++)
        sum += pixels[o] & 0xff;

Given the simplicity of the change; this makes a pretty big difference. I must remember this for the future.

I only tried this whilst writing this post so went back and applied it to some of the other algorithms ...

2980ms map/reduce in parallel with flat indices.

This uses IntStream.range().parallel() to create a stream of all indices in the array and uses .map() to turn it into an array of elements. Because range is only 1-dimentional this is not a general solution merely 'best case'.

IntStream.range(0, width * height)
    .parallel()
    .map((i) -> getb(i))
    .sum();

So ... it is not very 'best'. Actually a further wrinkle is unlike other examples this one runs slower after hotspot has completed it's optimisations and the first-run is the quickest. And we're talking 2-3x slower here not just a bit.

3373ms map/reduce in parallel with 2D indices.

This uses IntStream.range().parallel() to create a stream of all indices within the bounds of the image and then remaps these back to 2D coordinates inside the .map() function. This is the general solution which implements the required features.

IntStream.range(0, width * height)
    .parallel()
    .map((i) -> getb(offset + (i / width) * stride + (i % width)));
    .sum();

Not surprisingly it's a bit slower, and like the previous example also gets slower once optimised.

Basically; that's a no then, this just isn't going to cut it. So time to see if a custom spliterator will do it.

1960ms map/reduce in parallel with custom 2D spliterator

This uses a custom 2D spliterator that splits (at most) into rows and then uses map() to retrieve the pixel values.

StreamSupport.intStream(new Spliterator2DIndex(width, height, offset, stride), true)
    .map((i) -> getb(i))
    .sum();

Still "slow as shit" to use the correct technical jargon.

This is about when I worked out how to use spliterators and tried the next obvious thing: using the spliterator to do the map() since that seems to be "problematic".

125ms stream reduce in parallel with custom 2D array spliterator

The custom 2D spliterator performs the array lookups itself and feeds out the values as the stream. It supports the full features required.

StreamSupport.intStream(new Spliterator2DByteArray(width, height, offset, stride, pixels), true)
    .sum();

Ok, so now this is more like it. It's finally managed to beat that single threaded code.

Until this point I was ready to ditch even bothering with the stream code; one can deal with a bit of a performance drop for some convenience and code re-use, but 10x is just unacceptable.

115ms stream reduce in parallel with custom 1D array spliterator

This loops over all elements of an array and feeds out the array values. This is a non-conformant solution to determine the overheads of the 2D indexing.

StreamSupport.intStream(new SpliteratorByteArray(offset, width * height, pixels), true)
    .sum();

The overheads of the 2D case are not zero but they are modest.

I tried a bunch of other things right through to a completely different multi-threaded forEach() call and polled queues and a more OpenCL-like approach to the processing model; my best result was under 100ms.

By `OpenCL approach' I made the thread index an explicitly available parameter so that for example a reduction step can just write a partial result directly to a pre-allocated array of values by 'group id' rather than having to allocate and pass back working storage. This together with an exotic single-reader work queue produced the best result. I also learned how to use park(), ... maybe something for a later post.

Optimisations

I found it particularly critical to optimise the spliterator iteration function forEachRemaining(). Even trivial changes can have a big impact.

The difference between this:

private final byte[] pixels;
private final int width, stride;
private int height, offset;
private int x, y;

public void forEachRemaining(IntConsumer action) {
    while (y < height) {
        while (x < width)
            action.accept(pixels[(x++) + y * stride + offset] & 0xff);
        x = 0;
        y += 1;
    }
}

And this:

private final byte[] pixels;
private final int width, stride;
private int height, offset;
private int x, y;

public void forEachRemaining(IntConsumer action) {
    int y = this.y;
    int height = this.height;
    int x = this.x;
    byte[] pixels = this.pixels;
    int offset = this.offset;

    while (y < height) {
        for (int o = x + y * stride + offset, e = y * stride + offset + width; o < e; o++)
            action.accept(pixels[o] & 0xff);
        y += 1;
        x = 0;
    }
    this.x = x;
    this.y = y;
}

Is an over 600% difference in running time.

For images

These streams aren't really very useful for a lot of image processing; only handling single channel data and losing the positional information leaves only a few basic and not widely useful operations left over. The other stream interfaces like collections just wont scale very well to images either: they require intermediate blocks to be allocated rather than just referencing sub-ranges.

So I have been and probably will continue to use the streams in these cases as multi-threaded forEach() and gain the concurrency serialisation implicitly by each task working on independent values. But I can also just borrow ideas and implement similar interfaces even if they are not compatible, e.g. a reduce call that just runs directly on the data without forming a stream first. This might let me get some of the code re-usability and brevity benefits without trying to force a not-quite-right processing model onto the problem.

But being able to create custom/multi-dimensional spliterators is pretty handy all the same.

No comments: