Evaluate data frame analytics API

edit

Evaluates the data frame analytics for an annotated index. The API accepts an EvaluateDataFrameRequest object and returns an EvaluateDataFrameResponse.

Evaluate data frame analytics request

edit
EvaluateDataFrameRequest request =
    new EvaluateDataFrameRequest( 
        indexName, 
        new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), 
        evaluation); 

Constructing a new evaluation request

Reference to an existing index

The query with which to select data from indices

Evaluation to be performed

Evaluation

edit

Evaluation to be performed. Currently, supported evaluations include: OutlierDetection, Classification, Regression.

Outlier detection

edit
Evaluation evaluation =
    new OutlierDetection( 
        "label", 
        "p", 
        // Evaluation metrics 
        org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6), 
        org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7), 
        ConfusionMatrixMetric.at(0.5), 
        org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()); 

Constructing a new evaluation

Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false.

Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive.

The remaining parameters are the metrics to be calculated based on the two fields described above

Precision calculated at thresholds: 0.4, 0.5 and 0.6

Recall calculated at thresholds: 0.5 and 0.7

Confusion matrix calculated at threshold 0.5

AuC ROC calculated and the curve points returned

Classification

edit
Evaluation evaluation =
    new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( 
        "actual_class", 
        "predicted_class", 
        "ml.top_classes", 
        // Evaluation metrics 
        new AccuracyMetric(), 
        new PrecisionMetric(), 
        new RecallMetric(), 
        new MulticlassConfusionMatrixMetric(3), 
        AucRocMetric.forClass("cat")); 

Constructing a new evaluation

Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.

Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.

Name of the field in the index. Its value denotes the array of top classes. Must be nested.

The remaining parameters are the metrics to be calculated based on the two fields described above

Accuracy

Precision

Recall

Multiclass confusion matrix of size 3

AuC ROC calculated for class "cat" treated as positive and the rest as negative

Regression

edit
Evaluation evaluation =
    new org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression( 
        "actual_value", 
        "predicted_value", 
        // Evaluation metrics 
        new MeanSquaredErrorMetric(), 
        new MeanSquaredLogarithmicErrorMetric(1.0), 
        new HuberMetric(1.0), 
        new RSquaredMetric()); 

Constructing a new evaluation

Name of the field in the index. Its value denotes the actual (i.e. ground truth) value for an example.

Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) value for the example.

The remaining parameters are the metrics to be calculated based on the two fields described above

Mean squared error

Mean squared logarithmic error

Pseudo Huber loss

R squared

Synchronous execution

edit

When executing a EvaluateDataFrameRequest in the following manner, the client waits for the EvaluateDataFrameResponse to be returned before continuing with code execution:

EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);

Synchronous calls may throw an IOException in case of either failing to parse the REST response in the high-level REST client, the request times out or similar cases where there is no response coming back from the server.

In cases where the server returns a 4xx or 5xx error code, the high-level client tries to parse the response body error details instead and then throws a generic ElasticsearchException and adds the original ResponseException as a suppressed exception to it.

Asynchronous execution

edit

Executing a EvaluateDataFrameRequest can also be done in an asynchronous fashion so that the client can return directly. Users need to specify how the response or potential failures will be handled by passing the request and a listener to the asynchronous evaluate-data-frame method:

client.machineLearning().evaluateDataFrameAsync(request, RequestOptions.DEFAULT, listener); 

The EvaluateDataFrameRequest to execute and the ActionListener to use when the execution completes

The asynchronous method does not block and returns immediately. Once it is completed the ActionListener is called back using the onResponse method if the execution successfully completed or using the onFailure method if it failed. Failure scenarios and expected exceptions are the same as in the synchronous execution case.

A typical listener for evaluate-data-frame looks like:

ActionListener<EvaluateDataFrameResponse> listener = new ActionListener<EvaluateDataFrameResponse>() {
    @Override
    public void onResponse(EvaluateDataFrameResponse response) {
        
    }

    @Override
    public void onFailure(Exception e) {
        
    }
};

Called when the execution is successfully completed.

Called when the whole EvaluateDataFrameRequest fails.

Response

edit

The returned EvaluateDataFrameResponse contains the requested evaluation metrics.

List<EvaluationMetric.Result> metrics = response.getMetrics(); 

Fetching all the calculated metrics results

Results

edit

Outlier detection

edit
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
    response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME); 
double precision = precisionResult.getScoreByThreshold("0.4"); 

ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); 
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); 

Fetching precision metric by name

Fetching precision at a given (0.4) threshold

Fetching confusion matrix metric by name

Fetching confusion matrix at a given (0.5) threshold

Classification

edit
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); 
double accuracy = accuracyResult.getOverallAccuracy(); 

PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); 
double precision = precisionResult.getAvgPrecision(); 

RecallMetric.Result recallResult = response.getMetricByName(RecallMetric.NAME); 
double recall = recallResult.getAvgRecall(); 

MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
    response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); 

List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); 
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); 

AucRocResult aucRocResult = response.getMetricByName(AucRocMetric.NAME); 
double aucRocScore = aucRocResult.getValue(); 

Fetching accuracy metric by name

Fetching the actual accuracy value

Fetching precision metric by name

Fetching the actual precision value

Fetching recall metric by name

Fetching the actual recall value

Fetching multiclass confusion matrix metric by name

Fetching the contents of the confusion matrix

Fetching the number of classes that were not included in the matrix

Fetching AucRoc metric by name

Fetching the actual AucRoc score

Regression

edit
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); 
double meanSquaredError = meanSquaredErrorResult.getValue(); 

MeanSquaredLogarithmicErrorMetric.Result meanSquaredLogarithmicErrorResult =
    response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); 
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); 

HuberMetric.Result huberResult = response.getMetricByName(HuberMetric.NAME); 
double huber = huberResult.getValue(); 

RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); 
double rSquared = rSquaredResult.getValue(); 

Fetching mean squared error metric by name

Fetching the actual mean squared error value

Fetching mean squared logarithmic error metric by name

Fetching the actual mean squared logarithmic error value

Fetching pseudo Huber loss metric by name

Fetching the actual pseudo Huber loss value

Fetching R squared metric by name

Fetching the actual R squared value