From 6c765f89917191fa498860a22684cfeade748cdb Mon Sep 17 00:00:00 2001 From: JP Martin Date: Thu, 5 May 2016 11:10:03 -0700 Subject: [PATCH] Refactor to use ExecutorService and BaseEncoding --- .../google/cloud/examples/nio/CountBytes.java | 6 +- .../examples/nio/ParallelCountBytes.java | 126 ++++++++---------- 2 files changed, 60 insertions(+), 72 deletions(-) diff --git a/gcloud-java-examples/src/main/java/com/google/cloud/examples/nio/CountBytes.java b/gcloud-java-examples/src/main/java/com/google/cloud/examples/nio/CountBytes.java index 184df901f714..a3f9779a0790 100644 --- a/gcloud-java-examples/src/main/java/com/google/cloud/examples/nio/CountBytes.java +++ b/gcloud-java-examples/src/main/java/com/google/cloud/examples/nio/CountBytes.java @@ -17,8 +17,8 @@ package com.google.cloud.examples.nio; import com.google.common.base.Stopwatch; +import com.google.common.io.BaseEncoding; -import javax.xml.bind.annotation.adapters.HexBinaryAdapter; import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; @@ -35,7 +35,7 @@ *

This example shows how to read a file size using NIO. * File.size returns the size of the file as saved in Storage metadata. * This class also shows how to read all of the file's contents using NIO, - * and reports how long it took. + * computes a MD5 hash, and reports how long it took. * *

See the README for compilation instructions. Run this code with * {@code target/appassembler/bin/CountBytes } @@ -85,7 +85,7 @@ private static void countFile(String fname) { long elapsed = sw.elapsed(TimeUnit.SECONDS); System.out.println("Read all " + total + " bytes in " + elapsed + "s. " + "(" + readCalls +" calls to chan.read)"); - String hex = (new HexBinaryAdapter()).marshal(md.digest()); + String hex = String.valueOf(BaseEncoding.base16().encode(md.digest())); System.out.println("The MD5 is: 0x" + hex); if (total != size) { System.out.println("Wait, this doesn't match! We saw " + total + " bytes, " + diff --git a/gcloud-java-examples/src/main/java/com/google/cloud/examples/nio/ParallelCountBytes.java b/gcloud-java-examples/src/main/java/com/google/cloud/examples/nio/ParallelCountBytes.java index 388af029e905..f443699bbc3c 100644 --- a/gcloud-java-examples/src/main/java/com/google/cloud/examples/nio/ParallelCountBytes.java +++ b/gcloud-java-examples/src/main/java/com/google/cloud/examples/nio/ParallelCountBytes.java @@ -17,8 +17,8 @@ package com.google.cloud.examples.nio; import com.google.common.base.Stopwatch; +import com.google.common.io.BaseEncoding; -import javax.xml.bind.annotation.adapters.HexBinaryAdapter; import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; @@ -27,28 +27,59 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.security.MessageDigest; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; /** * ParallelCountBytes will read through the whole file given as input. * *

This example shows how to go through all the contents of a file, - * in order, using multithreaded NIO reads.It also reports how long it took. + * in order, using multithreaded NIO reads. + * It prints a MD5 hash and reports how long it took. * *

See the README for compilation instructions. Run this code with * {@code target/appassembler/bin/ParallelCountBytes } */ public class ParallelCountBytes { - private class BufWithLock { - public Object lock; - public ByteBuffer buf; - public boolean full; - public Thread t; + /** + * WorkUnit holds a buffer and the instructions for what to put in it. + */ + private class WorkUnit implements Callable { + public final ByteBuffer buf; + final SeekableByteChannel chan; + final int blockSize; + int blockIndex; - public BufWithLock(int size) { - this.buf = ByteBuffer.allocate(size); - this.lock = new Object(); + public WorkUnit(SeekableByteChannel chan, int blockSize, int blockIndex) { + this.chan = chan; + this.buf = ByteBuffer.allocate(blockSize); + this.blockSize = blockSize; + this.blockIndex = blockIndex; + } + + @Override + public WorkUnit call() throws IOException { + int pos = blockSize * blockIndex; + if (pos > chan.size()) { + return this; + } + chan.position(pos); + // read until buffer is full, or EOF + while (chan.read(buf) > 0) {}; + return this; + } + + public WorkUnit resetForIndex(int blockIndex) { + this.blockIndex = blockIndex; + buf.flip(); + return this; } } @@ -69,37 +100,6 @@ public void start(String[] args) throws IOException { } } - private void stridedRead(SeekableByteChannel chan, int blockSize, int firstBlock, int stride, BufWithLock output) { - try { - // stagger the threads a little bit. - Thread.sleep(250 * firstBlock); - long pos = firstBlock * blockSize; - synchronized(output.lock) { - while (true) { - if (pos > chan.size()) { - break; - } - chan.position(pos); - // read until buffer is full, or EOF - while (chan.read(output.buf) > 0) {}; - output.full = true; - output.lock.notifyAll(); - if (output.buf.hasRemaining()) { - break; - } - // wait for main thread to process it - while (output.full) { - output.lock.wait(); - } - output.buf.flip(); - pos += stride * blockSize; - } - } - } catch (InterruptedException | IOException o) { - // this simple example doesn't handle errors, sorry. - } - } - /** * Print the length of the indicated file. * @@ -109,49 +109,37 @@ private void stridedRead(SeekableByteChannel chan, int blockSize, int firstBlock private void countFile(String fname) throws IOException{ // large buffers pay off final int bufSize = 50 * 1024 * 1024; + Queue> work = new ArrayDeque<>(); try { Path path = Paths.get(new URI(fname)); long size = Files.size(path); System.out.println(fname + ": " + size + " bytes."); - ByteBuffer buf = ByteBuffer.allocate(bufSize); - int nBlocks = (int)Math.ceil( size / (double)bufSize); + int nBlocks = (int) Math.ceil(size / (double) bufSize); int nThreads = nBlocks; if (nThreads > 4) nThreads = 4; System.out.println("Reading the whole file using " + nThreads + " threads..."); Stopwatch sw = Stopwatch.createStarted(); - final BufWithLock[] bufs = new BufWithLock[nThreads]; - for (int i = 0; i < nThreads; i++) { - bufs[i] = new BufWithLock(bufSize); - final SeekableByteChannel chan = Files.newByteChannel(path); - final int finalNThreads = nThreads; - final int finalI = i; - bufs[i].t = new Thread(new Runnable() { - @Override - public void run() { - stridedRead(chan, bufSize, finalI, finalNThreads, bufs[finalI]); - } - }); - bufs[i].t.start(); - } - long total = 0; MessageDigest md = MessageDigest.getInstance("MD5"); - for (int block = 0; block < nBlocks; block++) { - BufWithLock bwl = bufs[block % bufs.length]; - synchronized (bwl.lock) { - while (!bwl.full) { - bwl.lock.wait(); - } - md.update(bwl.buf.array(), 0, bwl.buf.position()); - total += bwl.buf.position(); - bwl.full = false; - bwl.lock.notifyAll(); + + ExecutorService exec = Executors.newFixedThreadPool(nThreads); + int blockIndex; + for (blockIndex = 0; blockIndex < nThreads; blockIndex++) { + work.add(exec.submit(new WorkUnit(Files.newByteChannel(path), bufSize, blockIndex))); + } + while (true) { + WorkUnit full = work.remove().get(); + md.update(full.buf.array(), 0, full.buf.position()); + total += full.buf.position(); + if (full.buf.hasRemaining()) { + break; } + work.add(exec.submit(full.resetForIndex(blockIndex++))); } long elapsed = sw.elapsed(TimeUnit.SECONDS); System.out.println("Read all " + total + " bytes in " + elapsed + "s. "); - String hex = (new HexBinaryAdapter()).marshal(md.digest()); + String hex = String.valueOf(BaseEncoding.base16().encode(md.digest())); System.out.println("The MD5 is: 0x" + hex); if (total != size) { System.out.println("Wait, this doesn't match! We saw " + total + " bytes, " +