트리의 지름
트리의 지름이라는 것은 1967번 문제에도 나와 있듯이 어떤 두 노드를 선택해서 당겼을 때, 가장 길게 늘어나는 경우의 두 정점 사이의 거리를 말하는 것이다.
즉, 아래 그림의 9번, 12번 노드 사이의 거리는 45로 어떤 두 노드보다 가장 긴 거리를 가지므로, 이 두 노드 사이의 거리가 주어진 트리의 지름이다.
따라서 우리가 이 문제에서 구해야 하는 것은 두 노드 사이의 최대 길이이다.
백준 1967
두 노드 사이의 최대 거리를 구하기 위해서 두 번의 DFS를 수행한다.
- 루트노드(1번 노드)에서 모든 노드까지의 거리 계산
- 1에서 가장 먼 거리를 가지는 노드(max_node)에서 모든 노드까지의 거리 계산
DFS 문제에서 방문 배열로 사용되는 visited를 사용해 거리를 계산할 수 있도록 했다.
max_node를 구하기 위해 visited에서 최댓값을 찾아 index()를 사용했고, 그렇게 찾은 노드 번호로 다시 한번 DFS를 수행하여 트리의 지름을 구할 수 있다.
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n = int(input()) # 정점의 개수
# 그래프 정보
graph = [[] for _ in range(n+1)]
for _ in range(n-1):
parent, child, weight = map(int, input().split()) # 부모, 자식, 가중치
graph[parent].append((child, weight))
graph[child].append((parent, weight))
# DFS 함수
def DFS(x, distance):
for i, w in graph[x]:
# 아직 방문하지 않은 노드이면 현재까지의 거리 + 해당 노드까지의 가중치로 방문 배열 값을 변경
if visited[i] == -1:
visited[i] = distance + w
DFS(i, distance + w)
# 루트 노드에서 각 정점까지의 거리 계산
visited = [-1] * (n+1)
visited[1] = 0 # 루트 노드 거리는 0으로 초기화
DFS(1, 0)
max_distance = max(visited) # 최장 거리
max_node = visited.index(max_distance) # 해당 노드
# max_node에서 시작해 각 정점까지의 거리 계산
visited = [-1] * (n+1)
visited[max_node] = 0
DFS(max_node, 0)
print(max(visited)) # 최장 거리(=트리의 지름) 출력
백준 1167
1967번과 입력 형식만 다르고 풀이 방법은 똑같다.
DFS 코드
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n = int(input()) # 정점의 개수
# 그래프 정보
graph = [[] for _ in range(n+1)]
for _ in range(n):
tree = list(map(int, input().split()))
for i in range(1, len(tree)//2):
graph[tree[0]].append((tree[2*i-1], tree[2*i]))
# 방문 배열
visited = [-1] * (n+1)
visited[1] = 0 # 루트 노드 거리는 0으로 초기화
# DFS 함수
def DFS(x, distance):
for i, w in graph[x]:
# 아직 방문하지 않은 노드이면 현재까지의 거리 + 해당 노드까지의 가중치로 방문 배열 값을 변경
if visited[i] == -1:
visited[i] = distance + w
DFS(i, distance + w)
DFS(1, 0) # 루트 노드(1)에서의 각 정점까지의 거리 계산
max_distance = max(visited) # 최장 거리
max_node = visited.index(max_distance) # 해당 노드
# max_node에서 시작해 각 정점까지의 거리 계산
visited = [-1] * (n+1)
visited[max_node] = 0
DFS(max_node, 0)
print(max(visited))
BFS 코드
import sys
from collections import deque
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n = int(input()) # 정점의 개수
# 그래프 정보
graph = [[] for _ in range(n+1)]
for _ in range(n):
tree = list(map(int, input().split()))
for i in range(1, len(tree)//2):
graph[tree[0]].append((tree[2*i-1], tree[2*i]))
# BFS 함수
def BFS(x):
queue = deque([x])
visited[x] = True
while queue:
now = queue.popleft()
for i, w in graph[now]:
if not visited[i]:
queue.append((i))
visited[i] = True
distance[i] = distance[now] + w
# 루트 노드(1)에서의 각 정점까지의 거리 계산
visited = [False] * (n+1)
distance = [0] * (n+1)
BFS(1)
max_distance = max(distance) # 최장 거리
max_node = distance.index(max_distance) # 해당 노드
# max_node에서 시작해 각 정점까지의 거리 계산
visited = [False] * (n+1)
distance = [0] * (n+1)
BFS(max_node)
print(max(distance))
반응형