Package nltk_lite :: Package contrib :: Package classifier_tests :: Module decisiontreetests
[hide private]
[frames] | no frames]

Source Code for Module nltk_lite.contrib.classifier_tests.decisiontreetests

 1  # Natural Language Toolkit 
 2  # 
 3  # Author: Sumukh Ghodke <sumukh dot ghodke at gmail dot com> 
 4  # 
 5  # URL: <http://nltk.sf.net> 
 6  # This software is distributed under GPL, for license information see LICENSE.TXT 
 7   
 8  from nltk_lite.contrib.classifier_tests import * 
 9  from nltk_lite.contrib.classifier import decisiontree, decisionstump as ds, instances as ins, format 
10  from nltk_lite.contrib.classifier.exceptions import invaliddataerror as inv 
11   
12 -class DecisionTreeTestCase(unittest.TestCase):
13 - def test_tree_creation(self):
14 path = datasetsDir(self) + 'test_phones' + SEP + 'phoney' 15 tree = decisiontree.DecisionTree(format.C45_FORMAT.get_training_instances(path), format.C45_FORMAT.get_attributes(path), format.C45_FORMAT.get_klass(path)) 16 self.assertNotEqual(None, tree) 17 self.assertNotEqual(None, tree.root) 18 self.assertEqual('band', tree.root.attribute.name) 19 self.assertEqual(1, len(tree.root.children)) 20 self.assertEqual('size', tree.root.children['tri'].attribute.name)
21
23 path = datasetsDir(self) + 'minigolf' + SEP + 'weather' 24 tree = decisiontree.DecisionTree(format.C45_FORMAT.get_training_instances(path), format.C45_FORMAT.get_attributes(path), format.C45_FORMAT.get_klass(path)) 25 outlook = tree.attributes[0] 26 self.assertEqual(9, len(tree.training)) 27 filtered = tree.training.filter(outlook, 'sunny') 28 self.assertEqual(9, len(tree.training)) 29 self.assertEqual(4, len(filtered))
30
32 path = datasetsDir(self) + 'test_phones' + SEP + 'phoney' 33 tree = decisiontree.DecisionTree(format.C45_FORMAT.get_training_instances(path), format.C45_FORMAT.get_attributes(path), format.C45_FORMAT.get_klass(path)) 34 max_ig_stump = tree.maximum_information_gain() 35 self.assertEqual('size', max_ig_stump.attribute.name)
36
38 path = datasetsDir(self) + 'test_phones' + SEP + 'phoney' 39 tree = decisiontree.DecisionTree(format.C45_FORMAT.get_training_instances(path), format.C45_FORMAT.get_attributes(path), format.C45_FORMAT.get_klass(path)) 40 max_gr_stump = tree.maximum_gain_ratio() 41 self.assertEqual('pda', max_gr_stump.attribute.name)
42 43 44 # outlook 45 # sunny / | \ rainy 46 # / | \ 47 # temperature windy 48 #
50 path = datasetsDir(self) + 'minigolf' + SEP + 'weather' 51 tree = decisiontree.DecisionTree(format.C45_FORMAT.get_training_instances(path), format.C45_FORMAT.get_attributes(path), format.C45_FORMAT.get_klass(path)) 52 self.assertEqual('outlook', tree.root.attribute.name) 53 children = tree.root.children 54 self.assertEqual(2, len(children)) 55 56 sunny = children['sunny'] 57 self.assertEqual('temperature', sunny.attribute.name) 58 self.assertEqual(0, len(sunny.children)) 59 60 rainy = children['rainy'] 61 self.assertEqual('windy', rainy.attribute.name) 62 self.assertEqual(0, len(rainy.children))
63
65 try: 66 path = datasetsDir(self) + 'numerical' + SEP + 'weather' 67 decisiontree.DecisionTree(format.C45_FORMAT.get_training_instances(path), format.C45_FORMAT.get_attributes(path), format.C45_FORMAT.get_klass(path)) 68 self.fail('should have thrown an error') 69 except inv.InvalidDataError: 70 pass
71