Commit 20efc7bd authored by Elif Ceylan's avatar Elif Ceylan
Browse files

GROUPWORK_minor changes

parent 638a1f1a
# Internal
import torch
from optim import SGD
import modules as n
def train_model(model, train_input, train_target, criterion, lr, gamma, mini_batch_size, nb_epochs, optimizer=None):
if type(model.mods[-1]).__name__ == 'Sigmoid':
......@@ -18,8 +17,8 @@ def train_model(model, train_input, train_target, criterion, lr, gamma, mini_bat
for b in range(0, train_input.size(0), mini_batch_size):
## forward
output = model.forward(train_input.narrow(0, b, mini_batch_size))
y = train_target.narrow(0, b, mini_batch_size)
y = train_target.narrow(0, b, mini_batch_size)
## loss
loss = criterion.forward(output.squeeze(), y)
acc_loss = acc_loss + loss.item()
......@@ -41,14 +40,14 @@ def train_model(model, train_input, train_target, criterion, lr, gamma, mini_bat
optimizer.step()
else:
model.update_params(lr)
output[output<=threshold]=0
output[output>threshold]=1
# if(e==(nb_epochs-1)):
for k in range(mini_batch_size):
output[output<=threshold]=0
output[output>threshold]=1
if y[k] != output[k]:
nb_errors = nb_errors + 1
accuracy_e = 1 - nb_errors/train_input.size(0)
print(f'e, acc: {accuracy_e}')
print(e, acc_loss/train_input.size(0))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment