Commit 2660aa3b authored by Mira Arabi Haddad's avatar Mira Arabi Haddad
Browse files

GROUPWORK_Fix Train+Test functions for WSAL

parent d5277a2d
......@@ -24,9 +24,7 @@ def train_model(model, train_input, train_target, lr, criterion, mini_batch_size
for p in model.parameters():
p -= lr * p.grad
print(output)
output_b = (output.view(-1)>0.5).float()
print(output_b)
for k in range(mini_batch_size):
if y[k] != output_b[k]:
nb_errors = nb_errors + 1
......@@ -40,7 +38,7 @@ def train_model(model, train_input, train_target, lr, criterion, mini_batch_size
return loss_list, acc_list
## Test Model, works for both Basic + Weight Sharing Networks
def test_model(model, test_input, test_target, criterion, mini_batch_size):
def test_model(model, test_input, test_target, mini_batch_size):
nb_errors = 0
for b in range(0, test_input.size(0), mini_batch_size):
output = model(test_input.narrow(0, b, mini_batch_size))
......@@ -182,7 +180,7 @@ def train_model_WSAL(model, train_input, train_target, train_classes, lr, criter
loss = 0.5*loss1 + 0.5*loss2 + 1.0*loss3
acc_loss = acc_loss + loss.item()
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-5)
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=0.9)
optimizer.zero_grad()
loss.backward()
optimizer.step()
......@@ -191,20 +189,28 @@ def train_model_WSAL(model, train_input, train_target, train_classes, lr, criter
for p in model.parameters():
p -= lr * p.grad
print(e, acc_loss)
output_b = (res.view(-1)>0.5).float()
for k in range(mini_batch_size):
if y[k] != output_b[k]:
nb_errors = nb_errors + 1
# print(e, acc_loss)
loss_list.append(acc_loss)
accuracy_e = 1 - nb_errors/train_input.size(0)
print(f'{e}, acc: {accuracy_e}')
# print(f'{e}, acc: {accuracy_e}')
acc_list.append(accuracy_e)
def test_model_WSAL(model, input, target, mini_batch_size):
return loss_list, acc_list
def test_model_WSAL(model, test_input, test_target, mini_batch_size):
nb_errors = 0
for b in range(0, input.size(0), mini_batch_size):
_, output = model(input.narrow(0, b, mini_batch_size))
for b in range(0, test_input.size(0), mini_batch_size):
_, output = model(test_input.narrow(0, b, mini_batch_size))
output_b = (output.view(-1)>0.5).float()
y = test_target.narrow(0, b, mini_batch_size)
for k in range(mini_batch_size):
if target[b+k] != output_b[k]:
if y[k] != output_b[k]:
nb_errors = nb_errors + 1
return nb_errors
accuracy_e = 1 - nb_errors/test_input.size(0)
return accuracy_e
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment