読者です 読者をやめる 読者になる 読者になる

SHIROBAKO大好き人間のブログ

SHIROBAKOが好きなエンジニアによる技術ブログ

ベルマンフォード法

ダイクストラは知ってるけどベルマンフォード法は知らなかったので自分用にまとめておきます。

ベルマンフォード法

目的

重み付きの有向グラフにおいてあるノード sからその他のノードへの最短経路を見つける。

重みが負でもOK。

重みが負の場合、負の重みの閉路が存在すると最短経路がいくらでも小さくなってしまうが、このアルゴリズムは負の閉路が存在するかどうかも検出できる。

アルゴリズム

 Vをノードの集合、 Eをエッジの集合とする。

 uから vへの重み wの辺を (u, v, w)と表す。

開始ノード sからノード v \in Vへの距離を dist(v)とする。

最短経路におけるノード vの親を pre(v)と表す。

  1.  dist(s)=0、その他のノード vに対して dist(v)= \inftyと初期化する

  2. すべての (u, v, w) \in Eに対して以下の操作を行う
     dist(v) > dist(u) + wなら、 dist(v) dist(v) = dist(u) + wと更新する
    また、 pre(v)=uとする

  3. 2.の操作を |V|-1回繰り返す

  4. 負の閉路が存在するかを以下の手順でチェックする
    各エッジ (u, v, w) \in Eに対して dist(u) + w \lt dist(v)となる辺があるかどうか確かめる
    そのような辺がある場合には、グラフに負の閉路が存在する

負の閉路が存在しない場合には dist(v) sから vまでの最短距離が格納され、 pre(v)をたどっていくと最短経路が分かります。

具体例

簡単な例があった方が自分は理解しやすいので単純なグラフの例を載せておきます。

以下のようなグラフがあったとします。

f:id:phoro3:20170515224044p:plain

ノード sからこのアルゴリズムを適用した流れを示します。

 dist(s)=0 dist(v_1)=dist(v_2)=dist(v_3)=\infty

ループ1回目

エッジ (s, v_1, 1)
 dist(v_1) > dist(s) + 1なので dist(v_1) = dist(s) + 1 = 1, pre(v_1) = s

エッジ (v_1, v_2, -1)
 dist(v_2) > dist(v_1) - 1なので dist(v_2) = dist(v_1) - 1 = 0, pre(v_2) = v_1

エッジ (v_2, s, 2)
 dist(s) \lt dist(v_2) + 2なので更新しない

エッジ (v_1, v_3, -2)
 dist(v_3) > dist(v_1) - 2なので dist(v_3) = dist(v_1) - 2 = -1, pre(v_3) = v_1

ループ2回目

エッジ (s, v_1, 1)
 dist(v_1) = dist(s) + 1なので更新しない

エッジ (v_1, v_2, -1)
 dist(v_2) = dist(v_1) - 1なので更新しない

・・・という流れで続きます。

ループを |V|-1 = 3回繰り返すと、 dist(v_1) = 1というように distにちゃんと最短経路の長さが格納されます。

今回は単純な例なので、途中から自明な結果になりましたが、負の重みの閉路を含む例だとダイクストラとの違いが出てきて面白いです。

コード

実装例を載せておきます。

シンプルなアルゴリズムなので実装もしやすいです。

from collections import defaultdict

#E = [(u, v, w), ...]
def BellmanFord(V, E, s):
    dist = defaultdict(lambda :10**20)
    dist[s] = 0
    pre = {}

    for i in range(len(V)):
        for edge in E:
            u = edge[0]
            v = edge[1]
            w = edge[2]

            if dist[v] > dist[u] + w:
                dist[v] = dist[u] + w
                pre[v] = u

    #負閉路の検出
    is_cycle = False
    for edge in E:
        u = edge[0]
        v = edge[1]
        w = edge[2]
        if dist[u] + w < dist[v]:
            is_cycle = True

    return (dist, pre, is_cycle)