First, import libraries
from __future__ import unicode_literals, print_function, division
from io import open
import random
import os
import numpy as np
import pickle
import time
import math
import glob
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch
import torch.nn as nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torchsummary import summary
Model is simply composed of Encoder and Decoder with a main unit, gated recurrent unit (GRU).
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
#self.batch_size = batch_size
self.gru = nn.GRU(input_size, hidden_size)
def forward(self, input, hidden):
#input : (1,bs,4)
#hidden : (1,bs,hs)output = F.relu(output)
output, hidden = self.gru(input, hidden)
return output, hidden #output : (1,bs,hs) , hidden : (1,bs,hs)
#def initHidden(self):
# return torch.zeros(1, self.batch_size, self.hidden_size, device=device)
class DecoderRNN(nn.Module) :
def __init__(self,hidden_size,output_size,dropout_p):
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.dropout = nn.Dropout(dropout_p)
self.embedding = nn.Linear(output_size, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, input, hidden):
#input : (1,bs,os) #hidden : (1,bs,hs)
embedded = self.embedding(input) # embedded : (1,1,hs) / (1,bs,hs)
input = self.dropout(embedded)
output, hidden = self.gru(input, hidden)
output = self.out(output[0])
return output, hidden #output : (bs,os) , hidden : (1,bs,hs)
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
Train and test functions are defined as below:
trainIters()
: train the network using input_tensor
(EMG data) and output_tensor
(finger angle data). Adam
optimizer and MSELoss
function is used for training the model.
train()
: train the network inside the function trainIters()
test()
: test the network that in the middle of training or after the training.def train(input_tensor, target_tensor, time_length, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion,tfr_prob_list,iter):
#input_tensor : (time_length,1,bs,4)
#target_tensor : (time_length,1,bs,14)
encoder_hidden = torch.zeros(1, input_tensor.shape[2], encoder.hidden_size, device=device)#encoder.initHidden()
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
assert time_length == input_tensor.shape[0] #3
assert time_length == target_tensor.shape[0] #3
assert input_tensor.shape[2] == input_tensor.shape[2] #batchsize
target_size = target_tensor.shape[3]
encoder_outputs = torch.zeros(time_length, input_tensor.shape[2],encoder.hidden_size, device=device)
decoder_input = torch.zeros(1,input_tensor.shape[2],target_size,device=device)
loss = 0
for ei in range(time_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0]
# input_tensor = torch.Size([length,4]) / (time_length,1,bs,4)
# input_tensor[ei] = torch.Size([4]) / (1,bs,4)
# encoder_hidden = torch.Size([1,1,hs]) / (1,bs,hs)
# encoder_output = torch.Size([1,1,hs]) / (1,bs,hs)
# encoder_outputs = torch.Size([length,hs]) / (time_length,bs,hs)
decoder_hidden = encoder_hidden
#decoder_hidden = torch.Size([1, bs, hs])
prob = tfr_prob_list[iter]
use_teacher_forcing = True if prob < teacher_forcing_ratio else False
if use_teacher_forcing:
# Teacher forcing: Feed the target as the next input
for di in range(time_length):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
# decoder_input : (1,bs,os)
# decoder_hidden : (1,bs,hs)
# decoder_output : torch.Size([1,angle_num = 14]) / (bs,os)
# decoder_hidden : torch.Size([1,1,hs]) / (1,bs,hs)
loss += criterion(decoder_output, target_tensor[di].squeeze(0))
decoder_input = target_tensor[di] # Teacher forcing
else:
# Without teacher forcing: use its own predictions as the next input
for di in range(time_length):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
decoder_input = decoder_output.unsqueeze(0) # detach from history as input
loss += criterion(decoder_output, target_tensor[di].squeeze(0))
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
return loss.item() / time_length
def trainIters(input_data,output_data,input_data_eval,output_data_eval,time_length, encoder, decoder, n_epochs, print_every, plot_every, learning_rate,batch_size):
#input_data : (4,data_length)
#output_data : (14,data_length)
#time_length = 19
test_path = '/home/hyuninlee/PycharmProjects/xcorps/seq2seq_attentionmodel/testData/lhi'
best_test_mse = 100
from scipy.signal import savgol_filter
test_input_data, test_output_data,_,_ = dataprepare(test_path, test=True)
# test_output_data convert 0 <-> 1
test_output_data = 1-test_output_data
start = time.time()
plot_losses = []
print_loss_total = 0 # Reset every print_every
plot_loss_total = 0 # Reset every plot_every
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
for epoch in range(1, n_epochs + 1):
print('========== epoch : %d =========='% (epoch))
randomindex = [x for x in range(input_data.shape[1]-time_length)]
random.Random(epoch).shuffle(randomindex)
num_iters = (input_data.shape[1]-time_length)//batch_size
tfr_prob_list = np.random.random(num_iters)
for iter in range(num_iters):
input_tensor, target_tensor = dataloader(iter, time_length, input_data, output_data, randomindex,batchsize=batch_size)
np.random.seed(epoch)
loss = train(input_tensor, target_tensor, time_length, encoder,decoder, encoder_optimizer, decoder_optimizer, criterion, tfr_prob_list,iter)
writer.add_scalar('Loss/iter',loss,(epoch-1)*num_iters + iter)
print_loss_total += loss
plot_loss_total += loss
if iter % int(0.3*(input_data.shape[1]-time_length)//batch_size) == 0 :
print('iter : %d , loss : %.9f' % (iter, loss))
writer.add_scalar('Loss/epoch', print_loss_total, epoch)
if epoch % print_every == 0:
print_loss_avg = print_loss_total / (print_every * num_iters)
print_loss_total = 0
print('%s (%d %d%%) loss_avg : %.9f' % (timeSince(start, epoch / n_epochs),
epoch, epoch / n_epochs * 100, print_loss_avg))
_, eval_loss_avg = test(input_data_eval, output_data_eval, time_length, encoder, decoder)
writer.add_scalar('Loss/eval', eval_loss_avg, epoch)
if epoch % plot_every == 0:
plot_loss_avg = plot_loss_total / plot_every
plot_losses.append(plot_loss_avg)
plot_loss_total = 0
test_pred_target, _ = test(test_input_data, test_output_data, time_length, encoder1, attn_decoder1)
test_mse = gettestACC(test_pred_target,test_output_data)
print("current test mse : %.3f" %(test_mse))
print("best test mse : %.3f" % (best_test_mse))
writer.add_scalar('bestTestAcc/epoch', best_test_mse, epoch)
print('==============================')
if test_mse < best_test_mse :
best_test_mse = test_mse
print("new test mse : %.3f" %(best_test_mse))
torch.save(encoder1.state_dict(), model_path + name + '_encoder')
torch.save(attn_decoder1.state_dict(), model_path + name + '_attention_decoder')
print('savemodel!')
np.save(save_path + name + '/test_pred_target.npy', test_pred_target)
print('save target npy')
#showPlot(plot_losses)
def test(input_data, output_data, time_length, encoder, decoder):
loss_list = []
criterion = nn.MSELoss()
loss = 0
input_tensor_list, target_tensor_list = testdataloader(time_length,input_data,output_data)
predict_target_tensor = np.zeros_like(output_data)
with torch.no_grad() :
for idx,(input_tensor , target_tensor) in enumerate(zip(input_tensor_list,target_tensor_list)):
encoder_hidden = torch.zeros(1, input_tensor.shape[2], encoder.hidden_size, device=device)#encoder.initHidden()
assert time_length == input_tensor.shape[0] #3
assert time_length == target_tensor.shape[0] #3
assert input_tensor.shape[2] == input_tensor.shape[2] #batchsize
target_size = target_tensor.shape[3]
encoder_outputs = torch.zeros(time_length, input_tensor.shape[2],encoder.hidden_size, device=device)
decoder_input = torch.zeros(1,input_tensor.shape[2],target_size,device=device)
for ei in range(time_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0]
decoder_hidden = encoder_hidden
for di in range(time_length):
decoder_output, decoder_hidden0 = decoder(decoder_input, decoder_hidden)
predict_target_tensor[:,idx*time_length+di] = np.transpose(decoder_output.cpu().numpy()).squeeze()
decoder_input = decoder_output.unsqueeze(0) # detach from history as input
loss += criterion(decoder_output, target_tensor[di].squeeze(0))
#writer.add_scalar('Loss/test ', loss.item()/time_length, iter)
loss_list.append(loss.item()/time_length)
loss_avg = sum(loss_list)/len(loss_list)
print("eval loss : %.9f " %(loss_avg))
return predict_target_tensor , loss_avg
class Data():
def __init__(self,x_data,y_data):
self.x_data = x_data
self.y_data = y_data
def mse(y,t) :
return np.sqrt((1/2)*np.mean((y-t)**2))
def gettestACC(y,t) :
sum = 0
for idx in range(14):
sum += mse(y[idx, :], t[idx, :])
return sum/14
def synctime(inputdata,outputdata,starttime,videofs,emgfs) :
import matplotlib.pyplot as plt
plt.plot(inputdata[:starttime*emgfs,0],inputdata[:starttime*emgfs,1])
plt.plot(outputdata[14,:starttime * videofs], outputdata[0,:starttime * videofs])
plt.show()
def testdataloader(timelength,inputdata,outputdata) :
index = 0
inputTensorList,targetTensorList = [],[]
while index+timelength <= inputdata.shape[1] :
inputTensor=torch.tensor(np.transpose(inputdata[:, index : index+timelength]), #(4,timelength)
dtype=torch.float32, device=device)
inputTensor = torch.unsqueeze(torch.unsqueeze(inputTensor, 1), 1)
targetTensor = torch.tensor(np.transpose(outputdata[:, index : index+timelength]),
dtype=torch.float32, device=device) # (bs,4)
targetTensor = torch.unsqueeze(torch.unsqueeze(targetTensor, 1), 1) # (1,1,bs,14)
inputTensorList.append(inputTensor)
targetTensorList.append(targetTensor)
index = index + timelength
return inputTensorList, targetTensorList
def dataloader(iter,timelength,inputdata,outputdata,randomindex,batchsize) :
#inputdata : (4,data_legnth)
#outdata : (14,data_legnth)
assert inputdata.shape[0] ==4
assert outputdata.shape[0] == 14
assert inputdata.shape[1] == outputdata.shape[1]
input_tensor_group, target_tensor_group = None, None
for idx in range(timelength):
if batchsize * (iter + 1) > input_data.shape[1]-timelength:
indexend = -1
else:
indexend = batchsize * (iter + 1)
input_tensor = torch.tensor(np.transpose(inputdata[:, [x+idx for x in randomindex[batchsize * iter:indexend]]]),
dtype=torch.float32, device=device) #(bs,4)
input_tensor = torch.unsqueeze(torch.unsqueeze(input_tensor, 0), 0) #(1,1,bs,4)
target_tensor = torch.tensor(
np.transpose(outputdata[:, [x + idx for x in randomindex[batchsize * iter:indexend]]]),
dtype=torch.float32, device=device) # (bs,4)
target_tensor = torch.unsqueeze(torch.unsqueeze(target_tensor,0), 0) #(1,1,bs,14)
if idx == 0 :
input_tensor_group = input_tensor
target_tensor_group = target_tensor
else :
input_tensor_group = torch.cat((input_tensor_group,input_tensor),dim=0) #(timelength,1,bs,14)
target_tensor_group = torch.cat((target_tensor_group,target_tensor),dim=0) #(timelength,1,bs,14)
if indexend != - 1:
assert input_tensor.shape[2] == batchsize
return input_tensor_group,target_tensor_group
def dataprepare(datapath,doesEval=False,test = False) :
emglist, anglelist = None , None
for filepath in glob.glob(os.path.join(datapath,'*.pkl')):
print(filepath)
with open(filepath,'rb') as f:
data = pickle.load(f)
if emglist == None and anglelist == None :
emglist = data.x_data
anglelist = data.y_data
else :
emglist.extend(data.x_data)
anglelist.extend(data.y_data)
if test :
print("test on a single experiment data")
break
assert len(emglist) == len(anglelist)
random.Random(0).shuffle(emglist)
random.Random(0).shuffle(anglelist)
emgarray,anglearray= None, None
emgarray_eval,anglearray_eval = None,None
for idx,(emg,angle) in enumerate(zip(emglist,anglelist)):
if idx == 0 :
emgarray = emg
anglearray= angle
elif idx >= len(emglist)-5 :
if idx == len(emglist)-5 :
emgarray_eval = emg
anglearray_eval = angle
else :
emgarray_eval = np.concatenate((emgarray_eval, emg), axis=1)
anglearray_eval = np.concatenate((anglearray_eval, angle), axis=1)
else :
emgarray = np.concatenate((emgarray,emg),axis = 1)
anglearray = np.concatenate((anglearray, angle), axis=1)
if not doesEval :
emgarray = np.concatenate((emgarray_eval, emgarray), axis=1)
anglearray = np.concatenate((anglearray_eval, anglearray), axis=1)
emgarray_eval = None
anglearray_eval = None
return emgarray, anglearray ,emgarray_eval , anglearray_eval
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
def showPlot(points):
plt.figure()
plt.plot(points)
plt.show()