Skip to content

Commit 361c00c

Browse files
committed
Implement generalization to allow for custom data
1 parent 1fd2831 commit 361c00c

File tree

5 files changed

+173
-177
lines changed

5 files changed

+173
-177
lines changed

cmp_test_label.py

+31-48
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33
from model import UNet
44
import torch
55
import os
6-
7-
P_label = "result\PTest_test_label.npy"
8-
T_label = "result\TTest_test_label.npy"
9-
model_path = "result\saved_model_pt"
6+
from config import *
107

118
# Check if the model file exists
12-
if os.path.exists(model_path):
13-
whole_model = torch.load(model_path)
9+
if os.path.exists(os.path.join(RESULT_PATH, "saved_model_pt")):
10+
whole_model = torch.load(os.path.join(RESULT_PATH, "saved_model_pt"))
1411
# Construct the model details string
1512
model_details = ""
1613
if "data" in whole_model:
@@ -19,20 +16,17 @@
1916
model_details += (
2017
"Learning rate: " + str(whole_model.get("learning_rate")) + "\n"
2118
)
22-
if "phantom_IOU" in whole_model:
23-
model_details += "Phantom IoU: " + str(whole_model.get("phantom_IOU")) + "\n"
24-
if "phantom_DC" in whole_model:
25-
model_details += "Phantom DC: " + str(whole_model.get("phantom_DC")) + "\n"
26-
if "T1T6_IOU" in whole_model:
27-
model_details += "T1T6 IoU: " + str(whole_model.get("T1T6_IOU")) + "\n"
28-
if "T1T6_DC" in whole_model:
29-
model_details += "T1T6 DC: " + str(whole_model.get("T1T6_DC")) + "\n"
3019
if "pools" in whole_model:
3120
model_details += "Pooling layers used: " + str(whole_model.get("pools")) + "\n"
21+
if "reverse_pools" in whole_model:
22+
model_details += "Reverse pooling used: " + str(whole_model.get("reverse_pools")) + "\n"
3223
if "data_augmentations" in whole_model:
3324
model_details += (
3425
"Data Augmentations: " + str(whole_model.get("data_augmentations")) + "\n"
3526
)
27+
if "tested_on" in whole_model:
28+
for i, data in enumerate(whole_model["tested_on"]):
29+
model_details += f"{data} IOU: {whole_model['IOU'][i]}\n{data} DC: {whole_model['DC'][i]}\n"
3630
if "image_size" in whole_model:
3731
model_details += "Image Size: " + str(whole_model.get("image_size")) + "\n"
3832
if "optimizer" in whole_model:
@@ -52,49 +46,38 @@
5246
else:
5347
model_details = "Model not found or loaded"
5448

55-
# Create subplots for images and model details
56-
fig, axs = plt.subplots(2, 4, gridspec_kw={"width_ratios": [1, 1, 1, 0.5]})
57-
58-
# Check if P_label exists
59-
if os.path.exists(P_label):
60-
Phantom_created = np.load(P_label)
61-
Pimage = np.load("data\\PTest\\frame_0000.npy")
62-
Pmask = np.load("data\\PTest_label\\frame_0000.npy")
6349

64-
# Plot Phantom images
65-
axs[0, 0].imshow(np.squeeze(Pimage), cmap="gray")
66-
axs[0, 0].set_title("Phantom Original Image")
50+
# Get the list of files in the directory
51+
all_files = os.listdir(RESULT_PATH)
6752

68-
axs[0, 1].imshow(np.squeeze(Pmask), cmap="gray")
69-
axs[0, 1].set_title("Phantom Mask")
53+
# Filter the files based on the extension
54+
filtered_files = [file for file in all_files if file.endswith(".npy")]
7055

71-
axs[0, 2].imshow(np.squeeze(Phantom_created), cmap="gray")
72-
axs[0, 2].set_title("Phantom Model Created Mask")
56+
# Create subplots for images and model details
57+
fig, axs = plt.subplots(len(filtered_files), 4, gridspec_kw={"width_ratios": [1, 1, 1, 0.5]})
7358

74-
# Check if T_label exists
75-
if os.path.exists(T_label):
76-
T1T6_created = np.load(T_label)
77-
Timage = np.load("data\\TTest\\frame_0000.npy")
78-
Tmask = np.load("data\\TTest_label\\frame_0000.npy")
59+
idx = 0
60+
for file in filtered_files:
61+
data = file[:-15]
62+
created_mask = np.load(os.path.join(RESULT_PATH, file))
63+
original_mask = np.load(os.path.join(DATA_PATH, TESTING_DATA_MASK_LOCATION[TESTING_DATA.index(data)], MASK_DEFINITION % 0000))
64+
original_image = np.load(os.path.join(DATA_PATH, TESTING_DATA_LOCATION[TESTING_DATA.index(data)], IMAGE_DEFINITION % 0000) )
7965

80-
# Plot T1-T6 images
81-
axs[1, 0].imshow(np.squeeze(Timage), cmap="gray")
82-
axs[1, 0].set_title("T1-T6 Original Image")
66+
# Plot Phantom images
67+
axs[idx, 0].imshow(np.squeeze(original_image), cmap="gray")
68+
axs[idx, 0].set_title(data + " Original Image")
8369

84-
axs[1, 1].imshow(np.squeeze(Tmask), cmap="gray")
85-
axs[1, 1].set_title("T1-T6 Mask")
70+
axs[idx, 1].imshow(np.squeeze(original_mask), cmap="gray")
71+
axs[idx, 1].set_title(data + " Mask")
8672

87-
axs[1, 2].imshow(np.squeeze(T1T6_created), cmap="gray")
88-
axs[1, 2].set_title("T1-T6 Model Created Mask")
73+
axs[idx, 2].imshow(np.squeeze(created_mask), cmap="gray")
74+
axs[idx, 2].set_title(data + " Model Created Mask")
8975

90-
# Hide ticks and labels for the empty subplot
91-
axs[0, 3].axis("off")
76+
axs[idx, 3].axis("off")
9277

93-
# Adjust layout to add space for the text
94-
plt.subplots_adjust(bottom=0.1, top=0.9)
78+
idx += 1
9579

96-
# Add model details below the images
80+
# Hide ticks and labels for the empty subplot
9781
axs[0, 3].text(0, 0.5, model_details, fontsize=14, ha="left", va="center")
98-
axs[1, 3].axis("off")
9982

100-
plt.show()
83+
plt.show()

config.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
DATA = "P"
1+
DATA = "Phantom"
22
TRAIN = True
3-
TEST_ON_BOTH_DATA = False
3+
TEST_ON = [0, 1]
44
POOL = 4
55
REVERSE_POOL = True
66
LEARNING_RATE = 1e-2
@@ -17,12 +17,29 @@
1717
CREATE_TEST_MASK = True
1818

1919
RESULT_PATH = "./result"
20-
DATA_PATH = "data"
2120

2221
CREATE_FOLDER = True
2322

2423
NOTIFY = True
2524

25+
# DATA LOADING
26+
DATA_PATH = "data"
27+
28+
TRAINING_DATA = ["Phantom", "T1T6"]
29+
TESTING_DATA = ["Phantom", "T1T6"]
30+
31+
TRAINING_DATA_LOCATION = ["PTrain", "TTrain"]
32+
TESTING_DATA_LOCATION = ["PTest", "TTest"]
33+
34+
TRAINING_DATA_MASK_LOCATION = ["PTrain_label", "TTrain_label"]
35+
TESTING_DATA_MASK_LOCATION = ["PTest_label", "TTest_label"]
36+
37+
IMAGE_DEFINITION = "frame_%04d.npy"
38+
MASK_DEFINITION = "frame_%04d.npy"
39+
40+
TRAINING_DATA_COUNT = [1400, 845]
41+
TESTING_DATA_COUNT = [600, 362]
42+
2643
# Transformations
2744
TRANSFORM = False
2845

graph_progress.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import openpyxl
22
import matplotlib.pyplot as plt
33
import mplcursors
4+
import os
5+
from config import RESULT_PATH
46

57
# Load the workbook
6-
workbook = openpyxl.load_workbook("./result/progress.xlsx")
8+
workbook = openpyxl.load_workbook(os.path.join(RESULT_PATH, "progress.xlsx"))
79
worksheet = workbook.active
810

911
# Initialize lists to store data

0 commit comments

Comments
 (0)