maze_python/SourceCollector.py

254 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import csv
from collections import deque
class TreeNode:
def __init__(self):
self.fa = None
self.children = []
self.pos = None
self.final_pos = None
self.val = 0
self.id = 0
self.dp = 0
class SourceCollector:
def __init__(self, filename=None, maze=None):
self.filename = filename
self.maze = maze
self.start_pos = None
self.end_pos = None
self.path = []
self.node_path = []
if self.filename:
self.maze = []
with open(f"{self.filename}",'r') as f:
reader = csv.reader(f)
for row in reader:
t = []
for i in row:
if i.startswith('b') or i.startswith('l'):
t.append('0')
else:
t.append(i)
self.maze.append(t)
else:
self.maze = maze
self.rowNums = len(self.maze)
self.colNums = len(self.maze[0])
for i in range(self.rowNums):
for j in range(self.colNums):
if self.maze[i][j] =='s':
self.start_pos = (i,j)
if self.maze[i][j] =='e':
self.end_pos = (i,j)
def dfs_show(self,u):
if u.id != 0:
print(f"id: {u.id} , fa:{u.fa.id} , val:{u.val} , pos:{u.pos}")
else:
print(f"id: {u.id} , val:{u.val} , pos:{u.pos}")
for child in u.children:
self.dfs_show(child)
def build_a_tree(self):
cnt = 0
root = TreeNode()
root.pos = self.start_pos
root.id = 0
root.val = 0
root.fa = None
queue = deque([(self.start_pos[0], self.start_pos[1], root)])
st = [[False] * self.colNums for _ in range(self.rowNums)]
st[self.start_pos[0]][self.start_pos[1]] = True
dx = [-1, 0, 1, 0]
dy = [0, -1, 0, 1]
while queue:
x, y, parent = queue.popleft()
for i in range(4):
nx, ny = x + dx[i], y + dy[i]
if self.outofmap(nx, ny) or st[nx][ny]:
continue
if self.maze[nx][ny] != '1':
st[nx][ny] = True
new_node = TreeNode()
new_node.pos = (nx, ny)
new_node.fa = parent
cnt+=1
new_node.id = cnt
if self.maze[nx][ny].startswith('g'):
new_node.val = int(self.maze[nx][ny][1:])
elif self.maze[nx][ny].startswith('t'):
new_node.val =-1 *int(self.maze[nx][ny][1:])
parent.children.append(new_node)
queue.append((nx, ny, new_node))
return root
def outofmap(self,x,y):
return x < 0 or y < 0 or x > self.rowNums or y > self.colNums
def getlca(self,u, v):
def get_path_to_root(node):
path = []
while node:
path.append(node)
node = node.fa
return path
path_u = get_path_to_root(u)
path_v = get_path_to_root(v)
path_u.reverse()
path_v.reverse()
lca = None
for i in range(min(len(path_u),len(path_v))):
if path_u[i] == path_v[i]:
lca = path_u[i]
else:
break
if lca is None:
return []
u_to_lca = []
node = u
while node != lca:
u_to_lca.append(node)
node = node.fa
lca_to_v = []
node_list = []
node = v
while node != lca:
node_list.append(node)
node = node.fa
node_list.append(lca)
node_list.reverse()
for node in node_list:
lca_to_v.append(node)
full_path = u_to_lca + lca_to_v[:-1]
return full_path
def dfs(self, sn):
sn.dp = sn.val
sn.final_pos = sn.pos
sn.path = [sn]
# 对子节点按收益/距离优先遍历
children = sn.children[:]
# 计算每个child的“贪心优先级”金币优先距离近优先
def child_priority(child):
# 距离=曼哈顿距离
dist = abs(child.pos[0] - sn.pos[0]) + abs(child.pos[1] - sn.pos[1])
# 金币优先,陷阱次之
if self.maze[child.pos[0]][child.pos[1]].startswith('g'):
return (0, dist) # 金币优先,距离近优先
elif self.maze[child.pos[0]][child.pos[1]].startswith('t'):
return (2, dist) # 陷阱最后
else:
return (1, dist) # 普通通路
children.sort(key=child_priority)
cur = None
for idx, child in enumerate(children):
self.dfs(child)
if child.dp > 0:
sn.dp += child.dp
if cur is not None:
sn.path.extend(self.getlca(sn.path[-1], child))
sn.path.extend(child.path)
cur = child
sn.final_pos = cur.final_pos
def get_path(self):
return self.path
def bfs_path(self, start, end):
"""从start到end的最短路径含首尾"""
from collections import deque
n, m = self.rowNums, self.colNums
visited = [[False]*m for _ in range(n)]
prev = [[None]*m for _ in range(n)]
q = deque([start])
visited[start[0]][start[1]] = True
dx = [-1, 0, 1, 0]
dy = [0, -1, 0, 1]
while q:
x, y = q.popleft()
if (x, y) == end:
break
for i in range(4):
nx, ny = x + dx[i], y + dy[i]
if 0 <= nx < n and 0 <= ny < m and not visited[nx][ny]:
if self.maze[nx][ny] != '1':
visited[nx][ny] = True
prev[nx][ny] = (x, y)
q.append((nx, ny))
# 回溯路径
path = []
cur = end
while cur and cur != start:
path.append(cur)
cur = prev[cur[0]][cur[1]]
if cur == start:
path.append(start)
path.reverse()
return path
return []
def run(self):
sn = self.build_a_tree()
# self.dfs_show(sn)
self.dfs(sn)
self.path =[_.pos for _ in sn.path]
for idx,item in enumerate(self.path):
if idx > 0:
if item == self.path[idx-1]:
del self.path[idx]
if self.path and self.end_pos and self.path[-1] != self.end_pos:
bfs_tail = self.bfs_path(self.path[-1], self.end_pos)
if bfs_tail:
self.path.extend(bfs_tail[1:])
def output_list(self):
copy_maze = self.maze
for idx, (y, x) in enumerate(self.path):
if copy_maze[y][x].startswith('s') | copy_maze[y][x].startswith('e'):
continue
if copy_maze[y][x].startswith('g') | copy_maze[y][x].startswith('t'):
copy_maze[y][x] = f"{copy_maze[y][x]}p{idx}"
continue
copy_maze[y][x] = f"p{idx}"
return copy_maze
if __name__ == '__main__':
obj = SourceCollector(filename="maze.csv")
obj.run()
path = obj.get_path()
for i in path:
print(i)
# print(sn.pos)
# pre = sn.pos
# for _ in sn.dp_path:
# dx,dy = _[0] - pre[0],_[1]-pre[1]
# if dx > 0:
# print("down")
# elif dx < 0:
# print("up")
# elif dy > 0:
# print("right")
# elif dy < 0:
# print("left")
# pre = _