본문 바로가기
예바의 스터디/개념 정리

MPNN(Message Passing Neural Network)

by 예바두비두밥바 2024. 8. 15.

MPNN(Message Passing Neural Network)와 MLFF 분야에서의 활용

1. MPNN이란?

그래프 데이터를 처리하기 위해 설계된 딥러닝 모델이다. GNN의 한 종류로, Message Passing을 통해 각 노드가 이웃 노드와 정보를 주고 받으며 학습하고 예측하는 데에 사용된다. MPNN은 크게 Message Passing PhaseReadout Phase로 2가지 단계로 구성된다. 

 

2. MPNN의 2가지 단계: Message Passing & Readout Phase

#Phase 1. Message Passing

: 노드를 표현하는 Feature를 이웃 노드의 Feature를 토대로 업데이트하는 단계

 Message Function : 각 노드가 이웃 노드들로부터 받는(혹은 보내는) 메시지를 계산하는데 사용하는 함수

이 수식은 노드 v에서 노드 u로 보내는 메시지를 나타낸다. 이에 대해 간단히 들여다 보면,  시간 t에서의 Message 함수 M(t)의 입력 값은 시간 t-1에서의 노드 v의 Embedding(상태), 노드 u의 Embedding(상태), 노드 v와 u 사이의 엣지 Feature로 3가지이다.

 

 Update Function : 이웃 노드들로부터 받는 메시지를 기반으로 자신의 상태를 업데이트 하는데 사용하는 함수

이 수식은 시간 t에서 업데이트된 노드 u의 Embedding(상태)를 나타낸다. 시간 t-1에서의 자신의 Embedding(상태)집계(Aggregate)된 이웃 노드로부터의 메시지를 Update 함수 U(t)의 입력값으로 사용한다.

 

* 메시지 : 각 노드의 이웃으로부터 받는 정보

(여러 노드로부터 받은 메세지를 집계하여 노드의 새로운 상태를 계산하는데 사용)


노드 u의 계산된 상태를 Update 함수로 기존의 상태에서 새로운 상태로 업데이트한다. 노드 u의 새로운 상태를 계산하기 위해서 v를 포함한 여러 이웃 노드로부터 메시지를 받으며 이 메세지를 구성하기 위해 Message 함수가 사용된다. Message 함수는 t-1 시간(상태를 업데이트하기 이전)의 자기 자신과 해당 이웃 노드의 상태와 노드 간의 엣지 Feature를 결합하여 메시지를 구성한다.

 

상단의 예시를 통해 한번 더 이해해보자.

 

현재의 Target 노드는 A로, 노드 A의 새로운 상태를 계산해보자.

A의 새로운 상태를 계산하려면 A의 이웃 노드 B, C, D로부터 전달 받은 메시지를 집계해야한다. B, C, D 각각의 메시지는 Message 함수를 통해 생성된다.

  • B의 메시지 구성을 위해 사용되는 정보는 현재 A와 B의 노드 Embedding, A와 B 사이의 엣지 Feature이다.
  • C의 메시지 구성을 위해 사용되는 정보는 현재 A와 C의 노드 Embedding, A와 C 사이의 엣지 Feature이다.
  • D의 메시지 구성을 위해 사용되는 정보는 현재 A와 D의 노드 Embedding, A와 D 사이의 엣지 Feature이다.

결국, A는 B, C, D로부터 전달 받은 메시지를 집계(Aggregate)하여 자신의 상태를 업데이트한다.

 

A는 자신의 이웃 노드의 상태(Embedding)를 기반으로 업데이트 되었지만, A의 이웃 노드의 상태(Embedding)도 각각(B, C, D)의 이웃 노드의 상태(Embedding)을 기반으로 업데이트된 것이기에 단순히 A의 이웃 노드의 상태(Embedding)만을 포함하지 않는다.

 

A에게 메시지를 전달하기 이전에 B와 C, D 모두 자신의 Embedding을 가지기 위해 이와 비슷한 방식으로 자신의 이웃노드로부터의 메시지를 합하여 자신의 상태를 업데이트하였다. (좌측 그림의 왼쪽 부분을 통해 간단히 살펴 볼 수 있다.)

 

#Phase 2. Readout

: 각 노드가 얻은 상태를 기반으로, 그래프 전체 또는 특정 노드에 대한 최종 출력을 생성하는 단계

  • 노드 임베딩의 집계(Aggregation of Node Embeddings) : 각 노드의 임베딩을 결합해 하나의 벡터로 표현
  • 그래프 레벨 예측(Graph-level Prediction) : 학습한 state를 통해 그래프의 label 도출
  • 노드 레벨 예측(Node-level Prediction) : 학습한 state를 통해 노드의 label 도출

 

3. MPNN의 활용 in MLFF(Machine Learning Force Fields)

MLFF란, MD(분자 동역학) 시뮬레이션에서 분자의 물리적 특성인 힘과 에너지 등을 머신러닝을 통해 예측하는 분야이다. MLFF 모델은 크게 GNN 기반, Descriptor 기반으로 2가지로 나누어지는데, 현재 GNN 기반 모델이 더 높은 성능을 보이고 있다. GNN 기반 모델은 분자를 그래프 형태로 변환하여 MD 시뮬레이션 내에서 힘과 에너지를 예측한다. 이에 대한 예시로, MACE 모델이 있으며 해당 모델은  2가지 노드 간의 메시지를 패싱하던 기존 MPNN을 발전시켜 3가지 혹은 그 이상의 노드 간의 메시지를 패싱하여 높은 성능을 보이고 있다. 이를 통해, MLFF 분야에서 물리적 특성을 예측하기 위해 2개의 원자 간의 상호작용 보다는 여러 개의 원자 간의 상호작용을 종합적으로 구별해야함을 시사한다.

 

이에 대한 자세한 내용은 아래의 링크에 있는 논문을 참고해주세요!

 

MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields

Requests for name changes in the electronic proceedings will be accepted with no questions asked. However name changes may cause bibliographic tracking issues. Authors are asked to consider this carefully and discuss it with their co-authors prior to reque

proceedings.neurips.cc

 


프로젝트를 진행하며 모르는 내용을 따로 찾아보며 정리한 글이며

주관적인 내용 및 부족한 부분이 있습니다.

해당 자료에 관한 문의 및 피드백은 연락바랍니다 :]

시간에 따라 더 나은 자료를 위해 수정이 가능합니다!