multiple samples for 1 target or pose in ml5js to train a model

42 views Asked by At

Need some help. I'm trying to detect/classify some poses. For better results, I'm trying to train the model via multiple/many images (or snaps from webcam/video) for 1 single pose.

You can check the codes in jsFiddle, but added here too. - https://jsfiddle.net/zahedkamal87/1v8fcuyz/6/

I tried to add multiple samples for the same target/pose. But it only stores 1 (last 1). So, for 1 target how do I add like 10-20 samples?

$(document).ready(function () {
    function get_new_height(width) {
        var screen_width = 1920;
        var screen_height = 1080;
        // (original height / original width) x new width = new height
        var new_height = Math.round((screen_height / screen_width) * width);

        return new_height;
    }

    let video;
    let canvas;
    let poseNet;
    let poseNetOptions;
    let poses = [];
    let pose;
    let skeleton;
    let targetPose;
    let state = "waiting";

    let brain;
    let poseLabel = "Y";

    let webcam_res_x = parseInt($("#tool-preview").width());
    let webcam_res_y = get_new_height(webcam_res_x);

    video = document.getElementById("video");
    canvas = document.getElementById("canvas");
    canvas.width = webcam_res_x;
    canvas.height = webcam_res_y;
    var ctx = canvas.getContext("2d");

    var constraints = {
        video: true,
        audio: false
    };
    var streaming = false;

    video.addEventListener(
        "canplay",
        function (ev) {
            if (!streaming) {
                video.setAttribute("width", webcam_res_x);
                video.setAttribute("height", webcam_res_y);
                streaming = true;
            }
        },
        false
    );

    navigator.mediaDevices
        .getUserMedia(constraints)
        .then(function (stream) {
            video.srcObject = stream;
            video.play();
        })
        .catch(function (err) {
            console.log("An error occurred: ", err);
        });

    poseNetOptions = {
        // flipHorizontal: true
    };

    poseNet = ml5.poseNet(video, poseNetOptions, function () {
        console.log("poseNet ready");
    });

    poseNet.on("pose", function (results) {
        poses = results;
        // console.log(results);
        if (poses.length > 0) {
            pose = poses[0].pose;
            skeleton = poses[0].skeleton;
        }
    });

    let options = {
        inputs: 34,
        outputs: ["poses"],
        task: "classification",
        debug: true
    };

    brain = ml5.neuralNetwork(options);

    function getInputs() {
        let keypoints = poses[0].pose.keypoints;
        let inputs = [];
        for (let i = 0; i < keypoints.length; i++) {
            inputs.push(keypoints[i].position.x);
            inputs.push(keypoints[i].position.y);
        }
        return inputs;
    }

    function trainModel() {
        brain.normalizeData();
        let options = {
            epochs: 50
        };
        brain.train(options, finishedTraining);
    }

    // Begin prediction
    function finishedTraining() {
        classify();
    }

    // Classify
    function classify() {
        if (poses.length > 0) {
            let inputs = getInputs();
            brain.classify(inputs, gotResults);
        }
    }

    function gotResults(error, results) {
        console.log(results);
        if (results) {
            $("#classified").html(
                results[0].label + " " + Math.floor(results[0].confidence * 100) + "%"
            );
        }

        classify();
    }

    function drawCameraIntoCanvas() {
        ctx.drawImage(video, 0, 0, webcam_res_x, webcam_res_y);
        drawKeypoints();
        drawSkeleton();

        window.requestAnimationFrame(drawCameraIntoCanvas);
    }

    drawCameraIntoCanvas();

    function drawKeypoints() {
        for (let i = 0; i < poses.length; i += 1) {
            for (let j = 0; j < poses[i].pose.keypoints.length; j += 1) {
                let keypoint = poses[i].pose.keypoints[j];
                if (keypoint.score > 0.2) {
                    ctx.beginPath();
                    ctx.arc(keypoint.position.x, keypoint.position.y, 10, 0, 2 * Math.PI);
                    ctx.stroke();
                    ctx.strokeStyle = "red";
                    ctx.lineWidth = 3;
                }
            }
        }
    }

    function drawSkeleton() {
        for (let i = 0; i < poses.length; i += 1) {
            for (let j = 0; j < poses[i].skeleton.length; j += 1) {
                let partA = poses[i].skeleton[j][0];
                let partB = poses[i].skeleton[j][1];
                ctx.beginPath();
                ctx.moveTo(partA.position.x, partA.position.y);
                ctx.lineTo(partB.position.x, partB.position.y);
                ctx.stroke();
                ctx.strokeStyle = "red";
                ctx.lineWidth = 3;
            }
        }
    }

    $(document).on("click", ".take-snaps", function (e) {
        e.preventDefault();
        if (poses.length > 0) {
            let target = $(this)
                .closest(".tool-controls-group")
                .find(".input-pose-name")
                .val();
            let inputs = getInputs();
            brain.addData(inputs, [target]);

            var image = canvas.toDataURL("image/png");
            var $image = $("<img/>", {
                class: "snaped-image border p-1 my-1 mx-1",
                style: "max-width: 150px; transform: scaleX(-1);",
                src: image
            });

            $(this).closest(".tool-controls-group").find(".pose-images").append($image);
        }
    });

    $(document).on("click", "#train", function (e) {
        e.preventDefault();
        trainModel();
    });

    $(document).on("click", ".snaped-image", function (e) {
        e.preventDefault();

        var src = $(this).attr("src");

        $("#image-preview").find("img").attr("src", src);
        $("#image-preview").modal("show");
    });
});
.tool {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -ms-flex-wrap: wrap;
    flex-wrap: wrap;
    max-width: 1920px;
    margin: auto;
    padding-top: 3rem;
}

.tool-preview,
.tool-controls {
    position: relative;
}

.tool-preview {
    width: 70%;
}

.tool-preview video,
.tool-preview canvas {
    -webkit-transform: scaleX(-1);
    transform: scaleX(-1);
}

.tool-controls {
    width: 30%;
    padding-left: 2rem;
}
<link href="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.1.3/css/bootstrap.min.css" rel="stylesheet"/>
<script src="https://unpkg.com/[email protected]/dist/ml5.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.1.3/js/bootstrap.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>

<div class="px-3 py-4 text-center bg-dark text-white">
    <h2 class="mb-0 fw-bold">Save Pose</h2>
</div>

<div class="container-fluid">
    <div class="tool">
        <div class="tool-preview" id="tool-preview">
            <canvas id="canvas"></canvas>
            <video id="video" autoplay style="display: none"></video>

            <div class="mt-4 alert alert-info">
                <p class="mb-0" id="classified">No information</p>
            </div>
        </div>

        <div class="tool-controls">
            <form method="post">
                <div class="tool-controls-group mb-4">
                    <div class="mb-3">
                        <label class="form-label">Pose Name</label>
                        <input type="text" class="form-control input-pose-name" />
                    </div>

                    <div class="pose-images text-center"></div>

                    <div class="mt-3 text-end">
                        <button type="button" class="btn btn-primary take-snaps">Add a Snap</button>
                    </div>
                </div>

                <div class="pt-3 border-top text-end">
                    <button type="button" class="btn btn-success" id="train">Train</button>
                </div>
            </form>
        </div>
    </div>
</div>

1

There are 1 answers

0
Kokodoko On

Your question is a bit unclear. Your code seems to be working correctly.

The pose detection will only give you several poses if there are more people detected. Otherwise each frame gives you one pose which is in poses[0].

The way to train a model is to use addData(pose, label) as many times as possible for each pose that you want to learn. Your model will become better the more examples you add for each pose.