Skip to content

Screw Counting

This example demonstrates how to use Tzara, the Telekinesis Physical AI Agent, to build a screw counting pipeline. Tzara generates code that captures an RGB frame from a RealSense D405 camera, detects all visible screws using Grounding DINO open-vocabulary detection, applies Non-Maximum Suppression to handle distinct non-overlapping screws, and saves both the original image and a visualization annotated with bounding boxes and the total count.

The Natural Language Instructions

Prompt 1: I have a metal box with screws that extrude, write a pipeline to detect 
all the visible screws in the image. Visualize and save the result and the count 
in the top, also save the original image.

Prompt 2: All the screws are distinct and not overlapping, handle this.

The Generated Code

The generated code captures an RGB frame from the RealSense D405, runs Grounding DINO with the prompt "a screw ." to localize each screw, applies a low-IoU NMS pass (since the screws are distinct and non-overlapping) to remove duplicate boxes, and writes both the original capture and a visualization with bounding boxes and a count overlay to disk.

python
# Screw detection pipeline using Realsense D405 static camera
# Workflow:
#   1. Capture RGB frame from RealSense D405
#   2. Save original image
#   3. Detect screws using open-vocabulary detector (Grounding DINO with prompt "screw .")
#      Rationale: No specific screw model exists in Telekinesis Retina. Open-vocab is
#      preferred over classical (e.g. circle Hough) because screws on metal surfaces
#      have variable appearance/lighting and we need instance-level detections.
#   4. Visualize detections (bounding boxes + labels) overlaid on the original image
#   5. Save visualization to disk
#   6. Display visualization

import sys
import numpy as np
import cv2
from loguru import logger

from datatypes import datatypes, io
from telekinesis import retina
from telekinesis.medulla.cameras import RealSense

# ============================================================================
# TUNABLE PARAMETERS (edit here)
# ============================================================================
CAMERA_NAME = "d405_static"
CAMERA_SERIAL = None  # None -> first available RealSense
WARMUP_FRAMES = 30    # AE/AWB convergence frames

# Detection
DETECTION_PROMPT = "a screw ."  # Grounding DINO requires dot-separated, lowercase, ending with dot
BOX_THRESHOLD = 0.20
TEXT_THRESHOLD = 0.20

# NMS
IOU_THRESHOLD = 0.3  # Low because screws are distinct/non-overlapping

# Output paths
ORIGINAL_IMAGE_PATH = "captured_rgb.png"
VISUALIZATION_PATH = "screw_detections.png"

# Visualization
BBOX_COLOR_BGR = (0, 255, 0)   # green
BBOX_THICKNESS = 2
LABEL_COLOR_BGR = (0, 0, 0)
LABEL_BG_COLOR_BGR = (0, 255, 0)
FONT = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE = 0.5
FONT_THICKNESS = 1

# Count overlay
COUNT_TEXT_POSITION = (10, 10)  # top-left margin (x, y)
COUNT_FONT_SCALE = 0.8
COUNT_FONT_THICKNESS = 2
COUNT_TEXT_COLOR_BGR = (255, 255, 255)  # white
COUNT_BG_COLOR_BGR = (0, 0, 0)  # black

DISPLAY_WINDOW_NAME = "Screw Detections"
# ============================================================================


def main() -> int:
    camera = RealSense(name=CAMERA_NAME, serial_number=CAMERA_SERIAL)
    rgb_image = None
    try:
        # 1. Connect & capture
        logger.info("Connecting to RealSense D405...")
        camera.connect(warmup_frames=WARMUP_FRAMES)

        logger.info("Capturing RGB frame...")
        rgb_image = camera.capture_single_color_frame()
        if rgb_image is None:
            logger.error("Failed to capture RGB frame from RealSense.")
            return 1
        logger.info(f"Captured RGB frame with shape {rgb_image.shape}, dtype {rgb_image.dtype}.")
    except Exception as e:
        logger.exception(f"Error during camera capture: {e}")
        try:
            camera.disconnect()
        except Exception:
            pass
        return 1
    finally:
        # We can disconnect immediately — we have the frame in memory.
        try:
            camera.disconnect()
            logger.info("Camera disconnected.")
        except Exception as e:
            logger.warning(f"Error during camera disconnect: {e}")

    # 2. Save original captured RGB image
    # RealSense.capture_single_color_frame returns RGB; convert to BGR for cv2 saving.
    try:
        image_dt = datatypes.Image(image=rgb_image, color_model="RGB")
        ok = io.save_image(image_dt, ORIGINAL_IMAGE_PATH, as_bgr_for_opencv=True)
        if not ok:
            logger.error(f"Failed to save original image to {ORIGINAL_IMAGE_PATH}.")
            return 1
        logger.info(f"Saved original RGB image to {ORIGINAL_IMAGE_PATH}.")
    except Exception as e:
        logger.exception(f"Error saving original image: {e}")
        return 1

    # 3. Detect screws using Grounding DINO (open-vocabulary)
    # Preference order requested: specific > open-vocab > classical.
    # Telekinesis Retina has no specific screw model; YOLOX/RF-DETR are COCO classes
    # which do not include "screw". Hence we use Grounding DINO with prompt "screw .".
    try:
        logger.info(f"Running Grounding DINO with prompt: {DETECTION_PROMPT!r}")
        annotations, categories = retina.detect_objects_using_grounding_dino(
            image=rgb_image,
            prompt=DETECTION_PROMPT,
            box_threshold=BOX_THRESHOLD,
            text_threshold=TEXT_THRESHOLD,
        )
        anno_list = annotations.to_list()
        cat_list = categories.to_list()
        logger.info(f"Detected {len(anno_list)} screw candidates.")
    except Exception as e:
        logger.exception(f"Detection failed: {e}")
        return 1

    # Apply Non-Maximum Suppression to remove overlapping bounding boxes.
    # Retina API reference does not expose an NMS function, so use cv2.dnn.NMSBoxes
    # on the COCO bboxes ([x, y, w, h]) with their scores.
    if len(anno_list) > 0:
        nms_bboxes = []
        nms_scores = []
        for ann in anno_list:
            bbox = ann.get("bbox", None)
            score = ann.get("score", None)
            if bbox is None or len(bbox) < 4 or score is None:
                continue
            nms_bboxes.append([float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])])
            try:
                nms_scores.append(float(score))
            except Exception:
                nms_scores.append(0.0)

        if len(nms_bboxes) > 0:
            keep = cv2.dnn.NMSBoxes(nms_bboxes, nms_scores, BOX_THRESHOLD, IOU_THRESHOLD)
            if len(keep) > 0:
                keep_indices = [int(i) for i in np.array(keep).flatten()]
                anno_list = [anno_list[i] for i in keep_indices]
            else:
                anno_list = []
        logger.info(f"{len(anno_list)} detections remained after NMS (IOU_THRESHOLD={IOU_THRESHOLD}).")

    # Build category id -> name map for labels
    cat_name_by_id = {c["id"]: c.get("name", "screw") for c in cat_list}

    # 4. Visualize detections on the original image (work in BGR for cv2 drawing)
    vis_bgr = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR).copy()
    for i, ann in enumerate(anno_list):
        bbox = ann.get("bbox", None)  # COCO format: [x, y, w, h]
        if bbox is None or len(bbox) < 4:
            continue
        x, y, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
        x1, y1 = int(round(x)), int(round(y))
        x2, y2 = int(round(x + w)), int(round(y + h))

        cv2.rectangle(vis_bgr, (x1, y1), (x2, y2), BBOX_COLOR_BGR, BBOX_THICKNESS)

        cat_id = ann.get("category_id", None)
        score = ann.get("score", None)
        name = cat_name_by_id.get(cat_id, "screw")
        label = f"{name} #{i}"
        if score is not None:
            try:
                label += f" {float(score):.2f}"
            except Exception:
                pass

        (tw, th), baseline = cv2.getTextSize(label, FONT, FONT_SCALE, FONT_THICKNESS)
        ly1 = max(0, y1 - th - baseline - 2)
        ly2 = ly1 + th + baseline + 2
        lx2 = x1 + tw + 4
        cv2.rectangle(vis_bgr, (x1, ly1), (lx2, ly2), LABEL_BG_COLOR_BGR, -1)
        cv2.putText(
            vis_bgr,
            label,
            (x1 + 2, ly2 - baseline),
            FONT,
            FONT_SCALE,
            LABEL_COLOR_BGR,
            FONT_THICKNESS,
            cv2.LINE_AA,
        )

    # Draw count overlay in top-left corner
    count_label = f"Screws detected: {len(anno_list)}"
    (ctw, cth), cbaseline = cv2.getTextSize(count_label, FONT, COUNT_FONT_SCALE, COUNT_FONT_THICKNESS)
    cx, cy = COUNT_TEXT_POSITION
    cbg_x1 = cx
    cbg_y1 = cy
    cbg_x2 = cx + ctw + 4
    cbg_y2 = cy + cth + cbaseline + 4
    cv2.rectangle(vis_bgr, (cbg_x1, cbg_y1), (cbg_x2, cbg_y2), COUNT_BG_COLOR_BGR, -1)
    cv2.putText(
        vis_bgr,
        count_label,
        (cx + 2, cbg_y2 - cbaseline - 2),
        FONT,
        COUNT_FONT_SCALE,
        COUNT_TEXT_COLOR_BGR,
        COUNT_FONT_THICKNESS,
        cv2.LINE_AA,
    )

    # 5. Save visualization
    try:
        ok = cv2.imwrite(VISUALIZATION_PATH, vis_bgr)
        if not ok:
            logger.error(f"Failed to save visualization to {VISUALIZATION_PATH}.")
            return 1
        logger.info(f"Saved visualization to {VISUALIZATION_PATH}.")
    except Exception as e:
        logger.exception(f"Error saving visualization: {e}")
        return 1

    # 6. Display
    try:
        cv2.imshow(DISPLAY_WINDOW_NAME, vis_bgr)
        logger.info("Displaying visualization. Press any key in the image window to close.")
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    except Exception as e:
        # Display is best-effort (e.g. headless environments).
        logger.warning(f"Could not display visualization (headless?): {e}")

    return 0


if __name__ == "__main__":
    sys.exit(main())