Its took a while to build up the intuition for this problem (and some Leetcode discussions)
def countWaviness(num):
num_string = str(num)
num_length = len(num_string)
@lru_cache(None)
def dp(i, prev1, prev2, tight, started):
# returns total number found and total waviness
if i == num_length:
return (1, 0) # there is one number that completed and not further waviness since the number ended
max_allowed_num_for_digit = int(num_string[i]) if tight else 9
total_count, total_waviness = 0, 0
for new_digit in range(max_allowed_num_for_digit+1):
new_started = started or (new_digit != 0)
new_tight = tight & (new_digit == int(num_string[i]))
current_waviness = (started and prev1 != -1 and prev2 != -1) and ((prev1 < prev2 > new_digit) or (prev1 > prev2 < new_digit))
if not new_started:
child_count, child_waviness = dp(i+1, -1, -1, new_tight, new_started)
total_waviness += child_waviness
total_count += child_count
else:
child_count, child_waviness = dp(i+1, prev2, new_digit, new_tight, new_started)
total_waviness += (child_waviness + child_count if current_waviness else child_waviness)
total_count += child_count
return (total_count, total_waviness)
return dp(0, -1, -1, True, False)[1]
return countWaviness(num2) - countWaviness(num1-1)