255 lines
8.0 KiB
Python
255 lines
8.0 KiB
Python
import csv
|
||
from collections import deque
|
||
|
||
class TreeNode:
|
||
def __init__(self):
|
||
self.fa = None
|
||
self.children = []
|
||
self.pos = None
|
||
self.final_pos = None
|
||
self.val = 0
|
||
self.id = 0
|
||
self.dp = 0
|
||
|
||
class SourceCollector:
|
||
def __init__(self, filename=None, maze=None):
|
||
self.filename = filename
|
||
self.maze = maze
|
||
self.start_pos = None
|
||
self.end_pos = None
|
||
self.path = []
|
||
self.node_path = []
|
||
if self.filename:
|
||
self.maze = []
|
||
with open(f"{self.filename}",'r') as f:
|
||
reader = csv.reader(f)
|
||
for row in reader:
|
||
t = []
|
||
for i in row:
|
||
if i.startswith('b') or i.startswith('l'):
|
||
t.append('0')
|
||
else:
|
||
t.append(i)
|
||
self.maze.append(t)
|
||
else:
|
||
self.maze = maze
|
||
self.rowNums = len(self.maze)
|
||
self.colNums = len(self.maze[0])
|
||
|
||
|
||
for i in range(self.rowNums):
|
||
for j in range(self.colNums):
|
||
if self.maze[i][j] =='s':
|
||
self.start_pos = (i,j)
|
||
if self.maze[i][j] =='e':
|
||
self.end_pos = (i,j)
|
||
|
||
def dfs_show(self,u):
|
||
if u.id != 0:
|
||
print(f"id: {u.id} , fa:{u.fa.id} , val:{u.val} , pos:{u.pos}")
|
||
else:
|
||
print(f"id: {u.id} , val:{u.val} , pos:{u.pos}")
|
||
for child in u.children:
|
||
self.dfs_show(child)
|
||
def build_a_tree(self):
|
||
cnt = 0
|
||
root = TreeNode()
|
||
root.pos = self.start_pos
|
||
root.id = 0
|
||
root.val = 0
|
||
root.fa = None
|
||
|
||
queue = deque([(self.start_pos[0], self.start_pos[1], root)])
|
||
st = [[False] * self.colNums for _ in range(self.rowNums)]
|
||
st[self.start_pos[0]][self.start_pos[1]] = True
|
||
|
||
dx = [-1, 0, 1, 0]
|
||
dy = [0, -1, 0, 1]
|
||
|
||
while queue:
|
||
x, y, parent = queue.popleft()
|
||
for i in range(4):
|
||
nx, ny = x + dx[i], y + dy[i]
|
||
if self.outofmap(nx, ny) or st[nx][ny]:
|
||
continue
|
||
|
||
if self.maze[nx][ny] != '1':
|
||
st[nx][ny] = True
|
||
|
||
new_node = TreeNode()
|
||
new_node.pos = (nx, ny)
|
||
new_node.fa = parent
|
||
cnt+=1
|
||
new_node.id = cnt
|
||
if self.maze[nx][ny].startswith('g'):
|
||
new_node.val = int(self.maze[nx][ny][1:])
|
||
elif self.maze[nx][ny].startswith('t'):
|
||
new_node.val =-1 *int(self.maze[nx][ny][1:])
|
||
parent.children.append(new_node)
|
||
queue.append((nx, ny, new_node))
|
||
return root
|
||
|
||
|
||
def outofmap(self,x,y):
|
||
return x < 0 or y < 0 or x > self.rowNums or y > self.colNums
|
||
def getlca(self,u, v):
|
||
def get_path_to_root(node):
|
||
path = []
|
||
while node:
|
||
path.append(node)
|
||
node = node.fa
|
||
return path
|
||
path_u = get_path_to_root(u)
|
||
path_v = get_path_to_root(v)
|
||
|
||
path_u.reverse()
|
||
path_v.reverse()
|
||
|
||
lca = None
|
||
for i in range(min(len(path_u),len(path_v))):
|
||
if path_u[i] == path_v[i]:
|
||
lca = path_u[i]
|
||
else:
|
||
break
|
||
|
||
if lca is None:
|
||
return []
|
||
u_to_lca = []
|
||
node = u
|
||
while node != lca:
|
||
u_to_lca.append(node)
|
||
node = node.fa
|
||
|
||
lca_to_v = []
|
||
node_list = []
|
||
node = v
|
||
while node != lca:
|
||
node_list.append(node)
|
||
node = node.fa
|
||
node_list.append(lca)
|
||
node_list.reverse()
|
||
for node in node_list:
|
||
lca_to_v.append(node)
|
||
|
||
|
||
full_path = u_to_lca + lca_to_v[:-1]
|
||
return full_path
|
||
def dfs(self, sn):
|
||
sn.dp = sn.val
|
||
sn.final_pos = sn.pos
|
||
sn.path = [sn]
|
||
# 对子节点按收益/距离优先遍历
|
||
children = sn.children[:]
|
||
# 计算每个child的“贪心优先级”:金币优先,距离近优先
|
||
def child_priority(child):
|
||
# 距离=曼哈顿距离
|
||
dist = abs(child.pos[0] - sn.pos[0]) + abs(child.pos[1] - sn.pos[1])
|
||
# 金币优先,陷阱次之
|
||
if self.maze[child.pos[0]][child.pos[1]].startswith('g'):
|
||
return (0, dist) # 金币优先,距离近优先
|
||
elif self.maze[child.pos[0]][child.pos[1]].startswith('t'):
|
||
return (2, dist) # 陷阱最后
|
||
else:
|
||
return (1, dist) # 普通通路
|
||
children.sort(key=child_priority)
|
||
cur = None
|
||
for idx, child in enumerate(children):
|
||
self.dfs(child)
|
||
if child.dp > 0:
|
||
sn.dp += child.dp
|
||
if cur is not None:
|
||
sn.path.extend(self.getlca(sn.path[-1], child))
|
||
sn.path.extend(child.path)
|
||
cur = child
|
||
sn.final_pos = cur.final_pos
|
||
|
||
|
||
|
||
def get_path(self):
|
||
return self.path
|
||
|
||
def bfs_path(self, start, end):
|
||
return
|
||
"""从start到end的最短路径(含首尾)"""
|
||
from collections import deque
|
||
n, m = self.rowNums, self.colNums
|
||
visited = [[False]*m for _ in range(n)]
|
||
prev = [[None]*m for _ in range(n)]
|
||
q = deque([start])
|
||
visited[start[0]][start[1]] = True
|
||
dx = [-1, 0, 1, 0]
|
||
dy = [0, -1, 0, 1]
|
||
while q:
|
||
x, y = q.popleft()
|
||
if (x, y) == end:
|
||
break
|
||
for i in range(4):
|
||
nx, ny = x + dx[i], y + dy[i]
|
||
if 0 <= nx < n and 0 <= ny < m and not visited[nx][ny]:
|
||
if self.maze[nx][ny] != '1':
|
||
visited[nx][ny] = True
|
||
prev[nx][ny] = (x, y)
|
||
q.append((nx, ny))
|
||
# 回溯路径
|
||
path = []
|
||
cur = end
|
||
while cur and cur != start:
|
||
path.append(cur)
|
||
cur = prev[cur[0]][cur[1]]
|
||
if cur == start:
|
||
path.append(start)
|
||
path.reverse()
|
||
return path
|
||
return []
|
||
|
||
def run(self):
|
||
sn = self.build_a_tree()
|
||
# self.dfs_show(sn)
|
||
self.dfs(sn)
|
||
|
||
|
||
self.path =[_.pos for _ in sn.path]
|
||
for idx,item in enumerate(self.path):
|
||
if idx > 0:
|
||
if item == self.path[idx-1]:
|
||
del self.path[idx]
|
||
|
||
if self.path and self.end_pos and self.path[-1] != self.end_pos:
|
||
bfs_tail = self.bfs_path(self.path[-1], self.end_pos)
|
||
if bfs_tail:
|
||
self.path.extend(bfs_tail[1:])
|
||
|
||
|
||
|
||
def output_list(self):
|
||
copy_maze = self.maze
|
||
for idx, (y, x) in enumerate(self.path):
|
||
if copy_maze[y][x].startswith('s') | copy_maze[y][x].startswith('e'):
|
||
continue
|
||
if copy_maze[y][x].startswith('g') | copy_maze[y][x].startswith('t'):
|
||
copy_maze[y][x] = f"{copy_maze[y][x]}p{idx}"
|
||
continue
|
||
copy_maze[y][x] = f"p{idx}"
|
||
|
||
return copy_maze
|
||
|
||
|
||
if __name__ == '__main__':
|
||
obj = SourceCollector(filename="maze.csv")
|
||
obj.run()
|
||
path = obj.get_path()
|
||
for i in path:
|
||
print(i)
|
||
# print(sn.pos)
|
||
# pre = sn.pos
|
||
# for _ in sn.dp_path:
|
||
# dx,dy = _[0] - pre[0],_[1]-pre[1]
|
||
# if dx > 0:
|
||
# print("down")
|
||
# elif dx < 0:
|
||
# print("up")
|
||
# elif dy > 0:
|
||
# print("right")
|
||
# elif dy < 0:
|
||
# print("left")
|
||
# pre = _ |