Commit 0fe15956 authored by Mira Arabi Haddad's avatar Mira Arabi Haddad
Browse files

GROUPWORK_Fix testing

parent 2660aa3b
......@@ -14,16 +14,72 @@ from torch import nn
mini_batch_size=100
lr = 1e-4
for _ in range(10):
model = NN(200)
acc_train_NN_list, acc_test_NN_list, acc_train_CNN_list, acc_test_CNN_list = [], [], [], []
acc_train_WS_NN_list, acc_test_WS_NN_list, acc_train_WS_CNN_list, acc_test_WS_CNN_list = [], [], [], []
acc_train_WSAL_NN_list, acc_test_WSAL_NN_list, acc_train_WSAL_CNN_list, acc_test_WSAL_CNN_list = [], [], [], []
for _ in range(1):
# basic models
model_basic_NN = NN(200)
model_basic_CNN = CNN_VGG(200)
# models with weight sharing
model_WS_NN = NN_WS(200)
model_WS_CNN = CNN_WS(200)
# models with weight sharing and auxiliary loss
model_WSAL_NN = NN_WS_AL(200)
model_WSAL_CNN = CNN_WS_AL(200)
train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(1000)
# print(train_input[0])
loss_train, acc_train = train_model(model, train_input, train_target, lr, nn.BCELoss(), mini_batch_size, nb_epochs=25)
acc_test = test_model(model, test_input, test_target, nn.BCELoss(), mini_batch_size)
print(f' Basic Network train_acc = {acc_train[-1]}')
print(f' Basic Network test_acc = {acc_test}')
# # train + test basic models (NN+CNN)
# #NN
# loss_train_NN, acc_train_NN = train_model(model_basic_NN, train_input, train_target, lr, nn.BCELoss(), mini_batch_size, nb_epochs=25)
# acc_test_NN = test_model(model_basic_NN, test_input, test_target, mini_batch_size)
# acc_train_NN_list.append(acc_train_NN[-1])
# acc_test_NN_list.append(acc_test_NN)
# print(f' NN: Basic Network train_acc = {acc_train_NN[-1]}')
# print(f' NN: Basic Network test_acc = {acc_test_NN}')
# #CNN
# loss_train_CNN, acc_train_CNN = train_model(model_basic_CNN, train_input, train_target, lr, nn.BCELoss(), mini_batch_size, nb_epochs=25)
# acc_test_CNN = test_model(model_basic_CNN, test_input, test_target, mini_batch_size)
# acc_train_CNN_list.append(acc_train_CNN[-1])
# acc_test_CNN_list.append(acc_test_CNN)
# print(f' CNN: Basic Network train_acc = {acc_train_CNN[-1]}')
# print(f' CNN: Basic Network test_acc = {acc_test_CNN}')
# # train + test Weight Sharing models (NN+CNN)
# #NN
# loss_train_WS_NN, acc_train_WS_NN = train_model(model_WS_NN, train_input, train_target, lr, nn.BCELoss(), mini_batch_size, nb_epochs=25)
# acc_test_WS_NN = test_model(model_WS_NN, test_input, test_target, mini_batch_size)
# acc_train_WS_NN_list.append(acc_train_WS_NN[-1])
# acc_test_WS_NN_list.append(acc_test_WS_NN)
# print(f' NN: Weight Sharing Network train_acc = {acc_train_WS_NN[-1]}')
# print(f' NN: Weight Sharing Network test_acc = {acc_test_WS_NN}')
# #CNN
# loss_train_WS_CNN, acc_train_WS_CNN = train_model(model_WS_CNN, train_input, train_target, lr, nn.BCELoss(), mini_batch_size, nb_epochs=25)
# acc_test_WS_CNN = test_model(model_WS_CNN, test_input, test_target, mini_batch_size)
# acc_train_WS_CNN_list.append(acc_train_WS_CNN[-1])
# acc_test_WS_CNN_list.append(acc_test_WS_CNN)
# print(f' CNN: Weight Sharing Network train_acc = {acc_train_WS_CNN[-1]}')
# print(f' CNN: Weight Sharing Network test_acc = {acc_test_WS_CNN}')
# train + test Weight Sharing+Auxiliary Loss models (NN+CNN)
#NN
loss_train_WSAL_NN, acc_train_WSAL_NN = train_model_WSAL(model_WSAL_NN, train_input, train_target, train_classes, lr, nn.CrossEntropyLoss(), nn.BCELoss(), mini_batch_size, nb_epochs=40)
acc_test_WSAL_NN = test_model_WSAL(model_WSAL_NN, test_input, test_target, mini_batch_size)
acc_train_WSAL_NN_list.append(acc_train_WSAL_NN[-1])
acc_test_WSAL_NN_list.append(acc_test_WSAL_NN)
print(f' NN: Weight Sharing + Auxiliary Loss Network train_acc = {acc_train_WSAL_NN[-1]}')
print(f' NN: Weight Sharing + Auxiliary Loss Network test_acc = {acc_test_WSAL_NN}')
#CNN
loss_train_WSAL_CNN, acc_train_WSAL_CNN = train_model_WSAL(model_WSAL_CNN, train_input, train_target, train_classes, lr, nn.CrossEntropyLoss(), nn.BCELoss(), mini_batch_size, nb_epochs=40)
acc_test_WSAL_CNN = test_model_WSAL(model_WSAL_CNN, test_input, test_target, mini_batch_size)
acc_train_WSAL_CNN_list.append(acc_train_WSAL_CNN[-1])
acc_test_WSAL_CNN_list.append(acc_test_WSAL_CNN)
print(f' CNN: Weight Sharing + Auxiliary Loss Network train_acc = {acc_train_WSAL_CNN[-1]}')
print(f' CNN: Weight Sharing + Auxiliary Loss Network test_acc = {acc_test_WSAL_CNN}')
# model2 = NN_Classification(200)
# output = train_model_class(model2, train_input, train_classes, mini_batch_size, nb_epochs=100)
......@@ -51,7 +107,7 @@ for _ in range(10):
# print(f' test_acc = {acc_test}')
# for _ in range(5):
# for _ in range(1):
# model_WSAL = CNN_WS_AL(200)
# train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(1000)
# train_model_WSAL(model_WSAL, train_input, train_target, train_classes, mini_batch_size, nb_epochs=40)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment