Có bệnh Không có bệnh Tổng
Train set 509 352 861
Validation set 78 26 104
Test set 114 38 152
Các mô hình thử nghiệm sẽ lần lượt được huấn luyện trên tập Train và kiểm định trên tập Validation cho đến khi không có sự cải thiện trên tập này nữa thì kết thúc. Mô hình kết quả của quá trình huấn luyện sẽ được kiểm tra độ chính xác trên tập Test.
4.1.2. Chuẩn bị dữ liệu
Các ảnh trích xuất từ phần mềm chụp ảnh của máy SPECT đôi khi sẽ bị lẫn vào các ký tự trên ảnh, do đó các ảnh sẽ được tiền xử lý bằng cách loại bỏ các pixel nhiễu trên ảnh cực (hình 4.1) giá trị các pixel được đưa về khoảng (0, 1) trước khi đưa vào huấn luyện.
Hình 4. 1: Ảnh cực được tiền xử lý. Ảnh bên trái là ảnh ban đầu có một vài kí tự lẫn vào ở góc dưới. Ảnh bên phải là ảnh đã qua xử lý loại bỏ ký tự trong ảnh. lẫn vào ở góc dưới. Ảnh bên phải là ảnh đã qua xử lý loại bỏ ký tự trong ảnh.
Các giá trị pixel nằm ngoài bán kính ảnh cực sẽ được đặt về giá trị 0:
( ) ( ) ( ) (4.1)
Trong đó:
là tọa độ của pixel, gốc tọa độ được coi là điểm phía trên bên trái của ảnh.
( ) là giá trị pixel ở vị trí
= 176 bằng 1/2 kích thước ảnh (352 x 352) , là tâm ảnh có tọa độ ( ⁄ , ⁄ )
Ảnh cực sẽ được chuyển đổi sang dạng đồ thị bằng cách sử dụng các super-pixel. Mỗi super-pixel biểu diễn cho một vùng nhỏ trên ảnh mà các pixel ở trong vùng đó có giá trị độ sáng xấp xỉ nhau. Đầu tiên, các super-pixel ở mỗi ảnh sẽ được tính ra bằng thuật toán SLIC, mỗi super-pixel này sẽ được coi là một nút của đồ thị. Ma trận trọng số W của đồ thị sẽ được xây dựng tương tự với thuật toán k-nearest neighbor, với các phần tử được xác định như sau:
( ) (4.2)
Trong đó lần lượt là các tọa độ hai chiều trên ảnh của các super- pixel thứ i và thứ j, được gọi là tham số tỉ lệ, tham số này được xác định cho từng nút và được tính bằng trung bình khoảng cách của k nút gần nhất với nút đó.
Sau khi tính ra ma trận kề, mỗi nút sẽ tạo liên kết với k nút gần nó nhất tính theo các giá trị , vì thế ma trận kề sẽ được cập nhật lại sao cho mỗi nút chỉ liên kết với k nút gần nó nhất, các nút còn lại sẽ có trọng số bằng 0. Ví dụ về ảnh cực được chuyển sang dạng đồ thị được cho ở hình 4.2
Hình 4. 2: Ảnh cực được chuyển sang dạng đồ thị. Từ trái sang phải lần lượt là ảnh cực, ảnh các đỉnh tạo ra từ SLIC và ảnh gồm cả các đỉnh và cạnh. ảnh cực, ảnh các đỉnh tạo ra từ SLIC và ảnh gồm cả các đỉnh và cạnh.
Đặc trưng của mỗi nút sẽ bao gồm các giá trị màu RGB và tọa độ của super-pixel ứng với nút đó. Đặc trưng của cạnh từ nút i đến nút j sẽ là sẽ là giá trị .
Mô hình áp dụng mạng nơ-ron đồ thị cho bài toán phân loại ảnh cực được cho ở hình 4.3
Hình 4. 3: Mô hình mạng nơ-ron đồ thị cho bài toán phân loại ảnh cực
4.2. Kết quả
Hàm mất mát được sử dụng trong các mô hình mạng nơ-ron tích chập và nơ-ron đồ thị là hàm Cross Entropy có công thức:
( ) ( ( , -)
∑ ( , -) ) (4.3)
Trong đó:
: là vec-tơ dự đoán đầu ra của các mạng cho mẫu dữ liệu với [i]
thể hiện xác suất dự đoán cho nhãn thứ i. : số loại nhãn có trong dữ liệu
: thứ tự nhãn đúng của mẫu dữ liệu
Hàm mất mát của mô hình sẽ là tổng các giá trị mất mát của toàn bộ mẫu dữ liệu:
∑ ( )
(4.4)
Với là tập hợp chứa toàn bộ các mẫu dữ liệu, là nhãn tương ứng với từng mẫu dữ liệu
Các mô hình được thử nghiệm sẽ huấn luyện theo phương pháp tương tự trong bài báo [17]. Đó là mỗi mô hình sẽ được chọn một giá trị learning rate khởi tạo và một giá trị learning rate tối thiểu. Giá trị learning rate khởi tạo sẽ là giá trị learning rate ban đầu được sử dụng trong thuật toán tối ưu hàm mất mát của mô hình. Ở mỗi vòng lặp trong quá trình huấn luyện, độ chính xác của mô hình sẽ được kiểm tra trên tập Validation Set, nếu sau một số vòng lặp nhất định mà độ chính xác này không có sự cải thiện thì giá trị learning rate của thuật toán huấn luyện sẽ bị giảm đi một nửa. Quá trình huấn luyện tiếp tục cho đến khi giá trị learning rate bị giảm xuống mức nhỏ hơn giá trị learning rate tối thiểu đã định nghĩa thì dừng lại. Hiệu năng của mô hình sẽ được đánh giá trên tập Test Set.
Các mô hình trong luận văn được cài đặt bằng cách sử dụng framework Pytorch và bộ thư viện DGL (Deep Graph Library) – một thư viện viết trên ngôn ngữ Python hỗ trợ các bài toán với dữ liệu đồ thị. Hình 4.4 là ví dụ về mã nguồn dùng để đánh giá mô hình sau khi huấn luyện hoàn thành. Các tham số chọn trước của mỗi mô hình (hyperparameter) sẽ được thử nghiệm và lựa chọn bằng kĩ thuật tìm kiếm kiếm Grid (Grid Search) – tức là thử nghiệm toàn diện các giá trị cụ thể của từng tham số để tìm ra các giá trị tham số tốt nhất cho từng mô hình.
Ví dụ về mã nguồn khi cài đặt mạng GCN được cho như dưới đây (Mã nguồn 4.1): 1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 5 import dgl
6 from dgl.nn.pytorch import GraphConv
7
8 # Lớp mạng fully connected (FC)
9 # ở phần cuối của mô hình
10 class MLPReadout(nn.Module):
11
12 def __init__(self, input_dim, output_dim, L=2):
13 super().__init__()
14 list_FC_layers=[ nn.Linear(input_dim//2**l,
15 input_dim//2**(l+1),bias=True ) for l in range(L)]
16 list_FC_layers.append(nn.Linear(
17 input_dim//2**L , output_dim , bias=True ))
18 self.FC_layers=nn.ModuleList(list_FC_layers)
19 self.L = L
20
21 def forward(self, x):
22 y = x
23 for l in range(self.L):
24 y = self.FC_layers[l](y)
25 y = F.relu(y)
26 y = self.FC_layers[self.L](y)
27 return y
28
29 # Lớp mạng GCN
30 class GCNLayer(nn.Module):
31
32 def __init__(self, in_dim, out_dim):
33 super().__init__()
34 self.in_channels = in_dim
35 self.out_channels = out_dim
36 self.conv = GraphConv(in_dim, out_dim)
37
38 def forward(self, g, feature):
39
40 h = F.relu(self.conv(g, feature))
41 return h
42
43 # Mô hình mạng gồm nhiều lớp mạng GCN
44 class GCNNet(nn.Module):
45 def __init__(self, num_node_feat, n_layers):
47
48 # Lớp embedding biến đổi đặc trưng của đỉnh
49 self.embedding_h = nn.Embedding(num_node_feat
50 , n_layers[0])
51 # Một danh sách các lớp GCN
52 self.layers = nn.ModuleList(
53 [GCNLayer(n_layers[i], n_layers[i+1])
54 for i in range(n_layers-1)])
55 # Các lớp FC ở phần cuối của mạng
56 self.MLP_layer = MLPReadout(n_layers[-1], 2)
57
58 def forward(self, g, h, e):
59 # các đặc trưng đỉnh đi qua lớp embedding
60 h = self.embedding_h(h)
61 # các đặc trưng đỉnh được cập nhật
62 # bởi các lớp mạng GCN
63 for conv in self.layers:
64 h = conv(g, h)
65 g.ndata['h'] = h
66 # Lấy trung bình các đặc trưng đỉnh
67 hg = dgl.mean_nodes(g, 'h')
68 # vec-tơ cuối cùng qua các lớp FC
69 return self.MLP_layer(hg)
Mã nguồn 4. 1: Mạng GCN bằng ngôn ngữ Python
Và mã nguồn huấn luyện được viết dưới dạng giả mã dựa trên ngôn ngữ Python như sau (Mã nguồn 4.2):
1 def train(model, learning_rate, data_loader):
2
3 """
4 model: Mô hình được huấn luyện
5 learning_rate: Tham số learning_rate dùng cho
6 thuật toán tối ưu hàm mất mát dựa trên đạo hàm
7 data_loader: Tập dữ liệu bao gồm một danh sách
8 các phần tử (đồ thị, tập đặc trưng đỉnh,
9 tập đặc trưng cạnh, nhãn)
10 """
11
12 for (batch_graphs, batch_x, batch_e, batch_labels) \
13 in (data_loader):
14
15 # Hàm forward của model
16 # tính ra giá trị đầu ra dự đoán
17 batch_predict = model.forward(batch_graphs, \
19
20 # Hàm lossfn thực hiện việc
21 # tính giá trị hàm mất mát
22 # và giá trị cần cập nhật ở tham số
23 # sau mỗi vòng lặp
24 loss, gradient_descent = lossfn(model, batch_predict, \ 25 label, learning_rate)
26
27 # Hàm update thực hiện việc cập nhật lại các
28 # tham số huấn luyện của mô hình
29 update(model, gradient_descent)
30
31 def evaluate(model, data_loader):
32
33 """
34 model: Mô hình được huấn luyện
35 data_loader: Tập dữ liệu bao gồm
36 một danh sách các phần tử 37 (đồ thị, tập đặc trưng đỉnh, 38 tập đặc trưng cạnh, nhãn) 39 """ 40 nb_data = 0 41 epoch_test_acc = 0
42 for (batch_graphs, batch_x, batch_e, batch_labels) \
43 in (data_loader): 44 45 predict = model.forward(batch_graphs, \ 46 batch_x, batch_e) 47 48 # Hàm accuracy tính ra độ chính xác
49 # của kết quả dự đoán theo nhãn
50 epoch_test_acc += accuracy(batch_scores, \ 51 batch_labels) 52 53 nb_data += len(batch_labels) 54 55 epoch_test_acc /= nb_data 56 57 return epoch_test_acc 58
59 def train_model(model, learning_rate, num_epoch, train_data,
60 val_data, test_data, min_lr, patient):
61
62 """
63 model: Mô hình được huấn luyện
64 learning_rate: Tham số learning_rate dùng cho
65 thuật toán tối ưu hàm mất mát dựa trên đạo hàm
66 train_data: Tập dữ liệu huấn luyện bao gồm một
68 (đồ thị, tập đặc trưng đỉnh, tập đặc trưng cạnh, nhãn)
69 val_data: Tập dữ liệu Validation bao gồm
70 một danh sách các phần tử tương tự tập train
71 test_data: Tập dữ liệu Test bao gồm
72 một danh sách các phần tử tương tự tập train
73 min_lr: khi learning_rate nhỏ hơn giá trị này
74 thì dừng huấn luyện
75 patient: Khi sau một số lượt epoch bằng patient
76 mà mô hình không có sự cải thiện trên tập Validation
77 thì giảm learning_rate 78 79 """ 80 best_acc = 0 81 p = patient 82 for i in range(num_epoch):
83 train(model, learning_rate, train_data)
84
85 # Đánh giá độ chính xác trên tập validatation
86 val_acc = evaluate(model, val_data)
87
88 # Nếu độ chính xác không có sự cải thiện
89 # thì giảm biến đếm đi 1, khi biến đếm giảm về 0
90 # thì giảm learning_rate đi 1/2
91
92 if (val_acc < best_acc):
93 p = p - 1 94 else:
95
96 # Nếu độ chính xác có cải thiện thì
97 # cập nhật lại độ chính xác tốt nhất 98 p = patient 99 best_acc = val_acc 100 101 if p == 0: 102 learning_rate /= 2 103
104 # Nếu learning_rate giảm nhỏ dưới ngưỡng
105 # thì dừng huấn luyện
106 if (learning_rate < min_lr):
107 break
108
109 # Khi quá trình huấn luyện kết thúc
110 # thì đánh giá kết quả trên tập Test
111 test_acc = evaluate(model, test_data)
Ngoài độ chính xác, các mô hình cũng được đánh giá dựa trên các giá trị Precision, Recall, F1: (4.5) (4.6) (4.7) Trong đó:
: Số mẫu có bệnh được dự đoán đúng
: Số mẫu không có bệnh bị dự đoán sai
: Số mẫu có bệnh bị dự đoán sai
(4.8)