Quadratic Model Estimation With Ransac

719 views Asked by At

I'm trying to use Ransac for model fitting according to this example : https://scikit-image.org/docs/dev/auto_examples/transform/plot_ransac.html#sphx-glr-auto-examples-transform-plot-ransac-py

Acording to https://scikit-image.org/docs/0.13.x/api/skimage.measure.html#skimage.measure.LineModelND ,if I choose model_class as LineModel it'll fit my data with the standard line model y = ax + b. Instead, I want to fit my data with a quadratic funcion y = ax^2 + b*x + c. Is there a way to do that with scikit-image or opencv libraries ?

1

There are 1 answers

0
user185160 On

Here is a simple example:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, RANSACRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import r2_score

# Get example sea-ice data
df = pd.read_csv(
    "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/seaice.csv"
)
df["index1"] = df.index

# Define X and y
start=2950
X = (df[["index1"]].values)[start:start+100]
y = (df[["Extent"]].values)[start:start+100]

# Define RANSAC regressor
ransac = RANSACRegressor(
    LinearRegression(),
    max_trials=100,
    min_samples=50,
    residual_threshold=0.15,
    random_state=0,
)

# Fit RANSAC model to data
quadratic = PolynomialFeatures(degree=2)
X_quad = quadratic.fit_transform(X)
ransac = ransac.fit(X_quad, y)

# Get fitted RANSAC curve
X_fit = np.arange(X.min(), X.max(), 1)[:, np.newaxis]
y_quad_fit = ransac.predict(quadratic.fit_transform(X_fit))

# Get R2 value
quadratic_r2 = r2_score(y, ransac.predict(X_quad))

# Plot inliers
inlier_mask = ransac.inlier_mask_
plt.scatter(X[inlier_mask], y[inlier_mask], c="blue", marker="o", label="Inliers")

# Plot outliers
outlier_mask = np.logical_not(inlier_mask)
plt.scatter(
    X[outlier_mask], y[outlier_mask], c="lightgreen", marker="s", label="Outliers"
)

# Plot fitted RANSAC curve
plt.plot(
    X_fit,
    y_quad_fit,
    label="quadratic (d=2), $R^2=%.2f$" % quadratic_r2,
    color="red",
    lw=2,
    linestyle="-",
)

plt.xlabel("X")
plt.ylabel("Sea ice extent")
plt.legend(loc="upper left")
plt.tight_layout()
plt.show()

enter image description here