#!/usr/bin/env python3
import json, os, sys

# Windows 中文系统默认 GBK 编码，强制 stdout 使用 UTF-8
if sys.stdout.encoding and sys.stdout.encoding.lower() != 'utf-8':
    sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)

# ── 读取输入：优先环境变量，fallback 读 stdin ─────────────────
raw = os.environ.get('_CC_INPUT', '').strip()
if not raw:
    try:
        raw = sys.stdin.read().strip()
    except:
        raw = ''
if not raw:
    raw = '{}'

try:
    d = json.loads(raw)
except:
    d = {}

cw         = d.get('context_window') or {}
co         = d.get('cost') or {}
ws         = d.get('workspace') or {}
model      = (d.get('model') or {}).get('display_name', 'Claude')
transcript = d.get('transcript_path') or ''
api_used   = int(cw.get('total_input_tokens') or 0)
ctx_max    = int(cw.get('context_window_size') or 0)
cur_dir    = ws.get('current_dir') or ''

# ── 从 transcript 本地估算 token ──────────────────────────────
local_tokens = 0
if transcript:
    transcript = transcript.replace('\\', '/')
    transcript = os.path.expanduser(transcript)

if transcript and os.path.isfile(transcript):
    total_chars = 0
    try:
        with open(transcript, 'r', encoding='utf-8', errors='ignore') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj  = json.loads(line)
                    msg  = obj.get('message', obj)
                    cont = msg.get('content', '')
                    if isinstance(cont, str):
                        total_chars += len(cont)
                    elif isinstance(cont, list):
                        for blk in cont:
                            if not isinstance(blk, dict):
                                continue
                            total_chars += len(blk.get('text', ''))
                            inp = blk.get('input') or {}
                            if isinstance(inp, dict):
                                total_chars += len(json.dumps(inp))
                            rc = blk.get('content', '')
                            if isinstance(rc, str):
                                total_chars += len(rc)
                            elif isinstance(rc, list):
                                for rb in rc:
                                    if isinstance(rb, dict):
                                        total_chars += len(rb.get('text', ''))
                except:
                    pass
        local_tokens = int(total_chars / 3.5)
    except:
        pass

# ── 决定显示值 ────────────────────────────────────────────────
display_tokens = max(api_used, local_tokens)
token_source   = '~' if (local_tokens > 0 and api_used == 0) else ''
if ctx_max == 0:
    ctx_max = 262144  # 默认 256k

pct = min(int(display_tokens * 100 / ctx_max), 100) if ctx_max else 0

# ── ANSI 颜色 ─────────────────────────────────────────────────
R  = '\033[0m';  B  = '\033[1m';  DM = '\033[2m'
GR = '\033[38;5;114m'; YL = '\033[38;5;221m'; RD = '\033[38;5;203m'
CY = '\033[38;5;117m'; GY = '\033[38;5;245m'

bc  = RD if pct >= 90 else YL if pct >= 70 else GR
W   = 16
bar = bc + '▓' * (pct * W // 100) + GY + '░' * (W - pct * W // 100) + R

tok_k = display_tokens // 1000
max_k = ctx_max // 1000

out  = f'{CY}{B}[{model}]{R} '
out += f'{bar} '
out += f'{bc}{B}{pct}%{R}{GY}({token_source}{tok_k}k/{max_k}k){R}'
print(out)
