Commit 638a1f1a authored by Elif Ceylan's avatar Elif Ceylan
Browse files

GROUPWORK_minor changes

parent a49e768a
......@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
# Internal
import modules as n
from optim import *
from train import *
from net import Network
from helper import *
......@@ -13,18 +14,13 @@ from helper import *
# autograd globally off
torch.set_grad_enabled(False)
# generate train and test data
train_input, train_target = generate_disc_set(1000)
test_input, test_target = generate_disc_set(1000)
# normalize train and test inputs
# mean, std = train_input.mean(), train_input.std()
# train_input.sub_(mean).div_(std)
# test_input.sub_(mean).div_(std)
# generate train and test data for a given center and radius of a circle
center = (0.5, 0.5)
radius = 1/math.sqrt(2*math.pi)
train_input, train_target = generate_disc_set(1000, center, radius)
test_input, test_target = generate_disc_set(1000, center, radius)
# network parameters
lr = 1e-4
gamma = 0.9
......@@ -65,9 +61,8 @@ networks = {
],
}
# # initialize a network from the networks dictionary
model = Network(networks[1])
# train the network
# initialize a network from the networks dictionary
model = Network(networks[3])
optimizer = SGD(model, mini_batch_size, lr, gamma)
loss_train, acc_train = train_model(model, train_input, train_target, n.BCE(), lr, gamma, mini_batch_size, nb_epochs, optimizer=optimizer)
# test the network
......@@ -78,11 +73,15 @@ print(f' test_acc = {acc_test}')
print(output.size())
plot_figure(test_input, test_target, center=center, radius=radius)
plot_figure(test_input, output, center=center, radius=radius)
plot_figure(test_input, ~test_target, center=center, radius=radius)
plt.title('Test input x Test target')
plot_figure(test_input, ~output.int(), center=center, radius=radius)
plt.title('Test input x Network output')
plt.plot(range(nb_epochs), loss_train)
plt.plot(range(nb_epochs), acc_train)
plt.figure()
plt.plot(range(nb_epochs), loss_train, label='Train Loss')
plt.plot(range(nb_epochs), acc_train, label='Train Accuracy')
plt.legend()
plt.show()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment