-
Notifications
You must be signed in to change notification settings - Fork 1
/
condition.py
124 lines (96 loc) · 4.46 KB
/
condition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import abc
import typing
import numpy as np
import data_preprocessing
import label
class Condition(typing.Hashable, typing.Callable, abc.ABC):
"""Represents a condition that can be evaluated on an example.
When treated as a function, it takes an example (e.g., image, data) as input
and returns a value between 0 and 1 indicating whether the condition is satisfied.
A value of 0 means the condition is not met, while a value of 1 means it is fully met.
"""
def __init__(self, *args, **kwargs):
pass
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.Union[bool, np.array]:
pass
@abc.abstractmethod
def __hash__(self):
pass
@abc.abstractmethod
def __eq__(self, other):
pass
class PredCondition(Condition):
"""Represents a condition based on a model's prediction of a specific class.
It evaluates to 1 if the model predicts the specified class for a given example,
and 0 otherwise.
"""
def __init__(self,
l: label.Label,
secondary_model_name: str = None,
lower_prediction_index: int = None,
binary: bool = False,
negated: bool = False):
"""Initializes a PredCondition instance.
:param l: The target Label for which the condition is evaluated.
"""
super().__init__()
self.l = l
self.secondary_model_name = secondary_model_name
self.binary = binary
self.lower_prediction_index = lower_prediction_index
self.negated = negated
def __call__(self,
fine_data: np.array,
coarse_data: np.array,
secondary_fine_data: np.array = None,
secondary_coarse_data: np.array = None,
lower_predictions_fine_data: dict = None,
lower_predictions_coarse_data: dict = None,
binary_data: typing.Dict[label.Label, np.array] = None) -> np.array:
fine = self.l.g.g_str == 'fine'
if self.secondary_model_name is not None:
granularity_data = secondary_fine_data if fine else secondary_coarse_data
elif self.lower_prediction_index is not None:
granularity_data = lower_predictions_fine_data if fine else lower_predictions_coarse_data
elif self.binary:
granularity_data = None if binary_data is None else binary_data[self.l]
else:
granularity_data = fine_data if fine else coarse_data
if granularity_data is None:
raise ValueError(f'Condition with parameter: l={self.l}, '
f'secondary={self.secondary_model_name}, '
f'binary={self.binary},'
f'lower prediction index={self.lower_prediction_index}'
f'do not have associate data when do inference')
positive_result = 0 if self.negated else 1
return np.where(granularity_data == self.l.index, positive_result, 1 - positive_result)
def __str__(self) -> str:
secondary_str = f'_secondary_{self.secondary_model_name}' if self.secondary_model_name is not None else ''
lower_prediction_index_str = f'_lower_{self.lower_prediction_index}' \
if self.lower_prediction_index is not None else ''
binary_str = '_binary' if self.binary else ''
negated_str = '_negated' if self.negated else ''
g_str = f'_{self.l.g}' if self.l.g is not None else ''
return f'pred{g_str}_{self.l}{secondary_str}{lower_prediction_index_str}{binary_str}{negated_str}'
def __hash__(self):
return hash(self.__str__())
def __eq__(self, other):
return self.__hash__() == other.__hash__()
class InconsistencyCondition(Condition):
def __init__(self,
preprocessor: data_preprocessing.FineCoarseDataPreprocessor):
super().__init__()
self.preprocessor = preprocessor
def __call__(self,
fine_data: np.array,
coarse_data: np.array) -> np.array:
values = []
for fine_prediction_index, coarse_prediction_index in zip(fine_data, coarse_data):
values += [int(self.preprocessor.fine_to_course_idx[fine_prediction_index]
!= coarse_prediction_index)]
return np.array(values)
def __hash__(self):
return hash('ConsistencyCondition')
def __eq__(self, other):
return self.__hash__() == other.__hash__()