import time
import numpy as np
import numpy.random
from math import pi, sin, cos, atan2, sqrt
import turtle

def wrap_angle(theta):
    return theta + 2*pi if theta < -pi else theta - 2*pi if theta > pi else theta

class Robot:
    def __init__(self):
        self.x = self.y = self.theta = 0
        self.lm_distances = []
        self.lm_bearings = []

    def draw(self):
        t_rob.clear()
        t_rob.penup()
        t_rob.shapesize(3)
        t_rob.shape('tri')
        t_rob.color('black')
        t_rob.setposition(self.x, self.y)
        t_rob.setheading(self.theta * 180/pi + 90)
        t_rob.stamp()

    def set_pose(self, x, y, theta, pos_variance=5, hdg_variance=pi/40):
        pos_covariance = ((pos_variance, 0), (0, pos_variance))
        dpos = coll.rnd_gen.multivariate_normal((0,0), pos_covariance)
        dtheta = coll.rnd_gen.normal(loc=0, scale=hdg_variance)
        self.x = x + dpos[0]
        self.y = y + dpos[1]
        self.theta = theta + dtheta
        self.draw()
        xdiffs = [lm.x - self.x for lm in landmarks]
        ydiffs = [lm.y - self.y for lm in landmarks]
        num_landmarks = len(landmarks)
        self.lm_distances = [sqrt(xdiffs[i]**2 + ydiffs[i]**2) for i in range(num_landmarks)]
        self.lm_bearings = [wrap_angle(atan2(ydiffs[i], xdiffs[i]) - self.theta - pi/2) for i in range(num_landmarks)]
        # print(f'theta={self.theta*180/pi:.1f} ' + \
        #       f'lm_distances={[round(x) for x in self.lm_distances]} ' + \
        #       f'lm_bearings={[round(x*180/pi) for x in self.lm_bearings]}')
        

class Particle:
    def __init__(self, index, pos=(0,0), theta=0):
        self.index = index
        self.x = pos[0]
        self.y = pos[1]
        self.theta = theta
        self.weight = 1

    def __repr__(self):
        return f'<Particle {self.x:.1f}, {self.y:.1f} @ {self.theta*180/pi:.1f}>'

    def advance(self, distance):
        self.x += distance * cos(self.theta + pi/2)
        self.y += distance * sin(self.theta + pi/2)

    def turn(self, phi):
        self.theta = wrap_angle(self.theta + phi)

    def evaluate(self):
        num_landmarks = len(landmarks)
        lm_range = range(num_landmarks)
        xdiffs = [landmarks[i].x - self.x for i in lm_range]
        ydiffs = [landmarks[i].y - self.y for i in lm_range]
        distances = [sqrt(xdiffs[i] ** 2 + ydiffs[i] ** 2) for i in lm_range]
        error = sqrt(sum([abs(distances[i] - robot.lm_distances[i])**2 for i in lm_range]))
        self.weight *= min(1, sqrt(20/error))
                  
    def draw(self):
        t.shape('tri')
        t.setposition(self.x, self.y)
        t.setheading(self.theta * 180/pi + 90)
        t.shapesize(2.0, 2.0, 1)
        t.color((self.weight, 0, 1-self.weight))
        t.stamp()

    def draw_landmark_predictions(self):
        for i in range(len(landmarks)):
            lm = landmarks[i]
            t.penup()
            t.goto(self.x, self.y)
            heading = wrap_angle(self.theta + robot.lm_bearings[i] + pi/2)
            t.setheading(heading*180/pi)
            t.pendown()
            t.pensize(2)
            t.color(lm.color)
            t.forward(robot.lm_distances[i])
        turtle.tracer(1, 20)
        t.penup()
        t.forward(50)
        t.forward(-50)
        turtle.tracer(50000,0)
        for i in range(len(landmarks)):
            for j in range(7): t.undo()

class ParticleCollection:
    def __init__(self, num_particles=500, pos_mean=(0,0), pos_variance=None,
                 pos_covariance=((1,0),(0,1)), hdg_mean=0, hdg_variance=0):
        if pos_variance is not None:
            pos_covariance = ((pos_variance, 0), (0, pos_variance))
        self.num_particles = num_particles
        self.rnd_gen = numpy.random.Generator(numpy.random.PCG64())
        positions = self.rnd_gen.multivariate_normal(pos_mean, pos_covariance, size=num_particles)
        headings = self.rnd_gen.normal(loc=hdg_mean, scale=hdg_variance, size=num_particles)
        self.particles = [Particle(i, positions[i], headings[i]) for i in range(num_particles)]

    def weighted_pose(self):
        wsum = xsum = ysum = qcos = qsin = 0
        for p in self.particles:
            wsum += p.weight
            xsum += p.x * p.weight
            ysum += p.y * p.weight
            qcos += cos(p.theta) * p.weight
            qsin += sin(p.theta) * p.weight
        x = xsum / wsum
        y = ysum / wsum
        theta = atan2(qsin,qcos)
        return (x, y, theta)
    
    def draw(self):
        t.penup()
        t.clearstamps()
        for p in self.particles:
            p.draw()
        (x, y, theta) = self.weighted_pose()
        t.setpos(x,y)
        t.setheading(theta*180/pi + 90)
        t.color('green')
        t.shapesize(6.0, 6.0, 1)
        t.stamp()
        robot.draw()

    def sample_particle_landmarks(self):
        t.pensize(6)
        for p in coll.particles:
            if p.index % 40 == 0:
                p.draw_landmark_predictions()

    def advance(self, distance, variance=0):
        distances = self.rnd_gen.normal(loc=distance, scale=variance, size=self.num_particles)
        for p in self.particles:
            p.advance(distances[p.index])
        self.draw()

    def turn(self, angle, variance=0):
        angles = self.rnd_gen.normal(loc=angle, scale=variance, size=self.num_particles)
        for p in self.particles:
            p.turn(angles[p.index])
        self.draw()

    def evaluate(self):
        for p in self.particles: p.evaluate()

    def resample(self):
        cumsum = 0
        cdf = np.zeros(self.num_particles)
        for i in range(self.num_particles):
            cumsum += self.particles[i].weight**4
            cdf[i] = cumsum
        np.divide(cdf, cumsum, cdf)

        # Resampling loop: choose particles to spawn
        uincr = 1.0 / self.num_particles
        u = numpy.random.random() * uincr
        index = 0
        self.new_indices = np.zeros(self.num_particles, dtype=int)
        for j in range(self.num_particles):
            while u > cdf[index]:
                index += 1
            self.new_indices[j] = index
            u += uincr
        new_particles = [None] * self.num_particles
        for i in range(self.num_particles):
            p = self.particles[self.new_indices[i]]
            new_particles[i] = Particle(i, pos=(p.x, p.y), theta=p.theta)
        self.particles = new_particles
        self.draw()

class Landmark():
    def __init__(self, x, y, color, size, shape_fn):
        self.x = x
        self.y = y
        self.color = color
        self.size = size
        self.shape_fn = shape_fn

    def __repr__(self):
        return f'<Landmark {self.shape_fn.__name__} ({self.x},{self.y})>'

    def draw(self):
        t_lm.penup()
        t_lm.setposition(self.x, self.y)
        t_lm.setheading(0)
        t_lm.color(self.color)
        self.shape_fn(self.size)

turtle.tracer(50000, delay=0)
turtle.speed(0)
turtle.hideturtle()  # must come before turtle.setup(...)
turtle.setup(width=800, height=600)
turtle.register_shape("tri", ((-3, -2), (0, 3), (3, -2), (0, 0)))

t = turtle.Turtle()
t.hideturtle()
t.penup()

t_txt = turtle.Turtle()
t_txt.hideturtle()
t_txt.penup()

t_rob = turtle.Turtle()
t_rob.hideturtle()
t_rob.penup()

t_lm = turtle.Turtle()
t_lm.hideturtle()
t_lm.penup()

def clear_all():
    t.clear()
    t_txt.clear()
    t_rob.clear()
    t_lm.clear()

def show_text(text):
    t_txt.clear()
    t_txt.penup()
    t_txt.goto(-200, -250)
    t_txt.write(text, font=('Arial', 16, 'bold'), move=False)

track = ((0, 0,), (0, 150), (-300, 200), (-300, -150), (0, 0), (0, 150))
# track = ((0, 0), (0, 150), (-300, 200))

def draw_track():
    t.penup()
    t.setpos(track[0][0], track[0][1])
    t.pendown()
    t.pensize(1)
    t.color('black')
    for point in track[1:]:
        t.goto(*point)

def run_track(max_dist_variance = 20, turn_variance = pi/20,
              pos_variance = 5,
              hdg_variance = pi/40,
              evaluate = False,
              sample_landmarks = False,
              do_resample = lambda i : False):
    global robot
    robot.set_pose(track[0][0], track[0][1], 0,
                   pos_variance=pos_variance, hdg_variance=hdg_variance)
    coll.draw()
    for i in range(1, len(track)):
        prev_point = track[i-1]
        this_point = track[i]
        pause()
        xdiff = this_point[0] - prev_point[0]
        ydiff = this_point[1] - prev_point[1]
        dist = (xdiff ** 2 + ydiff ** 2) ** 0.5
        bearing = atan2(ydiff,xdiff)
        turn_amount = wrap_angle(bearing - robot.theta - pi/2)
        #print(f'bearing={bearing*180/pi:.1f}  theta={robot.theta*180/pi:.1f}  turn={turn_amount*180/pi:.1f}')
        coll.turn(turn_amount, turn_variance)
        if i < len(track)-1:
            coll.advance(dist, min(max_dist_variance, dist/5))
            robot.set_pose(*this_point, wrap_angle(robot.theta + turn_amount),
                           pos_variance=pos_variance, hdg_variance=hdg_variance)
        else:
            robot.set_pose(robot.x, robot.y, wrap_angle(robot.theta + turn_amount),
                           pos_variance=pos_variance, hdg_variance=hdg_variance)
        if evaluate:
            coll.evaluate()
        if sample_landmarks:
            coll.sample_particle_landmarks()
        coll.draw()
        if do_resample(i):
            coll.resample()

def star(side=10):
    phi = (1 + sqrt(5)) / 2
    t_lm.penup()
    t_lm.setheading(90+18)
    t_lm.forward(side * phi)
    t_lm.pendown()
    t_lm.pensize(1)
    t_lm.begin_fill()
    for i in range(5):
        t_lm.left(144)
        t_lm.forward(side)
        t_lm.right(72)
        t_lm.forward(side)
    t_lm.end_fill()

def surveyor(radius=10):
    def wedge(x,y,h):
        t_lm.penup()
        t_lm.setposition(x, y)
        t_lm.setheading(h)
        t_lm.pensize(1)
        t_lm.pendown()
        t_lm.begin_fill()
        t_lm.forward(radius)
        t_lm.setheading(h+90)
        t_lm.circle(radius, 90)
        t_lm.goto(x,y)
        t_lm.end_fill()
    (x,y) = t_lm.position()
    t_lm.penup()
    t_lm.setheading(0)
    t_lm.forward(radius)
    t_lm.setheading(90)
    t_lm.pensize(2)
    t_lm.pendown()
    t_lm.circle(radius)
    t_lm.penup()
    wedge(x, y, 0)
    wedge(x, y, 180)



coll = ParticleCollection(num_particles=400, pos_mean=track[0], pos_variance=250)

lm_star = Landmark(-50, 250, 'blue', 15, star)
lm_surveyor = Landmark(150, -50, 'magenta', 12, surveyor)

landmarks = (lm_star, lm_surveyor)

def draw_landmarks():
    t_lm.clear()
    for lm in landmarks: lm.draw()

robot = Robot()

"""
draw_landmarks()
draw_track()
run_track(do_resample=lambda i: 1 == i%2)
"""

def pause():
    input('Press <Enter> to continue:')

def demo_distribution():
    global coll
    clear_all()
    coll = ParticleCollection(num_particles=1, pos_mean=track[0], pos_variance=0)
    coll.draw()
    show_text('Where is the robot?')
    pause()
    clear_all()
    coll = ParticleCollection(num_particles=400, pos_mean=track[0], pos_variance=50)
    coll.draw()
    show_text('Low variance = High confidence')
    pause()
    clear_all()
    coll = ParticleCollection(num_particles=400, pos_mean=track[0], pos_variance=800)
    coll.draw()
    show_text('High variance = Low confidence')

def demo_drag_particles():
    global coll
    clear_all()
    coll = ParticleCollection(num_particles=400, pos_mean=track[0], pos_variance=250)
    coll.draw()
    draw_track()
    show_text('As the robot moves, drag the particles along')
    run_track(max_dist_variance=0, turn_variance=0, pos_variance=0, hdg_variance=0)

def demo_motion_model():
    global coll
    clear_all()
    coll = ParticleCollection(num_particles=400, pos_mean=track[0], pos_variance=250)
    coll.draw()
    draw_track()
    show_text('Motion is unreliable, so add noise to the particles')
    run_track()

def demo_landmarks():
    global coll
    clear_all()
    coll = ParticleCollection(num_particles=400, pos_mean=track[0], pos_variance=250)
    coll.draw()
    draw_track()
    draw_landmarks()
    show_text('Use landmarks to evaluate particles')
    run_track(evaluate=True, sample_landmarks=True)

def demo_resample():
    global coll
    clear_all()
    coll = ParticleCollection(num_particles=400, pos_mean=track[0], pos_variance=250)
    coll.draw()
    draw_track()
    draw_landmarks()
    show_text('Resample to replace useless particles')
    run_track(evaluate=True, sample_landmarks=True, do_resample = lambda i: 1 == i%2)

demo_list = [demo_distribution, demo_drag_particles, demo_motion_model, demo_landmarks, demo_resample]

def demo_loop():
    while True:
        for i in range(len(demo_list)):
            print(f'  {i+1}: {demo_list[i].__name__}')
        choice = input('Demo? ')
        try:
            val = int(choice)
            demo_list[val - 1]()
        except Exception:
            pass
        print()

demo_loop()
