Commit 8c8f2b85 authored by Mira Arabi Haddad's avatar Mira Arabi Haddad
Browse files

GROUPWORK_Finalize plotting functions

parent aa1b88c9
......@@ -7,17 +7,41 @@ def generate_disc_set(nb, center, radius):
center = torch.tensor(center)
# (x, y)
input = torch.empty(nb, 2).uniform_(0, 1)
# if (x-c_x)^2 + (y-c_y)^2 - r^2 >= 0 then sign=1,
# if (x-c_x)^2 + (y-c_y)^2 - r^2 < 0 then sign=-1
target = input.sub(center).pow(2).sum(1).sub(radius**2).sign().add(1).div(2).long()
# if (x-c_x)^2 + (y-c_y)^2 - r^2 >= 0 then 0,
# if (x-c_x)^2 + (y-c_y)^2 - r^2 < 0 then 1
target = (input.sub(center).pow(2).sum(1).sub(radius**2) < 0).long()
return input, target
# we will use this for both, at the beginning, and then to visually show the results for the test data
def plot_figure(input, output, center, radius):
def plot_data(input, output, center, radius, title):
plt.figure()
# data points
plt.scatter(input[:,0], input[:,1], c=output, cmap='RdYlGn')
# disc with given center and radius
circle = plt.Circle(center, radius, color='black', fill=False, lw=5)
plt.gca().add_patch(circle)
plt.gca().set_aspect('equal', adjustable='box')
plt.title(title)
plt.savefig(f'{title}.png')
# Plot function for plotting train accuracies, and train losses
def plot_figures(d, title, acc=False):
for value in d.values():
plt.plot(value)
plt.xlabel("Epochs")
if acc == False:
plt.ylabel("Loss")
else:
plt.ylabel("Accuracy")
plt.legend([k for k in d.keys()])
plt.title(title)
plt.savefig(f'{title}.png')
plt.show()
plt.close('all')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment