# Watershed demo from OpenCV documentation.

import numpy as np
import cv2
from matplotlib import pyplot as plt

fig, axes = plt.subplots(2,4)
(ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8) = axes.flat
fig.set_size_inches(8,6)
fig.suptitle('Watershed Algorithm')

# Load the test image
img = cv2.imread('water_coins.jpg')
cimg = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
ax1.imshow(cimg)
ax1.axis(False)
ax1.set_title('Input')

# Thresholding; use Otsu's algorithm to determine the threshold and invert the result
gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
ret, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
ax2.imshow(thresh, cmap='gray')
ax2.axis(False)
ax2.set_title('Thresholded')

# Noise removal by expansion, and distance transform
kernel = np.ones((3,3),np.uint8)
opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
ax3.imshow(dist_transform)
ax3.axis(False)
ax3.set_title('Distance')

# Find sure background area
sure_bg = cv2.dilate(opening,kernel,iterations=3)
# Find sure foreground area
ret, sure_fg = cv2.threshold(dist_transform,0.7*dist_transform.max(),255,0)
sure_fg = np.uint8(sure_fg)
ax4.imshow(sure_fg)
ax4.axis(False)
ax4.set_title('Centers')

# Find unknown region
unknown = cv2.subtract(sure_bg,sure_fg)
ax5.imshow(unknown)
ax5.axis(False)
ax5.set_title('Unknowns')

# Marker labelling
ret, markers = cv2.connectedComponents(sure_fg)
# Add a positive constant to all labels to be sure background is not 0
markers += 1
# Now, mark the region of unknown with zero
markers[unknown==255] = 0
m2 = markers.copy()
m2[m2>1] += 50  # shift the colors away from the background color
ax6.imshow(m2)
ax6.axis(False)
ax6.set_title('ConnComponent')

# Apply watershed algorithm; boundaries will be marked by -1
water_markers = cv2.watershed(img, markers)
ax7.imshow(water_markers, cmap='jet')
ax7.axis(False)
ax7.set_title('Watershed')

# Overlay detected boundaries on the test image
cimg[water_markers == -1] = [255,0,0]
ax8.imshow(cimg)
ax8.axis(False)
ax8.set_title('Segmented')

plt.pause(0.01)
input('Press Enter to exit...')
