The original operations are
RightShiftWithRounding(A*w0 + B*w1, 8)
with saturation and w0 + w1 == 16. The largest value of w0 is 13.
13 = 0xD, 13 << 13 = 0x1A000 is out of range for a 16-bit input. Shifting weight left by 7 would have vqdmulhq_s16 result in (x*weight) >> 8, which is close to what we want.
In general, switching to right-shifting before adding loses information, which results in missing carry bits. This case is no exception.
I think this is possible if weight and the shifts are factored out correctly using the equivalence of w0 = 16 - w1 and w1 = 16 - w0.
Original equation:
(p0w0 + p1w1 + 128) >> 8
Factored forms:
= (p0w0 + p1(16 - w0) + 128) >> 8
= (p0w0 + 16p1 - p1w0 + 128) >> 8
= ((p0 - p1)w0 + 16p1 + 128) >> 8
= (((p0 - p1)w0) >> 4) + p1 + 8) >> 4
= (p0(16 - w1) + p1w1 + 128) >> 8
= (16p0 - p0w1 + p1w1 + 128) >> 8
= ((p1 - p0)w1 + 16p0 + 128) >> 8
= (((p1 - p0)w1) >> 4) + p0 + 8) >> 4
For the weight scaling we then have:
= (((p0 - p1)(w0 << 12)) >> 16) + p1 + 8) >> 4
= (((p1 - p0)(w1 << 12)) >> 16) + p0 + 8) >> 4
pmulhw and sqdmulh provide the multiply and descale operation in the weight group and the final shift with saturation should take care of the rest.
For the sign change with w0 << 12 the equation could likely be adjusted based on the fact that -w0 = w1 - 16.
The NEON is done in cl/446264625
SSE4 is next
cl/446264625 covered the NEON case, but there's a problem with SSE4. There's no doubling, so the weight has to be shifted by 12 rather than 11. Additionally, we can't predict the order of the weights (see GetDistanceWeights[1]), so 13 << 12 is always possible, and gets negative sign extended in _mm_mulhi_epi16, unless we start with a branch that switches which derived formula is used. I'm exploring that approach now.
I think the negation identity still involves having to swap p0-p1 with p1-p0. I think it's reasonable to just start the function with:
const auto* pred_0 = (weight_0 < 9)
? static_cast<const int16_t*>(prediction_0)
: static_cast<const int16_t*>(prediction_1);
const auto* pred_1 = (weight_0 < 9)
? static_cast<const int16_t*>(prediction_1)
: static_cast<const int16_t*>(prediction_0);
const uint8_t weight = (weight_0 < 9) ? weight_0 : weight_1;
I believe ((((p1 - p0) * -w0) >> 4) + p1 + 8) >> 4 will work in all cases.
I revisited the negative weight idea because it seems like it should work in principle, but found the following:
>>> p0 = 9212
>>> p1 = -5132
>>> hex((p0 - p1) * (13 << 12))
'0x2d868000'
>>> hex((p1 - p0) * ((-13) << 12))
'0x2d868000'
>>> hex((2**16 - 13) << 12)
'0xfff3000'
>>> hex((p1 - p0) * 0x3000)
'-0xa818000'
However, ((p0-p1) << 1) * (w0 << 11) does work, at the cost of one more shift. It seems reasonable that this is worth the benefit of removing the overhead of those branches.