Skip to content

Error while running bash command: run_sample_video.sh | Error: "TypeError: missing a required argument: 'segment_ids'" #77

@samitm-123

Description

@samitm-123

I receive this error when i run this bash command: !bash LWM/scripts/run_sample_video.sh. I have followed all the direction listed in the repo.

/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/content/LWM/lwm/vision_generation.py", line 256, in <module>
    run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/content/LWM/lwm/vision_generation.py", line 92, in main
    model = FlaxVideoLLaMAForCausalLM(
  File "/content/LWM/lwm/vision_llama.py", line 141, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_flax_utils.py", line 224, in __init__
    params_shape_tree = jax.eval_shape(init_fn, self.key)
  File "/content/LWM/lwm/vision_llama.py", line 166, in init_weights
    random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
  File "/content/LWM/lwm/vision_llama.py", line 396, in __call__
    outputs = self.transformer(
  File "/content/LWM/lwm/vision_llama.py", line 315, in __call__
    outputs = self.h(
  File "/content/LWM/lwm/llama.py", line 945, in __call__
    hidden_states, _ = nn.scan(
  File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 151, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 123, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/content/LWM/lwm/llama.py", line 724, in __call__
    attn_outputs = self.attention(
  File "/content/LWM/lwm/llama.py", line 615, in __call__
    attn_output = ring_attention_sharded(
  File "/usr/lib/python3.10/inspect.py", line 3186, in bind
    return self._bind(args, kwargs)
  File "/usr/lib/python3.10/inspect.py", line 3101, in _bind
    raise TypeError(msg) from None
TypeError: missing a required argument: 'segment_ids'

Would appreciate some help here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions