Skip to content

vllm.model_executor.models.whisper

WhisperAudioInputs

Bases: TensorSchema

Dimensions
  • b: Batch size
  • nmb: Number of mel bins
  • t: Time frames (M)
Source code in vllm/model_executor/models/whisper.py
class WhisperAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames (M)
    """

    input_features: Annotated[
        list[torch.Tensor] | None,
        TensorShape("b", "nmb", "t"),
    ]

WhisperEncoderAttention

Bases: MMEncoderAttention

Multi-headed attention for Whisper encoder with 2D tensor support.

Source code in vllm/model_executor/models/whisper.py
class WhisperEncoderAttention(MMEncoderAttention):
    """Multi-headed attention for Whisper encoder with 2D tensor support."""

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input shape: batch_size x seq_len x hidden_size
                     or seq_len x hidden_size
        """
        is_2d = query.dim() == 2
        if is_2d:
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)

        # Call the parent forward method
        out = super().forward(query, key, value)

        if is_2d:
            out = out.squeeze(0)

        return out

forward

forward(
    query: Tensor, key: Tensor, value: Tensor
) -> Tensor
batch_size x seq_len x hidden_size

or seq_len x hidden_size

Source code in vllm/model_executor/models/whisper.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
) -> torch.Tensor:
    """
    Input shape: batch_size x seq_len x hidden_size
                 or seq_len x hidden_size
    """
    is_2d = query.dim() == 2
    if is_2d:
        query = query.unsqueeze(0)
        key = key.unsqueeze(0)
        value = value.unsqueeze(0)

    # Call the parent forward method
    out = super().forward(query, key, value)

    if is_2d:
        out = out.squeeze(0)

    return out

WhisperForConditionalGeneration

Bases: Module, SupportsTranscription, SupportsMultiModal

Source code in vllm/model_executor/models/whisper.py
@MULTIMODAL_REGISTRY.register_processor(
    WhisperMultiModalProcessor,
    info=WhisperProcessingInfo,
    dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
    nn.Module,
    SupportsTranscription,
    SupportsMultiModal,
):
    packed_modules_mapping = {
        "self_attn.qkv_proj": [
            "self_attn.q_proj",
            "self_attn.k_proj",
            "self_attn.v_proj",
        ],
        "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
    }

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
    )

    # Whisper only supports audio-conditioned generation.
    supports_transcription_only = True
    supports_segment_timestamp = True
    supports_explicit_language_detection = True
    supported_languages = ISO639_1_SUPPORTED_LANGS

    @classmethod
    def validate_language(cls, language: str | None) -> str | None:
        if language is None:
            logger.debug(
                "No language specified. Language will be auto-detected "
                "from audio. To skip detection, pass the `language` field "
                "in the TranscriptionRequest."
            )
            return None
        return super().validate_language(language)

    @classmethod
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        model_config: ModelConfig,  # not needed here
        stt_config: SpeechToTextConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        if language is None:
            raise ValueError(
                "Language must be specified when creating the Whisper prompt"
            )

        decoder_text = (
            f"<|prev|>{request_prompt}" if request_prompt else ""
        ) + f"<|startoftranscript|><|{language}|><|{task_type}|><|notimestamps|>"

        return ExplicitEncoderDecoderPrompt(
            encoder_prompt=TextPrompt(
                prompt="",  # Whisper does not support encoder prompt.
                multi_modal_data={"audio": (audio, stt_config.sample_rate)},
            ),
            decoder_prompt=TextPrompt(prompt=decoder_text),
        )

    @classmethod
    def get_language_token_ids(
        cls,
        tokenizer: object,
    ) -> list[int]:
        """Return token IDs for all supported language tokens.

        Used with ``SamplingParams.allowed_token_ids`` to constrain
        language detection to only produce valid language tokens.
        """
        token_ids = [
            tokenizer.convert_tokens_to_ids(f"<|{lang_code}|>")
            for lang_code in cls.supported_languages
        ]
        return token_ids

    @classmethod
    def get_language_detection_prompt(
        cls,
        audio: np.ndarray,
        stt_config: SpeechToTextConfig,
    ) -> PromptType:
        """Return a prompt that elicits a single language token from Whisper.

        Feed only ``<|startoftranscript|>`` as the decoder input so the model
        predicts the most likely language token (e.g. ``<|de|>``).
        """
        return ExplicitEncoderDecoderPrompt(
            encoder_prompt=TextPrompt(
                prompt="",
                multi_modal_data={"audio": (audio, stt_config.sample_rate)},
            ),
            decoder_prompt=TextPrompt(prompt="<|startoftranscript|>"),
        )

    @classmethod
    def parse_language_detection_output(
        cls,
        token_ids: list[int],
        tokenizer: object,
    ) -> str | None:
        """Parse the language token predicted by Whisper.

        Decodes the first token ID and extracts the language code from the
        ``<|xx|>`` format. Expects a valid language token from constrained generation.
        """

        decoded = tokenizer.decode(
            [token_ids[0]],
            skip_special_tokens=False,
        )
        # Whisper language tokens have the form <|xx|>
        assert decoded.startswith("<|") and decoded.endswith("|>")
        lang_code = decoded[2:-2]
        assert lang_code in cls.supported_languages
        return lang_code

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("audio"):
            return None

        raise ValueError("Only audio modality is supported")

    @classmethod
    def get_speech_to_text_config(
        cls, model_config: ModelConfig, task_type: str
    ) -> SpeechToTextConfig:
        processor = cached_processor_from_config(model_config)

        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
        )

    @classmethod
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
        model_config: ModelConfig,
    ) -> int | None:
        processor = cached_processor_from_config(model_config)
        hop_length = processor.feature_extractor.hop_length
        assert hop_length is not None
        # NOTE(NickLucche) user can't pass encoder
        # prompts directly at least not to Whisper.
        # One indicator of the encoder amount of processing
        # is the log-mel spectogram length.
        return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

        with self._mark_composite_model(
            vllm_config,
            language_targets=WhisperDecoder,
            tower_targets={"audio": WhisperEncoder},
        ):
            self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)

        self.proj_out = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj_out"),
        )
        self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        encoder_outputs: list[torch.Tensor] | None = None,
        **kwargs,
    ) -> torch.Tensor:
        if encoder_outputs is None:
            encoder_outputs = []
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
            encoder_outputs=encoder_outputs,
        )
        return decoder_outputs

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        # Required as part of SupportsMultiModal interface.
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        # Split concatenated encoder outputs into one tensor per audio input
        enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
        # The assumption is we can only process whole mm items (audios)
        return enc_output.unbind(dim=0)

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        # This method just returns the decoder sequence embeddings since
        # Whisper does not have encoder text tokens.
        return self.model.decoder.embed_input_ids(input_ids)

    def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
        input_features = kwargs.pop("input_features", None)

        if input_features is not None:
            input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)

        return WhisperAudioInputs(input_features=input_features)

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.proj_out, hidden_states)
        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])

        # add fake zeros bias for k_proj to state_dict
        weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

get_language_detection_prompt classmethod

get_language_detection_prompt(
    audio: ndarray, stt_config: SpeechToTextConfig
) -> PromptType

Return a prompt that elicits a single language token from Whisper.

Feed only <|startoftranscript|> as the decoder input so the model predicts the most likely language token (e.g. <|de|>).

Source code in vllm/model_executor/models/whisper.py
@classmethod
def get_language_detection_prompt(
    cls,
    audio: np.ndarray,
    stt_config: SpeechToTextConfig,
) -> PromptType:
    """Return a prompt that elicits a single language token from Whisper.

    Feed only ``<|startoftranscript|>`` as the decoder input so the model
    predicts the most likely language token (e.g. ``<|de|>``).
    """
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=TextPrompt(
            prompt="",
            multi_modal_data={"audio": (audio, stt_config.sample_rate)},
        ),
        decoder_prompt=TextPrompt(prompt="<|startoftranscript|>"),
    )

get_language_token_ids classmethod

get_language_token_ids(tokenizer: object) -> list[int]

Return token IDs for all supported language tokens.

Used with SamplingParams.allowed_token_ids to constrain language detection to only produce valid language tokens.

Source code in vllm/model_executor/models/whisper.py
@classmethod
def get_language_token_ids(
    cls,
    tokenizer: object,
) -> list[int]:
    """Return token IDs for all supported language tokens.

    Used with ``SamplingParams.allowed_token_ids`` to constrain
    language detection to only produce valid language tokens.
    """
    token_ids = [
        tokenizer.convert_tokens_to_ids(f"<|{lang_code}|>")
        for lang_code in cls.supported_languages
    ]
    return token_ids

parse_language_detection_output classmethod

parse_language_detection_output(
    token_ids: list[int], tokenizer: object
) -> str | None

Parse the language token predicted by Whisper.

Decodes the first token ID and extracts the language code from the <|xx|> format. Expects a valid language token from constrained generation.

Source code in vllm/model_executor/models/whisper.py
@classmethod
def parse_language_detection_output(
    cls,
    token_ids: list[int],
    tokenizer: object,
) -> str | None:
    """Parse the language token predicted by Whisper.

    Decodes the first token ID and extracts the language code from the
    ``<|xx|>`` format. Expects a valid language token from constrained generation.
    """

    decoded = tokenizer.decode(
        [token_ids[0]],
        skip_special_tokens=False,
    )
    # Whisper language tokens have the form <|xx|>
    assert decoded.startswith("<|") and decoded.endswith("|>")
    lang_code = decoded[2:-2]
    assert lang_code in cls.supported_languages
    return lang_code

WhisperProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/whisper.py
class WhisperProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> WhisperConfig:
        return self.ctx.get_hf_config(WhisperConfig)

    def get_default_tok_params(self) -> TokenizeParams:
        # Special tokens should be provided by the user based on the
        # task and language of their request. Also needed to avoid
        # appending an EOS token to the prompt which disrupts generation.
        return super().get_default_tok_params().with_kwargs(add_special_tokens=False)

    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

    @property
    def skip_prompt_length_check(self) -> bool:
        return True  # Because the encoder prompt is padded

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"audio": 1}

    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
        hf_processor = self.get_hf_processor(**kwargs)
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

    def get_target_channels(self) -> int:
        """Return target audio channels for Whisper models (mono)."""
        return 1

    def get_num_audio_tokens(self) -> int:
        return self.get_hf_config().max_source_positions

get_target_channels

get_target_channels() -> int

Return target audio channels for Whisper models (mono).

Source code in vllm/model_executor/models/whisper.py
def get_target_channels(self) -> int:
    """Return target audio channels for Whisper models (mono)."""
    return 1

_create_fake_bias_for_k_proj

_create_fake_bias_for_k_proj(
    weights: Iterable[tuple[str, Tensor]],
    fake_bias_key_name: str,
) -> Iterable[tuple[str, Tensor]]

Create full zeros bias for k_proj weight in self-attn and x-attn layers. So that the bias for k_proj in qkv_proj can be initialized with zeros.

Source code in vllm/model_executor/models/whisper.py
def _create_fake_bias_for_k_proj(
    weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
) -> Iterable[tuple[str, torch.Tensor]]:
    """
    Create full zeros bias for k_proj weight in self-attn and x-attn layers.
    So that the bias for k_proj in qkv_proj can be initialized with zeros.
    """
    for name, weight in weights:
        yield name, weight
        if name.endswith(fake_bias_key_name):
            bias = torch.zeros(weight.size(0))
            bias_name = name.replace("weight", "bias")
            yield bias_name, bias