개인프로젝트

Temporal Cycle Consistency Loss 구현체 최적화

백악기작은펭귄 2024. 11. 23.
반응형

벡터화 연산을 활용한 Temporal Cycle Consistency Loss 최적화

최근 연구실에서 골프 모션 가이던스를 위한 시스템을 제작하고 있다. 이때, 사용자와 전문가 모션의 차이를 추출하기 위해 2022년 IEEE Access에 게재된 'AI Golf: Golf Swing Analysis Tool for Self-Training'라는 논문에서 제시한 S-TCC 네트워크를 이용하고 있는데, 여기서 S-TCC 인코더 학습에 사용된 Temporal Cycle Consistency Loss의 계산을 개선해보았다.


골프 스윙 분석을 위한 S-TCC 네트워크

개선 결과의 설명에 앞서, 논문에서 제시한 S-TCC(Skeleton-Based Temporal Cycle-Consistency) 네트워크에 대해 먼저 간략하게 소개를 해보겠다. 해당 논문에서 제시한 골프 스윙 분석 시스템에서 핵심 역할을 하는 S-TCC 네트워크는 Skeleton 데이터를 활용해 동작을 임베딩하고 두 모션 간의 차이가 생기는 지점, 즉 Discrepancy를 detect하는 역할을 한다.

 

이 네트워크의 주요 구성 요소는 인코더(Encoder) 디코더(Decoder)로 나뉘며, 각각 동작 데이터를 처리하여 임베딩을 생성하고, 이를 다시 스켈레톤 시퀀스로 복원한다.

AI Golf에서 제시한 Discrepancy detection. Cycle Consistency Loss를 이용해서 두 시퀀스 사이 가까운 노드를 찾아냄으로써, 같은 모션은 비슷한 임베딩을 생성하도록 인코더를 학습시킨다.

 

S-TCC는 Self-supervised Learning을 통해 시공간적 표현(spatiotemporal representations)을 학습한다. 이 논문은 Temporal Cycle Consistency를 기반으로, 동작 데이터의 잠재 공간(latent space) 표현을 학습하여 시퀀스 간의 관계를 이해하는 데 초점을 맞추고 있다. 

 

1) S-TCC 인코더

이 글에서 다루고 있는 TCC Loss를 사용하여 임베딩 방법을 학습하는 S-TCC의 인코더는 Skeleton 데이터를 입력받아 고차원의 임베딩 벡터로 변환한다. 이 논문에서는 16개 Joint의 Position 데이터(x, y, z)를 1×48 벡터로 변환한 뒤, 시간적 정보를 집계하기 위해 3D 컨볼루션 Max Pooling을 적용한다. 이를 통해 최종적으로 128차원의 임베딩 벡터를 생성하며, 이 벡터는 Skeleton 동작의 특징을 압축적으로 표현한다.

 

2) S-TCC 디코더

S-TCC 디코더는 인코더가 생성한 임베딩 벡터를 입력받아 원래의 Skeleton 데이터를 복원한다. 이 과정에서 Fully-Connected 레이어를 사용하며, 복원된 데이터와 원본 데이터 간의 차이를 최소화하도록 학습된다. Mean Squared Error(MSE) 손실 함수를 사용해 복원 과정의 정확도를 높인다.

 

3) Loss

결론적으로, 인코더는 TCC Loss를 이용해 임베딩 벡터 변환방법을 학습하고, 디코더는 MSE Loss를 이용해 이 임베딩을 원래의 Skeleton 데이터로 복원하는 방법을 학습한다.

TCC Loss:
두 동작 시퀀스를 정렬하고 시간적 일관성을 유지하기 위해 사용된다. Cycle Consistency 개념을 기반으로, 한 동작의 프레임을 다른 동작의 프레임과 비교하며 손실을 계산한다.

MSE Loss:
디코더가 복원한 Skeleton 데이터와 원본 데이터 간의 차이를 측정한다. 관절 위치의 차이를 최소화하여 복원된 Skeleton 데이터의 품질을 보장한다.

 

Temporal Cycle Consistency Loss 개선

이제, 기존 구현체에서의 비효율성을 개선한 부분을 다뤄보겠다. TCC Loss는 두 동작 시퀀스 간의 순환적 일관성을 측정하기 위해 모든 쌍의 거리를 계산한다. 이를 위해 기존에 사용하고 있던 코드는 for 문을 통해 전체 노드를 순환하면서 Cycle Consistency를 계산하였다.

# 원본

def original_temporal_cycle_consistency_loss(L1, L2):
  device = L1.device
  
  loss = torch.tensor(0.0, device=device, requires_grad=False)
  count = 0
  
  for i in range(len(L1)):
    nearest_L2 = min(range(len(L2)), key=lambda j: torch.dist(L1[i], L2[j]))
    cycle_L1 = min(range(len(L1)), key=lambda k: torch.dist(L2[nearest_L2], L1[k]))
    if cycle_L1 != i:
      loss = loss + torch.dist(L1[i], L2[nearest_L2])
      count += 1
          
  return loss / count if count > 0 else loss

 

하지만 이는 $O(N^2)$의 시간복잡도를 가지며 loss를 매번 업데이트하기 때문에 작은 부동소수점 누적 오차가 발생할 가능성이 있다. 또한, 파이썬의 반복문은 GPU의 병렬 처리 성능을 제대로 활용하지 못한다는 단점이 존재한다.

 

 

개선된 구현

이러한 문제점을 해결하기 위해, 벡터화 연산을 활용하여 TCC Loss를 최적화하였다. PyTorch의 torch.cdist를 활용하여 두 텐서 간의 모든 쌍의 거리를 한 번에 계산함으로써, 반복문을 제거하였다.

# 개선
def improved_temporal_cycle_consistency_loss(L1, L2):
  device = L1.device
  
  # Compute pairwise distances between all vectors in L1 and L2
  pairwise_distances = torch.cdist(L1, L2, p=2)  # (seq_len1, seq_len2)

  # Step 1: Find nearest neighbors from L1 to L2
  nearest_L2_indices = torch.argmin(pairwise_distances, dim=1)  # (seq_len1,)
  nearest_L2 = L2[nearest_L2_indices]  # Nearest neighbors in L2 for each L1

  # Step 2: Find nearest neighbors from nearest_L2 back to L1
  # Calculate distances between nearest_L2 and L1
  pairwise_distances_cycle = torch.cdist(nearest_L2, L1, p=2)  # (seq_len1, seq_len1)
  nearest_L1_indices = torch.argmin(pairwise_distances_cycle, dim=1)  # (seq_len1,)

  # Step 3: Calculate the cycle consistency loss
  # Compare original L1 and cycle_L1 (mapped back from L2)
  cycle_L1 = L1[nearest_L1_indices]  # Reconstructed points in L1
  cycle_loss = torch.mean(torch.norm(L1 - cycle_L1, dim=1))  # Average Euclidean distance

  return cycle_loss

 

개선된 TCC Loss 구현은 여러 면에서 기존 방식에 비해 효율적이다.

 

첫째, 기존 구현은 for 문을 사용해 순차적으로 계산했으나, 개선된 방식은 torch.cdist를 활용해 모든 계산을 벡터화하고 병렬 처리를 가능하게 한다. 이를 통해 GPU와 CPU 모두에서 연산의 효율성이 극대화된다.

 

둘째, 메모리 효율성에서도 개선이 이루어졌다. 모든 거리를 한 번에 계산함으로써 불필요한 개별 텐서 생성과 업데이트를 제거했으며, 최적화된 텐서 연산을 통해 실행 속도가 향상되었다.

 

마지막으로, 수치적 안정성 또한 강화되었다. 병렬 계산은 손실 값의 부동소수점 연산에서 발생할 수 있는 누적 오차를 줄이는 데 기여하며, torch.mean을 활용해 평균 손실을 안정적이고 정확하게 계산할 수 있다. 이는 특히 대규모 데이터와 복잡한 모델에서 기존 방식 대비 더 나은 성능과 안정성을 제공한다.

 

결과

다음은 1,000의 길이를 가진 동작 시퀀스 11,200의 길이를 가진 동작 시퀀스 2 사이 TCC Loss를 계산한 결과이다.

기존 함수는  for 문으로 인한  torch.dist 의 반복호출을 하고 있는데, 파이썬 레벨의 단일 연산은 GPU보다 CPU가 효율적이다.

 

기존 함수에서는 부동소수점 누적으로 인해 개선된 구현체와 loss 값이 약간 차이가 있지만, 두 구현은 거의 유사한 값을 가진다. 기존의 구현체는 CPU에서는 34.929038초, GPU에서는 129.197484초(!)가 소요된 반면, 개선된 함수는 병렬 연산을 사용했기 때문에 CPU에서 0.019590초, GPU에서 0.001790초가 소요되며, 소요 시간이 크게 개선된 것을 확인할 수 있다.

반응형

'개인프로젝트' 카테고리의 다른 글

테크포임팩트: B-Peach Lab을 시작하며  (5) 2024.10.12

댓글