Back to index

python-biopython  1.60
test_NNGene.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 # This code is part of the Biopython distribution and governed by its
00004 # license.  Please see the LICENSE file that should have been included
00005 # as part of this package.
00006 
00007 """Test the different representations of Genes.
00008 
00009 This exercises the Motif, Schema and Signature methods of representing
00010 genes, as well as generic Pattern methods.
00011 """
00012 # standard library
00013 import os
00014 import unittest
00015 
00016 # Biopython
00017 from Bio import SeqIO
00018 from Bio.Seq import Seq
00019 from Bio.Alphabet import IUPAC
00020 
00021 # stuff we are testing
00022 from Bio.NeuralNetwork.Gene import Schema
00023 from Bio.NeuralNetwork.Gene import Motif
00024 from Bio.NeuralNetwork.Gene import Signature
00025 from Bio.NeuralNetwork.Gene import Pattern
00026 
00027 VERBOSE = 0
00028 
00029 
00030 # --- Tests for Pattern
00031 
00032 class PatternIOTest(unittest.TestCase):
00033     """Tests for reading and writing patterns to a file.
00034     """
00035     def setUp(self):
00036         self.alphabet = IUPAC.ambiguous_dna
00037         self.test_file = os.path.join("NeuralNetwork", "patternio.txt")
00038         #Remove any existing copy of the output file,
00039         if os.path.isfile(self.test_file):
00040             os.remove(self.test_file)
00041         self.pattern_io = Pattern.PatternIO(self.alphabet)
00042 
00043     def tearDown(self):
00044         #Clean up by removing our output file,
00045         if os.path.isfile(self.test_file):
00046             os.remove(self.test_file)
00047 
00048     def test_motif(self):
00049         """Reading and writing motifs to a file
00050         """
00051         # write to a file
00052         motifs = ["GAC", "AAA", "TTT", "GGG"]
00053         output_handle = open(self.test_file, "w")
00054         self.pattern_io.write(motifs, output_handle)
00055         output_handle.close()
00056 
00057         # read 'em back
00058         input_handle = open(self.test_file, "r")
00059         read_motifs = self.pattern_io.read(input_handle)
00060         input_handle.close()
00061         assert read_motifs == motifs, \
00062                "Failed to get back expected motifs %s, got %s" \
00063                % (motifs, read_motifs)
00064 
00065         # write seqs
00066         seq_motifs = []
00067         for motif in motifs:
00068             seq_motifs.append(Seq(motif, self.alphabet))
00069         output_handle = open(self.test_file, "w")
00070         self.pattern_io.write_seq(seq_motifs, output_handle)
00071         output_handle.close()
00072 
00073         # read the seqs back
00074         input_handle = open(self.test_file, "r")
00075         read_motifs = self.pattern_io.read(input_handle)
00076         input_handle.close()
00077         assert read_motifs == motifs, \
00078                "Failed to get back expected motifs %s from seqs, got %s" \
00079                % (motifs, read_motifs)
00080 
00081     def test_schema(self):
00082         """Reading and writing schemas to a file.
00083         """
00084         schemas = ["GTR", "GAC"]
00085         # write out the schemas
00086         output_handle = open(self.test_file, "w")
00087         self.pattern_io.write(schemas, output_handle)
00088         output_handle.close()
00089 
00090         # read back the schemas
00091         input_handle = open(self.test_file, "r")
00092         read_schemas = self.pattern_io.read(input_handle)
00093         input_handle.close()
00094         assert schemas == read_schemas, \
00095                "Read incorrect schemas %s, expected %s." \
00096                % (read_schemas, schemas)
00097 
00098         # --- make sure inappropriate alphabets are reported
00099         schemas = ["GTR", "G*C"] # '*' not in the unambigous alphabet
00100         output_handle = open(self.test_file, "w")
00101         self.pattern_io.write(schemas, output_handle)
00102         output_handle.close()
00103 
00104         input_handle = open(self.test_file, "r")
00105         try:
00106             read_schemas = self.pattern_io.read(input_handle)
00107             raise AssertionError("Did not report error on bad alphabet.")
00108         except ValueError:
00109             pass # expected behavior
00110         except:
00111             raise AssertionError("Got unexpected error while reading.")
00112 
00113         input_handle.close()
00114 
00115     def test_signature(self):
00116         """Reading and writing signatures to a file.
00117         """
00118         signatures = [("GAC", "GAC"), ("AAA", "TTT")]
00119         output_handle = open(self.test_file, "w")
00120         self.pattern_io.write(signatures, output_handle)
00121         output_handle.close()
00122 
00123         input_handle = open(self.test_file, "r")
00124         read_sigs = self.pattern_io.read(input_handle)
00125         input_handle.close()
00126         assert read_sigs == signatures, \
00127                "Got back unexpected signatures %s, wanted %s" \
00128                % (read_sigs, signatures)
00129 
00130 class PatternRepositoryTest(unittest.TestCase):
00131     """Tests for retrieving info from a repository of patterns.
00132     """ 
00133     def setUp(self):
00134         self.motifs = {"GATC" : 30,
00135                        "GGGG" : 10,
00136                        "GTAG" : 0,
00137                        "AAAA" : -10,
00138                        "ATAT" : -20}
00139 
00140         self.repository = Pattern.PatternRepository(self.motifs)
00141 
00142     def test_get_all(self):
00143         """Retrieve all patterns from a repository.
00144         """
00145         all_motifs = self.repository.get_all()
00146 
00147         assert all_motifs == ["GATC", "GGGG", "GTAG", "AAAA", "ATAT"], \
00148                "Unexpected motifs returned %s" % all_motifs
00149 
00150     def test_get_random(self):
00151         """Retrieve random patterns from the repository.
00152         """
00153         for num_patterns in range(5):
00154             patterns = self.repository.get_random(num_patterns)
00155             assert len(patterns) == num_patterns, \
00156                    "Got unexpected number of patterns %s, expected %s" \
00157                    % (len(patterns), num_patterns)
00158 
00159             for pattern in patterns:
00160                 assert pattern in self.motifs.keys(), \
00161                        "Got unexpected pattern %s" % pattern
00162 
00163     def test_get_top_percentage(self):
00164         """Retrieve the top percentge of patterns from the repository.
00165         """
00166         for num_patterns, percentage in ((1, 0.2), (2, .4), (5, 1.0)):
00167             patterns = self.repository.get_top_percentage(percentage)
00168             assert len(patterns) == num_patterns, \
00169                    "Got unexpected number of patterns %s, expected %s" \
00170                    % (len(patterns), num_patterns)
00171 
00172             for pattern in patterns:
00173                 assert pattern in self.motifs.keys(), \
00174                        "Got unexpected pattern %s" % pattern      
00175 
00176     def test_get_top(self):
00177         """Retrieve a certain number of the top patterns.
00178         """
00179         for num_patterns in range(5):
00180             patterns = self.repository.get_top(num_patterns)
00181             assert len(patterns) == num_patterns, \
00182                    "Got unexpected number of patterns %s, expected %s" \
00183                    % (len(patterns), num_patterns)
00184 
00185             for pattern in patterns:
00186                 assert pattern in self.motifs.keys(), \
00187                        "Got unexpected pattern %s" % pattern       
00188 
00189     def test_get_differing(self):
00190         """Retrieve patterns from both sides of the list (top and bottom).
00191         """
00192         patterns = self.repository.get_differing(2, 2)
00193         assert patterns == ["GATC", "GGGG", "AAAA", "ATAT"], \
00194                "Got unexpected patterns %s" % patterns
00195 
00196     def test_remove_polyA(self):
00197         """Test the ability to remove A rich patterns from the repository.
00198         """
00199         patterns = self.repository.get_all()
00200         assert len(patterns) == 5, "Unexpected starting: %s" % patterns
00201 
00202         self.repository.remove_polyA()
00203         
00204         patterns = self.repository.get_all()
00205         assert len(patterns) == 3, "Unexpected ending: %s" % patterns
00206         assert patterns == ["GATC", "GGGG", "GTAG"], \
00207                "Unexpected patterns: %s" % patterns
00208 
00209     def test_count(self):
00210         """Retrieve counts for particular patterns in the repository.
00211         """
00212         num_times = self.repository.count("GGGG")
00213         assert num_times == 10, \
00214                "Did not count item in the respository: %s" % num_times
00215 
00216         num_times = self.repository.count("NOT_IN_THERE")
00217         assert num_times == 0, \
00218                "Counted items not in repository: %s" % num_times
00219 
00220 # --- Tests for motifs
00221 
00222 class MotifFinderTest(unittest.TestCase):
00223     """Tests for finding motifs from sequences.
00224     """
00225     def setUp(self):
00226         test_file = os.path.join('NeuralNetwork', 'enolase.fasta')
00227         diff_file = os.path.join('NeuralNetwork', 'repeat.fasta')
00228 
00229         self.test_records = []
00230         self.diff_records = []
00231 
00232         # load the records
00233         for file, records in ((test_file, self.test_records),
00234                               (diff_file, self.diff_records)):
00235 
00236             handle = open(file, 'r')
00237 
00238             iterator = SeqIO.parse(handle, "fasta",
00239                                    alphabet=IUPAC.unambiguous_dna)
00240             while 1:
00241                 try:
00242                     seq_record = iterator.next()
00243                 except StopIteration:
00244                     break
00245                 if seq_record is None:
00246                     break
00247 
00248                 records.append(seq_record)
00249 
00250             handle.close()
00251 
00252         self.motif_finder = Motif.MotifFinder()
00253 
00254     def test_find(self):
00255         """Find all motifs in a set of sequences.
00256         """
00257         motif_repository = self.motif_finder.find(self.test_records, 8)
00258         top_motif = motif_repository.get_top(1)
00259 
00260         assert top_motif[0] == 'TTGGAAAG', \
00261                "Got unexpected motif %s" % top_motif[0]
00262 
00263     def test_find_differences(self):
00264         """Find the difference in motif counts between two sets of sequences.
00265         """
00266         motif_repository = \
00267                self.motif_finder.find_differences(self.test_records,
00268                                                   self.diff_records, 8)
00269 
00270         top, bottom = motif_repository.get_differing(1, 1)
00271 
00272         assert top == "TTGGAAAG", "Got unexpected top motif %s" % top
00273         assert bottom == "AATGGCAT", "Got unexpected bottom motif %s" % bottom
00274 
00275 class MotifCoderTest(unittest.TestCase):
00276     """Test the ability to encode sequences as a set of motifs.
00277     """
00278     def setUp(self):
00279         motifs = ["GAG", "GAT", "GCC", "ATA"]
00280 
00281         self.match_strings = (("GATCGCC", [0.0, 1.0, 1.0, 0.0]),
00282                               ("GATGATCGAGCC", [.5, 1.0, .5, 0.0]))
00283         
00284         self.coder = Motif.MotifCoder(motifs)
00285 
00286         
00287     def test_representation(self):
00288         """Convert a sequence into its motif representation.
00289         """
00290         for match_string, expected in self.match_strings:
00291             seq_to_code = Seq(match_string, IUPAC.unambiguous_dna)
00292             matches = self.coder.representation(seq_to_code)
00293 
00294             assert matches == expected, \
00295                    "Did not match representation, expected %s, got %s" \
00296                    % (expected, matches)
00297 
00298 # --- Tests for schemas
00299 
00300 class SchemaTest(unittest.TestCase):
00301     """Matching ambiguous motifs with multiple ambiguity characters.
00302     """
00303     def setUp(self):
00304         ambiguity_chars = {"G" : "G",
00305                            "A" : "A",
00306                            "T" : "T",
00307                            "C" : "C",
00308                            "R" : "AG",
00309                            "*" : "AGTC"}
00310 
00311         self.motif_coder = Schema.Schema(ambiguity_chars)
00312 
00313         self.match_string = "GATAG"
00314         self.match_info = [("GA", ["GA"]),
00315                            ("GATAG", ["GATAG"]),
00316                            ("GA*AG", ["GATAG"]),
00317                            ("GATRG", ["GATAG"]),
00318                            ("*A", ["GA", "TA"])]
00319 
00320     def test_find_matches(self):
00321         """Find all matches in a sequence.
00322         """
00323         for motif, expected in self.match_info:
00324             found_matches = self.motif_coder.find_matches(motif,
00325                                                           self.match_string)
00326             assert found_matches == expected, "Expected %s, got %s" \
00327                    % (expected, found_matches)
00328 
00329     def test_num_matches(self):
00330         """Find how many matches are present in a sequence.
00331         """
00332         for motif, expected in self.match_info:
00333             num_matches = self.motif_coder.num_matches(motif,
00334                                                        self.match_string)
00335             assert num_matches == len(expected), \
00336                    "Expected %s, got %s" % (num_matches, len(expected))
00337 
00338     def test_find_ambiguous(self):
00339         """Find the positions of ambiguous items in a sequence.
00340         """
00341         ambig_info = (("GATC", []),
00342                       ("G***", [1, 2, 3]),
00343                       ("GART", [2]),
00344                       ("*R*R", [0, 1, 2, 3]))
00345 
00346         for motif, expected in ambig_info:
00347             found_positions = self.motif_coder.find_ambiguous(motif)
00348             assert found_positions == expected, \
00349                    "Expected %s, got %s for %s" % (expected, found_positions,
00350                                                    motif)
00351         
00352     def test_num_ambiguous(self):
00353         """Find the number of ambiguous items in a sequence.
00354         """
00355         ambig_info = (("GATC", 0),
00356                       ("G***", 3),
00357                       ("GART", 1),
00358                       ("*R*R", 4))
00359 
00360         for motif, expected in ambig_info:
00361             found_num = self.motif_coder.num_ambiguous(motif)
00362             assert found_num == expected, \
00363                    "Expected %s, got %s for %s" % (expected, found_num, motif)
00364 
00365     def test_motif_cache(self):
00366         """Make sure motif compiled regular expressions are cached properly.
00367         """
00368         test_motif = "GATC"
00369 
00370         self.motif_coder.find_matches(test_motif, "GATCGATC")
00371 
00372         self.assertTrue(test_motif in self.motif_coder._motif_cache,
00373                      "Did not find motif cached properly.")
00374 
00375         # make sure we don't bomb out if we use the same motif twice
00376         self.motif_coder.find_matches(test_motif, "GATCGATC")
00377 
00378     def test_all_unambiguous(self):
00379         """Return all unambiguous characters that can be in a motif.
00380         """
00381         found_unambig = self.motif_coder.all_unambiguous()
00382 
00383         expected = ["A", "C", "G", "T"]
00384         assert found_unambig == expected, \
00385                "Got %s, expected %s" % (found_unambig, expected)
00386 
00387 class SchemaFinderTest(unittest.TestCase):
00388     """Test finding schemas from a set of sequences.
00389     """
00390     def setUp(self):
00391         test_file = os.path.join('NeuralNetwork', 'enolase.fasta')
00392         diff_file = os.path.join('NeuralNetwork', 'repeat.fasta')
00393 
00394         self.test_records = []
00395         self.diff_records = []
00396 
00397         # load the records
00398         for file, records in ((test_file, self.test_records),
00399                               (diff_file, self.diff_records)):
00400 
00401             handle = open(file, 'r')
00402             records.extend(SeqIO.parse(handle, "fasta",
00403                                        alphabet=IUPAC.unambiguous_dna))
00404             handle.close()
00405 
00406         self.num_schemas = 2
00407         schema_ga = Schema.GeneticAlgorithmFinder()
00408         schema_ga.min_generations = 1
00409         self.finder = Schema.SchemaFinder(num_schemas = self.num_schemas,
00410                                           schema_finder = schema_ga)
00411 
00412     def test_find(self):
00413         """Find schemas from sequence inputs.
00414         """
00415         # this test takes too long
00416         if VERBOSE:
00417             repository = self.finder.find(self.test_records + self.diff_records)
00418             schemas = repository.get_all()
00419 
00420             assert len(schemas) >= self.num_schemas, "Got too few schemas."
00421 
00422     def test_find_differences(self):
00423         """Find schemas that differentiate between two sets of sequences.
00424         """
00425         # this test takes too long
00426         if VERBOSE:
00427             repository = self.finder.find_differences(self.test_records,
00428                                                       self.diff_records)
00429             schemas = repository.get_all()
00430 
00431             assert len(schemas) >= self.num_schemas, "Got too few schemas."
00432         
00433 class SchemaCoderTest(unittest.TestCase):
00434     """Test encoding sequences as a grouping of motifs.
00435     """
00436     def setUp(self):
00437         ambiguity_chars = {"G" : "G",
00438                            "A" : "A",
00439                            "T" : "T",
00440                            "C" : "C",
00441                            "R" : "AG",
00442                            "*" : "AGTC"}
00443 
00444         motif_representation = Schema.Schema(ambiguity_chars)
00445         motifs = ("GA", "GATAG", "GA*AG", "GATRG", "*A")
00446         self.motif_coder = Schema.SchemaCoder(motifs,
00447                                               motif_representation)
00448 
00449         self.match_strings = [("GATAG", [.5, .5, .5, .5, 1.0]),
00450                               ("GAGAGATA", [float(3) / float(4), 0,
00451                                             float(1) / float(4), 0,
00452                                             1])]
00453 
00454     def test_representation(self):
00455         """Convert a string into a representation of motifs.
00456         """
00457         for match_string, expected in self.match_strings:
00458             match_seq = Seq(match_string, IUPAC.unambiguous_dna)
00459             found_rep = self.motif_coder.representation(match_seq)
00460             assert found_rep == expected, "Got %s, expected %s" % \
00461                    (found_rep, expected)
00462     
00463 class SchemaMatchingTest(unittest.TestCase):
00464     """Matching schema to strings works correctly.
00465     """
00466     def shortDescription(self):
00467         return "%s:%s" % (self.__class__.__name__, self.__doc__)
00468     
00469     def runTest(self):
00470         match = Schema.matches_schema("GATC", "AAAAA")
00471         assert match == 0, "Expected no match because of length differences"
00472 
00473         match = Schema.matches_schema("GATC", "GAT*")
00474         assert match == 1, "Expected match"
00475 
00476         match = Schema.matches_schema("GATC", "GATC")
00477         assert match == 1, "Expected match"
00478 
00479         match = Schema.matches_schema("GATC", "C*TC")
00480         assert match == 0, "Expected no match because of char mismatch."
00481 
00482         match = Schema.matches_schema("G*TC", "*TTC")
00483         assert match == 1, "Expected match because of ambiguity."
00484 
00485 class SchemaFactoryTest(unittest.TestCase):
00486     """Test the SchemaFactory for generating Schemas.
00487     """
00488     def __init__(self, method):
00489         unittest.TestCase.__init__(self, method)
00490 
00491         # a cached schema bank, so we don't have to load it multiple times
00492         self.schema_bank = None
00493     
00494     def setUp(self):
00495         self.factory = Schema.SchemaFactory()
00496 
00497         self.test_file = os.path.join(os.getcwd(), "NeuralNetwork", "enolase.fasta")
00498 
00499         ambiguity_chars = {"G" : "G",
00500                            "A" : "A",
00501                            "T" : "T",
00502                            "C" : "C",
00503                            "R" : "AG",
00504                            "*" : "AGTC"}
00505 
00506         self.schema = Schema.Schema(ambiguity_chars)
00507 
00508     def test_easy_from_motifs(self):
00509         """Generating schema from a simple list of motifs.
00510         """
00511         motifs = {"GATCGAA" : 20,
00512                   "GATCGAT" : 15,
00513                   "GATTGAC" : 25,
00514                   "TTTTTTT" : 10}
00515 
00516         motif_bank = Pattern.PatternRepository(motifs)
00517 
00518         schema_bank = self.factory.from_motifs(motif_bank, .5, 2)
00519         if VERBOSE:
00520             print "\nSchemas:"
00521             for schema in schema_bank.get_all():
00522                 print "%s: %s" % (schema, schema_bank.count(schema))
00523 
00524     def test_hard_from_motifs(self):
00525         """Generating schema from a real life set of motifs.
00526         """
00527         schema_bank = self._load_schema_repository()
00528 
00529         if VERBOSE:
00530             print "\nSchemas:"
00531             for schema in schema_bank.get_top(5):
00532                 print "%s: %s" % (schema, schema_bank.count(schema))
00533 
00534     def _load_schema_repository(self):
00535         """Helper function to load a schema repository from a file.
00536 
00537         This also caches a schema bank, to prevent having to do this
00538         time consuming operation multiple times.
00539         """
00540         # if we already have a cached repository, return it
00541         if self.schema_bank is not None:
00542             return self.schema_bank
00543         
00544         # otherwise, we'll read in a new schema bank
00545 
00546         # read in the all of the motif records
00547         motif_handle = open(self.test_file, 'r')
00548         seq_records = list(SeqIO.parse(motif_handle, "fasta",
00549                                        alphabet=IUPAC.unambiguous_dna))
00550         motif_handle.close()
00551 
00552         # find motifs from the file
00553         motif_finder = Motif.MotifFinder()
00554         motif_size = 9
00555 
00556         motif_bank = motif_finder.find(seq_records, motif_size)
00557         
00558         schema_bank = self.factory.from_motifs(motif_bank, .1, 2)
00559 
00560         # cache the repository
00561         self.schema_bank = schema_bank
00562 
00563         return schema_bank
00564 
00565     def test_schema_representation(self):
00566         """Convert sequences into schema representations.
00567         """
00568         # get a set of schemas we want to code the sequence in
00569         schema_bank = self._load_schema_repository()
00570         top_schemas = schema_bank.get_top(25)
00571         schema_coder = Schema.SchemaCoder(top_schemas, self.schema)
00572 
00573         # get the sequences one at a time, and encode them
00574         fasta_handle = open(self.test_file, 'r')
00575         for seq_record in SeqIO.parse(fasta_handle, "fasta",
00576                                       alphabet=IUPAC.unambiguous_dna):
00577             schema_values = schema_coder.representation(seq_record.seq)
00578             if VERBOSE:
00579                 print "Schema values:", schema_values
00580         fasta_handle.close()
00581 
00582 # --- Tests for Signatures
00583 class SignatureFinderTest(unittest.TestCase):
00584     """Test the ability to find signatures in a set of sequences.
00585     """
00586     def setUp(self):
00587         test_file = os.path.join('NeuralNetwork', 'enolase.fasta')
00588 
00589         self.test_records = []
00590 
00591         # load the records
00592         handle = open(test_file, 'r')
00593         self.test_records = list(SeqIO.parse(handle, "fasta",
00594                                              alphabet=IUPAC.unambiguous_dna))
00595         handle.close()
00596 
00597         self.sig_finder = Signature.SignatureFinder()
00598 
00599     def test_find(self):
00600         """Find signatures from sequence inputs.
00601         """
00602         repository = self.sig_finder.find(self.test_records, 6, 9)
00603         top_sig = repository.get_top(1)
00604 
00605         assert top_sig[0] == ('TTGGAA', 'TGGAAA'), \
00606                "Unexpected signature %s" % top_sig[0]
00607 
00608 class SignatureCoderTest(unittest.TestCase):
00609     """Test the ability to encode sequences as a set of signatures.
00610     """
00611     def setUp(self):
00612         signatures = [("GAC", "GAC"), ("AAA", "TTT"), ("CAA", "TTG")]
00613 
00614         self.coder = Signature.SignatureCoder(signatures, 9)
00615 
00616         self.test_seqs = [("GACAAAGACTTT", [1.0, 1.0, 0.0]),
00617                           ("CAAAGACGACTTTAAATTT", [0.5, 1.0, 0.0]),
00618                           ("AAATTTAAAGACTTTGAC", [1.0 / 3.0, 1.0, 0.0]),
00619                           ("GACGAC", [1.0, 0.0, 0.0]),
00620                           ("GACAAAAAAAAAGAC", [1.0, 0.0, 0.0]),
00621                           ("GACAAAAAAAAAAGAC", [0.0, 0.0, 0.0])]
00622 
00623     def test_representation(self):
00624         """Convert a sequence into its signature representation.
00625         """
00626         for seq_string, expected in self.test_seqs:
00627             test_seq = Seq(seq_string, IUPAC.unambiguous_dna)
00628             predicted = self.coder.representation(test_seq)
00629 
00630             assert predicted == expected, \
00631                    "Non-expected representation %s for %s, wanted %s" \
00632                    % (predicted, seq_string, expected)
00633 
00634 if __name__ == "__main__":
00635     runner = unittest.TextTestRunner(verbosity = 2)
00636     unittest.main(testRunner=runner)