from aim_fsm import *

import cv2
import numpy as np
from ultralytics import YOLO

WEIGHTS_VERSION = "train"     # or "train2" or "train3" or ...

WEIGHTS_PATH = r'runs/segment/' + WEIGHTS_VERSION + r'/weights/best.pt'

def roboflow_fit_resize(img, size=160):
    """
    Replicate Roboflow 'Resize: fit to NxN with black edges'.

    Returns:
        canvas: transformed image of shape (size, size, 3)
        scale: resize scale applied to original image
        pad_x: left padding in transformed image
        pad_y: top padding in transformed image
        new_w: resized width before padding
        new_h: resized height before padding
    """
    h, w = img.shape[:2]

    scale = min(size / w, size / h)
    new_w = int(round(w * scale))
    new_h = int(round(h * scale))

    resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

    canvas = np.zeros((size, size, 3), dtype=np.uint8)

    pad_x = (size - new_w) // 2
    pad_y = (size - new_h) // 2

    canvas[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = resized

    return canvas, scale, pad_x, pad_y, new_w, new_h

def map_mask_to_original(mask_reduced, scale, pad_x, pad_y, orig_w, orig_h):
    """
    Map one binary mask from 160x160 letterboxed coordinates back to original image size.

    mask_reduced: numpy array shape (160,160) or similar, values 0/1 or 0..255
    Returns:
        orig_mask: uint8 mask of shape (orig_h, orig_w), values 0 or 255
    """
    # Crop away padding region
    cropped = mask_reduced[pad_y:pad_y + int(round(orig_h * scale)),
                           pad_x:pad_x + int(round(orig_w * scale))]

    # Resize back to original image size
    restored = cv2.resize(cropped.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)

    # Normalize to 0/255
    restored = (restored > 0).astype(np.uint8) * 255
    return restored

def mask_annotate(original_image, masks_reduced, scale, pad_x, pad_y):
    orig_h, orig_w = original_image.shape[:2]
    vis = original_image.copy()
    for mask_reduced in masks_reduced:
        expanded_mask = map_mask_to_original(mask_reduced.data[0].cpu().numpy(), scale, pad_x, pad_y, orig_w, orig_h)
        vis[expanded_mask > 0] = (255, 0, 0)
    blended = cv2.addWeighted(original_image, 0.7, vis, 0.3, 0)
    return blended

class DominoSegment(StateMachineProgram):
    def __init__(self):
        super().__init__()
        print('*****>> Loading YOLO weights from', WEIGHTS_PATH)
        self.model = YOLO(WEIGHTS_PATH)
        self.model.eval()
        self.results = []

    class DetectDominoes(StateNode):
        def start(self, event=None):
            super().start(event)
            self.parent.image = self.robot.camera_image
            self.parent.fitted = roboflow_fit_resize(self.parent.image)
            (canvas, scale, pad_x, pad_y, orig_w, orig_h) = self.parent.fitted
            self.parent.results = self.parent.model(canvas, conf=0.3)
            self.post_data(self.parent.results)

    class DisplayResults(StateNode):
        def start(self, event=None):
            super().start(event)
            results = event.data if isinstance(event, DataEvent) else self.parent.results
            if len(results) == 0:
                return
            elif len(results) > 1:
                print(f'{len(results)=}.  Displaying first result.')
            result = results[0]
            print('Recognized', result.verbose(), 'confidence:', result.boxes.conf.cpu().numpy())
            masks_reduced = result.masks
            
            if masks_reduced:
                (canvas, scale, pad_x, pad_y, new_w, new_h) = self.parent.fitted
                annotated_image =  mask_annotate(self.parent.image, masks_reduced, scale, pad_x, pad_y)
            else:
                annotated_image = self.parent.image
            imshow('domino', annotated_image)
            imshow('dominoplot', result.plot(line_width=1, labels=False))

    $setup{
        StateNode() =T(2)=> loop   # weight for viewers to load

        loop: Print("Type 'tm' to recognize a domino") =TM=>
          self.DetectDominoes() =D=> self.DisplayResults() =N=> loop
    }
