본문 바로가기
Algorithm

[프로그래머스] Lv3 - 최적의 행렬 곱셈

by sangyunpark99 2025. 2. 25.

풀이 방법

단순 무식하게 브루트 포스 방법으로 구현하니, 시간 초과가 발생했습니다.

왜 시간 초과가 발생할까요?

 

행렬이 1개일때부터 한개씩 늘리면서 몇개의 경우가 생기는지 알아보겠습니다.

 

행렬이 1개인 경우

A

 

행렬을 곱할 수 없습니다.

 

 

행렬이 2개인 경우는 1가지 입니다.

A x B

 

행렬이 3개인 경우엔 2가지로 나뉩니다.

(A x B) x C
 A x (B x C)

 

행렬이 4개인 경우엔 5가지로 나뉩니다.

(A x B) x (C x D)
((A x B) x C) x D
(A x (B x C)) x D
 A x ((B x C) x D)
 A x (B x (C x D))

 

N개의 행렬이 있는 경우엔 아래와 같은 점화식을 따르게 됩니다.

 

문제에서 주어진 n의 최대 갯수는 200입니다.

(400)! / (201)! x 200!을 하는 경우 숫자가 너무 커집니다. 

 

브루트 포스 방법으로는 시간초과가 발생하게 됩니다.

 

어떤 알고리즘을 선택해서 풀어야 할까요?

 

행렬을 곱하는 과정에서 같은 부분의 행렬 곱셈을 여러번 계산하는 경우가 생기기 때문에, DP를 사용해서 이미 계산한 결과를 저장하고 재사용하면, 불필요한 계산을 줄일 수 있습니다. 

 

이 문제에 어떻게 불필요한 계산이 발생하나요?

 

4개의 행렬이 있을때 행렬의 곱셈을 계산하는 방식은 A x B → A x B x C → A x B x C x D 의 순서대로 합니다.

이중 A x B는 3번 중복 계산이 발생합니다. 중복 계산은 곧 불필요한 계산입니다.

A x B x C x D

 

 

 

이 문제는 큰 문제를 해결하기 위해 작은 문제를 해결해야 하는 방식으로, 작은 문제들의 최적해를 이용해 큰 문제의 최적해를 구하는 방식으로 풀어야 합니다. 이러한 성질은 동적 계획법(DP)로 해결할 수 있습니다.

 

이 문제는 작은 문제의 최적해를 조합해서 큰 문제의 최적해를 구할 수 있기 때문에, 최적 부분 구조를 가지게 됩니다.

 

이 문제에 어떻게 최적 부분 구조가 적용이 되나요?
A x B x C x D

 

4개의 행렬이 있을 때, 최소 연산 횟수를 구하는 과정은 작은 부분 문제부터 시작해 점점 큰 문제로 확장되는 구조를 가지게 됩니다.

A x B
A x B x C
A x B x C x D

 

쉽게 말해, A x B부터 시작해서 A x B x C, 그리고 A x B x C x D로 점점 확장되게 됩니다.

 

중복 되는 계산 부분최적 부분 구조로 인해 DP 알고리즘을 사용하는 것이 가장 적합합니다.

for(int k = start; k < end; k++) {
      int value = go(start, k) + go(k+1, end)
         + matrix[start][0] * matrix[k][1] * matrix[end][1];

 

이 문제는 "어디서 나눠야 할까"가 핵심이고, 이를 위해서 k라는 중간 지점을 정해서 구간을 계속 분할하는 방식으로 DP를 구현합니다.

  • go(start, k) : start부터 k구간 까지의 최소 연산 횟수
  • go(k+1, end) : k+1부터 end까지의 최소 연산 횟수
  • matrix[start][0] * matrix[k][1] * matrix[end][1] : 두 부분을 하나로 합칠때 필요한 연산 횟수의 값을 의미합니다.
왜 이렇게 점화식이 세워질까요?

 

4개의 행렬을 예시로 들어보겠습니다.

 

A x B x C x D

 

k = 1인 경우엔 다음과 같이 나뉩니다.

(A x B) x (C x D)

 

점화식과 비교해보면, 다음과 같습니다.

  • go(start,k) :  A구간부터 B구간 까지의 최소 연산 횟수 (A x B)
  • go(k+1, end) : C구간부터 D구간 까지의 최소 연산 횟수 (C x D)
  • 두 부분을 하나로 합치는 연산 횟수의 값  : (A x B) x (C x D)

 

풀이 코드 

import java.util.*;

class Solution {
    
    private int[][] dp;
    private int[][] matrix;
    
    public int solution(int[][] matrix_sizes) {
        int n = matrix_sizes.length;
        
        matrix = matrix_sizes;
        dp = new int[n][n];
        
        for(int i = 0; i < n; i++) {
            Arrays.fill(dp[i], -1);
        }
        
        return go(0, n - 1);
    }
    
    private int go(int start, int end) {        
        if(start == end) {
            return 0;
        }
        
        if(dp[start][end] != -1) return dp[start][end];
        
        int minValue = Integer.MAX_VALUE;
        
        for(int k = start; k < end; k++) {
            int value = go(start, k) + go(k+1, end)
                + matrix[start][0] * matrix[k][1] * matrix[end][1];
            
            minValue = Math.min(minValue, value);
        }
        
        return dp[start][end] = minValue;
    }
}

 

출력 결과

 

'Algorithm' 카테고리의 다른 글

[개념 정리] BFS(너비 우선 탐색)  (0) 2025.02.27
[개념 정리] DFS(깊이 우선 탐색)  (0) 2025.02.26
[프로그래머스] Lv 3. 블록 이동하기  (0) 2025.02.24
펜윅 트리  (0) 2025.02.23
최대 증가 부분 수열  (2) 2025.02.22