熟悉这个错误么?还记得在写递归方法时的抓狂么?在本文中,我们将基于HotSpot虚拟机讨论一种叫Trampoline的技术,它可以彻底消除StackOverflowError
。另外本文中保留了一些英文单词,因为个人觉得中文翻译并不会增加读者的理解程度,反而有可能造成信息丢失,造成的不便还请谅解。
什么是 Java Virtual Machine Stack?
要搞懂这个错误,我们必须先了解Java Virtual Machine Stack。
Java Virtual Machine Stack 以前也叫Java Stack。它可以记录当前线程中当前方法的当前状态,这种状态是以stack frame的形式存储的。Java Virtual Machine Stack是在创建新线程时被同时创建的,而且只有这个新线程可以操作它。
一个stack frame由三部分组成:局部变量(local variables),操作对象栈(operand stacks)和动态链接(dynamic links)。当方法A调用方法B时,B的stack frame会被创建并压入栈。当B运行结束时,B的stack frame会被弹出并销毁。A的stack frame这时会重新回到栈顶,且会被用来恢复A的状态以继续执行后续操作。
假设我们有一个方法叫factorial
public Long factorial(Long n) {
if (n == 1) {
return 1l;
}
return n * factorial(n - 1);
}
当调用factorial(4)
时, Java Virtual Machine Stack的空间变化如下
每个线程的Java Virtual Machine Stack是有最大空间限制的,我们可以通过下面的命令查看默认最大空间
java -XX:+PrintFlagsFinal -version|grep ThreadStackSize
我们也可以通过命令行参数来指定它的最大空间
java -Xss2M //设置最大空间为2M
OS | Default Stack Size |
---|---|
Windows |
40KB |
Linux AArch64 | 72KB |
Linux RISC-V | 72KB |
Linux s390 | 32KB |
Linux ARM | 32KB |
Linux x86 | 40KB |
什么是 StackOverflowError?
根据Java Virtual Machine Specification的描述,在一个线程中,如果运行需要的Java Virtual Machine Stack空间大于分配空间,Java虚拟机就会抛出StackOverflowError
。
例如,当我们调用factorial(10000l)
时,它就会抛出StackOverflowError
。因为在调用factorial(1l)
并开始弹出stack frame之前,Java Virtual Machine Stack的空间就已经被用光了。
factorial(10000l);
java.lang.StackOverflowError
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
.....
尾调用和尾递归
尾调用是指该方法调用是当前方法的最后一步,例如
public Integer add(Integer x, Integer y) {
return x + y;
}
public Ineger substract(Integer x, Integer y) {
return add(x, -1*y); // 尾调用
}
在上面的代码中,add(x, <result of -1*y>)
就是一个尾调用。
但是对factorial
来说,factorial(n-1)
就不是一个尾调用,因为它的最后一步其实是*
public Long factorial(Long n) {
if (n == 1) {
return 1l;
}
Long nextFacorial = factorial(n-1); // 不是尾调用
return n * nextFactorial; // 方法的最后一步
}
如果尾调用的方法名和当前方法名相同,我们就称当前方法为尾递归,例如
public Long fibonacci(Long n, Long a, Long b) { // 尾递归
if (n == 0) {
return a;
}
if (n == 1) {
return b;
}
return fibonacci(n - 1, b, a + b); // 尾调用
}
如何消除尾递归?
尾递归可以被重构成一个while
循环,例如
public Long fibonacci(Long n, Long a, Long b) {
Long nParam = n;
Long aParam = a;
Long bParam = b;
while (true) {
if (nParam == 0) {
return aParam;
}
if (nParam == 1) {
return bParam;
}
nParam = nParam - 1;
Long aCurrent = aParam;
aParam = bParam;
bParam = aCurrent + bParam;
}
}
我们可以按照下面的步骤来消除尾递归
- 为每个参数创建一个局部变量,例如
nParam
,aParam
和bParam
- 将方法体包在一个
while(true)
循环里,并将参数引用全部替换为对应局部变量的引用,例如行5到行15。 - 将尾调用替换为局部变量赋值,也就是将尾调用的各个参数赋给对应的局部变量,例如行12到行15。
尾递归完全可以由编译器自动消除,Scala已经做到了这一点,但Java目前还不支持。Java不支持的其中一个原因是多态,编译器无法知道当前方法是否被子类重写,也就没有办法用当前方法的逻辑来消除尾递归。即便在Scala中,编译器也要求尾递归的方法是private
或final
的,这样它们就不能被重写了。
如何消除 StackOverflowError?
没有方法调用的方法是不可能抛出StackOverflowError
的。让我们把注意力集中在那些有方法调用的方法。目前已知
StackOverflowError
是由Java Virtual Machine Stack的空间限制引起的- 我们可以用
while(loop)
来消除尾递归,这样需要的Java Virtual Machine Stack空间更少 - 堆通常要比Java Virtual Machine Stack的空间大
我们是否可以利用堆和尾递归来消除StackOverflowError
呢?这里给出的答案就是trampoline技术。
什么是 trampoline?
为了描述方便,这里我们将Java Virtual Machine Stack简称为栈。
Trampoline 是一种用堆空间来替换栈空间的技术。StackOverflowError
的根本原因是当前方法需要等到被调方法返回才能释放栈空间。Trampoline可以让当前方法在调用完方法后立即释放栈空间,不需要等待被调方法返回,因为它是将调用状态保存在堆里,而不是栈里。
这样做的代价就是我们需要自己从堆中获取调用状态然后执行,trampoline会有一个专门的方法来做这件事情,而且这个方法可以用尾递归来实现。
和普通方法调用不同的是,trampoline执行时栈的空间大小是有规律的升高和降低,就像蹦床一样。
CPS
为了使trampoline的应用更加容易,我们需要先将方法从常规风格转换为CPS 风格。
CPS的英文全称是continuation-passing style,是一种编码风格,它的显著特征是将方法执行完成后的后续操作显式的传递给当前方法。
对于CPS风格的方法,我们需要添加一个额外参数,参数的类型是函数,一般叫它continuation。在当前方法执行完成后,它不会将结果返回,而是将结果作为参数调用continuation。
例如,我们可以将factorial
重写为CPS风格
public void factorial(Long n, Consumer<Long> continuation) {
if (n == 1) {
continuation.accept(1l);
return;
}
factorial(n - 1, (Long result) -> continuation.accept(n * result));
}
现在它变成了尾递归,我们可以将它转换为while
循环
public void factorial(Long n, Consumer<Long> continuation) {
Long nParam = n;
Consumer<Long> continuationParam = continuation;
while (true) {
if (nParam == 1) {
continuationParam.accept(1l);
return;
}
Long nCurrent = nParam;
nParam = nParam - 1;
final Consumer<Long> currentContinuation = continuationParam;
continuationParam = (Long result) -> currentContinuation.accept(nCurrent * result);
}
}
但是转换后的方法还是会抛出StackOverflowError
factorial(10000l, (x) -> {});
java.lang.StackOverflowError
at java.base/java.lang.Long.longValue(Long.java:1353)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
....
错误是由第6行的continuationParam.accept(1l)
抛出的,因为在每次循环后第13行都会调用一次currentConinuation
,造成了多层函数的嵌套调用。
并不是所有的方法在重写成CPS风格后就会变成尾递归,例如
public boolean isEven(Long n) {
if (n == 0)
return true;
return isOdd(n - 1);
}
public boolean isOdd(Long n) {
if (n == 0)
return false;
return isEven(n - 1);
}
我们可以把这两个方法重写成CPS风格
public void isEven(Long n, Consumer<Boolean> continuation) {
if (n == 0) {
continuation.accept(true);
return;
}
isOdd(n - 1, (result) -> continuation.accept(result));
}
public void isOdd(Long n, Consumer<Boolean> continuation) {
if (n == 0) {
continuation.accept(false);
return;
}
isEven(n - 1, (result) -> continuation.accept(result));
}
上面代码中的方法调用是尾调用,但isEven
和isOdd
并不是尾递归。
我们并不能通过CPS重写来消除StackOverflowError
,而且CPS风格的代码是很容易出错的。但它的好处是所有的方法调用都是尾调用,且方法间的执行顺序是显式的,这些都让trampoline的应用更加容易。
如何实现 trampoline?
简单来说trampoline就是先将每次方法调用封装在一个thunk 函数里,然后将所有thunk函数在一个循环里挨个调用,直到得到最终结果。
Thunk函数的中文翻译是形实转换程序,鉴于它的中文名并不能提升我们的认知程度,后续我们仍将称其为thunk函数。
Thunk函数是一个没有参数的函数,它的实现类似于Java中的Supplier
或Scala中的Lazy
。
Supplier<Long> thunk = () -> 1l
为了消除StackOverflowError
,我们需要把方法调用的控制权从JVM那里拿过来。最简单的方式就是把方法调用放在一个thunk函数里,那么只有在我们调用thunk函数时对应的方法才会被调用,控制权也就落到了我们手里。
Supplier<Long> thunk = () -> factorial(4l); // 创建一个 Supplier 实例,factorial(4l) 并不会被调用
thunk.get(); // 调用thunk时,factorial(4) 被调用
如果我们用常规的编码风格,thunk函数必须在当前方法里调用,因为后续操作依赖于它的返回值,但这也就失去了使用它的意义
private Long factorial(Long n) {
if (n == 1) {
return 1l;
}
Supplier<Long> thunk = () -> factorial(n-1);
return n * thunk.get();
}
如果我们使用CPS风格,也同样面临这个问题
public void factorial(Long n, Consumer<Long> continuation) {
if (n == 1) {
Supplier<Void> thunk = () -> {
continuation.accept(1l);
return null;
};
think.get();
return;
}
Supplier<Void> thunk = () -> {
factorial(n - 1, (Long result) -> continuation.accept(n * result));
return null;
};
thunk.get();
}
但我们看到thunk.get()
是尾调用,我们可以将thunk返回,这样什么时候调用它就由我们决定了。
public Supplier<Void> factorial(Long n, Consumer<Long> continuation) {
if (n == 1) {
Supplier<Void> thunk = () -> {
continuation.accept(1l);
return null;
};
return thunk;
}
Supplier<Void> thunk = () -> {
Supplier<Void> thunkContinuation =
factorial(n - 1, (Long result) -> continuation.accept(n * result));
thunkContinuation.get();
return null;
};
return thunk;
}
不过这个方法仍然会抛出StackOverflowError
,因为在我们调用thunk.get()
时,它又会调用thunkContinuation.get()
,造成了多层函数的嵌套调用。根本原因是我们虽然把factorial
的尾调用作为thunk函数返回了,却忽略了thunk里的尾调用。
为了返回thunk里的尾调用,我们需要将thunk的函数签名从Supplier<Void>
变为 Supplier<Supplier<Supplier<....>>>
。但我们无法只用Supplier
来定义这种递归类型,这里需要引入一个新类型
public class More {
private Supplier<More> thunk;
public More(Supplier<More> thunk){
this.thunk = thunk;
}
}
然后我们可以把More
作为thunk函数返回
public More factorial(Long n, Function<Long, More> continuation) {
if (n == 1) {
return new More(() -> continuation.apply(1l));
}
return new More(() -> factorial(n - 1, (Long result) -> new More(() -> continuation.apply(n * result))));
}
这里我们同时改变了continuation
的类型,有两个原因
continuation
也可能抛出StackOverflowError
,它也需要应用trampoline- 在
n==1
时,我们没有办法用Supplier<Void>
构造出More
做了上面的修改之后,我们可以循环调用thunk函数来计算结果了
public static void run(More trampoline) {
run(trampoline.getThunk().get());
}
大家可能已经发现,我们现在没有办法调用factorial
来返回thunk函数,因为我们没有办法在continuation
里面实例化More
。More
的实例化需要另外一个More
的实例,这就陷入了死循环。
More trampoline = factorial(4l, x -> new More(y -> new More(z -> ....)))
所以我们需要一个新的类型Done
来表示计算已经结束,而且它需要和More
有共同的父类。
public interface Trampoline {
}
public class Done implements Trampoline {
public Done() {
}
}
public class More implements Trampoline {
private Supplier<Trampoline> thunk;
public More(Supplier<Trampoline> thunk) {
this.thunk = thunk;
}
}
我们的run
函数同样需要能够识别Done
来及时结束计算。
public static void run(Trampoline trampoline) {
if (trampoline instanceof Done) {
return;
}
run(((More) trampoline).getThunk().get());
}
现在我们的factorial
函数变成了这个样子
public Trampoline factorial(Long n, Function<Long, Trampoline> continuation) {
if (n == 1) {
return new More(() -> continuation.apply(1l));
}
return new More(() -> factorial(n - 1, (Long result) -> new More(() -> continuation.apply(n * result))));
}
我们也可以调用它了
Trampoline trampoline = factorial(4l, x -> new Done())
run(trampoline);
不要忘了,尾递归是可以被消除的,消除后即便我们调用factorial(10000l, x -> Done())
,也不会有StackOverflowError
抛出。
public static void run(Trampoline trampoline) {
Trampoline trampolineParam = trampoline;
while (true) {
if (trampolineParam instanceof Done) {
return;
}
trampolineParam = ((More) trampolineParam).getThunk().get();
}
}
这里总结一下我们是如何在CPS风格的方法上应用trampoline的
- 将返回类型从
void
变为Trampoline
- 将参数
continuation
的类型从Consumer<Long>
变为Function<Long, Trampoline>
- 将所有的尾调用替换为
More
,这里也包括continuation
中的尾调用 - 使用
run
方法来遍历Tramopline
调用run(facorial(4))
时栈的空间变化如下
我们可以看到每一次方法调用都立即返回,没有出现stack frame的持续叠加。
如果我们不把continuattion
中的尾调用替换为More
,factorial
的实现会变成
public Trampoline factorial(Long n, Function<Long, Trampoline> continuation) {
if (n == 1) {
return new More(() -> continuation.apply(1l));
}
return new More(() -> factorial(n - 1,
result -> continuation.apply(n * result))); // call continuation.apply directly
}
它会抛出StackOverflowError
,以run(facorial(4))
为例,栈的空间变化如下
可以看到在调用continuation3(1)
时,stack frame在持续叠加。这是因为它没有返回thunk而是直接调用了continuation.apply
,这样JVM就会将调用状态存储在栈上,trampoline也就无法对其进行控制了。
如何让trampoline更容易使用?
使用上节实现的trampoline有几个痛点
- CPS风格不是我们日常习惯的风格,难以使用
- 在将尾调用替换为
More
时非常容易出错
其实CPS风格的方法是用栈来存储当前方法执行完成后的后续操作,也就是把后续操作作为参数传给当前方法。考虑到所有方法都会应用trampoline,我们可以使用堆来存储当前方法的执行结果与后续操作的关系,存储这种关系的类叫FlatMap
。
public class FlatMap<A, B> implements Trampoline<B> {
private Trampoline<A> lastResult;
private Function<A, Trampoline<B>> continuation;
public FlatMap(Trampoline<A> lastResult, Function<A, Trampoline<B>> continuation) {
this.lastResult = lastResult;
this.continuation = continuation;
}
}
这个类包含了当前方法的执行结果lastResult
以及后续操作continuation
,这里我们将Trampoline
变成了泛型,用以表示方法返回结果的类型,Do
和More
也要做相应的重构。
public interface Trampoline<T> {
}
public class Done<T> implements Trampoline<T> {
private T result;
public Done(T result) {
this.result = result;
}
}
public class More<T> implements Trampoline<T> {
private Supplier<Trampoline<T>> thunk;
public More(Supplier<Trampoline<T>> thunk) {
this.thunk = thunk;
}
}
同时run
方法需要按下面的场景处理FlatMap
- 如果方法返回的是
Done
,直接调用continuation
- 如果方法返回的是
More
,调用More
,将得到的返回值与当前FlatMap
的continuation
组成新的FlatMap
- 如果方法返回的是
FlatMap
,将FlatMap(FlatMap(trampoline, g), f)
转换为FlatMap(trampoline, x -> FlatMap(g(x), f))
public static <S> S run(Trampoline<S> trampoline) {
if (trampoline instanceof Done) {
return ((Done<S>) trampoline).getResult();
} else if (trampoline instanceof More) {
return run(((More<S>) trampoline).getThunk().get());
} else {
FlatMap<Object, S> continuation = (FlatMap<Object, S>) trampoline;
Trampoline<Object> lastResult = continuation.getLastResult();
Function<Object, Trampoline<S>> continuationFunc = continuation.getContinuation();
if (lastResult instanceof FlatMap) {
FlatMap<Object, Object> lastResultContinuation =
(Continuation<Object, Object>) lastResult;
return run(new FlatMap<Object, S>(lastResultContinuation.getLastResult(),
x -> new FlatMap<Object, S>(
lastResultContinuation.getContinuation().apply(x),
continuationFunc)));
} else if (lastResult instanceof More) {
return run(new FlatMap<Object, S>(((More<Object>) lastResult).getThunk().get(),
continuationFunc));
} else {
return run(continuationFunc.apply(((Done<Object>) lastResult).getResult()));
}
}
}
现在我们不再需要将方法转换为CPS风格就可以直接应用trampoline
private Trampoline<Long> factorial(Long n) {
if (n == 1) {
return new Done<Long>(1l);
}
return new FlatMap<Long, Long>(factorial(n - 1), x -> new Done<Long>(n * x));
}
尴尬,它还是抛了StackOverflowError
,为什么?
Trampoline<Long> trampoline = factorial(10000l);
Trampoline.run(trampoline);
java.lang.StackOverflowError
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
....
原因是我们直接调用了factorial(n - 1)
new FlatMap<Long, Long>(factorial(n - 1), x -> new Done<Long>(n * x))
在我们创建FlatMap
实例之前,我们需要等待factorial(n - 1)
返回,而它会创建另外一个FlatMap
实例,又要等待factorial(n - 2)
返回,循环往复,stack frame不断堆叠。
我们可以使用More
来避免这种情况
new FlatMap<Long, Long>(new More<Long>(() -> factorial(n - 1)),
x -> new Done<Long>(n * x));
这里最重要的事情是一定要把所有的方法调用替换为More
。
从上面的例子可以看出,在创建Trampoline实例时还是不够方便,我们可以添加一些快捷方法来让它更易使用
public static <A> Trampoline<A> of(A v) {
return new Done<A>(v);
}
public static <A> Trampoline<A> suspend(Supplier<Trampoline<A>> thunk) {
return new More<A>(thunk);
}
<B> Trampoline<B> flatMap(Function<T, Trampoline<B>> continuation);
flatMap
需要每个子类单独实现
// Done and More
public <B> Trampoline<B> flatMap(Function<T, Trampoline<B>> continuation) {
return new FlatMap<T, B>(this, continuation);
}
//FlatMap
public <C> Trampoline<C> flatMap(Function<B, Trampoline<C>> nextContinuation) {
return new FlatMap<A, C>(lastResult,
x -> Trampoline.suspend(() -> continuation.apply(x)).flatMap(nextContinuation));
}
我们还会将FlatMap
的构造函数设为protected
,这样我们就可以对FlatMap
的构造制定一些规则,用这些规则来避免上面提到的直接调用问题。
通过使用这些快捷方法,factorial
会变的更加简洁
public Trampoline<Long> factorial(Long n) {
if (n == 1) {
return Trampoline.of(1l);
}
return Trampoline.suspend(() -> factorial(n - 1))
.flatMap(x -> Trampoline.of(n * x));
}
总结
Trampoline在函数式编程中是消除StackOverflowError
的重要技术。它也是我们要看懂一些函数式编程库源码的必备知识。
根据本文的实现,trampoline实际上是一个Free Monad,目前缺少的部分是map
public <B> Trampoline<B> map(Function<T, B> continuation) {
return new FlatMap<T, B>(this, x -> Trampoline.suspend(() -> Trampoline.of(continuation.apply(x))));
}
本文大多数的内容是Stackless Scala With Free Monads在Java中的解释,但是该文章中的几个观点我不是很理解
-
在4.3 An easy thing to get wrong中由左连接FlatMap引起的
SackOverflowError
(TheStackOverflowError
caused by the left-leaning tower of FlatMaps)我试图用下面的代码来复现这个错误
@Test public void tooManyLeftAssociateContinuationWillNotThrowError() { Trampoline<Long> trampoline = new Done<Long>(1l); for (int i = 1; i < 50000; i++) { trampoline = new FlatMap<Long, Long>(trampoline, x -> new Done<Long>(x)); } assertThat(Trampoline.run(trampoline)).isEqualTo(1); }
但是它并没有抛出
StackOverflowError
,我也画出了它的栈空间变化图,除非continuation抛出StackOverflowError
,否则这个场景是不会抛出StackOverflowError
的。所以问题不在于左连接的FlatMap
,问题在于continuation,但它又不在我们控制范围内 。 -
在 4.3 An easy thing to get wrong 中
flatMap
的实现方式文章中
flatMap
的实现如下def flatMap [B ]( f: A => Trampoline [B ]): Trampoline [ B] = this match { case FlatMap (a , g) => FlatMap (a , (x: Any ) = > g (x) flatMap f) case x => FlatMap (x , f ) }
它在下面的场景中抛出了
StackOverflowError
@Test public void tooManyFlatMapWillThrowError() { Trampoline<Long> trampoline = Trampoline.of(1l); for (int i = 1; i < 50000; i++) { trampoline = trampoline.flatMap(x -> Trampoline.of(x)); } assertThat(Trampoline.run(trampoline)).isEqualTo(1); }
对应栈的空间变化如下
我们可以通过将
continuation.apply(x)
(也就是文章中的)替换为More
来修复这个问题//FlatMap public <C> Trampoline<C> flatMap(Function<B, Trampoline<C>> nextContinuation) { return new FlatMap<A, C>(lastResult, x -> Trampoline.suspend(() -> continuation.apply(x)).flatMap(nextContinuation)); }
修复后栈的空间变化如下
我已经把本文中所有的代码上传到了trampoline-example,欢迎大家审阅,如果发现任何问题也请告知我,希望这篇文章能够帮助大家理解Trampoline。