ベイズ統計の最高事後密度区間を Java で求める

ベイズ統計における「信用区間」の一つである最高事後密度区間 (Highest posterior density interval, HPDI) について、その区間を求める方法を調べて Java で実装してみたメモです。

最高事後密度区間とは?

そもそもの 最高事後密度区間 を自分自身がちゃんと理解していないので、まずは定義を確認しておきます。

Web 上の日本語文献だと「確率分布の密度が高い部分の n%」みたいなゆるふわな説明が多かったので、あえて Cross Validated の What is a Highest Density Region (HDR)? の回答 を参考にしてみると、その定義は以下のようになるようです。

確率変数 $X$ の確率分布 $P(X)$ について、その確率密度関数を $f(x)$ とする。この $f(x)$ に対し、値が $f_\alpha$ 以上となる $x$ の集合 $R(f_\alpha) = \lbrace x \ \vert \ f(x) \ge f_\alpha \rbrace$ を $100(1-\alpha)\%$ 最高事後密度区間 とする。

ここで $f_\alpha$ は $P(X \in R(f_a)) \ge 1 - \alpha$ を満たす最大の定数である。

このように文章で表現するとなかなかイメージが掴みづらいですが、これを単峰性の確率分布を例に (具体的にはベータ分布 $Beta(8,4)$) 図で表現すると

Unimodal HPDI

のようになります。つまりは、

  • 区間内の確率密度が常に $f_\alpha$ 以上になる
  • 逆に、区間外の確率密度は常に $f_\alpha$ 未満となる
  • 区間内の確率は $1 - \alpha$ となる

これらの条件を満たす区間が 最高事後密度区間 であると言えます。

最高事後密度区間を Java で求める

さて、任意の 単峰性の確率分布 における最高事後密度区間をどのように求めるのかというと、これまた Cross Validated の Credible set for beta distribution の回答 に具体的な手順が紹介されています。

ここで紹介されているのは 一変数関数の最適化問題として解く 方法で、具体的には確率分布 $P(X)$ の累積分布関数を $F(x)$、また

\[y = F^{-1}(F(x) + 1 - \alpha) \tag{1}\]

としたとき、目的関数は次の式になります。

\[(f(y) - f(x))^2 + ((F(y) - F(x)) - (1 - \alpha))^2 \tag{2}\]

この目的関数を $x \in (-\infty, F^{-1}(\alpha))$ の区間で最小化することで得られる $x$ が最高事後密度区間の左端に、また式 $(1)$ に従って得られる $y$ が右端となります。

この手順を Java で実装するとなるとまず、一変数関数を最適化する手段を用意する必要があります。こちらは Java / Commons Math で一変数の関数を最適化する のエントリにあるように、Commons Math の BrentOptimizer を利用すれば一変数関数の最適化は実現できます。

後は、先の Cross Validated の回答にある R 実装を参考にしつつ Java で実装すればよいわけで、具体的な実装は次のようになります。

import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.distribution.BetaDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.univariate.BrentOptimizer;
import org.apache.commons.math3.optim.univariate.SearchInterval;
import org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;

public class HighestPosteriorDensityInterval {
    private final double lowerBound;
    private final double upperBound;

    HighestPosteriorDensityInterval(double lowerBound, double upperBound) {
        this.lowerBound = lowerBound;
        this.upperBound = upperBound;
    }

    public double getLowerBound() {
        return lowerBound;
    }

    public double getUpperBound() {
        return upperBound;
    }

    public static HighestPosteriorDensityInterval calculate(RealDistribution distribution, double alpha) {
        Solver solver = new Solver(distribution, alpha);

        double lowerBound = solver.solve();
        double upperBound = solver.offset(lowerBound);

        return new HighestPosteriorDensityInterval(lowerBound, upperBound);
    }

    static class Solver implements UnivariateFunction {
        private static final double THRESHOLD = 1e-15;
        private static final MaxEval MAX_EVAL = new MaxEval(100);

        private final RealDistribution distribution;
        private final double alpha;

        Solver(RealDistribution distribution, double alpha) {
            this.distribution = distribution;
            this.alpha = alpha;
        }

        double solve() {
            UnivariatePointValuePair result = new BrentOptimizer(THRESHOLD, THRESHOLD)
                    .optimize(
                            GoalType.MINIMIZE,
                            new UnivariateObjectiveFunction(this),
                            MAX_EVAL,
                            new SearchInterval(
                                    distribution.getSupportLowerBound(),
                                    distribution.inverseCumulativeProbability(alpha),
                                    distribution.inverseCumulativeProbability(alpha / 2)));

            return result.getPoint();
        }

        double offset(double x) {
            double q = distribution.cumulativeProbability(x);
            return distribution.inverseCumulativeProbability(Math.min(q + 1 - alpha, 1));
        }

        @Override
        public double value(double x) {
            double y = offset(x);
            double d1 = distribution.density(y) - distribution.density(x);
            double d2 = (distribution.cumulativeProbability(y) - distribution.cumulativeProbability(x)) - (1 - alpha);
            return d1 * d1 + d2 * d2;
        }
    }
}

上記のクラスの利用方法は次のとおり。

public class HpdiDemo {
    public static void main(String[] args) {
        BetaDistribution distribution = new BetaDistribution(8, 4);

        HighestPosteriorDensityInterval interval = HighestPosteriorDensityInterval.calculate(distribution, 0.05);

        // 95% 最高事後密度区間 [0.412047441090850, 0.906627667219367] が出力される
        System.out.printf("[%.15f, %.15f]", interval.getLowerBound(), interval.getUpperBound());
    }
}