-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
33 lines (27 loc) · 874 Bytes
/
utils.py
File metadata and controls
33 lines (27 loc) · 874 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from datasets import load_dataset
from nltk.stem import PorterStemmer
def plot_confusion_matrix(y_true, y_pred, display_labels=None, ax=None):
conf_matrix = confusion_matrix(y_true, y_pred)
conf_matrix_display = ConfusionMatrixDisplay(
conf_matrix,
display_labels=display_labels
)
conf_matrix_display.plot(ax=ax)
def preprocess_word(word: str, stem=True):
stemmer = PorterStemmer()
result = word
if stem:
result = stemmer.stem(result)
return result
class EmotionDataset:
def __init__(self):
self.dataset = load_dataset("dair-ai/emotion")
self.LABEL_TO_EMOTION_DICT = {
0: "sadness",
1: "joy",
2: "love",
3: "anger",
4: "fear",
5: "surprise"
}