Weighted curve fitting

curve-fitting
development
Tags: #<Tag:0x00007fd541d20438> #<Tag:0x00007fd541d202f8>

#1

I usually use the CurveFitter to do my regressions, but now I am facing the problem of doing weighted regressions for robust estimations like IRLS (iterative reweighted least squares) algorithms. I have searched everywhere in the api and it seems that it is not implemented. Is there any workaround for this?


#2

I would use R for this.

There are some packages offering IRLS:

https://cran.r-project.org/web/packages/robustreg/index.html

https://www.rdocumentation.org/packages/msme/versions/0.5.1/topics/irls

Export your values as *.csv and then import them in R, see:

http://www.r-tutor.com/r-introduction/data-frame/data-import


#3

The problem is that I do these fits inside a Fiji plugin, I will therefore need it inside Fiji itself. Is there any way to call R functionalities within Fiji plugins, or use any other Fiji package to do this? I wonder why at least the weighted regressions are not implemented at all in Fiji.


#4

There is a R scripting interface available for Renjin. However it is hard to tell if the package is full functional in Renjin:

http://packages.renjin.org/

I would ask here:

https://groups.google.com/forum/#!forum/renjin-dev


#5

I manage to use apache commons math3

Here are my abstract classes

import ij.IJ;
import ij.gui.Plot;
import ij.gui.PlotWindow;
import ij.util.Tools;
import java.awt.Color;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.fitting.AbstractCurveFitter;

/**
 *
 * @author alex.vergara
 */
public class RobustFitter {

    public static abstract class MyParametricUnivariateFunction implements ParametricUnivariateFunction {

        String Equation;

        public String getEquation() {
            return Equation;
        }
    }

    public static abstract class MyAbstractCurveFitter extends AbstractCurveFitter {

        double[] params;
        double[] xData, yData, weights;
        double sumY, sumY2, sumRes2;

        public abstract double[] fit(double[] initialGuess, double[] xpoints, double[] ypoints, double[] weights);

        public abstract double f(double x);

        public abstract String getFormula();
        
        public abstract String getName();

        public double[] getParams() {
            return params;
        }

        public int getNumParams() {
            return params.length;
        }

        public double[] getXPoints() {
            return xData;
        }

        public double[] getYPoints() {
            return yData;
        }

        public double[] getResiduals() {
            double[] residuals = new double[xData.length];
            for (int i = 0; i < xData.length; i++) {
                residuals[i] = yData[i] - f(xData[i]);
            }
            return residuals;
        }

        /**
         * calculates the sum of y and y^2
         */
        private void calculateSumYandY2() {
            sumY = 0.0;
            sumY2 = 0.0;
            for (int i = 0; i < yData.length; i++) {
                double y = yData[i];
                sumY += y;
                sumY2 += y * y;
            }
        }
        
        private void getSumResidualsSqr() {
            sumRes2 = 0.0;
            double[] residuals = getResiduals();
            for (int i = 0; i < residuals.length; i++) {
                sumRes2 = residuals[i]*residuals[i];
            }
        }

        public double getRSquared() {
            if (Double.isNaN(sumY)) {
                calculateSumYandY2();
                getSumResidualsSqr();
            }
            double sumMeanDiffSqr = sumY2 - sumY * sumY / yData.length;
            double rSquared = 0.0;
            if (sumMeanDiffSqr > 0.0) {
                rSquared = 1.0 - sumRes2 / sumMeanDiffSqr;
            }
            return rSquared;
        }

        /**
         *
         * @param Title The title of the plot
         * @param xLabel label of x axis
         * @param yLabel label of y axis
         * @param eightBitCalibrationPlot resample image to 8 bit
         */
        public void plot(String Title, String xLabel, String yLabel, boolean eightBitCalibrationPlot) {
            String title = "".equals(Title) ? getFormula() : Title;
            if (getParams().length < getNumParams()) {
                Plot plot = new Plot(title, xLabel, yLabel, xData, yData);
                plot.setColor(Color.RED);
                plot.addLabel(0.02, 0.1, getName());
                plot.show();
                return;
            }
            int npoints = Math.min(Math.max(xData.length, 100), 1000);
            double[] a = Tools.getMinMax(xData);
            double xmin = a[0], xmax = a[1];
            if (eightBitCalibrationPlot) {
                npoints = 256;
                xmin = 0;
                xmax = 255;
            }
            a = Tools.getMinMax(yData);
            double ymin = a[0], ymax = a[1]; //y range of data points
            double[] px = new double[npoints];
            double[] py = new double[npoints];
            double inc = (xmax - xmin) / (npoints - 1);
            double tmp = xmin;
            for (int i = 0; i < npoints; i++) {
                px[i] = tmp;
                tmp += inc;
            }
            for (int i = 0; i < npoints; i++) {
                py[i] = f(px[i]);
            }
            a = Tools.getMinMax(py);
            double dataRange = ymax - ymin;
            ymin = Math.max(ymin - dataRange, Math.min(ymin, a[0])); //expand y range for curve, but not too much
            ymax = Math.min(ymax + dataRange, Math.max(ymax, a[1]));
            Plot plot = new Plot(title, xLabel, yLabel, px, py);
            plot.setLimits(xmin, xmax, ymin, ymax);
            plot.setColor(Color.RED);
            plot.addPoints(xData, yData, PlotWindow.CIRCLE);
            plot.setColor(Color.BLUE);

            StringBuilder legend = new StringBuilder(100);
            legend.append(getName()).append('\n');
            legend.append(getFormula()).append('\n');
            double[] p = getParams();
            int n = getNumParams();
            char pChar = 'a';
            for (int i = 0; i < n; i++) {
                legend.append(pChar).append(" = ").append(IJ.d2s(p[i], 5, 9)).append('\n');
                pChar++;
            }
            legend.append("R^2 = ").append(IJ.d2s(getRSquared(), 4)).append('\n');
            plot.addLabel(0.02, 0.1, legend.toString());
            plot.setColor(Color.BLUE);
            plot.show();
        }
    }

}

and the fitting class is

import java.util.ArrayList;
import java.util.Collection;
import org.apache.commons.math3.fitting.AbstractCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoint;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.iaea.utils.MathUtils;

/**
 *
 * @author alex.vergara
 */
public class SineFit {
    
    static class FSine extends RobustFitter.MyParametricUnivariateFunction {

        @Override
        public double value(double t, double... parameters) {
            //"y = a + b * sin(c * x + d)"
            return parameters[0] + parameters[1] * Math.sin(parameters[2] * t + parameters[3]);
        }

        // Jacobian matrix of the above. In this case, this is just an array of
        // partial derivatives of the above function, with one element for each parameter.
        @Override
        public double[] gradient(double t, double... parameters) {
            final double a = parameters[0];
            final double b = parameters[1];
            final double c = parameters[2];
            final double d = parameters[3];

            return new double[]{
                1,
                Math.sin(c * t + d),
                b * t * Math.cos(c * t + d),
                b * Math.cos(c * t + d)
            };
        }
    }
    
    public static class SineFitter extends RobustFitter.MyAbstractCurveFitter {

        @Override
        public double[] fit(double[] initialGuess, double[] xpoints, double[] ypoints, double[] weights) {
            this.params = initialGuess;
            this.xData = xpoints;
            this.yData = ypoints;
            this.weights = weights;
            ArrayList<WeightedObservedPoint> points = new ArrayList<>();
            for (int i = 0; i < xpoints.length; i++) {
                WeightedObservedPoint point = new WeightedObservedPoint(weights[i], xpoints[i], ypoints[i]);
                points.add(point);
            }
            this.params = fit(points);
            return getParams();
        }

        public double[] fit(double[] xpoints, double[] ypoints, double[] weights) {
            // Using default initialization
            this.params = new double[4];
            this.params[0] = ypoints[0];
            this.params[1] = 0.5 * (MathUtils.Max(ypoints) - MathUtils.Min(ypoints));
            this.params[2] = 2 * Math.PI / ypoints.length;
            this.params[3] = 0.0;
            return fit(params, xpoints, ypoints, weights);
        }

        @Override
        protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points) {

            final AbstractCurveFitter.TheoreticalValuesFunction model = new AbstractCurveFitter.TheoreticalValuesFunction(new FSine(), points);

            return new LeastSquaresBuilder().
                    maxEvaluations(Integer.MAX_VALUE).
                    maxIterations(Integer.MAX_VALUE).
                    start(params).
                    target(yData).
                    weight(new DiagonalMatrix(weights)).
                    model(model.getModelFunction(), model.getModelFunctionJacobian()).
                    build();
        }

        @Override
        public String getFormula() {
            return (new FSine()).getEquation();
        }

        @Override
        public double f(double x) {
            return (new FSine()).value(x, params);
        }

        @Override
        public String getName() {
            return "Sine Fit";
        }

    }
    
}

however, all I get is

org.apache.commons.math3.exception.ConvergenceException: illegal state: unable to perform Q.R decomposition on the 32x4 jacobian matrix
	at org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer.qrDecomposition(LevenbergMarquardtOptimizer.java:974)
	at org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer.optimize(LevenbergMarquardtOptimizer.java:341)
	at org.apache.commons.math3.fitting.AbstractCurveFitter.fit(AbstractCurveFitter.java:63)
	at org.iaea.utils.fitting.SineFit$SineFitter.fit(SineFit.java:72)

Help is appreciated, What is wrong here?


#6

Still no luck on this, does anyone have some hints?


#7

suggests that there is a numerical issue - I’d have a look at the input data.

If your data are all zeros, for instance, that would be bad.

Less obvious cases of data being the problem might be if all your points were scaled versions of each other, for example.

Just an idea,
John


#8

I have checked the input data, the initial unweighted least squares is successfully performed, but when I try the weighted version the ConvergenceException appears again. See the code on where I use the above classes:

                double[] Rx = Fitter.SineFit(it[n], cmx[n], true).getResiduals(); //Performed nicely
                double[] xweights = MathUtils.Normalize(Rx);
                for (int i = 0; i < Rx.length; i++) {
                    xweights[i] = xweights[i] == 0 ? 1 : 1 / xweights[i]; //Weights are never 0
                }
                SineFitter rsf = new SineFitter();
                double params[] = rsf.fit(it[n], cmx[n], xweights); //on this line is where exception is thrown
                Rx = rsf.getResiduals();
                rsf.plot("Sine Fit", "Angle (rad)", "Position (px)", false);

#9

I have managed to make it to work!! The problem was the definition of the weights.

   double[] Rx = Fitter.SineFit(it[n], cmx[n], false).getResiduals();
   double[] xweights = MathUtils.Normalize(Rx);
   for (int i = 0; i < Rx.length; i++) {
        xweights[i] = 1 - xweights[i];
   }
   SineFitter rsf = SineFitter.create(it[n], cmx[n], xweights);
   double params[] = rsf.fit();
   Rx = rsf.getResiduals();

The weights needs to be normalized and their norm >0 for this to work


#10

For the records this is the final abstract class

import ij.IJ;
import ij.gui.Plot;
import ij.gui.PlotWindow;
import ij.util.Tools;
import java.awt.Color;
import java.util.Collection;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.fitting.AbstractCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoint;
import org.apache.commons.math3.fitting.WeightedObservedPoints;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.iaea.utils.MathUtils;

/**
 *
 * @author alex.vergara
 */
public class RobustFitter {

    public static abstract class MyParametricUnivariateFunction implements ParametricUnivariateFunction {

        public abstract String getEquation();
    }

    public static abstract class MyAbstractCurveFitter extends AbstractCurveFitter {

        protected MyParametricUnivariateFunction function;
        protected double[] params;
        protected double[] xData, yData, weights;

        public void fit(double[] initialGuess) {
            this.params = initialGuess.clone();
            WeightedObservedPoints points = new WeightedObservedPoints();
            for (int i = 0; i < xData.length; i++) {
                points.add(new WeightedObservedPoint(weights[i], xData[i], yData[i]));
            }
            this.params = fit(points.toList());
        }

        public abstract void fit();

        public void setWeights(double[] lweights) {
            this.weights = lweights.clone();
        }

        public double f(double x) {
            return function.value(x, params);
        }

        public String getFormula() {
            return function.getEquation();
        }

        public abstract String getName();

        public double[] getParams() {
            return params;
        }

        public int getNumParams() {
            return params.length;
        }

        public double[] getXPoints() {
            return xData;
        }

        public double[] getYPoints() {
            return yData;
        }

        public double[] getResiduals() {
            double[] residuals = new double[xData.length];
            for (int i = 0; i < xData.length; i++) {
                residuals[i] = yData[i] - f(xData[i]);
            }
            return residuals;
        }

        private double getSumResidualsSqr() {
            double[] residuals = getResiduals();
            return MathUtils.Variance(residuals);
        }

        public double getRSquared() {
            double sumMeanDiffSqr = MathUtils.Variance(yData);
            double rSquared = 0.0;
            if (sumMeanDiffSqr > 0.0) {
                rSquared = 1.0 - getSumResidualsSqr() / sumMeanDiffSqr;
            }
            return rSquared;
        }

        /**
         * {@inheritDoc}
         */
        @Override
        protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
            // Prepare least-squares problem.
            final int len = observations.size();
            final double[] target = new double[len];
            final double[] lweights = new double[len];

            int count = 0;
            for (WeightedObservedPoint obs : observations) {
                target[count] = obs.getY();
                lweights[count] = obs.getWeight();
                ++count;
            }

            final AbstractCurveFitter.TheoreticalValuesFunction model
                    = new AbstractCurveFitter.TheoreticalValuesFunction(function,
                            observations);

            // Create an optimizer for fitting the curve to the observed points.
            return new LeastSquaresBuilder().
                    maxEvaluations(Integer.MAX_VALUE).
                    maxIterations(Integer.MAX_VALUE).
                    start(params).
                    target(target).
                    weight(new DiagonalMatrix(lweights)).
                    model(model.getModelFunction(), model.getModelFunctionJacobian()).
                    build();
        }

        /**
         *
         * @param Title The title of the plot
         * @param xLabel label of x axis
         * @param yLabel label of y axis
         * @param eightBitCalibrationPlot resample image to 8 bit
         */
        public void plot(String Title, String xLabel, String yLabel, boolean eightBitCalibrationPlot) {
            String title = "".equals(Title) ? getFormula() : Title;
            if (getParams().length < getNumParams()) {
                Plot plot = new Plot(title, xLabel, yLabel, xData, yData);
                plot.setColor(Color.RED);
                plot.addLabel(0.02, 0.1, getName());
                plot.show();
                return;
            }
            int npoints = Math.min(Math.max(xData.length, 100), 1000);
            double[] a = Tools.getMinMax(xData);
            double xmin = a[0], xmax = a[1];
            if (eightBitCalibrationPlot) {
                npoints = 256;
                xmin = 0;
                xmax = 255;
            }
            a = Tools.getMinMax(yData);
            double ymin = a[0], ymax = a[1]; //y range of data points
            double[] px = new double[npoints];
            double[] py = new double[npoints];
            double inc = (xmax - xmin) / (npoints - 1);
            double tmp = xmin;
            for (int i = 0; i < npoints; i++) {
                px[i] = tmp;
                tmp += inc;
            }
            for (int i = 0; i < npoints; i++) {
                py[i] = f(px[i]);
            }
            a = Tools.getMinMax(py);
            double dataRange = ymax - ymin;
            ymin = Math.max(ymin - dataRange, Math.min(ymin, a[0])); //expand y range for curve, but not too much
            ymax = Math.min(ymax + dataRange, Math.max(ymax, a[1]));
            Plot plot = new Plot(title, xLabel, yLabel, px, py);
            plot.setLimits(xmin, xmax, ymin, ymax);
            plot.setColor(Color.RED);
            plot.addPoints(xData, yData, PlotWindow.CIRCLE);
            plot.setColor(Color.BLUE);

            StringBuilder legend = new StringBuilder(100);
            legend.append(getName()).append('\n');
            legend.append(getFormula()).append('\n');
            double[] p = getParams();
            int n = getNumParams();
            char pChar = 'a';
            for (int i = 0; i < n; i++) {
                legend.append(pChar).append(" = ").append(IJ.d2s(p[i], 5, 9)).append('\n');
                pChar++;
            }
            legend.append("R^2 = ").append(IJ.d2s(getRSquared(), 4)).append('\n');
            plot.addLabel(0.02, 0.1, legend.toString());
            plot.setColor(Color.BLUE);
            plot.show();
        }

        public void initializeIRLS() {
            fit();
            double[] Rx = getResiduals();
            for (int i = 0; i < Rx.length; i++) {
                Rx[i] = Math.abs(Rx[i]);
            }
            double[] xweights = MathUtils.Normalize(Rx);
            for (int i = 0; i < Rx.length; i++) {
                xweights[i] = 1 - xweights[i];
            }
            this.weights = xweights.clone();
        }

        public void runIRLS(int iterations) {
            int iter = 0;
            while (iter < iterations) {
                fit(params);
                if (getRSquared() > 0.9) {
                    break;
                }
                double[] Rx = getResiduals();
                for (int i = 0; i < Rx.length; i++) {
                    Rx[i] = Math.abs(Rx[i]);
                }
                double[] xweights = MathUtils.Normalize(Rx);
                for (int i = 0; i < Rx.length; i++) {
                    xweights[i] = 1 - xweights[i];
                }
                setWeights(xweights);
                ++iter;
            }
        }
    }

}

And an example implementation class

import java.util.Objects;
import org.apache.commons.math3.exception.NoDataException;
import org.iaea.utils.MathUtils;

/**
 *
 * @author alex.vergara
 */
public class SineFit {

    static class FSine extends RobustFitter.MyParametricUnivariateFunction {

        @Override
        public double value(double t, double... parameters) {
            return parameters[0] + parameters[1] * Math.sin(parameters[2] * t + parameters[3]);
        }

        // Jacobian matrix of the above. In this case, this is just an array of
        // partial derivatives of the above function, with one element for each parameter.
        @Override
        public double[] gradient(double t, double... parameters) throws NoDataException {
            final double a = parameters[0];
            final double b = parameters[1];
            final double c = parameters[2];
            final double d = parameters[3];

            final double a1 = 1;
            final double b1 = Math.sin(c * t + d);
            final double c1 = b * t * Math.cos(c * t + d);
            final double d1 = b * Math.cos(c * t + d);

            return new double[]{a1, b1, c1, d1};
        }

        @Override
        public String getEquation() {
            return "y = a + b * sin(c * x + d)";
        }
    }

    public static class SineFitter extends RobustFitter.MyAbstractCurveFitter {

        private SineFitter(double[] xpoints, double[] ypoints, double[] weights) {
            this.xData = xpoints.clone();
            this.yData = ypoints.clone();
            if (Objects.isNull(weights)) {
                this.weights = new double[yData.length];
                for (int i = 0; i < yData.length; i++) {
                    this.weights[i] = 1;
                }
            } else {
                this.weights = weights.clone();
            }
            this.function = new FSine();
        }

        public static SineFitter create(double[] xpoints, double[] ypoints, double[] weights) {
            return new SineFitter(xpoints, ypoints, weights);
        }

        @Override
        public void fit() {
            // Using default initialization
            double[] initialGuess = new double[]{
                yData[0],
                0.5 * (MathUtils.Max(yData) - MathUtils.Min(yData)),
                2 * Math.PI / yData.length,
                0.0
            };
            fit(initialGuess);
        }

        @Override
        public String getName() {
            return "Sine Fit";
        }

    }

}

#11

I have added IRLS technique to the above code, to use it just do:

SineFitter rsf = SineFitter.create(xi, yi, null);
rsf.initializeIRLS();
rsf.runIRLS(100);
double[] Rx = rsf.getResiduals();
double[] params = rsf.getParams();
rsf.plot("Robust Sine Fit", "Angle (rad)", "Position (px) (mm)", false);

I have also get rid of the ImageJ CurveFitter. So now this code only depends on apache commons math.

If there is a Core developer who wish to include these codes into Fiji Core, I can provide more implementation functions. You just need to adapt this to Fiji standards.


#12

It seems the right persons hasn´t read this yet, so I am bumping this up to see if I get an answer.

The original proposition is: “If there is a Core developer who wish to include these codes into Fiji Core, I can provide more implementation functions. You just need to adapt this to Fiji standards.”

I just want to know if being dependent from apache commons math package is a blocker to include this as a Fiji core package. Maybe it is advisable to launch a new thread on this. Anyway, I have created a github project if you want to contribute.