Skip to main content

Matrix chain multiplication using recursion

The Matrix Chain Multiplication problem is a classic optimization problem in computer science and mathematics. It involves finding the most efficient way to multiply a sequence of matrices. The objective is to minimize the total number of scalar multiplications required to compute the product.

Introduction

Given a chain of matrices A1, A2, A3, ..., An, where the dimensions of matrix Ai are given by the integer array dims[0..n], the goal is to determine the most efficient order of multiplication to minimize the overall cost. The cost of multiplying two matrices with dimensions A[i] x A[j] and A[j] x A[k] is A[i] x A[j] x A[k]. The problem can be solved using dynamic programming, but in this article, we will explore the recursive approach.

Problem Statement

Consider a scenario where we have four matrices with dimensions (10 x 16), (16 x 12), (12 x 6), and (6 x 14). The goal is to find the optimal way to multiply these matrices, resulting in the minimum number of scalar multiplications. The order of matrix multiplication can significantly impact the total number of operations, so we need to find an arrangement that minimizes the cost of computation.

Example

Let's take a specific example to understand the problem better:

  Matrices: A, B, C, D
  Dimensions: 
  A: 10 x 16
  B: 16 x 12
  C: 12 x 6
  D: 6 x 14
  

We want to compute the product (A(BC))D in the most efficient way.

Algorithm and Pseudocode

To solve the matrix chain multiplication problem using recursion, we'll implement a function that takes the following parameters:

  • dims: An array representing the dimensions of the matrices.
  • i: The starting index of the current subchain of matrices.
  • j: The ending index of the current subchain of matrices.

Here's the recursive function to find the optimal cost of matrix multiplication:

  
  function matrixChainMultiplication(dims[], i, j):
      if (j <= i + 1):
          return 0
      minValue = INT_MAX
      for k = i + 1 to j - 1:
          cost = matrixChainMultiplication(dims, i, k)
          cost = cost + matrixChainMultiplication(dims, k, j)
          cost = cost + dims[i] * dims[k] * dims[j]
          if (cost < minValue):
              minValue = cost
      return minValue
  
  

Program Solution

/*
    C program for
    Matrix chain multiplication using recursion
*/
#include <stdio.h>
#include <limits.h>

int matrixChainMultiplication(int dims[], int i, int j)
{
	if (j <= i + 1)
	{
		return 0;
	}
	int cost = 0;
	int minValue = INT_MAX;
	for (int k = i + 1; k < j; k++)
	{
		cost = matrixChainMultiplication(dims, i, k);
		cost = cost + matrixChainMultiplication(dims, k, j);
		// Change cost
		cost = cost + dims[i] *dims[k] *dims[j];
		if (cost < minValue)
		{
			// Get new minimum value
			minValue = cost;
		}
	}
	return minValue;
}
int main(int argc, char
	const *argv[])
{
	int dims1[] = {
		10 , 16 , 12 , 6 , 14
	};
	int n = sizeof(dims1) / sizeof(dims1[0]);
	/*

	    matrix A = 10 X 16 
	    matrix B = 16 X 12
	    matrix C = 12 X 6
	    matrix D = 6 X  14
	    --------------------
	    (A(BC))D

	    (16*12*6) + (10*16*6) + (10*6*14)
	     =  2952  
	*/
	printf("\n %d", matrixChainMultiplication(dims1, 0, n - 1));
	int dims2[] = {
		8 , 20 , 16 , 10 , 6
	};
	n = sizeof(dims2) / sizeof(dims2[0]);
	/*
	    matrix A = 8 X 20
	    matrix B = 20 X 16
	    matrix C = 16 X 10
	    matrix D = 10 X 6

	    A(B(CD)) =  3840

	    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840

	*/
	printf("\n %d", matrixChainMultiplication(dims2, 0, n - 1));
	return 0;
}

Output

 2952
 3840
/*
  Java program for
  Matrix chain multiplication using recursion
*/
public class Multiplication
{
	public int matrixChainMultiplication(int[] dims, int i, int j)
	{
		if (j <= i + 1)
		{
			return 0;
		}
		int cost = 0;
		int minValue = Integer.MAX_VALUE;
		for (int k = i + 1; k < j; k++)
		{
			cost = matrixChainMultiplication(dims, i, k);
			cost = cost + matrixChainMultiplication(dims, k, j);
			// Change cost
			cost = cost + dims[i] * dims[k] * dims[j];
			if (cost < minValue)
			{
				// Get new minimum value
				minValue = cost;
			}
		}
		return minValue;
	}
	public static void main(String[] args)
	{
		Multiplication task = new Multiplication();
		int[] dims1 = {
			10 , 16 , 12 , 6 , 14
		};
		int n = dims1.length;
		/*
		    matrix A = 10 X 16 
		    matrix B = 16 X 12
		    matrix C = 12 X 6
		    matrix D = 6 X  14
		    --------------------
		    (A(BC))D

		    (16*12*6) + (10*16*6) + (10*6*14)
		     =  2952  
		*/
		System.out.print("\n " + 
                         task.matrixChainMultiplication(dims1, 0, n - 1));
		int[] dims2 = {
			8 , 20 , 16 , 10 , 6
		};
		n = dims2.length;
		/*
		    matrix A = 8 X 20
		    matrix B = 20 X 16
		    matrix C = 16 X 10
		    matrix D = 10 X 6

		    A(B(CD)) =  3840

		    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840

		*/
		System.out.print("\n " + 
                         task.matrixChainMultiplication(dims2, 0, n - 1));
	}
}

Output

 2952
 3840
// Include header file
#include <iostream>

#include <limits.h>

using namespace std;
/*
  C++ program for
  Matrix chain multiplication using recursion
*/
class Multiplication
{
	public: int matrixChainMultiplication(int dims[], int i, int j)
	{
		if (j <= i + 1)
		{
			return 0;
		}
		int cost = 0;
		int minValue = INT_MAX;
		for (int k = i + 1; k < j; k++)
		{
			cost = this->matrixChainMultiplication(dims, i, k);
			cost = cost + this->matrixChainMultiplication(dims, k, j);
			// Change cost
			cost = cost + dims[i] *dims[k] *dims[j];
			if (cost < minValue)
			{
				// Get new minimum value
				minValue = cost;
			}
		}
		return minValue;
	}
};
int main()
{
	Multiplication *task = new Multiplication();
	int dims1[] = {
		10 , 16 , 12 , 6 , 14
	};
	int n = sizeof(dims1) / sizeof(dims1[0]);
	/*
	    matrix A = 10 X 16 
	    matrix B = 16 X 12
	    matrix C = 12 X 6
	    matrix D = 6 X  14
	    --------------------
	    (A(BC))D
	    (16*12*6) + (10*16*6) + (10*6*14)
	     =  2952  
	*/
	cout << "\n " << task->matrixChainMultiplication(dims1, 0, n - 1);
	int dims2[] = {
		8 , 20 , 16 , 10 , 6
	};
	n = sizeof(dims2) / sizeof(dims2[0]);
	/*
	    matrix A = 8 X 20
	    matrix B = 20 X 16
	    matrix C = 16 X 10
	    matrix D = 10 X 6
	    A(B(CD)) =  3840
	    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
	*/
	cout << "\n " << task->matrixChainMultiplication(dims2, 0, n - 1);
	return 0;
}

Output

 2952
 3840
// Include namespace system
using System;
/*
  Csharp program for
  Matrix chain multiplication using recursion
*/
public class Multiplication
{
	public int matrixChainMultiplication(int[] dims, int i, int j)
	{
		if (j <= i + 1)
		{
			return 0;
		}
		int cost = 0;
		int minValue = int.MaxValue;
		for (int k = i + 1; k < j; k++)
		{
			cost = this.matrixChainMultiplication(dims, i, k);
			cost = cost + this.matrixChainMultiplication(dims, k, j);
			// Change cost
			cost = cost + dims[i] * dims[k] * dims[j];
			if (cost < minValue)
			{
				// Get new minimum value
				minValue = cost;
			}
		}
		return minValue;
	}
	public static void Main(String[] args)
	{
		Multiplication task = new Multiplication();
		int[] dims1 = {
			10 , 16 , 12 , 6 , 14
		};
		int n = dims1.Length;
		/*
		    matrix A = 10 X 16 
		    matrix B = 16 X 12
		    matrix C = 12 X 6
		    matrix D = 6 X  14
		    --------------------
		    (A(BC))D
		    (16*12*6) + (10*16*6) + (10*6*14)
		     =  2952  
		*/
		Console.Write("\n " + 
                      task.matrixChainMultiplication(dims1, 0, n - 1));
		int[] dims2 = {
			8 , 20 , 16 , 10 , 6
		};
		n = dims2.Length;
		/*
		    matrix A = 8 X 20
		    matrix B = 20 X 16
		    matrix C = 16 X 10
		    matrix D = 10 X 6
		    A(B(CD)) =  3840
		    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
		*/
		Console.Write("\n " + 
                      task.matrixChainMultiplication(dims2, 0, n - 1));
	}
}

Output

 2952
 3840
package main
import "math"
import "fmt"
/*
  Go program for
  Matrix chain multiplication using recursion
*/

func matrixChainMultiplication(dims[] int, i int, j int) int {
	if j <= i + 1 {
		return 0
	}
	var cost int = 0
	var minValue int = math.MaxInt64
	for k := i + 1 ; k < j ; k++ {
		cost = matrixChainMultiplication(dims, i, k)
		cost = cost +matrixChainMultiplication(dims, k, j)
		// Change cost
		cost = cost + dims[i] * dims[k] * dims[j]
		if cost < minValue {
			// Get new minimum value
			minValue = cost
		}
	}
	return minValue
}
func main() {

	var dims1 = [] int {
		10,
		16,
		12,
		6,
		14,
	}
	var n int = len(dims1)
	/*
	    matrix A = 10 X 16 
	    matrix B = 16 X 12
	    matrix C = 12 X 6
	    matrix D = 6 X  14
	    --------------------
	    (A(BC))D
	    (16*12*6) + (10*16*6) + (10*6*14)
	     =  2952  
	*/
	fmt.Print("\n ", 
		matrixChainMultiplication(dims1, 0, n - 1))
	var dims2 = [] int {
		8,
		20,
		16,
		10,
		6,
	}
	n = len(dims2)
	/*
	    matrix A = 8 X 20
	    matrix B = 20 X 16
	    matrix C = 16 X 10
	    matrix D = 10 X 6
	    A(B(CD)) =  3840
	    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
	*/
	fmt.Print("\n ", 
		matrixChainMultiplication(dims2, 0, n - 1))
}

Output

 2952
 3840
<?php
/*
  Php program for
  Matrix chain multiplication using recursion
*/
class Multiplication
{
	public	function matrixChainMultiplication($dims, $i, $j)
	{
		if ($j <= $i + 1)
		{
			return 0;
		}
		$cost = 0;
		$minValue = PHP_INT_MAX;
		for ($k = $i + 1; $k < $j; $k++)
		{
			$cost = $this->matrixChainMultiplication($dims, $i, $k);
			$cost = $cost + $this->matrixChainMultiplication($dims, $k, $j);
			// Change cost
			$cost = $cost + $dims[$i] * $dims[$k] * $dims[$j];
			if ($cost < $minValue)
			{
				// Get new minimum value
				$minValue = $cost;
			}
		}
		return $minValue;
	}
}

function main()
{
	$task = new Multiplication();
	$dims1 = array(10, 16, 12, 6, 14);
	$n = count($dims1);
	/*
	    matrix A = 10 X 16 
	    matrix B = 16 X 12
	    matrix C = 12 X 6
	    matrix D = 6 X  14
	    --------------------
	    (A(BC))D
	    (16*12*6) + (10*16*6) + (10*6*14)
	     =  2952  
	*/
	echo("\n ".$task->matrixChainMultiplication($dims1, 0, $n - 1));
	$dims2 = array(8, 20, 16, 10, 6);
	$n = count($dims2);
	/*
	    matrix A = 8 X 20
	    matrix B = 20 X 16
	    matrix C = 16 X 10
	    matrix D = 10 X 6
	    A(B(CD)) =  3840
	    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
	*/
	echo("\n ".$task->matrixChainMultiplication($dims2, 0, $n - 1));
}
main();

Output

 2952
 3840
/*
  Node JS program for
  Matrix chain multiplication using recursion
*/
class Multiplication
{
	matrixChainMultiplication(dims, i, j)
	{
		if (j <= i + 1)
		{
			return 0;
		}
		var cost = 0;
		var minValue = Number.MAX_VALUE;
		for (var k = i + 1; k < j; k++)
		{
			cost = this.matrixChainMultiplication(dims, i, k);
			cost = cost + this.matrixChainMultiplication(dims, k, j);
			// Change cost
			cost = cost + dims[i] * dims[k] * dims[j];
			if (cost < minValue)
			{
				// Get new minimum value
				minValue = cost;
			}
		}
		return minValue;
	}
}

function main()
{
	var task = new Multiplication();
	var dims1 = [10, 16, 12, 6, 14];
	var n = dims1.length;
	/*
	    matrix A = 10 X 16 
	    matrix B = 16 X 12
	    matrix C = 12 X 6
	    matrix D = 6 X  14
	    --------------------
	    (A(BC))D
	    (16*12*6) + (10*16*6) + (10*6*14)
	     =  2952  
	*/
	console.log(task.matrixChainMultiplication(dims1, 0, n - 1));
	var dims2 = [8, 20, 16, 10, 6];
	n = dims2.length;
	/*
	    matrix A = 8 X 20
	    matrix B = 20 X 16
	    matrix C = 16 X 10
	    matrix D = 10 X 6
	    A(B(CD)) =  3840
	    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
	*/
	console.log(task.matrixChainMultiplication(dims2, 0, n - 1));
}
main();

Output

2952
3840
import sys
#  Python 3 program for
#  Matrix chain multiplication using recursion
class Multiplication :
	def matrixChainMultiplication(self, dims, i, j) :
		if (j <= i + 1) :
			return 0
		
		cost = 0
		minValue = sys.maxsize
		k = i + 1
		while (k < j) :
			cost = self.matrixChainMultiplication(dims, i, k)
			cost = cost + self.matrixChainMultiplication(dims, k, j)
			#  Change cost
			cost = cost + dims[i] * dims[k] * dims[j]
			if (cost < minValue) :
				#  Get new minimum value
				minValue = cost
			
			k += 1
		
		return minValue
	

def main() :
	task = Multiplication()
	dims1 = [10, 16, 12, 6, 14]
	n = len(dims1)
	#    matrix A = 10 X 16 
	#    matrix B = 16 X 12
	#    matrix C = 12 X 6
	#    matrix D = 6 X  14
	#    --------------------
	#    (A(BC))D
	#    (16*12*6) + (10*16*6) + (10*6*14)
	#     =  2952  
	print( task.matrixChainMultiplication(dims1, 0, n - 1))
	dims2 = [8, 20, 16, 10, 6]
	n = len(dims2)
	#    matrix A = 8 X 20
	#    matrix B = 20 X 16
	#    matrix C = 16 X 10
	#    matrix D = 10 X 6
	#    A(B(CD)) =  3840
	#    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
	print(task.matrixChainMultiplication(dims2, 0, n - 1))

if __name__ == "__main__": main()

Output

2952
3840
#  Ruby program for
#  Matrix chain multiplication using recursion
class Multiplication 
	def matrixChainMultiplication(dims, i, j) 
		if (j <= i + 1) 
			return 0
		end

		cost = 0
		minValue = (2 ** (0. size * 8 - 2))
		k = i + 1
		while (k < j) 
			cost = self.matrixChainMultiplication(dims, i, k)
			cost = cost + self.matrixChainMultiplication(dims, k, j)
			#  Change cost
			cost = cost + dims[i] * dims[k] * dims[j]
			if (cost < minValue) 
				#  Get new minimum value
				minValue = cost
			end

			k += 1
		end

		return minValue
	end

end

def main() 
	task = Multiplication.new()
	dims1 = [10, 16, 12, 6, 14]
	n = dims1.length
	#    matrix A = 10 X 16 
	#    matrix B = 16 X 12
	#    matrix C = 12 X 6
	#    matrix D = 6 X  14
	#    --------------------
	#    (A(BC))D
	#    (16*12*6) + (10*16*6) + (10*6*14)
	#     =  2952  
	print("\n ", task.matrixChainMultiplication(dims1, 0, n - 1))
	dims2 = [8, 20, 16, 10, 6]
	n = dims2.length
	#    matrix A = 8 X 20
	#    matrix B = 20 X 16
	#    matrix C = 16 X 10
	#    matrix D = 10 X 6
	#    A(B(CD)) =  3840
	#    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
	print("\n ", task.matrixChainMultiplication(dims2, 0, n - 1))
end

main()

Output

 2952
 3840
/*
  Scala program for
  Matrix chain multiplication using recursion
*/
class Multiplication()
{
	def matrixChainMultiplication(
      dims: Array[Int], 
      i: Int, j: Int): Int = {
		if (j <= i + 1)
		{
			return 0;
		}
		var cost: Int = 0;
		var minValue: Int = Int.MaxValue;
		var k: Int = i + 1;
		while (k < j)
		{
			cost = matrixChainMultiplication(dims, i, k);
			cost = cost + matrixChainMultiplication(dims, k, j);
			// Change cost
			cost = cost + dims(i) * dims(k) * dims(j);
			if (cost < minValue)
			{
				// Get new minimum value
				minValue = cost;
			}
			k += 1;
		}
		return minValue;
	}
}
object Main
{
	def main(args: Array[String]): Unit = {
		var task: Multiplication = new Multiplication();
		var dims1: Array[Int] = Array(10, 16, 12, 6, 14);
		var n: Int = dims1.length;
		/*
		    matrix A = 10 X 16 
		    matrix B = 16 X 12
		    matrix C = 12 X 6
		    matrix D = 6 X  14
		    --------------------
		    (A(BC))D
		    (16*12*6) + (10*16*6) + (10*6*14)
		     =  2952  
		*/
		print("\n " + task.matrixChainMultiplication(dims1, 0, n - 1));
		var dims2: Array[Int] = Array(8, 20, 16, 10, 6);
		n = dims2.length;
		/*
		    matrix A = 8 X 20
		    matrix B = 20 X 16
		    matrix C = 16 X 10
		    matrix D = 10 X 6
		    A(B(CD)) =  3840
		    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
		*/
		print("\n " + task.matrixChainMultiplication(dims2, 0, n - 1));
	}
}

Output

 2952
 3840
import Foundation;
/*
  Swift 4 program for
  Matrix chain multiplication using recursion
*/
class Multiplication
{
	func matrixChainMultiplication(_ dims: [Int],
 	 _ i: Int,
     _ j: Int) -> Int
	{
		if (j <= i + 1)
		{
			return 0;
		}
		var cost: Int = 0;
		var minValue: Int = Int.max;
		var k: Int = i + 1;
		while (k < j)
		{
			cost = self.matrixChainMultiplication(dims, i, k);
			cost = cost + self.matrixChainMultiplication(dims, k, j);
			// Change cost
			cost = cost + dims[i] * dims[k] * dims[j];
			if (cost < minValue)
			{
				// Get new minimum value
				minValue = cost;
			}
			k += 1;
		}
		return minValue;
	}
}
func main()
{
	let task: Multiplication = Multiplication();
	let dims1: [Int] = [10, 16, 12, 6, 14];
	var n: Int = dims1.count;
	/*
	    matrix A = 10 X 16 
	    matrix B = 16 X 12
	    matrix C = 12 X 6
	    matrix D = 6 X  14
	    --------------------
	    (A(BC))D
	    (16*12*6) + (10*16*6) + (10*6*14)
	     =  2952  
	*/
	print("\n ", 
          task.matrixChainMultiplication(dims1, 0, n - 1), 
          terminator: "");
	let dims2: [Int] = [8, 20, 16, 10, 6];
	n = dims2.count;
	/*
	    matrix A = 8 X 20
	    matrix B = 20 X 16
	    matrix C = 16 X 10
	    matrix D = 10 X 6
	    A(B(CD)) =  3840
	    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
	*/
	print("\n ", 
          task.matrixChainMultiplication(dims2, 0, n - 1), 
          terminator: "");
}
main();

Output

  2952
  3840
/*
  Kotlin program for
  Matrix chain multiplication using recursion
*/
class Multiplication
{
	fun matrixChainMultiplication(dims: Array < Int > , 
                                   i: Int, j: Int): Int
	{
		if (j <= i + 1)
		{
			return 0;
		}
		var cost: Int;
		var minValue: Int = Int.MAX_VALUE;
		var k: Int = i + 1;
		while (k < j)
		{
			cost = this.matrixChainMultiplication(dims, i, k);
			cost = cost + this.matrixChainMultiplication(dims, k, j);
			// Change cost
			cost = cost + dims[i] * dims[k] * dims[j];
			if (cost < minValue)
			{
				// Get new minimum value
				minValue = cost;
			}
			k += 1;
		}
		return minValue;
	}
}
fun main(args: Array < String > ): Unit
{
	val task: Multiplication = Multiplication();
	val dims1: Array < Int > = arrayOf(10, 16, 12, 6, 14);
	var n: Int = dims1.count();
	/*
	    matrix A = 10 X 16 
	    matrix B = 16 X 12
	    matrix C = 12 X 6
	    matrix D = 6 X  14
	    --------------------
	    (A(BC))D
	    (16*12*6) + (10*16*6) + (10*6*14)
	     =  2952  
	*/
	print("\n " + task.matrixChainMultiplication(dims1, 0, n - 1));
	val dims2: Array < Int > = arrayOf(8, 20, 16, 10, 6);
	n = dims2.count();
	/*
	    matrix A = 8 X 20
	    matrix B = 20 X 16
	    matrix C = 16 X 10
	    matrix D = 10 X 6
	    A(B(CD)) =  3840
	    (16 X 10 X 6) + (20 X 16 X 6 ) + ( 8 X 20 X 6 ) = 3840
	*/
	print("\n " + task.matrixChainMultiplication(dims2, 0, n - 1));
}

Output

 2952
 3840

Explanation

The recursive function starts by checking if the current subchain of matrices has only two matrices (j <= i + 1). In such cases, no multiplication is needed, and the function returns 0. Otherwise, it initializes the minimum cost (minValue) as INT_MAX.

Next, the function iterates over all possible points (k) where the chain of matrices can be split. It calculates the cost of splitting the chain at index k, performs recursive calls for the two subproblems (i to k and k to j), and then adds the cost of the actual multiplication.

The function keeps track of the minimum cost encountered and returns it as the result.

Output Explanation

For the given matrices (A, B, C, D) with their respective dimensions, the optimal order of multiplication to minimize the number of scalar multiplications is found by the recursive function. For the first example, the result is 2952, and for the second example, the result is 3840.

Time Complexity

The time complexity of the recursive solution can be quite high due to overlapping subproblems. The recursive function computes the same subproblems multiple times. As a result, the time complexity is exponential, O(2^n), where n is the number of matrices in the chain.

Naive Approach

A naive approach to solve this problem is to consider all possible parenthesizations and calculate the number of scalar multiplications required for each. Then, we choose the one that requires the minimum multiplications. However, this approach has exponential time complexity and is not efficient for large sequences of matrices.

Dynamic Programming Approach

The dynamic programming technique is used to solve this problem efficiently by avoiding redundant calculations. We can build a table to store the minimum number of scalar multiplications needed for the subproblems. The table is filled in a bottom-up manner, starting from smaller subproblems to larger ones. This way, we avoid recalculating the same subproblems multiple times.

Algorithm and Pseudocode

  1. Create a table m to store the minimum number of scalar multiplications.
  2. Initialize the diagonal elements of the table m with 0 since multiplying a single matrix requires no multiplication.
  3. For chain length l from 2 to n (where n is the number of matrices): a. For each i from 1 to n-l+1:
    • Set j = i + l - 1.
    • Initialize m[i][j] to a very large value (representing positive infinity).
    • For k from i to j-1, calculate cost = m[i][k] + m[k+1][j] + dims[i-1] * dims[k] * dims[j].
    • If cost is less than m[i][j], update m[i][j] with cost.
  4. The final result is stored in m[1][n], which represents the minimum number of scalar multiplications for the entire chain.

Pseudocode

function matrixChainMultiplication(dims[], n):
    Create a 2D table m[n][n]
    for i from 1 to n:
        m[i][i] = 0

    for l from 2 to n:
        for i from 1 to n-l+1:
            j = i + l - 1
            m[i][j] = INFINITY
            for k from i to j-1:
                cost = m[i][k] + m[k+1][j] + dims[i-1] * dims[k] * dims[j]
                if cost < m[i][j]:
                    m[i][j] = cost

    return m[1][n]

Here C code

#include <stdio.h>
#include <limits.h>

int matrixChainMultiplication(int dims[], int n)
{
    // Create a 2D table to store the minimum number of scalar multiplications
    int m[n][n];
    
    // Initialize the diagonal elements to 0 since multiplying a single matrix requires no multiplication
    for (int i = 1; i < n; i++)
        m[i][i] = 0;
    
    // Compute the minimum number of scalar multiplications for chains of length l
    for (int l = 2; l < n; l++)
    {
        for (int i = 1; i < n - l + 1; i++)
        {
            int j = i + l - 1;
            m[i][j] = INT_MAX; // Initialize to a very large value (positive infinity)
            
            // Find the minimum number of scalar multiplications for chain (i to j)
            for (int k = i; k <= j - 1; k++)
            {
                int cost = m[i][k] + m[k + 1][j] + dims[i - 1] * dims[k] * dims[j];
                if (cost < m[i][j])
                    m[i][j] = cost;
            }
        }
    }
    
    return m[1][n - 1]; // Final result is stored in m[1][n-1]
}

int main()
{
    int dims1[] = {10, 16, 12, 6, 14};
    int n1 = sizeof(dims1) / sizeof(dims1[0]);
    
    printf("Minimum number of scalar multiplications for sequence 1: %d\n", matrixChainMultiplication(dims1, n1));
    
    int dims2[] = {8, 20, 16, 10, 6};
    int n2 = sizeof(dims2) / sizeof(dims2[0]);
    
    printf("Minimum number of scalar multiplications for sequence 2: %d\n", matrixChainMultiplication(dims2, n2));
    
    return 0;
}

Time Complexity

The time complexity of the dynamic programming approach to the matrix chain multiplication problem is O(n^3) since we fill up a table of size n x n, and each cell takes O(n) time to compute.

Resultant Output Explanation

Using the dynamic programming approach, the provided code calculates the minimum number of scalar multiplications for two sequences of matrices: [10, 16, 12, 6, 14] and [8, 20, 16, 10, 6].

For the first sequence, the result is 2952. This corresponds to the optimal parenthesization: (A * (B * (C * D))), where:

  • A: 10 x 16
  • B: 16 x 12
  • C: 12 x 6
  • D: 6 x 14

For the second sequence, the result is 3840. This corresponds to the optimal parenthesization: ((A * (B * C)) * D), where:

  • A: 8 x 20
  • B: 20 x 16
  • C: 16 x 10
  • D: 10 x 6

In both cases, the provided code correctly calculates the minimum number of scalar multiplications needed to compute the product of the given matrices.

The dynamic programming approach makes it efficient to solve this problem even for larger sequences of matrices, as it avoids redundant calculations and optimally solves the subproblems.

Finally

Matrix Chain Multiplication is an essential optimization problem with practical applications. The dynamic programming approach provides an efficient solution by breaking down the problem into smaller subproblems and avoiding redundant calculations. The provided code demonstrates this approach and correctly finds the minimum number of scalar multiplications needed for two sequences of matrices. With a time complexity of O(n^3), the dynamic programming technique can handle even larger sequences efficiently, making it a valuable algorithm in various real-world scenarios.





Comment

Please share your knowledge to improve code and content standard. Also submit your doubts, and test case. We improve by your feedback. We will try to resolve your query as soon as possible.

New Comment