Skip to content

Commit a4fd0ea

Browse files
authored
fix edt float16 kernel (#2568)
* fix edt float16 kernel * fix a small bug
1 parent 1a5a2bf commit a4fd0ea

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tensorflow_addons/custom_ops/image/cc/kernels/euclidean_distance_transform_op.h

100644100755
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void Distance(const T* f, T* d, int* v,
4949
// compute horizontal position of intersection between the parabola from
5050
// q and the current lowest parabola
5151
s = (f[q] - f[v[k]]) / static_cast<T>(2 * (q - v[k])) +
52-
static_cast<T>((q + v[k]) / 2);
52+
static_cast<T>((q + v[k]) / 2.0);
5353
} while (s <= z[k]);
5454
k++;
5555
v[k] = q;
@@ -97,6 +97,9 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void EuclideanDistanceTransformSample(
9797
Distance<T>(f, d, vw, zw, width);
9898
for (int j = 0; j < width; j++) {
9999
int index = GET_INDEX(i, j, k, c);
100+
if (Eigen::numext::isinf(d[j])) {
101+
d[j] = Eigen::NumTraits<T>::highest();
102+
}
100103
output[index] = d[j];
101104
}
102105
}
@@ -108,6 +111,9 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void EuclideanDistanceTransformSample(
108111
Distance<T>(f, d, vh, zh, height);
109112
for (int i = 0; i < height; i++) {
110113
int index = GET_INDEX(i, j, k, c);
114+
if (Eigen::numext::isinf(d[i])) {
115+
d[i] = Eigen::NumTraits<T>::highest();
116+
}
111117
output[index] = Eigen::numext::sqrt(d[i]);
112118
}
113119
}

0 commit comments

Comments
 (0)