Covo-Audio-Chat / configuration_covo_audio.py
xu-jj24's picture
Upload folder using huggingface_hub
8b5398f verified
Raw
History Blame Contribute Delete
6.37 kB
from typing import Optional
from transformers import Qwen2Config, WhisperConfig
from transformers.configuration_utils import PretrainedConfig
class CovoAudioConfig(PretrainedConfig):
model_type = "covo_audio"
sub_configs = {"llm_config": Qwen2Config, "encoder_config": WhisperConfig} # type: ignore
has_no_defaults_at_init = True
def __init__(self,
llm_config:Optional[Qwen2Config]=None,
encoder_config:Optional[WhisperConfig]=None,
audio_token_index=151671,
adapter_downsample=8,
**kwargs):
if llm_config is None:
llm_config = Qwen2Config(
architectures=[
"Qwen2ForCausalLM"
],
attention_dropout=0.0,
bos_token_id=151643,
eos_token_id=151643,
hidden_act="silu",
hidden_size=3584,
initializer_range=0.02,
intermediate_size=18944,
max_position_embeddings=131072,
max_window_layers=28,
model_type="qwen2",
num_attention_heads=28,
num_hidden_layers=28,
num_key_value_heads=4,
rms_norm_eps=1e-06,
rope_scaling=None,
rope_theta=1000000.0,
sliding_window=131072,
torch_dtype="bfloat16",
use_cache=True,
use_mrope=False,
use_sliding_window=False,
vocab_size=168055
)
if encoder_config is None:
encoder_config = WhisperConfig(
_name_or_path="openai/whisper-large-v3",
activation_dropout=0.0,
activation_function="gelu",
apply_spec_augment=False,
architectures=[
"WhisperForConditionalGeneration"
],
attention_dropout=0.0,
begin_suppress_tokens=[
220,
50257
],
bos_token_id=50257,
classifier_proj_size=256,
d_model=1280,
decoder_attention_heads=20,
decoder_ffn_dim=5120,
decoder_layerdrop=0.0,
decoder_layers=32,
decoder_start_token_id=50258,
dropout=0.0,
encoder_attention_heads=20,
encoder_ffn_dim=5120,
encoder_layerdrop=0.0,
encoder_layers=32,
eos_token_id=50257,
init_std=0.02,
mask_feature_length=10,
mask_feature_min_masks=0,
mask_feature_prob=0.0,
mask_time_length=10,
mask_time_min_masks=2,
mask_time_prob=0.05,
max_length=448,
max_source_positions=1500,
max_target_positions=448,
median_filter_width=7,
model_type="whisper",
num_hidden_layers=32,
num_mel_bins=128,
scale_embedding=False,
torch_dtype="float16",
use_cache=True,
use_weighted_layer_sum=False,
vocab_size=51866
)
self.audio_token_index = audio_token_index
self.adapter_downsample = adapter_downsample
self.llm_config = llm_config
self.encoder_config = encoder_config
self.whisper_feats_dim = encoder_config.d_model
if "dtype" not in kwargs:
kwargs["dtype"] = "bfloat16"
self.dtype = kwargs["dtype"]
super().__init__(**kwargs)
@property
def num_hidden_layers(self):
return self.llm_config.num_hidden_layers
@property
def hidden_size(self):
return self.llm_config.hidden_size
def to_dict(self):
"""Serializes this instance to a Python dictionary, ensuring nested
PretrainedConfig objects are serialized via their own to_dict().
"""
output = super().to_dict()
# replace nested config objects with their dict representation
if hasattr(self, "llm_config") and isinstance(self.llm_config, PretrainedConfig):
output["llm_config"] = self.llm_config.to_dict()
output["_llm_config_type"] = getattr(self.llm_config, "model_type", None)
if hasattr(self, "encoder_config") and isinstance(self.encoder_config, PretrainedConfig):
output["encoder_config"] = self.encoder_config.to_dict()
output["_encoder_config_type"] = getattr(self.encoder_config, "model_type", None)
return output
@classmethod
def from_dict(cls, config_dict: dict, **kwargs):
"""Create an CovoAudioConfig from a dict, reconstructing nested config
objects (llm_config, encoder_config) using the classes declared in
`sub_configs` if available.
"""
# Make a shallow copy to avoid mutating input
data = dict(config_dict)
llm_conf = None
enc_conf = None
if "llm_config" in data and data["llm_config"] is not None:
llm_cls = cls.sub_configs.get("llm_config") if hasattr(cls, "sub_configs") else None
if llm_cls is not None:
# use the sub-config class to reconstruct
llm_conf = llm_cls.from_dict(data.pop("llm_config"))
else:
# fallback to raw dict
llm_conf = data.pop("llm_config")
if "encoder_config" in data and data["encoder_config"] is not None:
enc_cls = cls.sub_configs.get("encoder_config") if hasattr(cls, "sub_configs") else None
if enc_cls is not None:
enc_conf = enc_cls.from_dict(data.pop("encoder_config"))
else:
enc_conf = data.pop("encoder_config")
# ensure HF-compatible fields reflect the underlying decoder (LLM)
# remove internal helper keys if present
data.pop("_llm_config_type", None)
data.pop("_encoder_config_type", None)
# now construct instance using reconstructed nested configs
return cls(llm_config=llm_conf, encoder_config=enc_conf, **data)