If I understand correctly, the 5 in values.shape is not related to the (5, 5) in map.shape, but only to the range of values in map. The first thing to do is to make the values in map follow the conventional python indexing convention: start counting from zero, not from one.
Here is how to do it with the map having a different shape than the values:
import numpy as np
m, n, p = 2, 3, 4
values = np.arange(m*n*p).reshape(m, n, p)
map = np.array([
[0, 0, 0, 3, 3, 3],
[0, 0, 1, 3, 3, 2],
[1, 1, 1, 3, 2, 2],
[2, 1, 2, 2, 2, 1],
[2, 2, 2, 2, 1, 1]
])
q, r = map.shape
map_r = map.reshape(q*r)
# result_r: shape (q*r, m, n)
# If you want to keep the 1-based numbering, index as [:, :, i-1].
result_r = np.array([values[:, :, i] for i in map_r])
result = result_r.transpose(1, 2, 0).reshape(m, n, q, r)
# test
jq, jr = 4, 5
mapval = map[jq, jr]
print(f'map[{jq}, {jr}] = {mapval}')
print(f'result[:, :, {jq}, {jr}] =\n{result[:, :, jq, jr]}')
print(f'values[:, :, {mapval}] =\n{values[:, :, mapval]}')
Output:
map[4, 5] = 1
result[:, :, 4, 5] =
[[ 1 5 9]
[13 17 21]]
values[:, :, 1] =
[[ 1 5 9]
[13 17 21]]
As an aside: multidimensional arrays often work more cleanly if you define the left-most index to represent a higher level concept. If you can phrase it as "a list of N lists of M widgets", the natural array shape would be (N, M), not (M,N). So, you would have values.shape = (p, m, n) and result.shape (q,r,m,n) (and maybe swap m and n as well, depending on how you will process the data later on). You would avoid a lot of [:, :] indices and transpose operations.