介绍
Java7引入了Fork/Join的概念,来更好的支持并行运算。Fork/Join类似与流程语言的分支,合并的概念。也就是说Java7 SE原生支持了在一个主线程中开辟多个分支线程,并且根据分支线程的逻辑来等待(或者不等待)汇集,当然你也可以fork的某一个分支线程中再开辟Fork Join,这也就可以实现Fork Join的嵌套。
两个核心类ForkJoinPool和ForkJoinTask。
ForkJoinPool实现了ExecutorService接口,起到线程池的作用。所以他的用法和Executor框架的使用时一样的,当然Fork Join本身就是Executor框架的扩展。ForkJoinPool有3个关键的方法,来启动线程。
行为 | 客户端非fork/join调用 | 内部调用fork/join |
---|---|---|
异步执行 | execute(ForkJoinTask) | ForkJoinTask.fork |
阻塞,等待获取结果 | invoke(ForkJoinTask) | ForkJoinTask.invokeAll(Collection |
执行,获取Futrue | submit(ForkJoinTask) | ForkJoinTask.fork, join |
ForkJoinTask有两个子抽象类RecursiveAction和RecursiveTask。
-
RecursiveAction用于没有返回结果的任务;
abstract class RecursiveAction extends ForkJoinTask<Void>
-
RecursiveTask用于有返回结果的任务。
abstract class RecursiveTask<V> extends ForkJoinTask<V>
示例1, RecursiveAction的使用
大数组的初始化,采用三种方式实现,比较效率。
- 顺序实现,单线程顺序初始化数组;
- 使用Fork/Join实现;
- 使用并发包中的线程池来实现。
代码:
package edu.zju.chwl.forkjoin;
import static java.util.Arrays.asList;
import java.util.Random;
import java.util.concurrent.*;
public class ForkJoinRandomFillAction {
public static void main(String[] args){
sequential();
forkjoin();
threadPool();
}
public static void loadArray(int[] array, int low, int high) {
Random random = new Random();
for (int i = low; i < high; i++) {
array[i] = random.nextInt(10000);
}
}
static final int arrayLength = 30000000;
static final int iterations = 4;
static final int splitSize = 40000;
// 顺序实现
public static void sequential() {
int array[] = new int[arrayLength];
for (int i = 0; i < iterations; i++) {
long start = System.currentTimeMillis();
loadArray(array, 0, array.length);
System.out.println("Sequential processing time: "
+ (System.currentTimeMillis() - start) + " ms");
}
}
// 使用fork join并发递归实现
public static void forkjoin() {
int array[] = new int[arrayLength];
System.out.println("Number of processor available: " + Runtime.getRuntime().availableProcessors());
ForkJoinPool fjpool = new ForkJoinPool(64);
for (int i = 0; i < iterations; i++) {
// Create a task with the complete array
RecursiveAction task = new RandomFillAction(array, 0, array.length);
long start = System.currentTimeMillis();
fjpool.invoke(task);//阻塞等待返回結果
System.out.println("Parallel processing time: "
+ (System.currentTimeMillis() - start) + " ms");
}
System.out
.println("Number of steals: " + fjpool.getStealCount() + "\n");
}
//使用线程池并发递归实现, 和fork join 的效率差不多, 但不灵活,一开始就要定义好每个线程处理那些数据
public static void threadPool() {
int array[] = new int[arrayLength];
ExecutorService threadPool = Executors.newFixedThreadPool(64);
ExecutorCompletionService<Boolean> completionService = new ExecutorCompletionService<Boolean>(
threadPool);//completionService包装线程池
for (int i = 0; i < iterations; i++) {
long start = System.currentTimeMillis();
int size = (array.length - 1) / splitSize + 1;
for (int j = 0; j < size; j++) {
completionService.submit(new FillTask(array, j * splitSize,
(j + 1) * splitSize), true);
}
for (int j = 0; j < size; j++) {
try {
completionService.take().get();//确保任务已执行完
} catch (InterruptedException | ExecutionException e) {
}
}
System.out.println("Thread Pool processing time: "
+ (System.currentTimeMillis() - start) + " ms");
}
}
static class FillTask implements Runnable {
final int low;
final int high;
private int[] array;
public FillTask(int[] array, int low, int high) {
this.low = low;
this.high = high;
this.array = array;
}
@Override
public void run() {
loadArray(array,low,high);
}
}
static class RandomFillAction extends RecursiveAction {
private static final long serialVersionUID = 1L;
final int low;
final int high;
private int[] array;
public RandomFillAction(int[] array, int low, int high) {
this.low = low;
this.high = high;
this.array = array;
}
@Override
protected void compute() {
if (high - low > splitSize) {
int mid = (low + high) >>> 1;
invokeAll(asList(new RandomFillAction(array, low, mid),
new RandomFillAction(array, mid, high)));
//注意不能两个Action分开invoke,invoke会等待执行完成
//new RandomFillAction(array, low, mid).invoke();
//new RandomFillAction(array, mid, high).invoke();
} else {
loadArray(array,low,high);
}
}
}
}
执行结果:
Sequential processing time: 378 ms
Sequential processing time: 360 ms
Sequential processing time: 360 ms
Sequential processing time: 362 ms
Number of processor available: 4
Parallel processing time: 238 ms
Parallel processing time: 123 ms
Parallel processing time: 166 ms
Parallel processing time: 187 ms
Number of steals: 260
Thread Pool processing time: 134 ms
Thread Pool processing time: 133 ms
Thread Pool processing time: 119 ms
Thread Pool processing time: 133 ms
示例2, RecursiveTask的使用
n个随机数相加
代码:
package edu.zju.chwl.forkjoin;
import java.util.Random;
import java.util.concurrent.*;
public class ForkJoinCount {
static final int splitSize=20000000;
static final int n=100000000;
public static int sum(int low, int high) {
Random random = new Random();
int sum = 0;
for (int i = low; i < high; i++) {
sum += random.nextInt(2);
}
return sum;
}
public static void main(String[] args) {
forkjoin();
sequential();
}
public static void forkjoin() {
ForkJoinPool pool = new ForkJoinPool(64);
long start = System.currentTimeMillis();
ForkJoinTask<Integer> retFuture = pool.submit(new CountTask(1,n));
try {
retFuture.get();
} catch (InterruptedException|ExecutionException e) {
}
System.out.println("Parallel processing time: "
+ (System.currentTimeMillis() - start) + " ms");
}
public static void sequential() {
long start = System.currentTimeMillis();
sum(1,n);
System.out.println("Sequential processing time: "
+ (System.currentTimeMillis() - start) + " ms");
}
static class CountTask extends RecursiveTask<Integer>{
private static final long serialVersionUID = 1L;
private int start;
private int end;
CountTask(int start,int end){
this.start=start;
this.end=end;
}
@Override
protected Integer compute() {
if((end-start)<splitSize){
return sum(start,end);
}else{
int mid = (start+end)>>>1;
ForkJoinTask<Integer> leftTask = new CountTask(start, mid).fork();
ForkJoinTask<Integer> rightTask = new CountTask(mid, end).fork();
return leftTask.join()+rightTask.join();
//不能这样写,要注意join方法会阻塞当前线程并等待获取结果
//return new CountTask(start, mid).fork().join()+new CountTask(mid, end).fork().join();
}
}
}
}
执行结果:
Parallel processing time: 470 ms
Sequential processing time: 1199 ms
注意:ForkJoinTask.join()会阻塞当前线程并等待获取结果