Alink漫谈(八) : 二分类评估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何实现

前端之家收集整理的这篇文章主要介绍了Alink漫谈(八) : 二分类评估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何实现前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

Alink漫谈(八) : 二分类评估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何实现

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。二分类评估是对二分类算法的预测结果进行效果评估。本文将剖析Alink中对应代码实现。

0x01 相关概念

如果对本文某些概念有疑惑,可以参见之前文章 [白话解析] 通过实例来梳理概念 :准确率 (Accuracy)、精准率(Precision)、召回率(Recall) 和 F值(F-Measure)

0x02 示例代码

  1. public class EvalBinaryClassExample {
  2. AlgoOperator getData(boolean isBatch) {
  3. Row[] rows = new Row[]{
  4. Row.of("prefix1","{\"prefix1\": 0.9,\"prefix0\": 0.1}"),Row.of("prefix1","{\"prefix1\": 0.8,\"prefix0\": 0.2}"),"{\"prefix1\": 0.7,\"prefix0\": 0.3}"),Row.of("prefix0","{\"prefix1\": 0.75,\"prefix0\": 0.25}"),"{\"prefix1\": 0.6,\"prefix0\": 0.4}")
  5. };
  6. String[] schema = new String[]{"label","detailInput"};
  7. if (isBatch) {
  8. return new MemSourceBatchOp(rows,schema);
  9. } else {
  10. return new MemSourceStreamOp(rows,schema);
  11. }
  12. }
  13. public static void main(String[] args) throws Exception {
  14. EvalBinaryClassExample test = new EvalBinaryClassExample();
  15. BatchOperator batchData = (BatchOperator) test.getData(true);
  16. BinaryClassMetrics metrics = new EvalBinaryClassBatchOp()
  17. .setLabelCol("label")
  18. .setPredictionDetailCol("detailInput")
  19. .linkFrom(batchData)
  20. .collectMetrics();
  21. System.out.println("RocCurve:" + metrics.getRocCurve());
  22. System.out.println("AUC:" + metrics.getAuc());
  23. System.out.println("KS:" + metrics.getKs());
  24. System.out.println("PRC:" + metrics.getPrc());
  25. System.out.println("Accuracy:" + metrics.getAccuracy());
  26. System.out.println("Macro Precision:" + metrics.getMacroPrecision());
  27. System.out.println("Micro Recall:" + metrics.getMicroRecall());
  28. System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());
  29. }
  30. }

程序输出

  1. RocCurve:([0.0,0.0,0.5,1.0,1.0],[0.0,0.3333333333333333,0.6666666666666666,1.0])
  2. AUC:0.8333333333333333
  3. KS:0.6666666666666666
  4. PRC:0.9027777777777777
  5. Accuracy:0.6
  6. Macro Precision:0.3
  7. Micro Recall:0.6
  8. Weighted Sensitivity:0.6

在 Alink 中,二分类评估有批处理,流处理两种实现,下面一一为大家介绍( Alink 复杂之一在于大量精细的数据结构,所以下文会大量打印程序中变量以便大家理解)。

2.1 主要思路

  • 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。

  • 根据输入给positiveBin / negativeBin赋值。positiveBin就是 TP + FP,negativeBin就是 TN + FN。这些是后续计算的基础。

  • 遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算该点的混淆矩阵,tpr,以及rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;

  • 依据曲线内容计算并且存储 AUC/PRC/KS

具体后续还有详细调用关系综述。

0x03 批处理

3.1 EvalBinaryClassBatchOp

EvalBinaryClassBatchOp是二分类评估的实现,功能是计算二分类的评估指标(evaluation metrics)。

输入有两种:

  • label column and predResult column
  • label column and predDetail column。如果有predDetail,则predResult被忽略

我们例子中 "prefix1" 就是 label,"{\"prefix1\": 0.9,\"prefix0\": 0.1}" 就是 predDetail

  1. Row.of("prefix1",\"prefix0\": 0.1}")

具体类摘录如下:

  1. public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp<EvalBinaryClassBatchOp> implements BinaryEvaluationParams <EvalBinaryClassBatchOp>,EvaluationMetricsCollector<BinaryClassMetrics> {
  2. @Override
  3. public BinaryClassMetrics collectMetrics() {
  4. return new BinaryClassMetrics(this.collect().get(0));
  5. }
  6. }

可以看到,其主要工作都是在基类BaseEvalClassBatchOp中完成,所以我们会首先看BaseEvalClassBatchOp。

3.2 BaseEvalClassBatchOp

我们还是从 linkFrom 函数入手,其主要是做了几件事:

  • 获取配置信息
  • 从输入中提取某些列:"label","detailInput"
  • calLabelPredDetailLocal会按照partition分别计算evaluation metrics
  • 综合reduce上述计算结果
  • SaveDataAsParams函数会把最终数值输入到 output table

具体代码如下

  1. @Override
  2. public T linkFrom(BatchOperator<?>... inputs) {
  3. BatchOperator<?> in = checkAndGetFirst(inputs);
  4. String labelColName = this.get(MultiEvaluationParams.LABEL_COL);
  5. String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR);
  6. // Judge the evaluation type from params.
  7. ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());
  8. DataSet<BaseMetricsSummary> res;
  9. switch (type) {
  10. case PRED_DETAIL: {
  11. String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL);
  12. // 从输入中提取某些列:"label","detailInput"
  13. DataSet<Row> data = in.select(new String[] {labelColName,predDetailColName}).getDataSet();
  14. // 按照partition分别计算evaluation metrics
  15. res = calLabelPredDetailLocal(data,positiveValue,binary);
  16. break;
  17. }
  18. ......
  19. }
  20. // 综合reduce上述计算结果
  21. DataSet<BaseMetricsSummary> metrics = res
  22. .reduce(new EvaluationUtil.ReduceBaseMetrics());
  23. // 把最终数值输入到 output table
  24. this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),new String[] {DATA_OUTPUT},new TypeInformation[] {Types.STRING});
  25. return (T)this;
  26. }
  27. // 执行中一些变量如下
  28. labelColName = "label"
  29. predDetailColName = "detailInput"
  30. type = {ClassificationEvaluationUtil$Type@2532} "PRED_DETAIL"
  31. binary = true
  32. positiveValue = null

3.2.0 调用关系综述

因为后续代码调用关系复杂,所以先给出一个调用关系

  • 从输入中提取某些列:"label","detailInput",in.select(new String[] {labelColName,predDetailColName}).getDataSet()。因为可能输入还有其他列,而只有某些列是我们计算需要的,所以只提取这些列。
  • 按照partition分别计算evaluation metrics,即调用 calLabelPredDetailLocal(data,binary);
    • flatMap会从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。
    • reduceGroup主要功能是通过 buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,得到一个 <labels,ID>的map,最后返回是二元组(map,labels),即({prefix1=0,prefix0=1},[prefix1,prefix0])。从后文看,<labels,ID>Map看来是多分类才用到。二分类只用到了labels。
    • mapPartition 分区调用 CalLabelDetailLocal 来计算混淆矩阵,主要是分区调用getDetailStatistics,前文中得到的二元组(map,labels)会作为参数传递进来 。
      • getDetailStatistics 遍历 rows 数据,提取每一个item(比如 "prefix1,{"prefix1": 0.8,"prefix0": 0.2}"),然后通过updateBinaryMetricsSummary累积计算混淆矩阵所需数据。
        • updateBinaryMetricsSummary 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。positiveBin就是 TP + FP,negativeBin就是 TN + FN。
          • 如果某个 sample 为 正例 (positive value) 的概率是 p,则该 sample 对应的 bin index 就是 p * 100000。如果 p 被预测为正例 (positive value) ,则positiveBin[index]++,
          • 否则就是被预测为负例(negative value) ,则negativeBin[index]++。
  • 综合reduce上述计算结果,metrics = res.reduce(new EvaluationUtil.ReduceBaseMetrics());
    • 具体计算是在BinaryMetricsSummary.merge,其作用就是Merge the bins,and add the logLoss。
  • 把最终数值输入到 output table,setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()..);
    • 归并所有BaseMetrics后,得到total BaseMetrics,计算indexes存入params。collector.collect(t.toMetrics().serialize());
      • 实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,然后存储到params。
        • extractMatrixThreCurve函数取出非空的bins,据此计算出ConfusionMatrix array(混淆矩阵),threshold array,rocCurve/recallPrecisionCurve/LiftChart.
          • 遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算:
          • curTrue += positiveBin[index]; curFalse += negativeBin[index];
          • 得到该点的混淆矩阵 new ConfusionMatrix(new long[][] {{curTrue,curFalse},{totalTrue - curTrue,totalFalse - curFalse}});
          • 得到 tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
          • rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;
        • 依据曲线内容计算并且存储 AUC/PRC/KS
        • 生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
        • 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
        • 存储正例样本的度量指标
        • 存储Logloss
        • Pick the middle point where threshold is 0.5.

3.2.1 calLabelPredDetailLocal

函数按照partition分别计算评估指标 evaluation metrics。是的,这代码很短,但是有个地方需要注意。有时候越简单的地方越容易疏漏。容易疏漏点是:

第一行代码的结果 labels 是第二行代码的参数,而并非第二行主体。第二行代码主体和第一行代码主体一样,都是data。

  1. private static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Row> data,final String positiveValue,oolean binary) {
  2. DataSet<Tuple2<Map<String,Integer>,String[]>> labels = data.flatMap(new FlatMapFunction<Row,String>() {
  3. @Override
  4. public void flatMap(Row row,Collector<String> collector) {
  5. TreeMap<String,Double> labelProbMap;
  6. if (EvaluationUtil.checkRowFieldNotNull(row)) {
  7. labelProbMap = EvaluationUtil.extractLabelProbMap(row);
  8. labelProbMap.keySet().forEach(collector::collect);
  9. collector.collect(row.getField(0).toString());
  10. }
  11. }
  12. }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary,positiveValue));
  13. return data
  14. .rebalance()
  15. .mapPartition(new CalLabelDetailLocal(binary))
  16. .withBroadcastSet(labels,LABELS);
  17. }

calLabelPredDetailLocal中具体分为三步骤:

  • 在flatMap会从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。
  • reduceGroup的主要功能是去重 "labels名字",然后给每一个label一个ID,最后结果是一个<labels,ID>Map。
  • mapPartition 是分区调用 CalLabelDetailLocal 来计算混淆矩阵。

下面具体看看。

3.2.1.1 flatMap

在flatMap中,主要是从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。

EvaluationUtil.extractLabelProbMap 作用就是解析输入的json,获得具体detailInput中的信息。

下游算子是reduceGroup,所以Flink runtime会对这些labels自动去重。如果对这部分有兴趣,可以参见我之前介绍reduce的文章。CSDN : [源码解析] Flink的groupBy和reduce究竟做了什么 博客园 : [源码解析] Flink的groupBy和reduce究竟做了什么

程序中变量如下

  1. row = {Row@8922} "prefix1,{"prefix1": 0.9,"prefix0": 0.1}"
  2. fields = {Object[2]@8925}
  3. 0 = "prefix1"
  4. 1 = "{"prefix1": 0.9,"prefix0": 0.1}"
  5. labelProbMap = {TreeMap@9008} size = 2
  6. "prefix0" -> {Double@9015} 0.1
  7. "prefix1" -> {Double@9017} 0.9
  8. labelProbMap.keySet().forEach(collector::collect); //这里发送 "prefix0","prefix1"
  9. collector.collect(row.getField(0).toString()); // 这里发送 "prefix1"
  10. // 因为下一个操作是reduceGroup,所以这些label会被runtime去重
3.2.1.2 reduceGroup

主要功能是通过buildLabelIndexLabelArray去重labels,然后给每一个label一个ID,最后结果是一个<labels,ID>的Map。

  1. reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary,positiveValue));

DistinctLabelIndexMap的作用是从label列和prediction列中,取出所有不同的labels,返回一个<labels,ID>的map,根据后续代码看,这个map是多分类才用到。Get all the distinct labels from label column and prediction column,and return the map of labels and their IDs.

前面已经提到,这里的参数rows已经被自动去重。

  1. public static class DistinctLabelIndexMap implements
  2. GroupReduceFunction<String,Tuple2<Map<String,String[]>> {
  3. ......
  4. @Override
  5. public void reduce(Iterable<String> rows,Collector<Tuple2<Map<String,String[]>> collector) throws Exception {
  6. HashSet<String> labels = new HashSet<>();
  7. rows.forEach(labels::add);
  8. collector.collect(buildLabelIndexLabelArray(labels,binary,positiveValue));
  9. }
  10. }
  11. // 变量为
  12. labels = {HashSet@9008} size = 2
  13. 0 = "prefix1"
  14. 1 = "prefix0"
  15. binary = true

buildLabelIndexLabelArray的作用是给每一个label一个ID,得到一个 <labels,prefix0])。

  1. // Give each label an ID,return a map of label and ID.
  2. public static Tuple2<Map<String,String[]> buildLabelIndexLabelArray(HashSet<String> set,boolean binary,String positiveValue) {
  3. String[] labels = set.toArray(new String[0]);
  4. Arrays.sort(labels,Collections.reverSEOrder());
  5. Map<String,Integer> map = new HashMap<>(labels.length);
  6. if (binary && null != positiveValue) {
  7. if (labels[1].equals(positiveValue)) {
  8. labels[1] = labels[0];
  9. labels[0] = positiveValue;
  10. }
  11. map.put(labels[0],0);
  12. map.put(labels[1],1);
  13. } else {
  14. for (int i = 0; i < labels.length; i++) {
  15. map.put(labels[i],i);
  16. }
  17. }
  18. return Tuple2.of(map,labels);
  19. }
  20. // 程序变量如下
  21. labels = {String[2]@9013}
  22. 0 = "prefix1"
  23. 1 = "prefix0"
  24. map = {HashMap@9014} size = 2
  25. "prefix1" -> {Integer@9020} 0
  26. "prefix0" -> {Integer@9021} 1
3.2.1.3 mapPartition

这里主要功能是分区调用 CalLabelDetailLocal 来为后来计算混淆矩阵做准备。

  1. return data
  2. .rebalance()
  3. .mapPartition(new CalLabelDetailLocal(binary)) //这里是业务所在
  4. .withBroadcastSet(labels,LABELS);

具体工作是 CalLabelDetailLocal 完成的,其作用是分区调用getDetailStatistics

  1. // Calculate the confusion matrix based on the label and predResult.
  2. static class CalLabelDetailLocal extends RichMapPartitionFunction<Row,BaseMetricsSummary> {
  3. private Tuple2<Map<String,String[]> map;
  4. private boolean binary;
  5. @Override
  6. public void open(Configuration parameters) throws Exception {
  7. List<Tuple2<Map<String,String[]>> list = getRuntimeContext().getBroadcastVariable(LABELS);
  8. this.map = list.get(0);// 前文生成的二元组(map,labels)
  9. }
  10. @Override
  11. public void mapPartition(Iterable<Row> rows,Collector<BaseMetricsSummary> collector) {
  12. // 调用到了 getDetailStatistics
  13. collector.collect(getDetailStatistics(rows,map));
  14. }
  15. }

getDetailStatistics 的作用是:初始化分类评估的度量指标 base classification evaluation metrics,累积计算混淆矩阵需要的数据。主要就是遍历 rows 数据,提取每一个item(比如 "prefix1,"prefix0": 0.2}"),然后累积计算混淆矩阵所需数据。

  1. // Initialize the base classification evaluation metrics. There are two cases: BinaryClassMetrics and MultiClassMetrics.
  2. private static BaseMetricsSummary getDetailStatistics(Iterable<Row> rows,String positiveValue,String[]> tuple) {
  3. BinaryMetricsSummary binaryMetricsSummary = null;
  4. MultiMetricsSummary multiMetricsSummary = null;
  5. Tuple2<Map<String,String[]> labelIndexLabelArray = tuple; // 前文生成的二元组(map,labels)
  6. Iterator<Row> iterator = rows.iterator();
  7. Row row = null;
  8. while (iterator.hasNext() && !checkRowFieldNotNull(row)) {
  9. row = iterator.next();
  10. }
  11. Map<String,Integer> labelIndexMap = null;
  12. if (binary) {
  13. // 二分法在这里
  14. binaryMetricsSummary = new BinaryMetricsSummary(
  15. new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER],new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER],labelIndexLabelArray.f1,0L);
  16. } else {
  17. //
  18. labelIndexMap = labelIndexLabelArray.f0; // 前文生成的<labels,ID>Map看来是多分类才用到。
  19. multiMetricsSummary = new MultiMetricsSummary(
  20. new long[labelIndexMap.size()][labelIndexMap.size()],0L);
  21. }
  22. while (null != row) {
  23. if (checkRowFieldNotNull(row)) {
  24. TreeMap<String,Double> labelProbMap = extractLabelProbMap(row);
  25. String label = row.getField(0).toString();
  26. if (ArrayUtils.indexOf(labelIndexLabelArray.f1,label) >= 0) {
  27. if (binary) {
  28. // 二分法在这里
  29. updateBinaryMetricsSummary(labelProbMap,label,binaryMetricsSummary);
  30. } else {
  31. updateMultiMetricsSummary(labelProbMap,labelIndexMap,multiMetricsSummary);
  32. }
  33. }
  34. }
  35. row = iterator.hasNext() ? iterator.next() : null;
  36. }
  37. return binary ? binaryMetricsSummary : multiMetricsSummary;
  38. }
  39. //变量如下
  40. tuple = {Tuple2@9252} "({prefix1=0,prefix0])"
  41. f0 = {HashMap@9257} size = 2
  42. "prefix1" -> {Integer@9264} 0
  43. "prefix0" -> {Integer@9266} 1
  44. f1 = {String[2]@9258}
  45. 0 = "prefix1"
  46. 1 = "prefix0"
  47. row = {Row@9271} "prefix1,"prefix0": 0.2}"
  48. fields = {Object[2]@9276}
  49. 0 = "prefix1"
  50. 1 = "{"prefix1": 0.8,"prefix0": 0.2}"
  51. labelIndexLabelArray = {Tuple2@9240} "({prefix1=0,prefix0])"
  52. f0 = {HashMap@9288} size = 2
  53. "prefix1" -> {Integer@9294} 0
  54. "prefix0" -> {Integer@9296} 1
  55. f1 = {String[2]@9242}
  56. 0 = "prefix1"
  57. 1 = "prefix0"
  58. labelProbMap = {TreeMap@9342} size = 2
  59. "prefix0" -> {Double@9378} 0.1
  60. "prefix1" -> {Double@9380} 0.9

先回忆下混淆矩阵:

预测值 0 预测值 1
真实值 0 TN FP
真实值 1 FN TP

针对混淆矩阵,BinaryMetricsSummary 的作用是Save the evaluation data for binary classification。函数具体计算思路是:

  • 把 [0,1] 分成ClassificationEvaluationUtil.DETAIL_BIN_NUMBER(100000)这么多桶(bin)。所以binaryMetricsSummary的positiveBin/negativeBin分别是两个100000的数组。如果某一个 sample 为 正例(positive value) 的概率是 p,则该 sample 对应的 bin index 就是 p * 100000。如果 p 被预测为正例(positive value) ,则positiveBin[index]++,否则就是被预测为负例(negative value) ,则negativeBin[index]++。positiveBin就是 TP + FP,negativeBin就是 TN + FN。

  • 所以这里会遍历输入,如果某一个输入(以"prefix1",\"prefix0\": 0.1}"为例),0.9 是prefix1(正例) 的概率,0.1 是为prefix0(负例) 的概率。

    • 既然这个算法选择了 prefix1(正例) ,所以就说明此算法是判别成 positive 的,所以在 positiveBin 的 90000 处 + 1。
    • 假设这个算法选择了 prefix0(负例) ,则说明此算法是判别成 negative 的,所以应该在 negativeBin 的 90000 处 + 1。

具体对应我们示例代码的5个采样,分类如下:

  1. Row.of("prefix1",positiveBin 90000处+1
  2. Row.of("prefix1",positiveBin 80000处+1
  3. Row.of("prefix1",positiveBin 70000处+1
  4. Row.of("prefix0",negativeBin 75000处+1
  5. Row.of("prefix0",\"prefix0\": 0.4}") negativeBin 60000处+1

具体代码如下

  1. public static void updateBinaryMetricsSummary(TreeMap<String,Double> labelProbMap,String label,BinaryMetricsSummary binaryMetricsSummary) {
  2. binaryMetricsSummary.total++;
  3. binaryMetricsSummary.logLoss += extractLogloss(labelProbMap,label);
  4. double d = labelProbMap.get(binaryMetricsSummary.labels[0]);
  5. int idx = d == 1.0 ? ClassificationEvaluationUtil.DETAIL_BIN_NUMBER - 1 :
  6. (int)Math.floor(d * ClassificationEvaluationUtil.DETAIL_BIN_NUMBER);
  7. if (idx >= 0 && idx < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER) {
  8. if (label.equals(binaryMetricsSummary.labels[0])) {
  9. binaryMetricsSummary.positiveBin[idx] += 1;
  10. } else if (label.equals(binaryMetricsSummary.labels[1])) {
  11. binaryMetricsSummary.negativeBin[idx] += 1;
  12. } else {
  13. .....
  14. }
  15. }
  16. }
  17. private static double extractLogloss(TreeMap<String,String label) {
  18. Double prob = labelProbMap.get(label);
  19. prob = null == prob ? 0. : prob;
  20. return -Math.log(Math.max(Math.min(prob,1 - LOG_LOSS_EPS),LOG_LOSS_EPS));
  21. }
  22. // 变量如下
  23. ClassificationEvaluationUtil.DETAIL_BIN_NUMBER=100000
  24. // 当 "prefix1",\"prefix0\": 0.1}" 时候
  25. labelProbMap = {TreeMap@9305} size = 2
  26. "prefix0" -> {Double@9331} 0.1
  27. "prefix1" -> {Double@9333} 0.9
  28. d = 0.9
  29. idx = 90000
  30. binaryMetricsSummary = {BinaryMetricsSummary@9262}
  31. labels = {String[2]@9242}
  32. 0 = "prefix1"
  33. 1 = "prefix0"
  34. total = 1
  35. positiveBin = {long[100000]@9263} // 90000处+1
  36. negativeBin = {long[100000]@9264}
  37. logLoss = 0.10536051565782628
  38. // 当 "prefix0",\"prefix0\": 0.4}" 时候
  39. labelProbMap = {TreeMap@9514} size = 2
  40. "prefix0" -> {Double@9546} 0.4
  41. "prefix1" -> {Double@9547} 0.6
  42. d = 0.6
  43. idx = 60000
  44. binaryMetricsSummary = {BinaryMetricsSummary@9262}
  45. labels = {String[2]@9242}
  46. 0 = "prefix1"
  47. 1 = "prefix0"
  48. total = 2
  49. positiveBin = {long[100000]@9263}
  50. negativeBin = {long[100000]@9264} // 60000处+1
  51. logLoss = 1.0216512475319812

3.2.2 ReduceBaseMetrics

ReduceBaseMetrics作用是把局部计算的 BaseMetrics 聚合起来。

  1. DataSet<BaseMetricsSummary> metrics = res
  2. .reduce(new EvaluationUtil.ReduceBaseMetrics());

ReduceBaseMetrics如下

  1. public static class ReduceBaseMetrics implements ReduceFunction<BaseMetricsSummary> {
  2. @Override
  3. public BaseMetricsSummary reduce(BaseMetricsSummary t1,BaseMetricsSummary t2) throws Exception {
  4. return null == t1 ? t2 : t1.merge(t2);
  5. }
  6. }

具体计算是在BinaryMetricsSummary.merge,其作用就是Merge the bins,and add the logLoss。

  1. @Override
  2. public BinaryMetricsSummary merge(BinaryMetricsSummary binaryClassMetrics) {
  3. for (int i = 0; i < this.positiveBin.length; i++) {
  4. this.positiveBin[i] += binaryClassMetrics.positiveBin[i];
  5. }
  6. for (int i = 0; i < this.negativeBin.length; i++) {
  7. this.negativeBin[i] += binaryClassMetrics.negativeBin[i];
  8. }
  9. this.logLoss += binaryClassMetrics.logLoss;
  10. this.total += binaryClassMetrics.total;
  11. return this;
  12. }
  13. // 程序变量是
  14. this = {BinaryMetricsSummary@9316}
  15. labels = {String[2]@9322}
  16. 0 = "prefix1"
  17. 1 = "prefix0"
  18. total = 2
  19. positiveBin = {long[100000]@9320}
  20. negativeBin = {long[100000]@9323}
  21. logLoss = 1.742969305058623

3.2.3 SaveDataAsParams

  1. this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),new TypeInformation[] {Types.STRING});

当归并所有BaseMetrics之后,得到了total BaseMetrics,计算indexes,存入到params。

  1. public static class SaveDataAsParams implements FlatMapFunction<BaseMetricsSummary,Row> {
  2. @Override
  3. public void flatMap(BaseMetricsSummary t,Collector<Row> collector) throws Exception {
  4. collector.collect(t.toMetrics().serialize());
  5. }
  6. }

实际业务在BinaryMetricsSummary.toMetrics中完成,即基于bin的信息计算,得到confusionMatrix array,rocCurve/recallPrecisionCurve/LiftChart等等,然后存储到params。

  1. public BinaryClassMetrics toMetrics() {
  2. Params params = new Params();
  3. // 生成若干曲线,比如rocCurve/recallPrecisionCurve/LiftChart
  4. Tuple3<ConfusionMatrix[],double[],EvaluationCurve[]> matrixThreCurve =
  5. extractMatrixThreCurve(positiveBin,negativeBin,total);
  6. // 依据曲线内容计算并且存储 AUC/PRC/KS
  7. setCurveAreaParams(params,matrixThreCurve.f2);
  8. // 对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
  9. Tuple3<ConfusionMatrix[],EvaluationCurve[]> sampledMatrixThreCurve = sample(
  10. PROBABILITY_INTERVAL,matrixThreCurve);
  11. // 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
  12. setCurvePointsParams(params,sampledMatrixThreCurve);
  13. ConfusionMatrix[] matrices = sampledMatrixThreCurve.f0;
  14. // 存储正例样本的度量指标
  15. setComputationsArrayParams(params,sampledMatrixThreCurve.f1,sampledMatrixThreCurve.f0);
  16. // 存储Logloss
  17. setLoglossParams(params,logLoss,total);
  18. // Pick the middle point where threshold is 0.5.
  19. int middleIndex = getMiddleThresholdIndex(sampledMatrixThreCurve.f1);
  20. setMiddleThreParams(params,matrices[middleIndex],labels);
  21. return new BinaryClassMetrics(params);
  22. }

extractMatrixThreCurve是全文重点。这里是 Extract the bins who are not empty,keep the middle threshold 0.5,然后初始化了 RocCurve,Recall-Precision Curve and Lift Curve,计算出ConfusionMatrix array(混淆矩阵),rocCurve/recallPrecisionCurve/LiftChart.。

  1. /**
  2. * Extract the bins who are not empty,keep the middle threshold 0.5.
  3. * Initialize the RocCurve,Recall-Precision Curve and Lift Curve.
  4. * RocCurve: (FPR,TPR),starts with (0,0). Recall-Precision Curve: (recall,precision),p),p is the precision with the lowest. LiftChart: (TP+FP/total,TP),0). confusion matrix = [TP FP][FN * TN].
  5. *
  6. * @param positiveBin positiveBins.
  7. * @param negativeBin negativeBins.
  8. * @param total sample number
  9. * @return ConfusionMatrix array,rocCurve/recallPrecisionCurve/LiftChart.
  10. */
  11. static Tuple3<ConfusionMatrix[],EvaluationCurve[]> extractMatrixThreCurve(long[] positiveBin,long[] negativeBin,long total) {
  12. ArrayList<Integer> effectiveIndices = new ArrayList<>();
  13. long totalTrue = 0,totalFalse = 0;
  14. // 计算totalTrue,totalFalse,effectiveIndices
  15. for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) {
  16. if (0L != positiveBin[i] || 0L != negativeBin[i]
  17. || i == ClassificationEvaluationUtil.DETAIL_BIN_NUMBER / 2) {
  18. effectiveIndices.add(i);
  19. totalTrue += positiveBin[i];
  20. totalFalse += negativeBin[i];
  21. }
  22. }
  23. // 以我们例子,得到
  24. effectiveIndices = {ArrayList@9273} size = 6
  25. 0 = {Integer@9277} 50000 //这里加入了中间点
  26. 1 = {Integer@9278} 60000
  27. 2 = {Integer@9279} 70000
  28. 3 = {Integer@9280} 75000
  29. 4 = {Integer@9281} 80000
  30. 5 = {Integer@9282} 90000
  31. totalTrue = 3
  32. totalFalse = 2
  33. // 继续初始化,生成若干curve
  34. final int length = effectiveIndices.size();
  35. final int newLen = length + 1;
  36. final double m = 1.0 / ClassificationEvaluationUtil.DETAIL_BIN_NUMBER;
  37. EvaluationCurvePoint[] rocCurve = new EvaluationCurvePoint[newLen];
  38. EvaluationCurvePoint[] recallPrecisionCurve = new EvaluationCurvePoint[newLen];
  39. EvaluationCurvePoint[] liftChart = new EvaluationCurvePoint[newLen];
  40. ConfusionMatrix[] data = new ConfusionMatrix[newLen];
  41. double[] threshold = new double[newLen];
  42. long curTrue = 0;
  43. long curFalse = 0;
  44. // 以我们例子,得到
  45. length = 6
  46. newLen = 7
  47. m = 1.0E-5
  48. // 计算,其中rocCurve,recallPrecisionCurve,liftChart 都可以从代码中看出
  49. for (int i = 1; i < newLen; i++) {
  50. int index = effectiveIndices.get(length - i);
  51. curTrue += positiveBin[index];
  52. curFalse += negativeBin[index];
  53. threshold[i] = index * m;
  54. // 计算出混淆矩阵
  55. data[i] = new ConfusionMatrix(
  56. new long[][] {{curTrue,totalFalse - curFalse}});
  57. double tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
  58. // 比如当 90000 这点,得到 curTrue = 1 curFalse = 0 i = 1 index = 90000 tpr = 0.3333333333333333。totalTrue = 3 totalFalse = 2,
  59. // 我们也知道,TPR = TP / (TP + FN) ,所以可以计算 tpr = 1 / 3
  60. rocCurve[i] = new EvaluationCurvePoint(totalFalse == 0 ? 1.0 : 1.0 * curFalse / totalFalse,tpr,threshold[i]);
  61. recallPrecisionCurve[i] = new EvaluationCurvePoint(tpr,curTrue + curTrue == 0 ? 1.0 : 1.0 * curTrue / (curTrue + curFalse),threshold[i]);
  62. liftChart[i] = new EvaluationCurvePoint(1.0 * (curTrue + curFalse) / total,curTrue,threshold[i]);
  63. }
  64. // 以我们例子,得到
  65. curTrue = 3
  66. curFalse = 2
  67. threshold = {double[7]@9349}
  68. 0 = 0.0
  69. 1 = 0.9
  70. 2 = 0.8
  71. 3 = 0.7500000000000001
  72. 4 = 0.7000000000000001
  73. 5 = 0.6000000000000001
  74. 6 = 0.5
  75. rocCurve = {EvaluationCurvePoint[7]@9315}
  76. 1 = {EvaluationCurvePoint@9440}
  77. x = 0.0
  78. y = 0.3333333333333333
  79. p = 0.9
  80. 2 = {EvaluationCurvePoint@9448}
  81. x = 0.0
  82. y = 0.6666666666666666
  83. p = 0.8
  84. 3 = {EvaluationCurvePoint@9449}
  85. x = 0.5
  86. y = 0.6666666666666666
  87. p = 0.7500000000000001
  88. 4 = {EvaluationCurvePoint@9450}
  89. x = 0.5
  90. y = 1.0
  91. p = 0.7000000000000001
  92. 5 = {EvaluationCurvePoint@9451}
  93. x = 1.0
  94. y = 1.0
  95. p = 0.6000000000000001
  96. 6 = {EvaluationCurvePoint@9452}
  97. x = 1.0
  98. y = 1.0
  99. p = 0.5
  100. recallPrecisionCurve = {EvaluationCurvePoint[7]@9320}
  101. 1 = {EvaluationCurvePoint@9444}
  102. x = 0.3333333333333333
  103. y = 1.0
  104. p = 0.9
  105. 2 = {EvaluationCurvePoint@9453}
  106. x = 0.6666666666666666
  107. y = 1.0
  108. p = 0.8
  109. 3 = {EvaluationCurvePoint@9454}
  110. x = 0.6666666666666666
  111. y = 0.6666666666666666
  112. p = 0.7500000000000001
  113. 4 = {EvaluationCurvePoint@9455}
  114. x = 1.0
  115. y = 0.75
  116. p = 0.7000000000000001
  117. 5 = {EvaluationCurvePoint@9456}
  118. x = 1.0
  119. y = 0.6
  120. p = 0.6000000000000001
  121. 6 = {EvaluationCurvePoint@9457}
  122. x = 1.0
  123. y = 0.6
  124. p = 0.5
  125. liftChart = {EvaluationCurvePoint[7]@9325}
  126. 1 = {EvaluationCurvePoint@9458}
  127. x = 0.2
  128. y = 1.0
  129. p = 0.9
  130. 2 = {EvaluationCurvePoint@9459}
  131. x = 0.4
  132. y = 2.0
  133. p = 0.8
  134. 3 = {EvaluationCurvePoint@9460}
  135. x = 0.6
  136. y = 2.0
  137. p = 0.7500000000000001
  138. 4 = {EvaluationCurvePoint@9461}
  139. x = 0.8
  140. y = 3.0
  141. p = 0.7000000000000001
  142. 5 = {EvaluationCurvePoint@9462}
  143. x = 1.0
  144. y = 3.0
  145. p = 0.6000000000000001
  146. 6 = {EvaluationCurvePoint@9463}
  147. x = 1.0
  148. y = 3.0
  149. p = 0.5
  150. data = {ConfusionMatrix[7]@9339}
  151. 0 = {ConfusionMatrix@9486}
  152. longMatrix = {LongMatrix@9488}
  153. matrix = {long[2][]@9491}
  154. 0 = {long[2]@9492}
  155. 0 = 0
  156. 1 = 0
  157. 1 = {long[2]@9493}
  158. 0 = 3
  159. 1 = 2
  160. rowNum = 2
  161. colNum = 2
  162. labelCnt = 2
  163. total = 5
  164. actualLabelFrequency = {long[2]@9489}
  165. 0 = 3
  166. 1 = 2
  167. predictLabelFrequency = {long[2]@9490}
  168. 0 = 0
  169. 1 = 5
  170. tpCount = 2.0
  171. tnCount = 2.0
  172. fpCount = 3.0
  173. fnCount = 3.0
  174. 1 = {ConfusionMatrix@9435}
  175. longMatrix = {LongMatrix@9469}
  176. matrix = {long[2][]@9472}
  177. 0 = {long[2]@9474}
  178. 0 = 1
  179. 1 = 0
  180. 1 = {long[2]@9475}
  181. 0 = 2
  182. 1 = 2
  183. rowNum = 2
  184. colNum = 2
  185. labelCnt = 2
  186. total = 5
  187. actualLabelFrequency = {long[2]@9470}
  188. 0 = 3
  189. 1 = 2
  190. predictLabelFrequency = {long[2]@9471}
  191. 0 = 1
  192. 1 = 4
  193. tpCount = 3.0
  194. tnCount = 3.0
  195. fpCount = 2.0
  196. fnCount = 2.0
  197. ......
  198. threshold[0] = 1.0;
  199. data[0] = new ConfusionMatrix(new long[][] {{0,0},{totalTrue,totalFalse}});
  200. rocCurve[0] = new EvaluationCurvePoint(0,threshold[0]);
  201. recallPrecisionCurve[0] = new EvaluationCurvePoint(0,recallPrecisionCurve[1].getY(),threshold[0]);
  202. liftChart[0] = new EvaluationCurvePoint(0,threshold[0]);
  203. return Tuple3.of(data,threshold,new EvaluationCurve[] {new EvaluationCurve(rocCurve),new EvaluationCurve(recallPrecisionCurve),new EvaluationCurve(liftChart)});
  204. }

3.2.4 计算混淆矩阵

这里再给大家讲讲混淆矩阵如何计算,这里思路比较绕。

3.2.4.1 原始矩阵

调用之处是:

  1. // 调用之处
  2. data[i] = new ConfusionMatrix(
  3. new long[][] {{curTrue,totalFalse - curFalse}});
  4. // 调用时候各种赋值
  5. i = 1
  6. index = 90000
  7. totalTrue = 3
  8. totalFalse = 2
  9. curTrue = 1
  10. curFalse = 0

得到原始矩阵,以下都有cur,说明只针对当前点来说

curTrue = 1 curFalse = 0
totalTrue - curTrue = 2 totalFalse - curFalse = 2
3.2.4.2 计算标签

后续ConfusionMatrix计算中,由此可以得到

  1. actualLabelFrequency = longMatrix.getColSums();
  2. predictLabelFrequency = longMatrix.getRowSums();
  3. actualLabelFrequency = {long[2]@9322}
  4. 0 = 3
  5. 1 = 2
  6. predictLabelFrequency = {long[2]@9323}
  7. 0 = 1
  8. 1 = 4

可以看出来,Alink算法认为:每列的sum和实际标签有关;每行sum和预测标签有关。

得到新矩阵如下

predictLabelFrequency
curTrue = 1 curFalse = 0 1 = curTrue + curFalse
totalTrue - curTrue = 2 totalFalse - curFalse = 2 4 = total - curTrue - curFalse
actualLabelFrequency 3 = totalTrue 2 = totalFalse

后续计算将要基于这些来计算:

计算中就用到longMatrix 对角线上的数据,即longMatrix(0)(0)和 longMatrix(1)(1)。一定要注意,这里考虑的都是 当前状态 (画重点强调)

longMatrix(0)(0) :curTrue

longMatrix(1)(1) :totalFalse - curFalse

totalFalse :( TN + FN )

totalTrue :( TP + FP )

  1. double numTrueNegative(Integer labelIndex) {
  2. // labelIndex为 0 时候,return 1 + 5 - 1 - 3 = 2;
  3. // labelIndex为 1 时候,return 2 + 5 - 4 - 2 = 1;
  4. return null == labelIndex ? tnCount : longMatrix.getValue(labelIndex,labelIndex) + total - predictLabelFrequency[labelIndex] - actualLabelFrequency[labelIndex];
  5. }
  6. double numTruePositive(Integer labelIndex) {
  7. // labelIndex为 0 时候,return 1; 这个是 curTrue,就是真实标签是True,判别也是True。是TP
  8. // labelIndex为 1 时候,return 2; 这个是 totalFalse - curFalse,总判别错 - 当前判别错。这就意味着“本来判别错了但是当前没有发现”,所以认为在当前状态下,这也算是TP
  9. return null == labelIndex ? tpCount : longMatrix.getValue(labelIndex,labelIndex);
  10. }
  11. double numFalseNegative(Integer labelIndex) {
  12. // labelIndex为 0 时候,return 3 - 1;
  13. // actualLabelFrequency[0] = totalTrue。所以return totalTrue - curTrue,即当前“全部正确”中没有“判别为正确”,这个就可以认为是“判别错了且判别为负”
  14. // labelIndex为 1 时候,return 2 - 2;
  15. // actualLabelFrequency[1] = totalFalse。所以return totalFalse - ( totalFalse - curFalse ) = curFalse
  16. return null == labelIndex ? fnCount : actualLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex,labelIndex);
  17. }
  18. double numFalsePositive(Integer labelIndex) {
  19. // labelIndex为 0 时候,return 1 - 1;
  20. // predictLabelFrequency[0] = curTrue + curFalse。
  21. // 所以 return = curTrue + curFalse - curTrue = curFalse = current( TN + FN ) 这可以认为是判断错了实际是正确标签
  22. // labelIndex为 1 时候,return 4 - 2;
  23. // predictLabelFrequency[1] = total - curTrue - curFalse。
  24. // 所以 return = total - curTrue - curFalse - (totalFalse - curFalse) = totalTrue - curTrue = ( TP + FP ) - currentTP = currentFP
  25. return null == labelIndex ? fpCount : predictLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex,labelIndex);
  26. }
  27. // 最后得到
  28. tpCount = 3.0
  29. tnCount = 3.0
  30. fpCount = 2.0
  31. fnCount = 2.0
3.2.4.3 具体代码
  1. // 具体计算
  2. public ConfusionMatrix(LongMatrix longMatrix) {
  3. longMatrix = {LongMatrix@9297}
  4. 0 = {long[2]@9324}
  5. 0 = 1
  6. 1 = 0
  7. 1 = {long[2]@9325}
  8. 0 = 2
  9. 1 = 2
  10. this.longMatrix = longMatrix;
  11. labelCnt = this.longMatrix.getRowNum();
  12. // 这里就是计算
  13. actualLabelFrequency = longMatrix.getColSums();
  14. predictLabelFrequency = longMatrix.getRowSums();
  15. actualLabelFrequency = {long[2]@9322}
  16. 0 = 3
  17. 1 = 2
  18. predictLabelFrequency = {long[2]@9323}
  19. 0 = 1
  20. 1 = 4
  21. labelCnt = 2
  22. total = 5
  23. total = longMatrix.getTotal();
  24. for (int i = 0; i < labelCnt; i++) {
  25. tnCount += numTrueNegative(i);
  26. tpCount += numTruePositive(i);
  27. fnCount += numFalseNegative(i);
  28. fpCount += numFalsePositive(i);
  29. }
  30. }

0x04 流处理

4.1 示例

Alink原有python示例代码中,Stream部分是没有输出的,因为MemSourceStreamOp没有和时间相关联,而Alink中没有提供基于时间的StreamOperator,所以只能自己仿照MemSourceBatchOp写了一个。虽然代码有些丑,但是至少可以提供输出,这样就能够调试。

4.1.1 主类

  1. public class EvalBinaryClassExampleStream {
  2. AlgoOperator getData(boolean isBatch) {
  3. Row[] rows = new Row[]{
  4. Row.of("prefix1",\"prefix0\": 0.1}")
  5. };
  6. String[] schema = new String[]{"label","detailInput"};
  7. if (isBatch) {
  8. return new MemSourceBatchOp(rows,schema);
  9. } else {
  10. return new TimeMemSourceStreamOp(rows,schema,new EvalBinaryStreamSource());
  11. }
  12. }
  13. public static void main(String[] args) throws Exception {
  14. EvalBinaryClassExampleStream test = new EvalBinaryClassExampleStream();
  15. StreamOperator streamData = (StreamOperator) test.getData(false);
  16. StreamOperator sOp = new EvalBinaryClassStreamOp()
  17. .setLabelCol("label")
  18. .setPredictionDetailCol("detailInput")
  19. .setTimeInterval(1)
  20. .linkFrom(streamData);
  21. sOp.print();
  22. StreamOperator.execute();
  23. }
  24. }

4.1.2 TimeMemSourceStreamOp

这个是我自己炮制的。借鉴了MemSourceStreamOp。

  1. public final class TimeMemSourceStreamOp extends StreamOperator<TimeMemSourceStreamOp> {
  2. public TimeMemSourceStreamOp(Row[] rows,String[] colNames,EvalBinaryStrSource source) {
  3. super(null);
  4. init(source,Arrays.asList(rows),colNames);
  5. }
  6. private void init(EvalBinaryStreamSource source,List <Row> rows,String[] colNames) {
  7. Row first = rows.iterator().next();
  8. int arity = first.getArity();
  9. TypeInformation <?>[] types = new TypeInformation[arity];
  10. for (int i = 0; i < arity; ++i) {
  11. types[i] = TypeExtractor.getForObject(first.getField(i));
  12. }
  13. init(source,colNames,types);
  14. }
  15. private void init(EvalBinaryStreamSource source,TypeInformation <?>[] colTypes) {
  16. DataStream <Row> dastr = MLEnvironmentFactory.get(getMLEnvironmentId())
  17. .getStreamExecutionEnvironment().addSource(source);
  18. StringBuilder sbd = new StringBuilder();
  19. sbd.append(colNames[0]);
  20. for (int i = 1; i < colNames.length; i++) {
  21. sbd.append(",").append(colNames[i]);
  22. }
  23. this.setOutput(dastr,colTypes);
  24. }
  25. @Override
  26. public TimeMemSourceStreamOp linkFrom(StreamOperator<?>... inputs) {
  27. return null;
  28. }
  29. }

4.1.3 Source

定时提供Row,加入了随机数,让概率有变化。

  1. class EvalBinaryStreamSource extends RichSourceFunction[Row] {
  2. override def run(ctx: SourceFunction.SourceContext[Row]) = {
  3. while (true) {
  4. val rdm = Math.random() // 这里加入了随机数,让概率有变化
  5. val rows: Array[Row] = Array[Row](
  6. Row.of("prefix1","{\"prefix1\": " + rdm + ",\"prefix0\": " + (1-rdm) + "}"),\"prefix0\": 0.4}"))
  7. for(row <- rows) {
  8. println(s"当前值:$row")
  9. ctx.collect(row)
  10. }
  11. Thread.sleep(1000)
  12. }
  13. }
  14. override def cancel() = ???
  15. }

4.2 BaseEvalClassStreamOp

Alink流处理类是 EvalBinaryClassStreamOp,主要工作在其基类 BaseEvalClassStreamOp,所以我们重点看后者。

  1. public class BaseEvalClassStreamOp<T extends BaseEvalClassStreamOp<T>> extends StreamOperator<T> {
  2. @Override
  3. public T linkFrom(StreamOperator<?>... inputs) {
  4. StreamOperator<?> in = checkAndGetFirst(inputs);
  5. String labelColName = this.get(MultiEvaluationStreamParams.LABEL_COL);
  6. String positiveValue = this.get(BinaryEvaluationStreamParams.POS_LABEL_VAL_STR);
  7. Integer timeInterval = this.get(MultiEvaluationStreamParams.TIME_INTERVAL);
  8. ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());
  9. DataStream<BaseMetricsSummary> statistics;
  10. switch (type) {
  11. case PRED_RESULT: {
  12. ......
  13. }
  14. case PRED_DETAIL: {
  15. String predDetailColName = this.get(MultiEvaluationStreamParams.PREDICTION_DETAIL_COL);
  16. //
  17. PredDetailLabel eval = new PredDetailLabel(positiveValue,binary);
  18. // 获取输入数据,重点是timeWindowAll
  19. statistics = in.select(new String[] {labelColName,predDetailColName})
  20. .getDataStream()
  21. .timeWindowAll(Time.of(timeInterval,TimeUnit.SECONDS))
  22. .apply(eval);
  23. break;
  24. }
  25. }
  26. // 把各个窗口的数据累积到 totalStatistics,注意,这里是新变量了。
  27. DataStream<BaseMetricsSummary> totalStatistics = statistics
  28. .map(new EvaluationUtil.AllDataMerge())
  29. .setParallelism(1); // 并行度设置为1
  30. // 基于两种 bins 计算&序列化,得到当前的 statistics
  31. DataStream<Row> windowOutput = statistics.map(
  32. new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
  33. // 基于bins计算&序列化,得到累积的 totalStatistics
  34. DataStream<Row> allOutput = totalStatistics.map(
  35. new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0));
  36. // "当前" 和 "累积" 做联合,最终返回
  37. DataStream<Row> union = windowOutput.union(allOutput);
  38. this.setOutput(union,new String[] {ClassificationEvaluationUtil.STATISTICS_OUTPUT,DATA_OUTPUT},new TypeInformation[] {Types.STRING,Types.STRING});
  39. return (T)this;
  40. }
  41. }

具体业务是:

  • PredDetailLabel 会进行去重标签名字 和 累积计算混淆矩阵所需数据
    • buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,最后结果是一个<labels,ID>Map。
    • getDetailStatistics 遍历 rows 数据,提取每一个item(比如 "prefix1,"prefix0": 0.2}"),然后通过updateBinaryMetricsSummary累积计算混淆矩阵所需数据。
  • 根据标签从Window中获取数据 statistics = in.select().getDataStream().timeWindowAll() .apply(eval);
  • EvaluationUtil.AllDataMerge 把各个窗口的数据累积到 totalStatistics 。
  • 得到windowOutput -------- EvaluationUtil.SaveDataStream,对"当前数据statistics"做处理。实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,然后存储到params,并序列化返回Row。
    • extractMatrixThreCurve函数取出非空的bins,据此计算出ConfusionMatrix array(混淆矩阵),rocCurve/recallPrecisionCurve/LiftChart.
    • 依据曲线内容计算并且存储 AUC/PRC/KS
    • 生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
    • 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
    • 存储正例样本的度量指标
    • 存储Logloss
    • Pick the middle point where threshold is 0.5.
  • 得到allOutput -------- EvaluationUtil.SaveDataStream,对"累积数据totalStatistics"做处理。
    • 详细处理流程同windowOutput。
  • windowOutput 和 allOutput 做联合。最终返回 DataStream union = windowOutput.union(allOutput);

4.2.1 PredDetailLabel

  1. static class PredDetailLabel implements AllWindowFunction<Row,BaseMetricsSummary,TimeWindow> {
  2. @Override
  3. public void apply(TimeWindow timeWindow,Iterable<Row> rows,Collector<BaseMetricsSummary> collector) throws Exception {
  4. HashSet<String> labels = new HashSet<>();
  5. // 首先还是获取 labels 名字
  6. for (Row row : rows) {
  7. if (EvaluationUtil.checkRowFieldNotNull(row)) {
  8. labels.addAll(EvaluationUtil.extractLabelProbMap(row).keySet());
  9. labels.add(row.getField(0).toString());
  10. }
  11. }
  12. labels = {HashSet@9757} size = 2
  13. 0 = "prefix1"
  14. 1 = "prefix0"
  15. // 之前介绍过,buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,最后结果是一个<labels,ID>Map。
  16. // getDetailStatistics 遍历 rows 数据,累积计算混淆矩阵所需数据( "TP + FN" / "TN + FP")。
  17. if (labels.size() > 0) {
  18. collector.collect(
  19. getDetailStatistics(rows,buildLabelIndexLabelArray(labels,positiveValue)));
  20. }
  21. }
  22. }

4.2.2 AllDataMerge

EvaluationUtil.AllDataMerge 把各个窗口的数据累积

  1. /**
  2. * Merge data from different windows.
  3. */
  4. public static class AllDataMerge implements MapFunction<BaseMetricsSummary,BaseMetricsSummary> {
  5. private BaseMetricsSummary statistics;
  6. @Override
  7. public BaseMetricsSummary map(BaseMetricsSummary value) {
  8. this.statistics = (null == this.statistics ? value : this.statistics.merge(value));
  9. return this.statistics;
  10. }
  11. }

4.2.3 SaveDataStream

SaveDataStream具体调用函数之前批处理介绍过,实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,存储到params。

这里与批处理不同的是直接就把"构建出的度量信息“返回给用户

  1. public static class SaveDataStream implements MapFunction<BaseMetricsSummary,Row> {
  2. @Override
  3. public Row map(BaseMetricsSummary baseMetricsSummary) throws Exception {
  4. BaseMetricsSummary metrics = baseMetricsSummary;
  5. BaseMetrics baseMetrics = metrics.toMetrics();
  6. Row row = baseMetrics.serialize();
  7. return Row.of(funtionName,row.getField(0));
  8. }
  9. }
  10. // 最后得到的 row 其实就是最终返回给用户的度量信息
  11. row = {Row@10008} "{"PRC":"0.9164636268708667","SensitivityArray":"[0.38461538461538464,0.6923076923076923,1.0]","ConfusionMatrix":"[[13,8],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,1.0]" ...... 还有很多其他的

4.2.4 Union

  1. DataStream<Row> windowOutput = statistics.map(
  2. new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
  3. DataStream<Row> allOutput = totalStatistics.map(
  4. new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0));
  5. DataStream<Row> union = windowOutput.union(allOutput);

最后返回两种统计数据

4.2.4.1 allOutput
  1. all|{"PRC":"0.7341146115890359","SensitivityArray":"[0.3333333333333333,0.7333333333333333,0.8,0.8666666666666667,0.9333333333333333,10],[2,"MacroRecall":"0.43333333333333335","MacroSpecificity":"0.43333333333333335","TruePositiveRateArray":"[0.3333333333333333,"AUC":"0.5666666666666667","MacroAccuracy":"0.52",......

4.2.4.2 windowOutput

  1. window|{"PRC":"0.7638888888888888","ConfusionMatrix":"[[3,2],"AUC":"0.6666666666666666","MacroAccuracy":"0.6","RecallArray":"[0.3333333333333333,"KappaArray":"[0.28571428571428564,-0.15384615384615377,0.1666666666666666,0.5454545454545455,0.0]","MicroFalseNegativeRate":"0.4","WeightedRecall":"0.6","WeightedPrecision":"0.36","Recall":"1.0","MacroPrecision":"0.3",......

0xFF 参考

[[白话解析] 通过实例来梳理概念 :准确率 (Accuracy)、精准率(Precision)、召回率(Recall) 和 F值(F-Measure)](

猜你在找的大数据相关文章