Commit 1d4077bc authored by Mira Arabi Haddad's avatar Mira Arabi Haddad
Browse files

GROUPWORK_refactoring

parent 92480fd3
import torch
import dlc_practical_prologue
\ No newline at end of file
import dlc_practical_prologue as prologue
from net import *
from train import *
from helper import *
# ## file that they will run
# ## Basic Model
# model1 = CNN_VGG(200)
# model = CNN_VGG(200)
mini_batch_size=100
# # # target = 0 if x1 > x2, target = 1 if x1 <= x2
# train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(1000)
# for _ in range(3):
# model = NN(200)
# train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(1000)
# # print(train_input[0])
# loss_train = train_model_basic(model, train_input, train_target, mini_batch_size, nb_epochs=19)
# errors = compute_nb_errors(model, train_input, train_target, mini_batch_size)
# errors_test, loss_test = compute_nb_errors_test(model, test_input, test_target, mini_batch_size, nb_epochs=19)
# print(f'accuracy of Basic = {100-(errors/10)}')
# print(f'accuracy of Basic, testing = {100-(errors_test/10)}')
# # plt.plot(loss_train)
# # plt.plot(loss_test)
# # plt.legend(['train', 'test'])
# # plt.show()
# model2 = NN_Classification(200)
# output = train_model_class(model2, train_input, train_classes, mini_batch_size, nb_epochs=100)
# # output1, output2 = train_model_class(model2, train_input, train_classes, mini_batch_size, nb_epochs=100)
# output_class = torch.cat((output1, output2), 1).detach()
# # print(output_class.size())
# # errors = compute_nb_errors_class(model2, train_input, train_classes, mini_batch_size)
# # print((1-errors/20000)*100)
# # print(output1.size(), output2.size())
# # print(torch.cat((output1, output2), 1).size())
# model3 = MLP_Comparer(200)
# train_model_comp(model3, output_class, train_target, mini_batch_size, nb_epochs=100)
# # error_comp = compute_nb_errors_comp(model3, output_class, train_target, mini_batch_size)
# # print(f'accuracy of comparing = {100-(error_comp/10)}')
for _ in range(5):
model_WS = CNN_WS(200)
train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(1000)
train_model_basic(model_WS, train_input, train_target, mini_batch_size, nb_epochs=25)
errors_WS = compute_nb_errors(model_WS, train_input, train_target, mini_batch_size)
errors_WS_test = compute_nb_errors(model_WS, test_input, test_target, mini_batch_size)
print(f'accuracy of Weight Sharing = {100-(errors_WS/10)}')
print(f'accuracy of Weight Sharing, testing = {100-(errors_WS_test/10)}')
# for _ in range(5):
# model_WSAL = CNN_WS_AL(200)
# train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(1000)
# train_model_WSAL(model_WSAL, train_input, train_target, train_classes, mini_batch_size, nb_epochs=40)
# errors_WSAL = compute_nb_errors_AL(model_WSAL, train_input, train_target, mini_batch_size)
# errors_WSAL_test = compute_nb_errors_AL(model_WSAL, test_input, test_target, mini_batch_size)
# print(f'accuracy of Weight Sharing +AL= {100-(errors_WSAL/10)}')
# print(f'accuracy of Weight Sharing, testing +AL= {100-(errors_WSAL_test/10)}')
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment