Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
*.jar
*.war
*.ear
.DS_Store
.DS_Store
/target
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.nickferraro.bayesian.report.calc;

import java.util.List;

import com.nickferraro.bayesian.IDataRow;

/**
* The IAccuracyCalculator interface is a simple contract for calculating the accuracy of
* a bayesian system in classifying a list of data rows.
* @author Nick Ferraro
*
* @param <T> The category data type the bayesian system will be using
*/
public interface IAccuracyCalculator<T> {
/**
* Classify the list of data rows and calculate the accuracy of classifying the row category.
* @param dataRows A list of data rows to classify and use in a clean set of calculations.
* @return The accuracy of classifying this data set
*/
public double calculateAccuracy(List<IDataRow<T>> dataRows);

/**
* Classify the list of data rows and calculate the accuracy of classifying the row category.
* @param dataRows A list of data rows to classify and use in a clean or aggregated set of calculations.
* @param cleanSlate Whether or not to aggregate previous calculations with this new data set or start from a clean slate. Defaults to TRUE.
* @return The accuracy of classifying this data set
*/
public double calculateAccuracy(List<IDataRow<T>> dataRows, boolean cleanSlate);

/**
* Get the last accuracy calculated.
* @return The last accuracy calculated.
*/
public double getAccuracy();

/**
* Get the number of correct classifications from the last calculation.
* @return The number of correct classifications
*/
public int getCorrectCount();

/**
* Get the number of incorrect classifications from the last calculation.
* @return The number of incorrect classifications
*/
public int getIncorrectCount();

/**
* Get the number of total classifications calculated
* @return The number of total classifications
*/
public int getTotalCount();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.nickferraro.bayesian.report.calc;

/**
* An interface for calculating a confusion matrix.
* @author Nick Ferraro
*
* @param <T> The category data type of the bayesian model
*/
public interface IConfusionMatrix<T> extends IAccuracyCalculator<T> {
/**
* Get the count value of a matrix cell.
* @param actualCategory The actual category (row)
* @param classifiedCategory The classified category (column)
* @return The count value of the cell. Will never be less than 0.
*/
public int getCellCount(T actualCategory, T classifiedCategory);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package com.nickferraro.bayesian.report.calc.core;

import java.security.InvalidParameterException;
import java.util.List;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import com.nickferraro.bayesian.IBayesianSystem;
import com.nickferraro.bayesian.IClassification;
import com.nickferraro.bayesian.IDataRow;
import com.nickferraro.bayesian.report.calc.IAccuracyCalculator;

/**
* A thread-safe implementation of IAccuracyCalculator. This class will classify data rows
* and return the accuracy rate calculated. Multiple sets of data rows can be classified and aggregated
* into a single accuracy calculation. This implementation does not use the IDataRow id, nor does
* it verify that duplicate ids are only counted once.
* @author Nick Ferraro
*
* @param <T> The category data type the bayesian system will be using
*/
public class AccuracyCalculator<T> implements IAccuracyCalculator<T> {
protected final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock();
protected final Lock readLock = readWriteLock.readLock();
protected final Lock writeLock = readWriteLock.writeLock();
protected IBayesianSystem<T> bayesianSystem;
protected int total = 0;
protected int correct = 0;

/**
* Constructor for AccuracyCalculator. Requires a non-null IBayesianSystem parameter.
* @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL.
* @throws InvalidParameterException Thrown when IBayesianSystem is NULL.
*/
public AccuracyCalculator(IBayesianSystem<T> bayesianSystem) throws InvalidParameterException {
// DRY: Set the bayesian system and reset counts
setBayesianSystem(bayesianSystem);
resetCounts();
}

/**
* Set the IBayesianSystem used for classification.
* @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL.
* @throws InvalidParameterException Thrown when IBayesianSystem is NULL.
*/
public void setBayesianSystem(IBayesianSystem<T> bayesianSystem) throws InvalidParameterException {
// Validate the bayesianSystem parameter
if( bayesianSystem == null ) {
throw new InvalidParameterException("AccuracyCalculator cannot set a NULL IBayesianSystem");
}

// Lock
writeLock.lock();

try {
this.bayesianSystem = bayesianSystem;
} finally {
// Unlock
writeLock.unlock();
}
}

@Override
public double calculateAccuracy(List<IDataRow<T>> dataRows) {
// DRY: Call calculateAccuracy with cleanSlate = TRUE
return calculateAccuracy(dataRows, true);
}

@Override
public double calculateAccuracy(List<IDataRow<T>> dataRows, boolean cleanSlate) {
// Lock
writeLock.lock();

try {
// Call thread-unsafe private method
return _calculateAccuracy(dataRows, cleanSlate);
} finally {
// Unlock
writeLock.unlock();
}
}

@Override
public double getAccuracy() {
// Lock
readLock.lock();

try {
// Call thead-unsafe private method
return _getAccuracy();
} finally {
// Unlock
readLock.unlock();
}
}

@Override
public int getCorrectCount() {
// Lock
readLock.lock();

try {
return correct;
} finally {
// Unlock
readLock.unlock();
}
}

@Override
public int getIncorrectCount() {
// Lock
readLock.lock();

try {
// Calculate and return incorrect
return total - correct;
} finally {
// Unlock
readLock.unlock();
}
}

@Override
public int getTotalCount() {
// Lock
readLock.lock();

try {
return total;
} finally {
// Unlock
readLock.unlock();
}
}

/**
* Reset the current calculations and counts
*/
public void resetCounts() {
// Lock
writeLock.lock();

try {
// Call thread-unsafe private method
_resetCounts();
} finally {
writeLock.unlock();
}
}

/**
* Private thread-unsafe method for calculating the accuracy of a bayesian system.
* @param dataRows The data rows to classify and use in accuracy calculations
* @param cleanSlate Whether or not to aggregate calculations or start from a clean slate
* @return The current accuracy of the bayesian system
*/
protected double _calculateAccuracy(List<IDataRow<T>> dataRows, boolean cleanSlate) {
// Reset our counts if cleanSlate is TRUE
if( cleanSlate ) {
_resetCounts();
}

// Return 0% accuracy for null dataRows
if( dataRows == null ) {
return _getAccuracy();
}

// Iterate over data rows
for(IDataRow<T> dataRow : dataRows) {
// If a row is null, skip it
if( dataRow == null ) {
continue;
}

// Get the first classification and check if it is correct. NULL classifications are ignored.
List<IClassification<T>> classifications = bayesianSystem.classifyRow(dataRow);
if( classifications != null && classifications.size() > 0 ) {
IClassification<T> classification = classifications.get(0);
// Increase the total rows calculated
++total;

// If the classification category matches the data row category, increase the correct count
if( classification != null && classification.getCategory().equals(dataRow.getCategory()) ) {
++correct;
}
}
}

// Return the current accuracy using the thread-unsafe private method
return _getAccuracy();
}

/**
* Private thread-unsafe method for calculating the accuracy
* @return The last calculated accuracy
*/
protected double _getAccuracy() {
return total == 0 ? 0.0 : ((double)correct / (double)total);
}

/**
* Private thread-unsafe method for reseting the calculations and counts
*/
protected void _resetCounts() {
total = 0;
correct = 0;
}
}
Loading