Commit 45c50793 authored by Mira Arabi Haddad's avatar Mira Arabi Haddad
Browse files

GROUPWORK train.py

parent 627023ac
......@@ -41,8 +41,7 @@ def train_model(model, train_input, train_target, criterion, lr, gamma, mini_bat
else:
model.update_params(lr)
output[output<=threshold]=0
output[output>threshold]=1
output = (output>threshold).long()
# if(e==(nb_epochs-1)):
for k in range(mini_batch_size):
if y[k] != output[k]:
......@@ -66,8 +65,7 @@ def test_model(model, test_input, test_target, mini_batch_size):
nb_errors = 0
for b in range(0, test_input.size(0), mini_batch_size):
output = model.forward(test_input.narrow(0, b, mini_batch_size))
output[output<=threshold]=0
output[output>threshold]=1
output = (output>threshold).long()
y = test_target.narrow(0, b, mini_batch_size)
for k in range(mini_batch_size):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment