Skip to content

Trailing singleton dimensions are removed during dtype conversion #491

@jakemoran

Description

@jakemoran

Bug Description

When using PyArrayLikeDyn with AllowTypeChange, trailing singleton axes may be removed from inputs that are ndarrays but have the wrong dtype.

Steps to Reproduce

Cargo.toml

[package]
name = "singleton-removed"
version = "0.1.0"
edition = "2024"

[dependencies]
numpy = "0.24.0"
pyo3 = { version = "0.24.2", features = ["auto-initialize"] }

main.rs

use numpy::{AllowTypeChange, PyArrayDyn, PyArrayLikeDyn, PyArrayMethods};
use pyo3::ffi::c_str;
use pyo3::prelude::*;

#[pyfunction]
fn double<'py>(
    py: Python<'py>,
    a: PyArrayLikeDyn<'py, f64, AllowTypeChange>,
) -> Bound<'py, PyArrayDyn<f64>> {
    PyArrayDyn::from_owned_array(py, a.to_owned_array() * 2.0)
}

#[pymodule]
fn singleton_removed(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(double, m)?)?;
    Ok(())
}

fn main() -> PyResult<()> {
    pyo3::append_to_inittab!(singleton_removed);
    Python::with_gil(|py| {
        let code = c_str!(include_str!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/example.py"
        )));
        py.run(code, None, None)?;

        Ok(())
    })
}

example.py

import singleton_removed
import numpy as np

a = np.ones((3, 1), dtype=np.int32)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"

This results in the following error (plus a deprecation warning from numpy, seemingly for implicitly removing the singleton axis):

<string>:5: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
Error: PyErr { type: <class 'AssertionError'>, value: AssertionError('a.shape=(3, 1), b.shape=(3,)'), traceback: Some("Traceback (most recent call last):\n  File \"<string>\", line 6, in <module>\n") }

If no type change occurs, the axis is preserved, e.g.

import singleton_removed
import numpy as np

a = np.ones((3, 1), dtype=np.float64)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"

succeeds. Non-array inputs also behave properly, e.g.

import singleton_removed
import numpy as np

a = [[1], [2], [3]]
b = singleton_removed.double(a)
assert b.shape == (3, 1), f"{a.shape=}, {b.shape=}"

also succeeds.

Oddly enough, if I add a third axis, the trailing singleton dimension is no longer removed:

import singleton_removed
import numpy as np

a = np.ones((3, 2, 1), dtype=np.int32)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"

Relevant Info

Python Version

3.13.3

NumPy Version

2.2.5

PyO3 Version

0.24.2

rust-numpy Version

0.24.0

rustc Version

1.86.0

OS

Distributor ID: Ubuntu
Description:    Ubuntu 24.04.2 LTS
Release:        24.04
Codename:       noble

(via WSL)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions