public class OnlineNormalEstimator
extends java.lang.Object
implements gnu.trove.procedure.TLongProcedure
OnlineNormalEstimator
provides an object that estimates means,
variances, and standard deviations for a stream of numbers presented one at a
time. Given a set of samples x[0],...,x[N-1]
, the mean is defined by:
mean(x) = (1/N) * Σi < N x[i]
The variance is defined as the average squared difference from the mean:
var(x) = (1/N) * Σi < N (x[i] - mean(x))2
and the standard deviation is the square root of variance:
dev(x) = sqrt(var(x))
By convention, the mean and variance of a zero-length sequence of numbers are both returned as 0.0.
The above functions provide the maximum likelihood estimates of the mean, variance and standard deviation for a normal distribution generating the values. That is, the estimated parameters are the parameters which assign the observed data sequence the highest probability.
Unfortunately, the maximum likelihood variance and deviation estimates are biased in that they tend to underestimate variance in general. The unbiased estimates adjust counts downward by one, thus adjusting variance and deviation upwards:
varUnbiased(x) = (N / (N-1)) * var(x) devUnbiased(x) = sqrt(varUnbiased(x))
Note that var'(x) >= var(x)
and dev'(x) >= dev(x)
.
Welford's Algorithm
This class use's Welford's algorithm for estimation. This algorithm is far
more numerically stable than either using two passes calculating sums, and
sum of square differences, or using a single pass accumulating the sufficient
statistics, which are the two moments, the sum, and sum of squares of all
entries. The algorithm keeps member variables in the class, and performs the
following update when seeing a new variable x
:
long n = 0; double mu = 0.0; double sq = 0.0; void update(double x) { ++n; double muNew = mu + (x - mu)/n; sq += (x - mu) * (x - muNew) mu = muNew; } double mean() { return mu; } double var() { return n > 1 ? sq/n : 0.0; }
Welford's Algorithm with Deletes
LingPipe extends the Welford's algorithm to support deletes by value. Given
current values of n
, mu
, sq
, and any x
added
at some point, we can compute the previous values of n
, mu
,
sq
. The delete method is:
void delete(double x) { if (n == 0) throw new IllegalStateException(); if (n == 1) { n = 0; mu = 0.0; sq = 0.0; return; } muOld = (n * mu - x) / (n - 1); sq -= (x - mu) * (x - muOld); mu = muOld; --n; }
Because the data are exchangable for mean and variance calculations (that is, permutations of the inputs produce the same mean and variance), the order of removal does not need to match the order of addition.
References
Constructor and Description |
---|
OnlineNormalEstimator()
Construct an instance of an online normal estimator that has seen no
data.
|
Modifier and Type | Method and Description |
---|---|
boolean |
execute(long x)
Add the specified value to the collection of samples for this estimator.
|
long |
getCountBelowZero() |
double |
mean()
Returns the mean of the samples.
|
long |
numSamples()
Returns the number of samples seen by this estimator.
|
double |
standardDeviation()
Returns the maximum likelihood estimate of the standard deviation of the
samples.
|
double |
standardDeviationUnbiased()
Returns the unbiased estimate of the standard deviation of the samples.
|
java.lang.String |
toString()
Returns a string-based representation of the mean and standard deviation
and number of samples for this estimator.
|
void |
unHandle(double x)
Removes the specified value from the sample set.
|
double |
variance()
Returns the maximum likelihood estimate of the variance of the samples.
|
double |
varianceUnbiased()
Returns the unbiased estimate of the variance of the samples.
|
public OnlineNormalEstimator()
public boolean execute(long x)
execute
in interface gnu.trove.procedure.TLongProcedure
x
- Value to add.public void unHandle(double x)
x
- Value to remove from sample.java.lang.IllegalStateException
- If the current number of samples is 0.public long numSamples()
public double mean()
public double variance()
public double varianceUnbiased()
public double standardDeviation()
public double standardDeviationUnbiased()
public java.lang.String toString()
toString
in class java.lang.Object
public long getCountBelowZero()