Skip to content

Commit 4ada232

Browse files
committed
Added code for creating a random forest with all its trees
1 parent e79352a commit 4ada232

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

decision_trees/vhdl_generators/VHDLCreator.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,6 @@ def _add_architecture_signal_section(self) -> str:
135135
def _add_architecture_process_section(self) -> str:
136136
return
137137

138-
def create_vhdl_file(self, path: str) -> str:
139-
# open file for writing
140-
file_to_write = open(path + "/" + self._filename, "w")
141-
# add necessary headers
142-
text = ""
143-
text += self._add_headers()
144-
text += self._add_entity()
145-
text += self._add_architecture()
146-
147-
file_to_write.write(text)
148-
149-
file_to_write.close()
138+
@abc.abstractmethod
139+
def create_vhdl_file(self, path: str):
140+
return

decision_trees/vhdl_generators/__init__.py

Whitespace-only changes.

decision_trees/vhdl_generators/random_forest.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List
2+
13
from decision_trees.vhdl_generators.VHDLCreator import VHDLCreator
24
from decision_trees.vhdl_generators.tree import Tree
35

@@ -10,14 +12,14 @@
1012
class RandomForest(VHDLCreator):
1113

1214
def __init__(self, name: str, number_of_features: int, number_of_bits_per_feature: int):
13-
self.random_forest = []
15+
self.random_forest: List[Tree] = []
1416

1517
VHDLCreator.__init__(self, name, ClassifierType.RANDOM_FOREST.name,
1618
number_of_features, number_of_bits_per_feature)
1719

1820
def build(self, random_forest: sklearn.ensemble.RandomForestClassifier):
1921
for i, tree in enumerate(random_forest.estimators_):
20-
tree_builder = Tree("tree_" + str(i), self._number_of_features, self._number_of_bits_per_feature)
22+
tree_builder = Tree(f'tree_{i:02}', self._number_of_features, self._number_of_bits_per_feature)
2123
tree_builder.build(tree)
2224

2325
self.random_forest.append(tree_builder)
@@ -126,3 +128,15 @@ def _add_port_mapping(self, index: int) -> str:
126128
self.current_indent -= 1
127129

128130
return text
131+
132+
def create_vhdl_file(self, path: str):
133+
for d in self.random_forest:
134+
d.create_vhdl_file(path)
135+
136+
with open(path + '/' + self._filename, 'w') as f:
137+
text = ''
138+
text += self._add_headers()
139+
text += self._add_entity()
140+
text += self._add_architecture()
141+
f.write(text)
142+

decision_trees/vhdl_generators/tree.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import sklearn.tree
55

6-
from decision_trees.utils.convert_to_fixed_point import convert_to_fixed_point
6+
from decision_trees.utils.convert_to_fixed_point import convert_to_fixed_point, convert_fixed_point_to_integer
77
from decision_trees.utils.constants import ClassifierType
88

99

@@ -230,7 +230,7 @@ def _add_architecture_process_compare(self):
230230
text += self._insert_text_line_with_indent(
231231
"if unsigned(features(" +
232232
str(split.var_idx) + ")) <= to_unsigned(" +
233-
str(int(split.value_to_compare)) +
233+
str(convert_fixed_point_to_integer(split.value_to_compare, self._number_of_bits_per_feature)) +
234234
", features'length) then")
235235

236236
self.current_indent += 1
@@ -323,6 +323,9 @@ def _preorder(self, tree_, features, following_splits_IDs, following_splits_comp
323323
# print("Feature: " + str(tree_.threshold[node]) + ", after conversion: "
324324
# + str(convert_to_fixed_point(tree_.threshold[node], self._number_of_bits_per_feature)))
325325

326+
#print(f'tree_.threshold[node]: {tree_.threshold[node]:{1}.{3}}')
327+
#print(f'convert_to_fixed_point: {convert_to_fixed_point(tree_.threshold[node], self._number_of_bits_per_feature):{1}.{3}}')
328+
326329
# then create a split
327330
self._add_new_split(
328331
self._current_split_index,
@@ -364,3 +367,11 @@ def _preorder(self, tree_, features, following_splits_IDs, following_splits_comp
364367
following_splits_IDs,
365368
following_splits_compare_values
366369
)
370+
371+
def create_vhdl_file(self, path: str):
372+
with open(path + '/' + self._filename, 'w') as f:
373+
text = ''
374+
text += self._add_headers()
375+
text += self._add_entity()
376+
text += self._add_architecture()
377+
f.write(text)

0 commit comments

Comments
 (0)