Commit 2689acf9 by Elif Ceylan

### GROUPWORK_minor corrections

parent df35ccf2
 # External import torch import math import matplotlib.pyplot as plt def generate_disc_set(nb): def generate_disc_set(nb, center, radius): # center, convert to tensor for element-wise subtraction center = torch.tensor(center) # (x, y) input = torch.empty(nb, 2).uniform_(0, 1) # y^2 = r^2 - x^2 target = input.add(-0.5).pow(2).sum(1).sub(1/(2*math.pi)).sign().add(-3).div(2).long() return input, ~target # if (x-c_x)^2 + (y-c_y)^2 - r^2 >= 0 then sign=1, # if (x-c_x)^2 + (y-c_y)^2 - r^2 < 0 then sign=-1 target = input.sub(center).pow(2).sum(1).sub(radius**2).sign().add(1).div(2).long() return input, target # we will use this for both, at the beginning, and then to visually show the results for the test data def plot_figure(input, output, center=(0.5,0.5), radius=1/math.sqrt(2*math.pi)): def plot_figure(input, output, center, radius): plt.figure() plt.scatter(input[:,0], input[:,1], c=output, cmap='RdYlGn') circle = plt.Circle(center, radius, color='black', fill=False, lw=5) plt.gca().add_patch(circle) plt.gca().set_aspect('equal', adjustable='box') plt.show()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!