Entropix : Entropy Based Sampling and Parallel CoT Decoding

|

🟢 Entropix

최근 많이 사용되는 LLM들은 문맥에 맞는 일관된 텍스트를 잘 생성하지만 복잡한 추론이 필요한 작업에 대한 hallucination 및 shallow reasoning 문제가 있다. 즉, 복잡한 추론이 요구되는 문제에 대해 잘못된 정보를 말하거나 얕은 추론만 하는 문제가 있다. 이와 같은 문제를 해결하기 위해 최근에는 entropy 기반 sampling이 많이 연구 되고 있다. entropy 기반 sampling은 decoding 과정에서 entropy를 측정하여 모델이 불확실하게 생성하는 부분을 잘 파악하고 모델이 token을 보다 효과적으로 선택하도록 하는 방법론이다.

Entropix는 모델의 entropy와 varentropy(variance of entropy)를 통합하여 모델이 불확실성을 더 잘 인식하도록하여 더 정확하고 일관된 출력을 생성하도록 하는 방법론을 제안한다. 이 방법론은 모델의 신뢰도 수준(confidence level)에 따라 샘플링 전략을 조정함으로써 추론 중 연쇄 사고를 시뮬레이션하는 것을 목표로 한다. 간단히 말하자면, 모델이 token을 생성할 때 entropy가 높으면 모델에 불확실하다는 신호를 주어 sample parameter를 조정하여 보다 정확한 추론을 유도한다는 것이다. 이는 Chain-of-Thought (CoT)방식의 추론처럼 진행된다.

Chain-of-Thought (CoT): 복잡한 질문에 대해 단계별로 논리적인 사고를 하도록 하는 기술

**Entropy**

- LLM을 통해 다음 token을 예측할 때 불확실성(uncertainty) 혹은 무작위성(randomness)을 측정하여 다음 token에 대한 확률분포가 얼마나 넒게 퍼져있는지 정량화한 것으로, 모델이 다음 token을 예측할 때 얼마나 불확실한지 측정하는 지표가 된다.
- entropy가 높을수록 모델이 불확실한 상태이다.

**Varentropy**

- entropy의 변동성을 측정한 것. 즉, entropy의 분산이다.
- 이는 모델이 주어진 context에서 모델의 불확실성과 예측의 다양성을 갖는지 나타내며, 모델이 token을 예측하는 과정이 얼마나 안정적인지 보여주는 지표이다.
- varentropy가 낮을수록 불확실성이 고르게 분포하는 것이고, 높을수록 불확실성의 변동 폭이 크다는 것이다.
- 예를 들어 어떤 구간에서는 매우 확신을 갖는데 어떤 구간에서는 매우 불확실하게 token을 예측하면 varentropy가 높게 측정된다.

**Varentropy 계산**
1. 확률과 log probability 계산
  softmax를 사용하여 현재 위치에서 예측 결과로 반환 가능한 token에 대한 각 확률을 계산하고, 그들의 log probability를 산출한다.
2. entropy 계산
  현재 위치에서의 확률 분포에 대한 entropy를 계산한다.
3. varentropy 계산
  각 token에 대해 정보량(negative log probability)와 평균 정보량(entropy)을 제곱하고, 이를 해당 token의 확률에 더한다.

1. Main Components

  1. Language Model: LLM(Entropox에서는 LLaMa 3.1 사용)
  2. KV-Cache: 추론 최적화를 위해 key와 value tensor를 저장하는 cache
  3. Metric Calculator: entropy, varentropy, attention-based metric을 계산하는 도구들
  4. Sampling Strategies: sampling 조정 전략
  5. Adaptive Sampler: sampling 전략을 선택하고 적용하는 모듈

2. Method

Entropix는 entropy와 varentropy를 기반으로 모델이 다음 token을 예측하는 방법을 조정한다.

  • 모델의 entropy와 varentropy가 낮을 때는 예측에 대한 확신이 있다는 의미이기 때문에 일반적인 방식대로 진행한다.
  • 모델의 entropy와 varentropy가 높을 때는 예측에 대한 확신이 낮다는 의미이기 때문에 다른 가능성을 더 많이 탐색하거나 다른 추론 방식을 고려한다.

이와 같은 방법은 모델이 확신이 없을 때 더 깊이 생각하는 것처럼 행동하게 만들어 결과적으로 더 정확하고 일관된 출력을 만들어낸다.

3. Data Flow and Decision-Making Process

아래는 Entropix의 진행 과정을 도식화한 것이다.

Illustrated by the author


Step 1: Token Generation

모델이 입력 토큰을 처리하여 logit과 attention score를 생성

Step 2: Metric Calculation

Metric Calculator가 모델의 출력(logit과 attention score)을 기반으로 entropy, varentropy, attention entropy, attention agreement, interaction strength를 계산

Step 3: Strategy Selection

Adaptive Sampler가 계산된 metric을 분석하여 가장 적합한 sampling strategy를 선택

Step 4: Parameter Adjustment

선택된 strategy와 metric에 따라 sampling parameter를 동적으로 조정

Step 5: Token Sampling

선택된 sampling strategy가 적용되어 다음 token 예측

Step 6: Iteration

현재 time step에서 예측된 token을 포함하여 1단계부터 위 과정을 반복하여 전체 sequence 생성.

🟢 Sampling strategy

1. Main keyword

  • Entropy-based decision making

    logits의 entropy와 varentropy로 모델의 불확실성을 평가하고 이에 따라 sampling 전략을 조정한다.

    logits은 모델이 vocab 내 각 token에 대해 출력하는 정규화되지 않은 값이다. 이는 sequence에서 다음 token이 될 각 token에 대한 모델의 “신뢰도”를 나타낸다.

  • Attention-aware sampling

    attention pattern에서 도출된 entropy나 일치율(agreement) 등의 지표들을 활용해 sampling parameter 전략을 세운다.

    Entropix에서는 attention score를 활용하여 attention entropy와 agreement와 같은 지표를 계산하며, 이는 sampling 전략에 중요한 정보를 제공한다.

  • Dynamic parameter adjustment

    temperature, top-k, top-p와 같은 sampling parameter가 현재의 contet와 모델 상태에 따라 동적으로 조정된다.

  • Adaptive multi-sample approach

    중간 정도의 불확실성을 가질 때에는 여러 샘플을 생성하고 이를 평가하여 가장 적합한 토큰을 선택한다.

2. Impact on LLM Inference

Entropix의 sampling 조정 전략은 LLM 추론에 아래와 같은 영향을 미친다.

  • Improved coherence

    모델 추론 과정에 불확실성 측정 결과를 반영함으로써 일관성을 유지하는 데 도움을 준다.

  • Enhanced context sensitivity

    attention-aware sampling을 통해 생성 과정 전반에 걸쳐 context를 더 잘 보존하는 데에 도움을 준다.

  • Reduced hallucinations

    동적으로 sampling parameter를 조정하여 불확실성이 높은 상황에서 hallucination을 줄이는 데 도움을 준다.

  • Flexible generation

    현재 context에 따라 다양한 sampling 전략을 유연하게 전환하여 더 세밀하고 적절한 text 생성을 가능하게 한다.

3. Entropy, Varentropy

  • Entropy
    • LLM을 통해 다음 token을 예측할 때 불확실성(uncertainty) 혹은 무작위성(randomness)을 측정하여 다음 token에 대한 확률분포가 얼마나 넒게 퍼져있는지 정량화한 것으로, 모델이 다음 token을 예측할 때 얼마나 불확실한지 측정하는 지표가 된다.
    • entropy가 높을수록 모델이 불확실한 상태이다.
  • Varentropy
    • entropy의 변동성을 측정한 것. 즉, entropy의 분산이다.
    • 이는 모델이 주어진 context에서 모델의 불확실성과 예측의 다양성을 갖는지 나타내며, 모델이 token을 예측하는 과정이 얼마나 안정적인지 보여주는 지표이다.
    • varentropy가 낮을수록 불확실성이 고르게 분포하는 것이고, 높을수록 불확실성의 변동 폭이 크다는 것이다.
    • 예를 들어 어떤 구간에서는 매우 확신을 갖는데 어떤 구간에서는 매우 불확실하게 token을 예측하면 varentropy가 높게 측정된다.

3.1 Sampling

https://southbridge-research.notion.site/


3.1.1 Adaptive Sampling

Adaptive Sampling은 위 그림에서도 확인할 수 있듯이 entropy와 varentropy가 극단적이지 않을 때 적용된다.

  • Step 1: Calculate Metrics

    logit과 attention score로부터 아래 metric들을 계산

    • Logit의 entropy와 varentropy
    • Attention의 entropy와 varentropy
    • Attention agreement
    • Interaction strength
  • Step 2: Adjust Sampling Parameters

    위 metric들을 기반으로 sampling parameter를 동적으로 조정

    • Temperature(무작위성):
      • 불확실성과 agreement에 따라 조정
        temperature = base_temp * (1 + 0.3 * logits_uncertainty + 0.2 * attn_uncertainty - 0.2 * metrics["agreement"])
        
    • Top-p (nucleus sampling threshold)
      • Attention Varentropy에 따라 조정
    • Top-k
      • interaction strength과 agreement에 따라 조정
    • Minimum probability threshold (min_p)
      • logit의 불확실성에 따라 조정
  • Step 3: Generate Multiple Samples

    조정된 parameter를 사용하여 여러 후보 token들을 생성. (본 연구에서는 기본 값으로 12개의 후보 token이 생성됨.)

  • Step 4: Score Samples

    각 후보들은 다음 두 가지 요인을 기반으로 점수가 매겨진다.:

    1. model의 logit에서 나온 Log probability
    2. 계산된 metric에서 파생된 신뢰도 점수
  • Step 5: Select Best Sample

    가장 높은 최종 점수를 가진 token이 최종적으로 선택됨.

3.1.2 Specialized Sampling

Adaptive Sampling와는 달리 entropy와 varentropy가 극단적일 경우 사용된다.

Entropix에서는 네 가지 주요 Specialized Sampling 기법이 있다.

    varentropy  
    low high
entropy low 확신이 높은 상태.
일관성 있는 확신.
greedy sampling 사용
예측 token의 다양성이 높은 상태.
탐색적 sampling을 통해 다양한 가능성을 조사한다.
high 일관적으로 불확실한 상태
설명 삽입이나 탐색 범위를 늘려야 한다.
높은 불확실성, 일관성 매우 부족
parameter를 조정하여 높은 불확실성 sampling을 유도한다.

1. Greedy Sampling: entropy ↓ varentropy ↓

Greedy Sampling은 불확실성이 낮을 때, 즉 entropy와 varentropy가 모두 낮을 때 사용한다. 이 경우는 모델이 스스로의 예측에 대해 매우 높은 신뢰도를 보이고 있는 상태이다.

  • Step 1: entropy와 varentropy의 thresholds 확인 (thresholds는 일반적으로 0.1)

  • Step 2: thresholds를 만족한다면 가장 높은 확률을 지닌 token 선택

https://southbridge-research.notion.site/


2. Clarification Insertion: entropy ↑ varentropy ↓

Clarification Insertion은 entropy는 높고 varentropy는 낮을 때 사용한다. 이는 모델이 불확실성을 가지고 있지만 그 불확실성이 전체 sequence에 대해 일관된 상태이다.

  • Step 1: entropy는 높고 varentropy는 낮은 조건 확인

    entropy가 thresholds(0.3)이상이고, varentropy가 thresholds(0.1)이하인지 확인

  • Step 2: Insert clarification token

    위 조건이 충족되고, 설명 token이 사용되지 않았다면, 미리 정의한 “clarification question” token을 삽입

  • Step 3: 후속 sampling 조정

    설명 token이 삽입된 상태에서는 다음 token에 대해 temperature를 약간 높인다.

3. Exploration Sampling: entropy ↓ varentropy ↑

Exploration Sampling은 entropy는 낮고 varentropy는 높을 때 사용한다. 이는 모델이 자신감 있게 token을 예측하지만 많은 token 후보에 대해 높은 가능성을 할당하고 있는 상태이다.

  • Step 1: Adjust temperature

    interaction strength metric에 따라 temperature를 조정한다.

  • Step 2: Modify top-k

    attention agreement metric을 기반으로 top-k를 조정한다.

  • Step 3: Sample with adjusted parameters

    위 두 단계로 조정된 parameter를 사용하여 예측을 진행한다.

4. High Uncertainty Sampling: entropy ↑ varentropy ↑

High Uncertainty Sampling은 entropy와 varentropy 모두 높을 때 사용한다. 이는 모델이 매우 불확실한 상태를 의미한다.

  • Step 1: Significantly increase temperature

    attention varentropy에 따라 temperature를 대폭 증가시킨다.

  • Step 2: Modify top-p

    attention entropy가 높으면 top-p를 낮춘다.

  • Step 3: Sample with adjusted parameters

    위 두 단계로 조정된 parameter를 사용하여 예측을 진행한다.

3.2 Attention Entropy

Attention Entropy는 Entropix에서 attentino이 다양한 token에 걸쳐 얼마나 분산되어 있는지, 즉 얼마나 불확실한지 정량화하는 지표이다. 이는 attention probability를 활용하 계산된다.

Attention Entropy가 높으면 불확실성이 높은 거. (모델의 attention이 많은 token들에 걸쳐 분산되어 있다는 거니까.)

Attention Entropy가 낮으면 불확실성이 낮은 거. (모델이 특정 token들이 집중적으로 attention하고 있다는 거니까.)

attention_probs = jax.nn.softmax(attention_scores, axis=-1)
attn_entropy = -jnp.sum(attention_probs * jnp.log2(jnp.clip(attention_probs, 1e-10, 1.0)), axis=-1)

Attention은 transformer model에서 일반적으로 multi-head attention으로 구현된다. multi-head의 의미는 attention mechanism이 병력적으로 적용된다는 것이다. 이와 같이 attnetion을 병렬적으로 수행함으로써 모델이 입력을 다양한 측면으로 분석할 수 있게 한다.
Multi-head Attention을 구성하는 Attention head는 self-attention을 수행하는 유닛으로, 입력된 token과 다른 token들 간의 연관성을 계산하고 그 연관성을 기반으로 중요한 정보를 추출하는 역할을 한다.

decoder 기반 모델에서 사용되는 masked attention의 경우, 다음 token을 예측하는 task를 수행하기 때문에 현재 step 이후의 token에 대한 정보는 차단한다. 차단 방법은 attention score를 매우 큰 음수 혹은 무한대로 설정하여 softmax 연산에서 해당 값들이 거의 0에 가깝도록 처리한다. 예를 들어 현재 step이 세 번째일때, 앞의 두 token만 참고하여 세 번째 token을 예측한다. -> 이거 학습 시의 multi head attention 작동 방식이다.
학습과 추론 시의 masked attention 작동 방법이 살짝 다르다.

학습

  • 병렬처리: 전체 sequence가 이미 주어졌기 때문에 한번에 sequence 전체를 처리할 수 있도록 병렬처리를 한다.

추론

  • 순차적 처리: 모델이 아직 없는 token을 새롭게 생성해야 하기 때문에 순차적으로 처리되어야 한다. (따라서 한 문장에 대한 처리가 학습할 때보다 느리다.)
  • masking X: 현재 time step에서 아직 생성되지 않은 이후의 token들은 없기 때문에 masking할 것이 없다. 따라서 매 단계마다 masking할 것은 따로 없고 그냥 이전 token들을 참조하여 self-attention만 수행하면 된다.

3.2.1 Sampling

Attention Entropy가 높으면 sampling 탐색 범위를 넓혀야 한다.

3.3 Attention Agreement

Attention Agreement는 서로 다른 attention head 간의 attention pattern이 얼마나 일관성있는지 측정한 것이다. 각 head의 attention distribution을 평균 attention distribution과 비교하여 계산된다.

낮은 agreement는 각 head들이 각각 다른 측면에 집중하고 있음을 나타낸다. 이것은 모델이 context를 모호하게 받아들이고 있음을 의미한다.

mean_attention = jnp.mean(attention_probs, axis=1)
agreement = jnp.mean(jnp.abs(attention_probs - mean_attention[:, None, :]), axis=(1, 2))

3.3.1 Sampling

Attention Agreement가 낮으면 temperature나 top-k paramerter를 조정해야 한다.

3.4 Interaction Strength

Interaction Strength는 token 간의 결속력/token 관계의 강도 등을 나타낸다. Interaction Strength은 transformer model의 layer, head, position에서 attention score의 절대값 평균이다. 산출 방법은 아래와 같다:

  1. Attention Score 추출

    transformer model의 모든 layer와 head에서 raw attention score롤 추출한다.

  2. 절대값 적용

    attention score 값 자체에 집중하기 위해 score의 절대값을 구한다.

  3. 평균 계산

    이 절대값들의 평균을 모든 dimension(layers, heads, positions)에 대해서 계산한다.

interaction_strength = jnp.mean(jnp.abs(attention_scores), axis=(1, 2, 3))

3.4.1 Sampling

  • Interaction Strength가 높을 때
    • token 간의 관계가 복잡하다
    • context를 애매하게 이어 나갈 가능성 높음
    • 더 넓은 token 후보에 대한 신중한 선택이 필요한 상황. top-k 증가 필요
    • 다양한 출력을 유도하기 위해 temperature가 증가 필요
    • 복잡한 token간 관계를 포착하기 위해 sampling 전략으로 더 많은 탐색적으로 parameter 조정할 필요 있음. 예를 들어, temperature를 높이거나 top-p와 같은 파라미터의 범위를 넓힘으로써 더 다양하고 창의적인 결과를 생성할 수 있도록 유도.
  • Interaction Strength가 낮을 때
    • 단순한 sequence, token 간의 관계 단순
    • 더 집중적이거나 결정적인 선택이 가능한 상태
    • 가능성이 낮은 후보들을 더 적극적으로 제거하면 좋음.

Reference

https://github.com/xjdr-alt/entropix
https://github.com/hallucinomeny/hyperobject/blob/main/hyperobject.py
https://southbridge-research.notion.site/Entropixplained-11e5fec70db180b6bfafe878433c2104