Commit c2efce6a authored by Mira Arabi Haddad's avatar Mira Arabi Haddad
Browse files

GROUPWORK_fixed train+test functions

parent 92ee84d9
# Internal
# Internal
import torch
from optim import SGD
import modules as n
def train_model(model, train_input, train_target, criterion, lr, gamma, mini_batch_size=100, nb_epochs=100, optimizer=None):
def train_model(model, train_input, train_target, criterion, lr, gamma, mini_batch_size, nb_epochs, optimizer=None):
if type(model.mods[-1]).__name__ == 'Sigmoid':
threshold = 0.5
elif type(model.mods[-1]).__name__ == 'Tanh':
threshold = 0
loss_list = []
nb_errors = 0
acc_list = []
for e in range(nb_epochs):
acc_loss = 0
nb_errors = 0
......@@ -48,24 +48,32 @@ def train_model(model, train_input, train_target, criterion, lr, gamma, mini_bat
output[output>threshold]=1
if y[k] != output[k]:
nb_errors = nb_errors + 1
print(f'acc: {1 - nb_errors/train_input.size(0)}')
accuracy_e = 1 - acc_loss/train_input.size(0)
print(e, accuracy_e)
loss_list.append(acc_loss)
accuracy_e = 1 - nb_errors/train_input.size(0)
print(f'e, acc: {accuracy_e}')
print(e, acc_loss/train_input.size(0))
loss_list.append(acc_loss/train_input.size(0))
acc_list.append(accuracy_e)
return nb_errors, loss_list
return loss_list, acc_list #last of acc_list is the final accuracy
def test_model(model, test_input, test_target, mini_batch_size):
all_output = torch.zeros((1000, 1))
if type(model.mods[-1]).__name__ == 'Sigmoid':
threshold = 0.5
elif type(model.mods[-1]).__name__ == 'Tanh':
threshold = 0
nb_errors = 0
for b in range(0, test_input.size(0), mini_batch_size):
output = model.forward(test_input.narrow(0, b, mini_batch_size))
output[output<=0]=0
output[output>0]=1
output[output<=threshold]=0
output[output>threshold]=1
y = test_target.narrow(0, b, mini_batch_size)
for k in range(mini_batch_size):
if y[k] != output[k]:
nb_errors = nb_errors + 1
return nb_errors
\ No newline at end of file
all_output[b: b+mini_batch_size] = output
accuracy_e = 1 - nb_errors/test_input.size(0)
return accuracy_e, all_output.squeeze()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment