Back to index

python-biopython  1.60
test_HMMCasino.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 """Test out HMMs using the Occasionally Dishonest Casino.
00003 
00004 This uses the ocassionally dishonest casino example from Biological
00005 Sequence Analysis by Durbin et al.
00006 
00007 In this example, we are dealing with a casino that has two types of
00008 dice, a fair dice that has 1/6 probability of rolling any number and
00009 a loaded dice that has 1/2 probability to roll a 6, and 1/10 probability
00010 to roll any other number. The probability of switching from the fair to
00011 loaded dice is .05 and the probability of switching from loaded to fair is
00012 .1.
00013 """
00014 
00015 import os
00016 if os.name == 'java':
00017     from Bio import MissingExternalDependencyError
00018     #This is a slight miss-use of MissingExternalDependencyError,
00019     #but it will do in the short term to skip this unit test on Jython
00020     raise MissingExternalDependencyError("This test can cause a fatal error "
00021         "on Jython with some versions of Java")
00022 
00023 # standard modules
00024 import random
00025 
00026 # biopython
00027 from Bio import Alphabet
00028 from Bio.Seq import MutableSeq
00029 from Bio.Seq import Seq
00030 
00031 # HMM stuff we are testing
00032 from Bio.HMM import MarkovModel
00033 from Bio.HMM import Trainer
00034 from Bio.HMM import Utilities
00035 
00036 # whether we should print everything out. Set this to zero for
00037 # regression testing
00038 VERBOSE = 0
00039 
00040 # -- set up our alphabets
00041 class DiceRollAlphabet(Alphabet.Alphabet):
00042     letters = ['1', '2', '3', '4', '5', '6']
00043 
00044 class DiceTypeAlphabet(Alphabet.Alphabet):
00045     letters = ['F', 'L']
00046 
00047 # -- useful functions
00048 def _loaded_dice_roll(chance_num, cur_state):
00049     """Generate a loaded dice roll based on the state and a random number
00050     """
00051     if cur_state == 'F':
00052         if chance_num <= (float(1) / float(6)):
00053             return '1'
00054         elif chance_num <= (float(2) / float(6)):
00055             return '2'
00056         elif chance_num <= (float(3) / float(6)):
00057             return '3'
00058         elif chance_num <= (float(4) / float(6)):
00059             return '4'
00060         elif chance_num <= (float(5) / float(6)):
00061             return '5'
00062         else:
00063             return '6'
00064     elif cur_state == 'L':
00065         if chance_num <= (float(1) / float(10)):
00066             return '1'
00067         elif chance_num <= (float(2) / float(10)):
00068             return '2'
00069         elif chance_num <= (float(3) / float(10)):
00070             return '3'
00071         elif chance_num <= (float(4) / float(10)):
00072             return '4'
00073         elif chance_num <= (float(5) / float(10)):
00074             return '5'
00075         else:
00076             return '6'
00077     else:
00078         raise ValueError("Unexpected cur_state %s" % cur_state)
00079 
00080 def generate_rolls(num_rolls):
00081     """Generate a bunch of rolls corresponding to the casino probabilities.
00082 
00083     Returns:
00084 
00085     o The generate roll sequence
00086 
00087     o The state sequence that generated the roll.
00088     """
00089     # start off in the fair state
00090     cur_state = 'F'
00091     roll_seq = MutableSeq('', DiceRollAlphabet())
00092     state_seq = MutableSeq('', DiceTypeAlphabet())
00093 
00094     # generate the sequence
00095     for roll in range(num_rolls):
00096         state_seq.append(cur_state)
00097         # generate a random number
00098         chance_num = random.random()
00099 
00100         # add on a new roll to the sequence
00101         new_roll = _loaded_dice_roll(chance_num, cur_state)
00102         roll_seq.append(new_roll)
00103 
00104         # now give us a chance to switch to a new state
00105         chance_num = random.random()
00106         if cur_state == 'F':
00107             if chance_num <= .05:
00108                 cur_state = 'L'
00109         elif cur_state == 'L':
00110             if chance_num <= .1:
00111                 cur_state = 'F'
00112 
00113     return roll_seq.toseq(), state_seq.toseq()
00114     
00115 # -- build a MarkovModel
00116 mm_builder = MarkovModel.MarkovModelBuilder(DiceTypeAlphabet(),
00117                                             DiceRollAlphabet())
00118 
00119 mm_builder.allow_all_transitions()
00120 mm_builder.set_random_probabilities()
00121 """
00122 mm_builder.set_transition_score('F', 'L', .05)
00123 mm_builder.set_transition_score('F', 'F', .95)
00124 mm_builder.set_transition_score('L', 'F', .10)
00125 mm_builder.set_transition_score('L', 'L', .9)
00126 mm_builder.set_emission_score('F', '1', .17)
00127 mm_builder.set_emission_score('F', '2', .17)
00128 mm_builder.set_emission_score('F', '3', .17)
00129 mm_builder.set_emission_score('F', '4', .17)
00130 mm_builder.set_emission_score('F', '5', .17)
00131 mm_builder.set_emission_score('F', '6', .17)
00132 mm_builder.set_emission_score('L', '1', .1)
00133 mm_builder.set_emission_score('L', '2', .1)
00134 mm_builder.set_emission_score('L', '3', .1)
00135 mm_builder.set_emission_score('L', '4', .1)
00136 mm_builder.set_emission_score('L', '5', .1)
00137 mm_builder.set_emission_score('L', '6', .5)
00138 """
00139 
00140 # just get two different Markov Models -- we'll train one using
00141 # Baum Welch, and one using the Standard trainer
00142 baum_welch_mm = mm_builder.get_markov_model()
00143 standard_mm = mm_builder.get_markov_model()
00144 
00145 # get a sequence of rolls to train the markov model with
00146 rolls, states = generate_rolls(3000)
00147 
00148 # predicted_states, prob = my_mm.viterbi(rolls, DiceTypeAlphabet())
00149 # print "prob:", prob
00150 # Utilities.pretty_print_prediction(rolls, states, predicted_states)
00151 
00152 
00153 # -- now train the model
00154 def stop_training(log_likelihood_change, num_iterations):
00155     """Tell the training model when to stop.
00156     """
00157     if VERBOSE:
00158         print "ll change:", log_likelihood_change
00159     if log_likelihood_change < 0.01:
00160         return 1
00161     elif num_iterations >= 10:
00162         return 1
00163     else:
00164         return 0
00165 
00166 # -- Standard Training with known states
00167 print "Training with the Standard Trainer..."
00168 known_training_seq = Trainer.TrainingSequence(rolls, states)
00169 
00170 trainer = Trainer.KnownStateTrainer(standard_mm)
00171 trained_mm = trainer.train([known_training_seq])
00172 
00173 if VERBOSE:
00174     print trained_mm.transition_prob
00175     print trained_mm.emission_prob
00176 
00177 test_rolls, test_states = generate_rolls(300)
00178 
00179 predicted_states, prob = trained_mm.viterbi(test_rolls, DiceTypeAlphabet())
00180 if VERBOSE:
00181     print "Prediction probability:", prob
00182     Utilities.pretty_print_prediction(test_rolls, test_states, predicted_states)
00183 
00184 # -- Baum-Welch training without known state sequences
00185 print "Training with Baum-Welch..."
00186 training_seq = Trainer.TrainingSequence(rolls, Seq("", DiceTypeAlphabet()))
00187 
00188 trainer = Trainer.BaumWelchTrainer(baum_welch_mm)
00189 trained_mm = trainer.train([training_seq], stop_training)
00190 
00191 if VERBOSE:
00192     print trained_mm.transition_prob
00193     print trained_mm.emission_prob
00194 
00195 test_rolls, test_states = generate_rolls(300)
00196 
00197 predicted_states, prob = trained_mm.viterbi(test_rolls, DiceTypeAlphabet())
00198 if VERBOSE:
00199     print "Prediction probability:", prob
00200     Utilities.pretty_print_prediction(test_rolls, test_states, predicted_states)
00201