一文带你了解Java中的ForkJoin

2022-07-15,,,

前言:

forkjoin是在java7中新加入的特性,大家可能对其比较陌生,但是java8中stream的并行流parallelstream就是依赖于forkjoin。在forkjoin体系中最为关键的就是forkjointask和forkjoinpool,forkjoin就是利用分治的思想将大的任务按照一定规则fork拆分成小任务,再通过join聚合起来。

什么是forkjoin?

forkjoin 从字面上看fork是分岔的意思,join是结合的意思,我们可以理解为将大任务拆分成小任务进行计算求解,最后将小任务的结果进行结合求出大任务的解,这些裂变出来的小任务,我们就可以交给不同的线程去进行计算,这也就是分布式计算的一种思想。这与大数据中的分布式离线计算mapreduce类似,对forkjoin最经典的一个应用就是java8中的stream,我们知道stream分为串行流和并行流,其中并行流parallelstream就是依赖于forkjoin来实现并行处理的。

下面我们一起来看一下最为核心的forkjointaskforkjoinpool

forkjointask 任务

forkjointask本身的依赖关系并不复杂,它与异步任务计算futuretask一样均实现了future接口,futuretask我们在之前的文章中有讲到感兴趣的可以阅读一下——java从源码看异步任务计算futuretask

下面我们就forkjointask的核心源码来研究一下,该任务是如何通过分治法进行计算。

forkjointask最核心的莫过于fork()和join()方法了。

fork()

  • 判断当前线程是不是forkjoinworkerthread线程
    • 是 直接将当前线程push到工作队列中
    • 否 调用forkjoinpool 的externalpush方法

forkjoinpool构建了一个静态的common对象,这里调用的就是commonexternalpush()

join()

  • 调用dojoin()方法,等待线程执行完成
    public final forkjointask<v> fork() {
        thread t;
        if ((t = thread.currentthread()) instanceof forkjoinworkerthread)
            ((forkjoinworkerthread)t).workqueue.push(this);
        else
            forkjoinpool.common.externalpush(this);
        return this;
    }

    public final v join() {
        int s;
        if ((s = dojoin() & done_mask) != normal)
            reportexception(s);
        return getrawresult();
    }

    private int dojoin() {
        int s; thread t; forkjoinworkerthread wt; forkjoinpool.workqueue w;
        return (s = status) < 0 ? s :
            ((t = thread.currentthread()) instanceof forkjoinworkerthread) ?
            (w = (wt = (forkjoinworkerthread)t).workqueue).
            tryunpush(this) && (s = doexec()) < 0 ? s :
            wt.pool.awaitjoin(w, this, 0l) :
            externalawaitdone();
    }

	// 获取结果的方法由子类实现
	public abstract v getrawresult();	

recursivetask 是forkjointask的一个子类主要对获取结果的方法进行了实现,通过泛型约束结果。我们如果需要自己创建任务,仍需要实现recursivetask,并去编写最为核心的计算方法compute()。

public abstract class recursivetask<v> extends forkjointask<v> {
    private static final long serialversionuid = 5232453952276485270l;

    v result;

    protected abstract v compute();

    public final v getrawresult() {
        return result;
    }

    protected final void setrawresult(v value) {
        result = value;
    }
    protected final boolean exec() {
        result = compute();
        return true;
    }

}

forkjoinpool 线程池

forkjointask 中许多功能都依赖于forkjoinpool线程池,所以说forkjointask运行离不开forkjoinpool,forkjoinpool与threadpoolexecutor有许多相似之处,他是专门用来执行forkjointask任务的线程池,我之前也有文章对线程池技术进行了介绍,感兴趣的可以进行阅读——

forkjoinpool与threadpoolexecutor的继承关系几乎是相同的,他们相当于兄弟关系。

工作窃取算法

forkjoinpool中采取工作窃取算法,如果每次fork子任务如果都去创建新线程去处理的话,对系统资源的开销是巨大的,所以必须采取线程池。一般的线程池只有一个任务队列,但是对于forkjoinpool来说,由于同一个任务fork出的各个子任务是平行关系,为了提高效率,减少线程的竞争,需要将这些平行的任务放到不同的队列中,由于线程处理不同任务的速度不同,这样就可能存在某个线程先执行完了自己队列中的任务,这时为了提升效率,就可以让该线程去“窃取”其它任务队列中的任务,这就是所谓的“工作窃取算法”。

对于一般的队列来说,入队元素都是在队尾,出队元素在队首,要满足“工作窃取”的需求,任务队列应该支持从“队尾”出队元素,这样可以减少与其它工作线程的冲突(因为其它工作线程会从队首获取自己任务队列中的任务),这时就需要使用双端阻塞队列来解决。

构造方法

首先我们来看forkjoinpool线程池的构造方法,他为我们提供了三种形式的构造,其中最为复杂的是四个入参的构造,下面我们看一下它四个入参都代表什么?

  • int parallelism 可并行级别(不代表最多存在的线程数量)
  • forkjoinworkerthreadfactory factory 线程创建工厂
  • uncaughtexceptionhandler handler 异常捕获处理器
  • boolean asyncmode 先进先出的工作模式 或者 后进先出的工作模式
    public forkjoinpool() {
        this(math.min(max_cap, runtime.getruntime().availableprocessors()),
             defaultforkjoinworkerthreadfactory, null, false);
    }

	public forkjoinpool(int parallelism) {
        this(parallelism, defaultforkjoinworkerthreadfactory, null, false);
    }

	public forkjoinpool(int parallelism,
                        forkjoinworkerthreadfactory factory,
                        uncaughtexceptionhandler handler,
                        boolean asyncmode) {
        this(checkparallelism(parallelism),
             checkfactory(factory),
             handler,
             asyncmode ? fifo_queue : lifo_queue,
             "forkjoinpool-" + nextpoolid() + "-worker-");
        checkpermission();
    }

提交方法

下面我们看一下提交任务的方法:

externalpush这个方法我们很眼熟,它正是在fork的时候如果当前线程不是forkjoinworkerthread,新提交任务也是会通过这个方法去执行任务。由此可见,fork就是新建一个子任务进行提交。

externalsubmit是最为核心的一个方法,它可以首次向池提交第一个任务,并执行二次初始化。它还可以检测外部线程的首次提交,并创建一个新的共享队列。

signalwork(ws, q)是发送工作信号,让工作队列进行运转。

    public forkjointask<?> submit(runnable task) {
        if (task == null)
            throw new nullpointerexception();
        forkjointask<?> job;
        if (task instanceof forkjointask<?>) // avoid re-wrap
            job = (forkjointask<?>) task;
        else
            job = new forkjointask.adaptedrunnableaction(task);
        externalpush(job);
        return job;
    }

    final void externalpush(forkjointask<?> task) {
        workqueue[] ws; workqueue q; int m;
        int r = threadlocalrandom.getprobe();
        int rs = runstate;
        if ((ws = workqueues) != null && (m = (ws.length - 1)) >= 0 &&
            (q = ws[m & r & sqmask]) != null && r != 0 && rs > 0 &&
            u.compareandswapint(q, qlock, 0, 1)) {
            forkjointask<?>[] a; int am, n, s;
            if ((a = q.array) != null &&
                (am = a.length - 1) > (n = (s = q.top) - q.base)) {
                int j = ((am & s) << ashift) + abase;
                u.putorderedobject(a, j, task);
                u.putorderedint(q, qtop, s + 1);
                u.putorderedint(q, qlock, 0);
                if (n <= 1)
                    signalwork(ws, q);
                return;
            }
            u.compareandswapint(q, qlock, 1, 0);
        }
        externalsubmit(task);
    }

    private void externalsubmit(forkjointask<?> task) {
        int r;                                    // initialize caller's probe
        if ((r = threadlocalrandom.getprobe()) == 0) {
            threadlocalrandom.localinit();
            r = threadlocalrandom.getprobe();
        }
        for (;;) {
            workqueue[] ws; workqueue q; int rs, m, k;
            boolean move = false;
            if ((rs = runstate) < 0) {
                tryterminate(false, false);     // help terminate
                throw new rejectedexecutionexception();
            }
            else if ((rs & started) == 0 ||     // initialize
                     ((ws = workqueues) == null || (m = ws.length - 1) < 0)) {
                int ns = 0;
                rs = lockrunstate();
                try {
                    if ((rs & started) == 0) {
                        u.compareandswapobject(this, stealcounter, null,
                                               new atomiclong());
                        // create workqueues array with size a power of two
                        int p = config & smask; // ensure at least 2 slots
                        int n = (p > 1) ? p - 1 : 1;
                        n |= n >>> 1; n |= n >>> 2;  n |= n >>> 4;
                        n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
                        workqueues = new workqueue[n];
                        ns = started;
                    }
                } finally {
                    unlockrunstate(rs, (rs & ~rslock) | ns);
                }
            }
            else if ((q = ws[k = r & m & sqmask]) != null) {
                if (q.qlock == 0 && u.compareandswapint(q, qlock, 0, 1)) {
                    forkjointask<?>[] a = q.array;
                    int s = q.top;
                    boolean submitted = false; // initial submission or resizing
                    try {                      // locked version of push
                        if ((a != null && a.length > s + 1 - q.base) ||
                            (a = q.growarray()) != null) {
                            int j = (((a.length - 1) & s) << ashift) + abase;
                            u.putorderedobject(a, j, task);
                            u.putorderedint(q, qtop, s + 1);
                            submitted = true;
                        }
                    } finally {
                        u.compareandswapint(q, qlock, 1, 0);
                    }
                    if (submitted) {
                        signalwork(ws, q);
                        return;
                    }
                }
                move = true;                   // move on failure
            }
            else if (((rs = runstate) & rslock) == 0) { // create new queue
                q = new workqueue(this, null);
                q.hint = r;
                q.config = k | shared_queue;
                q.scanstate = inactive;
                rs = lockrunstate();           // publish index
                if (rs > 0 &&  (ws = workqueues) != null &&
                    k < ws.length && ws[k] == null)
                    ws[k] = q;                 // else terminated
                unlockrunstate(rs, rs & ~rslock);
            }
            else
                move = true;                   // move if busy
            if (move)
                r = threadlocalrandom.advanceprobe(r);
        }
    }

创建工人(线程)

提交任务后,通过signalwork(ws, q)方法,发送工作信号,当符合没有执行完毕,且没有出现异常的条件下,循环执行任务,根据控制变量尝试添加工人(线程),通过线程工厂,生成线程,并且启动线程,也控制着工人(线程)的下岗。

    final void signalwork(workqueue[] ws, workqueue q) {
        long c; int sp, i; workqueue v; thread p;
        while ((c = ctl) < 0l) {                       // too few active
            if ((sp = (int)c) == 0) {                  // no idle workers
                if ((c & add_worker) != 0l)            // too few workers
                    tryaddworker(c);
                break;
            }
            if (ws == null)                            // unstarted/terminated
                break;
            if (ws.length <= (i = sp & smask))         // terminated
                break;
            if ((v = ws[i]) == null)                   // terminating
                break;
            int vs = (sp + ss_seq) & ~inactive;        // next scanstate
            int d = sp - v.scanstate;                  // screen cas
            long nc = (uc_mask & (c + ac_unit)) | (sp_mask & v.stackpred);
            if (d == 0 && u.compareandswaplong(this, ctl, c, nc)) {
                v.scanstate = vs;                      // activate v
                if ((p = v.parker) != null)
                    u.unpark(p);
                break;
            }
            if (q != null && q.base == q.top)          // no more work
                break;
        }
    }

    private void tryaddworker(long c) {
        boolean add = false;
        do {
            long nc = ((ac_mask & (c + ac_unit)) |
                       (tc_mask & (c + tc_unit)));
            if (ctl == c) {
                int rs, stop;                 // check if terminating
                if ((stop = (rs = lockrunstate()) & stop) == 0)
                    add = u.compareandswaplong(this, ctl, c, nc);
                unlockrunstate(rs, rs & ~rslock);
                if (stop != 0)
                    break;
                if (add) {
                    createworker();
                    break;
                }
            }
        } while (((c = ctl) & add_worker) != 0l && (int)c == 0);
    }

    private boolean createworker() {
        forkjoinworkerthreadfactory fac = factory;
        throwable ex = null;
        forkjoinworkerthread wt = null;
        try {
            if (fac != null && (wt = fac.newthread(this)) != null) {
                wt.start();
                return true;
            }
        } catch (throwable rex) {
            ex = rex;
        }
        deregisterworker(wt, ex);
        return false;
    }

   final void deregisterworker(forkjoinworkerthread wt, throwable ex) {
        workqueue w = null;
        if (wt != null && (w = wt.workqueue) != null) {
            workqueue[] ws;                           // remove index from array
            int idx = w.config & smask;
            int rs = lockrunstate();
            if ((ws = workqueues) != null && ws.length > idx && ws[idx] == w)
                ws[idx] = null;
            unlockrunstate(rs, rs & ~rslock);
        }
        long c;                                       // decrement counts
        do {} while (!u.compareandswaplong
                     (this, ctl, c = ctl, ((ac_mask & (c - ac_unit)) |
                                           (tc_mask & (c - tc_unit)) |
                                           (sp_mask & c))));
        if (w != null) {
            w.qlock = -1;                             // ensure set
            w.transferstealcount(this);
            w.cancelall();                            // cancel remaining tasks
        }
        for (;;) {                                    // possibly replace
            workqueue[] ws; int m, sp;
            if (tryterminate(false, false) || w == null || w.array == null ||
                (runstate & stop) != 0 || (ws = workqueues) == null ||
                (m = ws.length - 1) < 0)              // already terminating
                break;
            if ((sp = (int)(c = ctl)) != 0) {         // wake up replacement
                if (tryrelease(c, ws[sp & m], ac_unit))
                    break;
            }
            else if (ex != null && (c & add_worker) != 0l) {
                tryaddworker(c);                      // create replacement
                break;
            }
            else                                      // don't need replacement
                break;
        }
        if (ex == null)                               // help clean on way out
            forkjointask.helpexpungestaleexceptions();
        else                                          // rethrow
            forkjointask.rethrow(ex);
    }

    public static interface forkjoinworkerthreadfactory {
        public forkjoinworkerthread newthread(forkjoinpool pool);
    }
    static final class defaultforkjoinworkerthreadfactory
        implements forkjoinworkerthreadfactory {
        public final forkjoinworkerthread newthread(forkjoinpool pool) {
            return new forkjoinworkerthread(pool);
        }
    }
    protected forkjoinworkerthread(forkjoinpool pool) {
        // use a placeholder until a useful name can be set in registerworker
        super("aforkjoinworkerthread");
        this.pool = pool;
        this.workqueue = pool.registerworker(this);
    }

    final workqueue registerworker(forkjoinworkerthread wt) {
        uncaughtexceptionhandler handler;
        wt.setdaemon(true);                           // configure thread
        if ((handler = ueh) != null)
            wt.setuncaughtexceptionhandler(handler);
        workqueue w = new workqueue(this, wt);
        int i = 0;                                    // assign a pool index
        int mode = config & mode_mask;
        int rs = lockrunstate();
        try {
            workqueue[] ws; int n;                    // skip if no array
            if ((ws = workqueues) != null && (n = ws.length) > 0) {
                int s = indexseed += seed_increment;  // unlikely to collide
                int m = n - 1;
                i = ((s << 1) | 1) & m;               // odd-numbered indices
                if (ws[i] != null) {                  // collision
                    int probes = 0;                   // step by approx half n
                    int step = (n <= 4) ? 2 : ((n >>> 1) & evenmask) + 2;
                    while (ws[i = (i + step) & m] != null) {
                        if (++probes >= n) {
                            workqueues = ws = arrays.copyof(ws, n <<= 1);
                            m = n - 1;
                            probes = 0;
                        }
                    }
                }
                w.hint = s;                           // use as random seed
                w.config = i | mode;
                w.scanstate = i;                      // publication fence
                ws[i] = w;
            }
        } finally {
            unlockrunstate(rs, rs & ~rslock);
        }
        wt.setname(workernameprefix.concat(integer.tostring(i >>> 1)));
        return w;
    }

例:forkjointask实现归并排序

这里我们就用经典的归并排序为例,构建一个我们自己的forkjointask,按照归并排序的思路,重写其核心的compute()方法,通过forkjoinpool.submit(task)提交任务,通过get()同步获取任务执行结果。

package com.zhj.interview;

import java.util.*;
import java.util.concurrent.executionexception;
import java.util.concurrent.forkjoinpool;
import java.util.concurrent.recursivetask;

public class test16 {

    public static void main(string[] args) throws executionexception, interruptedexception {
        int[] bigarr = new int[10000000];
        for (int i = 0; i < 10000000; i++) {
            bigarr[i] = (int) (math.random() * 10000000);
        }
        forkjoinpool forkjoinpool = new forkjoinpool();
        myforkjointask task = new myforkjointask(bigarr);
        long start = system.currenttimemillis();
        forkjoinpool.submit(task).get();
        long end = system.currenttimemillis();
        system.out.println("耗时:" + (end-start));
	}

}
class myforkjointask extends recursivetask<int[]> {

    private int source[];

    public myforkjointask(int source[]) {
        if (source == null) {
            throw new runtimeexception("参数有误!!!");
        }
        this.source = source;
    }

    @override
    protected int[] compute() {
        int l = source.length;
        if (l < 2) {
            return arrays.copyof(source, l);
        }
        if (l == 2) {
            if (source[0] > source[1]) {
                int[] tar = new int[2];
                tar[0] = source[1];
                tar[1] = source[0];
                return tar;
            } else {
                return arrays.copyof(source, l);
            }
        }
        if (l > 2) {
            int mid = l / 2;
            myforkjointask task1 = new myforkjointask(arrays.copyof(source, mid));
            task1.fork();
            myforkjointask task2 = new myforkjointask(arrays.copyofrange(source, mid, l));
            task2.fork();
            int[] res1 = task1.join();
            int[] res2 = task2.join();
            int tar[] = merge(res1, res2);
            return tar;
        }
        return null;
    }
	// 合并数组
    private int[] merge(int[] res1, int[] res2) {
        int l1 = res1.length;
        int l2 = res2.length;
        int l = l1 + l2;
        int tar[] = new int[l];
        for (int i = 0, i1 = 0, i2 = 0; i < l; i++) {
            int v1 = i1 >= l1 ? integer.max_value : res1[i1];
            int v2 = i2 >= l2 ? integer.max_value : res2[i2];
            // 如果条件成立,说明应该取数组array1中的值
            if(v1 < v2) {
                tar[i] = v1;
                i1++;
            } else {
                tar[i] = v2;
                i2++;
            }
        }
        return tar;
    }
}

forkjoin计算流程

通过forkjoinpool提交任务,获取结果流程如下,拆分子任务不一定是二分的形式,可参照mapreduce的模式,也可以按照具体需求进行灵活的设计。

到此这篇关于一文带你了解java中的forkjoin的文章就介绍到这了,更多相关java中的forkjoin内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!

《一文带你了解Java中的ForkJoin.doc》

下载本文的Word格式文档,以方便收藏与打印。