Using some Java8 features on DecisionTreeLearner and DataSet by AdrianBZG · Pull Request #296 · aimacode/aima-java
Expand Up
@@ -4,6 +4,8 @@
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import aima.core.util.Util;
Expand Down Expand Up @@ -39,31 +41,39 @@ public Example getExample(int number) {
public DataSet removeExample(Example e) { DataSet ds = new DataSet(specification); for (Example eg : examples) { if (!(e.equals(eg))) { ds.add(eg);
// We stream the examples, filter the elements that match the given // example so we don't get them, then we loop over the result (not filtered elements) // and add them to the DataSet (ds) to return it afterwards examples.stream().filter(example -> { if (e.equals(example)) { return false; } } return true; }).forEach(example -> ds.add(example));
return ds; }
public double getInformationFor() { String attributeName = specification.getTarget(); Hashtable<String, Integer> counts = new Hashtable<String, Integer>(); for (Example e : examples) {
String val = e.getAttributeValueAsString(attributeName); examples.stream().forEach(example -> { String val = example.getAttributeValueAsString(attributeName); if (counts.containsKey(val)) { counts.put(val, counts.get(val) + 1); } else { counts.put(val, 1); } } });
// Consider avoiding primitive data types, use wrappers instead to allow the usage // of useful JDK8 features double[] data = new double[counts.keySet().size()]; Iterator<Integer> iter = counts.values().iterator(); Iterator<Integer> iterator = counts.values().iterator(); for (int i = 0; i < data.length; i++) { data[i] = iter.next(); data[i] = iterator.next(); } data = Util.normalize(data);
Expand All @@ -72,30 +82,36 @@ public double getInformationFor() {
public Hashtable<String, DataSet> splitByAttribute(String attributeName) { Hashtable<String, DataSet> results = new Hashtable<String, DataSet>(); for (Example e : examples) { String val = e.getAttributeValueAsString(attributeName);
examples.stream().forEach(example -> { String val = example.getAttributeValueAsString(attributeName); if (results.containsKey(val)) { results.get(val).add(e); results.get(val).add(example); } else { DataSet ds = new DataSet(specification); ds.add(e); ds.add(example); results.put(val, ds); } } });
return results; }
public double calculateGainFor(String parameterName) { Hashtable<String, DataSet> hash = splitByAttribute(parameterName); double totalSize = examples.size(); double remainder = 0.0; for (String parameterValue : hash.keySet()) { double reducedDataSetSize = hash.get(parameterValue).examples .size(); remainder += (reducedDataSetSize / totalSize) * hash.get(parameterValue).getInformationFor(); } return getInformationFor() - remainder;
final AtomicReference<Double> remainder = new AtomicReference<>(); remainder.set(0.0);
hash.keySet().stream() .forEach(parameterValue -> { double reducedDataSetSize = hash.get(parameterValue).examples.size(); remainder.set(remainder.get() + ((reducedDataSetSize / totalSize) * hash.get(parameterValue).getInformationFor())); });
return getInformationFor() - remainder.get(); }
@Override Expand All @@ -121,9 +137,11 @@ public Iterator<Example> iterator() {
public DataSet copy() { DataSet ds = new DataSet(specification); for (Example e : examples) { ds.add(e); }
// We stream the examples, and loop over it's elements to add // them to the DataSet (ds) examples.stream().forEach(example -> ds.add(example));
return ds; }
Expand Down Expand Up @@ -154,12 +172,17 @@ public List<String> getPossibleAttributeValues(String attributeName) {
public DataSet matchingDataSet(String attributeName, String attributeValue) { DataSet ds = new DataSet(specification); for (Example e : examples) { if (e.getAttributeValueAsString(attributeName).equals( attributeValue)) { ds.add(e);
// We stream the examples, don't filter the elements that match the given // attributeName and attributeValue so we get them, then we loop over the result (not filtered elements) // and add them to the DataSet (ds) to return it afterwards examples.stream().filter(example -> { if (example.getAttributeValueAsString(attributeName).equals(attributeValue)) { return false; } } return true; }).forEach(example -> ds.add(example));
return ds; }
Expand Down
import aima.core.util.Util;
Expand Down Expand Up @@ -39,31 +41,39 @@ public Example getExample(int number) {
public DataSet removeExample(Example e) { DataSet ds = new DataSet(specification); for (Example eg : examples) { if (!(e.equals(eg))) { ds.add(eg);
// We stream the examples, filter the elements that match the given // example so we don't get them, then we loop over the result (not filtered elements) // and add them to the DataSet (ds) to return it afterwards examples.stream().filter(example -> { if (e.equals(example)) { return false; } } return true; }).forEach(example -> ds.add(example));
return ds; }
public double getInformationFor() { String attributeName = specification.getTarget(); Hashtable<String, Integer> counts = new Hashtable<String, Integer>(); for (Example e : examples) {
String val = e.getAttributeValueAsString(attributeName); examples.stream().forEach(example -> { String val = example.getAttributeValueAsString(attributeName); if (counts.containsKey(val)) { counts.put(val, counts.get(val) + 1); } else { counts.put(val, 1); } } });
// Consider avoiding primitive data types, use wrappers instead to allow the usage // of useful JDK8 features double[] data = new double[counts.keySet().size()]; Iterator<Integer> iter = counts.values().iterator(); Iterator<Integer> iterator = counts.values().iterator(); for (int i = 0; i < data.length; i++) { data[i] = iter.next(); data[i] = iterator.next(); } data = Util.normalize(data);
Expand All @@ -72,30 +82,36 @@ public double getInformationFor() {
public Hashtable<String, DataSet> splitByAttribute(String attributeName) { Hashtable<String, DataSet> results = new Hashtable<String, DataSet>(); for (Example e : examples) { String val = e.getAttributeValueAsString(attributeName);
examples.stream().forEach(example -> { String val = example.getAttributeValueAsString(attributeName); if (results.containsKey(val)) { results.get(val).add(e); results.get(val).add(example); } else { DataSet ds = new DataSet(specification); ds.add(e); ds.add(example); results.put(val, ds); } } });
return results; }
public double calculateGainFor(String parameterName) { Hashtable<String, DataSet> hash = splitByAttribute(parameterName); double totalSize = examples.size(); double remainder = 0.0; for (String parameterValue : hash.keySet()) { double reducedDataSetSize = hash.get(parameterValue).examples .size(); remainder += (reducedDataSetSize / totalSize) * hash.get(parameterValue).getInformationFor(); } return getInformationFor() - remainder;
final AtomicReference<Double> remainder = new AtomicReference<>(); remainder.set(0.0);
hash.keySet().stream() .forEach(parameterValue -> { double reducedDataSetSize = hash.get(parameterValue).examples.size(); remainder.set(remainder.get() + ((reducedDataSetSize / totalSize) * hash.get(parameterValue).getInformationFor())); });
return getInformationFor() - remainder.get(); }
@Override Expand All @@ -121,9 +137,11 @@ public Iterator<Example> iterator() {
public DataSet copy() { DataSet ds = new DataSet(specification); for (Example e : examples) { ds.add(e); }
// We stream the examples, and loop over it's elements to add // them to the DataSet (ds) examples.stream().forEach(example -> ds.add(example));
return ds; }
Expand Down Expand Up @@ -154,12 +172,17 @@ public List<String> getPossibleAttributeValues(String attributeName) {
public DataSet matchingDataSet(String attributeName, String attributeValue) { DataSet ds = new DataSet(specification); for (Example e : examples) { if (e.getAttributeValueAsString(attributeName).equals( attributeValue)) { ds.add(e);
// We stream the examples, don't filter the elements that match the given // attributeName and attributeValue so we get them, then we loop over the result (not filtered elements) // and add them to the DataSet (ds) to return it afterwards examples.stream().filter(example -> { if (example.getAttributeValueAsString(attributeName).equals(attributeValue)) { return false; } } return true; }).forEach(example -> ds.add(example));
return ds; }
Expand Down