|
| 1 | +## Evan's very own Object Detection program using Tensforflow MobileNet-SSD model |
| 2 | + |
| 3 | +## Some of this will be copied from Google's example at |
| 4 | +## https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb |
| 5 | + |
| 6 | +## and some will be copied from this guy's example at |
| 7 | +## https://github.com/datitran/object_detector_app/blob/master/object_detection_app.py |
| 8 | + |
| 9 | +## but I will change it to make it more understandable to me. |
| 10 | + |
| 11 | + |
| 12 | +# Import packages |
| 13 | +import os |
| 14 | +import cv2 |
| 15 | +import numpy as np |
| 16 | +import tensorflow as tf |
| 17 | +import sys |
| 18 | + |
| 19 | +# This is needed since the notebook is stored in the object_detection folder. |
| 20 | +sys.path.append("..") |
| 21 | + |
| 22 | +from utils import label_map_util |
| 23 | +from utils import visualization_utils as vis_util |
| 24 | + |
| 25 | +# Name of the directory containing the object detection module we're using |
| 26 | +MODEL_NAME = 'card_inference_graph' |
| 27 | +IMAGE_NAME = 'test3.jpg' |
| 28 | + |
| 29 | + |
| 30 | +# Grab path to current working directory |
| 31 | +CWD_PATH = os.getcwd() |
| 32 | + |
| 33 | +# Path to frozen detection graph .pb file, which contains the model that is used |
| 34 | +# for object detection. |
| 35 | +PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,'frozen_inference_graph.pb') |
| 36 | +#ckpt_path = 'C://Users/Evan/Documents/Object_Detection_stuff/tensorflow/models/research/object_detection/raccoon_inference_graph/frozen_inference_graph.pb' |
| 37 | +#PATH_TO_CKPT = ckpt_path.encode('utf8') |
| 38 | + |
| 39 | +# Path to label map file |
| 40 | +PATH_TO_LABELS = os.path.join(CWD_PATH,'training','card-labelmap.pbtxt') |
| 41 | +#PATH_TO_LABELS = os.path.join(CWD_PATH,MODEL_NAME,'object-detection.pbtxt') |
| 42 | +#label_path = 'C://Users/Evan/Documents/Object_Detection_stuff/tensorflow/models/research/object_detection/training/objectdetection.pbtxt' |
| 43 | +#PATH_TO_LABELS = label_path.encode('utf8') |
| 44 | +# Path to image |
| 45 | +PATH_TO_IMAGE = os.path.join(CWD_PATH,IMAGE_NAME) |
| 46 | + |
| 47 | +# Number of classes the object detector can identify |
| 48 | +NUM_CLASSES = 6 |
| 49 | + |
| 50 | +## Load the label map. |
| 51 | +# Label maps map indices to category names, so that when our convolution |
| 52 | +# network predicts `5`, we know that this corresponds to `airplane`. |
| 53 | +# Here we use internal utility functions, but anything that returns a |
| 54 | +# dictionary mapping integers to appropriate string labels would be fine |
| 55 | + |
| 56 | +# EVAN YOU NEED TO LOOK AT THESE FILES AND FIGURE OUT WHAT THEY'RE DOING BECAUSE |
| 57 | +# THIS SEEMS KIND OF EXCESSIVE |
| 58 | + |
| 59 | +label_map = label_map_util.load_labelmap(PATH_TO_LABELS) |
| 60 | +categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) |
| 61 | +category_index = label_map_util.create_category_index(categories) |
| 62 | + |
| 63 | +# Load the Tensorflow model into memory. |
| 64 | +# EVAN, the with statement basically makes it all close down after it's done |
| 65 | +# loading. Not sure what it does or means really. |
| 66 | +detection_graph = tf.Graph() |
| 67 | +with detection_graph.as_default(): |
| 68 | + od_graph_def = tf.GraphDef() |
| 69 | + with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: |
| 70 | + serialized_graph = fid.read() |
| 71 | + od_graph_def.ParseFromString(serialized_graph) |
| 72 | + tf.import_graph_def(od_graph_def, name='') |
| 73 | + |
| 74 | + sess = tf.Session(graph=detection_graph) |
| 75 | + |
| 76 | +## EVAN, I think this section sort of defines what the outputs of the model |
| 77 | +## are going to be |
| 78 | + |
| 79 | +## Define input and output tensors for detection_graph |
| 80 | + |
| 81 | +# Input tensor is the image |
| 82 | +image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') |
| 83 | + |
| 84 | +## Output tensors |
| 85 | +# Each box represents a part of the image where a particular object was detected |
| 86 | +detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') |
| 87 | + |
| 88 | +# Each score represents level of confidence for each of the objects. |
| 89 | +# The score is shown on the result image, together with the class label. |
| 90 | +detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') |
| 91 | +detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') |
| 92 | + |
| 93 | +# Number of objects detected |
| 94 | +num_detections = detection_graph.get_tensor_by_name('num_detections:0') |
| 95 | + |
| 96 | +# Load image, convert to RGB (which is needed by Tensorflow model), and |
| 97 | +# expand image dimensions to have shape: [1, None, None, 3] |
| 98 | +# i.e. a single-column array, where each item in the column has the pixel RGB value |
| 99 | +image = cv2.imread(PATH_TO_IMAGE) |
| 100 | +#image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| 101 | +image_rgb = image |
| 102 | +image_rgb_expanded = np.expand_dims(image_rgb, axis=0) |
| 103 | + |
| 104 | +# Perform the actual detection by running the model with the image as input |
| 105 | +(boxes, scores, classes, num) = sess.run( |
| 106 | + [detection_boxes, detection_scores, detection_classes, num_detections], |
| 107 | + feed_dict={image_tensor: image_rgb_expanded}) |
| 108 | + |
| 109 | +# Draw the results of the detection (aka 'visulaize the results') |
| 110 | +## EVAN, you need to figure out what this STUPID FRICKIN visualization utility |
| 111 | +## is doing so you can get rid of it |
| 112 | + |
| 113 | +vis_util.visualize_boxes_and_labels_on_image_array( |
| 114 | + image_rgb, |
| 115 | + np.squeeze(boxes), |
| 116 | + np.squeeze(classes).astype(np.int32), |
| 117 | + np.squeeze(scores), |
| 118 | + category_index, |
| 119 | + use_normalized_coordinates=True, |
| 120 | + line_thickness=8) |
| 121 | + |
| 122 | +# All the results have been drawn on image_rgb. Convert back to BGR and display. |
| 123 | +#final_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) |
| 124 | +cv2.imshow('Ayy', image_rgb) |
| 125 | + |
| 126 | +cv2.waitKey(0) |
| 127 | +cv2.destroyAllWindows() |
| 128 | + |
0 commit comments