Confusion matrix can now be drawn

This commit is contained in:
Drew Giffin
2025-10-20 16:12:24 -04:00
parent de7fd0d384
commit b7580a4d1c
2 changed files with 19 additions and 0 deletions
Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

+19
View File
@@ -41,6 +41,25 @@ def main():
# evaluation
evaluate_model(model, X, y_pred, y_test, le)
draw_confusion_matrix(y_test, y_pred, le)
def draw_confusion_matrix(y_test, y_pred, le):
y_test_decoded = le.inverse_transform(y_test)
y_pred_decoded = le.inverse_transform(y_pred)
cm = confusion_matrix(y_test_decoded, y_pred_decoded, labels=le.classes_)
# Plot
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_,
yticklabels=le.classes_)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("images/confusion_matrix.png", dpi=300) # Save for README
plt.show()
def predict_target(model, X_test):
y_pred = model.predict(X_test)