Mistral 7B

|

Mistral 7B abstract

1. Model Size:

  • 73억 개의 파라미터를 가진 언어 모델

2. Performance Comparison:

  • 모든 benchmark에서 LLaMa2 13B보다 우수한 성능을 보임.
  • 다수의 benchmark에서 LLaMa1 34B보다 우수한 성능을 보임.
  • CodeLLaMa 7B와 유사한 코드 퍼포먼스를 내면서 영어 task 처리 성능을 유지.

3. Architectural Features:

  • Grouped-Query Attention (GQA) 빠른 추론을 위해 Grouped-query attention (GQA)를 사용. (LLaMa2에도 사용되었다.)
  • Sliding Window Attention (SWA) 더 낮은 계산 비용으로 더 긴 시퀀스를 처리하기 위해 Sliding Window Attention (SWA)를 사용.

Performance in details

아래 그래프는 LLaMa2 family 중 7B와 13B 그리고 LLaMa1 34B와 Mistral 7B의 성능을 각종benchmark를 통해 비교한 자료이다.

성능 평가 결과, Mistral 7B는 모든 metric에서 LLaMa2 13B의 성능을 뛰어 넘었고, LLaMa1 34B와는 유사한 성능을 보이는 것으로 확인되었다. (34B 모델로 LLaMa2가 아닌 LLaMa1 이 채택된 이유는 단순히 LLaMa2 34B가 공개되지 않았기 때문이다. LLaMa 개발팀이 34B 성능이 다른 사이즈들보다 좋지 않아 공개하지 않았다고 LLaMa2 논문에서 밝혔다.)

성능 측정에 사용된 benchmark의 범주는 아래와 같이 정리할 수 있다:

  • Commonsense Reasoning: Hellaswag, Winogrande, PIQA, SIQA, OpenbookQA, ARC-Easy, ARC-Challenge 그리고 CommonsenseQA의 0-shot 평균
  • World Knowledge: NaturalQuestions과 TriviaQA의 5-shot 평균
  • Reading Comprehension: BoolQ와 QuAC의 0-shot 평균
  • Math: GSM8K with maj@8의 8-shot과 MATH with maj@4의 4-shot 평균
  • Code: Humaneval과 3-shot MBPP의 0-shot 평균
  • Popular aggregated results: 5-shot MMLU, 3-shot BBH 그리고 3-5-shot AGI Eval (English multiple-choice questions only)

모델간 cost 대비 performance를 비교하기 위해 “equivalent model sizes” 계산 metric을 제안하였다. 즉, 두 모델이 유사한 성능을 내기 위해 model parameter(size)를 각각 어디까지 scaling up해야 하는지 계산하는 metric을 제안하였다. 여기에는 아래 그래프와 같이 MMLU, Commonsense Reasoning, World Knowledge 그리고 Reading comprehension benchmark가 사용되었다. 실험 결과, reasoning, comprehension and STEM reasoning (MMLU)에서 Mistral 7B보다 3배 이상의 파라미터가 있어야 LLaMa2 모델이 동등한 성능을 낸다는 것이 밝혀졌다. 이를 통해 Mistral이 LLaMa2보다 메모리 절약과 데이터 전송 처리에 효과적인 모델임을 알 수 있었다고 한다.

equivalent model sizes 실험에 사용된 benchmark 중 World Knowledg를 제외한 모든 범주에서 Mistral 7B가 LLaMa2 13B보다 좋은 성능을 보였다. (이와 같은 결과는 Mistral의 제한적인 parameter 수로 인해 압축할 수 있는 지식의 양이 제한되었기 때문으로 추측된다고 한다.)

Mistral과 LLaMa2가 각각의 논문에서 수행한 성능 평가에서 주요하게 다른 점은 아래와 같다:

  • MBPP의 경우, Mistral에서는 hand-verified subset을 사용하였다.
  • Trivia QA의 경우, Mistral에서는 Wikipedia context를 제공하지 않았다.

Sliding Window Attention (SWA) mechanism

Mistral 7B는 각 layer가 이전 4096개의 hidden staet를 참조하는 Sliding Window Attention (SWA) mechanism을 사용한다. 이 방법론은 O(sliding_window.seq_len)의 선형 계산 비용을 가지기 때문에 고안되었다. 실제로 해당 방법론 적용으로 FlashAttention과 XFormers는 sequence length 16k, window size 4k일 때 2배 정도의 속도 향상이 있었다고 한다.

Sliding Window Attention은 현재 time step에 대한 window size 이상의 과거 context를 참조하기 위해 transformer의 stacked layer들을 사용한다: layer $k$의 token $i$는 layer $k-1$에 있는 $[i-sliding_window, i]$에 집중한다. 그리고 $k-1$에 있는 이 token들은 $[i-2*sliding_window, i]$에 집중한다. higher layer일수록 attention pattern이 수반하는 것보다 더 멀리 있는 정보에 접근할 수 있다.

Rolling buffer cache

1. Vanilla attention

Attention은 sequence 내 토근 간의 정보를 공유하는 방법이다. Vanilla Transformer의 attention은 causal mask 방식을 채택하여 sequence의 각 token이 자신을 포함한 과거의 모든 token에 attention을 하였다. 이와 같은 방법은 미래의 token 예측에 과거의 정보만 사용할 수 있다. fixed attention span은 rotating buffer를 사용하여 cache 크기를 sliding_window token size로 제한할 수 있는 것을 의미한다. 이를 통해 모델 품질에는 영향을 미치지 않으면서 추론 시 8192 sequence length에 대한 cache memory를 절반 정도 줄일 수 있다.

2. Sliding window to speed-up inference and reduce memory pressure

attention 연산 횟수는 sequence length에 대해 이차적(quadratic)이고 memory pressure은 sequence length에 대해 선형적이다. 이로 인해 추론 시 cache 가용성이 감소하기 때문에 데이터 처리량은 줄고 시간은 더 지연된다. 이 문제를 완화하기 위해 Mistal은 각 token이 window size만큼의 과거 토큰에만 attention을 하는 Sliding Window Attention을 사용한다. 아래 그림은 window size가 3인 경우이다.

해당 방법론의 특이 사항은 sliding window 범위 밖에 있는 token도 next word prediction에 영향을 미친다는 점이다. 각 attention layer에서 정보는 최대 window size 만큼의 token만 전달될 수 있다. 따라서 2개의 attention layer를 거친 후에는 2*window size만큼의 정보가 전달될 수 있다. 예를 들어 sequence length가 16k이고 sliding window size가 4k일 때 4 layer를 거친 후에는 정보가 전체 sequence length로 전달된다.

이와 같이 해당 방법론이 sliding window 밖의 정보도 전달하지만 sequence 길이가 너무 길어지면 모델이 더 이상 full context를 전부 사용하진 않는 현상이 관찰되었다고 한다

3. Rolling buffer cache

Mistral에는 rolling buffer cache가 적용되었다. cache는 W(window size)로 고정되고 (key, value)값을 cache position(i%W) 내의 position i에 저장한다. position i가 W보다 크면 과거 cache의 value들이 덮어씌워진다.

4. Pre-fill and chunking

sequence generating 시 각 token은 이전 token에 의존하기 때문에 token을 하나씩 예측해야 한다. 하지만 prompt는 이미 앞서 나와 있으니까 이에 대한 (k, v) cache를 미리 채울 수 있다(pre-fill). Prompt가 아주 큰 경우에는 작은 조각(chunk)으로 나누어 각 chunk로 pre-fill을 할 수 있다. 이때 chunk size는 window size로 선택 가능하다. 따라서 attention 연산 시 각 chunk에 대해 연산을 수행하면 된다.

Fine-tuning Mistral 7B for chat

Mistral 7B의 일반화 능력 검증을 위해 Mistral 팀은 Huggingface에 공개된 instruction dtat로 fine-tuning된 chat model을 만들었다. (data에 다른 trick이나 독점 데이터를 추가하지 않았다고 한다.) 그 결과, MT-Bench에서 모든 7B 모델보다 좋은 성능을 보였고, 13B chat model들과는 유사한 성능을 보였다고 한다.

Reference

https://mistral.ai/news/announcing-mistral-7b/

https://github.com/mistralai/mistral-src