From d71f430256f921cc1a637ce64cd5b0e3d3770ee5 Mon Sep 17 00:00:00 2001 From: Daniel Toyama Date: Sun, 21 Jun 2026 11:46:22 -0700 Subject: [PATCH] Optimize convert_int_to_float in pixel_fns and add microbenchmarks to unit tests. PiperOrigin-RevId: 935683438 --- android_env/components/pixel_fns.py | 21 ++++- android_env/components/pixel_fns_test.py | 110 ++++++++++++++++++++++- 2 files changed, 127 insertions(+), 4 deletions(-) diff --git a/android_env/components/pixel_fns.py b/android_env/components/pixel_fns.py index 2f491d46..77f129be 100644 --- a/android_env/components/pixel_fns.py +++ b/android_env/components/pixel_fns.py @@ -54,7 +54,9 @@ def orient_pixels(frame: np.ndarray, orientation: int) -> np.ndarray: ) -def convert_int_to_float(data: np.ndarray, data_spec: specs.Array): +def convert_int_to_float( + data: np.ndarray, data_spec: specs.Array +) -> np.ndarray: """Converts an array of int values to floats between 0 and 1.""" if not np.issubdtype(data.dtype, np.integer): @@ -67,4 +69,19 @@ def convert_int_to_float(data: np.ndarray, data_spec: specs.Array): iinfo = np.iinfo(data_spec.dtype) value_min = iinfo.min value_max = iinfo.max - return np.float32(1.0 * (data - value_min) / (value_max - value_min)) + # Optimize performance by: + # 1. Performing all calculations in float32 to avoid default float64 + # precision overhead. + # 2. Reusing the allocated float32 array for in-place operations to + # minimize memory allocation. + # 3. Using multiplication instead of division. + span = np.float32(value_max - value_min) + inv_span = np.float32(1.0) / span + out = data.astype(np.float32) # Allocate output array once + if np.all(value_min == 0): + # Skip subtraction if minimum is 0 (common for image data). + out *= inv_span # In-place multiplication is faster than division + else: + out -= np.float32(value_min) # In-place subtraction + out *= inv_span + return out diff --git a/android_env/components/pixel_fns_test.py b/android_env/components/pixel_fns_test.py index e53e7afb..fcbdd767 100644 --- a/android_env/components/pixel_fns_test.py +++ b/android_env/components/pixel_fns_test.py @@ -13,14 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pixel_fns.""" - +import timeit +from absl import flags from absl.testing import absltest from absl.testing import parameterized from android_env.components import pixel_fns from dm_env import specs import numpy as np +# Benchmarks take ~2 minutes to run, so they are disabled by default. +# Run with --test_arg=--run_benchmarks to enable. +_RUN_BENCHMARKS = flags.DEFINE_bool( + 'run_benchmarks', False, 'Whether to run microbenchmarks.' +) + class UtilsTest(parameterized.TestCase): @@ -103,5 +109,105 @@ def test_convert_int_to_float_no_bounds(self): np.array([0.0, 128. / 255., 1.0], dtype=np.float32), float_data) +class PixelFnsBenchmark(absltest.TestCase): + """Microbenchmarks for pixel functions. + + These are implemented as unit tests but are skipped by default because they + are slow. They are useful for verifying optimizations. + + NOTE: We use inlined strings with `timeit.Timer` instead of callables + (lambdas) to avoid Python function call overhead in the measurement loop. + For very fast operations like `transpose_pixels` (view) which take ~1 us, + the ~100ns lambda overhead would introduce a significant (~10%) measurement + error. + """ + + def setUp(self): + super().setUp() + if not _RUN_BENCHMARKS.value: + self.skipTest('Benchmark disabled. Run with --test_arg=--run_benchmarks') + + def test_touch_position_to_pixel_position(self): + setup = ( + 'from android_env.components import pixel_fns; import numpy as np; ' + 'touch_pos = np.array([0.5, 0.5]); width_height = [1080, 1920]' + ) + stmt = 'pixel_fns.touch_position_to_pixel_position(touch_pos, width_height)' + t = timeit.Timer(stmt, setup=setup) + number = 100000 + res = t.timeit(number=number) + print( + f'\ntouch_position_to_pixel_position: {res / number * 1e6:.3f} us per' + ' loop' + ) + + def test_transpose_pixels(self): + for size in [(320, 480), (1080, 1920)]: + setup = ( + 'from android_env.components import pixel_fns; import numpy as np;' + f' img = np.zeros(({size[1]}, {size[0]}, 3), dtype=np.uint8)' + ) + stmt = 'pixel_fns.transpose_pixels(img)' + t = timeit.Timer(stmt, setup=setup) + number = 1000 + res = t.timeit(number=number) + print( + f'\ntranspose_pixels {size} (view): {res / number * 1e3:.3f} ms per' + ' loop' + ) + + stmt_copy = 'pixel_fns.transpose_pixels(img).copy()' + t_copy = timeit.Timer(stmt_copy, setup=setup) + res_copy = t_copy.timeit(number=number) + print( + f'transpose_pixels {size} (copy): {res_copy / number * 1e3:.3f} ms' + ' per loop' + ) + + def test_orient_pixels(self): + for size in [(320, 480), (1080, 1920)]: + for orientation in [1, 2, 3]: + setup = ( + 'from android_env.components import pixel_fns; import numpy as np; ' + f'img = np.zeros(({size[1]}, {size[0]}, 3), dtype=np.uint8); ' + f'orientation = {orientation}' + ) + stmt = 'pixel_fns.orient_pixels(img, orientation)' + t = timeit.Timer(stmt, setup=setup) + number = 1000 + res = t.timeit(number=number) + print( + f'\norient_pixels {size}, orientation={orientation} (view):' + f' {res / number * 1e3:.3f} ms per loop' + ) + + stmt_copy = 'pixel_fns.orient_pixels(img, orientation).copy()' + t_copy = timeit.Timer(stmt_copy, setup=setup) + res_copy = t_copy.timeit(number=number) + print( + f'orient_pixels {size}, orientation={orientation} (copy):' + f' {res_copy / number * 1e3:.3f} ms per loop' + ) + + def test_convert_int_to_float(self): + for size in [(320, 480), (1080, 1920)]: + setup = ( + 'from android_env.components import pixel_fns\nimport numpy as' + ' np\nfrom dm_env import specs\nspec =' + f' specs.BoundedArray(shape=({size[1]}, {size[0]}, 3),' + ' dtype=np.uint8, minimum=0, maximum=255)\ndata =' + f' np.random.randint(0, 255, size=({size[1]}, {size[0]}, 3),' + ' dtype=np.uint8)\n' + ) + stmt = 'pixel_fns.convert_int_to_float(data, spec)' + t = timeit.Timer(stmt, setup=setup) + number = 100 + res = t.timeit(number=number) + print( + f'\nconvert_int_to_float {size} (BoundedArray):' + f' {res / number * 1e3:.3f} ms per loop' + ) + + if __name__ == '__main__': absltest.main()