Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Mira Arabi Haddad
DL Mini Projects Group M
Commits
638a1f1a
Commit
638a1f1a
authored
Dec 17, 2021
by
Elif Ceylan
Browse files
GROUPWORK_minor changes
parent
a49e768a
Changes
1
Hide whitespace changes
Inline
Side-by-side
p2/test.py
View file @
638a1f1a
...
...
@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
# Internal
import
modules
as
n
from
optim
import
*
from
train
import
*
from
net
import
Network
from
helper
import
*
...
...
@@ -13,18 +14,13 @@ from helper import *
# autograd globally off
torch
.
set_grad_enabled
(
False
)
# generate train and test data
train_input
,
train_target
=
generate_disc_set
(
1000
)
test_input
,
test_target
=
generate_disc_set
(
1000
)
# normalize train and test inputs
# mean, std = train_input.mean(), train_input.std()
# train_input.sub_(mean).div_(std)
# test_input.sub_(mean).div_(std)
# generate train and test data for a given center and radius of a circle
center
=
(
0.5
,
0.5
)
radius
=
1
/
math
.
sqrt
(
2
*
math
.
pi
)
train_input
,
train_target
=
generate_disc_set
(
1000
,
center
,
radius
)
test_input
,
test_target
=
generate_disc_set
(
1000
,
center
,
radius
)
# network parameters
lr
=
1e-4
gamma
=
0.9
...
...
@@ -65,9 +61,8 @@ networks = {
],
}
# # initialize a network from the networks dictionary
model
=
Network
(
networks
[
1
])
# train the network
# initialize a network from the networks dictionary
model
=
Network
(
networks
[
3
])
optimizer
=
SGD
(
model
,
mini_batch_size
,
lr
,
gamma
)
loss_train
,
acc_train
=
train_model
(
model
,
train_input
,
train_target
,
n
.
BCE
(),
lr
,
gamma
,
mini_batch_size
,
nb_epochs
,
optimizer
=
optimizer
)
# test the network
...
...
@@ -78,11 +73,15 @@ print(f' test_acc = {acc_test}')
print
(
output
.
size
())
plot_figure
(
test_input
,
test_target
,
center
=
center
,
radius
=
radius
)
plot_figure
(
test_input
,
output
,
center
=
center
,
radius
=
radius
)
plot_figure
(
test_input
,
~
test_target
,
center
=
center
,
radius
=
radius
)
plt
.
title
(
'Test input x Test target'
)
plot_figure
(
test_input
,
~
output
.
int
(),
center
=
center
,
radius
=
radius
)
plt
.
title
(
'Test input x Network output'
)
plt
.
plot
(
range
(
nb_epochs
),
loss_train
)
plt
.
plot
(
range
(
nb_epochs
),
acc_train
)
plt
.
figure
()
plt
.
plot
(
range
(
nb_epochs
),
loss_train
,
label
=
'Train Loss'
)
plt
.
plot
(
range
(
nb_epochs
),
acc_train
,
label
=
'Train Accuracy'
)
plt
.
legend
()
plt
.
show
()
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment