Tính toán bên trong mạng nơ-ron hồi quy

Một phần của tài liệu Nghiên cứu ứng dụng mạng nơ ron hồi quy trong nhận dạng tiếng nói (Trang 31 - 35)

Bên trong mỗi tế bào nơ-ron bao gồm các tính toán phức tạp dựa trên các giá trị như trọng số W, giá trị ngõ vào thời điểm hiện tại xt và giá trị đầu ra của tế bào phía trước đó at-1. Bên dưới là mô tả các thông số liên quan đến quá trình tính toán bên trong của một tế bào RNN.

Hình 2.22: Mô tả tính toán bên trong 1 tế bào RNN [18]

𝑎+,- = tanh(𝑊./𝑥+,- + 𝑊..𝑥+,)$-+ 𝑏.) 𝑦A+,- = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑊0.𝑥+,-+ 𝑏0)

Tại thời điểm t, 𝒚s𝒕 là giá trị dự đoán, at-1 là giá trị ngõ ra tính toán được ở nơ- ron trước đó, at là giá trị tính toán được sử dụng chia sẻ cho nơ-ron tính toán tiếp theo. Các trọng số tương ứng với giá trị ngõ vào trước đó là Waaa(t-1), giá trị ngõ vào

tại thời điểm t hiện tại là xt , trọng số tương ứng với giá trị ngõ vào hiện tại Waax(t)và giá trị biastại thời điểm t. Giá trị tính được tại tế bào thời điểm t là atđược sử dụng để tính cho tế bào ở thời điểm xt+1 tiếp theo. Hàm kích hoạt là hàm tanh thường sử (2.16) (2.17)

dụng tính toán trong mạng RNN. Mỗi đầu ra at (sử dụng hàm tanh) và giá trị dự đoán

yt (sử dụng hàm softmax) được xác định bởi công thức (2.16) và (2.17).

Huấn luyện mạng RNN gồm lan truyền tiến (forward) và lan truyền ngược (backward) để cập nhật các trọng số của mô hình.

v Lan truyền tiến trong mạng RNN

Với một tế bào RNN tính toán lan truyền tiến (rnn_cell_forward) được thực hiện như sau:

Bước 1:Tính các trạng thái ẩn bởi hàm kích hoạt tanh. 𝑎7 = tanh (𝑊88𝑎7!)+ 𝑊89𝑎7+ 𝑏8)

Bước 2: Sử dụng trạng thái ẩn mới, tính giá trị dự đoán 𝒚s(𝒕) bởi hàm kích hoạt

softmax:

𝑦y7 = softmax(𝑊:8𝑎7+ 𝑏:) Bước 3: Lưu trữ trạng thái tạm thời

𝑐𝑎𝑐ℎ𝑒 = {𝑎7, 𝑎7!) , 𝑥7, 𝑝𝑎𝑟𝑎𝑚𝑒𝑡𝑒𝑟𝑠} Bước 4: Trả ra bộ giá trị gồm: 𝑎7, 𝑦y7, 𝑣à cache.

Kiến trúc mạng RNN là một sự lặp lại của một tế bào, các tế bào kết nối với nhau hình thành nên mạng RNN, tổng quan mạng RNN có dạng như sau:

Hình 2.23: Mô tả tính toán liên kết giữa các tế bào RNN [18]

Khi đó, việc tính toán lan truyền tiến toàn bộ mạng RNN được thực hiện bởi các bước sau:

(2.18)

(2.19)

Bước 1: Khởi tại vec-tơ zero cho ngõ ra at, vec-tơ này được sử dụng lưu trữ các trạng thái được tính toán của mạng RNN.

Bước 2: Khởi tạo trạng thái ẩn phía trước nơ-ron đầu tiên là a0.

Bước 3: Bắt đầu lặp lại các bước, với sự tăng lên của thời điểm t: Cập nhập trạng thái ẩn tiếp theo và cache bởi gọi hàm tính toán trên một tế bào ở trên (rnn_cell_forward); Lưu trữ trạng thái ẩn tiếp theo trên a (thời điểm thứ t ); Lưu trữ giá trị dự đoán 𝒚s; Thêm cache tới danh sách của caches đã có trước đó.

Bước 4: Trả ra 𝒂, 𝒚s, 𝑣à 𝑑𝑎𝑛ℎ 𝑠á𝑐ℎ 𝒄𝒂𝒄𝒉𝒆.

v Lan truyền ngược liên hồi

Huấn luyện mạng nơ-ron hồi quy RNN sử dụng giải thuật lan truyền ngược liên hồi (BPTT - Backpropagation Through Time) bởi vì đạo hàm tại mỗi đầu ra phụ thuộc không chỉ vào các tính toán tại bước hiện tại, mà còn phụ thuộc vào các bước đã tính trước đó, và các tham số trong mạng RNN được sử dụng chung cho tất cả các bước trong mạng.

Xem xét công thức tính ngõ ra tại mỗi tế bào RNN: 𝑠7 = tanh(𝑈𝑥7+ 𝑊𝑠7!)) 𝑦s = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑉𝑠7 7)

Với 𝑦7 là giá trị đích (target) ở bước t, và 𝑦s7 là giá trị dự đoán. Định nghĩa hàm mất mát (hay hàm lỗi) dạng cross entropy như sau:

𝐸7(𝑦7, 𝑦s ) = − 𝑦7 7log 𝑦s 7 𝐸(𝑦, 𝑦s ) = S 𝐸7 7(𝑦7, 𝑦s ) 7

7

= − S 𝑦7 𝑙𝑜𝑔( 𝑦s ) 7

7

Mỗi chuỗi đầy đủ (một câu) là một mẫu. Khi đó, tổng số lỗi chính là tổng của tất cả các lỗi ở mỗi bước (mỗi từ) – hình 2.24.

(2.21)

(2.23) (2.22)

(2.24) (2.25)

Hình 2.24: Mô tả lỗi của toàn mạng RNN [18]

Mục tiêu là tính đạo hàm của hàm lỗi với tham số U, V, Wtương ứng bởi áp dụng phương pháp SGD (Stochastic gradient descent – Gradient descent ngẫu nhiên). Tương tự như việc cộng tổng các lỗi, ta cũng sẽ cộng tổng các đạo hàm tại mỗi bước cho mỗi mẫu huấn luyện:

𝜕𝐸

𝜕𝑊 = S 𝜕𝐸7 𝜕𝑊

7

Áp dụng quy tắt vi phân (Chain Rule) để tính lan truyền ngược lỗi: 𝜕𝐸; 𝜕𝑉 = 𝜕𝐸; 𝜕𝑦y; 𝜕𝑦•; 𝜕𝑉 = 𝜕𝐸; 𝜕𝑦y; 𝜕𝑦•; 𝜕𝑧; 𝜕𝑧; 𝜕𝑉 = (𝑦• − 𝑦; ;) ⨂ 𝑠;

Trong đó, ⨂ là phép nhân 2 vectơ, 𝑧; = 𝑉<!. Qua công thức này, 1213! chỉ còn phụ thuộc vào 𝐲• 𝟑 , 𝑦; và 𝑠; . Nhưng với W và U thì được tính như sau:

𝜕𝐸; 𝜕𝑊 = 𝜕𝐸; 𝜕𝑦y; 𝜕𝑦y; 𝜕𝑠; 𝜕𝑠; 𝜕𝑊

Với, 𝑠; = tanh (𝑈𝑥7 + 𝑊𝑠#) phụ thuộc vào 𝑠#và 𝑠# lại phụ thuộc vào W, 𝑠# không thể được xem là hằng số để tính toán với V được. Vì thế biến đổi tiếp như sau:

𝜕𝐸; 𝜕𝑊 = 𝜕𝐸; 𝜕𝑦y; 𝜕𝑦y; 𝜕𝑠; 𝜕𝑠; 𝜕𝑠% 𝜕𝑠% 𝜕𝑊 (2.26) (2.27) (2.28) (2.29) (2.30) (2.31)

Như vậy, với W, phải lan truyền từ t3 về t0 bằng cách phải cộng tất cả các đầu ra ở các bước trước.

Hình 2.25: Mô tả tính lỗi lan truyền trong mạng RNN [18]

Điểm khác với lan truyền ngược truyền thống là cộng tổng các đạo hàm của

W tại mỗi bước thời gian. Tương tự lan truyền ngược dạng truyền thống, vec-tơ 𝜹

khi lan truyền ngược như sau.

𝜹𝒙(𝟑) = 𝜕𝐸; 𝜕𝑧# = 𝜕𝐸; 𝜕𝑠; 𝜕𝑠; 𝜕𝑠# 𝜕𝑠# 𝜕𝑧# Với, 𝑧# = 𝑈𝑥#+ 𝑊𝑠#

Các bước tiếp theo có thể tính toán tương tự như lan truyền ngược dạng truyền thống.

Một phần của tài liệu Nghiên cứu ứng dụng mạng nơ ron hồi quy trong nhận dạng tiếng nói (Trang 31 - 35)

Tải bản đầy đủ (PDF)

(74 trang)