from aim_fsm import *

import PIL
import numpy as np
import time
import cv2

import torch
import torchvision
from torchvision import transforms
from torchvision.models import MobileNetV2
from labels import imagenet_labels

model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model.eval()

preprocess = transforms.Compose([
    transforms.CenterCrop((480,480)),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class MobileNet(StateMachineProgram):
  def start(self):
    self.mob_image = np.zeros((480,480,3))
    super().start()
#    namedWindow('mobilenet')
    print("Type 'tm' to take a picture and classify it.")

  def user_image(self, image, gray):
    imshow('mobilenet', self.mob_image)

  def user_annotate(self,image):
    half_box = 478 // 2
    cv2.rectangle(image, (320-half_box, 240-half_box), (320+half_box, 240+half_box), (255,255,0), 2)
    return image

  class MobileNetClassify(StateNode):
    def start(self,event=None):
      super().start(event)
      image = PIL.Image.fromarray(self.robot.camera_image)
      input_tensor = preprocess(image)
      im = input_tensor.detach().numpy().transpose((1,2,0))
      min = im.min()
      max = im.max()
      im = (im - min) / (max-min)
      self.parent.mob_image = im * 256
     
      input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
      output = model(input_batch)
      probs = torch.nn.functional.softmax(output[0], dim=0).detach().numpy()
      z = sorted(zip(probs,range(1000)), key=lambda x: x[0])
      top5 = list(reversed(z[-5:]))
      for t in top5:
        print('%7.5f  %s' % (t[0], imagenet_labels[t[1]]))
      raw_label = imagenet_labels[top5[0][1]]
      sep = raw_label.find(',')
      label = raw_label[:sep] if sep > -1 else raw_label
      self.post_data(label)


  def setup(self):
      #     start: StateNode() =TM=> mobile
      # 
      #     mobile: self.MobileNetClassify() =SayData=> Say() =N=> start
      
      # Code generated by genfsm on Fri Mar 13 03:41:16 2026:
      
      start = StateNode() .set_name("start") .set_parent(self)
      mobile = self.MobileNetClassify() .set_name("mobile") .set_parent(self)
      say1 = Say() .set_name("say1") .set_parent(self)
      
      textmsgtrans1 = TextMsgTrans() .set_name("textmsgtrans1")
      textmsgtrans1 .add_sources(start) .add_destinations(mobile)
      
      saydatatrans1 = SayDataTrans() .set_name("saydatatrans1")
      saydatatrans1 .add_sources(mobile) .add_destinations(say1)
      
      nulltrans1 = NullTrans() .set_name("nulltrans1")
      nulltrans1 .add_sources(say1) .add_destinations(start)
      
      return self
