Object tracking using mediapipe for web

216 views Asked by At

I am using mediapipe javascript api and able to do object detection with good results. I could not find any api related to the ability to track objects uniquely... Am I'm missing something? Can you suggest how to do object tracking?

I want to count objects passing a line.

this is my current code:

import { FilesetResolver, ObjectDetector } from "@mediapipe/tasks-vision";
import {
  FiberManualRecordRounded,
  MonitorHeartRounded,
  SyncRounded,
  VideocamRounded,
} from "@mui/icons-material";
import {
  AppBar,
  BottomNavigation,
  BottomNavigationAction,
  Box,
  Paper,
  Toolbar,
  Typography,
} from "@mui/material";
import { useRouter } from "next/router";
import { useEffect, useRef, useState } from "react";
import BoundingBoxHighlighter from "./BoundingBoxHighlighter";
import MonitorDialog from "./MonitorDialog";

const ObjectDetection = () => {
  const [action, setAction] = useState(0);
  const [openDialog, setOpenDialog] = useState(false);
  const [pairingInfo, setPairingInfo] = useState(null);
  const router = useRouter();

  /* OBJECT DETECTION CODE START HERE*/
  const [aspectRatio, setAspectRatio] = useState(null);
  const [objectDetector, setObjectDetector] = useState(null);
  const [detections, setDetections] = useState([]);
  const [runningMode, setRunningMode] = useState("VIDEO");
  const videoRef = useRef(null);
  const videoDimensionsRef = useRef({});
  const [count, setCount] = useState(0);
  const prevDetectionsRef = useRef([]);

  const detectionBufferRef = useRef([]); // A buffer to store recent detections
  const countedIdsRef = useRef(new Set()); // To store identifiers of counted detections

  useEffect(() => {
    const currentDetections = detections.detections || [];

    // Sort detections by their centerY values
    const sortedDetections = currentDetections.sort(
      (a, b) => a.boundingBox.originY - b.boundingBox.originY
    );

    // Push the current sorted detections into the buffer
    detectionBufferRef.current.push(sortedDetections);

    // Keep the buffer size to a maximum of 10
    if (detectionBufferRef.current.length > 10) {
      detectionBufferRef.current.shift();
    }

    // Check each detection's centerY value in the buffer
    // If it consistently moves past the 50% line, count it
    detectionBufferRef.current.forEach((bufferedDetections) => {
      bufferedDetections.forEach((detection) => {
        const normalizedCenterY =
          (detection.boundingBox.originY + detection.boundingBox.height / 2) /
          videoDimensionsRef.current.actualHeight;

        // Generate a unique identifier for the detection
        const detectionId = `${detection.boundingBox.originX.toFixed(
          2
        )}-${detection.boundingBox.originY.toFixed(2)}`;

        // If the detection hasn't been counted yet and its center is past the counting line
        if (
          !countedIdsRef.current.has(detectionId) &&
          normalizedCenterY > 0.5
        ) {
          countedIdsRef.current.add(detectionId);
          setCount((prevCount) => prevCount + 1);
        }
      });
    });
  }, [detections]);

  // initialize object detector on load
  useEffect(() => {
    async function initializeObjectDetector() {
      const vision = await FilesetResolver.forVisionTasks(
        "https://cdn.jsdelivr.net/npm/@mediapipe/[email protected]/wasm"
      );
      const detector = await ObjectDetector.createFromOptions(vision, {
        baseOptions: {
          modelAssetPath: `https://storage.googleapis.com/mediapipe-models/object_detector/efficientdet_lite0/float16/1/efficientdet_lite0.tflite`,
          delegate: "GPU",
        },
        scoreThreshold: 0.5,
        runningMode: runningMode,
        categoryAllowlist: ["cup"],
      });
      setObjectDetector(detector);
    }
    initializeObjectDetector();
  }, []);

  useEffect(() => {
    if (videoRef.current) {
      if (hasGetUserMedia()) {
        videoRef.current.addEventListener("loadeddata", predictWebcam);
      }
    }

    return () => {
      if (videoRef.current) {
        videoRef.current.removeEventListener("loadeddata", predictWebcam);
      }
      if (videoRef.current && videoRef.current.srcObject) {
        const tracks = videoRef.current.srcObject.getTracks();
        tracks.forEach((track) => track.stop());
      }
    };
  }, [objectDetector]);

  function hasGetUserMedia() {
    return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
  }

  const [facingMode, setFacingMode] = useState(false); // true is user false is environment

  const enableCam = async () => {
    if (!objectDetector) {
      console.log("Wait! objectDetector not loaded yet.");
      return;
    }

    try {
      const stream = await navigator.mediaDevices.getUserMedia({
        video: {
          facingMode: { exact: "environment" },
        },
        /* video: true, */
      });
      videoRef.current.srcObject = stream;
      // Get video's actual width and height after it's loaded.
      videoRef.current.onloadedmetadata = (e) => {
        const videoWidth = e.target.videoWidth;
        const videoHeight = e.target.videoHeight;

        // Calculate aspect ratio and set it
        setAspectRatio(videoHeight / videoWidth);

        // Store dimensions in useRef.
        videoDimensionsRef.current = {
          actualWidth: videoWidth,
          actualHeight: videoHeight,
        };
      };
    } catch (err) {
      console.error(err);
    }
  };

  // set detections from video
  const predictWebcam = async () => {
    if (runningMode === "IMAGE") {
      setRunningMode("VIDEO");
      await objectDetector.setOptions({
        runningMode: "VIDEO",
      });
    }
    let startTimeMs = performance.now();
    const newDetections = await objectDetector.detectForVideo(
      videoRef.current,
      startTimeMs
    );
    setDetections(newDetections);
    window.requestAnimationFrame(predictWebcam);
  };

  /* OBJECT DETECTION CODE END */

  /* INITIALIZATION CODE START*/
  useEffect(() => {
    sequence();
  }, []);

  // get pairing info from local storage
  const getPairingInfoLocal = async () => {
    if (typeof window !== "undefined" && window.localStorage) {
      const _pairingInfoLocal = localStorage.getItem("pairingInfo");
      const pairingInfoLocal = JSON.parse(_pairingInfoLocal);
      return pairingInfoLocal;
    }
  };

  const sequence = async () => {
    /* 1. LOOKING FOR MODEL, API_KEY AND DATA_SOURCE_ID in localStorage */
    const pairingInfoLocal = await getPairingInfoLocal();
    if (pairingInfoLocal) {
      setPairingInfo(pairingInfoLocal);
    } else {
      /* 2. SCAN QR CODE */
      router.push("/onboarding");
    }
  };
  /* INITIALIZATION CODE END */

  return (
    <>
      <AppBar position="fixed" color="default">
        <Toolbar>
          <Typography>{pairingInfo?.dataSourceName}</Typography>
        </Toolbar>
      </AppBar>
      <Box
        sx={{
          display: "flex",
          flexDirection: "column",
          justifyContent: "center",
          height: "calc(100vh - 64px - 56px)", // Subtract the AppBar and Bottom Navigation heights
          marginTop: "64px", // Offset by the AppBar's height
          position: "relative",
        }}
      >
        <Box
          sx={{
            position: "relative",
            width: "100%",
            flex: 1, // Make sure it takes the remaining space in the container
            paddingBottom: `${aspectRatio * 100}%`, // Dynamic based on aspect ratio
            overflow: "hidden",
          }}
        >
          <Box
            component="video"
            sx={{
              position: "absolute",
              top: 0,
              left: 0,
              width: "100%",
              height: "100%",
              objectFit: "cover",
            }}
            autoPlay
            playsInline
            ref={videoRef}
          ></Box>
          {/* Counting Line */}
          <Box
            sx={{
              position: "absolute",
              top: "50%",
              left: 0,
              width: "100%",
              height: "2px", // Thickness of the line
              backgroundColor: "yellow", // Color of the line
            }}
          ></Box>
          <Typography
            sx={{
              position: "absolute",
              bottom: "100px",
              backgroundColor: "orange",
            }}
          >
            Count: {count} {/* <-- Display the count here */}
          </Typography>
          {detections &&
            detections.detections &&
            detections.detections.map((detection, idx) => (
              <BoundingBoxHighlighter
                key={idx}
                detection={detection}
                videoRef={videoRef}
              />
            ))}
        </Box>
      </Box>

      <Paper
        sx={{ position: "fixed", bottom: 0, left: 0, right: 0 }}
        elevation={3}
      >
        <BottomNavigation
          showLabels
          value={action}
          onChange={(event, newValue) => {
            setAction(newValue);
          }}
        >
          {/* enable webcam  */}
          <BottomNavigationAction
            onClick={enableCam}
            label="Start Video"
            icon={<VideocamRounded />}
          />
          {/* TODO: call api to post the detections counter  */}
          <BottomNavigationAction
            label="Transmit"
            icon={<FiberManualRecordRounded />}
          />
          {/* open dialog to show detections on top of video  */}
          <BottomNavigationAction
            label="Monitor"
            onClick={() => {
              setOpenDialog((prev) => !prev);
            }}
            icon={<MonitorHeartRounded />}
          />
          {/* TODO: call api to sync pairing info  */}
          <BottomNavigationAction label="Sync" icon={<SyncRounded />} />
        </BottomNavigation>
      </Paper>
      {/* detections dialog */}
      <MonitorDialog
        openDialog={openDialog}
        setOpenDialog={setOpenDialog}
        pairingInfo={pairingInfo}
        detections={detections}
      />
    </>
  );
};
export default ObjectDetection;

Im trying to track and count objects but its over counting and consuming too much resources. so im thinking of integrating opencv.js... any suggestions?

0

There are 0 answers