aom: Corner match: Refine using optical flow

From ac4ff32713a4f6961f8ef7251eb98497b5512188 Mon Sep 17 00:00:00 2001
From: Rachel Barker <[EMAIL REDACTED]>
Date: Wed, 17 Jan 2024 13:40:28 +0000
Subject: [PATCH] Corner match: Refine using optical flow

Replace the previous refinement search with one round of the
optical flow refinement used by disflow. This gives much better
results, much faster.

Results @ "good" mode speed 4:

     Compared to      | BDRATE-PSNR | BDRATE-SSIM | Encode time
----------------------+-------------+-------------+-------------
Previous corner match |   -0.227%   |   -0.203%   |   -2.846%
       Disflow        |   +0.047%   |   +0.045%   |  +11.833%

No change to encoder output, as we currently use disflow.

Change-Id: Ie2491891c77cbd4b49cc3b45f0a8906b75604cbf
---
 aom_dsp/flow_estimation/corner_match.c | 89 ++++++--------------------
 1 file changed, 19 insertions(+), 70 deletions(-)

diff --git a/aom_dsp/flow_estimation/corner_match.c b/aom_dsp/flow_estimation/corner_match.c
index dc7589a8c6..dd524c1acb 100644
--- a/aom_dsp/flow_estimation/corner_match.c
+++ b/aom_dsp/flow_estimation/corner_match.c
@@ -17,14 +17,12 @@
 
 #include "aom_dsp/flow_estimation/corner_detect.h"
 #include "aom_dsp/flow_estimation/corner_match.h"
+#include "aom_dsp/flow_estimation/disflow.h"
 #include "aom_dsp/flow_estimation/flow_estimation.h"
 #include "aom_dsp/flow_estimation/ransac.h"
 #include "aom_dsp/pyramid.h"
 #include "aom_scale/yv12config.h"
 
-#define SEARCH_SZ 9
-#define SEARCH_SZ_BY2 ((SEARCH_SZ - 1) / 2)
-
 #define THRESHOLD_NCC 0.75
 
 /* Compute var(frame) * MATCH_SZ_SQ over a MATCH_SZ by MATCH_SZ window of frame,
@@ -87,66 +85,6 @@ static int is_eligible_distance(int point1x, int point1y, int point2x,
           (point1y - point2y) * (point1y - point2y)) <= thresh * thresh;
 }
 
-static void improve_correspondence(const unsigned char *src,
-                                   const unsigned char *ref, int width,
-                                   int height, int src_stride, int ref_stride,
-                                   Correspondence *correspondences,
-                                   int num_correspondences) {
-  int i;
-  for (i = 0; i < num_correspondences; ++i) {
-    int x, y, best_x = 0, best_y = 0;
-    double best_match_ncc = 0.0;
-    // For this algorithm, all points have integer coordinates.
-    // It's a little more efficient to convert them to ints once,
-    // before the inner loops
-    int x0 = (int)correspondences[i].x;
-    int y0 = (int)correspondences[i].y;
-    int rx0 = (int)correspondences[i].rx;
-    int ry0 = (int)correspondences[i].ry;
-    for (y = -SEARCH_SZ_BY2; y <= SEARCH_SZ_BY2; ++y) {
-      for (x = -SEARCH_SZ_BY2; x <= SEARCH_SZ_BY2; ++x) {
-        double match_ncc;
-        if (!is_eligible_point(rx0 + x, ry0 + y, width, height)) continue;
-        if (!is_eligible_distance(x0, y0, rx0 + x, ry0 + y, width, height))
-          continue;
-        match_ncc = av1_compute_cross_correlation(src, src_stride, x0, y0, ref,
-                                                  ref_stride, rx0 + x, ry0 + y);
-        if (match_ncc > best_match_ncc) {
-          best_match_ncc = match_ncc;
-          best_y = y;
-          best_x = x;
-        }
-      }
-    }
-    correspondences[i].rx += best_x;
-    correspondences[i].ry += best_y;
-  }
-  for (i = 0; i < num_correspondences; ++i) {
-    int x, y, best_x = 0, best_y = 0;
-    double best_match_ncc = 0.0;
-    int x0 = (int)correspondences[i].x;
-    int y0 = (int)correspondences[i].y;
-    int rx0 = (int)correspondences[i].rx;
-    int ry0 = (int)correspondences[i].ry;
-    for (y = -SEARCH_SZ_BY2; y <= SEARCH_SZ_BY2; ++y)
-      for (x = -SEARCH_SZ_BY2; x <= SEARCH_SZ_BY2; ++x) {
-        double match_ncc;
-        if (!is_eligible_point(x0 + x, y0 + y, width, height)) continue;
-        if (!is_eligible_distance(x0 + x, y0 + y, rx0, ry0, width, height))
-          continue;
-        match_ncc = av1_compute_cross_correlation(
-            ref, ref_stride, rx0, ry0, src, src_stride, x0 + x, y0 + y);
-        if (match_ncc > best_match_ncc) {
-          best_match_ncc = match_ncc;
-          best_y = y;
-          best_x = x;
-        }
-      }
-    correspondences[i].x += best_x;
-    correspondences[i].y += best_y;
-  }
-}
-
 static int determine_correspondence(const unsigned char *src,
                                     const int *src_corners, int num_src_corners,
                                     const unsigned char *ref,
@@ -187,16 +125,27 @@ static int determine_correspondence(const unsigned char *src,
     template_norm = compute_variance(src, src_stride, src_corners[2 * i],
                                      src_corners[2 * i + 1]);
     if (best_match_ncc > THRESHOLD_NCC * sqrt(template_norm)) {
-      correspondences[num_correspondences].x = src_corners[2 * i];
-      correspondences[num_correspondences].y = src_corners[2 * i + 1];
-      correspondences[num_correspondences].rx = ref_corners[2 * best_match_j];
-      correspondences[num_correspondences].ry =
-          ref_corners[2 * best_match_j + 1];
+      // Apply refinement
+      const int sx = src_corners[2 * i];
+      const int sy = src_corners[2 * i + 1];
+      const int rx = ref_corners[2 * best_match_j];
+      const int ry = ref_corners[2 * best_match_j + 1];
+      double u = (double)(rx - sx);
+      double v = (double)(ry - sy);
+
+      const int patch_tl_x = sx - DISFLOW_PATCH_CENTER;
+      const int patch_tl_y = sy - DISFLOW_PATCH_CENTER;
+
+      aom_compute_flow_at_point(src, ref, patch_tl_x, patch_tl_y, width, height,
+                                src_stride, &u, &v);
+
+      correspondences[num_correspondences].x = (double)sx;
+      correspondences[num_correspondences].y = (double)sy;
+      correspondences[num_correspondences].rx = (double)sx + u;
+      correspondences[num_correspondences].ry = (double)sy + v;
       num_correspondences++;
     }
   }
-  improve_correspondence(src, ref, width, height, src_stride, ref_stride,
-                         correspondences, num_correspondences);
   return num_correspondences;
 }