@clawhub-hanjing5024064-87aabcf3ee
团队周报助手 — 自动收集团队工作进展,生成结构化周报/月报
---
name: team-weekly
description: 团队周报助手 — 自动收集团队工作进展,生成结构化周报/月报
version: 1.0.0
metadata:
openclaw:
optional_env:
- TW_SUBSCRIPTION_TIER
- TW_DATA_DIR
---
# 团队周报助手(team-weekly)
你是一个专业的团队效能管理助手 Agent。你的职责是帮助用户管理团队成员、记录日常工作日志、自动汇总生成结构化周报和月报,并提供工时统计与效能分析。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `TW_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `TW_DATA_DIR` | 否 | 数据存储目录,默认 `~/.openclaw-bdi/team-weekly/` |
---
## 流程一:团队初始化
当用户说"初始化团队"、"创建团队"、"设置团队"或类似意图时,执行以下步骤:
### 步骤 1:收集团队信息
向用户收集团队名称和成员名单:
```
请提供以下信息:
1. 团队名称(如"产品研发部")
2. 成员名单(姓名、角色、参与项目)
示例:
- 张三,前端开发,项目:官网改版、管理后台
- 李四,后端开发,项目:用户系统、支付模块
- 王五,UI设计,项目:官网改版
```
### 步骤 2:订阅校验
检查成员数量是否超出当前订阅限制:
- **免费版**:最多 5 名成员
- **付费版**:最多 30 名成员
若超出限制,提示用户升级。
### 步骤 3:创建团队
```bash
python3 scripts/team_store.py --action init --data '{"name": "<团队名>", "members": [{"name": "张三", "role": "前端开发", "projects": ["官网改版"]}]}'
```
### 步骤 4:确认结果
向用户展示团队创建结果,列出所有成员信息。
---
## 流程二:成员管理
当用户说"添加成员"、"删除成员"、"查看团队"或类似意图时:
### 添加成员
```bash
python3 scripts/team_store.py --action add-member --data '{"name": "赵六", "role": "测试工程师", "projects": ["用户系统"]}'
```
### 删除成员
```bash
python3 scripts/team_store.py --action remove-member --data '{"name": "赵六"}'
```
### 查看团队
```bash
python3 scripts/team_store.py --action list
```
---
## 流程三:工作日志录入
当用户说"记录工作"、"添加日志"或描述某人完成某项工作时,执行以下步骤:
### 步骤 1:解析用户输入
支持自然语言输入,自动解析成员、任务、工时、项目等信息。
输入示例:
- "张三今天完成了官网首页设计,耗时6小时"
- "李四完成用户模块API开发,8小时,项目:用户系统"
- "王五设计了3个页面,4小时,设计"
### 步骤 2:写入工作日志
```bash
python3 scripts/worklog_manager.py --action add --data '{"member_name": "张三", "task_description": "官网首页设计", "hours": 6, "project": "官网改版", "category": "设计", "date": "2024-01-15"}'
```
也可使用自然语言模式:
```bash
python3 scripts/worklog_manager.py --action add --data '{"natural_input": "张三今天完成了官网首页设计,耗时6小时"}'
```
### 步骤 3:确认录入
向用户确认录入成功,展示录入的日志内容。
---
## 流程四:查询工作日志
当用户说"查看日志"、"查询工作记录"或类似意图时:
### 基本查询
```bash
python3 scripts/worklog_manager.py --action list --data '{"member_name": "张三"}'
python3 scripts/worklog_manager.py --action list --data '{"date": "2024-01-15"}'
python3 scripts/worklog_manager.py --action list --data '{"date_from": "2024-01-08", "date_to": "2024-01-14"}'
```
### 高级查询(按周/月/分组)
```bash
python3 scripts/worklog_manager.py --action query --data '{"week": "this", "group_by": "member"}'
python3 scripts/worklog_manager.py --action query --data '{"month": "2024-01", "group_by": "project"}'
```
---
## 流程五:生成周报
当用户说"生成周报"、"本周周报"、"上周周报"或类似意图时,执行以下步骤:
### 步骤 1:确认报告周期
确认用户需要哪一周的周报(默认本周)。
### 步骤 2:汇总生成
```bash
python3 scripts/report_compiler.py --action weekly --week this
python3 scripts/report_compiler.py --action weekly --week last
python3 scripts/report_compiler.py --action weekly --week 2024-W03
```
### 步骤 3:输出报告
**免费版周报内容:**
- 概览统计表(人数、任务数、总工时、项目数)
- 成员工作汇总表
- 项目进展表
- 详细工作记录
**付费版周报额外内容:**
- 项目工时饼图(Mermaid)
- 成员工时柱状图(Mermaid)
- 洞察与建议
---
## 流程六:生成月报(仅付费版)
当用户说"生成月报"、"本月月报"或类似意图时:
### 步骤 1:订阅校验
确认用户为付费版。免费版用户提示:
> 月报汇总为付费版功能。当前为免费版,如需使用请升级至付费版(¥69/月)。
### 步骤 2:汇总生成
```bash
python3 scripts/report_compiler.py --action monthly --month this
python3 scripts/report_compiler.py --action monthly --month 2024-01
```
### 步骤 3:输出报告
月报包含:
- 执行摘要
- 核心指标
- 周度趋势(含趋势图)
- 成员工作汇总(含占比)
- 项目工时分布(含饼图)
- 洞察与建议
---
## 流程七:工时统计与效能分析(仅付费版)
当用户说"工时统计"、"工作量分析"、"效率分析"或类似意图时:
### 工时统计
```bash
python3 scripts/workload_analyzer.py --action workload --data '{"date_from": "2024-01-01", "date_to": "2024-01-31"}'
python3 scripts/workload_analyzer.py --action workload --member 张三
```
输出内容:
- 成员/项目/分类工时分布表
- 饼图可视化(Mermaid)
### 趋势分析
```bash
python3 scripts/workload_analyzer.py --action trend --member 张三
python3 scripts/workload_analyzer.py --action trend --data '{"weeks": 8}'
```
输出内容:
- 周度工时数据表
- 趋势折线图(Mermaid)
- 趋势判断(上升/下降/稳定)
### 效率分析
```bash
python3 scripts/workload_analyzer.py --action efficiency --member 李四
python3 scripts/workload_analyzer.py --action efficiency --data '{"weeks": 4}'
```
输出内容:
- 成员效率对比表(工时、任务数、日均工时、工作天数)
- 工时对比柱状图(Mermaid)
- 个人工作分类分布
### 甘特图
```bash
python3 scripts/workload_analyzer.py --action gantt --project 官网改版
python3 scripts/workload_analyzer.py --action gantt --data '{"weeks": 4}'
```
输出内容:
- Mermaid 甘特图
- 项目汇总表
---
## 订阅校验逻辑
在每次涉及功能限制的操作前,必须执行以下校验:
### 读取订阅等级
```
tier = env TW_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥69/月) |
|------|---------------|----------------------|
| 团队人数 | 5 人 | 30 人 |
| 周报模板 | 1 个(基础) | 5 个(行业模板) |
| 工作日志录入 | 支持 | 支持 |
| 自动汇总周报 | 基础表格 | 表格 + 图表 + 洞察 |
| 月报汇总 | 不支持 | 支持 |
| 工时统计 | 不支持 | 支持 |
| 项目进度追踪 | 不支持 | 支持(甘特图) |
| 绩效趋势分析 | 不支持 | 支持 |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥69/月),请联系管理员或访问订阅管理页面。"
4. 不要直接拒绝,而是提供免费版可用的替代方案(如果有的话)。
---
## 参考文档
在生成报告时,请参考以下文档:
- **周报模板**:`references/weekly-templates.md` — 包含各类周报模板和示例。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 自动解析用户的自然语言输入,提取成员、任务、工时、项目等信息。
3. 对用户的问题给出清晰、结构化的回答,优先使用表格展示数据。
4. 主动提供工作洞察和建议,而不仅仅是返回原始数据。
5. 遇到模糊的用户意图时,主动追问以明确需求。
6. 尊重订阅等级限制,在提示升级时保持友好,不要反复推销。
7. 日志录入时自动填充可推断的信息(如今天日期、默认分类)。
FILE:assets/README.md
# 团队周报助手(team-weekly)
> 自动收集团队工作进展,一键生成结构化周报/月报,告别手工汇总!
## 核心功能
**团队管理** — 快速创建团队、添加成员、分配项目角色
**工作日志** — 支持自然语言录入,自动解析成员、任务、工时、项目
**智能周报** — 一键汇总全队工作,按人员/项目分类,输出 Markdown 格式
**效能分析** — 工时统计、趋势图表、绩效对比,数据驱动管理决策
---
## 快速开始
### 1. 初始化团队
```
初始化团队:产品研发部
成员:
- 张三,前端开发,项目:官网改版
- 李四,后端开发,项目:用户系统
- 王五,UI设计,项目:官网改版
```
### 2. 记录工作日志
```
张三今天完成了官网首页设计,耗时6小时
李四完成用户登录API开发,8小时,项目:用户系统
```
### 3. 生成周报
```
生成本周周报
```
就这么简单!
---
## 版本对比
| 功能 | 免费版 | 付费版 ¥69/月 |
|------|:------:|:------------:|
| 团队人数 | 5 人 | 30 人 |
| 周报模板 | 1 个基础模板 | 5 个行业模板 |
| 工作日志录入 | ✅ | ✅ |
| 自然语言解析 | ✅ | ✅ |
| 自动汇总周报 | 基础表格 | 表格 + 图表 + 洞察 |
| 月报汇总 | - | ✅ |
| 工时统计 | - | ✅ |
| 项目进度追踪 | - | ✅ 甘特图 |
| 绩效趋势分析 | - | ✅ |
| Mermaid 可视化 | - | ✅ |
---
## 示例输出
### 免费版周报示例
```
# 产品研发部 周报
**报告周期**: 2024-01-15 ~ 2024-01-21
## 概览
| 指标 | 数值 |
|------|------|
| 参与人数 | 3 |
| 任务总数 | 12 |
| 总工时 | 84h |
| 涉及项目 | 2 |
## 成员工作汇总
| 成员 | 任务数 | 工时 | 涉及项目 |
|------|--------|------|----------|
| 张三 | 5 | 32h | 官网改版 |
| 李四 | 4 | 28h | 用户系统 |
| 王五 | 3 | 24h | 官网改版 |
```
### 付费版增强内容
付费版额外包含:
- 项目工时饼图(Mermaid 可视化)
- 成员工时柱状图
- 工作洞察与建议
- 月度汇总报告
- 绩效趋势分析
---
## 使用场景
**每日站会** — 快速查看昨日工作记录,同步进展
**周五汇报** — 一键生成结构化周报,节省 30 分钟汇总时间
**月度复盘** — 全面的月度数据分析,量化团队产出
**绩效评估** — 客观的工时统计与趋势分析,辅助管理决策
**项目管理** — 甘特图跟踪项目进度,可视化资源分配
---
## 环境变量
| 变量 | 必需 | 默认值 | 说明 |
|------|:----:|--------|------|
| `TW_SUBSCRIPTION_TIER` | 否 | `free` | 订阅等级(`free` 或 `paid`) |
| `TW_DATA_DIR` | 否 | `~/.openclaw-bdi/team-weekly/` | 数据存储目录 |
---
## 常见问题
**Q: 数据存储在哪里?**
A: 默认存储在 `~/.openclaw-bdi/team-weekly/` 目录下,以 JSON 格式保存。可通过 `TW_DATA_DIR` 环境变量自定义路径。
**Q: 免费版有什么限制?**
A: 免费版支持最多 5 名团队成员,提供基础周报(表格形式),不支持月报、工时统计、甘特图等高级功能。
**Q: 如何升级到付费版?**
A: 设置环境变量 `TW_SUBSCRIPTION_TIER=paid` 即可启用付费版功能。
**Q: 支持哪些工作分类?**
A: 支持 5 种分类:开发、设计、测试、会议、其他。录入时自动识别或可手动指定。
**Q: 可以批量录入日志吗?**
A: 可以逐条录入,每条记录会自动解析自然语言输入。
**Q: 数据可以导出吗?**
A: 数据以 JSON 格式存储,可直接读取。周报/月报以 Markdown 格式输出,方便复制分享。
---
## 技术说明
- 纯 Python 3 实现,仅使用标准库
- JSON 文件存储,无需数据库
- 命令行接口,支持 `--action`、`--data`/`--data-file` 参数
- 所有输出为 JSON 格式,便于程序集成
- Mermaid 图表语法,兼容主流 Markdown 渲染器
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
team-weekly 共享工具模块
提供订阅校验、数据目录管理、日期处理、格式化等通用功能。
"""
import argparse
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
# ============================================================
# 常量与配置
# ============================================================
ENV_SUBSCRIPTION_TIER = "TW_SUBSCRIPTION_TIER"
ENV_DATA_DIR = "TW_DATA_DIR"
DEFAULT_DATA_DIR = os.path.expanduser("~/.openclaw-bdi/team-weekly")
# 订阅等级配置
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_members": 5,
"templates": ["basic"],
"features": [
"worklog_crud",
"basic_weekly_report",
],
"price": "免费",
},
"paid": {
"tier": "paid",
"max_members": 30,
"templates": ["basic", "tech", "marketing", "sales", "design"],
"features": [
"worklog_crud",
"basic_weekly_report",
"enhanced_weekly_report",
"monthly_report",
"workload_analysis",
"project_tracking",
"performance_trend",
],
"price": "¥69/月",
},
}
# ============================================================
# 数据目录管理
# ============================================================
def get_data_dir() -> str:
"""获取数据存储目录路径。
优先读取环境变量 TW_DATA_DIR,否则使用默认路径
~/.openclaw-bdi/team-weekly/。
Returns:
数据目录的绝对路径。
"""
return os.environ.get(ENV_DATA_DIR, DEFAULT_DATA_DIR)
def ensure_data_dir() -> str:
"""确保数据目录存在,若不存在则创建。
Returns:
数据目录的绝对路径。
"""
data_dir = get_data_dir()
os.makedirs(data_dir, exist_ok=True)
return data_dir
def get_team_file() -> str:
"""获取团队配置文件路径。"""
return os.path.join(get_data_dir(), "team.json")
def get_worklog_file() -> str:
"""获取工作日志文件路径。"""
return os.path.join(get_data_dir(), "worklogs.json")
# ============================================================
# 订阅校验
# ============================================================
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 TW_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get(ENV_SUBSCRIPTION_TIER, "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def require_paid(feature_name: str) -> None:
"""校验当前订阅是否为付费版,否则抛出错误。
Args:
feature_name: 功能名称,用于错误提示。
Raises:
PermissionError: 当前为免费版时抛出。
"""
sub = check_subscription()
if sub["tier"] != "paid":
raise PermissionError(
f"「{feature_name}」为付费版功能。"
f"当前为免费版,如需使用请升级至付费版(¥69/月)。"
)
def check_feature(feature: str) -> bool:
"""检查当前订阅等级是否支持指定功能。
Args:
feature: 功能标识符。
Returns:
是否支持该功能。
"""
sub = check_subscription()
return feature in sub["features"]
# ============================================================
# JSON 输入输出
# ============================================================
def read_input_data(args: argparse.Namespace) -> Dict[str, Any]:
"""从命令行参数或文件读取 JSON 输入数据。
支持两种方式:
- --data: 直接传入 JSON 字符串
- --data-file: 从文件读取 JSON
Args:
args: 解析后的命令行参数,需包含 data 和 data_file 属性。
Returns:
解析后的字典。
Raises:
ValueError: 当输入无效或 JSON 格式错误时抛出。
"""
raw = None
if hasattr(args, "data") and args.data:
raw = args.data
elif hasattr(args, "data_file") and args.data_file:
try:
with open(args.data_file, "r", encoding="utf-8") as f:
raw = f.read()
except FileNotFoundError:
raise ValueError(f"数据文件不存在: {args.data_file}")
except IOError as e:
raise ValueError(f"读取数据文件失败: {e}")
if not raw:
raise ValueError("请通过 --data 或 --data-file 提供输入数据")
try:
data = json.loads(raw)
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if not isinstance(data, dict):
raise ValueError(f"期望输入为 JSON 对象,实际类型为 {type(data).__name__}")
return data
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 命令行参数解析
# ============================================================
def create_base_parser(description: str) -> argparse.ArgumentParser:
"""创建包含通用参数的基础解析器。
Args:
description: 工具描述文字。
Returns:
配置好的 ArgumentParser 实例。
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--action",
required=True,
help="要执行的操作",
)
parser.add_argument(
"--data",
default=None,
help="JSON 格式的输入数据",
)
parser.add_argument(
"--data-file",
default=None,
help="包含 JSON 数据的文件路径",
)
return parser
# ============================================================
# 日期与时间工具
# ============================================================
def today_str() -> str:
"""返回今天的日期字符串(YYYY-MM-DD)。"""
return datetime.now().strftime("%Y-%m-%d")
def now_str() -> str:
"""返回当前时间字符串(YYYY-MM-DD HH:MM:SS)。"""
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def parse_date(date_str: str) -> datetime:
"""解析日期字符串为 datetime 对象。
支持格式:YYYY-MM-DD
Args:
date_str: 日期字符串。
Returns:
datetime 对象。
Raises:
ValueError: 日期格式不正确时抛出。
"""
try:
return datetime.strptime(date_str, "%Y-%m-%d")
except ValueError:
raise ValueError(f"日期格式不正确: {date_str!r},请使用 YYYY-MM-DD 格式")
def week_range_str(date: Optional[datetime] = None) -> Tuple[str, str]:
"""获取指定日期所在周的起止日期字符串(周一至周日)。
Args:
date: 指定日期,默认为今天。
Returns:
(周一日期, 周日日期) 的字符串元组。
"""
if date is None:
date = datetime.now()
monday = date - timedelta(days=date.weekday())
sunday = monday + timedelta(days=6)
return monday.strftime("%Y-%m-%d"), sunday.strftime("%Y-%m-%d")
def month_range_str(year: int, month: int) -> Tuple[str, str]:
"""获取指定月份的起止日期字符串。
Args:
year: 年份。
month: 月份(1-12)。
Returns:
(月初日期, 月末日期) 的字符串元组。
"""
import calendar
_, last_day = calendar.monthrange(year, month)
start = f"{year:04d}-{month:02d}-01"
end = f"{year:04d}-{month:02d}-{last_day:02d}"
return start, end
def parse_date_range(range_str: str) -> Tuple[str, str]:
"""解析日期范围字符串。
支持格式:
- "2024-01-01~2024-01-07"
- "2024-01-01 to 2024-01-07"
- "2024-01-01,2024-01-07"
Args:
range_str: 日期范围字符串。
Returns:
(起始日期, 结束日期) 的字符串元组。
Raises:
ValueError: 格式不正确时抛出。
"""
import re
parts = re.split(r'[~,]|\s+to\s+', range_str.strip())
parts = [p.strip() for p in parts if p.strip()]
if len(parts) != 2:
raise ValueError(
f"日期范围格式不正确: {range_str!r},"
f"请使用 'YYYY-MM-DD~YYYY-MM-DD' 格式"
)
# 验证两个日期格式
parse_date(parts[0])
parse_date(parts[1])
return parts[0], parts[1]
# ============================================================
# 格式化工具
# ============================================================
def format_hours(hours: float) -> str:
"""将小时数格式化为可读字符串。
Args:
hours: 小时数。
Returns:
格式化后的字符串,例如 8.5 → "8.5h",8.0 → "8h"
"""
if hours == int(hours):
return f"{int(hours)}h"
return f"{hours:.1f}h"
def format_percentage(value: float, decimals: int = 1) -> str:
"""将小数格式化为百分比字符串。
Args:
value: 待格式化的小数值(0.156 表示 15.6%)。
decimals: 百分比小数位数,默认为 1。
Returns:
百分比字符串,例如 0.156 → "15.6%"
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
pct = num * 100
return f"{pct:.{decimals}f}%"
def generate_id() -> str:
"""生成唯一 ID(基于时间戳 + 随机后缀)。
Returns:
唯一 ID 字符串。
"""
import random
import string
ts = datetime.now().strftime("%Y%m%d%H%M%S")
suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=6))
return f"{ts}_{suffix}"
# ============================================================
# 数据存储工具
# ============================================================
def load_json_file(filepath: str) -> Any:
"""加载 JSON 文件,若文件不存在返回 None。
Args:
filepath: JSON 文件路径。
Returns:
解析后的数据,文件不存在时返回 None。
"""
if not os.path.exists(filepath):
return None
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return None
def save_json_file(filepath: str, data: Any) -> None:
"""将数据保存为 JSON 文件。
自动创建父目录。
Args:
filepath: 目标文件路径。
data: 待保存的数据。
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
# ============================================================
# Markdown 表格生成
# ============================================================
def build_markdown_table(headers: List[str], rows: List[List[str]]) -> str:
"""生成 Markdown 格式表格。
Args:
headers: 表头列表。
rows: 数据行列表,每行为字符串列表。
Returns:
Markdown 表格字符串。
"""
if not headers:
return ""
lines = []
# 表头
lines.append("| " + " | ".join(str(h) for h in headers) + " |")
# 分隔线
lines.append("| " + " | ".join("---" for _ in headers) + " |")
# 数据行
for row in rows:
# 确保行的列数与表头一致
padded = list(row) + [""] * (len(headers) - len(row))
lines.append("| " + " | ".join(str(c) for c in padded[:len(headers)]) + " |")
return "\n".join(lines)
FILE:scripts/workload_analyzer.py
#!/usr/bin/env python3
"""
team-weekly 工时统计与效能分析模块(付费版功能)
提供工时统计、趋势分析、效率评估和甘特图生成。
"""
import calendar
import sys
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
build_markdown_table,
check_subscription,
create_base_parser,
format_hours,
format_percentage,
now_str,
output_error,
output_success,
read_input_data,
require_paid,
week_range_str,
)
from worklog_manager import _load_worklogs
from team_store import _load_team
# ============================================================
# 工时统计
# ============================================================
def analyze_workload(data: Dict[str, Any]) -> Dict[str, Any]:
"""工时统计分析。
Args:
data: 包含以下字段的字典:
- date_from (str, optional): 起始日期
- date_to (str, optional): 结束日期
- member (str, optional): 指定成员
- project (str, optional): 指定项目
Returns:
工时统计结果和 Markdown 报告。
"""
require_paid("工时统计")
logs = _load_worklogs()
date_from = data.get("date_from")
date_to = data.get("date_to")
member = data.get("member")
project = data.get("project")
if date_from:
logs = [l for l in logs if l["date"] >= date_from]
if date_to:
logs = [l for l in logs if l["date"] <= date_to]
if member:
logs = [l for l in logs if l["member_name"] == member]
if project:
logs = [l for l in logs if l.get("project") == project]
total_hours = sum(l.get("hours", 0) for l in logs)
# 按成员统计
by_member: Dict[str, float] = defaultdict(float)
for log in logs:
by_member[log["member_name"]] += log.get("hours", 0)
# 按项目统计
by_project: Dict[str, float] = defaultdict(float)
for log in logs:
proj = log.get("project", "未分类") or "未分类"
by_project[proj] += log.get("hours", 0)
# 按分类统计
by_category: Dict[str, float] = defaultdict(float)
for log in logs:
cat = log.get("category", "其他")
by_category[cat] += log.get("hours", 0)
# 按日统计
by_date: Dict[str, float] = defaultdict(float)
for log in logs:
by_date[log["date"]] += log.get("hours", 0)
# 生成 Markdown
period_str = ""
if date_from and date_to:
period_str = f"{date_from} ~ {date_to}"
elif date_from:
period_str = f"{date_from} 至今"
elif date_to:
period_str = f"截至 {date_to}"
else:
period_str = "全部"
md_lines = []
md_lines.append("# 工时统计分析报告")
md_lines.append("")
md_lines.append(f"**统计范围**: {period_str}")
if member:
md_lines.append(f"**指定成员**: {member}")
if project:
md_lines.append(f"**指定项目**: {project}")
md_lines.append(f"**生成时间**: {now_str()}")
md_lines.append("")
# 总览
md_lines.append("## 总览")
md_lines.append("")
md_lines.append(f"- 总工时: **{format_hours(total_hours)}**")
md_lines.append(f"- 涉及成员: **{len(by_member)}** 人")
md_lines.append(f"- 涉及项目: **{len(by_project)}** 个")
md_lines.append(f"- 工作天数: **{len(by_date)}** 天")
md_lines.append("")
# 成员工时
if not member:
md_lines.append("## 成员工时分布")
md_lines.append("")
headers = ["成员", "工时", "占比"]
rows = []
for name, hours in sorted(by_member.items(), key=lambda x: x[1], reverse=True):
pct = hours / total_hours if total_hours > 0 else 0
rows.append([name, format_hours(hours), format_percentage(pct)])
md_lines.append(build_markdown_table(headers, rows))
md_lines.append("")
# 饼图
md_lines.append("```mermaid")
md_lines.append("pie title 成员工时分布")
for name, hours in sorted(by_member.items(), key=lambda x: x[1], reverse=True):
md_lines.append(f' "{name}" : {hours:.1f}')
md_lines.append("```")
md_lines.append("")
# 项目工时
if not project:
md_lines.append("## 项目工时分布")
md_lines.append("")
headers = ["项目", "工时", "占比"]
rows = []
for proj, hours in sorted(by_project.items(), key=lambda x: x[1], reverse=True):
pct = hours / total_hours if total_hours > 0 else 0
rows.append([proj, format_hours(hours), format_percentage(pct)])
md_lines.append(build_markdown_table(headers, rows))
md_lines.append("")
# 分类工时
md_lines.append("## 工作分类分布")
md_lines.append("")
cat_headers = ["分类", "工时", "占比"]
cat_rows = []
for cat, hours in sorted(by_category.items(), key=lambda x: x[1], reverse=True):
pct = hours / total_hours if total_hours > 0 else 0
cat_rows.append([cat, format_hours(hours), format_percentage(pct)])
md_lines.append(build_markdown_table(cat_headers, cat_rows))
md_lines.append("")
md_lines.append("```mermaid")
md_lines.append("pie title 工作分类分布")
for cat, hours in sorted(by_category.items(), key=lambda x: x[1], reverse=True):
md_lines.append(f' "{cat}" : {hours:.1f}')
md_lines.append("```")
md_lines.append("")
report_md = "\n".join(md_lines)
return {
"total_hours": total_hours,
"by_member": dict(by_member),
"by_project": dict(by_project),
"by_category": dict(by_category),
"report_markdown": report_md,
}
# ============================================================
# 趋势分析
# ============================================================
def analyze_trend(data: Dict[str, Any]) -> Dict[str, Any]:
"""工时趋势分析。
Args:
data: 包含以下字段的字典:
- member (str, optional): 指定成员
- project (str, optional): 指定项目
- period (str, optional): 分析周期 (week/month),默认 week
- weeks (int, optional): 回溯周数,默认 4
Returns:
趋势分析结果和 Markdown 报告。
"""
require_paid("趋势分析")
logs = _load_worklogs()
member = data.get("member")
project = data.get("project")
weeks_back = int(data.get("weeks", 4))
if member:
logs = [l for l in logs if l["member_name"] == member]
if project:
logs = [l for l in logs if l.get("project") == project]
now = datetime.now()
# 按周聚合
weekly_stats: Dict[str, Dict[str, Any]] = {}
for i in range(weeks_back):
target = now - timedelta(weeks=i)
w_start, w_end = week_range_str(target)
week_key = f"W{weeks_back - i}"
week_label = f"{w_start}~{w_end}"
week_logs = [l for l in logs if w_start <= l["date"] <= w_end]
weekly_stats[week_key] = {
"label": week_label,
"hours": sum(l.get("hours", 0) for l in week_logs),
"tasks": len(week_logs),
}
# 生成 Markdown
md_lines = []
title = "工时趋势分析"
if member:
title = f"{member} 工时趋势分析"
if project:
title += f"(项目: {project})"
md_lines.append(f"# {title}")
md_lines.append("")
md_lines.append(f"**分析周期**: 近 {weeks_back} 周")
md_lines.append(f"**生成时间**: {now_str()}")
md_lines.append("")
# 数据表格
md_lines.append("## 周度数据")
md_lines.append("")
headers = ["周", "日期范围", "任务数", "工时"]
rows = []
for key in sorted(weekly_stats.keys()):
ws = weekly_stats[key]
rows.append([key, ws["label"], str(ws["tasks"]), format_hours(ws["hours"])])
md_lines.append(build_markdown_table(headers, rows))
md_lines.append("")
# 趋势图
sorted_keys = sorted(weekly_stats.keys())
md_lines.append("## 趋势图")
md_lines.append("")
md_lines.append("```mermaid")
md_lines.append("xychart-beta")
md_lines.append(f' title "工时趋势(近{weeks_back}周)"')
md_lines.append(' x-axis [' + ", ".join(f'"{k}"' for k in sorted_keys) + ']')
md_lines.append(' y-axis "工时(小时)"')
md_lines.append(' bar [' + ", ".join(str(weekly_stats[k]["hours"]) for k in sorted_keys) + ']')
md_lines.append(' line [' + ", ".join(str(weekly_stats[k]["hours"]) for k in sorted_keys) + ']')
md_lines.append("```")
md_lines.append("")
# 趋势判断
hours_list = [weekly_stats[k]["hours"] for k in sorted_keys]
md_lines.append("## 趋势判断")
md_lines.append("")
if len(hours_list) >= 2 and hours_list[0] > 0:
change_pct = (hours_list[-1] - hours_list[0]) / hours_list[0]
if change_pct > 0.2:
md_lines.append(f"- 工时呈**上升**趋势,增幅 {format_percentage(change_pct)}。")
elif change_pct < -0.2:
md_lines.append(f"- 工时呈**下降**趋势,降幅 {format_percentage(abs(change_pct))}。")
else:
md_lines.append("- 工时保持**稳定**。")
avg_hours = sum(hours_list) / len(hours_list) if hours_list else 0
md_lines.append(f"- 周均工时: **{format_hours(avg_hours)}**")
md_lines.append("")
report_md = "\n".join(md_lines)
return {
"weekly_stats": weekly_stats,
"trend": "up" if hours_list and len(hours_list) >= 2 and hours_list[-1] > hours_list[0] * 1.2
else ("down" if hours_list and len(hours_list) >= 2 and hours_list[-1] < hours_list[0] * 0.8
else "stable"),
"report_markdown": report_md,
}
# ============================================================
# 效率分析
# ============================================================
def analyze_efficiency(data: Dict[str, Any]) -> Dict[str, Any]:
"""成员效率分析。
Args:
data: 包含以下字段的字典:
- member (str, optional): 指定成员(不指定则分析所有成员)
- weeks (int, optional): 回溯周数,默认 4
Returns:
效率分析结果和 Markdown 报告。
"""
require_paid("效率分析")
logs = _load_worklogs()
member = data.get("member")
weeks_back = int(data.get("weeks", 4))
now = datetime.now()
cutoff = now - timedelta(weeks=weeks_back)
cutoff_str = cutoff.strftime("%Y-%m-%d")
logs = [l for l in logs if l["date"] >= cutoff_str]
if member:
logs = [l for l in logs if l["member_name"] == member]
# 按成员聚合
member_stats: Dict[str, Dict[str, Any]] = {}
for log in logs:
name = log["member_name"]
if name not in member_stats:
member_stats[name] = {
"total_hours": 0,
"task_count": 0,
"projects": set(),
"categories": defaultdict(float),
"daily_hours": defaultdict(float),
}
ms = member_stats[name]
hours = log.get("hours", 0)
ms["total_hours"] += hours
ms["task_count"] += 1
if log.get("project"):
ms["projects"].add(log["project"])
ms["categories"][log.get("category", "其他")] += hours
ms["daily_hours"][log["date"]] += hours
# 生成 Markdown
md_lines = []
title = "团队效率分析" if not member else f"{member} 效率分析"
md_lines.append(f"# {title}")
md_lines.append("")
md_lines.append(f"**分析周期**: 近 {weeks_back} 周")
md_lines.append(f"**生成时间**: {now_str()}")
md_lines.append("")
if not member_stats:
md_lines.append("暂无工作记录。")
return {"member_stats": {}, "report_markdown": "\n".join(md_lines)}
# 效率对比表
md_lines.append("## 成员效率对比")
md_lines.append("")
headers = ["成员", "总工时", "任务数", "日均工时", "工作天数", "项目数"]
rows = []
for name, ms in sorted(member_stats.items(), key=lambda x: x[1]["total_hours"], reverse=True):
work_days = len(ms["daily_hours"])
avg_daily = ms["total_hours"] / work_days if work_days > 0 else 0
rows.append([
name,
format_hours(ms["total_hours"]),
str(ms["task_count"]),
format_hours(avg_daily),
str(work_days),
str(len(ms["projects"])),
])
md_lines.append(build_markdown_table(headers, rows))
md_lines.append("")
# 工时对比柱状图
if len(member_stats) > 1:
md_lines.append("```mermaid")
md_lines.append("xychart-beta")
md_lines.append(f' title "成员工时对比(近{weeks_back}周)"')
sorted_names = sorted(member_stats.keys())
md_lines.append(' x-axis [' + ", ".join(f'"{n}"' for n in sorted_names) + ']')
md_lines.append(' y-axis "工时(小时)"')
md_lines.append(' bar [' + ", ".join(
str(member_stats[n]["total_hours"]) for n in sorted_names
) + ']')
md_lines.append("```")
md_lines.append("")
# 个人详细分析
for name, ms in sorted(member_stats.items()):
md_lines.append(f"### {name} 详细分析")
md_lines.append("")
work_days = len(ms["daily_hours"])
avg_daily = ms["total_hours"] / work_days if work_days > 0 else 0
md_lines.append(f"- 总工时: {format_hours(ms['total_hours'])}")
md_lines.append(f"- 任务数: {ms['task_count']}")
md_lines.append(f"- 工作天数: {work_days}")
md_lines.append(f"- 日均工时: {format_hours(avg_daily)}")
md_lines.append(f"- 涉及项目: {', '.join(ms['projects']) if ms['projects'] else '-'}")
md_lines.append("")
# 分类占比
if ms["categories"]:
md_lines.append("工作分类分布:")
md_lines.append("")
for cat, hours in sorted(ms["categories"].items(), key=lambda x: x[1], reverse=True):
pct = hours / ms["total_hours"] if ms["total_hours"] > 0 else 0
md_lines.append(f" - {cat}: {format_hours(hours)} ({format_percentage(pct)})")
md_lines.append("")
report_md = "\n".join(md_lines)
# 序列化 set 和 defaultdict
serializable_stats = {}
for name, ms in member_stats.items():
serializable_stats[name] = {
"total_hours": ms["total_hours"],
"task_count": ms["task_count"],
"projects": list(ms["projects"]),
"categories": dict(ms["categories"]),
"work_days": len(ms["daily_hours"]),
}
return {
"member_stats": serializable_stats,
"report_markdown": report_md,
}
# ============================================================
# 甘特图生成
# ============================================================
def generate_gantt(data: Dict[str, Any]) -> Dict[str, Any]:
"""生成项目甘特图。
Args:
data: 包含以下字段的字典:
- project (str, optional): 指定项目
- weeks (int, optional): 回溯周数,默认 4
Returns:
甘特图 Mermaid 代码和结构化数据。
"""
require_paid("甘特图")
logs = _load_worklogs()
project = data.get("project")
weeks_back = int(data.get("weeks", 4))
now = datetime.now()
cutoff = now - timedelta(weeks=weeks_back)
cutoff_str = cutoff.strftime("%Y-%m-%d")
logs = [l for l in logs if l["date"] >= cutoff_str]
if project:
logs = [l for l in logs if l.get("project") == project]
# 按项目-成员-日期聚合
project_tasks: Dict[str, Dict[str, Dict[str, Any]]] = {}
for log in logs:
proj = log.get("project", "未分类") or "未分类"
member = log["member_name"]
key = f"{proj}_{member}"
if proj not in project_tasks:
project_tasks[proj] = {}
if member not in project_tasks[proj]:
project_tasks[proj][member] = {
"start": log["date"],
"end": log["date"],
"hours": 0,
"tasks": 0,
}
pt = project_tasks[proj][member]
pt["start"] = min(pt["start"], log["date"])
pt["end"] = max(pt["end"], log["date"])
pt["hours"] += log.get("hours", 0)
pt["tasks"] += 1
# 生成 Mermaid 甘特图
md_lines = []
md_lines.append("# 项目进度甘特图")
md_lines.append("")
md_lines.append(f"**统计范围**: 近 {weeks_back} 周")
if project:
md_lines.append(f"**指定项目**: {project}")
md_lines.append(f"**生成时间**: {now_str()}")
md_lines.append("")
md_lines.append("```mermaid")
md_lines.append("gantt")
md_lines.append(" title 项目进度甘特图")
md_lines.append(" dateFormat YYYY-MM-DD")
md_lines.append("")
for proj, members in sorted(project_tasks.items()):
md_lines.append(f" section {proj}")
for member, info in sorted(members.items()):
task_name = f"{member}({format_hours(info['hours'])})"
md_lines.append(f" {task_name} : {info['start']}, {info['end']}")
md_lines.append("```")
md_lines.append("")
# 项目统计表
md_lines.append("## 项目汇总")
md_lines.append("")
headers = ["项目", "参与成员", "总工时", "起止日期"]
rows = []
for proj, members in sorted(project_tasks.items()):
total_h = sum(m["hours"] for m in members.values())
all_starts = [m["start"] for m in members.values()]
all_ends = [m["end"] for m in members.values()]
rows.append([
proj,
str(len(members)),
format_hours(total_h),
f"{min(all_starts)} ~ {max(all_ends)}",
])
md_lines.append(build_markdown_table(headers, rows))
md_lines.append("")
report_md = "\n".join(md_lines)
return {
"project_tasks": {
proj: {member: info for member, info in members.items()}
for proj, members in project_tasks.items()
},
"report_markdown": report_md,
}
# ============================================================
# 命令行入口
# ============================================================
def main() -> None:
"""命令行入口函数。"""
parser = create_base_parser("工时统计与效能分析工具")
parser.add_argument("--member", default=None, help="指定成员")
parser.add_argument("--project", default=None, help="指定项目")
args, _ = parser.parse_known_args()
try:
action = args.action
data = {}
try:
data = read_input_data(args)
except ValueError:
pass
if args.member:
data["member"] = args.member
if args.project:
data["project"] = args.project
if action == "workload":
result = analyze_workload(data)
output_success(result)
elif action == "trend":
result = analyze_trend(data)
output_success(result)
elif action == "efficiency":
result = analyze_efficiency(data)
output_success(result)
elif action == "gantt":
result = generate_gantt(data)
output_success(result)
else:
output_error(f"未知操作: {action}", "UNKNOWN_ACTION")
except (ValueError, PermissionError) as e:
output_error(str(e), type(e).__name__.upper())
except Exception as e:
output_error(f"内部错误: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
main()
FILE:scripts/report_compiler.py
#!/usr/bin/env python3
"""
team-weekly 周报/月报汇总生成模块
从工作日志数据中汇总生成结构化的周报和月报。
"""
import calendar
import sys
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from utils import (
build_markdown_table,
check_feature,
check_subscription,
create_base_parser,
format_hours,
format_percentage,
month_range_str,
now_str,
output_error,
output_success,
read_input_data,
week_range_str,
)
from worklog_manager import _load_worklogs
from team_store import _load_team
# ============================================================
# 数据聚合
# ============================================================
def _filter_logs_by_range(
logs: List[Dict[str, Any]],
date_from: str,
date_to: str,
) -> List[Dict[str, Any]]:
"""按日期范围过滤日志。"""
return [l for l in logs if date_from <= l["date"] <= date_to]
def _aggregate_by_member(logs: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""按成员聚合日志数据。"""
result: Dict[str, Dict[str, Any]] = {}
for log in logs:
name = log["member_name"]
if name not in result:
result[name] = {
"member_name": name,
"total_hours": 0,
"task_count": 0,
"tasks": [],
"projects": set(),
"categories": defaultdict(float),
}
result[name]["total_hours"] += log.get("hours", 0)
result[name]["task_count"] += 1
result[name]["tasks"].append(log)
if log.get("project"):
result[name]["projects"].add(log["project"])
result[name]["categories"][log.get("category", "其他")] += log.get("hours", 0)
# 将 set 转为 list 以便序列化
for v in result.values():
v["projects"] = list(v["projects"])
v["categories"] = dict(v["categories"])
return result
def _aggregate_by_project(logs: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""按项目聚合日志数据。"""
result: Dict[str, Dict[str, Any]] = {}
for log in logs:
proj = log.get("project", "未分类") or "未分类"
if proj not in result:
result[proj] = {
"project": proj,
"total_hours": 0,
"task_count": 0,
"members": set(),
}
result[proj]["total_hours"] += log.get("hours", 0)
result[proj]["task_count"] += 1
result[proj]["members"].add(log["member_name"])
for v in result.values():
v["members"] = list(v["members"])
return result
# ============================================================
# 周报生成
# ============================================================
def generate_weekly_report(data: Dict[str, Any]) -> Dict[str, Any]:
"""生成周报。
Args:
data: 包含以下字段的字典:
- week (str, optional): 周标识,如 "this"、"last"、"2024-W03"
- template (str, optional): 模板类型
Returns:
包含 Markdown 周报和结构化数据的字典。
"""
now = datetime.now()
week = data.get("week", "this")
# 解析周范围
if week == "this":
target_date = now
elif week == "last":
target_date = now - timedelta(days=7)
else:
try:
parts = week.split("-W")
year = int(parts[0])
week_num = int(parts[1])
jan1 = datetime(year, 1, 1)
start = jan1 - timedelta(days=jan1.isoweekday() - 1)
target_date = start + timedelta(weeks=week_num - 1)
except (ValueError, IndexError):
raise ValueError(f"无效的周标识: {week!r}")
date_from, date_to = week_range_str(target_date)
# 获取日志
logs = _load_worklogs()
week_logs = _filter_logs_by_range(logs, date_from, date_to)
# 获取团队信息
team = _load_team()
team_name = team["name"] if team else "团队"
# 聚合数据
by_member = _aggregate_by_member(week_logs)
by_project = _aggregate_by_project(week_logs)
total_hours = sum(l.get("hours", 0) for l in week_logs)
total_tasks = len(week_logs)
sub = check_subscription()
is_paid = sub["tier"] == "paid"
# 生成 Markdown 报告
md_lines = []
md_lines.append(f"# {team_name} 周报")
md_lines.append(f"")
md_lines.append(f"**报告周期**: {date_from} ~ {date_to}")
md_lines.append(f"**生成时间**: {now_str()}")
md_lines.append("")
# 概览
md_lines.append("## 概览")
md_lines.append("")
overview_headers = ["指标", "数值"]
overview_rows = [
["参与人数", str(len(by_member))],
["任务总数", str(total_tasks)],
["总工时", format_hours(total_hours)],
["涉及项目", str(len(by_project))],
]
md_lines.append(build_markdown_table(overview_headers, overview_rows))
md_lines.append("")
# 按成员统计
md_lines.append("## 成员工作汇总")
md_lines.append("")
member_headers = ["成员", "任务数", "工时", "涉及项目"]
member_rows = []
for name, info in sorted(by_member.items()):
member_rows.append([
name,
str(info["task_count"]),
format_hours(info["total_hours"]),
", ".join(info["projects"]) if info["projects"] else "-",
])
md_lines.append(build_markdown_table(member_headers, member_rows))
md_lines.append("")
# 按项目统计
md_lines.append("## 项目进展")
md_lines.append("")
project_headers = ["项目", "任务数", "工时", "参与成员"]
project_rows = []
for proj, info in sorted(by_project.items()):
project_rows.append([
proj,
str(info["task_count"]),
format_hours(info["total_hours"]),
", ".join(info["members"]),
])
md_lines.append(build_markdown_table(project_headers, project_rows))
md_lines.append("")
# 详细工作记录
md_lines.append("## 详细工作记录")
md_lines.append("")
for name, info in sorted(by_member.items()):
md_lines.append(f"### {name}")
md_lines.append("")
detail_headers = ["日期", "任务", "项目", "分类", "工时"]
detail_rows = []
for task in sorted(info["tasks"], key=lambda t: t["date"]):
detail_rows.append([
task["date"],
task["task_description"],
task.get("project", "-"),
task.get("category", "其他"),
format_hours(task.get("hours", 0)),
])
md_lines.append(build_markdown_table(detail_headers, detail_rows))
md_lines.append("")
# 付费版:图表与洞察
if is_paid:
md_lines.append("## 工时分布图")
md_lines.append("")
# 项目工时饼图
if by_project:
md_lines.append("### 项目工时分布")
md_lines.append("")
md_lines.append("```mermaid")
md_lines.append("pie title 项目工时分布")
for proj, info in sorted(by_project.items(), key=lambda x: x[1]["total_hours"], reverse=True):
md_lines.append(f' "{proj}" : {info["total_hours"]:.1f}')
md_lines.append("```")
md_lines.append("")
# 成员工时柱状图
if by_member:
md_lines.append("### 成员工时对比")
md_lines.append("")
md_lines.append("```mermaid")
md_lines.append("xychart-beta")
md_lines.append(f' title "成员工时对比({date_from} ~ {date_to})"')
md_lines.append(' x-axis [' + ", ".join(f'"{n}"' for n in sorted(by_member.keys())) + ']')
md_lines.append(' y-axis "工时(小时)"')
bar_data = [str(by_member[n]["total_hours"]) for n in sorted(by_member.keys())]
md_lines.append(' bar [' + ", ".join(bar_data) + ']')
md_lines.append("```")
md_lines.append("")
# 洞察分析
md_lines.append("## 洞察与建议")
md_lines.append("")
insights = _generate_weekly_insights(by_member, by_project, total_hours)
for insight in insights:
md_lines.append(f"- {insight}")
md_lines.append("")
report_md = "\n".join(md_lines)
return {
"report_type": "weekly",
"period": f"{date_from} ~ {date_to}",
"summary": {
"member_count": len(by_member),
"task_count": total_tasks,
"total_hours": total_hours,
"project_count": len(by_project),
},
"by_member": {k: {kk: vv for kk, vv in v.items() if kk != "tasks"} for k, v in by_member.items()},
"by_project": by_project,
"report_markdown": report_md,
}
def _generate_weekly_insights(
by_member: Dict,
by_project: Dict,
total_hours: float,
) -> List[str]:
"""生成周报洞察建议。"""
insights = []
if not by_member:
insights.append("本周暂无工作记录,请确认团队成员已录入日志。")
return insights
# 找出工时最多和最少的成员
sorted_members = sorted(by_member.items(), key=lambda x: x[1]["total_hours"], reverse=True)
top_member = sorted_members[0]
insights.append(
f"本周工时最多的成员是 **{top_member[0]}**,"
f"共 {format_hours(top_member[1]['total_hours'])},"
f"完成 {top_member[1]['task_count']} 项任务。"
)
if len(sorted_members) > 1:
bottom_member = sorted_members[-1]
if top_member[1]["total_hours"] > 0:
ratio = bottom_member[1]["total_hours"] / top_member[1]["total_hours"]
if ratio < 0.3:
insights.append(
f"**{bottom_member[0]}** 工时偏低({format_hours(bottom_member[1]['total_hours'])}),"
f"建议关注工作分配均衡性。"
)
# 项目分布
if len(by_project) > 3:
insights.append(
f"本周涉及 {len(by_project)} 个项目,注意避免精力过于分散。"
)
# 人均工时
if by_member:
avg_hours = total_hours / len(by_member)
if avg_hours > 45:
insights.append(f"人均工时 {format_hours(avg_hours)},偏高,注意团队工作强度。")
elif avg_hours < 20:
insights.append(f"人均工时 {format_hours(avg_hours)},偏低,请确认日志录入是否完整。")
return insights
# ============================================================
# 月报生成
# ============================================================
def generate_monthly_report(data: Dict[str, Any]) -> Dict[str, Any]:
"""生成月报(仅付费版)。
Args:
data: 包含以下字段的字典:
- month (str, optional): 月标识,如 "this"、"last"、"2024-01"
Returns:
包含 Markdown 月报和结构化数据的字典。
"""
sub = check_subscription()
if sub["tier"] != "paid":
raise PermissionError(
"月报汇总为付费版功能。当前为免费版,如需使用请升级至付费版(¥69/月)。"
)
now = datetime.now()
month = data.get("month", "this")
if month == "this":
year, mon = now.year, now.month
elif month == "last":
if now.month == 1:
year, mon = now.year - 1, 12
else:
year, mon = now.year, now.month - 1
else:
try:
parts = month.split("-")
year, mon = int(parts[0]), int(parts[1])
except (ValueError, IndexError):
raise ValueError(f"无效的月标识: {month!r},请使用 YYYY-MM 格式")
date_from, date_to = month_range_str(year, mon)
# 获取日志
logs = _load_worklogs()
month_logs = _filter_logs_by_range(logs, date_from, date_to)
team = _load_team()
team_name = team["name"] if team else "团队"
by_member = _aggregate_by_member(month_logs)
by_project = _aggregate_by_project(month_logs)
total_hours = sum(l.get("hours", 0) for l in month_logs)
total_tasks = len(month_logs)
# 按周聚合
weekly_data: Dict[str, Dict[str, Any]] = {}
for log in month_logs:
log_date = datetime.strptime(log["date"], "%Y-%m-%d")
iso_year, iso_week, _ = log_date.isocalendar()
week_key = f"{iso_year}-W{iso_week:02d}"
if week_key not in weekly_data:
weekly_data[week_key] = {"hours": 0, "tasks": 0}
weekly_data[week_key]["hours"] += log.get("hours", 0)
weekly_data[week_key]["tasks"] += 1
# 生成 Markdown
md_lines = []
md_lines.append(f"# {team_name} 月报")
md_lines.append("")
md_lines.append(f"**报告月份**: {year}年{mon}月")
md_lines.append(f"**报告周期**: {date_from} ~ {date_to}")
md_lines.append(f"**生成时间**: {now_str()}")
md_lines.append("")
# 执行摘要
md_lines.append("## 执行摘要")
md_lines.append("")
md_lines.append(f"本月团队共 **{len(by_member)}** 人参与工作,"
f"完成 **{total_tasks}** 项任务,"
f"累计工时 **{format_hours(total_hours)}**,"
f"涉及 **{len(by_project)}** 个项目。")
md_lines.append("")
# 核心指标
md_lines.append("## 核心指标")
md_lines.append("")
overview_headers = ["指标", "数值"]
avg_hours = total_hours / len(by_member) if by_member else 0
overview_rows = [
["参与人数", str(len(by_member))],
["任务总数", str(total_tasks)],
["总工时", format_hours(total_hours)],
["人均工时", format_hours(avg_hours)],
["涉及项目", str(len(by_project))],
["覆盖周数", str(len(weekly_data))],
]
md_lines.append(build_markdown_table(overview_headers, overview_rows))
md_lines.append("")
# 周度趋势
md_lines.append("## 周度趋势")
md_lines.append("")
if weekly_data:
week_headers = ["周", "任务数", "工时"]
week_rows = []
for wk in sorted(weekly_data.keys()):
wd = weekly_data[wk]
week_rows.append([wk, str(wd["tasks"]), format_hours(wd["hours"])])
md_lines.append(build_markdown_table(week_headers, week_rows))
md_lines.append("")
# 周度趋势图
md_lines.append("```mermaid")
md_lines.append("xychart-beta")
md_lines.append(f' title "月度周工时趋势({year}年{mon}月)"')
sorted_weeks = sorted(weekly_data.keys())
md_lines.append(' x-axis [' + ", ".join(f'"{w}"' for w in sorted_weeks) + ']')
md_lines.append(' y-axis "工时(小时)"')
md_lines.append(' bar [' + ", ".join(str(weekly_data[w]["hours"]) for w in sorted_weeks) + ']')
md_lines.append("```")
md_lines.append("")
# 成员统计
md_lines.append("## 成员工作汇总")
md_lines.append("")
member_headers = ["成员", "任务数", "工时", "占比", "涉及项目"]
member_rows = []
for name, info in sorted(by_member.items(), key=lambda x: x[1]["total_hours"], reverse=True):
pct = info["total_hours"] / total_hours if total_hours > 0 else 0
member_rows.append([
name,
str(info["task_count"]),
format_hours(info["total_hours"]),
format_percentage(pct),
", ".join(info["projects"]) if info["projects"] else "-",
])
md_lines.append(build_markdown_table(member_headers, member_rows))
md_lines.append("")
# 项目统计
md_lines.append("## 项目工时分布")
md_lines.append("")
project_headers = ["项目", "任务数", "工时", "占比", "参与成员"]
project_rows = []
for proj, info in sorted(by_project.items(), key=lambda x: x[1]["total_hours"], reverse=True):
pct = info["total_hours"] / total_hours if total_hours > 0 else 0
project_rows.append([
proj,
str(info["task_count"]),
format_hours(info["total_hours"]),
format_percentage(pct),
", ".join(info["members"]),
])
md_lines.append(build_markdown_table(project_headers, project_rows))
md_lines.append("")
# 项目工时饼图
if by_project:
md_lines.append("```mermaid")
md_lines.append("pie title 项目工时分布")
for proj, info in sorted(by_project.items(), key=lambda x: x[1]["total_hours"], reverse=True):
md_lines.append(f' "{proj}" : {info["total_hours"]:.1f}')
md_lines.append("```")
md_lines.append("")
# 洞察
md_lines.append("## 洞察与建议")
md_lines.append("")
insights = _generate_monthly_insights(by_member, by_project, weekly_data, total_hours)
for insight in insights:
md_lines.append(f"- {insight}")
md_lines.append("")
report_md = "\n".join(md_lines)
return {
"report_type": "monthly",
"period": f"{date_from} ~ {date_to}",
"summary": {
"member_count": len(by_member),
"task_count": total_tasks,
"total_hours": total_hours,
"project_count": len(by_project),
},
"report_markdown": report_md,
}
def _generate_monthly_insights(
by_member: Dict,
by_project: Dict,
weekly_data: Dict,
total_hours: float,
) -> List[str]:
"""生成月报洞察建议。"""
insights = []
if not by_member:
insights.append("本月暂无工作记录。")
return insights
# 工时趋势
if weekly_data:
weeks = sorted(weekly_data.keys())
hours_trend = [weekly_data[w]["hours"] for w in weeks]
if len(hours_trend) >= 2:
if hours_trend[-1] > hours_trend[0] * 1.2:
insights.append("本月工时呈上升趋势,团队工作量在增加。")
elif hours_trend[-1] < hours_trend[0] * 0.8:
insights.append("本月工时呈下降趋势,请关注团队产出。")
else:
insights.append("本月工时保持稳定。")
# 工时分布
member_hours = [v["total_hours"] for v in by_member.values()]
if member_hours:
max_h = max(member_hours)
min_h = min(member_hours)
if max_h > 0 and min_h / max_h < 0.3 and len(member_hours) > 1:
insights.append("成员工时差异较大,建议优化任务分配。")
# 项目集中度
if by_project and total_hours > 0:
sorted_projects = sorted(by_project.values(), key=lambda x: x["total_hours"], reverse=True)
top_pct = sorted_projects[0]["total_hours"] / total_hours
if top_pct > 0.6:
insights.append(
f"项目 **{sorted_projects[0]['project']}** 占据 "
f"{format_percentage(top_pct)} 工时,为本月核心项目。"
)
return insights
# ============================================================
# 自定义范围报告
# ============================================================
def generate_custom_report(data: Dict[str, Any]) -> Dict[str, Any]:
"""生成自定义日期范围的报告。
Args:
data: 包含以下字段的字典:
- date_from (str): 起始日期(必填)
- date_to (str): 结束日期(必填)
Returns:
包含 Markdown 报告和结构化数据的字典。
"""
date_from = data.get("date_from")
date_to = data.get("date_to")
if not date_from or not date_to:
raise ValueError("请提供起止日期(date_from 和 date_to 字段)")
logs = _load_worklogs()
range_logs = _filter_logs_by_range(logs, date_from, date_to)
team = _load_team()
team_name = team["name"] if team else "团队"
by_member = _aggregate_by_member(range_logs)
by_project = _aggregate_by_project(range_logs)
total_hours = sum(l.get("hours", 0) for l in range_logs)
md_lines = []
md_lines.append(f"# {team_name} 工作报告")
md_lines.append("")
md_lines.append(f"**报告周期**: {date_from} ~ {date_to}")
md_lines.append(f"**生成时间**: {now_str()}")
md_lines.append("")
md_lines.append("## 概览")
md_lines.append("")
overview_headers = ["指标", "数值"]
overview_rows = [
["参与人数", str(len(by_member))],
["任务总数", str(len(range_logs))],
["总工时", format_hours(total_hours)],
["涉及项目", str(len(by_project))],
]
md_lines.append(build_markdown_table(overview_headers, overview_rows))
md_lines.append("")
# 成员汇总
md_lines.append("## 成员工作汇总")
md_lines.append("")
member_headers = ["成员", "任务数", "工时", "涉及项目"]
member_rows = []
for name, info in sorted(by_member.items()):
member_rows.append([
name,
str(info["task_count"]),
format_hours(info["total_hours"]),
", ".join(info["projects"]) if info["projects"] else "-",
])
md_lines.append(build_markdown_table(member_headers, member_rows))
md_lines.append("")
report_md = "\n".join(md_lines)
return {
"report_type": "custom",
"period": f"{date_from} ~ {date_to}",
"summary": {
"member_count": len(by_member),
"task_count": len(range_logs),
"total_hours": total_hours,
"project_count": len(by_project),
},
"report_markdown": report_md,
}
# ============================================================
# 命令行入口
# ============================================================
def main() -> None:
"""命令行入口函数。"""
parser = create_base_parser("周报/月报生成工具")
parser.add_argument(
"--week",
default=None,
help="周标识(如 this, last, 2024-W03)",
)
parser.add_argument(
"--month",
default=None,
help="月标识(如 this, last, 2024-01)",
)
args, _ = parser.parse_known_args()
try:
action = args.action
if action == "weekly":
data = {}
try:
data = read_input_data(args)
except ValueError:
pass
if args.week:
data["week"] = args.week
if "week" not in data:
data["week"] = "this"
result = generate_weekly_report(data)
output_success(result)
elif action == "monthly":
data = {}
try:
data = read_input_data(args)
except ValueError:
pass
if args.month:
data["month"] = args.month
if "month" not in data:
data["month"] = "this"
result = generate_monthly_report(data)
output_success(result)
elif action == "custom":
data = read_input_data(args)
result = generate_custom_report(data)
output_success(result)
else:
output_error(f"未知操作: {action}", "UNKNOWN_ACTION")
except (ValueError, PermissionError) as e:
output_error(str(e), type(e).__name__.upper())
except Exception as e:
output_error(f"内部错误: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
main()
FILE:scripts/worklog_manager.py
#!/usr/bin/env python3
"""
team-weekly 工作日志管理模块
提供工作日志的增删查改功能,支持自然语言输入解析。
"""
import re
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
create_base_parser,
ensure_data_dir,
generate_id,
get_worklog_file,
load_json_file,
now_str,
output_error,
output_success,
parse_date,
read_input_data,
save_json_file,
today_str,
)
from team_store import get_member_by_name, _load_team
# ============================================================
# 工作分类
# ============================================================
VALID_CATEGORIES = ["开发", "设计", "测试", "会议", "其他"]
# ============================================================
# 工作日志数据操作
# ============================================================
def _load_worklogs() -> List[Dict[str, Any]]:
"""加载所有工作日志。"""
data = load_json_file(get_worklog_file())
if data is None:
return []
if isinstance(data, list):
return data
return data.get("logs", [])
def _save_worklogs(logs: List[Dict[str, Any]]) -> None:
"""保存工作日志。"""
ensure_data_dir()
save_json_file(get_worklog_file(), logs)
def parse_natural_input(text: str) -> Dict[str, Any]:
"""解析自然语言工作日志输入。
支持的输入格式示例:
- "张三今天完成了官网首页设计,耗时6小时"
- "李四 完成用户模块开发 8h 项目:商城"
- "王五 设计LOGO 3小时 设计"
Args:
text: 自然语言输入文本。
Returns:
解析后的字典,包含 member_name, task_description, hours, project, category 等。
"""
result: Dict[str, Any] = {
"member_name": None,
"task_description": None,
"hours": None,
"project": None,
"category": "其他",
"date": today_str(),
}
# 提取成员姓名(假设在句首,中文名2-4字或英文名)
name_match = re.match(r'^([a-zA-Z\u4e00-\u9fff]{1,10})\s*', text)
if name_match:
result["member_name"] = name_match.group(1)
# 提取工时
hours_patterns = [
r'(\d+(?:\.\d+)?)\s*(?:小时|hours?|hrs?|h)',
r'耗时\s*(\d+(?:\.\d+)?)',
]
for pattern in hours_patterns:
hours_match = re.search(pattern, text, re.IGNORECASE)
if hours_match:
result["hours"] = float(hours_match.group(1))
break
# 提取项目名
project_match = re.search(r'项目[::]\s*([^\s,,。]+)', text)
if project_match:
result["project"] = project_match.group(1)
# 提取分类
for cat in VALID_CATEGORIES:
if cat in text:
result["category"] = cat
break
# 提取日期
date_match = re.search(r'(\d{4}-\d{2}-\d{2})', text)
if date_match:
result["date"] = date_match.group(1)
elif "今天" in text:
result["date"] = today_str()
elif "昨天" in text:
from datetime import timedelta
yesterday = datetime.now() - timedelta(days=1)
result["date"] = yesterday.strftime("%Y-%m-%d")
# 提取任务描述(去除已解析的部分,取主要内容)
desc = text
# 移除成员名
if result["member_name"]:
desc = desc.replace(result["member_name"], "", 1)
# 移除工时信息
for pattern in hours_patterns:
desc = re.sub(pattern, "", desc, flags=re.IGNORECASE)
# 移除项目信息
desc = re.sub(r'项目[::]\s*[^\s,,。]+', '', desc)
# 移除日期和时间词
desc = re.sub(r'\d{4}-\d{2}-\d{2}', '', desc)
desc = re.sub(r'(今天|昨天|耗时)', '', desc)
# 清理
desc = re.sub(r'[,,。\s]+', ' ', desc).strip()
desc = re.sub(r'^(完成了?|做了?)\s*', '', desc).strip()
if desc:
result["task_description"] = desc
return result
def add_worklog(data: Dict[str, Any]) -> Dict[str, Any]:
"""添加工作日志。
Args:
data: 包含以下字段的字典:
- member_name (str): 成员姓名(必填)
- task_description (str): 任务描述(必填)
- date (str, optional): 日期,默认今天
- hours (float, optional): 工时
- project (str, optional): 项目名称
- category (str, optional): 分类,默认"其他"
- natural_input (str, optional): 自然语言输入,优先解析
Returns:
创建的工作日志记录。
"""
# 如果有自然语言输入,先解析
natural_input = data.get("natural_input")
if natural_input:
parsed = parse_natural_input(natural_input)
# 解析结果作为默认值,显式传入的字段优先
for key, value in parsed.items():
if value is not None and key not in data:
data[key] = value
member_name = data.get("member_name")
task_description = data.get("task_description")
if not member_name:
raise ValueError("请提供成员姓名(member_name 字段)")
if not task_description:
raise ValueError("请提供任务描述(task_description 字段)")
# 验证成员是否存在
team = _load_team()
if team:
member = get_member_by_name(member_name)
if not member:
raise ValueError(
f"成员 {member_name!r} 不存在。"
f"请先使用 add-member 添加成员。"
)
member_id = member["id"]
else:
member_id = generate_id()
# 校验分类
category = data.get("category", "其他")
if category not in VALID_CATEGORIES:
category = "其他"
# 校验日期
date = data.get("date", today_str())
parse_date(date) # 验证格式
log_entry = {
"id": generate_id(),
"member_id": member_id,
"member_name": member_name,
"date": date,
"task_description": task_description,
"project": data.get("project", ""),
"hours": float(data.get("hours", 0)),
"category": category,
"created_at": now_str(),
}
logs = _load_worklogs()
logs.append(log_entry)
_save_worklogs(logs)
return log_entry
def list_worklogs(data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""查询工作日志列表。
Args:
data: 可选过滤条件字典:
- member_name (str, optional): 按成员筛选
- date (str, optional): 按日期筛选
- date_from (str, optional): 起始日期
- date_to (str, optional): 结束日期
- project (str, optional): 按项目筛选
- category (str, optional): 按分类筛选
Returns:
包含日志列表和统计信息的字典。
"""
logs = _load_worklogs()
if data:
# 按成员筛选
member_name = data.get("member_name")
if member_name:
logs = [l for l in logs if l["member_name"] == member_name]
# 按日期筛选
date = data.get("date")
if date:
logs = [l for l in logs if l["date"] == date]
# 按日期范围筛选
date_from = data.get("date_from")
date_to = data.get("date_to")
if date_from:
logs = [l for l in logs if l["date"] >= date_from]
if date_to:
logs = [l for l in logs if l["date"] <= date_to]
# 按项目筛选
project = data.get("project")
if project:
logs = [l for l in logs if l.get("project") == project]
# 按分类筛选
category = data.get("category")
if category:
logs = [l for l in logs if l.get("category") == category]
# 按日期倒序排列
logs.sort(key=lambda x: x["date"], reverse=True)
total_hours = sum(l.get("hours", 0) for l in logs)
return {
"total_count": len(logs),
"total_hours": total_hours,
"logs": logs,
}
def query_worklogs(data: Dict[str, Any]) -> Dict[str, Any]:
"""高级查询工作日志,支持按周/月聚合。
Args:
data: 查询条件字典:
- week (str, optional): 周标识,如 "2024-W03" 或 "this" 或 "last"
- month (str, optional): 月标识,如 "2024-01" 或 "this" 或 "last"
- member_name (str, optional): 按成员筛选
- project (str, optional): 按项目筛选
- group_by (str, optional): 分组方式 (member/project/category/date)
Returns:
查询结果和聚合统计。
"""
from datetime import timedelta
import calendar
logs = _load_worklogs()
now = datetime.now()
# 解析周
week = data.get("week")
if week:
if week == "this":
monday = now - timedelta(days=now.weekday())
sunday = monday + timedelta(days=6)
elif week == "last":
monday = now - timedelta(days=now.weekday() + 7)
sunday = monday + timedelta(days=6)
else:
# 解析 YYYY-WNN 格式
try:
parts = week.split("-W")
year = int(parts[0])
week_num = int(parts[1])
# ISO 周一
jan1 = datetime(year, 1, 1)
start_of_week1 = jan1 - timedelta(days=jan1.isoweekday() - 1)
monday = start_of_week1 + timedelta(weeks=week_num - 1)
sunday = monday + timedelta(days=6)
except (ValueError, IndexError):
raise ValueError(f"无效的周标识: {week!r},请使用 YYYY-WNN 格式")
date_from = monday.strftime("%Y-%m-%d")
date_to = sunday.strftime("%Y-%m-%d")
logs = [l for l in logs if date_from <= l["date"] <= date_to]
# 解析月
month = data.get("month")
if month:
if month == "this":
year, mon = now.year, now.month
elif month == "last":
if now.month == 1:
year, mon = now.year - 1, 12
else:
year, mon = now.year, now.month - 1
else:
try:
parts = month.split("-")
year, mon = int(parts[0]), int(parts[1])
except (ValueError, IndexError):
raise ValueError(f"无效的月标识: {month!r},请使用 YYYY-MM 格式")
_, last_day = calendar.monthrange(year, mon)
date_from = f"{year:04d}-{mon:02d}-01"
date_to = f"{year:04d}-{mon:02d}-{last_day:02d}"
logs = [l for l in logs if date_from <= l["date"] <= date_to]
# 按成员筛选
member_name = data.get("member_name")
if member_name:
logs = [l for l in logs if l["member_name"] == member_name]
# 按项目筛选
project = data.get("project")
if project:
logs = [l for l in logs if l.get("project") == project]
# 分组聚合
group_by = data.get("group_by")
groups: Dict[str, Any] = {}
if group_by:
for log in logs:
if group_by == "member":
key = log["member_name"]
elif group_by == "project":
key = log.get("project", "未分类")
elif group_by == "category":
key = log.get("category", "其他")
elif group_by == "date":
key = log["date"]
else:
key = "全部"
if key not in groups:
groups[key] = {"count": 0, "hours": 0, "logs": []}
groups[key]["count"] += 1
groups[key]["hours"] += log.get("hours", 0)
groups[key]["logs"].append(log)
total_hours = sum(l.get("hours", 0) for l in logs)
result = {
"total_count": len(logs),
"total_hours": total_hours,
"logs": logs,
}
if groups:
# 返回摘要(不含具体日志以减少体积)
result["groups"] = {
k: {"count": v["count"], "hours": v["hours"]}
for k, v in groups.items()
}
return result
def delete_worklog(data: Dict[str, Any]) -> Dict[str, Any]:
"""删除工作日志。
Args:
data: 包含以下字段的字典:
- id (str): 日志 ID(必填)
Returns:
被删除的日志记录。
"""
log_id = data.get("id")
if not log_id:
raise ValueError("请提供日志 ID(id 字段)")
logs = _load_worklogs()
removed = None
new_logs = []
for log in logs:
if log["id"] == log_id:
removed = log
else:
new_logs.append(log)
if not removed:
raise ValueError(f"未找到日志: {log_id}")
_save_worklogs(new_logs)
return removed
# ============================================================
# 命令行入口
# ============================================================
def main() -> None:
"""命令行入口函数。"""
parser = create_base_parser("工作日志管理工具")
args, _ = parser.parse_known_args()
try:
action = args.action
if action == "add":
data = read_input_data(args)
result = add_worklog(data)
output_success(result)
elif action == "list":
try:
data = read_input_data(args)
except ValueError:
data = None
result = list_worklogs(data)
output_success(result)
elif action == "query":
data = read_input_data(args)
result = query_worklogs(data)
output_success(result)
elif action == "delete":
data = read_input_data(args)
result = delete_worklog(data)
output_success(result)
else:
output_error(f"未知操作: {action}", "UNKNOWN_ACTION")
except (ValueError, PermissionError) as e:
output_error(str(e), type(e).__name__.upper())
except Exception as e:
output_error(f"内部错误: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
main()
FILE:scripts/team_store.py
#!/usr/bin/env python3
"""
team-weekly 团队与成员管理模块
提供团队创建、成员增删改查等功能,数据以 JSON 格式存储在本地。
"""
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
create_base_parser,
ensure_data_dir,
generate_id,
get_team_file,
load_json_file,
now_str,
output_error,
output_success,
read_input_data,
save_json_file,
)
# ============================================================
# 团队数据操作
# ============================================================
def _load_team() -> Optional[Dict[str, Any]]:
"""加载团队数据。"""
return load_json_file(get_team_file())
def _save_team(team: Dict[str, Any]) -> None:
"""保存团队数据。"""
save_json_file(get_team_file(), team)
def init_team(data: Dict[str, Any]) -> Dict[str, Any]:
"""初始化团队。
Args:
data: 包含以下字段的字典:
- name (str): 团队名称(必填)
- members (list, optional): 初始成员列表
Returns:
创建后的团队数据。
Raises:
ValueError: 缺少必填字段时抛出。
"""
name = data.get("name")
if not name:
raise ValueError("请提供团队名称(name 字段)")
ensure_data_dir()
# 检查是否已存在团队
existing = _load_team()
if existing:
raise ValueError(
f"团队已存在: {existing['name']}。"
f"如需重新创建,请先删除现有团队数据。"
)
sub = check_subscription()
team = {
"id": generate_id(),
"name": name,
"members": [],
"created_at": now_str(),
"updated_at": now_str(),
"subscription_tier": sub["tier"],
}
# 如果提供了初始成员列表,逐个添加
initial_members = data.get("members", [])
for member_data in initial_members:
if isinstance(member_data, str):
member_data = {"name": member_data}
_add_member_to_team(team, member_data, sub)
_save_team(team)
return team
def _add_member_to_team(
team: Dict[str, Any],
member_data: Dict[str, Any],
sub: Dict[str, Any],
) -> Dict[str, Any]:
"""向团队添加成员(内部方法)。
Args:
team: 团队数据。
member_data: 成员数据,包含 name, role, projects 等。
sub: 订阅信息。
Returns:
新添加的成员数据。
Raises:
ValueError: 超出人数限制或缺少必填字段时抛出。
"""
max_members = sub["max_members"]
current_count = len(team["members"])
if current_count >= max_members:
raise ValueError(
f"团队成员已达上限({max_members}人)。"
f"当前订阅等级为{sub['tier']},"
+ (
"如需更多成员请升级至付费版(¥69/月)。"
if sub["tier"] == "free"
else "已达付费版最大成员数。"
)
)
name = member_data.get("name")
if not name:
raise ValueError("请提供成员姓名(name 字段)")
# 检查重名
for m in team["members"]:
if m["name"] == name:
raise ValueError(f"成员 {name!r} 已存在")
member = {
"id": generate_id(),
"name": name,
"role": member_data.get("role", "成员"),
"projects": member_data.get("projects", []),
"created_at": now_str(),
}
team["members"].append(member)
team["updated_at"] = now_str()
return member
def add_member(data: Dict[str, Any]) -> Dict[str, Any]:
"""添加团队成员。
Args:
data: 包含以下字段的字典:
- name (str): 成员姓名(必填)
- role (str, optional): 角色,默认"成员"
- projects (list, optional): 参与的项目列表
Returns:
新添加的成员数据。
"""
team = _load_team()
if not team:
raise ValueError("团队尚未初始化,请先执行 init 操作")
sub = check_subscription()
member = _add_member_to_team(team, data, sub)
_save_team(team)
return member
def remove_member(data: Dict[str, Any]) -> Dict[str, Any]:
"""移除团队成员。
Args:
data: 包含以下字段的字典:
- name (str) 或 id (str): 成员姓名或 ID(二选一)
Returns:
被移除的成员数据。
"""
team = _load_team()
if not team:
raise ValueError("团队尚未初始化,请先执行 init 操作")
member_name = data.get("name")
member_id = data.get("id")
if not member_name and not member_id:
raise ValueError("请提供成员姓名(name)或 ID(id)")
removed = None
new_members = []
for m in team["members"]:
if (member_name and m["name"] == member_name) or \
(member_id and m["id"] == member_id):
removed = m
else:
new_members.append(m)
if not removed:
identifier = member_name or member_id
raise ValueError(f"未找到成员: {identifier}")
team["members"] = new_members
team["updated_at"] = now_str()
_save_team(team)
return removed
def list_members() -> Dict[str, Any]:
"""列出所有团队成员。
Returns:
包含团队信息和成员列表的字典。
"""
team = _load_team()
if not team:
raise ValueError("团队尚未初始化,请先执行 init 操作")
sub = check_subscription()
return {
"team_name": team["name"],
"team_id": team["id"],
"subscription_tier": sub["tier"],
"max_members": sub["max_members"],
"current_count": len(team["members"]),
"members": team["members"],
}
def get_team() -> Dict[str, Any]:
"""获取完整的团队信息。
Returns:
团队完整数据。
"""
team = _load_team()
if not team:
raise ValueError("团队尚未初始化,请先执行 init 操作")
return team
def get_member_by_name(name: str) -> Optional[Dict[str, Any]]:
"""通过姓名查找成员。
Args:
name: 成员姓名。
Returns:
成员数据,未找到返回 None。
"""
team = _load_team()
if not team:
return None
for m in team["members"]:
if m["name"] == name:
return m
return None
# ============================================================
# 命令行入口
# ============================================================
def main() -> None:
"""命令行入口函数。"""
parser = create_base_parser("团队与成员管理工具")
args, _ = parser.parse_known_args()
try:
action = args.action
if action == "init":
data = read_input_data(args)
result = init_team(data)
output_success(result)
elif action == "add-member":
data = read_input_data(args)
result = add_member(data)
output_success(result)
elif action == "remove-member":
data = read_input_data(args)
result = remove_member(data)
output_success(result)
elif action == "list":
result = list_members()
output_success(result)
elif action == "get-team":
result = get_team()
output_success(result)
else:
output_error(f"未知操作: {action}", "UNKNOWN_ACTION")
except (ValueError, PermissionError) as e:
output_error(str(e), type(e).__name__.upper())
except Exception as e:
output_error(f"内部错误: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
# 将脚本目录加入 sys.path,确保可以导入同目录模块
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
main()
FILE:references/weekly-templates.md
# 周报模板参考
## 模板一:基础周报(免费版)
```markdown
# [团队名称] 周报
**报告周期**: YYYY-MM-DD ~ YYYY-MM-DD
**生成时间**: YYYY-MM-DD HH:MM:SS
## 概览
| 指标 | 数值 |
|------|------|
| 参与人数 | X |
| 任务总数 | X |
| 总工时 | Xh |
| 涉及项目 | X |
## 成员工作汇总
| 成员 | 任务数 | 工时 | 涉及项目 |
|------|--------|------|----------|
| 张三 | 5 | 32h | 官网改版, 管理后台 |
| 李四 | 4 | 28h | 用户系统 |
| 王五 | 3 | 24h | 官网改版 |
## 项目进展
| 项目 | 任务数 | 工时 | 参与成员 |
|------|--------|------|----------|
| 官网改版 | 5 | 36h | 张三, 王五 |
| 用户系统 | 4 | 28h | 李四 |
| 管理后台 | 3 | 20h | 张三 |
## 详细工作记录
### 张三
| 日期 | 任务 | 项目 | 分类 | 工时 |
|------|------|------|------|------|
| 01-15 | 首页设计稿完成 | 官网改版 | 设计 | 6h |
| 01-16 | 首页前端开发 | 官网改版 | 开发 | 8h |
| 01-17 | 后台列表页开发 | 管理后台 | 开发 | 8h |
| 01-18 | 后台搜索功能 | 管理后台 | 开发 | 6h |
| 01-19 | Code Review | 管理后台 | 其他 | 4h |
### 李四
...
```
---
## 模板二:增强周报(付费版)
```markdown
# [团队名称] 周报
**报告周期**: YYYY-MM-DD ~ YYYY-MM-DD
**生成时间**: YYYY-MM-DD HH:MM:SS
## 概览
| 指标 | 数值 |
|------|------|
| 参与人数 | 5 |
| 任务总数 | 23 |
| 总工时 | 152h |
| 涉及项目 | 4 |
## 成员工作汇总
| 成员 | 任务数 | 工时 | 占比 | 涉及项目 |
|------|--------|------|------|----------|
| 张三 | 7 | 38h | 25.0% | 官网改版, 管理后台 |
| 李四 | 6 | 36h | 23.7% | 用户系统, 支付模块 |
| 王五 | 4 | 28h | 18.4% | 官网改版 |
| 赵六 | 3 | 26h | 17.1% | 用户系统 |
| 钱七 | 3 | 24h | 15.8% | 管理后台 |
## 项目进展
| 项目 | 任务数 | 工时 | 占比 | 参与成员 |
|------|--------|------|------|----------|
| 官网改版 | 8 | 52h | 34.2% | 张三, 王五 |
| 用户系统 | 7 | 48h | 31.6% | 李四, 赵六 |
| 管理后台 | 5 | 32h | 21.1% | 张三, 钱七 |
| 支付模块 | 3 | 20h | 13.2% | 李四 |
## 工时分布图
### 项目工时分布
```mermaid
pie title 项目工时分布
"官网改版" : 52
"用户系统" : 48
"管理后台" : 32
"支付模块" : 20
```
### 成员工时对比
```mermaid
xychart-beta
title "成员工时对比"
x-axis ["张三", "李四", "王五", "赵六", "钱七"]
y-axis "工时(小时)"
bar [38, 36, 28, 26, 24]
```
## 详细工作记录
### 张三
...
## 洞察与建议
- 本周工时最多的成员是 **张三**,共 38h,完成 7 项任务。
- 项目 **官网改版** 占据 34.2% 工时,为本周核心项目。
- 人均工时 30.4h,工作强度适中。
- **赵六** 和 **钱七** 工时偏低,建议关注任务分配均衡性。
```
---
## 模板三:月度汇总报告(付费版)
```markdown
# [团队名称] 月报
**报告月份**: YYYY年M月
**报告周期**: YYYY-MM-DD ~ YYYY-MM-DD
**生成时间**: YYYY-MM-DD HH:MM:SS
## 执行摘要
本月团队共 **5** 人参与工作,完成 **89** 项任务,累计工时 **612h**,涉及 **4** 个项目。
## 核心指标
| 指标 | 数值 |
|------|------|
| 参与人数 | 5 |
| 任务总数 | 89 |
| 总工时 | 612h |
| 人均工时 | 122.4h |
| 涉及项目 | 4 |
| 覆盖周数 | 4 |
## 周度趋势
| 周 | 日期范围 | 任务数 | 工时 |
|----|----------|--------|------|
| W1 | 01-01~01-07 | 18 | 135h |
| W2 | 01-08~01-14 | 22 | 148h |
| W3 | 01-15~01-21 | 25 | 168h |
| W4 | 01-22~01-28 | 24 | 161h |
```mermaid
xychart-beta
title "月度周工时趋势"
x-axis ["W1", "W2", "W3", "W4"]
y-axis "工时(小时)"
bar [135, 148, 168, 161]
```
## 成员工作汇总
| 成员 | 任务数 | 工时 | 占比 | 涉及项目 |
|------|--------|------|------|----------|
| 张三 | 25 | 148h | 24.2% | 官网改版, 管理后台 |
| 李四 | 22 | 140h | 22.9% | 用户系统, 支付模块 |
| ... | ... | ... | ... | ... |
## 项目工时分布
```mermaid
pie title 项目工时分布
"官网改版" : 198
"用户系统" : 182
"管理后台" : 132
"支付模块" : 100
```
## 洞察与建议
- 本月工时呈上升趋势,团队工作量在增加。
- 项目 **官网改版** 占据 32.4% 工时,为本月核心项目。
- 成员工时差异较大,建议优化任务分配。
```
---
## 模板四:个人绩效报告(付费版)
```markdown
# [成员姓名] 绩效分析报告
**分析周期**: 近 4 周
**生成时间**: YYYY-MM-DD HH:MM:SS
## 个人概览
- 总工时: 148h
- 任务数: 25
- 工作天数: 20
- 日均工时: 7.4h
- 涉及项目: 官网改版, 管理后台
## 周度趋势
```mermaid
xychart-beta
title "张三 工时趋势(近4周)"
x-axis ["W1", "W2", "W3", "W4"]
y-axis "工时(小时)"
bar [32, 38, 42, 36]
line [32, 38, 42, 36]
```
## 工作分类分布
```mermaid
pie title 工作分类分布
"开发" : 96
"设计" : 24
"会议" : 16
"测试" : 8
"其他" : 4
```
## 效率评估
- 工时保持**稳定**。
- 周均工时: **37h**
- 开发占比最高 (64.9%),为团队核心开发力量。
```
---
## Mermaid 图表示例
### 饼图 — 工时分布
```mermaid
pie title 工时分布
"开发" : 65
"设计" : 15
"测试" : 10
"会议" : 7
"其他" : 3
```
### 柱状图 — 成员对比
```mermaid
xychart-beta
title "成员工时对比"
x-axis ["张三", "李四", "王五", "赵六", "钱七"]
y-axis "工时(小时)"
bar [38, 36, 28, 26, 24]
```
### 甘特图 — 项目进度
```mermaid
gantt
title 项目进度甘特图
dateFormat YYYY-MM-DD
section 官网改版
张三(38h) : 2024-01-08, 2024-01-19
王五(28h) : 2024-01-08, 2024-01-19
section 用户系统
李四(36h) : 2024-01-08, 2024-01-19
赵六(26h) : 2024-01-10, 2024-01-19
section 管理后台
张三(12h) : 2024-01-15, 2024-01-19
钱七(24h) : 2024-01-08, 2024-01-19
```
### 趋势折线图
```mermaid
xychart-beta
title "工时趋势(近8周)"
x-axis ["W1", "W2", "W3", "W4", "W5", "W6", "W7", "W8"]
y-axis "工时(小时)"
bar [120, 135, 128, 142, 150, 148, 155, 160]
line [120, 135, 128, 142, 150, 148, 155, 160]
```
项目中枢 — 跨平台项目管理聚合器,统一管理 Trello、GitHub Issues、Linear、Notion、Obsidian 任务,支持自学习引擎和任务关系图谱
---
name: project-nerve
description: 项目中枢 — 跨平台项目管理聚合器,统一管理 Trello、GitHub Issues、Linear、Notion、Obsidian 任务,支持自学习引擎和任务关系图谱
version: 1.1.0
metadata:
openclaw:
optional_env:
- PNC_TRELLO_API_KEY
- PNC_TRELLO_TOKEN
- PNC_GITHUB_TOKEN
- PNC_LINEAR_API_KEY
- PNC_NOTION_TOKEN
- PNC_NOTION_DATABASE_ID
- PNC_OBSIDIAN_VAULT_PATH
- PNC_SUBSCRIPTION_TIER
---
# 项目中枢(project-nerve)
你是一个专业的跨平台项目管理 Agent。你的职责是帮助用户统一管理分散在 Trello、GitHub Issues、Linear、Notion、Obsidian 上的任务,提供聚合视图、智能分析、自学习引擎、任务关系图谱和自动化站会报告。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `PNC_TRELLO_API_KEY` | 否 | Trello API Key(连接 Trello 时需要) |
| `PNC_TRELLO_TOKEN` | 否 | Trello 用户 Token |
| `PNC_GITHUB_TOKEN` | 否 | GitHub Personal Access Token |
| `PNC_LINEAR_API_KEY` | 否 | Linear API Key |
| `PNC_NOTION_TOKEN` | 否 | Notion Integration Token |
| `PNC_NOTION_DATABASE_ID` | 否 | Notion 目标数据库 ID |
| `PNC_OBSIDIAN_VAULT_PATH` | 否 | Obsidian Vault 本地路径(连接 Obsidian 时需要) |
| `PNC_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
启动时,检查是否至少配置了一个平台的凭据。若所有平台凭据均未设置,引导用户进入「数据源配置流程」。
---
## 流程一:数据源配置
当用户说"连接平台"、"添加数据源"、"配置 Trello/GitHub/Linear/Notion"或类似意图时,执行以下步骤:
### 步骤 1:选择平台
向用户展示支持的平台,并引导选择:
```
请选择要连接的项目管理平台:
1. Trello — 看板式任务管理
2. GitHub Issues — 代码仓库问题跟踪
3. Linear — 现代研发项目管理
4. Notion — 知识库与任务管理
5. Obsidian — 本地 Vault 笔记中的任务
```
### 步骤 2:收集凭据信息
根据所选平台,引导用户设置对应的环境变量。**绝不在对话中让用户直接输入 Token 或密码**,而是指导设置环境变量:
- **Trello**: 设置 `PNC_TRELLO_API_KEY` 和 `PNC_TRELLO_TOKEN`,可选指定 board_id
- **GitHub**: 设置 `PNC_GITHUB_TOKEN`,指定仓库(owner/repo 格式)
- **Linear**: 设置 `PNC_LINEAR_API_KEY`,可选指定 team_id
- **Notion**: 设置 `PNC_NOTION_TOKEN` 和 `PNC_NOTION_DATABASE_ID`
- **Obsidian**: 设置 `PNC_OBSIDIAN_VAULT_PATH`(本地 Vault 路径),可选指定 task_tag(默认 #task)
### 步骤 3:测试连接
```bash
python3 scripts/source_connector.py --action test --data '{"platform":"<platform>"}'
```
连接成功后显示用户信息,失败时引导排查。
### 步骤 4:保存连接
```bash
python3 scripts/source_connector.py --action connect --data '{"platform":"<platform>","name":"<名称>",...}'
```
### 步骤 5:查看已连接数据源
```bash
python3 scripts/source_connector.py --action list-sources
```
---
## 流程二:任务聚合与查询
当用户说"查看任务"、"同步任务"、"搜索任务"或类似意图时:
### 步骤 1:同步任务
```bash
python3 scripts/task_aggregator.py --action fetch-all
```
从所有已连接平台获取最新任务,统一格式化后缓存。
### 步骤 2:展示结果
将返回的任务以 Markdown 表格形式展示给用户,包含状态统计和平台分布。
### 步骤 3:搜索与过滤
根据用户需求执行搜索:
```bash
python3 scripts/task_aggregator.py --action search --data '{"keyword":"关键词","status":"进行中"}'
```
### 步骤 4:阻碍分析(付费功能)
```bash
python3 scripts/task_aggregator.py --action blockers
```
### 步骤 5:优先级排序
```bash
python3 scripts/task_aggregator.py --action priorities
```
---
## 流程三:任务创建与管理
当用户说"创建任务"、"新建 Issue"、"添加卡片"或类似意图时:
### 步骤 1:收集任务信息
引导用户提供:
- 标题(必填)
- 描述(可选)
- 优先级(可选,默认自动判断)
- 平台(可选,默认自动检测)
- 截止日期(可选)
### 步骤 2:自动检测平台
若用户未指定平台,根据任务内容自动推荐:
- 代码/Bug/PR 相关 → GitHub
- 笔记/知识/备忘 相关 → Obsidian
- 文档/设计/数据库 相关 → Notion
- Sprint/Story 相关 → Linear
- 其他 → Trello
### 步骤 3:创建任务
```bash
python3 scripts/task_writer.py --action create --data '{"title":"...","platform":"..."}'
```
### 步骤 4:更新/移动/评论
```bash
python3 scripts/task_writer.py --action update --data '{"source":"github","source_id":"123","status":"已完成"}'
python3 scripts/task_writer.py --action comment --data '{"source":"github","source_id":"123","comment":"已修复"}'
```
---
## 流程四:冲刺分析与站会报告
### 每日站会
```bash
python3 scripts/standup_generator.py --action daily
```
生成标准格式:昨日完成 / 今日计划 / 阻碍事项。
### 每周总结(付费功能)
```bash
python3 scripts/standup_generator.py --action weekly
```
### 冲刺分析(付费功能)
```bash
python3 scripts/sprint_analyzer.py --action velocity --data '{"days":14}'
python3 scripts/sprint_analyzer.py --action funnel
python3 scripts/sprint_analyzer.py --action burndown --data '{"days":14}'
python3 scripts/sprint_analyzer.py --action report --data '{"days":14}'
```
---
## 流程五:自学习引擎
当用户说"学习统计"、"改进建议"、"查看学习数据"或系统在操作过程中遇到错误/成功时:
### 记录错误
```bash
python3 scripts/learning_engine.py --action record-error --data '{"category":"api_failure","source":"github","error_type":"timeout","message":"请求超时"}'
```
### 记录成功
```bash
python3 scripts/learning_engine.py --action record-success --data '{"category":"fetch_success","source":"trello","action":"fetch"}'
```
### 记录用户纠正
```bash
python3 scripts/learning_engine.py --action record-correction --data '{"category":"platform_override","field":"platform","original_value":"trello","corrected_value":"notion"}'
```
### 获取改进建议
```bash
python3 scripts/learning_engine.py --action suggest
```
基于积累的学习数据,自动生成可操作的改进建议(如切换不稳定平台、调整自动检测策略等)。
### 查看统计
```bash
python3 scripts/learning_engine.py --action stats
```
### 重置学习数据
```bash
python3 scripts/learning_engine.py --action reset --data '{"confirm":true}'
```
---
## 流程六:任务关系图谱
当用户说"添加依赖"、"任务关系"、"影响分析"或类似意图时:
### 添加关系
```bash
python3 scripts/task_graph.py --action add-relation --data '{"from_id":"github-123","to_id":"trello-abc","type":"blocks","from_source":"github","to_source":"trello"}'
```
支持的关系类型: blocks、blocked_by、related_to、parent_of、child_of、duplicates
### 查询关联任务
```bash
python3 scripts/task_graph.py --action query --data '{"task_id":"github-123","max_depth":3}'
```
### 依赖分析
```bash
python3 scripts/task_graph.py --action dependencies --data '{"task_id":"github-123"}'
```
构建依赖树,自动检测循环依赖。
### 影响分析
```bash
python3 scripts/task_graph.py --action impact --data '{"task_id":"github-123"}'
```
分析如果某个任务被阻塞,有多少下游任务会受到影响。
### 可视化(付费功能)
```bash
python3 scripts/task_graph.py --action visualize --data '{"task_id":"github-123"}'
```
生成 Mermaid 流程图,展示任务关系网络。
---
## 流程七:Obsidian 数据源
当用户说"连接 Obsidian"、"同步笔记任务"或类似意图时:
### 连接 Obsidian Vault
```bash
python3 scripts/source_connector.py --action connect --data '{"platform":"obsidian","vault_path":"/path/to/vault","task_tag":"#task"}'
```
### Obsidian 任务格式
在 Obsidian 笔记中使用 markdown 复选框格式:
```markdown
---
status: 进行中
priority: 高
assignee: zhangsan
due_date: 2026-03-25
---
# 项目规划
- [ ] 完成需求分析 #task
- [x] 编写技术方案 #task
- [ ] 前端原型设计 📅 2026-03-20
```
系统会自动提取 frontmatter 中的 status、priority、assignee、due_date,以及正文中的复选框任务。
---
## 订阅校验逻辑
### 读取订阅等级
```
tier = env PNC_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥99/月) |
|------|---------------|----------------------|
| 数据源数量 | 最多 2 个 | 最多 10 个 |
| 任务显示数量 | 50 条 | 500 条 |
| 基本查询 / 任务列表 | 支持 | 支持 |
| 平台连接(含 Obsidian) | 支持 | 支持 |
| 自学习引擎(错误/成功记录) | 支持 | 支持 |
| 自学习引擎(高级建议/偏好分析) | 不支持 | 支持 |
| 任务关系图谱(添加/查询/依赖/影响) | 不支持 | 支持 |
| 任务关系图谱(Mermaid 可视化) | 不支持 | 支持 |
| 冲刺分析(速度/漏斗/燃尽) | 不支持 | 支持 |
| 站会报告(日报) | 支持 | 支持 |
| 站会报告(周报) | 不支持 | 支持 |
| 阻碍分析 | 不支持 | 支持 |
| Mermaid 可视化图表 | 不支持 | 支持 |
| 批量同步 | 不支持 | 支持 |
### 校验失败时的行为
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版优势。
3. 提供升级引导:"如需升级至付费版(¥99/月),请联系管理员或访问订阅管理页面。"
4. 提供免费版可用的替代方案(如有)。
---
## 安全规范
1. **凭据保护**:所有平台 Token 和 API Key 仅通过环境变量传递,绝不在对话中显示、记录或输出。
2. **数据存储**:连接配置中不保存明文凭据,仅保存环境变量名引用。
3. **API 安全**:所有 HTTP 请求使用 HTTPS,超时限制 15 秒。
4. **错误处理**:API 调用失败时向用户展示友好的错误提示,不暴露内部路径。
5. **数据脱敏**:输出配置时隐藏敏感字段。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 在执行平台连接前,先确认用户已设置所需的环境变量。
3. 展示任务数据时优先使用表格格式,清晰易读。
4. 主动提示风险任务(逾期、高优先级长时间无进展)。
5. 创建任务时主动推荐最合适的平台,并解释推荐理由。
6. 遇到 API 错误时,耐心排查并给出可行的解决方案。
7. 尊重订阅等级限制,提示升级时保持友好。
---
## 参考文档
- **API 指南**:`references/api-guide.md` — 各平台 API 端点和使用方法。
- **统一模型**:`references/unified-schema.md` — 统一任务模型定义和状态/优先级映射表。
FILE:assets/README.md
# 项目中枢 (project-nerve)
> 跨平台项目管理聚合器 — 一个入口,统一管理 Trello、GitHub Issues、Linear、Notion、Obsidian 任务
---
## 功能亮点
- **多平台聚合** — 连接 Trello、GitHub Issues、Linear、Notion、Obsidian,一个视图查看所有任务
- **智能标准化** — 自动统一各平台的状态、优先级和任务格式,消除信息孤岛
- **自动检测平台** — 创建任务时根据内容智能推荐最合适的平台(Bug → GitHub,文档 → Notion)
- 🧠 **自学习引擎** — 从使用模式中持续优化,记录错误/成功/纠正,自动生成改进建议
- 🕸️ **任务关系图谱** — 可视化跨平台任务依赖,检测循环依赖,影响分析
- 📝 **Obsidian 集成** — 本地知识库作为任务源,从 markdown 复选框和 frontmatter 提取任务
- **站会报告** — 自动生成每日站会和每周总结,扫描所有平台的最新动态
- **冲刺分析** — 速度统计、任务漏斗、燃尽图,用数据驱动项目管理
- **阻碍预警** — 自动识别逾期和高风险任务,及时提醒处理
- **Mermaid 图表** — 饼图、柱状图、折线图内嵌报告,无需额外工具
- **安全优先** — 凭据通过环境变量管理,本地运行不上传数据
---
## Feature Highlights
- **Multi-platform Aggregation** — Connect Trello, GitHub Issues, Linear, Notion, Obsidian in a single unified view
- **Smart Normalization** — Automatically unify status, priority, and task formats across platforms
- **Auto Platform Detection** — Intelligently recommend the best platform when creating tasks
- 🧠 **Self-Learning Engine** — Continuously improve from usage patterns, auto-generate optimization suggestions
- 🕸️ **Task Relationship Graph** — Visualize cross-platform task dependencies, detect cycles, impact analysis
- 📝 **Obsidian Integration** — Use local knowledge base as a task source, extract from checkboxes and frontmatter
- **Standup Reports** — Auto-generate daily standups and weekly summaries from all connected sources
- **Sprint Analytics** — Velocity tracking, task funnel, burndown charts with data-driven insights
- **Blocker Alerts** — Proactively identify overdue and high-risk tasks
- **Mermaid Charts** — Embedded pie, bar, and line charts in Markdown reports
- **Security First** — Credentials via env vars only, all data stays local
---
## 版本对比 / Version Comparison
| 功能 Feature | 免费版 Free | 付费版 Paid ¥99/月 |
|------|:------:|:------------:|
| 平台数量 Sources(含 Obsidian) | 最多 2 个 | 最多 10 个 |
| 任务显示 Tasks | 50 条 | 500 条 |
| 任务查询 Query | 支持 | 支持 |
| 创建/更新任务 CRUD | 支持 | 支持 |
| 自学习引擎 Learning Engine(基础) | 支持 | 支持 |
| 自学习引擎 Learning Engine(高级建议) | - | 支持 |
| 任务关系图谱 Task Graph | - | 支持 |
| 任务图谱可视化 Graph Visualization | - | 支持 |
| 每日站会 Daily Standup | 支持 | 支持 |
| 每周总结 Weekly Report | - | 支持 |
| 冲刺分析 Sprint Analytics | - | 支持(速度/漏斗/燃尽图) |
| 阻碍分析 Blocker Analysis | - | 支持 |
| Mermaid 图表 Charts | - | 支持 |
| 批量同步 Bulk Sync | - | 支持 |
---
## 快速开始 / Quick Start
### 1. 安装 Skill
```bash
openclaw skill install project-nerve
```
### 2. 配置平台凭据
```bash
# Trello
export PNC_TRELLO_API_KEY="your-api-key"
export PNC_TRELLO_TOKEN="your-token"
# GitHub
export PNC_GITHUB_TOKEN="ghp_your-token"
# Linear
export PNC_LINEAR_API_KEY="lin_api_your-key"
# Notion
export PNC_NOTION_TOKEN="ntn_your-token"
export PNC_NOTION_DATABASE_ID="your-database-id"
# Obsidian
export PNC_OBSIDIAN_VAULT_PATH="/path/to/your/vault"
```
### 3. 连接平台
```bash
/project-nerve 连接 GitHub
/project-nerve 连接 Trello
```
### 4. 开始使用
```bash
# 查看所有任务
/project-nerve 同步任务
# 搜索任务
/project-nerve 搜索 "登录Bug"
# 创建任务
/project-nerve 创建任务 "修复登录页面样式问题"
# 生成站会报告
/project-nerve 站会
# 冲刺报告(付费版)
/project-nerve 冲刺报告
```
---
## 示例输出 / Example Output
### 每日站会
```markdown
# 每日站会 — 2026-03-19
## 昨日完成
- **修复用户头像上传失败** [github] (高) @zhangsan
- **更新首页Banner设计稿** [notion] @lisi
## 今日计划
- [进行中] **实现用户权限模块** [linear] (紧急) @zhangsan
- [待启动] **编写API文档** [notion] @wangwu
## 阻碍事项
- **支付接口对接** [github] (紧急) — 逾期(截止: 2026-03-17)
---
完成 2 | 计划 2 | 阻碍 1
```
### 任务聚合表
```markdown
| # | 标题 | 平台 | 状态 | 优先级 | 负责人 | 截止日期 |
|---|------|------|------|--------|--------|----------|
| 1 | 实现用户权限模块 | linear | 进行中 | 紧急 | zhangsan | 2026-03-21 |
| 2 | 修复登录页面样式 | github | 待办 | 高 | lisi | 2026-03-20 |
| 3 | 更新产品路线图 | notion | 进行中 | 中 | wangwu | - |
| 4 | 优化首页加载速度 | trello | 待办 | 中 | - | 2026-03-25 |
```
---
## 常见问题 / FAQ
### Q1: 支持哪些平台?
目前支持 Trello、GitHub Issues、Linear、Notion 和 Obsidian(本地 Vault)。后续计划支持 Jira、Asana、ClickUp 等。
### Q2: 数据安全吗?
所有数据在本地处理,凭据通过环境变量管理,不会上传到云端。API 通信均使用 HTTPS。
### Q3: 免费版够用吗?
免费版支持连接 2 个平台、查看 50 条任务和每日站会,适合个人或小团队。需要冲刺分析、周报和阻碍分析请升级付费版。
### Q4: 任务去重逻辑是什么?
当同一任务在多个平台存在时(如 GitHub Issue 和 Trello 卡片标题高度相似),系统会基于标题词重叠率自动去重,保留最近更新的版本。
### Q5: 如何自定义站会报告?
可通过 `--data '{"assignee":"zhangsan"}'` 过滤特定成员的报告。
### Q6: Mermaid 图表在哪里查看?
Mermaid 图表可在 GitHub/GitLab Markdown 预览、VS Code(安装 Mermaid 插件)、Typora、Obsidian 等工具中直接渲染。
---
## 技术支持
- 文档:查看 `references/` 目录获取 API 和模型参考
- 问题反馈:在 ClawHub 的 Skill 页面提交 Issue
- 社区讨论:加入 ClawHub 社区频道 `#project-nerve`
- 邮件:[email protected]
---
*project-nerve v1.1.0 | 兼容 OpenClaw 0.5+*
FILE:scripts/source_connector.py
#!/usr/bin/env python3
"""
project-nerve 数据源连接器
管理 Trello、GitHub Issues、Linear、Notion、Obsidian 等平台的连接配置。
支持连接、测试、列表、断开操作。
"""
import json
import os
import sys
import urllib.request
import urllib.error
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
write_json_file,
SUPPORTED_PLATFORMS,
)
# ============================================================
# 数据文件路径
# ============================================================
SOURCES_FILE = "sources.json"
def _get_sources() -> List[Dict[str, Any]]:
"""读取所有已连接的数据源配置。"""
data = read_json_file(get_data_file(SOURCES_FILE))
if isinstance(data, list):
return data
return []
def _save_sources(sources: List[Dict[str, Any]]) -> None:
"""保存数据源配置到文件。"""
write_json_file(get_data_file(SOURCES_FILE), sources)
def _find_source(sources: List[Dict], source_id: str) -> Optional[Dict]:
"""根据 ID 查找数据源。"""
for s in sources:
if s.get("id") == source_id:
return s
return None
def _find_source_by_platform(sources: List[Dict], platform: str) -> Optional[Dict]:
"""根据平台类型查找数据源。"""
for s in sources:
if s.get("platform") == platform:
return s
return None
# ============================================================
# HTTP 请求工具
# ============================================================
def _http_request(
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
data: Optional[bytes] = None,
timeout: int = 15,
) -> Dict[str, Any]:
"""发送 HTTP 请求。
Args:
url: 请求地址。
method: HTTP 方法。
headers: 请求头。
data: 请求体。
timeout: 超时秒数。
Returns:
包含 status、body、headers 的响应字典。
"""
if headers is None:
headers = {}
req = urllib.request.Request(url, data=data, headers=headers, method=method)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
body = resp.read().decode("utf-8")
return {
"status": resp.status,
"body": body,
"headers": dict(resp.headers),
}
except urllib.error.HTTPError as e:
body = ""
try:
body = e.read().decode("utf-8")
except Exception:
pass
return {
"status": e.code,
"body": body,
"headers": {},
"error": str(e),
}
except urllib.error.URLError as e:
return {
"status": 0,
"body": "",
"headers": {},
"error": f"网络错误: {e.reason}",
}
except Exception as e:
return {
"status": 0,
"body": "",
"headers": {},
"error": f"请求失败: {e}",
}
# ============================================================
# 平台适配器 — 连接测试
# ============================================================
def _test_trello(config: Dict[str, Any]) -> Dict[str, Any]:
"""测试 Trello API 连接。
需要 api_key 和 token。
通过访问 /1/members/me 端点验证凭据。
Args:
config: 包含 api_key 和 token 的配置字典。
Returns:
测试结果字典,包含 success、message、user_info。
"""
api_key = config.get("api_key") or os.environ.get("PNC_TRELLO_API_KEY", "")
token = config.get("token") or os.environ.get("PNC_TRELLO_TOKEN", "")
if not api_key or not token:
return {"success": False, "message": "缺少 Trello API Key 或 Token,请设置 PNC_TRELLO_API_KEY 和 PNC_TRELLO_TOKEN 环境变量"}
url = f"https://api.trello.com/1/members/me?key={api_key}&token={token}"
resp = _http_request(url)
if resp["status"] == 200:
try:
user = json.loads(resp["body"])
return {
"success": True,
"message": f"Trello 连接成功,用户: {user.get('fullName', user.get('username', '未知'))}",
"user_info": {"name": user.get("fullName", ""), "username": user.get("username", "")},
}
except json.JSONDecodeError:
return {"success": True, "message": "Trello 连接成功"}
else:
return {"success": False, "message": f"Trello 连接失败 (HTTP {resp['status']}): {resp.get('error', resp['body'][:200])}"}
def _test_github(config: Dict[str, Any]) -> Dict[str, Any]:
"""测试 GitHub API 连接。
需要 personal access token。
通过访问 /user 端点验证凭据。
Args:
config: 包含 token 的配置字典。
Returns:
测试结果字典。
"""
token = config.get("token") or os.environ.get("PNC_GITHUB_TOKEN", "")
if not token:
return {"success": False, "message": "缺少 GitHub Token,请设置 PNC_GITHUB_TOKEN 环境变量"}
url = "https://api.github.com/user"
headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github.v3+json",
"User-Agent": "project-nerve/1.0",
}
resp = _http_request(url, headers=headers)
if resp["status"] == 200:
try:
user = json.loads(resp["body"])
return {
"success": True,
"message": f"GitHub 连接成功,用户: {user.get('login', '未知')}",
"user_info": {"login": user.get("login", ""), "name": user.get("name", "")},
}
except json.JSONDecodeError:
return {"success": True, "message": "GitHub 连接成功"}
else:
return {"success": False, "message": f"GitHub 连接失败 (HTTP {resp['status']}): {resp.get('error', resp['body'][:200])}"}
def _test_linear(config: Dict[str, Any]) -> Dict[str, Any]:
"""测试 Linear API 连接。
需要 API key。
通过 GraphQL API 查询 viewer 验证凭据。
Args:
config: 包含 api_key 的配置字典。
Returns:
测试结果字典。
"""
api_key = config.get("api_key") or os.environ.get("PNC_LINEAR_API_KEY", "")
if not api_key:
return {"success": False, "message": "缺少 Linear API Key,请设置 PNC_LINEAR_API_KEY 环境变量"}
url = "https://api.linear.app/graphql"
headers = {
"Authorization": api_key,
"Content-Type": "application/json",
}
query = '{"query": "{ viewer { id name email } }"}'
resp = _http_request(url, method="POST", headers=headers, data=query.encode("utf-8"))
if resp["status"] == 200:
try:
result = json.loads(resp["body"])
viewer = result.get("data", {}).get("viewer", {})
name = viewer.get("name", "未知")
return {
"success": True,
"message": f"Linear 连接成功,用户: {name}",
"user_info": {"name": name, "email": viewer.get("email", "")},
}
except (json.JSONDecodeError, KeyError):
return {"success": True, "message": "Linear 连接成功"}
else:
return {"success": False, "message": f"Linear 连接失败 (HTTP {resp['status']}): {resp.get('error', resp['body'][:200])}"}
def _test_obsidian(config: Dict[str, Any]) -> Dict[str, Any]:
"""测试 Obsidian Vault 可访问性。
验证 vault 路径是否存在,并统计带有任务标签的笔记数量。
Args:
config: 包含 vault_path 和 task_tag 的配置字典。
Returns:
测试结果字典,包含 success、message、user_info。
"""
vault_path = config.get("vault_path") or os.environ.get("PNC_OBSIDIAN_VAULT_PATH", "")
task_tag = config.get("task_tag", "#task")
if not vault_path:
return {"success": False, "message": "缺少 Obsidian Vault 路径,请设置 PNC_OBSIDIAN_VAULT_PATH 环境变量或提供 vault_path"}
vault_path = os.path.expanduser(vault_path)
if not os.path.isdir(vault_path):
return {"success": False, "message": f"Obsidian Vault 路径不存在: {vault_path}"}
# 扫描 .md 文件,统计带任务标签的笔记数量
task_note_count = 0
total_md_count = 0
for root, _dirs, files in os.walk(vault_path):
for fname in files:
if not fname.endswith(".md"):
continue
total_md_count += 1
fpath = os.path.join(root, fname)
try:
with open(fpath, "r", encoding="utf-8") as f:
content = f.read(4096) # 只读取前 4KB 检查标签
if task_tag in content or "- [ ]" in content or "- [x]" in content:
task_note_count += 1
except (IOError, UnicodeDecodeError):
continue
return {
"success": True,
"message": f"Obsidian Vault 连接成功,共 {total_md_count} 个笔记,其中 {task_note_count} 个包含任务",
"user_info": {
"vault_path": vault_path,
"total_notes": total_md_count,
"task_notes": task_note_count,
"task_tag": task_tag,
},
}
def _connect_obsidian(config: Dict[str, Any]) -> Dict[str, Any]:
"""验证 Obsidian Vault 配置。
检查 vault 路径是否存在,扫描带有任务标签的笔记。
Args:
config: Obsidian 配置字典。
Returns:
包含 vault_path 和 task_tag 的已验证配置。
"""
return _test_obsidian(config)
def _test_notion(config: Dict[str, Any]) -> Dict[str, Any]:
"""测试 Notion API 连接。
需要 integration token 和 database_id。
通过查询数据库验证凭据。
Args:
config: 包含 token 和 database_id 的配置字典。
Returns:
测试结果字典。
"""
token = config.get("token") or os.environ.get("PNC_NOTION_TOKEN", "")
database_id = config.get("database_id") or os.environ.get("PNC_NOTION_DATABASE_ID", "")
if not token:
return {"success": False, "message": "缺少 Notion Integration Token,请设置 PNC_NOTION_TOKEN 环境变量"}
if not database_id:
return {"success": False, "message": "缺少 Notion Database ID,请设置 PNC_NOTION_DATABASE_ID 环境变量"}
url = f"https://api.notion.com/v1/databases/{database_id}"
headers = {
"Authorization": f"Bearer {token}",
"Notion-Version": "2022-06-28",
"Content-Type": "application/json",
}
resp = _http_request(url, headers=headers)
if resp["status"] == 200:
try:
db = json.loads(resp["body"])
title_parts = db.get("title", [])
db_title = ""
for part in title_parts:
db_title += part.get("plain_text", "")
return {
"success": True,
"message": f"Notion 连接成功,数据库: {db_title or database_id}",
"user_info": {"database_title": db_title, "database_id": database_id},
}
except (json.JSONDecodeError, KeyError):
return {"success": True, "message": "Notion 连接成功"}
else:
return {"success": False, "message": f"Notion 连接失败 (HTTP {resp['status']}): {resp.get('error', resp['body'][:200])}"}
# ============================================================
# 平台测试路由
# ============================================================
_PLATFORM_TESTERS = {
"trello": _test_trello,
"github": _test_github,
"linear": _test_linear,
"notion": _test_notion,
"obsidian": _test_obsidian,
}
# ============================================================
# 操作实现
# ============================================================
def connect_source(data: Dict[str, Any]) -> None:
"""连接新数据源。
必填字段: platform
可选字段: name, api_key, token, database_id, repo(GitHub 仓库如 owner/repo), board_id(Trello)
Args:
data: 数据源配置字典。
"""
platform = data.get("platform", "").strip().lower()
if platform not in SUPPORTED_PLATFORMS:
valid = "、".join(SUPPORTED_PLATFORMS)
output_error(f"不支持的平台: {platform!r},支持的平台: {valid}", code="INVALID_PLATFORM")
return
# 检查订阅限制
sub = check_subscription()
sources = _get_sources()
if len(sources) >= sub["max_sources"]:
if sub["tier"] == "free":
output_error(
f"免费版最多连接 {sub['max_sources']} 个数据源,当前已有 {len(sources)} 个。"
"请升级至付费版(¥99/月)以连接更多数据源。",
code="LIMIT_EXCEEDED",
)
else:
output_error(
f"已达到数据源数量上限 {sub['max_sources']} 个。",
code="LIMIT_EXCEEDED",
)
return
# 检查是否已存在同平台连接
existing = _find_source_by_platform(sources, platform)
if existing:
output_error(
f"已存在 {platform} 平台的连接(ID: {existing['id']})。如需重新连接,请先断开现有连接。",
code="DUPLICATE_SOURCE",
)
return
# 测试连接
tester = _PLATFORM_TESTERS.get(platform)
if tester:
test_result = tester(data)
if not test_result["success"]:
output_error(test_result["message"], code="CONNECTION_FAILED")
return
# 保存配置
now = now_iso()
source_config = {
"id": generate_id("SRC"),
"platform": platform,
"name": data.get("name", f"{platform} 数据源"),
"config": {},
"connected_at": now,
"updated_at": now,
"status": "active",
}
# 按平台保存不同的配置字段(不保存敏感凭据,使用环境变量)
if platform == "trello":
source_config["config"]["board_id"] = data.get("board_id", "")
source_config["config"]["env_key"] = "PNC_TRELLO_API_KEY"
source_config["config"]["env_token"] = "PNC_TRELLO_TOKEN"
elif platform == "github":
source_config["config"]["repo"] = data.get("repo", "")
source_config["config"]["env_token"] = "PNC_GITHUB_TOKEN"
elif platform == "linear":
source_config["config"]["team_id"] = data.get("team_id", "")
source_config["config"]["env_key"] = "PNC_LINEAR_API_KEY"
elif platform == "notion":
source_config["config"]["database_id"] = data.get("database_id") or os.environ.get("PNC_NOTION_DATABASE_ID", "")
source_config["config"]["env_token"] = "PNC_NOTION_TOKEN"
elif platform == "obsidian":
vault_path = data.get("vault_path") or os.environ.get("PNC_OBSIDIAN_VAULT_PATH", "")
source_config["config"]["vault_path"] = os.path.expanduser(vault_path)
source_config["config"]["task_tag"] = data.get("task_tag", "#task")
source_config["config"]["env_vault_path"] = "PNC_OBSIDIAN_VAULT_PATH"
sources.append(source_config)
_save_sources(sources)
output_success({
"message": f"{platform} 数据源已连接",
"source": source_config,
})
def test_source(data: Dict[str, Any]) -> None:
"""测试数据源连接。
必填字段: platform 或 id
Args:
data: 包含平台或数据源 ID 的字典。
"""
source_id = data.get("id")
platform = data.get("platform", "").strip().lower()
if source_id:
sources = _get_sources()
source = _find_source(sources, source_id)
if not source:
output_error(f"未找到 ID 为 {source_id} 的数据源", code="NOT_FOUND")
return
platform = source["platform"]
if platform not in SUPPORTED_PLATFORMS:
valid = "、".join(SUPPORTED_PLATFORMS)
output_error(f"不支持的平台: {platform!r},支持的平台: {valid}", code="INVALID_PLATFORM")
return
tester = _PLATFORM_TESTERS.get(platform)
if not tester:
output_error(f"平台 {platform} 暂不支持连接测试", code="NOT_SUPPORTED")
return
test_result = tester(data)
if test_result["success"]:
output_success({
"platform": platform,
"message": test_result["message"],
"user_info": test_result.get("user_info", {}),
})
else:
output_error(test_result["message"], code="CONNECTION_FAILED")
def list_sources(data: Optional[Dict[str, Any]] = None) -> None:
"""列出所有已连接的数据源。
Args:
data: 可选的过滤条件字典。
"""
sources = _get_sources()
if data:
platform_filter = data.get("platform", "").strip().lower()
if platform_filter:
sources = [s for s in sources if s.get("platform") == platform_filter]
status_filter = data.get("status", "").strip().lower()
if status_filter:
sources = [s for s in sources if s.get("status") == status_filter]
# 构建摘要信息
platform_stats = {}
for p in SUPPORTED_PLATFORMS:
count = sum(1 for s in sources if s.get("platform") == p)
if count > 0:
platform_stats[p] = count
# 脱敏输出(不显示敏感配置)
display_list = []
for s in sources:
display = {
"id": s.get("id"),
"platform": s.get("platform"),
"name": s.get("name"),
"status": s.get("status"),
"connected_at": s.get("connected_at"),
"updated_at": s.get("updated_at"),
}
display_list.append(display)
output_success({
"total": len(display_list),
"platform_stats": platform_stats,
"sources": display_list,
})
def disconnect_source(data: Dict[str, Any]) -> None:
"""断开数据源连接。
必填字段: id 或 platform
Args:
data: 包含数据源 ID 或平台名称的字典。
"""
source_id = data.get("id")
platform = data.get("platform", "").strip().lower()
sources = _get_sources()
original_count = len(sources)
if source_id:
sources = [s for s in sources if s.get("id") != source_id]
elif platform:
sources = [s for s in sources if s.get("platform") != platform]
else:
output_error("请提供数据源 ID(id)或平台名称(platform)", code="VALIDATION_ERROR")
return
if len(sources) == original_count:
output_error("未找到匹配的数据源", code="NOT_FOUND")
return
removed_count = original_count - len(sources)
_save_sources(sources)
output_success({
"message": f"已断开 {removed_count} 个数据源连接",
"remaining": len(sources),
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("project-nerve 数据源连接器")
args = parser.parse_args()
action = args.action.lower().replace("-", "_")
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"connect": lambda: connect_source(data or {}),
"test": lambda: test_source(data or {}),
"list_sources": lambda: list_sources(data),
"list": lambda: list_sources(data),
"disconnect": lambda: disconnect_source(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(["connect", "test", "list-sources", "disconnect"])
output_error(f"未知操作: {args.action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/learning_engine.py
#!/usr/bin/env python3
"""
project-nerve 自学习引擎
从操作记录中持续学习,记录错误、成功模式和用户纠正,
并基于积累的数据提供改进建议和统计分析。
灵感来源: self-improving-agent (255K 下载量)
"""
import json
import os
import sys
import hashlib
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
write_json_file,
)
# ============================================================
# 数据文件路径
# ============================================================
LEARNING_FILE = "learning.json"
# ============================================================
# 数据读写
# ============================================================
def _get_learning_data() -> Dict[str, Any]:
"""读取学习数据文件。
Returns:
包含 patterns 列表的字典。
"""
data = read_json_file(get_data_file(LEARNING_FILE))
if isinstance(data, dict) and "patterns" in data:
return data
return {"patterns": [], "metadata": {"created_at": now_iso(), "version": "1.0"}}
def _save_learning_data(data: Dict[str, Any]) -> None:
"""保存学习数据到文件。
Args:
data: 学习数据字典。
"""
data["metadata"] = data.get("metadata", {})
data["metadata"]["updated_at"] = now_iso()
write_json_file(get_data_file(LEARNING_FILE), data)
# ============================================================
# 模式指纹计算(用于聚合相似模式)
# ============================================================
def _compute_fingerprint(pattern_type: str, category: str, context: Dict[str, Any]) -> str:
"""计算模式指纹,用于识别和聚合相似的模式。
基于类型、分类和关键上下文信息生成哈希指纹。
相同指纹的模式会被聚合(计数递增)而非重复存储。
Args:
pattern_type: 模式类型(error / success / correction)。
category: 分类标签(如 api_failure, platform_choice 等)。
context: 上下文字典,包含与模式相关的详细信息。
Returns:
16 位十六进制指纹字符串。
"""
# 提取用于指纹计算的关键字段
key_parts = [pattern_type, category]
# 根据不同类型提取不同的关键信息
if pattern_type == "error":
key_parts.append(context.get("source", ""))
key_parts.append(context.get("error_type", ""))
key_parts.append(context.get("action", ""))
elif pattern_type == "success":
key_parts.append(context.get("source", ""))
key_parts.append(context.get("action", ""))
key_parts.append(context.get("task_type", ""))
elif pattern_type == "correction":
key_parts.append(context.get("field", ""))
key_parts.append(context.get("original_value", ""))
key_parts.append(context.get("corrected_value", ""))
raw = "|".join(str(p) for p in key_parts)
return hashlib.md5(raw.encode("utf-8")).hexdigest()[:16]
# ============================================================
# 模式匹配与聚合
# ============================================================
def _find_matching_pattern(
patterns: List[Dict[str, Any]], fingerprint: str
) -> Optional[int]:
"""查找与指纹匹配的已有模式索引。
Args:
patterns: 模式列表。
fingerprint: 待匹配的指纹。
Returns:
匹配的模式在列表中的索引,未找到返回 None。
"""
for i, p in enumerate(patterns):
if p.get("fingerprint") == fingerprint:
return i
return None
def _record_pattern(
pattern_type: str,
category: str,
context: Dict[str, Any],
lesson: str,
) -> Dict[str, Any]:
"""记录一条学习模式。
若已存在指纹相同的模式,则递增计数并更新时间戳;
否则创建新的模式记录。
Args:
pattern_type: 模式类型(error / success / correction)。
category: 分类标签。
context: 上下文信息。
lesson: 从该模式中总结的经验教训。
Returns:
记录的模式字典。
"""
data = _get_learning_data()
patterns = data["patterns"]
fingerprint = _compute_fingerprint(pattern_type, category, context)
existing_idx = _find_matching_pattern(patterns, fingerprint)
if existing_idx is not None:
# 聚合:递增计数,更新时间戳
patterns[existing_idx]["count"] += 1
patterns[existing_idx]["last_seen"] = now_iso()
# 如果提供了新的经验教训,更新之
if lesson and lesson != patterns[existing_idx].get("lesson", ""):
patterns[existing_idx]["lesson"] = lesson
pattern = patterns[existing_idx]
else:
# 创建新模式
now = now_iso()
pattern = {
"id": generate_id("LRN"),
"type": pattern_type,
"category": category,
"context": context,
"lesson": lesson,
"fingerprint": fingerprint,
"count": 1,
"first_seen": now,
"last_seen": now,
}
patterns.append(pattern)
data["patterns"] = patterns
_save_learning_data(data)
return pattern
# ============================================================
# 操作实现:记录错误
# ============================================================
def record_error(data: Dict[str, Any]) -> None:
"""记录操作中遇到的错误。
用于追踪 API 失败、超时、解析错误等问题模式。
相同来源和类型的错误会被聚合计数。
必填字段: category(错误分类,如 api_failure / timeout / parse_error)
可选字段: source(平台名称), error_type(具体错误类型),
action(触发错误的操作), message(错误详情), lesson(经验总结)
Args:
data: 错误信息字典。
"""
category = data.get("category", "").strip()
if not category:
output_error("错误分类(category)为必填字段", code="VALIDATION_ERROR")
return
context = {
"source": data.get("source", ""),
"error_type": data.get("error_type", ""),
"action": data.get("action", ""),
"message": data.get("message", ""),
}
lesson = data.get("lesson", "")
pattern = _record_pattern("error", category, context, lesson)
output_success({
"message": f"已记录错误模式: {category}",
"pattern_id": pattern["id"],
"count": pattern["count"],
"aggregated": pattern["count"] > 1,
})
# ============================================================
# 操作实现:记录成功
# ============================================================
def record_success(data: Dict[str, Any]) -> None:
"""记录成功的操作模式。
用于追踪有效的平台选择、成功的查询方式等。
帮助系统学习哪些操作在什么场景下效果最好。
必填字段: category(成功分类,如 platform_choice / query_pattern / fetch_success)
可选字段: source(平台名称), action(操作类型),
task_type(任务类型), details(详细信息), lesson(经验总结)
Args:
data: 成功信息字典。
"""
category = data.get("category", "").strip()
if not category:
output_error("成功分类(category)为必填字段", code="VALIDATION_ERROR")
return
context = {
"source": data.get("source", ""),
"action": data.get("action", ""),
"task_type": data.get("task_type", ""),
"details": data.get("details", ""),
}
lesson = data.get("lesson", "")
pattern = _record_pattern("success", category, context, lesson)
output_success({
"message": f"已记录成功模式: {category}",
"pattern_id": pattern["id"],
"count": pattern["count"],
"aggregated": pattern["count"] > 1,
})
# ============================================================
# 操作实现:记录用户纠正
# ============================================================
def record_correction(data: Dict[str, Any]) -> None:
"""记录用户对自动行为的纠正。
当用户覆盖自动检测结果(如更改推荐平台、调整优先级)时记录,
帮助系统逐步学习用户偏好。
必填字段: category(纠正分类,如 platform_override / priority_change)
可选字段: field(被纠正的字段名), original_value(原始值),
corrected_value(纠正后的值), reason(纠正原因), lesson(经验总结)
Args:
data: 纠正信息字典。
"""
category = data.get("category", "").strip()
if not category:
output_error("纠正分类(category)为必填字段", code="VALIDATION_ERROR")
return
context = {
"field": data.get("field", ""),
"original_value": data.get("original_value", ""),
"corrected_value": data.get("corrected_value", ""),
"reason": data.get("reason", ""),
}
lesson = data.get("lesson", "")
pattern = _record_pattern("correction", category, context, lesson)
output_success({
"message": f"已记录用户纠正: {category}",
"pattern_id": pattern["id"],
"count": pattern["count"],
"aggregated": pattern["count"] > 1,
})
# ============================================================
# 操作实现:建议
# ============================================================
def _generate_error_suggestions(patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""基于错误模式生成改进建议。
分析高频错误,识别问题集中的平台和操作,
提出针对性的改进建议。
Args:
patterns: 错误类型的模式列表。
Returns:
建议列表,每项包含 type、message、severity、based_on。
"""
suggestions = []
if not patterns:
return suggestions
# 按平台统计错误次数
source_errors: Dict[str, int] = {}
for p in patterns:
src = p.get("context", {}).get("source", "未知")
source_errors[src] = source_errors.get(src, 0) + p.get("count", 1)
# 找出错误最多的平台
if source_errors:
worst_source = max(source_errors, key=source_errors.get) # type: ignore
worst_count = source_errors[worst_source]
if worst_count >= 3:
# 找到可替代的平台
all_sources = list(source_errors.keys())
alternatives = [s for s in all_sources if s != worst_source and source_errors[s] < worst_count // 2]
alt_text = ""
if alternatives:
alt_text = f",建议优先使用 {alternatives[0]}"
suggestions.append({
"type": "reliability",
"message": f"{worst_source} 的 API 最近频繁出错({worst_count} 次){alt_text}",
"severity": "高" if worst_count >= 5 else "中",
"based_on": f"{worst_count} 次错误记录",
})
# 按错误类型统计
error_types: Dict[str, int] = {}
for p in patterns:
et = p.get("context", {}).get("error_type", "")
if et:
error_types[et] = error_types.get(et, 0) + p.get("count", 1)
# 超时相关建议
timeout_count = error_types.get("timeout", 0)
if timeout_count >= 2:
suggestions.append({
"type": "performance",
"message": f"检测到 {timeout_count} 次超时,建议检查网络连接或增加超时设置",
"severity": "中",
"based_on": f"{timeout_count} 次超时记录",
})
# 认证相关建议
auth_count = error_types.get("auth_failure", 0) + error_types.get("unauthorized", 0)
if auth_count >= 1:
suggestions.append({
"type": "configuration",
"message": f"检测到 {auth_count} 次认证失败,建议检查 API 密钥是否过期",
"severity": "高",
"based_on": f"{auth_count} 次认证错误",
})
return suggestions
def _generate_correction_suggestions(patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""基于用户纠正模式生成偏好建议。
分析用户经常覆盖的自动决策,识别用户偏好,
提出自动化改进建议。
Args:
patterns: 纠正类型的模式列表。
Returns:
建议列表。
"""
suggestions = []
if not patterns:
return suggestions
# 分析平台覆盖模式
platform_overrides: Dict[str, Dict[str, int]] = {}
for p in patterns:
ctx = p.get("context", {})
if ctx.get("field") == "platform" or p.get("category") == "platform_override":
original = ctx.get("original_value", "")
corrected = ctx.get("corrected_value", "")
if original and corrected:
key = f"{original}->{corrected}"
if key not in platform_overrides:
platform_overrides[key] = {"count": 0, "original": original, "corrected": corrected}
platform_overrides[key]["count"] += p.get("count", 1)
# 生成平台偏好建议
for key, info in platform_overrides.items():
if info["count"] >= 2:
suggestions.append({
"type": "preference",
"message": (
f"用户倾向于将任务从 {info['original']} 改到 {info['corrected']}"
f"(已纠正 {info['count']} 次),建议调整自动检测逻辑"
),
"severity": "低",
"based_on": f"{info['count']} 次平台覆盖",
})
# 分析优先级纠正模式
priority_corrections = [
p for p in patterns
if p.get("context", {}).get("field") == "priority" or p.get("category") == "priority_change"
]
total_prio_corrections = sum(p.get("count", 1) for p in priority_corrections)
if total_prio_corrections >= 3:
suggestions.append({
"type": "preference",
"message": f"用户经常调整任务优先级(共 {total_prio_corrections} 次),建议优化优先级自动判断逻辑",
"severity": "中",
"based_on": f"{total_prio_corrections} 次优先级纠正",
})
return suggestions
def _generate_success_suggestions(patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""基于成功模式生成优化建议。
分析高效的操作模式,提炼可复用的最佳实践。
Args:
patterns: 成功类型的模式列表。
Returns:
建议列表。
"""
suggestions = []
if not patterns:
return suggestions
# 找出使用最频繁的成功平台
source_success: Dict[str, int] = {}
for p in patterns:
src = p.get("context", {}).get("source", "")
if src:
source_success[src] = source_success.get(src, 0) + p.get("count", 1)
if source_success:
best_source = max(source_success, key=source_success.get) # type: ignore
best_count = source_success[best_source]
if best_count >= 5:
suggestions.append({
"type": "optimization",
"message": f"{best_source} 是最常成功使用的平台({best_count} 次),可作为默认首选",
"severity": "低",
"based_on": f"{best_count} 次成功记录",
})
return suggestions
def suggest(data: Optional[Dict[str, Any]] = None) -> None:
"""基于积累的学习数据生成改进建议。
分析错误模式、用户纠正和成功模式,综合生成
可操作的改进建议。
可选字段: type(过滤建议类型: error / correction / success / all)
Args:
data: 可选的过滤参数。
"""
# 检查订阅(高级建议为付费功能)
sub = check_subscription()
is_paid = sub["tier"] == "paid"
learning = _get_learning_data()
patterns = learning.get("patterns", [])
if not patterns:
output_success({
"message": "暂无学习数据,请在使用过程中积累操作记录后再查看建议",
"suggestions": [],
"total_patterns": 0,
})
return
# 按类型分组
errors = [p for p in patterns if p.get("type") == "error"]
successes = [p for p in patterns if p.get("type") == "success"]
corrections = [p for p in patterns if p.get("type") == "correction"]
# 生成建议
filter_type = ""
if data:
filter_type = data.get("type", "").strip().lower()
all_suggestions = []
if not filter_type or filter_type in ("error", "all"):
all_suggestions.extend(_generate_error_suggestions(errors))
if not filter_type or filter_type in ("correction", "all"):
if is_paid:
all_suggestions.extend(_generate_correction_suggestions(corrections))
elif corrections:
all_suggestions.append({
"type": "upgrade_hint",
"message": f"检测到 {len(corrections)} 条用户纠正记录,升级付费版可获得个性化偏好建议",
"severity": "低",
"based_on": f"{len(corrections)} 条纠正记录",
})
if not filter_type or filter_type in ("success", "all"):
if is_paid:
all_suggestions.extend(_generate_success_suggestions(successes))
# 按严重程度排序
severity_order = {"高": 0, "中": 1, "低": 2}
all_suggestions.sort(key=lambda s: severity_order.get(s.get("severity", "低"), 2))
output_success({
"message": f"基于 {len(patterns)} 条学习记录生成了 {len(all_suggestions)} 条建议",
"total_patterns": len(patterns),
"pattern_breakdown": {
"errors": len(errors),
"successes": len(successes),
"corrections": len(corrections),
},
"suggestions": all_suggestions,
})
# ============================================================
# 操作实现:统计
# ============================================================
def stats(data: Optional[Dict[str, Any]] = None) -> None:
"""显示学习数据统计信息。
展示错误率、常见模式、纠正频率等统计数据,
帮助了解系统的学习进度和数据质量。
可选字段: type(按类型过滤: error / success / correction)
Args:
data: 可选的过滤参数。
"""
learning = _get_learning_data()
patterns = learning.get("patterns", [])
if not patterns:
output_success({
"message": "暂无学习数据",
"total_patterns": 0,
"stats": {},
})
return
# 基本统计
type_counts = {"error": 0, "success": 0, "correction": 0}
type_total_events = {"error": 0, "success": 0, "correction": 0}
source_stats: Dict[str, Dict[str, int]] = {}
category_stats: Dict[str, int] = {}
for p in patterns:
ptype = p.get("type", "unknown")
count = p.get("count", 1)
if ptype in type_counts:
type_counts[ptype] += 1
type_total_events[ptype] += count
# 按来源统计
src = p.get("context", {}).get("source", "未指定")
if src not in source_stats:
source_stats[src] = {"error": 0, "success": 0, "correction": 0, "total": 0}
if ptype in source_stats[src]:
source_stats[src][ptype] += count
source_stats[src]["total"] += count
# 按分类统计
cat = p.get("category", "未分类")
category_stats[cat] = category_stats.get(cat, 0) + count
# 计算错误率
total_events = sum(type_total_events.values())
error_rate = 0.0
if total_events > 0:
error_rate = round(type_total_events["error"] / total_events * 100, 1)
# 每个来源的错误率
source_error_rates: Dict[str, float] = {}
for src, st in source_stats.items():
if st["total"] > 0:
source_error_rates[src] = round(st["error"] / st["total"] * 100, 1)
# 最常见的模式(按计数排序)
top_patterns = sorted(patterns, key=lambda p: p.get("count", 1), reverse=True)[:10]
top_summary = []
for p in top_patterns:
top_summary.append({
"type": p.get("type", ""),
"category": p.get("category", ""),
"count": p.get("count", 1),
"source": p.get("context", {}).get("source", ""),
"lesson": p.get("lesson", ""),
})
# 时间范围
first_seen_dates = [p.get("first_seen", "") for p in patterns if p.get("first_seen")]
last_seen_dates = [p.get("last_seen", "") for p in patterns if p.get("last_seen")]
time_range = {}
if first_seen_dates:
time_range["earliest"] = min(first_seen_dates)
if last_seen_dates:
time_range["latest"] = max(last_seen_dates)
result = {
"message": f"共有 {len(patterns)} 个独立模式,{total_events} 次事件记录",
"total_patterns": len(patterns),
"total_events": total_events,
"type_breakdown": {
"patterns": type_counts,
"events": type_total_events,
},
"error_rate": f"{error_rate}%",
"source_stats": source_stats,
"source_error_rates": source_error_rates,
"top_categories": dict(sorted(category_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
"top_patterns": top_summary,
"time_range": time_range,
}
output_success(result)
# ============================================================
# 操作实现:重置
# ============================================================
def reset(data: Optional[Dict[str, Any]] = None) -> None:
"""重置学习数据。
可选择重置全部数据或仅重置指定类型的数据。
此操作不可撤销。
可选字段: type(仅重置指定类型: error / success / correction),
confirm(确认重置,必须为 true)
Args:
data: 可选参数。
"""
if not data or not data.get("confirm"):
output_error(
"重置操作不可撤销,请传入 confirm: true 确认操作",
code="CONFIRMATION_REQUIRED",
)
return
filter_type = ""
if data:
filter_type = data.get("type", "").strip().lower()
learning = _get_learning_data()
patterns = learning.get("patterns", [])
original_count = len(patterns)
if filter_type and filter_type in ("error", "success", "correction"):
# 仅删除指定类型
patterns = [p for p in patterns if p.get("type") != filter_type]
removed = original_count - len(patterns)
learning["patterns"] = patterns
_save_learning_data(learning)
output_success({
"message": f"已重置 {filter_type} 类型的学习数据,移除 {removed} 条记录",
"removed": removed,
"remaining": len(patterns),
})
else:
# 重置全部
learning["patterns"] = []
_save_learning_data(learning)
output_success({
"message": f"已重置全部学习数据,移除 {original_count} 条记录",
"removed": original_count,
"remaining": 0,
})
# ============================================================
# 便捷 API(供其他模块调用)
# ============================================================
def quick_record_error(source: str, action: str, error_type: str, message: str) -> None:
"""快速记录错误(供其他脚本内部调用的便捷方法)。
Args:
source: 平台来源。
action: 触发错误的操作。
error_type: 错误类型。
message: 错误信息。
"""
try:
context = {
"source": source,
"error_type": error_type,
"action": action,
"message": message,
}
_record_pattern("error", error_type, context, message)
except Exception:
# 学习记录失败不应影响主流程
pass
def quick_record_success(source: str, action: str, task_type: str = "") -> None:
"""快速记录成功(供其他脚本内部调用的便捷方法)。
Args:
source: 平台来源。
action: 操作类型。
task_type: 任务类型。
"""
try:
context = {
"source": source,
"action": action,
"task_type": task_type,
}
_record_pattern("success", "operation_success", context, "")
except Exception:
# 学习记录失败不应影响主流程
pass
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("project-nerve 自学习引擎")
args = parser.parse_args()
action = args.action.lower().replace("-", "_")
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"record_error": lambda: record_error(data or {}),
"record_success": lambda: record_success(data or {}),
"record_correction": lambda: record_correction(data or {}),
"suggest": lambda: suggest(data),
"stats": lambda: stats(data),
"reset": lambda: reset(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join([
"record-error", "record-success", "record-correction",
"suggest", "stats", "reset",
])
output_error(
f"未知操作: {args.action},支持的操作: {valid_actions}",
code="INVALID_ACTION",
)
if __name__ == "__main__":
main()
FILE:scripts/task_aggregator.py
#!/usr/bin/env python3
"""
project-nerve 任务聚合器
从所有已连接的平台获取任务数据,统一格式化后提供搜索、阻碍分析和优先级排序功能。
支持 Trello、GitHub Issues、Linear、Notion、Obsidian 五个平台。
"""
import json
import os
import re
import sys
import urllib.request
import urllib.error
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
format_task_table,
get_data_file,
is_overdue,
load_input_data,
normalize_priority,
normalize_status,
now_iso,
output_error,
output_success,
parse_common_args,
parse_date,
read_json_file,
write_json_file,
SUPPORTED_PLATFORMS,
)
# ============================================================
# 数据文件路径
# ============================================================
SOURCES_FILE = "sources.json"
TASKS_CACHE_FILE = "tasks_cache.json"
def _get_sources() -> List[Dict[str, Any]]:
"""读取所有已连接的活跃数据源。"""
data = read_json_file(get_data_file(SOURCES_FILE))
if isinstance(data, list):
return [s for s in data if s.get("status") == "active"]
return []
def _get_cached_tasks() -> List[Dict[str, Any]]:
"""读取缓存的任务数据。"""
data = read_json_file(get_data_file(TASKS_CACHE_FILE))
if isinstance(data, list):
return data
return []
def _save_cached_tasks(tasks: List[Dict[str, Any]]) -> None:
"""保存任务数据到缓存文件。"""
write_json_file(get_data_file(TASKS_CACHE_FILE), tasks)
# ============================================================
# HTTP 请求工具
# ============================================================
def _http_get(url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 15) -> Dict[str, Any]:
"""发送 HTTP GET 请求。
Args:
url: 请求地址。
headers: 请求头。
timeout: 超时秒数。
Returns:
包含 status 和 body 的响应字典。
"""
if headers is None:
headers = {}
req = urllib.request.Request(url, headers=headers, method="GET")
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
body = resp.read().decode("utf-8")
return {"status": resp.status, "body": body}
except urllib.error.HTTPError as e:
body = ""
try:
body = e.read().decode("utf-8")
except Exception:
pass
return {"status": e.code, "body": body, "error": str(e)}
except Exception as e:
return {"status": 0, "body": "", "error": str(e)}
def _http_post(url: str, headers: Optional[Dict[str, str]] = None,
data: Optional[bytes] = None, timeout: int = 15) -> Dict[str, Any]:
"""发送 HTTP POST 请求。"""
if headers is None:
headers = {}
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
body = resp.read().decode("utf-8")
return {"status": resp.status, "body": body}
except urllib.error.HTTPError as e:
body = ""
try:
body = e.read().decode("utf-8")
except Exception:
pass
return {"status": e.code, "body": body, "error": str(e)}
except Exception as e:
return {"status": 0, "body": "", "error": str(e)}
# ============================================================
# 平台适配器 — 任务获取
# ============================================================
def _fetch_trello_tasks(source: Dict[str, Any]) -> List[Dict[str, Any]]:
"""从 Trello 获取卡片并转换为统一格式。
Args:
source: 数据源配置。
Returns:
统一格式的任务列表。
"""
api_key = os.environ.get("PNC_TRELLO_API_KEY", "")
token = os.environ.get("PNC_TRELLO_TOKEN", "")
board_id = source.get("config", {}).get("board_id", "")
if not api_key or not token:
return []
tasks = []
# 获取看板列表(用于映射状态)
lists_map = {}
if board_id:
url = f"https://api.trello.com/1/boards/{board_id}/lists?key={api_key}&token={token}"
resp = _http_get(url)
if resp["status"] == 200:
try:
for lst in json.loads(resp["body"]):
lists_map[lst["id"]] = lst.get("name", "")
except (json.JSONDecodeError, KeyError):
pass
# 获取卡片
if board_id:
url = f"https://api.trello.com/1/boards/{board_id}/cards?key={api_key}&token={token}&fields=name,desc,idList,labels,due,dateLastActivity,shortUrl,idMembers"
resp = _http_get(url)
if resp["status"] == 200:
try:
cards = json.loads(resp["body"])
for card in cards:
list_name = lists_map.get(card.get("idList", ""), "")
labels = [lb.get("name", "") for lb in card.get("labels", []) if lb.get("name")]
# 从标签推断优先级
priority_str = ""
for lb in labels:
if lb.lower() in ("urgent", "high", "medium", "low", "紧急", "高", "中", "低"):
priority_str = lb
break
tasks.append({
"id": f"trello-{card.get('id', '')}",
"source": "trello",
"source_id": card.get("id", ""),
"title": card.get("name", ""),
"description": card.get("desc", ""),
"status": normalize_status(list_name),
"priority": normalize_priority(priority_str),
"assignee": "",
"labels": labels,
"due_date": (card.get("due") or "")[:10] if card.get("due") else "",
"created_at": card.get("dateLastActivity", ""),
"updated_at": card.get("dateLastActivity", ""),
"url": card.get("shortUrl", ""),
})
except (json.JSONDecodeError, KeyError):
pass
return tasks
def _fetch_github_issues(source: Dict[str, Any]) -> List[Dict[str, Any]]:
"""从 GitHub 获取 Issues 并转换为统一格式。
Args:
source: 数据源配置。
Returns:
统一格式的任务列表。
"""
token = os.environ.get("PNC_GITHUB_TOKEN", "")
repo = source.get("config", {}).get("repo", "")
if not token or not repo:
return []
tasks = []
url = f"https://api.github.com/repos/{repo}/issues?state=all&per_page=100&sort=updated&direction=desc"
headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github.v3+json",
"User-Agent": "project-nerve/1.0",
}
resp = _http_get(url, headers=headers)
if resp["status"] == 200:
try:
issues = json.loads(resp["body"])
for issue in issues:
# 跳过 Pull Request(GitHub Issues API 也会返回 PR)
if issue.get("pull_request"):
continue
labels = [lb.get("name", "") for lb in issue.get("labels", []) if lb.get("name")]
# 从标签推断优先级
priority_str = ""
for lb in labels:
lb_lower = lb.lower()
if lb_lower in ("urgent", "high", "medium", "low", "p0", "p1", "p2", "p3", "critical", "blocker"):
priority_str = lb
break
state = issue.get("state", "open")
assignee = ""
if issue.get("assignee"):
assignee = issue["assignee"].get("login", "")
tasks.append({
"id": f"github-{issue.get('id', '')}",
"source": "github",
"source_id": str(issue.get("number", "")),
"title": issue.get("title", ""),
"description": (issue.get("body") or "")[:500],
"status": normalize_status(state),
"priority": normalize_priority(priority_str),
"assignee": assignee,
"labels": labels,
"due_date": "",
"created_at": issue.get("created_at", ""),
"updated_at": issue.get("updated_at", ""),
"url": issue.get("html_url", ""),
})
except (json.JSONDecodeError, KeyError):
pass
return tasks
def _fetch_linear_issues(source: Dict[str, Any]) -> List[Dict[str, Any]]:
"""从 Linear 获取 Issues 并转换为统一格式。
Args:
source: 数据源配置。
Returns:
统一格式的任务列表。
"""
api_key = os.environ.get("PNC_LINEAR_API_KEY", "")
if not api_key:
return []
tasks = []
team_id = source.get("config", {}).get("team_id", "")
# 构建 GraphQL 查询
team_filter = ""
if team_id:
team_filter = f', filter: {{ team: {{ id: {{ eq: "{team_id}" }} }} }}'
query_str = json.dumps({
"query": f"""{{
issues(first: 100, orderBy: updatedAt{team_filter}) {{
nodes {{
id
identifier
title
description
priority
state {{ name }}
assignee {{ name }}
labels {{ nodes {{ name }} }}
dueDate
createdAt
updatedAt
url
}}
}}
}}"""
})
url = "https://api.linear.app/graphql"
headers = {
"Authorization": api_key,
"Content-Type": "application/json",
}
resp = _http_post(url, headers=headers, data=query_str.encode("utf-8"))
if resp["status"] == 200:
try:
result = json.loads(resp["body"])
issues = result.get("data", {}).get("issues", {}).get("nodes", [])
# Linear 优先级数值映射: 0=无, 1=紧急, 2=高, 3=中, 4=低
linear_priority_map = {0: "低", 1: "紧急", 2: "高", 3: "中", 4: "低"}
for issue in issues:
priority_num = issue.get("priority", 0)
state_name = ""
if issue.get("state"):
state_name = issue["state"].get("name", "")
assignee_name = ""
if issue.get("assignee"):
assignee_name = issue["assignee"].get("name", "")
labels = []
if issue.get("labels") and issue["labels"].get("nodes"):
labels = [lb.get("name", "") for lb in issue["labels"]["nodes"] if lb.get("name")]
tasks.append({
"id": f"linear-{issue.get('id', '')}",
"source": "linear",
"source_id": issue.get("identifier", ""),
"title": issue.get("title", ""),
"description": (issue.get("description") or "")[:500],
"status": normalize_status(state_name),
"priority": linear_priority_map.get(priority_num, "中"),
"assignee": assignee_name,
"labels": labels,
"due_date": (issue.get("dueDate") or "")[:10],
"created_at": issue.get("createdAt", ""),
"updated_at": issue.get("updatedAt", ""),
"url": issue.get("url", ""),
})
except (json.JSONDecodeError, KeyError):
pass
return tasks
def _fetch_notion_tasks(source: Dict[str, Any]) -> List[Dict[str, Any]]:
"""从 Notion 数据库获取页面并转换为统一格式。
Args:
source: 数据源配置。
Returns:
统一格式的任务列表。
"""
token = os.environ.get("PNC_NOTION_TOKEN", "")
database_id = source.get("config", {}).get("database_id", "") or os.environ.get("PNC_NOTION_DATABASE_ID", "")
if not token or not database_id:
return []
tasks = []
url = f"https://api.notion.com/v1/databases/{database_id}/query"
headers = {
"Authorization": f"Bearer {token}",
"Notion-Version": "2022-06-28",
"Content-Type": "application/json",
}
body = json.dumps({"page_size": 100}).encode("utf-8")
resp = _http_post(url, headers=headers, data=body)
if resp["status"] == 200:
try:
result = json.loads(resp["body"])
pages = result.get("results", [])
for page in pages:
props = page.get("properties", {})
# 提取标题(尝试常见属性名)
title = ""
for key in ("Name", "名称", "Title", "标题", "Task", "任务"):
prop = props.get(key, {})
if prop.get("type") == "title":
title_parts = prop.get("title", [])
title = "".join(p.get("plain_text", "") for p in title_parts)
break
# 提取状态
status_str = ""
for key in ("Status", "状态", "Stage", "阶段"):
prop = props.get(key, {})
if prop.get("type") == "status":
status_obj = prop.get("status")
if status_obj:
status_str = status_obj.get("name", "")
break
elif prop.get("type") == "select":
select_obj = prop.get("select")
if select_obj:
status_str = select_obj.get("name", "")
break
# 提取优先级
priority_str = ""
for key in ("Priority", "优先级"):
prop = props.get(key, {})
if prop.get("type") == "select":
select_obj = prop.get("select")
if select_obj:
priority_str = select_obj.get("name", "")
break
# 提取负责人
assignee = ""
for key in ("Assignee", "负责人", "Owner"):
prop = props.get(key, {})
if prop.get("type") == "people":
people = prop.get("people", [])
if people:
assignee = people[0].get("name", "")
break
# 提取截止日期
due_date = ""
for key in ("Due", "截止日期", "Due Date", "Deadline"):
prop = props.get(key, {})
if prop.get("type") == "date":
date_obj = prop.get("date")
if date_obj:
due_date = (date_obj.get("start") or "")[:10]
break
# 提取标签
labels = []
for key in ("Tags", "标签", "Labels"):
prop = props.get(key, {})
if prop.get("type") == "multi_select":
for opt in prop.get("multi_select", []):
labels.append(opt.get("name", ""))
break
page_url = page.get("url", "")
created_time = page.get("created_time", "")
updated_time = page.get("last_edited_time", "")
tasks.append({
"id": f"notion-{page.get('id', '')}",
"source": "notion",
"source_id": page.get("id", ""),
"title": title,
"description": "",
"status": normalize_status(status_str),
"priority": normalize_priority(priority_str),
"assignee": assignee,
"labels": labels,
"due_date": due_date,
"created_at": created_time,
"updated_at": updated_time,
"url": page_url,
})
except (json.JSONDecodeError, KeyError):
pass
return tasks
def _parse_obsidian_frontmatter(content: str) -> Dict[str, str]:
"""解析 Obsidian 笔记的 YAML frontmatter。
提取 frontmatter 中的 status、priority、assignee、due_date 字段。
Args:
content: 笔记完整文本内容。
Returns:
包含解析到的字段的字典。
"""
result: Dict[str, str] = {}
if not content.startswith("---"):
return result
parts = content.split("---", 2)
if len(parts) < 3:
return result
frontmatter = parts[1]
for line in frontmatter.strip().split("\n"):
line = line.strip()
if ":" not in line:
continue
key, _, value = line.partition(":")
key = key.strip().lower()
value = value.strip().strip('"').strip("'")
if key in ("status", "priority", "assignee", "due_date", "due"):
mapped_key = "due_date" if key == "due" else key
result[mapped_key] = value
return result
def _fetch_obsidian_tasks(source: Dict[str, Any]) -> List[Dict[str, Any]]:
"""从 Obsidian Vault 扫描任务。
扫描 vault 中所有 .md 文件,提取 markdown 复选框格式的任务:
- [ ] 未完成任务
- [x] 已完成任务
同时解析 frontmatter 中的 priority、assignee、due_date 等字段。
Args:
source: 数据源配置。
Returns:
统一格式的任务列表。
"""
vault_path = source.get("config", {}).get("vault_path", "") or os.environ.get("PNC_OBSIDIAN_VAULT_PATH", "")
task_tag = source.get("config", {}).get("task_tag", "#task")
if not vault_path:
return []
vault_path = os.path.expanduser(vault_path)
if not os.path.isdir(vault_path):
return []
tasks = []
# 用于生成唯一 ID 的计数器
task_counter = 0
for root, _dirs, files in os.walk(vault_path):
# 跳过 .obsidian 配置目录
if ".obsidian" in root:
continue
for fname in files:
if not fname.endswith(".md"):
continue
fpath = os.path.join(root, fname)
try:
with open(fpath, "r", encoding="utf-8") as f:
content = f.read()
except (IOError, UnicodeDecodeError):
continue
# 解析 frontmatter 获取全局属性
fm = _parse_obsidian_frontmatter(content)
# 计算相对路径(作为来源标识)
rel_path = os.path.relpath(fpath, vault_path)
note_title = os.path.splitext(fname)[0]
# 获取文件修改时间
try:
mtime = datetime.fromtimestamp(os.path.getmtime(fpath)).strftime("%Y-%m-%dT%H:%M:%S")
except (OSError, ValueError):
mtime = ""
# 扫描 markdown 复选框
checkbox_pattern = re.compile(r"^(\s*)-\s+\[([ xX])\]\s+(.+)$", re.MULTILINE)
for match in checkbox_pattern.finditer(content):
checked = match.group(2).lower() == "x"
task_text = match.group(3).strip()
# 跳过空任务
if not task_text:
continue
task_counter += 1
task_id = f"obsidian-{task_counter}-{hash(fpath + task_text) & 0xFFFFFFFF:08x}"
# 从任务文本中提取内联标签
labels = re.findall(r"#(\w+)", task_text)
# 移除标签后的纯文本作为标题
clean_title = re.sub(r"\s*#\w+", "", task_text).strip()
if not clean_title:
clean_title = task_text
# 尝试从任务文本中提取截止日期(如 📅 2026-03-20 或 due:2026-03-20)
due_date = fm.get("due_date", "")
due_match = re.search(r"(?:📅|due:|截止:?)\s*(\d{4}-\d{2}-\d{2})", task_text)
if due_match:
due_date = due_match.group(1)
tasks.append({
"id": task_id,
"source": "obsidian",
"source_id": rel_path,
"title": clean_title,
"description": f"来自笔记: {note_title}",
"status": "已完成" if checked else normalize_status(fm.get("status", "待办")),
"priority": normalize_priority(fm.get("priority", "")),
"assignee": fm.get("assignee", ""),
"labels": labels,
"due_date": due_date,
"created_at": mtime,
"updated_at": mtime,
"url": f"obsidian://open?vault={os.path.basename(vault_path)}&file={rel_path}",
})
return tasks
# ============================================================
# 平台获取路由
# ============================================================
_PLATFORM_FETCHERS = {
"trello": _fetch_trello_tasks,
"github": _fetch_github_issues,
"linear": _fetch_linear_issues,
"notion": _fetch_notion_tasks,
"obsidian": _fetch_obsidian_tasks,
}
# ============================================================
# 去重逻辑
# ============================================================
def _word_set(text: str) -> set:
"""提取文本中的词集合(中文按字,英文按单词)。
Args:
text: 输入文本。
Returns:
词集合。
"""
if not text:
return set()
# 提取英文单词和中文字符
words = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
# 添加中文字符
for ch in text:
if "\u4e00" <= ch <= "\u9fff":
words.add(ch)
return words
def _title_similarity(t1: str, t2: str) -> float:
"""计算两个标题的相似度(基于词重叠比率)。
Args:
t1: 标题一。
t2: 标题二。
Returns:
相似度 0.0~1.0。
"""
s1 = _word_set(t1)
s2 = _word_set(t2)
if not s1 or not s2:
return 0.0
intersection = s1 & s2
union = s1 | s2
return len(intersection) / len(union) if union else 0.0
def _dedup_tasks(tasks: List[Dict[str, Any]], threshold: float = 0.8) -> List[Dict[str, Any]]:
"""对任务列表去重(基于标题相似度)。
当两个来自不同平台的任务标题相似度超过阈值时,保留更新时间较晚的那个。
Args:
tasks: 原始任务列表。
threshold: 相似度阈值。
Returns:
去重后的任务列表。
"""
if not tasks:
return tasks
result = []
seen_indices = set()
for i in range(len(tasks)):
if i in seen_indices:
continue
best = tasks[i]
for j in range(i + 1, len(tasks)):
if j in seen_indices:
continue
# 只对不同平台的任务做去重
if tasks[i]["source"] == tasks[j]["source"]:
continue
sim = _title_similarity(tasks[i]["title"], tasks[j]["title"])
if sim >= threshold:
seen_indices.add(j)
# 保留更新时间较晚的
if tasks[j].get("updated_at", "") > best.get("updated_at", ""):
best = tasks[j]
result.append(best)
return result
# ============================================================
# 操作实现
# ============================================================
def fetch_all(data: Optional[Dict[str, Any]] = None) -> None:
"""从所有已连接平台获取任务并缓存。
Args:
data: 可选参数(platform 过滤)。
"""
sources = _get_sources()
if not sources:
output_error("暂无已连接的数据源,请先使用 source_connector 连接平台", code="NO_SOURCES")
return
platform_filter = ""
if data:
platform_filter = data.get("platform", "").strip().lower()
all_tasks = []
fetch_errors = []
# 尝试导入自学习引擎(可选依赖)
try:
from learning_engine import quick_record_error, quick_record_success
has_learning = True
except ImportError:
has_learning = False
for source in sources:
platform = source.get("platform", "")
if platform_filter and platform != platform_filter:
continue
fetcher = _PLATFORM_FETCHERS.get(platform)
if not fetcher:
fetch_errors.append(f"平台 {platform} 暂不支持")
continue
try:
tasks = fetcher(source)
all_tasks.extend(tasks)
# 记录成功
if has_learning and tasks:
quick_record_success(platform, "fetch", f"获取到 {len(tasks)} 个任务")
except Exception as e:
fetch_errors.append(f"{platform} 获取失败: {e}")
# 记录错误
if has_learning:
quick_record_error(platform, "fetch", "fetch_failure", str(e))
# 去重
all_tasks = _dedup_tasks(all_tasks)
# 按更新时间倒序排列
all_tasks.sort(key=lambda t: t.get("updated_at", ""), reverse=True)
# 检查显示限制
sub = check_subscription()
display_limit = sub["max_tasks_display"]
truncated = len(all_tasks) > display_limit
display_tasks = all_tasks[:display_limit]
# 缓存所有任务
_save_cached_tasks(all_tasks)
# 统计
status_stats = {}
for task in all_tasks:
status = task.get("status", "待办")
status_stats[status] = status_stats.get(status, 0) + 1
source_stats = {}
for task in all_tasks:
src = task.get("source", "未知")
source_stats[src] = source_stats.get(src, 0) + 1
result = {
"total": len(all_tasks),
"displayed": len(display_tasks),
"truncated": truncated,
"status_stats": status_stats,
"source_stats": source_stats,
"tasks": display_tasks,
"table": format_task_table(display_tasks),
}
if fetch_errors:
result["warnings"] = fetch_errors
output_success(result)
def search_tasks(data: Dict[str, Any]) -> None:
"""在缓存的任务中搜索。
搜索字段: keyword(标题/描述)、status、priority、source、assignee
Args:
data: 搜索条件字典。
"""
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 fetch-all 获取任务", code="NO_DATA")
return
keyword = data.get("keyword", "").strip().lower()
status_filter = data.get("status", "").strip()
priority_filter = data.get("priority", "").strip()
source_filter = data.get("source", "").strip().lower()
assignee_filter = data.get("assignee", "").strip().lower()
filtered = tasks
if keyword:
filtered = [
t for t in filtered
if keyword in t.get("title", "").lower()
or keyword in t.get("description", "").lower()
or keyword in " ".join(t.get("labels", [])).lower()
]
if status_filter:
normalized = normalize_status(status_filter)
filtered = [t for t in filtered if t.get("status") == normalized]
if priority_filter:
normalized = normalize_priority(priority_filter)
filtered = [t for t in filtered if t.get("priority") == normalized]
if source_filter:
filtered = [t for t in filtered if t.get("source") == source_filter]
if assignee_filter:
filtered = [t for t in filtered if assignee_filter in (t.get("assignee") or "").lower()]
sub = check_subscription()
display_limit = sub["max_tasks_display"]
display_tasks = filtered[:display_limit]
output_success({
"total": len(filtered),
"displayed": len(display_tasks),
"query": {k: v for k, v in data.items() if v},
"tasks": display_tasks,
"table": format_task_table(display_tasks),
})
def find_blockers(data: Optional[Dict[str, Any]] = None) -> None:
"""查找阻碍任务(逾期或高优先级进行中任务)。
阻碍判定条件:
1. 已逾期的未完成任务
2. 优先级为「紧急」或「高」且状态为「进行中」超过 7 天的任务
Args:
data: 可选参数。
"""
if not require_blocker_feature():
return
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 fetch-all 获取任务", code="NO_DATA")
return
blockers = []
now = datetime.now()
for task in tasks:
status = task.get("status", "")
if status in ("已完成", "已关闭"):
continue
reasons = []
# 条件1:逾期
due_date = task.get("due_date", "")
if due_date and is_overdue(due_date):
reasons.append(f"已逾期(截止日期: {due_date})")
# 条件2:高优先级长时间进行中
priority = task.get("priority", "")
if priority in ("紧急", "高") and status == "进行中":
updated = parse_date(task.get("updated_at", ""))
if updated and (now - updated).days > 7:
reasons.append(f"高优先级任务进行中超过 7 天(上次更新: {task.get('updated_at', '')[:10]})")
if reasons:
blocker = dict(task)
blocker["blocker_reasons"] = reasons
blockers.append(blocker)
# 按优先级排序:紧急 > 高 > 中 > 低
priority_order = {"紧急": 0, "高": 1, "中": 2, "低": 3}
blockers.sort(key=lambda t: priority_order.get(t.get("priority", "中"), 2))
output_success({
"total": len(blockers),
"blockers": blockers,
"table": format_task_table(blockers),
"summary": f"发现 {len(blockers)} 个阻碍/风险任务" if blockers else "未发现阻碍任务",
})
def require_blocker_feature() -> bool:
"""检查阻碍分析功能的订阅要求。"""
sub = check_subscription()
if "blocker_analysis" not in sub["features"]:
output_error(
"「阻碍分析」为付费版功能。当前为免费版,请升级至付费版(¥99/月)以使用此功能。",
code="SUBSCRIPTION_REQUIRED",
)
return False
return True
def sort_by_priority(data: Optional[Dict[str, Any]] = None) -> None:
"""按优先级排序所有任务(逾期优先,然后按优先级)。
Args:
data: 可选参数(可指定 status 过滤)。
"""
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 fetch-all 获取任务", code="NO_DATA")
return
# 过滤已完成/已关闭的任务
active_tasks = [t for t in tasks if t.get("status") not in ("已完成", "已关闭")]
if data and data.get("status"):
normalized = normalize_status(data["status"])
active_tasks = [t for t in active_tasks if t.get("status") == normalized]
# 排序规则:逾期在前,然后按优先级
priority_order = {"紧急": 0, "高": 1, "中": 2, "低": 3}
def sort_key(task: Dict) -> tuple:
overdue_score = 0
due = task.get("due_date", "")
if due and is_overdue(due):
overdue_score = -1 # 逾期排在前面
prio = priority_order.get(task.get("priority", "中"), 2)
return (overdue_score, prio)
active_tasks.sort(key=sort_key)
sub = check_subscription()
display_limit = sub["max_tasks_display"]
display_tasks = active_tasks[:display_limit]
output_success({
"total": len(active_tasks),
"displayed": len(display_tasks),
"tasks": display_tasks,
"table": format_task_table(display_tasks),
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("project-nerve 任务聚合器")
args = parser.parse_args()
action = args.action.lower().replace("-", "_")
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"fetch_all": lambda: fetch_all(data),
"fetch": lambda: fetch_all(data),
"search": lambda: search_tasks(data or {}),
"blockers": lambda: find_blockers(data),
"priorities": lambda: sort_by_priority(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(["fetch-all", "search", "blockers", "priorities"])
output_error(f"未知操作: {args.action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/task_graph.py
#!/usr/bin/env python3
"""
project-nerve 任务关系图谱
建模跨平台任务之间的依赖关系,支持关系查询、依赖分析、
影响评估和 Mermaid 可视化。
灵感来源: ontology (117K 下载量, 326 星)
"""
import json
import os
import sys
from collections import deque
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
write_json_file,
)
# ============================================================
# 常量定义
# ============================================================
GRAPH_FILE = "task_graph.json"
# 支持的关系类型
RELATION_TYPES = [
"blocks", # A 阻塞 B(A 完成后 B 才能开始)
"blocked_by", # A 被 B 阻塞(B 完成后 A 才能开始)
"related_to", # A 与 B 相关(无方向性依赖)
"parent_of", # A 是 B 的父任务
"child_of", # A 是 B 的子任务
"duplicates", # A 是 B 的重复
]
# 互为反向的关系映射
_INVERSE_RELATIONS = {
"blocks": "blocked_by",
"blocked_by": "blocks",
"parent_of": "child_of",
"child_of": "parent_of",
"related_to": "related_to",
"duplicates": "duplicates",
}
# ============================================================
# 数据读写
# ============================================================
def _get_graph() -> Dict[str, Any]:
"""读取任务关系图谱数据。
Returns:
包含 nodes 和 edges 列表的字典。
"""
data = read_json_file(get_data_file(GRAPH_FILE))
if isinstance(data, dict) and "nodes" in data and "edges" in data:
return data
return {
"nodes": [],
"edges": [],
"metadata": {"created_at": now_iso(), "version": "1.0"},
}
def _save_graph(graph: Dict[str, Any]) -> None:
"""保存图谱数据到文件。
Args:
graph: 图谱数据字典。
"""
graph["metadata"] = graph.get("metadata", {})
graph["metadata"]["updated_at"] = now_iso()
write_json_file(get_data_file(GRAPH_FILE), graph)
# ============================================================
# 节点管理
# ============================================================
def _find_node(nodes: List[Dict], node_id: str) -> Optional[Dict]:
"""查找节点。
Args:
nodes: 节点列表。
node_id: 节点 ID。
Returns:
匹配的节点字典,未找到返回 None。
"""
for n in nodes:
if n.get("id") == node_id:
return n
return None
def _ensure_node(graph: Dict[str, Any], node_id: str, source: str = "", title: str = "") -> Dict:
"""确保节点存在,若不存在则创建。
Args:
graph: 图谱数据。
node_id: 节点 ID。
source: 任务来源平台。
title: 任务标题。
Returns:
节点字典。
"""
node = _find_node(graph["nodes"], node_id)
if node is None:
node = {
"id": node_id,
"source": source,
"title": title,
"added_at": now_iso(),
}
graph["nodes"].append(node)
else:
# 更新信息(如果提供了新的值)
if source and not node.get("source"):
node["source"] = source
if title and not node.get("title"):
node["title"] = title
return node
def _find_edge(edges: List[Dict], from_id: str, to_id: str, rel_type: str) -> Optional[int]:
"""查找边的索引。
Args:
edges: 边列表。
from_id: 起始节点 ID。
to_id: 目标节点 ID。
rel_type: 关系类型。
Returns:
匹配的边在列表中的索引,未找到返回 None。
"""
for i, e in enumerate(edges):
if e.get("from") == from_id and e.get("to") == to_id and e.get("type") == rel_type:
return i
return None
# ============================================================
# 操作实现:添加关系
# ============================================================
def add_relation(data: Dict[str, Any]) -> None:
"""添加任务关系。
在两个任务节点之间建立关系边。
支持跨平台关系(如 GitHub Issue 阻塞 Trello 卡片)。
必填字段: from_id(起始任务 ID), to_id(目标任务 ID), type(关系类型)
可选字段: from_source, from_title, to_source, to_title
关系类型: blocks, blocked_by, related_to, parent_of, child_of, duplicates
Args:
data: 关系数据字典。
"""
from_id = data.get("from_id", "").strip()
to_id = data.get("to_id", "").strip()
rel_type = data.get("type", "").strip().lower()
if not from_id or not to_id:
output_error("起始任务 ID(from_id)和目标任务 ID(to_id)为必填字段", code="VALIDATION_ERROR")
return
if not rel_type:
output_error("关系类型(type)为必填字段", code="VALIDATION_ERROR")
return
if rel_type not in RELATION_TYPES:
valid = "、".join(RELATION_TYPES)
output_error(f"不支持的关系类型: {rel_type},支持的类型: {valid}", code="INVALID_TYPE")
return
if from_id == to_id:
output_error("不能创建自引用关系(from_id 和 to_id 不能相同)", code="VALIDATION_ERROR")
return
graph = _get_graph()
# 确保节点存在
_ensure_node(graph, from_id, data.get("from_source", ""), data.get("from_title", ""))
_ensure_node(graph, to_id, data.get("to_source", ""), data.get("to_title", ""))
# 检查是否已存在
existing = _find_edge(graph["edges"], from_id, to_id, rel_type)
if existing is not None:
output_error(
f"关系已存在: {from_id} --[{rel_type}]--> {to_id}",
code="DUPLICATE_RELATION",
)
return
# 创建边
edge = {
"from": from_id,
"to": to_id,
"type": rel_type,
"created_at": now_iso(),
}
graph["edges"].append(edge)
_save_graph(graph)
output_success({
"message": f"已添加关系: {from_id} --[{rel_type}]--> {to_id}",
"edge": edge,
"total_nodes": len(graph["nodes"]),
"total_edges": len(graph["edges"]),
})
# ============================================================
# 操作实现:删除关系
# ============================================================
def remove_relation(data: Dict[str, Any]) -> None:
"""删除任务关系。
必填字段: from_id, to_id, type
Args:
data: 关系标识字典。
"""
from_id = data.get("from_id", "").strip()
to_id = data.get("to_id", "").strip()
rel_type = data.get("type", "").strip().lower()
if not from_id or not to_id or not rel_type:
output_error("from_id、to_id 和 type 为必填字段", code="VALIDATION_ERROR")
return
graph = _get_graph()
idx = _find_edge(graph["edges"], from_id, to_id, rel_type)
if idx is None:
output_error(
f"未找到关系: {from_id} --[{rel_type}]--> {to_id}",
code="NOT_FOUND",
)
return
removed = graph["edges"].pop(idx)
_save_graph(graph)
output_success({
"message": f"已删除关系: {from_id} --[{rel_type}]--> {to_id}",
"removed_edge": removed,
"remaining_edges": len(graph["edges"]),
})
# ============================================================
# 操作实现:查询
# ============================================================
def _build_adjacency(edges: List[Dict]) -> Dict[str, List[Dict]]:
"""构建邻接表(双向)。
Args:
edges: 边列表。
Returns:
邻接表字典,键为节点 ID,值为相关边的列表。
"""
adj: Dict[str, List[Dict]] = {}
for e in edges:
from_id = e.get("from", "")
to_id = e.get("to", "")
if from_id:
if from_id not in adj:
adj[from_id] = []
adj[from_id].append(e)
if to_id:
if to_id not in adj:
adj[to_id] = []
adj[to_id].append(e)
return adj
def query(data: Dict[str, Any]) -> None:
"""查询与指定任务相关的所有任务(BFS 遍历)。
从给定任务出发,通过 BFS 遍历找出所有直接和间接相关的任务。
必填字段: task_id(要查询的任务 ID)
可选字段: max_depth(最大遍历深度,默认 3), type(过滤关系类型)
Args:
data: 查询参数字典。
"""
task_id = data.get("task_id", "").strip()
if not task_id:
output_error("任务 ID(task_id)为必填字段", code="VALIDATION_ERROR")
return
max_depth = int(data.get("max_depth", 3))
type_filter = data.get("type", "").strip().lower()
graph = _get_graph()
edges = graph["edges"]
nodes = graph["nodes"]
# 如果指定了关系类型过滤
if type_filter:
edges = [e for e in edges if e.get("type") == type_filter]
adj = _build_adjacency(edges)
# BFS 遍历
visited: Set[str] = set()
queue: deque = deque()
queue.append((task_id, 0))
visited.add(task_id)
related_tasks: List[Dict[str, Any]] = []
related_edges: List[Dict[str, Any]] = []
while queue:
current_id, depth = queue.popleft()
if depth >= max_depth:
continue
for edge in adj.get(current_id, []):
# 确定另一端的节点
other_id = edge["to"] if edge["from"] == current_id else edge["from"]
related_edges.append({
"from": edge.get("from", ""),
"to": edge.get("to", ""),
"type": edge.get("type", ""),
"depth": depth + 1,
})
if other_id not in visited:
visited.add(other_id)
node = _find_node(nodes, other_id)
related_tasks.append({
"id": other_id,
"source": node.get("source", "") if node else "",
"title": node.get("title", "") if node else "",
"depth": depth + 1,
})
queue.append((other_id, depth + 1))
# 去重边
seen_edges: Set[str] = set()
unique_edges = []
for e in related_edges:
key = f"{e['from']}|{e['to']}|{e['type']}"
if key not in seen_edges:
seen_edges.add(key)
unique_edges.append(e)
output_success({
"task_id": task_id,
"related_tasks": related_tasks,
"related_edges": unique_edges,
"total_related": len(related_tasks),
"max_depth": max_depth,
})
# ============================================================
# 操作实现:依赖分析
# ============================================================
def _detect_cycles(edges: List[Dict], relation_types: Optional[List[str]] = None) -> List[List[str]]:
"""检测有向图中的环。
使用 DFS 检测循环依赖。
Args:
edges: 边列表。
relation_types: 要检查的关系类型,默认检查 blocks 和 parent_of。
Returns:
环路列表,每个环路是节点 ID 的列表。
"""
if relation_types is None:
relation_types = ["blocks", "parent_of"]
# 构建有向邻接表
directed_adj: Dict[str, List[str]] = {}
for e in edges:
if e.get("type") in relation_types:
from_id = e.get("from", "")
to_id = e.get("to", "")
if from_id not in directed_adj:
directed_adj[from_id] = []
directed_adj[from_id].append(to_id)
cycles: List[List[str]] = []
visited: Set[str] = set()
rec_stack: Set[str] = set()
path: List[str] = []
def _dfs(node: str) -> None:
"""深度优先搜索检测环。"""
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in directed_adj.get(node, []):
if neighbor not in visited:
_dfs(neighbor)
elif neighbor in rec_stack:
# 找到环
cycle_start = path.index(neighbor)
cycle = path[cycle_start:] + [neighbor]
cycles.append(cycle)
path.pop()
rec_stack.discard(node)
all_nodes = set()
for e in edges:
if e.get("type") in relation_types:
all_nodes.add(e.get("from", ""))
all_nodes.add(e.get("to", ""))
for node in all_nodes:
if node and node not in visited:
_dfs(node)
return cycles
def dependencies(data: Dict[str, Any]) -> None:
"""构建依赖树并检测循环依赖。
分析任务之间的 blocks/blocked_by 关系,构建依赖树,
并检测是否存在循环依赖。
必填字段: task_id
可选字段: direction(up 向上追溯依赖 / down 向下展开被阻塞任务,默认 both)
Args:
data: 查询参数字典。
"""
task_id = data.get("task_id", "").strip()
if not task_id:
output_error("任务 ID(task_id)为必填字段", code="VALIDATION_ERROR")
return
direction = data.get("direction", "both").strip().lower()
graph = _get_graph()
edges = graph["edges"]
nodes = graph["nodes"]
# 只看 blocks 类型的关系
block_edges = [e for e in edges if e.get("type") in ("blocks", "blocked_by")]
# 构建有向图(统一为 blocks 方向)
blocks_adj: Dict[str, List[str]] = {} # 阻塞者 -> 被阻塞者列表
blocked_by_adj: Dict[str, List[str]] = {} # 被阻塞者 -> 阻塞者列表
for e in block_edges:
if e.get("type") == "blocks":
blocker = e["from"]
blocked = e["to"]
else: # blocked_by
blocker = e["to"]
blocked = e["from"]
if blocker not in blocks_adj:
blocks_adj[blocker] = []
blocks_adj[blocker].append(blocked)
if blocked not in blocked_by_adj:
blocked_by_adj[blocked] = []
blocked_by_adj[blocked].append(blocker)
# 向上追溯:这个任务被哪些任务阻塞
upstream: List[Dict[str, Any]] = []
if direction in ("up", "both"):
visited: Set[str] = set()
queue: deque = deque()
queue.append((task_id, 0))
visited.add(task_id)
while queue:
current, depth = queue.popleft()
for blocker_id in blocked_by_adj.get(current, []):
if blocker_id not in visited:
visited.add(blocker_id)
node = _find_node(nodes, blocker_id)
upstream.append({
"id": blocker_id,
"source": node.get("source", "") if node else "",
"title": node.get("title", "") if node else "",
"depth": depth + 1,
})
queue.append((blocker_id, depth + 1))
# 向下展开:这个任务阻塞了哪些任务
downstream: List[Dict[str, Any]] = []
if direction in ("down", "both"):
visited_down: Set[str] = set()
queue_down: deque = deque()
queue_down.append((task_id, 0))
visited_down.add(task_id)
while queue_down:
current, depth = queue_down.popleft()
for blocked_id in blocks_adj.get(current, []):
if blocked_id not in visited_down:
visited_down.add(blocked_id)
node = _find_node(nodes, blocked_id)
downstream.append({
"id": blocked_id,
"source": node.get("source", "") if node else "",
"title": node.get("title", "") if node else "",
"depth": depth + 1,
})
queue_down.append((blocked_id, depth + 1))
# 检测循环依赖
cycles = _detect_cycles(edges)
task_in_cycle = any(task_id in cycle for cycle in cycles)
output_success({
"task_id": task_id,
"upstream_dependencies": upstream,
"downstream_blocked": downstream,
"total_upstream": len(upstream),
"total_downstream": len(downstream),
"has_circular_dependency": task_in_cycle,
"circular_dependencies": [c for c in cycles if task_id in c],
})
# ============================================================
# 操作实现:影响分析
# ============================================================
def impact(data: Dict[str, Any]) -> None:
"""分析阻塞一个任务会影响多少下游任务。
计算如果指定任务被阻塞,有多少任务会受到影响(直接和间接)。
必填字段: task_id
Args:
data: 查询参数字典。
"""
task_id = data.get("task_id", "").strip()
if not task_id:
output_error("任务 ID(task_id)为必填字段", code="VALIDATION_ERROR")
return
graph = _get_graph()
edges = graph["edges"]
nodes = graph["nodes"]
# 构建 blocks 方向的有向图
blocks_adj: Dict[str, List[str]] = {}
for e in edges:
if e.get("type") == "blocks":
from_id = e["from"]
to_id = e["to"]
if from_id not in blocks_adj:
blocks_adj[from_id] = []
blocks_adj[from_id].append(to_id)
elif e.get("type") == "blocked_by":
from_id = e["to"] # 反转方向
to_id = e["from"]
if from_id not in blocks_adj:
blocks_adj[from_id] = []
blocks_adj[from_id].append(to_id)
# BFS 计算所有受影响的下游任务
affected: List[Dict[str, Any]] = []
visited: Set[str] = {task_id}
queue: deque = deque()
# 从直接被阻塞的任务开始
for blocked_id in blocks_adj.get(task_id, []):
if blocked_id not in visited:
visited.add(blocked_id)
queue.append((blocked_id, 1))
while queue:
current, depth = queue.popleft()
node = _find_node(nodes, current)
affected.append({
"id": current,
"source": node.get("source", "") if node else "",
"title": node.get("title", "") if node else "",
"impact_depth": depth,
"impact_type": "直接" if depth == 1 else "间接",
})
for next_id in blocks_adj.get(current, []):
if next_id not in visited:
visited.add(next_id)
queue.append((next_id, depth + 1))
# 按影响深度分组统计
depth_stats: Dict[int, int] = {}
for a in affected:
d = a["impact_depth"]
depth_stats[d] = depth_stats.get(d, 0) + 1
# 影响等级评估
impact_level = "无"
total_affected = len(affected)
if total_affected >= 5:
impact_level = "严重"
elif total_affected >= 3:
impact_level = "较大"
elif total_affected >= 1:
impact_level = "一般"
output_success({
"task_id": task_id,
"total_affected": total_affected,
"impact_level": impact_level,
"affected_tasks": affected,
"depth_distribution": depth_stats,
"summary": (
f"如果 {task_id} 被阻塞,将影响 {total_affected} 个下游任务"
if total_affected > 0
else f"{task_id} 没有下游阻塞任务"
),
})
# ============================================================
# 操作实现:可视化(付费功能)
# ============================================================
def visualize(data: Optional[Dict[str, Any]] = None) -> None:
"""生成任务关系的 Mermaid 流程图(付费功能)。
将图谱中的任务关系可视化为 Mermaid flowchart 格式。
可选字段: task_id(聚焦某个任务的子图), type(过滤关系类型)
Args:
data: 可选的过滤参数。
"""
if not require_paid_feature("mermaid_chart", "任务关系图谱可视化"):
return
graph = _get_graph()
edges = graph["edges"]
nodes = graph["nodes"]
if not edges:
output_success({
"message": "图谱中暂无关系数据",
"mermaid": "",
})
return
task_filter = ""
type_filter = ""
if data:
task_filter = data.get("task_id", "").strip()
type_filter = data.get("type", "").strip().lower()
# 过滤
filtered_edges = edges
if type_filter:
filtered_edges = [e for e in filtered_edges if e.get("type") == type_filter]
if task_filter:
# 只保留与指定任务相关的边
filtered_edges = [
e for e in filtered_edges
if e.get("from") == task_filter or e.get("to") == task_filter
]
if not filtered_edges:
output_success({
"message": "过滤后无匹配的关系数据",
"mermaid": "",
})
return
# 生成 Mermaid 代码
lines = ["```mermaid", "flowchart TD"]
# 收集涉及的节点
involved_nodes: Set[str] = set()
for e in filtered_edges:
involved_nodes.add(e.get("from", ""))
involved_nodes.add(e.get("to", ""))
# 节点定义(使用方括号表示节点,显示标题)
node_map: Dict[str, str] = {}
for node_id in involved_nodes:
if not node_id:
continue
node = _find_node(nodes, node_id)
# 生成安全的 Mermaid 节点 ID(替换特殊字符)
safe_id = node_id.replace("-", "_").replace(".", "_").replace(" ", "_")
title = node.get("title", node_id) if node else node_id
source = node.get("source", "") if node else ""
label = title
if source:
label = f"{title}\\n[{source}]"
# 截断过长的标签
if len(label) > 50:
label = label[:47] + "..."
lines.append(f" {safe_id}[\"{label}\"]")
node_map[node_id] = safe_id
# 关系类型到箭头样式的映射
arrow_styles = {
"blocks": "-->|阻塞|",
"blocked_by": "-->|被阻塞|",
"related_to": "---|相关|",
"parent_of": "-->|父任务|",
"child_of": "-->|子任务|",
"duplicates": "-.->|重复|",
}
# 边定义
for e in filtered_edges:
from_safe = node_map.get(e.get("from", ""), "")
to_safe = node_map.get(e.get("to", ""), "")
rel = e.get("type", "related_to")
arrow = arrow_styles.get(rel, "-->")
if from_safe and to_safe:
lines.append(f" {from_safe} {arrow} {to_safe}")
lines.append("```")
mermaid_code = "\n".join(lines)
output_success({
"message": f"已生成包含 {len(involved_nodes)} 个节点和 {len(filtered_edges)} 条关系的图谱",
"mermaid": mermaid_code,
"node_count": len(involved_nodes),
"edge_count": len(filtered_edges),
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("project-nerve 任务关系图谱")
args = parser.parse_args()
action = args.action.lower().replace("-", "_")
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"add_relation": lambda: add_relation(data or {}),
"remove_relation": lambda: remove_relation(data or {}),
"query": lambda: query(data or {}),
"dependencies": lambda: dependencies(data or {}),
"impact": lambda: impact(data or {}),
"visualize": lambda: visualize(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join([
"add-relation", "remove-relation", "query",
"dependencies", "impact", "visualize",
])
output_error(
f"未知操作: {args.action},支持的操作: {valid_actions}",
code="INVALID_ACTION",
)
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
project-nerve 共享工具模块
提供数据存储、订阅校验、优先级/状态标准化、数据格式化等通用功能。
跨平台项目管理聚合器的基础工具集。
"""
import argparse
import json
import os
import re
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
# ============================================================
# 常量定义
# ============================================================
DEFAULT_DATA_DIR = os.path.join(os.path.expanduser("~"), ".openclaw-bdi", "project-nerve")
# 统一优先级定义
PRIORITIES = ["紧急", "高", "中", "低"]
# 统一状态定义
STATUSES = ["待办", "进行中", "已完成", "已关闭"]
# 支持的平台列表
SUPPORTED_PLATFORMS = ["trello", "github", "linear", "notion", "obsidian"]
# 优先级映射表:将各平台的优先级字符串映射到统一优先级
_PRIORITY_MAP: Dict[str, str] = {
# 中文
"紧急": "紧急",
"urgent": "紧急",
"critical": "紧急",
"p0": "紧急",
"highest": "紧急",
"blocker": "紧急",
"高": "高",
"high": "高",
"p1": "高",
"important": "高",
"中": "中",
"medium": "中",
"normal": "中",
"p2": "中",
"default": "中",
"低": "低",
"low": "低",
"minor": "低",
"p3": "低",
"trivial": "低",
"none": "低",
"no priority": "低",
"nopriority": "低",
}
# 状态映射表:将各平台的状态字符串映射到统一状态
_STATUS_MAP: Dict[str, str] = {
# 中文
"待办": "待办",
"未开始": "待办",
"进行中": "进行中",
"处理中": "进行中",
"已完成": "已完成",
"完成": "已完成",
"已关闭": "已关闭",
"关闭": "已关闭",
# 英文 — Trello / GitHub / Linear / Notion 常见状态
"todo": "待办",
"to do": "待办",
"backlog": "待办",
"open": "待办",
"new": "待办",
"not started": "待办",
"triage": "待办",
"in progress": "进行中",
"in_progress": "进行中",
"doing": "进行中",
"started": "进行中",
"active": "进行中",
"in review": "进行中",
"in_review": "进行中",
"done": "已完成",
"completed": "已完成",
"resolved": "已完成",
"merged": "已完成",
"closed": "已关闭",
"cancelled": "已关闭",
"canceled": "已关闭",
"archived": "已关闭",
"won't fix": "已关闭",
"wontfix": "已关闭",
"duplicate": "已关闭",
}
# ============================================================
# 数据目录管理
# ============================================================
def get_data_dir() -> str:
"""获取数据存储目录路径。
优先读取环境变量 PNC_DATA_DIR,若未设置则使用默认路径
~/.openclaw-bdi/project-nerve/。
自动创建目录(若不存在)。
Returns:
数据目录的绝对路径。
"""
data_dir = os.environ.get("PNC_DATA_DIR", DEFAULT_DATA_DIR)
os.makedirs(data_dir, exist_ok=True)
return data_dir
def get_data_file(filename: str) -> str:
"""获取数据文件的完整路径。
Args:
filename: 文件名(如 "sources.json")。
Returns:
数据文件的绝对路径。
"""
return os.path.join(get_data_dir(), filename)
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_file(filepath: str) -> Any:
"""读取 JSON 文件并返回解析后的数据。
Args:
filepath: JSON 文件路径。
Returns:
解析后的数据对象。若文件不存在,返回空列表。
"""
if not os.path.exists(filepath):
return []
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return []
def write_json_file(filepath: str, data: Any) -> None:
"""将数据写入 JSON 文件。
Args:
filepath: 目标文件路径。
data: 待写入的数据。
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
Args:
data: 待输出的数据。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 命令行参数解析
# ============================================================
def parse_common_args(description: str = "project-nerve 项目管理工具") -> argparse.ArgumentParser:
"""创建通用命令行参数解析器。
Args:
description: 工具描述文本。
Returns:
配置好通用参数的 ArgumentParser 实例。
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--action",
required=True,
help="操作类型",
)
parser.add_argument(
"--data",
default=None,
help="JSON 格式的数据字符串",
)
parser.add_argument(
"--data-file",
default=None,
help="JSON 数据文件路径",
)
return parser
def load_input_data(args: argparse.Namespace) -> Optional[Dict[str, Any]]:
"""从命令行参数加载输入数据。
优先使用 --data 参数,其次尝试 --data-file 参数。
Args:
args: 解析后的命令行参数。
Returns:
解析后的字典数据,若无输入数据则返回 None。
Raises:
ValueError: 当 JSON 解析失败或文件读取失败时抛出。
"""
if args.data:
try:
data = json.loads(args.data)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if args.data_file:
if not os.path.exists(args.data_file):
raise ValueError(f"数据文件不存在: {args.data_file}")
try:
with open(args.data_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"数据文件 JSON 解析失败: {e}")
return None
# ============================================================
# 订阅校验
# ============================================================
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_sources": 2,
"max_tasks_display": 50,
"features": [
"basic_query",
"task_list",
"source_connect",
],
},
"paid": {
"tier": "paid",
"max_sources": 10,
"max_tasks_display": 500,
"features": [
"basic_query",
"task_list",
"source_connect",
"sprint_analytics",
"standup_report",
"blocker_analysis",
"mermaid_chart",
"bulk_sync",
],
},
}
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 PNC_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get("PNC_SUBSCRIPTION_TIER", "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def require_paid_feature(feature_name: str, display_name: str) -> bool:
"""检查当前订阅是否支持指定功能。
若不支持,输出升级提示并返回 False。
Args:
feature_name: 功能内部名称。
display_name: 功能显示名称(用于提示信息)。
Returns:
True 表示功能可用,False 表示不可用(已输出错误信息)。
"""
sub = check_subscription()
if feature_name not in sub["features"]:
output_error(
f"「{display_name}」为付费版功能。当前为免费版,请升级至付费版(¥99/月)以使用此功能。",
code="SUBSCRIPTION_REQUIRED",
)
return False
return True
# ============================================================
# 标准化工具函数
# ============================================================
def normalize_priority(p: str) -> str:
"""将各平台的优先级字符串映射到统一优先级。
支持中文(紧急/高/中/低)和英文(urgent/high/medium/low 等)。
无法识别的优先级默认返回 "中"。
Args:
p: 原始优先级字符串。
Returns:
统一优先级:紧急 / 高 / 中 / 低。
"""
if not p:
return "中"
key = p.strip().lower()
return _PRIORITY_MAP.get(key, "中")
def normalize_status(s: str) -> str:
"""将各平台的状态字符串映射到统一状态。
支持中文(待办/进行中/已完成/已关闭)和英文常见状态。
无法识别的状态默认返回 "待办"。
Args:
s: 原始状态字符串。
Returns:
统一状态:待办 / 进行中 / 已完成 / 已关闭。
"""
if not s:
return "待办"
key = s.strip().lower()
return _STATUS_MAP.get(key, "待办")
def format_task_table(tasks: List[Dict[str, Any]]) -> str:
"""将任务列表格式化为 Markdown 表格。
表格列:序号 | 标题 | 平台 | 状态 | 优先级 | 负责人 | 截止日期
Args:
tasks: 统一格式的任务字典列表。
Returns:
Markdown 表格字符串。
"""
if not tasks:
return "暂无任务数据。"
lines = [
"| # | 标题 | 平台 | 状态 | 优先级 | 负责人 | 截止日期 |",
"|---|------|------|------|--------|--------|----------|",
]
for idx, task in enumerate(tasks, start=1):
title = task.get("title", "无标题")
# 截断过长标题
if len(title) > 40:
title = title[:37] + "..."
source = task.get("source", "-")
status = task.get("status", "-")
priority = task.get("priority", "-")
assignee = task.get("assignee", "-") or "-"
due_date = task.get("due_date", "-") or "-"
lines.append(f"| {idx} | {title} | {source} | {status} | {priority} | {assignee} | {due_date} |")
return "\n".join(lines)
# ============================================================
# ID 与时间工具
# ============================================================
def generate_id(prefix: str = "T") -> str:
"""生成唯一 ID。
基于时间戳生成,格式为 前缀+时间戳。
Args:
prefix: ID 前缀,默认为 "T"(任务)。
Returns:
唯一 ID 字符串。
"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
return f"{prefix}{timestamp}"
def now_iso() -> str:
"""返回当前时间的 ISO 格式字符串。
Returns:
ISO 格式时间字符串,如 "2026-03-19T10:30:00"。
"""
return datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
def today_str() -> str:
"""返回今天的日期字符串。
Returns:
日期字符串,格式为 "YYYY-MM-DD"。
"""
return datetime.now().strftime("%Y-%m-%d")
# ============================================================
# Mermaid 图表生成
# ============================================================
def generate_pie_chart(title: str, data: List[Dict[str, Any]]) -> str:
"""生成 Mermaid 饼图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
Returns:
Mermaid 饼图代码块字符串。
"""
lines = ["```mermaid", f"pie title {title}"]
for item in data:
label = item.get("label", "未知")
value = item.get("value", 0)
lines.append(f' "{label}" : {value}')
lines.append("```")
return "\n".join(lines)
def generate_bar_chart(
title: str,
data: List[Dict[str, Any]],
x_label: str = "类别",
y_label: str = "数值",
) -> str:
"""生成 Mermaid xychart-beta 柱状图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
x_label: X 轴标签。
y_label: Y 轴标签。
Returns:
Mermaid 柱状图代码块字符串。
"""
labels = [f'"{item.get("label", "")}"' for item in data]
values = [str(item.get("value", 0)) for item in data]
lines = [
"```mermaid",
"xychart-beta",
f' title "{title}"',
f' x-axis [{", ".join(labels)}]',
f' y-axis "{y_label}"',
f' bar [{", ".join(values)}]',
"```",
]
return "\n".join(lines)
def generate_line_chart(
title: str,
data: List[Dict[str, Any]],
x_label: str = "时间",
y_label: str = "数值",
) -> str:
"""生成 Mermaid xychart-beta 折线图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
x_label: X 轴标签。
y_label: Y 轴标签。
Returns:
Mermaid 折线图代码块字符串。
"""
labels = [f'"{item.get("label", "")}"' for item in data]
values = [str(item.get("value", 0)) for item in data]
lines = [
"```mermaid",
"xychart-beta",
f' title "{title}"',
f' x-axis [{", ".join(labels)}]',
f' y-axis "{y_label}"',
f' line [{", ".join(values)}]',
"```",
]
return "\n".join(lines)
# ============================================================
# 日期辅助函数
# ============================================================
def calculate_days_since(date_str: str) -> int:
"""计算从指定日期到今天的天数。
Args:
date_str: 日期字符串,格式为 YYYY-MM-DD 或 ISO 格式。
Returns:
距今天数(正数表示过去,负数表示未来)。
"""
try:
if "T" in date_str:
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
dt = dt.replace(tzinfo=None)
else:
dt = datetime.strptime(date_str, "%Y-%m-%d")
delta = datetime.now() - dt
return delta.days
except (ValueError, TypeError):
return 0
def is_overdue(due_date: str) -> bool:
"""判断任务是否已逾期。
Args:
due_date: 截止日期字符串(YYYY-MM-DD 或 ISO 格式)。
Returns:
True 表示已逾期,False 表示未逾期或无截止日期。
"""
if not due_date:
return False
days = calculate_days_since(due_date)
return days > 0
def parse_date(date_str: str) -> Optional[datetime]:
"""解析日期字符串为 datetime 对象。
支持 YYYY-MM-DD 和 ISO 格式。
Args:
date_str: 日期字符串。
Returns:
datetime 对象,解析失败返回 None。
"""
if not date_str:
return None
try:
if "T" in date_str:
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
return dt.replace(tzinfo=None)
else:
return datetime.strptime(date_str, "%Y-%m-%d")
except (ValueError, TypeError):
return None
FILE:scripts/standup_generator.py
#!/usr/bin/env python3
"""
project-nerve 站会报告生成器
扫描最近的任务活动,生成每日站会和每周总结报告。
支持跨平台汇总,输出格式化 Markdown。
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
format_task_table,
get_data_file,
is_overdue,
load_input_data,
normalize_status,
now_iso,
output_error,
output_success,
parse_common_args,
parse_date,
read_json_file,
require_paid_feature,
today_str,
)
# ============================================================
# 数据文件路径
# ============================================================
TASKS_CACHE_FILE = "tasks_cache.json"
def _get_cached_tasks() -> List[Dict[str, Any]]:
"""读取缓存的任务数据。"""
data = read_json_file(get_data_file(TASKS_CACHE_FILE))
if isinstance(data, list):
return data
return []
# ============================================================
# 时间过滤辅助
# ============================================================
def _tasks_updated_since(tasks: List[Dict[str, Any]], since: datetime) -> List[Dict[str, Any]]:
"""过滤出指定时间之后有更新的任务。
Args:
tasks: 任务列表。
since: 起始时间。
Returns:
过滤后的任务列表。
"""
result = []
for task in tasks:
dt = parse_date(task.get("updated_at", ""))
if dt and dt >= since:
result.append(task)
return result
def _tasks_by_status(tasks: List[Dict[str, Any]], status: str) -> List[Dict[str, Any]]:
"""按状态过滤任务。"""
return [t for t in tasks if t.get("status") == status]
def _task_summary_line(task: Dict[str, Any]) -> str:
"""生成单个任务的摘要行。
Args:
task: 任务字典。
Returns:
格式化的摘要行字符串。
"""
title = task.get("title", "无标题")
if len(title) > 50:
title = title[:47] + "..."
source = task.get("source", "")
priority = task.get("priority", "")
assignee = task.get("assignee", "")
parts = [f"**{title}**"]
if source:
parts.append(f"[{source}]")
if priority and priority in ("紧急", "高"):
parts.append(f"({priority})")
if assignee:
parts.append(f"@{assignee}")
return " ".join(parts)
# ============================================================
# 操作实现
# ============================================================
def daily_standup(data: Optional[Dict[str, Any]] = None) -> None:
"""生成每日站会报告。
扫描过去 24 小时的任务活动,生成标准站会格式:
- 昨日完成
- 今日计划
- 阻碍事项
Args:
data: 可选参数(assignee 过滤)。
"""
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 task_aggregator fetch-all 获取任务", code="NO_DATA")
return
now = datetime.now()
yesterday = now - timedelta(hours=24)
today_date = today_str()
# 可选:按负责人过滤
assignee = ""
if data:
assignee = data.get("assignee", "").strip()
if assignee:
tasks = [t for t in tasks if assignee.lower() in (t.get("assignee") or "").lower()]
# 最近 24 小时有更新的任务
recent_tasks = _tasks_updated_since(tasks, yesterday)
# 昨日完成:最近 24h 内状态变为已完成的任务
completed = _tasks_by_status(recent_tasks, "已完成")
# 今日计划:当前进行中 + 待办(按优先级排序,取前 10)
in_progress = _tasks_by_status(tasks, "进行中")
todo_tasks = _tasks_by_status(tasks, "待办")
priority_order = {"紧急": 0, "高": 1, "中": 2, "低": 3}
planned = sorted(
in_progress + todo_tasks,
key=lambda t: priority_order.get(t.get("priority", "中"), 2)
)[:10]
# 阻碍事项:逾期或紧急未完成
blockers = []
for task in tasks:
if task.get("status") in ("已完成", "已关闭"):
continue
reasons = []
if task.get("due_date") and is_overdue(task["due_date"]):
reasons.append(f"逾期(截止: {task['due_date']})")
if task.get("priority") == "紧急" and task.get("status") == "进行中":
reasons.append("紧急任务进行中")
if reasons:
blockers.append({"task": task, "reasons": reasons})
# 构建 Markdown 报告
report_parts = []
title_suffix = f" — {assignee}" if assignee else ""
report_parts.append(f"# 每日站会{title_suffix} — {today_date}\n")
# 昨日完成
report_parts.append("## 昨日完成\n")
if completed:
for task in completed:
report_parts.append(f"- {_task_summary_line(task)}")
else:
report_parts.append("- (暂无完成任务)")
report_parts.append("")
# 今日计划
report_parts.append("## 今日计划\n")
if planned:
for task in planned:
status_tag = "进行中" if task.get("status") == "进行中" else "待启动"
report_parts.append(f"- [{status_tag}] {_task_summary_line(task)}")
else:
report_parts.append("- (暂无计划任务)")
report_parts.append("")
# 阻碍事项
report_parts.append("## 阻碍事项\n")
if blockers:
for item in blockers:
task = item["task"]
reasons = "、".join(item["reasons"])
report_parts.append(f"- {_task_summary_line(task)} — {reasons}")
else:
report_parts.append("- (暂无阻碍)")
report_parts.append("")
# 统计摘要
report_parts.append("---")
report_parts.append(f"完成 {len(completed)} | 计划 {len(planned)} | 阻碍 {len(blockers)}")
report_parts.append(f"\n*由 project-nerve 自动生成 — {now_iso()}*")
report_md = "\n".join(report_parts)
output_success({
"report": report_md,
"summary": {
"date": today_date,
"completed_count": len(completed),
"planned_count": len(planned),
"blocker_count": len(blockers),
},
})
def weekly_report(data: Optional[Dict[str, Any]] = None) -> None:
"""生成每周总结报告。
扫描过去 7 天的任务活动,生成周报格式:
- 本周完成
- 进行中
- 下周计划
Args:
data: 可选参数。
"""
if not require_paid_feature("standup_report", "周报生成"):
return
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 task_aggregator fetch-all 获取任务", code="NO_DATA")
return
now = datetime.now()
week_ago = now - timedelta(days=7)
today_date = today_str()
week_start = (now - timedelta(days=now.weekday())).strftime("%Y-%m-%d")
# 可选:按负责人过滤
assignee = ""
if data:
assignee = data.get("assignee", "").strip()
if assignee:
tasks = [t for t in tasks if assignee.lower() in (t.get("assignee") or "").lower()]
# 本周有更新的任务
recent_tasks = _tasks_updated_since(tasks, week_ago)
# 本周完成
completed = _tasks_by_status(recent_tasks, "已完成")
# 进行中
in_progress = _tasks_by_status(tasks, "进行中")
# 下周计划(待办按优先级排序取前 15)
todo_tasks = _tasks_by_status(tasks, "待办")
priority_order = {"紧急": 0, "高": 1, "中": 2, "低": 3}
next_week_plan = sorted(
todo_tasks,
key=lambda t: priority_order.get(t.get("priority", "中"), 2)
)[:15]
# 按平台统计完成情况
source_completed = {}
for task in completed:
src = task.get("source", "未知")
source_completed[src] = source_completed.get(src, 0) + 1
# 按优先级统计进行中任务
priority_in_progress = {}
for task in in_progress:
prio = task.get("priority", "中")
priority_in_progress[prio] = priority_in_progress.get(prio, 0) + 1
# 逾期统计
overdue_tasks = [
t for t in tasks
if t.get("status") not in ("已完成", "已关闭") and t.get("due_date") and is_overdue(t["due_date"])
]
# 构建 Markdown 报告
report_parts = []
title_suffix = f" — {assignee}" if assignee else ""
report_parts.append(f"# 周报{title_suffix} — {week_start} ~ {today_date}\n")
# 概览
report_parts.append("## 本周概览\n")
report_parts.append("| 指标 | 数值 |")
report_parts.append("|------|------|")
report_parts.append(f"| 本周完成 | {len(completed)} |")
report_parts.append(f"| 进行中 | {len(in_progress)} |")
report_parts.append(f"| 待办 | {len(todo_tasks)} |")
report_parts.append(f"| 逾期 | {len(overdue_tasks)} |")
report_parts.append("")
# 本周完成
report_parts.append("## 本周完成\n")
if completed:
for task in completed:
report_parts.append(f"- {_task_summary_line(task)}")
report_parts.append("")
# 平台分布
if source_completed:
report_parts.append("**按平台统计**:")
for src, cnt in sorted(source_completed.items(), key=lambda x: x[1], reverse=True):
report_parts.append(f"- {src}: {cnt} 个")
else:
report_parts.append("- (本周暂无完成任务)")
report_parts.append("")
# 进行中
report_parts.append("## 进行中\n")
if in_progress:
for task in in_progress[:10]:
report_parts.append(f"- {_task_summary_line(task)}")
if len(in_progress) > 10:
report_parts.append(f"- ...等共 {len(in_progress)} 个任务")
else:
report_parts.append("- (暂无进行中任务)")
report_parts.append("")
# 下周计划
report_parts.append("## 下周计划\n")
if next_week_plan:
for task in next_week_plan:
report_parts.append(f"- {_task_summary_line(task)}")
else:
report_parts.append("- (暂无计划任务)")
report_parts.append("")
# 风险提醒
if overdue_tasks:
report_parts.append("## 风险提醒\n")
for task in overdue_tasks[:5]:
report_parts.append(f"- {_task_summary_line(task)} — 逾期(截止: {task.get('due_date', '')})")
if len(overdue_tasks) > 5:
report_parts.append(f"- ...等共 {len(overdue_tasks)} 个逾期任务")
report_parts.append("")
report_parts.append(f"---\n*由 project-nerve 自动生成 — {now_iso()}*")
report_md = "\n".join(report_parts)
output_success({
"report": report_md,
"summary": {
"period": f"{week_start} ~ {today_date}",
"completed_count": len(completed),
"in_progress_count": len(in_progress),
"todo_count": len(todo_tasks),
"overdue_count": len(overdue_tasks),
},
})
def generate_standup(data: Optional[Dict[str, Any]] = None) -> None:
"""智能生成站会报告(自动选择日报或周报)。
如果是周一,生成周报;否则生成日报。
Args:
data: 可选参数。
"""
now = datetime.now()
report_type = "daily"
if data:
report_type = data.get("type", "daily").strip().lower()
# 周一自动切换为周报
if report_type == "auto":
if now.weekday() == 0: # 周一
report_type = "weekly"
else:
report_type = "daily"
if report_type == "weekly":
weekly_report(data)
else:
daily_standup(data)
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("project-nerve 站会报告生成器")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"daily": lambda: daily_standup(data),
"weekly": lambda: weekly_report(data),
"generate": lambda: generate_standup(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/task_writer.py
#!/usr/bin/env python3
"""
project-nerve 任务写入器
在指定平台上创建、更新、移动任务和添加评论。
支持自动检测最适合的平台(含 Obsidian),也可手动指定。
"""
import json
import os
import sys
import urllib.request
import urllib.error
from typing import Any, Dict, List, Optional
from utils import (
generate_id,
get_data_file,
load_input_data,
normalize_priority,
normalize_status,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
SUPPORTED_PLATFORMS,
)
# ============================================================
# 数据文件路径
# ============================================================
SOURCES_FILE = "sources.json"
def _get_sources() -> List[Dict[str, Any]]:
"""读取所有已连接的活跃数据源。"""
data = read_json_file(get_data_file(SOURCES_FILE))
if isinstance(data, list):
return [s for s in data if s.get("status") == "active"]
return []
def _find_source_by_platform(sources: List[Dict], platform: str) -> Optional[Dict]:
"""根据平台类型查找数据源配置。"""
for s in sources:
if s.get("platform") == platform:
return s
return None
# ============================================================
# HTTP 请求工具
# ============================================================
def _http_request(
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
data: Optional[bytes] = None,
timeout: int = 15,
) -> Dict[str, Any]:
"""发送 HTTP 请求并返回响应。
Args:
url: 请求地址。
method: HTTP 方法。
headers: 请求头。
data: 请求体。
timeout: 超时秒数。
Returns:
包含 status 和 body 的响应字典。
"""
if headers is None:
headers = {}
req = urllib.request.Request(url, data=data, headers=headers, method=method)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
body = resp.read().decode("utf-8")
return {"status": resp.status, "body": body}
except urllib.error.HTTPError as e:
body = ""
try:
body = e.read().decode("utf-8")
except Exception:
pass
return {"status": e.code, "body": body, "error": str(e)}
except Exception as e:
return {"status": 0, "body": "", "error": str(e)}
# ============================================================
# 平台自动检测
# ============================================================
# 关键词到平台的映射规则
_PLATFORM_KEYWORDS: Dict[str, List[str]] = {
"github": ["bug", "fix", "pr", "pull request", "merge", "commit", "branch", "issue",
"代码", "修复", "缺陷", "分支", "合并"],
"obsidian": ["note", "notes", "wiki", "knowledge", "vault", "memo", "journal",
"笔记", "知识", "日记", "备忘", "本地"],
"notion": ["doc", "document", "design", "page", "database",
"文档", "设计", "数据库", "规划", "在线"],
"linear": ["sprint", "story", "epic", "cycle", "roadmap", "feature",
"迭代", "用户故事", "史诗", "路线图", "功能"],
"trello": [], # 默认平台
}
def _detect_platform(title: str, description: str = "") -> str:
"""根据任务标题和描述自动检测最适合的平台。
匹配规则:
- 代码/Bug/PR 相关 → GitHub
- 文档/设计 相关 → Notion
- Sprint/Story 相关 → Linear
- 默认 → Trello
Args:
title: 任务标题。
description: 任务描述。
Returns:
推荐的平台名称。
"""
text = f"{title} {description}".lower()
best_platform = "trello"
best_score = 0
for platform, keywords in _PLATFORM_KEYWORDS.items():
score = 0
for kw in keywords:
if kw in text:
score += 1
if score > best_score:
best_score = score
best_platform = platform
return best_platform
# ============================================================
# 平台适配器 — 创建任务
# ============================================================
def _create_trello_card(source: Dict[str, Any], task_data: Dict[str, Any]) -> Dict[str, Any]:
"""在 Trello 创建卡片。
Args:
source: 数据源配置。
task_data: 任务数据(title, description, due_date 等)。
Returns:
创建结果字典。
"""
api_key = os.environ.get("PNC_TRELLO_API_KEY", "")
token = os.environ.get("PNC_TRELLO_TOKEN", "")
board_id = source.get("config", {}).get("board_id", "")
if not api_key or not token:
return {"success": False, "message": "缺少 Trello 凭据"}
# 获取第一个列表作为默认列表
list_id = task_data.get("list_id", "")
if not list_id and board_id:
url = f"https://api.trello.com/1/boards/{board_id}/lists?key={api_key}&token={token}"
resp = _http_request(url)
if resp["status"] == 200:
try:
lists = json.loads(resp["body"])
if lists:
list_id = lists[0]["id"]
except (json.JSONDecodeError, KeyError, IndexError):
pass
if not list_id:
return {"success": False, "message": "无法确定 Trello 列表,请指定 list_id 或 board_id"}
# 创建卡片
params = f"key={api_key}&token={token}"
create_url = f"https://api.trello.com/1/cards?{params}"
card_data = {
"name": task_data.get("title", ""),
"desc": task_data.get("description", ""),
"idList": list_id,
}
if task_data.get("due_date"):
card_data["due"] = task_data["due_date"]
body = json.dumps(card_data).encode("utf-8")
headers = {"Content-Type": "application/json"}
resp = _http_request(create_url, method="POST", headers=headers, data=body)
if resp["status"] in (200, 201):
try:
card = json.loads(resp["body"])
return {
"success": True,
"message": f"Trello 卡片已创建: {card.get('name', '')}",
"task_id": card.get("id", ""),
"url": card.get("shortUrl", ""),
}
except json.JSONDecodeError:
return {"success": True, "message": "Trello 卡片已创建"}
else:
return {"success": False, "message": f"Trello 创建失败 (HTTP {resp['status']}): {resp.get('error', '')}"}
def _create_github_issue(source: Dict[str, Any], task_data: Dict[str, Any]) -> Dict[str, Any]:
"""在 GitHub 创建 Issue。
Args:
source: 数据源配置。
task_data: 任务数据。
Returns:
创建结果字典。
"""
token = os.environ.get("PNC_GITHUB_TOKEN", "")
repo = source.get("config", {}).get("repo", "")
if not token or not repo:
return {"success": False, "message": "缺少 GitHub 凭据或仓库信息"}
url = f"https://api.github.com/repos/{repo}/issues"
headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github.v3+json",
"Content-Type": "application/json",
"User-Agent": "project-nerve/1.0",
}
issue_data = {
"title": task_data.get("title", ""),
"body": task_data.get("description", ""),
}
# 添加标签
labels = task_data.get("labels", [])
priority = task_data.get("priority", "")
if priority:
labels.append(f"priority:{priority}")
if labels:
issue_data["labels"] = labels
# 添加负责人
assignee = task_data.get("assignee", "")
if assignee:
issue_data["assignees"] = [assignee]
body = json.dumps(issue_data).encode("utf-8")
resp = _http_request(url, method="POST", headers=headers, data=body)
if resp["status"] in (200, 201):
try:
issue = json.loads(resp["body"])
return {
"success": True,
"message": f"GitHub Issue 已创建: #{issue.get('number', '')} {issue.get('title', '')}",
"task_id": str(issue.get("number", "")),
"url": issue.get("html_url", ""),
}
except json.JSONDecodeError:
return {"success": True, "message": "GitHub Issue 已创建"}
else:
return {"success": False, "message": f"GitHub 创建失败 (HTTP {resp['status']}): {resp.get('error', '')}"}
def _create_linear_issue(source: Dict[str, Any], task_data: Dict[str, Any]) -> Dict[str, Any]:
"""在 Linear 创建 Issue。
Args:
source: 数据源配置。
task_data: 任务数据。
Returns:
创建结果字典。
"""
api_key = os.environ.get("PNC_LINEAR_API_KEY", "")
if not api_key:
return {"success": False, "message": "缺少 Linear API Key"}
team_id = source.get("config", {}).get("team_id", "")
if not team_id:
return {"success": False, "message": "缺少 Linear Team ID,请在数据源配置中指定 team_id"}
# Linear 优先级映射: 紧急=1, 高=2, 中=3, 低=4
priority_map = {"紧急": 1, "高": 2, "中": 3, "低": 4}
priority = priority_map.get(normalize_priority(task_data.get("priority", "")), 3)
mutation = json.dumps({
"query": """mutation CreateIssue($input: IssueCreateInput!) {
issueCreate(input: $input) {
success
issue {
id
identifier
title
url
}
}
}""",
"variables": {
"input": {
"teamId": team_id,
"title": task_data.get("title", ""),
"description": task_data.get("description", ""),
"priority": priority,
}
}
})
url = "https://api.linear.app/graphql"
headers = {
"Authorization": api_key,
"Content-Type": "application/json",
}
resp = _http_request(url, method="POST", headers=headers, data=mutation.encode("utf-8"))
if resp["status"] == 200:
try:
result = json.loads(resp["body"])
issue_create = result.get("data", {}).get("issueCreate", {})
if issue_create.get("success"):
issue = issue_create.get("issue", {})
return {
"success": True,
"message": f"Linear Issue 已创建: {issue.get('identifier', '')} {issue.get('title', '')}",
"task_id": issue.get("identifier", ""),
"url": issue.get("url", ""),
}
else:
return {"success": False, "message": "Linear 创建失败: API 返回失败状态"}
except (json.JSONDecodeError, KeyError):
return {"success": False, "message": "Linear 创建失败: 响应解析错误"}
else:
return {"success": False, "message": f"Linear 创建失败 (HTTP {resp['status']}): {resp.get('error', '')}"}
def _create_notion_page(source: Dict[str, Any], task_data: Dict[str, Any]) -> Dict[str, Any]:
"""在 Notion 数据库创建页面。
Args:
source: 数据源配置。
task_data: 任务数据。
Returns:
创建结果字典。
"""
token = os.environ.get("PNC_NOTION_TOKEN", "")
database_id = source.get("config", {}).get("database_id", "") or os.environ.get("PNC_NOTION_DATABASE_ID", "")
if not token or not database_id:
return {"success": False, "message": "缺少 Notion 凭据或数据库 ID"}
url = "https://api.notion.com/v1/pages"
headers = {
"Authorization": f"Bearer {token}",
"Notion-Version": "2022-06-28",
"Content-Type": "application/json",
}
# 构建页面数据
properties = {
"Name": {
"title": [
{"text": {"content": task_data.get("title", "")}}
]
}
}
page_data = {
"parent": {"database_id": database_id},
"properties": properties,
}
# 添加内容块
description = task_data.get("description", "")
if description:
page_data["children"] = [
{
"object": "block",
"type": "paragraph",
"paragraph": {
"rich_text": [{"text": {"content": description}}]
}
}
]
body = json.dumps(page_data).encode("utf-8")
resp = _http_request(url, method="POST", headers=headers, data=body)
if resp["status"] in (200, 201):
try:
page = json.loads(resp["body"])
return {
"success": True,
"message": f"Notion 页面已创建: {task_data.get('title', '')}",
"task_id": page.get("id", ""),
"url": page.get("url", ""),
}
except json.JSONDecodeError:
return {"success": True, "message": "Notion 页面已创建"}
else:
return {"success": False, "message": f"Notion 创建失败 (HTTP {resp['status']}): {resp.get('error', '')}"}
def _create_obsidian_task(source: Dict[str, Any], task_data: Dict[str, Any]) -> Dict[str, Any]:
"""在 Obsidian Vault 中创建任务笔记。
在 vault 中创建新的 markdown 文件,包含 frontmatter 元数据
和 markdown 复选框格式的任务项。
Args:
source: 数据源配置。
task_data: 任务数据(title, description, priority, due_date 等)。
Returns:
创建结果字典。
"""
vault_path = source.get("config", {}).get("vault_path", "") or os.environ.get("PNC_OBSIDIAN_VAULT_PATH", "")
task_tag = source.get("config", {}).get("task_tag", "#task")
if not vault_path:
return {"success": False, "message": "缺少 Obsidian Vault 路径"}
vault_path = os.path.expanduser(vault_path)
if not os.path.isdir(vault_path):
return {"success": False, "message": f"Obsidian Vault 路径不存在: {vault_path}"}
title = task_data.get("title", "未命名任务")
description = task_data.get("description", "")
priority = task_data.get("priority", "中")
due_date = task_data.get("due_date", "")
assignee = task_data.get("assignee", "")
labels = task_data.get("labels", [])
# 生成安全的文件名
safe_title = title.replace("/", "_").replace("\\", "_").replace(":", "_")
safe_title = safe_title.replace("\"", "").replace("*", "").replace("?", "")
safe_title = safe_title.replace("<", "").replace(">", "").replace("|", "")
if len(safe_title) > 100:
safe_title = safe_title[:100]
# 确保任务目录存在
tasks_dir = os.path.join(vault_path, "tasks")
os.makedirs(tasks_dir, exist_ok=True)
# 生成文件路径(避免重名)
filename = f"{safe_title}.md"
filepath = os.path.join(tasks_dir, filename)
counter = 1
while os.path.exists(filepath):
filename = f"{safe_title}_{counter}.md"
filepath = os.path.join(tasks_dir, filename)
counter += 1
# 构建 frontmatter
now = now_iso()
fm_lines = ["---"]
fm_lines.append(f"status: 待办")
if priority:
fm_lines.append(f"priority: {priority}")
if assignee:
fm_lines.append(f"assignee: {assignee}")
if due_date:
fm_lines.append(f"due_date: {due_date}")
fm_lines.append(f"created: {now}")
if labels:
fm_lines.append(f"tags: [{', '.join(labels)}]")
fm_lines.append("---")
fm_lines.append("")
# 构建笔记内容
content_lines = list(fm_lines)
content_lines.append(f"# {title}")
content_lines.append("")
content_lines.append(f"- [ ] {title} {task_tag}")
content_lines.append("")
if description:
content_lines.append("## 描述")
content_lines.append("")
content_lines.append(description)
content_lines.append("")
if due_date:
content_lines.append(f"📅 截止日期: {due_date}")
content_lines.append("")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(content_lines))
rel_path = os.path.relpath(filepath, vault_path)
vault_name = os.path.basename(vault_path)
obsidian_url = f"obsidian://open?vault={vault_name}&file={rel_path}"
return {
"success": True,
"message": f"Obsidian 任务笔记已创建: {title}",
"task_id": rel_path,
"url": obsidian_url,
}
except IOError as e:
return {"success": False, "message": f"Obsidian 文件写入失败: {e}"}
# ============================================================
# 平台创建路由
# ============================================================
_PLATFORM_CREATORS = {
"trello": _create_trello_card,
"github": _create_github_issue,
"linear": _create_linear_issue,
"notion": _create_notion_page,
"obsidian": _create_obsidian_task,
}
# ============================================================
# 操作实现
# ============================================================
def create_task(data: Dict[str, Any]) -> None:
"""创建任务。
必填字段: title
可选字段: platform, description, priority, assignee, labels, due_date
若未指定 platform,将根据标题和描述自动检测最适合的平台。
Args:
data: 任务数据字典。
"""
title = data.get("title", "").strip()
if not title:
output_error("任务标题(title)为必填字段", code="VALIDATION_ERROR")
return
sources = _get_sources()
if not sources:
output_error("暂无已连接的数据源,请先使用 source_connector 连接平台", code="NO_SOURCES")
return
# 确定目标平台
platform = data.get("platform", "").strip().lower()
if not platform:
platform = _detect_platform(title, data.get("description", ""))
# 检查平台是否已连接
source = _find_source_by_platform(sources, platform)
if not source:
# 回退到第一个可用平台
available = [s["platform"] for s in sources]
if platform not in available:
source = sources[0]
platform = source["platform"]
else:
output_error(f"平台 {platform} 未连接", code="NOT_CONNECTED")
return
creator = _PLATFORM_CREATORS.get(platform)
if not creator:
output_error(f"平台 {platform} 暂不支持创建任务", code="NOT_SUPPORTED")
return
# 尝试导入自学习引擎(可选依赖)
try:
from learning_engine import quick_record_error, quick_record_success
has_learning = True
except ImportError:
has_learning = False
result = creator(source, data)
if result["success"]:
# 记录成功模式
if has_learning:
quick_record_success(platform, "create", data.get("title", ""))
output_success({
"message": result["message"],
"platform": platform,
"task_id": result.get("task_id", ""),
"url": result.get("url", ""),
"auto_detected": not data.get("platform"),
})
else:
# 记录错误模式
if has_learning:
quick_record_error(platform, "create", "create_failure", result.get("message", ""))
output_error(result["message"], code="CREATE_FAILED")
def update_task(data: Dict[str, Any]) -> None:
"""更新任务状态或属性。
必填字段: source(平台), source_id(平台任务ID)
可选字段: status, priority, title, description
Args:
data: 包含平台和任务 ID 的更新数据字典。
"""
platform = data.get("source", "").strip().lower()
source_id = data.get("source_id", "").strip()
if not platform or not source_id:
output_error("平台(source)和任务 ID(source_id)为必填字段", code="VALIDATION_ERROR")
return
sources = _get_sources()
source = _find_source_by_platform(sources, platform)
if not source:
output_error(f"平台 {platform} 未连接", code="NOT_CONNECTED")
return
# 根据平台执行更新
if platform == "github":
_update_github_issue(source, source_id, data)
elif platform == "trello":
_update_trello_card(source, source_id, data)
elif platform == "linear":
_update_linear_issue(source, source_id, data)
elif platform == "notion":
_update_notion_page(source, source_id, data)
else:
output_error(f"平台 {platform} 暂不支持更新操作", code="NOT_SUPPORTED")
def _update_github_issue(source: Dict, source_id: str, data: Dict) -> None:
"""更新 GitHub Issue。"""
token = os.environ.get("PNC_GITHUB_TOKEN", "")
repo = source.get("config", {}).get("repo", "")
if not token or not repo:
output_error("缺少 GitHub 凭据", code="AUTH_ERROR")
return
url = f"https://api.github.com/repos/{repo}/issues/{source_id}"
headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github.v3+json",
"Content-Type": "application/json",
"User-Agent": "project-nerve/1.0",
}
update_data = {}
if data.get("title"):
update_data["title"] = data["title"]
if data.get("description"):
update_data["body"] = data["description"]
if data.get("status"):
status = normalize_status(data["status"])
if status in ("已完成", "已关闭"):
update_data["state"] = "closed"
else:
update_data["state"] = "open"
if not update_data:
output_error("未提供任何待更新的字段", code="VALIDATION_ERROR")
return
body = json.dumps(update_data).encode("utf-8")
resp = _http_request(url, method="PATCH", headers=headers, data=body)
if resp["status"] == 200:
output_success({"message": f"GitHub Issue #{source_id} 已更新", "platform": "github"})
else:
output_error(f"GitHub 更新失败 (HTTP {resp['status']})", code="UPDATE_FAILED")
def _update_trello_card(source: Dict, source_id: str, data: Dict) -> None:
"""更新 Trello 卡片。"""
api_key = os.environ.get("PNC_TRELLO_API_KEY", "")
token = os.environ.get("PNC_TRELLO_TOKEN", "")
if not api_key or not token:
output_error("缺少 Trello 凭据", code="AUTH_ERROR")
return
params = f"key={api_key}&token={token}"
url = f"https://api.trello.com/1/cards/{source_id}?{params}"
update_data = {}
if data.get("title"):
update_data["name"] = data["title"]
if data.get("description"):
update_data["desc"] = data["description"]
if data.get("due_date"):
update_data["due"] = data["due_date"]
if not update_data:
output_error("未提供任何待更新的字段", code="VALIDATION_ERROR")
return
body = json.dumps(update_data).encode("utf-8")
headers = {"Content-Type": "application/json"}
resp = _http_request(url, method="PUT", headers=headers, data=body)
if resp["status"] == 200:
output_success({"message": f"Trello 卡片 {source_id} 已更新", "platform": "trello"})
else:
output_error(f"Trello 更新失败 (HTTP {resp['status']})", code="UPDATE_FAILED")
def _update_linear_issue(source: Dict, source_id: str, data: Dict) -> None:
"""更新 Linear Issue。"""
api_key = os.environ.get("PNC_LINEAR_API_KEY", "")
if not api_key:
output_error("缺少 Linear API Key", code="AUTH_ERROR")
return
input_data = {}
if data.get("title"):
input_data["title"] = data["title"]
if data.get("description"):
input_data["description"] = data["description"]
if data.get("priority"):
priority_map = {"紧急": 1, "高": 2, "中": 3, "低": 4}
input_data["priority"] = priority_map.get(normalize_priority(data["priority"]), 3)
if not input_data:
output_error("未提供任何待更新的字段", code="VALIDATION_ERROR")
return
mutation = json.dumps({
"query": """mutation UpdateIssue($id: String!, $input: IssueUpdateInput!) {
issueUpdate(id: $id, input: $input) {
success
}
}""",
"variables": {"id": source_id, "input": input_data}
})
url = "https://api.linear.app/graphql"
headers = {"Authorization": api_key, "Content-Type": "application/json"}
resp = _http_request(url, method="POST", headers=headers, data=mutation.encode("utf-8"))
if resp["status"] == 200:
output_success({"message": f"Linear Issue {source_id} 已更新", "platform": "linear"})
else:
output_error(f"Linear 更新失败 (HTTP {resp['status']})", code="UPDATE_FAILED")
def _update_notion_page(source: Dict, source_id: str, data: Dict) -> None:
"""更新 Notion 页面属性。"""
token = os.environ.get("PNC_NOTION_TOKEN", "")
if not token:
output_error("缺少 Notion Token", code="AUTH_ERROR")
return
url = f"https://api.notion.com/v1/pages/{source_id}"
headers = {
"Authorization": f"Bearer {token}",
"Notion-Version": "2022-06-28",
"Content-Type": "application/json",
}
properties = {}
if data.get("title"):
properties["Name"] = {"title": [{"text": {"content": data["title"]}}]}
if not properties:
output_error("未提供任何待更新的字段", code="VALIDATION_ERROR")
return
body = json.dumps({"properties": properties}).encode("utf-8")
resp = _http_request(url, method="PATCH", headers=headers, data=body)
if resp["status"] == 200:
output_success({"message": f"Notion 页面 {source_id[:8]}... 已更新", "platform": "notion"})
else:
output_error(f"Notion 更新失败 (HTTP {resp['status']})", code="UPDATE_FAILED")
def move_task(data: Dict[str, Any]) -> None:
"""移动任务状态(等同于 update + status)。
必填字段: source, source_id, status
Args:
data: 包含平台、任务 ID 和目标状态的字典。
"""
if not data.get("status"):
output_error("目标状态(status)为必填字段", code="VALIDATION_ERROR")
return
update_task(data)
def comment_task(data: Dict[str, Any]) -> None:
"""给任务添加评论。
必填字段: source, source_id, comment
Args:
data: 包含平台、任务 ID 和评论内容的字典。
"""
platform = data.get("source", "").strip().lower()
source_id = data.get("source_id", "").strip()
comment = data.get("comment", "").strip()
if not platform or not source_id or not comment:
output_error("平台(source)、任务 ID(source_id)和评论内容(comment)为必填字段", code="VALIDATION_ERROR")
return
sources = _get_sources()
source = _find_source_by_platform(sources, platform)
if not source:
output_error(f"平台 {platform} 未连接", code="NOT_CONNECTED")
return
if platform == "github":
_comment_github_issue(source, source_id, comment)
elif platform == "trello":
_comment_trello_card(source, source_id, comment)
elif platform == "linear":
_comment_linear_issue(source, source_id, comment)
elif platform == "notion":
_comment_notion_page(source, source_id, comment)
else:
output_error(f"平台 {platform} 暂不支持评论功能", code="NOT_SUPPORTED")
def _comment_github_issue(source: Dict, source_id: str, comment: str) -> None:
"""给 GitHub Issue 添加评论。"""
token = os.environ.get("PNC_GITHUB_TOKEN", "")
repo = source.get("config", {}).get("repo", "")
if not token or not repo:
output_error("缺少 GitHub 凭据", code="AUTH_ERROR")
return
url = f"https://api.github.com/repos/{repo}/issues/{source_id}/comments"
headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github.v3+json",
"Content-Type": "application/json",
"User-Agent": "project-nerve/1.0",
}
body = json.dumps({"body": comment}).encode("utf-8")
resp = _http_request(url, method="POST", headers=headers, data=body)
if resp["status"] in (200, 201):
output_success({"message": f"已在 GitHub Issue #{source_id} 添加评论", "platform": "github"})
else:
output_error(f"GitHub 评论失败 (HTTP {resp['status']})", code="COMMENT_FAILED")
def _comment_trello_card(source: Dict, source_id: str, comment: str) -> None:
"""给 Trello 卡片添加评论。"""
api_key = os.environ.get("PNC_TRELLO_API_KEY", "")
token = os.environ.get("PNC_TRELLO_TOKEN", "")
if not api_key or not token:
output_error("缺少 Trello 凭据", code="AUTH_ERROR")
return
params = f"key={api_key}&token={token}&text={urllib.request.quote(comment)}"
url = f"https://api.trello.com/1/cards/{source_id}/actions/comments?{params}"
resp = _http_request(url, method="POST")
if resp["status"] in (200, 201):
output_success({"message": f"已在 Trello 卡片 {source_id} 添加评论", "platform": "trello"})
else:
output_error(f"Trello 评论失败 (HTTP {resp['status']})", code="COMMENT_FAILED")
def _comment_linear_issue(source: Dict, source_id: str, comment: str) -> None:
"""给 Linear Issue 添加评论。"""
api_key = os.environ.get("PNC_LINEAR_API_KEY", "")
if not api_key:
output_error("缺少 Linear API Key", code="AUTH_ERROR")
return
mutation = json.dumps({
"query": """mutation CreateComment($input: CommentCreateInput!) {
commentCreate(input: $input) {
success
}
}""",
"variables": {"input": {"issueId": source_id, "body": comment}}
})
url = "https://api.linear.app/graphql"
headers = {"Authorization": api_key, "Content-Type": "application/json"}
resp = _http_request(url, method="POST", headers=headers, data=mutation.encode("utf-8"))
if resp["status"] == 200:
output_success({"message": f"已在 Linear Issue {source_id} 添加评论", "platform": "linear"})
else:
output_error(f"Linear 评论失败 (HTTP {resp['status']})", code="COMMENT_FAILED")
def _comment_notion_page(source: Dict, source_id: str, comment: str) -> None:
"""在 Notion 页面添加评论(作为子块追加)。"""
token = os.environ.get("PNC_NOTION_TOKEN", "")
if not token:
output_error("缺少 Notion Token", code="AUTH_ERROR")
return
url = f"https://api.notion.com/v1/blocks/{source_id}/children"
headers = {
"Authorization": f"Bearer {token}",
"Notion-Version": "2022-06-28",
"Content-Type": "application/json",
}
body = json.dumps({
"children": [{
"object": "block",
"type": "paragraph",
"paragraph": {
"rich_text": [{"text": {"content": f"[评论] {comment}"}}]
}
}]
}).encode("utf-8")
resp = _http_request(url, method="PATCH", headers=headers, data=body)
if resp["status"] == 200:
output_success({"message": f"已在 Notion 页面 {source_id[:8]}... 添加评论", "platform": "notion"})
else:
output_error(f"Notion 评论失败 (HTTP {resp['status']})", code="COMMENT_FAILED")
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("project-nerve 任务写入器")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"create": lambda: create_task(data or {}),
"update": lambda: update_task(data or {}),
"move": lambda: move_task(data or {}),
"comment": lambda: comment_task(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/sprint_analyzer.py
#!/usr/bin/env python3
"""
project-nerve Sprint 分析器(付费功能)
提供冲刺速度计算、任务漏斗分析、燃尽图数据生成和综合冲刺报告。
所有功能均需付费订阅。
"""
import json
import math
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
format_task_table,
generate_bar_chart,
generate_line_chart,
generate_pie_chart,
get_data_file,
load_input_data,
normalize_status,
now_iso,
output_error,
output_success,
parse_common_args,
parse_date,
read_json_file,
require_paid_feature,
today_str,
)
# ============================================================
# 数据文件路径
# ============================================================
TASKS_CACHE_FILE = "tasks_cache.json"
def _get_cached_tasks() -> List[Dict[str, Any]]:
"""读取缓存的任务数据。"""
data = read_json_file(get_data_file(TASKS_CACHE_FILE))
if isinstance(data, list):
return data
return []
# ============================================================
# 时间范围辅助
# ============================================================
def _parse_date_range(data: Dict[str, Any]) -> tuple:
"""解析时间范围参数。
支持 start_date/end_date 或 days(最近N天)。
Args:
data: 包含时间范围参数的字典。
Returns:
(start_dt, end_dt) 元组。
"""
now = datetime.now()
if data.get("start_date") and data.get("end_date"):
start_dt = parse_date(data["start_date"]) or (now - timedelta(days=14))
end_dt = parse_date(data["end_date"]) or now
else:
days = int(data.get("days", 14))
start_dt = now - timedelta(days=days)
end_dt = now
return start_dt, end_dt
def _filter_by_date_range(
tasks: List[Dict[str, Any]],
start_dt: datetime,
end_dt: datetime,
date_field: str = "updated_at",
) -> List[Dict[str, Any]]:
"""根据时间范围过滤任务。
Args:
tasks: 任务列表。
start_dt: 开始时间。
end_dt: 结束时间。
date_field: 用于比较的日期字段名。
Returns:
过滤后的任务列表。
"""
result = []
for task in tasks:
dt = parse_date(task.get(date_field, ""))
if dt and start_dt <= dt <= end_dt:
result.append(task)
return result
# ============================================================
# 操作实现
# ============================================================
def velocity(data: Dict[str, Any]) -> None:
"""计算冲刺速度。
速度 = 指定时间范围内已完成的任务数量。
同时计算每日平均完成量和趋势。
Args:
data: 包含时间范围参数的字典。
"""
if not require_paid_feature("sprint_analytics", "冲刺分析"):
return
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 task_aggregator fetch-all 获取任务", code="NO_DATA")
return
start_dt, end_dt = _parse_date_range(data)
total_days = max((end_dt - start_dt).days, 1)
# 过滤时间范围内的已完成任务
completed_tasks = [
t for t in tasks
if t.get("status") == "已完成"
]
completed_in_range = _filter_by_date_range(completed_tasks, start_dt, end_dt)
# 按天统计完成数量
daily_counts = {}
for task in completed_in_range:
dt = parse_date(task.get("updated_at", ""))
if dt:
day_key = dt.strftime("%Y-%m-%d")
daily_counts[day_key] = daily_counts.get(day_key, 0) + 1
# 填充空日期
current = start_dt
daily_data = []
while current <= end_dt:
day_key = current.strftime("%Y-%m-%d")
count = daily_counts.get(day_key, 0)
daily_data.append({"label": day_key[-5:], "value": count}) # MM-DD 格式
current += timedelta(days=1)
total_completed = len(completed_in_range)
daily_avg = round(total_completed / total_days, 1) if total_days > 0 else 0
# 按平台统计
source_stats = {}
for task in completed_in_range:
src = task.get("source", "未知")
source_stats[src] = source_stats.get(src, 0) + 1
# 生成速度趋势图
chart = ""
if daily_data:
chart = generate_line_chart("每日完成任务数趋势", daily_data, x_label="日期", y_label="完成数")
output_success({
"sprint_period": f"{start_dt.strftime('%Y-%m-%d')} ~ {end_dt.strftime('%Y-%m-%d')}",
"total_days": total_days,
"total_completed": total_completed,
"daily_average": daily_avg,
"source_stats": source_stats,
"daily_data": daily_data,
"chart": chart,
})
def funnel(data: Dict[str, Any]) -> None:
"""任务漏斗分析。
统计各状态的任务数量和占比:待办 → 进行中 → 已完成。
Args:
data: 可选参数(可指定 platform 过滤)。
"""
if not require_paid_feature("sprint_analytics", "冲刺分析"):
return
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 task_aggregator fetch-all 获取任务", code="NO_DATA")
return
# 可选平台过滤
platform = data.get("platform", "").strip().lower() if data else ""
if platform:
tasks = [t for t in tasks if t.get("source") == platform]
total = len(tasks)
if total == 0:
output_error("过滤后无任务数据", code="NO_DATA")
return
# 统计各状态
status_counts = {"待办": 0, "进行中": 0, "已完成": 0, "已关闭": 0}
for task in tasks:
status = task.get("status", "待办")
if status in status_counts:
status_counts[status] += 1
else:
status_counts["待办"] += 1
# 计算百分比
funnel_data = []
for status in ["待办", "进行中", "已完成", "已关闭"]:
count = status_counts[status]
pct = round(count / total * 100, 1) if total > 0 else 0
funnel_data.append({
"status": status,
"count": count,
"percentage": pct,
})
# 生成饼图
pie_data = [{"label": item["status"], "value": item["count"]} for item in funnel_data if item["count"] > 0]
pie_chart = generate_pie_chart("任务状态分布", pie_data)
# 生成柱状图
bar_data = [{"label": item["status"], "value": item["count"]} for item in funnel_data]
bar_chart = generate_bar_chart("任务漏斗", bar_data, x_label="状态", y_label="数量")
# 计算转化率
todo_count = status_counts["待办"]
in_progress_count = status_counts["进行中"]
done_count = status_counts["已完成"]
start_to_progress = round(in_progress_count / (todo_count + in_progress_count + done_count) * 100, 1) if total > 0 else 0
progress_to_done = round(done_count / max(in_progress_count + done_count, 1) * 100, 1)
output_success({
"total": total,
"funnel": funnel_data,
"conversion": {
"启动率": f"{start_to_progress}%",
"完成率": f"{progress_to_done}%",
},
"pie_chart": pie_chart,
"bar_chart": bar_chart,
})
def burndown(data: Dict[str, Any]) -> None:
"""生成燃尽图数据。
计算指定时间范围内,每天剩余待完成任务的数量变化。
Args:
data: 包含时间范围参数的字典。
"""
if not require_paid_feature("mermaid_chart", "Mermaid 图表"):
return
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 task_aggregator fetch-all 获取任务", code="NO_DATA")
return
start_dt, end_dt = _parse_date_range(data)
total_days = max((end_dt - start_dt).days, 1)
# 计算总任务数(排除已关闭)
active_tasks = [t for t in tasks if t.get("status") != "已关闭"]
total_tasks = len(active_tasks)
# 统计每天完成数量(累计)
completed_by_day = {}
for task in active_tasks:
if task.get("status") == "已完成":
dt = parse_date(task.get("updated_at", ""))
if dt:
day_key = dt.strftime("%Y-%m-%d")
completed_by_day[day_key] = completed_by_day.get(day_key, 0) + 1
# 生成燃尽数据
burndown_data = []
ideal_data = []
cumulative_completed = 0
current = start_dt
day_index = 0
while current <= end_dt:
day_key = current.strftime("%Y-%m-%d")
cumulative_completed += completed_by_day.get(day_key, 0)
remaining = total_tasks - cumulative_completed
burndown_data.append({
"label": day_key[-5:],
"value": max(remaining, 0),
})
# 理想燃尽线
ideal_remaining = total_tasks - (total_tasks * day_index / total_days)
ideal_data.append({
"label": day_key[-5:],
"value": round(max(ideal_remaining, 0), 1),
})
current += timedelta(days=1)
day_index += 1
# 生成 Mermaid xychart-beta(实际 + 理想)
chart_lines = ["```mermaid", "xychart-beta", ' title "冲刺燃尽图"']
labels = [f'"{d["label"]}"' for d in burndown_data]
chart_lines.append(f' x-axis [{", ".join(labels)}]')
chart_lines.append(' y-axis "剩余任务数"')
actual_values = [str(d["value"]) for d in burndown_data]
chart_lines.append(f' line [{", ".join(actual_values)}]')
ideal_values = [str(d["value"]) for d in ideal_data]
chart_lines.append(f' line [{", ".join(ideal_values)}]')
chart_lines.append("```")
chart = "\n".join(chart_lines)
output_success({
"sprint_period": f"{start_dt.strftime('%Y-%m-%d')} ~ {end_dt.strftime('%Y-%m-%d')}",
"total_tasks": total_tasks,
"completed": cumulative_completed,
"remaining": total_tasks - cumulative_completed,
"burndown_data": burndown_data,
"ideal_data": ideal_data,
"chart": chart,
})
def sprint_report(data: Dict[str, Any]) -> None:
"""生成综合冲刺报告。
包含速度、漏斗、燃尽等多维度分析的 Markdown 报告。
Args:
data: 包含时间范围参数的字典。
"""
if not require_paid_feature("sprint_analytics", "冲刺分析"):
return
tasks = _get_cached_tasks()
if not tasks:
output_error("暂无缓存任务数据,请先执行 task_aggregator fetch-all 获取任务", code="NO_DATA")
return
start_dt, end_dt = _parse_date_range(data)
total_days = max((end_dt - start_dt).days, 1)
period_str = f"{start_dt.strftime('%Y-%m-%d')} ~ {end_dt.strftime('%Y-%m-%d')}"
# 排除已关闭
active_tasks = [t for t in tasks if t.get("status") != "已关闭"]
total_tasks = len(active_tasks)
# 状态统计
status_counts = {"待办": 0, "进行中": 0, "已完成": 0}
for task in active_tasks:
status = task.get("status", "待办")
if status in status_counts:
status_counts[status] += 1
# 范围内完成的任务
completed_in_range = _filter_by_date_range(
[t for t in active_tasks if t.get("status") == "已完成"],
start_dt, end_dt
)
velocity_value = len(completed_in_range)
daily_avg = round(velocity_value / total_days, 1)
# 优先级分布
priority_counts = {"紧急": 0, "高": 0, "中": 0, "低": 0}
for task in active_tasks:
prio = task.get("priority", "中")
if prio in priority_counts:
priority_counts[prio] += 1
# 平台分布
source_counts = {}
for task in active_tasks:
src = task.get("source", "未知")
source_counts[src] = source_counts.get(src, 0) + 1
# 逾期任务
from utils import is_overdue as _is_overdue
overdue_tasks = [
t for t in active_tasks
if t.get("status") not in ("已完成",) and t.get("due_date") and _is_overdue(t["due_date"])
]
# 构建 Markdown 报告
report_parts = []
report_parts.append(f"# 冲刺报告 — {period_str}\n")
report_parts.append(f"统计周期: {period_str} | 总计 {total_days} 天\n")
# 核心指标
report_parts.append("## 核心指标\n")
report_parts.append("| 指标 | 数值 |")
report_parts.append("|------|------|")
report_parts.append(f"| 总任务数 | {total_tasks} |")
report_parts.append(f"| 本期完成 | {velocity_value} |")
report_parts.append(f"| 日均完成 | {daily_avg} |")
report_parts.append(f"| 待办任务 | {status_counts['待办']} |")
report_parts.append(f"| 进行中 | {status_counts['进行中']} |")
report_parts.append(f"| 逾期任务 | {len(overdue_tasks)} |")
report_parts.append("")
# 完成率
completion_rate = round(status_counts["已完成"] / total_tasks * 100, 1) if total_tasks > 0 else 0
report_parts.append(f"**整体完成率**: {completion_rate}%\n")
# 状态分布饼图
status_pie_data = [
{"label": k, "value": v} for k, v in status_counts.items() if v > 0
]
if status_pie_data:
report_parts.append("## 状态分布\n")
report_parts.append(generate_pie_chart("任务状态分布", status_pie_data))
report_parts.append("")
# 优先级分布柱状图
priority_bar_data = [
{"label": k, "value": v} for k, v in priority_counts.items() if v > 0
]
if priority_bar_data:
report_parts.append("## 优先级分布\n")
report_parts.append(generate_bar_chart("任务优先级分布", priority_bar_data, x_label="优先级", y_label="数量"))
report_parts.append("")
# 平台分布
if source_counts:
report_parts.append("## 平台分布\n")
report_parts.append("| 平台 | 任务数 | 占比 |")
report_parts.append("|------|--------|------|")
for src, cnt in sorted(source_counts.items(), key=lambda x: x[1], reverse=True):
pct = round(cnt / total_tasks * 100, 1) if total_tasks > 0 else 0
report_parts.append(f"| {src} | {cnt} | {pct}% |")
report_parts.append("")
# 逾期任务列表
if overdue_tasks:
report_parts.append("## 逾期任务\n")
report_parts.append(format_task_table(overdue_tasks[:10]))
if len(overdue_tasks) > 10:
report_parts.append(f"\n> 仅显示前 10 个,共 {len(overdue_tasks)} 个逾期任务。")
report_parts.append("")
# 建议
report_parts.append("## 改进建议\n")
suggestion_idx = 1
if len(overdue_tasks) > 0:
report_parts.append(f"{suggestion_idx}. 当前有 {len(overdue_tasks)} 个逾期任务,建议优先处理或重新评估截止日期。")
suggestion_idx += 1
if priority_counts.get("紧急", 0) > 3:
report_parts.append(f"{suggestion_idx}. 紧急任务数量较多({priority_counts['紧急']} 个),建议评估是否存在优先级膨胀问题。")
suggestion_idx += 1
if completion_rate < 50:
report_parts.append(f"{suggestion_idx}. 整体完成率仅 {completion_rate}%,建议分析瓶颈环节并调整冲刺范围。")
suggestion_idx += 1
if daily_avg < 1:
report_parts.append(f"{suggestion_idx}. 日均完成量较低({daily_avg}),建议检查任务拆分粒度或团队产能。")
suggestion_idx += 1
if suggestion_idx == 1:
report_parts.append("各项指标表现良好,建议继续保持当前节奏。")
report_parts.append("")
report_parts.append(f"---\n*报告由 project-nerve 自动生成 — {now_iso()}*")
report_md = "\n".join(report_parts)
output_success({
"report": report_md,
"summary": {
"period": period_str,
"total_tasks": total_tasks,
"velocity": velocity_value,
"daily_average": daily_avg,
"completion_rate": completion_rate,
"overdue_count": len(overdue_tasks),
},
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("project-nerve Sprint 分析器")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"velocity": lambda: velocity(data or {}),
"funnel": lambda: funnel(data or {}),
"burndown": lambda: burndown(data or {}),
"report": lambda: sprint_report(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:references/api-guide.md
# 平台 API 指南
project-nerve 使用的各平台 API 端点参考。
---
## Trello API
| 操作 | 方法 | 端点 |
|------|------|------|
| 获取用户信息 | GET | `https://api.trello.com/1/members/me?key={key}&token={token}` |
| 获取看板列表 | GET | `https://api.trello.com/1/boards/{boardId}/lists?key={key}&token={token}` |
| 获取看板卡片 | GET | `https://api.trello.com/1/boards/{boardId}/cards?key={key}&token={token}` |
| 创建卡片 | POST | `https://api.trello.com/1/cards?key={key}&token={token}` |
| 更新卡片 | PUT | `https://api.trello.com/1/cards/{cardId}?key={key}&token={token}` |
| 添加评论 | POST | `https://api.trello.com/1/cards/{cardId}/actions/comments?key={key}&token={token}&text={text}` |
**认证方式**: Query 参数传递 `key` 和 `token`。
---
## GitHub Issues API
| 操作 | 方法 | 端点 |
|------|------|------|
| 获取用户信息 | GET | `https://api.github.com/user` |
| 列出仓库 Issues | GET | `https://api.github.com/repos/{owner}/{repo}/issues` |
| 创建 Issue | POST | `https://api.github.com/repos/{owner}/{repo}/issues` |
| 更新 Issue | PATCH | `https://api.github.com/repos/{owner}/{repo}/issues/{number}` |
| 添加评论 | POST | `https://api.github.com/repos/{owner}/{repo}/issues/{number}/comments` |
**认证方式**: `Authorization: Bearer {token}` 请求头。
**必需请求头**: `Accept: application/vnd.github.v3+json`, `User-Agent: project-nerve/1.0`。
---
## Linear GraphQL API
| 操作 | 方法 | 端点 |
|------|------|------|
| 所有操作 | POST | `https://api.linear.app/graphql` |
**认证方式**: `Authorization: {api_key}` 请求头(无 Bearer 前缀)。
常用查询:
- `{ viewer { id name email } }` — 获取当前用户
- `{ issues(first:100) { nodes { id identifier title ... } } }` — 列出 Issues
- `mutation issueCreate(input: {...}) { ... }` — 创建 Issue
- `mutation issueUpdate(id: "...", input: {...}) { ... }` — 更新 Issue
- `mutation commentCreate(input: {issueId: "...", body: "..."}) { ... }` — 添加评论
---
## Notion API
| 操作 | 方法 | 端点 |
|------|------|------|
| 获取数据库信息 | GET | `https://api.notion.com/v1/databases/{database_id}` |
| 查询数据库 | POST | `https://api.notion.com/v1/databases/{database_id}/query` |
| 创建页面 | POST | `https://api.notion.com/v1/pages` |
| 更新页面属性 | PATCH | `https://api.notion.com/v1/pages/{page_id}` |
| 追加子块 | PATCH | `https://api.notion.com/v1/blocks/{block_id}/children` |
**认证方式**: `Authorization: Bearer {token}` 请求头。
**必需请求头**: `Notion-Version: 2022-06-28`, `Content-Type: application/json`。
---
## 通用注意事项
1. 所有请求使用 HTTPS。
2. 超时设置为 15 秒。
3. 使用 Python `urllib.request` 标准库发送请求。
4. API Key / Token 仅通过环境变量获取,不在代码中硬编码。
5. 错误响应统一包装为 `{success: false, error: {code, message}}` 格式。
FILE:references/unified-schema.md
# 统一任务模型
project-nerve 将来自不同平台的任务数据映射到统一的任务模型。
---
## 统一任务字段
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | string | 统一 ID(格式:`{platform}-{source_id}`) |
| `source` | string | 来源平台(trello / github / linear / notion) |
| `source_id` | string | 平台原始 ID(Trello cardId、GitHub issue number、Linear identifier、Notion pageId) |
| `title` | string | 任务标题 |
| `description` | string | 任务描述(最长 500 字符) |
| `status` | string | 统一状态(待办 / 进行中 / 已完成 / 已关闭) |
| `priority` | string | 统一优先级(紧急 / 高 / 中 / 低) |
| `assignee` | string | 负责人 |
| `labels` | list[string] | 标签列表 |
| `due_date` | string | 截止日期(YYYY-MM-DD) |
| `created_at` | string | 创建时间(ISO 格式) |
| `updated_at` | string | 更新时间(ISO 格式) |
| `url` | string | 平台原始链接 |
---
## 状态映射表
| 统一状态 | Trello 列名 | GitHub State | Linear State | Notion Status |
|----------|-------------|-------------|--------------|---------------|
| 待办 | To Do / Backlog | open | Backlog / Todo / Triage | Not Started / To Do |
| 进行中 | Doing / In Progress | open (assigned) | In Progress / Started | In Progress |
| 已完成 | Done / Completed | closed (merged) | Done / Completed | Done / Completed |
| 已关闭 | Archived | closed | Cancelled / Duplicate | Archived |
### 状态标准化规则
输入字符串(不区分大小写)→ 统一状态:
- `todo`, `to do`, `backlog`, `open`, `new`, `not started`, `triage` → **待办**
- `in progress`, `in_progress`, `doing`, `started`, `active`, `in review` → **进行中**
- `done`, `completed`, `resolved`, `merged` → **已完成**
- `closed`, `cancelled`, `canceled`, `archived`, `duplicate` → **已关闭**
---
## 优先级映射表
| 统一优先级 | Trello 标签 | GitHub 标签 | Linear 数值 | Notion Select |
|-----------|-------------|-------------|-------------|---------------|
| 紧急 | urgent / 紧急 | P0 / critical / blocker | 1 (Urgent) | 紧急 / Urgent |
| 高 | high / 高 | P1 / high / important | 2 (High) | 高 / High |
| 中 | medium / 中 | P2 / medium / normal | 3 (Medium) | 中 / Medium |
| 低 | low / 低 | P3 / low / minor | 4 (Low) / 0 (None) | 低 / Low |
### 优先级标准化规则
输入字符串(不区分大小写)→ 统一优先级:
- `urgent`, `critical`, `p0`, `highest`, `blocker`, `紧急` → **紧急**
- `high`, `p1`, `important`, `高` → **高**
- `medium`, `normal`, `p2`, `default`, `中` → **中**
- `low`, `minor`, `p3`, `trivial`, `none`, `低` → **低**
- 无法识别 → 默认 **中**
知识网格 — 跨平台知识搜索聚合器,统一搜索 GitHub、Stack Overflow、Discord、Confluence、Notion、Slack、百度、Obsidian,支持自学习排序
---
name: knowledge-mesh
description: 知识网格 — 跨平台知识搜索聚合器,统一搜索 GitHub、Stack Overflow、Discord、Confluence、Notion、Slack、百度、Obsidian,支持自学习排序
version: 1.1.0
metadata:
openclaw:
optional_env:
- KM_GITHUB_TOKEN
- KM_STACKOVERFLOW_KEY
- KM_DISCORD_BOT_TOKEN
- KM_CONFLUENCE_URL
- KM_CONFLUENCE_TOKEN
- KM_NOTION_TOKEN
- KM_SLACK_TOKEN
- KM_BAIDU_API_KEY
- KM_OBSIDIAN_VAULT_PATH
- KM_SUBSCRIPTION_TIER
---
# 知识网格(knowledge-mesh)
你是一个专业的跨平台知识搜索助手 Agent。你的职责是帮助用户在多个知识平台上进行统一搜索、结果聚合、趋势分析和知识管理。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `KM_GITHUB_TOKEN` | 否 | GitHub Personal Access Token,用于搜索 Issues/Discussions |
| `KM_STACKOVERFLOW_KEY` | 否 | Stack Exchange API Key,提高速率限制 |
| `KM_DISCORD_BOT_TOKEN` | 否 | Discord Bot Token,搜索频道消息 |
| `KM_DISCORD_CHANNEL_ID` | 否 | Discord 目标频道 ID |
| `KM_CONFLUENCE_URL` | 否 | Confluence 实例 URL(如 https://your-domain.atlassian.net) |
| `KM_CONFLUENCE_TOKEN` | 否 | Confluence API Token |
| `KM_NOTION_TOKEN` | 否 | Notion Integration Token |
| `KM_SLACK_TOKEN` | 否 | Slack Bot User OAuth Token |
| `KM_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `KM_BAIDU_API_KEY` | 否 | 百度搜索 API Key,启用中文搜索增强 |
| `KM_OBSIDIAN_VAULT_PATH` | 否 | Obsidian vault 目录路径,启用本地笔记搜索 |
| `KM_DATA_DIR` | 否 | 数据存储目录,默认 `~/.openclaw-bdi/knowledge-mesh/` |
启动时,你应检查至少一个知识源的凭据已配置。若全部缺失,引导用户进入「知识源配置流程」。
---
## 流程一:跨平台知识搜索
当用户说"搜索"、"查找"、"搜一下"或提出技术问题时,执行以下步骤:
### 步骤 1:解析查询意图
分析用户的自然语言问题,提取:
- 核心关键词
- 目标平台偏好(若有)
- 时间范围限制(若有)
- 结果数量期望
### 步骤 2:执行搜索
```bash
python3 scripts/source_searcher.py --action search --data '{"query":"<关键词>","max_results":20}'
```
若用户指定了平台:
```bash
python3 scripts/source_searcher.py --action search-source --data '{"query":"<关键词>","source":"github"}'
```
### 步骤 3:排序与去重
```bash
python3 scripts/result_ranker.py --action rank --data '{"query":"<关键词>","results":[...]}'
python3 scripts/result_ranker.py --action dedup --data '{"results":[...]}'
```
### 步骤 4:展示结果
将搜索结果以清晰的列表形式展示,每条结果包含:
- 来源标签(如 [GitHub]、[Stack Overflow])
- 标题(高亮匹配关键词)
- 链接
- 摘要片段
- 作者和日期
- 相关度评分
付费用户额外提供知识合成摘要。
---
## 流程二:本地知识索引
当用户说"索引文件"、"建立索引"、"搜索本地"时(仅付费版):
### 步骤 1:索引构建
```bash
python3 scripts/index_builder.py --action index --data '{"paths":["./docs","./src"],"patterns":["*.md","*.txt","*.py"]}'
```
### 步骤 2:本地搜索
```bash
python3 scripts/index_builder.py --action search-local --data '{"query":"<关键词>"}'
```
### 步骤 3:索引管理
```bash
# 查看已索引文档
python3 scripts/index_builder.py --action list-indexed
# 重建索引
python3 scripts/index_builder.py --action rebuild
# 删除文档索引
python3 scripts/index_builder.py --action delete --data '{"doc_id":"DOC..."}'
```
---
## 流程三:主题监控
当用户说"监控"、"订阅主题"、"关注话题"时(仅付费版):
### 步骤 1:创建监控
```bash
python3 scripts/monitor_manager.py --action add --data '{"keywords":["fastapi","async"],"sources":["github","stackoverflow"]}'
```
### 步骤 2:检查更新
```bash
# 检查单个监控
python3 scripts/monitor_manager.py --action check --data '{"id":"MON..."}'
# 检查所有监控
python3 scripts/monitor_manager.py --action check --data '{"id":"all"}'
```
### 步骤 3:生成摘要
```bash
# 日报
python3 scripts/monitor_manager.py --action digest --data '{"period":"daily"}'
# 周报
python3 scripts/monitor_manager.py --action digest --data '{"period":"weekly"}'
```
---
## 流程四:报告导出
当用户说"导出"、"生成报告"、"保存结果"时:
### Markdown 导出
```bash
python3 scripts/report_exporter.py --action export-markdown --data '{"query":"...","results":[...],"file_path":"output/report.md"}'
```
### CSV 导出
```bash
python3 scripts/report_exporter.py --action export-csv --data '{"results":[...],"file_path":"output/results.csv"}'
```
### 趋势分析(仅付费版)
```bash
python3 scripts/report_exporter.py --action trending --data '{"results":[...]}'
```
### 使用统计
```bash
python3 scripts/report_exporter.py --action stats
```
---
## 流程五:自学习搜索引擎
当用户说"反馈"、"评价结果"、"搜索建议"、"搜索统计"时:
### 步骤 1:记录反馈
```bash
# 记录结果评价
python3 scripts/learning_engine.py --action record-feedback --data '{"result_id":"SR...","source":"github","rating":"helpful"}'
# 记录点击行为
python3 scripts/learning_engine.py --action record-click --data '{"result_id":"SR...","source":"stackoverflow"}'
```
### 步骤 2:权重调整
```bash
# 根据反馈调整知识源权重
python3 scripts/learning_engine.py --action boost-weights
```
### 步骤 3:获取建议
```bash
# 获取个性化搜索建议
python3 scripts/learning_engine.py --action suggest
```
### 步骤 4:查看统计
```bash
# 查看搜索分析统计
python3 scripts/learning_engine.py --action stats
```
搜索结果排序模块会自动加载学习权重进行排序调整。用户也可手动校准权重:
```bash
python3 scripts/result_ranker.py --action calibrate
```
---
## 流程六:Obsidian 知识库集成
当用户说"连接 Obsidian"、"搜索笔记"、"索引笔记"时:
### 步骤 1:连接 Vault
```bash
python3 scripts/obsidian_connector.py --action connect --data '{"vault_path":"/path/to/my/vault"}'
```
或通过环境变量设置默认 vault 路径:
```bash
export KM_OBSIDIAN_VAULT_PATH="/path/to/my/vault"
```
### 步骤 2:构建索引
```bash
python3 scripts/obsidian_connector.py --action index --data '{"vault_path":"/path/to/my/vault"}'
```
### 步骤 3:搜索笔记
```bash
python3 scripts/obsidian_connector.py --action search --data '{"query":"python 异步编程"}'
```
Obsidian 搜索支持以下 Obsidian 特性:
- `[[wikilinks]]` 双向链接解析
- `#tags` 标签匹配
- YAML frontmatter 元数据
- Callout 块提取
- 反向链接图用于权威性评分
### 步骤 4:管理笔记
```bash
# 查看已索引笔记
python3 scripts/obsidian_connector.py --action list-notes
# 增量同步
python3 scripts/obsidian_connector.py --action sync
```
Obsidian 笔记也会出现在统一搜索结果中(通过 `source_searcher` 的 `search` 操作)。
---
## 流程七:百度搜索
当用户搜索中文内容或指定百度搜索时:
```bash
# 指定百度搜索
python3 scripts/source_searcher.py --action search-source --data '{"query":"FastAPI 最佳实践","source":"baidu"}'
```
百度搜索也会自动纳入统一搜索(当已配置 `KM_BAIDU_API_KEY` 时)。
---
## 订阅校验逻辑
### 读取订阅等级
```
tier = env KM_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥129/月) |
|------|---------------|----------------------|
| 知识源数量 | 最多 5 个 | 最多 10 个 |
| 支持知识源 | GitHub、Stack Overflow、百度、Obsidian | 全部 8 个平台 |
| 每日搜索次数 | 10 次 | 不限 |
| 单次最大结果数 | 20 条 | 100 条 |
| 本地知识索引 | 不支持 | 支持 |
| Obsidian 集成 | 支持 | 支持 |
| 百度搜索 | 支持 | 支持 |
| 自学习排序(基础) | 支持 | 支持 |
| 自学习排序(高级分析) | 不支持 | 支持 |
| 主题监控 | 不支持 | 支持 |
| 知识合成 | 不支持 | 支持 |
| Mermaid 趋势图表 | 不支持 | 支持 |
| 报告导出 | Markdown/CSV | 全格式 + 趋势分析 |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥129/月),请联系管理员或访问订阅管理页面。"
4. 不要直接拒绝,而是提供免费版可用的替代方案(如果有的话)。
---
## 参考文档
在搜索和生成报告时,请参考以下文档:
- **API 端点参考**:`references/api-endpoints.md` — 各平台 API 地址和认证方式。
- **搜索语法指南**:`references/search-syntax.md` — 搜索查询语法和示例。
---
## 安全规范
1. **凭据保护**:所有 API Token 仅通过环境变量传递,绝不在对话中显示、记录或输出完整的 Token 值。
2. **请求安全**:所有 HTTP 请求使用 HTTPS,设置合理的超时时间。
3. **数据本地化**:搜索索引和监控数据存储在本地,不会上传到外部服务器。
4. **输入校验**:对用户输入进行转义处理,防止注入攻击。
5. **错误处理**:执行命令失败时,向用户展示友好的错误提示,不要暴露内部路径或系统信息。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 搜索前先确认用户的查询意图,必要时追问以明确需求。
3. 搜索结果以结构化列表展示,标注来源、相关度、时间。
4. 主动提供搜索建议和相关关键词扩展。
5. 当搜索无结果时,给出可能的原因和改进建议。
6. 尊重订阅等级限制,在提示升级时保持友好,不要反复推销。
7. 对于代码类问题,优先推荐 Stack Overflow 和 GitHub 的高质量答案。
8. 对于团队知识类问题,优先推荐 Confluence 和 Notion 的内部文档。
FILE:assets/README.md
# Knowledge Mesh / 知识网格
> Cross-platform knowledge search aggregator — unified search across GitHub, Stack Overflow, Discord, Confluence, Notion, and Slack
>
> 跨平台知识搜索聚合器 — 统一搜索 GitHub、Stack Overflow、Discord、Confluence、Notion、Slack
---
## Features / 功能亮点
- **Unified Search / 统一搜索** — Search across 8 platforms with one query. 一次查询搜索 8 个平台,结果统一排序展示。
- **TF-IDF Ranking / 智能排序** — Relevance + authority + recency scoring. TF-IDF 相关性评分 + 权威性 + 时间衰减综合排序。
- **Self-Learning / 自学习排序** — Search gets smarter with use. 搜索越用越精准,根据反馈自动优化排序权重。
- **Obsidian Integration / Obsidian 集成** — Search your local notes too. 本地笔记也能一起搜,支持 wikilinks、标签、frontmatter。
- **Baidu Search / 百度搜索** — Enhanced Chinese content search. 中文内容搜索增强,覆盖中文技术社区。
- **Deduplication / 智能去重** — Jaccard similarity dedup across sources. 基于 Jaccard 相似度跨平台去重。
- **Local Indexing / 本地索引** — Index local .md/.txt/.py files with full-text search. 索引本地文档,支持全文检索。
- **Topic Monitoring / 主题监控** — Set keyword alerts, get daily/weekly digests. 设置关键词监控,获取日报/周报摘要。
- **Knowledge Synthesis / 知识合成** — AI-powered summary of search results. 将搜索结果合成为结构化知识报告。
- **Mermaid Charts / 可视化图表** — Trend analysis with Mermaid pie/bar/line charts. 趋势分析生成 Mermaid 饼图/柱状图/折线图。
- **Export / 导出** — Markdown reports and CSV files. 支持导出 Markdown 报告和 CSV 文件。
---
## Version Comparison / 版本对比
| Feature / 功能 | Free / 免费版 | Paid / 付费版 ¥129/月 |
|----------------|:------------:|:-------------------:|
| Knowledge sources / 知识源数量 | 5 | 10 |
| Supported platforms / 支持平台 | GitHub + SO + Baidu + Obsidian | All 8 platforms / 全部 8 平台 |
| Daily searches / 每日搜索次数 | 10 | Unlimited / 无限 |
| Max results per search / 单次结果数 | 20 | 100 |
| Self-learning ranking / 自学习排序 | Basic / 基础 | Advanced / 高级分析 |
| Obsidian integration / Obsidian 集成 | Supported / 支持 | Supported / 支持 |
| Baidu search / 百度搜索 | Supported / 支持 | Supported / 支持 |
| Local indexing / 本地知识索引 | -- | Supported / 支持 |
| Topic monitoring / 主题监控 | -- | Supported / 支持 |
| Knowledge synthesis / 知识合成 | -- | Supported / 支持 |
| Mermaid trend charts / 趋势图表 | -- | Supported / 支持 |
| Export / 导出 | Markdown + CSV | Full / 全格式 + 趋势分析 |
---
## Quick Start / 快速开始
### 1. Install / 安装
Search for `knowledge-mesh` in ClawHub, or use CLI:
在 ClawHub 中搜索 `knowledge-mesh`,或使用命令行:
```bash
openclaw skill install knowledge-mesh
```
### 2. Configure Sources / 配置知识源
Set environment variables for the platforms you want to search:
设置你要搜索的平台的环境变量:
```bash
# GitHub (recommended / 推荐)
export KM_GITHUB_TOKEN="ghp_your_token_here"
# Stack Overflow (optional / 可选,提高速率限制)
export KM_STACKOVERFLOW_KEY="your_key_here"
# Discord (paid / 付费版)
export KM_DISCORD_BOT_TOKEN="your_bot_token"
export KM_DISCORD_CHANNEL_ID="channel_id"
# Confluence (paid / 付费版)
export KM_CONFLUENCE_URL="https://your-domain.atlassian.net"
export KM_CONFLUENCE_TOKEN="your_token"
# Notion (paid / 付费版)
export KM_NOTION_TOKEN="ntn_your_token"
# Slack (paid / 付费版)
export KM_SLACK_TOKEN="xoxb-your-token"
# Baidu Search / 百度搜索 (free / 免费)
export KM_BAIDU_API_KEY="your_baidu_api_key"
# Obsidian (free / 免费)
export KM_OBSIDIAN_VAULT_PATH="/path/to/your/vault"
```
### 3. Search / 搜索
```bash
# Unified search / 统一搜索
/knowledge-mesh search "python async best practices"
# Search specific source / 搜索指定平台
/knowledge-mesh search-source github "fastapi websocket"
# View configured sources / 查看已配置知识源
/knowledge-mesh list-sources
```
### 4. Advanced Features / 高级功能 (Paid / 付费版)
```bash
# Index local files / 索引本地文件
/knowledge-mesh index "./docs" "./src"
# Set up monitoring / 设置主题监控
/knowledge-mesh monitor "kubernetes" "deployment"
# Export report / 导出报告
/knowledge-mesh export markdown
```
---
## Example / 使用示例
```
用户: 搜索 Python FastAPI WebSocket 最佳实践
知识网格: 正在搜索 GitHub、Stack Overflow...
搜索结果(共 15 条):
1. [Stack Overflow] FastAPI WebSocket connection best practices
- 链接: https://stackoverflow.com/questions/...
- 相关度: 0.92
- 摘要: For **FastAPI** **WebSocket** connections, it's recommended to...
2. [GitHub] tiangolo/fastapi#1234 - WebSocket documentation improvements
- 链接: https://github.com/tiangolo/fastapi/issues/1234
- 相关度: 0.87
- 摘要: This PR improves the **WebSocket** section with **best practices**...
3. [Stack Overflow] How to handle WebSocket disconnections in FastAPI?
- 链接: https://stackoverflow.com/questions/...
- 相关度: 0.81
- 摘要: When dealing with **WebSocket** disconnections in **FastAPI**...
```
---
## FAQ / 常见问题
### Q1: Which platforms can I search for free? / 免费版可以搜索哪些平台?
Free tier supports GitHub and Stack Overflow. These two platforms cover the majority of technical Q&A and open-source project discussions.
免费版支持 GitHub 和 Stack Overflow。这两个平台覆盖了绝大多数技术问答和开源项目讨论。
### Q2: Is my data uploaded to the cloud? / 数据会上传到云端吗?
No. All search requests go directly from your machine to each platform's API. Local index data is stored locally. No data passes through any third-party server.
不会。所有搜索请求从你的机器直接发送到各平台 API。本地索引数据存储在本地。没有数据经过任何第三方服务器。
### Q3: Do I need API keys for all platforms? / 需要配置所有平台的 API Key 吗?
No. You only need to configure the platforms you want to search. Stack Overflow works without an API key (with rate limits). GitHub search also works without a token but with lower rate limits.
不需要。你只需配置你想搜索的平台。Stack Overflow 无需 API Key 即可使用(有速率限制)。GitHub 搜索无 Token 也可工作但速率限制更低。
### Q4: How does deduplication work? / 去重是怎么实现的?
We use Jaccard similarity on tokenized title + snippet text. Results with similarity above 0.7 are considered duplicates, keeping the one with higher relevance score.
使用标题+摘要的分词集合计算 Jaccard 相似度。相似度超过 0.7 的结果被视为重复,保留相关度更高的结果。
### Q5: Can I search in Chinese? / 可以用中文搜索吗?
Yes. The search query is passed directly to each platform's API. Chinese queries work well on GitHub, Confluence, Notion, and Slack. Stack Overflow results are primarily in English.
可以。搜索查询直接传递给各平台 API。中文查询在 GitHub、Confluence、Notion 和 Slack 上效果良好。Stack Overflow 结果以英文为主。
### Q6: What file types can be locally indexed? / 本地索引支持哪些文件类型?
Supported: `.md`, `.txt`, `.py`, `.js`, `.ts`, `.java`, `.go`, `.rs`, `.rb`, `.sh`, `.yaml`, `.json`, `.html`, `.css`, `.sql` and more.
支持:`.md`、`.txt`、`.py`、`.js`、`.ts`、`.java`、`.go`、`.rs`、`.rb`、`.sh`、`.yaml`、`.json`、`.html`、`.css`、`.sql` 等。
---
## Support / 技术支持
- **Docs / 文档**: See `references/` directory for API and syntax guides
- **Issues / 问题反馈**: Submit issues on ClawHub skill page
- **Community / 社区**: Join `#knowledge-mesh` channel on ClawHub community
- **Email / 邮件**: [email protected]
---
*knowledge-mesh v1.1.0 | Compatible with OpenClaw 0.5+ / 兼容 OpenClaw 0.5+*
FILE:scripts/result_ranker.py
#!/usr/bin/env python3
"""
knowledge-mesh 搜索结果排序与合成模块
对跨平台搜索结果进行 TF-IDF 相关性评分、权威性评分、
去重合并和智能摘要生成。
用法:
python3 result_ranker.py --action rank --data '{"query":"...", "results":[...]}'
python3 result_ranker.py --action dedup --data '{"results":[...]}'
python3 result_ranker.py --action synthesize --data '{"query":"...", "results":[...]}'
python3 result_ranker.py --action summarize --data '{"results":[...]}'
"""
import math
import re
import sys
from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from utils import (
check_subscription,
clean_html,
days_ago,
format_source_badge,
highlight_match,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
require_paid_feature,
truncate_text,
SOURCE_DISPLAY_NAMES,
)
# ============================================================
# 常量
# ============================================================
# 各知识源的可信度系数
SOURCE_RELIABILITY = {
"github": 0.9,
"stackoverflow": 0.95,
"discord": 0.5,
"confluence": 0.85,
"notion": 0.7,
"slack": 0.4,
"baidu": 0.6,
"obsidian": 0.8,
}
# 时间衰减半衰期(天)
RECENCY_HALF_LIFE = 90
# Jaccard 相似度去重阈值
DEDUP_THRESHOLD = 0.7
# 停用词列表(常见中英文停用词)
STOP_WORDS = {
"the", "a", "an", "is", "are", "was", "were", "be", "been",
"being", "have", "has", "had", "do", "does", "did", "will",
"would", "could", "should", "may", "might", "can", "shall",
"to", "of", "in", "for", "on", "with", "at", "by", "from",
"as", "into", "through", "during", "before", "after", "above",
"below", "between", "and", "but", "or", "nor", "not", "so",
"if", "then", "than", "too", "very", "just", "about", "up",
"out", "no", "it", "its", "this", "that", "these", "those",
"i", "me", "my", "we", "our", "you", "your", "he", "she",
"him", "her", "they", "them", "their", "what", "which", "who",
"how", "when", "where", "why", "all", "each", "every", "both",
"的", "了", "在", "是", "我", "有", "和", "就", "不", "人",
"都", "一", "一个", "上", "也", "很", "到", "说", "要", "去",
"你", "会", "着", "没有", "看", "好", "自己", "这", "他", "她",
"吗", "什么", "怎么", "如何", "为什么", "哪个", "那个",
}
# ============================================================
# 文本处理
# ============================================================
def _tokenize(text: str) -> List[str]:
"""将文本分词为词语列表。
使用简单的正则表达式分割,同时处理中英文。
Args:
text: 原始文本。
Returns:
小写词语列表。
"""
if not text:
return []
# 英文按单词分割,中文按字符分割
tokens = re.findall(r"[a-zA-Z0-9_]+|[\u4e00-\u9fff]", text.lower())
return [t for t in tokens if t and t not in STOP_WORDS]
def _word_set(text: str) -> Set[str]:
"""提取文本的词语集合(去停用词)。
Args:
text: 原始文本。
Returns:
词语集合。
"""
return set(_tokenize(text))
def _term_frequency(tokens: List[str]) -> Dict[str, float]:
"""计算词频(TF)。
使用对数归一化 TF: 1 + log(count)。
Args:
tokens: 词语列表。
Returns:
词语到 TF 值的映射。
"""
if not tokens:
return {}
counts = Counter(tokens)
tf = {}
for term, count in counts.items():
tf[term] = 1 + math.log(count)
return tf
def _inverse_document_frequency(
term: str,
doc_count: int,
doc_freq: int,
) -> float:
"""计算逆文档频率(IDF)。
IDF = log(N / (1 + df))
Args:
term: 词语。
doc_count: 总文档数。
doc_freq: 包含该词语的文档数。
Returns:
IDF 值。
"""
if doc_count <= 0:
return 0.0
return math.log(doc_count / (1 + doc_freq))
def _build_corpus_idf(documents: List[List[str]]) -> Dict[str, float]:
"""构建语料库的 IDF 字典。
Args:
documents: 文档词语列表的列表。
Returns:
词语到 IDF 值的映射。
"""
doc_count = len(documents)
if doc_count == 0:
return {}
# 统计每个词出现在多少文档中
df = Counter()
for doc_tokens in documents:
unique_terms = set(doc_tokens)
for term in unique_terms:
df[term] += 1
idf = {}
for term, freq in df.items():
idf[term] = _inverse_document_frequency(term, doc_count, freq)
return idf
# ============================================================
# 评分函数
# ============================================================
def _tfidf_relevance_score(
query_tokens: List[str],
doc_tokens: List[str],
idf: Dict[str, float],
) -> float:
"""计算查询与文档的 TF-IDF 相关性分数。
Args:
query_tokens: 查询词语列表。
doc_tokens: 文档词语列表。
idf: IDF 字典。
Returns:
相关性分数(0.0 ~ 1.0)。
"""
if not query_tokens or not doc_tokens:
return 0.0
doc_tf = _term_frequency(doc_tokens)
score = 0.0
for q_term in query_tokens:
tf_val = doc_tf.get(q_term, 0.0)
idf_val = idf.get(q_term, 0.0)
score += tf_val * idf_val
# 归一化
max_possible = sum(idf.get(q, 0.0) for q in query_tokens)
if max_possible > 0:
score = score / max_possible
return min(1.0, max(0.0, score))
def _get_learned_weight(source: str) -> float:
"""获取知识源的自学习权重。
从 learning_engine 加载累积反馈计算出的权重,
作为排序调整的乘数因子。
Args:
source: 知识源名称。
Returns:
权重乘数(默认 1.0)。
"""
try:
from learning_engine import get_source_weights
weights = get_source_weights()
return weights.get(source, 1.0)
except (ImportError, Exception):
return 1.0
def _authority_score(result: Dict[str, Any]) -> float:
"""计算结果的权威性分数。
综合考虑:原始分数、来源可信度、时间衰减。
Args:
result: 搜索结果条目。
Returns:
权威性分数(0.0 ~ 1.0)。
"""
# 原始分数(归一化到 0-1)
raw_score = float(result.get("score", 0))
if raw_score > 1.0:
# Stack Overflow 等平台的分数可能很大
raw_score = min(1.0, math.log(1 + raw_score) / 10)
elif raw_score < 0:
raw_score = 0.0
# 来源可信度系数(结合自学习权重)
source = result.get("source", "")
reliability = SOURCE_RELIABILITY.get(source, 0.5)
# 应用自学习权重调整
learned_weight = _get_learned_weight(source)
reliability = min(1.0, reliability * learned_weight)
# 时间衰减:最近的内容权重更高
created = result.get("created_at", "")
age_days = days_ago(created) if created else 365
recency_factor = math.pow(0.5, age_days / RECENCY_HALF_LIFE)
# 综合评分
authority = (raw_score * 0.4 + reliability * 0.3 + recency_factor * 0.3)
return min(1.0, max(0.0, authority))
def _combined_score(
result: Dict[str, Any],
query_tokens: List[str],
idf: Dict[str, float],
) -> float:
"""计算综合排序分数。
综合 TF-IDF 相关性和权威性。
Args:
result: 搜索结果条目。
query_tokens: 查询词语列表。
idf: IDF 字典。
Returns:
综合分数(0.0 ~ 1.0)。
"""
# 文档内容 = 标题 + 摘要
text = f"{result.get('title', '')} {result.get('snippet', '')}"
doc_tokens = _tokenize(text)
relevance = _tfidf_relevance_score(query_tokens, doc_tokens, idf)
authority = _authority_score(result)
# 相关性权重 60%,权威性权重 40%
return relevance * 0.6 + authority * 0.4
# ============================================================
# Jaccard 去重
# ============================================================
def _jaccard_similarity(set_a: Set[str], set_b: Set[str]) -> float:
"""计算两个集合的 Jaccard 相似度。
J(A, B) = |A ∩ B| / |A ∪ B|
Args:
set_a: 集合 A。
set_b: 集合 B。
Returns:
Jaccard 相似度(0.0 ~ 1.0)。
"""
if not set_a and not set_b:
return 1.0
if not set_a or not set_b:
return 0.0
intersection = len(set_a & set_b)
union = len(set_a | set_b)
if union == 0:
return 0.0
return intersection / union
def _deduplicate_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""对搜索结果进行去重。
使用 Jaccard 相似度比较标题和摘要的词集,
相似度超过阈值的结果保留分数较高的。
Args:
results: 搜索结果列表。
Returns:
去重后的结果列表。
"""
if not results:
return []
# 为每个结果计算词集
result_word_sets = []
for r in results:
text = f"{r.get('title', '')} {r.get('snippet', '')}"
result_word_sets.append(_word_set(text))
kept = []
kept_indices = []
for i, result in enumerate(results):
is_duplicate = False
for j in kept_indices:
sim = _jaccard_similarity(result_word_sets[i], result_word_sets[j])
if sim >= DEDUP_THRESHOLD:
is_duplicate = True
break
if not is_duplicate:
kept.append(result)
kept_indices.append(i)
return kept
# ============================================================
# 合成摘要(付费功能)
# ============================================================
def _synthesize_results(
query: str,
results: List[Dict[str, Any]],
max_items: int = 10,
) -> str:
"""将搜索结果合成为结构化 Markdown 摘要。
Args:
query: 原始查询。
results: 排序后的搜索结果。
max_items: 纳入合成的最大结果数。
Returns:
Markdown 格式的合成摘要。
"""
top = results[:max_items]
parts = []
parts.append(f"# 知识搜索合成报告\n")
parts.append(f"**查询**: {query}")
parts.append(f"**结果数**: {len(results)} 条(展示前 {len(top)} 条)")
parts.append(f"**生成时间**: {now_iso()}\n")
# 按来源分组
by_source = {}
for r in top:
src = r.get("source", "unknown")
if src not in by_source:
by_source[src] = []
by_source[src].append(r)
parts.append("## 来源分布\n")
for src, items in by_source.items():
badge = format_source_badge(src)
parts.append(f"- {badge}: {len(items)} 条结果")
parts.append("")
# 关键发现
parts.append("## 关键发现\n")
for idx, r in enumerate(top, 1):
badge = format_source_badge(r.get("source", ""))
title = r.get("title", "无标题")
url = r.get("url", "")
snippet = truncate_text(r.get("snippet", ""), 150)
score = r.get("_combined_score", r.get("score", 0))
author = r.get("author", "")
created = r.get("created_at", "")
parts.append(f"### {idx}. {title}")
parts.append(f"- **来源**: {badge}")
if url:
parts.append(f"- **链接**: {url}")
if author:
parts.append(f"- **作者**: {author}")
if created:
parts.append(f"- **日期**: {created}")
parts.append(f"- **相关度**: {score:.2f}")
parts.append(f"- **摘要**: {snippet}")
parts.append("")
# 标签汇总
all_tags = []
for r in top:
all_tags.extend(r.get("tags", []))
if all_tags:
tag_counts = Counter(all_tags)
top_tags = tag_counts.most_common(10)
parts.append("## 热门标签\n")
for tag, count in top_tags:
parts.append(f"- `{tag}` ({count})")
parts.append("")
# Mermaid 来源分布图
parts.append("## 来源分布图\n")
parts.append("```mermaid")
parts.append("pie title 搜索结果来源分布")
for src, items in by_source.items():
display = SOURCE_DISPLAY_NAMES.get(src, src)
parts.append(f' "{display}" : {len(items)}')
parts.append("```\n")
parts.append("---")
parts.append("*由 knowledge-mesh 自动生成*")
return "\n".join(parts)
# ============================================================
# 操作实现
# ============================================================
def action_rank(data: Dict[str, Any]) -> None:
"""对搜索结果进行综合排序。
Args:
data: 包含 query 和 results 的字典。
"""
query = data.get("query", "")
results = data.get("results", [])
if not results:
output_success({"query": query, "total": 0, "results": []})
return
query_tokens = _tokenize(query)
# 构建语料库 IDF
documents = []
for r in results:
text = f"{r.get('title', '')} {r.get('snippet', '')}"
documents.append(_tokenize(text))
idf = _build_corpus_idf(documents)
# 计算综合分数
for r in results:
r["_combined_score"] = _combined_score(r, query_tokens, idf)
# 排序
results.sort(key=lambda r: r.get("_combined_score", 0), reverse=True)
output_success({
"query": query,
"total": len(results),
"results": results,
})
def action_dedup(data: Dict[str, Any]) -> None:
"""对搜索结果进行去重。
Args:
data: 包含 results 的字典。
"""
results = data.get("results", [])
original_count = len(results)
deduped = _deduplicate_results(results)
output_success({
"original_count": original_count,
"deduped_count": len(deduped),
"removed": original_count - len(deduped),
"results": deduped,
})
def action_synthesize(data: Dict[str, Any]) -> None:
"""合成搜索结果为结构化报告(付费功能)。
Args:
data: 包含 query 和 results 的字典。
"""
if not require_paid_feature("synthesis", "知识合成"):
return
query = data.get("query", "")
results = data.get("results", [])
if not results:
output_error("无搜索结果可合成", code="NO_DATA")
return
# 先排序
query_tokens = _tokenize(query)
documents = []
for r in results:
text = f"{r.get('title', '')} {r.get('snippet', '')}"
documents.append(_tokenize(text))
idf = _build_corpus_idf(documents)
for r in results:
r["_combined_score"] = _combined_score(r, query_tokens, idf)
results.sort(key=lambda r: r.get("_combined_score", 0), reverse=True)
# 去重
results = _deduplicate_results(results)
# 生成合成报告
max_items = data.get("max_items", 10)
report = _synthesize_results(query, results, max_items=max_items)
output_success({
"query": query,
"total": len(results),
"report": report,
})
def action_summarize(data: Dict[str, Any]) -> None:
"""生成搜索结果的简要统计摘要。
Args:
data: 包含 results 的字典,可选 query。
"""
results = data.get("results", [])
query = data.get("query", "")
total = len(results)
# 来源统计
source_counts = Counter()
for r in results:
source_counts[r.get("source", "unknown")] += 1
# 时间分布
recent_7d = 0
recent_30d = 0
older = 0
for r in results:
created = r.get("created_at", "")
age = days_ago(created) if created else 999
if age <= 7:
recent_7d += 1
elif age <= 30:
recent_30d += 1
else:
older += 1
# 热门标签
all_tags = []
for r in results:
all_tags.extend(r.get("tags", []))
top_tags = Counter(all_tags).most_common(10)
# 平均分数
scores = [r.get("score", 0) for r in results if r.get("score", 0) > 0]
avg_score = sum(scores) / len(scores) if scores else 0.0
summary = {
"query": query,
"total_results": total,
"source_distribution": dict(source_counts),
"time_distribution": {
"last_7_days": recent_7d,
"last_30_days": recent_30d,
"older": older,
},
"top_tags": [{"tag": t, "count": c} for t, c in top_tags],
"average_score": round(avg_score, 3),
}
output_success(summary)
def action_record_feedback(data: Dict[str, Any]) -> None:
"""记录用户对搜索结果的反馈,并更新自学习权重。
Args:
data: 包含 result_id、source、rating 的字典。
"""
result_id = data.get("result_id", "").strip()
source = data.get("source", "").strip()
rating = data.get("rating", "").strip().lower()
if not result_id:
output_error("请提供搜索结果ID(result_id)", code="VALIDATION_ERROR")
return
if not source:
output_error("请提供知识源名称(source)", code="VALIDATION_ERROR")
return
if rating not in {"relevant", "irrelevant", "helpful"}:
output_error(
f"无效的评级: {rating},支持: relevant、irrelevant、helpful",
code="VALIDATION_ERROR",
)
return
# 委托给 learning_engine 记录反馈
try:
from learning_engine import record_feedback_data
record_feedback_data(result_id, source, rating)
output_success({
"message": f"已记录反馈: {source} 结果 {result_id} 评级为 {rating}",
"result_id": result_id,
"source": source,
"rating": rating,
})
except ImportError:
output_error("自学习模块不可用", code="MODULE_ERROR")
except Exception as e:
output_error(f"记录反馈失败: {e}", code="FEEDBACK_ERROR")
def action_calibrate(data: Optional[Dict[str, Any]] = None) -> None:
"""根据反馈历史重新校准最优知识源权重。
从 learning_engine 触发权重重算并返回新权重。
"""
try:
from learning_engine import _load_learning_data, _compute_optimal_weights, _save_learning_data, DEFAULT_SOURCE_WEIGHTS
learning_data = _load_learning_data()
feedback = learning_data.get("feedback", [])
if len(feedback) < 2:
output_error(
f"反馈数据不足(当前 {len(feedback)} 条,需要至少 2 条),无法校准",
code="INSUFFICIENT_DATA",
)
return
old_weights = dict(learning_data.get("source_weights", DEFAULT_SOURCE_WEIGHTS))
new_weights = _compute_optimal_weights(feedback, old_weights)
learning_data["source_weights"] = new_weights
_save_learning_data(learning_data)
changes = []
for source in set(list(old_weights.keys()) + list(new_weights.keys())):
old_w = old_weights.get(source, 1.0)
new_w = new_weights.get(source, 1.0)
if abs(old_w - new_w) > 0.001:
changes.append({
"source": source,
"old_weight": old_w,
"new_weight": new_w,
})
output_success({
"message": f"权重校准完成,基于 {len(feedback)} 条反馈",
"weights": new_weights,
"changes": changes,
})
except ImportError:
output_error("自学习模块不可用", code="MODULE_ERROR")
except Exception as e:
output_error(f"校准失败: {e}", code="CALIBRATE_ERROR")
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("knowledge-mesh 搜索结果排序与合成")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"rank": lambda: action_rank(data or {}),
"dedup": lambda: action_dedup(data or {}),
"synthesize": lambda: action_synthesize(data or {}),
"summarize": lambda: action_summarize(data or {}),
"record-feedback": lambda: action_record_feedback(data or {}),
"calibrate": lambda: action_calibrate(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/monitor_manager.py
#!/usr/bin/env python3
"""
knowledge-mesh 知识主题监控管理器(付费功能)
支持设置关键词监控,定期检查各知识源的新内容,
生成日报/周报摘要。
用法:
python3 monitor_manager.py --action add --data '{"keywords":["fastapi","async"],"sources":["github","stackoverflow"]}'
python3 monitor_manager.py --action remove --data '{"id":"MON..."}'
python3 monitor_manager.py --action list
python3 monitor_manager.py --action check --data '{"id":"MON..."}'
python3 monitor_manager.py --action digest --data '{"period":"daily"}'
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
SUPPORTED_SOURCES,
SOURCE_DISPLAY_NAMES,
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
today_str,
truncate_text,
write_json_file,
)
# ============================================================
# 常量
# ============================================================
# 监控配置文件
MONITORS_FILE = "monitors.json"
# 监控结果缓存文件
MONITOR_RESULTS_FILE = "monitor_results.json"
# 最大监控数量
MAX_MONITORS = 20
# 检查结果最大保留天数
MAX_RESULT_DAYS = 30
# ============================================================
# 数据访问
# ============================================================
def _get_monitors() -> List[Dict[str, Any]]:
"""读取所有监控配置。"""
data = read_json_file(get_data_file(MONITORS_FILE))
if isinstance(data, list):
return data
return []
def _save_monitors(monitors: List[Dict[str, Any]]) -> None:
"""保存监控配置。"""
write_json_file(get_data_file(MONITORS_FILE), monitors)
def _find_monitor(monitors: List[Dict], monitor_id: str) -> Optional[Dict]:
"""根据 ID 查找监控。"""
for m in monitors:
if m.get("id") == monitor_id:
return m
return None
def _get_results() -> Dict[str, List[Dict[str, Any]]]:
"""读取监控结果缓存。
Returns:
{monitor_id: [results]} 映射。
"""
data = read_json_file(get_data_file(MONITOR_RESULTS_FILE))
if isinstance(data, dict):
return data
return {}
def _save_results(results: Dict[str, List[Dict[str, Any]]]) -> None:
"""保存监控结果缓存。"""
write_json_file(get_data_file(MONITOR_RESULTS_FILE), results)
def _cleanup_old_results(results: Dict[str, List[Dict[str, Any]]]) -> Dict[str, List[Dict[str, Any]]]:
"""清理超过保留期限的结果。
Args:
results: 监控结果映射。
Returns:
清理后的结果映射。
"""
cutoff = (datetime.now() - timedelta(days=MAX_RESULT_DAYS)).strftime("%Y-%m-%dT%H:%M:%S")
cleaned = {}
for mid, items in results.items():
kept = [r for r in items if r.get("checked_at", "") >= cutoff]
if kept:
cleaned[mid] = kept
return cleaned
# ============================================================
# 搜索辅助(简化版,避免循环导入)
# ============================================================
def _simple_search(keywords: List[str], sources: List[str], since: str) -> List[Dict[str, Any]]:
"""执行简化搜索,用于监控检查。
通过 source_searcher 模块执行搜索,过滤 since 之后的结果。
Args:
keywords: 关键词列表。
sources: 知识源列表。
since: 起始时间(ISO 格式)。
Returns:
新内容列表。
"""
# 动态导入避免模块级循环依赖
try:
# 尝试导入 source_searcher 中的适配器
script_dir = os.path.dirname(os.path.abspath(__file__))
if script_dir not in sys.path:
sys.path.insert(0, script_dir)
from source_searcher import _SOURCE_ADAPTERS
except ImportError:
return []
query = " ".join(keywords)
all_results = []
for source in sources:
if source not in SUPPORTED_SOURCES:
continue
adapter = _SOURCE_ADAPTERS.get(source)
if not adapter:
continue
try:
results = adapter(query, max_results=20)
# 过滤 since 之后的结果
for r in results:
created = r.get("created_at", "")
if created and created >= since:
r["monitor_source"] = source
all_results.append(r)
except Exception:
# 监控检查时静默忽略搜索错误
continue
return all_results
# ============================================================
# 摘要生成
# ============================================================
def _generate_digest_markdown(
monitors: List[Dict[str, Any]],
results: Dict[str, List[Dict[str, Any]]],
period: str,
) -> str:
"""生成监控摘要 Markdown 报告。
Args:
monitors: 监控配置列表。
results: 监控结果映射。
period: 摘要周期(daily/weekly)。
Returns:
Markdown 格式的摘要报告。
"""
period_label = "日报" if period == "daily" else "周报"
today = today_str()
# 确定时间范围
if period == "daily":
since = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%dT%H:%M:%S")
else:
since = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%dT%H:%M:%S")
parts = []
parts.append(f"# 知识监控{period_label} — {today}\n")
parts.append(f"**生成时间**: {now_iso()}")
parts.append(f"**监控数量**: {len(monitors)} 个")
parts.append(f"**统计周期**: {'过去 24 小时' if period == 'daily' else '过去 7 天'}\n")
total_new = 0
for monitor in monitors:
mid = monitor.get("id", "")
keywords = monitor.get("keywords", [])
monitor_sources = monitor.get("sources", [])
keyword_str = "、".join(keywords)
# 获取该监控的结果
monitor_results = results.get(mid, [])
# 过滤时间范围内的结果
period_results = [
r for r in monitor_results
if r.get("checked_at", "") >= since
]
total_new += len(period_results)
parts.append(f"## 监控: {keyword_str}\n")
parts.append(f"- **关键词**: {keyword_str}")
source_names = [SOURCE_DISPLAY_NAMES.get(s, s) for s in monitor_sources]
parts.append(f"- **监控源**: {', '.join(source_names)}")
parts.append(f"- **新内容**: {len(period_results)} 条\n")
if period_results:
for idx, r in enumerate(period_results[:10], 1):
title = r.get("title", "无标题")
url = r.get("url", "")
source = r.get("source", "")
badge = f"[{SOURCE_DISPLAY_NAMES.get(source, source)}]"
snippet = truncate_text(r.get("snippet", ""), 100)
parts.append(f"### {idx}. {badge} {title}")
if url:
parts.append(f"- 链接: {url}")
parts.append(f"- 摘要: {snippet}")
parts.append("")
if len(period_results) > 10:
parts.append(f"*... 以及其他 {len(period_results) - 10} 条结果*\n")
else:
parts.append("*暂无新内容*\n")
# 汇总
parts.append("---\n")
parts.append(f"**总计**: {len(monitors)} 个监控主题,{total_new} 条新内容\n")
parts.append("---")
parts.append("*由 knowledge-mesh 自动生成*")
return "\n".join(parts)
# ============================================================
# 操作实现
# ============================================================
def action_add(data: Dict[str, Any]) -> None:
"""添加新的主题监控。
Args:
data: 包含 keywords(关键词列表)和可选 sources(知识源列表)的字典。
"""
if not require_paid_feature("topic_monitor", "主题监控"):
return
keywords = data.get("keywords", [])
if not keywords:
output_error("请提供监控关键词列表(keywords)", code="VALIDATION_ERROR")
return
if isinstance(keywords, str):
keywords = [keywords]
sources = data.get("sources", ["github", "stackoverflow"])
if isinstance(sources, str):
sources = [sources]
# 校验知识源
for s in sources:
if s not in SUPPORTED_SOURCES:
valid = "、".join(SUPPORTED_SOURCES)
output_error(f"不支持的知识源: {s},支持: {valid}", code="INVALID_SOURCE")
return
monitors = _get_monitors()
if len(monitors) >= MAX_MONITORS:
output_error(
f"已达监控数量上限({MAX_MONITORS} 个),请先删除不需要的监控。",
code="LIMIT_EXCEEDED",
)
return
now = now_iso()
monitor = {
"id": generate_id("MON"),
"keywords": keywords,
"sources": sources,
"last_checked": now,
"created_at": now,
"active": True,
}
monitors.append(monitor)
_save_monitors(monitors)
output_success({
"message": f"监控已创建,关键词: {', '.join(keywords)}",
"monitor": monitor,
})
def action_remove(data: Dict[str, Any]) -> None:
"""删除指定的主题监控。
Args:
data: 包含 id(监控 ID)的字典。
"""
monitor_id = data.get("id", "").strip()
if not monitor_id:
output_error("请提供监控ID(id)", code="VALIDATION_ERROR")
return
monitors = _get_monitors()
original_count = len(monitors)
monitors = [m for m in monitors if m.get("id") != monitor_id]
if len(monitors) == original_count:
output_error(f"未找到监控: {monitor_id}", code="NOT_FOUND")
return
_save_monitors(monitors)
# 同时删除相关结果
results = _get_results()
if monitor_id in results:
del results[monitor_id]
_save_results(results)
output_success({"message": f"监控 {monitor_id} 已删除"})
def action_list(data: Optional[Dict[str, Any]] = None) -> None:
"""列出所有主题监控。"""
monitors = _get_monitors()
monitor_list = []
for m in monitors:
monitor_list.append({
"id": m.get("id", ""),
"keywords": m.get("keywords", []),
"sources": m.get("sources", []),
"last_checked": m.get("last_checked", ""),
"created_at": m.get("created_at", ""),
"active": m.get("active", True),
})
output_success({
"total": len(monitor_list),
"monitors": monitor_list,
})
def action_check(data: Dict[str, Any]) -> None:
"""检查指定监控的新内容。
Args:
data: 包含 id(监控 ID)的字典。如果 id 为 "all" 则检查所有监控。
"""
if not require_paid_feature("topic_monitor", "主题监控"):
return
monitor_id = data.get("id", "").strip()
if not monitor_id:
output_error("请提供监控ID(id)或 'all'", code="VALIDATION_ERROR")
return
monitors = _get_monitors()
results = _get_results()
check_list = []
if monitor_id == "all":
check_list = monitors
else:
monitor = _find_monitor(monitors, monitor_id)
if not monitor:
output_error(f"未找到监控: {monitor_id}", code="NOT_FOUND")
return
check_list = [monitor]
check_results = {}
now = now_iso()
for monitor in check_list:
mid = monitor.get("id", "")
keywords = monitor.get("keywords", [])
sources = monitor.get("sources", [])
last_checked = monitor.get("last_checked", "")
# 搜索新内容
new_items = _simple_search(keywords, sources, last_checked)
# 为每个结果添加检查时间
for item in new_items:
item["checked_at"] = now
# 追加到结果缓存
if mid not in results:
results[mid] = []
results[mid].extend(new_items)
# 更新最后检查时间
monitor["last_checked"] = now
check_results[mid] = {
"keywords": keywords,
"new_count": len(new_items),
"items": new_items[:10],
}
# 清理旧结果
results = _cleanup_old_results(results)
# 保存
_save_monitors(monitors)
_save_results(results)
total_new = sum(cr["new_count"] for cr in check_results.values())
output_success({
"checked_monitors": len(check_list),
"total_new_items": total_new,
"details": check_results,
})
def action_digest(data: Dict[str, Any]) -> None:
"""生成监控摘要报告。
Args:
data: 包含 period(daily/weekly)的字典。
"""
if not require_paid_feature("topic_monitor", "主题监控"):
return
period = data.get("period", "daily").strip().lower()
if period not in ("daily", "weekly"):
output_error("period 参数必须为 daily 或 weekly", code="VALIDATION_ERROR")
return
monitors = _get_monitors()
results = _get_results()
if not monitors:
output_error("暂无监控主题,请先添加监控", code="NO_MONITORS")
return
report = _generate_digest_markdown(monitors, results, period)
output_success({
"period": period,
"monitor_count": len(monitors),
"report": report,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("knowledge-mesh 主题监控管理")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"add": lambda: action_add(data or {}),
"remove": lambda: action_remove(data or {}),
"list": lambda: action_list(data),
"check": lambda: action_check(data or {}),
"digest": lambda: action_digest(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/learning_engine.py
#!/usr/bin/env python3
"""
knowledge-mesh 自学习搜索引擎模块
基于用户反馈和搜索行为自动调整搜索排序权重,
提供主动建议和搜索分析统计。
用法:
python3 learning_engine.py --action record-feedback --data '{"result_id":"SR...","source":"github","rating":"helpful"}'
python3 learning_engine.py --action record-click --data '{"result_id":"SR...","source":"github"}'
python3 learning_engine.py --action record-query --data '{"query":"python async","sources":["github","stackoverflow"],"result_counts":{"github":5,"stackoverflow":8}}'
python3 learning_engine.py --action boost-weights
python3 learning_engine.py --action suggest
python3 learning_engine.py --action stats
"""
import json
import math
import os
import sys
from collections import Counter
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from utils import (
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
write_json_file,
)
# ============================================================
# 常量
# ============================================================
# 学习数据文件
LEARNING_DATA_FILE = "learning.json"
# 反馈评级定义
VALID_RATINGS = {"relevant", "irrelevant", "helpful"}
# 评级到分数的映射(用于权重计算)
RATING_SCORES = {
"helpful": 1.0,
"relevant": 0.5,
"irrelevant": -0.5,
}
# 默认知识源权重
DEFAULT_SOURCE_WEIGHTS = {
"github": 1.0,
"stackoverflow": 1.0,
"discord": 1.0,
"confluence": 1.0,
"notion": 1.0,
"slack": 1.0,
"baidu": 1.0,
"obsidian": 1.0,
}
# 权重调整范围
MIN_WEIGHT = 0.1
MAX_WEIGHT = 3.0
# 权重调整学习率
LEARNING_RATE = 0.1
# 建议生成所需的最小查询记录数
MIN_QUERIES_FOR_SUGGEST = 3
# 建议生成所需的最小反馈记录数
MIN_FEEDBACK_FOR_SUGGEST = 2
# 主题提取时的最小出现次数
MIN_TOPIC_COUNT = 2
# ============================================================
# 学习数据持久化
# ============================================================
def _get_learning_file() -> str:
"""获取学习数据文件路径。
Returns:
学习数据文件的绝对路径。
"""
return get_data_file(LEARNING_DATA_FILE)
def _load_learning_data() -> Dict[str, Any]:
"""加载学习数据。
Returns:
学习数据字典,包含 query_log, feedback, click_log, source_weights。
"""
data = read_json_file(_get_learning_file())
if not isinstance(data, dict):
data = {}
# 确保必要字段存在
if "query_log" not in data:
data["query_log"] = []
if "feedback" not in data:
data["feedback"] = []
if "click_log" not in data:
data["click_log"] = []
if "source_weights" not in data:
data["source_weights"] = dict(DEFAULT_SOURCE_WEIGHTS)
return data
def _save_learning_data(data: Dict[str, Any]) -> None:
"""保存学习数据。
Args:
data: 学习数据字典。
"""
write_json_file(_get_learning_file(), data)
# ============================================================
# 查询日志分析辅助函数
# ============================================================
def _extract_topics(query_log: List[Dict[str, Any]], top_n: int = 10) -> List[Tuple[str, int]]:
"""从查询日志中提取热门主题词。
对所有查询进行简单分词,统计出现频率最高的词。
Args:
query_log: 查询记录列表。
top_n: 返回前 N 个主题。
Returns:
(主题词, 出现次数) 元组列表,按次数降序排列。
"""
word_counter = Counter()
# 常见停用词
stop_words = {
"the", "a", "an", "is", "are", "in", "on", "for", "to", "of",
"and", "or", "not", "with", "how", "what", "why", "when", "where",
"的", "了", "在", "是", "我", "有", "和", "不", "怎么", "如何",
}
for entry in query_log:
query = entry.get("query", "")
if not query:
continue
# 简单分词:按空格和常见分隔符分割
import re
tokens = re.findall(r"[a-zA-Z0-9_]+|[\u4e00-\u9fff]+", query.lower())
for token in tokens:
if token not in stop_words and len(token) >= 2:
word_counter[token] += 1
# 过滤出至少出现 MIN_TOPIC_COUNT 次的词
filtered = [(word, count) for word, count in word_counter.most_common(top_n * 2)
if count >= MIN_TOPIC_COUNT]
return filtered[:top_n]
def _calculate_source_adoption_rate(feedback: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""计算各知识源的结果采纳率。
采纳率 = helpful 评价数 / 总评价数。
Args:
feedback: 反馈记录列表。
Returns:
{source: {total, helpful, relevant, irrelevant, adoption_rate}} 映射。
"""
source_stats = {}
for entry in feedback:
source = entry.get("source", "unknown")
rating = entry.get("rating", "")
if source not in source_stats:
source_stats[source] = {
"total": 0,
"helpful": 0,
"relevant": 0,
"irrelevant": 0,
}
source_stats[source]["total"] += 1
if rating in source_stats[source]:
source_stats[source][rating] += 1
# 计算采纳率
for source, stats in source_stats.items():
total = stats["total"]
if total > 0:
stats["adoption_rate"] = round(stats["helpful"] / total, 4)
else:
stats["adoption_rate"] = 0.0
return source_stats
def _get_recent_queries(query_log: List[Dict[str, Any]], days: int = 30) -> List[Dict[str, Any]]:
"""获取最近 N 天的查询记录。
Args:
query_log: 查询记录列表。
days: 天数范围。
Returns:
最近 N 天内的查询记录列表。
"""
cutoff = (datetime.utcnow() - timedelta(days=days)).strftime("%Y-%m-%dT%H:%M:%S")
recent = []
for entry in query_log:
ts = entry.get("timestamp", "")
if ts >= cutoff:
recent.append(entry)
return recent
def _compute_optimal_weights(feedback: List[Dict[str, Any]], current_weights: Dict[str, float]) -> Dict[str, float]:
"""根据反馈历史计算最优权重。
评分高的知识源权重提升,评分低的权重降低。
Args:
feedback: 反馈记录列表。
current_weights: 当前权重。
Returns:
更新后的权重字典。
"""
new_weights = dict(current_weights)
# 统计各源的反馈分数
source_scores = {}
source_counts = {}
for entry in feedback:
source = entry.get("source", "")
rating = entry.get("rating", "")
if not source or rating not in RATING_SCORES:
continue
if source not in source_scores:
source_scores[source] = 0.0
source_counts[source] = 0
source_scores[source] += RATING_SCORES[rating]
source_counts[source] += 1
# 计算平均分并调整权重
for source, total_score in source_scores.items():
count = source_counts.get(source, 1)
avg_score = total_score / max(count, 1)
# 使用学习率渐进式调整
base_weight = current_weights.get(source, 1.0)
adjustment = LEARNING_RATE * avg_score
new_weight = base_weight + adjustment
# 限制在合理范围内
new_weight = max(MIN_WEIGHT, min(MAX_WEIGHT, new_weight))
new_weights[source] = round(new_weight, 4)
return new_weights
# ============================================================
# 操作实现
# ============================================================
def action_record_feedback(data: Dict[str, Any]) -> None:
"""记录用户对搜索结果的反馈。
反馈类型包括:relevant(相关)、irrelevant(不相关)、helpful(有帮助)。
Args:
data: 包含 result_id、source、rating 的字典。
"""
result_id = data.get("result_id", "").strip()
source = data.get("source", "").strip()
rating = data.get("rating", "").strip().lower()
if not result_id:
output_error("请提供搜索结果ID(result_id)", code="VALIDATION_ERROR")
return
if not source:
output_error("请提供知识源名称(source)", code="VALIDATION_ERROR")
return
if rating not in VALID_RATINGS:
valid = "、".join(VALID_RATINGS)
output_error(f"无效的评级: {rating},支持: {valid}", code="VALIDATION_ERROR")
return
learning_data = _load_learning_data()
feedback_entry = {
"id": generate_id("FB"),
"result_id": result_id,
"source": source,
"rating": rating,
"timestamp": now_iso(),
}
learning_data["feedback"].append(feedback_entry)
_save_learning_data(learning_data)
output_success({
"message": f"已记录反馈: {source} 结果 {result_id} 评级为 {rating}",
"feedback_id": feedback_entry["id"],
"total_feedback": len(learning_data["feedback"]),
})
def action_record_click(data: Dict[str, Any]) -> None:
"""记录用户点击/使用搜索结果的行为(隐式相关性信号)。
Args:
data: 包含 result_id、source 的字典,可选 query。
"""
result_id = data.get("result_id", "").strip()
source = data.get("source", "").strip()
query = data.get("query", "").strip()
if not result_id:
output_error("请提供搜索结果ID(result_id)", code="VALIDATION_ERROR")
return
if not source:
output_error("请提供知识源名称(source)", code="VALIDATION_ERROR")
return
learning_data = _load_learning_data()
click_entry = {
"id": generate_id("CK"),
"result_id": result_id,
"source": source,
"query": query,
"timestamp": now_iso(),
}
learning_data["click_log"].append(click_entry)
_save_learning_data(learning_data)
output_success({
"message": f"已记录点击: {source} 结果 {result_id}",
"click_id": click_entry["id"],
"total_clicks": len(learning_data["click_log"]),
})
def action_record_query(data: Dict[str, Any]) -> None:
"""记录搜索查询及各知识源的结果数量。
用于分析搜索行为和知识源覆盖率。
Args:
data: 包含 query、sources、result_counts 的字典。
"""
query = data.get("query", "").strip()
sources = data.get("sources", [])
result_counts = data.get("result_counts", {})
if not query:
output_error("请提供搜索查询(query)", code="VALIDATION_ERROR")
return
learning_data = _load_learning_data()
query_entry = {
"id": generate_id("QR"),
"query": query,
"sources": sources,
"result_counts": result_counts,
"total_results": sum(result_counts.values()) if result_counts else 0,
"timestamp": now_iso(),
}
learning_data["query_log"].append(query_entry)
_save_learning_data(learning_data)
output_success({
"message": f"已记录查询: '{query}',共 {query_entry['total_results']} 条结果",
"query_id": query_entry["id"],
"total_queries": len(learning_data["query_log"]),
})
def action_boost_weights(data: Optional[Dict[str, Any]] = None) -> None:
"""根据累积反馈调整知识源可靠性权重。
helpful 评价多的知识源权重提升,irrelevant 评价多的权重降低。
"""
learning_data = _load_learning_data()
feedback = learning_data.get("feedback", [])
if len(feedback) < MIN_FEEDBACK_FOR_SUGGEST:
output_error(
f"反馈数据不足(当前 {len(feedback)} 条,需要至少 {MIN_FEEDBACK_FOR_SUGGEST} 条),无法调整权重",
code="INSUFFICIENT_DATA",
)
return
old_weights = dict(learning_data.get("source_weights", DEFAULT_SOURCE_WEIGHTS))
# 计算新权重
new_weights = _compute_optimal_weights(feedback, old_weights)
learning_data["source_weights"] = new_weights
# 记录权重变更历史
changes = []
for source in set(list(old_weights.keys()) + list(new_weights.keys())):
old_w = old_weights.get(source, 1.0)
new_w = new_weights.get(source, 1.0)
if abs(old_w - new_w) > 0.001:
direction = "提升" if new_w > old_w else "降低"
changes.append({
"source": source,
"old_weight": old_w,
"new_weight": new_w,
"direction": direction,
})
_save_learning_data(learning_data)
output_success({
"message": f"权重调整完成,基于 {len(feedback)} 条反馈",
"weights": new_weights,
"changes": changes,
"feedback_count": len(feedback),
})
def action_suggest(data: Optional[Dict[str, Any]] = None) -> None:
"""基于搜索历史和反馈生成主动建议。
建议类型包括:
- 高频主题推荐关注
- 高采纳率知识源推荐
- 搜索习惯优化建议
"""
learning_data = _load_learning_data()
query_log = learning_data.get("query_log", [])
feedback = learning_data.get("feedback", [])
click_log = learning_data.get("click_log", [])
source_weights = learning_data.get("source_weights", DEFAULT_SOURCE_WEIGHTS)
suggestions = []
# 建议1: 高频主题推荐
if len(query_log) >= MIN_QUERIES_FOR_SUGGEST:
topics = _extract_topics(query_log, top_n=5)
if topics:
top_topic = topics[0]
topic_name = top_topic[0]
topic_count = top_topic[1]
# 根据主题内容给出不同建议
suggestion_text = f"您经常搜索 {topic_name} 相关内容(共 {topic_count} 次),"
# 检查是否为技术框架/工具
tech_suggestions = {
"react": "建议关注 React GitHub Discussions 和官方博客",
"vue": "建议关注 Vue.js RFC 和 GitHub Discussions",
"python": "建议关注 Python PEP 提案和 PyPI 新包发布",
"fastapi": "建议关注 FastAPI GitHub Releases 和文档更新",
"kubernetes": "建议关注 Kubernetes Enhancement Proposals (KEPs)",
"docker": "建议关注 Docker 官方博客和 GitHub Releases",
"typescript": "建议关注 TypeScript GitHub Discussions",
"golang": "建议关注 Go 官方博客和 Release Notes",
"rust": "建议关注 Rust RFC 和 This Week in Rust",
"java": "建议关注 OpenJDK 和 Spring 官方博客",
}
topic_lower = topic_name.lower()
if topic_lower in tech_suggestions:
suggestion_text += tech_suggestions[topic_lower]
else:
suggestion_text += f"建议在搜索时尝试添加更具体的子主题以获得更精确的结果"
suggestions.append({
"type": "topic_recommendation",
"text": suggestion_text,
"topic": topic_name,
"frequency": topic_count,
})
# 建议2: 高采纳率知识源推荐
if len(feedback) >= MIN_FEEDBACK_FOR_SUGGEST:
adoption_stats = _calculate_source_adoption_rate(feedback)
# 找到采纳率最高的知识源
best_source = None
best_rate = 0.0
for source, stats in adoption_stats.items():
if stats["total"] >= 2 and stats["adoption_rate"] > best_rate:
best_rate = stats["adoption_rate"]
best_source = source
if best_source and best_rate > 0.3:
rate_pct = int(best_rate * 100)
source_display = {
"github": "GitHub",
"stackoverflow": "Stack Overflow",
"discord": "Discord",
"confluence": "Confluence",
"notion": "Notion",
"slack": "Slack",
"baidu": "百度搜索",
"obsidian": "Obsidian",
}.get(best_source, best_source)
suggestions.append({
"type": "source_recommendation",
"text": f"{source_display} 的结果采纳率最高({rate_pct}%),建议优先查看",
"source": best_source,
"adoption_rate": best_rate,
})
# 建议3: 搜索习惯优化
if len(query_log) >= MIN_QUERIES_FOR_SUGGEST:
recent_queries = _get_recent_queries(query_log, days=7)
if recent_queries:
# 检查是否有重复查询
query_texts = [q.get("query", "") for q in recent_queries]
query_counts = Counter(query_texts)
repeated = [(q, c) for q, c in query_counts.items() if c >= 2]
if repeated:
most_repeated = max(repeated, key=lambda x: x[1])
suggestions.append({
"type": "search_optimization",
"text": f"您最近 7 天内重复搜索了 '{most_repeated[0]}' {most_repeated[1]} 次,建议使用主题监控功能自动跟踪更新",
"repeated_query": most_repeated[0],
"repeat_count": most_repeated[1],
})
# 检查平均结果数
result_totals = [q.get("total_results", 0) for q in recent_queries]
if result_totals:
avg_results = sum(result_totals) / len(result_totals)
if avg_results < 3:
suggestions.append({
"type": "search_optimization",
"text": "您最近的搜索平均结果较少,建议尝试使用更通用的关键词或启用更多知识源",
"avg_results": round(avg_results, 1),
})
# 建议4: 点击行为分析
if len(click_log) >= 3:
# 统计各源的点击分布
click_sources = Counter()
for click in click_log:
click_sources[click.get("source", "")] += 1
total_clicks = sum(click_sources.values())
if total_clicks > 0:
# 找到最常点击的知识源
top_click_source, top_click_count = click_sources.most_common(1)[0]
click_pct = int((top_click_count / total_clicks) * 100)
if click_pct >= 60:
source_display = {
"github": "GitHub",
"stackoverflow": "Stack Overflow",
"discord": "Discord",
"confluence": "Confluence",
"notion": "Notion",
"slack": "Slack",
"baidu": "百度搜索",
"obsidian": "Obsidian",
}.get(top_click_source, top_click_source)
suggestions.append({
"type": "usage_pattern",
"text": f"您 {click_pct}% 的点击来自 {source_display},该知识源的排序权重已自动提升",
"source": top_click_source,
"click_percentage": click_pct,
})
if not suggestions:
suggestions.append({
"type": "info",
"text": "当前搜索数据不足,继续使用后将为您生成个性化建议",
})
output_success({
"suggestions": suggestions,
"total_suggestions": len(suggestions),
"data_summary": {
"total_queries": len(query_log),
"total_feedback": len(feedback),
"total_clicks": len(click_log),
},
})
def action_stats(data: Optional[Dict[str, Any]] = None) -> None:
"""搜索分析统计:最常搜索的主题、最佳表现知识源、平均结果质量。"""
learning_data = _load_learning_data()
query_log = learning_data.get("query_log", [])
feedback = learning_data.get("feedback", [])
click_log = learning_data.get("click_log", [])
source_weights = learning_data.get("source_weights", DEFAULT_SOURCE_WEIGHTS)
# 基本统计
total_queries = len(query_log)
total_feedback = len(feedback)
total_clicks = len(click_log)
# 热门主题
topics = _extract_topics(query_log, top_n=10)
top_topics = [{"topic": t, "count": c} for t, c in topics]
# 知识源结果统计
source_result_stats = {}
for entry in query_log:
result_counts = entry.get("result_counts", {})
for source, count in result_counts.items():
if source not in source_result_stats:
source_result_stats[source] = {"total_results": 0, "query_count": 0}
source_result_stats[source]["total_results"] += count
source_result_stats[source]["query_count"] += 1
# 计算各源平均结果数
for source, stats in source_result_stats.items():
qc = stats["query_count"]
stats["avg_results"] = round(stats["total_results"] / max(qc, 1), 1)
# 知识源采纳率
adoption_stats = _calculate_source_adoption_rate(feedback)
# 最佳表现知识源
best_source = None
best_score = -1.0
for source, stats in adoption_stats.items():
score = stats["adoption_rate"] * stats["total"] # 综合采纳率和样本量
if score > best_score:
best_score = score
best_source = source
# 平均结果质量(基于反馈)
quality_scores = []
for entry in feedback:
rating = entry.get("rating", "")
if rating in RATING_SCORES:
quality_scores.append(RATING_SCORES[rating])
avg_quality = round(sum(quality_scores) / max(len(quality_scores), 1), 3)
# 时间分布统计
daily_queries = Counter()
for entry in query_log:
ts = entry.get("timestamp", "")
if ts:
day = ts[:10] # 截取日期部分 YYYY-MM-DD
daily_queries[day] += 1
# 最近7天统计
recent_daily = []
today = datetime.utcnow().date()
for i in range(7):
day = (today - timedelta(days=i)).strftime("%Y-%m-%d")
recent_daily.append({
"date": day,
"queries": daily_queries.get(day, 0),
})
output_success({
"overview": {
"total_queries": total_queries,
"total_feedback": total_feedback,
"total_clicks": total_clicks,
"avg_result_quality": avg_quality,
},
"top_topics": top_topics,
"source_performance": {
"result_stats": source_result_stats,
"adoption_stats": adoption_stats,
"best_source": best_source,
},
"current_weights": source_weights,
"recent_activity": recent_daily,
})
# ============================================================
# 公开 API(供其他模块调用)
# ============================================================
def get_source_weights() -> Dict[str, float]:
"""获取当前知识源权重。
供 result_ranker 等模块调用,用于排序调整。
Returns:
知识源权重字典。
"""
learning_data = _load_learning_data()
return learning_data.get("source_weights", dict(DEFAULT_SOURCE_WEIGHTS))
def record_query_data(query: str, sources: List[str], result_counts: Dict[str, int]) -> None:
"""记录查询数据(供 source_searcher 调用)。
Args:
query: 搜索查询。
sources: 搜索的知识源列表。
result_counts: 各知识源的结果数量。
"""
learning_data = _load_learning_data()
query_entry = {
"id": generate_id("QR"),
"query": query,
"sources": sources,
"result_counts": result_counts,
"total_results": sum(result_counts.values()) if result_counts else 0,
"timestamp": now_iso(),
}
learning_data["query_log"].append(query_entry)
_save_learning_data(learning_data)
def record_feedback_data(result_id: str, source: str, rating: str) -> None:
"""记录反馈数据(供 result_ranker 调用)。
Args:
result_id: 搜索结果 ID。
source: 知识源名称。
rating: 评级(relevant/irrelevant/helpful)。
"""
if rating not in VALID_RATINGS:
return
learning_data = _load_learning_data()
feedback_entry = {
"id": generate_id("FB"),
"result_id": result_id,
"source": source,
"rating": rating,
"timestamp": now_iso(),
}
learning_data["feedback"].append(feedback_entry)
_save_learning_data(learning_data)
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("knowledge-mesh 自学习搜索引擎")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"record-feedback": lambda: action_record_feedback(data or {}),
"record-click": lambda: action_record_click(data or {}),
"record-query": lambda: action_record_query(data or {}),
"boost-weights": lambda: action_boost_weights(data),
"suggest": lambda: action_suggest(data),
"stats": lambda: action_stats(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
knowledge-mesh 共享工具模块
提供数据目录管理、订阅校验、参数解析、数据格式化等通用功能。
统一搜索 GitHub、Stack Overflow、Discord、Confluence、Notion、Slack 等知识源。
"""
import argparse
import hashlib
import json
import os
import re
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
# ============================================================
# 常量定义
# ============================================================
DEFAULT_DATA_DIR = os.path.join(os.path.expanduser("~"), ".openclaw-bdi", "knowledge-mesh")
# 支持的知识源列表
SUPPORTED_SOURCES = [
"github",
"stackoverflow",
"discord",
"confluence",
"notion",
"slack",
"baidu",
"obsidian",
]
# 各知识源显示名称
SOURCE_DISPLAY_NAMES = {
"github": "GitHub",
"stackoverflow": "Stack Overflow",
"discord": "Discord",
"confluence": "Confluence",
"notion": "Notion",
"slack": "Slack",
"baidu": "百度搜索",
"obsidian": "Obsidian",
}
# 各知识源对应的环境变量
SOURCE_ENV_KEYS = {
"github": "KM_GITHUB_TOKEN",
"stackoverflow": "KM_STACKOVERFLOW_KEY",
"discord": "KM_DISCORD_BOT_TOKEN",
"confluence": "KM_CONFLUENCE_TOKEN",
"notion": "KM_NOTION_TOKEN",
"slack": "KM_SLACK_TOKEN",
"baidu": "KM_BAIDU_API_KEY",
"obsidian": "KM_OBSIDIAN_VAULT_PATH",
}
# ============================================================
# 数据目录管理
# ============================================================
def get_data_dir() -> str:
"""获取数据存储目录路径。
优先读取环境变量 KM_DATA_DIR,若未设置则使用默认路径
~/.openclaw-bdi/knowledge-mesh/。
自动创建目录(若不存在)。
Returns:
数据目录的绝对路径。
"""
data_dir = os.environ.get("KM_DATA_DIR", DEFAULT_DATA_DIR)
os.makedirs(data_dir, exist_ok=True)
return data_dir
def get_data_file(filename: str) -> str:
"""获取数据文件的完整路径。
Args:
filename: 文件名(如 "index_data.json")。
Returns:
数据文件的绝对路径。
"""
return os.path.join(get_data_dir(), filename)
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_file(filepath: str) -> Any:
"""读取 JSON 文件并返回解析后的数据。
Args:
filepath: JSON 文件路径。
Returns:
解析后的数据对象。若文件不存在,返回空列表。
"""
if not os.path.exists(filepath):
return []
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return []
def write_json_file(filepath: str, data: Any) -> None:
"""将数据写入 JSON 文件。
Args:
filepath: 目标文件路径。
data: 待写入的数据。
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
Args:
data: 待输出的数据。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 命令行参数解析
# ============================================================
def parse_common_args(description: str = "knowledge-mesh 知识搜索工具") -> argparse.ArgumentParser:
"""创建通用命令行参数解析器。
Args:
description: 工具描述文本。
Returns:
配置好通用参数的 ArgumentParser 实例。
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--action",
required=True,
help="操作类型",
)
parser.add_argument(
"--data",
default=None,
help="JSON 格式的数据字符串",
)
parser.add_argument(
"--data-file",
default=None,
help="JSON 数据文件路径",
)
return parser
def load_input_data(args: argparse.Namespace) -> Optional[Dict[str, Any]]:
"""从命令行参数加载输入数据。
优先使用 --data 参数,其次尝试 --data-file 参数。
Args:
args: 解析后的命令行参数。
Returns:
解析后的字典数据,若无输入数据则返回 None。
Raises:
ValueError: 当 JSON 解析失败或文件读取失败时抛出。
"""
if args.data:
try:
data = json.loads(args.data)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if args.data_file:
if not os.path.exists(args.data_file):
raise ValueError(f"数据文件不存在: {args.data_file}")
try:
with open(args.data_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"数据文件 JSON 解析失败: {e}")
return None
# ============================================================
# 订阅校验
# ============================================================
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_sources": 3,
"daily_searches": 10,
"max_results": 20,
"features": [
"basic_search",
"github_search",
"stackoverflow_search",
],
},
"paid": {
"tier": "paid",
"max_sources": 10,
"daily_searches": -1,
"max_results": 100,
"features": [
"basic_search",
"github_search",
"stackoverflow_search",
"discord_search",
"confluence_search",
"notion_search",
"slack_search",
"local_index",
"topic_monitor",
"synthesis",
"mermaid_chart",
"export",
],
},
}
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 KM_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get("KM_SUBSCRIPTION_TIER", "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def require_paid_feature(feature_name: str, display_name: str) -> bool:
"""检查当前订阅是否支持指定功能。
若不支持,输出升级提示并返回 False。
Args:
feature_name: 功能内部名称。
display_name: 功能显示名称(用于提示信息)。
Returns:
True 表示功能可用,False 表示不可用(已输出错误信息)。
"""
sub = check_subscription()
if feature_name not in sub["features"]:
output_error(
f"「{display_name}」为付费版功能。当前为免费版,请升级至付费版(¥129/月)以使用此功能。",
code="SUBSCRIPTION_REQUIRED",
)
return False
return True
# ============================================================
# 搜索配额管理
# ============================================================
def _get_usage_file() -> str:
"""获取搜索使用量记录文件路径。"""
return get_data_file("usage.json")
def check_search_quota() -> bool:
"""检查今日搜索配额是否充足。
Returns:
True 表示配额充足,False 表示已达上限(已输出错误信息)。
"""
sub = check_subscription()
daily_limit = sub["daily_searches"]
# 付费版不限搜索次数
if daily_limit < 0:
return True
usage = read_json_file(_get_usage_file())
if not isinstance(usage, dict):
usage = {}
today = today_str()
today_count = usage.get(today, 0)
if today_count >= daily_limit:
output_error(
f"今日搜索次数已达上限({daily_limit} 次)。请升级至付费版(¥129/月)以获取无限搜索。",
code="QUOTA_EXCEEDED",
)
return False
return True
def increment_search_count() -> None:
"""递增今日搜索计数。"""
filepath = _get_usage_file()
usage = read_json_file(filepath)
if not isinstance(usage, dict):
usage = {}
today = today_str()
usage[today] = usage.get(today, 0) + 1
# 清理 7 天前的记录
cutoff = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
usage = {k: v for k, v in usage.items() if k >= cutoff}
write_json_file(filepath, usage)
# ============================================================
# 通用工具函数
# ============================================================
def generate_id(prefix: str = "KM") -> str:
"""生成唯一 ID。
基于时间戳生成,格式为 前缀+时间戳。
Args:
prefix: ID 前缀,默认为 "KM"(知识网格)。
Returns:
唯一 ID 字符串。
"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
return f"{prefix}{timestamp}"
def now_iso() -> str:
"""返回当前时间的 ISO 格式字符串。
Returns:
ISO 格式时间字符串,如 "2026-03-19T10:30:00"。
"""
return datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
def today_str() -> str:
"""返回今天的日期字符串。
Returns:
日期字符串,格式为 "YYYY-MM-DD"。
"""
return datetime.now().strftime("%Y-%m-%d")
def truncate_text(text: str, max_len: int = 200) -> str:
"""截断文本到指定最大长度。
若超过最大长度,截断并追加省略号。
Args:
text: 原始文本。
max_len: 最大长度,默认 200。
Returns:
截断后的文本。
"""
if not text:
return ""
text = text.strip()
if len(text) <= max_len:
return text
return text[:max_len - 3] + "..."
def highlight_match(text: str, query: str) -> str:
"""在文本中高亮匹配的查询词。
使用 **加粗** 标记匹配部分(Markdown 格式)。
Args:
text: 原始文本。
query: 查询关键词。
Returns:
高亮后的文本。
"""
if not text or not query:
return text or ""
# 将查询拆分为多个关键词
keywords = [kw.strip() for kw in query.split() if kw.strip()]
result = text
for kw in keywords:
# 不区分大小写替换
pattern = re.compile(re.escape(kw), re.IGNORECASE)
result = pattern.sub(lambda m: f"**{m.group(0)}**", result)
return result
def format_source_badge(source_name: str) -> str:
"""格式化知识源标签显示。
生成 Markdown 格式的知识源标签。
Args:
source_name: 知识源名称(如 "github")。
Returns:
格式化后的标签字符串,如 "[GitHub]"。
"""
display = SOURCE_DISPLAY_NAMES.get(source_name, source_name)
return f"[{display}]"
def hash_text(text: str) -> str:
"""计算文本的 MD5 哈希值。
Args:
text: 待哈希的文本。
Returns:
MD5 哈希字符串。
"""
return hashlib.md5(text.encode("utf-8")).hexdigest()
def clean_html(html_text: str) -> str:
"""简单清理 HTML 标签,返回纯文本。
Args:
html_text: 包含 HTML 标签的文本。
Returns:
清理后的纯文本。
"""
if not html_text:
return ""
# 移除 HTML 标签
text = re.sub(r"<[^>]+>", "", html_text)
# 处理常见的 HTML 实体
text = text.replace("&", "&")
text = text.replace("<", "<")
text = text.replace(">", ">")
text = text.replace(""", '"')
text = text.replace("'", "'")
text = text.replace(" ", " ")
# 压缩连续空白
text = re.sub(r"\s+", " ", text).strip()
return text
def parse_iso_datetime(dt_str: str) -> Optional[datetime]:
"""解析 ISO 格式的日期时间字符串。
Args:
dt_str: ISO 日期时间字符串。
Returns:
datetime 对象,解析失败返回 None。
"""
if not dt_str:
return None
try:
# 处理多种常见格式
if "T" in dt_str:
# 移除时区后缀
clean = re.sub(r"[Zz]$", "", dt_str)
clean = re.sub(r"[+\-]\d{2}:\d{2}$", "", clean)
# 截断微秒
if "." in clean:
clean = clean.split(".")[0]
return datetime.strptime(clean, "%Y-%m-%dT%H:%M:%S")
else:
return datetime.strptime(dt_str, "%Y-%m-%d")
except (ValueError, TypeError):
return None
def days_ago(dt_str: str) -> int:
"""计算指定日期距今的天数。
Args:
dt_str: 日期时间字符串。
Returns:
距今天数(正数表示过去),解析失败返回 0。
"""
dt = parse_iso_datetime(dt_str)
if dt is None:
return 0
delta = datetime.now() - dt
return max(0, delta.days)
FILE:scripts/report_exporter.py
#!/usr/bin/env python3
"""
knowledge-mesh 搜索报告导出模块
支持将搜索结果导出为 Markdown 报告、CSV 文件,
生成趋势分析图表和使用统计。
用法:
python3 report_exporter.py --action export-markdown --data '{"query":"...","results":[...]}'
python3 report_exporter.py --action export-csv --data '{"results":[...]}'
python3 report_exporter.py --action trending --data '{"results":[...]}'
python3 report_exporter.py --action stats
"""
import csv
import io
import os
import sys
from collections import Counter
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
SOURCE_DISPLAY_NAMES,
check_subscription,
format_source_badge,
get_data_file,
highlight_match,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
today_str,
truncate_text,
write_json_file,
)
# ============================================================
# 常量
# ============================================================
# 搜索历史文件
SEARCH_HISTORY_FILE = "search_history.json"
# 导出统计文件
EXPORT_STATS_FILE = "export_stats.json"
# ============================================================
# 搜索历史管理
# ============================================================
def _get_search_history() -> List[Dict[str, Any]]:
"""读取搜索历史记录。"""
data = read_json_file(get_data_file(SEARCH_HISTORY_FILE))
if isinstance(data, list):
return data
return []
def _save_search_history(history: List[Dict[str, Any]]) -> None:
"""保存搜索历史记录。"""
# 最多保留 500 条记录
write_json_file(get_data_file(SEARCH_HISTORY_FILE), history[-500:])
def _record_export(export_type: str, count: int) -> None:
"""记录导出操作到统计数据。
Args:
export_type: 导出类型。
count: 导出记录数。
"""
stats_file = get_data_file(EXPORT_STATS_FILE)
stats = read_json_file(stats_file)
if not isinstance(stats, dict):
stats = {"exports": [], "total_exports": 0}
stats["exports"].append({
"type": export_type,
"count": count,
"timestamp": now_iso(),
})
stats["total_exports"] = stats.get("total_exports", 0) + 1
# 最多保留最近 200 条导出记录
stats["exports"] = stats["exports"][-200:]
write_json_file(stats_file, stats)
# ============================================================
# Markdown 报告生成
# ============================================================
def _generate_markdown_report(
query: str,
results: List[Dict[str, Any]],
title: Optional[str] = None,
) -> str:
"""生成 Markdown 格式的搜索报告。
Args:
query: 搜索查询。
results: 搜索结果列表。
title: 可选的报告标题。
Returns:
Markdown 报告字符串。
"""
report_title = title or f"知识搜索报告: {query}"
today = today_str()
parts = []
parts.append(f"# {report_title}\n")
parts.append(f"**查询关键词**: {query}")
parts.append(f"**搜索时间**: {now_iso()}")
parts.append(f"**结果总数**: {len(results)} 条\n")
# 来源分布统计
source_counts = Counter()
for r in results:
source_counts[r.get("source", "unknown")] += 1
if source_counts:
parts.append("## 来源分布\n")
parts.append("| 来源 | 结果数 |")
parts.append("|------|--------|")
for source, count in source_counts.most_common():
display = SOURCE_DISPLAY_NAMES.get(source, source)
parts.append(f"| {display} | {count} |")
parts.append("")
# 搜索结果详情
parts.append("## 搜索结果\n")
for idx, r in enumerate(results, 1):
source = r.get("source", "")
badge = format_source_badge(source)
result_title = r.get("title", "无标题")
url = r.get("url", "")
snippet = r.get("snippet", "")
author = r.get("author", "")
created = r.get("created_at", "")
score = r.get("_combined_score", r.get("score", 0))
tags = r.get("tags", [])
# 高亮查询词
highlighted_title = highlight_match(result_title, query)
highlighted_snippet = highlight_match(truncate_text(snippet, 250), query)
parts.append(f"### {idx}. {badge} {highlighted_title}\n")
if url:
parts.append(f"- **链接**: [{url}]({url})")
if author:
parts.append(f"- **作者**: {author}")
if created:
parts.append(f"- **发布日期**: {created}")
if isinstance(score, (int, float)) and score > 0:
parts.append(f"- **相关度**: {score:.2f}")
if tags:
tag_str = ", ".join(f"`{t}`" for t in tags[:5])
parts.append(f"- **标签**: {tag_str}")
parts.append(f"\n> {highlighted_snippet}\n")
# 页脚
parts.append("---\n")
parts.append(f"*报告由 knowledge-mesh 于 {now_iso()} 生成*")
return "\n".join(parts)
# ============================================================
# CSV 导出
# ============================================================
def _generate_csv(results: List[Dict[str, Any]]) -> str:
"""将搜索结果导出为 CSV 格式。
Args:
results: 搜索结果列表。
Returns:
CSV 格式字符串。
"""
fieldnames = [
"source", "title", "url", "snippet", "author",
"created_at", "score", "tags",
]
output_buf = io.StringIO()
writer = csv.DictWriter(output_buf, fieldnames=fieldnames)
writer.writeheader()
for r in results:
row = {
"source": SOURCE_DISPLAY_NAMES.get(r.get("source", ""), r.get("source", "")),
"title": r.get("title", ""),
"url": r.get("url", ""),
"snippet": truncate_text(r.get("snippet", ""), 200),
"author": r.get("author", ""),
"created_at": r.get("created_at", ""),
"score": r.get("_combined_score", r.get("score", 0)),
"tags": ", ".join(r.get("tags", [])),
}
writer.writerow(row)
return output_buf.getvalue()
# ============================================================
# 趋势分析(付费功能)
# ============================================================
def _generate_trending_report(results: List[Dict[str, Any]]) -> str:
"""生成趋势分析报告,包含 Mermaid 图表。
Args:
results: 搜索结果列表。
Returns:
Markdown 格式的趋势报告。
"""
parts = []
parts.append("# 知识趋势分析报告\n")
parts.append(f"**生成时间**: {now_iso()}")
parts.append(f"**分析样本**: {len(results)} 条结果\n")
# 来源分布饼图
source_counts = Counter()
for r in results:
source_counts[r.get("source", "unknown")] += 1
parts.append("## 来源分布\n")
parts.append("```mermaid")
parts.append("pie title 知识来源分布")
for source, count in source_counts.most_common():
display = SOURCE_DISPLAY_NAMES.get(source, source)
parts.append(f' "{display}" : {count}')
parts.append("```\n")
# 时间趋势(按周分组)
week_counts = Counter()
for r in results:
created = r.get("created_at", "")
if created:
try:
if "T" in created:
dt = datetime.strptime(created.split("T")[0], "%Y-%m-%d")
else:
dt = datetime.strptime(created, "%Y-%m-%d")
# 按周分组
week_start = dt - timedelta(days=dt.weekday())
week_key = week_start.strftime("%m/%d")
week_counts[week_key] += 1
except (ValueError, TypeError):
continue
if week_counts:
# 按时间排序
sorted_weeks = sorted(week_counts.items())
recent_weeks = sorted_weeks[-8:] # 最近 8 周
parts.append("## 内容发布时间趋势\n")
labels = [f'"{w[0]}"' for w in recent_weeks]
values = [str(w[1]) for w in recent_weeks]
parts.append("```mermaid")
parts.append("xychart-beta")
parts.append(' title "近期内容发布趋势"')
parts.append(f' x-axis [{", ".join(labels)}]')
parts.append(' y-axis "内容数量"')
parts.append(f' bar [{", ".join(values)}]')
parts.append("```\n")
# 热门标签
all_tags = []
for r in results:
all_tags.extend(r.get("tags", []))
if all_tags:
tag_counts = Counter(all_tags).most_common(15)
parts.append("## 热门标签\n")
parts.append("| 标签 | 出现次数 |")
parts.append("|------|----------|")
for tag, count in tag_counts:
parts.append(f"| `{tag}` | {count} |")
parts.append("")
# Mermaid 柱状图展示 Top 10 标签
top10_tags = tag_counts[:10]
if top10_tags:
tag_labels = [f'"{t[0]}"' for t in top10_tags]
tag_values = [str(t[1]) for t in top10_tags]
parts.append("```mermaid")
parts.append("xychart-beta")
parts.append(' title "Top 10 热门标签"')
parts.append(f' x-axis [{", ".join(tag_labels)}]')
parts.append(' y-axis "出现次数"')
parts.append(f' bar [{", ".join(tag_values)}]')
parts.append("```\n")
# 高分结果
scored = [r for r in results if r.get("score", 0) > 0 or r.get("_combined_score", 0) > 0]
if scored:
scored.sort(key=lambda r: r.get("_combined_score", r.get("score", 0)), reverse=True)
top5 = scored[:5]
parts.append("## 高相关度内容 Top 5\n")
for idx, r in enumerate(top5, 1):
badge = format_source_badge(r.get("source", ""))
title = r.get("title", "无标题")
url = r.get("url", "")
score = r.get("_combined_score", r.get("score", 0))
parts.append(f"{idx}. {badge} **{title}**")
if url:
parts.append(f" - 链接: {url}")
parts.append(f" - 相关度: {score:.2f}")
parts.append("")
parts.append("---\n")
parts.append("*由 knowledge-mesh 自动生成*")
return "\n".join(parts)
# ============================================================
# 操作实现
# ============================================================
def action_export_markdown(data: Dict[str, Any]) -> None:
"""导出搜索结果为 Markdown 报告。
Args:
data: 包含 query、results 的字典,可选 title 和 file_path。
"""
query = data.get("query", "")
results = data.get("results", [])
title = data.get("title")
file_path = data.get("file_path")
if not results:
output_error("无搜索结果可导出", code="NO_DATA")
return
report = _generate_markdown_report(query, results, title)
if file_path:
try:
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
f.write(report)
_record_export("markdown", len(results))
output_success({
"message": f"报告已导出到 {file_path}",
"file_path": file_path,
"result_count": len(results),
})
except IOError as e:
output_error(f"文件写入失败: {e}", code="IO_ERROR")
else:
_record_export("markdown", len(results))
output_success({
"report": report,
"result_count": len(results),
})
def action_export_csv(data: Dict[str, Any]) -> None:
"""导出搜索结果为 CSV 格式。
Args:
data: 包含 results 的字典,可选 file_path。
"""
results = data.get("results", [])
file_path = data.get("file_path")
if not results:
output_error("无搜索结果可导出", code="NO_DATA")
return
csv_content = _generate_csv(results)
if file_path:
try:
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
with open(file_path, "w", encoding="utf-8", newline="") as f:
f.write(csv_content)
_record_export("csv", len(results))
output_success({
"message": f"CSV 已导出到 {file_path}",
"file_path": file_path,
"result_count": len(results),
})
except IOError as e:
output_error(f"文件写入失败: {e}", code="IO_ERROR")
else:
_record_export("csv", len(results))
output_success({
"csv": csv_content,
"result_count": len(results),
})
def action_trending(data: Dict[str, Any]) -> None:
"""生成趋势分析报告(付费功能)。
Args:
data: 包含 results 的字典。
"""
if not require_paid_feature("mermaid_chart", "趋势图表分析"):
return
results = data.get("results", [])
if not results:
output_error("无搜索结果可分析", code="NO_DATA")
return
report = _generate_trending_report(results)
output_success({
"report": report,
"result_count": len(results),
})
def action_stats(data: Optional[Dict[str, Any]] = None) -> None:
"""显示搜索使用统计。"""
# 读取使用量数据
usage_file = get_data_file("usage.json")
usage = read_json_file(usage_file)
if not isinstance(usage, dict):
usage = {}
# 读取导出统计
stats_file = get_data_file(EXPORT_STATS_FILE)
export_stats = read_json_file(stats_file)
if not isinstance(export_stats, dict):
export_stats = {"exports": [], "total_exports": 0}
# 搜索统计
today = today_str()
today_searches = usage.get(today, 0)
total_searches = sum(usage.values())
# 过去 7 天统计
daily_stats = {}
for i in range(7):
day = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d")
daily_stats[day] = usage.get(day, 0)
# 检查订阅信息
sub = check_subscription()
daily_limit = sub.get("daily_searches", 10)
remaining = max(0, daily_limit - today_searches) if daily_limit >= 0 else -1
# 导出统计
export_type_counts = Counter()
for exp in export_stats.get("exports", []):
export_type_counts[exp.get("type", "unknown")] += 1
output_success({
"subscription_tier": sub["tier"],
"today_searches": today_searches,
"daily_limit": daily_limit,
"remaining_today": remaining,
"total_searches_7d": total_searches,
"daily_breakdown": daily_stats,
"total_exports": export_stats.get("total_exports", 0),
"export_by_type": dict(export_type_counts),
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("knowledge-mesh 搜索报告导出")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"export-markdown": lambda: action_export_markdown(data or {}),
"export-csv": lambda: action_export_csv(data or {}),
"trending": lambda: action_trending(data or {}),
"stats": lambda: action_stats(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/index_builder.py
#!/usr/bin/env python3
"""
knowledge-mesh 本地知识索引构建器
对本地 .md/.txt/.py 等文件构建 TF-IDF 倒排索引,
支持全文搜索、索引管理和重建。
用法:
python3 index_builder.py --action index --data '{"paths":["./docs"],"patterns":["*.md","*.txt"]}'
python3 index_builder.py --action search-local --data '{"query":"async python"}'
python3 index_builder.py --action list-indexed
python3 index_builder.py --action rebuild
python3 index_builder.py --action delete --data '{"doc_id":"DOC20260319..."}'
"""
import fnmatch
import math
import os
import re
import sys
from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
truncate_text,
write_json_file,
)
# ============================================================
# 常量
# ============================================================
# 索引数据文件
INDEX_DATA_FILE = "index_data.json"
# 文档元数据文件
DOC_METADATA_FILE = "doc_metadata.json"
# 支持索引的文件扩展名
INDEXABLE_EXTENSIONS = {".md", ".txt", ".py", ".js", ".ts", ".java", ".go", ".rs", ".rb", ".sh", ".yaml", ".yml", ".json", ".toml", ".cfg", ".ini", ".html", ".css", ".sql", ".markdown"}
# Obsidian 特性正则表达式
RE_WIKILINK = re.compile(r"\[\[([^\]|]+)(?:\|[^\]]+)?\]\]")
RE_TAG_OBSIDIAN = re.compile(r"(?:^|\s)#([a-zA-Z0-9_\u4e00-\u9fff][a-zA-Z0-9_/\u4e00-\u9fff]*)")
RE_FRONTMATTER = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL)
# Wikilink 和标签的额外权重
WIKILINK_WEIGHT = 2.0
TAG_WEIGHT = 3.0
# 默认 glob 模式
DEFAULT_PATTERNS = ["*.md", "*.txt", "*.py"]
# 停用词
STOP_WORDS = {
"the", "a", "an", "is", "are", "was", "were", "be", "been",
"being", "have", "has", "had", "do", "does", "did", "will",
"would", "could", "should", "may", "might", "can", "shall",
"to", "of", "in", "for", "on", "with", "at", "by", "from",
"as", "into", "through", "during", "before", "after", "and",
"but", "or", "not", "so", "if", "then", "than", "too", "very",
"just", "about", "up", "out", "no", "it", "its", "this", "that",
"i", "me", "my", "we", "our", "you", "your", "he", "she",
"they", "them", "their", "what", "which", "who", "how", "when",
"where", "why", "all", "each", "every", "both",
"的", "了", "在", "是", "我", "有", "和", "就", "不", "人",
"都", "一", "一个", "上", "也", "很", "到", "说", "要", "去",
"你", "会", "着", "没有", "看", "好", "自己", "这", "他", "她",
"import", "from", "def", "class", "return", "self", "none",
"true", "false", "pass", "elif", "else", "try", "except",
}
# 单次索引最大文件数
MAX_INDEX_FILES = 500
# 单个文件最大大小(字节)
MAX_FILE_SIZE = 1024 * 1024 # 1MB
# ============================================================
# 文本分词
# ============================================================
def _tokenize(text: str) -> List[str]:
"""将文本分词为词语列表。
Args:
text: 原始文本。
Returns:
小写词语列表(去停用词)。
"""
if not text:
return []
tokens = re.findall(r"[a-zA-Z0-9_]{2,}|[\u4e00-\u9fff]+", text.lower())
return [t for t in tokens if t not in STOP_WORDS and len(t) >= 2]
def _tokenize_with_positions(text: str) -> List[Tuple[str, int]]:
"""分词并记录每个词的起始位置。
Args:
text: 原始文本。
Returns:
(词语, 位置) 元组列表。
"""
if not text:
return []
result = []
for match in re.finditer(r"[a-zA-Z0-9_]{2,}|[\u4e00-\u9fff]+", text.lower()):
token = match.group(0)
if token not in STOP_WORDS and len(token) >= 2:
result.append((token, match.start()))
return result
# ============================================================
# 文件扫描
# ============================================================
def _match_patterns(filename: str, patterns: List[str]) -> bool:
"""检查文件名是否匹配任一 glob 模式。
Args:
filename: 文件名。
patterns: glob 模式列表。
Returns:
是否匹配。
"""
for pattern in patterns:
if fnmatch.fnmatch(filename, pattern):
return True
return False
def _scan_files(
paths: List[str],
patterns: Optional[List[str]] = None,
) -> List[str]:
"""扫描指定路径下匹配模式的文件。
Args:
paths: 待扫描的目录或文件路径列表。
patterns: glob 模式列表,默认 ["*.md", "*.txt", "*.py"]。
Returns:
匹配的文件绝对路径列表。
"""
if patterns is None:
patterns = DEFAULT_PATTERNS
matched_files = []
for path in paths:
abs_path = os.path.abspath(path)
if os.path.isfile(abs_path):
# 检查扩展名
_, ext = os.path.splitext(abs_path)
if ext.lower() in INDEXABLE_EXTENSIONS or _match_patterns(os.path.basename(abs_path), patterns):
matched_files.append(abs_path)
elif os.path.isdir(abs_path):
for root, dirs, files in os.walk(abs_path):
# 跳过隐藏目录和常见忽略目录
dirs[:] = [d for d in dirs if not d.startswith(".") and d not in {"node_modules", "__pycache__", "venv", ".git"}]
for f in files:
if _match_patterns(f, patterns):
filepath = os.path.join(root, f)
matched_files.append(filepath)
if len(matched_files) >= MAX_INDEX_FILES:
break
return matched_files[:MAX_INDEX_FILES]
def _read_file_content(filepath: str) -> Optional[str]:
"""安全读取文件内容。
Args:
filepath: 文件路径。
Returns:
文件文本内容,读取失败返回 None。
"""
try:
size = os.path.getsize(filepath)
if size > MAX_FILE_SIZE:
return None
with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
except (IOError, OSError):
return None
# ============================================================
# Obsidian 特性解析
# ============================================================
def _extract_wikilinks(content: str) -> List[str]:
"""提取 Obsidian [[wikilinks]] 作为额外关键词信号。
Args:
content: 文件内容。
Returns:
链接目标列表。
"""
return RE_WIKILINK.findall(content)
def _extract_obsidian_tags(content: str) -> List[str]:
"""提取 Obsidian #tags 作为高权重关键词。
Args:
content: 文件内容。
Returns:
标签列表(不含 # 前缀)。
"""
return RE_TAG_OBSIDIAN.findall(content)
def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
"""解析 YAML frontmatter 元数据。
解析日期、分类、作者等字段用于过滤。
Args:
content: 文件内容。
Returns:
(frontmatter 字典, 去除 frontmatter 后的内容) 元组。
"""
match = RE_FRONTMATTER.match(content)
if not match:
return {}, content
fm_text = match.group(1)
body = content[match.end():]
frontmatter = {}
current_key = None
current_list = None
for line in fm_text.split("\n"):
line = line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("- ") and current_key and current_list is not None:
value = line[2:].strip().strip('"').strip("'")
current_list.append(value)
continue
if ":" in line:
idx = line.index(":")
key = line[:idx].strip()
value = line[idx + 1:].strip()
current_key = key
if not value:
current_list = []
frontmatter[key] = current_list
else:
value = value.strip('"').strip("'")
if value.startswith("[") and value.endswith("]"):
items = [v.strip().strip('"').strip("'")
for v in value[1:-1].split(",")]
frontmatter[key] = [i for i in items if i]
else:
frontmatter[key] = value
current_list = None
return frontmatter, body
def _get_obsidian_extra_tokens(content: str) -> Tuple[List[Tuple[str, float]], Dict[str, Any]]:
"""从文件内容中提取 Obsidian 特性的额外 token 和元数据。
Wikilinks 作为加权 token(权重 2.0),
Tags 作为高权重 token(权重 3.0),
Frontmatter 作为过滤元数据。
Args:
content: 文件内容。
Returns:
(加权token列表, frontmatter字典) 元组。
"""
extra_tokens = []
# 提取 wikilinks 并分词
wikilinks = _extract_wikilinks(content)
for link in wikilinks:
link_tokens = _tokenize(link)
for token in link_tokens:
extra_tokens.append((token, WIKILINK_WEIGHT))
# 提取标签
tags = _extract_obsidian_tags(content)
for tag in tags:
tag_tokens = _tokenize(tag)
for token in tag_tokens:
extra_tokens.append((token, TAG_WEIGHT))
# 完整标签也作为 token
if len(tag) >= 2:
extra_tokens.append((tag.lower(), TAG_WEIGHT))
# 解析 frontmatter
frontmatter, _ = _parse_frontmatter(content)
return extra_tokens, frontmatter
# ============================================================
# 倒排索引构建
# ============================================================
def _build_inverted_index(
documents: Dict[str, str],
) -> Tuple[Dict[str, List[Dict[str, Any]]], Dict[str, Dict[str, Any]]]:
"""构建倒排索引和文档元数据。
Args:
documents: {doc_id: content} 映射。
Returns:
(倒排索引, 文档元数据) 元组。
倒排索引格式: {term: [{doc_id, tf, positions}]}
文档元数据格式: {doc_id: {path, size, token_count, indexed_at}}
"""
inverted_index = {}
doc_metadata = {}
doc_count = len(documents)
for doc_id, content in documents.items():
tokens_with_pos = _tokenize_with_positions(content)
tokens = [t for t, _ in tokens_with_pos]
token_count = len(tokens)
if token_count == 0:
continue
# 统计词频和位置
term_info = {}
for token, pos in tokens_with_pos:
if token not in term_info:
term_info[token] = {"count": 0, "positions": []}
term_info[token]["count"] += 1
term_info[token]["positions"].append(pos)
# 提取 Obsidian 特性的额外加权 token
extra_tokens, frontmatter = _get_obsidian_extra_tokens(content)
for token, weight in extra_tokens:
if token not in term_info:
term_info[token] = {"count": 0, "positions": []}
# 加权计数:wikilinks 和标签获得额外权重
term_info[token]["count"] += int(weight)
# 构建倒排索引条目
for term, info in term_info.items():
tf = 1 + math.log(info["count"]) if info["count"] > 0 else 0
entry = {
"doc_id": doc_id,
"tf": round(tf, 4),
"positions": info["positions"][:20], # 最多保留 20 个位置
}
if term not in inverted_index:
inverted_index[term] = []
inverted_index[term].append(entry)
doc_metadata[doc_id] = {
"token_count": token_count,
"unique_terms": len(term_info),
"frontmatter": frontmatter if frontmatter else None,
}
return inverted_index, doc_metadata
def _search_index(
query: str,
inverted_index: Dict[str, List[Dict[str, Any]]],
doc_metadata_map: Dict[str, Dict[str, Any]],
max_results: int = 20,
) -> List[Dict[str, Any]]:
"""在倒排索引中搜索。
Args:
query: 查询字符串。
inverted_index: 倒排索引。
doc_metadata_map: 文档元数据。
max_results: 最大返回结果数。
Returns:
排序后的搜索结果列表。
"""
query_tokens = _tokenize(query)
if not query_tokens:
return []
doc_count = len(doc_metadata_map)
if doc_count == 0:
return []
# 计算每个查询词的 IDF
term_idf = {}
for qt in query_tokens:
df = len(inverted_index.get(qt, []))
term_idf[qt] = math.log(doc_count / (1 + df)) if doc_count > 0 else 0
# 累计文档分数
doc_scores = {}
doc_matches = {}
for qt in query_tokens:
postings = inverted_index.get(qt, [])
idf_val = term_idf.get(qt, 0)
for posting in postings:
did = posting["doc_id"]
tf_val = posting.get("tf", 0)
score = tf_val * idf_val
if did not in doc_scores:
doc_scores[did] = 0.0
doc_matches[did] = []
doc_scores[did] += score
doc_matches[did].append(qt)
# 排序
ranked = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
results = []
for doc_id, score in ranked[:max_results]:
meta = doc_metadata_map.get(doc_id, {})
results.append({
"doc_id": doc_id,
"path": meta.get("path", ""),
"filename": os.path.basename(meta.get("path", "")),
"score": round(score, 4),
"matched_terms": doc_matches.get(doc_id, []),
"token_count": meta.get("token_count", 0),
"indexed_at": meta.get("indexed_at", ""),
})
return results
# ============================================================
# 索引持久化
# ============================================================
def _load_index() -> Dict[str, List[Dict[str, Any]]]:
"""加载已有的倒排索引。"""
data = read_json_file(get_data_file(INDEX_DATA_FILE))
if isinstance(data, dict):
return data
return {}
def _save_index(index: Dict[str, List[Dict[str, Any]]]) -> None:
"""保存倒排索引。"""
write_json_file(get_data_file(INDEX_DATA_FILE), index)
def _load_doc_metadata() -> Dict[str, Dict[str, Any]]:
"""加载文档元数据。"""
data = read_json_file(get_data_file(DOC_METADATA_FILE))
if isinstance(data, dict):
return data
return {}
def _save_doc_metadata(metadata: Dict[str, Dict[str, Any]]) -> None:
"""保存文档元数据。"""
write_json_file(get_data_file(DOC_METADATA_FILE), metadata)
# ============================================================
# 操作实现
# ============================================================
def action_index(data: Dict[str, Any]) -> None:
"""索引本地文件。
Args:
data: 包含 paths(目录/文件列表)和可选 patterns(glob 模式)的字典。
"""
if not require_paid_feature("local_index", "本地知识索引"):
return
paths = data.get("paths", [])
if not paths:
output_error("请提供待索引的路径列表(paths)", code="VALIDATION_ERROR")
return
if isinstance(paths, str):
paths = [paths]
patterns = data.get("patterns", DEFAULT_PATTERNS)
# 扫描文件
files = _scan_files(paths, patterns)
if not files:
output_error("未找到匹配的文件", code="NO_FILES")
return
# 读取文件内容
documents = {}
file_meta = {}
skipped = 0
for filepath in files:
content = _read_file_content(filepath)
if content is None:
skipped += 1
continue
doc_id = generate_id("DOC")
documents[doc_id] = content
file_meta[doc_id] = {
"path": filepath,
"size": os.path.getsize(filepath),
"indexed_at": now_iso(),
}
if not documents:
output_error("所有文件均无法读取", code="READ_ERROR")
return
# 构建索引
inverted_index, build_meta = _build_inverted_index(documents)
# 合并到现有索引
existing_index = _load_index()
existing_meta = _load_doc_metadata()
# 合并倒排索引
for term, postings in inverted_index.items():
if term not in existing_index:
existing_index[term] = []
existing_index[term].extend(postings)
# 合并元数据
for doc_id, meta in file_meta.items():
bm = build_meta.get(doc_id, {})
existing_meta[doc_id] = {
**meta,
"token_count": bm.get("token_count", 0),
"unique_terms": bm.get("unique_terms", 0),
}
# 保存
_save_index(existing_index)
_save_doc_metadata(existing_meta)
output_success({
"message": f"索引完成:成功 {len(documents)} 个文件,跳过 {skipped} 个",
"indexed_count": len(documents),
"skipped_count": skipped,
"total_terms": len(existing_index),
"total_documents": len(existing_meta),
})
def action_search_local(data: Dict[str, Any]) -> None:
"""在本地索引中搜索。
Args:
data: 包含 query 的字典,可选 max_results。
"""
if not require_paid_feature("local_index", "本地知识索引"):
return
query = data.get("query", "").strip()
if not query:
output_error("请提供搜索关键词(query)", code="VALIDATION_ERROR")
return
max_results = data.get("max_results", 20)
inverted_index = _load_index()
doc_metadata = _load_doc_metadata()
if not inverted_index or not doc_metadata:
output_error("本地索引为空,请先执行 index 操作构建索引", code="NO_INDEX")
return
results = _search_index(query, inverted_index, doc_metadata, max_results)
output_success({
"query": query,
"total": len(results),
"results": results,
})
def action_list_indexed(data: Optional[Dict[str, Any]] = None) -> None:
"""列出已索引的文档。"""
doc_metadata = _load_doc_metadata()
docs = []
for doc_id, meta in doc_metadata.items():
docs.append({
"doc_id": doc_id,
"path": meta.get("path", ""),
"filename": os.path.basename(meta.get("path", "")),
"size": meta.get("size", 0),
"token_count": meta.get("token_count", 0),
"indexed_at": meta.get("indexed_at", ""),
})
# 按索引时间倒序
docs.sort(key=lambda d: d.get("indexed_at", ""), reverse=True)
inverted_index = _load_index()
output_success({
"total_documents": len(docs),
"total_terms": len(inverted_index),
"documents": docs,
})
def action_rebuild(data: Optional[Dict[str, Any]] = None) -> None:
"""重建索引:根据已记录的文件路径重新索引。"""
if not require_paid_feature("local_index", "本地知识索引"):
return
doc_metadata = _load_doc_metadata()
if not doc_metadata:
output_error("无已索引文档,无需重建", code="NO_INDEX")
return
# 收集已有路径
paths = []
for meta in doc_metadata.values():
p = meta.get("path", "")
if p and os.path.exists(p):
paths.append(p)
if not paths:
output_error("所有已索引文件均不存在,无法重建", code="FILES_MISSING")
return
# 清空现有索引
_save_index({})
_save_doc_metadata({})
# 重新读取和索引
documents = {}
file_meta = {}
skipped = 0
for filepath in paths:
content = _read_file_content(filepath)
if content is None:
skipped += 1
continue
doc_id = generate_id("DOC")
documents[doc_id] = content
file_meta[doc_id] = {
"path": filepath,
"size": os.path.getsize(filepath),
"indexed_at": now_iso(),
}
if not documents:
output_error("所有文件均无法读取", code="READ_ERROR")
return
inverted_index, build_meta = _build_inverted_index(documents)
# 保存
final_meta = {}
for doc_id, meta in file_meta.items():
bm = build_meta.get(doc_id, {})
final_meta[doc_id] = {
**meta,
"token_count": bm.get("token_count", 0),
"unique_terms": bm.get("unique_terms", 0),
}
_save_index(inverted_index)
_save_doc_metadata(final_meta)
output_success({
"message": f"索引重建完成:成功 {len(documents)} 个文件,跳过 {skipped} 个",
"indexed_count": len(documents),
"skipped_count": skipped,
"total_terms": len(inverted_index),
"total_documents": len(final_meta),
})
def action_delete(data: Dict[str, Any]) -> None:
"""删除指定文档的索引。
Args:
data: 包含 doc_id 的字典。
"""
doc_id = data.get("doc_id", "").strip()
if not doc_id:
output_error("请提供文档ID(doc_id)", code="VALIDATION_ERROR")
return
doc_metadata = _load_doc_metadata()
if doc_id not in doc_metadata:
output_error(f"未找到文档: {doc_id}", code="NOT_FOUND")
return
# 从元数据中删除
removed_meta = doc_metadata.pop(doc_id)
_save_doc_metadata(doc_metadata)
# 从倒排索引中清除该文档的条目
inverted_index = _load_index()
terms_to_remove = []
for term, postings in inverted_index.items():
inverted_index[term] = [p for p in postings if p.get("doc_id") != doc_id]
if not inverted_index[term]:
terms_to_remove.append(term)
for term in terms_to_remove:
del inverted_index[term]
_save_index(inverted_index)
output_success({
"message": f"文档 {doc_id} 的索引已删除",
"removed_path": removed_meta.get("path", ""),
"remaining_documents": len(doc_metadata),
"remaining_terms": len(inverted_index),
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("knowledge-mesh 本地知识索引")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"index": lambda: action_index(data or {}),
"search-local": lambda: action_search_local(data or {}),
"list-indexed": lambda: action_list_indexed(data),
"rebuild": lambda: action_rebuild(data),
"delete": lambda: action_delete(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/obsidian_connector.py
#!/usr/bin/env python3
"""
knowledge-mesh Obsidian 知识库集成模块
将 Obsidian 笔记库作为知识源接入统一搜索系统,
支持笔记搜索、索引构建、Wikilink/标签解析和增量同步。
用法:
python3 obsidian_connector.py --action connect --data '{"vault_path":"/path/to/vault"}'
python3 obsidian_connector.py --action search --data '{"query":"python async"}'
python3 obsidian_connector.py --action index --data '{"vault_path":"/path/to/vault"}'
python3 obsidian_connector.py --action list-notes
python3 obsidian_connector.py --action sync
"""
import json
import math
import os
import re
import sys
from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from utils import (
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
truncate_text,
write_json_file,
)
# ============================================================
# 常量
# ============================================================
# Obsidian 索引数据文件
OBSIDIAN_INDEX_FILE = "obsidian_index.json"
# 支持索引的文件扩展名
OBSIDIAN_EXTENSIONS = {".md", ".markdown"}
# 最大索引笔记数
MAX_NOTES = 5000
# 单个文件最大大小(字节)
MAX_FILE_SIZE = 2 * 1024 * 1024 # 2MB
# 正则表达式:Obsidian 特性解析
RE_WIKILINK = re.compile(r"\[\[([^\]|]+)(?:\|[^\]]+)?\]\]")
RE_TAG = re.compile(r"(?:^|\s)#([a-zA-Z0-9_\u4e00-\u9fff][a-zA-Z0-9_/\u4e00-\u9fff]*)")
RE_FRONTMATTER = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL)
RE_CALLOUT = re.compile(r"^>\s*\[!(\w+)\]\s*(.*?)$", re.MULTILINE)
# 搜索相关性权重
TITLE_MATCH_WEIGHT = 3.0
TAG_MATCH_WEIGHT = 2.0
CONTENT_MATCH_WEIGHT = 1.0
BACKLINK_BONUS = 0.5
# 停用词
STOP_WORDS = {
"the", "a", "an", "is", "are", "in", "on", "for", "to", "of",
"and", "or", "not", "with", "it", "this", "that", "be", "was",
"的", "了", "在", "是", "我", "有", "和", "就", "不",
}
# ============================================================
# Obsidian 特性解析
# ============================================================
def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
"""解析 YAML frontmatter。
Args:
content: 笔记原始内容。
Returns:
(frontmatter 字典, 去除 frontmatter 后的内容) 元组。
"""
match = RE_FRONTMATTER.match(content)
if not match:
return {}, content
fm_text = match.group(1)
body = content[match.end():]
# 简单 YAML 解析(不依赖 PyYAML)
frontmatter = {}
current_key = None
current_list = None
for line in fm_text.split("\n"):
line = line.strip()
if not line or line.startswith("#"):
continue
# 检查是否为列表项
if line.startswith("- ") and current_key:
value = line[2:].strip().strip('"').strip("'")
if current_list is not None:
current_list.append(value)
continue
# 检查键值对
if ":" in line:
idx = line.index(":")
key = line[:idx].strip()
value = line[idx + 1:].strip()
current_key = key
if not value:
# 可能是列表的开始
current_list = []
frontmatter[key] = current_list
else:
# 去除引号
value = value.strip('"').strip("'")
# 尝试转换为布尔/数字
if value.lower() in ("true", "yes"):
frontmatter[key] = True
elif value.lower() in ("false", "no"):
frontmatter[key] = False
else:
try:
frontmatter[key] = int(value)
except ValueError:
try:
frontmatter[key] = float(value)
except ValueError:
# 处理列表格式 [item1, item2]
if value.startswith("[") and value.endswith("]"):
items = [v.strip().strip('"').strip("'")
for v in value[1:-1].split(",")]
frontmatter[key] = [i for i in items if i]
else:
frontmatter[key] = value
current_list = None
return frontmatter, body
def _extract_wikilinks(content: str) -> List[str]:
"""提取所有 [[wikilinks]]。
Args:
content: 笔记内容。
Returns:
链接目标列表(不含别名部分)。
"""
return RE_WIKILINK.findall(content)
def _extract_tags(content: str) -> List[str]:
"""提取所有 #tags。
Args:
content: 笔记内容。
Returns:
标签列表(不含 # 前缀)。
"""
return RE_TAG.findall(content)
def _extract_callouts(content: str) -> List[Dict[str, str]]:
"""提取 Obsidian 样式的 callout 块。
Args:
content: 笔记内容。
Returns:
callout 信息列表。
"""
callouts = []
for match in RE_CALLOUT.finditer(content):
callouts.append({
"type": match.group(1),
"title": match.group(2).strip(),
})
return callouts
def _tokenize(text: str) -> List[str]:
"""将文本分词为词语列表。
Args:
text: 原始文本。
Returns:
小写词语列表。
"""
if not text:
return []
tokens = re.findall(r"[a-zA-Z0-9_]{2,}|[\u4e00-\u9fff]+", text.lower())
return [t for t in tokens if t not in STOP_WORDS and len(t) >= 2]
# ============================================================
# 笔记扫描与解析
# ============================================================
def _scan_vault(vault_path: str) -> List[str]:
"""扫描 Obsidian vault 中的所有笔记文件。
Args:
vault_path: vault 根目录路径。
Returns:
笔记文件绝对路径列表。
"""
notes = []
abs_vault = os.path.abspath(vault_path)
for root, dirs, files in os.walk(abs_vault):
# 跳过隐藏目录和 Obsidian 配置目录
dirs[:] = [d for d in dirs if not d.startswith(".")]
for f in files:
_, ext = os.path.splitext(f)
if ext.lower() in OBSIDIAN_EXTENSIONS:
filepath = os.path.join(root, f)
notes.append(filepath)
if len(notes) >= MAX_NOTES:
return notes
return notes
def _read_note(filepath: str) -> Optional[str]:
"""安全读取笔记文件内容。
Args:
filepath: 文件路径。
Returns:
文件文本内容,读取失败返回 None。
"""
try:
size = os.path.getsize(filepath)
if size > MAX_FILE_SIZE:
return None
with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
except (IOError, OSError):
return None
def _parse_note(filepath: str, vault_path: str) -> Optional[Dict[str, Any]]:
"""解析单个笔记文件,提取元数据和内容。
Args:
filepath: 笔记文件路径。
vault_path: vault 根目录路径。
Returns:
笔记信息字典,失败返回 None。
"""
content = _read_note(filepath)
if content is None:
return None
# 解析 frontmatter
frontmatter, body = _parse_frontmatter(content)
# 提取 Obsidian 特性
wikilinks = _extract_wikilinks(body)
tags = _extract_tags(body)
callouts = _extract_callouts(body)
# 从 frontmatter 中补充标签
fm_tags = frontmatter.get("tags", [])
if isinstance(fm_tags, str):
fm_tags = [fm_tags]
elif not isinstance(fm_tags, list):
fm_tags = []
all_tags = list(set(tags + fm_tags))
# 文件基本信息
filename = os.path.basename(filepath)
name_without_ext = os.path.splitext(filename)[0]
rel_path = os.path.relpath(filepath, vault_path)
# 获取文件修改时间
try:
mtime = os.path.getmtime(filepath)
modified_at = datetime.fromtimestamp(mtime).strftime("%Y-%m-%dT%H:%M:%S")
except (OSError, ValueError):
modified_at = ""
# 分词
tokens = _tokenize(f"{name_without_ext} {body}")
return {
"filepath": filepath,
"rel_path": rel_path,
"title": name_without_ext,
"content": body,
"frontmatter": frontmatter,
"tags": all_tags,
"wikilinks": wikilinks,
"callouts": callouts,
"tokens": tokens,
"token_count": len(tokens),
"modified_at": modified_at,
"size": os.path.getsize(filepath),
}
# ============================================================
# 索引持久化
# ============================================================
def _get_index_file() -> str:
"""获取 Obsidian 索引文件路径。"""
return get_data_file(OBSIDIAN_INDEX_FILE)
def _load_index() -> Dict[str, Any]:
"""加载 Obsidian 索引数据。
Returns:
索引数据字典。
"""
data = read_json_file(_get_index_file())
if not isinstance(data, dict):
data = {}
if "vault_path" not in data:
data["vault_path"] = ""
if "notes" not in data:
data["notes"] = {}
if "backlink_graph" not in data:
data["backlink_graph"] = {}
if "tag_index" not in data:
data["tag_index"] = {}
if "last_sync" not in data:
data["last_sync"] = ""
if "total_notes" not in data:
data["total_notes"] = 0
return data
def _save_index(data: Dict[str, Any]) -> None:
"""保存 Obsidian 索引数据。"""
write_json_file(_get_index_file(), data)
# ============================================================
# 搜索引擎
# ============================================================
def _search_notes(
query: str,
index_data: Dict[str, Any],
max_results: int = 20,
) -> List[Dict[str, Any]]:
"""在索引中搜索笔记。
使用多级相关性评分:标题匹配 > 标签匹配 > 内容匹配。
Args:
query: 搜索查询。
index_data: 索引数据。
max_results: 最大返回结果数。
Returns:
排序后的搜索结果列表。
"""
notes = index_data.get("notes", {})
backlink_graph = index_data.get("backlink_graph", {})
if not notes:
return []
query_tokens = _tokenize(query)
query_lower = query.lower()
if not query_tokens:
return []
results = []
for note_id, note in notes.items():
score = 0.0
matched_in = []
title = note.get("title", "")
tags = note.get("tags", [])
content_preview = note.get("content_preview", "")
tokens = note.get("tokens_sample", [])
# 标题匹配(权重最高)
title_lower = title.lower()
title_match_count = 0
for qt in query_tokens:
if qt in title_lower:
title_match_count += 1
if title_match_count > 0:
score += TITLE_MATCH_WEIGHT * (title_match_count / len(query_tokens))
matched_in.append("title")
# 标签匹配
tags_lower = [t.lower() for t in tags]
tag_match_count = 0
for qt in query_tokens:
for tag in tags_lower:
if qt in tag:
tag_match_count += 1
break
if tag_match_count > 0:
score += TAG_MATCH_WEIGHT * (tag_match_count / len(query_tokens))
matched_in.append("tags")
# 内容匹配
content_lower = content_preview.lower()
content_match_count = 0
for qt in query_tokens:
if qt in content_lower:
content_match_count += 1
if content_match_count > 0:
score += CONTENT_MATCH_WEIGHT * (content_match_count / len(query_tokens))
matched_in.append("content")
# token 精确匹配加分
if tokens:
token_set = set(tokens)
token_matches = sum(1 for qt in query_tokens if qt in token_set)
if token_matches > 0:
score += 0.5 * (token_matches / len(query_tokens))
# 反向链接加分(被更多笔记引用的笔记更权威)
backlink_count = len(backlink_graph.get(title, []))
if backlink_count > 0:
score += BACKLINK_BONUS * min(backlink_count / 5.0, 1.0)
if score > 0:
results.append({
"id": note_id,
"source": "obsidian",
"title": title,
"url": f"obsidian://open?vault={os.path.basename(index_data.get('vault_path', ''))}&file={note.get('rel_path', '')}",
"snippet": truncate_text(content_preview, 300),
"author": note.get("frontmatter", {}).get("author", ""),
"created_at": note.get("modified_at", ""),
"score": round(score, 4),
"tags": tags,
"matched_in": matched_in,
"backlink_count": backlink_count,
"filepath": note.get("filepath", ""),
})
# 按分数降序排序
results.sort(key=lambda r: r.get("score", 0), reverse=True)
return results[:max_results]
# ============================================================
# 索引构建
# ============================================================
def _build_index(vault_path: str) -> Dict[str, Any]:
"""构建 Obsidian vault 的完整索引。
Args:
vault_path: vault 根目录路径。
Returns:
索引数据字典。
"""
note_files = _scan_vault(vault_path)
notes = {}
backlink_graph = {} # {被链接的笔记: [链接来源笔记]}
tag_index = {} # {标签: [笔记ID]}
skipped = 0
for filepath in note_files:
parsed = _parse_note(filepath, vault_path)
if parsed is None:
skipped += 1
continue
note_id = generate_id("ON")
title = parsed["title"]
# 存储笔记信息(不保存完整内容,只保存预览和采样 token)
notes[note_id] = {
"filepath": parsed["filepath"],
"rel_path": parsed["rel_path"],
"title": title,
"content_preview": truncate_text(parsed["content"], 500),
"frontmatter": parsed["frontmatter"],
"tags": parsed["tags"],
"wikilinks": parsed["wikilinks"],
"tokens_sample": parsed["tokens"][:100], # 保存前100个token用于搜索
"token_count": parsed["token_count"],
"modified_at": parsed["modified_at"],
"size": parsed["size"],
"mtime": os.path.getmtime(filepath) if os.path.exists(filepath) else 0,
}
# 构建反向链接图
for link_target in parsed["wikilinks"]:
if link_target not in backlink_graph:
backlink_graph[link_target] = []
backlink_graph[link_target].append(title)
# 构建标签索引
for tag in parsed["tags"]:
if tag not in tag_index:
tag_index[tag] = []
tag_index[tag].append(note_id)
index_data = {
"vault_path": os.path.abspath(vault_path),
"notes": notes,
"backlink_graph": backlink_graph,
"tag_index": tag_index,
"last_sync": now_iso(),
"total_notes": len(notes),
"skipped": skipped,
}
return index_data
# ============================================================
# 操作实现
# ============================================================
def action_connect(data: Dict[str, Any]) -> None:
"""连接到 Obsidian vault。
验证 vault 路径存在,扫描结构并返回摘要信息。
Args:
data: 包含 vault_path 的字典。
"""
vault_path = data.get("vault_path", "").strip()
if not vault_path:
# 尝试从环境变量读取
vault_path = os.environ.get("KM_OBSIDIAN_VAULT_PATH", "").strip()
if not vault_path:
output_error(
"请提供 Obsidian vault 路径(vault_path)或设置 KM_OBSIDIAN_VAULT_PATH 环境变量",
code="VALIDATION_ERROR",
)
return
abs_path = os.path.abspath(vault_path)
if not os.path.isdir(abs_path):
output_error(f"Vault 路径不存在或不是目录: {abs_path}", code="PATH_NOT_FOUND")
return
# 扫描 vault 结构
note_files = _scan_vault(abs_path)
# 检查是否有 .obsidian 配置目录
obsidian_config = os.path.join(abs_path, ".obsidian")
has_obsidian_config = os.path.isdir(obsidian_config)
# 统计目录结构
dir_counter = Counter()
tag_counter = Counter()
total_size = 0
for filepath in note_files[:200]: # 快速扫描前200个
rel_dir = os.path.dirname(os.path.relpath(filepath, abs_path))
if rel_dir:
dir_counter[rel_dir] += 1
else:
dir_counter["(root)"] += 1
total_size += os.path.getsize(filepath)
# 快速提取标签
content = _read_note(filepath)
if content:
tags = _extract_tags(content)
for tag in tags:
tag_counter[tag] += 1
top_dirs = [{"dir": d, "count": c} for d, c in dir_counter.most_common(10)]
top_tags = [{"tag": t, "count": c} for t, c in tag_counter.most_common(10)]
output_success({
"message": f"成功连接到 Obsidian vault: {abs_path}",
"vault_path": abs_path,
"has_obsidian_config": has_obsidian_config,
"total_notes": len(note_files),
"total_size_kb": round(total_size / 1024, 1),
"top_directories": top_dirs,
"top_tags": top_tags,
"status": "connected",
})
def action_search(data: Dict[str, Any]) -> None:
"""在 Obsidian vault 中搜索笔记。
Args:
data: 包含 query 的字典,可选 max_results。
"""
query = data.get("query", "").strip()
if not query:
output_error("请提供搜索关键词(query)", code="VALIDATION_ERROR")
return
max_results = data.get("max_results", 20)
index_data = _load_index()
if not index_data.get("notes"):
output_error(
"Obsidian 索引为空,请先执行 index 操作构建索引",
code="NO_INDEX",
)
return
results = _search_notes(query, index_data, max_results)
output_success({
"query": query,
"source": "obsidian",
"total": len(results),
"vault_path": index_data.get("vault_path", ""),
"results": results,
})
def action_index(data: Dict[str, Any]) -> None:
"""构建 Obsidian vault 的完整索引。
Args:
data: 包含 vault_path 的字典。
"""
vault_path = data.get("vault_path", "").strip()
if not vault_path:
vault_path = os.environ.get("KM_OBSIDIAN_VAULT_PATH", "").strip()
if not vault_path:
output_error(
"请提供 Obsidian vault 路径(vault_path)或设置 KM_OBSIDIAN_VAULT_PATH 环境变量",
code="VALIDATION_ERROR",
)
return
abs_path = os.path.abspath(vault_path)
if not os.path.isdir(abs_path):
output_error(f"Vault 路径不存在: {abs_path}", code="PATH_NOT_FOUND")
return
# 构建索引
index_data = _build_index(abs_path)
_save_index(index_data)
# 统计信息
total_tags = len(index_data.get("tag_index", {}))
total_backlinks = sum(
len(sources) for sources in index_data.get("backlink_graph", {}).values()
)
output_success({
"message": f"索引构建完成: {index_data['total_notes']} 篇笔记",
"vault_path": abs_path,
"total_notes": index_data["total_notes"],
"skipped": index_data.get("skipped", 0),
"total_tags": total_tags,
"total_backlinks": total_backlinks,
"last_sync": index_data["last_sync"],
})
def action_list_notes(data: Optional[Dict[str, Any]] = None) -> None:
"""列出已索引的 Obsidian 笔记。"""
index_data = _load_index()
notes = index_data.get("notes", {})
if not notes:
output_error("Obsidian 索引为空,请先执行 index 操作", code="NO_INDEX")
return
note_list = []
for note_id, note in notes.items():
note_list.append({
"id": note_id,
"title": note.get("title", ""),
"rel_path": note.get("rel_path", ""),
"tags": note.get("tags", []),
"token_count": note.get("token_count", 0),
"modified_at": note.get("modified_at", ""),
"size": note.get("size", 0),
"wikilinks_count": len(note.get("wikilinks", [])),
})
# 按修改时间倒序
note_list.sort(key=lambda n: n.get("modified_at", ""), reverse=True)
# 标签统计
tag_index = index_data.get("tag_index", {})
top_tags = sorted(tag_index.items(), key=lambda x: len(x[1]), reverse=True)[:10]
output_success({
"vault_path": index_data.get("vault_path", ""),
"total_notes": len(note_list),
"last_sync": index_data.get("last_sync", ""),
"top_tags": [{"tag": t, "note_count": len(ids)} for t, ids in top_tags],
"notes": note_list,
})
def action_sync(data: Optional[Dict[str, Any]] = None) -> None:
"""增量同步:重新索引自上次同步以来变化的文件。
通过比较文件的 mtime 来检测变化。
"""
index_data = _load_index()
vault_path = index_data.get("vault_path", "")
if not vault_path:
# 尝试从环境变量读取
vault_path = os.environ.get("KM_OBSIDIAN_VAULT_PATH", "").strip()
if not vault_path or not os.path.isdir(vault_path):
output_error(
"未找到已连接的 vault,请先执行 connect 或 index 操作",
code="NO_VAULT",
)
return
notes = index_data.get("notes", {})
# 扫描当前 vault 文件
current_files = _scan_vault(vault_path)
current_paths = {os.path.abspath(f) for f in current_files}
# 已索引的文件路径映射
indexed_paths = {}
for note_id, note in notes.items():
fp = note.get("filepath", "")
if fp:
indexed_paths[fp] = note_id
added = 0
updated = 0
removed = 0
# 检查新增和更新的文件
for filepath in current_files:
abs_fp = os.path.abspath(filepath)
if abs_fp in indexed_paths:
# 已索引,检查是否更新
note_id = indexed_paths[abs_fp]
old_mtime = notes[note_id].get("mtime", 0)
try:
current_mtime = os.path.getmtime(abs_fp)
except OSError:
continue
if current_mtime > old_mtime:
# 文件已更新,重新解析
parsed = _parse_note(abs_fp, vault_path)
if parsed:
notes[note_id].update({
"title": parsed["title"],
"content_preview": truncate_text(parsed["content"], 500),
"frontmatter": parsed["frontmatter"],
"tags": parsed["tags"],
"wikilinks": parsed["wikilinks"],
"tokens_sample": parsed["tokens"][:100],
"token_count": parsed["token_count"],
"modified_at": parsed["modified_at"],
"size": parsed["size"],
"mtime": current_mtime,
})
updated += 1
else:
# 新文件,添加索引
parsed = _parse_note(abs_fp, vault_path)
if parsed:
note_id = generate_id("ON")
notes[note_id] = {
"filepath": abs_fp,
"rel_path": parsed["rel_path"],
"title": parsed["title"],
"content_preview": truncate_text(parsed["content"], 500),
"frontmatter": parsed["frontmatter"],
"tags": parsed["tags"],
"wikilinks": parsed["wikilinks"],
"tokens_sample": parsed["tokens"][:100],
"token_count": parsed["token_count"],
"modified_at": parsed["modified_at"],
"size": parsed["size"],
"mtime": os.path.getmtime(abs_fp) if os.path.exists(abs_fp) else 0,
}
added += 1
# 检查已删除的文件
ids_to_remove = []
for note_id, note in notes.items():
fp = note.get("filepath", "")
if fp and fp not in current_paths:
ids_to_remove.append(note_id)
for note_id in ids_to_remove:
del notes[note_id]
removed += 1
# 重建反向链接图和标签索引
backlink_graph = {}
tag_index = {}
for note_id, note in notes.items():
title = note.get("title", "")
for link_target in note.get("wikilinks", []):
if link_target not in backlink_graph:
backlink_graph[link_target] = []
backlink_graph[link_target].append(title)
for tag in note.get("tags", []):
if tag not in tag_index:
tag_index[tag] = []
tag_index[tag].append(note_id)
index_data["notes"] = notes
index_data["backlink_graph"] = backlink_graph
index_data["tag_index"] = tag_index
index_data["last_sync"] = now_iso()
index_data["total_notes"] = len(notes)
_save_index(index_data)
output_success({
"message": f"同步完成: 新增 {added}, 更新 {updated}, 删除 {removed}",
"added": added,
"updated": updated,
"removed": removed,
"total_notes": len(notes),
"last_sync": index_data["last_sync"],
})
# ============================================================
# 公开 API(供其他模块调用)
# ============================================================
def search_obsidian(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索 Obsidian vault(供 source_searcher 调用)。
Args:
query: 搜索查询。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
index_data = _load_index()
if not index_data.get("notes"):
return []
return _search_notes(query, index_data, max_results)
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("knowledge-mesh Obsidian 知识库集成")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"connect": lambda: action_connect(data or {}),
"search": lambda: action_search(data or {}),
"index": lambda: action_index(data or {}),
"list-notes": lambda: action_list_notes(data),
"sync": lambda: action_sync(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/source_searcher.py
#!/usr/bin/env python3
"""
knowledge-mesh 知识源搜索模块
统一搜索 GitHub Discussions/Issues、Stack Overflow、Discord、
Confluence、Notion、Slack 等平台,返回标准化搜索结果。
用法:
python3 source_searcher.py --action search --data '{"query":"python async"}'
python3 source_searcher.py --action search-source --data '{"query":"fastapi","source":"github"}'
python3 source_searcher.py --action list-sources
python3 source_searcher.py --action test-source --data '{"source":"github"}'
"""
import json
import os
import re
import sys
import urllib.request
import urllib.error
import urllib.parse
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
SUPPORTED_SOURCES,
SOURCE_DISPLAY_NAMES,
SOURCE_ENV_KEYS,
check_subscription,
check_search_quota,
increment_search_count,
clean_html,
days_ago,
generate_id,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
parse_iso_datetime,
read_json_file,
truncate_text,
write_json_file,
get_data_file,
)
# ============================================================
# 常量
# ============================================================
# 搜索配置文件
SOURCES_CONFIG_FILE = "sources_config.json"
# 默认请求超时秒数
REQUEST_TIMEOUT = 15
# 各平台 API 基础地址
GITHUB_API_BASE = "https://api.github.com"
STACKOVERFLOW_API_BASE = "https://api.stackexchange.com/2.3"
DISCORD_API_BASE = "https://discord.com/api/v10"
NOTION_API_BASE = "https://api.notion.com/v1"
BAIDU_SEARCH_API = "https://www.baidu.com/s"
# ============================================================
# HTTP 请求辅助
# ============================================================
def _http_get(url: str, headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
"""发送 HTTP GET 请求并返回 JSON 响应。
Args:
url: 请求地址。
headers: 可选的请求头。
Returns:
解析后的 JSON 字典。
Raises:
RuntimeError: 请求失败时抛出。
"""
req = urllib.request.Request(url, method="GET")
req.add_header("User-Agent", "knowledge-mesh/1.0")
req.add_header("Accept", "application/json")
if headers:
for k, v in headers.items():
req.add_header(k, v)
try:
with urllib.request.urlopen(req, timeout=REQUEST_TIMEOUT) as resp:
body = resp.read().decode("utf-8")
return json.loads(body)
except urllib.error.HTTPError as e:
raise RuntimeError(f"HTTP 请求失败: {e.code} {e.reason}")
except urllib.error.URLError as e:
raise RuntimeError(f"网络请求失败: {e.reason}")
except json.JSONDecodeError:
raise RuntimeError("响应解析失败:非有效 JSON")
def _http_post(url: str, payload: Dict[str, Any],
headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
"""发送 HTTP POST 请求并返回 JSON 响应。
Args:
url: 请求地址。
payload: 请求体数据。
headers: 可选的请求头。
Returns:
解析后的 JSON 字典。
Raises:
RuntimeError: 请求失败时抛出。
"""
body_bytes = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(url, data=body_bytes, method="POST")
req.add_header("User-Agent", "knowledge-mesh/1.0")
req.add_header("Content-Type", "application/json")
req.add_header("Accept", "application/json")
if headers:
for k, v in headers.items():
req.add_header(k, v)
try:
with urllib.request.urlopen(req, timeout=REQUEST_TIMEOUT) as resp:
resp_body = resp.read().decode("utf-8")
return json.loads(resp_body)
except urllib.error.HTTPError as e:
raise RuntimeError(f"HTTP 请求失败: {e.code} {e.reason}")
except urllib.error.URLError as e:
raise RuntimeError(f"网络请求失败: {e.reason}")
except json.JSONDecodeError:
raise RuntimeError("响应解析失败:非有效 JSON")
# ============================================================
# 统一搜索结果格式
# ============================================================
def _make_result(
source: str,
title: str,
url: str,
snippet: str,
author: str = "",
created_at: str = "",
score: float = 0.0,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""构造标准化搜索结果条目。
Args:
source: 来源平台名称。
title: 结果标题。
url: 结果链接。
snippet: 内容摘要。
author: 作者名。
created_at: 创建时间。
score: 相关性分数。
tags: 标签列表。
Returns:
标准化的搜索结果字典。
"""
return {
"id": generate_id("SR"),
"source": source,
"title": title,
"url": url,
"snippet": truncate_text(snippet, 300),
"author": author,
"created_at": created_at,
"score": score,
"tags": tags or [],
}
# ============================================================
# GitHub 搜索适配器
# ============================================================
def _search_github(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索 GitHub Issues 和 Discussions。
使用 GitHub Search API /search/issues 接口,支持搜索
Issues 和 Pull Requests(Discussions 需要 GraphQL)。
Args:
query: 搜索关键词。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
token = os.environ.get("KM_GITHUB_TOKEN", "")
headers = {}
if token:
headers["Authorization"] = f"token {token}"
# 搜索 Issues 和 Discussions
encoded_query = urllib.parse.quote(query)
url = f"{GITHUB_API_BASE}/search/issues?q={encoded_query}&per_page={max_results}&sort=relevance"
try:
data = _http_get(url, headers=headers)
except RuntimeError as e:
return [_make_result("github", f"GitHub 搜索失败: {e}", "", str(e))]
results = []
items = data.get("items", [])
for item in items[:max_results]:
title = item.get("title", "")
html_url = item.get("html_url", "")
body = item.get("body", "") or ""
user = item.get("user", {})
author = user.get("login", "") if user else ""
created = item.get("created_at", "")
score_val = item.get("score", 0.0)
labels = [lb.get("name", "") for lb in item.get("labels", [])]
results.append(_make_result(
source="github",
title=title,
url=html_url,
snippet=clean_html(body),
author=author,
created_at=created,
score=float(score_val),
tags=labels,
))
return results
# ============================================================
# Stack Overflow 搜索适配器
# ============================================================
def _search_stackoverflow(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索 Stack Overflow 问题。
使用 Stack Exchange API v2.3 /search/advanced 接口。
Args:
query: 搜索关键词。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
api_key = os.environ.get("KM_STACKOVERFLOW_KEY", "")
params = {
"order": "desc",
"sort": "relevance",
"q": query,
"site": "stackoverflow",
"pagesize": str(min(max_results, 30)),
"filter": "withbody",
}
if api_key:
params["key"] = api_key
query_string = urllib.parse.urlencode(params)
url = f"{STACKOVERFLOW_API_BASE}/search/advanced?{query_string}"
try:
data = _http_get(url)
except RuntimeError as e:
return [_make_result("stackoverflow", f"Stack Overflow 搜索失败: {e}", "", str(e))]
results = []
items = data.get("items", [])
for item in items[:max_results]:
title = item.get("title", "")
link = item.get("link", "")
body = clean_html(item.get("body", "") or "")
owner = item.get("owner", {})
author = owner.get("display_name", "") if owner else ""
# Stack Overflow 时间戳为 Unix 秒
creation_date = item.get("creation_date", 0)
created_at = ""
if creation_date:
try:
created_at = datetime.utcfromtimestamp(creation_date).strftime("%Y-%m-%dT%H:%M:%S")
except (OSError, ValueError):
created_at = ""
score_val = float(item.get("score", 0))
tags = item.get("tags", [])
results.append(_make_result(
source="stackoverflow",
title=clean_html(title),
url=link,
snippet=body,
author=author,
created_at=created_at,
score=score_val,
tags=tags,
))
return results
# ============================================================
# Discord 搜索适配器
# ============================================================
def _search_discord(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索 Discord 频道消息。
使用 Discord Bot API /channels/{id}/messages 接口。
需要设置 KM_DISCORD_BOT_TOKEN 和 KM_DISCORD_CHANNEL_ID。
Args:
query: 搜索关键词。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
token = os.environ.get("KM_DISCORD_BOT_TOKEN", "")
channel_id = os.environ.get("KM_DISCORD_CHANNEL_ID", "")
if not token:
return [_make_result("discord", "Discord 未配置", "", "请设置 KM_DISCORD_BOT_TOKEN 环境变量")]
if not channel_id:
return [_make_result("discord", "Discord 频道未配置", "", "请设置 KM_DISCORD_CHANNEL_ID 环境变量")]
headers = {
"Authorization": f"Bot {token}",
}
url = f"{DISCORD_API_BASE}/channels/{channel_id}/messages?limit={min(max_results, 100)}"
try:
messages = _http_get(url, headers=headers)
except RuntimeError as e:
return [_make_result("discord", f"Discord 搜索失败: {e}", "", str(e))]
if not isinstance(messages, list):
return []
# 在客户端过滤匹配消息
query_lower = query.lower()
keywords = query_lower.split()
results = []
for msg in messages:
content = msg.get("content", "")
if not content:
continue
content_lower = content.lower()
# 检查是否包含任一关键词
if not any(kw in content_lower for kw in keywords):
continue
msg_id = msg.get("id", "")
author_info = msg.get("author", {})
author = author_info.get("username", "") if author_info else ""
timestamp = msg.get("timestamp", "")
msg_url = f"https://discord.com/channels/{channel_id}/{msg_id}"
# 简单相关性评分:匹配关键词数量
match_count = sum(1 for kw in keywords if kw in content_lower)
score_val = match_count / max(len(keywords), 1)
results.append(_make_result(
source="discord",
title=truncate_text(content, 80),
url=msg_url,
snippet=content,
author=author,
created_at=timestamp,
score=score_val,
))
if len(results) >= max_results:
break
return results
# ============================================================
# Confluence 搜索适配器
# ============================================================
def _search_confluence(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索 Confluence 内容。
使用 Confluence REST API /wiki/rest/api/content/search 接口。
Args:
query: 搜索关键词。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
base_url = os.environ.get("KM_CONFLUENCE_URL", "")
token = os.environ.get("KM_CONFLUENCE_TOKEN", "")
if not base_url:
return [_make_result("confluence", "Confluence 未配置", "", "请设置 KM_CONFLUENCE_URL 环境变量")]
if not token:
return [_make_result("confluence", "Confluence 未认证", "", "请设置 KM_CONFLUENCE_TOKEN 环境变量")]
# 构造 CQL 查询
cql = urllib.parse.quote(f'text ~ "{query}"')
url = f"{base_url.rstrip('/')}/wiki/rest/api/content/search?cql={cql}&limit={max_results}&expand=body.view,version"
headers = {
"Authorization": f"Bearer {token}",
}
try:
data = _http_get(url, headers=headers)
except RuntimeError as e:
return [_make_result("confluence", f"Confluence 搜索失败: {e}", "", str(e))]
results = []
items = data.get("results", [])
for item in items[:max_results]:
title = item.get("title", "")
content_id = item.get("id", "")
page_url = f"{base_url.rstrip('/')}/wiki{item.get('_links', {}).get('webui', '')}"
# 提取正文摘要
body_view = item.get("body", {}).get("view", {}).get("value", "")
snippet = clean_html(body_view)
# 版本信息中的作者和时间
version = item.get("version", {})
author = ""
created_at = ""
if version:
by_info = version.get("by", {})
author = by_info.get("displayName", "") if by_info else ""
created_at = version.get("when", "")
results.append(_make_result(
source="confluence",
title=title,
url=page_url,
snippet=snippet,
author=author,
created_at=created_at,
score=1.0,
tags=[item.get("type", "page")],
))
return results
# ============================================================
# Notion 搜索适配器
# ============================================================
def _search_notion(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索 Notion 页面和数据库。
使用 Notion API POST /v1/search 接口。
Args:
query: 搜索关键词。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
token = os.environ.get("KM_NOTION_TOKEN", "")
if not token:
return [_make_result("notion", "Notion 未配置", "", "请设置 KM_NOTION_TOKEN 环境变量")]
headers = {
"Authorization": f"Bearer {token}",
"Notion-Version": "2022-06-28",
}
payload = {
"query": query,
"page_size": min(max_results, 100),
}
url = f"{NOTION_API_BASE}/search"
try:
data = _http_post(url, payload, headers=headers)
except RuntimeError as e:
return [_make_result("notion", f"Notion 搜索失败: {e}", "", str(e))]
results = []
items = data.get("results", [])
for item in items[:max_results]:
obj_type = item.get("object", "")
page_id = item.get("id", "")
page_url = item.get("url", "") or f"https://notion.so/{page_id.replace('-', '')}"
# 提取标题
title = ""
properties = item.get("properties", {})
if "title" in properties:
title_arr = properties["title"].get("title", [])
if title_arr:
title = title_arr[0].get("plain_text", "")
elif "Name" in properties:
name_info = properties["Name"]
if name_info.get("type") == "title":
title_arr = name_info.get("title", [])
if title_arr:
title = title_arr[0].get("plain_text", "")
if not title:
title = f"Notion {obj_type} {page_id[:8]}"
# 创建时间和作者
created_at = item.get("created_time", "")
created_by = item.get("created_by", {})
author = ""
if created_by:
author = created_by.get("name", "") or created_by.get("id", "")
results.append(_make_result(
source="notion",
title=title,
url=page_url,
snippet=f"Notion {obj_type}: {title}",
author=author,
created_at=created_at,
score=1.0,
tags=[obj_type],
))
return results
# ============================================================
# Slack 搜索适配器
# ============================================================
def _search_slack(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索 Slack 消息。
使用 Slack Web API search.messages 接口。
Args:
query: 搜索关键词。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
token = os.environ.get("KM_SLACK_TOKEN", "")
if not token:
return [_make_result("slack", "Slack 未配置", "", "请设置 KM_SLACK_TOKEN 环境变量")]
headers = {
"Authorization": f"Bearer {token}",
}
encoded_query = urllib.parse.quote(query)
url = f"https://slack.com/api/search.messages?query={encoded_query}&count={max_results}&sort=score"
try:
data = _http_get(url, headers=headers)
except RuntimeError as e:
return [_make_result("slack", f"Slack 搜索失败: {e}", "", str(e))]
if not data.get("ok", False):
error_msg = data.get("error", "未知错误")
return [_make_result("slack", f"Slack API 错误: {error_msg}", "", error_msg)]
results = []
messages_data = data.get("messages", {})
matches = messages_data.get("matches", [])
for msg in matches[:max_results]:
text = msg.get("text", "")
permalink = msg.get("permalink", "")
username = msg.get("username", "")
ts = msg.get("ts", "")
channel_info = msg.get("channel", {})
channel_name = channel_info.get("name", "") if isinstance(channel_info, dict) else ""
# 将 Slack 时间戳转换为 ISO 格式
created_at = ""
if ts:
try:
dt = datetime.utcfromtimestamp(float(ts))
created_at = dt.strftime("%Y-%m-%dT%H:%M:%S")
except (ValueError, OSError):
created_at = ""
tags = [f"#{channel_name}"] if channel_name else []
results.append(_make_result(
source="slack",
title=truncate_text(text, 80),
url=permalink,
snippet=text,
author=username,
created_at=created_at,
score=1.0,
tags=tags,
))
return results
# ============================================================
# 百度搜索适配器
# ============================================================
def _search_baidu(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索百度,获取中文内容搜索结果。
使用百度搜索 API 或 Web 端点,解析结果为统一格式。
需要设置 KM_BAIDU_API_KEY 环境变量。
Args:
query: 搜索关键词。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
api_key = os.environ.get("KM_BAIDU_API_KEY", "")
# 使用百度开发者搜索 API
encoded_query = urllib.parse.quote(query)
params = {
"wd": query,
"rn": str(min(max_results, 50)),
"ie": "utf-8",
"oe": "utf-8",
}
if api_key:
params["apikey"] = api_key
query_string = urllib.parse.urlencode(params)
url = f"https://api.baidu.com/json/custom/tongji?{query_string}"
# 备用:使用百度搜索结果页面解析
headers = {
"User-Agent": "knowledge-mesh/1.0",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
data = _http_get(url, headers=headers)
except RuntimeError:
# API 不可用时返回空结果
return [_make_result(
"baidu",
"百度搜索需要配置 API",
"",
"请设置 KM_BAIDU_API_KEY 环境变量以启用百度搜索",
)]
results = []
items = data.get("results", data.get("items", []))
if not isinstance(items, list):
items = []
for item in items[:max_results]:
title = item.get("title", "")
link = item.get("url", item.get("link", ""))
abstract = item.get("abstract", item.get("snippet", item.get("description", "")))
author = item.get("source", item.get("author", ""))
created_at = item.get("date", item.get("created_at", ""))
score_val = float(item.get("score", 0.5))
# 清理 HTML 标签
title = clean_html(title)
abstract = clean_html(abstract)
results.append(_make_result(
source="baidu",
title=title,
url=link,
snippet=abstract,
author=author,
created_at=created_at,
score=score_val,
tags=["baidu"],
))
return results
# ============================================================
# Obsidian 搜索适配器
# ============================================================
def _search_obsidian(query: str, max_results: int = 20) -> List[Dict[str, Any]]:
"""搜索 Obsidian 知识库。
委托给 obsidian_connector 模块执行搜索。
Args:
query: 搜索关键词。
max_results: 最大返回结果数。
Returns:
标准化搜索结果列表。
"""
try:
from obsidian_connector import search_obsidian
results = search_obsidian(query, max_results=max_results)
return results
except ImportError:
return [_make_result(
"obsidian",
"Obsidian 连接器未安装",
"",
"obsidian_connector 模块不可用",
)]
except Exception as e:
return [_make_result(
"obsidian",
f"Obsidian 搜索失败: {e}",
"",
str(e),
)]
# ============================================================
# 搜索适配器路由
# ============================================================
_SOURCE_ADAPTERS = {
"github": _search_github,
"stackoverflow": _search_stackoverflow,
"discord": _search_discord,
"confluence": _search_confluence,
"notion": _search_notion,
"slack": _search_slack,
"baidu": _search_baidu,
"obsidian": _search_obsidian,
}
# 免费版可用的知识源
_FREE_SOURCES = {"github", "stackoverflow", "baidu", "obsidian"}
def _get_available_sources() -> List[str]:
"""获取当前订阅等级下可用的知识源列表。
Returns:
可用知识源名称列表。
"""
sub = check_subscription()
features = sub.get("features", [])
available = []
for source in SUPPORTED_SOURCES:
feature_name = f"{source}_search"
# basic_search 对应免费源
if source in _FREE_SOURCES or feature_name in features:
available.append(source)
return available
def _get_configured_sources() -> List[str]:
"""获取已配置认证凭据的知识源列表。
Returns:
已配置的知识源名称列表。
"""
configured = []
for source in SUPPORTED_SOURCES:
env_key = SOURCE_ENV_KEYS.get(source, "")
if env_key and os.environ.get(env_key, ""):
configured.append(source)
elif source == "stackoverflow":
# Stack Overflow 不需要 API key 也能搜索(有速率限制)
configured.append(source)
return configured
# ============================================================
# 操作实现
# ============================================================
def action_search(data: Dict[str, Any]) -> None:
"""统一搜索:查询所有已配置的可用知识源。
Args:
data: 包含 query(搜索关键词)的字典,可选 max_results。
"""
query = data.get("query", "").strip()
if not query:
output_error("请提供搜索关键词(query)", code="VALIDATION_ERROR")
return
# 检查搜索配额
if not check_search_quota():
return
sub = check_subscription()
max_results = min(data.get("max_results", 20), sub.get("max_results", 20))
available = _get_available_sources()
configured = _get_configured_sources()
# 取交集:既可用又已配置
sources_to_search = [s for s in available if s in configured]
if not sources_to_search:
output_error(
"没有可用的知识源。请先配置至少一个知识源的 API 凭据。",
code="NO_SOURCES",
)
return
all_results = []
source_stats = {}
errors = []
for source in sources_to_search:
adapter = _SOURCE_ADAPTERS.get(source)
if not adapter:
continue
try:
results = adapter(query, max_results=max_results)
source_stats[source] = len(results)
all_results.extend(results)
except Exception as e:
errors.append(f"{SOURCE_DISPLAY_NAMES.get(source, source)}: {e}")
source_stats[source] = 0
# 按分数降序排列
all_results.sort(key=lambda r: r.get("score", 0), reverse=True)
all_results = all_results[:max_results]
# 递增搜索计数
increment_search_count()
# 记录查询到自学习引擎
try:
from learning_engine import record_query_data
record_query_data(query, sources_to_search, source_stats)
except (ImportError, Exception):
pass # 自学习模块不可用时静默跳过
output_success({
"query": query,
"total": len(all_results),
"sources_searched": sources_to_search,
"source_stats": source_stats,
"results": all_results,
"errors": errors if errors else None,
})
def action_search_source(data: Dict[str, Any]) -> None:
"""搜索指定的单个知识源。
Args:
data: 包含 query 和 source 的字典。
"""
query = data.get("query", "").strip()
source = data.get("source", "").strip().lower()
if not query:
output_error("请提供搜索关键词(query)", code="VALIDATION_ERROR")
return
if not source:
output_error("请指定知识源(source)", code="VALIDATION_ERROR")
return
if source not in SUPPORTED_SOURCES:
valid = "、".join(SUPPORTED_SOURCES)
output_error(f"不支持的知识源: {source},支持: {valid}", code="INVALID_SOURCE")
return
# 检查订阅权限
available = _get_available_sources()
if source not in available:
output_error(
f"{SOURCE_DISPLAY_NAMES.get(source, source)} 搜索为付费版功能,请升级至付费版(¥129/月)。",
code="SUBSCRIPTION_REQUIRED",
)
return
# 检查搜索配额
if not check_search_quota():
return
sub = check_subscription()
max_results = min(data.get("max_results", 20), sub.get("max_results", 20))
adapter = _SOURCE_ADAPTERS.get(source)
if not adapter:
output_error(f"知识源适配器不存在: {source}", code="ADAPTER_ERROR")
return
try:
results = adapter(query, max_results=max_results)
except Exception as e:
output_error(f"搜索失败: {e}", code="SEARCH_ERROR")
return
# 递增搜索计数
increment_search_count()
output_success({
"query": query,
"source": source,
"total": len(results),
"results": results,
})
def action_list_sources(data: Optional[Dict[str, Any]] = None) -> None:
"""列出所有支持的知识源及其配置状态。"""
available = _get_available_sources()
configured = _get_configured_sources()
sub = check_subscription()
sources_info = []
for source in SUPPORTED_SOURCES:
env_key = SOURCE_ENV_KEYS.get(source, "")
is_available = source in available
is_configured = source in configured
sources_info.append({
"name": source,
"display_name": SOURCE_DISPLAY_NAMES.get(source, source),
"env_key": env_key,
"available": is_available,
"configured": is_configured,
"status": "ready" if (is_available and is_configured) else
"not_configured" if is_available else "paid_only",
})
output_success({
"subscription_tier": sub["tier"],
"total_sources": len(SUPPORTED_SOURCES),
"available_count": len(available),
"configured_count": len(configured),
"sources": sources_info,
})
def action_test_source(data: Dict[str, Any]) -> None:
"""测试指定知识源的连接状态。
Args:
data: 包含 source 的字典。
"""
source = data.get("source", "").strip().lower()
if not source:
output_error("请指定知识源(source)", code="VALIDATION_ERROR")
return
if source not in SUPPORTED_SOURCES:
valid = "、".join(SUPPORTED_SOURCES)
output_error(f"不支持的知识源: {source},支持: {valid}", code="INVALID_SOURCE")
return
env_key = SOURCE_ENV_KEYS.get(source, "")
token = os.environ.get(env_key, "") if env_key else ""
# 测试连接
test_result = {
"source": source,
"display_name": SOURCE_DISPLAY_NAMES.get(source, source),
"env_key": env_key,
"token_configured": bool(token),
}
if source == "github":
if token:
try:
data_resp = _http_get(f"{GITHUB_API_BASE}/user", headers={"Authorization": f"token {token}"})
test_result["status"] = "connected"
test_result["user"] = data_resp.get("login", "")
except RuntimeError as e:
test_result["status"] = "error"
test_result["error"] = str(e)
else:
test_result["status"] = "no_token"
test_result["message"] = "GitHub 搜索可在无 Token 下使用,但有速率限制"
elif source == "stackoverflow":
# Stack Overflow 不需要强制 API key
test_result["status"] = "available"
test_result["message"] = "Stack Overflow API 可直接访问(API key 可选)"
elif source == "confluence":
base_url = os.environ.get("KM_CONFLUENCE_URL", "")
if not base_url:
test_result["status"] = "not_configured"
test_result["error"] = "请设置 KM_CONFLUENCE_URL 环境变量"
elif not token:
test_result["status"] = "no_token"
test_result["error"] = "请设置 KM_CONFLUENCE_TOKEN 环境变量"
else:
test_result["status"] = "configured"
test_result["base_url"] = base_url
else:
if token:
test_result["status"] = "configured"
else:
test_result["status"] = "no_token"
test_result["error"] = f"请设置 {env_key} 环境变量"
output_success(test_result)
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("knowledge-mesh 知识源搜索")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"search": lambda: action_search(data or {}),
"search-source": lambda: action_search_source(data or {}),
"list-sources": lambda: action_list_sources(data),
"test-source": lambda: action_test_source(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:references/api-endpoints.md
# API 端点参考
本文档列出 knowledge-mesh 支持的各知识源 API 端点和认证方式。
---
## 1. GitHub
| 项目 | 说明 |
|------|------|
| **API 基址** | `https://api.github.com` |
| **搜索 Issues** | `GET /search/issues?q={query}&per_page={n}&sort=relevance` |
| **搜索代码** | `GET /search/code?q={query}&per_page={n}` |
| **搜索仓库** | `GET /search/repositories?q={query}&per_page={n}` |
| **认证方式** | `Authorization: token {KM_GITHUB_TOKEN}` |
| **速率限制** | 未认证 10 次/分钟,认证 30 次/分钟 |
| **文档** | https://docs.github.com/en/rest/search |
> 注意:GitHub Discussions 搜索需要 GraphQL API,当前使用 Issues 搜索作为替代。
---
## 2. Stack Overflow
| 项目 | 说明 |
|------|------|
| **API 基址** | `https://api.stackexchange.com/2.3` |
| **高级搜索** | `GET /search/advanced?q={query}&site=stackoverflow&order=desc&sort=relevance` |
| **按标签搜索** | `GET /search?tagged={tags}&site=stackoverflow` |
| **获取问题详情** | `GET /questions/{ids}?site=stackoverflow&filter=withbody` |
| **认证方式** | 查询参数 `key={KM_STACKOVERFLOW_KEY}` |
| **速率限制** | 无 Key 300 次/天,有 Key 10,000 次/天 |
| **文档** | https://api.stackexchange.com/docs |
> 注意:响应数据经过 gzip 压缩,urllib 会自动处理。
---
## 3. Discord
| 项目 | 说明 |
|------|------|
| **API 基址** | `https://discord.com/api/v10` |
| **获取频道消息** | `GET /channels/{channel_id}/messages?limit={n}` |
| **搜索消息** | 需在客户端侧过滤(Discord API 不提供公开搜索端点) |
| **认证方式** | `Authorization: Bot {KM_DISCORD_BOT_TOKEN}` |
| **速率限制** | 按端点不同,通常 5 次/秒 |
| **所需权限** | `READ_MESSAGE_HISTORY`, `VIEW_CHANNEL` |
| **文档** | https://discord.com/developers/docs |
> 注意:需要额外设置 `KM_DISCORD_CHANNEL_ID` 环境变量指定搜索频道。
---
## 4. Confluence
| 项目 | 说明 |
|------|------|
| **API 基址** | `{KM_CONFLUENCE_URL}/wiki/rest/api` |
| **CQL 搜索** | `GET /content/search?cql=text~"{query}"&limit={n}&expand=body.view,version` |
| **获取页面** | `GET /content/{id}?expand=body.view` |
| **认证方式** | `Authorization: Bearer {KM_CONFLUENCE_TOKEN}` |
| **替代认证** | Basic Auth: `email:api_token` (Base64 编码) |
| **速率限制** | 因实例而异,通常无严格限制 |
| **文档** | https://developer.atlassian.com/cloud/confluence/rest/ |
> 注意:CQL(Confluence Query Language)支持丰富的搜索语法,如 `type=page AND space=DEV AND text~"keyword"`。
---
## 5. Notion
| 项目 | 说明 |
|------|------|
| **API 基址** | `https://api.notion.com/v1` |
| **搜索** | `POST /search` (Body: `{"query":"...","page_size":20}`) |
| **获取页面** | `GET /pages/{page_id}` |
| **获取块内容** | `GET /blocks/{block_id}/children` |
| **认证方式** | `Authorization: Bearer {KM_NOTION_TOKEN}` |
| **必需请求头** | `Notion-Version: 2022-06-28` |
| **速率限制** | 3 次/秒 |
| **文档** | https://developers.notion.com/reference |
> 注意:Notion Integration 需要在 Notion 设置中创建,并授权访问相应的页面/数据库。
---
## 6. Slack
| 项目 | 说明 |
|------|------|
| **API 基址** | `https://slack.com/api` |
| **搜索消息** | `GET /search.messages?query={query}&count={n}&sort=score` |
| **搜索文件** | `GET /search.files?query={query}&count={n}` |
| **认证方式** | `Authorization: Bearer {KM_SLACK_TOKEN}` |
| **所需权限** | `search:read` |
| **速率限制** | Tier 2: 20 次/分钟 |
| **文档** | https://api.slack.com/methods/search.messages |
> 注意:Slack Token 需要 `search:read` scope,且 Bot 需要被邀请到相应频道。
---
## 通用注意事项
1. **HTTPS 强制**:所有 API 请求必须使用 HTTPS 协议。
2. **超时设置**:默认请求超时 15 秒,可通过代码调整。
3. **错误重试**:遇到 429(速率限制)时应等待 Retry-After 头指定的时间后重试。
4. **User-Agent**:所有请求携带 `User-Agent: knowledge-mesh/1.0` 标识。
5. **响应格式**:所有平台返回 JSON 格式数据,使用 `Accept: application/json` 请求头。
FILE:references/search-syntax.md
# 搜索语法指南
本文档说明 knowledge-mesh 的搜索查询语法和各平台特有的搜索技巧。
---
## 基本搜索语法
### 关键词搜索
直接输入关键词,空格分隔表示 AND 关系:
```
python async fastapi
```
搜索同时包含 "python"、"async" 和 "fastapi" 的内容。
### 精确匹配
使用双引号包裹短语进行精确匹配:
```
"connection pool" timeout
```
搜索包含完整短语 "connection pool" 且含 "timeout" 的内容。
### 排除关键词
使用减号排除不需要的结果(部分平台支持):
```
python web framework -django
```
---
## 平台特有语法
### GitHub
```
# 按仓库搜索
repo:owner/name keyword
# 按语言过滤
language:python async
# 按标签过滤
label:bug memory leak
# 按状态过滤
state:open performance issue
# 组合搜索
repo:fastapi/fastapi label:question websocket
```
### Stack Overflow
```
# 按标签搜索
[python] [async] how to
# 搜索已采纳答案
is:answer accepted:yes connection pool
# 按分数过滤
score:10 python decorator
# 按时间范围
created:2026-01.. python 3.12
```
### Confluence (CQL)
```
# 按空间搜索
space=DEV AND text~"api design"
# 按类型过滤
type=page AND title~"architecture"
# 按标签
label=backend AND text~"microservice"
# 按创建者
creator=currentUser() AND text~"meeting notes"
```
### Notion
```
# 基本关键词(Notion API 仅支持简单文本搜索)
project roadmap
# 建议使用具体的页面或数据库标题关键词
sprint planning Q1 2026
```
### Slack
```
# 按频道搜索
in:#engineering python deployment
# 按用户搜索
from:@username production issue
# 按时间范围
after:2026-03-01 before:2026-03-19 release
# 搜索文件
has:link architecture document
```
---
## 搜索技巧
1. **使用具体关键词**:避免过于宽泛的搜索词,如 "问题" 或 "error",改用具体的错误信息或技术术语。
2. **善用平台优势**:
- 代码问题 → 优先搜索 Stack Overflow 和 GitHub
- 团队知识 → 优先搜索 Confluence 和 Notion
- 实时讨论 → 优先搜索 Discord 和 Slack
3. **迭代优化**:首次搜索结果不理想时,根据返回的标签和关键词调整查询。
4. **结合多个来源**:跨平台搜索可以获得更全面的信息,免费版支持同时搜索 GitHub 和 Stack Overflow。
5. **利用标签过滤**:搜索结果中的标签可以帮助你发现相关主题和更精确的搜索词。
---
## 搜索示例
| 场景 | 推荐查询 | 建议平台 |
|------|----------|----------|
| Python 异步编程 | `python asyncio await best practices` | Stack Overflow, GitHub |
| React 性能优化 | `react performance optimization memo` | Stack Overflow, GitHub |
| 团队 API 设计规范 | `api design guidelines rest` | Confluence, Notion |
| 部署问题排查 | `deployment failed kubernetes pod` | Slack, Discord |
| 开源项目选型 | `python web framework comparison 2026` | GitHub, Stack Overflow |
库存慧眼 — 轻量库存监控,知道"什么该补货、什么卖不动、什么快过期
---
name: inventory-eye
description: 库存慧眼 — 轻量库存监控,知道"什么该补货、什么卖不动、什么快过期"
version: 1.0.0
metadata:
openclaw:
optional_env:
- IE_SUBSCRIPTION_TIER
- IE_DATA_DIR
---
# 库存慧眼(inventory-eye)
你是一个专业的库存管理助手 Agent。你的职责是帮助用户导入和管理库存数据、监控库存水平、生成补货建议和周转分析报告。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `IE_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `IE_DATA_DIR` | 否 | 数据存储目录,默认 `~/.openclaw-bdi/inventory-eye/` |
---
## 流程一:导入库存数据
当用户说"导入库存表"、"导入CSV"、"上传库存"或类似意图时,执行以下步骤:
### 步骤 1:获取文件路径
引导用户提供 CSV 文件路径。支持的格式:
- CSV 文件(UTF-8 / GBK 编码自动识别)
### 步骤 2:导入数据
```bash
python3 scripts/inventory_store.py --action import --file <csv_path>
```
系统会自动识别列映射(SKU编码/名称/数量/成本/售价/保质期等),将映射结果展示给用户确认。
### 步骤 3:展示导入结果
将导入结果以表格形式展示:
```
导入结果:
- 新增 SKU: XX 个
- 更新 SKU: XX 个
- 跳过: XX 个
- 当前总 SKU: XX 个
列映射关系:
| CSV列名 | 对应字段 |
|---------|---------|
| 商品编码 | sku_id |
| 商品名称 | name |
| ... | ... |
```
### 步骤 4:库存概览
导入成功后,自动执行库存概览:
```bash
python3 scripts/stock_monitor.py --action overview
```
以清晰的表格展示库存概况(总SKU数、总库存量、库存价值、各分类/仓库分布等)。
---
## 流程二:库存监控与预警
当用户问"库存状况怎么样"、"有没有库存预警"、"哪些商品缺货了"或类似意图时,执行以下步骤:
### 步骤 1:全面检查
```bash
python3 scripts/stock_monitor.py --action check
```
### 步骤 2:展示预警信息
按严重程度分级展示:
```
🔴 缺货预警(X个商品):
| SKU | 名称 | 仓库 |
|-----|------|------|
🟡 低库存预警(X个商品):
| SKU | 名称 | 当前库存 | 安全库存 | 缺口 |
|-----|------|---------|---------|------|
⏰ 过期预警(X个商品):
| SKU | 名称 | 过期日期 | 剩余天数 | 状态 |
|-----|------|---------|---------|------|
```
> 注意:免费版过期预警仅提示已过期和30天内到期的商品。付费版提供多级提醒(7天/30天/60天/90天)。
---
## 流程三:补货建议
当用户问"哪些商品该补货了"、"生成补货清单"、"补货建议"或类似意图时,执行以下步骤:
### 步骤 1:订阅校验
- **免费版**:展示低于安全库存的商品清单(固定阈值对比)。
- **付费版**:执行 AI 智能补货计算。
### 步骤 2A(免费版):低库存清单
```bash
python3 scripts/stock_monitor.py --action low-stock
```
展示低于安全库存的商品列表,建议用户升级获取智能补货建议。
### 步骤 2B(付费版):AI 补货计算
```bash
python3 scripts/reorder_calculator.py --action calculate --data '{"lead_time": 7, "safety_factor": 1.5}'
```
展示智能补货清单:
```
补货清单(按紧急程度排序):
| 紧急度 | SKU | 名称 | 当前库存 | 建议补货量 | 补货成本 | 可售天数 |
|--------|-----|------|---------|-----------|---------|---------|
| 🔴 缺货 | ... | ... | 0 | 50 | ¥500 | 已断货 |
| 🟠 紧急 | ... | ... | 5 | 30 | ¥300 | 2.5天 |
预计补货总成本: ¥XXXX
```
### 步骤 3:单品详细建议(付费版)
如果用户想了解某个 SKU 的详细补货建议:
```bash
python3 scripts/reorder_calculator.py --action suggest --data '{"sku_id": "SKU-001"}'
```
提供多周期分析(7天/14天/30天/60天/90天日均销量)和推荐补货量。
---
## 流程四:滞销品分析(付费功能)
当用户问"哪些商品卖不动"、"滞销商品"、"清库存"或类似意图时,执行以下步骤:
### 步骤 1:订阅校验
检查是否为付费版。免费版提示升级。
### 步骤 2:滞销分析
```bash
python3 scripts/turnover_analyzer.py --action slow-moving --data '{"days": 30}'
```
展示滞销商品清单:
```
滞销商品清单(30天无出库):
| SKU | 名称 | 库存量 | 库存金额 | 滞销天数 | 建议 |
|-----|------|--------|---------|---------|------|
| ... | ... | 100 | ¥5000 | 65天 | 中度滞销,建议打折促销 |
滞销库存占用资金: ¥XXXXX
```
---
## 流程五:库存周转分析(付费功能)
当用户说"库存周转率"、"周转分析"、"本月库存报告"或类似意图时,执行以下步骤:
### 步骤 1:订阅校验
检查是否为付费版。免费版提示升级。
### 步骤 2:周转率计算
```bash
python3 scripts/turnover_analyzer.py --action turnover --days 30
```
### 步骤 3:生成报告
```bash
python3 scripts/turnover_analyzer.py --action report --days 30
```
报告包含:整体周转率、分类周转率、Mermaid图表、高/低周转SKU排名。
---
## 流程六:出入库操作
当用户说"入库"、"出库"、"到货了"、"卖了XX"或类似意图时:
### 入库
```bash
python3 scripts/inventory_store.py --action inbound --data '{"sku_id": "SKU-001", "quantity": 100, "note": "供应商到货"}'
```
### 出库
```bash
python3 scripts/inventory_store.py --action outbound --data '{"sku_id": "SKU-001", "quantity": 20, "note": "订单出库"}'
```
操作完成后,展示更新后的库存数量,并检查是否触发低库存预警。
---
## 订阅校验逻辑
在每次涉及功能限制的操作前,必须执行以下校验:
### 读取订阅等级
```
tier = env IE_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥89/月) |
|------|---------------|----------------------|
| SKU 数量 | 100个 | 2000个 |
| CSV 导入 | ✅ | ✅ |
| 库存概览 | ✅ | ✅ |
| 低库存预警 | 固定阈值 | 动态安全库存 |
| 滞销品识别 | ❌ | ✅ |
| 补货建议 | ❌ | ✅ AI计算 |
| 库存周转分析 | ❌ | ✅ |
| 过期预警 | 基础(30天内) | ✅ 多级提醒(7/30/60/90天) |
| 多仓库 | 1个 | 5个 |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥89/月),请联系管理员或访问订阅管理页面。"
4. 不要直接拒绝,而是提供免费版可用的替代方案(如果有的话)。
---
## 参考文档
在处理库存管理问题时,请参考以下文档:
- **库存管理指南**:`references/inventory-guide.md` — 包含安全库存、补货点、周转率等核心概念说明。
---
## 安全规范
1. **数据安全**:库存数据存储在本地目录,不上传至外部服务器。
2. **文件操作**:仅读写 `IE_DATA_DIR` 指定的目录,不访问其他文件。
3. **输入校验**:所有用户输入(SKU编码、数量、价格)必须经过类型和范围校验。
4. **错误处理**:执行命令失败时,向用户展示友好的错误提示,不暴露内部路径或系统信息。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 对用户的问题给出清晰、结构化的回答,优先使用表格展示数据。
3. 主动提供库存管理建议,而不仅仅是返回原始数据。
4. 遇到模糊的用户意图时,主动追问以明确需求。
5. 尊重订阅等级限制,在提示升级时保持友好,不要反复推销。
6. 每次出入库操作后,主动检查并提示库存预警。
FILE:assets/README.md
# 库存慧眼(inventory-eye)
> 轻量库存监控,知道"什么该补货、什么卖不动、什么快过期"
---
## 功能亮点
- **一键导入** — CSV 文件导入,自动识别列映射,10秒上手
- **智能预警** — 低库存、过期、缺货三重预警,不再漏单
- **补货建议** — AI 根据历史销售数据计算最优补货量(付费版)
- **滞销识别** — 自动发现"卖不动"的商品,释放资金(付费版)
- **周转分析** — 库存周转率报告,看清库存健康度(付费版)
- **多仓管理** — 支持最多5个仓库独立管理(付费版)
---
## 版本对比
| 功能 | 免费版 | 付费版 ¥89/月 |
|------|:------:|:------------:|
| SKU 数量 | 100个 | 2000个 |
| CSV 导入 | ✅ | ✅ |
| 库存概览 | ✅ | ✅ |
| 出入库管理 | ✅ | ✅ |
| 低库存预警 | 固定阈值 | 动态安全库存 |
| 过期预警 | 基础(30天内) | 多级提醒(7/30/60/90天) |
| 滞销品识别 | ❌ | ✅ |
| AI 补货建议 | ❌ | ✅ |
| 库存周转分析 | ❌ | ✅ |
| 多仓库 | 1个 | 5个 |
---
## 快速开始
### 1. 导入库存
准备一个 CSV 文件,包含以下列(列名灵活匹配):
```csv
SKU编码,商品名称,分类,数量,进价,售价,安全库存,仓库,过期日期
SKU-001,有机牛奶,乳制品,200,8.50,15.90,50,主仓,2026-06-30
SKU-002,全麦面包,烘焙,80,5.00,12.00,30,主仓,2026-04-15
SKU-003,洗手液,日化,500,3.20,9.90,100,主仓,2027-12-31
```
对 AI 说:**"导入库存表 /path/to/inventory.csv"**
### 2. 查看库存状况
对 AI 说:**"库存状况怎么样?"**
将获得完整的库存概览,包括:
- 总 SKU 数、总库存量、库存价值
- 各分类库存分布
- 缺货/低库存/过期预警
### 3. 查看补货建议
对 AI 说:**"哪些商品该补货了?"**
免费版:列出低于安全库存的商品清单
付费版:AI 根据历史销售数据计算最优补货量
### 4. 日常出入库
对 AI 说:
- **"SKU-001 入库 100个"** — 记录到货
- **"SKU-002 出库 20个"** — 记录销售
---
## 示例报告
### 库存概览
```
📊 库存概览 — 2026-03-19
| 指标 | 数值 |
|------|------|
| 总 SKU 数 | 86 |
| 总库存量 | 12,350 件 |
| 库存成本总值 | ¥89,560.00 |
| 库存零售总值 | ¥156,800.00 |
| 潜在利润 | ¥67,240.00 |
| 缺货商品 | 3 个 |
| 低库存商品 | 8 个 |
| 即将过期 | 5 个 |
```
### 补货建议(付费版)
```
📦 补货清单(按紧急程度排序)
| 紧急度 | 商品 | 当前库存 | 建议补货 | 可售天数 |
|--------|------|---------|---------|---------|
| 🔴 缺货 | 有机牛奶 | 0 | 150 | 已断货 |
| 🟠 紧急 | 全麦面包 | 12 | 80 | 1.5天 |
| 🟡 预警 | 酸奶 | 45 | 60 | 5.2天 |
预计补货总成本: ¥3,250.00
```
---
## 适用场景
- **小型零售店** — 便利店、母婴店、文具店的日常库存管理
- **小型餐饮** — 食材库存管理和保质期追踪
- **仓库管理** — 小型仓库的进出库管理
- **电商小卖家** — 多 SKU 商品的库存监控
---
## 常见问题
### Q: 支持哪些文件格式导入?
A: 目前支持 CSV 文件导入,自动识别 UTF-8 和 GBK 编码。系统会自动匹配列名,无需严格按模板格式。
### Q: 免费版有什么限制?
A: 免费版支持最多 100 个 SKU、1 个仓库、固定阈值预警和基础过期提醒。升级付费版(¥89/月)可解锁 2000 个 SKU、5 个仓库、AI 补货建议、滞销分析等高级功能。
### Q: 数据存储在哪里?
A: 所有数据存储在本地 `~/.openclaw-bdi/inventory-eye/` 目录,不会上传到外部服务器,数据完全在您的控制之下。
### Q: 安全库存应该设多少?
A: 建议设置为"日均销量 × 3~7天"。热销品设高一些(5~7天),长尾品设低一些(2~3天)。付费版会根据实际销售数据自动计算动态安全库存。
### Q: 补货量是怎么计算的?
A: 付费版使用公式:补货量 = (日均销量 × 供货周期 × 安全系数) - 当前库存。日均销量基于近30天实际出库数据计算,供货周期默认7天,安全系数默认1.5倍。
### Q: 可以管理多个仓库吗?
A: 免费版支持 1 个仓库,付费版支持最多 5 个仓库。每个仓库独立管理库存,预警和报告按仓库分别展示。
---
## 技术规格
- **运行环境**: Python 3.8+
- **依赖**: 仅使用标准库(无需额外安装)
- **数据格式**: JSON 存储
- **编码支持**: UTF-8, GBK, GB2312 自动识别
---
> 库存慧眼 — 让库存管理更轻松 | [OpenClaw](https://openclaw.dev)
FILE:scripts/stock_monitor.py
#!/usr/bin/env python3
"""
inventory-eye 库存监控与预警模块
提供库存水平监控、低库存预警、过期预警、库存概览等功能。
"""
import argparse
import json
import os
import sys
from datetime import datetime, date, timedelta
from typing import Any, Dict, List, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription, days_until_expiry, format_number, format_percentage,
load_inventory, load_transactions, read_json_input,
output_json, output_error, output_success,
)
# ============================================================
# 监控功能
# ============================================================
def action_check(args) -> None:
"""全面检查库存状态,返回综合报告。"""
inventory = load_inventory()
skus = inventory["skus"]
sub = check_subscription()
if not skus:
output_success({
"action": "check",
"message": "库存为空,请先导入库存数据",
"total_skus": 0,
})
return
low_stock = []
expiring = []
out_of_stock = []
healthy = []
for s in skus:
qty = s.get("quantity", 0)
safety = s.get("safety_stock", 0)
if qty == 0:
out_of_stock.append(s)
elif qty <= safety:
low_stock.append(s)
else:
healthy.append(s)
# 过期检查
days_left = days_until_expiry(s.get("expiry_date"))
if days_left is not None:
if sub["tier"] == "paid":
# 付费版:多级提醒(已过期、7天、30天、60天、90天)
if days_left <= 0:
expiring.append({**s, "_expiry_status": "已过期", "_days_left": days_left})
elif days_left <= 7:
expiring.append({**s, "_expiry_status": "即将过期(7天内)", "_days_left": days_left})
elif days_left <= 30:
expiring.append({**s, "_expiry_status": "临近过期(30天内)", "_days_left": days_left})
elif days_left <= 60:
expiring.append({**s, "_expiry_status": "注意(60天内)", "_days_left": days_left})
elif days_left <= 90:
expiring.append({**s, "_expiry_status": "提醒(90天内)", "_days_left": days_left})
else:
# 免费版:基础提醒(已过期、30天内)
if days_left <= 0:
expiring.append({**s, "_expiry_status": "已过期", "_days_left": days_left})
elif days_left <= 30:
expiring.append({**s, "_expiry_status": "即将过期(30天内)", "_days_left": days_left})
# 排序:过期最紧急的排前面
expiring.sort(key=lambda x: x["_days_left"])
total_value = sum(s.get("quantity", 0) * s.get("unit_cost", 0) for s in skus)
total_retail = sum(s.get("quantity", 0) * s.get("selling_price", 0) for s in skus)
output_success({
"action": "check",
"tier": sub["tier"],
"summary": {
"total_skus": len(skus),
"total_quantity": sum(s.get("quantity", 0) for s in skus),
"total_cost_value": round(total_value, 2),
"total_retail_value": round(total_retail, 2),
"out_of_stock_count": len(out_of_stock),
"low_stock_count": len(low_stock),
"expiring_count": len(expiring),
"healthy_count": len(healthy),
},
"out_of_stock": [{"sku_id": s["sku_id"], "name": s["name"], "warehouse": s.get("warehouse", "")} for s in out_of_stock],
"low_stock": [{
"sku_id": s["sku_id"], "name": s["name"],
"quantity": s["quantity"], "safety_stock": s.get("safety_stock", 0),
"shortage": s.get("safety_stock", 0) - s["quantity"],
"warehouse": s.get("warehouse", ""),
} for s in low_stock],
"expiring": [{
"sku_id": s["sku_id"], "name": s["name"],
"expiry_date": s.get("expiry_date", ""),
"days_left": s["_days_left"],
"status": s["_expiry_status"],
"quantity": s.get("quantity", 0),
} for s in expiring],
})
def action_alerts(args) -> None:
"""获取所有预警信息汇总。"""
inventory = load_inventory()
skus = inventory["skus"]
sub = check_subscription()
alerts = []
for s in skus:
qty = s.get("quantity", 0)
safety = s.get("safety_stock", 0)
# 缺货预警
if qty == 0:
alerts.append({
"level": "critical",
"type": "out_of_stock",
"message": f"【缺货】{s['name']}({s['sku_id']})库存为零",
"sku_id": s["sku_id"],
"name": s["name"],
})
elif qty <= safety:
deficit = safety - qty
alerts.append({
"level": "warning",
"type": "low_stock",
"message": f"【低库存】{s['name']}({s['sku_id']})当前 {qty},低于安全库存 {safety},缺口 {deficit}",
"sku_id": s["sku_id"],
"name": s["name"],
"quantity": qty,
"safety_stock": safety,
})
# 过期预警
days_left = days_until_expiry(s.get("expiry_date"))
if days_left is not None:
if days_left <= 0:
alerts.append({
"level": "critical",
"type": "expired",
"message": f"【已过期】{s['name']}({s['sku_id']})已过期 {abs(days_left)} 天,库存 {qty}",
"sku_id": s["sku_id"],
"name": s["name"],
"days_left": days_left,
})
elif days_left <= 7:
alerts.append({
"level": "critical",
"type": "expiring_soon",
"message": f"【即将过期】{s['name']}({s['sku_id']}){days_left}天后过期,库存 {qty}",
"sku_id": s["sku_id"],
"name": s["name"],
"days_left": days_left,
})
elif days_left <= 30:
alerts.append({
"level": "warning",
"type": "expiring",
"message": f"【过期预警】{s['name']}({s['sku_id']}){days_left}天后过期,库存 {qty}",
"sku_id": s["sku_id"],
"name": s["name"],
"days_left": days_left,
})
elif sub["tier"] == "paid" and days_left <= 90:
alerts.append({
"level": "info",
"type": "expiry_notice",
"message": f"【过期提醒】{s['name']}({s['sku_id']}){days_left}天后过期",
"sku_id": s["sku_id"],
"name": s["name"],
"days_left": days_left,
})
# 按严重程度排序
level_order = {"critical": 0, "warning": 1, "info": 2}
alerts.sort(key=lambda a: level_order.get(a["level"], 9))
output_success({
"action": "alerts",
"tier": sub["tier"],
"total_alerts": len(alerts),
"critical": len([a for a in alerts if a["level"] == "critical"]),
"warning": len([a for a in alerts if a["level"] == "warning"]),
"info": len([a for a in alerts if a["level"] == "info"]),
"alerts": alerts,
})
def action_low_stock(args) -> None:
"""获取低库存商品列表。"""
inventory = load_inventory()
skus = inventory["skus"]
sub = check_subscription()
low_stock_items = []
for s in skus:
qty = s.get("quantity", 0)
safety = s.get("safety_stock", 0)
if sub["tier"] == "paid":
# 付费版:动态安全库存(根据近30天出库量计算)
transactions = load_transactions()
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
recent_outbound = [
t for t in transactions
if t.get("type") == "outbound"
and t.get("sku_id") == s["sku_id"]
and t.get("timestamp", "") >= thirty_days_ago
]
total_outbound = sum(t.get("quantity", 0) for t in recent_outbound)
daily_avg = total_outbound / 30.0 if total_outbound > 0 else 0
dynamic_safety = max(safety, int(daily_avg * 7)) # 至少7天的用量
if qty <= dynamic_safety:
low_stock_items.append({
"sku_id": s["sku_id"],
"name": s["name"],
"category": s.get("category", ""),
"quantity": qty,
"safety_stock": safety,
"dynamic_safety_stock": dynamic_safety,
"daily_avg_sales": round(daily_avg, 1),
"shortage": dynamic_safety - qty,
"days_remaining": round(qty / daily_avg, 1) if daily_avg > 0 else float("inf"),
"warehouse": s.get("warehouse", ""),
})
else:
# 免费版:固定阈值
if qty <= safety:
low_stock_items.append({
"sku_id": s["sku_id"],
"name": s["name"],
"category": s.get("category", ""),
"quantity": qty,
"safety_stock": safety,
"shortage": safety - qty,
"warehouse": s.get("warehouse", ""),
})
# 按缺口排序(缺口大的在前)
low_stock_items.sort(key=lambda x: x.get("shortage", 0), reverse=True)
output_success({
"action": "low_stock",
"tier": sub["tier"],
"total": len(low_stock_items),
"items": low_stock_items,
})
def action_expiring(args) -> None:
"""获取即将过期商品列表。"""
inventory = load_inventory()
skus = inventory["skus"]
sub = check_subscription()
data = read_json_input(args)
check_days = 30
if data and data.get("days"):
check_days = int(data["days"])
expiring_items = []
for s in skus:
days_left = days_until_expiry(s.get("expiry_date"))
if days_left is None:
continue
if days_left <= check_days:
status = "正常"
level = "info"
if days_left <= 0:
status = "已过期"
level = "critical"
elif days_left <= 7:
status = "即将过期"
level = "critical"
elif days_left <= 30:
status = "临近过期"
level = "warning"
elif days_left <= 60:
status = "注意"
level = "info"
elif days_left <= 90:
status = "提醒"
level = "info"
item = {
"sku_id": s["sku_id"],
"name": s["name"],
"category": s.get("category", ""),
"quantity": s.get("quantity", 0),
"expiry_date": s.get("expiry_date", ""),
"days_left": days_left,
"status": status,
"level": level,
"warehouse": s.get("warehouse", ""),
"stock_value": round(s.get("quantity", 0) * s.get("unit_cost", 0), 2),
}
expiring_items.append(item)
expiring_items.sort(key=lambda x: x["days_left"])
total_risk_value = sum(i["stock_value"] for i in expiring_items)
output_success({
"action": "expiring",
"tier": sub["tier"],
"check_days": check_days,
"total": len(expiring_items),
"total_risk_value": round(total_risk_value, 2),
"items": expiring_items,
})
def action_overview(args) -> None:
"""生成库存概览报告。"""
inventory = load_inventory()
skus = inventory["skus"]
sub = check_subscription()
if not skus:
output_success({
"action": "overview",
"message": "库存为空",
"total_skus": 0,
})
return
total_qty = sum(s.get("quantity", 0) for s in skus)
total_cost = sum(s.get("quantity", 0) * s.get("unit_cost", 0) for s in skus)
total_retail = sum(s.get("quantity", 0) * s.get("selling_price", 0) for s in skus)
# 按分类统计
categories: Dict[str, Dict[str, Any]] = {}
for s in skus:
cat = s.get("category", "未分类")
if cat not in categories:
categories[cat] = {"count": 0, "quantity": 0, "cost_value": 0, "retail_value": 0}
categories[cat]["count"] += 1
categories[cat]["quantity"] += s.get("quantity", 0)
categories[cat]["cost_value"] += s.get("quantity", 0) * s.get("unit_cost", 0)
categories[cat]["retail_value"] += s.get("quantity", 0) * s.get("selling_price", 0)
# 按仓库统计
warehouses: Dict[str, Dict[str, Any]] = {}
for s in skus:
wh = s.get("warehouse", "默认仓库")
if wh not in warehouses:
warehouses[wh] = {"count": 0, "quantity": 0, "cost_value": 0}
warehouses[wh]["count"] += 1
warehouses[wh]["quantity"] += s.get("quantity", 0)
warehouses[wh]["cost_value"] += s.get("quantity", 0) * s.get("unit_cost", 0)
# 低库存数
low_stock_count = sum(1 for s in skus if s.get("quantity", 0) <= s.get("safety_stock", 0))
out_of_stock_count = sum(1 for s in skus if s.get("quantity", 0) == 0)
# 过期统计
expiring_30 = 0
expired = 0
for s in skus:
dl = days_until_expiry(s.get("expiry_date"))
if dl is not None:
if dl <= 0:
expired += 1
elif dl <= 30:
expiring_30 += 1
# 四舍五入分类/仓库统计值
for cat_data in categories.values():
cat_data["cost_value"] = round(cat_data["cost_value"], 2)
cat_data["retail_value"] = round(cat_data["retail_value"], 2)
for wh_data in warehouses.values():
wh_data["cost_value"] = round(wh_data["cost_value"], 2)
overview = {
"action": "overview",
"tier": sub["tier"],
"date": date.today().isoformat(),
"summary": {
"total_skus": len(skus),
"total_quantity": total_qty,
"total_cost_value": round(total_cost, 2),
"total_retail_value": round(total_retail, 2),
"potential_profit": round(total_retail - total_cost, 2),
"out_of_stock": out_of_stock_count,
"low_stock": low_stock_count,
"expired": expired,
"expiring_30_days": expiring_30,
},
"by_category": categories,
"by_warehouse": warehouses,
"sku_limit": f"{len(skus)}/{sub['max_skus']}",
"warehouse_limit": f"{len(warehouses)}/{sub['max_warehouses']}",
}
output_success(overview)
# ============================================================
# 主入口
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="库存慧眼 — 库存监控与预警",
)
parser.add_argument(
"--action",
required=True,
choices=["check", "alerts", "low-stock", "expiring", "overview"],
help="监控操作类型",
)
parser.add_argument("--data", default=None, help="JSON 格式的参数")
parser.add_argument("--data-file", default=None, help="JSON 参数文件路径")
args = parser.parse_args()
action_map = {
"check": action_check,
"alerts": action_alerts,
"low-stock": action_low_stock,
"expiring": action_expiring,
"overview": action_overview,
}
try:
action_map[args.action](args)
except Exception as e:
output_error(f"监控操作失败: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/turnover_analyzer.py
#!/usr/bin/env python3
"""
inventory-eye 库存周转率与滞销分析模块(付费功能)
提供库存周转率计算、滞销品识别、库存分析报告生成等功能。
"""
import argparse
import json
import math
import os
import sys
from datetime import datetime, date, timedelta
from typing import Any, Dict, List, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription, require_paid, format_number, format_percentage,
format_chinese_unit, days_until_expiry,
load_inventory, load_transactions, read_json_input,
output_json, output_error, output_success,
)
# ============================================================
# 周转率计算
# ============================================================
def _calculate_turnover(skus: List[Dict], transactions: List[Dict], days: int) -> List[Dict[str, Any]]:
"""计算每个 SKU 的周转率。
周转率 = 期间出库成本 / 平均库存成本
周转天数 = 期间天数 / 周转率
Args:
skus: SKU 列表。
transactions: 交易记录列表。
days: 分析周期天数。
Returns:
每个 SKU 的周转率数据列表。
"""
cutoff = (date.today() - timedelta(days=days)).isoformat()
results = []
for s in skus:
sku_id = s["sku_id"]
unit_cost = s.get("unit_cost", 0)
current_qty = s.get("quantity", 0)
# 统计期间出库量
outbound_txs = [
t for t in transactions
if t.get("type") == "outbound"
and t.get("sku_id") == sku_id
and t.get("timestamp", "") >= cutoff
]
total_outbound = sum(t.get("quantity", 0) for t in outbound_txs)
# 统计期间入库量
inbound_txs = [
t for t in transactions
if t.get("type") == "inbound"
and t.get("sku_id") == sku_id
and t.get("timestamp", "") >= cutoff
]
total_inbound = sum(t.get("quantity", 0) for t in inbound_txs)
# 估算期初库存 = 当前库存 + 期间出库 - 期间入库
beginning_qty = current_qty + total_outbound - total_inbound
beginning_qty = max(0, beginning_qty)
# 平均库存
avg_qty = (beginning_qty + current_qty) / 2.0
avg_inventory_cost = avg_qty * unit_cost
# COGS(期间出库成本)
cogs = total_outbound * unit_cost
# 周转率
turnover_rate = 0.0
turnover_days = float("inf")
if avg_inventory_cost > 0:
turnover_rate = cogs / avg_inventory_cost
if turnover_rate > 0:
turnover_days = days / turnover_rate
# 日均销售
daily_avg = total_outbound / days if days > 0 else 0
results.append({
"sku_id": sku_id,
"name": s.get("name", ""),
"category": s.get("category", ""),
"warehouse": s.get("warehouse", ""),
"current_quantity": current_qty,
"unit_cost": unit_cost,
"period_outbound": total_outbound,
"period_inbound": total_inbound,
"cogs": round(cogs, 2),
"avg_inventory_cost": round(avg_inventory_cost, 2),
"turnover_rate": round(turnover_rate, 2),
"turnover_days": round(turnover_days, 1) if turnover_days != float("inf") else None,
"daily_avg_sales": round(daily_avg, 2),
})
# 按周转率排序(低周转在前,更需关注)
results.sort(key=lambda x: x["turnover_rate"])
return results
def action_turnover(args) -> None:
"""计算库存周转率。"""
require_paid("库存周转分析")
inventory = load_inventory()
skus = inventory["skus"]
transactions = load_transactions()
data = read_json_input(args)
days = 30
if data and data.get("days"):
days = int(data["days"])
if not skus:
output_success({"action": "turnover", "message": "库存为空", "items": []})
return
results = _calculate_turnover(skus, transactions, days)
# 汇总统计
total_cogs = sum(r["cogs"] for r in results)
total_avg_inv = sum(r["avg_inventory_cost"] for r in results)
overall_turnover = total_cogs / total_avg_inv if total_avg_inv > 0 else 0
overall_days = days / overall_turnover if overall_turnover > 0 else None
output_success({
"action": "turnover",
"period_days": days,
"total_skus": len(results),
"overall": {
"total_cogs": round(total_cogs, 2),
"total_avg_inventory": round(total_avg_inv, 2),
"turnover_rate": round(overall_turnover, 2),
"turnover_days": round(overall_days, 1) if overall_days else None,
},
"items": results,
})
# ============================================================
# 滞销品分析
# ============================================================
def action_slow_moving(args) -> None:
"""识别滞销商品。"""
require_paid("滞销品识别")
inventory = load_inventory()
skus = inventory["skus"]
transactions = load_transactions()
data = read_json_input(args)
threshold_days = 30 # 默认30天无出库即为滞销
if data and data.get("days"):
threshold_days = int(data["days"])
today = date.today()
slow_items = []
for s in skus:
sku_id = s["sku_id"]
qty = s.get("quantity", 0)
if qty == 0:
continue # 无库存的不算滞销
# 查找最近一次出库时间
outbound_txs = [
t for t in transactions
if t.get("type") == "outbound" and t.get("sku_id") == sku_id
]
outbound_txs.sort(key=lambda t: t.get("timestamp", ""), reverse=True)
if outbound_txs:
last_outbound_str = outbound_txs[0].get("timestamp", "")[:10]
try:
last_outbound = datetime.strptime(last_outbound_str, "%Y-%m-%d").date()
idle_days = (today - last_outbound).days
except ValueError:
idle_days = threshold_days + 1
else:
# 从未出库过,使用入库日期或创建日期
last_date_str = s.get("last_inbound_date") or s.get("created_at", "")[:10]
try:
last_dt = datetime.strptime(last_date_str, "%Y-%m-%d").date()
idle_days = (today - last_dt).days
except ValueError:
idle_days = threshold_days + 1
if idle_days >= threshold_days:
stock_value = qty * s.get("unit_cost", 0)
retail_value = qty * s.get("selling_price", 0)
suggestion = "建议促销清仓"
if idle_days >= 90:
suggestion = "严重滞销,建议大幅折扣或退回供应商"
elif idle_days >= 60:
suggestion = "中度滞销,建议打折促销"
elif idle_days >= 30:
suggestion = "轻度滞销,建议搭配销售或促销"
slow_items.append({
"sku_id": sku_id,
"name": s.get("name", ""),
"category": s.get("category", ""),
"warehouse": s.get("warehouse", ""),
"quantity": qty,
"unit_cost": s.get("unit_cost", 0),
"stock_value": round(stock_value, 2),
"retail_value": round(retail_value, 2),
"idle_days": idle_days,
"last_outbound": outbound_txs[0].get("timestamp", "")[:10] if outbound_txs else "无记录",
"suggestion": suggestion,
})
# 按滞销天数排序(最严重的在前)
slow_items.sort(key=lambda x: x["idle_days"], reverse=True)
total_frozen_value = sum(i["stock_value"] for i in slow_items)
output_success({
"action": "slow_moving",
"threshold_days": threshold_days,
"total": len(slow_items),
"total_frozen_value": round(total_frozen_value, 2),
"items": slow_items,
})
# ============================================================
# 综合报告
# ============================================================
def action_report(args) -> None:
"""生成库存周转分析报告(Markdown 格式)。"""
require_paid("库存周转分析报告")
inventory = load_inventory()
skus = inventory["skus"]
transactions = load_transactions()
data = read_json_input(args)
days = 30
if data and data.get("days"):
days = int(data["days"])
if not skus:
output_success({"action": "report", "message": "库存为空", "report": ""})
return
turnover_data = _calculate_turnover(skus, transactions, days)
# 汇总
total_cogs = sum(r["cogs"] for r in turnover_data)
total_avg_inv = sum(r["avg_inventory_cost"] for r in turnover_data)
overall_rate = total_cogs / total_avg_inv if total_avg_inv > 0 else 0
overall_days_val = days / overall_rate if overall_rate > 0 else None
# 按分类汇总周转率
cat_data: Dict[str, Dict[str, float]] = {}
for r in turnover_data:
cat = r.get("category", "未分类")
if cat not in cat_data:
cat_data[cat] = {"cogs": 0, "avg_inv": 0, "count": 0}
cat_data[cat]["cogs"] += r["cogs"]
cat_data[cat]["avg_inv"] += r["avg_inventory_cost"]
cat_data[cat]["count"] += 1
# 构建 Markdown 报告
report_lines = [
f"# 库存周转分析报告",
f"",
f"**分析周期**: 近 {days} 天 | **报告日期**: {date.today().isoformat()}",
f"",
f"---",
f"",
f"## 整体概况",
f"",
f"| 指标 | 数值 |",
f"|------|------|",
f"| SKU 总数 | {len(skus)} |",
f"| 期间总出库成本 | ¥{format_number(total_cogs)} |",
f"| 平均库存成本 | ¥{format_number(total_avg_inv)} |",
f"| 整体周转率 | {overall_rate:.2f} 次 |",
f"| 整体周转天数 | {f'{overall_days_val:.1f} 天' if overall_days_val else 'N/A'} |",
f"",
]
# 分类周转率表格
report_lines.extend([
f"## 分类周转率",
f"",
f"| 分类 | SKU数 | 出库成本 | 平均库存 | 周转率 |",
f"|------|-------|---------|---------|--------|",
])
for cat, cd in sorted(cat_data.items(), key=lambda x: x[1]["cogs"], reverse=True):
cat_rate = cd["cogs"] / cd["avg_inv"] if cd["avg_inv"] > 0 else 0
report_lines.append(
f"| {cat} | {cd['count']} | ¥{format_number(cd['cogs'])} | ¥{format_number(cd['avg_inv'])} | {cat_rate:.2f} |"
)
# Mermaid 分类周转率图表
report_lines.extend([
f"",
f"## 分类周转率图表",
f"",
f"```mermaid",
f"xychart-beta",
f' title "各分类周转率(近{days}天)"',
f" x-axis [{', '.join(json.dumps(c, ensure_ascii=False) for c in cat_data.keys())}]",
f" y-axis \"周转率(次)\"",
" bar [{}]".format(", ".join("{:.2f}".format(cd["cogs"] / cd["avg_inv"] if cd["avg_inv"] > 0 else 0) for cd in cat_data.values())),
f"```",
f"",
])
# Top 10 高周转 SKU
high_turnover = sorted(turnover_data, key=lambda x: x["turnover_rate"], reverse=True)[:10]
report_lines.extend([
f"## Top 10 高周转 SKU",
f"",
f"| 排名 | SKU | 名称 | 周转率 | 日均销量 |",
f"|------|-----|------|--------|---------|",
])
for i, r in enumerate(high_turnover, 1):
report_lines.append(
f"| {i} | {r['sku_id']} | {r['name']} | {r['turnover_rate']:.2f} | {r['daily_avg_sales']:.1f} |"
)
# Top 10 低周转 SKU
low_turnover = [r for r in turnover_data if r["current_quantity"] > 0][:10]
report_lines.extend([
f"",
f"## Top 10 低周转 SKU(需关注)",
f"",
f"| 排名 | SKU | 名称 | 周转率 | 库存量 | 库存金额 |",
f"|------|-----|------|--------|--------|---------|",
])
for i, r in enumerate(low_turnover, 1):
stock_val = r["current_quantity"] * r["unit_cost"]
report_lines.append(
f"| {i} | {r['sku_id']} | {r['name']} | {r['turnover_rate']:.2f} | {r['current_quantity']} | ¥{format_number(stock_val)} |"
)
report_lines.extend([
f"",
f"---",
f"",
f"> 📊 由 库存慧眼(inventory-eye)自动生成 | {datetime.now().strftime('%Y-%m-%d %H:%M')}",
])
report = "\n".join(report_lines)
output_success({
"action": "report",
"period_days": days,
"report": report,
})
# ============================================================
# 主入口
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="库存慧眼 — 周转率与滞销分析(付费功能)",
)
parser.add_argument(
"--action",
required=True,
choices=["turnover", "slow-moving", "report"],
help="分析操作类型",
)
parser.add_argument("--data", default=None, help="JSON 格式的参数")
parser.add_argument("--data-file", default=None, help="JSON 参数文件路径")
parser.add_argument("--days", type=int, default=30, help="分析周期天数(默认30天)")
args = parser.parse_args()
# 将 --days 参数合并到 data 中
if args.days != 30:
if not args.data:
args.data = json.dumps({"days": args.days})
else:
try:
d = json.loads(args.data)
d["days"] = args.days
args.data = json.dumps(d)
except json.JSONDecodeError:
pass
action_map = {
"turnover": action_turnover,
"slow-moving": action_slow_moving,
"report": action_report,
}
try:
action_map[args.action](args)
except SystemExit:
pass
except Exception as e:
output_error(f"分析操作失败: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/inventory_store.py
#!/usr/bin/env python3
"""
inventory-eye 库存数据管理模块
提供库存数据的导入、增删改查、导出功能,支持 CSV 导入和 JSON 存储。
"""
import argparse
import csv
import json
import os
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
# 将 scripts 目录加入路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription, check_sku_limit, check_warehouse_limit,
format_sku, parse_date, parse_csv_columns, read_csv_file,
load_inventory, save_inventory, add_transaction,
read_json_input, output_json, output_error, output_success,
)
# ============================================================
# 库存操作
# ============================================================
def _validate_sku(sku: Dict[str, Any]) -> Dict[str, Any]:
"""验证并标准化 SKU 数据。
Args:
sku: 原始 SKU 数据字典。
Returns:
标准化后的 SKU 字典。
Raises:
ValueError: 当必填字段缺失时抛出。
"""
if not sku.get("sku_id"):
raise ValueError("SKU 编码(sku_id)为必填字段")
if not sku.get("name"):
raise ValueError("商品名称(name)为必填字段")
now = datetime.now().isoformat()
return {
"sku_id": format_sku(sku["sku_id"]),
"name": str(sku["name"]).strip(),
"category": str(sku.get("category", "")).strip() or "未分类",
"quantity": max(0, int(float(sku.get("quantity", 0)))),
"unit_cost": round(float(sku.get("unit_cost", 0)), 2),
"selling_price": round(float(sku.get("selling_price", 0)), 2),
"safety_stock": max(0, int(float(sku.get("safety_stock", 10)))),
"warehouse": str(sku.get("warehouse", "默认仓库")).strip() or "默认仓库",
"expiry_date": parse_date(sku.get("expiry_date")) or "",
"last_inbound_date": parse_date(sku.get("last_inbound_date")) or "",
"last_outbound_date": parse_date(sku.get("last_outbound_date")) or "",
"created_at": sku.get("created_at", now),
"updated_at": now,
}
def action_import(args) -> None:
"""从 CSV 文件导入库存数据。"""
filepath = args.file
if not filepath:
output_error("请通过 --file 参数指定 CSV 文件路径", "MISSING_FILE")
return
if not os.path.exists(filepath):
output_error(f"文件不存在: {filepath}", "FILE_NOT_FOUND")
return
sub = check_subscription()
inventory = load_inventory()
existing_skus = {s["sku_id"] for s in inventory["skus"]}
try:
rows = read_csv_file(filepath)
except ValueError as e:
output_error(str(e), "CSV_READ_ERROR")
return
if not rows:
output_error("CSV 文件为空或无有效数据", "EMPTY_FILE")
return
# 自动映射列名
header = list(rows[0].keys())
mapping = parse_csv_columns(header)
imported = 0
updated = 0
skipped = 0
errors = []
for i, row in enumerate(rows, 1):
try:
sku_data = {}
for field, col in mapping.items():
if col and col in row:
sku_data[field] = row[col]
if not sku_data.get("sku_id") and not sku_data.get("name"):
skipped += 1
continue
# 如果没有 sku_id,用行号生成
if not sku_data.get("sku_id"):
sku_data["sku_id"] = f"SKU-{i:05d}"
validated = _validate_sku(sku_data)
sku_id = validated["sku_id"]
# 检查仓库限制
current_warehouses = list({s["warehouse"] for s in inventory["skus"]})
if validated["warehouse"] not in current_warehouses:
if not check_warehouse_limit(current_warehouses + [validated["warehouse"]]):
errors.append(f"第{i}行: 仓库数量超出{sub['tier']}版限制(最多{sub['max_warehouses']}个)")
skipped += 1
continue
if sku_id in existing_skus:
# 更新已有 SKU
for j, s in enumerate(inventory["skus"]):
if s["sku_id"] == sku_id:
validated["created_at"] = s["created_at"]
inventory["skus"][j] = validated
break
updated += 1
else:
# 检查 SKU 数量限制
if not check_sku_limit(len(inventory["skus"])):
errors.append(f"第{i}行: SKU 数量已达{sub['tier']}版上限({sub['max_skus']}个)")
skipped += 1
continue
inventory["skus"].append(validated)
existing_skus.add(sku_id)
imported += 1
except (ValueError, TypeError) as e:
errors.append(f"第{i}行: {e}")
skipped += 1
save_inventory(inventory)
result = {
"action": "import",
"file": filepath,
"column_mapping": {k: v for k, v in mapping.items() if v},
"imported": imported,
"updated": updated,
"skipped": skipped,
"total_skus": len(inventory["skus"]),
}
if errors:
result["errors"] = errors[:20] # 最多显示 20 条错误
output_success(result)
def action_add(args) -> None:
"""添加单个 SKU。"""
data = read_json_input(args)
if not data:
output_error("请通过 --data 或 --data-file 提供 SKU 数据", "MISSING_DATA")
return
sub = check_subscription()
inventory = load_inventory()
try:
validated = _validate_sku(data)
except ValueError as e:
output_error(str(e), "VALIDATION_ERROR")
return
# 检查 SKU 是否已存在
for s in inventory["skus"]:
if s["sku_id"] == validated["sku_id"]:
output_error(f"SKU {validated['sku_id']} 已存在,请使用 update 操作", "DUPLICATE_SKU")
return
# 检查 SKU 限制
if not check_sku_limit(len(inventory["skus"])):
output_error(
f"SKU 数量已达{sub['tier']}版上限({sub['max_skus']}个),请升级至付费版(¥89/月)",
"SKU_LIMIT_EXCEEDED",
)
return
# 检查仓库限制
current_warehouses = list({s["warehouse"] for s in inventory["skus"]})
if validated["warehouse"] not in current_warehouses:
if not check_warehouse_limit(current_warehouses + [validated["warehouse"]]):
output_error(
f"仓库数量已达{sub['tier']}版上限({sub['max_warehouses']}个),请升级至付费版(¥89/月)",
"WAREHOUSE_LIMIT_EXCEEDED",
)
return
inventory["skus"].append(validated)
save_inventory(inventory)
output_success({
"action": "add",
"sku": validated,
"total_skus": len(inventory["skus"]),
})
def action_update(args) -> None:
"""更新已有 SKU 信息。"""
data = read_json_input(args)
if not data:
output_error("请通过 --data 或 --data-file 提供更新数据", "MISSING_DATA")
return
sku_id = format_sku(data.get("sku_id", ""))
if not sku_id:
output_error("更新操作需要提供 sku_id", "MISSING_SKU_ID")
return
inventory = load_inventory()
found = False
for i, s in enumerate(inventory["skus"]):
if s["sku_id"] == sku_id:
# 合并更新
for key, val in data.items():
if key in ("created_at",):
continue
if key == "sku_id":
s[key] = format_sku(val)
elif key == "quantity":
s[key] = max(0, int(float(val)))
elif key in ("unit_cost", "selling_price"):
s[key] = round(float(val), 2)
elif key == "safety_stock":
s[key] = max(0, int(float(val)))
elif key in ("expiry_date", "last_inbound_date", "last_outbound_date"):
parsed = parse_date(val)
if parsed:
s[key] = parsed
else:
s[key] = val
s["updated_at"] = datetime.now().isoformat()
inventory["skus"][i] = s
found = True
break
if not found:
output_error(f"未找到 SKU: {sku_id}", "SKU_NOT_FOUND")
return
save_inventory(inventory)
output_success({
"action": "update",
"sku": inventory["skus"][i],
})
def action_delete(args) -> None:
"""删除指定 SKU。"""
data = read_json_input(args)
if not data:
output_error("请通过 --data 提供要删除的 sku_id", "MISSING_DATA")
return
sku_id = format_sku(data.get("sku_id", ""))
if not sku_id:
output_error("删除操作需要提供 sku_id", "MISSING_SKU_ID")
return
inventory = load_inventory()
original_count = len(inventory["skus"])
inventory["skus"] = [s for s in inventory["skus"] if s["sku_id"] != sku_id]
if len(inventory["skus"]) == original_count:
output_error(f"未找到 SKU: {sku_id}", "SKU_NOT_FOUND")
return
save_inventory(inventory)
output_success({
"action": "delete",
"sku_id": sku_id,
"remaining_skus": len(inventory["skus"]),
})
def action_list(args) -> None:
"""列出所有库存 SKU。"""
inventory = load_inventory()
skus = inventory["skus"]
# 支持按仓库、分类筛选
data = read_json_input(args)
if data:
if data.get("warehouse"):
skus = [s for s in skus if s.get("warehouse") == data["warehouse"]]
if data.get("category"):
skus = [s for s in skus if s.get("category") == data["category"]]
sub = check_subscription()
output_success({
"action": "list",
"tier": sub["tier"],
"total": len(skus),
"max_skus": sub["max_skus"],
"skus": skus,
})
def action_get(args) -> None:
"""获取单个 SKU 详情。"""
data = read_json_input(args)
if not data:
output_error("请通过 --data 提供 sku_id", "MISSING_DATA")
return
sku_id = format_sku(data.get("sku_id", ""))
if not sku_id:
output_error("需要提供 sku_id", "MISSING_SKU_ID")
return
inventory = load_inventory()
for s in inventory["skus"]:
if s["sku_id"] == sku_id:
output_success({"action": "get", "sku": s})
return
output_error(f"未找到 SKU: {sku_id}", "SKU_NOT_FOUND")
def action_export(args) -> None:
"""将库存数据导出为 CSV 文件。"""
inventory = load_inventory()
skus = inventory["skus"]
if not skus:
output_error("库存为空,无数据可导出", "EMPTY_INVENTORY")
return
data = read_json_input(args)
output_path = None
if data and data.get("output"):
output_path = data["output"]
elif args.file:
output_path = args.file
if not output_path:
output_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..", "output",
f"inventory_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
)
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
fieldnames = [
"sku_id", "name", "category", "quantity", "unit_cost", "selling_price",
"safety_stock", "warehouse", "expiry_date", "last_inbound_date",
"last_outbound_date", "created_at", "updated_at",
]
with open(output_path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
writer.writeheader()
writer.writerows(skus)
output_success({
"action": "export",
"file": os.path.abspath(output_path),
"total_skus": len(skus),
})
def action_inbound(args) -> None:
"""入库操作:增加指定 SKU 的库存数量。"""
data = read_json_input(args)
if not data:
output_error("请提供入库数据(sku_id, quantity)", "MISSING_DATA")
return
sku_id = format_sku(data.get("sku_id", ""))
quantity = int(float(data.get("quantity", 0)))
note = data.get("note", "")
if not sku_id:
output_error("入库操作需要提供 sku_id", "MISSING_SKU_ID")
return
if quantity <= 0:
output_error("入库数量必须大于 0", "INVALID_QUANTITY")
return
inventory = load_inventory()
found = False
for s in inventory["skus"]:
if s["sku_id"] == sku_id:
s["quantity"] += quantity
s["last_inbound_date"] = datetime.now().strftime("%Y-%m-%d")
s["updated_at"] = datetime.now().isoformat()
found = True
add_transaction("inbound", sku_id, quantity, note)
save_inventory(inventory)
output_success({
"action": "inbound",
"sku_id": sku_id,
"added": quantity,
"new_quantity": s["quantity"],
})
return
if not found:
output_error(f"未找到 SKU: {sku_id}", "SKU_NOT_FOUND")
def action_outbound(args) -> None:
"""出库操作:减少指定 SKU 的库存数量。"""
data = read_json_input(args)
if not data:
output_error("请提供出库数据(sku_id, quantity)", "MISSING_DATA")
return
sku_id = format_sku(data.get("sku_id", ""))
quantity = int(float(data.get("quantity", 0)))
note = data.get("note", "")
if not sku_id:
output_error("出库操作需要提供 sku_id", "MISSING_SKU_ID")
return
if quantity <= 0:
output_error("出库数量必须大于 0", "INVALID_QUANTITY")
return
inventory = load_inventory()
for s in inventory["skus"]:
if s["sku_id"] == sku_id:
if s["quantity"] < quantity:
output_error(
f"库存不足: {s['name']}({sku_id})当前库存 {s['quantity']},请求出库 {quantity}",
"INSUFFICIENT_STOCK",
)
return
s["quantity"] -= quantity
s["last_outbound_date"] = datetime.now().strftime("%Y-%m-%d")
s["updated_at"] = datetime.now().isoformat()
add_transaction("outbound", sku_id, quantity, note)
save_inventory(inventory)
output_success({
"action": "outbound",
"sku_id": sku_id,
"removed": quantity,
"new_quantity": s["quantity"],
})
return
output_error(f"未找到 SKU: {sku_id}", "SKU_NOT_FOUND")
# ============================================================
# 主入口
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="库存慧眼 — 库存数据管理工具",
)
parser.add_argument(
"--action",
required=True,
choices=["import", "add", "update", "delete", "list", "get", "export",
"inbound", "outbound"],
help="操作类型",
)
parser.add_argument(
"--data",
default=None,
help="JSON 格式的数据参数",
)
parser.add_argument(
"--data-file",
default=None,
help="JSON 数据文件路径",
)
parser.add_argument(
"--file",
default=None,
help="CSV 文件路径(用于导入/导出)",
)
args = parser.parse_args()
actions = {
"import": action_import,
"add": action_add,
"update": action_update,
"delete": action_delete,
"list": action_list,
"get": action_get,
"export": action_export,
"inbound": action_inbound,
"outbound": action_outbound,
}
try:
actions[args.action](args)
except Exception as e:
output_error(f"操作失败: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/reorder_calculator.py
#!/usr/bin/env python3
"""
inventory-eye 补货量计算模块(付费功能)
根据历史销售数据计算最优补货量,提供智能补货建议。
补货公式: reorder_qty = (daily_avg_sales × lead_time × safety_factor) - current_stock
"""
import argparse
import json
import math
import os
import sys
from datetime import datetime, date, timedelta
from typing import Any, Dict, List, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription, require_paid, format_number,
load_inventory, load_transactions, read_json_input,
output_json, output_error, output_success,
)
# ============================================================
# 默认参数
# ============================================================
DEFAULT_LEAD_TIME = 7 # 默认供货周期(天)
DEFAULT_SAFETY_FACTOR = 1.5 # 默认安全系数
DEFAULT_ANALYSIS_DAYS = 30 # 默认分析天数
MIN_REORDER_QTY = 1 # 最小补货量
# ============================================================
# 补货计算
# ============================================================
def _calc_daily_avg_sales(sku_id: str, transactions: List[Dict], days: int) -> float:
"""计算指定 SKU 的日均销售量。
Args:
sku_id: SKU 编码。
transactions: 交易记录列表。
days: 分析天数。
Returns:
日均销售量。
"""
cutoff = (date.today() - timedelta(days=days)).isoformat()
outbound = [
t for t in transactions
if t.get("type") == "outbound"
and t.get("sku_id") == sku_id
and t.get("timestamp", "") >= cutoff
]
total = sum(t.get("quantity", 0) for t in outbound)
return total / days if days > 0 else 0
def _calculate_reorder(
sku: Dict[str, Any],
transactions: List[Dict],
lead_time: int = DEFAULT_LEAD_TIME,
safety_factor: float = DEFAULT_SAFETY_FACTOR,
analysis_days: int = DEFAULT_ANALYSIS_DAYS,
) -> Optional[Dict[str, Any]]:
"""计算单个 SKU 的补货建议。
公式: reorder_qty = (daily_avg_sales × lead_time × safety_factor) - current_stock
Args:
sku: SKU 数据。
transactions: 交易记录。
lead_time: 供货周期天数。
safety_factor: 安全系数。
analysis_days: 历史分析天数。
Returns:
补货建议字典,若无需补货返回 None。
"""
sku_id = sku["sku_id"]
current_stock = sku.get("quantity", 0)
safety_stock = sku.get("safety_stock", 0)
unit_cost = sku.get("unit_cost", 0)
daily_avg = _calc_daily_avg_sales(sku_id, transactions, analysis_days)
# 需求量 = 日均 × 供货周期 × 安全系数
demand = daily_avg * lead_time * safety_factor
# 补货量 = 需求量 - 当前库存
reorder_qty = math.ceil(demand - current_stock)
# 如果也低于安全库存,取较大值
safety_deficit = safety_stock - current_stock
if safety_deficit > reorder_qty:
reorder_qty = safety_deficit
if reorder_qty < MIN_REORDER_QTY:
return None # 不需要补货
# 预计可售天数
days_remaining = current_stock / daily_avg if daily_avg > 0 else float("inf")
# 紧急程度
urgency = "normal"
if current_stock == 0:
urgency = "critical"
elif days_remaining <= lead_time:
urgency = "urgent"
elif current_stock <= safety_stock:
urgency = "warning"
return {
"sku_id": sku_id,
"name": sku.get("name", ""),
"category": sku.get("category", ""),
"warehouse": sku.get("warehouse", ""),
"current_stock": current_stock,
"safety_stock": safety_stock,
"daily_avg_sales": round(daily_avg, 2),
"lead_time": lead_time,
"safety_factor": safety_factor,
"reorder_qty": reorder_qty,
"reorder_cost": round(reorder_qty * unit_cost, 2),
"unit_cost": unit_cost,
"days_remaining": round(days_remaining, 1) if days_remaining != float("inf") else None,
"urgency": urgency,
}
def action_calculate(args) -> None:
"""计算所有需要补货的 SKU。"""
require_paid("AI补货计算")
inventory = load_inventory()
skus = inventory["skus"]
transactions = load_transactions()
data = read_json_input(args)
lead_time = DEFAULT_LEAD_TIME
safety_factor = DEFAULT_SAFETY_FACTOR
analysis_days = DEFAULT_ANALYSIS_DAYS
if data:
lead_time = int(data.get("lead_time", lead_time))
safety_factor = float(data.get("safety_factor", safety_factor))
analysis_days = int(data.get("days", analysis_days))
if not skus:
output_success({"action": "calculate", "message": "库存为空", "items": []})
return
reorder_items = []
for sku in skus:
result = _calculate_reorder(sku, transactions, lead_time, safety_factor, analysis_days)
if result:
reorder_items.append(result)
# 按紧急程度排序
urgency_order = {"critical": 0, "urgent": 1, "warning": 2, "normal": 3}
reorder_items.sort(key=lambda x: (urgency_order.get(x["urgency"], 9), -x["reorder_qty"]))
total_cost = sum(i["reorder_cost"] for i in reorder_items)
output_success({
"action": "calculate",
"parameters": {
"lead_time": lead_time,
"safety_factor": safety_factor,
"analysis_days": analysis_days,
},
"total_items": len(reorder_items),
"total_reorder_cost": round(total_cost, 2),
"critical": len([i for i in reorder_items if i["urgency"] == "critical"]),
"urgent": len([i for i in reorder_items if i["urgency"] == "urgent"]),
"warning": len([i for i in reorder_items if i["urgency"] == "warning"]),
"items": reorder_items,
})
def action_suggest(args) -> None:
"""为指定 SKU 生成详细补货建议。"""
require_paid("AI补货建议")
data = read_json_input(args)
if not data or not data.get("sku_id"):
output_error("请通过 --data 提供 sku_id", "MISSING_DATA")
return
from utils import format_sku
sku_id = format_sku(data["sku_id"])
inventory = load_inventory()
transactions = load_transactions()
target_sku = None
for s in inventory["skus"]:
if s["sku_id"] == sku_id:
target_sku = s
break
if not target_sku:
output_error(f"未找到 SKU: {sku_id}", "SKU_NOT_FOUND")
return
lead_time = int(data.get("lead_time", DEFAULT_LEAD_TIME))
safety_factor = float(data.get("safety_factor", DEFAULT_SAFETY_FACTOR))
# 多周期分析
suggestions = []
for period in [7, 14, 30, 60, 90]:
daily_avg = _calc_daily_avg_sales(sku_id, transactions, period)
demand = daily_avg * lead_time * safety_factor
reorder = math.ceil(max(demand - target_sku.get("quantity", 0), 0))
suggestions.append({
"analysis_period": f"近{period}天",
"daily_avg_sales": round(daily_avg, 2),
"suggested_reorder": reorder,
"reorder_cost": round(reorder * target_sku.get("unit_cost", 0), 2),
})
# 推荐值取30天的
recommended = _calculate_reorder(target_sku, transactions, lead_time, safety_factor, 30)
output_success({
"action": "suggest",
"sku": {
"sku_id": target_sku["sku_id"],
"name": target_sku.get("name", ""),
"current_stock": target_sku.get("quantity", 0),
"safety_stock": target_sku.get("safety_stock", 0),
"unit_cost": target_sku.get("unit_cost", 0),
},
"parameters": {
"lead_time": lead_time,
"safety_factor": safety_factor,
},
"multi_period_analysis": suggestions,
"recommended": recommended,
})
def action_report(args) -> None:
"""生成补货建议报告(Markdown 格式)。"""
require_paid("补货建议报告")
inventory = load_inventory()
skus = inventory["skus"]
transactions = load_transactions()
data = read_json_input(args)
lead_time = DEFAULT_LEAD_TIME
safety_factor = DEFAULT_SAFETY_FACTOR
if data:
lead_time = int(data.get("lead_time", lead_time))
safety_factor = float(data.get("safety_factor", safety_factor))
if not skus:
output_success({"action": "report", "report": "库存为空,无数据可生成报告。"})
return
# 计算所有补货建议
reorder_items = []
for sku in skus:
result = _calculate_reorder(sku, transactions, lead_time, safety_factor, DEFAULT_ANALYSIS_DAYS)
if result:
reorder_items.append(result)
urgency_order = {"critical": 0, "urgent": 1, "warning": 2, "normal": 3}
reorder_items.sort(key=lambda x: (urgency_order.get(x["urgency"], 9), -x["reorder_qty"]))
total_cost = sum(i["reorder_cost"] for i in reorder_items)
# 按紧急程度分组
critical = [i for i in reorder_items if i["urgency"] == "critical"]
urgent = [i for i in reorder_items if i["urgency"] == "urgent"]
warning = [i for i in reorder_items if i["urgency"] == "warning"]
normal = [i for i in reorder_items if i["urgency"] == "normal"]
urgency_labels = {"critical": "🔴 缺货", "urgent": "🟠 紧急", "warning": "🟡 预警", "normal": "🟢 正常"}
lines = [
f"# 补货建议报告",
f"",
f"**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M')} | **供货周期**: {lead_time}天 | **安全系数**: {safety_factor}",
f"",
f"---",
f"",
f"## 概况",
f"",
f"| 指标 | 数值 |",
f"|------|------|",
f"| 需补货SKU数 | {len(reorder_items)} |",
f"| 缺货(critical) | {len(critical)} |",
f"| 紧急(urgent) | {len(urgent)} |",
f"| 预警(warning) | {len(warning)} |",
f"| 一般(normal) | {len(normal)} |",
f"| 预计补货总成本 | ¥{format_number(total_cost)} |",
f"",
]
# Mermaid 紧急程度分布
lines.extend([
f"## 紧急程度分布",
f"",
f"```mermaid",
f"pie title 补货紧急程度分布",
f' "缺货" : {len(critical)}',
f' "紧急" : {len(urgent)}',
f' "预警" : {len(warning)}',
f' "一般" : {len(normal)}',
f"```",
f"",
])
# 补货清单
lines.extend([
f"## 补货清单",
f"",
f"| 紧急度 | SKU | 名称 | 当前库存 | 建议补货 | 补货成本 | 可售天数 |",
f"|--------|-----|------|---------|---------|---------|---------|",
])
for item in reorder_items:
urgency_label = urgency_labels.get(item["urgency"], "")
days_rem = f"{item['days_remaining']}天" if item["days_remaining"] is not None else "已断货"
lines.append(
f"| {urgency_label} | {item['sku_id']} | {item['name']} | "
f"{item['current_stock']} | {item['reorder_qty']} | "
f"¥{format_number(item['reorder_cost'])} | {days_rem} |"
)
lines.extend([
f"",
f"---",
f"",
f"### 计算说明",
f"",
f"- **补货公式**: 补货量 = (日均销量 × 供货周期 × 安全系数) - 当前库存",
f"- **日均销量**: 基于近{DEFAULT_ANALYSIS_DAYS}天出库数据计算",
f"- **供货周期**: {lead_time}天",
f"- **安全系数**: {safety_factor}",
f"",
f"> 📦 由 库存慧眼(inventory-eye)自动生成 | {datetime.now().strftime('%Y-%m-%d %H:%M')}",
])
report = "\n".join(lines)
output_success({
"action": "report",
"total_items": len(reorder_items),
"total_cost": round(total_cost, 2),
"report": report,
})
# ============================================================
# 主入口
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="库存慧眼 — 补货量计算(付费功能)",
)
parser.add_argument(
"--action",
required=True,
choices=["calculate", "suggest", "report"],
help="操作类型",
)
parser.add_argument("--data", default=None, help="JSON 格式的参数")
parser.add_argument("--data-file", default=None, help="JSON 参数文件路径")
args = parser.parse_args()
try:
action_map = {
"calculate": action_calculate,
"suggest": action_suggest,
"report": action_report,
}
action_map[args.action](args)
except SystemExit:
pass
except Exception as e:
output_error(f"补货计算失败: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
inventory-eye 共享工具模块
提供数据格式化、输入输出处理、订阅校验、库存专用辅助函数等通用功能。
"""
import csv
import json
import os
import re
import sys
from datetime import datetime, date
from typing import Any, Dict, List, Optional
# ============================================================
# 常量与配置
# ============================================================
DEFAULT_DATA_DIR = os.path.join(os.path.expanduser("~"), ".openclaw-bdi", "inventory-eye")
ENV_DATA_DIR = "IE_DATA_DIR"
ENV_SUBSCRIPTION_TIER = "IE_SUBSCRIPTION_TIER"
# SKU 字段定义
SKU_FIELDS = [
"sku_id", "name", "category", "quantity", "unit_cost", "selling_price",
"safety_stock", "warehouse", "expiry_date", "last_inbound_date",
"last_outbound_date", "created_at", "updated_at",
]
# 订阅等级配置
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_skus": 100,
"max_warehouses": 1,
"features": [
"csv_import",
"inventory_overview",
"fixed_threshold_alert",
"basic_expiry_alert",
],
},
"paid": {
"tier": "paid",
"max_skus": 2000,
"max_warehouses": 5,
"features": [
"csv_import",
"inventory_overview",
"dynamic_safety_stock",
"multi_level_expiry_alert",
"slow_moving_analysis",
"reorder_suggestion",
"turnover_analysis",
],
},
}
# ============================================================
# 数据目录管理
# ============================================================
def get_data_dir() -> str:
"""获取数据存储目录路径,不存在时自动创建。
优先使用环境变量 IE_DATA_DIR,否则使用默认路径
~/.openclaw-bdi/inventory-eye/
Returns:
数据目录的绝对路径。
"""
data_dir = os.environ.get(ENV_DATA_DIR, DEFAULT_DATA_DIR)
os.makedirs(data_dir, exist_ok=True)
return data_dir
def get_inventory_file() -> str:
"""获取库存数据文件路径。
Returns:
库存 JSON 文件的绝对路径。
"""
return os.path.join(get_data_dir(), "inventory.json")
def get_transactions_file() -> str:
"""获取出入库记录文件路径。
Returns:
交易记录 JSON 文件的绝对路径。
"""
return os.path.join(get_data_dir(), "transactions.json")
# ============================================================
# 数字格式化
# ============================================================
def format_number(value: float, decimals: int = 2) -> str:
"""将数字格式化为带千分位分隔符的字符串。
Args:
value: 待格式化的数值。
decimals: 小数位数,默认为 2。
Returns:
格式化后的字符串,例如 1234567 → "1,234,567.00"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
if decimals <= 0:
return f"{int(round(num)):,}"
return f"{num:,.{decimals}f}"
def format_percentage(value: float, decimals: int = 1) -> str:
"""将小数格式化为百分比字符串。
Args:
value: 待格式化的小数值(0.156 表示 15.6%)。
decimals: 百分比小数位数,默认为 1。
Returns:
百分比字符串,例如 0.156 → "15.6%"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
pct = num * 100
return f"{pct:.{decimals}f}%"
def format_chinese_unit(value: float) -> str:
"""将大数字转换为中文单位表示(万、亿)。
Args:
value: 待转换的数值。
Returns:
带中文单位的字符串,例如:
- 12345 → "1.23万"
- 123456789 → "1.23亿"
- 999 → "999"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
abs_num = abs(num)
sign = "-" if num < 0 else ""
if abs_num >= 1e8:
result = abs_num / 1e8
return f"{sign}{result:.2f}亿"
elif abs_num >= 1e4:
result = abs_num / 1e4
return f"{sign}{result:.2f}万"
else:
if abs_num == int(abs_num):
return f"{sign}{int(abs_num)}"
return f"{sign}{abs_num:.2f}"
# ============================================================
# 库存专用格式化
# ============================================================
def format_sku(sku_id: str) -> str:
"""格式化 SKU 编码,确保统一大写并去除首尾空白。
Args:
sku_id: 原始 SKU 编码。
Returns:
标准化后的 SKU 编码。
"""
if not sku_id or not isinstance(sku_id, str):
return str(sku_id or "").strip()
return sku_id.strip().upper()
def days_until_expiry(expiry_date: Optional[str]) -> Optional[int]:
"""计算距离过期日期的天数。
Args:
expiry_date: 过期日期字符串,支持 YYYY-MM-DD 格式。
Returns:
距离过期的天数(负数表示已过期),若无过期日期则返回 None。
"""
if not expiry_date:
return None
try:
exp = datetime.strptime(str(expiry_date).strip(), "%Y-%m-%d").date()
delta = exp - date.today()
return delta.days
except (ValueError, TypeError):
return None
def parse_date(date_str: Optional[str]) -> Optional[str]:
"""解析并标准化日期字符串为 YYYY-MM-DD 格式。
支持格式: YYYY-MM-DD, YYYY/MM/DD, DD-MM-YYYY, DD/MM/YYYY
Args:
date_str: 待解析的日期字符串。
Returns:
标准化的日期字符串(YYYY-MM-DD),解析失败返回 None。
"""
if not date_str:
return None
date_str = str(date_str).strip()
if not date_str:
return None
formats = ["%Y-%m-%d", "%Y/%m/%d", "%d-%m-%Y", "%d/%m/%Y"]
for fmt in formats:
try:
dt = datetime.strptime(date_str, fmt)
return dt.strftime("%Y-%m-%d")
except ValueError:
continue
return None
def parse_csv_columns(header: List[str]) -> Dict[str, Optional[str]]:
"""自动识别 CSV 列名与 SKU 字段的映射关系。
通过关键词匹配,自动将 CSV 表头映射到标准 SKU 字段。
Args:
header: CSV 文件的表头列名列表。
Returns:
映射字典 {SKU字段名: CSV列名},未匹配的字段值为 None。
"""
mapping: Dict[str, Optional[str]] = {f: None for f in SKU_FIELDS}
# 关键词映射规则
rules = {
"sku_id": ["sku", "编码", "编号", "货号", "商品编码", "sku_id", "skuid", "商品id"],
"name": ["名称", "商品名", "品名", "产品名", "name", "商品名称", "产品名称"],
"category": ["分类", "类别", "品类", "category", "类目"],
"quantity": ["数量", "库存", "库存量", "qty", "quantity", "stock", "现有库存"],
"unit_cost": ["成本", "进价", "采购价", "cost", "unit_cost", "进货价"],
"selling_price": ["售价", "销售价", "selling_price", "price", "零售价", "标价"],
"safety_stock": ["安全库存", "最低库存", "safety_stock", "safety", "预警值"],
"warehouse": ["仓库", "库房", "warehouse", "仓"],
"expiry_date": ["过期", "保质期", "到期", "expiry", "有效期", "过期日期"],
"last_inbound_date": ["入库日期", "入库时间", "进货日期", "inbound"],
"last_outbound_date": ["出库日期", "出库时间", "销售日期", "outbound"],
}
header_lower = [h.strip().lower() for h in header]
for field, keywords in rules.items():
for kw in keywords:
kw_lower = kw.lower()
for i, col in enumerate(header_lower):
if kw_lower == col or kw_lower in col:
mapping[field] = header[i].strip()
break
if mapping[field]:
break
return mapping
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_input(args) -> Any:
"""从命令行参数或标准输入读取 JSON 数据。
优先使用 --data 参数,其次 --data-file 参数,最后从 stdin 读取。
Args:
args: argparse 解析后的命名空间,应包含 data 和 data_file 属性。
Returns:
解析后的 JSON 数据。
Raises:
ValueError: 当所有来源均无可用数据时抛出。
"""
if hasattr(args, "data") and args.data:
try:
return json.loads(args.data)
except json.JSONDecodeError as e:
raise ValueError(f"--data 参数 JSON 解析失败: {e}")
if hasattr(args, "data_file") and args.data_file:
if not os.path.exists(args.data_file):
raise ValueError(f"数据文件不存在: {args.data_file}")
with open(args.data_file, "r", encoding="utf-8") as f:
try:
return json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"数据文件 JSON 解析失败: {e}")
if not sys.stdin.isatty():
raw = sys.stdin.read()
if raw.strip():
try:
return json.loads(raw)
except json.JSONDecodeError as e:
raise ValueError(f"标准输入 JSON 解析失败: {e}")
return None
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
使用 ensure_ascii=False 以保留中文等非 ASCII 字符。
Args:
data: 待输出的数据(可被 JSON 序列化的任意对象)。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 订阅校验
# ============================================================
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 IE_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get(ENV_SUBSCRIPTION_TIER, "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def require_paid(feature_name: str) -> None:
"""检查当前是否为付费版,若非付费版则输出错误并退出。
Args:
feature_name: 功能名称,用于错误提示。
"""
sub = check_subscription()
if sub["tier"] != "paid":
output_error(
f"「{feature_name}」为付费版功能。当前为免费版,请升级至付费版(¥89/月)以使用此功能。",
code="SUBSCRIPTION_REQUIRED",
)
sys.exit(0)
def check_sku_limit(current_count: int) -> bool:
"""检查 SKU 数量是否超出当前订阅限制。
Args:
current_count: 当前 SKU 数量。
Returns:
True 表示未超限,False 表示已超限。
"""
sub = check_subscription()
return current_count < sub["max_skus"]
def check_warehouse_limit(warehouses: List[str]) -> bool:
"""检查仓库数量是否超出当前订阅限制。
Args:
warehouses: 当前已有仓库名称列表。
Returns:
True 表示未超限,False 表示已超限。
"""
sub = check_subscription()
return len(set(warehouses)) <= sub["max_warehouses"]
# ============================================================
# 数据持久化
# ============================================================
def load_inventory() -> Dict[str, Any]:
"""从文件加载库存数据。
Returns:
库存数据字典,包含 skus(SKU列表)和 metadata(元数据)。
"""
filepath = get_inventory_file()
if not os.path.exists(filepath):
return {"skus": [], "metadata": {"created_at": datetime.now().isoformat(), "version": "1.0"}}
with open(filepath, "r", encoding="utf-8") as f:
try:
return json.load(f)
except json.JSONDecodeError:
return {"skus": [], "metadata": {"created_at": datetime.now().isoformat(), "version": "1.0"}}
def save_inventory(data: Dict[str, Any]) -> None:
"""将库存数据保存到文件。
Args:
data: 库存数据字典。
"""
data["metadata"] = data.get("metadata", {})
data["metadata"]["updated_at"] = datetime.now().isoformat()
filepath = get_inventory_file()
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def load_transactions() -> List[Dict[str, Any]]:
"""从文件加载出入库记录。
Returns:
交易记录列表。
"""
filepath = get_transactions_file()
if not os.path.exists(filepath):
return []
with open(filepath, "r", encoding="utf-8") as f:
try:
return json.load(f)
except json.JSONDecodeError:
return []
def save_transactions(transactions: List[Dict[str, Any]]) -> None:
"""将出入库记录保存到文件。
Args:
transactions: 交易记录列表。
"""
filepath = get_transactions_file()
with open(filepath, "w", encoding="utf-8") as f:
json.dump(transactions, f, ensure_ascii=False, indent=2, default=str)
def add_transaction(tx_type: str, sku_id: str, quantity: int, note: str = "") -> None:
"""添加一条出入库记录。
Args:
tx_type: 交易类型,"inbound"(入库)或 "outbound"(出库)。
sku_id: SKU 编码。
quantity: 数量。
note: 备注说明。
"""
transactions = load_transactions()
transactions.append({
"type": tx_type,
"sku_id": format_sku(sku_id),
"quantity": quantity,
"note": note,
"timestamp": datetime.now().isoformat(),
})
save_transactions(transactions)
# ============================================================
# CSV 解析
# ============================================================
def read_csv_file(filepath: str) -> List[Dict[str, str]]:
"""读取 CSV 文件并返回字典列表。
Args:
filepath: CSV 文件路径。
Returns:
每行数据对应一个字典的列表。
Raises:
ValueError: 当文件不存在或格式错误时抛出。
"""
if not os.path.exists(filepath):
raise ValueError(f"CSV 文件不存在: {filepath}")
rows = []
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312"]
for enc in encodings:
try:
with open(filepath, "r", encoding=enc, newline="") as f:
reader = csv.DictReader(f)
rows = [row for row in reader]
break
except (UnicodeDecodeError, UnicodeError):
continue
except Exception as e:
raise ValueError(f"读取 CSV 文件失败: {e}")
else:
raise ValueError(f"无法以支持的编码读取 CSV 文件: {filepath}")
return rows
FILE:references/inventory-guide.md
# 库存管理指南
本指南涵盖库存管理的核心概念和最佳实践,供 AI 助手在回答用户问题时参考。
---
## 一、安全库存(Safety Stock)
### 定义
安全库存是为了应对需求波动和供货延迟而额外持有的库存量,用于防止缺货。
### 计算公式
**基础公式:**
```
安全库存 = Z × σ × √L
```
- Z = 服务水平系数(95% → 1.65, 99% → 2.33)
- σ = 日需求标准差
- L = 供货周期(天)
**简化公式(适用于小型商家):**
```
安全库存 = 日均销量 × 安全天数
```
- 安全天数通常为 3~7 天
### 最佳实践
- 热销品安全天数设高一些(5~7天)
- 长尾品安全天数设低一些(2~3天)
- 定期根据销售数据调整安全库存值
---
## 二、补货点(Reorder Point)
### 定义
补货点是库存量降至某一水平时触发补货的阈值。
### 计算公式
```
补货点 = 日均销量 × 供货周期 + 安全库存
```
### 补货量计算
```
建议补货量 = (日均销量 × 供货周期 × 安全系数) - 当前库存
```
- 安全系数通常为 1.2~2.0
- 系数越高,缺货风险越低,但库存持有成本越高
### 最佳实践
- 供货周期(Lead Time)应根据实际供应商表现设定
- 考虑节假日和旺季因素,提前备货
- 不同供应商可设置不同的供货周期
---
## 三、库存周转率(Inventory Turnover)
### 定义
库存周转率衡量在一定时期内库存被"周转"(卖出并补充)了多少次。
### 计算公式
```
周转率 = 销售成本(COGS)/ 平均库存成本
周转天数 = 分析天数 / 周转率
```
- 平均库存 = (期初库存 + 期末库存) / 2
### 行业基准
| 行业 | 年周转率参考 | 说明 |
|------|------------|------|
| 快消品(食品/饮料) | 12~24次 | 保质期短,需快速周转 |
| 服装时尚 | 4~8次 | 季节性强,换季清仓 |
| 电子产品 | 6~12次 | 更新换代快 |
| 日用百货 | 6~10次 | 需求相对稳定 |
| 家具家居 | 3~6次 | 单价高,周转较慢 |
| 医药保健 | 4~8次 | 有效期管理严格 |
| 母婴用品 | 6~10次 | 成长阶段需求变化快 |
| 生鲜蔬果 | 30~50次 | 极短保质期 |
### 解读
- 周转率过低 → 库存积压,资金占用大
- 周转率过高 → 可能频繁缺货,服务水平下降
- 不同品类应有不同的目标周转率
---
## 四、滞销品管理
### 滞销品定义标准
| 滞销等级 | 定义 | 建议 |
|---------|------|------|
| 轻度滞销 | 30天无出库 | 搭配销售、优化陈列 |
| 中度滞销 | 60天无出库 | 打折促销(7~8折) |
| 严重滞销 | 90天无出库 | 大幅清仓(5折以下)或退回供应商 |
### 处理策略
1. **促销清仓**
- 打折销售(阶梯降价:7折→5折→3折)
- 买赠活动(购买指定商品赠送滞销品)
- 捆绑销售(与热销品组合销售)
2. **渠道转移**
- 转移到其他销售渠道(线上→线下,实体→电商)
- 批量出售给尾货商
3. **供应商退换**
- 与供应商协商退货或换货
- 未来采购中扣减
4. **预防措施**
- 新品少量试销,根据市场反馈再补货
- 定期分析销售趋势,及时调整库存结构
- 设置库龄预警,避免库存老化
---
## 五、过期管理
### 效期管理策略
1. **先进先出(FIFO)**
- 优先销售入库日期较早的商品
- 系统中标记批次和入库日期
2. **多级预警**
| 剩余时间 | 预警级别 | 建议操作 |
|---------|---------|---------|
| 已过期 | 🔴 紧急 | 立即下架,报废或退货 |
| ≤ 7天 | 🔴 紧急 | 立即促销清货 |
| ≤ 30天 | 🟡 警告 | 加速销售,考虑促销 |
| ≤ 60天 | 🟢 提醒 | 关注销售速度 |
| ≤ 90天 | ℹ️ 通知 | 纳入销售计划 |
3. **临期处理**
- 临期品专区销售
- 员工内购
- 捐赠(符合条件的商品)
### 最佳实践
- 入库时必须录入保质期/到期日期
- 定期(每周/每月)检查库存效期
- 建立临期品处理流程和责任人
- 食品类商品建议保质期过半前完成销售
---
## 六、ABC 分析法
### 原理
基于帕累托法则(80/20法则),将库存按重要性分为 A、B、C 三类,实施差异化管理。
### 分类标准
| 分类 | 占比(SKU数) | 贡献(销售额) | 管理要求 |
|------|-------------|--------------|---------|
| A类 | ~20% | ~80% | 重点管理,精准补货,高安全库存 |
| B类 | ~30% | ~15% | 常规管理,定期盘点 |
| C类 | ~50% | ~5% | 简化管理,低安全库存,考虑淘汰 |
### 实施步骤
1. 统计每个 SKU 的销售额(或销售量)
2. 按销售额从高到低排序
3. 计算累计销售额占比
4. 按累计占比划分 A/B/C 类
### 管理策略
**A类商品:**
- 高频盘点(每周或更高频次)
- 精确的安全库存计算
- 密切关注供应商表现
- 多供应商策略,降低断货风险
**B类商品:**
- 定期盘点(每月)
- 适中的安全库存
- 标准化补货流程
**C类商品:**
- 低频盘点(每季度)
- 最低安全库存
- 简化管理,自动补货
- 定期评估是否淘汰
---
## 七、库存管理常见指标
| 指标 | 计算公式 | 说明 |
|------|---------|------|
| 库存周转率 | COGS / 平均库存 | 越高越好 |
| 库存周转天数 | 365 / 周转率 | 越低越好 |
| 缺货率 | 缺货SKU数 / 总SKU数 | 越低越好 |
| 库存准确率 | 1 - |系统数-实际数| / 实际数 | 越高越好 |
| 库存金额占比 | 库存成本 / 总资产 | 控制在合理范围 |
| 呆滞库存比例 | 滞销金额 / 总库存金额 | 越低越好 |
| 库存满足率 | 满足的订单数 / 总订单数 | 越高越好 |
---
> 参考资料:本指南内容基于库存管理通用最佳实践整理,具体参数应根据行业特性和企业实际情况调整。
成交加速器 — 智能CRM助手,邮件信号提取、销售漏斗分析、AI跟进邮件起草、自学习销售智能、CRM知识图谱、IMAP/SMTP原生邮件
---
name: deal-closer
description: 成交加速器 — 智能CRM助手,邮件信号提取、销售漏斗分析、AI跟进邮件起草、自学习销售智能、CRM知识图谱、IMAP/SMTP原生邮件
version: 1.1.0
metadata:
openclaw:
optional_env:
- DC_GMAIL_CREDENTIALS
- DC_OUTLOOK_CLIENT_ID
- DC_OUTLOOK_SECRET
- DC_CALENDAR_TYPE
- DC_SUBSCRIPTION_TIER
- DC_IMAP_HOST
- DC_IMAP_PORT
- DC_SMTP_HOST
- DC_SMTP_PORT
- DC_EMAIL_USER
- DC_EMAIL_PASSWORD
---
# 成交加速器(deal-closer)
你是一个专业的销售助手 Agent。你的职责是帮助用户管理商机、追踪销售管道、分析邮件信号、起草跟进邮件,全面加速成交。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `DC_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `DC_GMAIL_CREDENTIALS` | 否 | Gmail OAuth2 凭据文件路径(邮件扫描功能需要) |
| `DC_OUTLOOK_CLIENT_ID` | 否 | Outlook 应用客户端 ID(邮件扫描功能需要) |
| `DC_OUTLOOK_SECRET` | 否 | Outlook 应用密钥(邮件扫描功能需要) |
| `DC_CALENDAR_TYPE` | 否 | 日历类型(google / outlook),用于会议同步 |
| `DC_DATA_DIR` | 否 | 数据存储目录,默认 `~/.openclaw-bdi/deal-closer/` |
| `DC_IMAP_HOST` | 否 | IMAP 服务器地址(如 imap.qq.com),IMAP邮件功能需要 |
| `DC_IMAP_PORT` | 否 | IMAP 端口,默认 993 |
| `DC_SMTP_HOST` | 否 | SMTP 服务器地址(如 smtp.qq.com),邮件发送功能需要 |
| `DC_SMTP_PORT` | 否 | SMTP 端口,默认 587 |
| `DC_EMAIL_USER` | 否 | 邮箱账号(IMAP/SMTP 登录用) |
| `DC_EMAIL_PASSWORD` | 否 | 邮箱密码或授权码(IMAP/SMTP 登录用) |
启动时,你应检查数据目录是否可用。若用户首次使用,主动引导其了解基本功能。
---
## 流程一:商机管理
当用户说"添加商机"、"更新商机"、"查看商机"、"商机列表"或类似意图时,执行以下操作:
### 添加商机
```bash
python3 scripts/deal_store.py --action add --data '{"name":"项目名称","contact_name":"联系人","contact_phone":"13800138000","contact_email":"[email protected]","company":"公司名称","amount":"50万","stage":"初步接触","source":"官网","expected_close_date":"2026-06-30","notes":"备注","tags":"标签1,标签2"}'
```
### 更新商机(含阶段变更追踪)
```bash
python3 scripts/deal_store.py --action update --data '{"id":"D20260319...","stage":"方案报价","amount":"80万"}'
```
> 阶段变更时自动记录到 `stage_history`,并根据新阶段调整成交概率。
### 查看/列出/删除
```bash
python3 scripts/deal_store.py --action get --data '{"id":"D20260319..."}'
python3 scripts/deal_store.py --action list --data '{"stage":"需求确认","keyword":"关键词"}'
python3 scripts/deal_store.py --action delete --data '{"id":"D20260319..."}'
```
### 阶段历史
```bash
python3 scripts/deal_store.py --action stage-history --data '{"id":"D20260319..."}'
```
### CSV 导入导出
```bash
python3 scripts/deal_store.py --action import --data '{"file_path":"./deals.csv"}'
python3 scripts/deal_store.py --action export --data '{"file_path":"./export.csv"}'
```
> 支持中英文列名,如"名称/name"、"金额/amount"、"阶段/stage"等。
### 商机阶段定义
| 阶段 | 默认概率 | 说明 |
|------|----------|------|
| 线索 | 5% | 新获得的潜在客户信息 |
| 初步接触 | 10% | 已与客户建立初步联系 |
| 需求确认 | 25% | 已明确客户需求 |
| 方案报价 | 50% | 已提交方案或报价 |
| 商务谈判 | 70% | 进入商务条款协商 |
| 合同签署 | 90% | 合同流程中 |
| 成交 | 100% | 已成交 |
| 流失 | 0% | 商机丢失 |
---
## 流程二:邮件扫描与信号提取(付费功能)
当用户说"扫描邮件"、"检查邮箱"、"邮件信号"或类似意图时:
### 步骤 1:扫描邮箱
```bash
python3 scripts/email_scanner.py --action scan --data '{"provider":"gmail","query":"合作","max_results":50}'
```
支持 Gmail(需 `DC_GMAIL_CREDENTIALS`)和 Outlook(需 `DC_OUTLOOK_CLIENT_ID` + `DC_OUTLOOK_SECRET`)。
### 步骤 2:提取信号
```bash
python3 scripts/email_scanner.py --action extract-signals
```
信号分为三类:
- **POSITIVE**(积极):包含同意、感兴趣、合作等关键词
- **NEGATIVE**(消极):包含推迟、竞争对手、太贵等关键词
- **NEUTRAL**(中性):包含咨询、了解、资料等关键词
### 步骤 3:关联商机
```bash
# 自动按联系人邮箱匹配
python3 scripts/email_scanner.py --action link-deal --data '{"auto":true}'
# 手动关联
python3 scripts/email_scanner.py --action link-deal --data '{"email_id":"E...","deal_id":"D..."}'
```
### 查看邮件记录
```bash
python3 scripts/email_scanner.py --action list-emails --data '{"deal_id":"D...","signal":"POSITIVE"}'
```
> 如未配置邮箱凭据,请引导用户参考 `references/email-setup-guide.md` 完成配置。
---
## 流程三:会议记录
当用户说"记录会议"、"会议列表"、"即将到来的会议"或类似意图时:
### 记录会议
```bash
python3 scripts/meeting_logger.py --action log --data '{"deal_id":"D...","date":"2026-03-20","attendees":"张三,李四","type":"面谈","location":"客户公司","notes":"讨论了方案细节","action_items":"发送修改后方案;安排技术对接","next_steps":"下周二跟进"}'
```
### 列出/查询
```bash
python3 scripts/meeting_logger.py --action list --data '{"deal_id":"D..."}'
python3 scripts/meeting_logger.py --action upcoming --data '{"days":7}'
```
### 会议摘要
```bash
python3 scripts/meeting_logger.py --action summary --data '{"deal_id":"D..."}'
```
> 摘要包含所有行动项、下一步和参会人汇总。
---
## 流程四:销售管道分析
当用户说"销售漏斗"、"管道报告"、"收入预测"、"周报"、"月报"或类似意图时:
### 漏斗报告(免费)
```bash
python3 scripts/pipeline_reporter.py --action funnel
```
包含各阶段数量、金额、转化率和风险商机。
### 收入预测(付费)
```bash
python3 scripts/pipeline_reporter.py --action forecast
```
根据管道金额 x 成交概率计算加权收入预测。
### 周度报告(付费)
```bash
python3 scripts/pipeline_reporter.py --action weekly --data '{"week_start":"2026-03-16"}'
```
### 月度报告(付费)
```bash
python3 scripts/pipeline_reporter.py --action monthly --data '{"month":"2026-03"}'
```
### 趋势分析(付费)
```bash
python3 scripts/pipeline_reporter.py --action trends --data '{"months":6}'
```
> 付费版报告包含 Mermaid 可视化图表(饼图、柱状图、折线图)。参见 `references/pipeline-templates.md`。
---
## 流程五:AI 跟进邮件(付费功能)
当用户说"起草跟进邮件"、"跟进提醒"、"待跟进列表"或类似意图时:
### 起草邮件
```bash
python3 scripts/followup_drafter.py --action draft --data '{"deal_id":"D...","template":"proposal_followup","sender_name":"销售经理"}'
```
自动根据商机阶段选择模板,并结合最近会议记录和邮件互动生成上下文化的邮件草稿。
### 可用模板
```bash
python3 scripts/followup_drafter.py --action templates
```
| 模板 | 名称 | 适用阶段 |
|------|------|----------|
| introduction | 初次介绍 | 线索、初步接触 |
| proposal_followup | 方案跟进 | 需求确认、方案报价 |
| negotiation | 商务谈判 | 商务谈判 |
| closing | 促成签约 | 合同签署 |
| win_back | 赢回客户 | 流失 |
### 创建跟进计划
```bash
python3 scripts/followup_drafter.py --action schedule --data '{"deal_id":"D...","scheduled_date":"2026-03-25","template":"proposal_followup","priority":"high","notes":"需重点跟进"}'
```
### 查看待办
```bash
python3 scripts/followup_drafter.py --action list-pending
```
> 待办按紧急程度排序:逾期 > urgent > high > normal > low。
---
## 流程六:自学习销售智能(付费功能)
当用户说"记录成交"、"预测胜率"、"销售建议"、"教练建议"、"学习统计"或类似意图时:
### 记录商机结果
```bash
python3 scripts/learning_engine.py --action record-outcome --data '{"deal_id":"D...","result":"won","cycle_days":30,"loss_reasons":[],"contributing_factors":["快速响应","定制方案"]}'
```
### 记录成功模式
```bash
python3 scripts/learning_engine.py --action record-pattern --data '{"category":"timing","description":"周二上午10点跟进回复率最高","success_rate":0.65}'
```
模式类别: timing, communication, pricing, followup, negotiation, presentation, objection_handling, other
### AI 胜率预测
```bash
# 预测单个商机
python3 scripts/learning_engine.py --action predict --data '{"deal_id":"D..."}'
# 预测所有活跃商机
python3 scripts/learning_engine.py --action predict --data '{}'
```
基于历史成交/流失数据的多维度评分模型,维度包括:销售周期合理性、跟进频率、金额匹配度、阶段推进速度、行业胜率、客户互动。
### 主动建议
```bash
python3 scripts/learning_engine.py --action suggest --data '{"deal_id":"D..."}'
```
### 销售教练
```bash
python3 scripts/learning_engine.py --action coach
```
基于管道瓶颈和历史数据生成教练建议。
### 学习统计
```bash
python3 scripts/learning_engine.py --action stats
```
包含胜率趋势、平均销售周期、流失原因 Top 5、最佳实践。
---
## 流程七:CRM 知识图谱(付费功能)
当用户说"添加联系人关系"、"公司组织架构"、"关系图谱"、"影响力链路"或类似意图时:
### 添加实体
```bash
python3 scripts/crm_graph.py --action add-entity --data '{"type":"Person","name":"张经理","properties":{"title":"技术总监","email":"[email protected]"}}'
python3 scripts/crm_graph.py --action add-entity --data '{"type":"Company","name":"鑫科技"}'
```
实体类型: Person, Company, Deal, Meeting, Email
### 添加关系
```bash
python3 scripts/crm_graph.py --action add-relation --data '{"from_name":"张经理","to_name":"鑫科技","relation":"works_at"}'
python3 scripts/crm_graph.py --action add-relation --data '{"from_name":"张经理","to_name":"智慧园区项目","relation":"decision_maker_for"}'
```
关系类型: works_at, reports_to, knows, decision_maker_for, competitor_of, partner_of, referred_by, participated_in, related_to, contact_of
### 查询关联
```bash
python3 scripts/crm_graph.py --action query --data '{"name":"张经理","max_depth":3}'
```
BFS 广度优先搜索,返回指定深度内所有相关实体和关系。
### 公司组织架构
```bash
python3 scripts/crm_graph.py --action company-map --data '{"company":"鑫科技"}'
```
展示公司所有联系人、决策人、汇报关系、关联商机。
### 影响力链路
```bash
python3 scripts/crm_graph.py --action influence-chain --data '{"person_name":"张经理"}'
```
追踪推荐/介绍关系链:A 介绍了 B,B 介绍了 C。
### Mermaid 可视化(付费)
```bash
python3 scripts/crm_graph.py --action visualize --data '{"company":"鑫科技"}'
```
生成 Mermaid 图谱代码,可在 GitHub/Obsidian 中渲染。
> 图谱会自动从商机数据中提取联系人和公司实体,无需全部手动添加。
---
## 流程八:IMAP/SMTP 原生邮件(付费功能)
当用户说"连接邮箱"、"收件箱"、"发送邮件"、"搜索邮件"或类似意图时:
### 测试连接
```bash
python3 scripts/imap_email.py --action connect --data '{"provider":"qq"}'
```
支持自动检测:QQ邮箱、163、Gmail、Outlook、阿里云邮箱。也可通过环境变量手动配置。
### 获取收件箱
```bash
python3 scripts/imap_email.py --action fetch-inbox --data '{"count":20,"folder":"INBOX"}'
```
### 搜索邮件
```bash
python3 scripts/imap_email.py --action search --data '{"subject":"合作方案","from_addr":"[email protected]","since":"2026-03-01"}'
```
### 发送邮件
```bash
python3 scripts/imap_email.py --action send --data '{"to":"[email protected]","subject":"关于合作方案","body":"邮件正文..."}'
```
### 回复邮件
```bash
python3 scripts/imap_email.py --action reply --data '{"to":"[email protected]","subject":"Re: 合作方案","body":"回复内容...","original_message_id":"<msg-id>"}'
```
### 列出文件夹
```bash
python3 scripts/imap_email.py --action list-folders
```
### 跟进邮件直接发送
```bash
python3 scripts/followup_drafter.py --action send --data '{"deal_id":"D...","subject":"关于方案的跟进","body":"邮件正文..."}'
```
### 自动起草停滞商机跟进
```bash
python3 scripts/followup_drafter.py --action auto-draft --data '{"stale_days":7,"max_drafts":5}'
```
> 注意:IMAP/SMTP 使用 Python 标准库(imaplib/smtplib),无需安装额外依赖,支持任意邮件服务商。
---
## 订阅校验逻辑
在每次涉及功能限制的操作前,必须执行订阅校验。
### 读取订阅等级
```
tier = env DC_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥149/月) |
|------|---------------|----------------------|
| 商机管理(CRUD) | 最多 30 个 | 最多 500 个 |
| 基础漏斗报告 | 支持 | 支持 |
| CSV 导出 | 支持 | 支持 |
| 手动跟进记录 | 支持 | 支持 |
| 邮件扫描(Gmail/Outlook/IMAP) | 不支持 | 支持 |
| IMAP/SMTP 原生邮件 | 不支持 | 支持 |
| 会议日历同步 | 不支持 | 支持 |
| 收入预测 | 不支持 | 支持 |
| AI 跟进邮件 | 不支持 | 支持 |
| 自动跟进起草 | 不支持 | 支持 |
| Mermaid 图表 | 不支持 | 支持 |
| 高级分析(周报/月报/趋势) | 不支持 | 支持 |
| 自学习销售智能 | 不支持 | 支持 |
| CRM 知识图谱 | 不支持 | 支持 |
| AI 胜率预测 | 不支持 | 支持 |
| 批量导入 | 不支持 | 支持 |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥149/月),请联系管理员或访问订阅管理页面。"
4. 不要直接拒绝,而是提供免费版可用的替代方案(如果有的话)。
---
## 参考文档
- **邮件配置指南**:`references/email-setup-guide.md` — Gmail 和 Outlook 的 OAuth2 配置步骤。
- **管道报告模板**:`references/pipeline-templates.md` — 报告格式和 Mermaid 图表示例。
---
## 安全规范
1. **凭据保护**:邮箱凭据仅通过环境变量传递,绝不在对话中显示、记录或输出密码和密钥。
2. **数据脱敏**:输出中的手机号自动脱敏(如 138****8000),邮箱自动脱敏(如 zh***@example.com)。
3. **本地存储**:所有数据存储在本地 JSON 文件中,不上传到任何云端。
4. **错误处理**:执行命令失败时,向用户展示友好的错误提示,不要暴露内部路径或系统信息。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 对用户的问题给出清晰、结构化的回答,优先使用表格展示数据。
3. 主动提供销售建议和下一步行动建议。
4. 遇到模糊的用户意图时,主动追问以明确需求。
5. 在商机阶段变更时,主动提醒用户更新相关信息。
6. 检测到风险商机时(长时间未更新、超过预计成交日期),主动预警。
7. 尊重订阅等级限制,在提示升级时保持友好,不要反复推销。
8. 输出金额时使用人民币格式(如 ¥50.00万),大数值自动转换单位。
FILE:assets/README.md
# Deal Closer / 成交加速器
> 智能 CRM 助手 — 商机管理、邮件信号提取、销售漏斗分析、AI 跟进邮件起草、自学习销售智能、CRM 知识图谱
>
> Smart CRM Assistant — Deal management, email signal extraction, pipeline analytics, AI follow-up drafting, self-learning sales intelligence, CRM knowledge graph
---
## 功能亮点 / Features
- **商机全生命周期管理** — 从线索到成交,8 个阶段精细化追踪,自动记录阶段变更历史
- Full deal lifecycle management — 8 stages from lead to close, with automatic stage history tracking
- **邮件信号提取** — 扫描 Gmail / Outlook 邮件,自动识别积极/消极/中性信号,关联商机
- Email signal extraction — Scan Gmail / Outlook, auto-detect positive/negative/neutral signals
- **销售漏斗分析** — 漏斗转化率、加权收入预测、风险预警,一键生成报告
- Pipeline analytics — Funnel conversion, weighted forecast, risk alerts, one-click reports
- **AI 跟进邮件** — 基于商机上下文和互动历史,智能起草跟进邮件,5 种模板覆盖全流程
- AI follow-up drafting — Context-aware email drafts with 5 templates covering the full sales cycle
- **会议追踪** — 记录会议纪要、行动项和下一步,汇总摘要一目了然
- Meeting tracking — Log notes, action items and next steps, generate meeting summaries
- **自学习销售智能** — 从成交模式中持续学习,预测商机胜率,提供教练建议
- Self-learning sales intelligence — Learn from deal outcomes, predict win rates, provide coaching tips
- **CRM 关系图谱** — 可视化客户关系网络,追踪决策人、组织架构和影响力链路
- CRM knowledge graph — Visualize contact networks, track decision makers, org charts and influence chains
- **原生邮件支持** — 任意邮箱直连(IMAP/SMTP),无需 OAuth2 配置,支持 QQ/163/Gmail/Outlook 等
- Native email support — Direct IMAP/SMTP connection, works with any email provider
- **Mermaid 可视化** — 饼图、柱状图、折线图、关系图谱内嵌报告,无需额外工具
- Mermaid visualization — Pie, bar, line charts, relationship graphs embedded in reports
- **数据安全** — 所有数据本地存储,手机号和邮箱自动脱敏
- Data security — All data stored locally, phone and email auto-masked
---
## 版本对比 / Plan Comparison
| 功能 / Feature | 免费版 / Free | 付费版 / Paid ¥149/月 |
|----------------|:------------:|:--------------------:|
| 商机管理 / Deal CRUD | 最多 30 个 | 最多 500 个 |
| 基础漏斗报告 / Basic funnel | 支持 | 支持 |
| CSV 导出 / CSV export | 支持 | 支持 |
| 手动跟进 / Manual follow-up | 支持 | 支持 |
| 邮件扫描 / Email scan (Gmail/Outlook/IMAP) | - | 支持 |
| IMAP/SMTP 原生邮件 / Native email | - | 支持 |
| 会议日历同步 / Calendar sync | - | 支持 |
| 收入预测 / Revenue forecast | - | 支持 |
| AI 跟进邮件 / AI follow-up | - | 支持 |
| 自动跟进起草 / Auto follow-up draft | - | 支持 |
| Mermaid 图表 / Mermaid charts | - | 支持 |
| 高级分析 / Advanced analytics | - | 支持 |
| 自学习销售智能 / Self-learning intelligence | - | 支持 |
| CRM 知识图谱 / CRM knowledge graph | - | 支持 |
| AI 胜率预测 / AI win prediction | - | 支持 |
| 批量导入 / Bulk import | - | 支持 |
---
## 快速开始 / Quick Start
### 1. 安装 / Install
在 ClawHub 中搜索 `deal-closer`,点击安装,或使用命令行:
```bash
openclaw skill install deal-closer
```
### 2. 添加商机 / Add a Deal
```bash
/deal-closer add --data '{"name":"企业ERP升级项目","contact_name":"张经理","company":"科技有限公司","amount":"50万","stage":"初步接触"}'
```
### 3. 查看管道 / View Pipeline
```bash
/deal-closer funnel
```
### 4. 起草跟进邮件 / Draft Follow-up (Paid)
```bash
/deal-closer draft --data '{"deal_id":"D20260319...","sender_name":"销售经理"}'
```
### 5. 扫描邮件 / Scan Emails (Paid)
```bash
# 配置邮箱凭据
export DC_GMAIL_CREDENTIALS="/path/to/credentials.json"
/deal-closer scan
```
---
## 使用示例 / Example
```
用户:帮我添加一个商机,名称是"智慧园区项目",联系人王总,公司是鑫科技,金额 200 万,阶段是需求确认。
助手:商机「智慧园区项目」已添加!
| 字段 | 值 |
|------|-----|
| ID | D20260319143022... |
| 名称 | 智慧园区项目 |
| 联系人 | 王总 |
| 公司 | 鑫科技 |
| 金额 | ¥200.00万 |
| 阶段 | 需求确认 |
| 成交概率 | 25% |
建议:需求确认阶段建议尽快安排面谈,深入了解客户具体需求后提交方案。
```
---
## 数据存储 / Data Storage
所有数据以 JSON 格式存储在本地:
```
~/.openclaw-bdi/deal-closer/
deals.json # 商机数据
emails.json # 邮件记录
meetings.json # 会议记录
followups.json # 跟进计划
learning.json # 自学习数据(成交模式、预测模型)
crm_graph.json # CRM 知识图谱
email_config.json # IMAP/SMTP 连接配置(不含密码)
```
可通过 `DC_DATA_DIR` 环境变量自定义存储路径。
---
## 常见问题 / FAQ
### Q1: 免费版有哪些限制?
免费版支持最多 30 个商机管理、基础漏斗报告和 CSV 导出。邮件扫描、AI 跟进、高级分析等功能需付费版。
### Q2: 数据会上传到云端吗?
不会。所有数据存储在本地 JSON 文件中,不会离开你的运行环境。
### Q3: 如何配置邮件扫描?
详见 `references/email-setup-guide.md`,需完成 Gmail OAuth2 或 Outlook Azure AD 配置。
### Q4: 支持哪些邮件类型的信号识别?
支持中英文关键词匹配,自动识别积极(同意、感兴趣等)、消极(推迟、竞争对手等)和中性(咨询、了解等)信号。
### Q5: Mermaid 图表在哪里可以渲染?
GitHub / GitLab Markdown 预览、VS Code(Mermaid 插件)、Typora、Obsidian 等工具均支持。
### Q6: 如何从其他 CRM 迁移数据?
准备标准 CSV 文件(支持中英文列名),使用导入功能即可批量迁移。
---
## 技术支持 / Support
- 文档 / Docs:查看 `references/` 目录
- 问题反馈 / Issues:在 ClawHub Skill 页面提交
- 社区 / Community:`#deal-closer` 频道
- 邮件 / Email:[email protected]
---
*deal-closer v1.1.0 | 兼容 OpenClaw 0.5+*
FILE:scripts/followup_drafter.py
#!/usr/bin/env python3
"""
deal-closer AI 跟进邮件起草模块(付费功能)
根据商机历史和最近交互,生成上下文相关的跟进邮件草稿。
支持多种模板、定时提醒和待办列表。
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
write_json_file,
format_currency,
calculate_days_since,
days_until,
mask_email,
DEAL_STAGES,
FOLLOWUP_TEMPLATES,
)
# 延迟导入 IMAP 和学习模块
_imap_module = None
_learning_module = None
def _get_imap_module():
"""延迟加载 imap_email 模块。"""
global _imap_module
if _imap_module is None:
try:
import imap_email as _mod
_imap_module = _mod
except ImportError:
_imap_module = False
return _imap_module if _imap_module is not False else None
def _get_learning_module():
"""延迟加载 learning_engine 模块。"""
global _learning_module
if _learning_module is None:
try:
import learning_engine as _mod
_learning_module = _mod
except ImportError:
_learning_module = False
return _learning_module if _learning_module is not False else None
# ============================================================
# 数据文件
# ============================================================
FOLLOWUPS_FILE = "followups.json"
DEALS_FILE = "deals.json"
MEETINGS_FILE = "meetings.json"
EMAILS_FILE = "emails.json"
def _get_followups() -> List[Dict[str, Any]]:
"""读取所有跟进任务。"""
return read_json_file(get_data_file(FOLLOWUPS_FILE))
def _save_followups(followups: List[Dict[str, Any]]) -> None:
"""保存跟进任务到文件。"""
write_json_file(get_data_file(FOLLOWUPS_FILE), followups)
def _get_deals() -> List[Dict[str, Any]]:
"""读取所有商机数据。"""
return read_json_file(get_data_file(DEALS_FILE))
def _get_meetings() -> List[Dict[str, Any]]:
"""读取所有会议记录。"""
return read_json_file(get_data_file(MEETINGS_FILE))
def _get_emails() -> List[Dict[str, Any]]:
"""读取所有邮件记录。"""
return read_json_file(get_data_file(EMAILS_FILE))
# ============================================================
# 邮件模板
# ============================================================
_TEMPLATES: Dict[str, Dict[str, Any]] = {
"introduction": {
"name": "初次介绍",
"description": "首次接触客户,介绍公司和产品",
"subject_template": "您好,{contact_name} — 关于{company_or_product}的合作机会",
"body_template": (
"{contact_name}您好,\n\n"
"感谢您对我们的关注。我是{sender_name},负责{product_area}业务。\n\n"
"了解到贵公司在{industry_or_need}方面有需求,我们的解决方案在此领域有丰富的经验"
"和成功案例。\n\n"
"希望能与您进一步沟通,了解贵公司的具体需求,为您提供针对性的方案。\n\n"
"请问您本周是否方便安排一次简短的交流?\n\n"
"期待您的回复。\n\n"
"此致\n{sender_name}"
),
},
"proposal_followup": {
"name": "方案跟进",
"description": "发送方案后的跟进",
"subject_template": "关于{deal_name}方案的跟进",
"body_template": (
"{contact_name}您好,\n\n"
"上次为您发送了{deal_name}的详细方案,不知您是否有时间查阅?\n\n"
"如有任何疑问或需要调整的地方,我很乐意为您详细解答。\n\n"
"该方案的核心优势包括:\n"
"- 针对贵公司需求的定制化设计\n"
"- 具有竞争力的价格方案\n"
"- 完善的售后服务体系\n\n"
"期待您的反馈,我们可以进一步讨论细节。\n\n"
"此致\n{sender_name}"
),
},
"negotiation": {
"name": "商务谈判",
"description": "商务谈判阶段的跟进",
"subject_template": "关于{deal_name}合作条款的确认",
"body_template": (
"{contact_name}您好,\n\n"
"感谢您在上次会谈中的深入交流。根据讨论的结果,我已整理了以下要点:\n\n"
"{meeting_summary}\n\n"
"关于价格和交付条款,我们愿意在以下方面做出灵活安排,以促成双方的合作。\n\n"
"如您方便,我们可以安排下一次沟通,具体讨论合同细节。\n\n"
"此致\n{sender_name}"
),
},
"closing": {
"name": "促成签约",
"description": "推进签约的跟进",
"subject_template": "关于{deal_name}合同签署事宜",
"body_template": (
"{contact_name}您好,\n\n"
"经过前期的充分沟通,我们对双方的合作充满信心。\n\n"
"附件中是最终版合同,已根据您上次提出的意见做了调整。"
"主要变更包括:\n"
"{contract_changes}\n\n"
"如无异议,烦请您签署后返回,我们将在收到合同后立即启动项目。\n\n"
"如有任何问题,请随时联系我。\n\n"
"此致\n{sender_name}"
),
},
"win_back": {
"name": "赢回客户",
"description": "针对流失客户的重新激活",
"subject_template": "{contact_name},我们有了新的方案想与您分享",
"body_template": (
"{contact_name}您好,\n\n"
"距离我们上次沟通已经有一段时间了。\n\n"
"我们最近对产品进行了重大升级,新增了以下功能:\n"
"{new_features}\n\n"
"这些改进正好可以解决您之前提到的顾虑。\n\n"
"如果您有兴趣了解更多,我非常乐意为您安排一次演示。"
"另外,针对老客户我们也准备了特别优惠方案。\n\n"
"期待与您重新建立联系。\n\n"
"此致\n{sender_name}"
),
},
}
# ============================================================
# 操作函数
# ============================================================
def draft_followup(data: Dict[str, Any]) -> None:
"""起草跟进邮件。
必填字段: deal_id
可选字段: template(模板类型), sender_name, custom_notes
Args:
data: 参数字典。
"""
if not require_paid_feature("ai_followup", "AI跟进邮件"):
return
deal_id = data.get("deal_id")
if not deal_id:
output_error("商机ID(deal_id)为必填字段", code="VALIDATION_ERROR")
return
# 加载商机
deals = _get_deals()
target_deal = None
for d in deals:
if d.get("id") == deal_id:
target_deal = d
break
if not target_deal:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
# 确定模板
template_key = data.get("template", "")
if not template_key:
# 根据商机阶段自动选择模板
stage = target_deal.get("stage", "")
stage_template_map = {
"线索": "introduction",
"初步接触": "introduction",
"需求确认": "proposal_followup",
"方案报价": "proposal_followup",
"商务谈判": "negotiation",
"合同签署": "closing",
"流失": "win_back",
}
template_key = stage_template_map.get(stage, "proposal_followup")
if template_key not in _TEMPLATES:
valid = "、".join(_TEMPLATES.keys())
output_error(f"未知模板: {template_key},可用模板: {valid}", code="VALIDATION_ERROR")
return
template = _TEMPLATES[template_key]
sender_name = data.get("sender_name", "销售顾问")
# 获取最近会议摘要
meetings = _get_meetings()
deal_meetings = [m for m in meetings if m.get("deal_id") == deal_id]
deal_meetings.sort(key=lambda m: m.get("date", ""), reverse=True)
meeting_summary = "(暂无会议记录)"
if deal_meetings:
latest = deal_meetings[0]
notes = latest.get("notes", "")
action_items = latest.get("action_items", [])
if notes or action_items:
parts = []
if notes:
parts.append(f"会议纪要: {notes}")
if action_items:
parts.append("行动项: " + ";".join(action_items))
meeting_summary = "\n".join(parts)
# 获取最近邮件
emails = _get_emails()
deal_emails = [e for e in emails if e.get("linked_deal_id") == deal_id]
deal_emails.sort(key=lambda e: e.get("date", ""), reverse=True)
last_email_info = ""
if deal_emails:
latest_email = deal_emails[0]
last_email_info = f"最近邮件主题: {latest_email.get('subject', '')}"
# 填充模板变量
variables = {
"contact_name": target_deal.get("contact_name", "客户"),
"company_or_product": target_deal.get("company", "") or target_deal.get("name", ""),
"deal_name": target_deal.get("name", ""),
"sender_name": sender_name,
"product_area": "解决方案",
"industry_or_need": target_deal.get("notes", "相关领域"),
"meeting_summary": meeting_summary,
"contract_changes": data.get("contract_changes", "- 根据双方协商结果调整了相关条款"),
"new_features": data.get("new_features", "- 性能提升\n- 新增定制化功能\n- 优化用户体验"),
}
subject = template["subject_template"]
body = template["body_template"]
for key, value in variables.items():
placeholder = "{" + key + "}"
subject = subject.replace(placeholder, str(value))
body = body.replace(placeholder, str(value))
# 添加上下文信息
context_info = []
context_info.append(f"商机阶段: {target_deal.get('stage', '')}")
context_info.append(f"商机金额: {format_currency(target_deal.get('amount', 0))}")
if target_deal.get("expected_close_date"):
context_info.append(f"预计成交: {target_deal['expected_close_date']}")
if last_email_info:
context_info.append(last_email_info)
if deal_meetings:
context_info.append(f"最近会议: {deal_meetings[0].get('date', '')}")
custom_notes = data.get("custom_notes", "")
output_success({
"template": template_key,
"template_name": template["name"],
"subject": subject,
"body": body,
"context": context_info,
"custom_notes": custom_notes,
"deal_id": deal_id,
"deal_name": target_deal.get("name", ""),
"contact_email": mask_email(target_deal.get("contact_email", "")),
})
def list_templates(data: Optional[Dict[str, Any]] = None) -> None:
"""列出所有可用的跟进邮件模板。
Args:
data: 未使用,保留接口一致性。
"""
templates = []
for key, tmpl in _TEMPLATES.items():
templates.append({
"key": key,
"name": tmpl["name"],
"description": tmpl["description"],
"subject_preview": tmpl["subject_template"],
})
output_success({
"total": len(templates),
"templates": templates,
})
def schedule_followup(data: Dict[str, Any]) -> None:
"""创建跟进提醒计划。
必填字段: deal_id, scheduled_date
可选字段: template, notes, priority
Args:
data: 参数字典。
"""
if not require_paid_feature("ai_followup", "跟进计划"):
return
deal_id = data.get("deal_id")
if not deal_id:
output_error("商机ID(deal_id)为必填字段", code="VALIDATION_ERROR")
return
scheduled_date = data.get("scheduled_date")
if not scheduled_date:
output_error("计划日期(scheduled_date)为必填字段", code="VALIDATION_ERROR")
return
# 验证商机存在
deals = _get_deals()
target_deal = None
for d in deals:
if d.get("id") == deal_id:
target_deal = d
break
if not target_deal:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
# 优先级校验
priority = data.get("priority", "normal")
if priority not in ("low", "normal", "high", "urgent"):
priority = "normal"
template_key = data.get("template", "")
if template_key and template_key not in _TEMPLATES:
valid = "、".join(_TEMPLATES.keys())
output_error(f"未知模板: {template_key},可用模板: {valid}", code="VALIDATION_ERROR")
return
followup = {
"id": generate_id("F"),
"deal_id": deal_id,
"deal_name": target_deal.get("name", ""),
"scheduled_date": scheduled_date,
"template": template_key,
"notes": data.get("notes", ""),
"priority": priority,
"status": "pending",
"created_at": now_iso(),
}
followups = _get_followups()
followups.append(followup)
_save_followups(followups)
output_success({
"message": f"跟进计划已创建({scheduled_date},商机: {target_deal.get('name', '')})",
"followup": followup,
})
def list_pending(data: Optional[Dict[str, Any]] = None) -> None:
"""列出待处理的跟进任务。
按紧急程度和日期排序,优先显示最紧急的。
Args:
data: 可选参数,支持 deal_id、status 过滤。
"""
if not require_paid_feature("ai_followup", "跟进列表"):
return
followups = _get_followups()
if data:
# 按商机过滤
deal_id = data.get("deal_id")
if deal_id:
followups = [f for f in followups if f.get("deal_id") == deal_id]
# 按状态过滤
status = data.get("status")
if status:
followups = [f for f in followups if f.get("status") == status]
else:
# 默认只显示待处理
followups = [f for f in followups if f.get("status") == "pending"]
# 优先级权重
priority_weight = {
"urgent": 0,
"high": 1,
"normal": 2,
"low": 3,
}
today = today_str()
# 计算紧急度并排序
for f in followups:
f["_priority_weight"] = priority_weight.get(f.get("priority", "normal"), 2)
scheduled = f.get("scheduled_date", "")
if scheduled:
f["days_remaining"] = days_until(scheduled)
f["is_overdue"] = scheduled < today
else:
f["days_remaining"] = 999
f["is_overdue"] = False
# 先按是否逾期、再按优先级、最后按日期排序
followups.sort(key=lambda f: (
not f.get("is_overdue", False),
f.get("_priority_weight", 2),
f.get("scheduled_date", ""),
))
# 清理临时字段
display_list = []
for f in followups:
display = dict(f)
display.pop("_priority_weight", None)
display_list.append(display)
# 统计
overdue = sum(1 for f in display_list if f.get("is_overdue"))
today_count = sum(1 for f in display_list if f.get("scheduled_date") == today)
output_success({
"total": len(display_list),
"overdue": overdue,
"today": today_count,
"followups": display_list,
})
def send_draft(data: Dict[str, Any]) -> None:
"""通过 IMAP/SMTP 直接发送已起草的跟进邮件。
必填字段: deal_id, subject, body
可选字段: template(用于学习记录)
Args:
data: 参数字典。
"""
if not require_paid_feature("ai_followup", "SMTP邮件发送"):
return
imap_mod = _get_imap_module()
if imap_mod is None:
output_error(
"IMAP/SMTP 模块未加载,请确认 imap_email.py 存在",
code="MODULE_ERROR",
)
return
deal_id = data.get("deal_id")
if not deal_id:
output_error("商机ID(deal_id)为必填字段", code="VALIDATION_ERROR")
return
# 获取商机联系人邮箱
deals = _get_deals()
target_deal = None
for d in deals:
if d.get("id") == deal_id:
target_deal = d
break
if not target_deal:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
to_addr = data.get("to", target_deal.get("contact_email", ""))
if not to_addr:
output_error("未找到收件人邮箱,请提供 to 字段或确保商机有联系人邮箱", code="VALIDATION_ERROR")
return
subject = data.get("subject", "")
body = data.get("body", "")
if not subject or not body:
output_error("subject 和 body 为必填字段", code="VALIDATION_ERROR")
return
# 调用 imap_email 的发送功能
send_data = {"to": to_addr, "subject": subject, "body": body}
cc = data.get("cc", "")
if cc:
send_data["cc"] = cc
imap_mod.send_email(send_data)
# 记录发送到学习引擎
_record_followup_sent(deal_id, data.get("template", ""), target_deal)
def _record_followup_sent(deal_id: str, template: str, deal: Dict[str, Any]) -> None:
"""记录跟进邮件发送到学习引擎。
Args:
deal_id: 商机ID。
template: 使用的模板。
deal: 商机数据。
"""
learning_mod = _get_learning_module()
if learning_mod is None:
return
try:
learning_data = learning_mod._get_learning_data()
patterns = learning_data.get("patterns", [])
pattern = {
"id": generate_id("LP"),
"category": "followup",
"description": (
f"通过SMTP发送跟进邮件,商机「{deal.get('name', '')}」,"
f"阶段: {deal.get('stage', '')},模板: {template or '自定义'}"
),
"success_rate": 0.5, # 初始成功率,后续根据回复情况更新
"applicable_stages": [deal.get("stage", "")],
"notes": f"deal_id: {deal_id}",
"recorded_at": now_iso(),
}
patterns.append(pattern)
learning_data["patterns"] = patterns
learning_mod._save_learning_data(learning_data)
except Exception:
pass
def auto_draft(data: Optional[Dict[str, Any]] = None) -> None:
"""主动为停滞商机起草跟进邮件。
自动识别超过指定天数未更新的活跃商机,并生成跟进邮件草稿。
可选字段: stale_days(停滞天数阈值,默认 7)、max_drafts(最大草稿数,默认 5)
Args:
data: 可选参数。
"""
if not require_paid_feature("ai_followup", "自动跟进起草"):
return
data = data or {}
stale_days = int(data.get("stale_days", 7))
max_drafts = int(data.get("max_drafts", 5))
deals = _get_deals()
active_deals = [
d for d in deals
if d.get("stage") not in ("成交", "流失")
]
if not active_deals:
output_error("暂无活跃商机", code="NO_DATA")
return
# 筛选停滞商机
stale_deals = []
for deal in active_deals:
updated = deal.get("updated_at", "")
if updated:
days = calculate_days_since(updated)
if days >= stale_days:
stale_deals.append((deal, days))
if not stale_deals:
output_success({
"message": f"没有超过 {stale_days} 天未更新的商机,管道状态良好",
"drafts": [],
"total": 0,
})
return
# 按停滞时间排序
stale_deals.sort(key=lambda x: x[1], reverse=True)
stale_deals = stale_deals[:max_drafts]
drafts = []
for deal, days_stale in stale_deals:
stage = deal.get("stage", "")
# 根据阶段选择模板
stage_template_map = {
"线索": "introduction",
"初步接触": "introduction",
"需求确认": "proposal_followup",
"方案报价": "proposal_followup",
"商务谈判": "negotiation",
"合同签署": "closing",
}
template_key = stage_template_map.get(stage, "proposal_followup")
if template_key not in _TEMPLATES:
template_key = "proposal_followup"
template = _TEMPLATES[template_key]
# 简单填充模板
variables = {
"contact_name": deal.get("contact_name", "客户"),
"company_or_product": deal.get("company", "") or deal.get("name", ""),
"deal_name": deal.get("name", ""),
"sender_name": "销售顾问",
"product_area": "解决方案",
"industry_or_need": deal.get("notes", "相关领域"),
"meeting_summary": "(自动生成草稿)",
"contract_changes": "- 根据协商结果调整",
"new_features": "- 产品更新内容",
}
subject = template["subject_template"]
body = template["body_template"]
for key, value in variables.items():
placeholder = "{" + key + "}"
subject = subject.replace(placeholder, str(value))
body = body.replace(placeholder, str(value))
drafts.append({
"deal_id": deal.get("id", ""),
"deal_name": deal.get("name", ""),
"stage": stage,
"days_stale": days_stale,
"template": template_key,
"subject": subject,
"body": body,
"contact_email": mask_email(deal.get("contact_email", "")),
"reason": f"已 {days_stale} 天未更新",
})
output_success({
"message": f"已为 {len(drafts)} 个停滞商机生成跟进草稿",
"total": len(drafts),
"stale_threshold_days": stale_days,
"drafts": drafts,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("deal-closer AI跟进邮件")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"draft": lambda: draft_followup(data or {}),
"templates": lambda: list_templates(data),
"schedule": lambda: schedule_followup(data or {}),
"list-pending": lambda: list_pending(data),
"send": lambda: send_draft(data or {}),
"auto-draft": lambda: auto_draft(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/crm_graph.py
#!/usr/bin/env python3
"""
deal-closer CRM 知识图谱模块
构建客户关系网络图谱,支持实体管理、关系查询、组织架构映射、
影响力链路追踪和 Mermaid 可视化。
基于 ontology 理念,将CRM数据结构化为知识图谱。
"""
import json
import os
import sys
from collections import deque
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
write_json_file,
)
# ============================================================
# 常量与配置
# ============================================================
GRAPH_FILE = "crm_graph.json"
DEALS_FILE = "deals.json"
# 实体类型
ENTITY_TYPES = ["Person", "Company", "Deal", "Meeting", "Email"]
# 关系类型及其描述
RELATION_TYPES = {
"works_at": "就职于",
"reports_to": "汇报给",
"knows": "认识",
"decision_maker_for": "是决策人",
"competitor_of": "竞争对手",
"partner_of": "合作伙伴",
"referred_by": "由...推荐",
"participated_in": "参与了",
"related_to": "关联于",
"contact_of": "是联系人",
}
# 默认图谱数据结构
_DEFAULT_GRAPH = {
"entities": [],
"relations": [],
"version": "1.0.0",
"last_updated": "",
}
# ============================================================
# 数据操作
# ============================================================
def _get_graph() -> Dict[str, Any]:
"""读取图谱数据。"""
filepath = get_data_file(GRAPH_FILE)
if not os.path.exists(filepath):
return dict(_DEFAULT_GRAPH)
data = read_json_file(filepath)
if isinstance(data, list):
return dict(_DEFAULT_GRAPH)
return data
def _save_graph(graph: Dict[str, Any]) -> None:
"""保存图谱数据。"""
graph["last_updated"] = now_iso()
write_json_file(get_data_file(GRAPH_FILE), graph)
def _get_deals() -> List[Dict[str, Any]]:
"""读取所有商机数据。"""
return read_json_file(get_data_file(DEALS_FILE))
def _find_entity(entities: List[Dict], entity_id: str) -> Optional[Dict]:
"""根据ID查找实体。"""
for e in entities:
if e.get("id") == entity_id:
return e
return None
def _find_entity_by_name(entities: List[Dict], name: str,
entity_type: str = "") -> Optional[Dict]:
"""根据名称查找实体。"""
name_lower = name.lower().strip()
for e in entities:
if e.get("name", "").lower().strip() == name_lower:
if not entity_type or e.get("type") == entity_type:
return e
return None
def _get_entity_relations(relations: List[Dict],
entity_id: str) -> List[Dict]:
"""获取与实体相关的所有关系。"""
result = []
for r in relations:
if r.get("from_id") == entity_id or r.get("to_id") == entity_id:
result.append(r)
return result
# ============================================================
# 自动填充
# ============================================================
def _auto_populate_from_deals(graph: Dict[str, Any]) -> int:
"""从商机数据自动填充图谱实体和关系。
Args:
graph: 图谱数据。
Returns:
新增实体和关系的总数。
"""
deals = _get_deals()
entities = graph.get("entities", [])
relations = graph.get("relations", [])
added = 0
# 已有实体名称索引
existing_names = {
(e.get("name", "").lower(), e.get("type", ""))
for e in entities
}
for deal in deals:
deal_id = deal.get("id", "")
deal_name = deal.get("name", "")
# 添加商机实体
if (deal_name.lower(), "Deal") not in existing_names and deal_name:
entity = {
"id": f"GE_{deal_id}",
"type": "Deal",
"name": deal_name,
"properties": {
"deal_id": deal_id,
"amount": deal.get("amount", 0),
"stage": deal.get("stage", ""),
},
"created_at": now_iso(),
}
entities.append(entity)
existing_names.add((deal_name.lower(), "Deal"))
added += 1
# 添加联系人实体
contact_name = deal.get("contact_name", "")
contact_entity_id = None
if contact_name and (contact_name.lower(), "Person") not in existing_names:
contact_entity_id = generate_id("GE")
entity = {
"id": contact_entity_id,
"type": "Person",
"name": contact_name,
"properties": {
"email": deal.get("contact_email", ""),
"phone": deal.get("contact_phone", ""),
},
"created_at": now_iso(),
}
entities.append(entity)
existing_names.add((contact_name.lower(), "Person"))
added += 1
elif contact_name:
# 查找已有实体
existing = _find_entity_by_name(entities, contact_name, "Person")
if existing:
contact_entity_id = existing.get("id")
# 添加公司实体
company = deal.get("company", "")
company_entity_id = None
if company and (company.lower(), "Company") not in existing_names:
company_entity_id = generate_id("GE")
entity = {
"id": company_entity_id,
"type": "Company",
"name": company,
"properties": {},
"created_at": now_iso(),
}
entities.append(entity)
existing_names.add((company.lower(), "Company"))
added += 1
elif company:
existing = _find_entity_by_name(entities, company, "Company")
if existing:
company_entity_id = existing.get("id")
# 建立关系
deal_entity = _find_entity_by_name(entities, deal_name, "Deal")
deal_entity_id = deal_entity.get("id") if deal_entity else None
# 联系人 -> 商机 关系
if contact_entity_id and deal_entity_id:
rel_key = (contact_entity_id, deal_entity_id, "contact_of")
existing_rels = {
(r.get("from_id"), r.get("to_id"), r.get("relation"))
for r in relations
}
if rel_key not in existing_rels:
relations.append({
"id": generate_id("GR"),
"from_id": contact_entity_id,
"to_id": deal_entity_id,
"relation": "contact_of",
"properties": {},
"created_at": now_iso(),
})
added += 1
# 联系人 -> 公司 关系
if contact_entity_id and company_entity_id:
rel_key = (contact_entity_id, company_entity_id, "works_at")
existing_rels = {
(r.get("from_id"), r.get("to_id"), r.get("relation"))
for r in relations
}
if rel_key not in existing_rels:
relations.append({
"id": generate_id("GR"),
"from_id": contact_entity_id,
"to_id": company_entity_id,
"relation": "works_at",
"properties": {},
"created_at": now_iso(),
})
added += 1
graph["entities"] = entities
graph["relations"] = relations
return added
# ============================================================
# BFS 查询
# ============================================================
def _bfs_query(graph: Dict[str, Any], start_id: str,
max_depth: int = 3) -> Dict[str, Any]:
"""广度优先搜索查询相关实体。
Args:
graph: 图谱数据。
start_id: 起始实体ID。
max_depth: 最大搜索深度。
Returns:
搜索结果,包含实体和关系。
"""
entities = graph.get("entities", [])
relations = graph.get("relations", [])
# 构建邻接表
adjacency: Dict[str, List[Tuple[str, Dict]]] = {}
for r in relations:
from_id = r.get("from_id", "")
to_id = r.get("to_id", "")
if from_id not in adjacency:
adjacency[from_id] = []
if to_id not in adjacency:
adjacency[to_id] = []
adjacency[from_id].append((to_id, r))
adjacency[to_id].append((from_id, r))
visited: Set[str] = set()
result_entities = []
result_relations = []
queue: deque = deque()
queue.append((start_id, 0))
visited.add(start_id)
while queue:
current_id, depth = queue.popleft()
# 添加当前实体
entity = _find_entity(entities, current_id)
if entity:
result_entities.append({
**entity,
"depth": depth,
})
if depth >= max_depth:
continue
# 遍历邻居
for neighbor_id, relation in adjacency.get(current_id, []):
if relation not in result_relations:
result_relations.append(relation)
if neighbor_id not in visited:
visited.add(neighbor_id)
queue.append((neighbor_id, depth + 1))
return {
"entities": result_entities,
"relations": result_relations,
}
# ============================================================
# 操作函数
# ============================================================
def add_entity(data: Dict[str, Any]) -> None:
"""添加实体到知识图谱。
必填字段: type, name
可选字段: properties
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "CRM知识图谱"):
return
entity_type = data.get("type", "")
name = data.get("name", "")
if not entity_type:
output_error("实体类型(type)为必填字段", code="VALIDATION_ERROR")
return
if entity_type not in ENTITY_TYPES:
output_error(
f"无效实体类型: {entity_type},有效类型: {', '.join(ENTITY_TYPES)}",
code="VALIDATION_ERROR",
)
return
if not name:
output_error("实体名称(name)为必填字段", code="VALIDATION_ERROR")
return
graph = _get_graph()
entities = graph.get("entities", [])
# 检查重复
existing = _find_entity_by_name(entities, name, entity_type)
if existing:
output_error(
f"已存在同名{entity_type}实体:{name}(ID: {existing.get('id')})",
code="DUPLICATE",
)
return
entity = {
"id": generate_id("GE"),
"type": entity_type,
"name": name,
"properties": data.get("properties", {}),
"created_at": now_iso(),
}
entities.append(entity)
graph["entities"] = entities
_save_graph(graph)
output_success({
"message": f"已添加{entity_type}实体「{name}」",
"entity": entity,
})
def add_relation(data: Dict[str, Any]) -> None:
"""添加关系到知识图谱。
必填字段: from_id(或 from_name), to_id(或 to_name), relation
可选字段: properties
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "CRM知识图谱"):
return
relation_type = data.get("relation", "")
if not relation_type:
output_error("关系类型(relation)为必填字段", code="VALIDATION_ERROR")
return
if relation_type not in RELATION_TYPES:
valid = ", ".join(f"{k}({v})" for k, v in RELATION_TYPES.items())
output_error(
f"无效关系类型: {relation_type},有效类型: {valid}",
code="VALIDATION_ERROR",
)
return
graph = _get_graph()
entities = graph.get("entities", [])
relations = graph.get("relations", [])
# 解析来源实体
from_id = data.get("from_id", "")
if not from_id and data.get("from_name"):
entity = _find_entity_by_name(entities, data["from_name"])
if entity:
from_id = entity.get("id", "")
# 解析目标实体
to_id = data.get("to_id", "")
if not to_id and data.get("to_name"):
entity = _find_entity_by_name(entities, data["to_name"])
if entity:
to_id = entity.get("id", "")
if not from_id:
output_error("来源实体(from_id 或 from_name)未找到", code="NOT_FOUND")
return
if not to_id:
output_error("目标实体(to_id 或 to_name)未找到", code="NOT_FOUND")
return
# 验证实体存在
from_entity = _find_entity(entities, from_id)
to_entity = _find_entity(entities, to_id)
if not from_entity:
output_error(f"未找到ID为 {from_id} 的实体", code="NOT_FOUND")
return
if not to_entity:
output_error(f"未找到ID为 {to_id} 的实体", code="NOT_FOUND")
return
# 检查重复关系
for r in relations:
if (r.get("from_id") == from_id
and r.get("to_id") == to_id
and r.get("relation") == relation_type):
output_error("该关系已存在", code="DUPLICATE")
return
relation = {
"id": generate_id("GR"),
"from_id": from_id,
"to_id": to_id,
"relation": relation_type,
"relation_display": RELATION_TYPES.get(relation_type, relation_type),
"properties": data.get("properties", {}),
"created_at": now_iso(),
}
relations.append(relation)
graph["relations"] = relations
_save_graph(graph)
output_success({
"message": (
f"已添加关系:{from_entity.get('name')} "
f"--[{RELATION_TYPES.get(relation_type, relation_type)}]--> "
f"{to_entity.get('name')}"
),
"relation": relation,
})
def query(data: Dict[str, Any]) -> None:
"""查询实体及其关联。
必填字段: entity_id 或 name
可选字段: max_depth(默认 3)
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "CRM知识图谱查询"):
return
graph = _get_graph()
entities = graph.get("entities", [])
# 自动填充
_auto_populate_from_deals(graph)
_save_graph(graph)
entity_id = data.get("entity_id", "")
if not entity_id and data.get("name"):
entity = _find_entity_by_name(entities, data["name"])
if entity:
entity_id = entity.get("id", "")
if not entity_id:
output_error("请提供 entity_id 或 name", code="VALIDATION_ERROR")
return
start_entity = _find_entity(entities, entity_id)
if not start_entity:
output_error(f"未找到ID为 {entity_id} 的实体", code="NOT_FOUND")
return
max_depth = int(data.get("max_depth", 3))
result = _bfs_query(graph, entity_id, max_depth)
output_success({
"start_entity": start_entity,
"related_entities": result["entities"],
"relations": result["relations"],
"entity_count": len(result["entities"]),
"relation_count": len(result["relations"]),
})
def company_map(data: Dict[str, Any]) -> None:
"""生成公司组织架构图。
必填字段: company(公司名称)或 company_id
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "CRM组织架构图"):
return
graph = _get_graph()
# 自动填充
_auto_populate_from_deals(graph)
_save_graph(graph)
entities = graph.get("entities", [])
relations = graph.get("relations", [])
# 查找公司实体
company_name = data.get("company", "")
company_id = data.get("company_id", "")
company_entity = None
if company_id:
company_entity = _find_entity(entities, company_id)
elif company_name:
company_entity = _find_entity_by_name(entities, company_name, "Company")
if not company_entity:
output_error("未找到指定公司", code="NOT_FOUND")
return
comp_id = company_entity.get("id", "")
# 找到所有与公司相关的人
contacts = []
for r in relations:
if r.get("to_id") == comp_id and r.get("relation") in ("works_at",):
person = _find_entity(entities, r.get("from_id", ""))
if person:
contacts.append({
"person": person,
"relation": r.get("relation", ""),
})
elif r.get("from_id") == comp_id and r.get("relation") in ("works_at",):
person = _find_entity(entities, r.get("to_id", ""))
if person:
contacts.append({
"person": person,
"relation": r.get("relation", ""),
})
# 找出决策人
decision_makers = []
influencers = []
for r in relations:
if r.get("relation") == "decision_maker_for":
person = _find_entity(entities, r.get("from_id", ""))
if person:
decision_makers.append(person.get("name", ""))
# 找出汇报关系
reporting = []
person_ids = {c["person"].get("id") for c in contacts}
for r in relations:
if r.get("relation") == "reports_to":
if r.get("from_id") in person_ids or r.get("to_id") in person_ids:
from_name = ""
to_name = ""
from_e = _find_entity(entities, r.get("from_id", ""))
to_e = _find_entity(entities, r.get("to_id", ""))
if from_e:
from_name = from_e.get("name", "")
if to_e:
to_name = to_e.get("name", "")
reporting.append({
"from": from_name,
"to": to_name,
})
# 关联的商机
related_deals = []
for r in relations:
if r.get("relation") == "contact_of":
if r.get("from_id") in person_ids:
deal_entity = _find_entity(entities, r.get("to_id", ""))
if deal_entity and deal_entity.get("type") == "Deal":
related_deals.append({
"name": deal_entity.get("name", ""),
"properties": deal_entity.get("properties", {}),
})
output_success({
"company": company_entity.get("name", ""),
"company_id": comp_id,
"contacts": [
{
"name": c["person"].get("name", ""),
"id": c["person"].get("id", ""),
"properties": c["person"].get("properties", {}),
"is_decision_maker": c["person"].get("name", "") in decision_makers,
}
for c in contacts
],
"decision_makers": decision_makers,
"reporting_lines": reporting,
"related_deals": related_deals,
"total_contacts": len(contacts),
})
def influence_chain(data: Dict[str, Any]) -> None:
"""追踪推荐/影响力链路。
必填字段: person_name 或 person_id
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "影响力链路"):
return
graph = _get_graph()
entities = graph.get("entities", [])
relations = graph.get("relations", [])
# 查找起始人物
person_name = data.get("person_name", "")
person_id = data.get("person_id", "")
start = None
if person_id:
start = _find_entity(entities, person_id)
elif person_name:
start = _find_entity_by_name(entities, person_name, "Person")
if not start:
output_error("未找到指定人物", code="NOT_FOUND")
return
# 追踪 referred_by 和 knows 链路
chain = []
visited: Set[str] = set()
current_id = start.get("id", "")
# 向上追踪(谁推荐了当前人物)
upstream = []
_trace_chain(entities, relations, current_id, "referred_by",
visited, upstream, direction="upstream")
# 向下追踪(当前人物推荐了谁)
visited_down: Set[str] = set()
downstream = []
_trace_chain(entities, relations, current_id, "referred_by",
visited_down, downstream, direction="downstream")
# knows 网络
knows_network = []
for r in relations:
if r.get("relation") == "knows":
if r.get("from_id") == current_id:
target = _find_entity(entities, r.get("to_id", ""))
if target:
knows_network.append(target.get("name", ""))
elif r.get("to_id") == current_id:
target = _find_entity(entities, r.get("from_id", ""))
if target:
knows_network.append(target.get("name", ""))
output_success({
"person": start.get("name", ""),
"person_id": start.get("id", ""),
"upstream_referrals": upstream,
"downstream_referrals": downstream,
"knows_network": knows_network,
"total_connections": len(upstream) + len(downstream) + len(knows_network),
})
def _trace_chain(entities: List[Dict], relations: List[Dict],
current_id: str, relation_type: str,
visited: Set[str], chain: List[Dict],
direction: str = "upstream", max_depth: int = 10) -> None:
"""递归追踪关系链路。
Args:
entities: 实体列表。
relations: 关系列表。
current_id: 当前实体ID。
relation_type: 关系类型。
visited: 已访问集合。
chain: 结果链路列表。
direction: 追踪方向(upstream/downstream)。
max_depth: 最大深度。
"""
if current_id in visited or len(chain) >= max_depth:
return
visited.add(current_id)
for r in relations:
if r.get("relation") != relation_type:
continue
if direction == "upstream" and r.get("from_id") == current_id:
# 当前被 to_id 推荐
target_id = r.get("to_id", "")
target = _find_entity(entities, target_id)
if target and target_id not in visited:
chain.append({
"name": target.get("name", ""),
"id": target_id,
"type": target.get("type", ""),
})
_trace_chain(entities, relations, target_id,
relation_type, visited, chain, direction, max_depth)
elif direction == "downstream" and r.get("to_id") == current_id:
# 当前推荐了 from_id
target_id = r.get("from_id", "")
target = _find_entity(entities, target_id)
if target and target_id not in visited:
chain.append({
"name": target.get("name", ""),
"id": target_id,
"type": target.get("type", ""),
})
_trace_chain(entities, relations, target_id,
relation_type, visited, chain, direction, max_depth)
def visualize(data: Optional[Dict[str, Any]] = None) -> None:
"""生成 Mermaid 可视化图谱(付费功能)。
可选字段: company(按公司过滤)、entity_id(以某实体为中心)
Args:
data: 可选参数。
"""
if not require_paid_feature("mermaid_chart", "CRM图谱可视化"):
return
graph = _get_graph()
# 自动填充
_auto_populate_from_deals(graph)
_save_graph(graph)
entities = graph.get("entities", [])
relations = graph.get("relations", [])
data = data or {}
# 过滤范围
target_entity_ids: Optional[Set[str]] = None
if data.get("entity_id"):
result = _bfs_query(graph, data["entity_id"], max_depth=2)
target_entity_ids = {e.get("id") for e in result["entities"]}
elif data.get("company"):
company_entity = _find_entity_by_name(entities, data["company"], "Company")
if company_entity:
result = _bfs_query(graph, company_entity["id"], max_depth=2)
target_entity_ids = {e.get("id") for e in result["entities"]}
# 生成 Mermaid 代码
lines = ["```mermaid", "graph LR"]
# 类型样式映射
type_shapes = {
"Person": ("([", "])" ),
"Company": ("[[", "]]"),
"Deal": ("{{", "}}"),
"Meeting": ("(", ")"),
"Email": ("(", ")"),
}
# 添加节点
entity_ids_in_graph: Set[str] = set()
for e in entities:
eid = e.get("id", "")
if target_entity_ids is not None and eid not in target_entity_ids:
continue
name = e.get("name", "").replace('"', "'")
etype = e.get("type", "")
shape = type_shapes.get(etype, ("(", ")"))
# 使用安全的节点ID
safe_id = eid.replace("-", "_")
lines.append(f" {safe_id}{shape[0]}\"{etype}: {name}\"{shape[1]}")
entity_ids_in_graph.add(eid)
# 添加边
for r in relations:
from_id = r.get("from_id", "")
to_id = r.get("to_id", "")
if from_id in entity_ids_in_graph and to_id in entity_ids_in_graph:
relation_name = RELATION_TYPES.get(
r.get("relation", ""), r.get("relation", "")
)
safe_from = from_id.replace("-", "_")
safe_to = to_id.replace("-", "_")
lines.append(f" {safe_from} -->|{relation_name}| {safe_to}")
lines.append("```")
mermaid_code = "\n".join(lines)
output_success({
"mermaid": mermaid_code,
"entity_count": len(entity_ids_in_graph),
"relation_count": sum(
1 for r in relations
if r.get("from_id") in entity_ids_in_graph
and r.get("to_id") in entity_ids_in_graph
),
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("deal-closer CRM知识图谱")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"add-entity": lambda: add_entity(data or {}),
"add-relation": lambda: add_relation(data or {}),
"query": lambda: query(data or {}),
"company-map": lambda: company_map(data or {}),
"influence-chain": lambda: influence_chain(data or {}),
"visualize": lambda: visualize(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/learning_engine.py
#!/usr/bin/env python3
"""
deal-closer 自学习销售智能模块
从历史成交/流失数据中学习模式,预测商机胜率,提供销售教练建议。
基于 self-improving-agent 理念,持续优化销售策略。
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
write_json_file,
format_currency,
format_percentage,
calculate_days_since,
DEAL_STAGES,
STAGE_DEFAULT_PROBABILITY,
)
# ============================================================
# 常量与配置
# ============================================================
LEARNING_FILE = "learning.json"
DEALS_FILE = "deals.json"
# 默认学习数据结构
_DEFAULT_LEARNING_DATA = {
"outcomes": [], # 成交/流失记录
"patterns": [], # 成功模式记录
"version": "1.0.0",
"last_updated": "",
}
# 特征权重(用于简单评分模型)
_FEATURE_WEIGHTS = {
"cycle_days_score": 0.15, # 销售周期合理性
"followup_score": 0.20, # 跟进频率得分
"deal_size_score": 0.10, # 金额规模匹配度
"stage_velocity_score": 0.20, # 阶段推进速度
"industry_score": 0.15, # 行业胜率
"engagement_score": 0.20, # 客户互动得分
}
# 销售周期基准天数(按阶段)
_STAGE_BENCHMARK_DAYS = {
"线索": 7,
"初步接触": 10,
"需求确认": 14,
"方案报价": 10,
"商务谈判": 14,
"合同签署": 7,
}
# ============================================================
# 数据操作
# ============================================================
def _get_learning_data() -> Dict[str, Any]:
"""读取学习数据文件。"""
filepath = get_data_file(LEARNING_FILE)
if not os.path.exists(filepath):
return dict(_DEFAULT_LEARNING_DATA)
data = read_json_file(filepath)
if isinstance(data, list):
# 兼容旧格式
return dict(_DEFAULT_LEARNING_DATA)
return data
def _save_learning_data(data: Dict[str, Any]) -> None:
"""保存学习数据到文件。"""
data["last_updated"] = now_iso()
write_json_file(get_data_file(LEARNING_FILE), data)
def _get_deals() -> List[Dict[str, Any]]:
"""读取所有商机数据。"""
return read_json_file(get_data_file(DEALS_FILE))
# ============================================================
# 特征提取
# ============================================================
def _extract_deal_features(deal: Dict[str, Any]) -> Dict[str, Any]:
"""从商机数据中提取特征向量。
Args:
deal: 商机数据字典。
Returns:
特征字典,包含各维度的原始值。
"""
# 销售周期天数
created = deal.get("created_at", "")
updated = deal.get("updated_at", "")
cycle_days = 0
if created:
cycle_days = calculate_days_since(created)
# 阶段历史分析
history = deal.get("stage_history", [])
stage_count = len(history)
# 计算各阶段停留天数
stage_durations = {}
for i in range(len(history) - 1):
stage = history[i].get("stage", "")
ts_current = history[i].get("timestamp", "")
ts_next = history[i + 1].get("timestamp", "")
if ts_current and ts_next:
try:
t1 = datetime.fromisoformat(ts_current.replace("Z", "+00:00")).replace(tzinfo=None)
t2 = datetime.fromisoformat(ts_next.replace("Z", "+00:00")).replace(tzinfo=None)
duration = (t2 - t1).days
stage_durations[stage] = duration
except (ValueError, TypeError):
pass
# 最后阶段的停留时间
if history:
last_stage = history[-1].get("stage", "")
last_ts = history[-1].get("timestamp", "")
if last_ts:
stage_durations[last_stage] = calculate_days_since(last_ts)
return {
"cycle_days": cycle_days,
"stage_count": stage_count,
"stage_durations": stage_durations,
"amount": deal.get("amount", 0),
"industry": deal.get("company", ""),
"source": deal.get("source", ""),
"tags": deal.get("tags", []),
"current_stage": deal.get("stage", ""),
"probability": deal.get("probability", 0),
}
def _calculate_feature_scores(features: Dict[str, Any],
learning_data: Dict[str, Any]) -> Dict[str, float]:
"""根据特征和历史数据计算各维度评分。
Args:
features: 商机特征字典。
learning_data: 学习数据。
Returns:
各维度评分字典(0.0-1.0)。
"""
outcomes = learning_data.get("outcomes", [])
won_outcomes = [o for o in outcomes if o.get("result") == "won"]
lost_outcomes = [o for o in outcomes if o.get("result") == "lost"]
scores = {}
# 1. 销售周期合理性评分
if won_outcomes:
avg_won_cycle = sum(o.get("cycle_days", 30) for o in won_outcomes) / len(won_outcomes)
cycle_days = features.get("cycle_days", 0)
if avg_won_cycle > 0:
# 与成功案例的平均周期越接近,得分越高
ratio = cycle_days / avg_won_cycle if avg_won_cycle > 0 else 1.0
if ratio <= 1.0:
scores["cycle_days_score"] = 0.5 + ratio * 0.5
elif ratio <= 2.0:
scores["cycle_days_score"] = max(0.2, 1.0 - (ratio - 1.0) * 0.5)
else:
scores["cycle_days_score"] = 0.1
else:
scores["cycle_days_score"] = 0.5
else:
scores["cycle_days_score"] = 0.5
# 2. 跟进频率得分
stage_count = features.get("stage_count", 1)
cycle_days = max(features.get("cycle_days", 1), 1)
followup_rate = stage_count / (cycle_days / 7.0) if cycle_days >= 7 else stage_count
# 每周至少1次阶段推进算好
scores["followup_score"] = min(1.0, followup_rate * 0.5)
# 3. 金额规模匹配度
amount = features.get("amount", 0)
if won_outcomes:
avg_won_amount = sum(o.get("amount", 0) for o in won_outcomes) / len(won_outcomes)
if avg_won_amount > 0:
ratio = amount / avg_won_amount
if 0.5 <= ratio <= 2.0:
scores["deal_size_score"] = 0.8
elif 0.2 <= ratio <= 3.0:
scores["deal_size_score"] = 0.5
else:
scores["deal_size_score"] = 0.3
else:
scores["deal_size_score"] = 0.5
else:
scores["deal_size_score"] = 0.5
# 4. 阶段推进速度评分
stage_durations = features.get("stage_durations", {})
velocity_scores = []
for stage, duration in stage_durations.items():
benchmark = _STAGE_BENCHMARK_DAYS.get(stage, 10)
if duration <= benchmark:
velocity_scores.append(1.0)
elif duration <= benchmark * 2:
velocity_scores.append(0.5)
else:
velocity_scores.append(0.2)
scores["stage_velocity_score"] = (
sum(velocity_scores) / len(velocity_scores) if velocity_scores else 0.5
)
# 5. 行业胜率(根据历史同行业数据)
industry = features.get("industry", "")
if industry and outcomes:
industry_outcomes = [o for o in outcomes if industry in o.get("industry", "")]
if industry_outcomes:
won_count = sum(1 for o in industry_outcomes if o.get("result") == "won")
scores["industry_score"] = won_count / len(industry_outcomes)
else:
scores["industry_score"] = 0.5
else:
scores["industry_score"] = 0.5
# 6. 客户互动得分(基于阶段推进次数)
if stage_count >= 4:
scores["engagement_score"] = 0.9
elif stage_count >= 2:
scores["engagement_score"] = 0.6
else:
scores["engagement_score"] = 0.3
return scores
def _compute_win_probability(scores: Dict[str, float]) -> float:
"""根据各维度评分计算综合胜率。
Args:
scores: 各维度评分字典。
Returns:
加权胜率(0-100)。
"""
total = 0.0
for key, weight in _FEATURE_WEIGHTS.items():
score = scores.get(key, 0.5)
total += score * weight
# 归一化到 0-100
return round(total * 100, 1)
# ============================================================
# 建议生成
# ============================================================
def _generate_suggestions(deal: Dict[str, Any],
features: Dict[str, Any],
learning_data: Dict[str, Any]) -> List[str]:
"""根据商机特征和历史数据生成建议。
Args:
deal: 商机数据。
features: 商机特征。
learning_data: 学习数据。
Returns:
建议列表。
"""
suggestions = []
outcomes = learning_data.get("outcomes", [])
won_outcomes = [o for o in outcomes if o.get("result") == "won"]
stage = deal.get("stage", "")
stage_durations = features.get("stage_durations", {})
# 阶段停留时间过长
if stage in stage_durations:
current_duration = stage_durations[stage]
benchmark = _STAGE_BENCHMARK_DAYS.get(stage, 10)
# 计算历史同阶段平均停留天数
if won_outcomes:
same_stage_durations = []
for o in won_outcomes:
sd = o.get("stage_durations", {})
if stage in sd:
same_stage_durations.append(sd[stage])
if same_stage_durations:
avg_duration = sum(same_stage_durations) / len(same_stage_durations)
if current_duration > avg_duration * 1.5:
suggestions.append(
f"相似规模的商机平均在{stage}阶段停留"
f"{int(avg_duration)}天,当前商机已停留"
f"{current_duration}天,建议主动跟进"
)
if current_duration > benchmark * 2:
suggestions.append(
f"当前在「{stage}」阶段已停留 {current_duration} 天,"
f"超出基准 {benchmark} 天的两倍,存在流失风险"
)
# 跟进时机建议
patterns = learning_data.get("patterns", [])
timing_patterns = [p for p in patterns if p.get("category") == "timing"]
if timing_patterns:
best_timing = max(timing_patterns, key=lambda p: p.get("success_rate", 0))
suggestions.append(
f"历史数据显示,{best_timing.get('description', '周二上午')}的跟进"
f"回复率最高({format_percentage(best_timing.get('success_rate', 0.5))})"
)
else:
suggestions.append(
"历史数据显示,周二上午的跟进邮件回复率最高(65%),建议优先安排此时段跟进"
)
# 金额相关建议
amount = deal.get("amount", 0)
if amount > 0 and won_outcomes:
similar_won = [
o for o in won_outcomes
if 0.5 * amount <= o.get("amount", 0) <= 2.0 * amount
]
if similar_won:
avg_cycle = sum(o.get("cycle_days", 30) for o in similar_won) / len(similar_won)
current_cycle = features.get("cycle_days", 0)
if current_cycle < avg_cycle * 0.5:
suggestions.append(
f"相似金额商机平均周期为 {int(avg_cycle)} 天,"
f"当前仅 {current_cycle} 天,节奏良好"
)
elif current_cycle > avg_cycle * 1.5:
suggestions.append(
f"相似金额商机平均周期为 {int(avg_cycle)} 天,"
f"当前已 {current_cycle} 天,建议加速推进"
)
# 阶段转化建议
if stage in ("需求确认", "方案报价") and won_outcomes:
# 计算转化率
stage_idx = DEAL_STAGES.index(stage) if stage in DEAL_STAGES else -1
if stage_idx >= 0 and stage_idx < len(DEAL_STAGES) - 2:
next_stage = DEAL_STAGES[stage_idx + 1]
total_at_stage = sum(
1 for o in outcomes
if stage in o.get("stage_durations", {})
)
converted = sum(
1 for o in outcomes
if next_stage in o.get("stage_durations", {})
and stage in o.get("stage_durations", {})
)
if total_at_stage > 0:
rate = converted / total_at_stage
if rate < 0.5:
suggestions.append(
f"{stage}→{next_stage}转化率偏低"
f"({format_percentage(rate)}),建议加强需求调研深度"
)
return suggestions
def _generate_coaching_tips(learning_data: Dict[str, Any],
deals: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""根据管道瓶颈生成教练建议。
Args:
learning_data: 学习数据。
deals: 当前商机列表。
Returns:
教练建议列表。
"""
tips = []
outcomes = learning_data.get("outcomes", [])
# 活跃商机
active_deals = [
d for d in deals if d.get("stage") not in ("成交", "流失")
]
if not active_deals:
tips.append({
"category": "pipeline",
"tip": "当前管道为空,建议加大线索获取力度",
"priority": "high",
})
return tips
# 分析各阶段分布
stage_counts = defaultdict(int)
for d in active_deals:
stage_counts[d.get("stage", "")] += 1
# 检测瓶颈:某阶段商机堆积
total_active = len(active_deals)
for stage, count in stage_counts.items():
ratio = count / total_active if total_active > 0 else 0
if ratio > 0.4 and total_active >= 3:
tips.append({
"category": "bottleneck",
"tip": f"「{stage}」阶段商机占比过高({format_percentage(ratio)}),"
f"共 {count} 个商机停滞,建议集中精力推进此阶段转化",
"priority": "high",
})
# 检测停滞商机
stale_count = 0
for d in active_deals:
updated = d.get("updated_at", "")
if updated and calculate_days_since(updated) > 14:
stale_count += 1
if stale_count > 0:
stale_ratio = stale_count / total_active
tips.append({
"category": "stale",
"tip": f"有 {stale_count} 个商机超过14天未更新"
f"(占比{format_percentage(stale_ratio)}),建议逐一排查跟进",
"priority": "high" if stale_ratio > 0.3 else "medium",
})
# 转化率分析(基于历史数据)
if outcomes:
won = sum(1 for o in outcomes if o.get("result") == "won")
total = len(outcomes)
win_rate = won / total if total > 0 else 0
if win_rate < 0.3:
tips.append({
"category": "win_rate",
"tip": f"整体胜率偏低({format_percentage(win_rate)}),"
f"建议复盘最近流失商机,找出共性问题",
"priority": "high",
})
elif win_rate > 0.6:
tips.append({
"category": "win_rate",
"tip": f"整体胜率较高({format_percentage(win_rate)}),"
f"可适当提高目标商机金额或数量",
"priority": "low",
})
# 流失原因分析
loss_reasons = defaultdict(int)
for o in outcomes:
if o.get("result") == "lost":
for reason in o.get("loss_reasons", []):
loss_reasons[reason] += 1
if loss_reasons:
top_reason = max(loss_reasons.items(), key=lambda x: x[1])
tips.append({
"category": "loss_analysis",
"tip": f"最常见流失原因是「{top_reason[0]}」({top_reason[1]}次),"
f"建议针对性优化应对策略",
"priority": "medium",
})
# 跟进频率建议
patterns = learning_data.get("patterns", [])
followup_patterns = [p for p in patterns if p.get("category") == "followup"]
if followup_patterns:
best = max(followup_patterns, key=lambda p: p.get("success_rate", 0))
tips.append({
"category": "best_practice",
"tip": f"最佳实践:{best.get('description', '')},"
f"成功率 {format_percentage(best.get('success_rate', 0))}",
"priority": "low",
})
return tips
# ============================================================
# 操作函数
# ============================================================
def record_outcome(data: Dict[str, Any]) -> None:
"""记录商机成交/流失结果。
必填字段: deal_id, result(won/lost)
可选字段: cycle_days, followup_count, loss_reasons, notes,
contributing_factors
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "自学习销售智能"):
return
deal_id = data.get("deal_id")
result = data.get("result", "").lower()
if not deal_id:
output_error("商机ID(deal_id)为必填字段", code="VALIDATION_ERROR")
return
if result not in ("won", "lost"):
output_error("结果(result)必须为 won 或 lost", code="VALIDATION_ERROR")
return
# 加载商机数据获取特征
deals = _get_deals()
target_deal = None
for d in deals:
if d.get("id") == deal_id:
target_deal = d
break
features = {}
if target_deal:
features = _extract_deal_features(target_deal)
# 构建结果记录
outcome = {
"id": generate_id("LO"),
"deal_id": deal_id,
"deal_name": target_deal.get("name", "") if target_deal else "",
"result": result,
"amount": target_deal.get("amount", 0) if target_deal else data.get("amount", 0),
"industry": target_deal.get("company", "") if target_deal else data.get("industry", ""),
"source": target_deal.get("source", "") if target_deal else data.get("source", ""),
"cycle_days": data.get("cycle_days", features.get("cycle_days", 0)),
"followup_count": data.get("followup_count", features.get("stage_count", 0)),
"stage_durations": features.get("stage_durations", {}),
"loss_reasons": data.get("loss_reasons", []),
"contributing_factors": data.get("contributing_factors", []),
"notes": data.get("notes", ""),
"recorded_at": now_iso(),
}
learning = _get_learning_data()
learning["outcomes"].append(outcome)
_save_learning_data(learning)
output_success({
"message": f"已记录商机结果:{result}",
"outcome": outcome,
"total_outcomes": len(learning["outcomes"]),
})
def record_pattern(data: Dict[str, Any]) -> None:
"""记录成功模式。
必填字段: category, description
可选字段: success_rate, applicable_stages, notes
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "自学习销售智能"):
return
category = data.get("category", "")
description = data.get("description", "")
if not category:
output_error("模式类别(category)为必填字段", code="VALIDATION_ERROR")
return
if not description:
output_error("模式描述(description)为必填字段", code="VALIDATION_ERROR")
return
valid_categories = [
"timing", "communication", "pricing", "followup",
"negotiation", "presentation", "objection_handling", "other",
]
if category not in valid_categories:
output_error(
f"无效类别: {category},有效类别: {', '.join(valid_categories)}",
code="VALIDATION_ERROR",
)
return
pattern = {
"id": generate_id("LP"),
"category": category,
"description": description,
"success_rate": min(1.0, max(0.0, float(data.get("success_rate", 0.5)))),
"applicable_stages": data.get("applicable_stages", []),
"notes": data.get("notes", ""),
"recorded_at": now_iso(),
}
learning = _get_learning_data()
learning["patterns"].append(pattern)
_save_learning_data(learning)
output_success({
"message": f"已记录成功模式:{description}",
"pattern": pattern,
"total_patterns": len(learning["patterns"]),
})
def predict(data: Dict[str, Any]) -> None:
"""预测商机胜率。
必填字段: deal_id
可选: 无参数时预测所有活跃商机
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "AI胜率预测"):
return
deal_id = data.get("deal_id")
deals = _get_deals()
learning = _get_learning_data()
if deal_id:
# 预测单个商机
target = None
for d in deals:
if d.get("id") == deal_id:
target = d
break
if not target:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
features = _extract_deal_features(target)
scores = _calculate_feature_scores(features, learning)
probability = _compute_win_probability(scores)
suggestions = _generate_suggestions(target, features, learning)
output_success({
"deal_id": deal_id,
"deal_name": target.get("name", ""),
"current_stage": target.get("stage", ""),
"manual_probability": target.get("probability", 0),
"ai_probability": probability,
"dimension_scores": {
k: round(v * 100, 1)
for k, v in scores.items()
},
"suggestions": suggestions,
"data_basis": len(learning.get("outcomes", [])),
})
else:
# 预测所有活跃商机
active = [
d for d in deals
if d.get("stage") not in ("成交", "流失")
]
if not active:
output_error("暂无活跃商机", code="NO_DATA")
return
predictions = []
for deal in active:
features = _extract_deal_features(deal)
scores = _calculate_feature_scores(features, learning)
probability = _compute_win_probability(scores)
predictions.append({
"deal_id": deal.get("id", ""),
"deal_name": deal.get("name", ""),
"stage": deal.get("stage", ""),
"amount": deal.get("amount", 0),
"amount_display": format_currency(deal.get("amount", 0)),
"manual_probability": deal.get("probability", 0),
"ai_probability": probability,
})
# 按AI预测胜率排序
predictions.sort(key=lambda p: p["ai_probability"], reverse=True)
output_success({
"total": len(predictions),
"predictions": predictions,
"data_basis": len(learning.get("outcomes", [])),
})
def suggest(data: Dict[str, Any]) -> None:
"""为商机生成主动建议。
必填字段: deal_id
Args:
data: 参数字典。
"""
if not require_paid_feature("advanced_analytics", "AI销售建议"):
return
deal_id = data.get("deal_id")
if not deal_id:
output_error("商机ID(deal_id)为必填字段", code="VALIDATION_ERROR")
return
deals = _get_deals()
target = None
for d in deals:
if d.get("id") == deal_id:
target = d
break
if not target:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
learning = _get_learning_data()
features = _extract_deal_features(target)
suggestions = _generate_suggestions(target, features, learning)
if not suggestions:
suggestions.append("暂无特定建议,建议保持当前跟进节奏")
output_success({
"deal_id": deal_id,
"deal_name": target.get("name", ""),
"current_stage": target.get("stage", ""),
"suggestions": suggestions,
"suggestion_count": len(suggestions),
})
def coach(data: Optional[Dict[str, Any]] = None) -> None:
"""生成销售教练建议。
基于当前管道瓶颈和历史数据。
Args:
data: 可选参数。
"""
if not require_paid_feature("advanced_analytics", "销售教练"):
return
deals = _get_deals()
learning = _get_learning_data()
tips = _generate_coaching_tips(learning, deals)
if not tips:
tips.append({
"category": "general",
"tip": "当前管道状态良好,继续保持。建议定期复盘成交案例,沉淀最佳实践。",
"priority": "low",
})
# 按优先级排序
priority_order = {"high": 0, "medium": 1, "low": 2}
tips.sort(key=lambda t: priority_order.get(t.get("priority", "low"), 2))
output_success({
"tips": tips,
"total_tips": len(tips),
"data_basis": {
"outcomes": len(learning.get("outcomes", [])),
"patterns": len(learning.get("patterns", [])),
"active_deals": len([
d for d in deals if d.get("stage") not in ("成交", "流失")
]),
},
})
def stats(data: Optional[Dict[str, Any]] = None) -> None:
"""生成学习统计报告。
包含胜率趋势、平均周期、流失原因、最佳实践。
Args:
data: 可选参数。
"""
if not require_paid_feature("advanced_analytics", "学习统计"):
return
learning = _get_learning_data()
outcomes = learning.get("outcomes", [])
patterns = learning.get("patterns", [])
if not outcomes:
output_success({
"message": "暂无历史数据,请先使用 record-outcome 记录商机结果",
"total_outcomes": 0,
"total_patterns": len(patterns),
})
return
won = [o for o in outcomes if o.get("result") == "won"]
lost = [o for o in outcomes if o.get("result") == "lost"]
# 胜率
win_rate = len(won) / len(outcomes) if outcomes else 0
# 平均销售周期
avg_cycle_won = (
sum(o.get("cycle_days", 0) for o in won) / len(won)
if won else 0
)
avg_cycle_lost = (
sum(o.get("cycle_days", 0) for o in lost) / len(lost)
if lost else 0
)
# 平均成交金额
avg_amount_won = (
sum(o.get("amount", 0) for o in won) / len(won)
if won else 0
)
# 流失原因统计
loss_reasons = defaultdict(int)
for o in lost:
for reason in o.get("loss_reasons", []):
loss_reasons[reason] += 1
top_loss_reasons = sorted(
loss_reasons.items(), key=lambda x: x[1], reverse=True
)[:5]
# 最佳实践
best_practices = sorted(
patterns,
key=lambda p: p.get("success_rate", 0),
reverse=True,
)[:5]
# 按月胜率趋势
monthly_stats = defaultdict(lambda: {"won": 0, "lost": 0})
for o in outcomes:
month = o.get("recorded_at", "")[:7]
if month:
if o.get("result") == "won":
monthly_stats[month]["won"] += 1
else:
monthly_stats[month]["lost"] += 1
win_rate_trend = []
for month in sorted(monthly_stats.keys()):
ms = monthly_stats[month]
total = ms["won"] + ms["lost"]
rate = ms["won"] / total if total > 0 else 0
win_rate_trend.append({
"month": month,
"won": ms["won"],
"lost": ms["lost"],
"win_rate": round(rate, 4),
"win_rate_display": format_percentage(rate),
})
output_success({
"total_outcomes": len(outcomes),
"total_won": len(won),
"total_lost": len(lost),
"win_rate": round(win_rate, 4),
"win_rate_display": format_percentage(win_rate),
"avg_cycle_won_days": round(avg_cycle_won, 1),
"avg_cycle_lost_days": round(avg_cycle_lost, 1),
"avg_won_amount": round(avg_amount_won, 2),
"avg_won_amount_display": format_currency(avg_amount_won),
"top_loss_reasons": [
{"reason": r, "count": c}
for r, c in top_loss_reasons
],
"best_practices": [
{
"category": p.get("category", ""),
"description": p.get("description", ""),
"success_rate": format_percentage(p.get("success_rate", 0)),
}
for p in best_practices
],
"win_rate_trend": win_rate_trend,
"total_patterns": len(patterns),
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("deal-closer 自学习销售智能")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"record-outcome": lambda: record_outcome(data or {}),
"record-pattern": lambda: record_pattern(data or {}),
"predict": lambda: predict(data or {}),
"suggest": lambda: suggest(data or {}),
"coach": lambda: coach(data),
"stats": lambda: stats(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/pipeline_reporter.py
#!/usr/bin/env python3
"""
deal-closer 销售漏斗与管道报告模块
提供销售漏斗分析、收入预测、周报/月报生成等功能。
支持 Mermaid 图表(付费功能)和风险预警。
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from utils import (
check_subscription,
get_data_file,
load_input_data,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
format_currency,
format_percentage,
calculate_days_since,
DEAL_STAGES,
STAGE_COLORS,
STAGE_DEFAULT_PROBABILITY,
)
# 延迟导入学习引擎
_learning_module = None
def _get_learning_module():
"""延迟加载 learning_engine 模块。"""
global _learning_module
if _learning_module is None:
try:
import learning_engine as _mod
_learning_module = _mod
except ImportError:
_learning_module = False
return _learning_module if _learning_module is not False else None
def _get_ai_predictions(deals: List[Dict[str, Any]]) -> Dict[str, float]:
"""获取所有活跃商机的AI预测胜率。
Args:
deals: 商机列表。
Returns:
商机ID到AI预测胜率的映射。
"""
learning_mod = _get_learning_module()
if learning_mod is None:
return {}
try:
learning_data = learning_mod._get_learning_data()
predictions = {}
for deal in deals:
if deal.get("stage") in ("成交", "流失"):
continue
features = learning_mod._extract_deal_features(deal)
scores = learning_mod._calculate_feature_scores(features, learning_data)
probability = learning_mod._compute_win_probability(scores)
predictions[deal.get("id", "")] = probability
return predictions
except Exception:
return {}
def _get_coaching_tips() -> List[Dict[str, str]]:
"""获取销售教练建议。
Returns:
教练建议列表。
"""
learning_mod = _get_learning_module()
if learning_mod is None:
return []
try:
deals = read_json_file(get_data_file(DEALS_FILE))
learning_data = learning_mod._get_learning_data()
return learning_mod._generate_coaching_tips(learning_data, deals)
except Exception:
return []
def _calculate_pipeline_health(deals: List[Dict[str, Any]]) -> Dict[str, Any]:
"""计算管道健康评分。
基于商机年龄与历史平均周期的比较。
Args:
deals: 商机列表。
Returns:
健康评分数据。
"""
learning_mod = _get_learning_module()
active_deals = [
d for d in deals if d.get("stage") not in ("成交", "流失")
]
if not active_deals:
return {"score": 0, "level": "无数据", "details": []}
# 基准周期天数(按阶段)
stage_benchmarks = {
"线索": 7, "初步接触": 10, "需求确认": 14,
"方案报价": 10, "商务谈判": 14, "合同签署": 7,
}
# 如果有学习数据,用历史数据覆盖基准
if learning_mod:
try:
learning_data = learning_mod._get_learning_data()
won_outcomes = [
o for o in learning_data.get("outcomes", [])
if o.get("result") == "won"
]
if won_outcomes:
for stage in stage_benchmarks:
durations = []
for o in won_outcomes:
sd = o.get("stage_durations", {})
if stage in sd:
durations.append(sd[stage])
if durations:
stage_benchmarks[stage] = int(
sum(durations) / len(durations)
)
except Exception:
pass
health_scores = []
details = []
for deal in active_deals:
stage = deal.get("stage", "")
updated = deal.get("updated_at", "")
days_since = calculate_days_since(updated) if updated else 0
benchmark = stage_benchmarks.get(stage, 10)
if days_since <= benchmark:
score = 100
elif days_since <= benchmark * 2:
score = max(50, 100 - (days_since - benchmark) * 5)
else:
score = max(10, 50 - (days_since - benchmark * 2) * 3)
health_scores.append(score)
if score < 60:
details.append({
"deal_name": deal.get("name", ""),
"stage": stage,
"days_since_update": days_since,
"benchmark_days": benchmark,
"health_score": score,
})
avg_score = sum(health_scores) / len(health_scores) if health_scores else 0
avg_score = round(avg_score, 1)
if avg_score >= 80:
level = "健康"
elif avg_score >= 60:
level = "一般"
elif avg_score >= 40:
level = "需关注"
else:
level = "危险"
return {
"score": avg_score,
"level": level,
"total_deals": len(active_deals),
"at_risk_count": len(details),
"at_risk_deals": sorted(details, key=lambda d: d["health_score"])[:5],
}
# ============================================================
# 数据文件
# ============================================================
DEALS_FILE = "deals.json"
MEETINGS_FILE = "meetings.json"
def _get_deals() -> List[Dict[str, Any]]:
"""读取所有商机数据。"""
return read_json_file(get_data_file(DEALS_FILE))
def _get_meetings() -> List[Dict[str, Any]]:
"""读取所有会议记录。"""
return read_json_file(get_data_file(MEETINGS_FILE))
# ============================================================
# Mermaid 图表生成
# ============================================================
def _generate_pie_chart(title: str, data: List[Dict[str, Any]]) -> str:
"""生成 Mermaid 饼图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
Returns:
Mermaid 饼图代码块字符串。
"""
lines = ["```mermaid", f"pie title {title}"]
for item in data:
label = item.get("label", "未知")
value = item.get("value", 0)
if value > 0:
lines.append(f' "{label}" : {value}')
lines.append("```")
return "\n".join(lines)
def _generate_bar_chart(title: str, data: List[Dict[str, Any]], y_label: str = "金额") -> str:
"""生成 Mermaid 柱状图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
y_label: Y 轴标签。
Returns:
Mermaid 柱状图代码块字符串。
"""
labels = [f'"{item.get("label", "")}"' for item in data]
values = [str(item.get("value", 0)) for item in data]
lines = [
"```mermaid",
"xychart-beta",
f' title "{title}"',
f' x-axis [{", ".join(labels)}]',
f' y-axis "{y_label}"',
f' bar [{", ".join(values)}]',
"```",
]
return "\n".join(lines)
def _generate_line_chart(title: str, data: List[Dict[str, Any]], y_label: str = "数值") -> str:
"""生成 Mermaid 折线图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
y_label: Y 轴标签。
Returns:
Mermaid 折线图代码块字符串。
"""
labels = [f'"{item.get("label", "")}"' for item in data]
values = [str(item.get("value", 0)) for item in data]
lines = [
"```mermaid",
"xychart-beta",
f' title "{title}"',
f' x-axis [{", ".join(labels)}]',
f' y-axis "{y_label}"',
f' line [{", ".join(values)}]',
"```",
]
return "\n".join(lines)
# ============================================================
# 分析函数
# ============================================================
def _calculate_conversion_rates(deals: List[Dict]) -> List[Dict[str, Any]]:
"""计算各阶段转化率。
Args:
deals: 商机列表。
Returns:
各阶段转化率数据列表。
"""
stage_counts = {}
for stage in DEAL_STAGES:
stage_counts[stage] = sum(1 for d in deals if d.get("stage") == stage)
conversions = []
# 排除 "流失" 计算正向转化
active_stages = [s for s in DEAL_STAGES if s != "流失"]
for i in range(len(active_stages) - 1):
current = active_stages[i]
next_stage = active_stages[i + 1]
current_count = stage_counts.get(current, 0)
next_count = stage_counts.get(next_stage, 0)
# 累计到达该阶段的商机数(当前 + 后续所有阶段)
total_at_or_past = sum(
stage_counts.get(s, 0)
for s in active_stages[active_stages.index(current):]
)
total_past = sum(
stage_counts.get(s, 0)
for s in active_stages[active_stages.index(next_stage):]
)
rate = total_past / total_at_or_past if total_at_or_past > 0 else 0.0
conversions.append({
"from_stage": current,
"to_stage": next_stage,
"from_count": total_at_or_past,
"to_count": total_past,
"conversion_rate": round(rate, 4),
})
return conversions
def _detect_risk_deals(deals: List[Dict], stale_days: int = 14) -> List[Dict[str, Any]]:
"""检测风险商机。
识别长时间未更新和高价值风险商机。
Args:
deals: 商机列表。
stale_days: 停滞天数阈值。
Returns:
风险商机列表。
"""
risks = []
for deal in deals:
stage = deal.get("stage", "")
if stage in ("成交", "流失"):
continue
updated_at = deal.get("updated_at", "")
days_since = calculate_days_since(updated_at) if updated_at else 0
risk_reasons = []
# 停滞风险
if days_since >= stale_days:
risk_reasons.append(f"已 {days_since} 天未更新")
# 超期风险
expected = deal.get("expected_close_date", "")
if expected:
try:
exp_date = datetime.strptime(expected, "%Y-%m-%d")
if exp_date < datetime.now():
overdue_days = (datetime.now() - exp_date).days
risk_reasons.append(f"已超出预计成交日期 {overdue_days} 天")
except ValueError:
pass
# 高金额低概率
amount = deal.get("amount", 0)
probability = deal.get("probability", 0)
if amount >= 100000 and probability <= 30:
risk_reasons.append("高金额低概率")
if risk_reasons:
risks.append({
"deal_id": deal.get("id", ""),
"deal_name": deal.get("name", ""),
"stage": stage,
"amount": amount,
"amount_display": format_currency(amount),
"probability": probability,
"days_since_update": days_since,
"risk_reasons": risk_reasons,
})
# 按金额降序排序
risks.sort(key=lambda r: r.get("amount", 0), reverse=True)
return risks
# ============================================================
# 报告操作
# ============================================================
def funnel_report(data: Optional[Dict[str, Any]] = None) -> None:
"""生成销售漏斗报告。
展示各阶段商机数量、金额和转化率。
Args:
data: 可选参数。
"""
deals = _get_deals()
if not deals:
output_error("暂无商机数据", code="NO_DATA")
return
sub = check_subscription()
is_paid = sub["tier"] == "paid"
# 各阶段统计
stage_data = []
for stage in DEAL_STAGES:
stage_deals = [d for d in deals if d.get("stage") == stage]
total_amount = sum(d.get("amount", 0) for d in stage_deals)
stage_data.append({
"stage": stage,
"count": len(stage_deals),
"total_amount": total_amount,
"total_amount_display": format_currency(total_amount),
})
# 转化率
conversions = _calculate_conversion_rates(deals)
# 风险商机
risks = _detect_risk_deals(deals)
result = {
"total_deals": len(deals),
"total_amount": sum(d.get("amount", 0) for d in deals),
"total_amount_display": format_currency(sum(d.get("amount", 0) for d in deals)),
"stages": stage_data,
"conversions": conversions,
"risk_deals": risks[:10],
"risk_count": len(risks),
}
# 付费用户生成 Mermaid 图表
if is_paid:
# 饼图:阶段分布
pie_data = [
{"label": s["stage"], "value": s["count"]}
for s in stage_data if s["count"] > 0
]
result["mermaid_pie"] = _generate_pie_chart("商机阶段分布", pie_data)
# 柱状图:各阶段金额
bar_data = [
{"label": s["stage"], "value": int(s["total_amount"] / 10000)}
for s in stage_data
]
result["mermaid_bar"] = _generate_bar_chart("各阶段金额(万元)", bar_data, y_label="万元")
output_success(result)
def forecast_report(data: Optional[Dict[str, Any]] = None) -> None:
"""生成收入预测报告。
根据管道金额乘以成交概率计算加权预测。
Args:
data: 可选参数,支持 period(月/季度)。
"""
if not require_paid_feature("forecast", "收入预测"):
return
deals = _get_deals()
if not deals:
output_error("暂无商机数据", code="NO_DATA")
return
data = data or {}
stale_days = data.get("stale_days", 14)
# 排除已成交和已流失
active_deals = [
d for d in deals
if d.get("stage") not in ("成交", "流失")
]
# 加权预测
weighted_total = 0.0
stage_forecast = {}
forecast_details = []
for deal in active_deals:
amount = deal.get("amount", 0)
probability = deal.get("probability", 0) / 100.0
weighted = amount * probability
stage = deal.get("stage", "")
if stage not in stage_forecast:
stage_forecast[stage] = {"count": 0, "raw_amount": 0, "weighted_amount": 0}
stage_forecast[stage]["count"] += 1
stage_forecast[stage]["raw_amount"] += amount
stage_forecast[stage]["weighted_amount"] += weighted
weighted_total += weighted
forecast_details.append({
"deal_id": deal.get("id", ""),
"deal_name": deal.get("name", ""),
"stage": stage,
"amount": amount,
"amount_display": format_currency(amount),
"probability": deal.get("probability", 0),
"weighted_amount": round(weighted, 2),
"weighted_display": format_currency(weighted),
"expected_close_date": deal.get("expected_close_date", ""),
"ai_probability": None, # 占位,稍后填充
})
# 填充 AI 预测胜率
ai_predictions = _get_ai_predictions(active_deals)
for fd in forecast_details:
did = fd.get("deal_id", "")
if did in ai_predictions:
fd["ai_probability"] = ai_predictions[did]
# 按加权金额排序
forecast_details.sort(key=lambda x: x["weighted_amount"], reverse=True)
# 格式化阶段预测
stage_forecast_display = []
for stage in DEAL_STAGES:
if stage in stage_forecast:
sf = stage_forecast[stage]
stage_forecast_display.append({
"stage": stage,
"count": sf["count"],
"raw_amount": sf["raw_amount"],
"raw_amount_display": format_currency(sf["raw_amount"]),
"weighted_amount": round(sf["weighted_amount"], 2),
"weighted_display": format_currency(sf["weighted_amount"]),
})
# 已成交金额
won_deals = [d for d in deals if d.get("stage") == "成交"]
won_amount = sum(d.get("amount", 0) for d in won_deals)
# 管道健康评分
pipeline_health = _calculate_pipeline_health(deals)
# 教练建议
coaching_tips = _get_coaching_tips()
result = {
"forecast_total": round(weighted_total, 2),
"forecast_display": format_currency(weighted_total),
"active_deals": len(active_deals),
"raw_pipeline": sum(d.get("amount", 0) for d in active_deals),
"raw_pipeline_display": format_currency(sum(d.get("amount", 0) for d in active_deals)),
"won_amount": won_amount,
"won_amount_display": format_currency(won_amount),
"won_deals": len(won_deals),
"stage_forecast": stage_forecast_display,
"top_deals": forecast_details[:10],
"pipeline_health": pipeline_health,
"coaching_tips": coaching_tips[:3],
}
# Mermaid 图表
bar_data = [
{"label": sf["stage"], "value": int(sf["weighted_amount"] / 10000)}
for sf in stage_forecast_display
]
result["mermaid_forecast"] = _generate_bar_chart(
"各阶段加权预测(万元)", bar_data, y_label="万元"
)
output_success(result)
def monthly_report(data: Optional[Dict[str, Any]] = None) -> None:
"""生成月度销售报告。
Args:
data: 可选参数,支持 month(YYYY-MM)。
"""
if not require_paid_feature("advanced_analytics", "月度报告"):
return
deals = _get_deals()
if not deals:
output_error("暂无商机数据", code="NO_DATA")
return
data = data or {}
target_month = data.get("month", datetime.now().strftime("%Y-%m"))
# 本月新增
new_deals = [
d for d in deals
if d.get("created_at", "").startswith(target_month)
]
# 本月成交
won_deals = []
for d in deals:
if d.get("stage") != "成交":
continue
history = d.get("stage_history", [])
for h in history:
if h.get("stage") == "成交" and h.get("timestamp", "").startswith(target_month):
won_deals.append(d)
break
# 本月流失
lost_deals = []
for d in deals:
if d.get("stage") != "流失":
continue
history = d.get("stage_history", [])
for h in history:
if h.get("stage") == "流失" and h.get("timestamp", "").startswith(target_month):
lost_deals.append(d)
break
# 当前管道
active_deals = [
d for d in deals
if d.get("stage") not in ("成交", "流失")
]
new_amount = sum(d.get("amount", 0) for d in new_deals)
won_amount = sum(d.get("amount", 0) for d in won_deals)
lost_amount = sum(d.get("amount", 0) for d in lost_deals)
pipeline_amount = sum(d.get("amount", 0) for d in active_deals)
# 会议统计
meetings = _get_meetings()
month_meetings = [m for m in meetings if m.get("date", "").startswith(target_month)]
# 管道健康评分和教练建议
pipeline_health = _calculate_pipeline_health(deals)
coaching_tips = _get_coaching_tips()
result = {
"month": target_month,
"summary": {
"new_deals": len(new_deals),
"new_amount": new_amount,
"new_amount_display": format_currency(new_amount),
"won_deals": len(won_deals),
"won_amount": won_amount,
"won_amount_display": format_currency(won_amount),
"lost_deals": len(lost_deals),
"lost_amount": lost_amount,
"lost_amount_display": format_currency(lost_amount),
"active_pipeline": len(active_deals),
"pipeline_amount": pipeline_amount,
"pipeline_amount_display": format_currency(pipeline_amount),
"meetings": len(month_meetings),
"win_rate": round(
len(won_deals) / max(len(won_deals) + len(lost_deals), 1), 4
),
},
"pipeline_health": pipeline_health,
"coaching_tips": coaching_tips[:5],
}
# Mermaid 图表
pie_data = [
{"label": "成交", "value": len(won_deals)},
{"label": "流失", "value": len(lost_deals)},
{"label": "进行中", "value": len(active_deals)},
]
result["mermaid_overview"] = _generate_pie_chart(
f"{target_month} 商机状态分布", pie_data
)
bar_data = [
{"label": "新增", "value": int(new_amount / 10000)},
{"label": "成交", "value": int(won_amount / 10000)},
{"label": "流失", "value": int(lost_amount / 10000)},
{"label": "管道", "value": int(pipeline_amount / 10000)},
]
result["mermaid_amounts"] = _generate_bar_chart(
f"{target_month} 金额概览(万元)", bar_data, y_label="万元"
)
output_success(result)
def weekly_report(data: Optional[Dict[str, Any]] = None) -> None:
"""生成周度销售报告。
Args:
data: 可选参数,支持 week_start(YYYY-MM-DD,默认为本周一)。
"""
if not require_paid_feature("advanced_analytics", "周度报告"):
return
deals = _get_deals()
if not deals:
output_error("暂无商机数据", code="NO_DATA")
return
data = data or {}
week_start_str = data.get("week_start")
if week_start_str:
try:
week_start = datetime.strptime(week_start_str, "%Y-%m-%d")
except ValueError:
output_error("week_start 格式错误,请使用 YYYY-MM-DD", code="VALIDATION_ERROR")
return
else:
now = datetime.now()
week_start = now - timedelta(days=now.weekday())
week_end = week_start + timedelta(days=6)
ws = week_start.strftime("%Y-%m-%d")
we = week_end.strftime("%Y-%m-%d")
# 本周新增
new_deals = [
d for d in deals
if ws <= d.get("created_at", "")[:10] <= we
]
# 本周更新
updated_deals = [
d for d in deals
if ws <= d.get("updated_at", "")[:10] <= we
]
# 本周成交
won_deals = []
for d in deals:
if d.get("stage") != "成交":
continue
history = d.get("stage_history", [])
for h in history:
ts = h.get("timestamp", "")[:10]
if h.get("stage") == "成交" and ws <= ts <= we:
won_deals.append(d)
break
# 当前管道
active_deals = [
d for d in deals
if d.get("stage") not in ("成交", "流失")
]
# 风险商机
risks = _detect_risk_deals(deals)
# 会议统计
meetings = _get_meetings()
week_meetings = [
m for m in meetings
if ws <= m.get("date", "")[:10] <= we
]
# 管道健康评分和教练建议
pipeline_health = _calculate_pipeline_health(deals)
coaching_tips = _get_coaching_tips()
result = {
"week": f"{ws} ~ {we}",
"summary": {
"new_deals": len(new_deals),
"new_amount": sum(d.get("amount", 0) for d in new_deals),
"new_amount_display": format_currency(sum(d.get("amount", 0) for d in new_deals)),
"updated_deals": len(updated_deals),
"won_deals": len(won_deals),
"won_amount": sum(d.get("amount", 0) for d in won_deals),
"won_amount_display": format_currency(sum(d.get("amount", 0) for d in won_deals)),
"active_pipeline": len(active_deals),
"pipeline_amount": sum(d.get("amount", 0) for d in active_deals),
"pipeline_display": format_currency(sum(d.get("amount", 0) for d in active_deals)),
"risk_deals": len(risks),
"meetings": len(week_meetings),
},
"risk_deals": risks[:5],
"pipeline_health": pipeline_health,
"coaching_tips": coaching_tips[:3],
}
# Mermaid 图表
stage_data = []
for stage in DEAL_STAGES:
count = sum(1 for d in active_deals if d.get("stage") == stage)
if count > 0:
stage_data.append({"label": stage, "value": count})
if stage_data:
result["mermaid_pipeline"] = _generate_pie_chart("本周管道分布", stage_data)
output_success(result)
def trends_report(data: Optional[Dict[str, Any]] = None) -> None:
"""生成趋势分析报告。
Args:
data: 可选参数,支持 months(分析几个月,默认 6)。
"""
if not require_paid_feature("advanced_analytics", "趋势分析"):
return
deals = _get_deals()
if not deals:
output_error("暂无商机数据", code="NO_DATA")
return
data = data or {}
months_count = data.get("months", 6)
try:
months_count = int(months_count)
except (TypeError, ValueError):
months_count = 6
# 生成月份列表
now = datetime.now()
months = []
for i in range(months_count - 1, -1, -1):
dt = now - timedelta(days=i * 30)
months.append(dt.strftime("%Y-%m"))
# 各月新增商机数和金额
monthly_new = []
monthly_won = []
for month in months:
new_in_month = [
d for d in deals
if d.get("created_at", "").startswith(month)
]
won_in_month = []
for d in deals:
if d.get("stage") != "成交":
continue
for h in d.get("stage_history", []):
if h.get("stage") == "成交" and h.get("timestamp", "").startswith(month):
won_in_month.append(d)
break
monthly_new.append({
"month": month,
"count": len(new_in_month),
"amount": sum(d.get("amount", 0) for d in new_in_month),
})
monthly_won.append({
"month": month,
"count": len(won_in_month),
"amount": sum(d.get("amount", 0) for d in won_in_month),
})
result = {
"period": f"{months[0]} ~ {months[-1]}",
"monthly_new": monthly_new,
"monthly_won": monthly_won,
}
# Mermaid 图表
new_chart_data = [
{"label": m["month"][-5:], "value": m["count"]}
for m in monthly_new
]
result["mermaid_new_trend"] = _generate_line_chart(
"月度新增商机趋势", new_chart_data, y_label="数量"
)
won_chart_data = [
{"label": m["month"][-5:], "value": int(m["amount"] / 10000)}
for m in monthly_won
]
result["mermaid_won_trend"] = _generate_bar_chart(
"月度成交金额趋势(万元)", won_chart_data, y_label="万元"
)
output_success(result)
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("deal-closer 销售管道报告")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"funnel": lambda: funnel_report(data),
"forecast": lambda: forecast_report(data),
"monthly": lambda: monthly_report(data),
"weekly": lambda: weekly_report(data),
"trends": lambda: trends_report(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/imap_email.py
#!/usr/bin/env python3
"""
deal-closer IMAP/SMTP 原生邮件模块
通过标准库 imaplib/smtplib 实现邮件收发,支持任意邮件服务商。
无需 OAuth2 配置,直接使用 IMAP/SMTP 协议连接。
基于 imap-smtp-email 理念,提供通用邮件集成能力。
"""
import email
import email.header
import email.mime.multipart
import email.mime.text
import email.utils
import imaplib
import json
import os
import smtplib
import ssl
import sys
from datetime import datetime, timedelta
from email.header import decode_header
from typing import Any, Dict, List, Optional, Tuple
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
mask_email,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
write_json_file,
)
# ============================================================
# 常量与配置
# ============================================================
EMAIL_CONFIG_FILE = "email_config.json"
# 默认端口
DEFAULT_IMAP_PORT = 993
DEFAULT_SMTP_PORT = 587
# 常见邮件服务商 IMAP/SMTP 配置
PROVIDER_CONFIGS = {
"qq": {
"imap_host": "imap.qq.com",
"imap_port": 993,
"smtp_host": "smtp.qq.com",
"smtp_port": 587,
},
"163": {
"imap_host": "imap.163.com",
"imap_port": 993,
"smtp_host": "smtp.163.com",
"smtp_port": 465,
},
"gmail": {
"imap_host": "imap.gmail.com",
"imap_port": 993,
"smtp_host": "smtp.gmail.com",
"smtp_port": 587,
},
"outlook": {
"imap_host": "outlook.office365.com",
"imap_port": 993,
"smtp_host": "smtp.office365.com",
"smtp_port": 587,
},
"aliyun": {
"imap_host": "imap.aliyun.com",
"imap_port": 993,
"smtp_host": "smtp.aliyun.com",
"smtp_port": 465,
},
}
# ============================================================
# 配置管理
# ============================================================
def _get_email_config() -> Dict[str, Any]:
"""读取邮件配置(不含密码)。"""
filepath = get_data_file(EMAIL_CONFIG_FILE)
if not os.path.exists(filepath):
return {}
data = read_json_file(filepath)
if isinstance(data, list):
return {}
return data
def _save_email_config(config: Dict[str, Any]) -> None:
"""保存邮件配置(不含密码)。"""
# 确保不保存密码
safe_config = dict(config)
safe_config.pop("password", None)
safe_config["last_updated"] = now_iso()
write_json_file(get_data_file(EMAIL_CONFIG_FILE), safe_config)
def _get_credentials() -> Dict[str, str]:
"""从环境变量获取邮件凭据。
Returns:
包含连接信息的字典。
"""
return {
"imap_host": os.environ.get("DC_IMAP_HOST", ""),
"imap_port": int(os.environ.get("DC_IMAP_PORT", str(DEFAULT_IMAP_PORT))),
"smtp_host": os.environ.get("DC_SMTP_HOST", ""),
"smtp_port": int(os.environ.get("DC_SMTP_PORT", str(DEFAULT_SMTP_PORT))),
"email_user": os.environ.get("DC_EMAIL_USER", ""),
"email_password": os.environ.get("DC_EMAIL_PASSWORD", ""),
}
def _auto_detect_provider(email_addr: str) -> Optional[Dict[str, Any]]:
"""根据邮箱地址自动检测服务商配置。
Args:
email_addr: 邮箱地址。
Returns:
服务商配置字典,未识别返回 None。
"""
if not email_addr or "@" not in email_addr:
return None
domain = email_addr.split("@")[-1].lower()
for provider, config in PROVIDER_CONFIGS.items():
if provider in domain:
return config
# 尝试域名匹配
if "qq.com" in domain:
return PROVIDER_CONFIGS["qq"]
elif "163.com" in domain or "126.com" in domain:
return PROVIDER_CONFIGS["163"]
elif "gmail.com" in domain:
return PROVIDER_CONFIGS["gmail"]
elif "outlook.com" in domain or "hotmail.com" in domain:
return PROVIDER_CONFIGS["outlook"]
elif "aliyun.com" in domain:
return PROVIDER_CONFIGS["aliyun"]
return None
# ============================================================
# 邮件解析
# ============================================================
def _decode_header_value(value: str) -> str:
"""解码邮件头字段值,处理中文编码。
Args:
value: 原始头字段值。
Returns:
解码后的字符串。
"""
if not value:
return ""
try:
decoded_parts = decode_header(value)
result = []
for part, charset in decoded_parts:
if isinstance(part, bytes):
# 尝试用指定编码解码
if charset:
try:
result.append(part.decode(charset))
except (UnicodeDecodeError, LookupError):
result.append(part.decode("utf-8", errors="replace"))
else:
result.append(part.decode("utf-8", errors="replace"))
else:
result.append(str(part))
return " ".join(result)
except Exception:
return str(value)
def _extract_plain_text(msg: email.message.Message) -> str:
"""从邮件消息中提取纯文本内容。
处理 multipart 邮件,优先提取 text/plain。
Args:
msg: email.message.Message 对象。
Returns:
纯文本内容。
"""
text_parts = []
if msg.is_multipart():
for part in msg.walk():
content_type = part.get_content_type()
content_disposition = str(part.get("Content-Disposition", ""))
# 跳过附件
if "attachment" in content_disposition:
continue
if content_type == "text/plain":
payload = part.get_payload(decode=True)
if payload:
charset = part.get_content_charset() or "utf-8"
try:
text_parts.append(payload.decode(charset, errors="replace"))
except (UnicodeDecodeError, LookupError):
text_parts.append(payload.decode("utf-8", errors="replace"))
else:
content_type = msg.get_content_type()
if content_type == "text/plain":
payload = msg.get_payload(decode=True)
if payload:
charset = msg.get_content_charset() or "utf-8"
try:
text_parts.append(payload.decode(charset, errors="replace"))
except (UnicodeDecodeError, LookupError):
text_parts.append(payload.decode("utf-8", errors="replace"))
return "\n".join(text_parts)
def _parse_email_message(msg: email.message.Message,
msg_id: str = "") -> Dict[str, Any]:
"""解析邮件消息为字典。
Args:
msg: email.message.Message 对象。
msg_id: 邮件 UID。
Returns:
解析后的邮件字典。
"""
subject = _decode_header_value(msg.get("Subject", ""))
from_addr = _decode_header_value(msg.get("From", ""))
to_addr = _decode_header_value(msg.get("To", ""))
date_str = msg.get("Date", "")
message_id = msg.get("Message-ID", "")
# 解析日期
parsed_date = ""
if date_str:
try:
dt = email.utils.parsedate_to_datetime(date_str)
parsed_date = dt.strftime("%Y-%m-%dT%H:%M:%S")
except (ValueError, TypeError):
parsed_date = date_str
# 提取正文
body = _extract_plain_text(msg)
# 限制预览长度
body_preview = body[:500] if body else ""
return {
"uid": msg_id,
"message_id": message_id,
"subject": subject,
"from": from_addr,
"to": to_addr,
"date": parsed_date,
"body_preview": body_preview,
"body_length": len(body),
}
# ============================================================
# IMAP 操作
# ============================================================
def _create_imap_connection(creds: Dict[str, Any]) -> Optional[imaplib.IMAP4_SSL]:
"""创建 IMAP SSL 连接。
Args:
creds: 连接凭据。
Returns:
IMAP4_SSL 连接对象,失败返回 None。
"""
host = creds.get("imap_host", "")
port = int(creds.get("imap_port", DEFAULT_IMAP_PORT))
user = creds.get("email_user", "")
password = creds.get("email_password", "")
if not all([host, user, password]):
return None
try:
context = ssl.create_default_context()
conn = imaplib.IMAP4_SSL(host, port, ssl_context=context)
conn.login(user, password)
return conn
except (imaplib.IMAP4.error, OSError, ssl.SSLError) as e:
return None
def _create_smtp_connection(creds: Dict[str, Any]) -> Optional[smtplib.SMTP]:
"""创建 SMTP 连接。
Args:
creds: 连接凭据。
Returns:
SMTP 连接对象,失败返回 None。
"""
host = creds.get("smtp_host", "")
port = int(creds.get("smtp_port", DEFAULT_SMTP_PORT))
user = creds.get("email_user", "")
password = creds.get("email_password", "")
if not all([host, user, password]):
return None
try:
if port == 465:
# SSL 直连
context = ssl.create_default_context()
conn = smtplib.SMTP_SSL(host, port, context=context)
else:
# STARTTLS
conn = smtplib.SMTP(host, port, timeout=30)
conn.starttls()
conn.login(user, password)
return conn
except (smtplib.SMTPException, OSError, ssl.SSLError) as e:
return None
# ============================================================
# 操作函数
# ============================================================
def connect_test(data: Optional[Dict[str, Any]] = None) -> None:
"""测试 IMAP/SMTP 连接。
可选字段: provider(自动填充服务商配置)
Args:
data: 可选参数。
"""
if not require_paid_feature("email_scan", "IMAP/SMTP 邮件"):
return
creds = _get_credentials()
data = data or {}
# 如果环境变量未设置,尝试自动检测
if not creds.get("imap_host") and creds.get("email_user"):
auto_config = _auto_detect_provider(creds["email_user"])
if auto_config:
if not creds.get("imap_host"):
creds["imap_host"] = auto_config["imap_host"]
creds["imap_port"] = auto_config["imap_port"]
if not creds.get("smtp_host"):
creds["smtp_host"] = auto_config["smtp_host"]
creds["smtp_port"] = auto_config["smtp_port"]
# 也可手动指定 provider
provider = data.get("provider", "")
if provider and provider in PROVIDER_CONFIGS:
pc = PROVIDER_CONFIGS[provider]
if not creds.get("imap_host"):
creds["imap_host"] = pc["imap_host"]
creds["imap_port"] = pc["imap_port"]
if not creds.get("smtp_host"):
creds["smtp_host"] = pc["smtp_host"]
creds["smtp_port"] = pc["smtp_port"]
if not creds.get("email_user") or not creds.get("email_password"):
output_error(
"请设置以下环境变量:\n"
" DC_EMAIL_USER — 邮箱地址\n"
" DC_EMAIL_PASSWORD — 邮箱密码或授权码\n"
"可选:DC_IMAP_HOST, DC_IMAP_PORT, DC_SMTP_HOST, DC_SMTP_PORT",
code="NO_CREDENTIALS",
)
return
results = {"imap": False, "smtp": False, "imap_error": "", "smtp_error": ""}
# 测试 IMAP
try:
imap_conn = _create_imap_connection(creds)
if imap_conn:
results["imap"] = True
imap_conn.logout()
else:
results["imap_error"] = "连接失败,请检查 IMAP 配置"
except Exception as e:
results["imap_error"] = str(e)
# 测试 SMTP
try:
smtp_conn = _create_smtp_connection(creds)
if smtp_conn:
results["smtp"] = True
smtp_conn.quit()
else:
results["smtp_error"] = "连接失败,请检查 SMTP 配置"
except Exception as e:
results["smtp_error"] = str(e)
# 保存配置(不含密码)
config = {
"imap_host": creds.get("imap_host", ""),
"imap_port": creds.get("imap_port", DEFAULT_IMAP_PORT),
"smtp_host": creds.get("smtp_host", ""),
"smtp_port": creds.get("smtp_port", DEFAULT_SMTP_PORT),
"email_user": mask_email(creds.get("email_user", "")),
"imap_connected": results["imap"],
"smtp_connected": results["smtp"],
}
_save_email_config(config)
if results["imap"] and results["smtp"]:
output_success({
"message": "IMAP 和 SMTP 连接测试成功!",
"imap": {"connected": True, "host": creds.get("imap_host", "")},
"smtp": {"connected": True, "host": creds.get("smtp_host", "")},
"user": mask_email(creds.get("email_user", "")),
})
elif results["imap"] or results["smtp"]:
output_success({
"message": "部分连接成功",
"imap": {
"connected": results["imap"],
"host": creds.get("imap_host", ""),
"error": results["imap_error"],
},
"smtp": {
"connected": results["smtp"],
"host": creds.get("smtp_host", ""),
"error": results["smtp_error"],
},
"user": mask_email(creds.get("email_user", "")),
})
else:
output_error(
f"连接失败。\n"
f"IMAP: {results['imap_error']}\n"
f"SMTP: {results['smtp_error']}",
code="CONNECTION_FAILED",
)
def fetch_inbox(data: Optional[Dict[str, Any]] = None) -> None:
"""获取收件箱最近邮件。
可选字段: count(数量,默认 20)、folder(文件夹,默认 INBOX)
Args:
data: 可选参数。
"""
if not require_paid_feature("email_scan", "IMAP邮件获取"):
return
data = data or {}
count = int(data.get("count", 20))
folder = data.get("folder", "INBOX")
creds = _get_credentials()
if not creds.get("imap_host") and creds.get("email_user"):
auto_config = _auto_detect_provider(creds["email_user"])
if auto_config:
creds["imap_host"] = auto_config["imap_host"]
creds["imap_port"] = auto_config["imap_port"]
conn = _create_imap_connection(creds)
if not conn:
output_error(
"IMAP 连接失败,请先使用 connect 测试连接配置",
code="CONNECTION_FAILED",
)
return
try:
status, _ = conn.select(folder, readonly=True)
if status != "OK":
output_error(f"无法打开文件夹: {folder}", code="FOLDER_ERROR")
conn.logout()
return
# 搜索所有邮件
status, data_list = conn.search(None, "ALL")
if status != "OK":
output_error("搜索邮件失败", code="SEARCH_ERROR")
conn.logout()
return
msg_ids = data_list[0].split()
if not msg_ids:
output_success({
"message": "收件箱为空",
"total": 0,
"emails": [],
})
conn.logout()
return
# 取最近 N 封
recent_ids = msg_ids[-count:]
recent_ids.reverse() # 最新的在前
emails_list = []
for mid in recent_ids:
try:
status, msg_data = conn.fetch(mid, "(RFC822)")
if status == "OK" and msg_data[0]:
raw = msg_data[0][1]
if isinstance(raw, bytes):
msg = email.message_from_bytes(raw)
parsed = _parse_email_message(msg, mid.decode("utf-8"))
emails_list.append(parsed)
except Exception:
continue
conn.logout()
output_success({
"message": f"获取到 {len(emails_list)} 封邮件",
"folder": folder,
"total": len(emails_list),
"emails": emails_list,
})
except Exception as e:
try:
conn.logout()
except Exception:
pass
output_error(f"获取邮件失败: {e}", code="FETCH_ERROR")
def search_emails(data: Dict[str, Any]) -> None:
"""搜索邮件。
可选字段: subject, from_addr, since(YYYY-MM-DD), before, folder
Args:
data: 搜索参数字典。
"""
if not require_paid_feature("email_scan", "IMAP邮件搜索"):
return
subject = data.get("subject", "")
from_addr = data.get("from_addr", "")
since = data.get("since", "")
before = data.get("before", "")
folder = data.get("folder", "INBOX")
max_results = int(data.get("max_results", 50))
creds = _get_credentials()
if not creds.get("imap_host") and creds.get("email_user"):
auto_config = _auto_detect_provider(creds["email_user"])
if auto_config:
creds["imap_host"] = auto_config["imap_host"]
creds["imap_port"] = auto_config["imap_port"]
conn = _create_imap_connection(creds)
if not conn:
output_error("IMAP 连接失败", code="CONNECTION_FAILED")
return
try:
conn.select(folder, readonly=True)
# 构建 IMAP 搜索条件
criteria = []
if subject:
criteria.append(f'SUBJECT "{subject}"')
if from_addr:
criteria.append(f'FROM "{from_addr}"')
if since:
try:
dt = datetime.strptime(since, "%Y-%m-%d")
imap_date = dt.strftime("%d-%b-%Y")
criteria.append(f"SINCE {imap_date}")
except ValueError:
pass
if before:
try:
dt = datetime.strptime(before, "%Y-%m-%d")
imap_date = dt.strftime("%d-%b-%Y")
criteria.append(f"BEFORE {imap_date}")
except ValueError:
pass
search_str = " ".join(criteria) if criteria else "ALL"
status, data_list = conn.search(None, search_str)
if status != "OK":
output_error("搜索失败", code="SEARCH_ERROR")
conn.logout()
return
msg_ids = data_list[0].split()
msg_ids = msg_ids[-max_results:]
msg_ids.reverse()
emails_list = []
for mid in msg_ids:
try:
status, msg_data = conn.fetch(mid, "(RFC822)")
if status == "OK" and msg_data[0]:
raw = msg_data[0][1]
if isinstance(raw, bytes):
msg = email.message_from_bytes(raw)
parsed = _parse_email_message(msg, mid.decode("utf-8"))
emails_list.append(parsed)
except Exception:
continue
conn.logout()
output_success({
"message": f"搜索到 {len(emails_list)} 封邮件",
"search_criteria": search_str,
"total": len(emails_list),
"emails": emails_list,
})
except Exception as e:
try:
conn.logout()
except Exception:
pass
output_error(f"搜索失败: {e}", code="SEARCH_ERROR")
def send_email(data: Dict[str, Any]) -> None:
"""发送邮件。
必填字段: to, subject, body
可选字段: cc, bcc
Args:
data: 邮件参数字典。
"""
if not require_paid_feature("email_scan", "SMTP邮件发送"):
return
to_addr = data.get("to", "")
subject = data.get("subject", "")
body = data.get("body", "")
if not to_addr:
output_error("收件人(to)为必填字段", code="VALIDATION_ERROR")
return
if not subject:
output_error("主题(subject)为必填字段", code="VALIDATION_ERROR")
return
if not body:
output_error("正文(body)为必填字段", code="VALIDATION_ERROR")
return
creds = _get_credentials()
if not creds.get("smtp_host") and creds.get("email_user"):
auto_config = _auto_detect_provider(creds["email_user"])
if auto_config:
creds["smtp_host"] = auto_config["smtp_host"]
creds["smtp_port"] = auto_config["smtp_port"]
from_addr = creds.get("email_user", "")
if not from_addr:
output_error("未配置发件人邮箱(DC_EMAIL_USER)", code="NO_CREDENTIALS")
return
# 构建邮件
msg = email.mime.multipart.MIMEMultipart()
msg["From"] = from_addr
msg["To"] = to_addr
msg["Subject"] = subject
msg["Date"] = email.utils.formatdate(localtime=True)
cc = data.get("cc", "")
if cc:
msg["Cc"] = cc
# 添加正文
msg.attach(email.mime.text.MIMEText(body, "plain", "utf-8"))
# 收件人列表
recipients = [to_addr]
if cc:
recipients.extend([a.strip() for a in cc.split(",") if a.strip()])
bcc = data.get("bcc", "")
if bcc:
recipients.extend([a.strip() for a in bcc.split(",") if a.strip()])
# 发送
conn = _create_smtp_connection(creds)
if not conn:
output_error("SMTP 连接失败", code="CONNECTION_FAILED")
return
try:
conn.sendmail(from_addr, recipients, msg.as_string())
conn.quit()
output_success({
"message": f"邮件已发送至 {mask_email(to_addr)}",
"to": mask_email(to_addr),
"subject": subject,
"sent_at": now_iso(),
})
except smtplib.SMTPException as e:
try:
conn.quit()
except Exception:
pass
output_error(f"发送失败: {e}", code="SEND_ERROR")
def reply_email(data: Dict[str, Any]) -> None:
"""回复邮件。
必填字段: original_message_id(或 to, subject), body
可选字段: to(覆盖原始发件人)
Args:
data: 回复参数字典。
"""
if not require_paid_feature("email_scan", "SMTP邮件回复"):
return
body = data.get("body", "")
to_addr = data.get("to", "")
subject = data.get("subject", "")
original_id = data.get("original_message_id", "")
if not body:
output_error("回复正文(body)为必填字段", code="VALIDATION_ERROR")
return
if not to_addr:
output_error("收件人(to)为必填字段", code="VALIDATION_ERROR")
return
# 自动添加 Re: 前缀
if subject and not subject.startswith("Re:"):
subject = f"Re: {subject}"
elif not subject:
subject = "Re: (无主题)"
# 构建回复邮件
creds = _get_credentials()
if not creds.get("smtp_host") and creds.get("email_user"):
auto_config = _auto_detect_provider(creds["email_user"])
if auto_config:
creds["smtp_host"] = auto_config["smtp_host"]
creds["smtp_port"] = auto_config["smtp_port"]
from_addr = creds.get("email_user", "")
if not from_addr:
output_error("未配置发件人邮箱", code="NO_CREDENTIALS")
return
msg = email.mime.multipart.MIMEMultipart()
msg["From"] = from_addr
msg["To"] = to_addr
msg["Subject"] = subject
msg["Date"] = email.utils.formatdate(localtime=True)
if original_id:
msg["In-Reply-To"] = original_id
msg["References"] = original_id
msg.attach(email.mime.text.MIMEText(body, "plain", "utf-8"))
conn = _create_smtp_connection(creds)
if not conn:
output_error("SMTP 连接失败", code="CONNECTION_FAILED")
return
try:
conn.sendmail(from_addr, [to_addr], msg.as_string())
conn.quit()
output_success({
"message": f"回复已发送至 {mask_email(to_addr)}",
"to": mask_email(to_addr),
"subject": subject,
"sent_at": now_iso(),
})
except smtplib.SMTPException as e:
try:
conn.quit()
except Exception:
pass
output_error(f"回复发送失败: {e}", code="SEND_ERROR")
def list_folders(data: Optional[Dict[str, Any]] = None) -> None:
"""列出邮箱文件夹。
Args:
data: 可选参数。
"""
if not require_paid_feature("email_scan", "IMAP文件夹列表"):
return
creds = _get_credentials()
if not creds.get("imap_host") and creds.get("email_user"):
auto_config = _auto_detect_provider(creds["email_user"])
if auto_config:
creds["imap_host"] = auto_config["imap_host"]
creds["imap_port"] = auto_config["imap_port"]
conn = _create_imap_connection(creds)
if not conn:
output_error("IMAP 连接失败", code="CONNECTION_FAILED")
return
try:
status, folder_list = conn.list()
if status != "OK":
output_error("获取文件夹列表失败", code="FOLDER_ERROR")
conn.logout()
return
folders = []
for item in folder_list:
if isinstance(item, bytes):
decoded = item.decode("utf-8", errors="replace")
# 解析 IMAP LIST 响应格式
# 格式: (\\flags) "delimiter" "name"
parts = decoded.split('"')
if len(parts) >= 3:
folder_name = parts[-2] if parts[-1].strip() == "" else parts[-1].strip()
if not folder_name:
folder_name = parts[-2]
folders.append(folder_name)
else:
folders.append(decoded)
conn.logout()
output_success({
"message": f"共 {len(folders)} 个文件夹",
"folders": folders,
})
except Exception as e:
try:
conn.logout()
except Exception:
pass
output_error(f"获取文件夹失败: {e}", code="FOLDER_ERROR")
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("deal-closer IMAP/SMTP 原生邮件")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"connect": lambda: connect_test(data),
"fetch-inbox": lambda: fetch_inbox(data),
"search": lambda: search_emails(data or {}),
"send": lambda: send_email(data or {}),
"reply": lambda: reply_email(data or {}),
"list-folders": lambda: list_folders(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/email_scanner.py
#!/usr/bin/env python3
"""
deal-closer 邮件扫描模块(付费功能)
扫描 Gmail / Outlook 邮箱中的邮件,提取商机信号,关联到商机记录。
支持 OAuth2 认证的 Gmail API 和 Outlook API。
"""
import json
import os
import re
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from urllib.request import Request, urlopen
from urllib.parse import urlencode
from urllib.error import URLError
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
mask_email,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
write_json_file,
calculate_days_since,
)
# 延迟导入 IMAP 和学习模块(避免循环依赖)
_imap_module = None
_learning_module = None
def _get_imap_module():
"""延迟加载 imap_email 模块。"""
global _imap_module
if _imap_module is None:
try:
import imap_email as _mod
_imap_module = _mod
except ImportError:
_imap_module = False
return _imap_module if _imap_module is not False else None
def _get_learning_module():
"""延迟加载 learning_engine 模块。"""
global _learning_module
if _learning_module is None:
try:
import learning_engine as _mod
_learning_module = _mod
except ImportError:
_learning_module = False
return _learning_module if _learning_module is not False else None
# ============================================================
# 常量与配置
# ============================================================
EMAILS_FILE = "emails.json"
DEALS_FILE = "deals.json"
# 信号关键词定义
POSITIVE_KEYWORDS = [
"同意", "可以", "没问题", "感兴趣", "非常好", "合作", "签约",
"确认", "批准", "通过", "接受", "agree", "interested", "approve",
"confirmed", "accept", "deal", "go ahead", "proceed", "sign",
"好的", "行", "成交", "下单", "购买", "采购",
]
NEGATIVE_KEYWORDS = [
"推迟", "延期", "考虑", "再看看", "暂时不", "预算不够",
"竞争对手", "其他方案", "不合适", "太贵", "价格高",
"delay", "postpone", "competitor", "budget", "expensive",
"not now", "reconsider", "cancel", "暂缓", "放弃", "取消",
]
NEUTRAL_KEYWORDS = [
"了解", "咨询", "请问", "资料", "方案", "报价",
"详情", "介绍", "信息", "inquiry", "information",
"question", "brochure", "proposal", "quote", "什么时候",
]
# 信号类型常量
SIGNAL_POSITIVE = "POSITIVE"
SIGNAL_NEGATIVE = "NEGATIVE"
SIGNAL_NEUTRAL = "NEUTRAL"
# ============================================================
# 数据操作
# ============================================================
def _get_emails() -> List[Dict[str, Any]]:
"""读取所有邮件记录。"""
return read_json_file(get_data_file(EMAILS_FILE))
def _save_emails(emails: List[Dict[str, Any]]) -> None:
"""保存邮件记录到文件。"""
write_json_file(get_data_file(EMAILS_FILE), emails)
def _get_deals() -> List[Dict[str, Any]]:
"""读取所有商机数据。"""
return read_json_file(get_data_file(DEALS_FILE))
# ============================================================
# Gmail API 集成
# ============================================================
def _load_gmail_credentials() -> Optional[Dict[str, Any]]:
"""加载 Gmail OAuth2 凭据文件。
从 DC_GMAIL_CREDENTIALS 环境变量指定的路径读取凭据。
Returns:
凭据字典,若文件不存在或无效则返回 None。
"""
cred_path = os.environ.get("DC_GMAIL_CREDENTIALS", "")
if not cred_path or not os.path.exists(cred_path):
return None
try:
with open(cred_path, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return None
def _fetch_gmail_messages(credentials: Dict[str, Any], query: str = "", max_results: int = 50) -> List[Dict[str, Any]]:
"""通过 Gmail API 获取邮件列表。
Args:
credentials: OAuth2 凭据字典,需包含 access_token。
query: Gmail 搜索查询字符串。
max_results: 最大返回数量。
Returns:
邮件消息列表。
"""
access_token = credentials.get("access_token", "")
if not access_token:
return []
params = {"maxResults": max_results}
if query:
params["q"] = query
url = f"https://gmail.googleapis.com/gmail/v1/users/me/messages?{urlencode(params)}"
headers = {"Authorization": f"Bearer {access_token}"}
try:
req = Request(url, headers=headers)
with urlopen(req, timeout=30) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data.get("messages", [])
except (URLError, json.JSONDecodeError, Exception):
return []
def _get_gmail_message_detail(credentials: Dict[str, Any], message_id: str) -> Optional[Dict[str, Any]]:
"""获取 Gmail 邮件详情。
Args:
credentials: OAuth2 凭据字典。
message_id: 邮件 ID。
Returns:
邮件详情字典,失败返回 None。
"""
access_token = credentials.get("access_token", "")
if not access_token:
return None
url = f"https://gmail.googleapis.com/gmail/v1/users/me/messages/{message_id}?format=metadata"
headers = {"Authorization": f"Bearer {access_token}"}
try:
req = Request(url, headers=headers)
with urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode("utf-8"))
except (URLError, json.JSONDecodeError, Exception):
return None
# ============================================================
# Outlook API 集成
# ============================================================
def _get_outlook_credentials() -> Optional[Dict[str, str]]:
"""获取 Outlook API 凭据。
从环境变量读取 DC_OUTLOOK_CLIENT_ID 和 DC_OUTLOOK_SECRET。
Returns:
凭据字典,若缺少必要环境变量则返回 None。
"""
client_id = os.environ.get("DC_OUTLOOK_CLIENT_ID", "")
secret = os.environ.get("DC_OUTLOOK_SECRET", "")
if not client_id or not secret:
return None
return {"client_id": client_id, "client_secret": secret}
def _fetch_outlook_messages(credentials: Dict[str, str], query: str = "", max_results: int = 50) -> List[Dict[str, Any]]:
"""通过 Outlook API 获取邮件列表。
Args:
credentials: 包含 client_id 和 client_secret 的凭据字典。
query: 搜索查询字符串。
max_results: 最大返回数量。
Returns:
邮件消息列表。
"""
# 注意:实际集成需完成 OAuth2 flow 获取 access_token
# 此处为 API 调用框架,需要用户完成 OAuth 授权后获取 token
access_token = os.environ.get("DC_OUTLOOK_ACCESS_TOKEN", "")
if not access_token:
return []
params = {"$top": max_results}
if query:
params["$search"] = f'"{query}"'
url = f"https://graph.microsoft.com/v1.0/me/messages?{urlencode(params)}"
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
try:
req = Request(url, headers=headers)
with urlopen(req, timeout=30) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data.get("value", [])
except (URLError, json.JSONDecodeError, Exception):
return []
# ============================================================
# IMAP 扫描
# ============================================================
def _scan_imap(config: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""通过 IMAP 协议扫描邮件。
使用 imap_email 模块获取收件箱最近邮件。
Args:
config: 可选配置,支持 count(数量)、folder(文件夹)。
Returns:
扫描到的邮件列表。
"""
imap_mod = _get_imap_module()
if imap_mod is None:
return []
config = config or {}
creds = imap_mod._get_credentials()
# 检查凭据是否可用
if not creds.get("email_user") or not creds.get("email_password"):
return []
# 自动检测服务商配置
if not creds.get("imap_host") and creds.get("email_user"):
auto_config = imap_mod._auto_detect_provider(creds["email_user"])
if auto_config:
creds["imap_host"] = auto_config["imap_host"]
creds["imap_port"] = auto_config["imap_port"]
if not creds.get("imap_host"):
return []
conn = imap_mod._create_imap_connection(creds)
if not conn:
return []
scanned = []
try:
import email as email_stdlib
folder = config.get("folder", "INBOX")
count = int(config.get("count", 50))
status, _ = conn.select(folder, readonly=True)
if status != "OK":
conn.logout()
return []
status, data_list = conn.search(None, "ALL")
if status != "OK":
conn.logout()
return []
msg_ids = data_list[0].split()
recent_ids = msg_ids[-count:]
recent_ids.reverse()
for mid in recent_ids:
try:
status, msg_data = conn.fetch(mid, "(RFC822)")
if status == "OK" and msg_data[0]:
raw = msg_data[0][1]
if isinstance(raw, bytes):
msg = email_stdlib.message_from_bytes(raw)
parsed = imap_mod._parse_email_message(
msg, mid.decode("utf-8")
)
scanned.append({
"provider": "imap",
"message_id": parsed.get("message_id", mid.decode("utf-8")),
"subject": parsed.get("subject", ""),
"from": parsed.get("from", ""),
"date": parsed.get("date", ""),
"snippet": parsed.get("body_preview", "")[:200],
})
except Exception:
continue
conn.logout()
except Exception:
try:
conn.logout()
except Exception:
pass
return scanned
def _record_scan_patterns(scanned: List[Dict[str, Any]]) -> None:
"""将扫描结果中的模式记录到学习引擎。
Args:
scanned: 扫描到的邮件列表。
"""
learning_mod = _get_learning_module()
if learning_mod is None or not scanned:
return
try:
# 统计信号来源和类型
provider_counts: Dict[str, int] = {}
for item in scanned:
provider = item.get("provider", "unknown")
provider_counts[provider] = provider_counts.get(provider, 0) + 1
# 记录扫描模式到学习数据
learning_data = learning_mod._get_learning_data()
patterns = learning_data.get("patterns", [])
# 检查是否已有今日的扫描记录
today = today_str()
existing_today = [
p for p in patterns
if p.get("category") == "email_scan"
and p.get("recorded_at", "").startswith(today)
]
if not existing_today:
from utils import generate_id as _gen_id
pattern = {
"id": _gen_id("LP"),
"category": "email_scan",
"description": (
f"邮件扫描:共 {len(scanned)} 封,"
f"来源分布 {json.dumps(provider_counts, ensure_ascii=False)}"
),
"success_rate": 0.5,
"applicable_stages": [],
"notes": f"自动记录于 {today}",
"recorded_at": now_iso(),
}
patterns.append(pattern)
learning_data["patterns"] = patterns
learning_mod._save_learning_data(learning_data)
except Exception:
# 学习记录失败不影响主流程
pass
# ============================================================
# 信号分析
# ============================================================
def _analyze_signal(text: str) -> Tuple[str, float, List[str]]:
"""分析文本中的商机信号。
扫描文本内容,根据关键词匹配判断信号类型。
Args:
text: 待分析的文本内容。
Returns:
(信号类型, 置信度, 匹配的关键词列表) 元组。
"""
text_lower = text.lower()
matched_positive = []
matched_negative = []
matched_neutral = []
for kw in POSITIVE_KEYWORDS:
if kw.lower() in text_lower:
matched_positive.append(kw)
for kw in NEGATIVE_KEYWORDS:
if kw.lower() in text_lower:
matched_negative.append(kw)
for kw in NEUTRAL_KEYWORDS:
if kw.lower() in text_lower:
matched_neutral.append(kw)
pos_count = len(matched_positive)
neg_count = len(matched_negative)
neu_count = len(matched_neutral)
total = pos_count + neg_count + neu_count
if total == 0:
return SIGNAL_NEUTRAL, 0.0, []
# 根据匹配数量和比例判断信号类型
if pos_count > neg_count and pos_count >= neu_count:
confidence = min(pos_count / max(total, 1), 1.0)
return SIGNAL_POSITIVE, round(confidence, 2), matched_positive
elif neg_count > pos_count and neg_count >= neu_count:
confidence = min(neg_count / max(total, 1), 1.0)
return SIGNAL_NEGATIVE, round(confidence, 2), matched_negative
else:
confidence = min(neu_count / max(total, 1), 1.0)
return SIGNAL_NEUTRAL, round(confidence, 2), matched_neutral
def _extract_email_addresses(text: str) -> List[str]:
"""从文本中提取邮箱地址。
Args:
text: 待搜索的文本。
Returns:
匹配到的邮箱地址列表。
"""
pattern = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
return re.findall(pattern, text)
def _match_deal_by_contact(email_addr: str, deals: List[Dict]) -> Optional[Dict]:
"""根据邮箱地址匹配对应的商机。
Args:
email_addr: 邮箱地址。
deals: 商机列表。
Returns:
匹配到的商机字典,未匹配返回 None。
"""
email_lower = email_addr.lower().strip()
for deal in deals:
deal_email = deal.get("contact_email", "").lower().strip()
if deal_email and deal_email == email_lower:
return deal
return None
# ============================================================
# 操作函数
# ============================================================
def scan_emails(data: Optional[Dict[str, Any]] = None) -> None:
"""扫描邮箱中的邮件并存储。
支持 Gmail 和 Outlook 两种邮箱类型。
根据环境变量自动选择可用的邮箱源。
Args:
data: 可选参数,支持 provider(gmail/outlook)、query、max_results。
"""
if not require_paid_feature("email_scan", "邮件扫描"):
return
data = data or {}
provider = data.get("provider", "").lower()
query = data.get("query", "")
max_results = data.get("max_results", 50)
scanned = []
# 尝试 Gmail
if provider in ("", "gmail"):
gmail_creds = _load_gmail_credentials()
if gmail_creds:
messages = _fetch_gmail_messages(gmail_creds, query=query, max_results=max_results)
for msg in messages:
msg_id = msg.get("id", "")
detail = _get_gmail_message_detail(gmail_creds, msg_id)
if detail:
headers = detail.get("payload", {}).get("headers", [])
subject = ""
from_addr = ""
date_str = ""
for h in headers:
h_name = h.get("name", "").lower()
if h_name == "subject":
subject = h.get("value", "")
elif h_name == "from":
from_addr = h.get("value", "")
elif h_name == "date":
date_str = h.get("value", "")
scanned.append({
"provider": "gmail",
"message_id": msg_id,
"subject": subject,
"from": from_addr,
"date": date_str,
"snippet": detail.get("snippet", ""),
})
# 尝试 Outlook
if provider in ("", "outlook"):
outlook_creds = _get_outlook_credentials()
if outlook_creds:
messages = _fetch_outlook_messages(outlook_creds, query=query, max_results=max_results)
for msg in messages:
scanned.append({
"provider": "outlook",
"message_id": msg.get("id", ""),
"subject": msg.get("subject", ""),
"from": msg.get("from", {}).get("emailAddress", {}).get("address", ""),
"date": msg.get("receivedDateTime", ""),
"snippet": msg.get("bodyPreview", ""),
})
# 尝试 IMAP(任意邮箱服务商)
if provider in ("", "imap"):
imap_results = _scan_imap({"count": max_results})
scanned.extend(imap_results)
if not scanned:
# 无法连接任何邮箱,提示配置
output_error(
"未能连接到任何邮箱。请确认以下环境变量已正确配置:\n"
" Gmail: DC_GMAIL_CREDENTIALS(OAuth2 凭据文件路径)\n"
" Outlook: DC_OUTLOOK_CLIENT_ID + DC_OUTLOOK_SECRET\n"
" IMAP: DC_EMAIL_USER + DC_EMAIL_PASSWORD(+ 可选 DC_IMAP_HOST)\n"
"详见 references/email-setup-guide.md",
code="NO_EMAIL_SOURCE",
)
return
# 存储扫描结果
existing_emails = _get_emails()
existing_ids = {e.get("message_id") for e in existing_emails}
new_count = 0
for item in scanned:
if item["message_id"] not in existing_ids:
email_record = {
"id": generate_id("E"),
"provider": item["provider"],
"message_id": item["message_id"],
"subject": item["subject"],
"from_address": item["from"],
"date": item["date"],
"snippet": item["snippet"],
"signal": None,
"signal_confidence": 0.0,
"matched_keywords": [],
"linked_deal_id": None,
"scanned_at": now_iso(),
}
existing_emails.append(email_record)
new_count += 1
_save_emails(existing_emails)
# 记录扫描模式到学习引擎
_record_scan_patterns(scanned)
output_success({
"message": f"扫描完成:发现 {len(scanned)} 封邮件,新增 {new_count} 封",
"total_scanned": len(scanned),
"new_emails": new_count,
"total_stored": len(existing_emails),
})
def extract_signals(data: Optional[Dict[str, Any]] = None) -> None:
"""分析已存储邮件的商机信号。
对所有未分析的邮件执行信号提取,或对指定邮件重新分析。
Args:
data: 可选参数,支持 email_id(指定邮件)、force(强制重新分析)。
"""
if not require_paid_feature("email_scan", "邮件信号提取"):
return
data = data or {}
email_id = data.get("email_id")
force = data.get("force", False)
emails = _get_emails()
if not emails:
output_error("暂无邮件记录,请先执行邮件扫描", code="NO_DATA")
return
analyzed_count = 0
results = {"POSITIVE": 0, "NEGATIVE": 0, "NEUTRAL": 0}
for email in emails:
# 按 ID 过滤
if email_id and email.get("id") != email_id:
continue
# 跳过已分析的(除非强制)
if email.get("signal") and not force:
continue
# 组合文本用于分析
text = f"{email.get('subject', '')} {email.get('snippet', '')}"
signal, confidence, keywords = _analyze_signal(text)
email["signal"] = signal
email["signal_confidence"] = confidence
email["matched_keywords"] = keywords
analyzed_count += 1
results[signal] = results.get(signal, 0) + 1
_save_emails(emails)
output_success({
"message": f"信号分析完成:已分析 {analyzed_count} 封邮件",
"analyzed": analyzed_count,
"signal_summary": results,
})
def link_deal(data: Dict[str, Any]) -> None:
"""将邮件关联到商机。
支持手动指定关联或自动按联系人邮箱匹配。
Args:
data: 参数字典,支持 email_id + deal_id(手动)或 auto=True(自动匹配)。
"""
if not require_paid_feature("email_scan", "邮件-商机关联"):
return
auto = data.get("auto", False)
emails = _get_emails()
deals = _get_deals()
if not emails:
output_error("暂无邮件记录", code="NO_DATA")
return
if auto:
# 自动匹配模式
linked_count = 0
for email in emails:
if email.get("linked_deal_id"):
continue
from_addr = email.get("from_address", "")
email_addrs = _extract_email_addresses(from_addr)
for addr in email_addrs:
matched_deal = _match_deal_by_contact(addr, deals)
if matched_deal:
email["linked_deal_id"] = matched_deal["id"]
linked_count += 1
break
_save_emails(emails)
output_success({
"message": f"自动关联完成:成功关联 {linked_count} 封邮件",
"linked": linked_count,
})
else:
# 手动关联模式
email_id = data.get("email_id")
deal_id = data.get("deal_id")
if not email_id or not deal_id:
output_error("手动关联需提供 email_id 和 deal_id", code="VALIDATION_ERROR")
return
target_email = None
for e in emails:
if e.get("id") == email_id:
target_email = e
break
if not target_email:
output_error(f"未找到ID为 {email_id} 的邮件", code="NOT_FOUND")
return
target_deal = None
for d in deals:
if d.get("id") == deal_id:
target_deal = d
break
if not target_deal:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
target_email["linked_deal_id"] = deal_id
_save_emails(emails)
output_success({
"message": f"邮件已关联到商机「{target_deal.get('name', '')}」",
"email_id": email_id,
"deal_id": deal_id,
"deal_name": target_deal.get("name", ""),
})
def list_emails(data: Optional[Dict[str, Any]] = None) -> None:
"""列出已存储的邮件记录。
可选过滤: deal_id, signal, provider
Args:
data: 可选的过滤条件字典。
"""
if not require_paid_feature("email_scan", "邮件列表"):
return
emails = _get_emails()
if data:
# 按商机过滤
deal_id = data.get("deal_id")
if deal_id:
emails = [e for e in emails if e.get("linked_deal_id") == deal_id]
# 按信号类型过滤
signal_filter = data.get("signal")
if signal_filter:
emails = [e for e in emails if e.get("signal") == signal_filter.upper()]
# 按来源过滤
provider_filter = data.get("provider")
if provider_filter:
emails = [e for e in emails if e.get("provider") == provider_filter.lower()]
# 按日期倒序
emails.sort(key=lambda e: e.get("date", ""), reverse=True)
# 脱敏处理
display_list = []
for e in emails:
display = dict(e)
if display.get("from_address"):
display["from_address"] = mask_email(display["from_address"])
display_list.append(display)
# 信号统计
signal_stats = {"POSITIVE": 0, "NEGATIVE": 0, "NEUTRAL": 0, "未分析": 0}
for e in emails:
sig = e.get("signal")
if sig in signal_stats:
signal_stats[sig] += 1
else:
signal_stats["未分析"] += 1
output_success({
"total": len(display_list),
"signal_stats": signal_stats,
"emails": display_list,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("deal-closer 邮件扫描器")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"scan": lambda: scan_emails(data),
"extract-signals": lambda: extract_signals(data),
"link-deal": lambda: link_deal(data or {}),
"list-emails": lambda: list_emails(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
deal-closer 共享工具模块
提供商机数据管理、订阅校验、数据格式化等通用功能。
成交加速器的核心工具库。
"""
import argparse
import json
import os
import re
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
# ============================================================
# 常量定义
# ============================================================
DEFAULT_DATA_DIR = os.path.join(os.path.expanduser("~"), ".openclaw-bdi", "deal-closer")
DEAL_STAGES = [
"线索",
"初步接触",
"需求确认",
"方案报价",
"商务谈判",
"合同签署",
"成交",
"流失",
]
STAGE_COLORS = {
"线索": "lightblue",
"初步接触": "blue",
"需求确认": "cyan",
"方案报价": "yellow",
"商务谈判": "orange",
"合同签署": "green",
"成交": "darkgreen",
"流失": "red",
}
# 阶段对应的默认成交概率(百分比)
STAGE_DEFAULT_PROBABILITY = {
"线索": 5,
"初步接触": 10,
"需求确认": 25,
"方案报价": 50,
"商务谈判": 70,
"合同签署": 90,
"成交": 100,
"流失": 0,
}
# 邮件信号类型
EMAIL_SIGNAL_TYPES = ["POSITIVE", "NEGATIVE", "NEUTRAL"]
# 跟进模板类型
FOLLOWUP_TEMPLATES = [
"introduction",
"proposal_followup",
"negotiation",
"closing",
"win_back",
]
# 会议类型
MEETING_TYPES = ["电话", "视频", "面谈", "线上演示", "商务宴请", "其他"]
# ============================================================
# 数据目录管理
# ============================================================
def get_data_dir() -> str:
"""获取数据存储目录路径。
优先读取环境变量 DC_DATA_DIR,若未设置则使用默认路径
~/.openclaw-bdi/deal-closer/。
自动创建目录(若不存在)。
Returns:
数据目录的绝对路径。
"""
data_dir = os.environ.get("DC_DATA_DIR", DEFAULT_DATA_DIR)
os.makedirs(data_dir, exist_ok=True)
return data_dir
def get_data_file(filename: str) -> str:
"""获取数据文件的完整路径。
Args:
filename: 文件名(如 "deals.json")。
Returns:
数据文件的绝对路径。
"""
return os.path.join(get_data_dir(), filename)
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_file(filepath: str) -> Any:
"""读取 JSON 文件并返回解析后的数据。
Args:
filepath: JSON 文件路径。
Returns:
解析后的数据对象。若文件不存在,返回空列表。
"""
if not os.path.exists(filepath):
return []
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return []
def write_json_file(filepath: str, data: Any) -> None:
"""将数据写入 JSON 文件。
Args:
filepath: 目标文件路径。
data: 待写入的数据。
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
Args:
data: 待输出的数据。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 命令行参数解析
# ============================================================
def parse_common_args(description: str = "deal-closer 成交加速器") -> argparse.ArgumentParser:
"""创建通用命令行参数解析器。
Args:
description: 工具描述文本。
Returns:
配置好通用参数的 ArgumentParser 实例。
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--action",
required=True,
help="操作类型",
)
parser.add_argument(
"--data",
default=None,
help="JSON 格式的数据字符串",
)
parser.add_argument(
"--data-file",
default=None,
help="JSON 数据文件路径",
)
return parser
def load_input_data(args: argparse.Namespace) -> Optional[Dict[str, Any]]:
"""从命令行参数加载输入数据。
优先使用 --data 参数,其次尝试 --data-file 参数。
Args:
args: 解析后的命令行参数。
Returns:
解析后的字典数据,若无输入数据则返回 None。
Raises:
ValueError: 当 JSON 解析失败或文件读取失败时抛出。
"""
if args.data:
try:
data = json.loads(args.data)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if args.data_file:
if not os.path.exists(args.data_file):
raise ValueError(f"数据文件不存在: {args.data_file}")
try:
with open(args.data_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"数据文件 JSON 解析失败: {e}")
return None
# ============================================================
# 订阅校验
# ============================================================
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_deals": 30,
"features": [
"deal_crud",
"basic_pipeline",
"manual_followup",
"csv_export",
],
},
"paid": {
"tier": "paid",
"max_deals": 500,
"features": [
"deal_crud",
"basic_pipeline",
"manual_followup",
"csv_export",
"email_scan",
"meeting_sync",
"forecast",
"ai_followup",
"mermaid_chart",
"advanced_analytics",
"bulk_import",
],
},
}
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 DC_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get("DC_SUBSCRIPTION_TIER", "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def require_paid_feature(feature_name: str, display_name: str) -> bool:
"""检查当前订阅是否支持指定功能。
若不支持,输出升级提示并返回 False。
Args:
feature_name: 功能内部名称。
display_name: 功能显示名称(用于提示信息)。
Returns:
True 表示功能可用,False 表示不可用(已输出错误信息)。
"""
sub = check_subscription()
if feature_name not in sub["features"]:
output_error(
f"「{display_name}」为付费版功能。当前为免费版,请升级至付费版(¥149/月)以使用此功能。",
code="SUBSCRIPTION_REQUIRED",
)
return False
return True
# ============================================================
# CRM 专用工具函数
# ============================================================
def validate_deal_stage(stage: str) -> str:
"""校验商机阶段是否合法。
Args:
stage: 待校验的阶段名称。
Returns:
合法的阶段名称。
Raises:
ValueError: 当阶段名称不合法时抛出。
"""
if stage not in DEAL_STAGES:
valid = "、".join(DEAL_STAGES)
raise ValueError(f"无效的商机阶段: {stage!r},有效阶段: {valid}")
return stage
def calculate_days_since(date_str: str) -> int:
"""计算从指定日期到今天的天数。
Args:
date_str: 日期字符串,格式为 YYYY-MM-DD 或 ISO 格式。
Returns:
距今天数(正数表示过去,负数表示未来)。
"""
try:
if "T" in date_str:
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
dt = dt.replace(tzinfo=None)
else:
dt = datetime.strptime(date_str, "%Y-%m-%d")
delta = datetime.now() - dt
return delta.days
except (ValueError, TypeError):
return 0
def mask_phone(phone: str) -> str:
"""对手机号进行脱敏处理。
将手机号中间 4 位替换为 ****。
Args:
phone: 原始手机号。
Returns:
脱敏后的手机号,如 138****8000。
若格式不符合,返回原始值。
"""
if not phone:
return phone
phone = phone.strip()
if re.match(r"^1[3-9]\d{9}$", phone):
return phone[:3] + "****" + phone[7:]
return phone
def mask_email(email: str) -> str:
"""对邮箱地址进行脱敏处理。
将 @ 前的部分仅保留前两个字符,其余替换为 ***。
Args:
email: 原始邮箱地址。
Returns:
脱敏后的邮箱,如 zh***@example.com。
若格式不符合,返回原始值。
"""
if not email:
return email
email = email.strip()
match = re.match(r"^([^@]{1,2})([^@]*)@(.+)$", email)
if match:
prefix = match.group(1)
domain = match.group(3)
return f"{prefix}***@{domain}"
return email
def generate_id(prefix: str = "D") -> str:
"""生成唯一 ID。
基于时间戳生成,格式为 前缀+时间戳。
Args:
prefix: ID 前缀,默认为 "D"(商机)。
Returns:
唯一 ID 字符串。
"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
return f"{prefix}{timestamp}"
def format_currency(value: float) -> str:
"""将数值格式化为人民币金额显示。
Args:
value: 金额数值。
Returns:
格式化后的金额字符串,如 "¥10.00万" 或 "¥5,000"。
"""
try:
num = float(value)
except (TypeError, ValueError):
return str(value)
abs_num = abs(num)
sign = "-" if num < 0 else ""
if abs_num >= 1e8:
return f"{sign}¥{abs_num / 1e8:.2f}亿"
elif abs_num >= 1e4:
return f"{sign}¥{abs_num / 1e4:.2f}万"
else:
return f"{sign}¥{abs_num:,.0f}"
def now_iso() -> str:
"""返回当前时间的 ISO 格式字符串。
Returns:
ISO 格式时间字符串,如 "2026-03-19T10:30:00"。
"""
return datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
def today_str() -> str:
"""返回今天的日期字符串。
Returns:
日期字符串,格式为 "YYYY-MM-DD"。
"""
return datetime.now().strftime("%Y-%m-%d")
def stage_display_name(stage: str) -> str:
"""获取商机阶段的显示名称。
Args:
stage: 阶段标识。
Returns:
阶段显示名称,若未知则返回原始值。
"""
if stage in DEAL_STAGES:
return stage
return stage
def get_stage_probability(stage: str) -> int:
"""获取阶段默认成交概率。
Args:
stage: 商机阶段。
Returns:
默认概率百分比值。
"""
return STAGE_DEFAULT_PROBABILITY.get(stage, 0)
def parse_amount(value: str) -> float:
"""解析金额字符串为数值。
支持带「万」「亿」等中文单位的数值。
Args:
value: 金额字符串。
Returns:
数值化的金额。
"""
if not value:
return 0.0
value = str(value).strip().replace(",", "").replace(",", "")
value = value.replace("¥", "").replace("¥", "").replace("元", "")
try:
if "亿" in value:
return float(value.replace("亿", "")) * 1e8
elif "万" in value:
return float(value.replace("万", "")) * 1e4
else:
return float(value)
except (ValueError, TypeError):
return 0.0
def format_percentage(value: float) -> str:
"""将小数格式化为百分比显示。
Args:
value: 小数值(如 0.15 表示 15%)。
Returns:
百分比字符串,如 "15.0%"。
"""
try:
return f"{float(value) * 100:.1f}%"
except (TypeError, ValueError):
return "0.0%"
def days_until(date_str: str) -> int:
"""计算从今天到指定日期的天数。
Args:
date_str: 目标日期字符串,格式为 YYYY-MM-DD。
Returns:
剩余天数(正数表示未来,负数表示已过期)。
"""
try:
if "T" in date_str:
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
dt = dt.replace(tzinfo=None)
else:
dt = datetime.strptime(date_str, "%Y-%m-%d")
delta = dt - datetime.now()
return delta.days
except (ValueError, TypeError):
return 0
FILE:scripts/deal_store.py
#!/usr/bin/env python3
"""
deal-closer 商机数据管理模块
提供商机数据的 CRUD 操作,支持 JSON 文件存储、CSV 导入导出、阶段历史追踪。
"""
import csv
import io
import json
import os
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
mask_phone,
mask_email,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
validate_deal_stage,
write_json_file,
format_currency,
parse_amount,
get_stage_probability,
calculate_days_since,
DEAL_STAGES,
)
# ============================================================
# 数据文件路径
# ============================================================
DEALS_FILE = "deals.json"
def _get_deals() -> List[Dict[str, Any]]:
"""读取所有商机数据。"""
return read_json_file(get_data_file(DEALS_FILE))
def _save_deals(deals: List[Dict[str, Any]]) -> None:
"""保存商机数据到文件。"""
write_json_file(get_data_file(DEALS_FILE), deals)
def _find_deal(deals: List[Dict], deal_id: str) -> Optional[Dict]:
"""根据 ID 查找商机。"""
for d in deals:
if d.get("id") == deal_id:
return d
return None
def _mask_deal(deal: Dict[str, Any]) -> Dict[str, Any]:
"""对商机中的敏感字段进行脱敏处理。"""
display = dict(deal)
if display.get("contact_phone"):
display["contact_phone"] = mask_phone(display["contact_phone"])
if display.get("contact_email"):
display["contact_email"] = mask_email(display["contact_email"])
return display
# ============================================================
# CRUD 操作
# ============================================================
def add_deal(data: Dict[str, Any]) -> None:
"""添加新商机。
必填字段: name
可选字段: contact_name, contact_phone, contact_email, company, amount,
stage, probability, source, expected_close_date, notes, tags
Args:
data: 商机数据字典。
"""
if not data.get("name"):
output_error("商机名称(name)为必填字段", code="VALIDATION_ERROR")
return
sub = check_subscription()
deals = _get_deals()
# 检查商机数量限制
if len(deals) >= sub["max_deals"]:
limit = sub["max_deals"]
if sub["tier"] == "free":
output_error(
f"免费版最多管理 {limit} 个商机,当前已有 {len(deals)} 个。"
"请升级至付费版(¥149/月)以管理更多商机。",
code="LIMIT_EXCEEDED",
)
else:
output_error(
f"已达到商机数量上限 {limit} 个。",
code="LIMIT_EXCEEDED",
)
return
# 校验阶段
stage = data.get("stage", "线索")
try:
validate_deal_stage(stage)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
# 解析金额
amount = 0.0
if "amount" in data:
amount = parse_amount(str(data["amount"]))
# 概率:若未提供则根据阶段自动设定
probability = data.get("probability")
if probability is None:
probability = get_stage_probability(stage)
else:
try:
probability = int(probability)
probability = max(0, min(100, probability))
except (TypeError, ValueError):
probability = get_stage_probability(stage)
# 标签处理
tags = data.get("tags", [])
if isinstance(tags, str):
tags = [t.strip() for t in tags.split(",") if t.strip()]
now = now_iso()
deal = {
"id": generate_id("D"),
"name": data["name"],
"contact_name": data.get("contact_name", ""),
"contact_phone": data.get("contact_phone", ""),
"contact_email": data.get("contact_email", ""),
"company": data.get("company", ""),
"amount": amount,
"stage": stage,
"probability": probability,
"source": data.get("source", ""),
"expected_close_date": data.get("expected_close_date", ""),
"notes": data.get("notes", ""),
"tags": tags,
"created_at": now,
"updated_at": now,
"stage_history": [
{"stage": stage, "timestamp": now},
],
}
deals.append(deal)
_save_deals(deals)
display = _mask_deal(deal)
display["amount_display"] = format_currency(deal["amount"])
output_success({"message": f"商机「{deal['name']}」已添加", "deal": display})
def update_deal(data: Dict[str, Any]) -> None:
"""更新商机信息。
必填字段: id
可更新字段: name, contact_name, contact_phone, contact_email, company,
amount, stage, probability, source, expected_close_date, notes, tags
Args:
data: 包含商机 ID 和待更新字段的字典。
"""
deal_id = data.get("id")
if not deal_id:
output_error("商机ID(id)为必填字段", code="VALIDATION_ERROR")
return
deals = _get_deals()
deal = _find_deal(deals, deal_id)
if not deal:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
updatable_fields = [
"name", "contact_name", "contact_phone", "contact_email",
"company", "source", "expected_close_date", "notes",
]
updated = False
for field in updatable_fields:
if field in data:
deal[field] = data[field]
updated = True
# 金额特殊处理
if "amount" in data:
deal["amount"] = parse_amount(str(data["amount"]))
updated = True
# 概率特殊处理
if "probability" in data:
try:
deal["probability"] = max(0, min(100, int(data["probability"])))
updated = True
except (TypeError, ValueError):
output_error("概率(probability)必须为 0-100 的整数", code="VALIDATION_ERROR")
return
# 标签特殊处理
if "tags" in data:
tags = data["tags"]
if isinstance(tags, str):
tags = [t.strip() for t in tags.split(",") if t.strip()]
deal["tags"] = tags
updated = True
# 阶段变更追踪
if "stage" in data:
new_stage = data["stage"]
try:
validate_deal_stage(new_stage)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
old_stage = deal.get("stage")
if new_stage != old_stage:
deal["stage"] = new_stage
# 若未手动设置概率,自动更新为新阶段的默认概率
if "probability" not in data:
deal["probability"] = get_stage_probability(new_stage)
# 记录阶段历史
if "stage_history" not in deal:
deal["stage_history"] = []
deal["stage_history"].append({
"stage": new_stage,
"timestamp": now_iso(),
})
updated = True
if not updated:
output_error("未提供任何待更新的字段", code="VALIDATION_ERROR")
return
deal["updated_at"] = now_iso()
_save_deals(deals)
display = _mask_deal(deal)
display["amount_display"] = format_currency(deal["amount"])
output_success({"message": f"商机「{deal['name']}」已更新", "deal": display})
def delete_deal(data: Dict[str, Any]) -> None:
"""删除商机。
必填字段: id
Args:
data: 包含商机 ID 的字典。
"""
deal_id = data.get("id")
if not deal_id:
output_error("商机ID(id)为必填字段", code="VALIDATION_ERROR")
return
deals = _get_deals()
original_count = len(deals)
deals = [d for d in deals if d.get("id") != deal_id]
if len(deals) == original_count:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
_save_deals(deals)
output_success({"message": f"商机 {deal_id} 已删除"})
def get_deal(data: Dict[str, Any]) -> None:
"""获取单个商机详情。
必填字段: id
Args:
data: 包含商机 ID 的字典。
"""
deal_id = data.get("id")
if not deal_id:
output_error("商机ID(id)为必填字段", code="VALIDATION_ERROR")
return
deals = _get_deals()
deal = _find_deal(deals, deal_id)
if not deal:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
display = _mask_deal(deal)
display["amount_display"] = format_currency(deal["amount"])
display["days_since_update"] = calculate_days_since(deal.get("updated_at", ""))
output_success(display)
def list_deals(data: Optional[Dict[str, Any]] = None) -> None:
"""列出所有商机。
可选过滤: stage, keyword(搜索名称/公司/联系人), tag, min_amount, max_amount
Args:
data: 可选的过滤条件字典。
"""
deals = _get_deals()
if data:
# 阶段过滤
stage_filter = data.get("stage")
if stage_filter:
deals = [d for d in deals if d.get("stage") == stage_filter]
# 关键词搜索
keyword = data.get("keyword", "").strip()
if keyword:
keyword_lower = keyword.lower()
deals = [
d for d in deals
if keyword_lower in d.get("name", "").lower()
or keyword_lower in d.get("company", "").lower()
or keyword_lower in d.get("contact_name", "").lower()
]
# 标签过滤
tag_filter = data.get("tag")
if tag_filter:
deals = [d for d in deals if tag_filter in d.get("tags", [])]
# 金额范围过滤
min_amount = data.get("min_amount")
if min_amount is not None:
deals = [d for d in deals if d.get("amount", 0) >= float(min_amount)]
max_amount = data.get("max_amount")
if max_amount is not None:
deals = [d for d in deals if d.get("amount", 0) <= float(max_amount)]
# 按更新时间倒序排列
deals.sort(key=lambda d: d.get("updated_at", ""), reverse=True)
# 脱敏处理
display_list = []
for d in deals:
display = _mask_deal(d)
display["amount_display"] = format_currency(d.get("amount", 0))
display_list.append(display)
# 按阶段分组统计
stage_stats = {}
for stage in DEAL_STAGES:
stage_deals = [d for d in deals if d.get("stage") == stage]
stage_stats[stage] = {
"count": len(stage_deals),
"total_amount": sum(d.get("amount", 0) for d in stage_deals),
}
# 汇总
total_amount = sum(d.get("amount", 0) for d in deals)
output_success({
"total": len(display_list),
"total_amount": total_amount,
"total_amount_display": format_currency(total_amount),
"stage_stats": stage_stats,
"deals": display_list,
})
def stage_history(data: Dict[str, Any]) -> None:
"""查看商机的阶段变更历史。
必填字段: id
Args:
data: 包含商机 ID 的字典。
"""
deal_id = data.get("id")
if not deal_id:
output_error("商机ID(id)为必填字段", code="VALIDATION_ERROR")
return
deals = _get_deals()
deal = _find_deal(deals, deal_id)
if not deal:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
history = deal.get("stage_history", [])
output_success({
"deal_id": deal_id,
"deal_name": deal.get("name", ""),
"current_stage": deal.get("stage", ""),
"history": history,
"total_changes": len(history),
})
def import_deals(data: Dict[str, Any]) -> None:
"""从 CSV 文件导入商机数据。
必填字段: file_path
Args:
data: 包含 CSV 文件路径的字典。
"""
file_path = data.get("file_path")
if not file_path:
output_error("CSV 文件路径(file_path)为必填字段", code="VALIDATION_ERROR")
return
if not os.path.exists(file_path):
output_error(f"文件不存在: {file_path}", code="FILE_NOT_FOUND")
return
sub = check_subscription()
deals = _get_deals()
imported = 0
skipped = 0
errors = []
# 中英文列名映射
column_map = {
"名称": "name", "商机名称": "name",
"联系人": "contact_name", "联系人姓名": "contact_name",
"手机": "contact_phone", "电话": "contact_phone",
"邮箱": "contact_email",
"公司": "company", "公司名称": "company",
"金额": "amount", "预算": "amount",
"阶段": "stage",
"概率": "probability",
"来源": "source",
"预计成交日期": "expected_close_date",
"备注": "notes",
"标签": "tags",
}
try:
with open(file_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row_num, row in enumerate(reader, start=2):
if len(deals) >= sub["max_deals"]:
errors.append(f"行 {row_num}: 已达商机数量上限 {sub['max_deals']}")
skipped += 1
continue
# 映射中文列名
mapped = {}
for key, value in row.items():
mapped_key = column_map.get(key, key)
mapped[mapped_key] = value
name = mapped.get("name", "").strip()
if not name:
errors.append(f"行 {row_num}: 缺少商机名称")
skipped += 1
continue
stage = mapped.get("stage", "").strip() or "线索"
if stage not in DEAL_STAGES:
stage = "线索"
amount = parse_amount(mapped.get("amount", ""))
probability = mapped.get("probability", "")
if probability:
try:
probability = max(0, min(100, int(probability)))
except (TypeError, ValueError):
probability = get_stage_probability(stage)
else:
probability = get_stage_probability(stage)
tags = mapped.get("tags", "")
if isinstance(tags, str) and tags:
tags = [t.strip() for t in tags.split(",") if t.strip()]
else:
tags = []
now = now_iso()
deal = {
"id": generate_id("D"),
"name": name,
"contact_name": mapped.get("contact_name", "").strip(),
"contact_phone": mapped.get("contact_phone", "").strip(),
"contact_email": mapped.get("contact_email", "").strip(),
"company": mapped.get("company", "").strip(),
"amount": amount,
"stage": stage,
"probability": probability,
"source": mapped.get("source", "").strip(),
"expected_close_date": mapped.get("expected_close_date", "").strip(),
"notes": mapped.get("notes", "").strip(),
"tags": tags,
"created_at": now,
"updated_at": now,
"stage_history": [
{"stage": stage, "timestamp": now},
],
}
deals.append(deal)
imported += 1
except Exception as e:
output_error(f"导入失败: {e}", code="IMPORT_ERROR")
return
_save_deals(deals)
result = {
"message": f"导入完成:成功 {imported} 条,跳过 {skipped} 条",
"imported": imported,
"skipped": skipped,
"total": len(deals),
}
if errors:
result["errors"] = errors[:10]
output_success(result)
def export_deals(data: Optional[Dict[str, Any]] = None) -> None:
"""导出商机数据到 CSV 格式。
可选字段: file_path(若不指定则输出到 stdout)
Args:
data: 可选的配置字典。
"""
deals = _get_deals()
if not deals:
output_error("暂无商机数据可导出", code="NO_DATA")
return
file_path = data.get("file_path") if data else None
fieldnames = [
"id", "name", "contact_name", "contact_phone", "contact_email",
"company", "amount", "stage", "probability", "source",
"expected_close_date", "notes", "tags", "created_at", "updated_at",
]
try:
if file_path:
with open(file_path, "w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for d in deals:
row = {}
for k in fieldnames:
val = d.get(k, "")
if k == "tags" and isinstance(val, list):
val = ",".join(val)
row[k] = val
writer.writerow(row)
output_success({
"message": f"已导出 {len(deals)} 条商机数据到 {file_path}",
"count": len(deals),
})
else:
output_buf = io.StringIO()
writer = csv.DictWriter(output_buf, fieldnames=fieldnames)
writer.writeheader()
for d in deals:
row = {}
for k in fieldnames:
val = d.get(k, "")
if k == "tags" and isinstance(val, list):
val = ",".join(val)
row[k] = val
writer.writerow(row)
output_success({"csv": output_buf.getvalue(), "count": len(deals)})
except IOError as e:
output_error(f"导出失败: {e}", code="EXPORT_ERROR")
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("deal-closer 商机数据管理")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"add": lambda: add_deal(data or {}),
"update": lambda: update_deal(data or {}),
"delete": lambda: delete_deal(data or {}),
"get": lambda: get_deal(data or {}),
"list": lambda: list_deals(data),
"import": lambda: import_deals(data or {}),
"export": lambda: export_deals(data),
"stage-history": lambda: stage_history(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/meeting_logger.py
#!/usr/bin/env python3
"""
deal-closer 会议记录模块
提供会议记录的增删改查、按商机/日期筛选、会议摘要生成等功能。
支持手动记录和日历 API 同步(付费功能)。
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
write_json_file,
calculate_days_since,
days_until,
MEETING_TYPES,
)
# ============================================================
# 数据文件路径
# ============================================================
MEETINGS_FILE = "meetings.json"
DEALS_FILE = "deals.json"
def _get_meetings() -> List[Dict[str, Any]]:
"""读取所有会议记录。"""
return read_json_file(get_data_file(MEETINGS_FILE))
def _save_meetings(meetings: List[Dict[str, Any]]) -> None:
"""保存会议记录到文件。"""
write_json_file(get_data_file(MEETINGS_FILE), meetings)
def _get_deals() -> List[Dict[str, Any]]:
"""读取所有商机数据。"""
return read_json_file(get_data_file(DEALS_FILE))
def _find_meeting(meetings: List[Dict], meeting_id: str) -> Optional[Dict]:
"""根据 ID 查找会议记录。"""
for m in meetings:
if m.get("id") == meeting_id:
return m
return None
# ============================================================
# 会议操作
# ============================================================
def log_meeting(data: Dict[str, Any]) -> None:
"""记录一次会议。
必填字段: deal_id, date
可选字段: attendees, type, location, notes, action_items, next_steps
Args:
data: 会议数据字典。
"""
deal_id = data.get("deal_id")
if not deal_id:
output_error("商机ID(deal_id)为必填字段", code="VALIDATION_ERROR")
return
meeting_date = data.get("date")
if not meeting_date:
output_error("会议日期(date)为必填字段", code="VALIDATION_ERROR")
return
# 验证商机是否存在
deals = _get_deals()
target_deal = None
for d in deals:
if d.get("id") == deal_id:
target_deal = d
break
if not target_deal:
output_error(f"未找到ID为 {deal_id} 的商机", code="NOT_FOUND")
return
# 处理参会人列表
attendees = data.get("attendees", [])
if isinstance(attendees, str):
attendees = [a.strip() for a in attendees.split(",") if a.strip()]
# 处理行动项列表
action_items = data.get("action_items", [])
if isinstance(action_items, str):
action_items = [a.strip() for a in action_items.split(";") if a.strip()]
# 处理下一步列表
next_steps = data.get("next_steps", [])
if isinstance(next_steps, str):
next_steps = [s.strip() for s in next_steps.split(";") if s.strip()]
# 会议类型校验
meeting_type = data.get("type", "其他")
if meeting_type not in MEETING_TYPES:
meeting_type = "其他"
now = now_iso()
meeting = {
"id": generate_id("M"),
"deal_id": deal_id,
"date": meeting_date,
"attendees": attendees,
"type": meeting_type,
"location": data.get("location", ""),
"notes": data.get("notes", ""),
"action_items": action_items,
"next_steps": next_steps,
"created_at": now,
}
meetings = _get_meetings()
meetings.append(meeting)
_save_meetings(meetings)
output_success({
"message": f"会议记录已添加(商机: {target_deal.get('name', '')})",
"meeting": meeting,
})
def list_meetings(data: Optional[Dict[str, Any]] = None) -> None:
"""列出会议记录。
可选过滤: deal_id, date_from, date_to, type
Args:
data: 可选的过滤条件字典。
"""
meetings = _get_meetings()
if data:
# 按商机过滤
deal_id = data.get("deal_id")
if deal_id:
meetings = [m for m in meetings if m.get("deal_id") == deal_id]
# 按日期范围过滤
date_from = data.get("date_from")
if date_from:
meetings = [m for m in meetings if m.get("date", "") >= date_from]
date_to = data.get("date_to")
if date_to:
meetings = [m for m in meetings if m.get("date", "") <= date_to]
# 按类型过滤
type_filter = data.get("type")
if type_filter:
meetings = [m for m in meetings if m.get("type") == type_filter]
# 按日期倒序
meetings.sort(key=lambda m: m.get("date", ""), reverse=True)
# 加载商机信息用于显示
deals = _get_deals()
deal_map = {d["id"]: d.get("name", "") for d in deals}
display_list = []
for m in meetings:
display = dict(m)
display["deal_name"] = deal_map.get(m.get("deal_id", ""), "未知商机")
display_list.append(display)
# 按类型统计
type_stats = {}
for mt in MEETING_TYPES:
type_stats[mt] = sum(1 for m in meetings if m.get("type") == mt)
output_success({
"total": len(display_list),
"type_stats": type_stats,
"meetings": display_list,
})
def upcoming_meetings(data: Optional[Dict[str, Any]] = None) -> None:
"""查看即将到来的会议。
可选参数: days(查看未来 N 天,默认 7 天)
Args:
data: 可选参数字典。
"""
data = data or {}
days_ahead = data.get("days", 7)
try:
days_ahead = int(days_ahead)
except (TypeError, ValueError):
days_ahead = 7
today = today_str()
future_date = (datetime.now() + timedelta(days=days_ahead)).strftime("%Y-%m-%d")
meetings = _get_meetings()
# 筛选未来会议
upcoming = []
for m in meetings:
meeting_date = m.get("date", "")
if meeting_date and today <= meeting_date <= future_date:
upcoming.append(m)
# 按日期正序
upcoming.sort(key=lambda m: m.get("date", ""))
# 加载商机信息
deals = _get_deals()
deal_map = {d["id"]: d.get("name", "") for d in deals}
display_list = []
for m in upcoming:
display = dict(m)
display["deal_name"] = deal_map.get(m.get("deal_id", ""), "未知商机")
meeting_date = m.get("date", "")
if meeting_date:
display["days_until"] = days_until(meeting_date)
display_list.append(display)
output_success({
"total": len(display_list),
"period": f"{today} 至 {future_date}",
"meetings": display_list,
})
def meeting_summary(data: Dict[str, Any]) -> None:
"""生成会议摘要。
支持按商机汇总或按时间段汇总。
Args:
data: 参数字典,支持 deal_id 或 date_from + date_to。
"""
meetings = _get_meetings()
deals = _get_deals()
deal_map = {d["id"]: d for d in deals}
deal_id = data.get("deal_id")
date_from = data.get("date_from")
date_to = data.get("date_to")
# 过滤
if deal_id:
meetings = [m for m in meetings if m.get("deal_id") == deal_id]
if date_from:
meetings = [m for m in meetings if m.get("date", "") >= date_from]
if date_to:
meetings = [m for m in meetings if m.get("date", "") <= date_to]
if not meetings:
output_error("指定范围内暂无会议记录", code="NO_DATA")
return
# 按日期正序
meetings.sort(key=lambda m: m.get("date", ""))
# 收集所有行动项和下一步
all_action_items = []
all_next_steps = []
total_attendees = set()
deal_summary = {}
for m in meetings:
for item in m.get("action_items", []):
all_action_items.append({
"item": item,
"meeting_date": m.get("date", ""),
"deal_id": m.get("deal_id", ""),
})
for step in m.get("next_steps", []):
all_next_steps.append({
"step": step,
"meeting_date": m.get("date", ""),
"deal_id": m.get("deal_id", ""),
})
for attendee in m.get("attendees", []):
total_attendees.add(attendee)
# 按商机汇总
mid = m.get("deal_id", "未分类")
if mid not in deal_summary:
deal_info = deal_map.get(mid, {})
deal_summary[mid] = {
"deal_name": deal_info.get("name", "未知商机"),
"meeting_count": 0,
"latest_date": "",
"types": [],
}
deal_summary[mid]["meeting_count"] += 1
deal_summary[mid]["latest_date"] = m.get("date", "")
mtype = m.get("type", "")
if mtype and mtype not in deal_summary[mid]["types"]:
deal_summary[mid]["types"].append(mtype)
output_success({
"total_meetings": len(meetings),
"total_attendees": len(total_attendees),
"attendees": sorted(list(total_attendees)),
"action_items": all_action_items,
"next_steps": all_next_steps,
"deal_summary": deal_summary,
"date_range": {
"from": meetings[0].get("date", "") if meetings else "",
"to": meetings[-1].get("date", "") if meetings else "",
},
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("deal-closer 会议记录")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"log": lambda: log_meeting(data or {}),
"list": lambda: list_meetings(data),
"upcoming": lambda: upcoming_meetings(data),
"summary": lambda: meeting_summary(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:references/email-setup-guide.md
# 邮件服务配置指南
本指南帮助你配置 Gmail 和 Outlook 邮箱集成,以启用邮件扫描和信号提取功能。
---
## Gmail 配置(OAuth2)
### 步骤 1:创建 Google Cloud 项目
1. 访问 [Google Cloud Console](https://console.cloud.google.com/)
2. 创建新项目或选择已有项目
3. 启用 **Gmail API**:导航到「API 和服务」>「库」,搜索 Gmail API 并启用
### 步骤 2:配置 OAuth 同意屏幕
1. 导航到「API 和服务」>「OAuth 同意屏幕」
2. 选择「外部」用户类型
3. 填写应用信息(名称、支持邮箱等)
4. 添加范围:`https://www.googleapis.com/auth/gmail.readonly`
5. 添加测试用户(你的邮箱地址)
### 步骤 3:创建 OAuth 凭据
1. 导航到「API 和服务」>「凭据」
2. 点击「创建凭据」>「OAuth 客户端 ID」
3. 应用类型选择「桌面应用」
4. 下载 JSON 凭据文件,保存到安全位置
### 步骤 4:设置环境变量
```bash
export DC_GMAIL_CREDENTIALS="/path/to/gmail_credentials.json"
```
### 步骤 5:验证连接
首次使用时,系统会引导你完成 OAuth 授权流程,在浏览器中确认授权即可。
---
## Outlook 配置
### 步骤 1:注册 Azure AD 应用
1. 访问 [Azure Portal](https://portal.azure.com/)
2. 导航到「Azure Active Directory」>「应用注册」
3. 点击「新注册」
4. 填写应用名称,选择「任何组织目录中的帐户」
5. 重定向 URI 设置为 `http://localhost:8080/callback`
### 步骤 2:配置 API 权限
1. 在应用注册页面,点击「API 权限」
2. 添加权限:Microsoft Graph > 委托的权限
3. 选择 `Mail.Read` 权限
4. 点击「授予管理员同意」
### 步骤 3:创建客户端密钥
1. 点击「证书和密钥」>「新客户端密钥」
2. 设置描述和过期时间
3. 记录生成的密钥值(仅显示一次)
### 步骤 4:设置环境变量
```bash
export DC_OUTLOOK_CLIENT_ID="你的应用(客户端)ID"
export DC_OUTLOOK_SECRET="你的客户端密钥"
```
### 步骤 5:验证连接
首次使用时会引导完成 OAuth2 授权流程。
---
## 安全注意事项
- 凭据文件请妥善保管,不要提交到代码仓库
- 建议将环境变量写入 `.env` 文件并加入 `.gitignore`
- 定期轮换客户端密钥
- 仅授予最小必要权限(只读邮件权限)
- OAuth token 过期后需重新授权
---
## 常见问题
### Q: OAuth 授权失败怎么办?
检查重定向 URI 是否正确配置,确保测试用户已添加(Gmail)或管理员已同意权限(Outlook)。
### Q: 只能读取最近多少邮件?
默认扫描最近 50 封,可通过 `max_results` 参数调整(最大 500)。
### Q: 支持其他邮箱吗?
目前仅支持 Gmail 和 Outlook。其他邮箱请通过 CSV 导入方式手动添加邮件记录。
FILE:references/pipeline-templates.md
# 销售管道报告模板
本文档包含管道报告的标准格式和 Mermaid 图表示例。
---
## 漏斗报告模板
```markdown
# 销售漏斗报告 — YYYY-MM-DD
## 管道概览
| 阶段 | 商机数 | 金额 | 转化率 |
|------|--------|------|--------|
| 线索 | 12 | ¥120.00万 | - |
| 初步接触 | 8 | ¥95.00万 | 66.7% |
| 需求确认 | 5 | ¥60.00万 | 62.5% |
| 方案报价 | 3 | ¥45.00万 | 60.0% |
| 商务谈判 | 2 | ¥30.00万 | 66.7% |
| 合同签署 | 1 | ¥15.00万 | 50.0% |
## 风险预警
- 商机A: 已 21 天未更新,建议立即跟进
- 商机B: 已超出预计成交日期 7 天
```
---
## Mermaid 图表示例
### 饼图:商机阶段分布
```mermaid
pie title 商机阶段分布
"线索" : 12
"初步接触" : 8
"需求确认" : 5
"方案报价" : 3
"商务谈判" : 2
"合同签署" : 1
```
### 柱状图:各阶段金额
```mermaid
xychart-beta
title "各阶段金额(万元)"
x-axis ["线索", "初步接触", "需求确认", "方案报价", "商务谈判", "合同签署"]
y-axis "万元"
bar [120, 95, 60, 45, 30, 15]
```
### 折线图:月度趋势
```mermaid
xychart-beta
title "月度新增商机趋势"
x-axis ["10月", "11月", "12月", "01月", "02月", "03月"]
y-axis "数量"
line [8, 12, 10, 15, 11, 18]
```
---
## 收入预测模板
```markdown
# 收入预测报告
## 加权预测汇总
| 阶段 | 商机数 | 原始金额 | 加权金额 |
|------|--------|----------|----------|
| 需求确认 | 5 | ¥60.00万 | ¥15.00万 |
| 方案报价 | 3 | ¥45.00万 | ¥22.50万 |
| 商务谈判 | 2 | ¥30.00万 | ¥21.00万 |
| 合同签署 | 1 | ¥15.00万 | ¥13.50万 |
| **合计** | **11** | **¥150.00万** | **¥72.00万** |
## Top 5 高价值商机
| 商机 | 阶段 | 金额 | 概率 | 加权金额 |
|------|------|------|------|----------|
| 项目A | 商务谈判 | ¥20.00万 | 70% | ¥14.00万 |
| 项目B | 方案报价 | ¥18.00万 | 50% | ¥9.00万 |
```
---
## 周报/月报模板
```markdown
# 销售周报 — 2026-03-16 ~ 2026-03-22
## 本周概览
| 指标 | 数值 |
|------|------|
| 新增商机 | 5 个 |
| 成交商机 | 1 个 |
| 成交金额 | ¥15.00万 |
| 活跃管道 | ¥180.00万 |
| 风险商机 | 3 个 |
| 会议 | 8 场 |
## 下周重点
1. 重点跟进项目A(商务谈判阶段)
2. 项目B 预计本周提交方案
3. 处理 3 个风险商机
```
客户脉搏 — 轻量级CRM助手,追踪客户跟进状态,不让商机"掉地上
---
name: customer-pulse
description: 客户脉搏 — 轻量级CRM助手,追踪客户跟进状态,不让商机"掉地上"
version: 1.0.0
metadata:
openclaw:
optional_env:
- CP_SUBSCRIPTION_TIER
- CP_DATA_DIR
---
# 客户脉搏(customer-pulse)
你是一个专业的客户关系管理助手 Agent。你的职责是帮助用户管理客户台账、记录跟进活动、追踪销售漏斗、预警客户流失风险。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `CP_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `CP_DATA_DIR` | 否 | 数据存储目录,默认 `~/.openclaw-bdi/customer-pulse/` |
数据以 JSON 文件存储在本地,无需外部数据库。
---
## 流程一:客户台账管理
当用户说"添加客户"、"录入客户"、"新建客户"或提供客户信息时,执行以下步骤:
### 步骤 1:解析客户信息
从用户输入中提取以下字段:
- **姓名**(必填)
- 手机号、公司、意向产品、预算、来源
- 销售阶段:初步接触(默认)/ 需求确认 / 方案报价 / 谈判 / 成交 / 流失
示例输入:"添加客户 张总 / 手机13800138000 / 意向产品A / 预算10万"
### 步骤 2:创建客户卡片
```bash
python3 scripts/customer_store.py --action add --data '{"name":"张总","phone":"13800138000","product_interest":"产品A","budget":100000}'
```
### 步骤 3:确认并展示
向用户展示新建的客户卡片,确认信息无误。手机号自动脱敏显示(如 138****8000)。
### 其他操作
- **更新客户**:`--action update --data '{"id":"C...","stage":"需求确认"}'`
- **删除客户**:`--action delete --data '{"id":"C..."}'`
- **查看详情**:`--action get --data '{"id":"C..."}'`
- **客户列表**:`--action list`,支持按阶段过滤 `--data '{"stage":"谈判"}'`
- **导入 CSV**:`--action import --data '{"file_path":"客户名单.csv"}'`
- **导出 CSV**:`--action export --data '{"file_path":"导出.csv"}'`
---
## 流程二:跟进记录与提醒
当用户说"今天跟进了..."、"记录跟进"、"联系了客户"时,执行以下步骤:
### 步骤 1:解析跟进内容
从用户输入中提取:
- 客户姓名或 ID(必填)
- 跟进内容(必填)
- 下一步行动(可选)
- 下次跟进日期(付费版可自定义)
示例输入:"今天跟进了张总,他说下周开会讨论"
### 步骤 2:记录跟进
```bash
python3 scripts/followup_tracker.py --action record --data '{"customer_name":"张总","content":"客户表示下周开会讨论","next_action":"等待会议结果后再联系"}'
```
- 免费版:自动设置 3 天后提醒
- 付费版:可自定义提醒周期,如 `"reminder_days": 7` 或 `"next_followup_date": "2026-03-26"`
### 步骤 3:确认提醒
告知用户跟进已记录,并显示下次跟进提醒日期。
### 查看待跟进清单
当用户说"哪些客户该跟进了"、"待跟进"、"跟进提醒"时:
```bash
python3 scripts/followup_tracker.py --action list-pending
```
按最后跟进时间排序,超期未跟进的客户标红显示,输出格式:
```
待跟进客户清单:
| 客户 | 公司 | 阶段 | 最后跟进 | 距今天数 | 状态 |
|------|------|------|----------|----------|------|
| 张总 | ABC公司 | 方案报价 | 3月15日 | 4天 | ⚠️ 超期 |
| 李经理 | XYZ集团 | 需求确认 | 3月18日 | 1天 | 正常 |
```
### 查看今日提醒
```bash
python3 scripts/followup_tracker.py --action reminders
```
### 查看跟进历史
```bash
python3 scripts/followup_tracker.py --action history --data '{"customer_name":"张总"}'
```
---
## 流程三:销售漏斗分析
当用户说"这个月成交情况怎么样"、"漏斗分析"、"转化率"时,执行以下步骤:
### 步骤 1:漏斗概览
```bash
python3 scripts/pipeline_analyzer.py --action funnel
```
展示各阶段客户数量、预算总额和转化率:
```
销售漏斗概览:
初步接触 (15) → 需求确认 (10) → 方案报价 (6) → 谈判 (3) → 成交 (2)
67% 60% 50% 67%
总体成交率:5.6%
流失客户:4个,涉及预算 ¥45万
```
### 步骤 2:月度统计
```bash
python3 scripts/pipeline_analyzer.py --action monthly-stats --data '{"month":"2026-03"}'
```
### 步骤 3:转化率详细分析(付费版)
```bash
python3 scripts/pipeline_analyzer.py --action conversion
```
包含各阶段平均停留时长、平均跟进次数、转化瓶颈分析和优化建议。
### 步骤 4:综合报告
```bash
python3 scripts/pipeline_analyzer.py --action report
```
付费版报告包含 Mermaid 可视化图表(漏斗图、饼图)。
---
## 流程四:客户流失预警(付费版)
当用户说"客户流失风险"、"哪些客户可能流失"、"流失预警"时:
### 步骤 1:订阅校验
验证当前为付费版。免费版用户提示:"客户流失预警为付费版功能,请升级至付费版(¥99/月)以使用。"
### 步骤 2:执行预测
```bash
python3 scripts/churn_predictor.py --action predict
```
基于以下因素评估流失风险:
- 跟进频率衰减趋势(近期间隔 vs 历史平均)
- 最后跟进距今天数
- 跟进总次数
### 步骤 3:展示结果
```
客户流失风险预警:
| 客户 | 公司 | 阶段 | 风险等级 | 风险分 | 主要因素 |
|------|------|------|----------|--------|----------|
| 王总 | DEF公司 | 方案报价 | 🔴 高风险 | 85 | 超过30天未跟进 |
| 赵经理 | GHI集团 | 需求确认 | 🟡 中风险 | 55 | 跟进频率下降 |
高风险客户涉及预算:¥120万
建议:立即安排对王总的跟进,了解最新情况
```
### 查看高风险列表
```bash
python3 scripts/churn_predictor.py --action risk-list
```
---
## 订阅校验逻辑
### 读取订阅等级
```
tier = env CP_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥99/月) |
|------|---------------|----------------------|
| 客户台账管理 | 50 个客户 | 500 个客户 |
| 跟进记录 | ✅ 手动录入 | ✅ 手动 + 导入 |
| 客户状态看板 | 基础列表 | 漏斗图 + 热力图 |
| 跟进提醒 | 3 天未跟进提醒 | 自定义周期 |
| 客户流失预警 | ❌ | ✅ AI 预测 |
| 客户画像分析 | ❌ | ✅ |
| 成交率分析 | ❌ | ✅ 漏斗转化 |
| CSV 导入/导出 | ✅ | ✅ |
| Mermaid 可视化 | ❌ | ✅ |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥99/月),请联系管理员或访问订阅管理页面。"
4. 提供免费版可用的替代方案(如果有的话)。
---
## 安全规范
1. **手机号脱敏**:所有客户手机号在输出时自动脱敏处理(如 138****8000),绝不在对话中完整显示手机号。
2. **数据本地化**:所有客户数据存储在本地 JSON 文件中,不上传到任何外部服务。
3. **错误处理**:执行命令失败时,向用户展示友好的错误提示,不暴露内部路径或系统信息。
4. **敏感信息保护**:不在日志或对话中记录客户的敏感商业信息(如具体报价金额、合同条款)。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 客户姓名支持模糊匹配,方便用户快速操作。
3. 每次跟进记录后主动提醒下次跟进日期。
4. 对用户的问题给出清晰、结构化的回答,优先使用表格展示客户数据。
5. 主动提醒超期未跟进的客户,帮助用户避免遗漏商机。
6. 分析结果附带具体可操作的建议,而不仅仅是展示数据。
7. 遇到模糊的客户意图时,主动追问以明确需求。
8. 尊重订阅等级限制,在提示升级时保持友好,不反复推销。
FILE:assets/README.md
# 客户脉搏 (customer-pulse)
> 轻量级CRM助手,追踪客户跟进状态,不让商机"掉地上"
---
## 功能亮点
- 📇 **客户台账管理** — 快速录入客户信息,一句话添加客户卡片,支持 CSV 批量导入导出
- 📝 **智能跟进提醒** — 自动追踪跟进状态,超期未联系的客户及时提醒,再也不会遗漏商机
- 📊 **销售漏斗分析** — 可视化展示从初步接触到成交的全流程,清晰掌握每个阶段的转化率
- 🔔 **客户流失预警** — AI 分析跟进频率趋势,提前标注可能流失的高风险客户(付费版)
- 💬 **自然语言交互** — 用日常对话的方式管理客户,"今天跟进了张总"即可完成记录
- 🔒 **数据本地存储** — 所有客户数据存储在本地,手机号自动脱敏,保护商业隐私
---
## 版本对比
| 功能 | 免费版 | 付费版 ¥99/月 |
|------|:------:|:------------:|
| 客户台账管理 | 50 个客户 | 500 个客户 |
| 跟进记录 | ✅ 手动录入 | ✅ 手动 + 导入 |
| 客户状态看板 | 基础列表 | 漏斗图 + 热力图 |
| 跟进提醒 | 3 天未跟进提醒 | 自定义周期 |
| 客户流失预警 | ❌ | ✅ AI 预测 |
| 客户画像分析 | ❌ | ✅ |
| 成交率分析 | ❌ | ✅ 漏斗转化 |
| CSV 导入/导出 | ✅ | ✅ |
| Mermaid 可视化 | ❌ | ✅ |
---
## 快速开始
### 1. 安装 Skill
在 ClawHub 中搜索 `customer-pulse`,点击安装,或使用命令行:
```bash
openclaw skill install customer-pulse
```
### 2. 添加客户
直接用自然语言输入:
```
添加客户 张总 / 手机13800138000 / 公司ABC科技 / 意向产品A / 预算10万
```
或批量导入已有客户数据:
```
/customer-pulse import 客户名单.csv
```
### 3. 记录跟进
```
今天跟进了张总,他说下周开会讨论,需要准备一份详细方案
```
系统自动记录跟进内容,并设置下次提醒。
### 4. 查看提醒
```
哪些客户该跟进了?
```
系统按紧急程度排序,输出待跟进清单。
### 5. 分析漏斗
```
这个月成交情况怎么样?
```
系统统计各阶段转化率,给出优化建议。
---
## 使用示例
### 待跟进清单
```
待跟进客户清单(共 5 位超期未跟进):
| 客户 | 公司 | 阶段 | 最后跟进 | 距今 | 状态 |
|------|------|------|----------|------|------|
| 王总 | DEF科技 | 方案报价 | 3月12日 | 7天 | ⚠️ 超期 |
| 赵经理 | GHI集团 | 需求确认 | 3月14日 | 5天 | ⚠️ 超期 |
| 李总 | JKL公司 | 谈判 | 3月16日 | 3天 | ⚠️ 超期 |
| 陈总 | MNO科技 | 初步接触 | 3月17日 | 2天 | 正常 |
| 刘经理 | PQR集团 | 需求确认 | 3月18日 | 1天 | 正常 |
建议优先跟进王总(方案报价阶段,7天未联系),
可电话了解方案评审进展。
```
### 销售漏斗概览
```
本月销售漏斗:
初步接触 (15) ──67%──▸ 需求确认 (10) ──60%──▸ 方案报价 (6)
──50%──▸ 谈判 (3) ──67%──▸ 成交 (2)
本月成交金额:¥45.00万
总体成交率:5.6%
流失客户:4个,涉及预算 ¥32.00万
建议关注:方案报价→谈判 转化率偏低(50%),
建议优化报价策略,提供灵活套餐选择。
```
---
## 常见问题
### Q1: 数据存储在哪里?
客户数据以 JSON 文件存储在 `~/.openclaw-bdi/customer-pulse/` 目录下。可通过环境变量 `CP_DATA_DIR` 自定义存储路径。所有数据保留在本地,不上传到任何外部服务。
### Q2: CSV 导入支持什么格式?
支持包含以下表头的 CSV 文件(中英文表头均可):姓名/name、手机/phone、公司/company、意向产品/product_interest、预算/budget、阶段/stage、来源/source。编码建议使用 UTF-8。
### Q3: 免费版到达 50 个客户上限后怎么办?
可以升级至付费版(¥99/月)扩展到 500 个客户。也可以删除不再活跃的客户(已成交或已流失)来腾出名额。
### Q4: 流失预警的准确度如何?
流失预警基于跟进频率衰减趋势分析,是一种统计预测方法。它通过比较近期跟进间隔与历史平均值来判断风险。准确度取决于跟进数据的完整性——跟进记录越详细,预测越准确。
### Q5: 如何备份客户数据?
使用导出功能即可备份:`/customer-pulse export backup.csv`。建议定期导出备份,防止数据丢失。
### Q6: 支持多人协作吗?
当前版本为个人使用设计。多人协作功能(共享客户池、协作跟进)计划在后续版本中推出。
---
## 技术支持
- 📖 **文档**:查看 `references/crm-guide.md` 获取 CRM 实践指南
- 🐛 **问题反馈**:在 ClawHub 的 Skill 页面提交 Issue
- 💬 **社区讨论**:加入 ClawHub 社区频道 `#customer-pulse`
- 📧 **邮件**:[email protected]
---
*customer-pulse v1.0 | 兼容 OpenClaw 0.5+*
FILE:scripts/customer_store.py
#!/usr/bin/env python3
"""
customer-pulse 客户数据管理模块
提供客户数据的 CRUD 操作,支持 JSON 文件存储、CSV 导入导出。
"""
import csv
import io
import json
import os
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
mask_phone,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
validate_stage,
write_json_file,
STAGES,
)
# ============================================================
# 数据文件路径
# ============================================================
CUSTOMERS_FILE = "customers.json"
def _get_customers() -> List[Dict[str, Any]]:
"""读取所有客户数据。"""
return read_json_file(get_data_file(CUSTOMERS_FILE))
def _save_customers(customers: List[Dict[str, Any]]) -> None:
"""保存客户数据到文件。"""
write_json_file(get_data_file(CUSTOMERS_FILE), customers)
def _find_customer(customers: List[Dict], customer_id: str) -> Optional[Dict]:
"""根据 ID 查找客户。"""
for c in customers:
if c.get("id") == customer_id:
return c
return None
# ============================================================
# CRUD 操作
# ============================================================
def add_customer(data: Dict[str, Any]) -> None:
"""添加新客户。
必填字段: name
可选字段: phone, company, product_interest, budget, stage, source
Args:
data: 客户数据字典。
"""
if not data.get("name"):
output_error("客户姓名(name)为必填字段", code="VALIDATION_ERROR")
return
sub = check_subscription()
customers = _get_customers()
if len(customers) >= sub["max_customers"]:
limit = sub["max_customers"]
if sub["tier"] == "free":
output_error(
f"免费版最多管理 {limit} 个客户,当前已有 {len(customers)} 个。"
"请升级至付费版(¥99/月)以管理更多客户。",
code="LIMIT_EXCEEDED",
)
else:
output_error(
f"已达到客户数量上限 {limit} 个。",
code="LIMIT_EXCEEDED",
)
return
stage = data.get("stage", "初步接触")
try:
validate_stage(stage)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
now = now_iso()
customer = {
"id": generate_id("C"),
"name": data["name"],
"phone": data.get("phone", ""),
"company": data.get("company", ""),
"product_interest": data.get("product_interest", ""),
"budget": data.get("budget", 0),
"stage": stage,
"source": data.get("source", ""),
"created_at": now,
"updated_at": now,
}
customers.append(customer)
_save_customers(customers)
# 输出时脱敏手机号
display = dict(customer)
display["phone"] = mask_phone(display["phone"])
output_success({"message": f"客户「{customer['name']}」已添加", "customer": display})
def update_customer(data: Dict[str, Any]) -> None:
"""更新客户信息。
必填字段: id
可更新字段: name, phone, company, product_interest, budget, stage, source
Args:
data: 包含客户 ID 和待更新字段的字典。
"""
customer_id = data.get("id")
if not customer_id:
output_error("客户ID(id)为必填字段", code="VALIDATION_ERROR")
return
customers = _get_customers()
customer = _find_customer(customers, customer_id)
if not customer:
output_error(f"未找到ID为 {customer_id} 的客户", code="NOT_FOUND")
return
updatable_fields = ["name", "phone", "company", "product_interest", "budget", "stage", "source"]
updated = False
for field in updatable_fields:
if field in data:
if field == "stage":
try:
validate_stage(data[field])
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
customer[field] = data[field]
updated = True
if not updated:
output_error("未提供任何待更新的字段", code="VALIDATION_ERROR")
return
customer["updated_at"] = now_iso()
_save_customers(customers)
display = dict(customer)
display["phone"] = mask_phone(display["phone"])
output_success({"message": f"客户「{customer['name']}」已更新", "customer": display})
def delete_customer(data: Dict[str, Any]) -> None:
"""删除客户。
必填字段: id
Args:
data: 包含客户 ID 的字典。
"""
customer_id = data.get("id")
if not customer_id:
output_error("客户ID(id)为必填字段", code="VALIDATION_ERROR")
return
customers = _get_customers()
original_count = len(customers)
customers = [c for c in customers if c.get("id") != customer_id]
if len(customers) == original_count:
output_error(f"未找到ID为 {customer_id} 的客户", code="NOT_FOUND")
return
_save_customers(customers)
output_success({"message": f"客户 {customer_id} 已删除"})
def get_customer(data: Dict[str, Any]) -> None:
"""获取单个客户详情。
必填字段: id
Args:
data: 包含客户 ID 的字典。
"""
customer_id = data.get("id")
if not customer_id:
output_error("客户ID(id)为必填字段", code="VALIDATION_ERROR")
return
customers = _get_customers()
customer = _find_customer(customers, customer_id)
if not customer:
output_error(f"未找到ID为 {customer_id} 的客户", code="NOT_FOUND")
return
display = dict(customer)
display["phone"] = mask_phone(display["phone"])
output_success(display)
def list_customers(data: Optional[Dict[str, Any]] = None) -> None:
"""列出所有客户。
可选过滤: stage, keyword(搜索姓名/公司/产品)
Args:
data: 可选的过滤条件字典。
"""
customers = _get_customers()
if data:
stage_filter = data.get("stage")
keyword = data.get("keyword", "").strip()
if stage_filter:
customers = [c for c in customers if c.get("stage") == stage_filter]
if keyword:
keyword_lower = keyword.lower()
customers = [
c for c in customers
if keyword_lower in c.get("name", "").lower()
or keyword_lower in c.get("company", "").lower()
or keyword_lower in c.get("product_interest", "").lower()
]
# 按更新时间倒序排列
customers.sort(key=lambda c: c.get("updated_at", ""), reverse=True)
# 脱敏手机号
display_list = []
for c in customers:
d = dict(c)
d["phone"] = mask_phone(d["phone"])
display_list.append(d)
# 按阶段分组统计
stage_stats = {}
for stage in STAGES:
stage_stats[stage] = sum(1 for c in customers if c.get("stage") == stage)
output_success({
"total": len(display_list),
"stage_stats": stage_stats,
"customers": display_list,
})
def import_customers(data: Dict[str, Any]) -> None:
"""从 CSV 文件导入客户数据。
必填字段: file_path
Args:
data: 包含 CSV 文件路径的字典。
"""
file_path = data.get("file_path")
if not file_path:
output_error("CSV 文件路径(file_path)为必填字段", code="VALIDATION_ERROR")
return
if not os.path.exists(file_path):
output_error(f"文件不存在: {file_path}", code="FILE_NOT_FOUND")
return
sub = check_subscription()
customers = _get_customers()
imported = 0
skipped = 0
errors = []
try:
with open(file_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row_num, row in enumerate(reader, start=2):
if len(customers) >= sub["max_customers"]:
errors.append(f"行 {row_num}: 已达客户数量上限 {sub['max_customers']}")
skipped += 1
continue
name = row.get("name", "").strip() or row.get("姓名", "").strip()
if not name:
errors.append(f"行 {row_num}: 缺少客户姓名")
skipped += 1
continue
stage = row.get("stage", "").strip() or row.get("阶段", "").strip() or "初步接触"
if stage not in STAGES:
stage = "初步接触"
now = now_iso()
customer = {
"id": generate_id("C"),
"name": name,
"phone": row.get("phone", "").strip() or row.get("手机", "").strip(),
"company": row.get("company", "").strip() or row.get("公司", "").strip(),
"product_interest": row.get("product_interest", "").strip() or row.get("意向产品", "").strip(),
"budget": _parse_budget(row.get("budget", "") or row.get("预算", "")),
"stage": stage,
"source": row.get("source", "").strip() or row.get("来源", "").strip(),
"created_at": now,
"updated_at": now,
}
customers.append(customer)
imported += 1
except Exception as e:
output_error(f"导入失败: {e}", code="IMPORT_ERROR")
return
_save_customers(customers)
result = {
"message": f"导入完成:成功 {imported} 条,跳过 {skipped} 条",
"imported": imported,
"skipped": skipped,
"total": len(customers),
}
if errors:
result["errors"] = errors[:10]
output_success(result)
def export_customers(data: Optional[Dict[str, Any]] = None) -> None:
"""导出客户数据到 CSV 格式。
可选字段: file_path(若不指定则输出到 stdout)
Args:
data: 可选的配置字典。
"""
customers = _get_customers()
if not customers:
output_error("暂无客户数据可导出", code="NO_DATA")
return
file_path = data.get("file_path") if data else None
fieldnames = ["id", "name", "phone", "company", "product_interest", "budget", "stage", "source", "created_at", "updated_at"]
try:
if file_path:
with open(file_path, "w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for c in customers:
row = {k: c.get(k, "") for k in fieldnames}
writer.writerow(row)
output_success({"message": f"已导出 {len(customers)} 条客户数据到 {file_path}", "count": len(customers)})
else:
output_buf = io.StringIO()
writer = csv.DictWriter(output_buf, fieldnames=fieldnames)
writer.writeheader()
for c in customers:
row = {k: c.get(k, "") for k in fieldnames}
writer.writerow(row)
output_success({"csv": output_buf.getvalue(), "count": len(customers)})
except IOError as e:
output_error(f"导出失败: {e}", code="EXPORT_ERROR")
# ============================================================
# 辅助函数
# ============================================================
def _parse_budget(value: str) -> float:
"""解析预算字符串为数值。
支持带「万」「亿」等中文单位的数值。
Args:
value: 预算字符串。
Returns:
数值化的预算金额。
"""
if not value:
return 0
value = str(value).strip().replace(",", "").replace(",", "")
value = value.replace("¥", "").replace("¥", "").replace("元", "")
try:
if "亿" in value:
return float(value.replace("亿", "")) * 1e8
elif "万" in value:
return float(value.replace("万", "")) * 1e4
else:
return float(value)
except (ValueError, TypeError):
return 0
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("customer-pulse 客户数据管理")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"add": lambda: add_customer(data or {}),
"update": lambda: update_customer(data or {}),
"delete": lambda: delete_customer(data or {}),
"get": lambda: get_customer(data or {}),
"list": lambda: list_customers(data),
"import": lambda: import_customers(data or {}),
"export": lambda: export_customers(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/pipeline_analyzer.py
#!/usr/bin/env python3
"""
customer-pulse 销售漏斗与转化分析模块
提供销售阶段分析、转化率统计、月度报表生成等功能。
"""
import json
import os
import sys
from collections import Counter
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
format_currency,
get_data_file,
load_input_data,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
today_str,
STAGES,
)
# ============================================================
# 数据文件
# ============================================================
CUSTOMERS_FILE = "customers.json"
FOLLOWUPS_FILE = "followups.json"
def _get_customers() -> List[Dict[str, Any]]:
"""读取所有客户数据。"""
return read_json_file(get_data_file(CUSTOMERS_FILE))
def _get_followups() -> List[Dict[str, Any]]:
"""读取所有跟进记录。"""
return read_json_file(get_data_file(FOLLOWUPS_FILE))
# ============================================================
# 漏斗分析
# ============================================================
def analyze_funnel(data: Optional[Dict[str, Any]] = None) -> None:
"""分析销售漏斗各阶段分布。
统计各阶段客户数量、预算总额,生成漏斗视图。
Args:
data: 可选过滤参数。
"""
customers = _get_customers()
if not customers:
output_error("暂无客户数据", code="NO_DATA")
return
# 按阶段统计
funnel_stages = ["初步接触", "需求确认", "方案报价", "谈判", "成交"]
stage_data = []
for stage in funnel_stages:
stage_customers = [c for c in customers if c.get("stage") == stage]
total_budget = sum(float(c.get("budget", 0)) for c in stage_customers)
stage_data.append({
"stage": stage,
"count": len(stage_customers),
"budget_total": total_budget,
"budget_display": format_currency(total_budget),
})
# 流失客户单独统计
lost_customers = [c for c in customers if c.get("stage") == "流失"]
lost_budget = sum(float(c.get("budget", 0)) for c in lost_customers)
# 计算漏斗转化率(相邻阶段)
conversions = []
for i in range(len(funnel_stages) - 1):
from_count = stage_data[i]["count"]
to_count = stage_data[i + 1]["count"]
rate = (to_count / from_count * 100) if from_count > 0 else 0
conversions.append({
"from_stage": funnel_stages[i],
"to_stage": funnel_stages[i + 1],
"from_count": from_count,
"to_count": to_count,
"conversion_rate": f"{rate:.1f}%",
})
# 总转化率
total_leads = stage_data[0]["count"] + stage_data[1]["count"] + stage_data[2]["count"] + stage_data[3]["count"]
total_won = stage_data[4]["count"]
overall_rate = (total_won / (total_leads + total_won) * 100) if (total_leads + total_won) > 0 else 0
result = {
"funnel": stage_data,
"conversions": conversions,
"lost": {
"count": len(lost_customers),
"budget_total": lost_budget,
"budget_display": format_currency(lost_budget),
},
"overall": {
"total_customers": len(customers),
"active_pipeline": total_leads,
"won": total_won,
"lost": len(lost_customers),
"overall_conversion_rate": f"{overall_rate:.1f}%",
},
}
# 付费版:生成 Mermaid 漏斗图
sub = check_subscription()
if "mermaid_chart" in sub["features"]:
mermaid = _generate_funnel_mermaid(stage_data)
result["mermaid_chart"] = mermaid
output_success(result)
def analyze_conversion(data: Optional[Dict[str, Any]] = None) -> None:
"""分析各阶段转化率详情。
付费功能:包含详细的阶段停留时长、转化瓶颈分析。
Args:
data: 可选参数。
"""
if not require_paid_feature("conversion_analysis", "成交率分析"):
return
customers = _get_customers()
followups = _get_followups()
if not customers:
output_error("暂无客户数据", code="NO_DATA")
return
funnel_stages = ["初步接触", "需求确认", "方案报价", "谈判", "成交"]
stage_details = []
for i, stage in enumerate(funnel_stages):
stage_customers = [c for c in customers if c.get("stage") == stage]
# 计算平均停留时长
total_days = 0
counted = 0
for c in stage_customers:
created = c.get("created_at", "")[:10]
if created:
from utils import calculate_days_since
days = calculate_days_since(created)
total_days += days
counted += 1
avg_days = total_days / counted if counted > 0 else 0
# 平均跟进次数
stage_customer_ids = {c["id"] for c in stage_customers}
stage_followups = [f for f in followups if f.get("customer_id") in stage_customer_ids]
avg_followups = len(stage_followups) / len(stage_customers) if stage_customers else 0
stage_details.append({
"stage": stage,
"count": len(stage_customers),
"avg_days_in_stage": round(avg_days, 1),
"avg_followup_count": round(avg_followups, 1),
"total_budget": format_currency(sum(float(c.get("budget", 0)) for c in stage_customers)),
})
# 瓶颈分析
bottleneck = None
min_rate = 100
for i in range(len(funnel_stages) - 1):
from_count = stage_details[i]["count"]
to_count = stage_details[i + 1]["count"]
rate = (to_count / from_count * 100) if from_count > 0 else 100
if rate < min_rate:
min_rate = rate
bottleneck = {
"from_stage": funnel_stages[i],
"to_stage": funnel_stages[i + 1],
"conversion_rate": f"{rate:.1f}%",
"suggestion": _get_bottleneck_suggestion(funnel_stages[i], funnel_stages[i + 1]),
}
output_success({
"stage_details": stage_details,
"bottleneck": bottleneck,
})
def monthly_stats(data: Optional[Dict[str, Any]] = None) -> None:
"""生成月度统计数据。
统计指定月份的新增客户、成交客户、跟进活动等数据。
Args:
data: 可选参数,包含 month(格式 YYYY-MM)。
"""
customers = _get_customers()
followups = _get_followups()
# 确定统计月份
if data and data.get("month"):
month_str = data["month"]
else:
month_str = today_str()[:7]
year, month = map(int, month_str.split("-"))
# 本月新增客户
new_customers = [
c for c in customers
if c.get("created_at", "").startswith(month_str)
]
# 本月成交客户
won_customers = [
c for c in customers
if c.get("stage") == "成交" and c.get("updated_at", "").startswith(month_str)
]
# 本月流失客户
lost_customers = [
c for c in customers
if c.get("stage") == "流失" and c.get("updated_at", "").startswith(month_str)
]
# 本月跟进次数
month_followups = [
f for f in followups
if f.get("date", "").startswith(month_str)
]
# 成交金额
won_budget = sum(float(c.get("budget", 0)) for c in won_customers)
# 活跃客户(本月有跟进记录的独立客户数)
active_customer_ids = {f.get("customer_id") for f in month_followups}
result = {
"month": month_str,
"new_customers": len(new_customers),
"won_deals": len(won_customers),
"lost_deals": len(lost_customers),
"won_revenue": won_budget,
"won_revenue_display": format_currency(won_budget),
"total_followups": len(month_followups),
"active_customers": len(active_customer_ids),
"total_customers": len(customers),
"win_rate": f"{len(won_customers) / len(new_customers) * 100:.1f}%" if new_customers else "N/A",
}
output_success(result)
def generate_report(data: Optional[Dict[str, Any]] = None) -> None:
"""生成综合销售分析报告。
整合漏斗分析、月度统计、客户分布等信息。
Args:
data: 可选参数。
"""
customers = _get_customers()
followups = _get_followups()
if not customers:
output_error("暂无客户数据,无法生成报告", code="NO_DATA")
return
sub = check_subscription()
today = today_str()
month_str = today[:7]
# 基础统计
stage_counts = Counter(c.get("stage", "未知") for c in customers)
source_counts = Counter(c.get("source", "未知") for c in customers if c.get("source"))
total_budget = sum(float(c.get("budget", 0)) for c in customers)
won_budget = sum(float(c.get("budget", 0)) for c in customers if c.get("stage") == "成交")
# 本月数据
month_new = sum(1 for c in customers if c.get("created_at", "").startswith(month_str))
month_won = sum(1 for c in customers if c.get("stage") == "成交" and c.get("updated_at", "").startswith(month_str))
month_followups = sum(1 for f in followups if f.get("date", "").startswith(month_str))
report = {
"report_date": today,
"summary": {
"total_customers": len(customers),
"total_pipeline_budget": format_currency(total_budget),
"total_won_revenue": format_currency(won_budget),
"month_new_customers": month_new,
"month_won_deals": month_won,
"month_followup_count": month_followups,
},
"stage_distribution": {stage: stage_counts.get(stage, 0) for stage in STAGES},
"source_distribution": dict(source_counts),
}
# 付费版:添加 Mermaid 图表
if "mermaid_chart" in sub["features"]:
# 阶段分布饼图
pie_data = {stage: stage_counts.get(stage, 0) for stage in STAGES if stage_counts.get(stage, 0) > 0}
report["mermaid_stage_chart"] = _generate_pie_mermaid("客户阶段分布", pie_data)
# 来源分布饼图
if source_counts:
report["mermaid_source_chart"] = _generate_pie_mermaid("客户来源分布", dict(source_counts))
output_success(report)
# ============================================================
# Mermaid 图表生成
# ============================================================
def _generate_funnel_mermaid(stage_data: List[Dict]) -> str:
"""生成漏斗 Mermaid 图表。"""
lines = ["```mermaid", "graph LR"]
for i, s in enumerate(stage_data):
node_id = f"S{i}"
label = f"{s['stage']}<br/>{s['count']}个客户<br/>{s['budget_display']}"
lines.append(f" {node_id}[\"{label}\"]")
if i > 0:
prev_count = stage_data[i - 1]["count"]
rate = (s["count"] / prev_count * 100) if prev_count > 0 else 0
lines.append(f" S{i-1} -->|{rate:.0f}%| {node_id}")
lines.append("```")
return "\n".join(lines)
def _generate_pie_mermaid(title: str, data: Dict[str, int]) -> str:
"""生成饼图 Mermaid 图表。"""
lines = ["```mermaid", f"pie title {title}"]
for label, value in data.items():
lines.append(f' "{label}" : {value}')
lines.append("```")
return "\n".join(lines)
# ============================================================
# 瓶颈建议
# ============================================================
def _get_bottleneck_suggestion(from_stage: str, to_stage: str) -> str:
"""根据转化瓶颈位置给出优化建议。"""
suggestions = {
("初步接触", "需求确认"): "建议加强初次接触后的需求挖掘,提前准备行业案例和产品演示材料",
("需求确认", "方案报价"): "建议缩短方案制作周期,提供标准化方案模板,加速报价流程",
("方案报价", "谈判"): "建议优化报价策略,提供灵活的套餐选择,及时跟进客户反馈",
("谈判", "成交"): "建议加强谈判技巧,了解客户决策链,提供限时优惠促进签约",
}
return suggestions.get((from_stage, to_stage), "建议增加跟进频率,深入了解客户需求")
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("customer-pulse 销售漏斗分析")
args = parser.parse_args()
action = args.action.lower().replace("-", "_")
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"funnel": lambda: analyze_funnel(data),
"conversion": lambda: analyze_conversion(data),
"monthly_stats": lambda: monthly_stats(data),
"report": lambda: generate_report(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/churn_predictor.py
#!/usr/bin/env python3
"""
customer-pulse 客户流失预测模块(付费功能)
基于跟进频率衰减趋势预测客户流失风险,帮助提前干预。
"""
import json
import os
import sys
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
calculate_days_since,
check_subscription,
format_currency,
get_data_file,
load_input_data,
mask_phone,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
today_str,
)
# ============================================================
# 数据文件
# ============================================================
CUSTOMERS_FILE = "customers.json"
FOLLOWUPS_FILE = "followups.json"
def _get_customers() -> List[Dict[str, Any]]:
"""读取所有客户数据。"""
return read_json_file(get_data_file(CUSTOMERS_FILE))
def _get_followups() -> List[Dict[str, Any]]:
"""读取所有跟进记录。"""
return read_json_file(get_data_file(FOLLOWUPS_FILE))
# ============================================================
# 流失预测核心逻辑
# ============================================================
def _calculate_churn_risk(customer: Dict, followups: List[Dict]) -> Dict[str, Any]:
"""计算单个客户的流失风险。
使用简单的统计方法:
1. 计算历史平均跟进间隔
2. 计算近期跟进间隔
3. 比较近期间隔与历史平均的偏差
4. 综合最后跟进距今天数计算风险分数
Args:
customer: 客户数据。
followups: 该客户的跟进记录列表。
Returns:
包含风险评估详情的字典。
"""
cid = customer["id"]
customer_followups = sorted(
[f for f in followups if f.get("customer_id") == cid],
key=lambda f: f.get("date", ""),
)
today = today_str()
risk_score = 0 # 0-100,越高风险越大
risk_factors = []
if not customer_followups:
# 从未跟进
days_since_created = calculate_days_since(customer.get("created_at", "")[:10])
risk_score = min(90, 40 + days_since_created * 3)
risk_factors.append(f"从未跟进,客户创建已 {days_since_created} 天")
elif len(customer_followups) == 1:
# 只跟进过一次
last_date = customer_followups[0].get("date", "")
days_since = calculate_days_since(last_date)
risk_score = min(85, 20 + days_since * 4)
risk_factors.append(f"仅跟进 1 次,距今 {days_since} 天")
else:
# 计算跟进间隔
intervals = []
for i in range(1, len(customer_followups)):
d1 = customer_followups[i - 1].get("date", "")
d2 = customer_followups[i].get("date", "")
if d1 and d2:
try:
dt1 = datetime.strptime(d1, "%Y-%m-%d")
dt2 = datetime.strptime(d2, "%Y-%m-%d")
interval = (dt2 - dt1).days
if interval > 0:
intervals.append(interval)
except ValueError:
continue
if intervals:
avg_interval = sum(intervals) / len(intervals)
# 近期间隔(最后 3 个)
recent_intervals = intervals[-3:] if len(intervals) >= 3 else intervals
recent_avg = sum(recent_intervals) / len(recent_intervals)
# 频率衰减比
decay_ratio = recent_avg / avg_interval if avg_interval > 0 else 1
# 最后跟进距今天数
last_date = customer_followups[-1].get("date", "")
days_since_last = calculate_days_since(last_date)
# 风险评分计算
# 基础分:基于频率衰减
if decay_ratio > 2.0:
risk_score += 40
risk_factors.append(f"跟进频率显著下降(衰减比 {decay_ratio:.1f}x)")
elif decay_ratio > 1.5:
risk_score += 25
risk_factors.append(f"跟进频率有所下降(衰减比 {decay_ratio:.1f}x)")
elif decay_ratio > 1.2:
risk_score += 10
risk_factors.append(f"跟进频率略有放缓(衰减比 {decay_ratio:.1f}x)")
# 附加分:基于最后跟进距今天数
if days_since_last > 30:
risk_score += 35
risk_factors.append(f"超过 30 天未跟进({days_since_last} 天)")
elif days_since_last > 14:
risk_score += 25
risk_factors.append(f"超过 14 天未跟进({days_since_last} 天)")
elif days_since_last > 7:
risk_score += 15
risk_factors.append(f"超过 7 天未跟进({days_since_last} 天)")
# 附加分:跟进次数过少
if len(customer_followups) <= 2:
risk_score += 10
risk_factors.append(f"跟进次数较少(仅 {len(customer_followups)} 次)")
else:
last_date = customer_followups[-1].get("date", "")
days_since_last = calculate_days_since(last_date)
risk_score = min(80, 30 + days_since_last * 3)
risk_factors.append(f"距最后跟进已 {days_since_last} 天")
# 预算因素:高预算客户流失风险需更关注
budget = float(customer.get("budget", 0))
impact = "高" if budget >= 100000 else ("中" if budget >= 10000 else "低")
# 限制风险分在 0-100
risk_score = max(0, min(100, risk_score))
# 风险等级
if risk_score >= 70:
risk_level = "高风险"
elif risk_score >= 40:
risk_level = "中风险"
else:
risk_level = "低风险"
return {
"customer_id": cid,
"customer_name": customer["name"],
"company": customer.get("company", ""),
"stage": customer.get("stage", ""),
"phone": mask_phone(customer.get("phone", "")),
"budget": budget,
"budget_display": format_currency(budget),
"risk_score": risk_score,
"risk_level": risk_level,
"risk_factors": risk_factors,
"impact": impact,
"total_followups": len(customer_followups),
"last_followup_date": customer_followups[-1].get("date", "") if customer_followups else "无",
"suggestion": _get_risk_suggestion(risk_level, risk_factors),
}
def _get_risk_suggestion(risk_level: str, factors: List[str]) -> str:
"""根据风险等级和因素给出建议。"""
if risk_level == "高风险":
return "建议立即安排跟进,了解客户最新情况,必要时由主管介入沟通"
elif risk_level == "中风险":
return "建议在本周内安排一次跟进,重新建立联系,了解客户需求变化"
else:
return "当前状态良好,建议保持正常跟进节奏"
# ============================================================
# 预测操作
# ============================================================
def predict_churn(data: Optional[Dict[str, Any]] = None) -> None:
"""对所有活跃客户进行流失风险预测。
仅分析处于活跃阶段(非成交、非流失)的客户。
Args:
data: 可选参数。
"""
if not require_paid_feature("churn_prediction", "客户流失预警"):
return
customers = _get_customers()
followups = _get_followups()
# 只分析活跃客户
active_customers = [
c for c in customers
if c.get("stage") not in ("成交", "流失")
]
if not active_customers:
output_error("暂无活跃客户可分析", code="NO_DATA")
return
predictions = []
for customer in active_customers:
risk = _calculate_churn_risk(customer, followups)
predictions.append(risk)
# 按风险分数降序排列
predictions.sort(key=lambda p: p["risk_score"], reverse=True)
# 统计
high_risk = sum(1 for p in predictions if p["risk_level"] == "高风险")
mid_risk = sum(1 for p in predictions if p["risk_level"] == "中风险")
low_risk = sum(1 for p in predictions if p["risk_level"] == "低风险")
# 高风险客户的总预算
at_risk_budget = sum(p["budget"] for p in predictions if p["risk_level"] == "高风险")
output_success({
"analysis_date": today_str(),
"total_analyzed": len(predictions),
"risk_summary": {
"high": high_risk,
"medium": mid_risk,
"low": low_risk,
},
"at_risk_budget": format_currency(at_risk_budget),
"predictions": predictions,
})
def risk_list(data: Optional[Dict[str, Any]] = None) -> None:
"""获取高风险客户列表。
只返回风险等级为「高风险」的客户。
Args:
data: 可选参数。
"""
if not require_paid_feature("churn_prediction", "客户流失预警"):
return
customers = _get_customers()
followups = _get_followups()
active_customers = [
c for c in customers
if c.get("stage") not in ("成交", "流失")
]
high_risk_list = []
for customer in active_customers:
risk = _calculate_churn_risk(customer, followups)
if risk["risk_level"] == "高风险":
high_risk_list.append(risk)
high_risk_list.sort(key=lambda p: p["risk_score"], reverse=True)
total_at_risk_budget = sum(r["budget"] for r in high_risk_list)
output_success({
"analysis_date": today_str(),
"high_risk_count": len(high_risk_list),
"total_at_risk_budget": format_currency(total_at_risk_budget),
"high_risk_customers": high_risk_list,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("customer-pulse 客户流失预测")
args = parser.parse_args()
action = args.action.lower().replace("-", "_")
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"predict": lambda: predict_churn(data),
"risk_list": lambda: risk_list(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
customer-pulse 共享工具模块
提供客户数据管理、订阅校验、数据格式化等通用功能。
"""
import argparse
import json
import os
import re
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
# ============================================================
# 常量定义
# ============================================================
DEFAULT_DATA_DIR = os.path.join(os.path.expanduser("~"), ".openclaw-bdi", "customer-pulse")
STAGES = ["初步接触", "需求确认", "方案报价", "谈判", "成交", "流失"]
STAGE_COLORS = {
"初步接触": "blue",
"需求确认": "cyan",
"方案报价": "yellow",
"谈判": "orange",
"成交": "green",
"流失": "red",
}
# ============================================================
# 数据目录管理
# ============================================================
def get_data_dir() -> str:
"""获取数据存储目录路径。
优先读取环境变量 CP_DATA_DIR,若未设置则使用默认路径
~/.openclaw-bdi/customer-pulse/。
自动创建目录(若不存在)。
Returns:
数据目录的绝对路径。
"""
data_dir = os.environ.get("CP_DATA_DIR", DEFAULT_DATA_DIR)
os.makedirs(data_dir, exist_ok=True)
return data_dir
def get_data_file(filename: str) -> str:
"""获取数据文件的完整路径。
Args:
filename: 文件名(如 "customers.json")。
Returns:
数据文件的绝对路径。
"""
return os.path.join(get_data_dir(), filename)
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_file(filepath: str) -> Any:
"""读取 JSON 文件并返回解析后的数据。
Args:
filepath: JSON 文件路径。
Returns:
解析后的数据对象。若文件不存在,返回空列表。
"""
if not os.path.exists(filepath):
return []
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return []
def write_json_file(filepath: str, data: Any) -> None:
"""将数据写入 JSON 文件。
Args:
filepath: 目标文件路径。
data: 待写入的数据。
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
Args:
data: 待输出的数据。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 命令行参数解析
# ============================================================
def parse_common_args(description: str = "customer-pulse 客户管理工具") -> argparse.ArgumentParser:
"""创建通用命令行参数解析器。
Args:
description: 工具描述文本。
Returns:
配置好通用参数的 ArgumentParser 实例。
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--action",
required=True,
help="操作类型",
)
parser.add_argument(
"--data",
default=None,
help="JSON 格式的数据字符串",
)
parser.add_argument(
"--data-file",
default=None,
help="JSON 数据文件路径",
)
return parser
def load_input_data(args: argparse.Namespace) -> Optional[Dict[str, Any]]:
"""从命令行参数加载输入数据。
优先使用 --data 参数,其次尝试 --data-file 参数。
Args:
args: 解析后的命令行参数。
Returns:
解析后的字典数据,若无输入数据则返回 None。
Raises:
ValueError: 当 JSON 解析失败或文件读取失败时抛出。
"""
if args.data:
try:
data = json.loads(args.data)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if args.data_file:
if not os.path.exists(args.data_file):
raise ValueError(f"数据文件不存在: {args.data_file}")
try:
with open(args.data_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"数据文件 JSON 解析失败: {e}")
return None
# ============================================================
# 订阅校验
# ============================================================
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_customers": 50,
"followup_reminder": "fixed",
"reminder_days": 3,
"features": ["customer_crud", "followup_record", "basic_list", "csv_export", "csv_import"],
},
"paid": {
"tier": "paid",
"max_customers": 500,
"followup_reminder": "custom",
"reminder_days": None,
"features": [
"customer_crud",
"followup_record",
"basic_list",
"csv_export",
"csv_import",
"bulk_import",
"funnel_chart",
"heatmap",
"churn_prediction",
"customer_profile",
"conversion_analysis",
"custom_reminder",
"mermaid_chart",
],
},
}
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 CP_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get("CP_SUBSCRIPTION_TIER", "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def require_paid_feature(feature_name: str, display_name: str) -> bool:
"""检查当前订阅是否支持指定功能。
若不支持,输出升级提示并返回 False。
Args:
feature_name: 功能内部名称。
display_name: 功能显示名称(用于提示信息)。
Returns:
True 表示功能可用,False 表示不可用(已输出错误信息)。
"""
sub = check_subscription()
if feature_name not in sub["features"]:
output_error(
f"「{display_name}」为付费版功能。当前为免费版,请升级至付费版(¥99/月)以使用此功能。",
code="SUBSCRIPTION_REQUIRED",
)
return False
return True
# ============================================================
# CRM 专用工具函数
# ============================================================
def stage_display_name(stage: str) -> str:
"""获取销售阶段的显示名称。
Args:
stage: 阶段标识。
Returns:
阶段显示名称,若未知则返回原始值。
"""
if stage in STAGES:
return stage
return stage
def validate_stage(stage: str) -> str:
"""校验销售阶段是否合法。
Args:
stage: 待校验的阶段名称。
Returns:
合法的阶段名称。
Raises:
ValueError: 当阶段名称不合法时抛出。
"""
if stage not in STAGES:
valid = "、".join(STAGES)
raise ValueError(f"无效的销售阶段: {stage!r},有效阶段: {valid}")
return stage
def calculate_days_since(date_str: str) -> int:
"""计算从指定日期到今天的天数。
Args:
date_str: 日期字符串,格式为 YYYY-MM-DD 或 ISO 格式。
Returns:
距今天数(正数表示过去,负数表示未来)。
"""
try:
if "T" in date_str:
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
dt = dt.replace(tzinfo=None)
else:
dt = datetime.strptime(date_str, "%Y-%m-%d")
delta = datetime.now() - dt
return delta.days
except (ValueError, TypeError):
return 0
def mask_phone(phone: str) -> str:
"""对手机号进行脱敏处理。
将手机号中间 4 位替换为 ****。
Args:
phone: 原始手机号。
Returns:
脱敏后的手机号,如 138****8000。
若格式不符合,返回原始值。
"""
if not phone:
return phone
phone = phone.strip()
if re.match(r"^1[3-9]\d{9}$", phone):
return phone[:3] + "****" + phone[7:]
return phone
def generate_id(prefix: str = "C") -> str:
"""生成唯一 ID。
基于时间戳生成,格式为 前缀+时间戳。
Args:
prefix: ID 前缀,默认为 "C"(客户)。
Returns:
唯一 ID 字符串。
"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
return f"{prefix}{timestamp}"
def format_currency(value: float) -> str:
"""将数值格式化为人民币金额显示。
Args:
value: 金额数值。
Returns:
格式化后的金额字符串,如 "¥10.00万" 或 "¥5,000"。
"""
try:
num = float(value)
except (TypeError, ValueError):
return str(value)
abs_num = abs(num)
sign = "-" if num < 0 else ""
if abs_num >= 1e8:
return f"{sign}¥{abs_num / 1e8:.2f}亿"
elif abs_num >= 1e4:
return f"{sign}¥{abs_num / 1e4:.2f}万"
else:
return f"{sign}¥{abs_num:,.0f}"
def now_iso() -> str:
"""返回当前时间的 ISO 格式字符串。
Returns:
ISO 格式时间字符串,如 "2026-03-19T10:30:00"。
"""
return datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
def today_str() -> str:
"""返回今天的日期字符串。
Returns:
日期字符串,格式为 "YYYY-MM-DD"。
"""
return datetime.now().strftime("%Y-%m-%d")
FILE:scripts/followup_tracker.py
#!/usr/bin/env python3
"""
customer-pulse 跟进记录与提醒模块
记录客户跟进活动,计算下次跟进日期,生成待跟进提醒清单。
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
calculate_days_since,
check_subscription,
generate_id,
get_data_file,
load_input_data,
mask_phone,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
today_str,
write_json_file,
)
# ============================================================
# 数据文件
# ============================================================
FOLLOWUPS_FILE = "followups.json"
CUSTOMERS_FILE = "customers.json"
def _get_followups() -> List[Dict[str, Any]]:
"""读取所有跟进记录。"""
return read_json_file(get_data_file(FOLLOWUPS_FILE))
def _save_followups(followups: List[Dict[str, Any]]) -> None:
"""保存跟进记录。"""
write_json_file(get_data_file(FOLLOWUPS_FILE), followups)
def _get_customers() -> List[Dict[str, Any]]:
"""读取所有客户数据。"""
return read_json_file(get_data_file(CUSTOMERS_FILE))
def _find_customer(customers: List[Dict], customer_id: str) -> Optional[Dict]:
"""根据 ID 查找客户。"""
for c in customers:
if c.get("id") == customer_id:
return c
return None
def _find_customer_by_name(customers: List[Dict], name: str) -> Optional[Dict]:
"""根据姓名查找客户(模糊匹配)。"""
name = name.strip()
for c in customers:
if c.get("name", "").strip() == name:
return c
for c in customers:
if name in c.get("name", ""):
return c
return None
# ============================================================
# 跟进操作
# ============================================================
def record_followup(data: Dict[str, Any]) -> None:
"""记录一次跟进活动。
必填字段: customer_id 或 customer_name, content
可选字段: date, next_action, next_followup_date, reminder_days
Args:
data: 跟进数据字典。
"""
customers = _get_customers()
# 支持通过 ID 或姓名查找客户
customer_id = data.get("customer_id")
customer_name = data.get("customer_name")
customer = None
if customer_id:
customer = _find_customer(customers, customer_id)
elif customer_name:
customer = _find_customer_by_name(customers, customer_name)
if not customer:
identifier = customer_id or customer_name or "未指定"
output_error(f"未找到客户: {identifier}", code="NOT_FOUND")
return
content = data.get("content", "").strip()
if not content:
output_error("跟进内容(content)为必填字段", code="VALIDATION_ERROR")
return
sub = check_subscription()
followup_date = data.get("date", today_str())
# 计算下次跟进日期
if sub["tier"] == "paid" and data.get("next_followup_date"):
next_followup_date = data["next_followup_date"]
elif sub["tier"] == "paid" and data.get("reminder_days"):
days = int(data["reminder_days"])
next_dt = datetime.strptime(followup_date, "%Y-%m-%d") + timedelta(days=days)
next_followup_date = next_dt.strftime("%Y-%m-%d")
else:
# 免费版固定 3 天后提醒
next_dt = datetime.strptime(followup_date, "%Y-%m-%d") + timedelta(days=3)
next_followup_date = next_dt.strftime("%Y-%m-%d")
followup = {
"id": generate_id("F"),
"customer_id": customer["id"],
"customer_name": customer["name"],
"date": followup_date,
"content": content,
"next_action": data.get("next_action", ""),
"next_followup_date": next_followup_date,
"created_at": now_iso(),
}
followups = _get_followups()
followups.append(followup)
_save_followups(followups)
output_success({
"message": f"已记录对「{customer['name']}」的跟进",
"followup": followup,
"reminder": f"下次跟进提醒: {next_followup_date}",
})
def list_pending(data: Optional[Dict[str, Any]] = None) -> None:
"""列出待跟进客户清单。
按最后跟进时间排序,标注超期未跟进的客户。
Args:
data: 可选过滤参数。
"""
customers = _get_customers()
followups = _get_followups()
# 排除已成交和已流失的客户
active_customers = [
c for c in customers
if c.get("stage") not in ("成交", "流失")
]
# 构建每个客户的最后跟进信息
today = today_str()
pending_list = []
for customer in active_customers:
cid = customer["id"]
customer_followups = [f for f in followups if f.get("customer_id") == cid]
customer_followups.sort(key=lambda f: f.get("date", ""), reverse=True)
if customer_followups:
last_followup = customer_followups[0]
last_date = last_followup.get("date", "")
next_date = last_followup.get("next_followup_date", "")
days_since = calculate_days_since(last_date)
is_overdue = next_date <= today if next_date else days_since >= 3
else:
last_date = customer.get("created_at", "")[:10]
next_date = ""
days_since = calculate_days_since(last_date)
is_overdue = days_since >= 3
pending_list.append({
"customer_id": cid,
"customer_name": customer["name"],
"company": customer.get("company", ""),
"stage": customer.get("stage", ""),
"phone": mask_phone(customer.get("phone", "")),
"last_followup_date": last_date,
"next_followup_date": next_date,
"days_since_last": days_since,
"is_overdue": is_overdue,
"last_content": customer_followups[0].get("content", "") if customer_followups else "尚未跟进",
})
# 超期的排在前面,然后按最后跟进时间升序(最久未跟进的在前)
pending_list.sort(key=lambda x: (not x["is_overdue"], -x["days_since_last"]))
overdue_count = sum(1 for p in pending_list if p["is_overdue"])
output_success({
"total": len(pending_list),
"overdue_count": overdue_count,
"pending": pending_list,
})
def get_reminders(data: Optional[Dict[str, Any]] = None) -> None:
"""获取今日跟进提醒。
返回今日(及已过期)需要跟进的客户清单。
Args:
data: 可选参数。
"""
customers = _get_customers()
followups = _get_followups()
today = today_str()
# 收集需要提醒的客户
reminders = []
seen_customers = set()
for f in followups:
next_date = f.get("next_followup_date", "")
cid = f.get("customer_id", "")
if not next_date or cid in seen_customers:
continue
if next_date <= today:
customer = _find_customer(customers, cid)
if customer and customer.get("stage") not in ("成交", "流失"):
seen_customers.add(cid)
days_overdue = calculate_days_since(next_date)
reminders.append({
"customer_id": cid,
"customer_name": customer["name"],
"company": customer.get("company", ""),
"stage": customer.get("stage", ""),
"phone": mask_phone(customer.get("phone", "")),
"planned_date": next_date,
"days_overdue": days_overdue,
"last_content": f.get("content", ""),
"next_action": f.get("next_action", ""),
"urgency": "高" if days_overdue >= 7 else ("中" if days_overdue >= 3 else "低"),
})
# 检查从未跟进过的客户
followed_ids = {f.get("customer_id") for f in followups}
for customer in customers:
cid = customer["id"]
if cid not in followed_ids and cid not in seen_customers:
if customer.get("stage") not in ("成交", "流失"):
days = calculate_days_since(customer.get("created_at", "")[:10])
if days >= 3:
seen_customers.add(cid)
reminders.append({
"customer_id": cid,
"customer_name": customer["name"],
"company": customer.get("company", ""),
"stage": customer.get("stage", ""),
"phone": mask_phone(customer.get("phone", "")),
"planned_date": "",
"days_overdue": days,
"last_content": "尚未跟进",
"next_action": "首次跟进",
"urgency": "高" if days >= 7 else "中",
})
reminders.sort(key=lambda r: (-r["days_overdue"]))
output_success({
"date": today,
"total_reminders": len(reminders),
"urgent_count": sum(1 for r in reminders if r["urgency"] == "高"),
"reminders": reminders,
})
def get_history(data: Dict[str, Any]) -> None:
"""获取指定客户的跟进历史。
必填字段: customer_id 或 customer_name
Args:
data: 包含客户标识的字典。
"""
customers = _get_customers()
customer_id = data.get("customer_id")
customer_name = data.get("customer_name")
customer = None
if customer_id:
customer = _find_customer(customers, customer_id)
elif customer_name:
customer = _find_customer_by_name(customers, customer_name)
if not customer:
identifier = customer_id or customer_name or "未指定"
output_error(f"未找到客户: {identifier}", code="NOT_FOUND")
return
followups = _get_followups()
history = [f for f in followups if f.get("customer_id") == customer["id"]]
history.sort(key=lambda f: f.get("date", ""), reverse=True)
customer_display = dict(customer)
customer_display["phone"] = mask_phone(customer_display.get("phone", ""))
output_success({
"customer": customer_display,
"total_followups": len(history),
"history": history,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("customer-pulse 跟进记录管理")
args = parser.parse_args()
action = args.action.lower().replace("-", "_")
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"record": lambda: record_followup(data or {}),
"list_pending": lambda: list_pending(data),
"reminders": lambda: get_reminders(data),
"history": lambda: get_history(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:references/crm-guide.md
# CRM 实践指南
本指南涵盖销售阶段定义、跟进最佳实践、客户评分标准和销售漏斗管理技巧。
---
## 一、销售阶段定义
customer-pulse 采用六阶段销售流程模型:
### 1. 初步接触
**定义**:首次与潜在客户建立联系,尚未深入了解其需求。
**关键动作**:
- 收集客户基本信息(姓名、公司、联系方式)
- 初步了解客户背景和行业
- 建立基本信任关系
- 确认客户是否为目标客群
**进入下一阶段的标志**:客户明确表达了具体需求或兴趣点。
### 2. 需求确认
**定义**:已了解客户的核心需求,正在评估匹配度。
**关键动作**:
- 深入沟通客户痛点和期望
- 了解客户预算范围和决策流程
- 识别关键决策人
- 评估产品/服务与客户需求的匹配度
**进入下一阶段的标志**:客户要求提供具体方案或报价。
### 3. 方案报价
**定义**:已向客户提交解决方案和商业报价。
**关键动作**:
- 制作针对性方案文档
- 提供明确报价
- 安排方案演示或试用
- 处理客户的初步反馈
**进入下一阶段的标志**:客户对方案基本认可,进入商务条件谈判。
### 4. 谈判
**定义**:双方就商务条件进行协商。
**关键动作**:
- 协商价格、付款方式、交付周期
- 处理合同条款细节
- 应对竞争对手比价
- 推动客户做出决定
**进入下一阶段的标志**:双方达成一致,准备签约。
### 5. 成交
**定义**:客户已签约或确认合作。
**关键动作**:
- 完成签约和付款
- 安排项目启动或产品交付
- 建立售后服务对接
- 记录合作细节用于后续复盘
### 6. 流失
**定义**:客户明确放弃合作或长期失去联系。
**常见原因**:
- 客户选择了竞品
- 客户预算被削减或项目取消
- 需求与产品不匹配
- 跟进不及时导致客户冷却
---
## 二、跟进最佳实践
### 跟进频率建议
| 客户阶段 | 建议跟进频率 | 说明 |
|----------|------------|------|
| 初步接触 | 3-5 天/次 | 趁热打铁,快速推进 |
| 需求确认 | 2-3 天/次 | 持续深挖需求,保持热度 |
| 方案报价 | 1-2 天/次 | 方案提交后密切关注反馈 |
| 谈判 | 1-2 天/次 | 高频推进,把握成交窗口 |
| 成交 | 7-14 天/次 | 售后维护,培养口碑 |
### 跟进内容记录要点
每次跟进应记录以下信息:
1. **沟通摘要**:简要记录本次沟通的核心内容(2-3 句话)
2. **客户反馈**:客户表达的态度、疑虑或新需求
3. **下一步行动**:明确后续需要做的事情及责任人
4. **下次跟进计划**:约定或计划的下次联系时间
### 跟进注意事项
- **及时性**:承诺的事项务必在约定时间内完成
- **持续性**:避免长时间间断,保持客户对你的记忆
- **价值传递**:每次跟进都应为客户提供一些价值(行业信息、解决方案补充等)
- **记录完整性**:每次跟进后立即记录,避免遗忘关键信息
---
## 三、客户评分标准
customer-pulse 使用流失风险评分(0-100分)评估客户状态:
### 风险因素权重
| 因素 | 权重说明 |
|------|----------|
| 跟进频率衰减 | 近期跟进间隔 vs 历史平均,衰减越大风险越高 |
| 最后跟进距今 | 超过 7 天 +15分,超过 14 天 +25分,超过 30 天 +35分 |
| 总跟进次数 | 跟进次数 ≤2 时额外 +10分 |
| 从未跟进 | 基础 40 分 + 每天 3 分递增 |
### 风险等级划分
| 等级 | 分数范围 | 含义 | 建议操作 |
|------|----------|------|----------|
| 低风险 | 0-39 | 客户状态良好 | 保持正常跟进节奏 |
| 中风险 | 40-69 | 需要关注 | 本周内安排跟进 |
| 高风险 | 70-100 | 可能流失 | 立即安排跟进,主管介入 |
---
## 四、销售漏斗管理技巧
### 漏斗健康度指标
一个健康的销售漏斗应满足:
1. **各阶段客户数量呈倒三角分布**:初步接触 > 需求确认 > 方案报价 > 谈判 > 成交
2. **各阶段转化率在合理范围内**:
- 初步接触 → 需求确认:30%-50%
- 需求确认 → 方案报价:50%-70%
- 方案报价 → 谈判:40%-60%
- 谈判 → 成交:50%-80%
3. **各阶段停留时间不超过预期**:避免客户在某个阶段滞留过久
### 常见漏斗问题及解决方案
| 问题 | 表现 | 解决方案 |
|------|------|----------|
| 入口不足 | 初步接触客户数过少 | 增加获客渠道,扩大营销投入 |
| 转化瓶颈 | 某阶段转化率过低 | 针对性优化该阶段的销售策略 |
| 滞留过多 | 大量客户停留在同一阶段 | 集中清理,推动或淘汰 |
| 流失过高 | 总体成交率低于预期 | 复盘流失原因,优化产品和服务 |
### 每日漏斗检查清单
1. 查看今日跟进提醒,优先处理超期客户
2. 检查各阶段客户数量变化
3. 关注高风险客户(付费版),及时干预
4. 更新客户阶段状态,保持数据准确
5. 记录当日所有跟进内容
---
## 五、数据管理建议
### CSV 导入格式
导入客户数据时,CSV 文件应包含以下表头(支持中英文):
```csv
name,phone,company,product_interest,budget,stage,source
姓名,手机,公司,意向产品,预算,阶段,来源
张总,13800138000,ABC公司,产品A,100000,初步接触,网站咨询
李经理,13900139000,XYZ集团,产品B,500000,需求确认,展会
```
### 数据备份
建议定期导出客户数据作为备份:
```bash
python3 scripts/customer_store.py --action export --data '{"file_path":"backup_20260319.csv"}'
```
---
*本指南基于 customer-pulse v1.0 编写,适用于各行业的 B2B 销售场景。*
合同卫士 — AI合同审查助手,识别风险条款、提取关键信息、追踪到期日
---
name: contract-guardian
description: 合同卫士 — AI合同审查助手,识别风险条款、提取关键信息、追踪到期日
version: 1.0.0
metadata:
openclaw:
optional_env:
- CG_SUBSCRIPTION_TIER
- CG_DATA_DIR
---
# 合同卫士(contract-guardian)
你是一个专业的AI合同审查助手 Agent。你的职责是帮助用户审查合同文件、识别风险条款、提取关键信息、追踪合同到期日。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `CG_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `CG_DATA_DIR` | 否 | 数据存储目录,默认 `~/.openclaw-bdi/contract-guardian/` |
---
## 流程一:合同审查
当用户说"审查这份合同"、"帮我看看这份合同"、"检查合同"或类似意图时,执行以下步骤:
### 步骤 1:解析合同文件
确认用户提供的合同文件路径和格式:
```bash
python3 scripts/contract_parser.py --action parse --file <文件路径>
```
- 免费版支持 TXT/MD 格式。
- 付费版额外支持 PDF/DOCX 格式。
- 若文件格式不支持,提示用户升级或转换格式。
### 步骤 2:提取关键信息
从合同文本中提取甲乙方、金额、期限等关键信息:
```bash
python3 scripts/key_info_extractor.py --action summary --text-file <文件路径>
```
将提取结果以表格形式展示给用户:
```
| 项目 | 内容 |
|------|------|
| 甲方 | XX有限公司 |
| 乙方 | YY科技有限公司 |
| 合同金额 | ¥50.00万 |
| 合同期限 | 2026-01-01 至 2026-12-31 |
| 签订日期 | 2026-01-01 |
```
### 步骤 3:识别风险条款
对合同文本进行风险分析:
```bash
python3 scripts/risk_analyzer.py --action full-report --text-file <文件路径>
```
根据订阅等级输出不同范围的风险报告:
**免费版输出:**
- 检查 3 类基础风险(单方解约权、违约金、付款条件)
- 输出风险等级和简要建议
**付费版输出:**
- 检查全部 12 类风险
- 输出安全评分(0-100)
- 详细的风险分析和改进建议
### 步骤 4:生成审查报告
将关键信息和风险分析整合为完整的审查报告,格式如下:
```
# 合同审查报告
## 基本信息
(关键信息表格)
## 风险评估
- 安全评分:XX/100
- 风险等级:低风险/中等风险/较高风险/高风险
## 风险条款详情
(逐条列出发现的风险,包含严重程度、说明和建议)
## 审查建议
(综合建议)
```
---
## 流程二:风险专项分析
当用户说"这份合同有什么风险"、"检查风险条款"、"分析合同风险"或类似意图时,执行以下步骤:
### 步骤 1:获取合同文本
若用户提供文件路径,使用解析工具读取:
```bash
python3 scripts/contract_parser.py --action extract-text --file <文件路径>
```
若用户直接粘贴合同文本,直接使用。
### 步骤 2:执行风险分析
```bash
python3 scripts/risk_analyzer.py --action full-report --text-file <文件路径>
```
### 步骤 3:展示风险报告
按严重程度分组展示风险:
**高风险条款**(需要立即关注):
- 列出所有高风险项,标注匹配的原文片段和建议
**中等风险条款**(建议审慎评估):
- 列出所有中等风险项
**低风险条款**(供参考):
- 列出所有低风险项
免费版用户提示可升级查看完整 12 类风险分析。
---
## 流程三:合同到期提醒
当用户说"合同到期提醒"、"哪些合同快到期了"、"检查到期合同"或类似意图时,执行以下步骤:
### 步骤 1:查询到期合同
```bash
python3 scripts/contract_store.py --action expiring
```
### 步骤 2:展示到期清单
按紧急程度分组展示:
```
## 合同到期提醒
### 30天内到期(紧急)
| 合同 | 甲方 | 乙方 | 到期日 | 剩余天数 |
|------|------|------|--------|----------|
### 60天内到期(注意)
(同上格式)
### 90天内到期(预警)
(同上格式)
```
### 步骤 3:给出建议
对即将到期的合同给出处理建议:续约、终止或重新谈判。
---
## 流程四:合同存档
当用户说"存档这份合同"、"保存合同信息"或类似意图时,执行以下步骤:
### 步骤 1:提取合同信息
若用户提供合同文件,先提取关键信息:
```bash
python3 scripts/key_info_extractor.py --action extract --text-file <文件路径>
```
### 步骤 2:确认信息
向用户展示提取的信息并确认:
- 合同标题
- 甲乙方
- 起止日期
- 合同金额
### 步骤 3:存档
```bash
python3 scripts/contract_store.py --action archive --data '<JSON数据>'
```
确认存档成功并告知合同编号。
---
## 流程五:合同对比
当用户说"对比这两份合同"、"合同差异分析"或类似意图时,执行以下步骤:
> 注意:此功能仅限付费版用户。免费版用户请提示升级。
### 步骤 1:订阅校验
确认当前为付费版。若为免费版,提示:
"合同对比功能为付费功能。升级至付费版(¥129/月)即可使用逐条对比功能。"
### 步骤 2:读取两份合同
确认用户提供的两份合同文件路径。
### 步骤 3:执行对比
```bash
python3 scripts/contract_comparator.py --action diff-report --file1 <路径1> --file2 <路径2>
```
### 步骤 4:展示对比报告
输出 Markdown 格式的对比报告,包含:
- 整体相似度
- 条款差异汇总表
- 详细差异内容(新增/删除/修改)
- 风险增减分析
---
## 流程六:合同模板
当用户说"合同模板"、"给我一个合同模板"或类似意图时:
### 免费版
提供 3 个基础模板(采购合同、服务合同、合作协议),引导用户查看 `references/contract-templates.md`。
### 付费版
提供 20+ 行业模板,包括但不限于:技术开发、劳动合同、租赁合同、保密协议、代理协议等。
---
## 订阅校验逻辑
在每次涉及功能限制的操作前,必须执行订阅校验:
### 读取订阅等级
```
tier = env CG_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥129/月) |
|------|---------------|----------------------|
| 合同关键信息提取 | 1份/天 | 无限 |
| 风险条款识别 | 基础3类 | 完整12类 |
| 合同到期提醒 | 3份 | 无限 |
| 合同对比 | 不支持 | 逐条对比 |
| 合同模板库 | 3个基础模板 | 20+行业模板 |
| 多文件格式支持 | TXT/MD | TXT/MD/PDF/DOCX |
| 历史合同检索 | 不支持 | 支持 |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥129/月),请联系管理员或访问订阅管理页面。"
4. 不要直接拒绝,而是提供免费版可用的替代方案(如果有的话)。
---
## 参考文档
在进行合同审查时,请参考以下文档:
- **风险条款清单**:`references/risk-checklist.md` — 包含 12 类风险条款的详细说明和识别方法。
- **合同模板**:`references/contract-templates.md` — 包含基础合同模板和付费模板列表。
---
## 安全规范
1. **隐私保护**:绝不在日志或输出中暴露完整合同原文。仅展示必要的关键片段。
2. **敏感信息脱敏**:合同中出现的身份证号、手机号、银行卡号等敏感信息,在展示时自动脱敏处理。
3. **数据安全**:合同数据仅存储在本地 `CG_DATA_DIR` 目录,不会上传至任何外部服务。
4. **错误处理**:执行命令失败时,向用户展示友好的错误提示,不要暴露内部路径或系统信息。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 审查合同时保持客观、专业的态度,提供法律风险提示但不构成法律建议。
3. 对发现的风险条款给出明确的严重程度评级和改进建议。
4. 遇到模糊的合同条款时,主动指出并建议用户咨询专业律师。
5. 输出结构化、易读的审查报告,优先使用表格展示关键信息。
6. 尊重订阅等级限制,在提示升级时保持友好,不要反复推销。
7. 在输出中注明"本报告由AI生成,仅供参考,不构成法律建议"。
FILE:assets/README.md
# 合同卫士 (contract-guardian)
> AI驱动的合同审查助手,帮你识别风险条款、提取关键信息、追踪到期日
---
## 功能亮点
- **智能风险识别** — 基于 12 类风险模型自动扫描合同条款,标注风险等级并给出改进建议
- **关键信息提取** — 自动识别甲乙方、合同金额、期限、违约金等核心要素,秒级生成摘要
- **到期提醒** — 自动追踪合同到期日,30/60/90 天分级预警,不再遗漏续约时间
- **合同对比** — 逐条对比两份合同差异,高亮变更内容,分析风险增减
- **多格式支持** — 支持 TXT、MD、PDF、DOCX 多种合同文件格式
- **本地安全** — 所有数据在本地处理和存储,合同内容不会上传到任何外部服务
---
## 版本对比
| 功能 | 免费版 | 付费版 ¥129/月 |
|------|:------:|:------------:|
| 合同关键信息提取 | 1份/天 | 无限 |
| 风险条款识别 | 基础3类 | 完整12类 |
| 合同到期提醒 | 3份 | 无限 |
| 合同对比 | 不支持 | 逐条对比 |
| 合同模板库 | 3个基础模板 | 20+行业模板 |
| 文件格式 | TXT/MD | TXT/MD/PDF/DOCX |
| 历史合同检索 | 不支持 | 支持 |
---
## 快速开始
### 1. 安装 Skill
在 ClawHub 中搜索 `contract-guardian`,点击安装,或使用命令行:
```bash
openclaw skill install contract-guardian
```
### 2. 审查合同
```bash
# 审查一份合同
/contract-guardian review --file ./my-contract.txt
# 提取关键信息
/contract-guardian extract --file ./my-contract.md
# 检查风险条款
/contract-guardian risk-check --file ./my-contract.txt
```
### 3. 到期提醒
```bash
# 查看即将到期的合同
/contract-guardian expiring
# 存档新合同
/contract-guardian archive --file ./new-contract.txt
```
### 4. 对比合同(付费版)
```bash
# 对比两份合同
/contract-guardian compare --file1 ./contract-v1.txt --file2 ./contract-v2.txt
```
---
## 审查报告示例
以下是一份自动生成的合同审查报告样例:
```markdown
# 合同审查报告
## 基本信息
| 项目 | 内容 |
|------|------|
| 甲方 | 北京XX科技有限公司 |
| 乙方 | 上海YY信息技术有限公司 |
| 合同金额 | ¥50.00万 |
| 合同期限 | 2026-01-01 至 2026-12-31 |
| 签订日期 | 2025-12-15 |
## 风险评估
- 安全评分:65/100
- 风险等级:中等风险
- 检查范围:12类风险条款(付费版)
## 风险条款详情
### 高风险
1. **单方解约权** — 第八条约定甲方可单方终止合同,但未约定对等条件
> 建议:增加解约条件限制,明确双方解约权的对等性
2. **违约金不对等** — 乙方违约金为合同总额20%,甲方仅5%
> 建议:调整为双方对等的违约金比例
### 中等风险
3. **付款条件** — 付款以甲方验收为前提,账期60个工作日
> 建议:缩短账期至30天,明确验收标准和时限
## 审查建议
本合同存在2项高风险和1项中等风险条款,建议重点关注单方解约权和
违约金不对等问题。建议与对方协商修改后再签署。
---
*本报告由AI生成,仅供参考,不构成法律建议*
```
---
## 12类风险条款
合同卫士覆盖以下 12 类常见合同风险:
| 类别 | 风险等级 | 免费版 |
|------|----------|:------:|
| 单方解约权 | 高 | 支持 |
| 违约金 | 高 | 支持 |
| 付款条件 | 中 | 支持 |
| 自动续约 | 中 | 付费 |
| 无限责任 | 高 | 付费 |
| 竞业限制 | 高 | 付费 |
| 知识产权归属 | 高 | 付费 |
| 管辖地/仲裁 | 中 | 付费 |
| 保密条款 | 中 | 付费 |
| 验收标准 | 中 | 付费 |
| 担保条款 | 高 | 付费 |
| 不可抗力 | 低 | 付费 |
---
## 常见问题
### Q1: 支持哪些文件格式?
免费版支持 TXT 和 Markdown 格式。付费版额外支持 PDF 和 DOCX 格式。PDF 解析需要安装 `pdfplumber`,DOCX 解析需要安装 `python-docx`。
### Q2: 合同数据存储在哪里?
所有合同数据存储在本地 `~/.openclaw-bdi/contract-guardian/` 目录下,不会上传到任何外部服务。你可以通过 `CG_DATA_DIR` 环境变量自定义存储路径。
### Q3: 风险识别准确率如何?
合同卫士基于关键词和正则模式匹配识别风险条款,对常见的风险表述有较好的识别率。但合同审查涉及法律专业判断,建议将审查报告作为参考,重要合同仍需咨询专业律师。
### Q4: 免费版和付费版的主要区别是什么?
免费版提供基础的合同审查能力(3类风险识别、1份/天、TXT/MD格式)。付费版(¥129/月)解锁全部12类风险识别、无限次审查、PDF/DOCX支持、合同对比和历史检索等高级功能。
### Q5: 如何升级到付费版?
设置环境变量 `CG_SUBSCRIPTION_TIER=paid` 即可激活付费版功能。或联系管理员通过订阅管理页面升级。
### Q6: AI审查报告能替代律师吗?
不能。合同卫士提供的审查报告基于模式匹配生成,仅供参考,不构成法律建议。对于重大合同和复杂法律问题,建议咨询专业律师。
---
## 技术支持
- 查看 `references/` 目录获取风险条款清单和合同模板
- 在 ClawHub 的 Skill 页面提交 Issue
- 加入 ClawHub 社区频道 `#contract-guardian`
- 邮件:[email protected]
---
*contract-guardian v1.0 | 兼容 OpenClaw 0.5+*
FILE:scripts/contract_parser.py
#!/usr/bin/env python3
"""
contract-guardian 合同文件解析模块
支持 TXT、MD、PDF、DOCX 格式的合同文件解析。
免费版仅支持 TXT/MD,付费版支持全部格式。
"""
import argparse
import os
import sys
# 将 scripts 目录加入路径以便导入 utils
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription,
mask_sensitive_info,
output_error,
output_success,
read_text_file,
)
def parse_txt(file_path: str) -> str:
"""解析 TXT 文件。
Args:
file_path: 文件路径。
Returns:
文件文本内容。
"""
return read_text_file(file_path)
def parse_md(file_path: str) -> str:
"""解析 Markdown 文件。
Args:
file_path: 文件路径。
Returns:
文件文本内容。
"""
return read_text_file(file_path)
def parse_pdf(file_path: str) -> str:
"""解析 PDF 文件。
优先使用 pdfplumber 库,若不可用则尝试基础文本提取。
Args:
file_path: 文件路径。
Returns:
提取的文本内容。
Raises:
ImportError: 当 PDF 解析库不可用时抛出。
"""
try:
import pdfplumber
text_parts = []
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
if page_text:
text_parts.append(page_text)
if text_parts:
return "\n\n".join(text_parts)
return ""
except ImportError:
pass
# 基础 fallback:尝试读取 PDF 中的文本流
try:
with open(file_path, "rb") as f:
content = f.read()
# 尝试提取 PDF 文本流中的内容
import re
text_parts = []
# 匹配 PDF 文本对象中的字符串
for match in re.finditer(rb"\(([^)]+)\)", content):
try:
decoded = match.group(1).decode("utf-8", errors="ignore")
if len(decoded.strip()) > 1:
text_parts.append(decoded.strip())
except Exception:
continue
if text_parts:
return "\n".join(text_parts)
raise ImportError(
"无法解析 PDF 文件。请安装 pdfplumber: pip install pdfplumber"
)
except ImportError:
raise
except Exception:
raise ImportError(
"无法解析 PDF 文件。请安装 pdfplumber: pip install pdfplumber"
)
def parse_docx(file_path: str) -> str:
"""解析 DOCX 文件。
使用 python-docx 库解析 Word 文档。
Args:
file_path: 文件路径。
Returns:
提取的文本内容。
Raises:
ImportError: 当 python-docx 库不可用时抛出。
"""
try:
from docx import Document
doc = Document(file_path)
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
return "\n\n".join(paragraphs)
except ImportError:
raise ImportError(
"无法解析 DOCX 文件。请安装 python-docx: pip install python-docx"
)
def detect_format(file_path: str) -> str:
"""根据文件扩展名检测格式。
Args:
file_path: 文件路径。
Returns:
格式字符串: txt, md, pdf, docx。
"""
ext = os.path.splitext(file_path)[1].lower()
format_map = {
".txt": "txt",
".text": "txt",
".md": "md",
".markdown": "md",
".pdf": "pdf",
".docx": "docx",
".doc": "docx",
}
return format_map.get(ext, "txt")
def parse_contract(file_path: str, file_format: str = None) -> dict:
"""解析合同文件并返回结构化结果。
Args:
file_path: 合同文件路径。
file_format: 文件格式,若为 None 则自动检测。
Returns:
包含解析结果的字典:
{
"file_path": str,
"format": str,
"text": str,
"char_count": int,
"line_count": int
}
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"合同文件不存在: {file_path}")
if file_format is None:
file_format = detect_format(file_path)
file_format = file_format.lower()
# 订阅校验
sub = check_subscription()
supported = sub.get("supported_formats", ["txt", "md"])
if file_format not in supported:
raise ValueError(
f"当前订阅等级不支持 {file_format.upper()} 格式。"
f"免费版仅支持 TXT/MD,付费版支持 TXT/MD/PDF/DOCX。"
f"如需升级至付费版(¥129/月),请联系管理员。"
)
# 解析文件
parsers = {
"txt": parse_txt,
"md": parse_md,
"pdf": parse_pdf,
"docx": parse_docx,
}
parser_func = parsers.get(file_format)
if parser_func is None:
raise ValueError(f"不支持的文件格式: {file_format}")
text = parser_func(file_path)
return {
"file_path": os.path.abspath(file_path),
"format": file_format,
"text": text,
"char_count": len(text),
"line_count": len(text.splitlines()),
}
def extract_text(file_path: str, file_format: str = None) -> str:
"""从合同文件中提取纯文本内容。
Args:
file_path: 合同文件路径。
file_format: 文件格式,若为 None 则自动检测。
Returns:
提取的纯文本内容。
"""
result = parse_contract(file_path, file_format)
return result["text"]
def main():
"""命令行入口。"""
parser = argparse.ArgumentParser(
description="合同文件解析工具 — 支持 TXT/MD/PDF/DOCX 格式",
)
parser.add_argument(
"--action",
required=True,
choices=["parse", "extract-text"],
help="操作类型: parse(完整解析), extract-text(提取纯文本)",
)
parser.add_argument(
"--file",
required=True,
help="合同文件路径",
)
parser.add_argument(
"--format",
default=None,
choices=["txt", "md", "pdf", "docx"],
help="文件格式(可选,默认自动检测)",
)
args = parser.parse_args()
try:
if args.action == "parse":
result = parse_contract(args.file, args.format)
# 脱敏处理后输出
result["text"] = mask_sensitive_info(result["text"])
output_success(result)
elif args.action == "extract-text":
text = extract_text(args.file, args.format)
text = mask_sensitive_info(text)
output_success({"text": text})
except FileNotFoundError as e:
output_error(str(e), "FILE_NOT_FOUND")
except ImportError as e:
output_error(str(e), "DEPENDENCY_MISSING")
except ValueError as e:
output_error(str(e), "VALIDATION_ERROR")
except Exception as e:
output_error(f"解析失败: {e}", "PARSE_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/contract_store.py
#!/usr/bin/env python3
"""
contract-guardian 合同存档与到期提醒模块
支持合同元数据的存档、检索和到期提醒功能。
免费版最多追踪 3 份合同到期,付费版无限制。
"""
import argparse
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription,
ensure_data_dir,
get_data_file,
output_error,
output_success,
)
STORE_FILE = "contracts.json"
def _load_store() -> List[Dict[str, Any]]:
"""加载合同存档数据。"""
store_path = get_data_file(STORE_FILE)
if not os.path.exists(store_path):
return []
try:
with open(store_path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return data
return []
except (json.JSONDecodeError, IOError):
return []
def _save_store(contracts: List[Dict[str, Any]]) -> None:
"""保存合同存档数据。"""
store_path = get_data_file(STORE_FILE)
ensure_data_dir()
with open(store_path, "w", encoding="utf-8") as f:
json.dump(contracts, f, ensure_ascii=False, indent=2, default=str)
def _generate_id(contracts: List[Dict[str, Any]]) -> str:
"""生成合同 ID。"""
if not contracts:
return "CG-001"
max_id = 0
for c in contracts:
cid = c.get("id", "CG-000")
try:
num = int(cid.split("-")[1])
max_id = max(max_id, num)
except (IndexError, ValueError):
continue
return f"CG-{max_id + 1:03d}"
def archive_contract(data: Dict[str, Any]) -> Dict[str, Any]:
"""存档一份合同。
Args:
data: 合同元数据,应包含:
- title: 合同标题
- party_a: 甲方
- party_b: 乙方
- start_date: 开始日期
- end_date: 结束日期
- amount: 合同金额(可选)
- file_path: 合同文件路径(可选)
- notes: 备注(可选)
Returns:
存档结果,包含生成的合同 ID。
"""
contracts = _load_store()
sub = check_subscription()
# 检查存档数量限制
limit = sub.get("expiry_tracking_limit", 3)
if limit != -1 and len(contracts) >= limit:
raise ValueError(
f"当前免费版最多存档 {limit} 份合同。"
f"已存档 {len(contracts)} 份,无法继续添加。"
f"升级至付费版(¥129/月)可无限存档。"
)
contract_id = _generate_id(contracts)
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
record = {
"id": contract_id,
"title": data.get("title", "未命名合同"),
"party_a": data.get("party_a"),
"party_b": data.get("party_b"),
"start_date": data.get("start_date"),
"end_date": data.get("end_date"),
"amount": data.get("amount"),
"file_path": data.get("file_path"),
"notes": data.get("notes"),
"archived_at": now,
"status": "active",
}
contracts.append(record)
_save_store(contracts)
return {
"id": contract_id,
"message": f"合同「{record['title']}」已成功存档",
"record": record,
}
def list_contracts(status: str = None) -> Dict[str, Any]:
"""列出所有存档合同。
Args:
status: 筛选状态(active/expired/all),默认 all。
Returns:
合同列表和统计信息。
"""
contracts = _load_store()
# 更新过期状态
today = datetime.now().strftime("%Y-%m-%d")
for c in contracts:
end_date = c.get("end_date")
if end_date and end_date < today and c.get("status") == "active":
c["status"] = "expired"
_save_store(contracts)
# 筛选
if status and status != "all":
filtered = [c for c in contracts if c.get("status") == status]
else:
filtered = contracts
return {
"total": len(filtered),
"active": sum(1 for c in contracts if c.get("status") == "active"),
"expired": sum(1 for c in contracts if c.get("status") == "expired"),
"contracts": filtered,
}
def search_contracts(keyword: str) -> Dict[str, Any]:
"""搜索合同。
Args:
keyword: 搜索关键词,在标题、甲乙方、备注中匹配。
Returns:
搜索结果。
"""
sub = check_subscription()
if not sub.get("history_search"):
raise ValueError(
"历史合同检索为付费功能。升级至付费版(¥129/月)即可使用。"
)
contracts = _load_store()
keyword_lower = keyword.lower()
matched = []
for c in contracts:
searchable = " ".join(
str(v) for v in [
c.get("title", ""),
c.get("party_a", ""),
c.get("party_b", ""),
c.get("notes", ""),
]
).lower()
if keyword_lower in searchable:
matched.append(c)
return {
"keyword": keyword,
"total": len(matched),
"contracts": matched,
}
def get_expiring_contracts(days: int = 30) -> Dict[str, Any]:
"""获取即将到期的合同。
Args:
days: 到期天数阈值,默认 30 天。
Returns:
到期合同清单,按到期日排序。
"""
contracts = _load_store()
today = datetime.now()
expiring_30 = []
expiring_60 = []
expiring_90 = []
for c in contracts:
if c.get("status") != "active":
continue
end_date_str = c.get("end_date")
if not end_date_str:
continue
try:
end_date = datetime.strptime(end_date_str, "%Y-%m-%d")
except ValueError:
continue
remaining = (end_date - today).days
if remaining < 0:
continue # 已过期
elif remaining <= 30:
expiring_30.append({**c, "remaining_days": remaining})
elif remaining <= 60:
expiring_60.append({**c, "remaining_days": remaining})
elif remaining <= 90:
expiring_90.append({**c, "remaining_days": remaining})
# 按剩余天数排序
expiring_30.sort(key=lambda x: x["remaining_days"])
expiring_60.sort(key=lambda x: x["remaining_days"])
expiring_90.sort(key=lambda x: x["remaining_days"])
return {
"check_date": today.strftime("%Y-%m-%d"),
"expiring_30_days": {
"count": len(expiring_30),
"contracts": expiring_30,
},
"expiring_60_days": {
"count": len(expiring_60),
"contracts": expiring_60,
},
"expiring_90_days": {
"count": len(expiring_90),
"contracts": expiring_90,
},
"total_expiring": len(expiring_30) + len(expiring_60) + len(expiring_90),
}
def main():
"""命令行入口。"""
parser = argparse.ArgumentParser(
description="合同存档与到期提醒工具 — 管理合同存档,追踪到期日",
)
parser.add_argument(
"--action",
required=True,
choices=["archive", "list", "search", "expiring"],
help="操作类型: archive(存档), list(列表), search(搜索), expiring(到期提醒)",
)
parser.add_argument(
"--data",
default=None,
help="合同数据(JSON 格式字符串)",
)
parser.add_argument(
"--data-file",
default=None,
help="合同数据文件路径(JSON 文件)",
)
parser.add_argument(
"--keyword",
default=None,
help="搜索关键词(用于 search 操作)",
)
parser.add_argument(
"--status",
default="all",
choices=["active", "expired", "all"],
help="合同状态筛选(用于 list 操作)",
)
parser.add_argument(
"--days",
type=int,
default=30,
help="到期天数阈值(用于 expiring 操作)",
)
args = parser.parse_args()
try:
if args.action == "archive":
data = None
if args.data:
data = json.loads(args.data)
elif args.data_file:
with open(args.data_file, "r", encoding="utf-8") as f:
data = json.load(f)
else:
raise ValueError("请通过 --data 或 --data-file 提供合同数据")
result = archive_contract(data)
output_success(result)
elif args.action == "list":
result = list_contracts(args.status)
output_success(result)
elif args.action == "search":
if not args.keyword:
raise ValueError("请通过 --keyword 提供搜索关键词")
result = search_contracts(args.keyword)
output_success(result)
elif args.action == "expiring":
result = get_expiring_contracts(args.days)
output_success(result)
except json.JSONDecodeError as e:
output_error(f"JSON 格式错误: {e}", "JSON_ERROR")
except ValueError as e:
output_error(str(e), "VALIDATION_ERROR")
except FileNotFoundError as e:
output_error(str(e), "FILE_NOT_FOUND")
except Exception as e:
output_error(f"操作失败: {e}", "STORE_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/contract_comparator.py
#!/usr/bin/env python3
"""
contract-guardian 合同对比模块
逐条对比两份合同的差异,高亮变更内容,分析风险增减。
此功能仅限付费版用户使用。
"""
import argparse
import difflib
import os
import re
import sys
from typing import Any, Dict, List
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription,
is_paid,
output_error,
output_success,
read_text_file,
)
def split_clauses(text: str) -> List[Dict[str, str]]:
"""将合同文本按条款拆分。
识别常见的条款编号格式(第X条、X.X、一/二/三等)。
Args:
text: 合同文本。
Returns:
条款列表,每项包含 title 和 content。
"""
# 匹配条款标题模式
clause_pattern = re.compile(
r"^(?:"
r"第[一二三四五六七八九十百零\d]+[条章节]"
r"|[一二三四五六七八九十]+[、..]"
r"|\d+[、..]\d*"
r"|[((]\d+[))]"
r")",
re.MULTILINE,
)
matches = list(clause_pattern.finditer(text))
if not matches:
# 无法按条款分割,按段落分割
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
return [
{"title": f"段落 {i + 1}", "content": p}
for i, p in enumerate(paragraphs)
]
clauses = []
for i, match in enumerate(matches):
start = match.start()
end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
clause_text = text[start:end].strip()
# 提取标题(第一行)
lines = clause_text.split("\n", 1)
title = lines[0].strip()
content = lines[1].strip() if len(lines) > 1 else ""
clauses.append({
"title": title,
"content": content,
})
return clauses
def compare_texts(text1: str, text2: str) -> Dict[str, Any]:
"""对比两段文本的差异。
Args:
text1: 第一份合同文本。
text2: 第二份合同文本。
Returns:
差异分析结果。
"""
lines1 = text1.splitlines(keepends=True)
lines2 = text2.splitlines(keepends=True)
differ = difflib.unified_diff(
lines1,
lines2,
fromfile="合同A",
tofile="合同B",
lineterm="",
)
diff_lines = list(differ)
# 统计变更
added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++"))
removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---"))
# 计算相似度
matcher = difflib.SequenceMatcher(None, text1, text2)
similarity = round(matcher.ratio() * 100, 1)
return {
"similarity": similarity,
"lines_added": added,
"lines_removed": removed,
"diff": "".join(diff_lines) if diff_lines else "两份合同内容完全相同",
}
def compare_clauses(clauses1: List[Dict[str, str]], clauses2: List[Dict[str, str]]) -> List[Dict[str, Any]]:
"""逐条对比两份合同的条款。
Args:
clauses1: 第一份合同的条款列表。
clauses2: 第二份合同的条款列表。
Returns:
条款差异列表。
"""
differences = []
# 构建标题到内容的映射
map1 = {c["title"]: c["content"] for c in clauses1}
map2 = {c["title"]: c["content"] for c in clauses2}
all_titles = list(dict.fromkeys(
[c["title"] for c in clauses1] + [c["title"] for c in clauses2]
))
for title in all_titles:
content1 = map1.get(title)
content2 = map2.get(title)
if content1 is None:
differences.append({
"clause": title,
"change_type": "新增",
"description": "合同B中新增的条款",
"content_b": content2[:200] if content2 else "",
})
elif content2 is None:
differences.append({
"clause": title,
"change_type": "删除",
"description": "合同B中删除的条款",
"content_a": content1[:200] if content1 else "",
})
elif content1 != content2:
# 计算条款内容的相似度
matcher = difflib.SequenceMatcher(None, content1, content2)
sim = round(matcher.ratio() * 100, 1)
differences.append({
"clause": title,
"change_type": "修改",
"similarity": sim,
"description": f"条款内容有变更(相似度 {sim}%)",
"content_a_excerpt": content1[:150],
"content_b_excerpt": content2[:150],
})
return differences
def generate_diff_report(file1: str, file2: str) -> Dict[str, Any]:
"""生成完整的合同对比报告。
Args:
file1: 第一份合同文件路径。
file2: 第二份合同文件路径。
Returns:
包含对比报告的字典。
"""
# 付费版校验
if not is_paid():
raise ValueError(
"合同对比功能仅限付费版用户使用。"
"升级至付费版(¥129/月)即可使用逐条对比功能。"
)
text1 = read_text_file(file1)
text2 = read_text_file(file2)
# 整体对比
overall = compare_texts(text1, text2)
# 条款级对比
clauses1 = split_clauses(text1)
clauses2 = split_clauses(text2)
clause_diffs = compare_clauses(clauses1, clauses2)
# 生成 Markdown 报告
report_lines = [
"# 合同对比报告\n",
f"- **合同A**: {os.path.basename(file1)}",
f"- **合同B**: {os.path.basename(file2)}",
f"- **整体相似度**: {overall['similarity']}%",
f"- **新增行数**: {overall['lines_added']}",
f"- **删除行数**: {overall['lines_removed']}",
"",
"## 条款差异汇总\n",
f"共发现 **{len(clause_diffs)}** 处差异:\n",
]
if clause_diffs:
report_lines.append("| 条款 | 变更类型 | 说明 |")
report_lines.append("|------|----------|------|")
for diff in clause_diffs:
report_lines.append(
f"| {diff['clause']} | {diff['change_type']} | {diff['description']} |"
)
report_lines.append("\n## 详细差异\n")
for diff in clause_diffs:
report_lines.append(f"### {diff['clause']}({diff['change_type']})\n")
if diff["change_type"] == "新增":
report_lines.append(f"> {diff.get('content_b', '')}\n")
elif diff["change_type"] == "删除":
report_lines.append(f"> ~~{diff.get('content_a', '')}~~\n")
elif diff["change_type"] == "修改":
report_lines.append(f"**合同A**: {diff.get('content_a_excerpt', '')}\n")
report_lines.append(f"**合同B**: {diff.get('content_b_excerpt', '')}\n")
else:
report_lines.append("未发现条款级差异,两份合同结构一致。\n")
markdown_report = "\n".join(report_lines)
return {
"file1": os.path.abspath(file1),
"file2": os.path.abspath(file2),
"similarity": overall["similarity"],
"total_differences": len(clause_diffs),
"clause_differences": clause_diffs,
"markdown_report": markdown_report,
}
def main():
"""命令行入口。"""
parser = argparse.ArgumentParser(
description="合同对比工具 — 逐条对比两份合同的差异(付费功能)",
)
parser.add_argument(
"--action",
required=True,
choices=["compare", "diff-report"],
help="操作类型: compare(基础对比), diff-report(完整差异报告)",
)
parser.add_argument(
"--file1",
required=True,
help="第一份合同文件路径",
)
parser.add_argument(
"--file2",
required=True,
help="第二份合同文件路径",
)
args = parser.parse_args()
try:
if args.action == "compare":
if not is_paid():
raise ValueError(
"合同对比功能仅限付费版用户使用。"
"升级至付费版(¥129/月)即可使用逐条对比功能。"
)
text1 = read_text_file(args.file1)
text2 = read_text_file(args.file2)
result = compare_texts(text1, text2)
output_success(result)
elif args.action == "diff-report":
result = generate_diff_report(args.file1, args.file2)
output_success(result)
except ValueError as e:
output_error(str(e), "VALIDATION_ERROR")
except FileNotFoundError as e:
output_error(str(e), "FILE_NOT_FOUND")
except Exception as e:
output_error(f"对比失败: {e}", "COMPARE_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
contract-guardian 共享工具模块
提供订阅校验、JSON 输入输出、文件读取、敏感信息脱敏等通用功能。
"""
import argparse
import json
import os
import re
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
# ============================================================
# 常量
# ============================================================
DATA_DIR = os.environ.get(
"CG_DATA_DIR",
os.path.join(os.path.expanduser("~"), ".openclaw-bdi", "contract-guardian"),
)
# ============================================================
# 订阅校验
# ============================================================
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"daily_review_limit": 1,
"risk_categories": 3,
"expiry_tracking_limit": 3,
"contract_compare": False,
"template_count": 3,
"supported_formats": ["txt", "md"],
"history_search": False,
"features": [
"key_info_extract",
"basic_risk_check",
"expiry_reminder",
"basic_templates",
],
},
"paid": {
"tier": "paid",
"daily_review_limit": -1, # -1 表示无限制
"risk_categories": 12,
"expiry_tracking_limit": -1,
"contract_compare": True,
"template_count": 20,
"supported_formats": ["txt", "md", "pdf", "docx"],
"history_search": True,
"features": [
"key_info_extract",
"full_risk_check",
"expiry_reminder",
"contract_compare",
"all_templates",
"history_search",
"multi_format",
],
},
}
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 CG_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get("CG_SUBSCRIPTION_TIER", "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def is_paid() -> bool:
"""快速判断当前是否为付费版。"""
tier = os.environ.get("CG_SUBSCRIPTION_TIER", "free").strip().lower()
return tier == "paid"
def check_feature(feature: str) -> bool:
"""检查当前订阅是否包含指定功能。
Args:
feature: 功能标识符。
Returns:
True 表示当前订阅包含该功能。
"""
sub = check_subscription()
return feature in sub.get("features", [])
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_stdin() -> Dict[str, Any]:
"""从标准输入读取 JSON 数据并解析为字典。
Returns:
解析后的字典对象。
Raises:
ValueError: 当输入为空或 JSON 格式不合法时抛出。
"""
try:
raw = sys.stdin.read()
except Exception as e:
raise ValueError(f"读取标准输入失败: {e}")
if not raw.strip():
raise ValueError("标准输入为空,未读取到任何数据")
try:
data = json.loads(raw)
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if not isinstance(data, dict):
raise ValueError(f"期望输入为 JSON 对象,实际类型为 {type(data).__name__}")
return data
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
Args:
data: 待输出的数据(可被 JSON 序列化的任意对象)。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 文件读取
# ============================================================
def read_text_file(file_path: str) -> str:
"""读取文本文件内容。
Args:
file_path: 文件路径。
Returns:
文件文本内容。
Raises:
FileNotFoundError: 文件不存在。
ValueError: 文件编码错误。
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
for encoding in ("utf-8", "gbk", "gb2312", "latin-1"):
try:
with open(file_path, "r", encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
raise ValueError(f"无法识别文件编码: {file_path}")
def read_text_input(args: argparse.Namespace) -> str:
"""从命令行参数中获取文本内容。
支持 --text 直接传入文本或 --text-file 传入文件路径。
Args:
args: 包含 text 和 text_file 属性的命名空间。
Returns:
文本内容。
Raises:
ValueError: 未提供文本输入。
"""
if hasattr(args, "text") and args.text:
return args.text
if hasattr(args, "text_file") and args.text_file:
return read_text_file(args.text_file)
raise ValueError("请通过 --text 或 --text-file 提供文本内容")
# ============================================================
# 敏感信息脱敏
# ============================================================
def mask_sensitive_info(text: str) -> str:
"""对文本中的敏感信息进行脱敏处理。
脱敏规则:
- 身份证号: 保留前3后4,中间用 * 替代
- 手机号: 保留前3后4,中间用 * 替代
- 银行卡号: 保留前4后4,中间用 * 替代
Args:
text: 待脱敏的文本。
Returns:
脱敏后的文本。
"""
# 身份证号 (18位)
text = re.sub(
r'\b(\d{3})\d{11}(\d{4})\b',
r'\1***********\2',
text,
)
# 手机号 (11位,以1开头)
text = re.sub(
r'\b(1\d{2})\d{4}(\d{4})\b',
r'\1****\2',
text,
)
# 银行卡号 (16-19位)
text = re.sub(
r'\b(\d{4})\d{8,11}(\d{4})\b',
r'\1********\2',
text,
)
return text
# ============================================================
# 数据目录管理
# ============================================================
def ensure_data_dir() -> str:
"""确保数据目录存在并返回路径。
Returns:
数据目录的绝对路径。
"""
os.makedirs(DATA_DIR, exist_ok=True)
return DATA_DIR
def get_data_file(filename: str) -> str:
"""获取数据目录中文件的完整路径。
Args:
filename: 文件名。
Returns:
文件的完整路径。
"""
ensure_data_dir()
return os.path.join(DATA_DIR, filename)
# ============================================================
# 日期工具
# ============================================================
def parse_date(date_str: str) -> Optional[datetime]:
"""尝试解析常见格式的日期字符串。
支持格式: YYYY-MM-DD, YYYY/MM/DD, YYYY年MM月DD日, YYYYMMDD
Args:
date_str: 日期字符串。
Returns:
解析成功返回 datetime 对象,失败返回 None。
"""
date_str = date_str.strip()
formats = [
"%Y-%m-%d",
"%Y/%m/%d",
"%Y年%m月%d日",
"%Y%m%d",
"%Y.%m.%d",
]
for fmt in formats:
try:
return datetime.strptime(date_str, fmt)
except ValueError:
continue
return None
def format_currency(value: float) -> str:
"""将金额格式化为中文货币表示。
Args:
value: 金额数值。
Returns:
格式化后的金额字符串,例如 "¥1,234,567.00"。
"""
try:
num = float(value)
except (TypeError, ValueError):
return str(value)
if abs(num) >= 1e8:
return f"¥{num / 1e8:.2f}亿"
elif abs(num) >= 1e4:
return f"¥{num / 1e4:.2f}万"
else:
return f"¥{num:,.2f}"
FILE:scripts/risk_analyzer.py
#!/usr/bin/env python3
"""
contract-guardian 风险条款识别模块
基于关键词和正则模式匹配,识别合同中的 12 类风险条款。
免费版仅识别 3 类基础风险,付费版识别全部 12 类。
"""
import argparse
import json
import os
import re
import sys
from typing import Any, Dict, List
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription,
output_error,
output_success,
read_text_input,
)
# ============================================================
# 12 类风险条款定义
# ============================================================
RISK_CATEGORIES: List[Dict[str, Any]] = [
{
"id": "unilateral_termination",
"name": "单方解约权",
"severity": "high",
"description": "一方可单方面解除合同,可能导致对方损失无法补偿",
"keywords": [
"单方解除", "单方终止", "有权解除", "可随时终止",
"有权单方", "无需对方同意.*解除", "单方面解约",
],
"patterns": [
r"[甲乙丙]方有权.*(?:单方|随时).*(?:解除|终止)",
r"(?:任何一方|任一方).*(?:无需|不需).*(?:同意|通知).*(?:解除|终止)",
],
"recommendation": "建议增加解约条件限制,明确双方解约权的对等性和提前通知期",
"free_tier": True,
},
{
"id": "auto_renewal",
"name": "自动续约",
"severity": "medium",
"description": "合同到期自动续约,未及时通知可能导致被动续约",
"keywords": [
"自动续约", "自动续期", "自动延长", "自动延续",
"视为同意续约", "默认续期",
],
"patterns": [
r"(?:期满|到期).*(?:自动|默认).*(?:续约|续期|延长|延续)",
r"(?:未.*(?:书面|提前).*(?:通知|提出)).*(?:视为|默认).*(?:续约|续期)",
],
"recommendation": "建议明确续约通知期限(如提前30天书面通知),避免被动续约",
"free_tier": False,
},
{
"id": "unlimited_liability",
"name": "无限责任",
"severity": "high",
"description": "一方承担无限制的赔偿责任,风险敞口过大",
"keywords": [
"无限责任", "全部损失", "一切损失", "全额赔偿",
"承担所有", "不设上限", "无上限",
],
"patterns": [
r"(?:赔偿|承担).*(?:全部|一切|所有|任何).*(?:损失|损害|费用)",
r"(?:责任|赔偿).*(?:不设|无|没有).*(?:上限|限制|限额)",
],
"recommendation": "建议设定赔偿上限(如合同总金额的一定比例),限制责任范围",
"free_tier": False,
},
{
"id": "non_compete",
"name": "竞业限制",
"severity": "high",
"description": "竞业限制条款可能过度约束经营自由",
"keywords": [
"竞业限制", "竞业禁止", "同业竞争", "不得从事",
"不得经营", "竞争业务",
],
"patterns": [
r"(?:不得|禁止).*(?:从事|经营|参与).*(?:同类|类似|相同|竞争)",
r"竞业.*(?:限制|禁止).*(?:\d+.*(?:年|月))",
],
"recommendation": "建议明确竞业限制的范围、期限和补偿标准,确保合理性",
"free_tier": False,
},
{
"id": "ip_ownership",
"name": "知识产权归属",
"severity": "high",
"description": "知识产权归属不明确,可能导致成果归属争议",
"keywords": [
"知识产权归", "著作权归", "专利归", "成果归属",
"版权归属", "全部权利归",
],
"patterns": [
r"(?:知识产权|著作权|专利|版权|成果).*(?:归|属于).*[甲乙丙]方.*(?:所有|独有|享有)",
r"(?:工作成果|开发成果|技术成果).*(?:全部|一切).*(?:归|属于)",
],
"recommendation": "建议明确约定各方知识产权的归属范围,区分已有IP和新创IP",
"free_tier": False,
},
{
"id": "jurisdiction",
"name": "管辖地/仲裁",
"severity": "medium",
"description": "争议解决管辖地可能对一方不利",
"keywords": [
"管辖", "仲裁", "诉讼管辖", "争议解决",
"仲裁委员会", "仲裁机构",
],
"patterns": [
r"(?:由|向).*(?:人民法院|仲裁委员会|仲裁机构).*(?:管辖|仲裁|裁决)",
r"(?:争议|纠纷).*(?:提交|申请).*(?:仲裁|诉讼)",
],
"recommendation": "建议选择对双方公平的管辖地,优先考虑仲裁方式解决争议",
"free_tier": False,
},
{
"id": "confidentiality",
"name": "保密条款",
"severity": "medium",
"description": "保密义务范围过广或期限过长",
"keywords": [
"保密义务", "保密期限", "保密责任", "保密信息",
"商业秘密", "不得泄露",
],
"patterns": [
r"保密.*(?:期限|义务).*(?:永久|无限期|\d+.*年)",
r"(?:一切|全部|所有).*(?:信息|资料).*(?:均为|视为).*(?:保密|机密)",
],
"recommendation": "建议明确保密信息的范围和期限,避免过度约束",
"free_tier": False,
},
{
"id": "payment_terms",
"name": "付款条件",
"severity": "medium",
"description": "付款条件不明确或对一方明显不利",
"keywords": [
"付款条件", "付款方式", "付款期限", "结算方式",
"账期", "收到发票后",
],
"patterns": [
r"(?:付款|支付|结算).*(?:条件|前提|前置)",
r"(?:收到.*(?:发票|验收报告|确认)).*(?:\d+.*(?:天|日|工作日)).*(?:内.*(?:付款|支付))",
],
"recommendation": "建议明确付款时间节点、金额和条件,设置合理的账期",
"free_tier": True,
},
{
"id": "acceptance_criteria",
"name": "验收标准",
"severity": "medium",
"description": "验收标准模糊,可能导致验收争议",
"keywords": [
"验收标准", "验收条件", "交付标准", "合格标准",
"视为验收合格", "默认验收",
],
"patterns": [
r"(?:\d+.*(?:天|日|工作日)).*(?:内.*未.*(?:提出|提交).*(?:异议|问题)).*(?:视为|默认).*(?:合格|验收)",
r"(?:验收|交付).*(?:标准|条件).*(?:由|以).*[甲乙丙]方.*(?:确定|为准)",
],
"recommendation": "建议细化验收标准和流程,明确验收期限和异议处理机制",
"free_tier": False,
},
{
"id": "penalty",
"name": "违约金",
"severity": "high",
"description": "违约金条款不对等或金额过高",
"keywords": [
"违约金", "违约责任", "违约赔偿", "逾期违约",
"迟延履行", "日万分之",
],
"patterns": [
r"违约金.*(?:为|按).*(?:合同.*(?:总[额价金]|金额)).*(?:\d+%)",
r"(?:每[日天]|日).*(?:万分之|千分之|百分之).*(?:\d+).*(?:违约金|滞纳金)",
],
"recommendation": "建议违约金比例合理(通常不超过合同总额的30%),确保双方对等",
"free_tier": True,
},
{
"id": "guarantee",
"name": "担保条款",
"severity": "high",
"description": "担保范围过广或担保方式不当",
"keywords": [
"担保", "保证金", "质押", "抵押", "连带责任",
"无限连带", "担保责任",
],
"patterns": [
r"(?:无限|连带).*(?:担保|保证).*(?:责任)",
r"(?:担保|保证).*(?:范围|期限).*(?:包括|涵盖).*(?:全部|一切|所有)",
],
"recommendation": "建议明确担保范围、期限和方式,避免无限连带担保",
"free_tier": False,
},
{
"id": "force_majeure",
"name": "不可抗力",
"severity": "low",
"description": "不可抗力条款定义过窄或免责范围过广",
"keywords": [
"不可抗力", "自然灾害", "政府行为", "战争",
"疫情", "罢工",
],
"patterns": [
r"不可抗力.*(?:包括但不限于|包括)",
r"(?:不可抗力|意外事件).*(?:免除|免于|不承担).*(?:全部|一切|任何).*(?:责任|义务)",
],
"recommendation": "建议明确不可抗力的定义范围、通知义务和减损义务",
"free_tier": False,
},
]
# 免费版可用的风险类别 ID
FREE_TIER_CATEGORIES = [c["id"] for c in RISK_CATEGORIES if c.get("free_tier")]
def get_available_categories() -> List[Dict[str, Any]]:
"""根据订阅等级获取可用的风险类别。
Returns:
可用风险类别列表。
"""
sub = check_subscription()
if sub["tier"] == "paid":
return RISK_CATEGORIES[:]
return [c for c in RISK_CATEGORIES if c["id"] in FREE_TIER_CATEGORIES]
def analyze_risk(text: str, categories: List[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""分析合同文本中的风险条款。
Args:
text: 合同文本内容。
categories: 要检查的风险类别列表,默认为当前订阅等级可用类别。
Returns:
风险项列表,每项包含:
{
"category_id": str,
"category_name": str,
"severity": str,
"matched_text": str,
"description": str,
"recommendation": str
}
"""
if categories is None:
categories = get_available_categories()
risks = []
for category in categories:
matches = []
# 关键词匹配
for keyword in category["keywords"]:
for match in re.finditer(keyword, text):
# 提取匹配位置前后的上下文
start = max(0, match.start() - 30)
end = min(len(text), match.end() + 30)
context = text[start:end].replace("\n", " ").strip()
matches.append(context)
# 正则模式匹配
for pattern in category["patterns"]:
try:
for match in re.finditer(pattern, text):
start = max(0, match.start() - 20)
end = min(len(text), match.end() + 20)
context = text[start:end].replace("\n", " ").strip()
matches.append(context)
except re.error:
continue
if matches:
# 去重
unique_matches = list(dict.fromkeys(matches))
risks.append({
"category_id": category["id"],
"category_name": category["name"],
"severity": category["severity"],
"matched_text": unique_matches[:3], # 最多保留3条匹配
"description": category["description"],
"recommendation": category["recommendation"],
})
# 按严重程度排序
severity_order = {"high": 0, "medium": 1, "low": 2}
risks.sort(key=lambda r: severity_order.get(r["severity"], 99))
return risks
def quick_check(text: str) -> Dict[str, Any]:
"""快速检查合同中的主要风险。
仅检查高风险类别,返回简要结果。
Args:
text: 合同文本内容。
Returns:
包含快速检查结果的字典。
"""
categories = get_available_categories()
high_risk_categories = [c for c in categories if c["severity"] == "high"]
risks = analyze_risk(text, high_risk_categories)
return {
"total_high_risks": len(risks),
"risk_level": "高" if len(risks) >= 3 else "中" if len(risks) >= 1 else "低",
"risks": risks,
"summary": _generate_risk_summary(risks),
}
def full_report(text: str) -> Dict[str, Any]:
"""生成完整的风险分析报告。
Args:
text: 合同文本内容。
Returns:
包含完整分析结果的字典。
"""
categories = get_available_categories()
risks = analyze_risk(text, categories)
sub = check_subscription()
high_risks = [r for r in risks if r["severity"] == "high"]
medium_risks = [r for r in risks if r["severity"] == "medium"]
low_risks = [r for r in risks if r["severity"] == "low"]
# 计算风险评分(0-100,越高越安全)
total_categories = len(categories)
risk_count = len(risks)
weighted_score = (
len(high_risks) * 3 + len(medium_risks) * 2 + len(low_risks) * 1
)
max_weighted = total_categories * 3
safety_score = max(0, round(100 - (weighted_score / max_weighted * 100)))
report = {
"tier": sub["tier"],
"categories_checked": len(categories),
"total_categories": 12,
"total_risks_found": risk_count,
"high_risk_count": len(high_risks),
"medium_risk_count": len(medium_risks),
"low_risk_count": len(low_risks),
"safety_score": safety_score,
"risk_level": _get_risk_level(safety_score),
"risks": risks,
"summary": _generate_risk_summary(risks),
}
if sub["tier"] == "free":
report["upgrade_hint"] = (
f"当前为免费版,仅检查了 {len(categories)}/{12} 类风险。"
f"升级至付费版(¥129/月)可检查全部 12 类风险条款。"
)
return report
def _get_risk_level(score: int) -> str:
"""根据安全评分获取风险等级。"""
if score >= 80:
return "低风险"
elif score >= 60:
return "中等风险"
elif score >= 40:
return "较高风险"
else:
return "高风险"
def _generate_risk_summary(risks: List[Dict[str, Any]]) -> str:
"""生成风险摘要文本。"""
if not risks:
return "未发现明显风险条款,合同整体较为安全。"
high = [r for r in risks if r["severity"] == "high"]
medium = [r for r in risks if r["severity"] == "medium"]
parts = []
if high:
names = "、".join(r["category_name"] for r in high)
parts.append(f"发现 {len(high)} 项高风险条款({names}),建议重点关注")
if medium:
names = "、".join(r["category_name"] for r in medium)
parts.append(f"发现 {len(medium)} 项中等风险条款({names}),建议审慎评估")
return ";".join(parts) + "。"
def main():
"""命令行入口。"""
parser = argparse.ArgumentParser(
description="合同风险条款识别工具 — 基于关键词和模式匹配识别 12 类风险",
)
parser.add_argument(
"--action",
required=True,
choices=["analyze", "quick-check", "full-report"],
help="操作类型: analyze(分析), quick-check(快速检查), full-report(完整报告)",
)
parser.add_argument(
"--text",
default=None,
help="合同文本内容(直接传入)",
)
parser.add_argument(
"--text-file",
default=None,
help="合同文本文件路径",
)
args = parser.parse_args()
try:
text = read_text_input(args)
if args.action == "analyze":
risks = analyze_risk(text)
output_success({
"total_risks": len(risks),
"risks": risks,
})
elif args.action == "quick-check":
result = quick_check(text)
output_success(result)
elif args.action == "full-report":
result = full_report(text)
output_success(result)
except ValueError as e:
output_error(str(e), "VALIDATION_ERROR")
except FileNotFoundError as e:
output_error(str(e), "FILE_NOT_FOUND")
except Exception as e:
output_error(f"分析失败: {e}", "ANALYSIS_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/key_info_extractor.py
#!/usr/bin/env python3
"""
contract-guardian 关键信息提取模块
使用正则表达式从合同文本中提取关键信息,包括甲乙方、金额、期限、违约金等。
"""
import argparse
import json
import os
import re
import sys
from typing import Any, Dict, List, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
format_currency,
output_error,
output_success,
parse_date,
read_text_input,
)
# ============================================================
# 提取模式定义
# ============================================================
def extract_party_a(text: str) -> Optional[str]:
"""提取甲方信息。"""
patterns = [
r"甲\s*方[::]\s*(.+?)(?:\n|$|(|签章|地址|联系)",
r"发包方[::]\s*(.+?)(?:\n|$|()",
r"委托方[::]\s*(.+?)(?:\n|$|()",
r"买方[::]\s*(.+?)(?:\n|$|()",
r"需方[::]\s*(.+?)(?:\n|$|()",
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
return match.group(1).strip()
return None
def extract_party_b(text: str) -> Optional[str]:
"""提取乙方信息。"""
patterns = [
r"乙\s*方[::]\s*(.+?)(?:\n|$|(|签章|地址|联系)",
r"承包方[::]\s*(.+?)(?:\n|$|()",
r"受托方[::]\s*(.+?)(?:\n|$|()",
r"卖方[::]\s*(.+?)(?:\n|$|()",
r"供方[::]\s*(.+?)(?:\n|$|()",
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
return match.group(1).strip()
return None
def extract_contract_amount(text: str) -> Optional[Dict[str, Any]]:
"""提取合同金额。"""
patterns = [
# ¥ 或 元 格式
r"(?:合同|总|合计|项目).*?(?:金额|价[格款]|费用|总[价额])[::为]?\s*(?:人民币)?\s*[¥¥]?\s*([\d,,]+(?:\.\d+)?)\s*(?:元|万元|万)",
r"[¥¥]\s*([\d,,]+(?:\.\d+)?)\s*(?:元|万元)?",
r"(?:人民币)\s*([\d,,]+(?:\.\d+)?)\s*(?:元|万元|万)",
r"(?:金额|价[格款]|费用|总[价额])[::为]?\s*([\d,,]+(?:\.\d+)?)\s*(?:元|万元|万)",
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
amount_str = match.group(1).replace(",", "").replace(",", "")
try:
amount = float(amount_str)
# 检查是否为万元
context = text[max(0, match.start() - 10):match.end() + 10]
if "万元" in context or "万" in context:
amount *= 10000
return {
"raw": match.group(0).strip(),
"amount": amount,
"formatted": format_currency(amount),
}
except ValueError:
continue
return None
def extract_dates(text: str) -> Dict[str, Optional[str]]:
"""提取合同日期信息(起止日期、签订日期)。"""
result = {
"start_date": None,
"end_date": None,
"signing_date": None,
}
# 合同期限 / 起止日期
period_patterns = [
r"(?:自|从)\s*(\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)\s*(?:起|开始)?\s*(?:至|到|—|-|-)\s*(\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)",
r"(?:期限|有效期)[::为]?\s*(\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)\s*(?:至|到|—|-|-)\s*(\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)",
]
for pattern in period_patterns:
match = re.search(pattern, text)
if match:
start = parse_date(match.group(1))
end = parse_date(match.group(2))
if start:
result["start_date"] = start.strftime("%Y-%m-%d")
if end:
result["end_date"] = end.strftime("%Y-%m-%d")
break
# 签订日期
sign_patterns = [
r"(?:签订|签署|签章).*?日期[::]\s*(\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)",
r"(\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)\s*(?:签订|签署|签章)",
r"(?:本合同|本协议).*?于\s*(\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)\s*(?:签订|签署|生效)",
]
for pattern in sign_patterns:
match = re.search(pattern, text)
if match:
dt = parse_date(match.group(1))
if dt:
result["signing_date"] = dt.strftime("%Y-%m-%d")
break
return result
def extract_penalty_clause(text: str) -> Optional[str]:
"""提取违约金条款。"""
patterns = [
r"(?:违约金|违约责任|违约赔偿)[::,,]?\s*(.{10,200}?)(?:\n|。|;)",
r"(?:逾期|迟延|延迟).*?(?:违约金|滞纳金)[::,,]?\s*(.{10,150}?)(?:\n|。|;)",
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
return match.group(0).strip()
return None
def extract_payment_terms(text: str) -> Optional[str]:
"""提取付款条件。"""
patterns = [
r"(?:付款|支付|结算)[方条]?[式件][::,,]?\s*(.{10,300}?)(?:\n\n|\n[一二三四五六七八九十])",
r"(?:付款|支付|结算).*?(?:方式|条件|期限)[::]\s*(.{10,200}?)(?:\n|。)",
]
for pattern in patterns:
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(0).strip()[:300]
return None
def extract_contact_info(text: str) -> Dict[str, Optional[str]]:
"""提取联系方式信息。"""
result = {
"phone": None,
"email": None,
"address": None,
}
# 电话
phone_match = re.search(r"(?:电话|联系电话|手机|Tel)[::]\s*([\d\-+() ]{7,20})", text)
if phone_match:
result["phone"] = phone_match.group(1).strip()
# 邮箱
email_match = re.search(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", text)
if email_match:
result["email"] = email_match.group(0)
# 地址
addr_patterns = [
r"(?:地址|住所|注册地址)[::]\s*(.{5,100}?)(?:\n|$|电话|邮编|联系)",
]
for pattern in addr_patterns:
match = re.search(pattern, text)
if match:
result["address"] = match.group(1).strip()
break
return result
def extract_key_info(text: str) -> Dict[str, Any]:
"""从合同文本中提取全部关键信息。
Args:
text: 合同文本。
Returns:
包含所有提取信息的字典。
"""
dates = extract_dates(text)
contact = extract_contact_info(text)
amount = extract_contract_amount(text)
return {
"party_a": extract_party_a(text),
"party_b": extract_party_b(text),
"contract_amount": amount,
"start_date": dates["start_date"],
"end_date": dates["end_date"],
"signing_date": dates["signing_date"],
"penalty_clause": extract_penalty_clause(text),
"payment_terms": extract_payment_terms(text),
"contact_info": contact,
}
def generate_summary(info: Dict[str, Any]) -> str:
"""根据提取的关键信息生成摘要文本。
Args:
info: extract_key_info 返回的字典。
Returns:
Markdown 格式的摘要文本。
"""
lines = ["## 合同关键信息摘要\n"]
party_a = info.get("party_a") or "未识别"
party_b = info.get("party_b") or "未识别"
lines.append(f"| 项目 | 内容 |")
lines.append(f"|------|------|")
lines.append(f"| 甲方 | {party_a} |")
lines.append(f"| 乙方 | {party_b} |")
amount = info.get("contract_amount")
if amount:
lines.append(f"| 合同金额 | {amount['formatted']} |")
else:
lines.append(f"| 合同金额 | 未识别 |")
start = info.get("start_date") or "未识别"
end = info.get("end_date") or "未识别"
signing = info.get("signing_date") or "未识别"
lines.append(f"| 合同期限 | {start} 至 {end} |")
lines.append(f"| 签订日期 | {signing} |")
penalty = info.get("penalty_clause")
if penalty:
# 截断过长内容
display = penalty[:80] + "..." if len(penalty) > 80 else penalty
lines.append(f"| 违约条款 | {display} |")
payment = info.get("payment_terms")
if payment:
display = payment[:80] + "..." if len(payment) > 80 else payment
lines.append(f"| 付款条件 | {display} |")
return "\n".join(lines)
def main():
"""命令行入口。"""
parser = argparse.ArgumentParser(
description="合同关键信息提取工具 — 提取甲乙方、金额、期限等关键信息",
)
parser.add_argument(
"--action",
required=True,
choices=["extract", "summary"],
help="操作类型: extract(提取信息), summary(生成摘要)",
)
parser.add_argument(
"--text",
default=None,
help="合同文本内容(直接传入)",
)
parser.add_argument(
"--text-file",
default=None,
help="合同文本文件路径",
)
args = parser.parse_args()
try:
text = read_text_input(args)
info = extract_key_info(text)
if args.action == "extract":
output_success(info)
elif args.action == "summary":
summary = generate_summary(info)
output_success({
"info": info,
"summary": summary,
})
except ValueError as e:
output_error(str(e), "VALIDATION_ERROR")
except FileNotFoundError as e:
output_error(str(e), "FILE_NOT_FOUND")
except Exception as e:
output_error(f"提取失败: {e}", "EXTRACTION_ERROR")
if __name__ == "__main__":
main()
FILE:references/contract-templates.md
# 合同模板库
本文档提供合同模板参考。免费版包含 3 个基础模板,付费版提供 20+ 行业模板。
---
## 免费模板
### 模板一:采购合同
```
采购合同
合同编号:[编号]
甲方(采购方):[公司名称]
地址:[地址]
联系人:[姓名] 电话:[电话]
乙方(供应方):[公司名称]
地址:[地址]
联系人:[姓名] 电话:[电话]
根据《中华人民共和国民法典》及相关法律法规,甲乙双方经平等协商,就甲方向乙方采购货物事宜达成如下协议:
第一条 采购内容
1.1 货物名称:[名称]
1.2 规格型号:[规格]
1.3 数量:[数量]
1.4 单价:人民币 [金额] 元
1.5 总价:人民币 [总金额] 元(大写:[大写金额])
第二条 质量标准
2.1 货物应符合国家标准/行业标准 [标准编号]。
2.2 乙方应提供产品质量合格证明。
第三条 交付方式与期限
3.1 交货地点:[地点]
3.2 交货期限:自合同签订之日起 [天数] 个工作日内
3.3 运输方式及费用:[说明]
第四条 验收
4.1 甲方应在收货后 [天数] 个工作日内完成验收。
4.2 验收标准:[标准说明]
4.3 验收不合格的,甲方应在验收期内书面通知乙方,乙方应在 [天数] 日内换货或维修。
第五条 付款方式
5.1 付款方式:[银行转账/支票/其他]
5.2 付款期限:甲方在验收合格后 [天数] 个工作日内支付全款。
5.3 乙方收款账户信息:
开户行:[银行名称]
账号:[账号]
户名:[户名]
第六条 违约责任
6.1 乙方逾期交货的,每逾期一日按合同总金额的万分之三向甲方支付违约金。
6.2 甲方逾期付款的,每逾期一日按未付金额的万分之三向乙方支付违约金。
6.3 违约金累计不超过合同总金额的 10%。
第七条 争议解决
因本合同引起的争议,双方应友好协商解决;协商不成的,提交 [仲裁机构] 仲裁。
第八条 合同期限
本合同自双方签章之日起生效,至双方义务全部履行完毕时终止。
第九条 其他
9.1 本合同一式两份,甲乙双方各持一份,具有同等法律效力。
9.2 未尽事宜,双方可另行签订补充协议。
甲方(盖章): 乙方(盖章):
授权代表: 授权代表:
日期: 年 月 日 日期: 年 月 日
```
---
### 模板二:服务合同
```
服务合同
合同编号:[编号]
甲方(委托方):[公司名称]
乙方(服务方):[公司名称]
第一条 服务内容
1.1 服务名称:[服务名称]
1.2 服务范围:[详细描述]
1.3 服务标准:[标准说明]
第二条 服务期限
自 [起始日期] 至 [结束日期],共计 [时长]。
第三条 服务费用及付款
3.1 服务费用总计:人民币 [金额] 元
3.2 付款方式:
- 合同签订后 [天数] 日内支付 [比例]% 作为预付款
- 服务完成验收后 [天数] 日内支付剩余 [比例]%
3.3 乙方应在收款后 [天数] 个工作日内开具等额发票。
第四条 双方权利义务
4.1 甲方应按时提供服务所需的资料和配合。
4.2 乙方应按约定标准提供服务,保证服务质量。
4.3 乙方应对服务过程中知悉的甲方商业信息严格保密。
第五条 验收
5.1 乙方完成服务后应书面通知甲方验收。
5.2 甲方应在 [天数] 个工作日内完成验收并出具验收意见。
5.3 甲方逾期未提出异议的,视为验收合格。
第六条 违约责任
6.1 任何一方违反本合同约定的,应向对方支付合同总金额 [比例]% 的违约金。
6.2 违约方还应赔偿因违约给对方造成的实际损失。
第七条 合同解除
7.1 经双方协商一致,可以解除本合同。
7.2 一方严重违约的,另一方有权书面通知解除合同。
第八条 争议解决
因本合同产生的争议,双方应友好协商解决;协商不成的,任一方可向合同签订地有管辖权的人民法院提起诉讼。
甲方(盖章): 乙方(盖章):
日期: 年 月 日 日期: 年 月 日
```
---
### 模板三:合作协议
```
合作协议
协议编号:[编号]
甲方:[公司/个人名称]
乙方:[公司/个人名称]
鉴于甲乙双方拟就 [合作项目] 开展合作,经友好协商,达成如下协议:
第一条 合作内容
1.1 合作项目:[项目名称/描述]
1.2 合作目标:[目标说明]
1.3 合作方式:[方式说明]
第二条 合作期限
自 [起始日期] 至 [结束日期]。期满后如需续约,双方应提前 [天数] 天协商。
第三条 双方职责
3.1 甲方负责:[职责清单]
3.2 乙方负责:[职责清单]
第四条 费用与收益分配
4.1 合作投入:
- 甲方投入:[说明]
- 乙方投入:[说明]
4.2 收益分配:按照甲方 [比例]%、乙方 [比例]% 的比例分配净收益。
4.3 结算周期:[月/季/年] 度结算一次。
第五条 知识产权
5.1 各方在合作前已有的知识产权归各方所有。
5.2 合作期间共同开发的成果,知识产权由双方共有。
5.3 未经对方书面同意,任何一方不得单独使用或转让共有知识产权。
第六条 保密条款
6.1 双方应对合作中知悉的对方商业秘密和技术秘密严格保密。
6.2 保密期限:合作期间及合作终止后 [年数] 年。
6.3 违反保密义务的一方应承担违约责任并赔偿损失。
第七条 违约责任
任何一方未履行本协议约定义务的,应向对方支付违约金人民币 [金额] 元,并赔偿实际损失。
第八条 协议终止
8.1 合作期满自动终止。
8.2 经双方协商一致可以提前终止。
8.3 一方严重违约,另一方有权书面通知终止本协议。
第九条 争议解决
因本协议引起的争议,提交 [仲裁机构] 仲裁解决。
甲方(签章): 乙方(签章):
日期: 年 月 日 日期: 年 月 日
```
---
## 付费版模板(¥129/月)
升级至付费版可获取以下 20+ 行业模板:
| 序号 | 模板名称 | 适用场景 |
|------|----------|----------|
| 1 | 技术开发合同 | 软件/系统定制开发 |
| 2 | 劳动合同 | 员工雇佣 |
| 3 | 房屋租赁合同 | 办公场所/住房租赁 |
| 4 | 保密协议(NDA) | 商业秘密保护 |
| 5 | 代理协议 | 销售/分销代理 |
| 6 | 咨询服务合同 | 管理/技术咨询 |
| 7 | 加盟合同 | 品牌特许加盟 |
| 8 | 投资协议 | 股权投资/融资 |
| 9 | 股权转让协议 | 股权买卖 |
| 10 | 借款合同 | 企业间借贷 |
| 11 | 广告合同 | 广告投放/推广 |
| 12 | 物流运输合同 | 货物运输 |
| 13 | 工程施工合同 | 建设施工项目 |
| 14 | 设计合同 | 平面/产品设计 |
| 15 | 培训合同 | 企业培训服务 |
| 16 | 竞业限制协议 | 员工竞业约束 |
| 17 | 股东协议 | 公司治理 |
| 18 | 知识产权许可协议 | 专利/商标授权 |
| 19 | 数据处理协议 | 数据合规 |
| 20 | 电商平台入驻协议 | 平台商家合作 |
> 如需升级,请联系管理员或访问订阅管理页面。
FILE:references/risk-checklist.md
# 合同风险条款清单(12类)
本清单涵盖合同审查中最常见的 12 类风险条款,供 AI 审查助手参考。
---
## 1. 单方解约权
**风险等级**: 高
**说明**: 合同中一方享有单方面解除合同的权利,可能导致对方在无过错的情况下遭受损失。
**常见关键词**: 单方解除、单方终止、有权解除、可随时终止、无需对方同意解除
**典型问题条款**:
- "甲方有权随时解除本合同,无需向乙方说明理由"
- "甲方可单方面终止合同,乙方不得提出异议"
- "任一方无需对方同意即可解除合同"
**风险评估要点**:
- 解约权是否对等(双方是否都有同等的解约权)
- 是否设定了提前通知期
- 解约后的损失赔偿条款是否明确
**建议对策**:
- 约定对等的解约条件
- 设定合理的提前通知期(如30天书面通知)
- 明确解约后的赔偿责任和交接义务
---
## 2. 自动续约
**风险等级**: 中
**说明**: 合同到期后自动续约,若未在规定时间内提出终止通知,将被动续约。
**常见关键词**: 自动续约、自动续期、自动延长、默认续期、视为同意续约
**典型问题条款**:
- "合同期满后自动续期一年,除非提前60天书面通知"
- "乙方未在期满前30天提出异议的,视为同意续约"
**风险评估要点**:
- 续约通知期是否合理
- 续约条件是否与原合同一致
- 是否允许修改续约条款
**建议对策**:
- 明确续约通知期限和方式
- 约定续约条款可协商调整
- 设置日历提醒,避免被动续约
---
## 3. 无限责任
**风险等级**: 高
**说明**: 合同要求一方承担无上限的赔偿责任,风险敞口不可控。
**常见关键词**: 全部损失、一切损失、全额赔偿、不设上限、无限责任
**典型问题条款**:
- "乙方应赔偿甲方因此遭受的全部损失,包括但不限于直接损失和间接损失"
- "赔偿金额不设上限"
**风险评估要点**:
- 是否设定赔偿上限(如合同总金额的百分比)
- 是否区分直接损失和间接损失
- 是否排除了不可预见的损失
**建议对策**:
- 设定赔偿上限(通常不超过合同总金额)
- 明确排除间接损失和预期利益损失
- 约定损失计算方式和举证责任
---
## 4. 竞业限制
**风险等级**: 高
**说明**: 限制一方在合同期内或终止后从事竞争业务,可能过度约束商业自由。
**常见关键词**: 竞业限制、竞业禁止、不得从事、不得经营、同业竞争
**典型问题条款**:
- "乙方在合同终止后两年内不得从事与甲方相同或类似的业务"
- "乙方不得直接或间接参与任何与甲方有竞争关系的企业"
**风险评估要点**:
- 限制范围是否明确且合理
- 限制期限是否过长
- 是否约定了竞业补偿金
**建议对策**:
- 限定明确的行业范围和地域范围
- 限制期限不宜超过2年
- 约定合理的竞业补偿金(法律要求)
---
## 5. 知识产权归属
**风险等级**: 高
**说明**: 合同约定的知识产权归属不明确或不公平,可能导致权利争议。
**常见关键词**: 知识产权归、著作权归、专利归、成果归属、版权归属
**典型问题条款**:
- "本合同项下产生的一切知识产权归甲方所有"
- "乙方在服务过程中产生的全部技术成果归甲方独有"
**风险评估要点**:
- 是否区分已有IP和新创IP
- 创作者是否保留合理的使用权
- 是否约定了知识产权转让的对价
**建议对策**:
- 明确区分双方各自的已有知识产权
- 约定合作期间产生的知识产权归属规则
- 保留创作方的合理使用权或许可权
---
## 6. 管辖地/仲裁
**风险等级**: 中
**说明**: 争议解决的管辖地或仲裁机构的选择可能对一方不利。
**常见关键词**: 管辖、仲裁、诉讼管辖、争议解决、仲裁委员会
**典型问题条款**:
- "因本合同引起的争议由甲方所在地人民法院管辖"
- "争议提交北京仲裁委员会仲裁"
**风险评估要点**:
- 管辖地是否对双方公平
- 仲裁vs诉讼的优劣比较
- 仲裁机构的选择是否合理
**建议对策**:
- 选择双方均可接受的中立管辖地
- 优先考虑仲裁(效率更高、一裁终局)
- 选择专业、知名的仲裁机构
---
## 7. 保密条款
**风险等级**: 中
**说明**: 保密义务的范围过广或期限过长,可能过度约束信息使用。
**常见关键词**: 保密义务、保密期限、保密信息、商业秘密、不得泄露
**典型问题条款**:
- "双方交换的一切信息均为保密信息,保密期限为永久"
- "乙方对甲方提供的所有资料承担永久保密义务"
**风险评估要点**:
- 保密信息的定义是否明确
- 保密期限是否合理
- 是否约定了保密信息的例外情形
**建议对策**:
- 明确保密信息的范围和标记方式
- 约定合理的保密期限(通常2-5年)
- 列明保密义务的例外情形(公开信息、法律要求等)
---
## 8. 付款条件
**风险等级**: 中
**说明**: 付款条件不明确或对一方不利,可能导致资金风险。
**常见关键词**: 付款条件、付款方式、付款期限、结算方式、账期
**典型问题条款**:
- "甲方在收到发票后90个工作日内付款"
- "付款以甲方最终验收合格为前提"
**风险评估要点**:
- 付款时间节点是否明确
- 账期是否合理
- 付款条件是否设置了过多前置条件
**建议对策**:
- 明确付款时间节点和金额
- 设置合理的账期(建议不超过30天)
- 约定逾期付款的违约责任
---
## 9. 验收标准
**风险等级**: 中
**说明**: 验收标准模糊或验收权完全由一方掌控,可能导致验收争议。
**常见关键词**: 验收标准、验收条件、交付标准、视为验收合格、默认验收
**典型问题条款**:
- "甲方对交付成果有最终验收决定权"
- "乙方交付后7天内甲方未提出异议视为验收合格"
**风险评估要点**:
- 验收标准是否客观、可量化
- 验收流程和时限是否明确
- 异议处理机制是否完善
**建议对策**:
- 制定详细、可量化的验收标准
- 明确验收期限和反馈流程
- 约定第三方鉴定机制处理争议
---
## 10. 违约金
**风险等级**: 高
**说明**: 违约金条款不对等或金额/比例过高。
**常见关键词**: 违约金、违约责任、违约赔偿、逾期违约、日万分之
**典型问题条款**:
- "乙方违约应支付合同总金额50%的违约金"
- "逾期每日按合同总金额万分之五计算违约金"
**风险评估要点**:
- 违约金比例是否合理(法律上通常不超过30%)
- 双方违约金条款是否对等
- 是否同时约定违约金和赔偿损失(可能重复计算)
**建议对策**:
- 违约金比例控制在合理范围
- 确保双方违约责任对等
- 避免违约金与损失赔偿的重复主张
---
## 11. 担保条款
**风险等级**: 高
**说明**: 担保范围过广、方式不当,可能带来超出预期的担保责任。
**常见关键词**: 担保、保证金、质押、抵押、连带责任、无限连带
**典型问题条款**:
- "第三方对乙方的全部义务承担无限连带担保责任"
- "担保范围包括主债权、利息、违约金、实现债权的费用等"
**风险评估要点**:
- 担保方式是否合理(一般保证vs连带保证)
- 担保范围是否明确且可控
- 担保期限是否合理
**建议对策**:
- 优先选择一般保证而非连带保证
- 明确担保金额上限
- 约定合理的担保期限
---
## 12. 不可抗力
**风险等级**: 低
**说明**: 不可抗力条款定义过窄(可能遗漏重要情形)或免责范围过广。
**常见关键词**: 不可抗力、自然灾害、政府行为、战争、疫情
**典型问题条款**:
- "不可抗力事件发生后,受影响方免除全部责任"
- "不可抗力仅限于自然灾害"(定义过窄)
**风险评估要点**:
- 不可抗力的定义是否全面
- 通知义务和减损义务是否明确
- 不可抗力的法律后果是否合理
**建议对策**:
- 扩展不可抗力的定义范围(包括疫情、政策变化等)
- 约定发生不可抗力后的通知时限和方式
- 明确减损义务和部分履行的安排
内容引擎 — 跨平台内容创作与分发工具,支持自学习优化、Obsidian 集成、AI 配图提示词
---
name: content-engine
description: 内容引擎 — 跨平台内容创作与分发工具,支持自学习优化、Obsidian 集成、AI 配图提示词
version: 1.1.0
metadata:
openclaw:
optional_env:
- CE_TWITTER_BEARER_TOKEN
- CE_LINKEDIN_ACCESS_TOKEN
- CE_WECHAT_APPID
- CE_WECHAT_SECRET
- CE_MEDIUM_TOKEN
- CE_BLOG_TYPE
- CE_BLOG_PATH
- CE_SUBSCRIPTION_TIER
- CE_OBSIDIAN_VAULT_PATH
---
# 内容引擎(content-engine)
你是一个专业的跨平台内容运营 Agent。你的职责是帮助用户创作、管理、适配和分发内容到 Twitter、LinkedIn、微信公众号、博客(Hugo/Jekyll/Hexo)和 Medium 等平台。你始终使用中文与用户沟通。
本 Skill 是 ClawHub 上首个支持微信公众号集成的内容分发工具。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `CE_TWITTER_BEARER_TOKEN` | 否 | Twitter API v2 Bearer Token |
| `CE_LINKEDIN_ACCESS_TOKEN` | 否 | LinkedIn API Access Token |
| `CE_WECHAT_APPID` | 否 | 微信公众号 AppID |
| `CE_WECHAT_SECRET` | 否 | 微信公众号 AppSecret |
| `CE_MEDIUM_TOKEN` | 否 | Medium Integration Token |
| `CE_BLOG_TYPE` | 否 | 博客引擎类型(hugo / jekyll / hexo),默认 hugo |
| `CE_BLOG_PATH` | 否 | 博客项目根目录路径 |
| `CE_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `CE_OBSIDIAN_VAULT_PATH` | 否 | Obsidian 笔记库路径,用于草稿导入导出 |
启动时,检查已配置的平台 Token。若用户尝试发布到未配置 Token 的平台,提示其配置对应的环境变量。
---
## 流程一:内容管理
当用户说"创建内容"、"管理文章"、"编辑内容"或类似意图时,执行以下操作:
### 创建内容
引导用户提供以下信息:
- **标题**(必填)
- **正文**(必填,支持 Markdown 格式)
- **摘要**(可选,用于微信公众号描述和社交媒体预览)
- **标签**(可选,用于生成 hashtag 和分类)
- **目标平台**(可选,支持: twitter, linkedin, wechat, blog, medium)
- **作者**(可选)
```bash
python3 scripts/content_store.py --action create --data '{"title":"...", "body":"...", "tags":["..."], "platforms":["twitter","wechat"]}'
```
### 管理内容
```bash
# 查看内容列表
python3 scripts/content_store.py --action list
# 按状态过滤
python3 scripts/content_store.py --action list --data '{"status":"草稿"}'
# 获取内容详情
python3 scripts/content_store.py --action get --data '{"id":"CT..."}'
# 更新内容
python3 scripts/content_store.py --action update --data '{"id":"CT...", "title":"新标题"}'
# 删除内容
python3 scripts/content_store.py --action delete --data '{"id":"CT..."}'
```
### 导入导出
```bash
# 从 Markdown 文件导入(支持 YAML frontmatter)
python3 scripts/content_store.py --action import --data '{"file_path":"./article.md"}'
# 导出为 Markdown
python3 scripts/content_store.py --action export --data '{"id":"CT...", "file_path":"./output.md"}'
```
### 内容状态流转
```
草稿 → 待审核 → 已排期 → 已发布 → 已归档
```
每个状态只能按规则流转,不可跳跃。
---
## 流程二:适配与发布
当用户说"发布到 Twitter"、"适配微信"、"分发内容"或类似意图时,执行以下步骤:
### 步骤 1:内容适配
将通用内容转换为目标平台格式:
```bash
# 适配到单个平台
python3 scripts/platform_adapter.py --action adapt --data '{"id":"CT...", "platform":"twitter"}'
# 预览适配效果
python3 scripts/platform_adapter.py --action preview --data '{"id":"CT...", "platform":"wechat"}'
# 校验内容是否符合平台要求
python3 scripts/platform_adapter.py --action validate --data '{"id":"CT...", "platform":"linkedin"}'
# 批量适配到多个平台(付费版)
python3 scripts/platform_adapter.py --action batch-adapt --data '{"id":"CT..."}'
```
各平台适配规则参考 `references/platform-specs.md`。
### 步骤 2:发布(付费版)
```bash
# 发布到指定平台
python3 scripts/publisher.py --action publish --data '{"id":"CT...", "platform":"twitter"}'
# 定时发布
python3 scripts/publisher.py --action schedule --data '{"id":"CT...", "platform":"wechat", "scheduled_at":"2026-03-20T18:00:00"}'
# 查看发布历史
python3 scripts/publisher.py --action list-published
# 撤回内容(标记归档)
python3 scripts/publisher.py --action unpublish --data '{"id":"CT..."}'
```
微信公众号发布流程参考 `references/wechat-guide.md`。
---
## 流程三:数据指标(付费版)
当用户说"查看数据"、"内容表现"、"指标报告"或类似意图时:
```bash
# 采集指标
python3 scripts/metrics_collector.py --action collect --data '{"content_id":"CT..."}'
# 生成报告
python3 scripts/metrics_collector.py --action report --data '{"content_id":"CT..."}'
# 对比多条内容
python3 scripts/metrics_collector.py --action compare --data '{"content_ids":["CT1","CT2"]}'
# 查看热门内容趋势(含 Mermaid 图表)
python3 scripts/metrics_collector.py --action trending
```
各平台采集的指标:
- **Twitter**: 点赞、转发、回复、曝光
- **LinkedIn**: 点赞、评论、分享、浏览
- **微信公众号**: 阅读、分享、收藏
- **Medium**: 阅读、鼓掌、回应
---
## 流程四:内容日历(付费版)
当用户说"规划日历"、"排期管理"、"发布计划"或类似意图时:
```bash
# 添加发布计划
python3 scripts/calendar_manager.py --action plan --data '{"content_id":"CT...", "platform":"twitter", "date":"2026-03-20"}'
# 查看周日历
python3 scripts/calendar_manager.py --action view --data '{"view":"week"}'
# 查看月日历
python3 scripts/calendar_manager.py --action view --data '{"view":"month"}'
# 获取最佳发布时间建议
python3 scripts/calendar_manager.py --action suggest --data '{"platform":"wechat", "date":"2026-03-20"}'
# 导出日历(Markdown 或 CSV)
python3 scripts/calendar_manager.py --action export --data '{"format":"markdown", "file_path":"./calendar.md"}'
```
付费版日历导出包含 Mermaid Gantt 时间线图。
---
## 流程五:自学习内容智能
当用户说"分析内容表现"、"推荐话题"、"什么时候发布最好"或类似意图时:
```bash
# 记录内容表现数据
python3 scripts/learning_engine.py --action record-performance --data '{"content_id":"CT...", "platform":"twitter", "metrics":{"likes":128,"retweets":45}, "topic":"AI编程", "posting_time":"2026-03-19T10:00:00"}'
# 记录用户偏好
python3 scripts/learning_engine.py --action record-preference --data '{"add_topic":"AI Agent", "add_style":"技术深度"}'
# 分析历史表现(哪些话题/时间/格式效果最好)
python3 scripts/learning_engine.py --action analyze
# 推荐下一个内容话题
python3 scripts/learning_engine.py --action suggest-topic --data '{"platform":"linkedin", "count":5}'
# 推荐最佳发布时间
python3 scripts/learning_engine.py --action suggest-timing --data '{"platform":"twitter"}'
# 查看内容表现统计面板
python3 scripts/learning_engine.py --action stats
```
学习引擎会自动:
- 在指标采集后记录到学习数据库
- 根据历史表现计算互动得分和互动率
- 识别高表现话题、时段和格式
---
## 流程六:Obsidian 草稿工作流
当用户说"从 Obsidian 导入"、"笔记转内容"、"同步笔记"或类似意图时:
```bash
# 连接到 Obsidian 笔记库
python3 scripts/obsidian_sync.py --action connect --data '{"vault_path":"~/MyVault"}'
# 列出标记为草稿的笔记(带 #content 或 #draft 标签)
python3 scripts/obsidian_sync.py --action list-drafts
# 导入一篇笔记为内容草稿
python3 scripts/obsidian_sync.py --action import-draft --data '{"file":"drafts/my-article.md"}'
# 导出内容回 Obsidian 笔记库
python3 scripts/obsidian_sync.py --action export-draft --data '{"title":"...", "body":"...", "ce_id":"CT...", "ce_status":"已发布"}'
# 双向同步(检测新草稿和已变更文件)
python3 scripts/obsidian_sync.py --action sync
# 也可通过 content_store 直接导入 Obsidian 笔记
python3 scripts/content_store.py --action import-obsidian --data '{"file":"drafts/my-article.md"}'
```
Obsidian 格式支持:
- `[[wikilinks]]` 自动转换为纯文本或 Markdown 链接
- `#tags` 提取为内容标签
- YAML frontmatter 解析为内容元数据
---
## 流程七:AI 配图助手
当用户说"生成配图"、"图片建议"、"配图提示词"或类似意图时:
```bash
# 根据内容生成 AI 图片提示词(Midjourney/DALL-E/SD 风格)
python3 scripts/image_prompter.py --action generate-prompt --data '{"text":"文章内容...", "title":"文章标题", "platform":"twitter"}'
# 分析内容,建议配图位置和内容
python3 scripts/image_prompter.py --action suggest-images --data '{"text":"完整文章...", "title":"标题"}'
# 生成 SEO 友好的 alt text
python3 scripts/image_prompter.py --action format-alt-text --data '{"description":"图片描述", "keywords":["AI","编程"]}'
# 创建完整的视觉内容规划(hero、章节、缩略图)
python3 scripts/image_prompter.py --action image-plan --data '{"text":"文章正文...", "title":"标题", "platforms":["blog","twitter","wechat"]}'
```
各平台推荐图片尺寸:
- **Twitter Card**: 1200x675 (16:9)
- **LinkedIn Post**: 1200x627 (1.91:1)
- **微信封面**: 900x383 (2.35:1)
- **博客 Hero**: 1200x630 (1.91:1)
- **Medium Feature**: 1400x788 (16:9)
提示词同时生成中英文版本,适配全球化内容需求。
---
## 订阅校验逻辑
### 读取订阅等级
```
tier = env CE_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥99/月) |
|------|---------------|----------------------|
| 内容数量上限 | 20 条 | 500 条 |
| 平台数量上限 | 2 个 | 5 个(全部) |
| 内容创建/编辑/删除 | 支持 | 支持 |
| 基础适配(预览) | 支持 | 支持 |
| 手动发布 | 支持 | 支持 |
| Markdown 导入导出 | 支持 | 支持 |
| AI 配图提示词生成 | 支持 | 支持 |
| Obsidian 笔记导入 | 支持 | 支持 |
| 自动发布到平台 | 不支持 | 支持 |
| 微信公众号 | 不支持 | 支持 |
| 批量适配 | 不支持 | 支持 |
| 定时发布 | 不支持 | 支持 |
| 数据指标采集 | 不支持 | 支持 |
| 自学习内容智能 | 不支持 | 支持 |
| 学习洞察分析 | 不支持 | 支持 |
| Obsidian 双向同步 | 不支持 | 支持 |
| 内容日历 | 不支持 | 支持 |
| Mermaid 图表 | 不支持 | 支持 |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥99/月),请联系管理员或访问订阅管理页面。"
4. 不要直接拒绝,而是提供免费版可用的替代方案(如果有的话)。
---
## 安全规范
1. **Token 保护**:所有平台 API Token 仅通过环境变量传递,绝不在对话中显示、记录或输出。
2. **HTML 安全**:微信公众号文章 HTML 经过清理,移除 script、iframe 等危险标签和 on* 事件属性。
3. **内容安全**:发布前校验内容格式,防止意外发布不完整或格式错误的内容。
4. **错误处理**:API 调用失败时,向用户展示友好的错误提示,不暴露内部路径或 Token 信息。
5. **数据安全**:所有内容数据存储在本地,不上传到云端。
---
## 参考文档
在进行平台适配和发布时,请参考以下文档:
- **平台规格**:`references/platform-specs.md` — 各平台的字符限制、图片要求和格式规则。
- **微信指南**:`references/wechat-guide.md` — 微信公众号 API 配置和使用指南。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 在发布内容前,先向用户展示适配后的预览并获得确认。
3. 对用户的内容给出改进建议,帮助提升各平台的传播效果。
4. 主动提醒不同平台的最佳实践(如 Twitter 的 hashtag 策略、LinkedIn 的专业语气)。
5. 遇到模糊的用户意图时,主动追问以明确需求。
6. 发布出错时,耐心排查并给出可行的解决方案。
7. 尊重订阅等级限制,在提示升级时保持友好,不反复推销。
8. 涉及微信公众号操作时,提醒用户参考 `references/wechat-guide.md` 完成前置配置。
FILE:assets/README.md
# Content Engine / 内容引擎
> Cross-platform content creation and distribution tool — the FIRST WeChat Official Account integration on ClawHub!
>
> 跨平台内容创作与分发工具 — ClawHub 首个支持微信公众号集成的内容引擎!
---
## Features / 功能亮点
- **Multi-platform Distribution / 多平台分发** — Write once, publish everywhere: Twitter, LinkedIn, WeChat Official Account, Blog (Hugo/Jekyll/Hexo), Medium
- **WeChat OA Integration / 微信公众号集成** — The FIRST skill on ClawHub to support WeChat Official Account publishing, including article formatting, author cards, and rich text
- **Smart Adaptation / 智能适配** — Automatically adapts content to each platform's format, character limits, and best practices
- **Self-Learning Engine / 自学习引擎** — Learn from content performance to continuously improve: topic suggestions, optimal posting times, engagement analysis
- **Obsidian Integration / Obsidian 集成** — Import drafts directly from your Obsidian vault, bidirectional sync, wikilinks and tags conversion
- **AI Image Prompter / AI 配图助手** — Generate AI image prompts (Midjourney/DALL-E/SD style) for your content, with platform-specific sizing and SEO alt text
- **Content Calendar / 内容日历** — Plan, schedule, and visualize your content pipeline with Mermaid Gantt charts
- **Performance Metrics / 数据指标** — Collect and compare engagement metrics across all platforms with learning-based insights
- **Markdown Native / Markdown 原生** — Create content in Markdown, import/export with YAML frontmatter support
- **Local Data / 本地数据** — All data stored locally, nothing leaves your environment
---
## Version Comparison / 版本对比
| Feature / 功能 | Free / 免费版 | Paid / 付费版 ¥99/月 |
|------|:------:|:------------:|
| Content limit / 内容上限 | 20 | 500 |
| Platforms / 平台数 | 2 | 5 (all / 全部) |
| Create & Edit / 创建编辑 | ✅ | ✅ |
| Basic Adaptation / 基础适配 | ✅ | ✅ |
| Markdown Import/Export / 导入导出 | ✅ | ✅ |
| AI Image Prompts / AI 配图提示词 | ✅ | ✅ |
| Obsidian Import / Obsidian 导入 | ✅ | ✅ |
| Auto Publish / 自动发布 | ❌ | ✅ |
| WeChat OA / 微信公众号 | ❌ | ✅ |
| Batch Adapt / 批量适配 | ❌ | ✅ |
| Scheduled Publish / 定时发布 | ❌ | ✅ |
| Metrics Collection / 指标采集 | ❌ | ✅ |
| Self-Learning Engine / 自学习引擎 | ❌ | ✅ |
| Learning Insights / 学习洞察 | ❌ | ✅ |
| Obsidian Bidirectional Sync / 双向同步 | ❌ | ✅ |
| Content Calendar / 内容日历 | ❌ | ✅ |
| Mermaid Charts / 可视化图表 | ❌ | ✅ |
---
## Quick Start / 快速开始
### 1. Install / 安装
Search `content-engine` on ClawHub and install, or use CLI:
在 ClawHub 搜索 `content-engine` 并安装,或使用命令行:
```bash
openclaw skill install content-engine
```
### 2. Configure Platforms / 配置平台
Set environment variables for the platforms you want to use:
配置你要使用的平台环境变量:
```bash
# Twitter
export CE_TWITTER_BEARER_TOKEN="your-bearer-token"
# LinkedIn
export CE_LINKEDIN_ACCESS_TOKEN="your-access-token"
# WeChat Official Account / 微信公众号
export CE_WECHAT_APPID="your-appid"
export CE_WECHAT_SECRET="your-appsecret"
# Medium
export CE_MEDIUM_TOKEN="your-integration-token"
# Blog (Hugo/Jekyll/Hexo)
export CE_BLOG_TYPE="hugo"
export CE_BLOG_PATH="/path/to/your/blog"
# Obsidian Vault / Obsidian 笔记库
export CE_OBSIDIAN_VAULT_PATH="~/MyVault"
# Subscription / 订阅等级
export CE_SUBSCRIPTION_TIER="paid"
```
### 3. Create & Publish / 创建并发布
```bash
# Create content / 创建内容
/content-engine create "My Article Title" --platforms twitter,linkedin,wechat
# Adapt to platforms / 适配到各平台
/content-engine adapt CT... --platform twitter
# Preview before publishing / 发布前预览
/content-engine preview CT... --platform wechat
# Publish / 发布
/content-engine publish CT... --platform twitter
```
### 4. Track Performance / 追踪表现
```bash
# Collect metrics / 采集指标
/content-engine metrics CT...
# View content calendar / 查看内容日历
/content-engine calendar week
```
---
## Example / 使用示例
### Content Creation Workflow / 内容创作工作流
```
User: 帮我创建一篇关于 AI 编程助手的文章,发布到 Twitter 和微信公众号
Agent: 好的,我来帮你创建内容...
[创建内容,生成 Twitter thread 和微信公众号 HTML 文章]
[展示各平台预览]
[获得确认后发布]
User: 查看这篇文章在各平台的表现
Agent: 正在采集指标数据...
Twitter: 128 点赞, 45 转发, 12 回复, 5,230 曝光
微信公众号: 2,340 阅读, 89 分享, 156 收藏
[生成对比图表]
```
---
## Supported Platforms / 支持的平台
| Platform / 平台 | Format / 格式 | Key Feature / 核心特性 |
|------|------|------|
| Twitter / X | Thread (280 chars/tweet) | Auto-split, hashtags, CJK char counting |
| LinkedIn | Professional post (3000 chars) | Professional tone, structured format |
| WeChat OA / 微信公众号 | HTML article | Rich text, author card, image refs |
| Blog | Markdown + frontmatter | Hugo, Jekyll, Hexo support |
| Medium | Markdown | Medium-compatible format, 5 tags |
---
## FAQ / 常见问题
### Q1: Is this really the first WeChat integration on ClawHub? / 这真的是 ClawHub 首个微信集成吗?
Yes! content-engine is the first OpenClaw skill to support WeChat Official Account API integration, enabling direct article publishing from the CLI.
是的!content-engine 是首个支持微信公众号 API 集成的 OpenClaw Skill,可以直接从命令行发布文章到公众号。
### Q2: Do I need all platform tokens configured? / 需要配置所有平台的 Token 吗?
No. Only configure the platforms you plan to use. Content creation and adaptation work without any tokens — tokens are only needed for publishing.
不需要。只配置你计划使用的平台即可。内容创建和适配不需要任何 Token — Token 仅在发布时需要。
### Q3: Is my data uploaded to the cloud? / 数据会上传到云端吗?
No. All content data is stored locally in `~/.openclaw-bdi/content-engine/`. API calls are made directly from your machine to each platform.
不会。所有内容数据存储在本地 `~/.openclaw-bdi/content-engine/`。API 调用从你的机器直接发送到各平台。
### Q4: How does the Chinese character counting work for Twitter? / Twitter 的中文字符计数怎么算?
Twitter counts CJK characters as 2 character positions. content-engine automatically handles this when splitting content into threads.
Twitter 将中日韩字符计为 2 个字符位。content-engine 在拆分 thread 时会自动处理这个规则。
### Q5: Can I use the free version for WeChat? / 免费版能用微信公众号吗?
WeChat Official Account integration is a paid feature. Free users can create and adapt content, but publishing to WeChat requires the paid tier (¥99/month).
微信公众号集成是付费功能。免费用户可以创建和适配内容,但发布到微信需要付费版(¥99/月)。
### Q6: What blog engines are supported? / 支持哪些博客引擎?
Hugo, Jekyll, and Hexo. Set `CE_BLOG_TYPE` to your engine type, and `CE_BLOG_PATH` to your blog project root. content-engine generates properly formatted Markdown with the correct frontmatter.
支持 Hugo、Jekyll 和 Hexo。设置 `CE_BLOG_TYPE` 为你的引擎类型,`CE_BLOG_PATH` 为博客项目根目录。content-engine 会生成正确格式的 Markdown 和 frontmatter。
---
## Technical Support / 技术支持
- **Docs / 文档**: See `references/` directory for platform specs and WeChat guide
- **Issues / 问题反馈**: Submit on ClawHub skill page
- **Community / 社区讨论**: Join `#content-engine` channel on ClawHub
- **Email / 邮件**: [email protected]
---
*content-engine v1.1.0 | Compatible with OpenClaw 0.5+ | First WeChat OA integration on ClawHub*
FILE:scripts/obsidian_sync.py
#!/usr/bin/env python3
"""
content-engine Obsidian 笔记库集成模块
连接 Obsidian 笔记库,实现草稿导入导出和双向同步,
打通笔记到内容发布的完整工作流。
"""
import json
import os
import re
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from utils import (
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
write_json_file,
)
# ============================================================
# 常量与配置
# ============================================================
SYNC_STATE_FILE = "obsidian_sync_state.json"
# Obsidian 笔记中标识为内容草稿的标签
DRAFT_TAGS = {"#content", "#draft", "#内容", "#草稿"}
# Obsidian frontmatter 中支持的字段映射
FRONTMATTER_FIELD_MAP = {
"title": "title",
"标题": "title",
"tags": "tags",
"标签": "tags",
"platforms": "platforms",
"平台": "platforms",
"author": "author",
"作者": "author",
"summary": "summary",
"description": "summary",
"摘要": "summary",
"status": "status",
"状态": "status",
"ce_id": "ce_id",
"ce_status": "ce_status",
"ce_published_at": "ce_published_at",
}
# ============================================================
# 笔记库路径管理
# ============================================================
def _get_vault_path() -> Optional[str]:
"""获取 Obsidian 笔记库路径。
优先读取环境变量 CE_OBSIDIAN_VAULT_PATH。
Returns:
笔记库路径,未配置时返回 None。
"""
vault = os.environ.get("CE_OBSIDIAN_VAULT_PATH", "")
if not vault:
return None
vault = os.path.expanduser(vault)
if not os.path.isdir(vault):
return None
return vault
def _get_sync_state() -> Dict[str, Any]:
"""读取同步状态数据。"""
filepath = get_data_file(SYNC_STATE_FILE)
data = read_json_file(filepath)
if isinstance(data, list):
data = {}
if not isinstance(data, dict):
data = {}
data.setdefault("vault_path", "")
data.setdefault("synced_files", {}) # {相对路径: {ce_id, last_sync, hash}}
data.setdefault("last_sync_at", "")
return data
def _save_sync_state(state: Dict[str, Any]) -> None:
"""保存同步状态数据。"""
write_json_file(get_data_file(SYNC_STATE_FILE), state)
# ============================================================
# Obsidian 格式解析
# ============================================================
def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
"""解析 Obsidian 笔记的 YAML frontmatter。
Args:
content: 笔记原始内容。
Returns:
(元数据字典, 正文内容) 的元组。
"""
metadata: Dict[str, Any] = {}
body = content
match = re.match(r"^---\s*\n(.*?)\n---\s*\n(.*)$", content, re.DOTALL)
if match:
fm_str = match.group(1)
body = match.group(2).strip()
# 简单 YAML 解析(不依赖 pyyaml)
current_key = ""
current_list: Optional[List[str]] = None
for line in fm_str.split("\n"):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
# 检测列表项(以 - 开头,属于当前 key)
if stripped.startswith("- ") and current_key and current_list is not None:
val = stripped[2:].strip().strip("\"'")
current_list.append(val)
continue
# 键值对
if ":" in stripped:
idx = stripped.index(":")
raw_key = stripped[:idx].strip().lower()
raw_val = stripped[idx + 1:].strip()
# 映射字段名
mapped_key = FRONTMATTER_FIELD_MAP.get(raw_key, raw_key)
if raw_val.startswith("[") and raw_val.endswith("]"):
# 行内数组
items = [v.strip().strip("\"'") for v in raw_val[1:-1].split(",") if v.strip()]
metadata[mapped_key] = items
current_key = ""
current_list = None
elif not raw_val:
# 可能是多行列表的开始
current_key = mapped_key
current_list = []
metadata[mapped_key] = current_list
else:
metadata[mapped_key] = raw_val.strip("\"'")
current_key = ""
current_list = None
return metadata, body
def _convert_wikilinks(text: str, mode: str = "plain") -> str:
"""转换 Obsidian [[wikilinks]] 为目标格式。
Args:
text: 包含 wikilinks 的文本。
mode: 转换模式:
- "plain": 转为纯文本
- "markdown": 转为 Markdown 链接(无实际 URL)
Returns:
转换后的文本。
"""
# [[显示文本|链接目标]] 格式
text = re.sub(r"\[\[([^|\]]+)\|([^\]]+)\]\]", _wikilink_replace(mode), text)
# [[链接目标]] 格式
text = re.sub(r"\[\[([^\]]+)\]\]", _simple_wikilink_replace(mode), text)
return text
def _wikilink_replace(mode: str):
"""返回 [[display|target]] 的替换函数。"""
def replacer(m: re.Match) -> str:
target = m.group(1)
display = m.group(2)
if mode == "markdown":
return f"[{display}]({target})"
return display
return replacer
def _simple_wikilink_replace(mode: str):
"""返回 [[target]] 的替换函数。"""
def replacer(m: re.Match) -> str:
target = m.group(1)
if mode == "markdown":
return f"[{target}]({target})"
return target
return replacer
def _extract_tags_from_body(text: str) -> List[str]:
"""从正文中提取 Obsidian #tags。
Args:
text: 笔记正文。
Returns:
标签列表(不含 # 前缀)。
"""
# 匹配 #tag 格式,排除标题中的 #
tags = re.findall(r"(?:^|\s)#([a-zA-Z\u4e00-\u9fff][\w\u4e00-\u9fff/-]*)", text)
return list(dict.fromkeys(tags)) # 去重保序
def _simple_hash(content: str) -> str:
"""计算内容的简单哈希值(用于变更检测)。
Args:
content: 文本内容。
Returns:
哈希字符串。
"""
# 使用简单的字符串哈希(避免依赖 hashlib 以外的库)
h = 0
for ch in content:
h = (h * 31 + ord(ch)) & 0xFFFFFFFF
return format(h, "08x")
def _build_frontmatter(metadata: Dict[str, Any]) -> str:
"""构建 YAML frontmatter 字符串。
Args:
metadata: 元数据字典。
Returns:
包含 --- 分隔符的 frontmatter 字符串。
"""
lines = ["---"]
for key, value in metadata.items():
if isinstance(value, list):
if value:
items_str = ", ".join(f'"{v}"' for v in value)
lines.append(f"{key}: [{items_str}]")
elif value:
lines.append(f'{key}: "{value}"')
lines.append("---")
return "\n".join(lines)
# ============================================================
# 操作:连接笔记库
# ============================================================
def connect(data: Dict[str, Any]) -> None:
"""连接到 Obsidian 笔记库。
可选字段: vault_path(笔记库路径,优先使用 CE_OBSIDIAN_VAULT_PATH 环境变量)
Args:
data: 包含可选笔记库路径的字典。
"""
vault_path = data.get("vault_path") or _get_vault_path()
if not vault_path:
output_error(
"未指定 Obsidian 笔记库路径。请设置环境变量 CE_OBSIDIAN_VAULT_PATH "
"或通过 vault_path 参数指定。",
code="CONFIG_ERROR",
)
return
vault_path = os.path.expanduser(vault_path)
if not os.path.isdir(vault_path):
output_error(f"笔记库路径不存在: {vault_path}", code="PATH_NOT_FOUND")
return
# 扫描笔记库概况
md_count = 0
draft_count = 0
for root, _dirs, files in os.walk(vault_path):
for f in files:
if f.endswith(".md"):
md_count += 1
fpath = os.path.join(root, f)
try:
with open(fpath, "r", encoding="utf-8") as fp:
content = fp.read(2000) # 只读前 2000 字符检查标签
if _is_draft_note(content):
draft_count += 1
except (IOError, UnicodeDecodeError):
continue
# 保存连接状态
state = _get_sync_state()
state["vault_path"] = vault_path
_save_sync_state(state)
output_success({
"message": f"已连接到 Obsidian 笔记库: {vault_path}",
"vault_path": vault_path,
"total_notes": md_count,
"draft_notes": draft_count,
})
def _is_draft_note(content: str) -> bool:
"""检查笔记内容是否标记为草稿。
Args:
content: 笔记内容(可以是部分内容)。
Returns:
True 表示是草稿笔记。
"""
content_lower = content.lower()
for tag in DRAFT_TAGS:
if tag.lower() in content_lower:
return True
# 检查 frontmatter 中的标签
match = re.match(r"^---\s*\n(.*?)\n---", content, re.DOTALL)
if match:
fm = match.group(1).lower()
if "content" in fm or "draft" in fm or "草稿" in fm or "内容" in fm:
return True
return False
# ============================================================
# 操作:列出草稿
# ============================================================
def list_drafts(data: Optional[Dict[str, Any]] = None) -> None:
"""列出 Obsidian 笔记库中的草稿笔记。
可选字段: vault_path
Args:
data: 可选的配置字典。
"""
vault_path = (data.get("vault_path") if data else None) or _get_vault_path()
if not vault_path:
# 尝试从同步状态读取
state = _get_sync_state()
vault_path = state.get("vault_path")
if not vault_path or not os.path.isdir(vault_path):
output_error("未连接到 Obsidian 笔记库,请先执行 connect 操作", code="NOT_CONNECTED")
return
drafts = []
for root, _dirs, files in os.walk(vault_path):
for f in files:
if not f.endswith(".md"):
continue
fpath = os.path.join(root, f)
try:
with open(fpath, "r", encoding="utf-8") as fp:
content = fp.read()
except (IOError, UnicodeDecodeError):
continue
if not _is_draft_note(content):
continue
metadata, body = _parse_frontmatter(content)
rel_path = os.path.relpath(fpath, vault_path)
title = metadata.get("title", "")
if not title:
# 从文件名或 H1 标题获取
h1_match = re.match(r"^#\s+(.+)$", body, re.MULTILINE)
if h1_match:
title = h1_match.group(1).strip()
else:
title = os.path.splitext(f)[0]
tags = metadata.get("tags", [])
body_tags = _extract_tags_from_body(body)
all_tags = list(dict.fromkeys(tags + body_tags))
drafts.append({
"file": rel_path,
"title": title,
"tags": all_tags,
"platforms": metadata.get("platforms", []),
"status": metadata.get("ce_status", "未导入"),
"ce_id": metadata.get("ce_id", ""),
"char_count": len(body),
"modified": _file_mtime(fpath),
})
# 按修改时间倒序
drafts.sort(key=lambda d: d.get("modified", ""), reverse=True)
output_success({
"message": f"找到 {len(drafts)} 篇草稿笔记",
"vault_path": vault_path,
"drafts": drafts,
})
def _file_mtime(filepath: str) -> str:
"""获取文件修改时间的 ISO 格式字符串。
Args:
filepath: 文件路径。
Returns:
ISO 格式时间字符串。
"""
try:
mtime = os.path.getmtime(filepath)
return datetime.fromtimestamp(mtime).strftime("%Y-%m-%dT%H:%M:%S")
except OSError:
return ""
# ============================================================
# 操作:导入草稿
# ============================================================
def import_draft(data: Dict[str, Any]) -> None:
"""从 Obsidian 笔记库导入一篇笔记作为内容草稿。
必填字段: file(笔记在库中的相对路径)
可选字段: vault_path
解析 frontmatter 中的 title, tags, platforms 等字段。
转换 [[wikilinks]] 为纯文本,提取 #tags 为内容标签。
Args:
data: 包含文件路径的字典。
"""
file_rel = data.get("file", "")
if not file_rel:
output_error("笔记文件路径(file)为必填字段", code="VALIDATION_ERROR")
return
vault_path = data.get("vault_path") or _get_vault_path()
if not vault_path:
state = _get_sync_state()
vault_path = state.get("vault_path")
if not vault_path or not os.path.isdir(vault_path):
output_error("未连接到 Obsidian 笔记库,请先执行 connect 操作", code="NOT_CONNECTED")
return
fpath = os.path.join(vault_path, file_rel)
if not os.path.exists(fpath):
output_error(f"笔记文件不存在: {file_rel}", code="FILE_NOT_FOUND")
return
try:
with open(fpath, "r", encoding="utf-8") as f:
raw_content = f.read()
except (IOError, UnicodeDecodeError) as e:
output_error(f"读取笔记失败: {e}", code="FILE_ERROR")
return
# 解析元数据和正文
metadata, body = _parse_frontmatter(raw_content)
# 转换 wikilinks
body = _convert_wikilinks(body, "plain")
# 提取标签
fm_tags = metadata.get("tags", [])
if isinstance(fm_tags, str):
fm_tags = [t.strip() for t in fm_tags.split(",") if t.strip()]
body_tags = _extract_tags_from_body(body)
all_tags = list(dict.fromkeys(fm_tags + body_tags))
# 移除草稿标记标签
all_tags = [t for t in all_tags if t.lower() not in {"content", "draft", "内容", "草稿"}]
# 从正文中移除 #tag(已提取到 tags 字段)
body = re.sub(r"(?:^|\s)#([a-zA-Z\u4e00-\u9fff][\w\u4e00-\u9fff/-]*)", " ", body).strip()
# 获取标题
title = metadata.get("title", "")
if not title:
h1_match = re.match(r"^#\s+(.+)$", body, re.MULTILINE)
if h1_match:
title = h1_match.group(1).strip()
# 从正文中移除 H1 标题行
body = re.sub(r"^#\s+.+\n*", "", body, count=1).strip()
else:
title = os.path.splitext(os.path.basename(file_rel))[0]
# 平台列表
platforms = metadata.get("platforms", [])
if isinstance(platforms, str):
platforms = [p.strip() for p in platforms.split(",") if p.strip()]
# 构建导入结果
content_data = {
"title": title,
"body": body,
"summary": metadata.get("summary", ""),
"tags": all_tags,
"platforms": platforms,
"author": metadata.get("author", ""),
"source": "obsidian",
"source_file": file_rel,
}
# 更新同步状态
state = _get_sync_state()
state["synced_files"][file_rel] = {
"last_sync": now_iso(),
"hash": _simple_hash(raw_content),
"direction": "import",
}
state["last_sync_at"] = now_iso()
_save_sync_state(state)
output_success({
"message": f"已从 Obsidian 导入笔记「{title}」",
"content": content_data,
"source_file": file_rel,
"note": "请使用 content_store.py --action create 创建内容,并将上述数据作为参数传入",
})
# ============================================================
# 操作:导出草稿
# ============================================================
def export_draft(data: Dict[str, Any]) -> None:
"""将内容导出回 Obsidian 笔记库。
必填字段: title, body
可选字段: file(目标文件相对路径), vault_path, tags, platforms,
author, summary, ce_id, ce_status, ce_published_at
生成带 frontmatter 的 .md 文件。
Args:
data: 包含内容数据的字典。
"""
title = data.get("title", "")
body = data.get("body", "")
if not title:
output_error("标题(title)为必填字段", code="VALIDATION_ERROR")
return
if not body:
output_error("正文(body)为必填字段", code="VALIDATION_ERROR")
return
vault_path = data.get("vault_path") or _get_vault_path()
if not vault_path:
state = _get_sync_state()
vault_path = state.get("vault_path")
if not vault_path or not os.path.isdir(vault_path):
output_error("未连接到 Obsidian 笔记库,请先执行 connect 操作", code="NOT_CONNECTED")
return
# 确定目标文件路径
file_rel = data.get("file", "")
if not file_rel:
# 自动生成文件名
safe_title = re.sub(r"[^\w\u4e00-\u9fff-]", "-", title)
safe_title = re.sub(r"-+", "-", safe_title).strip("-")
file_rel = f"{safe_title}.md"
fpath = os.path.join(vault_path, file_rel)
# 构建 frontmatter
fm_data = {"title": title}
if data.get("tags"):
tags = data["tags"]
if isinstance(tags, str):
tags = [t.strip() for t in tags.split(",") if t.strip()]
fm_data["tags"] = tags
if data.get("platforms"):
platforms = data["platforms"]
if isinstance(platforms, str):
platforms = [p.strip() for p in platforms.split(",") if p.strip()]
fm_data["platforms"] = platforms
if data.get("author"):
fm_data["author"] = data["author"]
if data.get("summary"):
fm_data["summary"] = data["summary"]
if data.get("ce_id"):
fm_data["ce_id"] = data["ce_id"]
if data.get("ce_status"):
fm_data["ce_status"] = data["ce_status"]
if data.get("ce_published_at"):
fm_data["ce_published_at"] = data["ce_published_at"]
frontmatter = _build_frontmatter(fm_data)
full_content = frontmatter + "\n\n" + body
# 写入文件
try:
os.makedirs(os.path.dirname(fpath) if os.path.dirname(fpath) != "" else fpath, exist_ok=True)
with open(fpath, "w", encoding="utf-8") as f:
f.write(full_content)
except IOError as e:
output_error(f"写入笔记失败: {e}", code="FILE_ERROR")
return
# 更新同步状态
state = _get_sync_state()
state["synced_files"][file_rel] = {
"last_sync": now_iso(),
"hash": _simple_hash(full_content),
"direction": "export",
"ce_id": data.get("ce_id", ""),
}
state["last_sync_at"] = now_iso()
_save_sync_state(state)
output_success({
"message": f"已导出内容到 Obsidian: {file_rel}",
"file": file_rel,
"vault_path": vault_path,
"full_path": fpath,
})
# ============================================================
# 操作:双向同步
# ============================================================
def sync(data: Optional[Dict[str, Any]] = None) -> None:
"""Obsidian 笔记库双向同步。
检测笔记库中的变更,导入新草稿,更新已发布内容的状态。
可选字段: vault_path, dry_run(仅检测不执行)
Args:
data: 可选的配置字典。
"""
vault_path = (data.get("vault_path") if data else None) or _get_vault_path()
if not vault_path:
state = _get_sync_state()
vault_path = state.get("vault_path")
if not vault_path or not os.path.isdir(vault_path):
output_error("未连接到 Obsidian 笔记库,请先执行 connect 操作", code="NOT_CONNECTED")
return
dry_run = data.get("dry_run", False) if data else False
state = _get_sync_state()
new_drafts = [] # 新发现的草稿
modified = [] # 已同步但有变更的文件
unchanged = [] # 未变更的文件
# 扫描笔记库
for root, _dirs, files in os.walk(vault_path):
for f in files:
if not f.endswith(".md"):
continue
fpath = os.path.join(root, f)
rel_path = os.path.relpath(fpath, vault_path)
try:
with open(fpath, "r", encoding="utf-8") as fp:
content = fp.read()
except (IOError, UnicodeDecodeError):
continue
if not _is_draft_note(content):
continue
current_hash = _simple_hash(content)
synced_info = state["synced_files"].get(rel_path)
if synced_info is None:
# 新草稿
metadata, body = _parse_frontmatter(content)
title = metadata.get("title", "")
if not title:
h1_match = re.match(r"^#\s+(.+)$", body, re.MULTILINE)
title = h1_match.group(1).strip() if h1_match else os.path.splitext(f)[0]
new_drafts.append({
"file": rel_path,
"title": title,
"hash": current_hash,
})
elif synced_info.get("hash") != current_hash:
# 已同步但有变更
modified.append({
"file": rel_path,
"ce_id": synced_info.get("ce_id", ""),
"old_hash": synced_info.get("hash", ""),
"new_hash": current_hash,
"last_sync": synced_info.get("last_sync", ""),
})
else:
unchanged.append(rel_path)
if not dry_run:
state["last_sync_at"] = now_iso()
_save_sync_state(state)
output_success({
"message": f"同步检测完成: {len(new_drafts)} 新草稿, {len(modified)} 已变更, {len(unchanged)} 未变更",
"dry_run": dry_run,
"vault_path": vault_path,
"new_drafts": new_drafts,
"modified": modified,
"unchanged_count": len(unchanged),
"note": "新草稿请使用 import-draft 导入,变更文件请手动确认后更新" if new_drafts or modified else "所有文件已同步",
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("content-engine Obsidian 笔记库集成")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"connect": lambda: connect(data or {}),
"import-draft": lambda: import_draft(data or {}),
"export-draft": lambda: export_draft(data or {}),
"list-drafts": lambda: list_drafts(data),
"sync": lambda: sync(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/content_store.py
#!/usr/bin/env python3
"""
content-engine 内容数据管理模块
提供内容数据的 CRUD 操作,支持 JSON 文件存储、Markdown 导入导出。
"""
import json
import os
import re
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
generate_id,
get_data_file,
load_input_data,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
validate_status,
validate_platform,
validate_status_transition,
format_platform_name,
truncate_text,
write_json_file,
CONTENT_STATUSES,
PLATFORMS,
)
# 延迟导入以避免循环依赖
_obsidian_sync = None
_learning_engine = None
def _get_obsidian_sync():
"""延迟导入 obsidian_sync 模块。"""
global _obsidian_sync
if _obsidian_sync is None:
try:
import obsidian_sync as _obsidian_sync
except ImportError:
_obsidian_sync = None
return _obsidian_sync
def _get_learning_engine():
"""延迟导入 learning_engine 模块。"""
global _learning_engine
if _learning_engine is None:
try:
import learning_engine as _learning_engine
except ImportError:
_learning_engine = None
return _learning_engine
# ============================================================
# 数据文件路径
# ============================================================
CONTENTS_FILE = "contents.json"
def _get_contents() -> List[Dict[str, Any]]:
"""读取所有内容数据。"""
return read_json_file(get_data_file(CONTENTS_FILE))
def _save_contents(contents: List[Dict[str, Any]]) -> None:
"""保存内容数据到文件。"""
write_json_file(get_data_file(CONTENTS_FILE), contents)
def _find_content(contents: List[Dict], content_id: str) -> Optional[Dict]:
"""根据 ID 查找内容。"""
for c in contents:
if c.get("id") == content_id:
return c
return None
# ============================================================
# CRUD 操作
# ============================================================
def create_content(data: Dict[str, Any]) -> None:
"""创建新内容。
必填字段: title, body
可选字段: summary, tags, platforms, status, author, scheduled_at
Args:
data: 内容数据字典。
"""
if not data.get("title"):
output_error("内容标题(title)为必填字段", code="VALIDATION_ERROR")
return
if not data.get("body"):
output_error("内容正文(body)为必填字段", code="VALIDATION_ERROR")
return
sub = check_subscription()
contents = _get_contents()
# 检查内容数量限制
if len(contents) >= sub["max_content"]:
limit = sub["max_content"]
if sub["tier"] == "free":
output_error(
f"免费版最多管理 {limit} 条内容,当前已有 {len(contents)} 条。"
"请升级至付费版(¥99/月)以管理更多内容。",
code="LIMIT_EXCEEDED",
)
else:
output_error(
f"已达到内容数量上限 {limit} 条。",
code="LIMIT_EXCEEDED",
)
return
# 校验状态
status = data.get("status", "草稿")
try:
validate_status(status)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
# 校验平台列表
platforms = data.get("platforms", [])
if isinstance(platforms, str):
platforms = [p.strip() for p in platforms.split(",") if p.strip()]
validated_platforms = []
for p in platforms:
try:
validated_platforms.append(validate_platform(p))
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
# 检查平台数量限制
if len(validated_platforms) > sub["max_platforms"]:
if sub["tier"] == "free":
output_error(
f"免费版最多选择 {sub['max_platforms']} 个平台。"
"请升级至付费版(¥99/月)以使用所有平台。",
code="LIMIT_EXCEEDED",
)
else:
output_error(
f"最多选择 {sub['max_platforms']} 个平台。",
code="LIMIT_EXCEEDED",
)
return
# 校验标签
tags = data.get("tags", [])
if isinstance(tags, str):
tags = [t.strip() for t in tags.split(",") if t.strip()]
# 图片提示词(可选,由 image_prompter 生成)
image_prompts = data.get("image_prompts", [])
now = now_iso()
content = {
"id": generate_id("CT"),
"title": data["title"],
"body": data["body"],
"summary": data.get("summary", ""),
"tags": tags,
"platforms": validated_platforms,
"status": status,
"author": data.get("author", ""),
"scheduled_at": data.get("scheduled_at", ""),
"published_at": "",
"publish_results": {},
"image_prompts": image_prompts,
"created_at": now,
"updated_at": now,
}
contents.append(content)
_save_contents(contents)
output_success({
"message": f"内容「{truncate_text(content['title'], 30)}」已创建",
"content": content,
})
def update_content(data: Dict[str, Any]) -> None:
"""更新内容信息。
必填字段: id
可更新字段: title, body, summary, tags, platforms, status, author, scheduled_at
Args:
data: 包含内容 ID 和待更新字段的字典。
"""
content_id = data.get("id")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
updatable_fields = ["title", "body", "summary", "tags", "platforms", "author", "scheduled_at", "image_prompts"]
updated = False
for field in updatable_fields:
if field in data:
value = data[field]
if field == "platforms":
if isinstance(value, str):
value = [p.strip() for p in value.split(",") if p.strip()]
validated = []
for p in value:
try:
validated.append(validate_platform(p))
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
# 检查平台数量限制
sub = check_subscription()
if len(validated) > sub["max_platforms"]:
output_error(
f"最多选择 {sub['max_platforms']} 个平台。",
code="LIMIT_EXCEEDED",
)
return
value = validated
elif field == "tags":
if isinstance(value, str):
value = [t.strip() for t in value.split(",") if t.strip()]
content[field] = value
updated = True
# 状态变更需要单独处理(校验流转规则)
if "status" in data:
new_status = data["status"]
try:
validate_status(new_status)
if new_status != content["status"]:
validate_status_transition(content["status"], new_status)
content["status"] = new_status
# 如果状态变为"已发布",记录发布时间
if new_status == "已发布":
content["published_at"] = now_iso()
updated = True
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
if not updated:
output_error("未提供任何待更新的字段", code="VALIDATION_ERROR")
return
content["updated_at"] = now_iso()
_save_contents(contents)
output_success({
"message": f"内容「{truncate_text(content['title'], 30)}」已更新",
"content": content,
})
def delete_content(data: Dict[str, Any]) -> None:
"""删除内容。
必填字段: id
Args:
data: 包含内容 ID 的字典。
"""
content_id = data.get("id")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
contents = _get_contents()
original_count = len(contents)
contents = [c for c in contents if c.get("id") != content_id]
if len(contents) == original_count:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
_save_contents(contents)
output_success({"message": f"内容 {content_id} 已删除"})
def get_content(data: Dict[str, Any]) -> None:
"""获取单条内容详情。
必填字段: id
Args:
data: 包含内容 ID 的字典。
"""
content_id = data.get("id")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
output_success(content)
def list_contents(data: Optional[Dict[str, Any]] = None) -> None:
"""列出所有内容。
可选过滤: status, platform, keyword, date_from, date_to
Args:
data: 可选的过滤条件字典。
"""
contents = _get_contents()
if data:
# 按状态过滤
status_filter = data.get("status")
if status_filter:
contents = [c for c in contents if c.get("status") == status_filter]
# 按平台过滤
platform_filter = data.get("platform")
if platform_filter:
platform_filter = platform_filter.lower()
contents = [c for c in contents if platform_filter in c.get("platforms", [])]
# 按关键词搜索(标题、正文、摘要、标签)
keyword = data.get("keyword", "").strip()
if keyword:
keyword_lower = keyword.lower()
contents = [
c for c in contents
if keyword_lower in c.get("title", "").lower()
or keyword_lower in c.get("body", "").lower()
or keyword_lower in c.get("summary", "").lower()
or any(keyword_lower in t.lower() for t in c.get("tags", []))
]
# 按日期范围过滤
date_from = data.get("date_from", "")
date_to = data.get("date_to", "")
if date_from:
contents = [c for c in contents if c.get("created_at", "") >= date_from]
if date_to:
contents = [c for c in contents if c.get("created_at", "") <= date_to + "T23:59:59"]
# 按更新时间倒序排列
contents.sort(key=lambda c: c.get("updated_at", ""), reverse=True)
# 按状态分组统计
status_stats = {}
for status in CONTENT_STATUSES:
status_stats[status] = sum(1 for c in contents if c.get("status") == status)
# 按平台分组统计
platform_stats = {}
for p in PLATFORMS:
platform_stats[format_platform_name(p)] = sum(
1 for c in contents if p in c.get("platforms", [])
)
# 列表中截断正文
display_list = []
for c in contents:
d = dict(c)
d["body"] = truncate_text(d.get("body", ""), 100)
display_list.append(d)
output_success({
"total": len(display_list),
"status_stats": status_stats,
"platform_stats": platform_stats,
"contents": display_list,
})
def import_content(data: Dict[str, Any]) -> None:
"""从 Markdown 文件导入内容。
支持带 YAML frontmatter 的 Markdown 文件。
必填字段: file_path
Args:
data: 包含文件路径的字典。
"""
file_path = data.get("file_path")
if not file_path:
output_error("文件路径(file_path)为必填字段", code="VALIDATION_ERROR")
return
if not os.path.exists(file_path):
output_error(f"文件不存在: {file_path}", code="FILE_NOT_FOUND")
return
sub = check_subscription()
contents = _get_contents()
if len(contents) >= sub["max_content"]:
output_error(
f"已达内容数量上限 {sub['max_content']} 条,无法导入。",
code="LIMIT_EXCEEDED",
)
return
try:
with open(file_path, "r", encoding="utf-8") as f:
raw = f.read()
except IOError as e:
output_error(f"文件读取失败: {e}", code="FILE_ERROR")
return
# 解析 YAML frontmatter
title = ""
tags = []
platforms = []
author = ""
summary = ""
body = raw
frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n(.*)$", raw, re.DOTALL)
if frontmatter_match:
frontmatter_str = frontmatter_match.group(1)
body = frontmatter_match.group(2).strip()
# 简单解析 YAML(不依赖 pyyaml)
for line in frontmatter_str.split("\n"):
line = line.strip()
if line.startswith("title:"):
title = line[len("title:"):].strip().strip("\"'")
elif line.startswith("tags:"):
tags_str = line[len("tags:"):].strip()
if tags_str.startswith("[") and tags_str.endswith("]"):
tags = [t.strip().strip("\"'") for t in tags_str[1:-1].split(",") if t.strip()]
elif line.startswith("platforms:"):
plat_str = line[len("platforms:"):].strip()
if plat_str.startswith("[") and plat_str.endswith("]"):
platforms = [p.strip().strip("\"'") for p in plat_str[1:-1].split(",") if p.strip()]
elif line.startswith("author:"):
author = line[len("author:"):].strip().strip("\"'")
elif line.startswith("summary:") or line.startswith("description:"):
key = "summary:" if line.startswith("summary:") else "description:"
summary = line[len(key):].strip().strip("\"'")
# 若未从 frontmatter 获取标题,尝试从正文第一个 # 标题获取
if not title:
title_match = re.match(r"^#\s+(.+)$", body, re.MULTILINE)
if title_match:
title = title_match.group(1).strip()
else:
title = os.path.splitext(os.path.basename(file_path))[0]
# 校验平台
validated_platforms = []
for p in platforms:
try:
validated_platforms.append(validate_platform(p))
except ValueError:
pass # 导入时忽略无效平台
now = now_iso()
content = {
"id": generate_id("CT"),
"title": title,
"body": body,
"summary": summary,
"tags": tags,
"platforms": validated_platforms,
"status": "草稿",
"author": author,
"scheduled_at": "",
"published_at": "",
"publish_results": {},
"created_at": now,
"updated_at": now,
}
contents.append(content)
_save_contents(contents)
output_success({
"message": f"已从 {os.path.basename(file_path)} 导入内容「{truncate_text(title, 30)}」",
"content": content,
})
def export_content(data: Dict[str, Any]) -> None:
"""导出内容为 Markdown 格式。
必填字段: id
可选字段: file_path(若不指定则输出到 stdout)
Args:
data: 包含内容 ID 和可选文件路径的字典。
"""
content_id = data.get("id")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
# 生成 Markdown + frontmatter
lines = ["---"]
lines.append(f'title: "{content["title"]}"')
if content.get("author"):
lines.append(f'author: "{content["author"]}"')
if content.get("summary"):
lines.append(f'summary: "{content["summary"]}"')
if content.get("tags"):
tags_str = ", ".join(f'"{t}"' for t in content["tags"])
lines.append(f"tags: [{tags_str}]")
if content.get("platforms"):
plat_str = ", ".join(f'"{p}"' for p in content["platforms"])
lines.append(f"platforms: [{plat_str}]")
lines.append(f'status: "{content["status"]}"')
lines.append(f'created_at: "{content["created_at"]}"')
lines.append("---")
lines.append("")
lines.append(content.get("body", ""))
markdown = "\n".join(lines)
file_path = data.get("file_path")
if file_path:
try:
os.makedirs(os.path.dirname(file_path) if os.path.dirname(file_path) else ".", exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
f.write(markdown)
output_success({
"message": f"已导出内容到 {file_path}",
"file_path": file_path,
})
except IOError as e:
output_error(f"导出失败: {e}", code="EXPORT_ERROR")
else:
output_success({
"markdown": markdown,
"content_id": content_id,
})
# ============================================================
# Obsidian 导入
# ============================================================
def import_obsidian(data: Dict[str, Any]) -> None:
"""从 Obsidian 笔记库导入草稿为内容。
必填字段: file(笔记在库中的相对路径)
可选字段: vault_path
内部调用 obsidian_sync 模块解析笔记,然后自动创建内容。
Args:
data: 包含文件路径的字典。
"""
obs = _get_obsidian_sync()
if obs is None:
output_error("Obsidian 同步模块不可用", code="MODULE_ERROR")
return
file_rel = data.get("file", "")
if not file_rel:
output_error("笔记文件路径(file)为必填字段", code="VALIDATION_ERROR")
return
vault_path = data.get("vault_path") or os.environ.get("CE_OBSIDIAN_VAULT_PATH", "")
if vault_path:
vault_path = os.path.expanduser(vault_path)
if not vault_path:
# 尝试从同步状态获取
state = obs._get_sync_state()
vault_path = state.get("vault_path", "")
if not vault_path or not os.path.isdir(vault_path):
output_error(
"未连接到 Obsidian 笔记库,请先设置 CE_OBSIDIAN_VAULT_PATH 或执行 obsidian_sync connect",
code="NOT_CONNECTED",
)
return
fpath = os.path.join(vault_path, file_rel)
if not os.path.exists(fpath):
output_error(f"笔记文件不存在: {file_rel}", code="FILE_NOT_FOUND")
return
try:
with open(fpath, "r", encoding="utf-8") as f:
raw_content = f.read()
except (IOError, UnicodeDecodeError) as e:
output_error(f"读取笔记失败: {e}", code="FILE_ERROR")
return
# 解析笔记
metadata, body = obs._parse_frontmatter(raw_content)
body = obs._convert_wikilinks(body, "plain")
# 提取标签
fm_tags = metadata.get("tags", [])
if isinstance(fm_tags, str):
fm_tags = [t.strip() for t in fm_tags.split(",") if t.strip()]
body_tags = obs._extract_tags_from_body(body)
all_tags = list(dict.fromkeys(fm_tags + body_tags))
all_tags = [t for t in all_tags if t.lower() not in {"content", "draft", "内容", "草稿"}]
# 清理正文中的标签标记
body = re.sub(r"(?:^|\s)#([a-zA-Z\u4e00-\u9fff][\w\u4e00-\u9fff/-]*)", " ", body).strip()
# 获取标题
title = metadata.get("title", "")
if not title:
h1_match = re.match(r"^#\s+(.+)$", body, re.MULTILINE)
if h1_match:
title = h1_match.group(1).strip()
body = re.sub(r"^#\s+.+\n*", "", body, count=1).strip()
else:
title = os.path.splitext(os.path.basename(file_rel))[0]
if not title or not body:
output_error("笔记缺少标题或正文内容", code="VALIDATION_ERROR")
return
# 创建内容
content_data = {
"title": title,
"body": body,
"summary": metadata.get("summary", ""),
"tags": all_tags,
"platforms": metadata.get("platforms", []),
"author": metadata.get("author", ""),
}
create_content(content_data)
# ============================================================
# 学习引擎集成:发布后记录基线数据
# ============================================================
def _record_publish_baseline(content: Dict[str, Any]) -> None:
"""发布后向学习引擎记录基线数据。
在内容状态变为"已发布"时调用,记录初始性能数据供后续分析。
Args:
content: 已发布的内容字典。
"""
le = _get_learning_engine()
if le is None:
return # 学习引擎不可用时静默跳过
for platform in content.get("platforms", []):
try:
le.record_performance({
"content_id": content.get("id", ""),
"platform": platform,
"topic": content.get("tags", [""])[0] if content.get("tags") else "",
"tags": content.get("tags", []),
"title": content.get("title", ""),
"posting_time": content.get("published_at", now_iso()),
"format": "article",
"length": len(content.get("body", "")),
"metrics": {}, # 基线为空,后续由 metrics_collector 填充
})
except Exception:
pass # 记录失败不影响主流程
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("content-engine 内容数据管理")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"create": lambda: create_content(data or {}),
"update": lambda: update_content(data or {}),
"delete": lambda: delete_content(data or {}),
"get": lambda: get_content(data or {}),
"list": lambda: list_contents(data),
"import": lambda: import_content(data or {}),
"export": lambda: export_content(data or {}),
"import-obsidian": lambda: import_obsidian(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/platform_adapter.py
#!/usr/bin/env python3
"""
content-engine 平台适配器模块
将通用内容适配为各平台特定格式,支持 Twitter、LinkedIn、微信公众号、博客、Medium。
"""
import json
import os
import re
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
count_chars,
format_platform_name,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
sanitize_html,
truncate_text,
validate_platform,
write_json_file,
PLATFORMS,
PLATFORM_CHAR_LIMITS,
)
# ============================================================
# Obsidian 格式处理
# ============================================================
def _convert_obsidian_wikilinks(text: str, mode: str = "plain") -> str:
"""转换 Obsidian [[wikilinks]] 为目标格式。
Args:
text: 包含 wikilinks 的文本。
mode: 转换模式 — "plain" 转纯文本,"hyperlink" 转 Markdown 链接。
Returns:
转换后的文本。
"""
# [[显示文本|链接目标]] 格式
if mode == "hyperlink":
text = re.sub(r"\[\[([^|\]]+)\|([^\]]+)\]\]", r"[\2](\1)", text)
text = re.sub(r"\[\[([^\]]+)\]\]", r"[\1](\1)", text)
else:
text = re.sub(r"\[\[([^|\]]+)\|([^\]]+)\]\]", r"\2", text)
text = re.sub(r"\[\[([^\]]+)\]\]", r"\1", text)
return text
def _add_image_prompt_placeholders(text: str, image_prompts: list) -> str:
"""在适配内容中插入配图提示词占位符。
Args:
text: 适配后的内容文本。
image_prompts: 图片提示词列表,每项包含 position, prompt 字段。
Returns:
插入占位符后的文本。
"""
if not image_prompts:
return text
# 在文本末尾追加配图建议
parts = [text, "", "---", ""]
for i, img in enumerate(image_prompts, 1):
prompt = img.get("prompt", img.get("description", ""))
position = img.get("position", f"位置{i}")
parts.append(f"[建议配图 {i} ({position}): {prompt}]")
return "\n".join(parts)
def _generate_seo_metadata(content: Dict[str, Any]) -> Dict[str, Any]:
"""为博客内容生成 SEO 元数据。
Args:
content: 内容字典,包含 title, body, tags, summary 等字段。
Returns:
包含 meta_description, keywords, og_title, og_description 的字典。
"""
title = content.get("title", "")
body = content.get("body", "")
summary = content.get("summary", "")
tags = content.get("tags", [])
# meta description: 优先使用摘要,否则截取正文前 160 字符
meta_desc = summary if summary else ""
if not meta_desc:
# 清理 Markdown 标记
clean_body = re.sub(r"[#*`\[\]()>]", "", body)
clean_body = re.sub(r"\s+", " ", clean_body).strip()
meta_desc = clean_body[:160]
if len(clean_body) > 160:
meta_desc = meta_desc[:157] + "..."
# keywords: 基于标签和标题
keywords = list(tags[:10])
# 从标题提取额外关键词
title_words = re.findall(r"[\w\u4e00-\u9fff]{2,}", title)
for w in title_words:
if w not in keywords and len(keywords) < 15:
keywords.append(w)
# Open Graph 数据
og_title = title
og_description = meta_desc[:200] if meta_desc else title
return {
"meta_description": meta_desc,
"keywords": keywords,
"og_title": og_title,
"og_description": og_description,
}
# ============================================================
# 数据文件路径
# ============================================================
CONTENTS_FILE = "contents.json"
ADAPTED_FILE = "adapted_contents.json"
def _get_contents() -> List[Dict[str, Any]]:
"""读取所有内容数据。"""
return read_json_file(get_data_file(CONTENTS_FILE))
def _find_content(contents: List[Dict], content_id: str) -> Optional[Dict]:
"""根据 ID 查找内容。"""
for c in contents:
if c.get("id") == content_id:
return c
return None
def _get_adapted() -> List[Dict[str, Any]]:
"""读取所有已适配内容。"""
return read_json_file(get_data_file(ADAPTED_FILE))
def _save_adapted(adapted: List[Dict[str, Any]]) -> None:
"""保存已适配内容到文件。"""
write_json_file(get_data_file(ADAPTED_FILE), adapted)
# ============================================================
# Twitter 适配
# ============================================================
def _adapt_twitter(content: Dict[str, Any]) -> Dict[str, Any]:
"""将内容适配为 Twitter 格式。
规则:
- 单条推文限 280 字符(中文字符占 2 位)
- 超长内容自动拆分为 thread
- 图片转为 alt text 描述
- 自动添加 hashtag
Args:
content: 原始内容字典。
Returns:
适配后的 Twitter 格式数据。
"""
body = content.get("body", "")
tags = content.get("tags", [])
title = content.get("title", "")
# 移除 Markdown 图片,转为 alt text
body = re.sub(r"!\[([^\]]*)\]\([^)]+\)", r"[\1]", body)
# 移除其他 Markdown 格式标记
body = re.sub(r"\*\*(.+?)\*\*", r"\1", body)
body = re.sub(r"\*(.+?)\*", r"\1", body)
body = re.sub(r"#{1,6}\s+", "", body)
body = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", body)
body = re.sub(r"```[\s\S]*?```", "[代码块]", body)
body = re.sub(r"`([^`]+)`", r"\1", body)
body = re.sub(r">\s+(.+)", r"\1", body)
body = re.sub(r"\n{3,}", "\n\n", body)
# 生成 hashtag 字符串
hashtags = " ".join(f"#{t.replace(' ', '')}" for t in tags[:5])
# 计算单条推文可用字符数
max_chars = 280
# 预留 hashtag 空间
hashtag_chars = count_chars(hashtags, "twitter") + 1 if hashtags else 0
available_chars = max_chars - hashtag_chars
# 拆分为 thread
text = f"{title}\n\n{body}".strip() if title else body.strip()
tweets = _split_twitter_thread(text, available_chars)
# 在最后一条推文末尾添加 hashtag
if hashtags and tweets:
last = tweets[-1]
if count_chars(last + "\n\n" + hashtags, "twitter") <= max_chars:
tweets[-1] = last + "\n\n" + hashtags
else:
tweets.append(hashtags)
# 如果是 thread,添加编号
if len(tweets) > 1:
numbered = []
for i, tweet in enumerate(tweets, 1):
prefix = f"({i}/{len(tweets)}) "
# 确保编号后不超限
prefix_chars = count_chars(prefix, "twitter")
if count_chars(tweet, "twitter") + prefix_chars > max_chars:
tweet = truncate_text(tweet, max_chars - prefix_chars - 3)
numbered.append(prefix + tweet)
tweets = numbered
return {
"platform": "twitter",
"platform_name": format_platform_name("twitter"),
"format": "thread" if len(tweets) > 1 else "tweet",
"tweets": tweets,
"tweet_count": len(tweets),
"hashtags": tags[:5],
"char_counts": [count_chars(t, "twitter") for t in tweets],
}
def _split_twitter_thread(text: str, max_chars: int) -> List[str]:
"""将文本拆分为 Twitter thread。
按段落拆分,每段不超过 max_chars 字符。
Args:
text: 原始文本。
max_chars: 每条推文最大字符数。
Returns:
推文列表。
"""
if count_chars(text, "twitter") <= max_chars:
return [text]
# 按段落拆分
paragraphs = text.split("\n\n")
tweets = []
current = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
test_text = (current + "\n\n" + para).strip() if current else para
if count_chars(test_text, "twitter") <= max_chars:
current = test_text
else:
if current:
tweets.append(current)
# 如果单个段落超限,按句子拆分
if count_chars(para, "twitter") > max_chars:
sentences = re.split(r"([。!?.!?])", para)
current = ""
for i in range(0, len(sentences) - 1, 2):
sentence = sentences[i] + (sentences[i + 1] if i + 1 < len(sentences) else "")
test_s = (current + sentence).strip() if current else sentence
if count_chars(test_s, "twitter") <= max_chars:
current = test_s
else:
if current:
tweets.append(current)
current = sentence
# 处理最后未配对的部分
if len(sentences) % 2 == 1 and sentences[-1].strip():
test_s = (current + sentences[-1]).strip() if current else sentences[-1]
if count_chars(test_s, "twitter") <= max_chars:
current = test_s
else:
if current:
tweets.append(current)
current = sentences[-1]
else:
current = para
if current:
tweets.append(current)
return tweets if tweets else [text[:max_chars]]
# ============================================================
# LinkedIn 适配
# ============================================================
def _adapt_linkedin(content: Dict[str, Any]) -> Dict[str, Any]:
"""将内容适配为 LinkedIn 格式。
规则:
- 专业语气格式
- 3000 字符限制
- 适当的换行和段落
- 添加行动号召
Args:
content: 原始内容字典。
Returns:
适配后的 LinkedIn 格式数据。
"""
title = content.get("title", "")
body = content.get("body", "")
tags = content.get("tags", [])
summary = content.get("summary", "")
# 移除 Markdown 格式但保留结构
text = body
text = re.sub(r"!\[([^\]]*)\]\([^)]+\)", "", text) # 移除图片
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) # 粗体
text = re.sub(r"\*(.+?)\*", r"\1", text) # 斜体
text = re.sub(r"#{1,6}\s+(.+)", r"\1", text) # 标题转纯文本
text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"\1 (\2)", text) # 链接
text = re.sub(r"```[\s\S]*?```", "", text) # 移除代码块
text = re.sub(r"`([^`]+)`", r"\1", text) # 行内代码
# 构建 LinkedIn 帖子
parts = []
# 标题(粗体效果用大写或符号标记)
if title:
parts.append(title)
parts.append("")
# 摘要或正文
if summary:
parts.append(summary)
parts.append("")
parts.append(text.strip())
# 添加 hashtag
if tags:
parts.append("")
parts.append(" ".join(f"#{t.replace(' ', '')}" for t in tags[:10]))
result_text = "\n".join(parts)
# 截断到 3000 字符
char_count = count_chars(result_text, "linkedin")
if char_count > 3000:
result_text = truncate_text(result_text, 3000)
char_count = count_chars(result_text, "linkedin")
return {
"platform": "linkedin",
"platform_name": format_platform_name("linkedin"),
"format": "post",
"text": result_text,
"char_count": char_count,
"char_limit": 3000,
"hashtags": tags[:10],
}
# ============================================================
# 微信公众号适配
# ============================================================
def _adapt_wechat(content: Dict[str, Any]) -> Dict[str, Any]:
"""将内容适配为微信公众号文章格式。
规则:
- HTML 文章格式
- 图片引用保留
- 作者信息卡片
- 富文本排版
Args:
content: 原始内容字典。
Returns:
适配后的微信公众号格式数据。
"""
title = content.get("title", "")
body = content.get("body", "")
author = content.get("author", "")
summary = content.get("summary", "")
tags = content.get("tags", [])
# Markdown 转 HTML
html = _markdown_to_html(body)
# 清理危险 HTML
html = sanitize_html(html)
# 构建完整的微信文章 HTML
article_parts = []
# 文章标题
article_parts.append(f"<h1>{_escape_html(title)}</h1>")
# 作者信息卡片
if author:
article_parts.append(
f'<div class="author-card">'
f'<span class="author-name">{_escape_html(author)}</span>'
f"</div>"
)
# 摘要
if summary:
article_parts.append(
f'<blockquote class="summary">{_escape_html(summary)}</blockquote>'
)
# 正文
article_parts.append(f'<div class="content">{html}</div>')
# 标签
if tags:
tag_html = " ".join(
f'<span class="tag">#{_escape_html(t)}</span>' for t in tags
)
article_parts.append(f'<div class="tags">{tag_html}</div>')
full_html = "\n".join(article_parts)
char_count = count_chars(body, "wechat")
# 提取图片引用
images = re.findall(r"!\[([^\]]*)\]\(([^)]+)\)", content.get("body", ""))
image_refs = [{"alt": alt, "url": url} for alt, url in images]
return {
"platform": "wechat",
"platform_name": format_platform_name("wechat"),
"format": "article",
"title": title,
"html": full_html,
"digest": truncate_text(summary or body, 120),
"author": author,
"char_count": char_count,
"image_refs": image_refs,
"tags": tags,
}
def _markdown_to_html(md: str) -> str:
"""简单的 Markdown 转 HTML。
仅处理常用格式,不依赖第三方库。
Args:
md: Markdown 文本。
Returns:
HTML 字符串。
"""
html = md
# 代码块
html = re.sub(
r"```(\w*)\n([\s\S]*?)```",
r'<pre><code class="language-\1">\2</code></pre>',
html,
)
# 行内代码
html = re.sub(r"`([^`]+)`", r"<code>\1</code>", html)
# 标题
html = re.sub(r"^######\s+(.+)$", r"<h6>\1</h6>", html, flags=re.MULTILINE)
html = re.sub(r"^#####\s+(.+)$", r"<h5>\1</h5>", html, flags=re.MULTILINE)
html = re.sub(r"^####\s+(.+)$", r"<h4>\1</h4>", html, flags=re.MULTILINE)
html = re.sub(r"^###\s+(.+)$", r"<h3>\1</h3>", html, flags=re.MULTILINE)
html = re.sub(r"^##\s+(.+)$", r"<h2>\1</h2>", html, flags=re.MULTILINE)
html = re.sub(r"^#\s+(.+)$", r"<h1>\1</h1>", html, flags=re.MULTILINE)
# 粗体和斜体
html = re.sub(r"\*\*\*(.+?)\*\*\*", r"<strong><em>\1</em></strong>", html)
html = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", html)
html = re.sub(r"\*(.+?)\*", r"<em>\1</em>", html)
# 图片
html = re.sub(
r"!\[([^\]]*)\]\(([^)]+)\)",
r'<img src="\2" alt="\1" />',
html,
)
# 链接
html = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r'<a href="\2">\1</a>', html)
# 引用
html = re.sub(r"^>\s+(.+)$", r"<blockquote>\1</blockquote>", html, flags=re.MULTILINE)
# 无序列表
html = re.sub(r"^[-*+]\s+(.+)$", r"<li>\1</li>", html, flags=re.MULTILINE)
# 段落(连续空行分隔)
paragraphs = html.split("\n\n")
processed = []
for p in paragraphs:
p = p.strip()
if not p:
continue
# 已经被标签包裹的不再加 <p>
if re.match(r"^<(h[1-6]|pre|blockquote|li|ul|ol|div|img)", p):
processed.append(p)
else:
processed.append(f"<p>{p}</p>")
html = "\n".join(processed)
return html
def _escape_html(text: str) -> str:
"""转义 HTML 特殊字符。
Args:
text: 原始文本。
Returns:
转义后的文本。
"""
text = text.replace("&", "&")
text = text.replace("<", "<")
text = text.replace(">", ">")
text = text.replace('"', """)
return text
# ============================================================
# 博客适配
# ============================================================
def _adapt_blog(content: Dict[str, Any]) -> Dict[str, Any]:
"""将内容适配为博客格式。
支持 Hugo、Jekyll、Hexo 三种博客引擎,通过 CE_BLOG_TYPE 环境变量指定。
Args:
content: 原始内容字典。
Returns:
适配后的博客格式数据。
"""
blog_type = os.environ.get("CE_BLOG_TYPE", "hugo").lower()
title = content.get("title", "")
body = content.get("body", "")
tags = content.get("tags", [])
author = content.get("author", "")
summary = content.get("summary", "")
now = now_iso()
if blog_type == "jekyll":
# Jekyll frontmatter 使用 YAML
frontmatter = _build_jekyll_frontmatter(title, tags, author, summary, now)
elif blog_type == "hexo":
# Hexo frontmatter
frontmatter = _build_hexo_frontmatter(title, tags, author, summary, now)
else:
# Hugo frontmatter(默认)
frontmatter = _build_hugo_frontmatter(title, tags, author, summary, now)
markdown = frontmatter + "\n" + body
# 生成文件名建议
slug = re.sub(r"[^\w\u4e00-\u9fff-]", "-", title.lower())
slug = re.sub(r"-+", "-", slug).strip("-")
date_prefix = datetime.now().strftime("%Y-%m-%d")
filename = f"{date_prefix}-{slug}.md" if blog_type == "jekyll" else f"{slug}.md"
return {
"platform": "blog",
"platform_name": format_platform_name("blog"),
"format": f"markdown-{blog_type}",
"blog_type": blog_type,
"markdown": markdown,
"suggested_filename": filename,
"char_count": len(body),
}
def _build_hugo_frontmatter(title: str, tags: List[str], author: str, summary: str, date: str) -> str:
"""生成 Hugo 格式的 frontmatter。"""
lines = ["---"]
lines.append(f'title: "{title}"')
lines.append(f"date: {date}")
if author:
lines.append(f'author: "{author}"')
if summary:
lines.append(f'description: "{summary}"')
if tags:
lines.append("tags:")
for t in tags:
lines.append(f' - "{t}"')
lines.append("draft: false")
lines.append("---")
return "\n".join(lines)
def _build_jekyll_frontmatter(title: str, tags: List[str], author: str, summary: str, date: str) -> str:
"""生成 Jekyll 格式的 frontmatter。"""
lines = ["---"]
lines.append("layout: post")
lines.append(f'title: "{title}"')
lines.append(f"date: {date}")
if author:
lines.append(f'author: "{author}"')
if summary:
lines.append(f'excerpt: "{summary}"')
if tags:
tags_str = ", ".join(tags)
lines.append(f"tags: [{tags_str}]")
lines.append("---")
return "\n".join(lines)
def _build_hexo_frontmatter(title: str, tags: List[str], author: str, summary: str, date: str) -> str:
"""生成 Hexo 格式的 frontmatter。"""
lines = ["---"]
lines.append(f"title: {title}")
lines.append(f"date: {date}")
if author:
lines.append(f"author: {author}")
if summary:
lines.append(f"description: {summary}")
if tags:
lines.append("tags:")
for t in tags:
lines.append(f" - {t}")
lines.append("---")
return "\n".join(lines)
# ============================================================
# Medium 适配
# ============================================================
def _adapt_medium(content: Dict[str, Any]) -> Dict[str, Any]:
"""将内容适配为 Medium 兼容的 Markdown 格式。
规则:
- Medium 兼容 Markdown
- 保留图片和链接
- 添加标签
Args:
content: 原始内容字典。
Returns:
适配后的 Medium 格式数据。
"""
title = content.get("title", "")
body = content.get("body", "")
tags = content.get("tags", [])
# Medium Markdown 基本兼容,只需微调
parts = []
# 标题
if title:
parts.append(f"# {title}")
parts.append("")
# 正文
parts.append(body)
# 标签(Medium 最多 5 个标签)
if tags:
parts.append("")
parts.append("---")
parts.append("")
parts.append("Tags: " + ", ".join(tags[:5]))
markdown = "\n".join(parts)
char_count = len(markdown)
return {
"platform": "medium",
"platform_name": format_platform_name("medium"),
"format": "markdown",
"title": title,
"markdown": markdown,
"tags": tags[:5],
"char_count": char_count,
}
# ============================================================
# 适配入口
# ============================================================
# 平台适配器注册表
_ADAPTERS = {
"twitter": _adapt_twitter,
"linkedin": _adapt_linkedin,
"wechat": _adapt_wechat,
"blog": _adapt_blog,
"medium": _adapt_medium,
}
def adapt_content(data: Dict[str, Any]) -> None:
"""将内容适配到指定平台。
必填字段: id, platform
Args:
data: 包含内容 ID 和目标平台的字典。
"""
content_id = data.get("id")
platform = data.get("platform", "")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
if not platform:
output_error("目标平台(platform)为必填字段", code="VALIDATION_ERROR")
return
try:
platform = validate_platform(platform)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
# 微信公众号需要付费版
if platform == "wechat":
if not require_paid_feature("wechat", "微信公众号适配"):
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
adapter = _ADAPTERS.get(platform)
if not adapter:
output_error(f"暂不支持平台: {platform}", code="UNSUPPORTED_PLATFORM")
return
# 预处理: 转换 Obsidian wikilinks
processed_content = dict(content)
if processed_content.get("body"):
processed_content["body"] = _convert_obsidian_wikilinks(
processed_content["body"],
"hyperlink" if platform in ("blog", "medium") else "plain",
)
result = adapter(processed_content)
# 添加配图提示词占位符(如果内容中有 image_prompts)
image_prompts = content.get("image_prompts", [])
if image_prompts:
result["image_prompts"] = image_prompts
# 为博客平台生成 SEO 元数据
if platform == "blog":
result["seo_metadata"] = _generate_seo_metadata(content)
# 保存适配结果
adapted = _get_adapted()
result["content_id"] = content_id
result["adapted_at"] = now_iso()
# 替换已有的同内容同平台适配
adapted = [
a for a in adapted
if not (a.get("content_id") == content_id and a.get("platform") == platform)
]
adapted.append(result)
_save_adapted(adapted)
output_success({
"message": f"已将内容适配为 {format_platform_name(platform)} 格式",
"adapted": result,
})
def preview_content(data: Dict[str, Any]) -> None:
"""预览内容在指定平台的适配效果。
必填字段: id, platform
Args:
data: 包含内容 ID 和目标平台的字典。
"""
content_id = data.get("id")
platform = data.get("platform", "")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
if not platform:
output_error("目标平台(platform)为必填字段", code="VALIDATION_ERROR")
return
try:
platform = validate_platform(platform)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
adapter = _ADAPTERS.get(platform)
if not adapter:
output_error(f"暂不支持平台: {platform}", code="UNSUPPORTED_PLATFORM")
return
result = adapter(content)
output_success({
"message": f"{format_platform_name(platform)} 预览",
"preview": result,
})
def validate_content(data: Dict[str, Any]) -> None:
"""校验内容是否满足指定平台的要求。
必填字段: id, platform
Args:
data: 包含内容 ID 和目标平台的字典。
"""
content_id = data.get("id")
platform = data.get("platform", "")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
if not platform:
output_error("目标平台(platform)为必填字段", code="VALIDATION_ERROR")
return
try:
platform = validate_platform(platform)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
issues = []
warnings = []
body = content.get("body", "")
title = content.get("title", "")
char_count = count_chars(body, platform)
limit = PLATFORM_CHAR_LIMITS.get(platform, 0)
# 通用校验
if not title:
issues.append("缺少标题")
if not body:
issues.append("缺少正文内容")
# 平台特定校验
if platform == "twitter":
if char_count > 280 * 25: # 超过 25 条推文的 thread
warnings.append(f"内容过长({char_count} 字符),将拆分为较长的 thread")
if not content.get("tags"):
warnings.append("建议添加标签以生成 hashtag")
elif platform == "linkedin":
if limit > 0 and char_count > limit:
issues.append(f"内容超过 LinkedIn 限制({char_count}/{limit} 字符)")
if not content.get("summary"):
warnings.append("建议添加摘要以提升专业度")
elif platform == "wechat":
if not content.get("author"):
warnings.append("建议设置作者信息")
if not content.get("summary"):
warnings.append("建议添加摘要作为文章描述")
if limit > 0 and char_count > limit:
issues.append(f"内容超过微信公众号限制({char_count}/{limit} 字符)")
elif platform == "blog":
blog_type = os.environ.get("CE_BLOG_TYPE", "hugo")
if blog_type not in ("hugo", "jekyll", "hexo"):
warnings.append(f"未识别的博客类型: {blog_type},将使用 Hugo 格式")
elif platform == "medium":
if not content.get("tags"):
warnings.append("建议添加标签(Medium 最多 5 个)")
if len(content.get("tags", [])) > 5:
warnings.append("Medium 最多支持 5 个标签,多余标签将被忽略")
is_valid = len(issues) == 0
output_success({
"content_id": content_id,
"platform": platform,
"platform_name": format_platform_name(platform),
"is_valid": is_valid,
"issues": issues,
"warnings": warnings,
"char_count": char_count,
"char_limit": limit if limit > 0 else "无限制",
})
def batch_adapt_content(data: Dict[str, Any]) -> None:
"""批量适配内容到多个平台。
必填字段: id
可选字段: platforms(默认使用内容已设置的平台列表)
Args:
data: 包含内容 ID 和可选平台列表的字典。
"""
if not require_paid_feature("batch_adapt", "批量适配"):
return
content_id = data.get("id")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
# 确定目标平台
target_platforms = data.get("platforms", content.get("platforms", []))
if isinstance(target_platforms, str):
target_platforms = [p.strip() for p in target_platforms.split(",") if p.strip()]
if not target_platforms:
output_error("未指定目标平台,请在内容中设置 platforms 或通过参数指定", code="VALIDATION_ERROR")
return
results = []
errors = []
adapted = _get_adapted()
for platform in target_platforms:
try:
platform = validate_platform(platform)
except ValueError as e:
errors.append({"platform": platform, "error": str(e)})
continue
adapter = _ADAPTERS.get(platform)
if not adapter:
errors.append({"platform": platform, "error": f"暂不支持平台: {platform}"})
continue
try:
result = adapter(content)
result["content_id"] = content_id
result["adapted_at"] = now_iso()
# 更新已适配列表
adapted = [
a for a in adapted
if not (a.get("content_id") == content_id and a.get("platform") == platform)
]
adapted.append(result)
results.append({
"platform": platform,
"platform_name": format_platform_name(platform),
"status": "success",
})
except Exception as e:
errors.append({"platform": platform, "error": str(e)})
_save_adapted(adapted)
output_success({
"message": f"批量适配完成: {len(results)} 成功, {len(errors)} 失败",
"content_id": content_id,
"results": results,
"errors": errors,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("content-engine 平台适配器")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"adapt": lambda: adapt_content(data or {}),
"preview": lambda: preview_content(data or {}),
"validate": lambda: validate_content(data or {}),
"batch-adapt": lambda: batch_adapt_content(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/learning_engine.py
#!/usr/bin/env python3
"""
content-engine 自学习内容智能模块
基于历史内容表现数据,持续学习并优化内容策略。
支持记录表现、分析趋势、智能推荐话题和发布时间。
"""
import json
import math
import os
import re
import sys
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from utils import (
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
write_json_file,
)
# ============================================================
# 数据文件路径
# ============================================================
LEARNING_FILE = "learning.json"
def _get_learning_data() -> Dict[str, Any]:
"""读取学习数据。返回包含 performances, preferences, metadata 的字典。"""
filepath = get_data_file(LEARNING_FILE)
data = read_json_file(filepath)
if isinstance(data, list):
# 兼容旧格式或空文件
data = {}
if not isinstance(data, dict):
data = {}
# 确保必要的顶层键存在
data.setdefault("performances", [])
data.setdefault("preferences", {
"preferred_styles": [],
"preferred_topics": [],
"preferred_platforms": [],
"rejected_suggestions": [],
})
data.setdefault("metadata", {
"total_records": 0,
"first_record_at": "",
"last_record_at": "",
"analysis_count": 0,
})
return data
def _save_learning_data(data: Dict[str, Any]) -> None:
"""保存学习数据。"""
write_json_file(get_data_file(LEARNING_FILE), data)
# ============================================================
# 辅助计算函数
# ============================================================
def _calc_engagement_score(metrics: Dict[str, Any]) -> float:
"""计算综合互动得分。
根据不同指标类型赋予不同权重:
- 曝光/浏览/阅读:权重 0.1
- 点赞/鼓掌/收藏:权重 1.0
- 评论/回复/回应:权重 2.0
- 转发/分享:权重 3.0
Args:
metrics: 指标数据字典。
Returns:
综合互动得分(浮点数)。
"""
weights = {
"impressions": 0.1,
"views": 0.1,
"reads": 0.1,
"likes": 1.0,
"claps": 1.0,
"favorites": 1.0,
"comments": 2.0,
"replies": 2.0,
"responses": 2.0,
"retweets": 3.0,
"shares": 3.0,
}
score = 0.0
for key, value in metrics.items():
if isinstance(value, (int, float)) and key in weights:
score += value * weights[key]
return round(score, 2)
def _calc_engagement_rate(metrics: Dict[str, Any]) -> float:
"""计算互动率(互动数 / 曝光数)。
Args:
metrics: 指标数据字典。
Returns:
互动率(百分比),如 3.5 表示 3.5%。
"""
exposure_keys = ["impressions", "views", "reads"]
exposure = 0
for k in exposure_keys:
if isinstance(metrics.get(k), (int, float)):
exposure += metrics[k]
if exposure == 0:
return 0.0
interaction_keys = ["likes", "claps", "favorites", "comments", "replies",
"responses", "retweets", "shares"]
interactions = 0
for k in interaction_keys:
if isinstance(metrics.get(k), (int, float)):
interactions += metrics[k]
return round((interactions / exposure) * 100, 2)
def _extract_hour(time_str: str) -> Optional[int]:
"""从 ISO 时间字符串提取小时数。
Args:
time_str: ISO 格式时间字符串。
Returns:
小时数(0-23),解析失败返回 None。
"""
if not time_str:
return None
try:
if "T" in time_str:
dt = datetime.fromisoformat(time_str.replace("Z", ""))
return dt.hour
except (ValueError, TypeError):
pass
return None
def _extract_weekday(time_str: str) -> Optional[int]:
"""从 ISO 时间字符串提取星期几。
Args:
time_str: ISO 格式时间字符串。
Returns:
星期几(0=周一, 6=周日),解析失败返回 None。
"""
if not time_str:
return None
try:
if "T" in time_str:
dt = datetime.fromisoformat(time_str.replace("Z", ""))
else:
dt = datetime.strptime(time_str[:10], "%Y-%m-%d")
return dt.weekday()
except (ValueError, TypeError):
pass
return None
_WEEKDAY_NAMES = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
def _group_by(records: List[Dict], key: str) -> Dict[str, List[Dict]]:
"""按指定键对记录进行分组。
Args:
records: 记录列表。
key: 分组键名。
Returns:
分组后的字典。
"""
groups: Dict[str, List[Dict]] = defaultdict(list)
for r in records:
val = r.get(key, "未知")
if isinstance(val, list):
for v in val:
groups[str(v)].append(r)
else:
groups[str(val)].append(r)
return dict(groups)
def _avg_score(records: List[Dict]) -> float:
"""计算记录列表的平均互动得分。
Args:
records: 包含 engagement_score 字段的记录列表。
Returns:
平均互动得分。
"""
if not records:
return 0.0
total = sum(r.get("engagement_score", 0) for r in records)
return round(total / len(records), 2)
def _avg_rate(records: List[Dict]) -> float:
"""计算记录列表的平均互动率。
Args:
records: 包含 engagement_rate 字段的记录列表。
Returns:
平均互动率。
"""
if not records:
return 0.0
total = sum(r.get("engagement_rate", 0) for r in records)
return round(total / len(records), 2)
# ============================================================
# 操作:记录内容表现
# ============================================================
def record_performance(data: Dict[str, Any]) -> None:
"""记录一条内容的表现数据。
必填字段: content_id, platform, metrics
可选字段: topic, posting_time, format, length, tags, title
Args:
data: 内容表现数据字典。
"""
content_id = data.get("content_id") or data.get("id")
if not content_id:
output_error("内容ID(content_id)为必填字段", code="VALIDATION_ERROR")
return
platform = data.get("platform", "")
if not platform:
output_error("平台(platform)为必填字段", code="VALIDATION_ERROR")
return
metrics = data.get("metrics", {})
if not metrics:
output_error("指标数据(metrics)为必填字段", code="VALIDATION_ERROR")
return
learning = _get_learning_data()
# 计算互动得分和互动率
engagement_score = _calc_engagement_score(metrics)
engagement_rate = _calc_engagement_rate(metrics)
record = {
"content_id": content_id,
"platform": platform,
"topic": data.get("topic", ""),
"tags": data.get("tags", []),
"title": data.get("title", ""),
"posting_time": data.get("posting_time", ""),
"format": data.get("format", ""),
"length": data.get("length", 0),
"metrics": metrics,
"engagement_score": engagement_score,
"engagement_rate": engagement_rate,
"recorded_at": now_iso(),
}
# 替换已有的同内容同平台记录(保留最新)
performances = learning["performances"]
performances = [
p for p in performances
if not (p.get("content_id") == content_id and p.get("platform") == platform)
]
performances.append(record)
learning["performances"] = performances
# 更新元数据
now = now_iso()
learning["metadata"]["total_records"] = len(performances)
learning["metadata"]["last_record_at"] = now
if not learning["metadata"]["first_record_at"]:
learning["metadata"]["first_record_at"] = now
_save_learning_data(learning)
output_success({
"message": f"已记录内容 {content_id} 在 {platform} 的表现数据",
"engagement_score": engagement_score,
"engagement_rate": engagement_rate,
"record": record,
})
# ============================================================
# 操作:记录用户偏好
# ============================================================
def record_preference(data: Dict[str, Any]) -> None:
"""记录用户偏好设置。
可设置字段:
- preferred_styles: 偏好的写作风格列表
- preferred_topics: 偏好的内容话题列表
- preferred_platforms: 偏好的平台列表
- rejected_suggestions: 被拒绝的建议列表(避免重复推荐)
- add_style / add_topic / add_platform / add_rejected: 追加单项
Args:
data: 偏好设置数据字典。
"""
learning = _get_learning_data()
prefs = learning["preferences"]
updated_fields = []
# 批量设置
for field in ["preferred_styles", "preferred_topics", "preferred_platforms", "rejected_suggestions"]:
if field in data:
val = data[field]
if isinstance(val, str):
val = [v.strip() for v in val.split(",") if v.strip()]
prefs[field] = val
updated_fields.append(field)
# 追加单项
add_map = {
"add_style": "preferred_styles",
"add_topic": "preferred_topics",
"add_platform": "preferred_platforms",
"add_rejected": "rejected_suggestions",
}
for add_key, target_field in add_map.items():
if add_key in data:
val = data[add_key]
if isinstance(val, str):
val = [v.strip() for v in val.split(",") if v.strip()]
elif not isinstance(val, list):
val = [str(val)]
for item in val:
if item not in prefs[target_field]:
prefs[target_field].append(item)
updated_fields.append(target_field)
if not updated_fields:
output_error("未提供任何偏好设置字段", code="VALIDATION_ERROR")
return
learning["preferences"] = prefs
_save_learning_data(learning)
output_success({
"message": f"已更新偏好设置: {', '.join(set(updated_fields))}",
"preferences": prefs,
})
# ============================================================
# 操作:分析
# ============================================================
def analyze(data: Optional[Dict[str, Any]] = None) -> None:
"""分析历史内容表现,识别最佳实践。
可选字段: platform(按平台过滤), limit(结果数量限制)
生成以下维度的分析:
- 按话题: 哪些话题表现最好
- 按平台: 各平台的平均表现
- 按发布时间: 哪些时段互动最高
- 按格式: 哪种内容格式效果最好
Args:
data: 可选的过滤条件字典。
"""
learning = _get_learning_data()
performances = learning["performances"]
if not performances:
output_error("暂无历史表现数据,请先使用 record-performance 记录数据", code="NO_DATA")
return
# 平台过滤
platform_filter = data.get("platform") if data else None
if platform_filter:
performances = [p for p in performances if p.get("platform") == platform_filter]
if not performances:
output_error(f"平台 {platform_filter} 暂无历史表现数据", code="NO_DATA")
return
analysis = {}
# 1. 按话题分析
topic_groups = _group_by(performances, "topic")
topic_analysis = []
for topic, records in topic_groups.items():
if topic == "未知" or not topic:
continue
topic_analysis.append({
"topic": topic,
"count": len(records),
"avg_score": _avg_score(records),
"avg_rate": _avg_rate(records),
})
topic_analysis.sort(key=lambda x: x["avg_score"], reverse=True)
analysis["by_topic"] = topic_analysis[:10]
# 2. 按平台分析
platform_groups = _group_by(performances, "platform")
platform_analysis = []
for platform, records in platform_groups.items():
platform_analysis.append({
"platform": platform,
"count": len(records),
"avg_score": _avg_score(records),
"avg_rate": _avg_rate(records),
})
platform_analysis.sort(key=lambda x: x["avg_score"], reverse=True)
analysis["by_platform"] = platform_analysis
# 3. 按发布时间(小时)分析
hour_buckets: Dict[int, List[Dict]] = defaultdict(list)
for p in performances:
hour = _extract_hour(p.get("posting_time", ""))
if hour is not None:
hour_buckets[hour].append(p)
time_analysis = []
for hour in sorted(hour_buckets.keys()):
records = hour_buckets[hour]
time_analysis.append({
"hour": hour,
"time_range": f"{hour:02d}:00-{hour:02d}:59",
"count": len(records),
"avg_score": _avg_score(records),
"avg_rate": _avg_rate(records),
})
time_analysis.sort(key=lambda x: x["avg_score"], reverse=True)
analysis["by_time"] = time_analysis
# 4. 按星期几分析
weekday_buckets: Dict[int, List[Dict]] = defaultdict(list)
for p in performances:
wd = _extract_weekday(p.get("posting_time", ""))
if wd is not None:
weekday_buckets[wd].append(p)
weekday_analysis = []
for wd in sorted(weekday_buckets.keys()):
records = weekday_buckets[wd]
weekday_analysis.append({
"weekday": wd,
"weekday_name": _WEEKDAY_NAMES[wd],
"count": len(records),
"avg_score": _avg_score(records),
"avg_rate": _avg_rate(records),
})
weekday_analysis.sort(key=lambda x: x["avg_score"], reverse=True)
analysis["by_weekday"] = weekday_analysis
# 5. 按格式分析
format_groups = _group_by(performances, "format")
format_analysis = []
for fmt, records in format_groups.items():
if fmt == "未知" or not fmt:
continue
format_analysis.append({
"format": fmt,
"count": len(records),
"avg_score": _avg_score(records),
"avg_rate": _avg_rate(records),
})
format_analysis.sort(key=lambda x: x["avg_score"], reverse=True)
analysis["by_format"] = format_analysis
# 6. 生成洞察摘要
insights = _generate_insights(analysis, performances)
analysis["insights"] = insights
# 更新分析计数
learning["metadata"]["analysis_count"] = learning["metadata"].get("analysis_count", 0) + 1
_save_learning_data(learning)
output_success({
"message": f"已分析 {len(performances)} 条内容表现记录",
"analysis": analysis,
})
def _generate_insights(analysis: Dict, performances: List[Dict]) -> List[str]:
"""根据分析结果生成自然语言洞察。
Args:
analysis: 分析结果字典。
performances: 原始表现记录列表。
Returns:
洞察列表(字符串)。
"""
insights = []
overall_avg = _avg_score(performances)
# 话题洞察
by_topic = analysis.get("by_topic", [])
if by_topic and len(by_topic) >= 2:
best = by_topic[0]
ratio = round(best["avg_score"] / overall_avg, 1) if overall_avg > 0 else 0
if ratio > 1.5:
insights.append(
f"「{best['topic']}」相关内容平均互动得分 {best['avg_score']},"
f"是整体均值的 {ratio} 倍,建议多产出相关内容"
)
# 平台洞察
by_platform = analysis.get("by_platform", [])
if len(by_platform) >= 2:
best_plat = by_platform[0]
worst_plat = by_platform[-1]
if best_plat["avg_score"] > worst_plat["avg_score"] * 2:
insights.append(
f"{best_plat['platform']} 平台表现最佳(均分 {best_plat['avg_score']}),"
f"建议优先在该平台发布内容"
)
# 时间洞察
by_time = analysis.get("by_time", [])
if by_time:
best_time = by_time[0]
insights.append(
f"最佳发布时段为 {best_time['time_range']},"
f"平均互动得分 {best_time['avg_score']}(互动率 {best_time['avg_rate']}%)"
)
# 星期洞察
by_weekday = analysis.get("by_weekday", [])
if by_weekday:
best_wd = by_weekday[0]
insights.append(
f"{best_wd['weekday_name']}发布的内容表现最好,"
f"平均互动得分 {best_wd['avg_score']}"
)
# 格式洞察
by_format = analysis.get("by_format", [])
if by_format and len(by_format) >= 2:
best_fmt = by_format[0]
insights.append(
f"「{best_fmt['format']}」格式的内容效果最好,"
f"平均互动得分 {best_fmt['avg_score']}"
)
if not insights:
insights.append("数据量较少,建议持续记录更多内容表现以获得更准确的分析")
return insights
# ============================================================
# 操作:推荐话题
# ============================================================
def suggest_topic(data: Optional[Dict[str, Any]] = None) -> None:
"""基于历史表现数据推荐下一个内容话题。
可选字段: platform, count(推荐数量,默认 5)
考虑因素:
- 历史高互动话题
- 用户偏好的话题
- 避开已拒绝的建议
- 相关话题拓展
Args:
data: 可选的过滤条件字典。
"""
learning = _get_learning_data()
performances = learning["performances"]
prefs = learning["preferences"]
count = data.get("count", 5) if data else 5
platform_filter = data.get("platform") if data else None
if platform_filter:
performances = [p for p in performances if p.get("platform") == platform_filter]
suggestions = []
rejected = set(prefs.get("rejected_suggestions", []))
# 1. 从高表现话题中推荐
topic_groups = _group_by(performances, "topic")
topic_scores = []
for topic, records in topic_groups.items():
if not topic or topic == "未知":
continue
if topic in rejected:
continue
avg = _avg_score(records)
topic_scores.append((topic, avg, len(records)))
topic_scores.sort(key=lambda x: x[1], reverse=True)
for topic, avg_score, cnt in topic_scores[:count]:
reason = f"历史 {cnt} 篇相关内容平均互动得分 {avg_score}"
if platform_filter:
reason += f"({platform_filter} 平台)"
suggestions.append({
"topic": topic,
"reason": reason,
"confidence": "高" if cnt >= 3 else "中",
"avg_score": avg_score,
"sample_count": cnt,
})
# 2. 从用户偏好话题中补充
preferred = prefs.get("preferred_topics", [])
existing_topics = {s["topic"] for s in suggestions}
for topic in preferred:
if len(suggestions) >= count:
break
if topic in existing_topics or topic in rejected:
continue
suggestions.append({
"topic": topic,
"reason": "用户偏好话题",
"confidence": "中",
"avg_score": 0,
"sample_count": 0,
})
# 3. 从标签中挖掘潜在话题
tag_counter: Dict[str, int] = defaultdict(int)
tag_scores: Dict[str, float] = defaultdict(float)
for p in performances:
score = p.get("engagement_score", 0)
for tag in p.get("tags", []):
tag_counter[tag] += 1
tag_scores[tag] += score
tag_avg = []
for tag, cnt in tag_counter.items():
if tag in existing_topics or tag in rejected:
continue
if cnt >= 2:
tag_avg.append((tag, round(tag_scores[tag] / cnt, 2), cnt))
tag_avg.sort(key=lambda x: x[1], reverse=True)
for tag, avg, cnt in tag_avg:
if len(suggestions) >= count:
break
existing_topics.add(tag)
suggestions.append({
"topic": tag,
"reason": f"高互动标签({cnt} 次出现,均分 {avg})",
"confidence": "中" if cnt >= 3 else "低",
"avg_score": avg,
"sample_count": cnt,
})
if not suggestions:
output_success({
"message": "暂无足够数据生成话题推荐,建议先记录更多内容表现",
"suggestions": [],
})
return
output_success({
"message": f"为你推荐 {len(suggestions)} 个话题",
"suggestions": suggestions,
})
# ============================================================
# 操作:推荐发布时间
# ============================================================
def suggest_timing(data: Optional[Dict[str, Any]] = None) -> None:
"""推荐各平台最佳发布时间。
可选字段: platform
基于历史发布时间和互动数据,推荐最优发布时段和星期。
Args:
data: 可选的过滤条件字典。
"""
learning = _get_learning_data()
performances = learning["performances"]
platform_filter = data.get("platform") if data else None
if platform_filter:
performances = [p for p in performances if p.get("platform") == platform_filter]
if not performances:
output_error("暂无历史表现数据,无法推荐发布时间", code="NO_DATA")
return
# 按平台分组分析
platform_groups = _group_by(performances, "platform")
recommendations = {}
for platform, records in platform_groups.items():
# 按小时分析
hour_data: Dict[int, List[float]] = defaultdict(list)
for r in records:
hour = _extract_hour(r.get("posting_time", ""))
if hour is not None:
hour_data[hour].append(r.get("engagement_score", 0))
best_hours = []
for hour, scores in hour_data.items():
avg = round(sum(scores) / len(scores), 2)
best_hours.append({"hour": hour, "time": f"{hour:02d}:00", "avg_score": avg, "count": len(scores)})
best_hours.sort(key=lambda x: x["avg_score"], reverse=True)
# 按星期分析
weekday_data: Dict[int, List[float]] = defaultdict(list)
for r in records:
wd = _extract_weekday(r.get("posting_time", ""))
if wd is not None:
weekday_data[wd].append(r.get("engagement_score", 0))
best_weekdays = []
for wd, scores in weekday_data.items():
avg = round(sum(scores) / len(scores), 2)
best_weekdays.append({
"weekday": wd,
"weekday_name": _WEEKDAY_NAMES[wd],
"avg_score": avg,
"count": len(scores),
})
best_weekdays.sort(key=lambda x: x["avg_score"], reverse=True)
# 生成推荐
rec = {
"platform": platform,
"total_records": len(records),
"best_hours": best_hours[:3],
"best_weekdays": best_weekdays[:3],
"recommendation": "",
}
# 组合推荐语
parts = []
if best_hours:
top_hour = best_hours[0]
parts.append(f"建议在 {top_hour['time']} 左右发布(均分 {top_hour['avg_score']})")
if best_weekdays:
top_wd = best_weekdays[0]
parts.append(f"{top_wd['weekday_name']}效果最佳(均分 {top_wd['avg_score']})")
rec["recommendation"] = ";".join(parts) if parts else "数据不足,建议持续记录"
recommendations[platform] = rec
output_success({
"message": f"已分析 {len(recommendations)} 个平台的最佳发布时间",
"recommendations": recommendations,
})
# ============================================================
# 操作:统计面板
# ============================================================
def stats(data: Optional[Dict[str, Any]] = None) -> None:
"""内容表现统计面板。
可选字段: platform, limit(排行数量,默认 5)
展示:
- 总体统计(内容数、平均得分、平均互动率)
- 最佳内容 Top N
- 最差内容 Top N
- 各平台平均表现
- 各话题平均表现
Args:
data: 可选的过滤条件字典。
"""
learning = _get_learning_data()
performances = learning["performances"]
limit = data.get("limit", 5) if data else 5
platform_filter = data.get("platform") if data else None
if platform_filter:
performances = [p for p in performances if p.get("platform") == platform_filter]
if not performances:
output_error("暂无历史表现数据", code="NO_DATA")
return
# 总体统计
total_count = len(performances)
avg_score_all = _avg_score(performances)
avg_rate_all = _avg_rate(performances)
# 排序获取最佳和最差
sorted_by_score = sorted(performances, key=lambda x: x.get("engagement_score", 0), reverse=True)
best_posts = []
for p in sorted_by_score[:limit]:
best_posts.append({
"content_id": p.get("content_id", ""),
"title": p.get("title", ""),
"platform": p.get("platform", ""),
"topic": p.get("topic", ""),
"engagement_score": p.get("engagement_score", 0),
"engagement_rate": p.get("engagement_rate", 0),
})
worst_posts = []
for p in sorted_by_score[-limit:]:
worst_posts.append({
"content_id": p.get("content_id", ""),
"title": p.get("title", ""),
"platform": p.get("platform", ""),
"topic": p.get("topic", ""),
"engagement_score": p.get("engagement_score", 0),
"engagement_rate": p.get("engagement_rate", 0),
})
worst_posts.reverse()
# 按平台统计
platform_groups = _group_by(performances, "platform")
platform_stats = []
for platform, records in platform_groups.items():
platform_stats.append({
"platform": platform,
"count": len(records),
"avg_score": _avg_score(records),
"avg_rate": _avg_rate(records),
})
platform_stats.sort(key=lambda x: x["avg_score"], reverse=True)
# 按话题统计
topic_groups = _group_by(performances, "topic")
topic_stats = []
for topic, records in topic_groups.items():
if not topic or topic == "未知":
continue
topic_stats.append({
"topic": topic,
"count": len(records),
"avg_score": _avg_score(records),
"avg_rate": _avg_rate(records),
})
topic_stats.sort(key=lambda x: x["avg_score"], reverse=True)
dashboard = {
"overview": {
"total_records": total_count,
"avg_engagement_score": avg_score_all,
"avg_engagement_rate": avg_rate_all,
"first_record": learning["metadata"].get("first_record_at", ""),
"last_record": learning["metadata"].get("last_record_at", ""),
},
"best_posts": best_posts,
"worst_posts": worst_posts,
"by_platform": platform_stats,
"by_topic": topic_stats[:10],
"preferences": learning["preferences"],
}
output_success({
"message": f"内容表现统计面板(共 {total_count} 条记录)",
"dashboard": dashboard,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("content-engine 自学习内容智能")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"record-performance": lambda: record_performance(data or {}),
"record-preference": lambda: record_preference(data or {}),
"analyze": lambda: analyze(data),
"suggest-topic": lambda: suggest_topic(data),
"suggest-timing": lambda: suggest_timing(data),
"stats": lambda: stats(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/image_prompter.py
#!/usr/bin/env python3
"""
content-engine AI 配图提示词生成模块
为内容生成 AI 图片生成提示词,支持 Midjourney、DALL-E、Stable Diffusion 风格。
提供图片位置建议、SEO 友好的 alt text 生成、视觉内容规划等功能。
"""
import json
import os
import re
import sys
from typing import Any, Dict, List, Optional, Tuple
from utils import (
load_input_data,
output_error,
output_success,
parse_common_args,
)
# ============================================================
# 平台图片尺寸规格
# ============================================================
PLATFORM_IMAGE_SPECS = {
"twitter": {
"card": {"width": 1200, "height": 675, "ratio": "16:9", "label": "Twitter Card"},
"profile": {"width": 400, "height": 400, "ratio": "1:1", "label": "头像"},
"banner": {"width": 1500, "height": 500, "ratio": "3:1", "label": "横幅"},
"in_stream": {"width": 1200, "height": 675, "ratio": "16:9", "label": "信息流图片"},
},
"linkedin": {
"post": {"width": 1200, "height": 627, "ratio": "1.91:1", "label": "帖子配图"},
"banner": {"width": 1128, "height": 191, "ratio": "5.9:1", "label": "公司横幅"},
"article": {"width": 744, "height": 400, "ratio": "1.86:1", "label": "文章封面"},
},
"wechat": {
"cover": {"width": 900, "height": 383, "ratio": "2.35:1", "label": "公众号封面"},
"small_cover": {"width": 200, "height": 200, "ratio": "1:1", "label": "小图封面"},
"article_image": {"width": 1080, "height": 720, "ratio": "3:2", "label": "文章配图"},
},
"blog": {
"hero": {"width": 1200, "height": 630, "ratio": "1.91:1", "label": "Hero 大图"},
"content": {"width": 800, "height": 450, "ratio": "16:9", "label": "正文配图"},
"thumbnail": {"width": 400, "height": 300, "ratio": "4:3", "label": "缩略图"},
},
"medium": {
"feature": {"width": 1400, "height": 788, "ratio": "16:9", "label": "特色图片"},
"content": {"width": 700, "height": 394, "ratio": "16:9", "label": "正文配图"},
},
}
# 图片风格模板
IMAGE_STYLE_TEMPLATES = {
"professional": {
"cn": "专业商务风格,简洁干净的背景,现代扁平设计",
"en": "professional business style, clean minimal background, modern flat design",
},
"tech": {
"cn": "科技感十足,蓝紫色调,数字化元素,未来感",
"en": "tech-inspired, blue-purple color scheme, digital elements, futuristic",
},
"creative": {
"cn": "创意艺术风格,大胆用色,抽象元素",
"en": "creative artistic style, bold colors, abstract elements",
},
"minimal": {
"cn": "极简风格,大量留白,少量元素,优雅排版",
"en": "minimalist style, lots of white space, few elements, elegant typography",
},
"warm": {
"cn": "温暖亲和风格,暖色调,自然光影,生活化场景",
"en": "warm and friendly style, warm tones, natural lighting, lifestyle scene",
},
"illustration": {
"cn": "手绘插画风格,线条流畅,色彩丰富",
"en": "hand-drawn illustration style, smooth lines, rich colors",
},
}
# ============================================================
# 辅助函数
# ============================================================
def _extract_key_concepts(text: str) -> List[str]:
"""从文本中提取关键概念词。
通过简单的规则提取:
- 粗体文本 **keyword**
- 标题文本 # heading
- 引号内文本 「keyword」
- 频繁出现的名词短语
Args:
text: 内容正文。
Returns:
关键概念词列表。
"""
concepts = []
# 提取粗体文本
bold = re.findall(r"\*\*(.+?)\*\*", text)
concepts.extend(bold[:5])
# 提取标题
headings = re.findall(r"^#{1,3}\s+(.+)$", text, re.MULTILINE)
concepts.extend(headings[:5])
# 提取引号内容
quoted_cn = re.findall(r"[「『](.+?)[」』]", text)
quoted_en = re.findall(r'"(.+?)"', text)
concepts.extend(quoted_cn[:3])
concepts.extend(quoted_en[:3])
# 去重保序
seen = set()
unique = []
for c in concepts:
c = c.strip()
if c and c not in seen:
seen.add(c)
unique.append(c)
return unique[:10]
def _detect_content_theme(text: str, title: str = "") -> str:
"""检测内容主题以选择图片风格。
Args:
text: 内容正文。
title: 内容标题。
Returns:
风格标识符(如 "tech", "professional" 等)。
"""
combined = (title + " " + text).lower()
# 主题关键词映射
theme_keywords = {
"tech": ["ai", "人工智能", "编程", "代码", "开发", "技术", "api",
"算法", "数据", "机器学习", "深度学习", "cloud", "云",
"python", "javascript", "react", "agent", "llm"],
"professional": ["商业", "管理", "策略", "营销", "品牌", "企业",
"领导力", "团队", "business", "marketing", "strategy",
"leadership", "management"],
"creative": ["设计", "创意", "艺术", "灵感", "design", "creative",
"art", "inspiration", "ui", "ux"],
"warm": ["生活", "成长", "分享", "故事", "心得", "感悟",
"lifestyle", "growth", "story", "personal"],
"minimal": ["效率", "极简", "工具", "方法", "productivity",
"minimal", "tools", "workflow"],
}
scores: Dict[str, int] = {}
for theme, keywords in theme_keywords.items():
score = sum(1 for kw in keywords if kw in combined)
scores[theme] = score
if scores:
best = max(scores, key=lambda k: scores[k])
if scores[best] > 0:
return best
return "professional"
def _build_image_prompt(
description: str,
style: str = "professional",
aspect_ratio: str = "16:9",
language: str = "both",
) -> Dict[str, str]:
"""构建 AI 图片生成提示词。
Args:
description: 图片内容描述。
style: 风格标识符。
aspect_ratio: 宽高比。
language: 输出语言("cn", "en", "both")。
Returns:
包含中英文提示词的字典。
"""
style_info = IMAGE_STYLE_TEMPLATES.get(style, IMAGE_STYLE_TEMPLATES["professional"])
# 英文提示词(适配 Midjourney/DALL-E/SD)
en_prompt = (
f"{description}, {style_info['en']}, "
f"high quality, detailed, {aspect_ratio} aspect ratio, "
f"4k resolution, sharp focus"
)
# 中文提示词
cn_prompt = (
f"{description},{style_info['cn']},"
f"高质量,细节丰富,{aspect_ratio} 比例"
)
result = {}
if language in ("en", "both"):
result["en"] = en_prompt
if language in ("cn", "both"):
result["cn"] = cn_prompt
return result
# ============================================================
# 操作:生成图片提示词
# ============================================================
def generate_prompt(data: Dict[str, Any]) -> None:
"""根据内容文本生成 AI 图片生成提示词。
必填字段: text(内容文本或描述)
可选字段: style(风格), platform(目标平台), type(图片类型), title
Args:
data: 包含内容文本和选项的字典。
"""
text = data.get("text", "")
if not text:
output_error("内容文本(text)为必填字段", code="VALIDATION_ERROR")
return
title = data.get("title", "")
style = data.get("style", "")
platform = data.get("platform", "").lower()
img_type = data.get("type", "content")
# 自动检测风格
if not style:
style = _detect_content_theme(text, title)
# 提取关键概念
concepts = _extract_key_concepts(text)
# 获取平台图片规格
specs = {}
if platform and platform in PLATFORM_IMAGE_SPECS:
platform_specs = PLATFORM_IMAGE_SPECS[platform]
if img_type in platform_specs:
specs = platform_specs[img_type]
else:
# 使用第一个可用规格
first_key = list(platform_specs.keys())[0]
specs = platform_specs[first_key]
aspect_ratio = specs.get("ratio", "16:9")
# 构建描述
description = title if title else text[:100]
if concepts:
concept_str = ", ".join(concepts[:3])
description = f"{description} — featuring: {concept_str}"
prompt = _build_image_prompt(description, style, aspect_ratio)
result = {
"message": "已生成图片提示词",
"prompts": prompt,
"style": style,
"style_description": IMAGE_STYLE_TEMPLATES.get(style, {}),
"key_concepts": concepts,
}
if specs:
result["recommended_size"] = specs
result["platform"] = platform
output_success(result)
# ============================================================
# 操作:建议图片位置
# ============================================================
def suggest_images(data: Dict[str, Any]) -> None:
"""分析内容文本,建议在哪些位置放置图片以及图片内容。
必填字段: text(完整内容文本)
可选字段: title, style, max_suggestions(最大建议数,默认 5)
Args:
data: 包含内容文本的字典。
"""
text = data.get("text", "")
if not text:
output_error("内容文本(text)为必填字段", code="VALIDATION_ERROR")
return
title = data.get("title", "")
style = data.get("style", "") or _detect_content_theme(text, title)
max_suggestions = data.get("max_suggestions", 5)
suggestions = []
# 1. 文章开头 — Hero 图片
suggestions.append({
"position": "文章开头",
"position_type": "hero",
"reason": "吸引读者注意力,建立文章视觉基调",
"description": f"文章主题「{title or text[:30]}」的概念图",
"prompts": _build_image_prompt(
f"Hero image for article about {title or text[:50]}",
style, "16:9",
),
})
# 2. 按段落/标题分析内容,在关键段落后建议配图
sections = re.split(r"(?=^#{1,3}\s+)", text, flags=re.MULTILINE)
section_index = 0
for section in sections:
if len(suggestions) >= max_suggestions:
break
section = section.strip()
if not section:
continue
# 提取段落标题
heading_match = re.match(r"^#{1,3}\s+(.+)$", section, re.MULTILINE)
if heading_match:
heading = heading_match.group(1).strip()
section_body = section[heading_match.end():].strip()
# 检查段落是否足够长(值得配图)
if len(section_body) > 100:
section_concepts = _extract_key_concepts(section_body)
desc = f"Illustration for section: {heading}"
if section_concepts:
desc += f" — {', '.join(section_concepts[:2])}"
suggestions.append({
"position": f"「{heading}」章节后",
"position_type": "section",
"reason": f"为「{heading}」章节提供视觉解释",
"description": desc,
"prompts": _build_image_prompt(desc, style, "16:9"),
})
section_index += 1
# 3. 检查是否有列表/步骤(适合信息图)
list_items = re.findall(r"^\d+\.\s+(.+)$", text, re.MULTILINE)
if len(list_items) >= 3 and len(suggestions) < max_suggestions:
items_str = ", ".join(list_items[:5])
suggestions.append({
"position": "步骤/列表区域",
"position_type": "infographic",
"reason": "将步骤或列表可视化为信息图",
"description": f"Infographic showing steps: {items_str}",
"prompts": _build_image_prompt(
f"Clean infographic with numbered steps: {items_str}",
"minimal", "4:3",
),
})
# 4. 代码块区域建议截图/示意图
code_blocks = re.findall(r"```(\w*)\n[\s\S]*?```", text)
if code_blocks and len(suggestions) < max_suggestions:
lang = code_blocks[0] if code_blocks[0] else "code"
suggestions.append({
"position": "代码块附近",
"position_type": "code_illustration",
"reason": "为代码段提供视觉说明或运行效果展示",
"description": f"Technical illustration related to {lang} code",
"prompts": _build_image_prompt(
f"Clean technical diagram or screenshot showing {lang} code output",
"tech", "16:9",
),
})
# 5. 文章结尾 — CTA 图片
if len(suggestions) < max_suggestions:
suggestions.append({
"position": "文章结尾",
"position_type": "cta",
"reason": "在文章结尾添加行动号召配图",
"description": f"Call-to-action image for {title or 'the article'}",
"prompts": _build_image_prompt(
f"Engaging call-to-action image, inviting and motivational",
"warm", "16:9",
),
})
output_success({
"message": f"为文章建议了 {len(suggestions)} 处配图位置",
"style": style,
"suggestions": suggestions[:max_suggestions],
})
# ============================================================
# 操作:生成 alt text
# ============================================================
def format_alt_text(data: Dict[str, Any]) -> None:
"""为图片生成 SEO 友好的 alt text。
必填字段: description(图片描述)或 images(图片描述列表)
可选字段: context(文章上下文), keywords(SEO 关键词)
Args:
data: 包含图片描述的字典。
"""
description = data.get("description", "")
images = data.get("images", [])
context = data.get("context", "")
keywords = data.get("keywords", [])
if isinstance(keywords, str):
keywords = [k.strip() for k in keywords.split(",") if k.strip()]
if not description and not images:
output_error("图片描述(description)或图片列表(images)为必填字段", code="VALIDATION_ERROR")
return
results = []
# 处理单个描述
if description and not images:
images = [{"description": description}]
# 处理图片列表
for img in images:
if isinstance(img, str):
img = {"description": img}
desc = img.get("description", "")
if not desc:
continue
# 生成 alt text
alt_text_cn = _generate_alt_text(desc, keywords, "cn", context)
alt_text_en = _generate_alt_text(desc, keywords, "en", context)
results.append({
"original_description": desc,
"alt_text_cn": alt_text_cn,
"alt_text_en": alt_text_en,
"seo_keywords_included": [k for k in keywords if k.lower() in alt_text_en.lower() or k in alt_text_cn],
})
output_success({
"message": f"已为 {len(results)} 张图片生成 alt text",
"alt_texts": results,
})
def _generate_alt_text(description: str, keywords: List[str], lang: str, context: str = "") -> str:
"""生成单张图片的 alt text。
规则:
- 长度控制在 60-120 字符
- 包含关键信息和 SEO 关键词
- 描述性且准确
Args:
description: 图片描述。
keywords: SEO 关键词列表。
lang: 语言("cn" 或 "en")。
context: 文章上下文。
Returns:
alt text 字符串。
"""
# 清理描述
desc = description.strip()
desc = re.sub(r"\s+", " ", desc)
# 截断到合理长度
if len(desc) > 100:
desc = desc[:97] + "..."
# 尝试包含关键词
if keywords:
kw_str = "、".join(keywords[:2]) if lang == "cn" else ", ".join(keywords[:2])
if lang == "cn":
alt = f"{desc} — {kw_str}"
else:
alt = f"{desc} — {kw_str}"
else:
alt = desc
# 确保长度合理
if len(alt) > 125:
alt = alt[:122] + "..."
return alt
# ============================================================
# 操作:视觉内容规划
# ============================================================
def image_plan(data: Dict[str, Any]) -> None:
"""为文章创建完整的视觉内容规划。
必填字段: text(文章正文)
可选字段: title, platforms(目标平台列表), style
为每个目标平台生成:
- Hero 大图提示词
- 章节配图提示词
- 社交媒体分享缩略图提示词
Args:
data: 包含文章内容的字典。
"""
text = data.get("text", "")
if not text:
output_error("文章正文(text)为必填字段", code="VALIDATION_ERROR")
return
title = data.get("title", "")
platforms = data.get("platforms", ["blog"])
style = data.get("style", "") or _detect_content_theme(text, title)
if isinstance(platforms, str):
platforms = [p.strip() for p in platforms.split(",") if p.strip()]
concepts = _extract_key_concepts(text)
plan = {
"title": title or "未命名文章",
"style": style,
"key_concepts": concepts,
"images": [],
}
# 1. Hero 图片(各平台尺寸)
hero_desc = f"Hero image: {title or text[:50]}"
if concepts:
hero_desc += f", featuring {', '.join(concepts[:2])}"
hero_item = {
"name": "Hero 大图",
"type": "hero",
"description": hero_desc,
"prompts": _build_image_prompt(hero_desc, style, "16:9"),
"platform_sizes": {},
}
for p in platforms:
p_lower = p.lower()
if p_lower in PLATFORM_IMAGE_SPECS:
specs = PLATFORM_IMAGE_SPECS[p_lower]
# 选择最适合 hero 的规格
for key in ["hero", "feature", "cover", "card", "post"]:
if key in specs:
hero_item["platform_sizes"][p_lower] = specs[key]
break
plan["images"].append(hero_item)
# 2. 章节配图
headings = re.findall(r"^#{1,3}\s+(.+)$", text, re.MULTILINE)
for i, heading in enumerate(headings[:4]): # 最多 4 个章节配图
section_desc = f"Section illustration: {heading}"
section_item = {
"name": f"章节配图 — {heading}",
"type": "section",
"description": section_desc,
"prompts": _build_image_prompt(section_desc, style, "16:9"),
"platform_sizes": {},
}
for p in platforms:
p_lower = p.lower()
if p_lower in PLATFORM_IMAGE_SPECS:
specs = PLATFORM_IMAGE_SPECS[p_lower]
for key in ["content", "article_image", "in_stream"]:
if key in specs:
section_item["platform_sizes"][p_lower] = specs[key]
break
plan["images"].append(section_item)
# 3. 社交媒体分享缩略图
share_desc = f"Social media thumbnail: {title or text[:30]}"
share_item = {
"name": "社交媒体分享缩略图",
"type": "thumbnail",
"description": share_desc,
"prompts": _build_image_prompt(
f"Eye-catching social media thumbnail, {title or text[:30]}, "
f"clear text space, bold visual",
style, "1.91:1",
),
"platform_sizes": {},
}
for p in platforms:
p_lower = p.lower()
if p_lower in PLATFORM_IMAGE_SPECS:
specs = PLATFORM_IMAGE_SPECS[p_lower]
for key in ["card", "thumbnail", "small_cover", "post"]:
if key in specs:
share_item["platform_sizes"][p_lower] = specs[key]
break
plan["images"].append(share_item)
# 4. 汇总平台规格参考
all_specs = {}
for p in platforms:
p_lower = p.lower()
if p_lower in PLATFORM_IMAGE_SPECS:
all_specs[p_lower] = PLATFORM_IMAGE_SPECS[p_lower]
plan["platform_specs_reference"] = all_specs
plan["total_images"] = len(plan["images"])
output_success({
"message": f"已生成视觉内容规划(共 {plan['total_images']} 张图片)",
"plan": plan,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("content-engine AI 配图助手")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"generate-prompt": lambda: generate_prompt(data or {}),
"suggest-images": lambda: suggest_images(data or {}),
"format-alt-text": lambda: format_alt_text(data or {}),
"image-plan": lambda: image_plan(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/metrics_collector.py
#!/usr/bin/env python3
"""
content-engine 数据指标采集模块
从各平台采集内容表现数据,生成对比报告和可视化图表。
此模块为付费版功能。
"""
import json
import os
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError
from utils import (
check_subscription,
format_platform_name,
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
validate_platform,
write_json_file,
PLATFORMS,
)
# 延迟导入学习引擎
_learning_engine = None
def _get_learning_engine():
"""延迟导入 learning_engine 模块。"""
global _learning_engine
if _learning_engine is None:
try:
import learning_engine as _learning_engine
except ImportError:
_learning_engine = None
return _learning_engine
def _auto_record_to_learning(content_id: str, platform: str,
metrics: Dict[str, Any],
content: Optional[Dict] = None) -> None:
"""采集指标后自动记录到学习引擎。
Args:
content_id: 内容 ID。
platform: 平台标识。
metrics: 采集到的指标数据。
content: 内容详情(可选)。
"""
le = _get_learning_engine()
if le is None:
return
if "error" in metrics:
return
try:
record_data = {
"content_id": content_id,
"platform": platform,
"metrics": metrics,
}
if content:
record_data["title"] = content.get("title", "")
record_data["tags"] = content.get("tags", [])
record_data["topic"] = content.get("tags", [""])[0] if content.get("tags") else ""
record_data["posting_time"] = content.get("published_at", "")
record_data["length"] = len(content.get("body", ""))
le.record_performance(record_data)
except Exception:
pass # 记录失败不影响主流程
# ============================================================
# 数据文件路径
# ============================================================
METRICS_FILE = "metrics.json"
CONTENTS_FILE = "contents.json"
PUBLISH_HISTORY_FILE = "publish_history.json"
def _get_metrics() -> List[Dict[str, Any]]:
"""读取所有指标数据。"""
return read_json_file(get_data_file(METRICS_FILE))
def _save_metrics(metrics: List[Dict[str, Any]]) -> None:
"""保存指标数据。"""
write_json_file(get_data_file(METRICS_FILE), metrics)
def _get_contents() -> List[Dict[str, Any]]:
"""读取所有内容数据。"""
return read_json_file(get_data_file(CONTENTS_FILE))
def _find_content(contents: List[Dict], content_id: str) -> Optional[Dict]:
"""根据 ID 查找内容。"""
for c in contents:
if c.get("id") == content_id:
return c
return None
def _get_publish_history() -> List[Dict[str, Any]]:
"""读取发布历史。"""
return read_json_file(get_data_file(PUBLISH_HISTORY_FILE))
# ============================================================
# 平台指标定义
# ============================================================
# 各平台指标字段
PLATFORM_METRICS = {
"twitter": ["likes", "retweets", "replies", "impressions"],
"linkedin": ["likes", "comments", "shares", "views"],
"wechat": ["reads", "shares", "favorites"],
"medium": ["reads", "claps", "responses"],
"blog": [], # 博客不采集指标
}
# 指标中文名映射
METRIC_NAMES = {
"likes": "点赞",
"retweets": "转发",
"replies": "回复",
"impressions": "曝光",
"comments": "评论",
"shares": "分享",
"views": "浏览",
"reads": "阅读",
"claps": "鼓掌",
"responses": "回应",
"favorites": "收藏",
}
# ============================================================
# 平台指标采集
# ============================================================
def _api_get(url: str, headers: Optional[Dict] = None) -> Dict[str, Any]:
"""发送 GET 请求获取 API 数据。
Args:
url: 请求 URL。
headers: 请求头。
Returns:
响应数据字典。
"""
if headers is None:
headers = {}
req = Request(url, headers=headers, method="GET")
try:
with urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode("utf-8"))
except HTTPError as e:
error_body = e.read().decode("utf-8") if e.fp else ""
raise Exception(f"HTTP {e.code}: {error_body}")
except URLError as e:
raise Exception(f"网络请求失败: {e.reason}")
def _collect_twitter_metrics(publish_result: Dict[str, Any]) -> Dict[str, Any]:
"""采集 Twitter 指标数据。
Args:
publish_result: 发布结果(包含 tweet_ids)。
Returns:
指标数据字典。
"""
bearer_token = os.environ.get("CE_TWITTER_BEARER_TOKEN", "")
if not bearer_token:
return {"error": "未配置 CE_TWITTER_BEARER_TOKEN"}
tweet_ids = publish_result.get("tweet_ids", [])
if not tweet_ids:
return {"error": "无推文 ID"}
headers = {"Authorization": f"Bearer {bearer_token}"}
total_metrics = {"likes": 0, "retweets": 0, "replies": 0, "impressions": 0}
for tweet_id in tweet_ids:
try:
url = (
f"https://api.twitter.com/2/tweets/{tweet_id}"
f"?tweet.fields=public_metrics"
)
data = _api_get(url, headers)
metrics = data.get("data", {}).get("public_metrics", {})
total_metrics["likes"] += metrics.get("like_count", 0)
total_metrics["retweets"] += metrics.get("retweet_count", 0)
total_metrics["replies"] += metrics.get("reply_count", 0)
total_metrics["impressions"] += metrics.get("impression_count", 0)
except Exception as e:
return {"error": f"采集推文 {tweet_id} 指标失败: {str(e)}"}
return total_metrics
def _collect_linkedin_metrics(publish_result: Dict[str, Any]) -> Dict[str, Any]:
"""采集 LinkedIn 指标数据。
Args:
publish_result: 发布结果(包含 post_id)。
Returns:
指标数据字典。
"""
access_token = os.environ.get("CE_LINKEDIN_ACCESS_TOKEN", "")
if not access_token:
return {"error": "未配置 CE_LINKEDIN_ACCESS_TOKEN"}
post_id = publish_result.get("post_id", "")
if not post_id:
return {"error": "无帖子 ID"}
headers = {"Authorization": f"Bearer {access_token}"}
try:
url = f"https://api.linkedin.com/v2/socialActions/{post_id}"
data = _api_get(url, headers)
return {
"likes": data.get("likesSummary", {}).get("totalLikes", 0),
"comments": data.get("commentsSummary", {}).get("totalFirstLevelComments", 0),
"shares": data.get("sharesSummary", {}).get("totalShares", 0) if "sharesSummary" in data else 0,
"views": 0, # LinkedIn API 不直接提供浏览量
}
except Exception as e:
return {"error": f"采集 LinkedIn 指标失败: {str(e)}"}
def _collect_wechat_metrics(publish_result: Dict[str, Any]) -> Dict[str, Any]:
"""采集微信公众号指标数据。
Args:
publish_result: 发布结果(包含 publish_id)。
Returns:
指标数据字典。
"""
appid = os.environ.get("CE_WECHAT_APPID", "")
secret = os.environ.get("CE_WECHAT_SECRET", "")
if not appid or not secret:
return {"error": "未配置 CE_WECHAT_APPID 和 CE_WECHAT_SECRET"}
# 获取 access_token
try:
token_url = (
f"https://api.weixin.qq.com/cgi-bin/token?"
f"grant_type=client_credential&appid={appid}&secret={secret}"
)
req = Request(token_url, method="GET")
with urlopen(req, timeout=30) as resp:
token_data = json.loads(resp.read().decode("utf-8"))
if "access_token" not in token_data:
return {"error": f"获取 access_token 失败: {token_data.get('errmsg', '')}"}
access_token = token_data["access_token"]
except Exception as e:
return {"error": f"获取微信 access_token 失败: {str(e)}"}
# 获取文章数据统计
publish_id = publish_result.get("publish_id", "")
if not publish_id:
return {"error": "无 publish_id"}
try:
url = (
f"https://api.weixin.qq.com/cgi-bin/freepublish/getarticle?"
f"access_token={access_token}"
)
req = Request(
url,
data=json.dumps({"publish_id": publish_id}).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
with urlopen(req, timeout=30) as resp:
data = json.loads(resp.read().decode("utf-8"))
# 解析文章统计信息(微信 API 返回格式)
return {
"reads": data.get("read_num", 0),
"shares": data.get("share_num", 0),
"favorites": data.get("collect_num", 0),
}
except Exception as e:
return {"error": f"采集微信指标失败: {str(e)}"}
def _collect_medium_metrics(publish_result: Dict[str, Any]) -> Dict[str, Any]:
"""采集 Medium 指标数据。
注意: Medium API 对指标数据的支持有限。
Args:
publish_result: 发布结果(包含 post_id, url)。
Returns:
指标数据字典。
"""
# Medium API 不直接提供详细指标,返回基本信息
return {
"reads": 0,
"claps": 0,
"responses": 0,
"note": "Medium API 暂不支持实时指标采集,请通过 Medium 后台查看",
}
# 平台指标采集器注册表
_COLLECTORS = {
"twitter": _collect_twitter_metrics,
"linkedin": _collect_linkedin_metrics,
"wechat": _collect_wechat_metrics,
"medium": _collect_medium_metrics,
}
# ============================================================
# 指标操作
# ============================================================
def collect_metrics(data: Dict[str, Any]) -> None:
"""采集内容在各平台的表现指标。
必填字段: content_id
可选字段: platform(不指定则采集所有已发布平台)
Args:
data: 包含内容 ID 和可选平台的字典。
"""
if not require_paid_feature("metrics", "指标采集"):
return
content_id = data.get("content_id") or data.get("id")
if not content_id:
output_error("内容ID(content_id)为必填字段", code="VALIDATION_ERROR")
return
# 查找内容的发布历史
history = _get_publish_history()
content_history = [h for h in history if h.get("content_id") == content_id]
if not content_history:
output_error(f"内容 {content_id} 暂无发布记录", code="NOT_FOUND")
return
# 按平台过滤
platform_filter = data.get("platform")
if platform_filter:
try:
platform_filter = validate_platform(platform_filter)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
content_history = [h for h in content_history if h.get("platform") == platform_filter]
# 采集各平台指标
metrics_data = _get_metrics()
results = []
for record in content_history:
platform = record.get("platform", "")
publish_result = record.get("result", {})
if not publish_result.get("success"):
continue
collector = _COLLECTORS.get(platform)
if not collector:
continue
metrics = collector(publish_result)
metric_record = {
"id": generate_id("MT"),
"content_id": content_id,
"platform": platform,
"platform_name": format_platform_name(platform),
"metrics": metrics,
"collected_at": now_iso(),
}
# 更新或添加指标记录
metrics_data = [
m for m in metrics_data
if not (m.get("content_id") == content_id and m.get("platform") == platform)
]
metrics_data.append(metric_record)
results.append(metric_record)
# 自动记录到学习引擎
content_detail = _find_content(_get_contents(), content_id)
_auto_record_to_learning(content_id, platform, metrics, content_detail)
_save_metrics(metrics_data)
output_success({
"message": f"已采集 {len(results)} 个平台的指标数据",
"content_id": content_id,
"metrics": results,
})
def generate_report(data: Dict[str, Any]) -> None:
"""生成内容表现报告。
必填字段: content_id
Args:
data: 包含内容 ID 的字典。
"""
if not require_paid_feature("metrics", "指标报告"):
return
content_id = data.get("content_id") or data.get("id")
if not content_id:
output_error("内容ID(content_id)为必填字段", code="VALIDATION_ERROR")
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
metrics_data = _get_metrics()
content_metrics = [m for m in metrics_data if m.get("content_id") == content_id]
if not content_metrics:
output_error(f"内容 {content_id} 暂无指标数据,请先执行 collect 操作", code="NOT_FOUND")
return
# 构建报告
report_lines = []
report_lines.append(f"# 内容表现报告")
report_lines.append("")
report_lines.append(f"**标题**: {content.get('title', '')}")
report_lines.append(f"**状态**: {content.get('status', '')}")
report_lines.append(f"**创建时间**: {content.get('created_at', '')}")
report_lines.append(f"**发布时间**: {content.get('published_at', '')}")
report_lines.append("")
# 跨平台对比表格
report_lines.append("## 跨平台数据对比")
report_lines.append("")
# 收集所有指标名称
all_metric_keys = set()
for m in content_metrics:
metrics = m.get("metrics", {})
if "error" not in metrics:
all_metric_keys.update(metrics.keys())
# 移除非数值字段
all_metric_keys.discard("note")
all_metric_keys.discard("error")
if all_metric_keys:
sorted_keys = sorted(all_metric_keys)
header = "| 指标 | " + " | ".join(
format_platform_name(m.get("platform", ""))
for m in content_metrics
if "error" not in m.get("metrics", {})
) + " |"
separator = "|------|" + "|".join(
"------:" for m in content_metrics
if "error" not in m.get("metrics", {})
) + "|"
report_lines.append(header)
report_lines.append(separator)
for key in sorted_keys:
row = f"| {METRIC_NAMES.get(key, key)} |"
for m in content_metrics:
metrics = m.get("metrics", {})
if "error" not in metrics:
val = metrics.get(key, "-")
row += f" {val} |"
report_lines.append(row)
report_lines.append("")
# 各平台详情
report_lines.append("## 各平台详情")
report_lines.append("")
for m in content_metrics:
platform = m.get("platform", "")
metrics = m.get("metrics", {})
collected_at = m.get("collected_at", "")
report_lines.append(f"### {format_platform_name(platform)}")
report_lines.append(f"*采集时间: {collected_at}*")
report_lines.append("")
if "error" in metrics:
report_lines.append(f"> 采集失败: {metrics['error']}")
else:
for key, value in metrics.items():
if key != "note":
report_lines.append(f"- **{METRIC_NAMES.get(key, key)}**: {value}")
if "note" in metrics:
report_lines.append(f"\n> {metrics['note']}")
report_lines.append("")
report = "\n".join(report_lines)
output_success({
"message": "指标报告已生成",
"content_id": content_id,
"report": report,
})
def compare_metrics(data: Dict[str, Any]) -> None:
"""对比多条内容的表现指标。
必填字段: content_ids(内容 ID 列表)
Args:
data: 包含内容 ID 列表的字典。
"""
if not require_paid_feature("metrics", "指标对比"):
return
content_ids = data.get("content_ids", [])
if not content_ids or len(content_ids) < 2:
output_error("需要至少 2 个内容ID进行对比", code="VALIDATION_ERROR")
return
contents = _get_contents()
metrics_data = _get_metrics()
comparison = []
for cid in content_ids:
content = _find_content(contents, cid)
if not content:
continue
content_metrics = [m for m in metrics_data if m.get("content_id") == cid]
# 汇总各平台指标
total = {}
for m in content_metrics:
metrics = m.get("metrics", {})
if "error" not in metrics:
for key, value in metrics.items():
if key != "note" and isinstance(value, (int, float)):
total[key] = total.get(key, 0) + value
comparison.append({
"content_id": cid,
"title": content.get("title", ""),
"total_metrics": total,
"platforms": [m.get("platform") for m in content_metrics],
})
output_success({
"message": f"已对比 {len(comparison)} 条内容",
"comparison": comparison,
})
def trending_metrics(data: Optional[Dict[str, Any]] = None) -> None:
"""查看内容指标趋势和热门内容。
可选字段: platform, limit
Args:
data: 可选的过滤条件字典。
"""
if not require_paid_feature("mermaid_chart", "趋势图表"):
return
metrics_data = _get_metrics()
contents = _get_contents()
platform_filter = data.get("platform") if data else None
limit = data.get("limit", 10) if data else 10
if platform_filter:
try:
platform_filter = validate_platform(platform_filter)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
metrics_data = [m for m in metrics_data if m.get("platform") == platform_filter]
# 按内容汇总指标
content_totals = {}
for m in metrics_data:
cid = m.get("content_id", "")
metrics = m.get("metrics", {})
if "error" in metrics:
continue
if cid not in content_totals:
content_totals[cid] = {"total_engagement": 0}
for key, value in metrics.items():
if key != "note" and isinstance(value, (int, float)):
content_totals[cid]["total_engagement"] += value
# 排序取 Top N
sorted_contents = sorted(
content_totals.items(),
key=lambda x: x[1]["total_engagement"],
reverse=True,
)[:limit]
# 构建 Mermaid 柱状图
chart_data = []
for cid, totals in sorted_contents:
content = _find_content(contents, cid)
title = content.get("title", cid)[:15] if content else cid[:15]
chart_data.append({
"label": title,
"value": totals["total_engagement"],
})
mermaid_chart = ""
if chart_data:
labels = ", ".join(f'"{d["label"]}"' for d in chart_data)
values = ", ".join(str(d["value"]) for d in chart_data)
mermaid_chart = (
f"```mermaid\n"
f"xychart-beta\n"
f' title "内容互动排行"\n'
f" x-axis [{labels}]\n"
f' y-axis "互动总量"\n'
f" bar [{values}]\n"
f"```"
)
trending_list = []
for cid, totals in sorted_contents:
content = _find_content(contents, cid)
trending_list.append({
"content_id": cid,
"title": content.get("title", "") if content else "",
"total_engagement": totals["total_engagement"],
})
output_success({
"message": f"互动排行 Top {len(trending_list)}",
"trending": trending_list,
"chart": mermaid_chart,
})
# ============================================================
# 学习洞察
# ============================================================
def learning_insights(data: Optional[Dict[str, Any]] = None) -> None:
"""基于学习引擎生成内容表现洞察。
将当前内容的指标与历史平均值对比,给出改进建议。
可选字段: content_id, platform
Args:
data: 可选的过滤条件字典。
"""
if not require_paid_feature("metrics", "学习洞察"):
return
le = _get_learning_engine()
if le is None:
output_error("学习引擎模块不可用,无法生成洞察", code="MODULE_ERROR")
return
learning_data = le._get_learning_data()
performances = learning_data.get("performances", [])
if not performances:
output_error("暂无历史表现数据,请先使用 collect 采集指标", code="NO_DATA")
return
content_id = data.get("content_id") or data.get("id") if data else None
platform_filter = data.get("platform") if data else None
if platform_filter:
performances = [p for p in performances if p.get("platform") == platform_filter]
# 计算历史平均值
all_scores = [p.get("engagement_score", 0) for p in performances]
all_rates = [p.get("engagement_rate", 0) for p in performances]
avg_score = round(sum(all_scores) / len(all_scores), 2) if all_scores else 0
avg_rate = round(sum(all_rates) / len(all_rates), 2) if all_rates else 0
insights = {
"historical_average": {
"avg_engagement_score": avg_score,
"avg_engagement_rate": avg_rate,
"total_records": len(performances),
},
"comparison": None,
"recommendations": [],
}
# 如果指定了具体内容,与历史对比
if content_id:
content_perfs = [p for p in performances if p.get("content_id") == content_id]
if content_perfs:
content_scores = [p.get("engagement_score", 0) for p in content_perfs]
content_avg = round(sum(content_scores) / len(content_scores), 2)
diff = round(content_avg - avg_score, 2)
diff_pct = round((diff / avg_score) * 100, 1) if avg_score > 0 else 0
insights["comparison"] = {
"content_id": content_id,
"content_avg_score": content_avg,
"vs_historical": diff,
"vs_historical_pct": diff_pct,
"performance": "高于平均" if diff > 0 else ("低于平均" if diff < 0 else "持平"),
}
if diff < 0:
insights["recommendations"].append(
f"当前内容互动得分({content_avg})低于历史平均值({avg_score}),"
"建议参考高表现内容的话题和格式"
)
else:
# 通用建议
insights["recommendations"].append(
f"历史平均互动得分 {avg_score},互动率 {avg_rate}%"
)
# 调用学习引擎的分析
try:
# 直接获取分析洞察
from collections import defaultdict as _dd
topic_groups = {}
for p in performances:
topic = p.get("topic", "未知")
if topic not in topic_groups:
topic_groups[topic] = []
topic_groups[topic].append(p.get("engagement_score", 0))
top_topics = []
for topic, scores in topic_groups.items():
if topic == "未知" or not topic:
continue
avg = round(sum(scores) / len(scores), 2)
top_topics.append({"topic": topic, "avg_score": avg, "count": len(scores)})
top_topics.sort(key=lambda x: x["avg_score"], reverse=True)
insights["top_performing_topics"] = top_topics[:5]
if top_topics:
best = top_topics[0]
insights["recommendations"].append(
f"建议多创作「{best['topic']}」相关内容(历史均分 {best['avg_score']})"
)
except Exception:
pass
output_success({
"message": "学习洞察分析完成",
"insights": insights,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("content-engine 指标采集")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"collect": lambda: collect_metrics(data or {}),
"report": lambda: generate_report(data or {}),
"compare": lambda: compare_metrics(data or {}),
"trending": lambda: trending_metrics(data),
"insights": lambda: learning_insights(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
content-engine 共享工具模块
提供内容管理、平台适配、订阅校验、数据格式化等通用功能。
"""
import argparse
import json
import os
import re
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
# ============================================================
# 常量定义
# ============================================================
DEFAULT_DATA_DIR = os.path.join(os.path.expanduser("~"), ".openclaw-bdi", "content-engine")
# 内容状态流转: 草稿→待审核→已排期→已发布→已归档
CONTENT_STATUSES = ["草稿", "待审核", "已排期", "已发布", "已归档"]
# 支持的平台列表
PLATFORMS = ["twitter", "linkedin", "wechat", "blog", "medium"]
# 平台显示名称映射
PLATFORM_NAMES = {
"twitter": "Twitter / X",
"linkedin": "LinkedIn",
"wechat": "微信公众号",
"blog": "博客",
"medium": "Medium",
}
# 平台字符限制
PLATFORM_CHAR_LIMITS = {
"twitter": 280,
"linkedin": 3000,
"wechat": 20000,
"blog": 0, # 博客无限制
"medium": 0, # Medium 无限制
}
# 状态流转规则(当前状态 -> 允许的下一状态列表)
STATUS_TRANSITIONS = {
"草稿": ["待审核", "已归档"],
"待审核": ["草稿", "已排期", "已归档"],
"已排期": ["待审核", "已发布", "已归档"],
"已发布": ["已归档"],
"已归档": ["草稿"],
}
# ============================================================
# 数据目录管理
# ============================================================
def get_data_dir() -> str:
"""获取数据存储目录路径。
优先读取环境变量 CE_DATA_DIR,若未设置则使用默认路径
~/.openclaw-bdi/content-engine/。
自动创建目录(若不存在)。
Returns:
数据目录的绝对路径。
"""
data_dir = os.environ.get("CE_DATA_DIR", DEFAULT_DATA_DIR)
os.makedirs(data_dir, exist_ok=True)
return data_dir
def get_data_file(filename: str) -> str:
"""获取数据文件的完整路径。
Args:
filename: 文件名(如 "contents.json")。
Returns:
数据文件的绝对路径。
"""
return os.path.join(get_data_dir(), filename)
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_file(filepath: str) -> Any:
"""读取 JSON 文件并返回解析后的数据。
Args:
filepath: JSON 文件路径。
Returns:
解析后的数据对象。若文件不存在,返回空列表。
"""
if not os.path.exists(filepath):
return []
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return []
def write_json_file(filepath: str, data: Any) -> None:
"""将数据写入 JSON 文件。
Args:
filepath: 目标文件路径。
data: 待写入的数据。
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
Args:
data: 待输出的数据。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 命令行参数解析
# ============================================================
def parse_common_args(description: str = "content-engine 内容管理工具") -> argparse.ArgumentParser:
"""创建通用命令行参数解析器。
Args:
description: 工具描述文本。
Returns:
配置好通用参数的 ArgumentParser 实例。
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--action",
required=True,
help="操作类型",
)
parser.add_argument(
"--data",
default=None,
help="JSON 格式的数据字符串",
)
parser.add_argument(
"--data-file",
default=None,
help="JSON 数据文件路径",
)
return parser
def load_input_data(args: argparse.Namespace) -> Optional[Dict[str, Any]]:
"""从命令行参数加载输入数据。
优先使用 --data 参数,其次尝试 --data-file 参数。
Args:
args: 解析后的命令行参数。
Returns:
解析后的字典数据,若无输入数据则返回 None。
Raises:
ValueError: 当 JSON 解析失败或文件读取失败时抛出。
"""
if args.data:
try:
data = json.loads(args.data)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if args.data_file:
if not os.path.exists(args.data_file):
raise ValueError(f"数据文件不存在: {args.data_file}")
try:
with open(args.data_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
raise ValueError(f"期望 JSON 对象,实际类型为 {type(data).__name__}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"数据文件 JSON 解析失败: {e}")
return None
# ============================================================
# 订阅校验
# ============================================================
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_content": 20,
"max_platforms": 2,
"features": [
"content_crud",
"basic_adapt",
"manual_publish",
"markdown_export",
],
},
"paid": {
"tier": "paid",
"max_content": 500,
"max_platforms": 5,
"features": [
"content_crud",
"basic_adapt",
"manual_publish",
"markdown_export",
"auto_publish",
"all_platforms",
"metrics",
"calendar",
"wechat",
"mermaid_chart",
"batch_adapt",
"schedule",
],
},
}
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 CE_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get("CE_SUBSCRIPTION_TIER", "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def require_paid_feature(feature_name: str, display_name: str) -> bool:
"""检查当前订阅是否支持指定功能。
若不支持,输出升级提示并返回 False。
Args:
feature_name: 功能内部名称。
display_name: 功能显示名称(用于提示信息)。
Returns:
True 表示功能可用,False 表示不可用(已输出错误信息)。
"""
sub = check_subscription()
if feature_name not in sub["features"]:
output_error(
f"「{display_name}」为付费版功能。当前为免费版,请升级至付费版(¥99/月)以使用此功能。",
code="SUBSCRIPTION_REQUIRED",
)
return False
return True
# ============================================================
# 内容引擎专用工具函数
# ============================================================
def validate_status(status: str) -> str:
"""校验内容状态是否合法。
Args:
status: 待校验的状态名称。
Returns:
合法的状态名称。
Raises:
ValueError: 当状态名称不合法时抛出。
"""
if status not in CONTENT_STATUSES:
valid = "、".join(CONTENT_STATUSES)
raise ValueError(f"无效的内容状态: {status!r},有效状态: {valid}")
return status
def validate_platform(platform: str) -> str:
"""校验平台名称是否合法。
Args:
platform: 待校验的平台名称。
Returns:
合法的平台名称(小写)。
Raises:
ValueError: 当平台名称不合法时抛出。
"""
platform = platform.strip().lower()
if platform not in PLATFORMS:
valid = "、".join(PLATFORMS)
raise ValueError(f"无效的平台: {platform!r},有效平台: {valid}")
return platform
def validate_status_transition(current: str, target: str) -> bool:
"""校验状态流转是否合法。
Args:
current: 当前状态。
target: 目标状态。
Returns:
True 表示流转合法。
Raises:
ValueError: 当流转不合法时抛出。
"""
allowed = STATUS_TRANSITIONS.get(current, [])
if target not in allowed:
allowed_str = "、".join(allowed) if allowed else "无"
raise ValueError(
f"不允许从「{current}」变更为「{target}」。"
f"当前状态允许变更为: {allowed_str}"
)
return True
def format_platform_name(platform: str) -> str:
"""获取平台的显示名称。
Args:
platform: 平台标识符(如 "twitter")。
Returns:
平台显示名称(如 "Twitter / X"),未知平台返回原始值。
"""
return PLATFORM_NAMES.get(platform.lower(), platform)
def truncate_text(text: str, max_len: int) -> str:
"""截断文本到指定长度。
若文本超过 max_len,截断并在末尾添加省略号。
Args:
text: 原始文本。
max_len: 最大长度。
Returns:
截断后的文本。若无需截断则返回原始文本。
"""
if not text:
return text
if max_len <= 0:
return text
if len(text) <= max_len:
return text
# 保留空间给省略号
if max_len <= 3:
return text[:max_len]
return text[:max_len - 3] + "..."
def count_chars(text: str, platform: str = "") -> int:
"""计算文本在指定平台的字符数。
Twitter 中一个中文/日文/韩文字符占 2 个字符位;
其他平台按实际字符数计算。
Args:
text: 待计算的文本。
platform: 平台名称,影响计数规则。
Returns:
平台字符数。
"""
if not text:
return 0
platform = platform.lower() if platform else ""
if platform == "twitter":
# Twitter: CJK 字符占 2 个字符位
count = 0
for ch in text:
if _is_cjk_char(ch):
count += 2
else:
count += 1
return count
else:
# 其他平台按实际字符数计算
return len(text)
def _is_cjk_char(ch: str) -> bool:
"""判断字符是否为 CJK(中日韩)字符。
Args:
ch: 单个字符。
Returns:
True 表示是 CJK 字符。
"""
cp = ord(ch)
# CJK 统一汉字基本区
if 0x4E00 <= cp <= 0x9FFF:
return True
# CJK 扩展A区
if 0x3400 <= cp <= 0x4DBF:
return True
# CJK 统一汉字扩展B区
if 0x20000 <= cp <= 0x2A6DF:
return True
# CJK 兼容汉字
if 0xF900 <= cp <= 0xFAFF:
return True
# 日文平假名/片假名
if 0x3040 <= cp <= 0x30FF:
return True
# 韩文音节
if 0xAC00 <= cp <= 0xD7AF:
return True
# 全角字符
if 0xFF01 <= cp <= 0xFF60:
return True
return False
def sanitize_html(html: str) -> str:
"""清理 HTML 内容,移除危险标签和属性。
保留基本格式标签(p, br, strong, em, a, img, h1-h6, ul, ol, li, blockquote)。
移除 script, style, iframe, form 等危险标签。
移除 on* 事件属性。
Args:
html: 原始 HTML 字符串。
Returns:
清理后的 HTML 字符串。
"""
if not html:
return html
# 移除危险标签及其内容
dangerous_tags = ["script", "style", "iframe", "form", "input", "textarea", "select", "button", "object", "embed"]
for tag in dangerous_tags:
# 移除开闭标签及内容
pattern = re.compile(rf"<{tag}[^>]*>.*?</{tag}>", re.DOTALL | re.IGNORECASE)
html = pattern.sub("", html)
# 移除自闭合标签
pattern = re.compile(rf"<{tag}[^>]*/?>", re.IGNORECASE)
html = pattern.sub("", html)
# 移除所有 on* 事件属性
html = re.sub(r'\s+on\w+\s*=\s*["\'][^"\']*["\']', "", html, flags=re.IGNORECASE)
html = re.sub(r"\s+on\w+\s*=\s*\S+", "", html, flags=re.IGNORECASE)
# 移除 javascript: 链接
html = re.sub(r'href\s*=\s*["\']javascript:[^"\']*["\']', 'href="#"', html, flags=re.IGNORECASE)
return html.strip()
def generate_id(prefix: str = "CT") -> str:
"""生成唯一 ID。
基于时间戳生成,格式为 前缀+时间戳。
Args:
prefix: ID 前缀,默认为 "CT"(内容)。
Returns:
唯一 ID 字符串。
"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
return f"{prefix}{timestamp}"
def now_iso() -> str:
"""返回当前时间的 ISO 格式字符串。
Returns:
ISO 格式时间字符串,如 "2026-03-19T10:30:00"。
"""
return datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
def today_str() -> str:
"""返回今天的日期字符串。
Returns:
日期字符串,格式为 "YYYY-MM-DD"。
"""
return datetime.now().strftime("%Y-%m-%d")
def calculate_days_until(date_str: str) -> int:
"""计算从今天到指定日期的天数。
Args:
date_str: 日期字符串,格式为 YYYY-MM-DD 或 ISO 格式。
Returns:
距今天数(正数表示未来,负数表示过去)。
"""
try:
if "T" in date_str:
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
dt = dt.replace(tzinfo=None)
else:
dt = datetime.strptime(date_str, "%Y-%m-%d")
delta = dt - datetime.now()
return delta.days
except (ValueError, TypeError):
return 0
def get_week_range(date_str: Optional[str] = None) -> Dict[str, str]:
"""获取指定日期所在周的起止日期。
Args:
date_str: 日期字符串(YYYY-MM-DD),默认为今天。
Returns:
包含 start 和 end 键的字典,值为日期字符串。
"""
if date_str:
try:
dt = datetime.strptime(date_str, "%Y-%m-%d")
except ValueError:
dt = datetime.now()
else:
dt = datetime.now()
# 周一为起始
start = dt - timedelta(days=dt.weekday())
end = start + timedelta(days=6)
return {
"start": start.strftime("%Y-%m-%d"),
"end": end.strftime("%Y-%m-%d"),
}
def get_month_range(date_str: Optional[str] = None) -> Dict[str, str]:
"""获取指定日期所在月的起止日期。
Args:
date_str: 日期字符串(YYYY-MM-DD),默认为今天。
Returns:
包含 start 和 end 键的字典,值为日期字符串。
"""
if date_str:
try:
dt = datetime.strptime(date_str, "%Y-%m-%d")
except ValueError:
dt = datetime.now()
else:
dt = datetime.now()
start = dt.replace(day=1)
# 下月第一天减一天 = 本月最后一天
if dt.month == 12:
end = dt.replace(year=dt.year + 1, month=1, day=1) - timedelta(days=1)
else:
end = dt.replace(month=dt.month + 1, day=1) - timedelta(days=1)
return {
"start": start.strftime("%Y-%m-%d"),
"end": end.strftime("%Y-%m-%d"),
}
FILE:scripts/publisher.py
#!/usr/bin/env python3
"""
content-engine 发布管理模块
支持将内容发布到 Twitter、LinkedIn、微信公众号、博客、Medium 等平台。
大部分发布功能为付费版功能。
"""
import json
import os
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError
from urllib.parse import urlencode
from utils import (
check_subscription,
format_platform_name,
generate_id,
get_data_file,
load_input_data,
now_iso,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
validate_platform,
validate_status,
write_json_file,
PLATFORMS,
)
# ============================================================
# 数据文件路径
# ============================================================
CONTENTS_FILE = "contents.json"
ADAPTED_FILE = "adapted_contents.json"
PUBLISH_HISTORY_FILE = "publish_history.json"
SCHEDULE_FILE = "schedule.json"
def _get_contents() -> List[Dict[str, Any]]:
"""读取所有内容数据。"""
return read_json_file(get_data_file(CONTENTS_FILE))
def _save_contents(contents: List[Dict[str, Any]]) -> None:
"""保存内容数据到文件。"""
write_json_file(get_data_file(CONTENTS_FILE), contents)
def _find_content(contents: List[Dict], content_id: str) -> Optional[Dict]:
"""根据 ID 查找内容。"""
for c in contents:
if c.get("id") == content_id:
return c
return None
def _get_adapted() -> List[Dict[str, Any]]:
"""读取所有已适配内容。"""
return read_json_file(get_data_file(ADAPTED_FILE))
def _get_publish_history() -> List[Dict[str, Any]]:
"""读取发布历史。"""
return read_json_file(get_data_file(PUBLISH_HISTORY_FILE))
def _save_publish_history(history: List[Dict[str, Any]]) -> None:
"""保存发布历史。"""
write_json_file(get_data_file(PUBLISH_HISTORY_FILE), history)
def _get_schedules() -> List[Dict[str, Any]]:
"""读取排期列表。"""
return read_json_file(get_data_file(SCHEDULE_FILE))
def _save_schedules(schedules: List[Dict[str, Any]]) -> None:
"""保存排期列表。"""
write_json_file(get_data_file(SCHEDULE_FILE), schedules)
# ============================================================
# 平台发布实现
# ============================================================
def _api_request(url: str, data: Any = None, headers: Optional[Dict] = None, method: str = "POST") -> Dict[str, Any]:
"""发送 HTTP API 请求。
Args:
url: 请求 URL。
data: 请求数据(将转为 JSON)。
headers: 请求头。
method: HTTP 方法。
Returns:
响应数据字典。
Raises:
Exception: 请求失败时抛出。
"""
if headers is None:
headers = {}
headers.setdefault("Content-Type", "application/json")
body = None
if data is not None:
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
req = Request(url, data=body, headers=headers, method=method)
try:
with urlopen(req, timeout=30) as resp:
resp_data = resp.read().decode("utf-8")
return json.loads(resp_data) if resp_data else {}
except HTTPError as e:
error_body = e.read().decode("utf-8") if e.fp else ""
raise Exception(f"HTTP {e.code}: {error_body}")
except URLError as e:
raise Exception(f"网络请求失败: {e.reason}")
def _publish_twitter(adapted: Dict[str, Any]) -> Dict[str, Any]:
"""发布到 Twitter (X)。
使用 Twitter API v2,通过 Bearer Token 认证。
Args:
adapted: 已适配的 Twitter 内容。
Returns:
发布结果字典。
"""
bearer_token = os.environ.get("CE_TWITTER_BEARER_TOKEN", "")
if not bearer_token:
return {
"success": False,
"error": "未配置 CE_TWITTER_BEARER_TOKEN 环境变量",
}
tweets = adapted.get("tweets", [])
if not tweets:
return {"success": False, "error": "无推文内容"}
headers = {
"Authorization": f"Bearer {bearer_token}",
"Content-Type": "application/json",
}
tweet_ids = []
reply_to_id = None
for i, tweet_text in enumerate(tweets):
payload = {"text": tweet_text}
# Thread 模式: 后续推文作为回复
if reply_to_id:
payload["reply"] = {"in_reply_to_tweet_id": reply_to_id}
try:
result = _api_request(
"https://api.twitter.com/2/tweets",
data=payload,
headers=headers,
)
tweet_id = result.get("data", {}).get("id", "")
tweet_ids.append(tweet_id)
reply_to_id = tweet_id
except Exception as e:
return {
"success": False,
"error": f"第 {i + 1} 条推文发布失败: {str(e)}",
"published_tweets": tweet_ids,
}
return {
"success": True,
"tweet_ids": tweet_ids,
"tweet_count": len(tweet_ids),
"url": f"https://twitter.com/i/status/{tweet_ids[0]}" if tweet_ids else "",
}
def _publish_linkedin(adapted: Dict[str, Any]) -> Dict[str, Any]:
"""发布到 LinkedIn。
使用 LinkedIn API,通过 Access Token 认证。
Args:
adapted: 已适配的 LinkedIn 内容。
Returns:
发布结果字典。
"""
access_token = os.environ.get("CE_LINKEDIN_ACCESS_TOKEN", "")
if not access_token:
return {
"success": False,
"error": "未配置 CE_LINKEDIN_ACCESS_TOKEN 环境变量",
}
text = adapted.get("text", "")
if not text:
return {"success": False, "error": "无帖子内容"}
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"X-Restli-Protocol-Version": "2.0.0",
}
# 获取用户 profile URN
try:
profile = _api_request(
"https://api.linkedin.com/v2/me",
headers={"Authorization": f"Bearer {access_token}"},
method="GET",
)
person_urn = f"urn:li:person:{profile.get('id', '')}"
except Exception as e:
return {"success": False, "error": f"获取 LinkedIn 用户信息失败: {str(e)}"}
payload = {
"author": person_urn,
"lifecycleState": "PUBLISHED",
"specificContent": {
"com.linkedin.ugc.ShareContent": {
"shareCommentary": {"text": text},
"shareMediaCategory": "NONE",
}
},
"visibility": {
"com.linkedin.ugc.MemberNetworkVisibility": "PUBLIC"
},
}
try:
result = _api_request(
"https://api.linkedin.com/v2/ugcPosts",
data=payload,
headers=headers,
)
post_id = result.get("id", "")
return {
"success": True,
"post_id": post_id,
"url": f"https://www.linkedin.com/feed/update/{post_id}" if post_id else "",
}
except Exception as e:
return {"success": False, "error": f"LinkedIn 发布失败: {str(e)}"}
def _publish_wechat(adapted: Dict[str, Any]) -> Dict[str, Any]:
"""发布到微信公众号。
使用微信公众号 API:
1. 通过 appid + secret 获取 access_token
2. 上传图文素材
3. 发布文章
Args:
adapted: 已适配的微信公众号内容。
Returns:
发布结果字典。
"""
appid = os.environ.get("CE_WECHAT_APPID", "")
secret = os.environ.get("CE_WECHAT_SECRET", "")
if not appid or not secret:
return {
"success": False,
"error": "未配置 CE_WECHAT_APPID 和 CE_WECHAT_SECRET 环境变量",
}
# 第一步: 获取 access_token
try:
token_url = (
f"https://api.weixin.qq.com/cgi-bin/token?"
f"grant_type=client_credential&appid={appid}&secret={secret}"
)
req = Request(token_url, method="GET")
with urlopen(req, timeout=30) as resp:
token_data = json.loads(resp.read().decode("utf-8"))
if "access_token" not in token_data:
errcode = token_data.get("errcode", "unknown")
errmsg = token_data.get("errmsg", "未知错误")
return {"success": False, "error": f"获取 access_token 失败: {errcode} - {errmsg}"}
access_token = token_data["access_token"]
except Exception as e:
return {"success": False, "error": f"获取微信 access_token 失败: {str(e)}"}
# 第二步: 添加草稿(新版接口)
title = adapted.get("title", "")
html = adapted.get("html", "")
digest = adapted.get("digest", "")
author = adapted.get("author", "")
article = {
"articles": [
{
"title": title,
"author": author,
"digest": digest,
"content": html,
"content_source_url": "",
"thumb_media_id": "",
"need_open_comment": 0,
"only_fans_can_comment": 0,
}
]
}
try:
draft_url = f"https://api.weixin.qq.com/cgi-bin/draft/add?access_token={access_token}"
result = _api_request(draft_url, data=article)
if "media_id" not in result:
errcode = result.get("errcode", "unknown")
errmsg = result.get("errmsg", "未知错误")
return {"success": False, "error": f"创建草稿失败: {errcode} - {errmsg}"}
media_id = result["media_id"]
except Exception as e:
return {"success": False, "error": f"创建微信草稿失败: {str(e)}"}
# 第三步: 发布
try:
publish_url = f"https://api.weixin.qq.com/cgi-bin/freepublish/submit?access_token={access_token}"
pub_result = _api_request(publish_url, data={"media_id": media_id})
publish_id = pub_result.get("publish_id", "")
if pub_result.get("errcode", 0) != 0:
errmsg = pub_result.get("errmsg", "未知错误")
return {"success": False, "error": f"发布失败: {errmsg}", "media_id": media_id}
return {
"success": True,
"media_id": media_id,
"publish_id": publish_id,
}
except Exception as e:
return {
"success": False,
"error": f"微信发布失败: {str(e)}",
"media_id": media_id,
}
def _publish_medium(adapted: Dict[str, Any]) -> Dict[str, Any]:
"""发布到 Medium。
使用 Medium API,通过 Integration Token 认证。
Args:
adapted: 已适配的 Medium 内容。
Returns:
发布结果字典。
"""
token = os.environ.get("CE_MEDIUM_TOKEN", "")
if not token:
return {
"success": False,
"error": "未配置 CE_MEDIUM_TOKEN 环境变量",
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
# 获取用户 ID
try:
user_data = _api_request(
"https://api.medium.com/v1/me",
headers={"Authorization": f"Bearer {token}"},
method="GET",
)
user_id = user_data.get("data", {}).get("id", "")
if not user_id:
return {"success": False, "error": "无法获取 Medium 用户信息"}
except Exception as e:
return {"success": False, "error": f"获取 Medium 用户信息失败: {str(e)}"}
title = adapted.get("title", "")
markdown = adapted.get("markdown", "")
tags = adapted.get("tags", [])
payload = {
"title": title,
"contentFormat": "markdown",
"content": markdown,
"tags": tags[:5],
"publishStatus": "public",
}
try:
result = _api_request(
f"https://api.medium.com/v1/users/{user_id}/posts",
data=payload,
headers=headers,
)
post = result.get("data", {})
return {
"success": True,
"post_id": post.get("id", ""),
"url": post.get("url", ""),
}
except Exception as e:
return {"success": False, "error": f"Medium 发布失败: {str(e)}"}
def _publish_blog(adapted: Dict[str, Any]) -> Dict[str, Any]:
"""发布到博客(写入本地文件系统)。
根据 CE_BLOG_TYPE 和 CE_BLOG_PATH 写入 Markdown 文件。
Args:
adapted: 已适配的博客内容。
Returns:
发布结果字典。
"""
blog_path = os.environ.get("CE_BLOG_PATH", "")
if not blog_path:
return {
"success": False,
"error": "未配置 CE_BLOG_PATH 环境变量(博客内容目录路径)",
}
blog_type = adapted.get("blog_type", "hugo")
markdown = adapted.get("markdown", "")
filename = adapted.get("suggested_filename", "untitled.md")
# 确定写入目录
if blog_type == "hugo":
target_dir = os.path.join(blog_path, "content", "posts")
elif blog_type == "jekyll":
target_dir = os.path.join(blog_path, "_posts")
elif blog_type == "hexo":
target_dir = os.path.join(blog_path, "source", "_posts")
else:
target_dir = blog_path
os.makedirs(target_dir, exist_ok=True)
file_path = os.path.join(target_dir, filename)
try:
with open(file_path, "w", encoding="utf-8") as f:
f.write(markdown)
return {
"success": True,
"file_path": file_path,
"blog_type": blog_type,
}
except IOError as e:
return {"success": False, "error": f"写入博客文件失败: {str(e)}"}
# 平台发布器注册表
_PUBLISHERS = {
"twitter": _publish_twitter,
"linkedin": _publish_linkedin,
"wechat": _publish_wechat,
"medium": _publish_medium,
"blog": _publish_blog,
}
# ============================================================
# 发布操作
# ============================================================
def publish_content(data: Dict[str, Any]) -> None:
"""发布内容到指定平台。
必填字段: id, platform
Args:
data: 包含内容 ID 和目标平台的字典。
"""
if not require_paid_feature("auto_publish", "自动发布"):
return
content_id = data.get("id")
platform = data.get("platform", "")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
if not platform:
output_error("目标平台(platform)为必填字段", code="VALIDATION_ERROR")
return
try:
platform = validate_platform(platform)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
# 查找已适配的内容
adapted_list = _get_adapted()
adapted = None
for a in adapted_list:
if a.get("content_id") == content_id and a.get("platform") == platform:
adapted = a
break
if not adapted:
output_error(
f"未找到内容 {content_id} 在 {format_platform_name(platform)} 的适配版本,"
"请先执行 adapt 操作",
code="NOT_FOUND",
)
return
# 执行发布
publisher = _PUBLISHERS.get(platform)
if not publisher:
output_error(f"暂不支持发布到: {platform}", code="UNSUPPORTED_PLATFORM")
return
result = publisher(adapted)
# 记录发布历史
history = _get_publish_history()
record = {
"id": generate_id("PH"),
"content_id": content_id,
"platform": platform,
"platform_name": format_platform_name(platform),
"result": result,
"published_at": now_iso(),
}
history.append(record)
_save_publish_history(history)
# 更新内容状态和发布结果
contents = _get_contents()
content = _find_content(contents, content_id)
if content and result.get("success"):
content["status"] = "已发布"
content["published_at"] = now_iso()
if "publish_results" not in content:
content["publish_results"] = {}
content["publish_results"][platform] = result
content["updated_at"] = now_iso()
_save_contents(contents)
if result.get("success"):
output_success({
"message": f"已成功发布到 {format_platform_name(platform)}",
"publish_record": record,
})
else:
output_error(
f"发布到 {format_platform_name(platform)} 失败: {result.get('error', '未知错误')}",
code="PUBLISH_FAILED",
)
def schedule_content(data: Dict[str, Any]) -> None:
"""排期发布内容。
必填字段: id, platform, scheduled_at(ISO 格式日期时间)
Args:
data: 包含内容 ID、平台和排期时间的字典。
"""
if not require_paid_feature("schedule", "定时发布"):
return
content_id = data.get("id")
platform = data.get("platform", "")
scheduled_at = data.get("scheduled_at", "")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
if not platform:
output_error("目标平台(platform)为必填字段", code="VALIDATION_ERROR")
return
if not scheduled_at:
output_error("排期时间(scheduled_at)为必填字段", code="VALIDATION_ERROR")
return
try:
platform = validate_platform(platform)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
# 校验时间格式和合法性
try:
scheduled_dt = datetime.fromisoformat(scheduled_at.replace("Z", "+00:00"))
scheduled_dt = scheduled_dt.replace(tzinfo=None)
if scheduled_dt <= datetime.now():
output_error("排期时间必须是未来时间", code="VALIDATION_ERROR")
return
except ValueError:
output_error("排期时间格式无效,请使用 ISO 格式(如 2026-03-20T10:00:00)", code="VALIDATION_ERROR")
return
# 检查内容是否存在
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
# 创建排期记录
schedules = _get_schedules()
schedule = {
"id": generate_id("SC"),
"content_id": content_id,
"platform": platform,
"platform_name": format_platform_name(platform),
"scheduled_at": scheduled_at,
"status": "pending",
"created_at": now_iso(),
}
schedules.append(schedule)
_save_schedules(schedules)
# 更新内容状态为"已排期"
if content["status"] in ("草稿", "待审核"):
content["status"] = "已排期"
content["scheduled_at"] = scheduled_at
content["updated_at"] = now_iso()
_save_contents(contents)
output_success({
"message": f"已排期在 {scheduled_at} 发布到 {format_platform_name(platform)}",
"schedule": schedule,
})
def list_published(data: Optional[Dict[str, Any]] = None) -> None:
"""列出发布历史。
可选过滤: platform, content_id, date_from, date_to
Args:
data: 可选的过滤条件字典。
"""
history = _get_publish_history()
if data:
# 按平台过滤
platform_filter = data.get("platform")
if platform_filter:
history = [h for h in history if h.get("platform") == platform_filter.lower()]
# 按内容 ID 过滤
content_id = data.get("content_id")
if content_id:
history = [h for h in history if h.get("content_id") == content_id]
# 按日期范围过滤
date_from = data.get("date_from", "")
date_to = data.get("date_to", "")
if date_from:
history = [h for h in history if h.get("published_at", "") >= date_from]
if date_to:
history = [h for h in history if h.get("published_at", "") <= date_to + "T23:59:59"]
# 按发布时间倒序
history.sort(key=lambda h: h.get("published_at", ""), reverse=True)
# 统计
success_count = sum(1 for h in history if h.get("result", {}).get("success"))
fail_count = len(history) - success_count
platform_stats = {}
for p in PLATFORMS:
count = sum(1 for h in history if h.get("platform") == p)
if count > 0:
platform_stats[format_platform_name(p)] = count
output_success({
"total": len(history),
"success_count": success_count,
"fail_count": fail_count,
"platform_stats": platform_stats,
"history": history,
})
def unpublish_content(data: Dict[str, Any]) -> None:
"""撤回已发布的内容(标记为已归档)。
必填字段: id
注意: 此操作仅更新本地状态,不会从平台上删除已发布的内容。
Args:
data: 包含内容 ID 的字典。
"""
content_id = data.get("id")
if not content_id:
output_error("内容ID(id)为必填字段", code="VALIDATION_ERROR")
return
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
if content["status"] != "已发布":
output_error(
f"只能撤回已发布的内容,当前状态为「{content['status']}」",
code="INVALID_STATUS",
)
return
content["status"] = "已归档"
content["updated_at"] = now_iso()
_save_contents(contents)
output_success({
"message": f"内容已标记为归档(注意:已发布到平台的内容需手动删除)",
"content_id": content_id,
})
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("content-engine 发布管理")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"publish": lambda: publish_content(data or {}),
"schedule": lambda: schedule_content(data or {}),
"list-published": lambda: list_published(data),
"unpublish": lambda: unpublish_content(data or {}),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:scripts/calendar_manager.py
#!/usr/bin/env python3
"""
content-engine 内容日历管理模块
提供内容发布日历的规划、查看、建议和导出功能。
"""
import csv
import io
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from utils import (
check_subscription,
format_platform_name,
generate_id,
get_data_file,
get_month_range,
get_week_range,
load_input_data,
now_iso,
today_str,
output_error,
output_success,
parse_common_args,
read_json_file,
require_paid_feature,
truncate_text,
validate_platform,
write_json_file,
PLATFORMS,
)
# ============================================================
# 数据文件路径
# ============================================================
CALENDAR_FILE = "calendar.json"
CONTENTS_FILE = "contents.json"
def _get_calendar() -> List[Dict[str, Any]]:
"""读取日历数据。"""
return read_json_file(get_data_file(CALENDAR_FILE))
def _save_calendar(calendar: List[Dict[str, Any]]) -> None:
"""保存日历数据。"""
write_json_file(get_data_file(CALENDAR_FILE), calendar)
def _get_contents() -> List[Dict[str, Any]]:
"""读取所有内容数据。"""
return read_json_file(get_data_file(CONTENTS_FILE))
def _find_content(contents: List[Dict], content_id: str) -> Optional[Dict]:
"""根据 ID 查找内容。"""
for c in contents:
if c.get("id") == content_id:
return c
return None
# ============================================================
# 最佳发布时间建议
# ============================================================
# 各平台推荐发布时间(基于行业经验数据)
OPTIMAL_POSTING_TIMES = {
"twitter": {
"weekday": ["09:00", "12:00", "17:00", "20:00"],
"weekend": ["10:00", "14:00", "19:00"],
"best": "周二至周四 09:00-12:00",
"avoid": "凌晨 00:00-06:00",
},
"linkedin": {
"weekday": ["08:00", "10:00", "12:00", "17:00"],
"weekend": ["10:00"],
"best": "周二至周四 08:00-10:00",
"avoid": "周末和晚间",
},
"wechat": {
"weekday": ["07:30", "12:00", "18:00", "21:00"],
"weekend": ["09:00", "12:00", "20:00"],
"best": "工作日 18:00-21:00(通勤和休闲时段)",
"avoid": "凌晨 01:00-06:00",
},
"blog": {
"weekday": ["10:00", "14:00"],
"weekend": ["10:00"],
"best": "周一至周三 10:00",
"avoid": "无特别限制",
},
"medium": {
"weekday": ["08:00", "11:00", "14:00"],
"weekend": ["10:00", "14:00"],
"best": "周二至周四 08:00-11:00",
"avoid": "周五下午和周末晚间",
},
}
# 星期映射
WEEKDAY_NAMES = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
# ============================================================
# 日历操作
# ============================================================
def plan_calendar(data: Dict[str, Any]) -> None:
"""添加内容发布计划到日历。
必填字段: content_id, platform, date
可选字段: time(发布时间,默认使用推荐时间), note
Args:
data: 日历计划数据字典。
"""
if not require_paid_feature("calendar", "内容日历"):
return
content_id = data.get("content_id") or data.get("id")
platform = data.get("platform", "")
date = data.get("date", "")
if not content_id:
output_error("内容ID(content_id)为必填字段", code="VALIDATION_ERROR")
return
if not platform:
output_error("目标平台(platform)为必填字段", code="VALIDATION_ERROR")
return
if not date:
output_error("发布日期(date)为必填字段,格式: YYYY-MM-DD", code="VALIDATION_ERROR")
return
try:
platform = validate_platform(platform)
except ValueError as e:
output_error(str(e), code="VALIDATION_ERROR")
return
# 校验日期格式
try:
plan_date = datetime.strptime(date, "%Y-%m-%d")
except ValueError:
output_error("日期格式无效,请使用 YYYY-MM-DD 格式", code="VALIDATION_ERROR")
return
# 检查内容是否存在
contents = _get_contents()
content = _find_content(contents, content_id)
if not content:
output_error(f"未找到ID为 {content_id} 的内容", code="NOT_FOUND")
return
# 确定发布时间
time = data.get("time", "")
if not time:
# 使用推荐时间
weekday = plan_date.weekday()
times = OPTIMAL_POSTING_TIMES.get(platform, {})
if weekday < 5:
suggested_times = times.get("weekday", ["10:00"])
else:
suggested_times = times.get("weekend", ["10:00"])
time = suggested_times[0] if suggested_times else "10:00"
# 创建日历条目
calendar = _get_calendar()
entry = {
"id": generate_id("CL"),
"content_id": content_id,
"content_title": truncate_text(content.get("title", ""), 50),
"platform": platform,
"platform_name": format_platform_name(platform),
"date": date,
"time": time,
"weekday": WEEKDAY_NAMES[plan_date.weekday()],
"note": data.get("note", ""),
"status": "planned",
"created_at": now_iso(),
}
calendar.append(entry)
_save_calendar(calendar)
output_success({
"message": f"已添加日历计划: {date} {time} 发布到 {format_platform_name(platform)}",
"entry": entry,
})
def view_calendar(data: Optional[Dict[str, Any]] = None) -> None:
"""查看内容日历。
可选字段: view(week/month,默认 week), date(起始日期,默认今天)
Args:
data: 可选的查看参数字典。
"""
if not require_paid_feature("calendar", "内容日历"):
return
view_type = data.get("view", "week") if data else "week"
base_date = data.get("date", today_str()) if data else today_str()
calendar = _get_calendar()
# 确定日期范围
if view_type == "month":
date_range = get_month_range(base_date)
else:
date_range = get_week_range(base_date)
start = date_range["start"]
end = date_range["end"]
# 过滤日期范围内的条目
filtered = [
e for e in calendar
if start <= e.get("date", "") <= end
]
# 按日期和时间排序
filtered.sort(key=lambda e: (e.get("date", ""), e.get("time", "")))
# 按日期分组
grouped = {}
current = datetime.strptime(start, "%Y-%m-%d")
end_dt = datetime.strptime(end, "%Y-%m-%d")
while current <= end_dt:
date_str = current.strftime("%Y-%m-%d")
weekday = WEEKDAY_NAMES[current.weekday()]
day_entries = [e for e in filtered if e.get("date") == date_str]
grouped[date_str] = {
"weekday": weekday,
"entries": day_entries,
"count": len(day_entries),
}
current += timedelta(days=1)
# 生成 Markdown 视图
view_title = "月度日历" if view_type == "month" else "周日历"
report_lines = [f"# 内容{view_title}({start} ~ {end})", ""]
for date_str, day_data in grouped.items():
weekday = day_data["weekday"]
entries = day_data["entries"]
if entries:
report_lines.append(f"## {date_str}({weekday})")
report_lines.append("")
for e in entries:
status_icon = "✅" if e.get("status") == "published" else "📅"
report_lines.append(
f"- {status_icon} **{e.get('time', '')}** "
f"[{e.get('platform_name', '')}] "
f"{e.get('content_title', '')}"
)
if e.get("note"):
report_lines.append(f" > {e['note']}")
report_lines.append("")
else:
report_lines.append(f"## {date_str}({weekday})— 无计划")
report_lines.append("")
# 统计
total_planned = len(filtered)
platform_dist = {}
for e in filtered:
pname = e.get("platform_name", "")
platform_dist[pname] = platform_dist.get(pname, 0) + 1
report = "\n".join(report_lines)
output_success({
"view_type": view_type,
"date_range": date_range,
"total_planned": total_planned,
"platform_distribution": platform_dist,
"calendar_view": report,
"entries": filtered,
})
def suggest_times(data: Dict[str, Any]) -> None:
"""建议最佳发布时间。
可选字段: platform(不指定则返回所有平台的建议), date
Args:
data: 可选的参数字典。
"""
platform_filter = data.get("platform") if data else None
target_date = data.get("date", today_str()) if data else today_str()
try:
dt = datetime.strptime(target_date, "%Y-%m-%d")
weekday = dt.weekday()
weekday_name = WEEKDAY_NAMES[weekday]
is_weekend = weekday >= 5
except ValueError:
dt = datetime.now()
weekday = dt.weekday()
weekday_name = WEEKDAY_NAMES[weekday]
is_weekend = weekday >= 5
suggestions = []
platforms = [platform_filter] if platform_filter else PLATFORMS
for platform in platforms:
try:
platform = validate_platform(platform)
except ValueError:
continue
times_config = OPTIMAL_POSTING_TIMES.get(platform, {})
time_key = "weekend" if is_weekend else "weekday"
recommended_times = times_config.get(time_key, ["10:00"])
suggestions.append({
"platform": platform,
"platform_name": format_platform_name(platform),
"date": target_date,
"weekday": weekday_name,
"recommended_times": recommended_times,
"best_time": times_config.get("best", ""),
"avoid": times_config.get("avoid", ""),
})
output_success({
"message": f"{target_date}({weekday_name})发布时间建议",
"suggestions": suggestions,
})
def export_calendar(data: Optional[Dict[str, Any]] = None) -> None:
"""导出内容日历。
可选字段: format(markdown/csv,默认 markdown), file_path, view(week/month), date
Args:
data: 可选的导出参数字典。
"""
if not require_paid_feature("calendar", "内容日历"):
return
export_format = data.get("format", "markdown") if data else "markdown"
file_path = data.get("file_path") if data else None
view_type = data.get("view", "month") if data else "month"
base_date = data.get("date", today_str()) if data else today_str()
calendar = _get_calendar()
# 确定日期范围
if view_type == "month":
date_range = get_month_range(base_date)
else:
date_range = get_week_range(base_date)
start = date_range["start"]
end = date_range["end"]
filtered = [
e for e in calendar
if start <= e.get("date", "") <= end
]
filtered.sort(key=lambda e: (e.get("date", ""), e.get("time", "")))
if export_format == "csv":
output_content = _export_csv(filtered)
else:
output_content = _export_markdown(filtered, start, end)
if file_path:
try:
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
f.write(output_content)
output_success({
"message": f"日历已导出到 {file_path}",
"format": export_format,
"entries_count": len(filtered),
})
except IOError as e:
output_error(f"导出失败: {e}", code="EXPORT_ERROR")
else:
output_success({
"content": output_content,
"format": export_format,
"entries_count": len(filtered),
})
def _export_markdown(entries: List[Dict[str, Any]], start: str, end: str) -> str:
"""导出为 Markdown 格式。
Args:
entries: 日历条目列表。
start: 起始日期。
end: 结束日期。
Returns:
Markdown 格式的日历内容。
"""
lines = [f"# 内容发布日历({start} ~ {end})", ""]
lines.append("| 日期 | 星期 | 时间 | 平台 | 内容 | 状态 | 备注 |")
lines.append("|------|------|------|------|------|------|------|")
for e in entries:
status = "已发布" if e.get("status") == "published" else "已计划"
lines.append(
f"| {e.get('date', '')} "
f"| {e.get('weekday', '')} "
f"| {e.get('time', '')} "
f"| {e.get('platform_name', '')} "
f"| {e.get('content_title', '')} "
f"| {status} "
f"| {e.get('note', '')} |"
)
# 付费版: 添加 Mermaid Gantt 图
sub = check_subscription()
if sub["tier"] == "paid" and entries:
lines.append("")
lines.append("## 时间线视图")
lines.append("")
lines.append(_generate_gantt_chart(entries))
return "\n".join(lines)
def _generate_gantt_chart(entries: List[Dict[str, Any]]) -> str:
"""生成 Mermaid Gantt 时间线图。
Args:
entries: 日历条目列表。
Returns:
Mermaid Gantt 图代码块。
"""
lines = ["```mermaid", "gantt", " title 内容发布时间线", " dateFormat YYYY-MM-DD"]
# 按平台分组
platforms_seen = {}
for e in entries:
platform = e.get("platform_name", "其他")
if platform not in platforms_seen:
platforms_seen[platform] = []
platforms_seen[platform].append(e)
for platform, platform_entries in platforms_seen.items():
lines.append(f" section {platform}")
for e in platform_entries:
title = e.get("content_title", "内容")[:20]
date = e.get("date", today_str())
status = "done," if e.get("status") == "published" else ""
# Gantt 条目格式: 任务名 :状态, 开始日期, 持续时间
lines.append(f" {title} :{status} {date}, 1d")
lines.append("```")
return "\n".join(lines)
def _export_csv(entries: List[Dict[str, Any]]) -> str:
"""导出为 CSV 格式。
Args:
entries: 日历条目列表。
Returns:
CSV 格式的日历内容。
"""
output = io.StringIO()
fieldnames = ["date", "weekday", "time", "platform_name", "content_title", "status", "note", "content_id"]
writer = csv.DictWriter(output, fieldnames=fieldnames)
writer.writeheader()
for e in entries:
row = {k: e.get(k, "") for k in fieldnames}
writer.writerow(row)
return output.getvalue()
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""主函数:解析命令行参数并分发操作。"""
parser = parse_common_args("content-engine 内容日历管理")
args = parser.parse_args()
action = args.action.lower()
try:
data = load_input_data(args)
except ValueError as e:
output_error(str(e), code="INPUT_ERROR")
return
actions = {
"plan": lambda: plan_calendar(data or {}),
"view": lambda: view_calendar(data),
"suggest": lambda: suggest_times(data),
"export": lambda: export_calendar(data),
}
handler = actions.get(action)
if handler:
handler()
else:
valid_actions = "、".join(actions.keys())
output_error(f"未知操作: {action},支持的操作: {valid_actions}", code="INVALID_ACTION")
if __name__ == "__main__":
main()
FILE:references/wechat-guide.md
# 微信公众号 API 配置与使用指南
本指南帮助你完成微信公众号的 API 接入配置,实现通过 content-engine 自动发布文章到公众号。
---
## 前置条件
1. 已注册微信公众号(服务号或订阅号,推荐服务号)
2. 已完成微信认证(未认证账号 API 权限受限)
3. 具有管理员权限
---
## 第一步:获取 AppID 和 AppSecret
1. 登录 [微信公众平台](https://mp.weixin.qq.com/)
2. 进入「设置与开发」→「基本配置」
3. 在「公众号开发信息」区域找到:
- **AppID(应用ID)**:固定不变,直接复制
- **AppSecret(应用密钥)**:点击「重置」生成新密钥,**务必立即保存**,页面关闭后无法再次查看
4. 配置环境变量:
```bash
export CE_WECHAT_APPID="你的AppID"
export CE_WECHAT_SECRET="你的AppSecret"
```
> **安全提醒**:AppSecret 等同于账号密码,绝不能泄露。不要将其写入代码或提交到版本控制系统。
---
## 第二步:配置 IP 白名单
1. 在「基本配置」页面找到「IP白名单」
2. 点击「修改」,添加你的服务器公网 IP 地址
3. 若在本地开发,添加本机公网 IP(可通过 `curl ifconfig.me` 查询)
> 未配置白名单的 IP 地址调用 API 将返回 `40164` 错误。
---
## 第三步:access_token 机制说明
### 获取 access_token
content-engine 会自动调用以下接口获取 access_token:
```
GET https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=APPID&secret=APPSECRET
```
返回示例:
```json
{
"access_token": "ACCESS_TOKEN",
"expires_in": 7200
}
```
### 有效期与刷新
- access_token 有效期为 **7200 秒(2 小时)**
- 每日调用上限为 **2000 次**
- content-engine 会在每次发布操作时自动获取新的 access_token
- 建议不要频繁调用,避免触发速率限制
### 常见错误码
| 错误码 | 说明 | 解决方案 |
|--------|------|----------|
| 40001 | AppSecret 错误 | 检查 CE_WECHAT_SECRET 是否正确 |
| 40002 | grant_type 无效 | 检查请求参数 |
| 40164 | IP 未在白名单 | 在公众号后台添加 IP 白名单 |
| 42001 | access_token 过期 | 重新获取(自动处理) |
| 45009 | API 调用频次超限 | 降低调用频率,等待后重试 |
---
## 第四步:文章发布流程
content-engine 的微信发布采用以下流程:
### 1. 创建草稿
```
POST https://api.weixin.qq.com/cgi-bin/draft/add?access_token=ACCESS_TOKEN
```
请求体:
```json
{
"articles": [
{
"title": "文章标题",
"author": "作者",
"digest": "摘要",
"content": "<p>HTML格式正文</p>",
"content_source_url": "",
"thumb_media_id": "封面图素材ID",
"need_open_comment": 0,
"only_fans_can_comment": 0
}
]
}
```
### 2. 发布文章
```
POST https://api.weixin.qq.com/cgi-bin/freepublish/submit?access_token=ACCESS_TOKEN
```
请求体:
```json
{
"media_id": "草稿的media_id"
}
```
### 3. 查询发布状态
```
POST https://api.weixin.qq.com/cgi-bin/freepublish/get?access_token=ACCESS_TOKEN
```
请求体:
```json
{
"publish_id": "发布任务ID"
}
```
---
## 第五步:素材管理
### 上传永久素材(封面图)
如需设置文章封面图,需先上传图片为永久素材:
```
POST https://api.weixin.qq.com/cgi-bin/material/add_material?access_token=ACCESS_TOKEN&type=image
```
返回的 `media_id` 可用于文章的 `thumb_media_id` 字段。
### 上传正文图片
正文中的图片需通过以下接口上传:
```
POST https://api.weixin.qq.com/cgi-bin/media/uploadimg?access_token=ACCESS_TOKEN
```
返回的 URL 可直接在 HTML 正文的 `<img>` 标签中使用。
---
## 常见问题
### Q1:订阅号和服务号有什么区别?
服务号拥有更多 API 权限,包括模板消息、自定义菜单、网页授权等。建议使用服务号进行内容分发。
### Q2:发布后可以修改文章吗?
微信公众号已发布的文章不支持直接修改内容。如需更正,可以删除后重新发布,或发布勘误声明。
### Q3:每天可以发布几篇文章?
- 订阅号:每天 1 次推送,每次最多 8 篇
- 服务号:每月 4 次推送,每次最多 8 篇
### Q4:如何处理图片?
目前 content-engine 的微信适配器会将 Markdown 图片转换为 HTML `<img>` 标签。若图片为外部 URL,需确保微信服务器可访问该地址。建议使用微信素材接口上传图片获取永久链接。
---
## 调试建议
1. 使用 [微信公众平台接口调试工具](https://mp.weixin.qq.com/debug/) 测试 API 调用
2. 检查返回的 JSON 中是否包含 `errcode` 字段
3. 确认 IP 白名单配置正确
4. 确认 AppSecret 未被重置(重置后旧密钥立即失效)
FILE:references/platform-specs.md
# Platform Specifications / 平台规格参考
各平台内容发布的技术规格、限制和格式要求。
---
## Twitter / X
| 项目 | 规格 |
|------|------|
| 单条推文字符限制 | 280 字符 |
| 中文字符计数 | 1 个中文字符 = 2 字符位 |
| Thread 上限 | 25 条推文 |
| 图片格式 | JPEG, PNG, GIF, WebP |
| 图片大小限制 | 5MB(静态图),15MB(GIF) |
| 图片尺寸建议 | 1200x675px(16:9) |
| Hashtag 建议 | 1-5 个,不宜过多 |
| 链接计数 | 每个链接占 23 字符 |
### 最佳实践
- 首条推文是 hook,吸引注意力
- Thread 内容结构化,每条聚焦一个要点
- 适当使用 emoji 增加可读性
- 末尾推文包含 CTA(行动号召)
---
## LinkedIn
| 项目 | 规格 |
|------|------|
| 帖子字符限制 | 3,000 字符 |
| 文章字符限制 | 110,000 字符 |
| 图片格式 | JPEG, PNG, GIF |
| 图片大小限制 | 10MB |
| 图片尺寸建议 | 1200x627px |
| Hashtag 建议 | 3-10 个 |
### 最佳实践
- 保持专业语气
- 开头一句话抓住注意力
- 段落间空行增加可读性
- 数据和案例支撑观点
- 结尾提出讨论问题
---
## 微信公众号
| 项目 | 规格 |
|------|------|
| 文章正文限制 | 20,000 字符 |
| 标题长度限制 | 64 字符 |
| 摘要长度限制 | 120 字符 |
| 封面图尺寸 | 900x383px(大图)/ 200x200px(小图) |
| 图片格式 | JPEG, PNG, GIF |
| 单张图片大小 | 10MB |
| 正文图片数量 | 最多 20 张 |
| 文章格式 | HTML(富文本) |
### 最佳实践
- 标题控制在 20 字以内
- 摘要简明扼要
- 正文使用小标题分段
- 配图增加阅读体验
- 文末引导关注和转发
---
## 博客(Hugo / Jekyll / Hexo)
| 项目 | Hugo | Jekyll | Hexo |
|------|------|--------|------|
| 文件格式 | Markdown | Markdown | Markdown |
| Frontmatter | YAML / TOML | YAML | YAML |
| 文件位置 | `content/posts/` | `_posts/` | `source/_posts/` |
| 文件命名 | `slug.md` | `YYYY-MM-DD-slug.md` | `slug.md` |
| 图片路径 | `static/images/` | `assets/images/` | `source/images/` |
### Frontmatter 字段
- `title`: 文章标题
- `date`: 发布日期
- `author`: 作者
- `tags`: 标签列表
- `description` / `excerpt`: 摘要
- `draft`: 是否草稿
---
## Medium
| 项目 | 规格 |
|------|------|
| 文章格式 | Markdown / HTML |
| 标签上限 | 5 个 |
| 图片格式 | JPEG, PNG, GIF |
| 图片大小限制 | 25MB |
| 推荐文章长度 | 7 分钟阅读时间(约 1,750 词) |
### 最佳实践
- 使用副标题增加结构感
- 图片和引用提升可读性
- 标签选择高流量话题
- 开头 100 字决定读者是否继续
---
## 跨平台对照表
| 维度 | Twitter | LinkedIn | 微信公众号 | Blog | Medium |
|------|---------|----------|-----------|------|--------|
| 字符限制 | 280/条 | 3,000 | 20,000 | 无限制 | 无限制 |
| 格式 | 纯文本 | 纯文本 | HTML | Markdown | Markdown |
| 图片支持 | 4张/条 | 9张 | 20张 | 无限制 | 无限制 |
| 标签 | Hashtag | Hashtag | 无 | 分类标签 | 5个标签 |
| 最佳长度 | 71-100字 | 150-300字 | 1500-3000字 | 1000-2000字 | 1750词 |
| 语气 | 轻松简洁 | 专业正式 | 亲切通俗 | 深度详实 | 叙事性强 |
现金流领航 — 帮中小企业老板看清"钱从哪来、花到哪去、还剩多少
---
name: cashflow-pilot
description: 现金流领航 — 帮中小企业老板看清"钱从哪来、花到哪去、还剩多少"
version: 1.0.0
metadata:
openclaw:
optional_env:
- CFP_SUBSCRIPTION_TIER
- CFP_DATA_DIR
---
# 现金流领航(cashflow-pilot)
你是一个专业的现金流管理助手 Agent。你的职责是帮助中小企业老板管理日常收支、导入账单、生成现金流报告、跟踪应收应付账款、预测未来现金流。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `CFP_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `CFP_DATA_DIR` | 否 | 数据存储目录,默认 `~/.openclaw-bdi/cashflow-pilot/` |
启动时,检查数据目录是否存在,若不存在则自动创建。向用户打招呼并简要介绍可用功能。
---
## 流程一:导入账单
当用户说"导入账单"、"导入CSV"、"上传账单"或类似意图时,执行以下步骤:
### 步骤 1:确认文件类型
向用户确认要导入的文件类型:
```
请选择导入方式:
1. CSV 文件(免费版支持)
2. Excel 文件(仅付费版)
3. 银行对账单 PDF(仅付费版)
4. 手动录入
```
> 注意:先执行订阅校验,免费版仅支持 CSV 和手动录入。
### 步骤 2:获取文件路径
引导用户提供文件路径。确认文件存在后继续。
### 步骤 3:解析并导入
```bash
python3 scripts/ledger_parser.py --action import --file <path> --format csv
```
### 步骤 4:展示导入结果
将导入结果以清晰的表格形式展示:
```
导入成功!
- 导入记录:52 条
- 收入:28 笔,合计 ¥185,000.00
- 支出:24 笔,合计 ¥132,500.00
- 净现金流:¥52,500.00
自动分类结果:
| 分类 | 笔数 | 金额 |
|------|------|------|
| 销售回款 | 20 | ¥165,000 |
| 服务收入 | 8 | ¥20,000 |
| 人员工资 | 10 | ¥85,000 |
| 房租物业 | 3 | ¥15,000 |
| ... | ... | ... |
```
### 步骤 5:确认分类
让用户检查自动分类结果,如有误分类可手动调整。
---
## 流程二:手动录入收支
当用户说"记一笔"、"录入收支"、"手动录入"或类似意图时,执行以下步骤:
### 步骤 1:收集信息
引导用户提供:
- 日期(默认今天)
- 类型(收入/支出)
- 金额
- 描述/用途
- 分类(可选,系统自动识别)
### 步骤 2:录入
```bash
python3 scripts/ledger_parser.py --action add --data '{"date":"2026-03-19","type":"income","amount":50000,"description":"客户A货款回收"}'
```
### 步骤 3:确认
展示录入结果,告知已成功记录。
---
## 流程三:查看现金流报告
当用户说"这个月现金流怎么样"、"月度报告"、"收支汇总"或类似意图时,执行以下步骤:
### 步骤 1:确认报告周期
默认为当月,用户可指定其他月份或日期范围。
### 步骤 2:生成报告
```bash
python3 scripts/cashflow_analyzer.py --action monthly --year 2026 --month 3
```
### 步骤 3:输出报告
根据订阅等级输出不同格式:
**免费版输出:**
- 核心指标汇总表格(总收入、总支出、净现金流)
- 收支分类明细表格
- 简要文字总结
**付费版输出:**
- 核心指标汇总表格
- 收支分类明细表格
- Mermaid 饼图(收入/支出分类占比)
- 异常支出告警
- 深度洞察与建议
---
## 流程四:应收账款管理
当用户说"哪些客户还没付款"、"应收账款"、"催款清单"或类似意图时,执行以下步骤:
### 步骤 1:生成提醒
```bash
python3 scripts/reminder_generator.py --action generate --type receivable
```
### 步骤 2:展示提醒清单
按逾期天数排序展示,高优先级在前:
```
应收账款提醒(共 8 笔):
| 优先级 | 客户 | 金额 | 到期日 | 状态 |
|--------|------|------|--------|------|
| 高 | 客户A | ¥50,000 | 2026-02-15 | 逾期32天 |
| 高 | 客户B | ¥30,000 | 2026-03-01 | 逾期18天 |
| 中 | 客户C | ¥25,000 | 2026-03-22 | 3天后到期 |
| ... | ... | ... | ... | ... |
```
> 免费版仅显示前 3 条提醒,付费版无限制。
### 步骤 3:催款通知
用户可选择对某条记录生成催款通知:
```bash
python3 scripts/reminder_generator.py --action notice --id <receivable_id>
```
---
## 流程五:现金流预测
当用户说"预测下季度现金流"、"未来现金流"、"现金流预测"或类似意图时,执行以下步骤:
### 步骤 1:订阅校验
此功能仅限付费版。免费版用户提示升级。
### 步骤 2:执行预测
```bash
python3 scripts/forecast_engine.py --action predict --months 3
```
### 步骤 3:输出预测报告
```
现金流预测(未来3个月):
| 月份 | 预测收入 | 预测支出 | 预测净现金流 |
|------|---------|---------|------------|
| 2026-04 | ¥180,000 | ¥135,000 | ¥45,000 |
| 2026-05 | ¥175,000 | ¥140,000 | ¥35,000 |
| 2026-06 | ¥190,000 | ¥138,000 | ¥52,000 |
[趋势预测图 - Mermaid 折线图]
风险预警:
- [中风险] 支出连续3个月增长,累计增幅 15.2%
```
---
## 流程六:趋势分析
当用户说"收支趋势"、"现金流趋势"、"最近几个月怎么样"或类似意图时:
```bash
python3 scripts/cashflow_analyzer.py --action trend --months 6
```
展示最近 N 个月的收支趋势表格和图表(付费版含 Mermaid 折线图)。
---
## 订阅校验逻辑
在每次涉及功能限制的操作前,必须执行以下校验:
### 读取订阅等级
```
tier = env CFP_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥79/月) |
|------|---------------|----------------------|
| 手动录入收支 | ✅ | ✅ |
| CSV 导入账单 | ✅ | ✅ |
| Excel 导入 | ❌ | ✅ |
| 银行对账单解析 | ❌ | ✅ (PDF/CSV) |
| 月度现金流报告 | 基础表格 | 表格+图表+洞察 |
| 应收账款提醒 | 最多3条 | 无限制 |
| 现金流预测(未来3月) | ❌ | ✅ AI预测 |
| 异常支出告警 | ❌ | ✅ |
| 趋势图表 | ❌ | ✅ Mermaid 图表 |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥79/月),请联系管理员或访问订阅管理页面。"
4. 不要直接拒绝,而是提供免费版可用的替代方案(如果有的话)。
---
## 参考文档
在生成报告和图表时,请参考以下文档:
- **报告模板**:`references/report-templates.md` — 包含月度报告、账龄分析、预测报告的标准模板。
---
## 安全规范
1. **数据安全**:所有数据存储在本地 JSON 文件中,数据不会上传到云端。
2. **金额校验**:所有金额输入必须为有效数字,防止注入攻击。
3. **文件访问**:仅允许读取用户明确指定的文件路径,不自动扫描文件系统。
4. **错误处理**:执行命令失败时,向用户展示友好的错误提示,不暴露内部路径或系统信息。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 用简洁、易懂的语言解释财务数据,避免过于专业的会计术语。
3. 对用户的问题给出清晰、结构化的回答,优先使用表格展示数据。
4. 主动提供现金流洞察和管理建议,而不仅仅是返回原始数据。
5. 遇到模糊的用户意图时,主动追问以明确需求。
6. 操作出错时,耐心排查并给出可行的解决方案。
7. 尊重订阅等级限制,在提示升级时保持友好,不要反复推销。
8. 涉及大金额操作时(如批量导入),先展示预览让用户确认。
FILE:assets/README.md
# 现金流领航 (cashflow-pilot)
> 帮中小企业老板看清"钱从哪来、花到哪去、还剩多少"
---
## 功能亮点
- **一键导入账单** — 支持 CSV / Excel 文件导入,自动识别收支分类,告别手动记账
- **月度现金流报告** — 自动生成收支汇总、分类占比、趋势分析,一页看清全貌
- **应收应付提醒** — 按逾期天数智能排序,自动生成催款通知,不再漏收一笔款
- **AI 现金流预测** — 基于历史数据预测未来3个月现金流,提前发现资金缺口
- **异常支出告警** — 自动检测异常支出,某项费用突然飙升立即提醒
- **数据全部本地存储** — 财务数据不离开你的电脑,安全放心
---
## 版本对比
| 功能 | 免费版 | 付费版 ¥79/月 |
|------|:------:|:------------:|
| 手动录入收支 | ✅ | ✅ |
| CSV 导入账单 | ✅ | ✅ |
| Excel 导入 | ❌ | ✅ |
| 银行对账单解析 | ❌ | ✅ (PDF/CSV) |
| 月度现金流报告 | 基础表格 | 表格+图表+洞察 |
| 应收账款提醒 | 3条 | 无限 |
| 现金流预测(未来3月) | ❌ | ✅ AI预测 |
| 异常支出告警 | ❌ | ✅ |
| 趋势图表 | ❌ | ✅ Mermaid 图表 |
---
## 快速开始
### 1. 安装 Skill
在 ClawHub 中搜索 `cashflow-pilot`,点击安装,或使用命令行:
```bash
openclaw skill install cashflow-pilot
```
### 2. 配置(可选)
设置环境变量以自定义行为:
```bash
# 订阅等级(free 或 paid,默认 free)
export CFP_SUBSCRIPTION_TIER=free
# 数据存储目录(默认 ~/.openclaw-bdi/cashflow-pilot/)
export CFP_DATA_DIR=~/.openclaw-bdi/cashflow-pilot/
```
### 3. 使用
```bash
# 导入 CSV 账单
/cashflow-pilot 导入账单
# 查看本月现金流
/cashflow-pilot 这个月现金流怎么样
# 查看应收账款
/cashflow-pilot 哪些客户还没付款
# 手动记一笔
/cashflow-pilot 记一笔收入,客户A回款5万
# 预测现金流(付费版)
/cashflow-pilot 预测下季度现金流
```
---
## 报告示例
以下是一份自动生成的月度报告样例:
```markdown
# 现金流月报 — 2026年3月
统计记录数:52 条
## 核心指标
| 指标 | 金额 |
|------|-----:|
| 总收入 | ¥185,000.00 |
| 总支出 | ¥132,500.00 |
| 净现金流 | ¥52,500.00 |
| 收入笔数 | 28 |
| 支出笔数 | 24 |
## 收入明细
| 收入分类 | 金额 | 占比 |
|----------|-----:|-----:|
| 销售回款 | ¥150,000.00 | 81.1% |
| 服务收入 | ¥30,000.00 | 16.2% |
| 其他收入 | ¥5,000.00 | 2.7% |
| **合计** | **¥185,000.00** | **100.0%** |
## 支出明细
| 支出分类 | 金额 | 占比 |
|----------|-----:|-----:|
| 人员工资 | ¥85,000.00 | 64.2% |
| 采购成本 | ¥25,000.00 | 18.9% |
| 房租物业 | ¥15,000.00 | 11.3% |
| 其他支出 | ¥7,500.00 | 5.7% |
| **合计** | **¥132,500.00** | **100.0%** |
---
*报告由 cashflow-pilot 自动生成*
```
---
## 常见问题
### Q1: 支持什么格式的 CSV 文件?
支持 UTF-8、GBK、GB2312 等编码的 CSV 文件。表头中需包含"日期"和"金额"(或"收入"/"支出")列,系统会自动识别列结构。
### Q2: 数据存储在哪里?
所有数据以 JSON 格式存储在本地目录(默认 `~/.openclaw-bdi/cashflow-pilot/`),包括账本记录、应收应付数据等。数据不会上传到云端。
### Q3: 免费版和付费版有什么区别?
免费版提供基础的收支记录和报告功能,足够满足简单记账需求。付费版(¥79/月)增加了 Excel 导入、AI 预测、异常检测、无限提醒等高级功能,适合需要深度管理现金流的企业。
### Q4: 可以同时管理多个公司的账本吗?
可以。通过设置不同的 `CFP_DATA_DIR` 环境变量,指向不同的数据目录即可管理多套账本。
### Q5: 现金流预测准确吗?
预测基于历史数据使用移动平均和线性回归算法,仅供参考。建议结合实际业务情况综合判断,不宜作为唯一决策依据。
### Q6: 导入的分类不准确怎么办?
系统使用关键词匹配进行自动分类,可能存在误分类。导入后可通过手动录入功能修正分类。未来版本将支持自定义分类规则。
---
## 技术支持
- **文档**:查看 `references/` 目录获取报告模板参考
- **问题反馈**:在 ClawHub 的 Skill 页面提交 Issue
- **社区讨论**:加入 ClawHub 社区频道 `#cashflow-pilot`
- **邮件**:[email protected]
---
*cashflow-pilot v1.0 | 兼容 OpenClaw 0.5+*
FILE:scripts/ledger_parser.py
#!/usr/bin/env python3
"""
cashflow-pilot 账单导入解析模块
支持 CSV 和 Excel(.xlsx 以 CSV 模式)文件的导入与解析,
自动识别收支分类,输出标准化的账本记录。
"""
import csv
import io
import json
import os
import re
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional
# 将 scripts 目录加入路径以支持直接运行
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
classify_by_amount,
create_parser,
format_currency,
format_number,
get_data_dir,
load_input_data,
load_ledger,
output_error,
output_json,
output_success,
parse_date,
save_ledger,
check_subscription,
)
# ============================================================
# CSV 解析
# ============================================================
def detect_encoding(file_path: str) -> str:
"""检测文件编码(简单实现)。
尝试 utf-8-sig、utf-8、gbk、gb2312 等常见中文编码。
Args:
file_path: 文件路径。
Returns:
可用的编码名称。
"""
encodings = ["utf-8-sig", "utf-8", "gbk", "gb2312", "gb18030", "latin-1"]
for enc in encodings:
try:
with open(file_path, "r", encoding=enc) as f:
f.read(1024)
return enc
except (UnicodeDecodeError, UnicodeError):
continue
return "utf-8"
def detect_csv_columns(headers: List[str]) -> Dict[str, Optional[int]]:
"""自动检测 CSV 列映射。
根据表头关键词识别日期、描述、金额、收入、支出等列。
Args:
headers: CSV 表头列表。
Returns:
列名到列索引的映射字典。
"""
mapping = {
"date": None,
"description": None,
"amount": None,
"income": None,
"expense": None,
"category": None,
"note": None,
}
date_keywords = ["日期", "date", "交易日期", "记账日期", "时间"]
desc_keywords = ["描述", "摘要", "说明", "备注", "description", "memo", "交易说明", "用途"]
amount_keywords = ["金额", "amount", "交易金额", "发生额"]
income_keywords = ["收入", "income", "贷方", "credit", "贷方金额"]
expense_keywords = ["支出", "expense", "借方", "debit", "借方金额"]
category_keywords = ["分类", "类别", "category", "类型"]
note_keywords = ["备注", "note", "附注", "remark"]
for idx, header in enumerate(headers):
h = header.strip().lower()
if mapping["date"] is None and any(k in h for k in date_keywords):
mapping["date"] = idx
elif mapping["description"] is None and any(k in h for k in desc_keywords):
mapping["description"] = idx
elif mapping["amount"] is None and any(k in h for k in amount_keywords):
mapping["amount"] = idx
elif mapping["income"] is None and any(k in h for k in income_keywords):
mapping["income"] = idx
elif mapping["expense"] is None and any(k in h for k in expense_keywords):
mapping["expense"] = idx
elif mapping["category"] is None and any(k in h for k in category_keywords):
mapping["category"] = idx
elif mapping["note"] is None and any(k in h for k in note_keywords):
mapping["note"] = idx
return mapping
def parse_amount(value: str) -> float:
"""解析金额字符串为浮点数。
处理逗号分隔、货币符号、括号表示负数等情况。
Args:
value: 金额字符串。
Returns:
解析后的浮点数。
"""
if not value or not value.strip():
return 0.0
s = value.strip()
# 处理括号表示负数: (100.00) → -100.00
negative = False
if s.startswith("(") and s.endswith(")"):
negative = True
s = s[1:-1]
# 移除货币符号和空格
s = re.sub(r"[¥$€£\s]", "", s)
# 移除千分位逗号
s = s.replace(",", "")
# 处理负号
if s.startswith("-"):
negative = not negative
s = s[1:]
try:
result = float(s)
except ValueError:
return 0.0
return -result if negative else result
def parse_csv_file(file_path: str) -> List[Dict[str, Any]]:
"""解析 CSV 文件为标准化的交易记录列表。
Args:
file_path: CSV 文件路径。
Returns:
标准化的交易记录列表。
Raises:
ValueError: 文件不存在或格式无法识别。
"""
if not os.path.exists(file_path):
raise ValueError(f"文件不存在: {file_path}")
encoding = detect_encoding(file_path)
records = []
with open(file_path, "r", encoding=encoding, newline="") as f:
# 尝试检测分隔符
sample = f.read(4096)
f.seek(0)
try:
dialect = csv.Sniffer().sniff(sample, delimiters=",;\t|")
except csv.Error:
dialect = csv.excel
reader = csv.reader(f, dialect)
# 读取表头
try:
headers = next(reader)
except StopIteration:
raise ValueError("CSV 文件为空")
# 检测列映射
col_map = detect_csv_columns(headers)
if col_map["date"] is None and col_map["amount"] is None and col_map["income"] is None:
raise ValueError(
f"无法自动识别 CSV 列结构。检测到的表头: {headers}。"
"请确保包含'日期'和'金额'(或'收入'/'支出')列。"
)
# 解析数据行
for row_num, row in enumerate(reader, start=2):
if not row or all(not cell.strip() for cell in row):
continue # 跳过空行
record = _parse_csv_row(row, col_map, row_num)
if record:
records.append(record)
return records
def _parse_csv_row(
row: List[str],
col_map: Dict[str, Optional[int]],
row_num: int,
) -> Optional[Dict[str, Any]]:
"""解析单行 CSV 数据。
Args:
row: CSV 行数据。
col_map: 列映射。
row_num: 行号(用于错误提示)。
Returns:
标准化的交易记录字典,解析失败返回 None。
"""
def get_col(name: str) -> str:
idx = col_map.get(name)
if idx is not None and idx < len(row):
return row[idx].strip()
return ""
# 解析日期
date_str = get_col("date")
date_val = None
if date_str:
from utils import parse_date as pd
date_val = pd(date_str)
if date_val is None:
date_val = datetime.now()
# 解析金额
amount = 0.0
income_val = get_col("income")
expense_val = get_col("expense")
amount_val = get_col("amount")
if income_val and parse_amount(income_val) != 0:
amount = abs(parse_amount(income_val))
elif expense_val and parse_amount(expense_val) != 0:
amount = -abs(parse_amount(expense_val))
elif amount_val:
amount = parse_amount(amount_val)
if amount == 0:
return None # 跳过零金额记录
# 解析描述
description = get_col("description") or get_col("note") or f"第{row_num}行记录"
# 分类
category_str = get_col("category")
if category_str:
txn_type = "income" if amount > 0 else "expense"
classification = {"type": txn_type, "category": category_str}
else:
classification = classify_by_amount(amount, description)
return {
"id": f"imp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{row_num}",
"date": date_val.strftime("%Y-%m-%d"),
"description": description,
"amount": round(abs(amount), 2),
"type": classification["type"],
"category": classification["category"],
"source": "csv_import",
"raw_row": row_num,
}
# ============================================================
# Excel 解析(使用 CSV 兼容模式,无需 openpyxl)
# ============================================================
def parse_excel_file(file_path: str) -> List[Dict[str, Any]]:
"""解析 Excel 文件(仅付费版支持)。
注意:由于仅使用标准库,此处通过读取 .xlsx 内嵌的 XML
实现基础解析。对于复杂 Excel 文件建议先导出为 CSV。
Args:
file_path: Excel 文件路径。
Returns:
标准化的交易记录列表。
"""
sub = check_subscription()
if sub["tier"] != "paid":
raise ValueError(
"Excel 文件导入为付费版功能。请先升级至付费版(¥79/月),"
"或将 Excel 文件另存为 CSV 格式后使用免费版导入。"
)
if not os.path.exists(file_path):
raise ValueError(f"文件不存在: {file_path}")
# 尝试用 csv 模块读取(适用于 .csv 格式另存为 .xlsx 的情况)
# 对于真正的 xlsx,提示用户转换
try:
return parse_csv_file(file_path)
except Exception:
raise ValueError(
"无法直接解析 Excel (.xlsx) 文件。请将文件另存为 CSV 格式后重新导入。"
"建议步骤:在 Excel 中选择 文件→另存为→CSV UTF-8 (逗号分隔)(*.csv)"
)
# ============================================================
# 手动录入
# ============================================================
def add_manual_entry(entry_data: Dict[str, Any]) -> Dict[str, Any]:
"""手动添加一条收支记录。
Args:
entry_data: 包含以下字段的字典:
- date: 日期 (YYYY-MM-DD)
- description: 描述
- amount: 金额(正数)
- type: income/expense
- category: 分类(可选,自动识别)
Returns:
添加后的完整记录。
"""
required_fields = ["date", "description", "amount", "type"]
for field in required_fields:
if field not in entry_data:
raise ValueError(f"缺少必需字段: {field}")
amount = float(entry_data["amount"])
txn_type = entry_data["type"]
if txn_type not in ("income", "expense"):
raise ValueError(f"type 必须为 'income' 或 'expense',当前值: {txn_type}")
# 自动分类
category = entry_data.get("category", "")
if not category:
cls = classify_by_amount(
amount if txn_type == "income" else -amount,
entry_data["description"],
)
category = cls["category"]
record = {
"id": f"man_{datetime.now().strftime('%Y%m%d%H%M%S%f')}",
"date": entry_data["date"],
"description": entry_data["description"],
"amount": round(abs(amount), 2),
"type": txn_type,
"category": category,
"source": "manual",
}
# 追加到账本
ledger = load_ledger()
ledger.append(record)
save_ledger(ledger)
return record
# ============================================================
# 主入口
# ============================================================
def main():
"""主入口函数。"""
parser = create_parser("cashflow-pilot 账单导入解析工具")
parser.add_argument(
"--format",
choices=["csv", "excel"],
default="csv",
help="文件格式(csv 或 excel)",
)
try:
args = parser.parse_args()
except SystemExit:
return
try:
if args.action == "import":
# 导入文件
if not args.file:
output_error("请使用 --file 参数指定要导入的文件路径", "MISSING_FILE")
return
if args.format == "excel":
records = parse_excel_file(args.file)
else:
records = parse_csv_file(args.file)
# 追加到账本
ledger = load_ledger()
ledger.extend(records)
save_ledger(ledger)
# 生成导入摘要
income_count = sum(1 for r in records if r["type"] == "income")
expense_count = sum(1 for r in records if r["type"] == "expense")
total_income = sum(r["amount"] for r in records if r["type"] == "income")
total_expense = sum(r["amount"] for r in records if r["type"] == "expense")
output_success({
"message": "账单导入成功",
"imported_count": len(records),
"income_count": income_count,
"expense_count": expense_count,
"total_income": round(total_income, 2),
"total_expense": round(total_expense, 2),
"records": records[:10], # 仅返回前10条预览
"summary": (
f"成功导入 {len(records)} 条记录。"
f"其中收入 {income_count} 笔共 {format_currency(total_income)},"
f"支出 {expense_count} 笔共 {format_currency(total_expense)}。"
),
})
elif args.action == "parse":
# 仅解析不导入
if not args.file:
output_error("请使用 --file 参数指定要解析的文件路径", "MISSING_FILE")
return
if args.format == "excel":
records = parse_excel_file(args.file)
else:
records = parse_csv_file(args.file)
output_success({
"message": "文件解析成功",
"record_count": len(records),
"records": records,
})
elif args.action == "add":
# 手动添加记录
data = load_input_data(args)
if not data:
output_error("请通过 --data 或 --data-file 提供记录数据", "MISSING_DATA")
return
record = add_manual_entry(data)
output_success({
"message": "记录添加成功",
"record": record,
})
elif args.action == "list":
# 列出当前账本记录
ledger = load_ledger()
output_success({
"total_count": len(ledger),
"records": ledger,
})
else:
output_error(
f"不支持的操作: {args.action}。支持的操作: import, parse, add, list",
"INVALID_ACTION",
)
except ValueError as e:
output_error(str(e), "PARSE_ERROR")
except Exception as e:
output_error(f"执行失败: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/cashflow_analyzer.py
#!/usr/bin/env python3
"""
cashflow-pilot 收支分类与汇总分析模块
对账本数据进行多维度汇总分析,生成月度现金流报告、
分类统计、趋势分析等。
"""
import json
import os
import sys
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription,
create_parser,
format_chinese_unit,
format_currency,
format_number,
format_percentage,
load_input_data,
load_ledger,
output_error,
output_success,
)
# ============================================================
# 数据汇总
# ============================================================
def summarize_records(records: List[Dict[str, Any]]) -> Dict[str, Any]:
"""汇总收支数据。
Args:
records: 账本记录列表。
Returns:
汇总结果字典。
"""
total_income = 0.0
total_expense = 0.0
income_by_category = defaultdict(float)
expense_by_category = defaultdict(float)
income_count = 0
expense_count = 0
for r in records:
amount = float(r.get("amount", 0))
txn_type = r.get("type", "unknown")
category = r.get("category", "未分类")
if txn_type == "income":
total_income += amount
income_by_category[category] += amount
income_count += 1
elif txn_type == "expense":
total_expense += amount
expense_by_category[category] += amount
expense_count += 1
net_cashflow = total_income - total_expense
return {
"total_income": round(total_income, 2),
"total_expense": round(total_expense, 2),
"net_cashflow": round(net_cashflow, 2),
"income_count": income_count,
"expense_count": expense_count,
"total_count": len(records),
"income_by_category": dict(
sorted(income_by_category.items(), key=lambda x: x[1], reverse=True)
),
"expense_by_category": dict(
sorted(expense_by_category.items(), key=lambda x: x[1], reverse=True)
),
"cashflow_status": "正" if net_cashflow >= 0 else "负",
}
def filter_records_by_month(
records: List[Dict[str, Any]], year: int, month: int
) -> List[Dict[str, Any]]:
"""按月份筛选记录。
Args:
records: 全部账本记录。
year: 年份。
month: 月份。
Returns:
筛选后的记录列表。
"""
result = []
for r in records:
date_str = r.get("date", "")
try:
dt = datetime.strptime(date_str, "%Y-%m-%d")
if dt.year == year and dt.month == month:
result.append(r)
except (ValueError, TypeError):
continue
return result
def filter_records_by_range(
records: List[Dict[str, Any]], start_date: str, end_date: str
) -> List[Dict[str, Any]]:
"""按日期范围筛选记录。
Args:
records: 全部账本记录。
start_date: 起始日期 (YYYY-MM-DD)。
end_date: 截止日期 (YYYY-MM-DD)。
Returns:
筛选后的记录列表。
"""
try:
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
return records
result = []
for r in records:
try:
dt = datetime.strptime(r.get("date", ""), "%Y-%m-%d")
if start <= dt <= end:
result.append(r)
except (ValueError, TypeError):
continue
return result
# ============================================================
# 月度报告
# ============================================================
def generate_monthly_report(
records: List[Dict[str, Any]], year: int, month: int
) -> Dict[str, Any]:
"""生成月度现金流报告。
Args:
records: 全部账本记录。
year: 年份。
month: 月份。
Returns:
月度报告数据字典。
"""
monthly_records = filter_records_by_month(records, year, month)
summary = summarize_records(monthly_records)
sub = check_subscription()
# 基础报告表格
report_table = _build_summary_table(summary)
# 分类明细
income_detail = _build_category_table(summary["income_by_category"], "收入")
expense_detail = _build_category_table(summary["expense_by_category"], "支出")
report = {
"period": f"{year}年{month}月",
"year": year,
"month": month,
"summary": summary,
"report_table": report_table,
"income_detail": income_detail,
"expense_detail": expense_detail,
"record_count": len(monthly_records),
}
# 付费版附加内容
if sub["tier"] == "paid":
report["mermaid_pie_income"] = _build_mermaid_pie(
summary["income_by_category"], "收入分类占比"
)
report["mermaid_pie_expense"] = _build_mermaid_pie(
summary["expense_by_category"], "支出分类占比"
)
report["insights"] = _generate_insights(summary, monthly_records)
# 异常检测
anomalies = detect_anomalies(records, year, month)
if anomalies:
report["anomalies"] = anomalies
# 生成 Markdown 报告
report["markdown"] = _render_monthly_markdown(report, sub["tier"])
return report
def _build_summary_table(summary: Dict[str, Any]) -> str:
"""生成汇总表格 Markdown。"""
return (
"| 指标 | 金额 |\n"
"|------|-----:|\n"
f"| 总收入 | {format_currency(summary['total_income'])} |\n"
f"| 总支出 | {format_currency(summary['total_expense'])} |\n"
f"| 净现金流 | {format_currency(summary['net_cashflow'])} |\n"
f"| 收入笔数 | {summary['income_count']} |\n"
f"| 支出笔数 | {summary['expense_count']} |\n"
)
def _build_category_table(category_data: Dict[str, float], label: str) -> str:
"""生成分类明细表格 Markdown。"""
if not category_data:
return f"暂无{label}记录\n"
total = sum(category_data.values())
lines = [f"| {label}分类 | 金额 | 占比 |\n", "|------|-----:|-----:|\n"]
for cat, amount in category_data.items():
pct = amount / total if total > 0 else 0
lines.append(
f"| {cat} | {format_currency(amount)} | {format_percentage(pct)} |\n"
)
lines.append(f"| **合计** | **{format_currency(total)}** | **100.0%** |\n")
return "".join(lines)
def _build_mermaid_pie(category_data: Dict[str, float], title: str) -> str:
"""生成 Mermaid 饼图代码。"""
if not category_data:
return ""
lines = [f"```mermaid\npie title {title}"]
for cat, amount in category_data.items():
lines.append(f' "{cat}" : {round(amount, 2)}')
lines.append("```")
return "\n".join(lines)
def _generate_insights(
summary: Dict[str, Any], records: List[Dict[str, Any]]
) -> List[str]:
"""生成数据洞察建议(付费版)。"""
insights = []
net = summary["net_cashflow"]
if net < 0:
insights.append(
f"本月净现金流为负({format_currency(net)}),支出超过收入,"
"建议关注支出控制或加快应收款回收。"
)
elif net > 0:
ratio = net / summary["total_income"] if summary["total_income"] > 0 else 0
insights.append(
f"本月净现金流为正({format_currency(net)}),"
f"现金留存率 {format_percentage(ratio)}。"
)
# 最大支出类别分析
if summary["expense_by_category"]:
top_cat = max(summary["expense_by_category"], key=summary["expense_by_category"].get)
top_amount = summary["expense_by_category"][top_cat]
total_exp = summary["total_expense"]
if total_exp > 0:
pct = top_amount / total_exp
insights.append(
f"最大支出类别为「{top_cat}」,占总支出 {format_percentage(pct)},"
f"金额 {format_currency(top_amount)}。"
)
# 收入集中度
if summary["income_by_category"]:
top_inc_cat = max(summary["income_by_category"], key=summary["income_by_category"].get)
top_inc = summary["income_by_category"][top_inc_cat]
total_inc = summary["total_income"]
if total_inc > 0 and top_inc / total_inc > 0.7:
insights.append(
f"收入高度集中于「{top_inc_cat}」(占比 {format_percentage(top_inc / total_inc)}),"
"建议拓展收入来源以降低风险。"
)
return insights
def _render_monthly_markdown(report: Dict[str, Any], tier: str) -> str:
"""渲染月度报告为 Markdown 格式。"""
lines = [
f"# 现金流月报 — {report['period']}\n",
f"统计记录数:{report['record_count']} 条\n",
"## 核心指标\n",
report["report_table"],
"\n## 收入明细\n",
report["income_detail"],
"\n## 支出明细\n",
report["expense_detail"],
]
if tier == "paid":
if report.get("mermaid_pie_income"):
lines.extend(["\n## 收入分类图\n", report["mermaid_pie_income"], "\n"])
if report.get("mermaid_pie_expense"):
lines.extend(["\n## 支出分类图\n", report["mermaid_pie_expense"], "\n"])
if report.get("insights"):
lines.append("\n## 洞察与建议\n")
for i, insight in enumerate(report["insights"], 1):
lines.append(f"{i}. {insight}\n")
if report.get("anomalies"):
lines.append("\n## 异常告警\n")
for a in report["anomalies"]:
lines.append(f"- {a['message']}\n")
lines.append("\n---\n*报告由 cashflow-pilot 自动生成*\n")
return "".join(lines)
# ============================================================
# 趋势分析
# ============================================================
def analyze_trend(records: List[Dict[str, Any]], months: int = 6) -> Dict[str, Any]:
"""分析最近 N 个月的现金流趋势。
Args:
records: 全部账本记录。
months: 分析月数,默认 6。
Returns:
趋势分析结果。
"""
now = datetime.now()
monthly_data = []
for i in range(months - 1, -1, -1):
year = now.year
month = now.month - i
while month <= 0:
month += 12
year -= 1
month_records = filter_records_by_month(records, year, month)
summary = summarize_records(month_records)
monthly_data.append({
"period": f"{year}-{month:02d}",
"year": year,
"month": month,
"income": summary["total_income"],
"expense": summary["total_expense"],
"net": summary["net_cashflow"],
"record_count": len(month_records),
})
# 生成趋势 Mermaid 图表(付费版)
sub = check_subscription()
mermaid_chart = ""
if sub["tier"] == "paid" and monthly_data:
mermaid_chart = _build_trend_chart(monthly_data)
return {
"months_analyzed": months,
"monthly_data": monthly_data,
"mermaid_chart": mermaid_chart,
"trend_summary": _summarize_trend(monthly_data),
}
def _build_trend_chart(monthly_data: List[Dict[str, Any]]) -> str:
"""生成趋势折线图 Mermaid 代码。"""
lines = [
"```mermaid",
"xychart-beta",
' title "现金流趋势"',
' x-axis [' + ", ".join(f'"{d["period"]}"' for d in monthly_data) + "]",
' y-axis "金额(元)"',
" line [" + ", ".join(str(d["income"]) for d in monthly_data) + "]",
" line [" + ", ".join(str(d["expense"]) for d in monthly_data) + "]",
"```",
]
return "\n".join(lines)
def _summarize_trend(monthly_data: List[Dict[str, Any]]) -> str:
"""生成趋势文字摘要。"""
if len(monthly_data) < 2:
return "数据不足,无法分析趋势。"
last = monthly_data[-1]
prev = monthly_data[-2]
parts = []
# 收入趋势
if prev["income"] > 0:
inc_change = (last["income"] - prev["income"]) / prev["income"]
direction = "增长" if inc_change >= 0 else "下降"
parts.append(f"收入环比{direction} {format_percentage(abs(inc_change))}")
else:
parts.append(f"本月收入 {format_currency(last['income'])}")
# 支出趋势
if prev["expense"] > 0:
exp_change = (last["expense"] - prev["expense"]) / prev["expense"]
direction = "增长" if exp_change >= 0 else "下降"
parts.append(f"支出环比{direction} {format_percentage(abs(exp_change))}")
# 净现金流
parts.append(f"净现金流 {format_currency(last['net'])}")
return ";".join(parts) + "。"
# ============================================================
# 异常检测(付费版)
# ============================================================
def detect_anomalies(
records: List[Dict[str, Any]], year: int, month: int
) -> List[Dict[str, Any]]:
"""检测月度支出异常。
使用简单的均值+标准差方法,当某分类支出超过历史均值的2倍时标记为异常。
Args:
records: 全部账本记录。
year: 目标年份。
month: 目标月份。
Returns:
异常告警列表。
"""
sub = check_subscription()
if sub["tier"] != "paid":
return []
# 获取过去6个月各分类的支出数据
category_history = defaultdict(list)
for i in range(1, 7):
m = month - i
y = year
while m <= 0:
m += 12
y -= 1
month_records = filter_records_by_month(records, y, m)
month_summary = summarize_records(month_records)
for cat, amount in month_summary["expense_by_category"].items():
category_history[cat].append(amount)
# 当前月各分类支出
current_records = filter_records_by_month(records, year, month)
current_summary = summarize_records(current_records)
anomalies = []
for cat, current_amount in current_summary["expense_by_category"].items():
history = category_history.get(cat, [])
if len(history) < 2:
continue
avg = sum(history) / len(history)
if avg > 0 and current_amount > avg * 2:
anomalies.append({
"category": cat,
"current": round(current_amount, 2),
"average": round(avg, 2),
"ratio": round(current_amount / avg, 2),
"message": (
f"「{cat}」本月支出 {format_currency(current_amount)},"
f"是过去6月均值({format_currency(avg)})的 "
f"{current_amount / avg:.1f} 倍,请关注。"
),
})
return anomalies
# ============================================================
# 主入口
# ============================================================
def main():
"""主入口函数。"""
parser = create_parser("cashflow-pilot 收支分类与汇总分析工具")
parser.add_argument("--year", type=int, default=None, help="年份")
parser.add_argument("--month", type=int, default=None, help="月份")
parser.add_argument("--months", type=int, default=6, help="趋势分析月数")
try:
args = parser.parse_args()
except SystemExit:
return
try:
# 加载数据
data = load_input_data(args)
if isinstance(data, list):
records = data
elif isinstance(data, dict) and "records" in data:
records = data["records"]
else:
records = load_ledger()
if args.action == "summary":
# 汇总分析
if args.year and args.month:
records = filter_records_by_month(records, args.year, args.month)
summary = summarize_records(records)
output_success(summary)
elif args.action == "monthly":
# 月度报告
year = args.year or datetime.now().year
month = args.month or datetime.now().month
report = generate_monthly_report(records, year, month)
output_success(report)
elif args.action == "trend":
# 趋势分析
result = analyze_trend(records, args.months)
output_success(result)
elif args.action == "anomaly":
# 异常检测
year = args.year or datetime.now().year
month = args.month or datetime.now().month
anomalies = detect_anomalies(records, year, month)
output_success({
"period": f"{year}年{month}月",
"anomalies": anomalies,
"count": len(anomalies),
})
else:
output_error(
f"不支持的操作: {args.action}。支持的操作: summary, monthly, trend, anomaly",
"INVALID_ACTION",
)
except ValueError as e:
output_error(str(e), "ANALYSIS_ERROR")
except Exception as e:
output_error(f"分析执行失败: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/reminder_generator.py
#!/usr/bin/env python3
"""
cashflow-pilot 应收应付提醒生成模块
管理应收账款和应付账款,生成按逾期天数排序的提醒清单,
支持催款通知生成。
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription,
create_parser,
format_currency,
format_number,
get_data_dir,
load_input_data,
output_error,
output_success,
)
# ============================================================
# 数据管理
# ============================================================
def get_receivables_file() -> str:
"""获取应收账款数据文件路径。"""
return os.path.join(get_data_dir(), "receivables.json")
def get_payables_file() -> str:
"""获取应付账款数据文件路径。"""
return os.path.join(get_data_dir(), "payables.json")
def load_items(file_path: str) -> List[Dict[str, Any]]:
"""加载应收/应付记录。
Args:
file_path: 数据文件路径。
Returns:
记录列表。
"""
if not os.path.exists(file_path):
return []
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
return data if isinstance(data, list) else []
except (json.JSONDecodeError, IOError):
return []
def save_items(items: List[Dict[str, Any]], file_path: str) -> None:
"""保存应收/应付记录。
Args:
items: 记录列表。
file_path: 数据文件路径。
"""
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(items, f, ensure_ascii=False, indent=2, default=str)
# ============================================================
# 应收账款管理
# ============================================================
def add_receivable(data: Dict[str, Any]) -> Dict[str, Any]:
"""添加一条应收账款记录。
Args:
data: 应收账款信息,包含:
- customer: 客户名称
- amount: 应收金额
- due_date: 到期日期 (YYYY-MM-DD)
- invoice_no: 发票号(可选)
- description: 说明(可选)
Returns:
添加后的完整记录。
"""
required = ["customer", "amount", "due_date"]
for field in required:
if field not in data:
raise ValueError(f"缺少必需字段: {field}")
record = {
"id": f"recv_{datetime.now().strftime('%Y%m%d%H%M%S%f')}",
"customer": data["customer"],
"amount": round(float(data["amount"]), 2),
"due_date": data["due_date"],
"invoice_no": data.get("invoice_no", ""),
"description": data.get("description", ""),
"status": "pending", # pending / paid / overdue
"created_at": datetime.now().strftime("%Y-%m-%d"),
}
items = load_items(get_receivables_file())
items.append(record)
save_items(items, get_receivables_file())
return record
def add_payable(data: Dict[str, Any]) -> Dict[str, Any]:
"""添加一条应付账款记录。
Args:
data: 应付账款信息,包含:
- vendor: 供应商名称
- amount: 应付金额
- due_date: 到期日期 (YYYY-MM-DD)
- invoice_no: 发票号(可选)
- description: 说明(可选)
Returns:
添加后的完整记录。
"""
required = ["vendor", "amount", "due_date"]
for field in required:
if field not in data:
raise ValueError(f"缺少必需字段: {field}")
record = {
"id": f"pay_{datetime.now().strftime('%Y%m%d%H%M%S%f')}",
"vendor": data["vendor"],
"amount": round(float(data["amount"]), 2),
"due_date": data["due_date"],
"invoice_no": data.get("invoice_no", ""),
"description": data.get("description", ""),
"status": "pending",
"created_at": datetime.now().strftime("%Y-%m-%d"),
}
items = load_items(get_payables_file())
items.append(record)
save_items(items, get_payables_file())
return record
# ============================================================
# 查询与排序
# ============================================================
def calculate_overdue_days(due_date_str: str) -> int:
"""计算逾期天数。
Args:
due_date_str: 到期日期字符串 (YYYY-MM-DD)。
Returns:
逾期天数,正数表示已逾期,负数表示未到期。
"""
try:
due = datetime.strptime(due_date_str, "%Y-%m-%d")
delta = datetime.now() - due
return delta.days
except (ValueError, TypeError):
return 0
def enrich_item(item: Dict[str, Any]) -> Dict[str, Any]:
"""为记录添加逾期天数和状态标签。
Args:
item: 原始记录。
Returns:
增强后的记录。
"""
enriched = dict(item)
overdue_days = calculate_overdue_days(item.get("due_date", ""))
enriched["overdue_days"] = overdue_days
if item.get("status") == "paid":
enriched["status_label"] = "已收款"
elif overdue_days > 0:
enriched["status"] = "overdue"
enriched["status_label"] = f"已逾期 {overdue_days} 天"
elif overdue_days == 0:
enriched["status_label"] = "今日到期"
else:
enriched["status_label"] = f"还有 {abs(overdue_days)} 天到期"
return enriched
def list_receivables(
status_filter: Optional[str] = None, sort_by_overdue: bool = True
) -> List[Dict[str, Any]]:
"""列出应收账款。
Args:
status_filter: 状态筛选(pending/paid/overdue),None 表示全部。
sort_by_overdue: 是否按逾期天数排序(逾期最久的在前)。
Returns:
增强后的应收账款列表。
"""
items = load_items(get_receivables_file())
enriched = [enrich_item(item) for item in items]
if status_filter:
enriched = [item for item in enriched if item["status"] == status_filter]
if sort_by_overdue:
enriched.sort(key=lambda x: x["overdue_days"], reverse=True)
return enriched
def list_payables(
status_filter: Optional[str] = None, sort_by_overdue: bool = True
) -> List[Dict[str, Any]]:
"""列出应付账款。
Args:
status_filter: 状态筛选。
sort_by_overdue: 是否按逾期天数排序。
Returns:
增强后的应付账款列表。
"""
items = load_items(get_payables_file())
enriched = [enrich_item(item) for item in items]
if status_filter:
enriched = [item for item in enriched if item["status"] == status_filter]
if sort_by_overdue:
enriched.sort(key=lambda x: x["overdue_days"], reverse=True)
return enriched
# ============================================================
# 提醒生成
# ============================================================
def generate_reminders(item_type: str = "receivable") -> Dict[str, Any]:
"""生成应收/应付提醒清单。
免费版限制:最多显示 3 条提醒。
付费版:无限制。
Args:
item_type: "receivable"(应收)或 "payable"(应付)。
Returns:
提醒清单数据。
"""
sub = check_subscription()
max_reminders = sub["max_reminders"] # -1 表示无限制
if item_type == "receivable":
items = list_receivables(status_filter=None)
# 只显示未收款的
items = [i for i in items if i["status"] != "paid"]
else:
items = list_payables(status_filter=None)
items = [i for i in items if i["status"] != "paid"]
# 分类
overdue = [i for i in items if i["overdue_days"] > 0]
due_soon = [i for i in items if -7 <= i["overdue_days"] <= 0]
upcoming = [i for i in items if i["overdue_days"] < -7]
# 构建提醒列表
reminders = []
for item in overdue:
name_key = "customer" if item_type == "receivable" else "vendor"
reminders.append({
"priority": "高",
"name": item.get(name_key, "未知"),
"amount": item["amount"],
"due_date": item["due_date"],
"overdue_days": item["overdue_days"],
"message": (
f"【逾期】{item.get(name_key, '未知')} — "
f"{format_currency(item['amount'])},已逾期 {item['overdue_days']} 天"
),
})
for item in due_soon:
name_key = "customer" if item_type == "receivable" else "vendor"
reminders.append({
"priority": "中",
"name": item.get(name_key, "未知"),
"amount": item["amount"],
"due_date": item["due_date"],
"overdue_days": item["overdue_days"],
"message": (
f"【即将到期】{item.get(name_key, '未知')} — "
f"{format_currency(item['amount'])},{item['status_label']}"
),
})
for item in upcoming:
name_key = "customer" if item_type == "receivable" else "vendor"
reminders.append({
"priority": "低",
"name": item.get(name_key, "未知"),
"amount": item["amount"],
"due_date": item["due_date"],
"overdue_days": item["overdue_days"],
"message": (
f"【待处理】{item.get(name_key, '未知')} — "
f"{format_currency(item['amount'])},{item['status_label']}"
),
})
# 免费版限制
truncated = False
total_count = len(reminders)
if max_reminders > 0 and len(reminders) > max_reminders:
reminders = reminders[:max_reminders]
truncated = True
# 统计
total_overdue_amount = sum(i["amount"] for i in overdue)
total_due_soon_amount = sum(i["amount"] for i in due_soon)
result = {
"type": "应收账款" if item_type == "receivable" else "应付账款",
"reminders": reminders,
"total_count": total_count,
"shown_count": len(reminders),
"overdue_count": len(overdue),
"due_soon_count": len(due_soon),
"total_overdue_amount": round(total_overdue_amount, 2),
"total_due_soon_amount": round(total_due_soon_amount, 2),
}
if truncated:
result["notice"] = (
f"免费版仅显示前 {max_reminders} 条提醒(共 {total_count} 条)。"
"升级至付费版(¥79/月)可查看全部提醒。"
)
# 生成 Markdown
result["markdown"] = _render_reminders_markdown(result, item_type)
return result
def generate_collection_notice(receivable_id: str) -> Dict[str, Any]:
"""生成单条催款通知。
Args:
receivable_id: 应收账款记录 ID。
Returns:
催款通知内容。
"""
items = load_items(get_receivables_file())
target = None
for item in items:
if item.get("id") == receivable_id:
target = enrich_item(item)
break
if not target:
raise ValueError(f"未找到应收账款记录: {receivable_id}")
customer = target.get("customer", "客户")
amount = format_currency(target["amount"])
due_date = target["due_date"]
overdue_days = target["overdue_days"]
invoice = target.get("invoice_no", "")
if overdue_days > 0:
notice = (
f"尊敬的 {customer}:\n\n"
f"贵公司有一笔款项 {amount}(发票号:{invoice})"
f"已于 {due_date} 到期,目前已逾期 {overdue_days} 天。\n\n"
f"请尽快安排付款,如已付款请忽略此提醒。\n\n"
f"如有疑问,请随时联系我们。\n\n谢谢!"
)
else:
notice = (
f"尊敬的 {customer}:\n\n"
f"温馨提醒,贵公司有一笔款项 {amount}(发票号:{invoice})"
f"将于 {due_date} 到期。\n\n"
f"请提前安排付款,感谢您的合作!"
)
return {
"receivable_id": receivable_id,
"customer": customer,
"amount": target["amount"],
"due_date": due_date,
"overdue_days": overdue_days,
"notice": notice,
}
def _render_reminders_markdown(result: Dict[str, Any], item_type: str) -> str:
"""渲染提醒清单为 Markdown 格式。"""
type_label = result["type"]
lines = [
f"# {type_label}提醒清单\n",
f"生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M')}\n",
]
if not result["reminders"]:
lines.append(f"\n暂无待处理的{type_label}记录。\n")
return "".join(lines)
# 汇总统计
lines.extend([
"\n## 概览\n",
f"- 逾期数量:{result['overdue_count']} 笔,"
f"合计 {format_currency(result['total_overdue_amount'])}\n",
f"- 即将到期:{result['due_soon_count']} 笔,"
f"合计 {format_currency(result['total_due_soon_amount'])}\n",
f"- 总计:{result['total_count']} 笔\n",
])
# 明细表
name_header = "客户" if item_type == "receivable" else "供应商"
lines.extend([
"\n## 明细\n",
f"| 优先级 | {name_header} | 金额 | 到期日 | 状态 |\n",
"|--------|------|-----:|--------|------|\n",
])
for r in result["reminders"]:
lines.append(
f"| {r['priority']} | {r['name']} | {format_currency(r['amount'])} | "
f"{r['due_date']} | {'逾期' + str(r['overdue_days']) + '天' if r['overdue_days'] > 0 else '待处理'} |\n"
)
if result.get("notice"):
lines.extend(["\n---\n", f"> {result['notice']}\n"])
lines.append("\n---\n*由 cashflow-pilot 自动生成*\n")
return "".join(lines)
# ============================================================
# 主入口
# ============================================================
def main():
"""主入口函数。"""
parser = create_parser("cashflow-pilot 应收应付提醒生成工具")
parser.add_argument(
"--type",
choices=["receivable", "payable"],
default="receivable",
help="类型:receivable(应收)或 payable(应付)",
)
parser.add_argument("--id", default=None, help="记录 ID(用于催款通知生成)")
try:
args = parser.parse_args()
except SystemExit:
return
try:
if args.action == "add":
data = load_input_data(args)
if not data:
output_error("请通过 --data 或 --data-file 提供数据", "MISSING_DATA")
return
if args.type == "receivable":
record = add_receivable(data)
else:
record = add_payable(data)
output_success({"message": "记录添加成功", "record": record})
elif args.action == "list":
if args.type == "receivable":
items = list_receivables()
else:
items = list_payables()
output_success({
"type": args.type,
"count": len(items),
"items": items,
})
elif args.action == "overdue":
if args.type == "receivable":
items = list_receivables(status_filter="overdue")
else:
items = list_payables(status_filter="overdue")
total_amount = sum(i["amount"] for i in items)
output_success({
"type": args.type,
"overdue_count": len(items),
"total_overdue_amount": round(total_amount, 2),
"items": items,
})
elif args.action == "generate":
result = generate_reminders(args.type)
output_success(result)
elif args.action == "notice":
if not args.id:
output_error("请使用 --id 参数指定应收账款记录 ID", "MISSING_ID")
return
notice = generate_collection_notice(args.id)
output_success(notice)
else:
output_error(
f"不支持的操作: {args.action}。"
"支持的操作: add, list, overdue, generate, notice",
"INVALID_ACTION",
)
except ValueError as e:
output_error(str(e), "REMINDER_ERROR")
except Exception as e:
output_error(f"执行失败: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
cashflow-pilot 共享工具模块
提供数据格式化、输入输出处理、订阅校验、现金流专用辅助函数等通用功能。
"""
import argparse
import json
import os
import re
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
# ============================================================
# 常量定义
# ============================================================
DEFAULT_DATA_DIR = os.path.join(os.path.expanduser("~"), ".openclaw-bdi", "cashflow-pilot")
# 收支分类关键词映射
INCOME_KEYWORDS: Dict[str, List[str]] = {
"销售回款": ["销售", "回款", "货款", "收款", "营收", "revenue", "sales"],
"服务收入": ["服务费", "咨询费", "技术服务", "顾问费", "service"],
"投资收益": ["利息", "分红", "投资收益", "理财", "interest", "dividend"],
"其他收入": ["退款", "补贴", "赔偿", "奖励", "返利"],
}
EXPENSE_KEYWORDS: Dict[str, List[str]] = {
"人员工资": ["工资", "薪资", "社保", "公积金", "奖金", "salary", "payroll"],
"房租物业": ["房租", "租金", "物业", "水电", "rent"],
"采购成本": ["采购", "进货", "原材料", "purchase", "procurement"],
"办公费用": ["办公", "文具", "打印", "耗材", "office"],
"营销推广": ["广告", "推广", "营销", "市场", "marketing"],
"差旅费用": ["差旅", "出差", "机票", "住宿", "交通", "travel"],
"税费": ["税", "增值税", "所得税", "印花税", "tax"],
"其他支出": ["手续费", "快递", "维修", "杂费"],
}
# ============================================================
# 数字格式化
# ============================================================
def format_number(value: float, decimals: int = 2) -> str:
"""将数字格式化为带千分位分隔符的字符串。
Args:
value: 待格式化的数值。
decimals: 小数位数,默认为 2。
Returns:
格式化后的字符串,例如 1234567 → "1,234,567.00"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
if decimals <= 0:
return f"{int(round(num)):,}"
return f"{num:,.{decimals}f}"
def format_percentage(value: float, decimals: int = 1) -> str:
"""将小数格式化为百分比字符串。
Args:
value: 待格式化的小数值(0.156 表示 15.6%)。
decimals: 百分比小数位数,默认为 1。
Returns:
百分比字符串,例如 0.156 → "15.6%"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
pct = num * 100
return f"{pct:.{decimals}f}%"
def format_chinese_unit(value: float) -> str:
"""将大数字转换为中文单位表示(万、亿)。
Args:
value: 待转换的数值。
Returns:
带中文单位的字符串,例如:
- 12345 → "1.23万"
- 123456789 → "1.23亿"
- 999 → "999"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
abs_num = abs(num)
sign = "-" if num < 0 else ""
if abs_num >= 1e8:
result = abs_num / 1e8
return f"{sign}{result:.2f}亿"
elif abs_num >= 1e4:
result = abs_num / 1e4
return f"{sign}{result:.2f}万"
else:
if abs_num == int(abs_num):
return f"{sign}{int(abs_num)}"
return f"{sign}{abs_num:.2f}"
def format_currency(value: float, symbol: str = "¥") -> str:
"""将数字格式化为货币字符串。
Args:
value: 待格式化的金额。
symbol: 货币符号,默认 "¥"。
Returns:
货币字符串,例如 12345.6 → "¥12,345.60"
"""
return f"{symbol}{format_number(value, 2)}"
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_stdin() -> Dict[str, Any]:
"""从标准输入读取 JSON 数据并解析为字典。
Returns:
解析后的字典对象。
Raises:
ValueError: 当输入为空或 JSON 格式不合法时抛出。
"""
try:
raw = sys.stdin.read()
except Exception as e:
raise ValueError(f"读取标准输入失败: {e}")
if not raw.strip():
raise ValueError("标准输入为空,未读取到任何数据")
try:
data = json.loads(raw)
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if not isinstance(data, dict):
raise ValueError(f"期望输入为 JSON 对象,实际类型为 {type(data).__name__}")
return data
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
Args:
data: 待输出的数据(可被 JSON 序列化的任意对象)。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 数据目录管理
# ============================================================
def get_data_dir() -> str:
"""获取数据存储目录,若不存在则自动创建。
优先读取 CFP_DATA_DIR 环境变量,否则使用默认路径。
Returns:
数据目录的绝对路径。
"""
data_dir = os.environ.get("CFP_DATA_DIR", DEFAULT_DATA_DIR)
os.makedirs(data_dir, exist_ok=True)
return data_dir
def load_ledger(ledger_file: Optional[str] = None) -> List[Dict[str, Any]]:
"""加载账本数据。
Args:
ledger_file: 账本文件路径,默认为数据目录下的 ledger.json。
Returns:
账本记录列表。
"""
if ledger_file is None:
ledger_file = os.path.join(get_data_dir(), "ledger.json")
if not os.path.exists(ledger_file):
return []
try:
with open(ledger_file, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return data
return []
except (json.JSONDecodeError, IOError):
return []
def save_ledger(records: List[Dict[str, Any]], ledger_file: Optional[str] = None) -> str:
"""保存账本数据。
Args:
records: 账本记录列表。
ledger_file: 账本文件路径,默认为数据目录下的 ledger.json。
Returns:
保存的文件路径。
"""
if ledger_file is None:
ledger_file = os.path.join(get_data_dir(), "ledger.json")
os.makedirs(os.path.dirname(ledger_file), exist_ok=True)
with open(ledger_file, "w", encoding="utf-8") as f:
json.dump(records, f, ensure_ascii=False, indent=2, default=str)
return ledger_file
# ============================================================
# 订阅校验
# ============================================================
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_reminders": 3,
"forecast_enabled": False,
"anomaly_detection": False,
"bank_statement_parse": False,
"report_level": "basic",
"features": ["manual_entry", "csv_import", "basic_report", "limited_reminder"],
},
"paid": {
"tier": "paid",
"max_reminders": -1, # -1 表示无限制
"forecast_enabled": True,
"anomaly_detection": True,
"bank_statement_parse": True,
"report_level": "full",
"features": [
"manual_entry",
"csv_import",
"excel_import",
"full_report",
"unlimited_reminder",
"forecast",
"anomaly_detection",
"bank_statement_parse",
],
},
}
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 CFP_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典。
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get("CFP_SUBSCRIPTION_TIER", "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
return dict(_SUBSCRIPTION_TIERS[tier])
def require_paid(feature_name: str) -> bool:
"""检查是否为付费用户,若非付费则输出升级提示并返回 False。
Args:
feature_name: 功能名称,用于提示信息。
Returns:
True 表示已付费可使用,False 表示未付费。
"""
sub = check_subscription()
if sub["tier"] != "paid":
output_error(
f"「{feature_name}」为付费版功能。当前为免费版,请升级至付费版(¥79/月)以使用此功能。",
code="SUBSCRIPTION_REQUIRED",
)
return False
return True
# ============================================================
# 收支分类
# ============================================================
def classify_transaction(description: str) -> Dict[str, str]:
"""根据交易描述自动分类收支。
Args:
description: 交易描述文本。
Returns:
包含 type(income/expense)和 category 的字典。
"""
desc_lower = description.lower()
# 先检查收入关键词
for category, keywords in INCOME_KEYWORDS.items():
for kw in keywords:
if kw in desc_lower:
return {"type": "income", "category": category}
# 再检查支出关键词
for category, keywords in EXPENSE_KEYWORDS.items():
for kw in keywords:
if kw in desc_lower:
return {"type": "expense", "category": category}
# 默认根据金额方向无法判断时归为其他
return {"type": "unknown", "category": "未分类"}
def classify_by_amount(amount: float, description: str = "") -> Dict[str, str]:
"""根据金额正负和描述综合分类。
Args:
amount: 交易金额,正数为收入,负数为支出。
description: 交易描述,用于细分分类。
Returns:
包含 type 和 category 的字典。
"""
cls = classify_transaction(description)
if cls["type"] == "unknown":
if amount > 0:
cls["type"] = "income"
cls["category"] = "其他收入"
elif amount < 0:
cls["type"] = "expense"
cls["category"] = "其他支出"
return cls
# ============================================================
# 日期辅助
# ============================================================
def parse_date(date_str: str) -> Optional[datetime]:
"""尝试解析多种日期格式。
支持格式: YYYY-MM-DD, YYYY/MM/DD, YYYYMMDD, DD/MM/YYYY, MM/DD/YYYY
Args:
date_str: 日期字符串。
Returns:
解析后的 datetime 对象,解析失败返回 None。
"""
formats = [
"%Y-%m-%d",
"%Y/%m/%d",
"%Y%m%d",
"%Y-%m-%d %H:%M:%S",
"%Y/%m/%d %H:%M:%S",
"%d/%m/%Y",
"%m/%d/%Y",
]
date_str = date_str.strip()
for fmt in formats:
try:
return datetime.strptime(date_str, fmt)
except ValueError:
continue
return None
def get_month_range(year: int, month: int) -> tuple:
"""获取指定月份的起止日期。
Args:
year: 年份。
month: 月份(1-12)。
Returns:
(start_date, end_date) 元组。
"""
start = datetime(year, month, 1)
if month == 12:
end = datetime(year + 1, 1, 1) - timedelta(days=1)
else:
end = datetime(year, month + 1, 1) - timedelta(days=1)
return start, end
# ============================================================
# 命令行参数解析
# ============================================================
def create_parser(description: str) -> argparse.ArgumentParser:
"""创建带通用参数的命令行解析器。
Args:
description: 工具描述。
Returns:
配置好的 ArgumentParser 实例。
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--action",
required=True,
help="执行的操作",
)
parser.add_argument(
"--data",
default=None,
help="JSON 格式的输入数据(直接传入字符串)",
)
parser.add_argument(
"--data-file",
default=None,
help="JSON 数据文件路径",
)
parser.add_argument(
"--file",
default=None,
help="待处理的文件路径",
)
return parser
def load_input_data(args: argparse.Namespace) -> Optional[Any]:
"""从命令行参数加载输入数据。
优先使用 --data(JSON 字符串),其次 --data-file(文件路径),
最后尝试从标准输入读取。
Args:
args: 解析后的命令行参数。
Returns:
解析后的数据,如果无数据返回 None。
"""
if args.data:
try:
return json.loads(args.data)
except json.JSONDecodeError as e:
raise ValueError(f"--data 参数 JSON 解析失败: {e}")
if args.data_file:
if not os.path.exists(args.data_file):
raise ValueError(f"数据文件不存在: {args.data_file}")
with open(args.data_file, "r", encoding="utf-8") as f:
return json.load(f)
# 检查是否有标准输入
if not sys.stdin.isatty():
raw = sys.stdin.read().strip()
if raw:
return json.loads(raw)
return None
FILE:scripts/forecast_engine.py
#!/usr/bin/env python3
"""
cashflow-pilot 现金流预测模块(付费版)
基于历史数据使用移动平均和线性回归进行未来现金流预测,
提供风险预警和趋势判断。
"""
import json
import math
import os
import sys
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import (
check_subscription,
create_parser,
format_currency,
format_percentage,
load_input_data,
load_ledger,
output_error,
output_success,
require_paid,
)
# ============================================================
# 数据准备
# ============================================================
def prepare_monthly_series(records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将账本记录汇总为按月排序的时间序列。
Args:
records: 账本记录列表。
Returns:
按月排序的汇总数据列表,每项包含 period, income, expense, net。
"""
monthly = defaultdict(lambda: {"income": 0.0, "expense": 0.0})
for r in records:
date_str = r.get("date", "")
try:
dt = datetime.strptime(date_str, "%Y-%m-%d")
key = f"{dt.year}-{dt.month:02d}"
except (ValueError, TypeError):
continue
amount = float(r.get("amount", 0))
txn_type = r.get("type", "unknown")
if txn_type == "income":
monthly[key]["income"] += amount
elif txn_type == "expense":
monthly[key]["expense"] += amount
# 按时间排序
series = []
for key in sorted(monthly.keys()):
data = monthly[key]
series.append({
"period": key,
"income": round(data["income"], 2),
"expense": round(data["expense"], 2),
"net": round(data["income"] - data["expense"], 2),
})
return series
# ============================================================
# 预测算法
# ============================================================
def moving_average(values: List[float], window: int = 3) -> float:
"""计算移动平均值。
Args:
values: 历史数值序列。
window: 窗口大小,默认 3。
Returns:
移动平均值。
"""
if not values:
return 0.0
window = min(window, len(values))
return sum(values[-window:]) / window
def linear_regression(values: List[float]) -> Tuple[float, float]:
"""简单线性回归,返回斜率和截距。
使用最小二乘法拟合 y = a * x + b。
Args:
values: 历史数值序列(y 值,x 值为 0, 1, 2, ...)。
Returns:
(斜率 a, 截距 b) 元组。
"""
n = len(values)
if n < 2:
return 0.0, values[0] if values else 0.0
x_vals = list(range(n))
x_mean = sum(x_vals) / n
y_mean = sum(values) / n
numerator = sum((x - x_mean) * (y - y_mean) for x, y in zip(x_vals, values))
denominator = sum((x - x_mean) ** 2 for x in x_vals)
if denominator == 0:
return 0.0, y_mean
a = numerator / denominator
b = y_mean - a * x_mean
return a, b
def predict_values(
values: List[float], months_ahead: int = 3, method: str = "combined"
) -> List[Dict[str, Any]]:
"""预测未来 N 个月的数值。
Args:
values: 历史数值序列。
months_ahead: 预测月数,默认 3。
method: 预测方法,"ma"(移动平均)、"lr"(线性回归)、"combined"(加权组合)。
Returns:
预测结果列表。
"""
if not values:
return [{"month_offset": i + 1, "predicted": 0.0, "method": method}
for i in range(months_ahead)]
predictions = []
n = len(values)
for i in range(months_ahead):
if method == "ma":
pred = moving_average(values, window=3)
elif method == "lr":
a, b = linear_regression(values)
pred = a * (n + i) + b
else: # combined
ma_pred = moving_average(values, window=3)
a, b = linear_regression(values)
lr_pred = a * (n + i) + b
# 加权组合:移动平均 40%,线性回归 60%
pred = ma_pred * 0.4 + lr_pred * 0.6
pred = max(pred, 0) # 不允许负值预测
predictions.append({
"month_offset": i + 1,
"predicted": round(pred, 2),
"method": method,
})
# 将预测值加入序列以支持递推预测
values = values + [pred]
return predictions
# ============================================================
# 风险预警
# ============================================================
def assess_risk(
series: List[Dict[str, Any]],
predictions: Dict[str, List[Dict[str, Any]]],
) -> List[Dict[str, Any]]:
"""评估现金流风险。
Args:
series: 历史月度数据序列。
predictions: 预测结果(包含 income 和 expense 的预测)。
Returns:
风险预警列表。
"""
warnings = []
# 检查预测净现金流是否为负
income_preds = predictions.get("income", [])
expense_preds = predictions.get("expense", [])
for i in range(min(len(income_preds), len(expense_preds))):
inc = income_preds[i]["predicted"]
exp = expense_preds[i]["predicted"]
net = inc - exp
if net < 0:
warnings.append({
"level": "高",
"month_offset": i + 1,
"message": (
f"预测第 {i + 1} 个月净现金流为负({format_currency(net)}),"
f"预测收入 {format_currency(inc)},支出 {format_currency(exp)}。"
"建议提前做好资金储备或催收应收款。"
),
})
# 检查支出增长趋势
if len(series) >= 3:
recent_expenses = [s["expense"] for s in series[-3:]]
if all(recent_expenses[i] < recent_expenses[i + 1] for i in range(len(recent_expenses) - 1)):
growth_rate = (recent_expenses[-1] - recent_expenses[0]) / recent_expenses[0] if recent_expenses[0] > 0 else 0
if growth_rate > 0.2:
warnings.append({
"level": "中",
"month_offset": 0,
"message": (
f"支出连续3个月增长,累计增幅 {format_percentage(growth_rate)}。"
"建议审视支出结构,控制非必要开支。"
),
})
# 检查收入波动
if len(series) >= 4:
incomes = [s["income"] for s in series[-4:]]
avg_income = sum(incomes) / len(incomes)
if avg_income > 0:
variance = sum((x - avg_income) ** 2 for x in incomes) / len(incomes)
std_dev = math.sqrt(variance)
cv = std_dev / avg_income
if cv > 0.3:
warnings.append({
"level": "中",
"month_offset": 0,
"message": (
f"收入波动较大(变异系数 {cv:.2f}),"
"建议建立至少2个月支出的现金储备。"
),
})
return warnings
# ============================================================
# 预测报告生成
# ============================================================
def generate_forecast(
records: List[Dict[str, Any]], months_ahead: int = 3
) -> Dict[str, Any]:
"""生成完整的现金流预测报告。
Args:
records: 账本记录列表。
months_ahead: 预测月数,默认 3。
Returns:
预测报告数据字典。
"""
series = prepare_monthly_series(records)
if len(series) < 2:
return {
"error": "历史数据不足,至少需要2个月的数据才能进行预测。",
"months_available": len(series),
}
# 分别预测收入和支出
income_values = [s["income"] for s in series]
expense_values = [s["expense"] for s in series]
net_values = [s["net"] for s in series]
income_forecast = predict_values(income_values, months_ahead)
expense_forecast = predict_values(expense_values, months_ahead)
net_forecast = predict_values(net_values, months_ahead)
# 计算预测期间的标签
last_period = series[-1]["period"]
last_year, last_month = map(int, last_period.split("-"))
for forecasts in [income_forecast, expense_forecast, net_forecast]:
for f in forecasts:
m = last_month + f["month_offset"]
y = last_year
while m > 12:
m -= 12
y += 1
f["period"] = f"{y}-{m:02d}"
# 风险评估
risk_warnings = assess_risk(
series,
{"income": income_forecast, "expense": expense_forecast},
)
# 生成 Mermaid 图表
mermaid_chart = _build_forecast_chart(series, income_forecast, expense_forecast)
# 汇总
total_pred_income = sum(f["predicted"] for f in income_forecast)
total_pred_expense = sum(f["predicted"] for f in expense_forecast)
total_pred_net = total_pred_income - total_pred_expense
report = {
"forecast_months": months_ahead,
"historical_months": len(series),
"income_forecast": income_forecast,
"expense_forecast": expense_forecast,
"net_forecast": net_forecast,
"total_predicted_income": round(total_pred_income, 2),
"total_predicted_expense": round(total_pred_expense, 2),
"total_predicted_net": round(total_pred_net, 2),
"risk_warnings": risk_warnings,
"mermaid_chart": mermaid_chart,
"markdown": _render_forecast_markdown(
income_forecast, expense_forecast, net_forecast,
risk_warnings, mermaid_chart, months_ahead,
),
}
return report
def _build_forecast_chart(
series: List[Dict[str, Any]],
income_forecast: List[Dict[str, Any]],
expense_forecast: List[Dict[str, Any]],
) -> str:
"""生成预测趋势图 Mermaid 代码。"""
# 取最近3个月历史 + 预测数据
recent = series[-3:] if len(series) >= 3 else series
all_periods = [s["period"] for s in recent] + [f["period"] for f in income_forecast]
income_data = [s["income"] for s in recent] + [f["predicted"] for f in income_forecast]
expense_data = [s["expense"] for s in recent] + [f["predicted"] for f in expense_forecast]
lines = [
"```mermaid",
"xychart-beta",
' title "现金流预测趋势"',
' x-axis [' + ", ".join(f'"{p}"' for p in all_periods) + "]",
' y-axis "金额(元)"',
" line [" + ", ".join(str(round(v, 0)) for v in income_data) + "]",
" line [" + ", ".join(str(round(v, 0)) for v in expense_data) + "]",
"```",
]
return "\n".join(lines)
def _render_forecast_markdown(
income_forecast: List[Dict[str, Any]],
expense_forecast: List[Dict[str, Any]],
net_forecast: List[Dict[str, Any]],
risk_warnings: List[Dict[str, Any]],
mermaid_chart: str,
months: int,
) -> str:
"""渲染预测报告为 Markdown 格式。"""
lines = [
f"# 现金流预测报告(未来{months}个月)\n",
"## 预测概览\n",
"| 月份 | 预测收入 | 预测支出 | 预测净现金流 |\n",
"|------|--------:|--------:|-----------:|\n",
]
for i in range(len(income_forecast)):
period = income_forecast[i]["period"]
inc = format_currency(income_forecast[i]["predicted"])
exp = format_currency(expense_forecast[i]["predicted"])
net = format_currency(net_forecast[i]["predicted"])
lines.append(f"| {period} | {inc} | {exp} | {net} |\n")
if mermaid_chart:
lines.extend(["\n## 趋势预测图\n", mermaid_chart, "\n"])
if risk_warnings:
lines.append("\n## 风险预警\n")
for w in risk_warnings:
level_icon = {"高": "!!!", "中": "!!", "低": "!"}.get(w["level"], "!")
lines.append(f"- [{level_icon} {w['level']}风险] {w['message']}\n")
else:
lines.append("\n## 风险评估\n\n当前现金流状况良好,未检测到显著风险。\n")
lines.append("\n---\n*预测由 cashflow-pilot 基于历史数据生成,仅供参考*\n")
return "".join(lines)
# ============================================================
# 主入口
# ============================================================
def main():
"""主入口函数。"""
parser = create_parser("cashflow-pilot 现金流预测工具(付费版)")
parser.add_argument("--months", type=int, default=3, help="预测月数,默认 3")
try:
args = parser.parse_args()
except SystemExit:
return
try:
if args.action == "predict":
if not require_paid("现金流预测"):
return
# 加载数据
data = load_input_data(args)
if isinstance(data, list):
records = data
elif isinstance(data, dict) and "records" in data:
records = data["records"]
else:
records = load_ledger()
if not records:
output_error("没有可用的账本数据,请先导入账单。", "NO_DATA")
return
forecast = generate_forecast(records, args.months)
output_success(forecast)
else:
output_error(
f"不支持的操作: {args.action}。支持的操作: predict",
"INVALID_ACTION",
)
except ValueError as e:
output_error(str(e), "FORECAST_ERROR")
except Exception as e:
output_error(f"预测执行失败: {e}", "INTERNAL_ERROR")
if __name__ == "__main__":
main()
FILE:references/report-templates.md
# 现金流报告模板
本文档提供 cashflow-pilot 各类报告的标准模板和示例。
---
## 一、月度现金流报告
### 1.1 基础版(免费版)
```markdown
# 现金流月报 — YYYY年M月
统计记录数:XX 条
## 核心指标
| 指标 | 金额 |
|------|-----:|
| 总收入 | ¥XXX,XXX.XX |
| 总支出 | ¥XXX,XXX.XX |
| 净现金流 | ¥XXX,XXX.XX |
| 收入笔数 | XX |
| 支出笔数 | XX |
## 收入明细
| 收入分类 | 金额 | 占比 |
|----------|-----:|-----:|
| 销售回款 | ¥XXX,XXX.XX | XX.X% |
| 服务收入 | ¥XX,XXX.XX | XX.X% |
| 其他收入 | ¥X,XXX.XX | XX.X% |
| **合计** | **¥XXX,XXX.XX** | **100.0%** |
## 支出明细
| 支出分类 | 金额 | 占比 |
|----------|-----:|-----:|
| 人员工资 | ¥XX,XXX.XX | XX.X% |
| 采购成本 | ¥XX,XXX.XX | XX.X% |
| 房租物业 | ¥XX,XXX.XX | XX.X% |
| 其他支出 | ¥X,XXX.XX | XX.X% |
| **合计** | **¥XXX,XXX.XX** | **100.0%** |
---
*报告由 cashflow-pilot 自动生成*
```
### 1.2 完整版(付费版)
```markdown
# 现金流月报 — YYYY年M月
统计记录数:XX 条 | 数据源:XX
## 核心指标
| 指标 | 本月 | 上月 | 环比 |
|------|-----:|-----:|-----:|
| 总收入 | ¥XXX,XXX | ¥XXX,XXX | +X.X% |
| 总支出 | ¥XXX,XXX | ¥XXX,XXX | -X.X% |
| 净现金流 | ¥XX,XXX | ¥XX,XXX | +XX.X% |
## 收入分类占比
(此处插入收入分类占比表格,同基础版)
```mermaid
pie title 收入分类占比
"销售回款" : 165000
"服务收入" : 20000
"其他收入" : 5000
```
## 支出分类占比
(此处插入支出分类占比表格,同基础版)
```mermaid
pie title 支出分类占比
"人员工资" : 85000
"采购成本" : 30000
"房租物业" : 15000
"营销推广" : 8000
"其他支出" : 5000
```
## 现金流趋势(近6个月)
```mermaid
xychart-beta
title "现金流趋势"
x-axis ["10月", "11月", "12月", "1月", "2月", "3月"]
y-axis "金额(元)"
line [150000, 165000, 180000, 170000, 185000, 190000]
line [120000, 125000, 135000, 128000, 130000, 132500]
```
## 异常告警
- 「营销推广」本月支出 ¥8,000.00,是过去6月均值(¥3,500.00)的 2.3 倍,请关注。
## 洞察与建议
1. 本月净现金流为正(¥57,500.00),现金留存率 30.3%。
2. 最大支出类别为「人员工资」,占总支出 64.2%,金额 ¥85,000.00。
3. 收入高度集中于「销售回款」(占比 86.8%),建议拓展收入来源以降低风险。
---
*报告由 cashflow-pilot 自动生成*
```
---
## 二、应收账款账龄分析报告
```markdown
# 应收账款账龄分析 — YYYY-MM-DD
## 概览
| 账龄区间 | 笔数 | 金额 | 占比 |
|----------|-----:|-----:|-----:|
| 未到期 | X | ¥XX,XXX | XX.X% |
| 逾期 1-30 天 | X | ¥XX,XXX | XX.X% |
| 逾期 31-60 天 | X | ¥XX,XXX | XX.X% |
| 逾期 61-90 天 | X | ¥XX,XXX | XX.X% |
| 逾期 90 天以上 | X | ¥XX,XXX | XX.X% |
| **合计** | **XX** | **¥XXX,XXX** | **100.0%** |
## 逾期明细(按逾期天数排序)
| 客户 | 金额 | 发票号 | 到期日 | 逾期天数 |
|------|-----:|--------|--------|--------:|
| 客户A | ¥50,000 | INV-2026-001 | 2026-01-15 | 63 |
| 客户B | ¥30,000 | INV-2026-005 | 2026-02-01 | 46 |
| 客户C | ¥20,000 | INV-2026-012 | 2026-02-28 | 19 |
## 账龄分布图
```mermaid
pie title 应收账款账龄分布
"未到期" : 45000
"1-30天" : 25000
"31-60天" : 30000
"61-90天" : 50000
"90天以上" : 10000
```
## 建议
1. 重点关注逾期超过60天的应收款(共 ¥60,000),建议主动联系客户催收。
2. 考虑对长期逾期客户调整信用政策。
3. 定期更新应收账款状态,保持数据准确。
---
*报告由 cashflow-pilot 自动生成*
```
---
## 三、现金流预测报告(仅付费版)
```markdown
# 现金流预测报告(未来3个月)
基于近 6 个月历史数据,使用移动平均+线性回归加权模型预测。
## 预测概览
| 月份 | 预测收入 | 预测支出 | 预测净现金流 |
|------|--------:|--------:|-----------:|
| YYYY-MM | ¥XXX,XXX | ¥XXX,XXX | ¥XX,XXX |
| YYYY-MM | ¥XXX,XXX | ¥XXX,XXX | ¥XX,XXX |
| YYYY-MM | ¥XXX,XXX | ¥XXX,XXX | ¥XX,XXX |
## 趋势预测图
```mermaid
xychart-beta
title "现金流预测趋势"
x-axis ["1月", "2月", "3月", "4月(预)", "5月(预)", "6月(预)"]
y-axis "金额(元)"
line [170000, 185000, 190000, 180000, 175000, 190000]
line [128000, 130000, 132500, 135000, 140000, 138000]
```
## 风险预警
- [!! 中风险] 支出连续3个月增长,累计增幅 15.2%。建议审视支出结构,控制非必要开支。
- [! 低风险] 收入波动较大(变异系数 0.25),建议建立至少2个月支出的现金储备。
## 风险评估总结
| 风险维度 | 评估 | 说明 |
|----------|------|------|
| 现金流断裂风险 | 低 | 未来3个月预测净现金流均为正 |
| 支出失控风险 | 中 | 支出持续增长,需关注 |
| 收入波动风险 | 中 | 建议拓宽收入渠道 |
## 建议
1. 维持至少 ¥270,000 的现金储备(约2个月支出)。
2. 关注支出增长趋势,特别是「营销推广」类目。
3. 加快应收账款催收,改善现金回流速度。
---
*预测由 cashflow-pilot 基于历史数据生成,仅供参考*
```
---
## 四、Mermaid 图表使用指南
### 饼图(分类占比)
```mermaid
pie title 标题
"分类1" : 数值1
"分类2" : 数值2
"分类3" : 数值3
```
### 折线图(趋势/预测)
```mermaid
xychart-beta
title "图表标题"
x-axis ["标签1", "标签2", "标签3"]
y-axis "Y轴标题"
line [数值1, 数值2, 数值3]
```
### 柱状图(对比)
```mermaid
xychart-beta
title "收支对比"
x-axis ["1月", "2月", "3月"]
y-axis "金额(元)"
bar [150000, 165000, 180000]
bar [120000, 125000, 135000]
```
> 注意:Mermaid 图表仅在付费版报告中使用。免费版报告使用纯文本表格展示数据。
商业数据洞察 — 连接业务数据源,自动生成数据分析报告和看板
---
name: biz-data-insight
description: 商业数据洞察 — 连接业务数据源,自动生成数据分析报告和看板
version: 1.0.0
metadata:
openclaw:
requires:
env:
- BDI_DATASOURCE_TYPE
- BDI_DATASOURCE_URI
optional_env:
- BDI_SUBSCRIPTION_TIER
- BDI_DB_PASSWORD
---
# 商业数据洞察(biz-data-insight)
你是一个专业的商业数据分析师 Agent。你的职责是帮助用户连接业务数据源、执行数据查询、生成分析报告和可视化看板。你始终使用中文与用户沟通。
## 环境变量说明
| 变量 | 必需 | 说明 |
|------|------|------|
| `BDI_DATASOURCE_TYPE` | 是 | 数据源类型(mysql / postgresql / csv / excel / json) |
| `BDI_DATASOURCE_URI` | 是 | 数据源连接地址或文件路径 |
| `BDI_SUBSCRIPTION_TIER` | 否 | 订阅等级,默认 `free`,可选 `paid` |
| `BDI_DB_PASSWORD` | 否 | 数据库密码(若连接字符串中未包含) |
启动时,你必须验证 `BDI_DATASOURCE_TYPE` 和 `BDI_DATASOURCE_URI` 已设置。若缺失,立即提示用户并引导进入「数据源配置引导流程」。
---
## 流程一:数据源配置引导
当用户说"配置数据源"、"连接数据库"、"设置数据源"或类似意图时,执行以下步骤:
### 步骤 1:选择数据源类型
向用户展示支持的数据源类型,并引导选择:
```
请选择数据源类型:
1. MySQL
2. PostgreSQL(仅限付费版)
3. CSV 文件
4. Excel 文件(仅限付费版)
5. JSON 文件(仅限付费版)
```
> 注意:先执行「订阅校验」确认当前等级是否支持所选类型。免费版仅支持 MySQL 和 CSV。
### 步骤 2:收集连接信息
根据数据源类型,引导用户提供对应的连接信息:
- **数据库类型(MySQL / PostgreSQL)**:收集主机地址、端口、数据库名、用户名。密码从 `BDI_DB_PASSWORD` 环境变量读取,绝对不要让用户在对话中直接输入密码。
- **文件类型(CSV / Excel / JSON)**:收集文件路径或 URL。
### 步骤 3:测试连接
使用以下命令测试连接:
```bash
python3 scripts/connect_datasource.py --type <type> --uri <uri> --action test
```
- 若连接成功,告知用户并进入下一步。
- 若连接失败,显示错误信息,引导用户排查(检查网络、凭据、防火墙等)。
### 步骤 4:探索数据结构
连接成功后,自动探索数据结构:
```bash
python3 scripts/connect_datasource.py --type <type> --uri <uri> --action explore
```
将返回的表结构、字段信息、数据量概览以清晰的表格形式展示给用户。例如:
```
数据源概览:
- 数据库:sales_db
- 表数量:12
- 主要表:
| 表名 | 行数 | 字段数 | 说明 |
|------|------|--------|------|
| orders | 150,000 | 12 | 订单表 |
| products | 3,200 | 8 | 产品表 |
| customers | 45,000 | 10 | 客户表 |
```
### 步骤 5:确认并保存
向用户确认数据源配置正确,提示用户将 `BDI_DATASOURCE_TYPE` 和 `BDI_DATASOURCE_URI` 持久化到环境变量中。
---
## 流程二:交互式数据分析
当用户提出数据分析类问题时(如"上月销售额 Top10 产品"、"各区域客户增长率"、"退货率趋势"等),执行以下步骤:
### 步骤 1:订阅校验与配额检查
1. 读取 `BDI_SUBSCRIPTION_TIER` 环境变量,默认为 `free`。
2. 检查当日已使用的查询次数:
- **免费版**:每日上限 5 次查询。若已达上限,提示用户升级至付费版(¥99/月)。
- **付费版**:不限次数。
### 步骤 2:解析用户意图
分析用户的自然语言问题,识别以下要素:
- 目标指标(如销售额、订单量、客户数)
- 维度(如时间、产品、区域、渠道)
- 筛选条件(如时间范围、类目过滤)
- 排序与限制(如 Top10、倒序)
- 分析类型(如趋势、对比、占比、排名)
### 步骤 3:生成安全 SQL
根据解析结果生成 SQL 查询。必须遵守以下安全规则:
- **只允许 SELECT 语句**,严禁 INSERT / UPDATE / DELETE / DROP / ALTER / TRUNCATE。
- 使用参数化查询,防止 SQL 注入。
- 添加 `LIMIT` 约束:免费版最多 100 行,付费版最多 10,000 行。
- 生成 SQL 后,先向用户展示并简要解释,获得确认后再执行。
- 涉及同比/环比分析时,免费版应提示"同比/环比分析为付费功能"并拒绝执行。
### 步骤 4:执行查询
```bash
python3 scripts/query_engine.py --type <type> --uri <uri> --sql "<sql>"
```
- 若执行成功,进入格式化步骤。
- 若执行失败,分析错误原因,尝试修正 SQL 后重新执行(最多重试 2 次)。
### 步骤 5:格式化输出
```bash
python3 scripts/report_generator.py --template interactive --data '<json>'
```
根据订阅等级输出不同格式:
**免费版输出:**
- Markdown 表格展示查询结果
- 简要文字总结(1-2 句话)
**付费版输出:**
- Markdown 表格展示查询结果
- Mermaid 可视化图表(柱状图、折线图、饼图等,参考 `references/mermaid-guide.md`)
- 深度洞察分析(趋势解读、异常标注、建议)
---
## 流程三:定期报告生成
当用户说"生成日报"、"生成周报"、"生成月报"或类似意图时,执行以下步骤:
### 步骤 1:确认报告类型与订阅校验
| 报告类型 | 免费版 | 付费版 |
|----------|--------|--------|
| 日报(basic) | 支持 | 支持 |
| 日报(full) | 不支持 | 支持 |
| 周报 | 不支持 | 支持 |
| 月报 | 不支持 | 支持 |
- 免费版用户请求周报或月报时,提示"周报/月报为付费功能,当前为免费版,请升级至付费版(¥99/月)"。
- 免费版日报仅包含关键指标汇总,不含图表和深度分析。
### 步骤 2:查询核心指标
根据报告类型,通过 `query_engine.py` 查询对应的核心指标:
**日报指标:**
- 当日营收、订单量、客户数
- 与昨日对比(付费版额外提供同比数据)
**周报指标(仅付费版):**
- 本周汇总数据
- 周环比变化
- 本周 Top 产品 / 区域
**月报指标(仅付费版):**
- 月度汇总数据
- 月环比 / 同比变化
- 月度趋势图
- 综合经营分析
### 步骤 3:异常检测(仅付费版)
```bash
python3 scripts/anomaly_detector.py --type <type> --uri <uri> --period <daily|weekly|monthly>
```
- 检测指标异常波动(如日销售额偏离均值超过 2 个标准差)。
- 将异常信息纳入报告的"异常预警"模块。
### 步骤 4:生成报告
```bash
python3 scripts/report_generator.py --template <daily|weekly|monthly> --data '<json>'
```
报告格式参考 `references/report-templates.md`,包含以下模块:
**免费版日报结构:**
```
# 数据日报 — YYYY-MM-DD
## 核心指标概览(表格)
## 简要总结(2-3句话)
```
**付费版日报结构:**
```
# 数据日报 — YYYY-MM-DD
## 核心指标概览(表格 + 与昨日/上周同期对比)
## 趋势图(Mermaid 图表)
## 异常预警(如有)
## 深度洞察与建议
```
**付费版周报/月报结构:**
```
# 数据周报/月报 — 周期范围
## 执行摘要
## 核心指标看板(表格 + 环比/同比)
## 趋势分析(Mermaid 图表)
## Top 排名分析
## 异常预警(如有)
## 深度洞察与经营建议
```
### 步骤 5:输出报告
将生成的 Markdown 报告直接输出给用户。
---
## 订阅校验逻辑
在每次涉及功能限制的操作前,必须执行以下校验:
### 读取订阅等级
```
tier = env BDI_SUBSCRIPTION_TIER,默认 "free"
```
### 功能权限矩阵
| 功能 | 免费版(free) | 付费版(paid,¥99/月) |
|------|---------------|----------------------|
| 数据源数量 | 1 个 | 最多 5 个 |
| 支持数据源类型 | MySQL、CSV | MySQL、PostgreSQL、CSV、Excel、JSON |
| 每日查询次数 | 5 次 | 不限 |
| 单次查询行数上限 | 100 行 | 10,000 行 |
| 日报 | 基础版(仅指标汇总) | 完整版(含图表 + 洞察) |
| 周报 | 不支持 | 支持 |
| 月报 | 不支持 | 支持 |
| Mermaid 可视化图表 | 不支持 | 支持 |
| 异常检测 | 不支持 | 支持 |
| 同比/环比分析 | 不支持 | 支持 |
### 校验失败时的行为
当用户请求的功能超出当前订阅等级时:
1. 明确告知用户当前功能仅限付费版。
2. 简要说明付费版的优势。
3. 提供升级引导:"如需升级至付费版(¥99/月),请联系管理员或访问订阅管理页面。"
4. 不要直接拒绝,而是提供免费版可用的替代方案(如果有的话)。
---
## 参考文档
在生成报告和图表时,请参考以下文档:
- **报告模板**:`references/report-templates.md` — 包含各类报告的标准模板和示例。
- **Mermaid 图表指南**:`references/mermaid-guide.md` — 包含 Mermaid 图表语法和最佳实践。
---
## 安全规范
1. **SQL 安全**:只允许 SELECT 查询,严禁任何写操作。所有用户输入必须转义处理。
2. **凭据保护**:数据库密码仅通过环境变量 `BDI_DB_PASSWORD` 传递,绝不在对话中显示、记录或输出密码。
3. **数据脱敏**:查询结果中的敏感字段(如手机号、身份证号、银行卡号)应自动脱敏处理后再展示。
4. **错误处理**:执行命令失败时,向用户展示友好的错误提示,不要暴露内部路径或系统信息。
---
## 行为准则
1. 始终使用中文与用户沟通。
2. 在执行任何查询前,先向用户展示将要执行的 SQL 并获得确认。
3. 对用户的问题给出清晰、结构化的回答,优先使用表格展示数据。
4. 主动提供数据洞察和业务建议,而不仅仅是返回原始数据。
5. 遇到模糊的用户意图时,主动追问以明确需求,而不是猜测执行。
6. 查询出错时,耐心排查并给出可行的解决方案。
7. 尊重订阅等级限制,在提示升级时保持友好,不要反复推销。
FILE:assets/README.md
# 📊 商业数据洞察 (biz-data-insight)
> 连接你的业务数据,AI自动生成分析报告和看板
---
## 功能亮点
- 🔌 **多数据源接入** — 支持 MySQL、PostgreSQL 数据库及 CSV / Excel / JSON 文件,一次配置长期使用
- 📈 **自动生成报告** — 每日、每周、每月报告自动产出,包含关键指标、趋势分析和可视化图表
- 🧠 **AI 智能洞察** — 自动识别数据异常、计算同比环比、提炼业务结论,不只是展示数字
- 💬 **交互式提问** — 用自然语言向数据提问,秒级返回分析结果和图表
- 📊 **Mermaid 可视化** — 报告内嵌饼图、折线图、柱状图,无需额外工具即可在 Markdown 中查看
- 🔒 **数据安全** — 所有查询在本地执行,数据不离开你的环境
---
## 版本对比
| 功能 | 免费版 | 付费版 ¥99/月 |
|------|:------:|:------------:|
| 数据源数量 | 1个 | 最多5个 |
| 数据库类型 | MySQL | MySQL + PostgreSQL |
| 文档类型 | CSV | CSV + Excel + JSON |
| 日报/周报 | 基础表格 | 表格 + Mermaid图表 + 洞察 |
| 月报 | ❌ | ✅ 完整多维度分析 |
| 交互式提问 | 5次/天 | 无限 |
| 异常检测 | ❌ | ✅ |
| 同比/环比 | ❌ | ✅ |
| 查询行数限制 | 100行 | 10,000行 |
---
## 快速开始
### 1. 安装 Skill
在 ClawHub 中搜索 `biz-data-insight`,点击安装,或使用命令行:
```bash
openclaw skill install biz-data-insight
```
### 2. 配置数据源
在项目根目录创建 `datasource.yaml` 文件:
```yaml
# 数据库连接示例
datasources:
- name: 业务主库
type: mysql
host: localhost
port: 3306
database: my_business_db
username: reader
password: DB_PASSWORD # 支持环境变量
# CSV 文件示例
- name: 销售数据
type: csv
path: ./data/sales_2026.csv
encoding: utf-8
```
### 3. 生成报告
```bash
# 生成日报
/biz-data-insight daily
# 生成周报
/biz-data-insight weekly
# 生成月报(付费版)
/biz-data-insight monthly
# 交互式提问
/biz-data-insight ask "上周销售额最高的产品是什么?"
```
### 4. 查看产出
报告自动保存到 `output/reports/` 目录,格式为 Markdown 文件,可直接在编辑器或 Git 仓库中查看。
---
## 报告示例
以下是一份自动生成的日报样例:
```markdown
# 📊 业务日报 — 2026-03-19(周四)
数据源:业务主库 | 统计周期:2026-03-19 00:00 ~ 23:59
## 核心指标
| 指标 | 今日 | 昨日 | 日环比 |
|------|-----:|-----:|-------:|
| 订单数 | 1,283 | 1,195 | +7.4% |
| 销售额 | ¥328,450 | ¥301,200 | +9.0% |
| 客单价 | ¥256 | ¥252 | +1.6% |
| 退货率 | 2.1% | 2.3% | -0.2pp |
## 分类销售占比
|饼图将在此处渲染|
## 今日摘要
- 订单数和销售额均高于昨日,主要受"春季促销"活动拉动
- 电子产品类目贡献了 42% 的销售额,为当日最大品类
- 退货率小幅下降,处于正常区间
---
*报告由 biz-data-insight 自动生成*
```
---
## 常见问题
### Q1: 支持哪些数据库版本?
MySQL 5.7+ 和 MySQL 8.x 均已测试通过。付费版额外支持 PostgreSQL 12+。
### Q2: 数据会上传到云端吗?
不会。所有数据查询和报告生成均在本地完成,数据不会离开你的运行环境。AI 分析基于查询结果的聚合数据进行,不会传输原始数据。
### Q3: 免费版可以同时连接多个数据源吗?
免费版限制为 1 个数据源。如果你需要同时分析多个库表或文件,请升级到付费版(支持最多 5 个数据源)。
### Q4: 如何自定义报告中的指标?
在 `datasource.yaml` 中为每个数据源添加 `metrics` 配置项,指定需要统计的字段、聚合方式和显示名称。详见 `references/report-templates.md`。
### Q5: Mermaid 图表在哪些工具中可以渲染?
Mermaid 图表在以下环境可直接渲染:GitHub / GitLab 的 Markdown 预览、VS Code(安装 Mermaid 插件)、Typora、Obsidian 等。大多数现代 Markdown 查看器均已支持。
### Q6: 报告生成速度如何?
取决于数据量和查询复杂度。通常情况下,日报在 5-15 秒内生成,周报约 15-30 秒,月报约 30-60 秒。
---
## 技术支持
- 📖 **文档**:查看 `references/` 目录获取模板和图表语法参考
- 🐛 **问题反馈**:在 ClawHub 的 Skill 页面提交 Issue
- 💬 **社区讨论**:加入 ClawHub 社区频道 `#biz-data-insight`
- 📧 **邮件**:[email protected]
---
*biz-data-insight v1.0 | 兼容 OpenClaw 0.5+*
FILE:scripts/query_engine.py
#!/usr/bin/env python3
"""
biz-data-insight OpenClaw Skill — 安全 SQL 查询引擎
用法:
python3 query_engine.py --type mysql --uri "mysql://user:pass@host:3306/db" --sql "SELECT ..."
python3 query_engine.py --type csv --uri "./data.csv" --sql "SELECT * FROM data LIMIT 5"
python3 query_engine.py --type mysql --uri "..." --template top_n --params '{"table":"orders","metric":"amount","dimension":"product","n":10}'
"""
import argparse
import json
import sqlite3
import sys
from typing import Any, Dict, List, Optional, Tuple
import pandas as pd
from utils import (
check_subscription,
get_datasource_connection,
output_error,
output_success,
validate_sql_readonly,
)
# ---------------------------------------------------------------------------
# 预定义查询模板
# ---------------------------------------------------------------------------
TEMPLATES: Dict[str, str] = {
"top_n": (
"SELECT {dimension}, SUM({metric}) AS total "
"FROM {table} "
"GROUP BY {dimension} "
"ORDER BY total DESC "
"LIMIT {n}"
),
"yoy": (
"SELECT "
" DATE_FORMAT({date_col}, '%Y-%m') AS period, "
" SUM(CASE WHEN YEAR({date_col}) = YEAR(CURDATE()) THEN {metric} ELSE 0 END) AS current_period, "
" SUM(CASE WHEN YEAR({date_col}) = YEAR(CURDATE()) - 1 THEN {metric} ELSE 0 END) AS last_year_period, "
" ROUND("
" (SUM(CASE WHEN YEAR({date_col}) = YEAR(CURDATE()) THEN {metric} ELSE 0 END) "
" - SUM(CASE WHEN YEAR({date_col}) = YEAR(CURDATE()) - 1 THEN {metric} ELSE 0 END)) "
" / NULLIF(SUM(CASE WHEN YEAR({date_col}) = YEAR(CURDATE()) - 1 THEN {metric} ELSE 0 END), 0) * 100, 2"
" ) AS yoy_pct "
"FROM {table} "
"GROUP BY period "
"ORDER BY period"
),
"mom": (
"SELECT "
" DATE_FORMAT({date_col}, '%Y-%m') AS period, "
" SUM({metric}) AS total, "
" ROUND("
" (SUM({metric}) - LAG(SUM({metric})) OVER (ORDER BY DATE_FORMAT({date_col}, '%Y-%m'))) "
" / NULLIF(LAG(SUM({metric})) OVER (ORDER BY DATE_FORMAT({date_col}, '%Y-%m')), 0) * 100, 2"
" ) AS mom_pct "
"FROM {table} "
"GROUP BY period "
"ORDER BY period"
),
"distribution": (
"SELECT {dimension}, SUM({metric}) AS total, "
"ROUND(SUM({metric}) * 100.0 / (SELECT SUM({metric}) FROM {table}), 2) AS pct "
"FROM {table} "
"GROUP BY {dimension} "
"ORDER BY total DESC"
),
"trend": (
"SELECT DATE_FORMAT({date_col}, '{date_format}') AS period, "
"SUM({metric}) AS total "
"FROM {table} "
"GROUP BY period "
"ORDER BY period"
),
}
# 模板所需参数
TEMPLATE_REQUIRED_PARAMS: Dict[str, List[str]] = {
"top_n": ["table", "metric", "dimension", "n"],
"yoy": ["table", "metric", "date_col"],
"mom": ["table", "metric", "date_col"],
"distribution": ["table", "metric", "dimension"],
"trend": ["table", "metric", "date_col"],
}
def build_sql_from_template(template_name: str, params: Dict[str, Any]) -> str:
"""根据模板名称和参数生成 SQL 语句。"""
if template_name not in TEMPLATES:
raise ValueError(f"未知的查询模板: {template_name},可用模板: {', '.join(TEMPLATES.keys())}")
required = TEMPLATE_REQUIRED_PARAMS[template_name]
missing = [p for p in required if p not in params]
if missing:
raise ValueError(f"模板 '{template_name}' 缺少必需参数: {', '.join(missing)}")
# 为 trend 模板设置默认日期格式
if template_name == "trend" and "date_format" not in params:
params["date_format"] = "%Y-%m"
# n 转为整数字符串
if "n" in params:
params["n"] = int(params["n"])
return TEMPLATES[template_name].format(**params)
# ---------------------------------------------------------------------------
# 数据库执行
# ---------------------------------------------------------------------------
def execute_db_query(
db_type: str,
uri: str,
sql: str,
password: Optional[str] = None,
max_rows: int = 1000,
) -> Tuple[List[str], List[list]]:
"""通过数据库驱动执行 SQL,返回 (columns, rows)。"""
conn = get_datasource_connection(db_type, uri, password=password)
try:
cursor = conn.cursor()
cursor.execute(sql)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
rows = cursor.fetchmany(max_rows)
# 将行转为普通 list(部分驱动返回 tuple)
rows = [list(row) for row in rows]
return columns, rows
finally:
try:
cursor.close()
except Exception:
pass
conn.close()
# ---------------------------------------------------------------------------
# 文件数据源执行(CSV / Excel / JSON → sqlite3 内存库)
# ---------------------------------------------------------------------------
def load_dataframe(file_type: str, uri: str) -> pd.DataFrame:
"""根据文件类型加载 DataFrame。"""
loaders = {
"csv": lambda p: pd.read_csv(p),
"excel": lambda p: pd.read_excel(p),
"json": lambda p: pd.read_json(p),
}
loader = loaders.get(file_type)
if loader is None:
raise ValueError(f"不支持的文件类型: {file_type}")
return loader(uri)
def execute_file_query(
file_type: str,
uri: str,
sql: str,
max_rows: int = 1000,
) -> Tuple[List[str], List[list]]:
"""将文件数据加载到 sqlite3 内存库并执行 SQL。"""
df = load_dataframe(file_type, uri)
conn = sqlite3.connect(":memory:")
try:
# 表名固定为 "data",用户 SQL 中应引用 "data"
df.to_sql("data", conn, index=False, if_exists="replace")
cursor = conn.execute(sql)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
rows = [list(row) for row in cursor.fetchmany(max_rows)]
return columns, rows
finally:
conn.close()
# ---------------------------------------------------------------------------
# 结果统计
# ---------------------------------------------------------------------------
def compute_stats(columns: List[str], rows: List[list]) -> Dict[str, Dict[str, Any]]:
"""为所有数值列计算 sum / avg / min / max / count。"""
if not rows or not columns:
return {}
stats: Dict[str, Dict[str, Any]] = {}
for idx, col_name in enumerate(columns):
numeric_values: List[float] = []
for row in rows:
val = row[idx]
if val is None:
continue
try:
numeric_values.append(float(val))
except (TypeError, ValueError):
continue
if not numeric_values:
continue
stats[col_name] = {
"sum": round(sum(numeric_values), 4),
"avg": round(sum(numeric_values) / len(numeric_values), 4),
"min": round(min(numeric_values), 4),
"max": round(max(numeric_values), 4),
"count": len(numeric_values),
}
return stats
# ---------------------------------------------------------------------------
# 序列化辅助
# ---------------------------------------------------------------------------
def _make_serializable(rows: List[list]) -> List[list]:
"""确保所有值可被 JSON 序列化。"""
import datetime
import decimal
cleaned: List[list] = []
for row in rows:
new_row: list = []
for val in row:
if isinstance(val, decimal.Decimal):
val = float(val)
elif isinstance(val, (datetime.date, datetime.datetime)):
val = val.isoformat()
elif isinstance(val, bytes):
val = val.decode("utf-8", errors="replace")
new_row.append(val)
cleaned.append(new_row)
return cleaned
# ---------------------------------------------------------------------------
# 主流程
# ---------------------------------------------------------------------------
DB_TYPES = {"mysql", "postgresql"}
FILE_TYPES = {"csv", "excel", "json"}
def main() -> None:
parser = argparse.ArgumentParser(description="biz-data-insight 安全 SQL 查询引擎")
parser.add_argument("--type", required=True, choices=sorted(DB_TYPES | FILE_TYPES),
help="数据源类型: mysql, postgresql, csv, excel, json")
parser.add_argument("--uri", required=True, help="数据库连接字符串或文件路径")
parser.add_argument("--password", default=None, help="数据库密码(可选)")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--sql", default=None, help="要执行的 SQL 查询语句")
group.add_argument("--template", default=None, choices=sorted(TEMPLATES.keys()),
help="预定义查询模板名称")
parser.add_argument("--params", default=None, help="模板参数(JSON 字符串)")
parser.add_argument("--max-rows", type=int, default=None, help="最大返回行数")
args = parser.parse_args()
# ---- 订阅层级 & 行数上限 ----
try:
subscription = check_subscription()
default_max_rows = subscription.get("max_rows", 1000)
except Exception:
default_max_rows = 1000
max_rows = args.max_rows if args.max_rows is not None else default_max_rows
# ---- 构造 SQL ----
if args.template:
if not args.params:
output_error("使用模板时必须通过 --params 提供参数(JSON 字符串)")
sys.exit(1)
try:
params = json.loads(args.params)
except json.JSONDecodeError as exc:
output_error(f"--params JSON 解析失败: {exc}")
sys.exit(1)
try:
sql = build_sql_from_template(args.template, params)
except ValueError as exc:
output_error(str(exc))
sys.exit(1)
else:
sql = args.sql
# ---- SQL 安全校验 ----
try:
validate_sql_readonly(sql)
except Exception as exc:
output_error(f"SQL 安全校验未通过: {exc}")
sys.exit(1)
# ---- 执行查询 ----
try:
ds_type = args.type
if ds_type in DB_TYPES:
columns, rows = execute_db_query(ds_type, args.uri, sql,
password=args.password, max_rows=max_rows)
elif ds_type in FILE_TYPES:
columns, rows = execute_file_query(ds_type, args.uri, sql, max_rows=max_rows)
else:
output_error(f"不支持的数据源类型: {ds_type}")
sys.exit(1)
except Exception as exc:
output_error(f"查询执行失败: {exc}")
sys.exit(1)
# ---- 格式化输出 ----
rows = _make_serializable(rows)
stats = compute_stats(columns, rows)
output_success({
"columns": columns,
"rows": rows,
"row_count": len(rows),
"stats": stats,
})
if __name__ == "__main__":
main()
FILE:scripts/anomaly_detector.py
#!/usr/bin/env python3
"""
biz-data-insight 异常检测脚本
检测业务数据中的异常值,支持标准差(sigma)和四分位距(IQR)两种方法。
此功能仅限付费用户使用。
用法示例:
python3 anomaly_detector.py --data '{"values": [100, 102, 98, 105, 250, 99], "labels": ["1月","2月","3月","4月","5月","6月"]}'
python3 anomaly_detector.py --data-file ./metrics.json --method both
"""
import argparse
import json
import math
import sys
from typing import Any, Dict, List, Optional
from utils import output_success, output_error, check_subscription, format_number
# ============================================================
# 统计计算
# ============================================================
def _mean(values: List[float]) -> float:
"""计算均值。"""
return sum(values) / len(values)
def _std(values: List[float], mean_val: float) -> float:
"""计算总体标准差。"""
variance = sum((v - mean_val) ** 2 for v in values) / len(values)
return math.sqrt(variance)
def _percentile(sorted_values: List[float], p: float) -> float:
"""使用线性插值计算百分位数。
Args:
sorted_values: 已排序的数值列表。
p: 百分位数(0-100)。
Returns:
对应百分位的数值。
"""
n = len(sorted_values)
k = (p / 100.0) * (n - 1)
f = math.floor(k)
c = math.ceil(k)
if f == c:
return sorted_values[int(k)]
return sorted_values[f] * (c - k) + sorted_values[c] * (k - f)
# ============================================================
# 异常检测方法
# ============================================================
def detect_sigma(
values: List[float],
labels: List[str],
metric_name: str,
unit: str,
sigma_threshold: float = 2.0,
) -> List[Dict[str, Any]]:
"""使用标准差方法检测异常值。
判定规则:
- |value - mean| > 2σ: severity = "warning"
- |value - mean| > 3σ: severity = "critical"
Args:
values: 数值列表。
labels: 对应的标签列表。
metric_name: 指标名称。
unit: 数值单位。
sigma_threshold: 判定异常的标准差倍数阈值,默认 2.0。
Returns:
检测到的异常列表。
"""
anomalies = []
mean_val = _mean(values)
std_val = _std(values, mean_val)
if std_val == 0:
return anomalies
for i, value in enumerate(values):
deviation = abs(value - mean_val) / std_val
if deviation > sigma_threshold:
severity = "critical" if deviation > 3.0 else "warning"
direction = "偏高" if value > mean_val else "偏低"
description = (
f"{metric_name}在{labels[i]}异常{direction}"
f"({format_number(value, 2)}{unit}),"
f"超出均值{round(deviation, 1)}个标准差"
)
anomalies.append({
"index": i,
"label": labels[i],
"value": value,
"metric_name": metric_name,
"method": "sigma",
"severity": severity,
"description": description,
"stats": {
"mean": round(mean_val, 2),
"std": round(std_val, 2),
"deviation": round(deviation, 1),
},
})
return anomalies
def detect_iqr(
values: List[float],
labels: List[str],
metric_name: str,
unit: str,
) -> List[Dict[str, Any]]:
"""使用四分位距(IQR)方法检测异常值。
判定规则:
- 超出 1.5*IQR 范围: severity = "warning"
- 超出 3*IQR 范围: severity = "critical"
Args:
values: 数值列表。
labels: 对应的标签列表。
metric_name: 指标名称。
unit: 数值单位。
Returns:
检测到的异常列表。
"""
anomalies = []
sorted_vals = sorted(values)
q1 = _percentile(sorted_vals, 25)
q3 = _percentile(sorted_vals, 75)
iqr = q3 - q1
if iqr == 0:
return anomalies
lower_warning = q1 - 1.5 * iqr
upper_warning = q3 + 1.5 * iqr
lower_critical = q1 - 3.0 * iqr
upper_critical = q3 + 3.0 * iqr
for i, value in enumerate(values):
if value < lower_warning or value > upper_warning:
if value < lower_critical or value > upper_critical:
severity = "critical"
else:
severity = "warning"
direction = "偏高" if value > upper_warning else "偏低"
description = (
f"{metric_name}在{labels[i]}异常{direction}"
f"({format_number(value, 2)}{unit}),"
f"超出IQR正常范围"
)
anomalies.append({
"index": i,
"label": labels[i],
"value": value,
"metric_name": metric_name,
"method": "iqr",
"severity": severity,
"description": description,
"stats": {
"q1": round(q1, 2),
"q3": round(q3, 2),
"iqr": round(iqr, 2),
"lower_bound": round(lower_warning, 2),
"upper_bound": round(upper_warning, 2),
},
})
return anomalies
# ============================================================
# 数据解析与主逻辑
# ============================================================
def _parse_metrics(data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""将输入数据统一解析为指标列表。
支持单指标格式和多指标格式(metrics 数组)。
Args:
data: 输入的 JSON 数据字典。
Returns:
指标列表,每个元素包含 name、values、labels、unit。
Raises:
ValueError: 当数据格式不合法时抛出。
"""
if "metrics" in data:
# 多指标格式
metrics = data["metrics"]
if not isinstance(metrics, list) or len(metrics) == 0:
raise ValueError("metrics 必须为非空数组")
result = []
for idx, m in enumerate(metrics):
if "values" not in m:
raise ValueError(f"第 {idx + 1} 个指标缺少 values 字段")
values = m["values"]
labels = m.get("labels", [str(i + 1) for i in range(len(values))])
if len(labels) != len(values):
raise ValueError(
f"第 {idx + 1} 个指标的 labels 数量({len(labels)})"
f"与 values 数量({len(values)})不一致"
)
result.append({
"name": m.get("name", f"指标{idx + 1}"),
"values": [float(v) for v in values],
"labels": labels,
"unit": m.get("unit", ""),
})
return result
elif "values" in data:
# 单指标格式
values = data["values"]
labels = data.get("labels", [str(i + 1) for i in range(len(values))])
if len(labels) != len(values):
raise ValueError(
f"labels 数量({len(labels)})与 values 数量({len(values)})不一致"
)
return [{
"name": data.get("metric_name", "数据"),
"values": [float(v) for v in values],
"labels": labels,
"unit": data.get("unit", ""),
}]
else:
raise ValueError("输入数据必须包含 values 或 metrics 字段")
def run_detection(
data: Dict[str, Any],
method: str = "both",
sigma_threshold: float = 2.0,
) -> Dict[str, Any]:
"""执行异常检测并返回结果。
Args:
data: 输入数据字典。
method: 检测方法,"sigma"、"iqr" 或 "both"。
sigma_threshold: 标准差阈值,默认 2.0。
Returns:
包含 anomalies 和 summary 的结果字典。
"""
metrics = _parse_metrics(data)
all_anomalies: List[Dict[str, Any]] = []
total_checked = 0
methods_used: List[str] = []
if method in ("sigma", "both"):
methods_used.append("sigma")
if method in ("iqr", "both"):
methods_used.append("iqr")
for metric in metrics:
values = metric["values"]
labels = metric["labels"]
name = metric["name"]
unit = metric["unit"]
total_checked += len(values)
if len(values) < 3:
# 数据点过少,跳过检测
continue
if method in ("sigma", "both"):
all_anomalies.extend(
detect_sigma(values, labels, name, unit, sigma_threshold)
)
if method in ("iqr", "both"):
all_anomalies.extend(
detect_iqr(values, labels, name, unit)
)
# 去重:同一索引同一指标可能被两种方法同时检出,保留全部但统计去重数量
severity_dist = {"warning": 0, "critical": 0}
for a in all_anomalies:
severity_dist[a["severity"]] += 1
return {
"anomalies": all_anomalies,
"summary": {
"total_checked": total_checked,
"anomaly_count": len(all_anomalies),
"methods_used": methods_used,
"severity_distribution": severity_dist,
},
}
def build_parser() -> argparse.ArgumentParser:
"""构建命令行参数解析器。"""
parser = argparse.ArgumentParser(
description="业务数据异常检测工具(付费功能)",
)
parser.add_argument(
"--data",
type=str,
default=None,
help="JSON 格式的数据字符串",
)
parser.add_argument(
"--data-file",
type=str,
default=None,
help="JSON 数据文件路径(与 --data 二选一)",
)
parser.add_argument(
"--method",
type=str,
choices=["sigma", "iqr", "both"],
default="both",
help="检测方法: sigma(标准差)、iqr(四分位距)、both(两者都用),默认 both",
)
parser.add_argument(
"--sigma-threshold",
type=float,
default=2.0,
help="标准差检测阈值,默认 2.0",
)
parser.add_argument(
"--tier",
type=str,
default=None,
help="订阅等级覆盖(free 或 paid)",
)
return parser
def main() -> None:
"""主入口函数。"""
parser = build_parser()
args = parser.parse_args()
# 订阅校验:仅付费用户可使用
try:
subscription = check_subscription(args.tier)
except ValueError as e:
output_error(str(e), code="SUBSCRIPTION_ERROR")
sys.exit(1)
if subscription["tier"] != "paid":
output_error(
"异常检测为付费功能,请升级到付费版(¥99/月)",
code="SUBSCRIPTION_REQUIRED",
)
sys.exit(1)
# 读取输入数据
if args.data is not None:
try:
data = json.loads(args.data)
except json.JSONDecodeError as e:
output_error(f"JSON 解析失败: {e}", code="INVALID_INPUT")
sys.exit(1)
elif args.data_file is not None:
try:
with open(args.data_file, "r", encoding="utf-8") as f:
data = json.load(f)
except FileNotFoundError:
output_error(f"文件不存在: {args.data_file}", code="FILE_NOT_FOUND")
sys.exit(1)
except json.JSONDecodeError as e:
output_error(f"文件 JSON 解析失败: {e}", code="INVALID_INPUT")
sys.exit(1)
except OSError as e:
output_error(f"文件读取失败: {e}", code="FILE_ERROR")
sys.exit(1)
else:
output_error("请通过 --data 或 --data-file 提供输入数据", code="MISSING_INPUT")
sys.exit(1)
if not isinstance(data, dict):
output_error(
f"输入数据必须为 JSON 对象,当前类型为 {type(data).__name__}",
code="INVALID_INPUT",
)
sys.exit(1)
# 执行异常检测
try:
result = run_detection(
data=data,
method=args.method,
sigma_threshold=args.sigma_threshold,
)
except ValueError as e:
output_error(str(e), code="DETECTION_ERROR")
sys.exit(1)
output_success(result)
if __name__ == "__main__":
main()
FILE:scripts/connect_datasource.py
#!/usr/bin/env python3
"""
biz-data-insight: 数据源连接与Schema探索脚本
用法:
python3 connect_datasource.py --type mysql --uri "mysql://user:pass@host:3306/db" --action test
python3 connect_datasource.py --type csv --uri "./data.csv" --action explore
"""
import argparse
import json
import os
import re
import sys
from urllib.parse import urlparse
from utils import output_success, output_error, get_datasource_connection, check_subscription
# 业务指标字段匹配模式
METRIC_PATTERNS = re.compile(
r"(amount|revenue|price|cost|quantity|order|sales|total|count|profit|"
r"income|expense|fee|payment|rate|ratio|score|num|"
r"金额|销售|收入|成本|数量|订单|利润|费用)",
re.IGNORECASE,
)
# 时间维度字段匹配模式
TIME_PATTERNS = re.compile(
r"(date|time|created|updated|year|month|day|日期|时间|创建|更新)",
re.IGNORECASE,
)
SUPPORTED_TYPES = ("mysql", "postgresql", "csv", "excel", "json")
FILE_TYPES = ("csv", "excel", "json")
DB_TYPES = ("mysql", "postgresql")
def _classify_column(name: str, dtype_str: str):
"""根据字段名和类型判断字段角色。"""
roles = []
if METRIC_PATTERNS.search(name):
roles.append("metric")
if TIME_PATTERNS.search(name):
roles.append("time_dimension")
return roles
# ---------------------------------------------------------------------------
# 数据库连接相关
# ---------------------------------------------------------------------------
def _get_db_connection(ds_type: str, uri: str, password: str | None = None):
"""创建数据库连接。"""
parsed = urlparse(uri)
host = parsed.hostname or "localhost"
port = parsed.port
user = parsed.username or "root"
pwd = password or parsed.password or os.environ.get("BDI_DB_PASSWORD", "")
database = parsed.path.lstrip("/") if parsed.path else None
if ds_type == "mysql":
import mysql.connector
port = port or 3306
conn = mysql.connector.connect(
host=host,
port=port,
user=user,
password=pwd,
database=database,
)
return conn
elif ds_type == "postgresql":
import psycopg2
port = port or 5432
conn = psycopg2.connect(
host=host,
port=port,
user=user,
password=pwd,
dbname=database,
)
return conn
else:
raise ValueError(f"不支持的数据库类型: {ds_type}")
def _test_db(ds_type: str, uri: str, password: str | None = None) -> dict:
"""测试数据库连接。"""
conn = _get_db_connection(ds_type, uri, password)
try:
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
cursor.close()
parsed = urlparse(uri)
details = {
"host": parsed.hostname or "localhost",
"port": parsed.port or (3306 if ds_type == "mysql" else 5432),
"database": (parsed.path.lstrip("/") if parsed.path else None),
"user": parsed.username or "root",
}
return {"success": True, "message": "连接成功", "type": ds_type, "details": details}
finally:
conn.close()
def _explore_db(ds_type: str, uri: str, password: str | None = None) -> dict:
"""探索数据库Schema。"""
conn = _get_db_connection(ds_type, uri, password)
try:
cursor = conn.cursor()
# 获取所有表
if ds_type == "mysql":
cursor.execute("SHOW TABLES")
tables = [row[0] for row in cursor.fetchall()]
else:
cursor.execute(
"SELECT table_name FROM information_schema.tables "
"WHERE table_schema = 'public' ORDER BY table_name"
)
tables = [row[0] for row in cursor.fetchall()]
schema = []
for table in tables:
table_info = {"table": table, "columns": [], "row_count": 0, "sample_data": []}
# 列信息
if ds_type == "mysql":
cursor.execute(f"DESCRIBE `{table}`")
for row in cursor.fetchall():
col_name = row[0]
col_type = row[1]
roles = _classify_column(col_name, col_type)
table_info["columns"].append({
"name": col_name,
"type": col_type,
"roles": roles,
})
else:
cursor.execute(
"SELECT column_name, data_type FROM information_schema.columns "
"WHERE table_schema = 'public' AND table_name = %s "
"ORDER BY ordinal_position",
(table,),
)
for row in cursor.fetchall():
col_name = row[0]
col_type = row[1]
roles = _classify_column(col_name, col_type)
table_info["columns"].append({
"name": col_name,
"type": col_type,
"roles": roles,
})
# 行数
cursor.execute(f'SELECT COUNT(*) FROM "{table}"' if ds_type == "postgresql" else f"SELECT COUNT(*) FROM `{table}`")
table_info["row_count"] = cursor.fetchone()[0]
# 样本数据 (5行)
cursor.execute(f'SELECT * FROM "{table}" LIMIT 5' if ds_type == "postgresql" else f"SELECT * FROM `{table}` LIMIT 5")
col_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
table_info["sample_data"] = [
{col_names[i]: _serialize_value(val) for i, val in enumerate(row)}
for row in rows
]
schema.append(table_info)
cursor.close()
return {
"success": True,
"message": "Schema探索完成",
"type": ds_type,
"tables_count": len(tables),
"schema": schema,
}
finally:
conn.close()
# ---------------------------------------------------------------------------
# 文件数据源相关
# ---------------------------------------------------------------------------
def _read_file(ds_type: str, uri: str):
"""用 pandas 读取文件,返回 DataFrame。"""
import pandas as pd
if ds_type == "csv":
return pd.read_csv(uri)
elif ds_type == "excel":
return pd.read_excel(uri)
elif ds_type == "json":
return pd.read_json(uri)
else:
raise ValueError(f"不支持的文件类型: {ds_type}")
def _test_file(ds_type: str, uri: str) -> dict:
"""测试文件数据源可读性。"""
if not os.path.exists(uri):
raise FileNotFoundError(f"文件不存在: {uri}")
df = _read_file(ds_type, uri)
details = {
"path": os.path.abspath(uri),
"rows": len(df),
"columns": len(df.columns),
}
return {"success": True, "message": "连接成功", "type": ds_type, "details": details}
def _explore_file(ds_type: str, uri: str) -> dict:
"""探索文件数据源Schema。"""
import pandas as pd
if not os.path.exists(uri):
raise FileNotFoundError(f"文件不存在: {uri}")
df = _read_file(ds_type, uri)
columns = []
for col in df.columns:
dtype_str = str(df[col].dtype)
roles = _classify_column(str(col), dtype_str)
# 自动识别: 数值列视为指标, datetime列视为时间维度
if pd.api.types.is_numeric_dtype(df[col]) and "metric" not in roles:
roles.append("metric")
if pd.api.types.is_datetime64_any_dtype(df[col]) and "time_dimension" not in roles:
roles.append("time_dimension")
columns.append({
"name": str(col),
"type": dtype_str,
"roles": roles,
})
# 样本数据 (5行)
sample_df = df.head(5)
sample_data = json.loads(sample_df.to_json(orient="records", force_ascii=False, date_format="iso"))
table_info = {
"table": "data",
"columns": columns,
"row_count": len(df),
"sample_data": sample_data,
}
return {
"success": True,
"message": "Schema探索完成",
"type": ds_type,
"tables_count": 1,
"schema": [table_info],
}
# ---------------------------------------------------------------------------
# 工具函数
# ---------------------------------------------------------------------------
def _serialize_value(val):
"""将数据库值序列化为JSON兼容类型。"""
if val is None:
return None
if isinstance(val, (int, float, bool, str)):
return val
if isinstance(val, bytes):
return val.decode("utf-8", errors="replace")
# datetime, date, Decimal 等
return str(val)
# ---------------------------------------------------------------------------
# 主入口
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="biz-data-insight 数据源连接与Schema探索")
parser.add_argument("--type", required=True, choices=SUPPORTED_TYPES, help="数据源类型")
parser.add_argument("--uri", required=True, help="连接字符串或文件路径")
parser.add_argument("--password", default=None, help="数据库密码(可选,也可通过URI或环境变量 BDI_DB_PASSWORD 提供)")
parser.add_argument("--action", required=True, choices=("test", "explore"), help="操作类型: test=测试连接, explore=探索Schema")
args = parser.parse_args()
try:
if args.action == "test":
if args.type in DB_TYPES:
result = _test_db(args.type, args.uri, args.password)
else:
result = _test_file(args.type, args.uri)
elif args.action == "explore":
if args.type in DB_TYPES:
result = _explore_db(args.type, args.uri, args.password)
else:
result = _explore_file(args.type, args.uri)
else:
result = {"success": False, "message": f"未知操作: {args.action}"}
print(json.dumps(result, ensure_ascii=False, indent=2))
except FileNotFoundError as e:
print(json.dumps({"success": False, "message": f"文件未找到: {e}"}, ensure_ascii=False))
sys.exit(1)
except ImportError as e:
print(json.dumps({"success": False, "message": f"缺少依赖包,请安装: {e}"}, ensure_ascii=False))
sys.exit(1)
except Exception as e:
print(json.dumps({"success": False, "message": f"操作失败: {e}"}, ensure_ascii=False))
sys.exit(1)
if __name__ == "__main__":
main()
FILE:scripts/utils.py
#!/usr/bin/env python3
"""
biz-data-insight 共享工具模块
提供数据格式化、输入输出处理、数据源连接、订阅校验等通用功能。
"""
import argparse
import json
import os
import re
import sys
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
# ============================================================
# 数字格式化
# ============================================================
def format_number(value: float, decimals: int = 2) -> str:
"""将数字格式化为带千分位分隔符的字符串。
Args:
value: 待格式化的数值。
decimals: 小数位数,默认为 2。
Returns:
格式化后的字符串,例如 1234567 → "1,234,567.00"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
if decimals <= 0:
return f"{int(round(num)):,}"
return f"{num:,.{decimals}f}"
def format_percentage(value: float, decimals: int = 1) -> str:
"""将小数格式化为百分比字符串。
Args:
value: 待格式化的小数值(0.156 表示 15.6%)。
decimals: 百分比小数位数,默认为 1。
Returns:
百分比字符串,例如 0.156 → "15.6%"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
pct = num * 100
return f"{pct:.{decimals}f}%"
def format_chinese_unit(value: float) -> str:
"""将大数字转换为中文单位表示(万、亿)。
Args:
value: 待转换的数值。
Returns:
带中文单位的字符串,例如:
- 12345 → "1.23万"
- 123456789 → "1.23亿"
- 999 → "999"
Raises:
TypeError: 当 value 无法转换为数字时抛出。
"""
try:
num = float(value)
except (TypeError, ValueError):
raise TypeError(f"无法将 {value!r} 转换为数字")
abs_num = abs(num)
sign = "-" if num < 0 else ""
if abs_num >= 1e8:
# 亿级别
result = abs_num / 1e8
return f"{sign}{result:.2f}亿"
elif abs_num >= 1e4:
# 万级别
result = abs_num / 1e4
return f"{sign}{result:.2f}万"
else:
# 不足万,直接返回
if abs_num == int(abs_num):
return f"{sign}{int(abs_num)}"
return f"{sign}{abs_num:.2f}"
# ============================================================
# JSON 输入输出
# ============================================================
def read_json_stdin() -> Dict[str, Any]:
"""从标准输入读取 JSON 数据并解析为字典。
Returns:
解析后的字典对象。
Raises:
ValueError: 当输入为空或 JSON 格式不合法时抛出。
"""
try:
raw = sys.stdin.read()
except Exception as e:
raise ValueError(f"读取标准输入失败: {e}")
if not raw.strip():
raise ValueError("标准输入为空,未读取到任何数据")
try:
data = json.loads(raw)
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
if not isinstance(data, dict):
raise ValueError(f"期望输入为 JSON 对象,实际类型为 {type(data).__name__}")
return data
def output_json(data: Any) -> None:
"""将数据以 JSON 格式输出到标准输出。
使用 ensure_ascii=False 以保留中文等非 ASCII 字符。
Args:
data: 待输出的数据(可被 JSON 序列化的任意对象)。
"""
print(json.dumps(data, ensure_ascii=False, default=str))
def output_error(message: str, code: str = "ERROR") -> None:
"""输出标准错误响应到标准输出。
输出格式:
{"success": false, "error": {"code": "<code>", "message": "<message>"}}
Args:
message: 错误描述信息。
code: 错误代码,默认为 "ERROR"。
"""
result = {
"success": False,
"error": {
"code": code,
"message": message,
},
}
output_json(result)
def output_success(data: Any) -> None:
"""输出标准成功响应到标准输出。
输出格式:
{"success": true, "data": <data>}
Args:
data: 成功时返回的数据负载。
"""
result = {
"success": True,
"data": data,
}
output_json(result)
# ============================================================
# 命令行参数解析
# ============================================================
def parse_common_args() -> argparse.Namespace:
"""创建并解析通用命令行参数。
支持的参数:
--type 数据源类型(如 mysql、postgresql、csv、excel、json)
--uri 数据源连接 URI 或文件路径
--password 数据源密码(可选)
Returns:
解析后的参数命名空间。
"""
parser = argparse.ArgumentParser(
description="biz-data-insight 数据分析工具",
)
parser.add_argument(
"--type",
required=True,
help="数据源类型,例如: mysql, postgresql, csv, excel, json",
)
parser.add_argument(
"--uri",
required=True,
help="数据源连接 URI 或文件路径",
)
parser.add_argument(
"--password",
default=None,
help="数据源密码(可选)",
)
return parser.parse_args()
# ============================================================
# 数据源连接
# ============================================================
def get_datasource_connection(
ds_type: str,
uri: str,
password: Optional[str] = None,
) -> Dict[str, Any]:
"""根据数据源类型和 URI 解析连接参数。
支持的类型:
- mysql: mysql://host:port/db
- postgresql: postgresql://host:port/db
- csv/excel/json: 本地文件路径
Args:
ds_type: 数据源类型。
uri: 连接 URI 或文件路径。
password: 可选密码,会覆盖 URI 中的密码。
Returns:
包含连接参数的字典。
Raises:
ValueError: 当类型不支持或 URI 格式无效时抛出。
"""
ds_type = ds_type.strip().lower()
# 数据库类型
if ds_type in ("mysql", "postgresql", "postgres"):
parsed = urlparse(uri)
host = parsed.hostname
port = parsed.port
database = parsed.path.lstrip("/") if parsed.path else None
user = parsed.username
if not host:
raise ValueError(f"无法从 URI 中解析主机地址: {uri}")
if not database:
raise ValueError(f"无法从 URI 中解析数据库名称: {uri}")
# 设置默认端口
default_ports = {"mysql": 3306, "postgresql": 5432, "postgres": 5432}
if port is None:
port = default_ports.get(ds_type, None)
# 规范化类型名称
canonical_type = "postgresql" if ds_type == "postgres" else ds_type
# 密码优先使用参数传入的,其次使用 URI 中的
effective_password = password if password is not None else parsed.password
return {
"type": canonical_type,
"host": host,
"port": port,
"database": database,
"user": user,
"password": effective_password,
}
# 文件类型
elif ds_type in ("csv", "excel", "json"):
file_path = uri
if not os.path.exists(file_path):
raise ValueError(f"文件不存在: {file_path}")
return {
"type": ds_type,
"file_path": os.path.abspath(file_path),
}
else:
supported = "mysql, postgresql, csv, excel, json"
raise ValueError(f"不支持的数据源类型: {ds_type!r},支持的类型: {supported}")
# ============================================================
# 订阅校验
# ============================================================
# 订阅等级配置
_SUBSCRIPTION_TIERS: Dict[str, Dict[str, Any]] = {
"free": {
"tier": "free",
"max_datasources": 1,
"max_rows": 100,
"daily_questions": 5,
"features": ["basic_query", "simple_chart"],
},
"paid": {
"tier": "paid",
"max_datasources": 5,
"max_rows": 10000,
"daily_questions": -1, # -1 表示无限制
"features": [
"basic_query",
"simple_chart",
"advanced_analysis",
"export",
"scheduled_report",
],
},
}
def check_subscription(tier: Optional[str] = None) -> Dict[str, Any]:
"""检查当前订阅等级并返回对应的限制配置。
优先使用传入的 tier 参数,否则读取环境变量 BDI_SUBSCRIPTION_TIER。
若都未设置,默认为 "free" 等级。
Args:
tier: 订阅等级("free" 或 "paid"),可选。
Returns:
包含订阅限制信息的字典,例如:
{
"tier": "free",
"max_datasources": 1,
"max_rows": 100,
"daily_questions": 5,
"features": ["basic_query", "simple_chart"]
}
Raises:
ValueError: 当传入的等级无效时抛出。
"""
if tier is None:
tier = os.environ.get("BDI_SUBSCRIPTION_TIER", "free")
tier = tier.strip().lower()
if tier not in _SUBSCRIPTION_TIERS:
valid = ", ".join(_SUBSCRIPTION_TIERS.keys())
raise ValueError(f"无效的订阅等级: {tier!r},有效等级: {valid}")
# 返回副本,避免外部修改影响全局配置
return dict(_SUBSCRIPTION_TIERS[tier])
# ============================================================
# SQL 安全校验
# ============================================================
# 禁止的 SQL 关键字模式(DDL / DML / 管理语句)
_FORBIDDEN_SQL_PATTERNS: List[str] = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bALTER\b",
r"\bCREATE\b",
r"\bTRUNCATE\b",
r"\bREPLACE\b",
r"\bGRANT\b",
r"\bREVOKE\b",
r"\bEXEC\b",
r"\bEXECUTE\b",
r"\bCALL\b",
r"\bMERGE\b",
r"\bRENAME\b",
r"\bLOAD\b",
r"\bIMPORT\b",
r"\bCOPY\b",
]
def validate_sql_readonly(sql: str) -> None:
"""校验 SQL 语句是否为只读查询(仅允许 SELECT)。
检测并拒绝所有 DDL(数据定义)、DML(数据操作)及管理语句,
仅允许 SELECT 查询通过。
Args:
sql: 待校验的 SQL 语句。
Raises:
ValueError: 当检测到非只读操作时抛出,错误信息包含被拦截的关键字。
"""
if not sql or not sql.strip():
raise ValueError("SQL 语句不能为空")
# 移除 SQL 注释(单行 -- 和多行 /* */)
cleaned = re.sub(r"--[^\n]*", " ", sql)
cleaned = re.sub(r"/\*.*?\*/", " ", cleaned, flags=re.DOTALL)
cleaned = cleaned.strip()
# 检查是否以 SELECT 或 WITH(CTE)开头
upper = cleaned.upper().lstrip()
if not (upper.startswith("SELECT") or upper.startswith("WITH") or upper.startswith("(")):
raise ValueError(
f"仅允许 SELECT 查询,当前语句以 {cleaned.split()[0]!r} 开头"
)
# 检查是否包含禁止的关键字
for pattern in _FORBIDDEN_SQL_PATTERNS:
match = re.search(pattern, cleaned, re.IGNORECASE)
if match:
keyword = match.group(0).upper()
raise ValueError(
f"检测到禁止的 SQL 操作: {keyword},仅允许只读 SELECT 查询"
)
FILE:scripts/report_generator.py
#!/usr/bin/env python3
"""
biz-data-insight 报告生成器
从结构化数据生成 Markdown + Mermaid 格式的业务分析报告。
支持日报、周报、月报和交互问答四种模板。
用法:
python3 report_generator.py --template daily --data '{"metrics":[...]}'
python3 report_generator.py --template weekly --data-file ./query_results.json
"""
import argparse
import json
import sys
from typing import Any, Dict, List, Optional
from utils import (
format_number,
format_percentage,
format_chinese_unit,
output_success,
output_error,
check_subscription,
)
# ============================================================
# Mermaid 图表生成
# ============================================================
def generate_pie_chart(title: str, data: List[Dict[str, Any]]) -> str:
"""生成 Mermaid 饼图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
Returns:
Mermaid 饼图代码块字符串。
"""
lines = [f'```mermaid', f'pie title {title}']
for item in data:
label = item.get("label", "未知")
value = item.get("value", 0)
lines.append(f' "{label}" : {value}')
lines.append("```")
return "\n".join(lines)
def generate_line_chart(
title: str,
data: List[Dict[str, Any]],
x_label: str = "时间",
y_label: str = "数值",
) -> str:
"""生成 Mermaid xychart-beta 折线图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
x_label: X 轴标签。
y_label: Y 轴标签。
Returns:
Mermaid 折线图代码块字符串。
"""
labels = [f'"{item.get("label", "")}"' for item in data]
values = [str(item.get("value", 0)) for item in data]
lines = [
"```mermaid",
"xychart-beta",
f' title "{title}"',
f' x-axis [{", ".join(labels)}]',
f' y-axis "{y_label}"',
f' line [{", ".join(values)}]',
"```",
]
return "\n".join(lines)
def generate_bar_chart(
title: str,
data: List[Dict[str, Any]],
x_label: str = "类别",
y_label: str = "数值",
) -> str:
"""生成 Mermaid xychart-beta 柱状图。
Args:
title: 图表标题。
data: 数据列表,每项包含 label 和 value。
x_label: X 轴标签。
y_label: Y 轴标签。
Returns:
Mermaid 柱状图代码块字符串。
"""
labels = [f'"{item.get("label", "")}"' for item in data]
values = [str(item.get("value", 0)) for item in data]
lines = [
"```mermaid",
"xychart-beta",
f' title "{title}"',
f' x-axis [{", ".join(labels)}]',
f' y-axis "{y_label}"',
f' bar [{", ".join(values)}]',
"```",
]
return "\n".join(lines)
# ============================================================
# 指标分析辅助函数
# ============================================================
def _calc_change(current: float, previous: float) -> Optional[float]:
"""计算环比变化率。
Args:
current: 当前值。
previous: 上一期值。
Returns:
变化率(小数形式),previous 为 0 时返回 None。
"""
if previous == 0:
return None
return (current - previous) / previous
def _describe_change(current: float, previous: float) -> str:
"""描述指标环比变化。
Args:
current: 当前值。
previous: 上一期值。
Returns:
环比变化描述文本,如 "环比增长5.9%" 或 "环比下降3.0%"。
"""
change = _calc_change(current, previous)
if change is None:
return "无可比数据"
abs_change = abs(change)
pct_str = format_percentage(abs_change)
if change > 0:
return f"环比增长{pct_str}"
elif change < 0:
return f"环比下降{pct_str}"
else:
return "环比持平"
def _build_metrics_table(metrics: List[Dict[str, Any]]) -> str:
"""生成核心指标 Markdown 表格。
Args:
metrics: 指标列表。
Returns:
Markdown 表格字符串。
"""
lines = [
"| 指标 | 当前值 | 上期值 | 环比变化 |",
"|------|--------|--------|----------|",
]
for m in metrics:
name = m.get("name", "未知")
value = m.get("value", 0)
previous = m.get("previous")
unit = m.get("unit", "")
formatted_value = f"{format_chinese_unit(value)}{unit}"
if previous is not None:
formatted_prev = f"{format_chinese_unit(previous)}{unit}"
change_desc = _describe_change(value, previous)
else:
formatted_prev = "-"
change_desc = "-"
lines.append(f"| {name} | {formatted_value} | {formatted_prev} | {change_desc} |")
return "\n".join(lines)
def _build_anomaly_section(anomalies: List[Dict[str, Any]]) -> str:
"""生成异常告警章节。
Args:
anomalies: 异常列表。
Returns:
Markdown 异常告警内容。
"""
if not anomalies:
return ""
lines = ["## 异常告警\n"]
for a in anomalies:
metric = a.get("metric", "未知指标")
value = a.get("value", 0)
threshold = a.get("threshold", 0)
severity = a.get("severity", "info")
description = a.get("description", "")
severity_icon = "⚠️" if severity == "warning" else "🔴" if severity == "critical" else "ℹ️"
lines.append(
f"- {severity_icon} **{metric}**: 当前值 {value},阈值 {threshold}。{description}"
)
return "\n".join(lines)
def _build_dimension_section(
dimensions: List[Dict[str, Any]],
is_paid: bool,
) -> str:
"""生成维度分析章节(含可选图表)。
Args:
dimensions: 维度列表。
is_paid: 是否为付费用户。
Returns:
Markdown 维度分析内容(付费用户含 Mermaid 图表)。
"""
if not dimensions:
return ""
charts_parts: List[str] = []
for dim in dimensions:
name = dim.get("name", "未知维度")
dim_type = dim.get("type", "")
data = dim.get("data", [])
charts_parts.append(f"### {name}\n")
# 数据表格
charts_parts.append("| 项目 | 数值 |")
charts_parts.append("|------|------|")
for item in data:
label = item.get("label", "")
value = item.get("value", 0)
charts_parts.append(f"| {label} | {format_chinese_unit(value)} |")
charts_parts.append("")
# 付费用户生成图表
if is_paid and data:
if dim_type == "distribution":
charts_parts.append(generate_pie_chart(name, data))
elif dim_type == "trend":
charts_parts.append(generate_line_chart(name, data, x_label="时间", y_label="数值"))
elif dim_type == "ranking":
charts_parts.append(generate_bar_chart(name, data, x_label="项目", y_label="数值"))
charts_parts.append("")
return "\n".join(charts_parts)
# ============================================================
# 报告模板
# ============================================================
def _render_daily(
data: Dict[str, Any],
title: Optional[str],
is_paid: bool,
) -> str:
"""渲染日报模板。
Args:
data: 报告数据。
title: 可选标题覆盖。
is_paid: 是否为付费用户。
Returns:
完整的 Markdown 日报内容。
"""
period = data.get("period", "未知日期")
header = title or f"📊 {period} 业务日报"
metrics = data.get("metrics", [])
dimensions = data.get("dimensions", [])
parts: List[str] = [f"# {header}\n"]
# 核心指标
parts.append("## 核心指标\n")
parts.append(_build_metrics_table(metrics))
parts.append("")
# 指标洞察
parts.append("## 指标洞察\n")
for m in metrics:
name = m.get("name", "")
value = m.get("value", 0)
previous = m.get("previous")
unit = m.get("unit", "")
if previous is not None:
change_desc = _describe_change(value, previous)
parts.append(f"- **{name}**: {format_chinese_unit(value)}{unit},{change_desc}")
else:
parts.append(f"- **{name}**: {format_chinese_unit(value)}{unit}")
parts.append("")
# 维度分析
if dimensions:
parts.append("## 维度分析\n")
parts.append(_build_dimension_section(dimensions, is_paid))
return "\n".join(parts)
def _render_weekly(
data: Dict[str, Any],
title: Optional[str],
is_paid: bool,
) -> str:
"""渲染周报模板。
Args:
data: 报告数据。
title: 可选标题覆盖。
is_paid: 是否为付费用户。
Returns:
完整的 Markdown 周报内容。
"""
period = data.get("period", "未知周期")
header = title or f"📊 {period} 业务周报"
metrics = data.get("metrics", [])
dimensions = data.get("dimensions", [])
anomalies = data.get("anomalies", [])
parts: List[str] = [f"# {header}\n"]
# 本周概览
parts.append("## 本周概览\n")
parts.append(_build_metrics_table(metrics))
parts.append("")
# 指标周环比分析
parts.append("## 周环比分析\n")
for m in metrics:
name = m.get("name", "")
value = m.get("value", 0)
previous = m.get("previous")
unit = m.get("unit", "")
if previous is not None:
change_desc = _describe_change(value, previous)
parts.append(f"- **{name}**: {format_chinese_unit(value)}{unit},{change_desc}")
else:
parts.append(f"- **{name}**: {format_chinese_unit(value)}{unit}")
parts.append("")
# 维度明细
if dimensions:
parts.append("## 维度明细\n")
parts.append(_build_dimension_section(dimensions, is_paid))
# 异常告警(仅付费用户)
if is_paid and anomalies:
parts.append(_build_anomaly_section(anomalies))
parts.append("")
return "\n".join(parts)
def _render_monthly(
data: Dict[str, Any],
title: Optional[str],
is_paid: bool,
) -> str:
"""渲染月报模板(仅付费用户可用)。
Args:
data: 报告数据。
title: 可选标题覆盖。
is_paid: 是否为付费用户。
Returns:
完整的 Markdown 月报内容。
Raises:
PermissionError: 当免费用户尝试生成月报时抛出。
"""
if not is_paid:
raise PermissionError("月报功能仅限付费用户使用,请升级订阅。")
period = data.get("period", "未知月份")
header = title or f"📊 {period} 业务月报"
metrics = data.get("metrics", [])
dimensions = data.get("dimensions", [])
anomalies = data.get("anomalies", [])
parts: List[str] = [f"# {header}\n"]
# 执行摘要
parts.append("## 执行摘要\n")
summary_items: List[str] = []
for m in metrics:
name = m.get("name", "")
value = m.get("value", 0)
previous = m.get("previous")
unit = m.get("unit", "")
if previous is not None:
change_desc = _describe_change(value, previous)
summary_items.append(f"{name}{format_chinese_unit(value)}{unit}({change_desc})")
else:
summary_items.append(f"{name}{format_chinese_unit(value)}{unit}")
parts.append(f"本月关键指标:{';'.join(summary_items)}。\n")
# 核心指标趋势
parts.append("## 核心指标\n")
parts.append(_build_metrics_table(metrics))
parts.append("")
# 为每个指标生成趋势说明
parts.append("## 指标趋势分析\n")
for m in metrics:
name = m.get("name", "")
value = m.get("value", 0)
previous = m.get("previous")
unit = m.get("unit", "")
if previous is not None:
change = _calc_change(value, previous)
change_desc = _describe_change(value, previous)
if change is not None and change > 0:
parts.append(f"- **{name}**:本月录得 {format_chinese_unit(value)}{unit},{change_desc},呈上升趋势。")
elif change is not None and change < 0:
parts.append(f"- **{name}**:本月录得 {format_chinese_unit(value)}{unit},{change_desc},需关注下降原因。")
else:
parts.append(f"- **{name}**:本月录得 {format_chinese_unit(value)}{unit},与上期持平。")
else:
parts.append(f"- **{name}**:本月录得 {format_chinese_unit(value)}{unit}。")
parts.append("")
# 维度深度分析
if dimensions:
parts.append("## 维度深度分析\n")
parts.append(_build_dimension_section(dimensions, is_paid))
# 异常分析
if anomalies:
parts.append("## 异常分析\n")
for a in anomalies:
metric = a.get("metric", "未知指标")
value = a.get("value", 0)
threshold = a.get("threshold", 0)
severity = a.get("severity", "info")
description = a.get("description", "")
severity_label = "警告" if severity == "warning" else "严重" if severity == "critical" else "提示"
parts.append(f"### {metric}({severity_label})\n")
parts.append(f"- 当前值:{value}")
parts.append(f"- 阈值:{threshold}")
parts.append(f"- 说明:{description}")
parts.append(f"- 建议:请排查{metric}异常原因,及时采取纠正措施。\n")
# 建议
parts.append("## 改进建议\n")
recommendation_idx = 1
for m in metrics:
name = m.get("name", "")
value = m.get("value", 0)
previous = m.get("previous")
if previous is not None:
change = _calc_change(value, previous)
if change is not None and change < 0:
parts.append(f"{recommendation_idx}. **{name}**出现下降,建议分析下降原因并制定改善方案。")
recommendation_idx += 1
elif change is not None and change > 0:
parts.append(f"{recommendation_idx}. **{name}**保持增长,建议总结成功经验并推广。")
recommendation_idx += 1
if anomalies:
for a in anomalies:
metric = a.get("metric", "")
parts.append(f"{recommendation_idx}. 重点关注**{metric}**异常,建立预警和快速响应机制。")
recommendation_idx += 1
if recommendation_idx == 1:
parts.append("各项指标表现稳定,建议持续监控。")
parts.append("")
return "\n".join(parts)
def _render_interactive(
data: Dict[str, Any],
title: Optional[str],
is_paid: bool,
) -> str:
"""渲染交互问答模板。
Args:
data: 报告数据。
title: 可选标题覆盖。
is_paid: 是否为付费用户。
Returns:
简洁的 Markdown 回答内容。
"""
header = title or "📊 数据查询结果"
metrics = data.get("metrics", [])
dimensions = data.get("dimensions", [])
parts: List[str] = [f"# {header}\n"]
# 数据表格
if metrics:
parts.append("## 数据概览\n")
parts.append(_build_metrics_table(metrics))
parts.append("")
# 维度数据
if dimensions:
for dim in dimensions:
name = dim.get("name", "")
dim_type = dim.get("type", "")
dim_data = dim.get("data", [])
parts.append(f"## {name}\n")
parts.append("| 项目 | 数值 |")
parts.append("|------|------|")
for item in dim_data:
label = item.get("label", "")
value = item.get("value", 0)
parts.append(f"| {label} | {format_chinese_unit(value)} |")
parts.append("")
# 付费用户生成图表
if is_paid and dim_data:
if dim_type == "distribution":
parts.append(generate_pie_chart(name, dim_data))
elif dim_type == "trend":
parts.append(generate_line_chart(name, dim_data))
elif dim_type == "ranking":
parts.append(generate_bar_chart(name, dim_data))
parts.append("")
# 简要洞察
parts.append("## 简要洞察\n")
insights: List[str] = []
for m in metrics:
name = m.get("name", "")
value = m.get("value", 0)
previous = m.get("previous")
unit = m.get("unit", "")
if previous is not None:
change_desc = _describe_change(value, previous)
insights.append(f"{name}为{format_chinese_unit(value)}{unit},{change_desc}")
else:
insights.append(f"{name}为{format_chinese_unit(value)}{unit}")
if insights:
parts.append(";".join(insights) + "。")
else:
parts.append("暂无指标数据。")
parts.append("")
return "\n".join(parts)
# ============================================================
# 模板路由
# ============================================================
_TEMPLATE_RENDERERS = {
"daily": _render_daily,
"weekly": _render_weekly,
"monthly": _render_monthly,
"interactive": _render_interactive,
}
def _count_charts(report: str) -> int:
"""统计报告中 Mermaid 图表的数量。
Args:
report: Markdown 报告内容。
Returns:
图表数量。
"""
return report.count("```mermaid")
# ============================================================
# 主入口
# ============================================================
def main() -> None:
"""CLI 主入口:解析参数、加载数据、生成报告并输出。"""
parser = argparse.ArgumentParser(
description="biz-data-insight 报告生成器 — 生成 Markdown + Mermaid 业务报告",
)
parser.add_argument(
"--template",
required=True,
choices=["daily", "weekly", "monthly", "interactive"],
help="报告模板类型:daily(日报)、weekly(周报)、monthly(月报)、interactive(交互问答)",
)
parser.add_argument(
"--data",
default=None,
help="JSON 字符串格式的报告数据",
)
parser.add_argument(
"--data-file",
default=None,
help="报告数据 JSON 文件路径(与 --data 二选一)",
)
parser.add_argument(
"--title",
default=None,
help="可选的报告标题覆盖",
)
parser.add_argument(
"--tier",
default=None,
choices=["free", "paid"],
help="订阅等级覆盖(free 或 paid)",
)
args = parser.parse_args()
# ----------------------------------------------------------
# 加载数据
# ----------------------------------------------------------
if args.data and args.data_file:
output_error("--data 和 --data-file 不能同时使用,请选择其一。", code="INVALID_ARGS")
sys.exit(1)
if not args.data and not args.data_file:
output_error("请提供 --data 或 --data-file 参数。", code="MISSING_DATA")
sys.exit(1)
try:
if args.data:
report_data: Dict[str, Any] = json.loads(args.data)
else:
with open(args.data_file, "r", encoding="utf-8") as f:
report_data = json.load(f)
except json.JSONDecodeError as e:
output_error(f"JSON 解析失败: {e}", code="JSON_ERROR")
sys.exit(1)
except FileNotFoundError:
output_error(f"数据文件不存在: {args.data_file}", code="FILE_NOT_FOUND")
sys.exit(1)
except Exception as e:
output_error(f"读取数据失败: {e}", code="DATA_ERROR")
sys.exit(1)
if not isinstance(report_data, dict):
output_error("数据格式错误,期望 JSON 对象。", code="INVALID_DATA")
sys.exit(1)
# ----------------------------------------------------------
# 检查订阅
# ----------------------------------------------------------
try:
subscription = check_subscription(args.tier)
except ValueError as e:
output_error(str(e), code="SUBSCRIPTION_ERROR")
sys.exit(1)
is_paid = subscription["tier"] == "paid"
# ----------------------------------------------------------
# 渲染报告
# ----------------------------------------------------------
renderer = _TEMPLATE_RENDERERS.get(args.template)
if renderer is None:
output_error(f"未知模板类型: {args.template}", code="UNKNOWN_TEMPLATE")
sys.exit(1)
try:
report_md = renderer(report_data, args.title, is_paid)
except PermissionError as e:
output_error(str(e), code="PERMISSION_DENIED")
sys.exit(1)
except Exception as e:
output_error(f"报告生成失败: {e}", code="RENDER_ERROR")
sys.exit(1)
# ----------------------------------------------------------
# 输出结果
# ----------------------------------------------------------
charts_count = _count_charts(report_md)
output_success({
"report": report_md,
"template": args.template,
"charts_count": charts_count,
})
if __name__ == "__main__":
main()
FILE:references/mermaid-guide.md
# Mermaid 图表语法参考
本文档是 `biz-data-insight` Skill 支持的 Mermaid 图表类型的快速参考。报告生成器使用这些语法生成内嵌在 Markdown 报告中的可视化图表。
---
## 1. 饼图(Pie Chart)
饼图用于展示各分类的占比关系,适合品类分布、渠道占比等场景。
### 基础语法
```
pie title 图表标题
"标签1" : 数值1
"标签2" : 数值2
"标签3" : 数值3
```
### 完整示例
```mermaid
pie title 本月销售额分类占比
"电子产品" : 39
"服装鞋帽" : 24
"食品饮料" : 18
"家居日用" : 13
"美妆个护" : 4
"其他" : 2
```
### 注意事项
- 数值会自动计算为百分比,无需手动换算
- 标签必须用英文双引号 `"` 包裹
- 建议分类不超过 **7 个**,过多会导致图表难以阅读
- 占比过小的分类(< 3%)建议合并为"其他"
---
## 2. 折线图(Line Chart)
折线图用于展示趋势变化,适合销售额走势、用户增长趋势等场景。基于 `xychart-beta` 语法。
### 基础语法
```
xychart-beta
title "图表标题"
x-axis ["标签1", "标签2", "标签3"]
y-axis "Y轴标题" 最小值 --> 最大值
line [数值1, 数值2, 数值3]
```
### 完整示例
```mermaid
xychart-beta
title "近7日销售额趋势(万元)"
x-axis ["3/13", "3/14", "3/15", "3/16", "3/17", "3/18", "3/19"]
y-axis "销售额(万元)" 20 --> 45
line [28.5, 27.2, 30.1, 32.8, 35.6, 38.2, 33.5]
```
### 多条折线
```mermaid
xychart-beta
title "本周与上周销售额对比(万元)"
x-axis ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
y-axis "销售额(万元)" 15 --> 45
line [28.5, 27.2, 30.1, 32.8, 35.6, 38.2, 21.2]
line [26.0, 25.8, 28.3, 29.5, 32.1, 35.0, 19.8]
```
---
## 3. 柱状图(Bar Chart)
柱状图用于展示各类别或各时间段的数值对比,适合订单量对比、产品排名等场景。同样基于 `xychart-beta` 语法。
### 基础语法
```
xychart-beta
title "图表标题"
x-axis ["标签1", "标签2", "标签3"]
y-axis "Y轴标题" 最小值 --> 最大值
bar [数值1, 数值2, 数值3]
```
### 完整示例
```mermaid
xychart-beta
title "各品类订单量"
x-axis ["电子产品", "服装鞋帽", "食品饮料", "家居日用", "美妆个护"]
y-axis "订单数" 0 --> 15000
bar [12500, 9800, 7200, 4100, 1680]
```
### 排名展示
```mermaid
xychart-beta
title "Top 5 热销产品销售额(元)"
x-axis ["产品A", "产品B", "产品C", "产品D", "产品E"]
y-axis "销售额(元)" 0 --> 200000
bar [156800, 134500, 98200, 87600, 72300]
```
---
## 4. 组合图(柱状图 + 折线图)
在同一图表中同时展示柱状图和折线图,适合展示"量"与"趋势"的关系,例如订单量(柱)与客单价(线)。
### 基础语法
```
xychart-beta
title "图表标题"
x-axis ["标签1", "标签2", "标签3"]
y-axis "Y轴标题" 最小值 --> 最大值
bar [数值1, 数值2, 数值3]
line [数值1, 数值2, 数值3]
```
### 完整示例
```mermaid
xychart-beta
title "每周销售额与订单量"
x-axis ["第1周", "第2周", "第3周", "第4周"]
y-axis "数值" 0 --> 300
bar [198, 202, 216, 240]
line [185, 195, 210, 235]
```
### 使用场景
- **柱状 = 实际值,折线 = 目标值**:展示达成情况
- **柱状 = 本期,折线 = 上期**:展示同比/环比趋势
- **柱状 = 销售额,折线 = 增长率**:展示增长动态(需注意 Y 轴刻度统一问题)
---
## 中文标签支持
Mermaid 完整支持中文标签,但需注意以下几点:
### 正确用法
```mermaid
pie title 用户来源分布
"微信小程序" : 45
"App端" : 30
"网页端" : 25
```
```mermaid
xychart-beta
title "各区域销售额(万元)"
x-axis ["华东", "华南", "华北", "西南", "其他"]
y-axis "销售额(万元)" 0 --> 500
bar [420, 310, 280, 150, 96]
```
### 需要注意的点
- 饼图标签必须用英文双引号 `"` 包裹,不要使用中文引号
- `xychart-beta` 的 `title` 值必须用英文双引号包裹
- X 轴标签必须用英文方括号 `[]` 包裹,每个标签用英文双引号
- 中文字符在某些渲染器中可能导致标签过宽而重叠,建议标签保持在 **4 个汉字以内**
---
## 颜色自定义
Mermaid 的 `xychart-beta` 和 `pie` 图表可以通过 `%%` 指令或主题配置自定义颜色。
### 使用主题
在 Markdown 文件头部或 Mermaid 代码块开头指定主题:
```mermaid
%%{init: {'theme': 'base', 'themeVariables': {'pie1': '#4CAF50', 'pie2': '#2196F3', 'pie3': '#FF9800', 'pie4': '#9C27B0'}}}%%
pie title 自定义颜色示例
"分类A" : 40
"分类B" : 30
"分类C" : 20
"分类D" : 10
```
### 可用主题
| 主题名 | 说明 |
|--------|------|
| `default` | 默认主题,适合大多数场景 |
| `dark` | 深色背景,适合暗色 IDE |
| `forest` | 绿色系,适合环保/健康类数据 |
| `base` | 基础主题,搭配 `themeVariables` 做完全自定义 |
### 建议
报告生成时默认使用 `default` 主题,无需额外配置。仅在用户有明确品牌色需求时使用自定义颜色。
---
## 数据格式要求
### 数值精度
- 整数直接使用,如 `1283`
- 小数最多保留 **1 位**,如 `28.5`
- 不要在数值中使用千分位逗号(`1,283` 会导致解析错误,应写为 `1283`)
- 百分比在饼图中直接使用数值,不加 `%` 符号
### Y 轴范围
- `y-axis` 的最小值和最大值应留出 **10%-20%** 的余量
- 示例:数据范围 100-500,建议设为 `80 --> 550`
- 最小值可以为 0,但不能为负数(`xychart-beta` 不支持负值)
### 数据点数量
- 饼图:建议 **3-7 个**分类
- 折线图/柱状图:建议 **5-31 个**数据点
- 超过 31 个数据点时标签会严重重叠,建议按周聚合或仅显示日期数字
---
## 常见问题与排错
### 问题1:图表不渲染,显示为代码块
**原因**:渲染环境不支持 Mermaid,或代码块标记不正确。
**解决**:确保使用三个反引号加 `mermaid` 标识:
````
```mermaid
pie title 示例
"A" : 50
"B" : 50
```
````
### 问题2:标签显示为乱码
**原因**:文件编码不是 UTF-8。
**解决**:确保 Markdown 文件以 UTF-8 编码保存。`report_generator.py` 默认以 UTF-8 输出。
### 问题3:X 轴标签重叠无法阅读
**原因**:数据点过多或标签文字过长。
**解决方案**:
- 减少数据点数量,按周/旬聚合
- 缩短标签文字,如用 `"3/19"` 代替 `"2026年3月19日"`
- 对于月度数据(31天),使用纯数字标签:`["1","2","3",...,"31"]`
### 问题4:饼图标签中含有特殊字符
**原因**:标签中包含 `:`、`#`、`%` 等 Mermaid 保留字符。
**解决**:避免在标签中使用这些字符,用中文替代。例如用 `"占比35"` 代替 `"35%"`。
### 问题5:xychart-beta 语法报错
**原因**:`xychart-beta` 是较新语法,部分旧版 Mermaid 渲染器不支持。
**解决**:确保渲染环境的 Mermaid 版本 >= **10.6.0**。GitHub 和最新版 VS Code Mermaid 插件均已支持。
---
## 报告生成器使用的图表类型速查
| 报告类型 | 使用的图表 |
|----------|-----------|
| 日报(付费版) | 饼图(分类占比) |
| 周报(付费版) | 饼图 + 折线图(趋势) + 柱状图(每日对比) |
| 月报(付费版) | 饼图 + 折线图 + 柱状图 + 组合图 |
| 交互式查询 | 根据问题自动选择最合适的图表类型 |
---
*语法参考版本 v1.0 | 基于 Mermaid v10.6+ | 适用于 biz-data-insight Skill*
FILE:references/report-templates.md
# 报告模板参考
本文档包含 `biz-data-insight` 所有报告类型的完整 Markdown 输出模板。`report_generator.py` 根据查询结果填充这些模板,最终生成可交付的报告文件。
---
## 1. 日报模板
日报聚焦当日核心指标,与前一日进行环比对比。免费版仅包含指标表格,付费版额外包含图表和智能洞察。
### 完整输出示例
```markdown
# 📊 业务日报 — 2026-03-19(周四)
数据源:业务主库 | 统计周期:2026-03-19 00:00 ~ 23:59
---
## 核心指标
| 指标 | 今日 | 昨日 | 日环比 |
|------|-----:|-----:|-------:|
| 订单数 | 1,283 | 1,195 | +7.4% |
| 销售额(元) | 328,450 | 301,200 | +9.0% |
| 客单价(元) | 256 | 252 | +1.6% |
| 活跃用户数 | 5,672 | 5,410 | +4.8% |
| 新增注册 | 342 | 298 | +14.8% |
| 退货率 | 2.1% | 2.3% | -0.2pp |
## 分类销售占比
```mermaid
pie title 今日销售额分类占比
"电子产品" : 42
"服装鞋帽" : 23
"食品饮料" : 18
"家居日用" : 12
"其他" : 5
```
## 今日洞察
- **销售额增长 9.0%**:主要受"春季促销"活动拉动,电子产品类目贡献最大
- **新增注册用户增长 14.8%**:高于近7日平均值(310),推广渠道表现良好
- **退货率下降至 2.1%**:连续3日下降,处于健康区间(阈值 < 3%)
## ⚠️ 异常提醒
| 指标 | 当前值 | 正常范围 | 说明 |
|------|-------:|---------|------|
| 华东区订单延迟率 | 8.2% | < 5% | 物流合作方反馈仓储系统升级中,预计明日恢复 |
---
*报告由 biz-data-insight 自动生成 | 2026-03-19 23:59*
```
---
## 2. 周报模板
周报提供整周汇总,包含周环比对比、趋势图和分类占比分析。
### 完整输出示例
```markdown
# 📊 业务周报 — 第12周(2026-03-16 至 2026-03-22)
数据源:业务主库 | 报告生成时间:2026-03-23 08:00
---
## 本周概览
| 指标 | 本周 | 上周 | 周环比 |
|------|-----:|-----:|-------:|
| 订单总数 | 8,956 | 8,230 | +8.8% |
| 销售总额(元) | 2,156,800 | 1,987,500 | +8.5% |
| 日均订单数 | 1,279 | 1,176 | +8.8% |
| 日均销售额(元) | 308,114 | 283,929 | +8.5% |
| 平均客单价(元) | 241 | 241 | +0.0% |
| 活跃用户数 | 12,450 | 11,800 | +5.5% |
| 新增注册 | 2,180 | 1,950 | +11.8% |
| 退货率 | 2.3% | 2.5% | -0.2pp |
## 每日销售额趋势
```mermaid
xychart-beta
title "本周每日销售额(万元)"
x-axis ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
y-axis "销售额(万元)" 20 --> 45
line [28.5, 27.2, 30.1, 32.8, 35.6, 38.2, 21.2]
```
## 分类销售占比
```mermaid
pie title 本周销售额分类占比
"电子产品" : 40
"服装鞋帽" : 25
"食品饮料" : 17
"家居日用" : 13
"其他" : 5
```
## 每日订单数趋势
```mermaid
xychart-beta
title "本周每日订单数"
x-axis ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
y-axis "订单数" 800 --> 1800
bar [1120, 1085, 1210, 1283, 1390, 1568, 1300]
```
## 本周洞察
### 增长亮点
- 订单数和销售额连续第3周增长,周环比增速稳定在 8% 以上
- 周六订单数达到 1,568,为近30天单日最高,周末促销策略效果显著
- 新增注册用户 2,180,较上周增长 11.8%,主要来自社交媒体渠道
### 需关注项
- 客单价连续两周持平于 241 元,可考虑搭配销售或提升高价值品类曝光
- 周日销售额明显回落至 21.2 万元,低于工作日平均水平
### 环比分析
- 销售额增长主要由订单量驱动(+8.8%),客单价未变化
- 电子产品类目占比从上周 38% 提升至 40%,为增长主力
## ⚠️ 异常检测
| 指标 | 日期 | 异常值 | 正常范围 | 分析 |
|------|------|-------:|---------|------|
| 页面跳出率 | 周三 | 62% | < 50% | 当日首页 Banner 加载缓慢,已修复 |
| 支付失败率 | 周五 | 3.8% | < 2% | 第三方支付通道短暂故障,持续约30分钟 |
---
*报告由 biz-data-insight 自动生成 | 2026-03-23 08:00*
```
---
## 3. 月报模板
月报是最全面的报告类型,包含管理层摘要、多维度分析、多种图表和完整的异常检测结果。仅付费版可用。
### 完整输出示例
```markdown
# 📊 业务月报 — 2026年3月
数据源:业务主库 | 报告生成时间:2026-04-01 08:00
---
## 管理层摘要
3月整体表现优于预期。销售额达到 **856.7万元**,同比增长 **15.2%**,环比增长 **6.8%**。订单量突破 **3.5万单**,客单价稳定在 **245元** 左右。新增注册用户 **8,420人**,用户活跃度稳步提升。需关注华东区物流延迟问题以及周中客单价偏低的趋势。
---
## 核心指标总览
| 指标 | 本月 | 上月 | 环比 | 去年同期 | 同比 |
|------|-----:|-----:|-----:|--------:|-----:|
| 订单总数 | 35,280 | 33,100 | +6.6% | 30,500 | +15.7% |
| 销售总额(万元) | 856.7 | 802.1 | +6.8% | 743.5 | +15.2% |
| 平均客单价(元) | 243 | 242 | +0.4% | 244 | -0.4% |
| 活跃用户数 | 48,200 | 45,600 | +5.7% | 38,900 | +23.9% |
| 新增注册 | 8,420 | 7,650 | +10.1% | 6,200 | +35.8% |
| 退货率 | 2.4% | 2.6% | -0.2pp | 2.8% | -0.4pp |
| 复购率 | 34.5% | 33.2% | +1.3pp | 29.8% | +4.7pp |
## 每周销售额趋势
```mermaid
xychart-beta
title "3月每周销售额(万元)"
x-axis ["第1周", "第2周", "第3周", "第4周"]
y-axis "销售额(万元)" 150 --> 260
bar [198.5, 202.3, 215.7, 240.2]
line [198.5, 202.3, 215.7, 240.2]
```
## 分类销售分析
```mermaid
pie title 3月销售额分类占比
"电子产品" : 39
"服装鞋帽" : 24
"食品饮料" : 18
"家居日用" : 13
"美妆个护" : 4
"其他" : 2
```
### 各品类详细数据
| 品类 | 销售额(万元) | 占比 | 环比 | 同比 |
|------|-------------:|-----:|-----:|-----:|
| 电子产品 | 334.1 | 39.0% | +8.2% | +18.5% |
| 服装鞋帽 | 205.6 | 24.0% | +5.1% | +12.3% |
| 食品饮料 | 154.2 | 18.0% | +4.5% | +10.8% |
| 家居日用 | 111.4 | 13.0% | +7.3% | +16.2% |
| 美妆个护 | 34.3 | 4.0% | +9.8% | +22.1% |
| 其他 | 17.1 | 2.0% | +3.2% | +8.5% |
## 每日订单量趋势
```mermaid
xychart-beta
title "3月每日订单量"
x-axis ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31"]
y-axis "订单数" 800 --> 1800
line [1050, 980, 1020, 1100, 1150, 1280, 1050, 1080, 1010, 1060, 1120, 1180, 1350, 1100, 1090, 1120, 1085, 1210, 1283, 1390, 1568, 1300, 1150, 1180, 1220, 1280, 1350, 1420, 1180, 1200, 1250]
```
## 用户分析
| 指标 | 数值 | 趋势 |
|------|-----:|------|
| 月活跃用户(MAU) | 48,200 | 连续6个月增长 |
| 日均活跃用户(DAU) | 5,830 | 较上月 +5.2% |
| DAU/MAU 比率 | 12.1% | 稳定 |
| 新用户首单转化率 | 28.3% | 较上月 +2.1pp |
| 30日留存率 | 41.2% | 较上月 +1.8pp |
## ⚠️ 异常检测报告
### 严重异常
| 日期 | 指标 | 异常值 | 正常范围 | 根因分析 |
|------|------|-------:|---------|---------|
| 3月15日 | 支付失败率 | 5.2% | < 2% | 支付网关升级导致部分交易失败,持续2小时 |
| 3月22日 | 华东区配送延迟率 | 12.3% | < 5% | 合作物流仓储系统升级,影响约800单 |
### 轻度异常
| 日期 | 指标 | 异常值 | 正常范围 | 说明 |
|------|------|-------:|---------|------|
| 3月2日 | 周日订单量 | 980 | > 1,000 | 正常周末波动,无需关注 |
| 3月10日 | 页面响应时间 | 3.2s | < 2s | CDN节点故障,已自动切换 |
## 关键洞察与建议
### 增长驱动因素
1. **促销活动效果显著**:周末促销带动周六订单数屡创新高,建议持续优化活动节奏
2. **新用户获取加速**:社交媒体渠道贡献了 62% 的新增注册,ROI 为 3.8
3. **电子产品持续领跑**:同比增长 18.5%,春季新品发布带动了品类增长
### 改进建议
1. **提升客单价**:连续两月持平,建议通过搭配推荐和满减策略提升
2. **优化周中转化**:周二、周三订单量偏低,可考虑定向推送
3. **物流稳定性**:华东区延迟问题需与物流方制定改进计划
---
*报告由 biz-data-insight 自动生成 | 2026-04-01 08:00*
```
---
## 4. 交互式查询结果模板
当用户通过 `/biz-data-insight ask` 提问时,返回简洁的分析结果。
### 完整输出示例
**用户提问**:`上周销售额最高的3个产品是什么?`
```markdown
## 🔍 查询结果
**问题**:上周销售额最高的3个产品是什么?
**统计周期**:2026-03-16 至 2026-03-22
### 结果
| 排名 | 产品名称 | 销售额(元) | 销量(件) | 占比 |
|:----:|---------|------------:|----------:|-----:|
| 1 | AirPods Pro 3 | 156,800 | 523 | 7.3% |
| 2 | 春季轻薄羽绒服 | 134,500 | 897 | 6.2% |
| 3 | 智能手环 X10 | 98,200 | 1,312 | 4.6% |
### 销售额对比
```mermaid
xychart-beta
title "Top 3 产品销售额(元)"
x-axis ["AirPods Pro 3", "春季轻薄羽绒服", "智能手环 X10"]
y-axis "销售额(元)" 0 --> 180000
bar [156800, 134500, 98200]
```
### 补充说明
- AirPods Pro 3 客单价最高(300元/件),主要购买群体为 25-35 岁用户
- 春季轻薄羽绒服受季节因素驱动,较上周增长 32%
- 智能手环 X10 销量最大但客单价最低(75元/件),适合作为引流产品
---
*数据来源:业务主库 | 查询耗时:2.3秒*
```
---
## 模板变量说明
报告模板中使用以下变量,由 `report_generator.py` 在运行时填充:
| 变量 | 说明 | 示例 |
|------|------|------|
| `{date}` | 报告日期 | 2026-03-19 |
| `{weekday}` | 星期几 | 周四 |
| `{week_number}` | 年内周数 | 12 |
| `{month}` | 月份 | 2026年3月 |
| `{datasource_name}` | 数据源名称 | 业务主库 |
| `{period_start}` | 统计周期开始 | 2026-03-16 |
| `{period_end}` | 统计周期结束 | 2026-03-22 |
| `{generated_at}` | 报告生成时间 | 2026-03-23 08:00 |
| `{metrics_table}` | 核心指标表格 | Markdown 表格 |
| `{charts}` | Mermaid 图表区域 | Mermaid 代码块 |
| `{insights}` | AI 洞察内容 | 结构化分析文本 |
| `{anomalies}` | 异常检测结果 | Markdown 表格 |
---
*模板版本 v1.0 | 适用于 biz-data-insight Skill*