First, import libraries
from __future__ import unicode_literals, print_function, division
from io import open
import random
import os
import numpy as np
import pickle
import time
import math
import glob
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch
import torch.nn as nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
#### set random_seed_number ###
random_seed_number = 1
Model has 5 hidden layers ( layer1, layer2, layer3, layer4, layer5
) and each layer follows a leakyRelu layer except layer1.
class simpleNN(nn.Module) :
def __init__(self,input,output,hidden):
super(simpleNN, self).__init__()
self.layer1 = nn.Linear(input,hidden)
self.leakyrelu = nn.LeakyReLU(0.1)
self.layer2 = nn.Linear(hidden,hidden * 2)
self.layer3 = nn.Linear(hidden * 2, hidden * 2)
self.layer4 = nn.Linear(hidden * 2, hidden)
self.layer5 = nn.Linear(hidden,output)
# self.drop1 = nn.Dropout(p=0.2)
# self.drop2 = nn.Dropout(p=0.4)
def forward(self,inputData):
x = self.layer2(self.layer1(inputData))
x = self.leakyrelu(x)
x = self.layer3(x)
x = self.leakyrelu(x)
x = self.layer4(x)
x = self.leakyrelu(x)
outputData = self.layer5(x)
return outputData
Train and test functions are defined as below:
trainIters()
: train the network using input_tensor
(EMG data) and output_tensor
(finger angle data). Adam
optimizer and MSELoss
function is used for training the model.
train()
: train the network inside the function trainIters()
test()
: test the network that in the middle of training or after the training.def train(input_tensor, target_tensor, time_length, simplenetwork, simplenetwork_optimizer, criterion,tfr_prob_list,iter):
#input_tensor : (time_length,1,bs,4) #timelength =1
#target_tensor : (time_length,1,bs,14) #timelength =1
assert input_tensor.shape[0] == 1
input_tensor = input_tensor.squeeze(dim=0).squeeze(dim=0)
target_tensor = target_tensor.squeeze(dim=0).squeeze(dim=0)
output = simplenetwork(input_tensor)
simplenetwork_optimizer.zero_grad()
loss = 0
loss += criterion(output, target_tensor)
loss.backward()
simplenetwork_optimizer.step()
return loss.item() / time_length
def trainIters(input_data,output_data,input_data_eval,output_data_eval,time_length, simplenetwork , n_epochs, print_every, plot_every, learning_rate,batch_size):
#input_data : (4,data_length)
#output_data : (14,data_length)
#prepare test data
from scipy.signal import savgol_filter
test_input_data, test_output_data,_,_ = dataprepare(test_path, test=True)
# shift dataset
test_input_data = test_input_data[:, :test_input_data.shape[1] - shiftLength]
new_test_output_data = np.zeros((14, test_input_data.shape[1]))
for idx in range(new_test_output_data.shape[1]):
new_test_output_data[:, idx] = test_output_data[:, idx + shiftLength]
test_output_data = new_test_output_data
del new_test_output_data
# test_output_data convert 0 <-> 1
test_output_data = 1-test_output_data
start = time.time()
print_loss_total = 0 # Reset every print_every
simplenetwork_optimizer = optim.Adam(simplenetwork.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
for epoch in range(1, n_epochs + 1):
print('========== epoch : %d =========='% (epoch))
###################################
simplenetwork.train()
#########################
randomindex = [x for x in range(input_data.shape[1]-time_length)]
random.Random(epoch+random_seed_number).shuffle(randomindex)
num_iters = (input_data.shape[1]-time_length)//batch_size
tfr_prob_list = np.random.random(num_iters)
for iter in range(num_iters):
input_tensor, target_tensor = dataloader(iter, time_length, input_data, output_data, randomindex,batchsize=batch_size)
#np.random.seed(epoch+random_seed_number)
loss = train(input_tensor, target_tensor, time_length, simplenetwork, simplenetwork_optimizer, criterion, tfr_prob_list,iter)
writer.add_scalar('Loss/iter',loss,(epoch-1)*num_iters + iter)
print_loss_total += loss
if iter % int(0.3*(input_data.shape[1]-time_length)//batch_size) == 0 :
print('iter : %d , loss : %.9f' % (iter, loss))
print_loss_avg = print_loss_total/num_iters
writer.add_scalar('Loss/epoch', print_loss_avg, epoch)
print_loss_total = 0
###################################
simplenetwork.eval()
#########################
if epoch % print_every == 0:
print('%s (%d %d%%) loss_avg : %.9f' % (timeSince(start, epoch / n_epochs),
epoch, epoch / n_epochs * 100, print_loss_avg))
_, eval_loss_avg = test(input_data_eval, output_data_eval, time_length, simplenetwork)
_, test_loss_avg = test(test_input_data, test_output_data, time_length, simplenetwork)
writer.add_scalar('Loss/eval', eval_loss_avg, epoch)
writer.add_scalar('Loss/test', test_loss_avg, epoch)
print('==============================')
#showPlot(plot_losses)
def test(input_data, output_data, time_length, simplenetwork):
loss_list = []
criterion = nn.MSELoss()
loss = 0
input_tensor_list, target_tensor_list = testdataloader(time_length,input_data,output_data)
predict_target_tensor = np.zeros_like(output_data)
with torch.no_grad() :
for idx,(input_tensor , target_tensor) in enumerate(zip(input_tensor_list,target_tensor_list)):
assert input_tensor.shape[0] == 1
input_tensor = input_tensor.squeeze(dim=0).squeeze(dim=0)
target_tensor = target_tensor.squeeze(dim=0).squeeze(dim=0)
output = torch.transpose(simplenetwork(input_tensor),0,1)
predict_target_tensor[:,idx] = output.cpu().numpy().squeeze()
#a = np.transpose(predict_target_tensor[:,idx]).astype(np.float32)
#a = np.expand_dims(a,axis=1)
#b = torch.transpose(target_tensor, 0, 1)
#b = b.cpu().numpy()
loss += criterion( output ,torch.transpose(target_tensor, 0, 1))
loss = loss/(idx+1)
assert (idx+1) == len(input_tensor_list)
#loss_list.append(loss.item()/time_length)
#loss_avg = sum(loss_list)/len(loss_list)
#print("eval loss : %.9f " %(loss_avg))
return predict_target_tensor , loss
Define some utility functions which will be used for training the model
dataprepare()
: prepare data for training and data fro evaluationdataloader()
: load data for every iteration during training ( inside function trainIters()
)synctime()
: plot the EMGdata and finger angle data after matches their start timetestdataloader()
: load data when testing the model.(inside function test()
) One of its arguments, timelength
, set to a value 1.