Description
Hello,
I have been trying to pass different prompts to the infer_evo2
script using the --prompt
flag, however no matter what prompts I pass to the model, the model seems to generate nearly identical sequences, as if the passed prompts had no effect on the model generation. I have tried this for both the 1-billion 8k and 7-billion 8k base models as well as analogous versions fine-tuned on a custom dataset, testing different prompts and sequence lengths.
Example
If I pass no prompt, which defaults to the default prompt (i.e., "|d__Bacteria;p__Pseudomonadota;c__Gammaproteobacteria;..." etc.), by running
infer_evo2 --ckpt-dir $BIONEMO_CACHE/nemo2_evo2_7b_8k/ --temperature 0.7 --top-k 4 --max-new-tokens 100
I get the following output sequence,
"
CTCTTCTGGTATTTGGTGTGCGATGGACCTATTGAGCTGATAGTGTGGATCCTGTTGCTCCGGATGGTTTTGCTGGCCAACGGCCCCGCAACGCGTGCAC
"
If I prompt the model with a very specific prompt of an unrelated gene to the default prompt,
infer_evo2 --ckpt-dir $BIONEMO_CACHE/nemo2_evo2_7b_8k/ --temperature 0.7 --top-k 4 --max-new-tokens 1000 --prompt GAGCCTCCTTAACTCTTTACCGTTCACGTTAAATCACTTTCGCTACCGCGTAAACGTGCACGAGCCACTATAGTGGCGC
The model outputs a sequences that is over 90% identical to the sequence outputted with the default prompt, suggesting the provided prompt had no effect on the outputted sequence,
"
GTCTTGCGGTATTTTGTGTGCGATGGACCTATTGAGCGAATAGTGTGGATCCTGTTGCTCCGGATGGTTT
GACTGGCCAACGGCCCCCCAACGCGGGCAC
"
I have tried this for different prompts, different temperature or top-k values, always observing the same behaviour.
Additionally, during the inference process, the following warning is outputted six times,
"
[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention.py:5213: UserWarning: window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=no_mask warnings.warn(
"
I am not sure if I am prompting the model incorrectly or this reflects some deeper issue. Any help with this would be greatly appreciated, thanks.