Commit 2689acf9 authored by Elif Ceylan's avatar Elif Ceylan
Browse files

GROUPWORK_minor corrections

parent df35ccf2
# External
import torch
import math
import matplotlib.pyplot as plt
def generate_disc_set(nb):
def generate_disc_set(nb, center, radius):
# center, convert to tensor for element-wise subtraction
center = torch.tensor(center)
# (x, y)
input = torch.empty(nb, 2).uniform_(0, 1)
# y^2 = r^2 - x^2
target = input.add(-0.5).pow(2).sum(1).sub(1/(2*math.pi)).sign().add(-3).div(2).long()
return input, ~target
# if (x-c_x)^2 + (y-c_y)^2 - r^2 >= 0 then sign=1,
# if (x-c_x)^2 + (y-c_y)^2 - r^2 < 0 then sign=-1
target = input.sub(center).pow(2).sum(1).sub(radius**2).sign().add(1).div(2).long()
return input, target
# we will use this for both, at the beginning, and then to visually show the results for the test data
def plot_figure(input, output, center=(0.5,0.5), radius=1/math.sqrt(2*math.pi)):
def plot_figure(input, output, center, radius):
plt.figure()
plt.scatter(input[:,0], input[:,1], c=output, cmap='RdYlGn')
circle = plt.Circle(center, radius, color='black', fill=False, lw=5)
plt.gca().add_patch(circle)
plt.gca().set_aspect('equal', adjustable='box')
plt.show()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment