from aim_fsm import *

import PIL
import numpy as np
import time
import cv2

import torch
import torchvision
from torchvision import transforms
from torchvision.models import MobileNetV2
from labels import imagenet_labels

model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model.eval()

preprocess = transforms.Compose([
    transforms.CenterCrop((480,480)),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class MobileNet(StateMachineProgram):
  def start(self):
    self.mob_image = np.zeros((480,480,3))
    super().start()
#    namedWindow('mobilenet')
    print("Type 'tm' to take a picture and classify it.")

  def user_image(self, image, gray):
    imshow('mobilenet', self.mob_image)

  def user_annotate(self,image):
    half_box = 478 // 2
    cv2.rectangle(image, (320-half_box, 240-half_box), (320+half_box, 240+half_box), (255,255,0), 2)
    return image

  class MobileNetClassify(StateNode):
    def start(self,event=None):
      super().start(event)
      image = PIL.Image.fromarray(self.robot.camera_image)
      input_tensor = preprocess(image)
      im = input_tensor.detach().numpy().transpose((1,2,0))
      min = im.min()
      max = im.max()
      im = (im - min) / (max-min)
      self.parent.mob_image = im * 256
     
      input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
      output = model(input_batch)
      probs = torch.nn.functional.softmax(output[0], dim=0).detach().numpy()
      z = sorted(zip(probs,range(1000)), key=lambda x: x[0])
      top5 = list(reversed(z[-5:]))
      for t in top5:
        print('%7.5f  %s' % (t[0], imagenet_labels[t[1]]))
      raw_label = imagenet_labels[top5[0][1]]
      sep = raw_label.find(',')
      label = raw_label[:sep] if sep > -1 else raw_label
      self.post_data(label)


  $setup{
    start: StateNode() =TM=> mobile

    mobile: self.MobileNetClassify() =SayData=> Say() =N=> start
  }
