from io import open
from operator import itemgetter
from .armine import ARM
from .rule import ClassificationRule
[docs]class ARMClassifier(ARM):
"""Utility class for Classification Rule Mining.
This class provides methods to generate a set of Classification rules
from a transactional dataset or a tabular dataset. You can then use this
class to classify unclassified data instances. The classification is done
using a modified version of the CBA Algorithm.
"""
def __init__(self):
super(ARMClassifier, self).__init__()
self._classes = []
self._default_class = None
self._transactional_database = False
[docs] def load(self, data, transactional_database=False):
"""Load dataset from a Dictionary.
Parameters
----------
data : dict
Dictionary with keys as features and values as labels.
transactional_database : bool
Whether the database is transactional(Default False).
Note
----
A database is transactional, if it contains transactions accompanied
with respective labels. On the other hand, A non transactional
database is basically a tabular dataset, with each column representing
a distinct feature.
"""
self._clear()
for features, label in data.items():
if not transactional_database:
features = ["feature{}-{}".format(i+1, feature)
for i, feature in enumerate(features)]
self._dataset.append(tuple(features))
self._classes.append(label)
self._transactional_database = transactional_database
[docs] def load_from_csv(self, filename, label_index=0,
transactional_database=False):
"""Load dataset from a csv file.
Parameters
----------
filename : string
Name of the csv file which contains the dataset.
label_index : int
Index of the column which contains the labels for each row.
Supports negative indexing(Default -1 which corresponds to the
last column).
transactional_database : bool
Whether the database is transactional(Default False).
"""
self._clear()
import csv
with open(filename, newline='') as csvfile:
mycsv = csv.reader(csvfile)
for row in mycsv:
label = row[label_index]
if label_index >= 0:
features = row[:label_index] + row[label_index + 1:]
else:
features = (row[:len(row) + label_index]
+ row[len(row) + label_index + 1:])
if not transactional_database:
features = ["feature{}-{}".format(i+1, feature)
for i, feature in enumerate(features)]
self._dataset.append(tuple(features))
self._classes.append(label)
self._transactional_database = transactional_database
def _clear(self):
super(ARMClassifier, self)._clear()
self._classes = []
def _clean_items(self, items):
if not self._transactional_database:
return tuple([feature.split('-')[1] for feature in items])
else:
return tuple(items)
def _get_itemcount(self, items):
try:
classwise_count = self._itemcounts[tuple(set(items))]
except KeyError:
classwise_count = self._get_classwise_count(items)
return self._get_itemcount_from_classwise_count(classwise_count)
def _should_join_candidate(self, candidate1, candidate2):
if not self._transactional_database:
# If the last entry of both candidates belong to different
# classes in a non transactional database
# then they cannot be joined as the resulting
# candidate would have support 0.
feature1 = candidate1[-1].split('-')[0]
feature2 = candidate2[-1].split('-')[0]
if (feature1 == feature2):
return False
return super(ARMClassifier, self)._should_join_candidate(candidate1, candidate2)
def _get_classwise_count(self, items):
count_class = dict()
for key in set(self._classes):
count_class[key] = [0, 0]
for i, data in enumerate(self._dataset):
found = True
for item in items:
if item not in data:
found = False
break
if found:
count_class[self._classes[i]][0] += 1
count_class[self._classes[i]][1] += 1
return count_class
@staticmethod
def _get_itemcount_from_classwise_count(classwise_count):
net_itemcount = 0
for itemcount, _ in classwise_count.values():
net_itemcount += itemcount
return net_itemcount
def _generate_rules(self, itemset):
"""Generates classification rules from itemset and appends them to
the list of rules"""
for items in itemset:
if len(items) > 0:
rules = []
for label in set(self._classes):
classwise_count = self._get_classwise_count(tuple(items))
count_lhs = self._get_itemcount_from_classwise_count(
classwise_count)
count_rhs = classwise_count[label][1]
count_both = classwise_count[label][0]
antecedent = self._clean_items(items)
rule = ClassificationRule(antecedent, label,
count_both, count_lhs, count_rhs,
len(self._dataset))
if (rule.confidence >= self._real_confidence_threshold):
rules.append(rule)
rules.sort(key=self._rule_key)
try:
self._rules.append(rules[-1])
except IndexError:
pass
def _update_default_class(self):
counter = dict.fromkeys(set(self._classes), 0)
for i, _ in enumerate(self._dataset):
is_match = False
for rule in self.rules:
items = self._clean_items(self._dataset[i])
if (rule.match_antecedent(items) and
rule.match_consequent(self._classes[i])):
is_match = True
break
if is_match is False:
counter[self._classes[i]] += 1
self._default_class = max(counter.items(), key=itemgetter(1))[0]
def _learn(self, support_threshold, confidence_threshold,
coverage_threshold):
super(ARMClassifier, self)._learn(support_threshold,
confidence_threshold,
coverage_threshold)
self._update_default_class()
[docs] def classify(self, data_instance, top_k_rules=25):
"""Classify `data_instance` using rules generated by `learn` method.
Parameters
----------
data_instance : array_like
Unclassified input.
top_k_rules : int
Maximum number of rules, which will be used to classify
`data_instance`.
Returns
-------
str
Predicted label for the `data_instance`.
Note
----
If the support_threshold and confidence_threshold passed to classify
are both greater than the values at which learning was done,
The result is same as if the learning is done at those higher values.
This helps in optimization purposes where you only need to learn once
at a low support and confidence_threshold, which reduces optimization
time.
"""
matching_rules = []
for rule in self.rules:
if rule.match_antecedent(data_instance):
matching_rules.append(rule)
if len(matching_rules) == top_k_rules:
break
if len(matching_rules) > 0:
score = dict()
for rule in matching_rules:
label = rule.consequent
score[label] = (score.get(label, 0) + rule.lift)
return max(score.items(), key=itemgetter(1))[0]
else:
return self._default_class