Gemma-3 모델에서 LoRA 학습 후 추론이 깨지는 문제를 해결하며, vLLM에 tie_word_embeddings=False 지원을 기여한 경험을 공유합니다.
1. Sanity Check X
1.1 증상
Gemma-3-4B-IT 모델을 LoRA로 파인튜닝했다.
[Training Log - v2]
Epoch 3/3: loss=0.0812, eval_loss=0.0756
Training completed successfully!
그런데 추론을 실행하니까
Input: "Packet Analysis: modbus.fc=3, modbus.addr=100"
Output (Expected):
"**Field Description:** modbus.fc: Function Code..."
Output (Actual):
"트트트트트트트트트트트트트<end_of_turn><end_of_turn><end_of_turn>..."
무의미한 반복 토큰..
1.2 디버깅 과정
처음에는 단순한 학습 문제로 생각했다:
- 학습률이 너무 높았나? → 조정해도 동일
- 데이터에 문제가 있나? → 다른 데이터로도 동일
- Epoch을 너무 많이 돌렸나? → 1 epoch만 해도 동일
근본적인 원인은 다른 곳에 있었다.
2. 원인 분석: tie_word_embeddings의 함정
2.1 Gemma-3의 아키텍처 특성
Gemma-3 모델은 기본적으로 tie_word_embeddings=True 설정을 사용한다. 이는 **입력 임베딩(embed_tokens)**과 **출력 임베딩(lm_head)**이 동일한 가중치 텐서를 공유한다는 의미다.
# Gemma-3 기본 설정
config.tie_word_embeddings = True
# 실제로는 이런 상태:
model.embed_tokens.weight is model.lm_head.weight # True (동일 객체!)
2.2 PEFT/LoRA의 알려진 버그
문제는 이 tied된 두 레이어를 모두 LoRA target으로 지정할 때 발생한다.
내가 사용한 LoRA 설정:
# finetuned/h4mg_03_train_lora_v2.py (문제가 발생한 버전)
lora_config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"embed_tokens", # 입력 임베딩 ← 문제!
"lm_head" # 출력 임베딩 ← 문제!
],
...
)
PEFT 버그 메커니즘:
- embed_tokens와 lm_head가 같은 weight tensor를 참조
- PEFT가 두 레이어 각각에 LoRA adapter를 추가
- merge_and_unload() 호출 시 같은 텐서에 두 번 병합
- 결과: 가중치가 두 배로 적용되어 완전히 손상됨
참고: PEFT Issue #2018 - 이 문제는 공식적으로 알려진 버그다.
With tied embeddings adapter merged to tied layers · Issue #2018 · huggingface/peft
With tied embeddings adapter merged to tied layers · Issue #2018 · huggingface/peft
System Info peft=0.12.0 transformers =4.44.0 Who can help? No response Information The official example scripts My own modified scripts Tasks An officially supported task in the examples folder My ...
github.com
2.3 두 번째 문제: Vocab Size 불일치
커스텀 토크나이저를 만들 때 또 다른 실수가 있었다.
Base model embedding size: 262,208
Custom tokenizer v1 vocab: 262,191 (← 더 작음!)
토크나이저 vocab이 모델 임베딩보다 작으면, resize_token_embeddings() 호출 시 임베딩이 잘린다(truncation). 이로 인해 일부 토큰의 임베딩이 손실되었다.
3. 해결 방법
3.1 토크나이저 수정
vocab size가 base model embedding보다 반드시 크도록 패딩 토큰 추가
# h4mg_01_create_tokenizer_v3.py:49-101
def create_tokenizer(args):
"""Create custom tokenizer with H4MG field tokens"""
# 상수 정의: Gemma-3 실제 임베딩 크기
BASE_MODEL_EMBED_SIZE = 262208
# 1. 기본 토크나이저 로드
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
base_vocab_size = len(tokenizer) # 262,144
print(f"Base vocab size: {base_vocab_size}")
# 2. 도메인 토큰 로드 (46개 ICS 프로토콜 토큰)
field_tokens = load_field_tokens(args.field_tokens)
print(f"Total unique tokens: {len(field_tokens)}")
# 3. 토큰 추가
num_added = tokenizer.add_tokens(field_tokens)
print(f"Actually added: {num_added} tokens")
print(f"New vocab size: {len(tokenizer)}") # 262,190
# 4. ★ 핵심 로직: 임베딩 절단 방지
if len(tokenizer) <= BASE_MODEL_EMBED_SIZE:
padding_needed = BASE_MODEL_EMBED_SIZE - len(tokenizer) + 1
# 262,208 - 262,190 + 1 = 19개 패딩 토큰 필요
print(f"⚠️ Warning: Vocab size ({len(tokenizer)}) <= Base embedding size ({BASE_MODEL_EMBED_SIZE})")
print(f"Adding {padding_needed} padding tokens to prevent embedding truncation...")
padding_tokens = [f"<h4mg_pad_{i}>" for i in range(padding_needed)]
tokenizer.add_tokens(padding_tokens)
print(f"Final vocab size: {len(tokenizer)}") # 262,209
# 5. 토크나이저 저장
tokenizer.save_pretrained(args.output_path)
return tokenizer
결과:
- v1: 262,191 (문제 발생)
- v3: 262,209 (안전)
3.2 모델 로드 방식 변경 (tie=False)
tie_word_embeddings=False로 모델을 로드하여 두 레이어를 완전히 분리
# h4mg_03_train_lora_v3.py:142-202
def load_model_untied(model_path: str, tokenizer_vocab_size: int):
"""
Load model with tie_word_embeddings=False to fix LoRA issues.
Critical for vLLM compatibility when using embed_tokens and lm_head
as LoRA targets.
"""
print(f"Loading model with tie_word_embeddings=False...")
# 1. Config 로드 및 수정
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
original_tie = config.tie_word_embeddings # True
config.tie_word_embeddings = False # ★ 분리 설정
print(f"Original tie_word_embeddings: {original_tie}")
print(f"Modified tie_word_embeddings: {config.tie_word_embeddings}")
# 2. 분리된 임베딩으로 모델 로드
model = AutoModelForCausalLM.from_pretrained(
model_path,
config=config,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
attn_implementation="eager", # Flash Attention 비활성화 (안정성)
)
# 3. 임베딩 레이어 확인
embed_tokens = model.get_input_embeddings()
lm_head = model.get_output_embeddings()
print(f"embed_tokens shape: {embed_tokens.weight.shape}") # [262208, 2560]
print(f"lm_head shape: {lm_head.weight.shape}") # [262208, 2560]
# 4. ★ 핵심: embed_tokens 가중치를 lm_head에 복사
# tie=False로 로드하면 lm_head가 랜덤 초기화되므로 복사 필수!
print(f"Copying embed_tokens weights to lm_head...")
with torch.no_grad():
lm_head.weight.data = embed_tokens.weight.data.clone()
# 5. 어휘 크기 리사이즈 (커스텀 토큰 포함)
current_vocab_size = embed_tokens.weight.shape[0] # 262208
if tokenizer_vocab_size != current_vocab_size: # 262209
print(f"Resizing embeddings: {current_vocab_size} -> {tokenizer_vocab_size}")
model.resize_token_embeddings(tokenizer_vocab_size)
# 6. 새 토큰 임베딩 초기화 (기존 평균으로)
with torch.no_grad():
input_emb = model.get_input_embeddings().weight
output_emb = model.get_output_embeddings().weight
# 기존 어휘의 평균 임베딩 계산
avg_emb = input_emb[:current_vocab_size].mean(dim=0)
# 새 토큰에 평균 임베딩 할당
input_emb[current_vocab_size:] = avg_emb
output_emb[current_vocab_size:] = avg_emb
print(f"Initialized {tokenizer_vocab_size - current_vocab_size} new token embeddings")
return model
3.3 v3 학습 결과
학습 환경:
- GPU: NVIDIA H100 80GB
- Base Model: google/gemma-3-4b-it
- 학습 데이터: 13,000 샘플 (26개 카테고리)
학습 로그:
[2024-12-09 15:10:17] Training started...
[2024-12-09 15:10:17] Using untied embeddings (tie_word_embeddings=False)
[2024-12-09 15:10:18] embed_tokens and lm_head are now independent
Epoch 1/3: loss=0.2341 → Epoch 3/3: loss=0.0521
Final Eval Loss: 0.0739
Token Accuracy: 97.20%
Training Time: 1시간 45분
Total Steps: 2,196
4. 새로운 문제: vLLM 서빙 불가
4.1 vLLM의 제약사항
학습은 성공했지만, 프로덕션 서빙을 위해 vLLM을 사용하려고 하니 또 다른 문제가 발생했다.
# vLLM으로 모델 로드 시도
from vllm import LLM
llm = LLM(
model="path/to/merged_model",
trust_remote_code=True,
)
에러 발생:
AssertionError: Gemma3 requires tie_word_embeddings=True
vLLM의 Gemma3 구현체(vllm/model_executor/models/gemma3.py)가 tie_word_embeddings=True만 지원하도록 하드코딩되어 있었다.
# vllm 기존 코드 (v0.6.x)
class Gemma3ForCausalLM(nn.Module):
def __init__(self, ...):
...
assert config.tie_word_embeddings # ← 이 줄이 문제!
4.2 선택지
- vLLM 포기: 다른 추론 엔진 사용 (성능 손실)
- Tied로 재학습: 해봤는데
- vLLM 수정: 직접 고쳐서 PR 올리기
울며 겨자먹기로 3번을 선택했다. 왜냐면....
5090에서 전체 패턴뽑는데 113일이 걸리던 걸 vLLM으로 4.5일로 줄였단말이지...
아무튼 transformers 에서 113일이 걸리던 것을 vLLM으로 4.5일로 줄여뒀는데, 이걸 포기할 수는 없었다.
5. vLLM 오픈소스 기여
5.1 vLLM이란?
"The Standard of LLM Serving"
- GitHub Stars: 30,000+
- 주요 사용처: NVIDIA, AWS, Azure 등 주요 클라우드 플랫폼
- 핵심 기술: PagedAttention, Continuous Batching
- 의의: vLLM에 코드가 머지되면, 전 세계 수천 개 서비스의 프로덕션 환경에서 실행됨
[Bug]: Custom Hugging Face model with `tie_word_embeddings=False` causes `lm_head` loading issue · Issue #16555 · vllm-project
Your current environment The output of `python collect_env.py` PyTorch version: 2.5.0a0+e000cf0ad9.nv24.10 Is debug build: False CUDA used to build PyTorch: 12.6 ROCM used to build PyTorch: N/A OS:...
github.com

나같은 버그 이슈가 있었다.
- vLLM PR: [Model] Gemma3: Support untied word embeddings by www-spam · Pull Request #30827 · vllm-project/vllm
[Model] Gemma3: Support untied word embeddings by www-spam · Pull Request #30827 · vllm-project/vllm
Summary Add support for tie_word_embeddings=False in Gemma3ForCausalLM This enables loading merged LoRA models where embeddings are untied after merging Changes Add ColumnParallelLinear import R...
github.com
5.2 기술적 구현 과정
Phase 1: 초기 구현
접근 방식: tie_word_embeddings=False일 때 별도의 lm_head 레이어 생성
# 초기 구현 (v1)
from vllm.model_executor.layers.linear import ColumnParallelLinear
class Gemma3ForCausalLM(nn.Module):
def __init__(self, ...):
# 기존 assert 제거
# assert config.tie_word_embeddings ← 삭제
if not config.tie_word_embeddings:
# Untied일 때 별도 lm_head 생성
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
)
else:
self.lm_head = None
문제점: vLLM의 LogitsProcessor는 VocabParallelEmbedding 타입을 기대함. ColumnParallelLinear는 타입이 달라서 호환성 문제 발생.
Phase 2: 타입 안전성 확보
개선: ColumnParallelLinear → ParallelLMHead 교체
# 개선된 구현 (v2)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
class Gemma3ForCausalLM(nn.Module):
def __init__(self, ...):
if not config.tie_word_embeddings:
# ParallelLMHead는 VocabParallelEmbedding을 상속받음
self.lm_head = ParallelLMHead(
config.vocab_size, # ⚠️ 인자 순서 주의!
config.hidden_size,
bias=False,
)
주의사항: 두 클래스의 생성자 인자 순서가 다름!
ColumnParallelLinear(input_size, output_size) # hidden → vocab
ParallelLMHead(vocab_size, embedding_dim) # vocab → hidden
Phase 3: 메인테이너 피드백 반영
DarkLight1337(메인테이너) 제안:
"compute_logits 내에 if/else 분기를 두면 유지보수가 어려워집니다. 항상 lm_head를 생성하되, tied 설정이면 가중치를 공유하는 방식을 추천합니다."
최종 구현:
# 최종 구현 (v3) - vllm/model_executor/models/gemma3.py
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
class Gemma3ForCausalLM(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
# 1. Transformer 모델 생성
self.model = Gemma3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
# 2. ★ 항상 lm_head 생성
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=vllm_config.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
# 3. ★ Tied 설정이면 가중치 공유
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# 4. Logits Processor
self.logits_processor = LogitsProcessor(config.vocab_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# 분기 없이 항상 self.lm_head 사용
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
# embed_tokens 처리
if "embed_tokens" in name:
# ... (기존 로직)
pass
# ★ lm_head 처리 (untied일 때만)
elif "lm_head" in name:
if not self.config.tie_word_embeddings:
# Untied면 별도로 로드
param = params_dict.get("lm_head.weight")
if param is not None:
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add("lm_head.weight")
# Tied면 embed_tokens에서 이미 로드됨 (무시)
else:
# ... (기타 가중치 처리)
pass
return loaded_params
5.3 tie_weights() 메서드의 의미
# 메인테이너 제안 방식 (직접 할당)
self.lm_head.weight = self.model.embed_tokens.weight
# 내가 적용한 방식 (메서드 활용)
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
tie_weights() 사용의 장점:
- 의도 명확화: 코드를 읽는 사람에게 "가중치를 묶는다"는 의도가 명확히 전달됨
- 일관성: vLLM 내 다른 모델(Llama, Mistral 등) 구현체와 패턴 통일
- 캡슐화: ParallelLMHead 내부 구조가 변경되어도 메서드 내부에서 처리 가능
5.4 코드 변경 요약
구분 초기 구현 (Phase 1) 최종 구현 (Phase 3)
| Class | ColumnParallelLinear | ParallelLMHead |
| Weight Sharing | lm_head = None 분기 | tie_weights() 사용 |
| Logic Flow | if lm_head is not None: | 분기 제거 (Unified) |
| Type Safety | Type Hint 불일치 | 완벽한 호환성 |
6. LoRA → vLLM 병합 모델 생성
6.1 PEFT merge_and_unload() 문제
vLLM에서 LoRA 어댑터를 직접 로드할 수도 있지만, 성능을 위해 미리 병합된 모델을 사용하는 것이 좋다. 그러나 PEFT의 merge_and_unload()는 여전히 버그가 있다.
# 이 코드는 여전히 문제가 있을 수 있음
model = PeftModel.from_pretrained(base_model, adapter_path)
merged_model = model.merge_and_unload() # ⚠️ 가중치 손상 가능
6.2 수동 병합 스크립트
안전한 병합을 위해 수동으로 LoRA 가중치를 적용하는 스크립트를 작성했다.
# merge_lora_manual.py
import torch
from safetensors.torch import load_file, save_file
from pathlib import Path
import json
def merge_lora_weights(
base_model_path: str,
adapter_path: str,
output_path: str,
lora_alpha: int = 64,
lora_r: int = 32,
):
"""
Manually merge LoRA weights into base model.
LoRA 병합 공식:
- Linear 레이어: W' = W + (lora_alpha / lora_r) * (B @ A)
- Embedding 레이어: W' = W + (lora_alpha / lora_r) * (B @ A).T
"""
scaling = lora_alpha / lora_r # 64 / 32 = 2.0
print(f"LoRA scaling factor: {scaling}")
# 1. Base model weights 로드
base_weights = {}
base_path = Path(base_model_path)
for shard_file in base_path.glob("model*.safetensors"):
print(f"Loading base weights from {shard_file.name}...")
base_weights.update(load_file(str(shard_file)))
# 2. LoRA adapter weights 로드
adapter_path = Path(adapter_path)
adapter_weights = load_file(str(adapter_path / "adapter_model.safetensors"))
# 3. LoRA config 로드
with open(adapter_path / "adapter_config.json") as f:
adapter_config = json.load(f)
target_modules = adapter_config["target_modules"]
print(f"Target modules: {target_modules}")
# 4. 각 타겟 모듈에 대해 LoRA 병합
merged_weights = dict(base_weights)
for module_name in target_modules:
# LoRA A, B 매트릭스 찾기
lora_a_keys = [k for k in adapter_weights if f"{module_name}.lora_A" in k]
lora_b_keys = [k for k in adapter_weights if f"{module_name}.lora_B" in k]
for lora_a_key in lora_a_keys:
# 대응하는 B 매트릭스와 원본 가중치 키 추출
base_key = lora_a_key.replace(".lora_A.weight", ".weight")
base_key = base_key.replace("base_model.model.", "")
lora_b_key = lora_a_key.replace("lora_A", "lora_B")
if base_key not in merged_weights:
print(f"⚠️ Base key not found: {base_key}")
continue
if lora_b_key not in adapter_weights:
print(f"⚠️ LoRA B key not found: {lora_b_key}")
continue
# LoRA 매트릭스 로드
lora_a = adapter_weights[lora_a_key] # [r, in_features]
lora_b = adapter_weights[lora_b_key] # [out_features, r]
base_weight = merged_weights[base_key]
print(f"Merging {base_key}:")
print(f" Base shape: {base_weight.shape}")
print(f" LoRA A shape: {lora_a.shape}, LoRA B shape: {lora_b.shape}")
# 5. ★ 병합 공식 적용
if "embed_tokens" in base_key or "lm_head" in base_key:
# Embedding 레이어: W' = W + scaling * (B @ A).T
delta = scaling * (lora_b @ lora_a).T
else:
# Linear 레이어: W' = W + scaling * (B @ A)
delta = scaling * (lora_b @ lora_a)
merged_weights[base_key] = base_weight + delta.to(base_weight.dtype)
print(f" ✓ Merged with delta norm: {delta.norm().item():.4f}")
# 6. 병합된 가중치 저장
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
# 가중치 키 변환 (Transformers → vLLM 형식)
vllm_weights = convert_to_vllm_format(merged_weights)
save_file(vllm_weights, str(output_path / "model.safetensors"))
print(f"✓ Saved merged model to {output_path}")
# 7. config.json 수정 (tie_word_embeddings=False)
config_path = base_path / "config.json"
with open(config_path) as f:
config = json.load(f)
config["tie_word_embeddings"] = False
if "text_config" in config:
config["text_config"]["tie_word_embeddings"] = False
with open(output_path / "config.json", "w") as f:
json.dump(config, f, indent=2)
print(f"✓ Updated config with tie_word_embeddings=False")
return output_path
def convert_to_vllm_format(weights: dict) -> dict:
"""
Convert Transformers weight keys to vLLM format.
Transformers: model.language_model.layers.0.self_attn.q_proj.weight
vLLM: model.layers.0.self_attn.q_proj.weight
"""
vllm_weights = {}
key_mapping = {
"model.language_model.embed_tokens": "model.embed_tokens",
"model.language_model.layers": "model.layers",
"model.language_model.norm": "model.norm",
}
for key, value in weights.items():
new_key = key
for old_prefix, new_prefix in key_mapping.items():
if key.startswith(old_prefix):
new_key = key.replace(old_prefix, new_prefix, 1)
break
vllm_weights[new_key] = value
# lm_head 별도 복제 (embed_tokens와 분리)
if "model.embed_tokens.weight" in vllm_weights:
if "lm_head.weight" not in vllm_weights:
# Untied embeddings를 위해 lm_head 별도 생성
vllm_weights["lm_head.weight"] = vllm_weights["model.embed_tokens.weight"].clone()
print("✓ Created separate lm_head.weight for untied embeddings")
return vllm_weights
if __name__ == "__main__":
merge_lora_weights(
base_model_path="/home/slime/models/gemma-3-4b-it",
adapter_path="/home/slime/SLM/finetuned/h4mg_lora_v3_h4mg_v3_20251209_151017",
output_path="/home/slime/SLM/finetuned/h4mg_merged_v3_vllm",
lora_alpha=64,
lora_r=32,
)
6.3 추론 시 주의사항
병합된 모델을 Transformers로 로드할 때도 동일하게 tie_word_embeddings=False로 설정해야 한다:
# 추론 시 올바른 로드 방법
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
# 1. Config 수정
config = AutoConfig.from_pretrained(merged_model_path)
config.tie_word_embeddings = False # ★ 필수!
# 2. 모델 로드
model = AutoModelForCausalLM.from_pretrained(
merged_model_path,
config=config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# 3. embed_tokens → lm_head 가중치 복사 (안전을 위해)
with torch.no_grad():
model.get_output_embeddings().weight.data = \\\\
model.get_input_embeddings().weight.data.clone()
# 4. 토크나이저 로드 및 임베딩 리사이즈
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model.resize_token_embeddings(len(tokenizer))
vLLM으로 로드할 때는 이제 자연스럽게 작동한다. :
from vllm import LLM, SamplingParams
# vLLM 로드 - 이제 tie_word_embeddings=False도 지원!
llm = LLM(
model="/home/slime/SLM/finetuned/h4mg_merged_v3_vllm",
trust_remote_code=True,
dtype="bfloat16",
gpu_memory_utilization=0.9,
)
# 추론
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
outputs = llm.generate(["Packet Analysis: modbus.fc=3..."], sampling_params)
print(outputs[0].outputs[0].text)
7. 결론
7.1 기술적 교훈
- 모델 아키텍처 이해의 중요성: tie_word_embeddings라는 단일 설정이 전체 학습 파이프라인을 망가뜨릴 수 있다.
- Vocab Size 관리: 커스텀 토크나이저를 만들 때는 반드시 base model embedding size보다 크게 설정해야 한다.
- Type Safety: vLLM 같은 시스템에서는 클래스 상속 구조를 이해하고 올바른 타입을 사용해야 한다.
7.2 오픈소스 기여 교훈
- 메인테이너와의 소통: 코드 리뷰 과정에서 더 나은 구현 방식을 배울 수 있다.
- 일관성 존중: 기존 코드베이스의 패턴을 따르는 것이 중요하다 (tie_weights() 패턴 사용).
7.3 프로젝트 성과
항목 v2 (문제 발생) v3 (해결 후)
| 출력 품질 | 깨진 토큰 반복 | 정상 출력 |
| Eval Loss | 0.0756 | 0.0739 |
| Token Accuracy | 측정 불가 | 97.20% |
| vLLM 호환성 | ❌ 불가 | ✅ 완벽 지원 |
| 오픈소스 기여 | - | vLLM PR Merged |
7.4 관련 링크
- vLLM PR: #30827 - Gemma3: Support untied word embeddings
- 관련 Issue: #16555 - tie_word_embeddings=False loading issue
- PEFT Bug: #2018 - Tied embeddings merge issue
이 글이 Gemma-3 + LoRA + vLLM 조합에서 비슷한 문제를 겪는 분들께 도움이 되길 바랍니다.