JUINTINATION
백준 1967번: 트리의 지름 본문
문제
https://www.acmicpc.net/problem/1967
풀이
입력으로 루트가 있는 트리를 가중치가 있는 간선들로 줄 때 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이인 트리의 지름을 구하는 문제입니다.
접근
트리의 간선들에 가중치가 있기 때문에 단순히 가장 왼쪽에 있는 노드와 가장 오른쪽에 있는 노드 사이의 거리를 구하는 문제가 아닙니다. 처음에 이 문제를 보고 루트 노드에서 가장 먼 노드와 두 번째로 먼 노드 사이의 거리를 구해볼까 했는데 이는 두 노드가 같은 경로에 있을 가능성이 있습니다. 그래서 루트 노드에서 가장 먼 노드를 찾고 그 노드에서부터 가장 먼 노드 사이의 거리를 구하는 방식으로 문제를 해결했습니다. 각 노드에서 가장 먼 노드를 찾기 위해 dfs 알고리즘을 사용하였습니다.
코드
자바
import java.io.IOException;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.StringTokenizer;
class node {
int data, cost;
public node(int data, int cost) {
this.data = data;
this.cost = cost;
}
}
public class Main {
static int tmp = 1, max = 0;
static boolean[] visited;
static ArrayList<node>[] list;
public static void dfs(int idx, int len) {
if (len > max) { // (2)
max = len;
tmp = idx;
}
visited[idx] = true;
if (list[idx] != null) { // (1)
for (node e : list[idx]) {
if (!visited[e.data]) {
dfs(e.data, len + e.cost);
}
}
}
}
@SuppressWarnings("unchecked")
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine());
list = new ArrayList[n + 1];
visited = new boolean[n + 1];
for (int i = 1; i <= n; i++) {
list[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
StringTokenizer st = new StringTokenizer(br.readLine(), " ");
int parent = Integer.parseInt(st.nextToken());
int child = Integer.parseInt(st.nextToken());
int cost = Integer.parseInt(st.nextToken());
list[parent].add(new node(child, cost));
list[child].add(new node(parent, cost));
}
dfs(1, 0);
Arrays.fill(visited, false);
dfs(tmp, 0); // (3)
System.out.println(max);
}
}
ArrayList[] 배열 list는 임의의 수 x번 노드에 연결된 다른 노드들의 번호와 해당 노드와의 거리를 저장해야 합니다. 그래서 트리의 노드 클래스 node는 list[x]와 연결된 노드의 번호 data와 해당 노드와의 거리 cost를 포함합니다.
모든 입력이 끝난 후 dfs(1, 0)을 실행하여 루트 노드인 1번 노드와 가장 먼 노드를 찾습니다. dfs(int idx, int len)에서 idx는 현재 탐색중인 노드의 번호를, len은 dfs 메소드를 실행했을 때 시작했던 노드와의 거리를 의미합니다.
(1) idx번 노드의 방문 여부를 true로 초기화한 뒤에 idx번 노드와 연결된 다른 노드를 확인합니다.
list[idx]가 null이 아니라면(idx번 노드에 연결된 노드가 1개라도 있을 경우) foreach문을 통해 list[idx]를 탐색합니다. 탐색중인 노드 e를 방문하지 않았다면(!visited[e.data]) dfs(e.data, len + e.cost)를 실행하여 다음 노드를 탐색합니다.
(2) 현재 탐색중인 노드까지의 길이가 max보다 큰지 확인합니다.
len > max라면 max를 len으로 초기화하고 루트 노드인 1번 노드와 가장 먼 노드의 번호를 의미하는 tmp를 idx로 초기화합니다.
(3) 위의 과정이 끝났다면 visited 배열을 모두 false로 초기화한 후 dfs(tmp, 0)을 실행합니다.
(1)번 과정과 (2)번 과정을 반복하면서 tmp번 노드와 가장 먼 노드를 탐색합니다.
모든 과정을 종료한 후에 max를 출력합니다.
C언어
#include <stdio.h>
#include <stdlib.h>
int n, tmp = 1, max = 0, *visited;
typedef struct node {
int data, cost;
struct node* next;
} node;
void add(node* target, int data, int cost) {
node* now = (node*)malloc(sizeof(node));
now->data = data;
now->cost = cost;
now->next = target->next;
target->next = now;
return;
}
node* list[10001];
void dfs(int idx, int len) {
if (len > max) {
max = len;
tmp = idx;
}
visited[idx] = 1;
node* curr = list[idx]->next;
while (curr != NULL) {
if (!visited[curr->data]) {
dfs(curr->data, len + curr->cost);
}
curr = curr->next;
}
}
main() {
scanf("%d", &n);
visited = (int*)malloc(sizeof(int) * (n + 1));
for (int i = 1; i <= n; i++) {
list[i] = (node*)malloc(sizeof(node));
list[i]->next = NULL;
visited[i] = 0;
}
for (int i = 0; i < n - 1; i++) {
int idx, data, cost;
scanf("%d %d %d", &idx, &data, &cost);
add(list[idx], data, cost);
add(list[data], idx, cost);
}
dfs(1, 0);
for (int i = 1; i <= n; i++) {
visited[i] = 0;
}
dfs(tmp, 0);
printf("%d", max);
}
자바를 이용한 풀이와 동일합니다. ArrayList 배열은 연결리스트를 이용하여 구현했습니다.
결론
이 문제를 풀고 정리하면서 다시 생각해보니 위에 (1)번 과정에서 list[idx]가 null인 경우가 존재하지 않을 것 같다는 생각이 들었습니다. 그래서 if (list[idx] != null) 조건문을 생략하고 제출해보니 런타임 에러 (NullPointer)가 발생했습니다. 그 이유는 n이 1일 경우에 1번 노드에 연결된 다른 노드가 존재하지 않기 때문이었습니다. 이 경우밖에 없다보니 main 메소드 안에서 조건문 하나만 추가하면 해결할 수 있을 것 같았습니다.
import java.io.IOException;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.StringTokenizer;
class node {
int data, cost;
public node(int data, int cost) {
this.data = data;
this.cost = cost;
}
}
public class Main {
static int tmp, max = 0;
static boolean[] visited;
static ArrayList<node>[] list;
public static void dfs(int idx, int len) {
if (len > max) {
max = len;
tmp = idx;
}
visited[idx] = true;
for (node e : list[idx]) {
if (!visited[e.data]) {
dfs(e.data, len + e.cost);
}
}
}
@SuppressWarnings("unchecked")
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine());
if (n == 1) {
System.out.println(0);
} else {
list = new ArrayList[n + 1];
visited = new boolean[n + 1];
for (int i = 1; i <= n; i++) {
list[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
StringTokenizer st = new StringTokenizer(br.readLine(), " ");
int parent = Integer.parseInt(st.nextToken());
int child = Integer.parseInt(st.nextToken());
int cost = Integer.parseInt(st.nextToken());
list[parent].add(new node(child, cost));
list[child].add(new node(parent, cost));
}
dfs(1, 0);
Arrays.fill(visited, false);
dfs(tmp, 0);
System.out.println(max);
}
}
}
풀이는 동일하지만 (1)번 과정에서 list[idx]가 null인지 확인하지 않으며 main 메소드에서 n이 1일 경우에 0을 출력하고 프로그램을 종료합니다.
import java.io.IOException;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.StringTokenizer;
class node {
int data, cost;
public node(int data, int cost) {
this.data = data;
this.cost = cost;
}
}
public class Main {
static int tmp, max = 0;
static boolean[] visited;
static ArrayList<node>[] list;
public static void dfs(int idx, int len) {
if (len > max) {
max = len;
tmp = idx;
}
visited[idx] = true;
for (node e : list[idx]) {
if (!visited[e.data]) {
dfs(e.data, len + e.cost);
}
}
}
@SuppressWarnings("unchecked")
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine());
try {
list = new ArrayList[n + 1];
visited = new boolean[n + 1];
for (int i = 1; i <= n; i++) {
list[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
StringTokenizer st = new StringTokenizer(br.readLine(), " ");
int parent = Integer.parseInt(st.nextToken());
int child = Integer.parseInt(st.nextToken());
int cost = Integer.parseInt(st.nextToken());
list[parent].add(new node(child, cost));
list[child].add(new node(parent, cost));
}
dfs(1, 0);
Arrays.fill(visited, false);
dfs(tmp, 0);
System.out.println(max);
} catch (NullPointerException e) {
System.out.println(0);
}
}
}
이 코드는 try-catch문을 이용한 코드입니다. NullPointerException이 발생한다면 0을 출력한 후 프로그램을 종료합니다.
'백준 알고리즘 > 트리' 카테고리의 다른 글
백준 1167번: 트리의 지름 (0) | 2023.02.08 |
---|---|
백준 5639번: 이진 검색 트리 (0) | 2023.01.29 |
백준 2250번: 트리의 높이와 너비 (0) | 2023.01.24 |