JDK1.7 Fork Join

Aug 19, 2016


介绍

Java7引入了Fork/Join的概念,来更好的支持并行运算。Fork/Join类似与流程语言的分支,合并的概念。也就是说Java7 SE原生支持了在一个主线程中开辟多个分支线程,并且根据分支线程的逻辑来等待(或者不等待)汇集,当然你也可以fork的某一个分支线程中再开辟Fork Join,这也就可以实现Fork Join的嵌套。

两个核心类ForkJoinPool和ForkJoinTask。

ForkJoinPool实现了ExecutorService接口,起到线程池的作用。所以他的用法和Executor框架的使用时一样的,当然Fork Join本身就是Executor框架的扩展。ForkJoinPool有3个关键的方法,来启动线程。

行为 客户端非fork/join调用 内部调用fork/join
异步执行 execute(ForkJoinTask) ForkJoinTask.fork
阻塞,等待获取结果 invoke(ForkJoinTask) ForkJoinTask.invokeAll(Collection tasks)
执行,获取Futrue submit(ForkJoinTask) ForkJoinTask.fork, join

ForkJoinTask有两个子抽象类RecursiveAction和RecursiveTask。

  1. RecursiveAction用于没有返回结果的任务;

     abstract class RecursiveAction extends ForkJoinTask<Void>
    
  2. RecursiveTask用于有返回结果的任务。

     abstract class RecursiveTask<V> extends ForkJoinTask<V>
    

示例1, RecursiveAction的使用

大数组的初始化,采用三种方式实现,比较效率。

  • 顺序实现,单线程顺序初始化数组;
  • 使用Fork/Join实现;
  • 使用并发包中的线程池来实现。

代码:

package edu.zju.chwl.forkjoin;

import static java.util.Arrays.asList;

import java.util.Random;
import java.util.concurrent.*;

public class ForkJoinRandomFillAction {
    
    public static void main(String[] args){
        sequential();
        forkjoin();
        threadPool();
    }

    public static void loadArray(int[] array, int low, int high) {
        Random random = new Random();
        for (int i = low; i < high; i++) {
            array[i] = random.nextInt(10000);
        }
    }

    static final int arrayLength = 30000000;

    static final int iterations = 4;

    static final int splitSize = 40000;

    // 顺序实现
    public static void sequential() {

        int array[] = new int[arrayLength];

        for (int i = 0; i < iterations; i++) {
            long start = System.currentTimeMillis();
            loadArray(array, 0, array.length);
            System.out.println("Sequential processing time: "
                    + (System.currentTimeMillis() - start) + " ms");
        }

    }

    // 使用fork join并发递归实现
    public static void forkjoin() {

        int array[] = new int[arrayLength];

        System.out.println("Number of processor available: " + Runtime.getRuntime().availableProcessors());

        ForkJoinPool fjpool = new ForkJoinPool(64); 

        for (int i = 0; i < iterations; i++) {
            // Create a task with the complete array
            RecursiveAction task = new RandomFillAction(array, 0, array.length);
            long start = System.currentTimeMillis();
            fjpool.invoke(task);//阻塞等待返回結果
            System.out.println("Parallel processing time: "
                    + (System.currentTimeMillis() - start) + " ms");
        }

        System.out
                .println("Number of steals: " + fjpool.getStealCount() + "\n");

    }

    //使用线程池并发递归实现, 和fork join 的效率差不多, 但不灵活,一开始就要定义好每个线程处理那些数据
    public static void threadPool() {

        int array[] = new int[arrayLength];

        ExecutorService threadPool = Executors.newFixedThreadPool(64);
        ExecutorCompletionService<Boolean> completionService = new ExecutorCompletionService<Boolean>(
                threadPool);//completionService包装线程池
        for (int i = 0; i < iterations; i++) {
            long start = System.currentTimeMillis();
            int size = (array.length - 1) / splitSize + 1;
            for (int j = 0; j < size; j++) {
                completionService.submit(new FillTask(array, j * splitSize,
                        (j + 1) * splitSize), true);
            }
            for (int j = 0; j < size; j++) {
                try {
                    completionService.take().get();//确保任务已执行完
                } catch (InterruptedException | ExecutionException e) {
                }
            }
            System.out.println("Thread Pool processing time: "
                    + (System.currentTimeMillis() - start) + " ms");
        }

    }
    
    static class FillTask implements Runnable {

        final int low;
        final int high;
        private int[] array;

        public FillTask(int[] array, int low, int high) {
            this.low = low;
            this.high = high;
            this.array = array;
        }

        @Override
        public void run() {
            loadArray(array,low,high);
        }

    }
    
    static class RandomFillAction extends RecursiveAction {
        private static final long serialVersionUID = 1L;
        final int low;
        final int high;
        private int[] array;

        public RandomFillAction(int[] array, int low, int high) {
            this.low = low;
            this.high = high;
            this.array = array;
        }

        @Override
        protected void compute() {
            if (high - low > splitSize) {
                int mid = (low + high) >>> 1;
                invokeAll(asList(new RandomFillAction(array, low, mid),
                        new RandomFillAction(array, mid, high)));
                //注意不能两个Action分开invoke,invoke会等待执行完成
                //new RandomFillAction(array, low, mid).invoke();
                //new RandomFillAction(array, mid, high).invoke();
            } else {
                loadArray(array,low,high);
            }
        }
    }
}

执行结果:

Sequential processing time: 378 ms
Sequential processing time: 360 ms
Sequential processing time: 360 ms
Sequential processing time: 362 ms
Number of processor available: 4
Parallel processing time: 238 ms
Parallel processing time: 123 ms
Parallel processing time: 166 ms
Parallel processing time: 187 ms
Number of steals: 260

Thread Pool processing time: 134 ms
Thread Pool processing time: 133 ms
Thread Pool processing time: 119 ms
Thread Pool processing time: 133 ms

示例2, RecursiveTask的使用

n个随机数相加

代码:

package edu.zju.chwl.forkjoin;

import java.util.Random;
import java.util.concurrent.*;

public class ForkJoinCount {

    static final int splitSize=20000000;
        
    static final int n=100000000;
    
    public static int sum(int low, int high) {
        Random random = new Random();
        int sum = 0;
        for (int i = low; i < high; i++) {
            sum += random.nextInt(2);
        }
        return sum;
    }
    
    public static void main(String[] args) {
        forkjoin();
        sequential();
    }
    
    public static void forkjoin() {
        ForkJoinPool pool = new ForkJoinPool(64);
        long start = System.currentTimeMillis();
        ForkJoinTask<Integer> retFuture = pool.submit(new CountTask(1,n));
        try {
            retFuture.get();
        } catch (InterruptedException|ExecutionException e) {
        }
        System.out.println("Parallel processing time: "
                + (System.currentTimeMillis() - start) + " ms");
    }
    
    public static void sequential() {
        long start = System.currentTimeMillis();
        sum(1,n);
        System.out.println("Sequential processing time: "
                + (System.currentTimeMillis() - start) + " ms");
    }
    
    static class CountTask extends RecursiveTask<Integer>{

        private static final long serialVersionUID = 1L;
        
        private int start;
        
        private int end;
        
        CountTask(int start,int end){
            this.start=start;
            this.end=end;
        }

        @Override
        protected Integer compute() {
            if((end-start)<splitSize){
                return sum(start,end);
            }else{
                int mid = (start+end)>>>1;
                ForkJoinTask<Integer> leftTask = new CountTask(start, mid).fork();
                ForkJoinTask<Integer> rightTask = new CountTask(mid, end).fork();
                return leftTask.join()+rightTask.join();
                //不能这样写,要注意join方法会阻塞当前线程并等待获取结果
                //return new CountTask(start, mid).fork().join()+new CountTask(mid, end).fork().join();
            }
        }        
    }
}

执行结果:

Parallel processing time: 470 ms
Sequential processing time: 1199 ms

注意:ForkJoinTask.join()会阻塞当前线程并等待获取结果


上一篇博客:NIO实例