import numpy as np
!pip install mediapipe
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.python._framework_bindings import image as image_module
_Image = image_module.Image
from mediapipe.python._framework_bindings import image_frame
_ImageFormat = image_frame.ImageFormat
BG_COLOR = (192, 192, 192) # gray
MASK_COLOR = (255, 255, 255) # white
# Create the options that will be used for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='hair.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,
output_category_mask=True)
# Create the image segmenter
with vision.ImageSegmenter.create_from_options(options) as segmenter:
# Loop through demo image(s)
for image_file_name in IMAGE_FILENAMES:
# Create the MediaPipe image file that will be segmented
# image = mp.Image.create_from_file(image_file_name)
# read image using OpenCV
rgb_image = cv2.imread("john3.png")
#cv2.imshow("rgb_image")
rgba_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGBA)
# set alpha channel to empty.
rgba_image[:,:,3] = 0
# create MP Image object from numpy array
image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image)
# Retrieve the masks for the segmented image
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
# Generate solid color images for showing the output segmentation mask.
image_data = image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
#fg_image[:] = MASK_COLOR
fg_image[:, :, 0:3] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.2
output_image = np.where(condition, fg_image, bg_image)
print(f'Segmentation mask of {name}:')
resize_and_show(output_image)