Introduction
A first order HMM (hidden Markov model) is a tuple \$(H, \Sigma, T, E, \mathbb{P})\$, where \$H = \{1, \ldots, \vert H \vert\}\$ is the set of hidden states, \$\Sigma\$ is the set of symbols, \$T \subseteq H \times H\$ is the set of transitions, \$E \subseteq H \times \Sigma\$ is the set of emissions, and \$\mathbb{P}\$ is the probability function for elements of \$T\$ and \$E\$, satisfying the following conditions:
- There is a single start state \$h_{\texttt{start}} \in H\$ with no incoming transitions \$(h, h_{\texttt{start}}) \in T\$, and no emissions.
- There is a single end state \$h_{\texttt{end}} \in H\$ with no out-going transitions \$(h_{\texttt{end}}, h) \in H\$, and no emissions.
- Let \$\mathbb{P}(h \, \vert \, h^\prime)\$ denote the probability for the transition \$(h^\prime, h) \in T\$, and let \$\mathbb{P}(c \, \vert \, h)\$ denote the probability of an emission \$(h, c) \in E\$, for \$h^\prime, h \in H\$ and \$c \in \Sigma\$. It must hold that $$ \sum_{h \in H} \mathbb{P}(h \, \vert \, h^\prime) = 1, \text{ for all } h^\prime \in H \setminus \{ h_{\texttt{end}} \}, $$ and $$\sum_{c \in \Sigma} \mathbb{P}(c \, \vert \, h) = 1, \text{ for all } h \in H \setminus \{ h_{\texttt{start}}, h_{\texttt{end}} \}.$$
A path through an HMM is a sequence \$P\$ of hidden states \$P=p_0p_1 \cdots p_n p_{n+1}\$, where \$(p_i, p_{i + 1}) \in T\$, for each \$i \in \{ 0, \ldots, n \}. \$ The joint probability of \$P\$ and a sequence \$S = s_1 s_2 \cdots s_n\$, with each \$s_i \in \Sigma,\$ is $$ \mathbb{P}(P, S) = \prod_{i = 0}^n \mathbb{P}(p_{i + 1} \, \vert \, p_i) \prod_{i = 1}^n \mathbb{P}(s_i \, \vert \, p_i). $$ Also, we define \$\mathcal{P}(n)\$ as the set of all paths \$p_0 p_1 \cdots p_{n + 1}\$ through the HMM, of length \$n + 2\$, such that \$p_0 = h_{\texttt{start}}\$ and \$p_{n + 1} = h_{\texttt{end}}.\$
We need to solve two problems here. First, we need to construct the most probable path \$P^{\star}\$ that accords the input sequence \$S\$: $$ P^\star = \arg \max_{P \in \mathcal{P}(n)} \mathbb{P}(P, S). $$ Second, we want to generate all the sequence \$S\$ according state paths and return the sum of all state path probabilities: $$ \mathbb{P}(S) = \sum_{P \in \mathcal{P}(n)} \mathbb{P}(P, S). $$
Code
(The entire repository is in HiddenMarkovModel.java.)
com.github.coderodde.hmm.HiddenMarkovModel.java:
package com.github.coderodde.hmm;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
/**
* This class implements an HMM (hidden Markov model).
*/
public final class HiddenMarkovModel {
/**
* Used to denote the Viterbi matrix cells that are not yet computed.
*/
private static final double UNSET_PROBABILITY = -1.0;
/**
* The start state of the process.
*/
private final HiddenMarkovModelState startState;
/**
* The end state of the process.
*/
private final HiddenMarkovModelState endState;
private final Random random;
public HiddenMarkovModel(HiddenMarkovModelState startState,
HiddenMarkovModelState endState,
Random random) {
this.startState = startState;
this.endState = endState;
this.random = random;
}
public static double sumPathProbabilities(
List<HiddenMarkovModelStatePath> paths) {
double psum = 0.0;
for (HiddenMarkovModelStatePath path : paths) {
psum += path.getProbability();
}
return psum;
}
/**
* Performs a brute-force computation of the list of all possible paths in
* this HMM.
*
* @param sequence the target text.
* @return the list of sequences, sorted from most probable to the most
* improbable.
*/
public List<HiddenMarkovModelStatePath>
computeAllStatePaths(String sequence) {
int expectedStatePathSize = sequence.length() + 2;
List<List<HiddenMarkovModelState>> statePaths = new ArrayList<>();
List<HiddenMarkovModelState> currentPath = new ArrayList<>();
// BEGIN: Do the search:
currentPath.add(startState);
depthFirstSearchImpl(statePaths,
currentPath,
expectedStatePathSize,
startState);
// END: Searching done.
// Construct the sequences:
List<HiddenMarkovModelStatePath> sequenceList =
new ArrayList<>(statePaths.size());
for (List<HiddenMarkovModelState> statePath : statePaths) {
sequenceList.add(
new HiddenMarkovModelStatePath(
statePath,
sequence,
this));
}
// Put into descending order by probabilities:
Collections.sort(sequenceList);
Collections.reverse(sequenceList);
return sequenceList;
}
/**
* Returns the most probable state path for the input sequence using the
* <a href="https://en.wikipedia.org/wiki/Viterbi_algorithm">
* Viterbi algorithm.</a>
*
* @param sequence the target sequence.
* @return the state path.
*/
public HiddenMarkovModelStatePath runViterbiAlgorithm(String sequence) {
// Get all the states reachable from the start state:
Set<HiddenMarkovModelState> graph = computeAllStates();
if (!graph.contains(endState)) {
// End state unreachable. Abort.
throw new IllegalStateException("End state is unreachable.");
}
// Maps the column index to the corresponding state:
Map<Integer, HiddenMarkovModelState> stateMap =
new HashMap<>(graph.size());
// Maps the state to the corresponding column index:
Map<HiddenMarkovModelState, Integer> inverseStateMap =
new HashMap<>(graph.size());
// Initialize maps for start and end states:
stateMap.put(0, startState);
stateMap.put(graph.size() - 1, endState);
inverseStateMap.put(startState, 0);
inverseStateMap.put(endState, graph.size() - 1);
int stateId = 1;
// Assign indices to hidden states:
for (HiddenMarkovModelState state : graph) {
if (!state.equals(startState) && !state.equals(endState)) {
stateMap.put(stateId, state);
inverseStateMap.put(state, stateId);
stateId++;
}
}
// Computes the entire Viterbi matrix:
double[][] viterbiMatrix =
computeViterbiMatrix(
sequence,
stateMap,
inverseStateMap);
// Uses the dynamic programming result reconstruction:
return tracebackStateSequenceViterbi(viterbiMatrix,
sequence,
stateMap,
inverseStateMap);
}
/**
* Returns the sum of probabilities over all the feasible paths that accord
* to the {@code sequence}.
*
* @param sequence the sequence text.
*
* @return the sum of probabilities.
*/
public double runForwardAlgorithm(String sequence) {
// Get all the states reachable from the start state:
Set<HiddenMarkovModelState> graph = computeAllStates();
if (!graph.contains(endState)) {
// End state unreachable. Abort.
throw new IllegalStateException("End state is unreachable.");
}
// Maps the column index to the corresponding state:
Map<Integer, HiddenMarkovModelState> stateMap =
new HashMap<>(graph.size());
// Maps the state to the corresponding column index:
Map<HiddenMarkovModelState, Integer> inverseStateMap =
new HashMap<>(graph.size());
// Initialize maps for start and end states:
stateMap.put(0, startState);
stateMap.put(graph.size() - 1, endState);
inverseStateMap.put(startState, 0);
inverseStateMap.put(endState, graph.size() - 1);
int stateId = 1;
// Assign indices to hidden states:
for (HiddenMarkovModelState state : graph) {
if (!state.equals(startState) && !state.equals(endState)) {
stateMap.put(stateId, state);
inverseStateMap.put(state, stateId);
stateId++;
}
}
// Computes the entire forward matrix:
double[][] forwardMatrix =
computeForwardMatrix(
sequence,
stateMap,
inverseStateMap);
// Uses the dynamic programming result reconstruction:
return tracebackStateSequenceForward(forwardMatrix,
inverseStateMap);
}
/**
* Composes a sequence according to this HMM.
*
* @return a sequence.
*/
public String compose() {
StringBuilder sb = new StringBuilder();
HiddenMarkovModelState currentState = startState;
while (true) {
currentState = doStateTransition(currentState);
if (currentState.equals(endState)) {
// Once here, we are done:
return sb.toString();
}
sb.append(doEmit(currentState));
}
}
/**
* Computes the entire Viterbi matrix.
*
* @param sequence the input text.
* @param stateMap the state map. From column index to the state.
* @param inverseStateMap the inverse state map. From state to column index.
*
* @return the entire Viterbi matrix.
*/
private double[][] computeViterbiMatrix(
String sequence,
Map<Integer, HiddenMarkovModelState> stateMap,
Map<HiddenMarkovModelState, Integer> inverseStateMap) {
double[][] matrix = new double[sequence.length() + 1]
[stateMap.size()];
// Set all required cells to unset sentinel:
for (int rowIndex = 1; rowIndex < matrix.length; rowIndex++) {
Arrays.fill(matrix[rowIndex], UNSET_PROBABILITY);
}
// BEGIN: Base case initialization.
matrix[0][0] = 1.0;
for (int columnIndex = 1;
columnIndex < matrix[0].length;
columnIndex++) {
matrix[0][columnIndex] = 0.0;
}
// END: Done with the base case initialization.
// Recursively build the matrix:
for (int h = 1; h < matrix[0].length - 1; h++) {
matrix[sequence.length()][h] =
computeViterbiMatrixImpl(
sequence.length(),
h,
matrix,
sequence,
stateMap,
inverseStateMap);
}
return matrix;
}
/**
* Computes the entire forward matrix.
*
* @param sequence the input text.
* @param stateMap the state map. From column index to the state.
* @param inverseStateMap the inverse state map. From state to column index.
*
* @return the entire Viterbi matrix.
*/
private double[][] computeForwardMatrix(
String sequence,
Map<Integer, HiddenMarkovModelState> stateMap,
Map<HiddenMarkovModelState, Integer> inverseStateMap) {
double[][] forwardMatrix = new double[sequence.length() + 1]
[stateMap.size()];
// Set all required cells to unset sentinel:
for (int rowIndex = 1; rowIndex < forwardMatrix.length; rowIndex++) {
Arrays.fill(forwardMatrix[rowIndex], UNSET_PROBABILITY);
}
// BEGIN: Base case initialization.
forwardMatrix[0][0] = 1.0;
for (int columnIndex = 1;
columnIndex < forwardMatrix[0].length;
columnIndex++) {
forwardMatrix[0][columnIndex] = 0.0;
}
// END: Done with the base case initialization.
// Recursively build the matrix:
for (int h = 1; h < forwardMatrix[0].length - 1; h++) {
forwardMatrix[sequence.length()][h] =
computeForwardMatrixImpl(
sequence.length(),
h,
forwardMatrix,
sequence,
stateMap,
inverseStateMap);
}
return forwardMatrix;
}
/**
* Computes the actual Viterbi matrix.
*
* @param i the {@code i} variable; the length of the sequence
* prefix.
* @param h the state index.
* @param viterbiMatrix the actual Viterbi matrix under construction.
* @param sequence the symbol sequence.
* @param stateMap the map mapping state IDs to states.
* @return
*/
private double computeViterbiMatrixImpl(
int i,
int h,
double[][] viterbiMatrix,
String sequence,
Map<Integer, HiddenMarkovModelState> stateMap,
Map<HiddenMarkovModelState, Integer> inverseStateMap) {
if (viterbiMatrix[i][h] != UNSET_PROBABILITY) {
return viterbiMatrix[i][h];
}
final int NUMBER_OF_STATES = stateMap.size();
if (h >= NUMBER_OF_STATES - 1) {
return UNSET_PROBABILITY;
}
if (h == 0) {
return i == 0 ? 1.0 : -1.0;
}
char symbol = sequence.charAt(i - 1);
HiddenMarkovModelState state = stateMap.get(h);
double psih = state.getEmissions().get(symbol);
Set<HiddenMarkovModelState> parentStates = state.getIncomingStates();
double maximumProbability = Double.NEGATIVE_INFINITY;
for (HiddenMarkovModelState parent : parentStates) {
double v =
computeViterbiMatrixImpl(
i - 1,
inverseStateMap.get(parent),
viterbiMatrix,
sequence,
stateMap,
inverseStateMap);
v *= parent.getFollowingStates().get(state);
maximumProbability = Math.max(maximumProbability, v);
}
viterbiMatrix[i][h] = maximumProbability * psih;
return viterbiMatrix[i][h];
}
/**
* Computes the actual forward matrix.
*
* @param i the {@code i} variable; the length of the sequence
* prefix.
* @param h the state index.
* @param forwardMatrix the actual Viterbi matrix under construction.
* @param sequence the symbol sequence.
* @param stateMap the map mapping state IDs to states.
*
* @return the forward matrix.
*/
private double computeForwardMatrixImpl(
int i,
int h,
double[][] forwardMatrix,
String sequence,
Map<Integer, HiddenMarkovModelState> stateMap,
Map<HiddenMarkovModelState, Integer> inverseStateMap) {
if (forwardMatrix[i][h] != UNSET_PROBABILITY) {
return forwardMatrix[i][h];
}
final int NUMBER_OF_STATES = stateMap.size();
if (h >= NUMBER_OF_STATES - 1) {
return UNSET_PROBABILITY;
}
if (h == 0) {
return i == 0 ? 1.0 : 0.0;
}
char symbol = sequence.charAt(i - 1);
HiddenMarkovModelState state = stateMap.get(h);
double psih = state.getEmissions().get(symbol);
Set<HiddenMarkovModelState> parentStates = state.getIncomingStates();
double totalProbability = 0.0;
for (HiddenMarkovModelState parent : parentStates) {
double f =
computeForwardMatrixImpl(
i - 1,
inverseStateMap.get(parent),
forwardMatrix,
sequence,
stateMap,
inverseStateMap);
f *= parent.getFollowingStates().get(state);
totalProbability += f ;
}
forwardMatrix[i][h] = totalProbability * psih;
return forwardMatrix[i][h];
}
private HiddenMarkovModelStatePath
tracebackStateSequenceViterbi(
double[][] viterbiMatrix,
String sequence,
Map<Integer, HiddenMarkovModelState> stateMap,
Map<HiddenMarkovModelState, Integer> inverseStateMap) {
int bottomMaximumIndex = computeBottomMaximumIndex(viterbiMatrix);
HiddenMarkovModelState bottomState = stateMap.get(bottomMaximumIndex);
List<HiddenMarkovModelState> stateList =
new ArrayList<>(viterbiMatrix[0].length);
stateList.add(endState);
final int HIGHEST_I = viterbiMatrix.length - 1;
tracebackStateSequenceImpl(viterbiMatrix,
HIGHEST_I,
bottomState,
stateList,
inverseStateMap);
stateList.add(startState);
Collections.reverse(stateList);
return new HiddenMarkovModelStatePath(stateList, sequence, this);
}
private double
tracebackStateSequenceForward(
double[][] forwardMatrix,
Map<HiddenMarkovModelState, Integer> inverseStateMap) {
final int ROW_INDEX = forwardMatrix.length - 1;
double probability = 0.0;
Set<HiddenMarkovModelState> parents = endState.getIncomingStates();
for (HiddenMarkovModelState parent : parents) {
if (parent.equals(startState) || parent.equals(endState)) {
// Omit the start state:
continue;
}
int parentIndex = inverseStateMap.get(parent);
double p = forwardMatrix[ROW_INDEX][parentIndex];
p *= parent.getFollowingStates().get(endState);
probability += p;
}
return probability;
}
private void tracebackStateSequenceImpl(
double[][] viterbiMatrix,
int i,
HiddenMarkovModelState state,
List<HiddenMarkovModelState> stateList,
Map<HiddenMarkovModelState, Integer> inverseStateMap) {
if (state.equals(startState)) {
return;
}
stateList.add(state);
HiddenMarkovModelState nextState =
computeNextState(viterbiMatrix,
i,
state,
inverseStateMap);
tracebackStateSequenceImpl(viterbiMatrix,
i - 1,
nextState,
stateList,
inverseStateMap);
}
private HiddenMarkovModelState
computeNextState(
double[][] viterbiMatrix,
int i,
HiddenMarkovModelState state,
Map<HiddenMarkovModelState, Integer> inverseStateMap) {
Set<HiddenMarkovModelState> parents = state.getIncomingStates();
HiddenMarkovModelState nextState = null;
double maximumProbability = Double.NEGATIVE_INFINITY;
for (HiddenMarkovModelState parent : parents) {
int parentIndex = inverseStateMap.get(parent);
double probability = parent.getFollowingStates().get(state);
probability *= viterbiMatrix[i - 1][parentIndex];
if (maximumProbability < probability) {
maximumProbability = probability;
nextState = parent;
}
}
return nextState;
}
private int computeBottomMaximumIndex(double[][] viterbiMatrix) {
int maximumIndex = -1;
double maximumProbability = Double.NEGATIVE_INFINITY;
final int ROW_INDEX = viterbiMatrix.length - 1;
for (int i = 1; i < viterbiMatrix[0].length; i++) {
double currentProbability = viterbiMatrix[ROW_INDEX][i];
if (maximumProbability < currentProbability) {
maximumProbability = currentProbability;
maximumIndex = i;
}
}
return maximumIndex;
}
private HiddenMarkovModelState
doStateTransition(HiddenMarkovModelState currentState) {
double coin = random.nextDouble();
for (Map.Entry<HiddenMarkovModelState, Double> e
: currentState.getFollowingStates().entrySet()) {
if (coin >= e.getValue()) {
coin -= e.getValue();
} else {
return e.getKey();
}
}
throw new IllegalStateException("Should not get here.");
}
private char doEmit(HiddenMarkovModelState currentState) {
double coin = random.nextDouble();
for (Map.Entry<Character, Double> e
: currentState.getEmissions().entrySet()) {
if (coin >= e.getValue()) {
coin -= e.getValue();
} else {
return e.getKey();
}
}
throw new IllegalStateException("Should not get here.");
}
private Set<HiddenMarkovModelState> computeAllStates() {
Deque<HiddenMarkovModelState> queue = new ArrayDeque<>();
Set<HiddenMarkovModelState> visited = new HashSet<>();
queue.addLast(startState);
visited.add(startState);
while (!queue.isEmpty()) {
HiddenMarkovModelState currentState = queue.removeFirst();
for (HiddenMarkovModelState followerState
: currentState.getFollowingStates().keySet()) {
if (!visited.contains(followerState)) {
visited.add(followerState);
queue.addLast(followerState);
}
}
}
return visited;
}
private void depthFirstSearchImpl(
List<List<HiddenMarkovModelState>> statePaths,
List<HiddenMarkovModelState> currentPath,
int expectedStatePathSize,
HiddenMarkovModelState currentState) {
if (currentPath.size() == expectedStatePathSize) {
// End recursion. If the current state equals the end state, we have
// a path:
if (currentState.equals(endState)) {
statePaths.add(new ArrayList<>(currentPath));
}
return;
}
// For each child, do...
for (HiddenMarkovModelState followerState
: currentState.getFollowingStates().keySet()) {
// Do state:
currentPath.add(followerState);
// Recur deeper:
depthFirstSearchImpl(statePaths,
currentPath,
expectedStatePathSize,
followerState);
// Undo state:
currentPath.remove(currentPath.size() - 1);
}
}
}
com.github.coderodde.hmm.HiddenMarkovModelState.java:
package com.github.coderodde.hmm;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* This class defines the hidden states of a hidden Markov model.
*/
public final class HiddenMarkovModelState {
/**
* The ID of this state. Used to differentiating between the states.
*/
private final int id;
/**
* The state type of this state.
*/
private final HiddenMarkovModelStateType type;
/**
* Maps each transition target to the transition probability.
*/
private final Map<HiddenMarkovModelState, Double> transitionMap =
new HashMap<>();
/**
* Holds all incoming states.
*/
private final Set<HiddenMarkovModelState> incomingTransitions =
new HashSet<>();
/**
* Maps each emission character to its probability.
*/
private final Map<Character, Double> emissionMap = new HashMap<>();
public HiddenMarkovModelState(int id, HiddenMarkovModelStateType type) {
this.id = id;
this.type = type;
}
public int getId() {
return id;
}
public Map<HiddenMarkovModelState, Double> getFollowingStates() {
return Collections.unmodifiableMap(transitionMap);
}
public Map<Character, Double> getEmissions() {
return Collections.unmodifiableMap(emissionMap);
}
public Set<HiddenMarkovModelState> getIncomingStates() {
return Collections.unmodifiableSet(incomingTransitions);
}
public void normalize() {
normalizeEmissionMap();
normalizeTransitionMap();
}
public void addStateTransition(HiddenMarkovModelState followerState,
Double probability) {
if (type.equals(HiddenMarkovModelStateType.END)) {
throw new IllegalArgumentException(
"End HMM states may not have outgoing state transitions.");
}
transitionMap.put(followerState, probability);
followerState.incomingTransitions.add(this);
}
public void addEmissionTransition(Character character, Double probability) {
switch (type) {
case START:
case END:
throw new IllegalArgumentException(
"Start and end HMM states may not have emissions.");
}
emissionMap.put(character, probability);
}
@Override
public boolean equals(Object o) {
return id == ((HiddenMarkovModelState) o).id;
}
@Override
public int hashCode() {
return id;
}
@Override
public String toString() {
return String.format("[HMM state, ID = %d, type = %s]", id, type);
}
private void normalizeTransitionMap() {
double sumOfProbabilities = computeTransitionProbabilitySum();
for (Map.Entry<HiddenMarkovModelState, Double> e
: transitionMap.entrySet()) {
e.setValue(e.getValue() / sumOfProbabilities);
}
}
private void normalizeEmissionMap() {
double sumOfProbabilities = computeEmissionProbabilitySum();
for (Map.Entry<Character, Double> e : emissionMap.entrySet()) {
e.setValue(e.getValue() / sumOfProbabilities);
}
}
private double computeTransitionProbabilitySum() {
double sumOfProbabilities = 0.0;
for (Double probability : transitionMap.values()) {
sumOfProbabilities += probability;
}
return sumOfProbabilities;
}
private double computeEmissionProbabilitySum() {
double sumOfProbabilities = 0.0;
for (Double probability : emissionMap.values()) {
sumOfProbabilities += probability;
}
return sumOfProbabilities;
}
}
com.github.coderodde.hmm.HiddenMarkovModelStatePath.java:
package com.github.coderodde.hmm;
import static java.lang.Math.E;
import static java.lang.Math.log;
import static java.lang.Math.pow;
import java.util.List;
/**
* This class stores the state path over the hidden states of an HMM.
*/
public final class HiddenMarkovModelStatePath
implements Comparable<HiddenMarkovModelStatePath> {
private final List<HiddenMarkovModelState> stateSequence;
private final double probability;
HiddenMarkovModelStatePath(List<HiddenMarkovModelState> stateSequence,
String observedSymbols,
HiddenMarkovModel hiddenMarkovModel) {
this.stateSequence = stateSequence;
this.probability = computeJointProbability(observedSymbols);
}
public int size() {
return stateSequence.size();
}
public HiddenMarkovModelState getState(int stateIndex) {
return stateSequence.get(stateIndex);
}
public double getProbability() {
return probability;
}
@Override
public int compareTo(HiddenMarkovModelStatePath o) {
return Double.compare(probability, o.probability);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
boolean first = true;
for (HiddenMarkovModelState state : stateSequence) {
if (first) {
first = false;
} else {
sb.append(", ");
}
sb.append(state.getId());
}
sb.append("| p = ");
sb.append(probability);
sb.append("]");
return sb.toString();
}
/**
* Computes the joint probability of this path.
*
* @param observedSymbols the observation text.
*
* @return the joint probability of this path and the input text.
*/
private double computeJointProbability(String observedSymbols) {
double logProbability = computeEmissionProbabilities(observedSymbols) +
computeTransitionProbabilities();
// Convert to probability:
return pow(E, logProbability);
}
/**
* Computes the product of emission probabilities over the input text.
*
* @param observedSymbols the input text.
*
* @return the total emission probability.
*/
private double computeEmissionProbabilities(String observedSymbols) {
double probability = 0.0;
for (int i = 0; i != observedSymbols.length(); i++) {
char observedSymbol = observedSymbols.charAt(i);
HiddenMarkovModelState state = stateSequence.get(i + 1);
probability += log(state.getEmissions().get(observedSymbol));
}
return probability;
}
/**
* Computes the product of transition probabilities over this path.
*
* @return the product of transitions.
*/
private double computeTransitionProbabilities() {
double probability = 0.0;
for (int i = 0; i < stateSequence.size() - 1; i++) {
HiddenMarkovModelState sourceState = stateSequence.get(i);
HiddenMarkovModelState targetState = stateSequence.get(i + 1);
probability += log(sourceState.getFollowingStates()
.get(targetState));
}
return probability;
}
}
com.github.coderodde.hmm.demo.HMMDemo.java:
package com.github.coderodde.hmm.demo;
import com.github.coderodde.hmm.HiddenMarkovModel;
import com.github.coderodde.hmm.HiddenMarkovModelState;
import com.github.coderodde.hmm.HiddenMarkovModelStatePath;
import com.github.coderodde.hmm.HiddenMarkovModelStateType;
import java.util.List;
import java.util.Random;
/**
* This class implements the demonstration of the hidden Markov model.
*/
public final class HMMDemo {
public static void main(String[] args) {
Random random = new Random(13L);
HiddenMarkovModelState startState =
new HiddenMarkovModelState(0, HiddenMarkovModelStateType.START);
HiddenMarkovModelState endState =
new HiddenMarkovModelState(3, HiddenMarkovModelStateType.END);
HiddenMarkovModelState codingState =
new HiddenMarkovModelState(
1,
HiddenMarkovModelStateType.HIDDEN);
HiddenMarkovModelState noncodingState =
new HiddenMarkovModelState(
2,
HiddenMarkovModelStateType.HIDDEN);
HiddenMarkovModel hmm = new HiddenMarkovModel(startState,
endState,
random);
// BEGIN: State transitions.
startState.addStateTransition(noncodingState, 0.49);
startState.addStateTransition(codingState, 0.49);
startState.addStateTransition(endState, 0.02);
codingState.addStateTransition(codingState, 0.4);
codingState.addStateTransition(endState, 0.3);
codingState.addStateTransition(noncodingState, 0.3);
noncodingState.addStateTransition(noncodingState, 0.3);
noncodingState.addStateTransition(codingState, 0.35);
noncodingState.addStateTransition(endState, 0.35);
// END: State transitions.
// BEGIN: Emissions.
codingState.addEmissionTransition('A', 0.18);
codingState.addEmissionTransition('C', 0.32);
codingState.addEmissionTransition('G', 0.32);
codingState.addEmissionTransition('T', 0.18);
noncodingState.addEmissionTransition('A', 0.25);
noncodingState.addEmissionTransition('C', 0.25);
noncodingState.addEmissionTransition('G', 0.25);
noncodingState.addEmissionTransition('T', 0.25);
// END: Emissions.
startState.normalize();
codingState.normalize();
noncodingState.normalize();
endState.normalize();
System.out.println("--- Composing random walks ---");
for (int i = 0; i < 10; i++) {
int lineNumber = i + 1;
System.out.printf("%2d: %s\n", lineNumber, hmm.compose());
}
String sequence = "AGCG";
List<HiddenMarkovModelStatePath> statePathSequences =
hmm.computeAllStatePaths(sequence);
System.out.printf("Brute-force path inference for sequence \"%s\", " +
"total probability = %f.\n",
sequence,
HiddenMarkovModel.sumPathProbabilities(
statePathSequences));
double hmmProbabilitySum = hmm.runForwardAlgorithm(sequence);
System.out.printf("HMM total probability: %f.\n", hmmProbabilitySum);
System.out.println("Brute-force state paths:");
int lineNumber = 1;
for (HiddenMarkovModelStatePath stateSequence : statePathSequences) {
System.out.printf("%4d: %s\n", lineNumber++, stateSequence);
}
}
}
com.github.coderodde.hmm.demo.HiddenMarkovModelStateType.java:
package com.github.coderodde.hmm;
public enum HiddenMarkovModelStateType {
START,
HIDDEN,
END;
@Override
public String toString() {
switch (this) {
case START:
return "S";
case HIDDEN:
return "H";
case END:
return "E";
default:
throw new EnumConstantNotPresentException(
HiddenMarkovModelStateType.class,
"Unknown enum constant: " + this);
}
}
}
Critique request
Please, tell me anything that comes to mind.