Qua mục 3.6.2 đã trình bày kiến trúc và lý do LSTM có khả năng tránh vanishing gradient, trong mục này, quá trình huấn luyện của mạng LSTM được mô tả như sau:
- Hàm chi phí (cost function):
Cost function được sử dụng là một hàm L2-norm có công thức như sau:
𝐸(𝑥, 𝑥̂) = (𝑥 − 𝑥̂)
2
2
Hàm L2-norm cost function là tổng bình phương của của giá trị dự đoán và giá trị nhãn.
Đạo hàm của cost function là:
𝜕𝑥𝐸(𝑥, 𝑥̂) = 𝑥 − 𝑥̂
Quá trình forward trong mạng LSTM được biểu diễn qua hình ảnh sau, các thông tin về các gate: output, hidden, forget,.. đã được lược bỏ để dễ quan sát hơn.
Vào mỗi thời điểm T, một cell state 𝑐𝑇nhận giá trị 𝑥𝑇và hidden state
ℎ𝑇−1 và đảm nhận tính toán hidden state ℎ𝑇 và cell state tiếp theo 𝑐𝑇+1. Hidden state ℎ𝑇 lúc này được đưa qua một (activation) layers để tính toán giá trị cost function 𝐶𝑇.
Quá trình Backpropagation trong mạng LSTM được biểu diễn qua hình ảnh sau. Dễ dàng nhận thấy, toàn bộ qúa trình này tương tự với quá trình Forward, tuy nhiên toàn bộ các dấu mũi tên đã được đảo ngược.
Tại mỗi thời điểm T, cell state 𝑐𝑇sẽ nhận giá trị tích luỹ của đạo hàm cost function của hidden state ở thời điểm T+1 với giá trị của output state T+1. Từ đó, các giá trị đạo hàm của các gate thuộc cell state T sẽ được tính toán lại, dựa theo các công thức như sau. Đầu ra của cell state 𝑐𝑇 sẽ là một giá trị 𝑇−1
và sẽ được lặp lại như giá trị đầu vào của cell state 𝑐𝑇−1 tiếp theo. Giá trị của các gate được cập nhập theo các công thức sau:
Sau tính toán, các trọng số W, U, b được tính toán bằng cách tính tổng thay đổi: 𝛿𝑊 = ∑ 𝛿𝑔𝑎𝑡𝑒𝑠𝑡⨂ 𝑥𝑡 𝑇 𝑡=0 𝛿𝑈 = ∑ 𝛿𝑔𝑎𝑡𝑒𝑠𝑡+1⨂ 𝑜𝑢𝑡𝑡 𝑇−1 𝑡=0 𝛿𝑏 = ∑ 𝛿𝑔𝑎𝑡𝑒𝑠𝑡+1 𝑇 𝑡=0
Sau cùng, các trọng số được cập nhật thông qua hàm Stochastic Gradient Descent (SGD) với trọng số learning rate :
𝑊𝑛𝑒𝑤 = 𝑊𝑜𝑙𝑑 − 𝛿𝑊𝑜𝑙𝑑