@clawhub-anderskev-60f39d7981
Architectural guidance for building node-based UIs with React Flow. Use when designing flow-based applications, making decisions about state management, inte...
---
name: react-flow-architecture
description: Architectural guidance for building node-based UIs with React Flow. Use when designing flow-based applications, making decisions about state management, integration patterns, or evaluating whether React Flow fits a use case.
---
# React Flow Architecture
## When to Use React Flow
### Good Fit
- Visual programming interfaces
- Workflow builders and automation tools
- Diagram editors (flowcharts, org charts)
- Data pipeline visualization
- Mind mapping tools
- Node-based audio/video editors
- Decision tree builders
- State machine designers
### Consider Alternatives
- Simple static diagrams (use SVG or canvas directly)
- Heavy real-time collaboration (may need custom sync layer)
- 3D visualizations (use Three.js, react-three-fiber)
- Graph analysis with 10k+ nodes (use WebGL-based solutions like Sigma.js)
### Decision workflow (gates)
Run this sequence before locking the stack or sprinting implementation. Skip only for throwaway prototypes.
1. **Name the interactions** — List the top user actions (e.g. drag, connect, delete, group). **Pass:** Each action maps to a concrete React Flow callback you will implement (`onNodesChange`, `onConnect`, …).
2. **Classify scale** — Estimate peak nodes (visible canvas or document total). **Pass:** Your range matches a row in [Node Count Guidelines](#node-count-guidelines) and you accept the listed strategy (e.g. `onlyRenderVisibleElements` when that row implies it).
3. **Place state** — Choose local hooks, an external store, or Redux/other. **Pass:** One sentence states where persistence, undo, or cross-surface sync will live, or explicitly “not needed yet.”
4. **Re-check alternatives** — If the use case matches [Consider Alternatives](#consider-alternatives), **Pass:** One sentence explains why React Flow still fits or which listed alternative you chose instead.
## Architecture Patterns
### Package Structure (xyflow)
```
@xyflow/system (vanilla TypeScript)
├── Core algorithms (edge paths, bounds, viewport)
├── xypanzoom (d3-based pan/zoom)
├── xydrag, xyhandle, xyminimap, xyresizer
└── Shared types
@xyflow/react (depends on @xyflow/system)
├── React components and hooks
├── Zustand store for state management
└── Framework-specific integrations
@xyflow/svelte (depends on @xyflow/system)
└── Svelte components and stores
```
**Implication**: Core logic is framework-agnostic. When contributing or debugging, check if issue is in @xyflow/system or framework-specific package.
### State Management Approaches
#### 1. Local State (Simple Apps)
```tsx
// useNodesState/useEdgesState for prototyping
const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes);
const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges);
```
**Pros**: Simple, minimal boilerplate
**Cons**: State isolated to component tree
#### 2. External Store (Production)
```tsx
// Zustand store example
import { create } from 'zustand';
interface FlowStore {
nodes: Node[];
edges: Edge[];
setNodes: (nodes: Node[]) => void;
onNodesChange: OnNodesChange;
}
const useFlowStore = create<FlowStore>((set, get) => ({
nodes: initialNodes,
edges: initialEdges,
setNodes: (nodes) => set({ nodes }),
onNodesChange: (changes) => {
set({ nodes: applyNodeChanges(changes, get().nodes) });
},
}));
// In component
function Flow() {
const { nodes, edges, onNodesChange } = useFlowStore();
return <ReactFlow nodes={nodes} onNodesChange={onNodesChange} />;
}
```
**Pros**: State accessible anywhere, easier persistence/sync
**Cons**: More setup, need careful selector optimization
#### 3. Redux/Other State Libraries
```tsx
// Connect via selectors
const nodes = useSelector(selectNodes);
const dispatch = useDispatch();
const onNodesChange = useCallback((changes: NodeChange[]) => {
dispatch(nodesChanged(changes));
}, [dispatch]);
```
### Data Flow Architecture
```
User Input → Change Event → Reducer/Handler → State Update → Re-render
↓
[Drag node] → onNodesChange → applyNodeChanges → setNodes → ReactFlow
↓
[Connect] → onConnect → addEdge → setEdges → ReactFlow
↓
[Delete] → onNodesDelete → deleteElements → setNodes/setEdges → ReactFlow
```
### Sub-Flow Pattern (Nested Nodes)
```tsx
// Parent node containing child nodes
const nodes = [
{
id: 'group-1',
type: 'group',
position: { x: 0, y: 0 },
style: { width: 300, height: 200 },
},
{
id: 'child-1',
parentId: 'group-1', // Key: parent reference
extent: 'parent', // Key: constrain to parent
position: { x: 10, y: 30 }, // Relative to parent
data: { label: 'Child' },
},
];
```
**Considerations**:
- Use `extent: 'parent'` to constrain dragging
- Use `expandParent: true` to auto-expand parent
- Parent z-index affects child rendering order
### Viewport Persistence
```tsx
// Save viewport state
const { toObject, setViewport } = useReactFlow();
const handleSave = () => {
const flow = toObject();
// flow.nodes, flow.edges, flow.viewport
localStorage.setItem('flow', JSON.stringify(flow));
};
const handleRestore = () => {
const flow = JSON.parse(localStorage.getItem('flow'));
setNodes(flow.nodes);
setEdges(flow.edges);
setViewport(flow.viewport);
};
```
## Integration Patterns
### With Backend/API
```tsx
// Load from API
useEffect(() => {
fetch('/api/flow')
.then(r => r.json())
.then(({ nodes, edges }) => {
setNodes(nodes);
setEdges(edges);
});
}, []);
// Debounced auto-save
const debouncedSave = useMemo(
() => debounce((nodes, edges) => {
fetch('/api/flow', {
method: 'POST',
body: JSON.stringify({ nodes, edges }),
});
}, 1000),
[]
);
useEffect(() => {
debouncedSave(nodes, edges);
}, [nodes, edges]);
```
### With Layout Algorithms
```tsx
import dagre from 'dagre';
function getLayoutedElements(nodes: Node[], edges: Edge[]) {
const g = new dagre.graphlib.Graph();
g.setGraph({ rankdir: 'TB' });
g.setDefaultEdgeLabel(() => ({}));
nodes.forEach((node) => {
g.setNode(node.id, { width: 150, height: 50 });
});
edges.forEach((edge) => {
g.setEdge(edge.source, edge.target);
});
dagre.layout(g);
return {
nodes: nodes.map((node) => {
const pos = g.node(node.id);
return { ...node, position: { x: pos.x, y: pos.y } };
}),
edges,
};
}
```
## Performance Scaling
### Node Count Guidelines
| Nodes | Strategy |
|-------|----------|
| < 100 | Default settings |
| 100-500 | Enable `onlyRenderVisibleElements` |
| 500-1000 | Simplify custom nodes, reduce DOM elements |
| > 1000 | Consider virtualization, WebGL alternatives |
### Optimization Techniques
```tsx
<ReactFlow
// Only render nodes/edges in viewport
onlyRenderVisibleElements={true}
// Reduce node border radius (improves intersect calculations)
nodeExtent={[[-1000, -1000], [1000, 1000]]}
// Disable features not needed
elementsSelectable={false}
panOnDrag={false}
zoomOnScroll={false}
/>
```
## Trade-offs
### Controlled vs Uncontrolled
| Controlled | Uncontrolled |
|------------|--------------|
| More boilerplate | Less code |
| Full state control | Internal state |
| Easy persistence | Need `toObject()` |
| Better for complex apps | Good for prototypes |
### Connection Modes
| Strict (default) | Loose |
|------------------|-------|
| Source → Target only | Any handle → any handle |
| Predictable behavior | More flexible |
| Use for data flows | Use for diagrams |
```tsx
<ReactFlow connectionMode={ConnectionMode.Loose} />
```
### Edge Rendering
| Default edges | Custom edges |
|---------------|--------------|
| Fast rendering | More control |
| Limited styling | Any SVG/HTML |
| Simple use cases | Complex labels |
Advanced React Flow patterns for complex use cases. Use when implementing sub-flows, custom connection lines, programmatic layouts, drag-and-drop, undo/redo,...
---
name: react-flow-advanced
description: Advanced React Flow patterns for complex use cases. Use when implementing sub-flows, custom connection lines, programmatic layouts, drag-and-drop, undo/redo, or complex state synchronization.
---
# Advanced React Flow Patterns
## Gates (check before shipping)
Use these as **sequenced** checks—not “I think it works.”
1. **Sub-flows / groups:** **Pass:** Every `parentId` matches an existing node `id`; the parent `type` is registered in `nodeTypes`; child positions are relative to the parent as intended (spot-check one drag inside/outside the group).
2. **Custom connection line:** **Pass:** With a valid/invalid drag, stroke or `connectionStatus` visibly differs; path renders without console errors from `getSmoothStepPath` (invalid coords).
3. **External drag-and-drop:** **Pass:** `onDragOver` always `preventDefault()`; drop position uses `screenToFlowPosition` (not raw `clientX`/`clientY` as flow coords); new node appears under the cursor on the pane.
4. **Undo/redo:** **Pass:** One undo returns to the prior `{ nodes, edges }`; redo restores; rapid changes do not leave `canUndo`/`canRedo` inconsistent with visible graph (exercise add → undo → redo once).
5. **Programmatic layout (dagre):** **Pass:** After `setNodes`, node positions match intended `rankdir`; `fitView` runs after layout (e.g. `requestAnimationFrame`) so the viewport is not stale.
6. **Connect on drop (new node):** **Pass:** Dropping on empty pane creates a node **and** an edge from the source handle; dropping on a valid target does not duplicate nodes (only the invalid-drop path adds a node).
7. **Selectors / store:** **Pass:** Components that `useStore` with objects use `shallow` (or equivalent) so unrelated store updates do not re-render every frame.
## Sub-Flows (Nested Nodes)
```tsx
const nodes = [
// Parent (group) node
{
id: 'group-1',
type: 'group',
position: { x: 0, y: 0 },
style: { width: 400, height: 300, padding: 10 },
data: { label: 'Group' },
},
// Child nodes
{
id: 'child-1',
parentId: 'group-1', // Reference parent
extent: 'parent', // Constrain to parent bounds
expandParent: true, // Auto-expand parent if dragged to edge
position: { x: 20, y: 50 }, // Relative to parent
data: { label: 'Child 1' },
},
{
id: 'child-2',
parentId: 'group-1',
extent: 'parent',
position: { x: 200, y: 50 },
data: { label: 'Child 2' },
},
];
```
### Group Node Component
```tsx
function GroupNode({ data, id }: NodeProps) {
return (
<div className="group-node">
<div className="group-header">{data.label}</div>
{/* Children are rendered automatically by React Flow */}
</div>
);
}
```
## Custom Connection Line
```tsx
import { ConnectionLineComponentProps, getSmoothStepPath } from '@xyflow/react';
function CustomConnectionLine({
fromX, fromY, fromPosition,
toX, toY, toPosition,
connectionStatus,
}: ConnectionLineComponentProps) {
const [path] = getSmoothStepPath({
sourceX: fromX,
sourceY: fromY,
sourcePosition: fromPosition,
targetX: toX,
targetY: toY,
targetPosition: toPosition,
});
return (
<g>
<path
d={path}
fill="none"
stroke={connectionStatus === 'valid' ? '#22c55e' : '#ef4444'}
strokeWidth={2}
strokeDasharray="5 5"
/>
</g>
);
}
<ReactFlow connectionLineComponent={CustomConnectionLine} />
```
## Drag and Drop from External Source
```tsx
import { useCallback, useRef, useState } from 'react';
import { useReactFlow } from '@xyflow/react';
function DnDFlow() {
const reactFlowWrapper = useRef(null);
const { screenToFlowPosition, addNodes } = useReactFlow();
const [reactFlowInstance, setReactFlowInstance] = useState(null);
const onDragOver = useCallback((event: DragEvent) => {
event.preventDefault();
event.dataTransfer.dropEffect = 'move';
}, []);
const onDrop = useCallback((event: DragEvent) => {
event.preventDefault();
const type = event.dataTransfer.getData('application/reactflow');
if (!type) return;
// Convert screen position to flow position
const position = screenToFlowPosition({
x: event.clientX,
y: event.clientY,
});
const newNode = {
id: `Date.now()`,
type,
position,
data: { label: `type node` },
};
addNodes(newNode);
}, [screenToFlowPosition, addNodes]);
return (
<div ref={reactFlowWrapper} style={{ height: '100%' }}>
<ReactFlow
onDragOver={onDragOver}
onDrop={onDrop}
onInit={setReactFlowInstance}
/>
</div>
);
}
// Sidebar component
function Sidebar() {
const onDragStart = (event: DragEvent, nodeType: string) => {
event.dataTransfer.setData('application/reactflow', nodeType);
event.dataTransfer.effectAllowed = 'move';
};
return (
<aside>
<div draggable onDragStart={(e) => onDragStart(e, 'input')}>
Input Node
</div>
<div draggable onDragStart={(e) => onDragStart(e, 'default')}>
Default Node
</div>
</aside>
);
}
```
## Undo/Redo
```tsx
import { useCallback, useState } from 'react';
function useUndoRedo<T>(initialState: T) {
const [history, setHistory] = useState<T[]>([initialState]);
const [index, setIndex] = useState(0);
const state = history[index];
const setState = useCallback((newState: T | ((prev: T) => T)) => {
setHistory((prev) => {
const resolved = typeof newState === 'function'
? (newState as (prev: T) => T)(prev[index])
: newState;
// Remove future states and add new state
const newHistory = prev.slice(0, index + 1);
return [...newHistory, resolved];
});
setIndex((i) => i + 1);
}, [index]);
const undo = useCallback(() => {
setIndex((i) => Math.max(0, i - 1));
}, []);
const redo = useCallback(() => {
setIndex((i) => Math.min(history.length - 1, i + 1));
}, [history.length]);
const canUndo = index > 0;
const canRedo = index < history.length - 1;
return { state, setState, undo, redo, canUndo, canRedo };
}
// Usage
function Flow() {
const {
state: { nodes, edges },
setState,
undo, redo, canUndo, canRedo
} = useUndoRedo({ nodes: initialNodes, edges: initialEdges });
// Capture state on significant changes
const onNodesChange = useCallback((changes) => {
const hasPositionChange = changes.some(c => c.type === 'position' && !c.dragging);
if (hasPositionChange) {
setState(prev => ({
nodes: applyNodeChanges(changes, prev.nodes),
edges: prev.edges,
}));
}
}, [setState]);
}
```
## Programmatic Layout with dagre
```tsx
import dagre from 'dagre';
interface LayoutOptions {
direction: 'TB' | 'BT' | 'LR' | 'RL';
nodeWidth: number;
nodeHeight: number;
}
function getLayoutedElements(
nodes: Node[],
edges: Edge[],
options: LayoutOptions = { direction: 'TB', nodeWidth: 172, nodeHeight: 36 }
) {
const g = new dagre.graphlib.Graph();
g.setGraph({ rankdir: options.direction });
g.setDefaultEdgeLabel(() => ({}));
nodes.forEach((node) => {
g.setNode(node.id, {
width: node.measured?.width ?? options.nodeWidth,
height: node.measured?.height ?? options.nodeHeight,
});
});
edges.forEach((edge) => {
g.setEdge(edge.source, edge.target);
});
dagre.layout(g);
const layoutedNodes = nodes.map((node) => {
const nodeWithPosition = g.node(node.id);
return {
...node,
position: {
x: nodeWithPosition.x - (node.measured?.width ?? options.nodeWidth) / 2,
y: nodeWithPosition.y - (node.measured?.height ?? options.nodeHeight) / 2,
},
};
});
return { nodes: layoutedNodes, edges };
}
// Usage after nodes are measured
function Flow() {
const { fitView } = useReactFlow();
const onLayout = useCallback((direction: 'TB' | 'LR') => {
const { nodes: layoutedNodes, edges: layoutedEdges } = getLayoutedElements(
nodes,
edges,
{ direction, nodeWidth: 150, nodeHeight: 50 }
);
setNodes([...layoutedNodes]);
setEdges([...layoutedEdges]);
window.requestAnimationFrame(() => {
fitView({ duration: 500 });
});
}, [nodes, edges, setNodes, setEdges, fitView]);
}
```
## Connection with Edge on Drop
```tsx
function Flow() {
const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes);
const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges);
const { screenToFlowPosition } = useReactFlow();
const onConnectEnd = useCallback(
(event: MouseEvent | TouchEvent, connectionState: FinalConnectionState) => {
// Only proceed if dropped on pane (not on a node)
if (!connectionState.isValid && connectionState.fromHandle) {
const id = `Date.now()`;
const { clientX, clientY } = 'changedTouches' in event
? event.changedTouches[0]
: event;
const newNode = {
id,
position: screenToFlowPosition({ x: clientX, y: clientY }),
data: { label: 'New Node' },
};
setNodes((nds) => [...nds, newNode]);
setEdges((eds) => [
...eds,
{
id: `e-connectionState.fromNode?.id-id`,
source: connectionState.fromNode?.id ?? '',
target: id,
},
]);
}
},
[screenToFlowPosition, setNodes, setEdges]
);
return (
<ReactFlow
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnectEnd={onConnectEnd}
/>
);
}
```
## Accessing Node Data from Edges
```tsx
import { useNodesData, type EdgeProps } from '@xyflow/react';
function DataEdge({ source, target, ...props }: EdgeProps) {
// Get data for source and target nodes
const nodesData = useNodesData([source, target]);
const sourceData = nodesData[0];
const targetData = nodesData[1];
const [path, labelX, labelY] = getSmoothStepPath(props);
return (
<>
<BaseEdge path={path} />
<EdgeLabelRenderer>
<div style={{ transform: `translate(-50%, -50%) translate(labelXpx, labelYpx)` }}>
{sourceData?.data?.label} → {targetData?.data?.label}
</div>
</EdgeLabelRenderer>
</>
);
}
```
## Middleware for Node Changes
```tsx
// Filter or modify changes before they're applied
const onNodesChangeMiddleware = useCallback((changes: NodeChange[]) => {
// Example: Prevent deletion of certain nodes
const filteredChanges = changes.filter((change) => {
if (change.type === 'remove') {
const node = nodes.find((n) => n.id === change.id);
return node?.data?.deletable !== false;
}
return true;
});
setNodes((nds) => applyNodeChanges(filteredChanges, nds));
}, [nodes, setNodes]);
```
## Keyboard Shortcuts
```tsx
import { useKeyPress } from '@xyflow/react';
function Flow() {
const { deleteElements, getNodes, getEdges, fitView } = useReactFlow();
// Ctrl/Cmd + A: Select all
const selectAllPressed = useKeyPress(['Meta+a', 'Control+a']);
useEffect(() => {
if (selectAllPressed) {
setNodes((nds) => nds.map((n) => ({ ...n, selected: true })));
setEdges((eds) => eds.map((e) => ({ ...e, selected: true })));
}
}, [selectAllPressed]);
// Custom delete handler
const deletePressed = useKeyPress(['Backspace', 'Delete']);
useEffect(() => {
if (deletePressed) {
const selectedNodes = getNodes().filter((n) => n.selected);
const selectedEdges = getEdges().filter((e) => e.selected);
deleteElements({ nodes: selectedNodes, edges: selectedEdges });
}
}, [deletePressed]);
}
```
## Performance: Memoizing Selectors
```tsx
import { useCallback } from 'react';
import { useStore, type ReactFlowState } from '@xyflow/react';
import { shallow } from 'zustand/shallow';
// Create stable selector outside component
const nodesSelector = (state: ReactFlowState) => state.nodes;
// Or use multiple values with shallow compare
const flowStateSelector = (state: ReactFlowState) => ({
nodes: state.nodes,
edges: state.edges,
viewport: state.transform,
});
function FlowInfo() {
const { nodes, edges, viewport } = useStore(flowStateSelector, shallow);
return <div>Nodes: {nodes.length}, Edges: {edges.length}</div>;
}
```
Automatic graph layout using dagre with React Flow (@xyflow/react). Use when implementing auto-layout, hierarchical layouts, tree structures, or arranging no...
---
name: dagre-react-flow
description: Automatic graph layout using dagre with React Flow (@xyflow/react). Use when implementing auto-layout, hierarchical layouts, tree structures, or arranging nodes programmatically. Triggers on dagre, auto-layout, automatic layout, getLayoutedElements, rankdir, hierarchical graph.
---
# Dagre with React Flow
Dagre is a JavaScript library for laying out directed graphs. It computes optimal node positions for hierarchical/tree layouts. React Flow handles rendering; dagre handles positioning.
## Quick Start
```bash
pnpm add @dagrejs/dagre
```
```typescript
import dagre from '@dagrejs/dagre';
import { Node, Edge } from '@xyflow/react';
const getLayoutedElements = (
nodes: Node[],
edges: Edge[],
direction: 'TB' | 'LR' = 'TB'
) => {
const g = new dagre.graphlib.Graph();
g.setGraph({ rankdir: direction });
g.setDefaultEdgeLabel(() => ({}));
nodes.forEach((node) => {
g.setNode(node.id, { width: 172, height: 36 });
});
edges.forEach((edge) => {
g.setEdge(edge.source, edge.target);
});
dagre.layout(g);
const layoutedNodes = nodes.map((node) => {
const pos = g.node(node.id);
return {
...node,
position: { x: pos.x - 86, y: pos.y - 18 }, // Center to top-left
};
});
return { nodes: layoutedNodes, edges };
};
```
## Core Concepts
### Coordinate System Difference
**Critical:** Dagre returns center coordinates; React Flow uses top-left.
```typescript
// Dagre output: center of node
const dagrePos = g.node(nodeId); // { x: 100, y: 50 } = center
// React Flow expects: top-left corner
const rfPosition = {
x: dagrePos.x - nodeWidth / 2,
y: dagrePos.y - nodeHeight / 2,
};
```
### Node Dimensions
Dagre requires explicit dimensions. Three approaches:
**1. Fixed dimensions (simplest):**
```typescript
g.setNode(node.id, { width: 172, height: 36 });
```
**2. Per-node dimensions from data:**
```typescript
g.setNode(node.id, {
width: node.data.width ?? 172,
height: node.data.height ?? 36,
});
```
**3. Measured dimensions (most accurate):**
```typescript
// After React Flow measures nodes
g.setNode(node.id, {
width: node.measured?.width ?? 172,
height: node.measured?.height ?? 36,
});
```
### Layout Directions
| Value | Direction | Use Case |
|-------|-----------|----------|
| `TB` | Top to Bottom | Org charts, decision trees |
| `BT` | Bottom to Top | Dependency graphs (deps at bottom) |
| `LR` | Left to Right | Timelines, horizontal flows |
| `RL` | Right to Left | RTL layouts |
```typescript
g.setGraph({ rankdir: 'LR' }); // Horizontal layout
```
## Hard gates
Run these in order before treating layout as correct (each step has an objective pass condition):
1. **Dimensions match conversion** — For every node id, the `width` and `height` given to `g.setNode` for that id are the same numbers used to compute `position.x` / `position.y` from `g.node(id)` (half-width / half-height must match the dagre node box).
2. **Center → top-left** — `position` is `{ x: centerX - width/2, y: centerY - height/2 }`, not raw `g.node(id).x` / `.y` alone.
3. **React Flow state update** — After programmatic layout, `setNodes` / `setEdges` receive a **new** array instance (e.g. `[...layouted]` or `layouted.map(...)`), not the previous reference unchanged.
4. **Optional sanity** — If you use `fitView` after layout, it runs after nodes are committed (e.g. next `requestAnimationFrame` or `setTimeout(0)`), not in the same synchronous tick as `setNodes` with stale measurements.
## Complete Implementation
### Basic Layout Function
```typescript
import dagre from '@dagrejs/dagre';
import type { Node, Edge } from '@xyflow/react';
interface LayoutOptions {
direction?: 'TB' | 'BT' | 'LR' | 'RL';
nodeWidth?: number;
nodeHeight?: number;
nodesep?: number; // Horizontal spacing
ranksep?: number; // Vertical spacing (between ranks)
}
export function getLayoutedElements(
nodes: Node[],
edges: Edge[],
options: LayoutOptions = {}
): { nodes: Node[]; edges: Edge[] } {
const {
direction = 'TB',
nodeWidth = 172,
nodeHeight = 36,
nodesep = 50,
ranksep = 50,
} = options;
const g = new dagre.graphlib.Graph();
g.setGraph({ rankdir: direction, nodesep, ranksep });
g.setDefaultEdgeLabel(() => ({}));
nodes.forEach((node) => {
const width = node.measured?.width ?? nodeWidth;
const height = node.measured?.height ?? nodeHeight;
g.setNode(node.id, { width, height });
});
edges.forEach((edge) => {
g.setEdge(edge.source, edge.target);
});
dagre.layout(g);
const layoutedNodes = nodes.map((node) => {
const pos = g.node(node.id);
const width = node.measured?.width ?? nodeWidth;
const height = node.measured?.height ?? nodeHeight;
return {
...node,
position: {
x: pos.x - width / 2,
y: pos.y - height / 2,
},
};
});
return { nodes: layoutedNodes, edges };
}
```
### React Flow Integration
```tsx
import { useCallback } from 'react';
import {
ReactFlow,
useNodesState,
useEdgesState,
useReactFlow,
ReactFlowProvider,
} from '@xyflow/react';
import { getLayoutedElements } from './layout';
const initialNodes = [
{ id: '1', data: { label: 'Start' }, position: { x: 0, y: 0 } },
{ id: '2', data: { label: 'Process' }, position: { x: 0, y: 0 } },
{ id: '3', data: { label: 'End' }, position: { x: 0, y: 0 } },
];
const initialEdges = [
{ id: 'e1-2', source: '1', target: '2' },
{ id: 'e2-3', source: '2', target: '3' },
];
// Apply initial layout
const { nodes: layoutedNodes, edges: layoutedEdges } = getLayoutedElements(
initialNodes,
initialEdges,
{ direction: 'TB' }
);
function Flow() {
const [nodes, setNodes, onNodesChange] = useNodesState(layoutedNodes);
const [edges, setEdges, onEdgesChange] = useEdgesState(layoutedEdges);
const { fitView } = useReactFlow();
const onLayout = useCallback((direction: 'TB' | 'LR') => {
const { nodes: newNodes, edges: newEdges } = getLayoutedElements(
nodes,
edges,
{ direction }
);
setNodes([...newNodes]);
setEdges([...newEdges]);
// Fit view after layout with animation
window.requestAnimationFrame(() => {
fitView({ duration: 300 });
});
}, [nodes, edges, setNodes, setEdges, fitView]);
return (
<div style={{ width: '100%', height: '100vh' }}>
<div style={{ position: 'absolute', zIndex: 10, padding: 10 }}>
<button onClick={() => onLayout('TB')}>Vertical</button>
<button onClick={() => onLayout('LR')}>Horizontal</button>
</div>
<ReactFlow
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
fitView
/>
</div>
);
}
export default function App() {
return (
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
);
}
```
## useAutoLayout Hook
Reusable hook for automatic layout:
```typescript
import { useCallback, useEffect, useRef } from 'react';
import {
useReactFlow,
useNodesInitialized,
type Node,
type Edge,
} from '@xyflow/react';
import dagre from '@dagrejs/dagre';
interface UseAutoLayoutOptions {
direction?: 'TB' | 'BT' | 'LR' | 'RL';
nodesep?: number;
ranksep?: number;
}
export function useAutoLayout(options: UseAutoLayoutOptions = {}) {
const { direction = 'TB', nodesep = 50, ranksep = 50 } = options;
const { getNodes, getEdges, setNodes, fitView } = useReactFlow();
const nodesInitialized = useNodesInitialized();
const layoutApplied = useRef(false);
const runLayout = useCallback(() => {
const nodes = getNodes();
const edges = getEdges();
const g = new dagre.graphlib.Graph();
g.setGraph({ rankdir: direction, nodesep, ranksep });
g.setDefaultEdgeLabel(() => ({}));
nodes.forEach((node) => {
g.setNode(node.id, {
width: node.measured?.width ?? 172,
height: node.measured?.height ?? 36,
});
});
edges.forEach((edge) => {
g.setEdge(edge.source, edge.target);
});
dagre.layout(g);
const layouted = nodes.map((node) => {
const pos = g.node(node.id);
const width = node.measured?.width ?? 172;
const height = node.measured?.height ?? 36;
return {
...node,
position: { x: pos.x - width / 2, y: pos.y - height / 2 },
};
});
setNodes(layouted);
window.requestAnimationFrame(() => fitView({ duration: 200 }));
}, [direction, nodesep, ranksep, getNodes, getEdges, setNodes, fitView]);
// Auto-layout on initialization
useEffect(() => {
if (nodesInitialized && !layoutApplied.current) {
runLayout();
layoutApplied.current = true;
}
}, [nodesInitialized, runLayout]);
return { runLayout };
}
```
Usage:
```tsx
function Flow() {
const { runLayout } = useAutoLayout({ direction: 'LR', ranksep: 100 });
return (
<>
<button onClick={runLayout}>Re-layout</button>
<ReactFlow ... />
</>
);
}
```
## Edge Options
Control edge routing with weight and minlen:
```typescript
edges.forEach((edge) => {
g.setEdge(edge.source, edge.target, {
weight: edge.data?.priority ?? 1, // Higher = more direct path
minlen: edge.data?.minRanks ?? 1, // Minimum ranks between nodes
});
});
```
**weight**: Higher weight edges are prioritized for shorter, more direct paths.
**minlen**: Forces minimum rank separation between connected nodes.
```typescript
// Force 2 ranks between nodes
g.setEdge('a', 'b', { minlen: 2 });
```
## Common Patterns
### Handle Position Based on Direction
Adjust handles for horizontal vs vertical layouts:
```tsx
function CustomNode({ data }: NodeProps) {
const isHorizontal = data.direction === 'LR' || data.direction === 'RL';
return (
<div>
<Handle
type="target"
position={isHorizontal ? Position.Left : Position.Top}
/>
<div>{data.label}</div>
<Handle
type="source"
position={isHorizontal ? Position.Right : Position.Bottom}
/>
</div>
);
}
```
### Animated Layout Transitions
Smooth position changes using CSS transitions:
```css
.react-flow__node {
transition: transform 300ms ease-out;
}
```
For programmatic animation, see [reference.md](reference.md#animated-layout-transitions).
### Layout with Node Groups
Exclude group nodes from dagre layout:
```typescript
const layoutWithGroups = (nodes: Node[], edges: Edge[]) => {
// Separate regular nodes from groups
const regularNodes = nodes.filter((n) => n.type !== 'group');
const groupNodes = nodes.filter((n) => n.type === 'group');
// Layout only regular nodes
const { nodes: layouted } = getLayoutedElements(regularNodes, edges);
// Combine back
return { nodes: [...groupNodes, ...layouted], edges };
};
```
## Troubleshooting
### Nodes Overlapping
Increase spacing:
```typescript
g.setGraph({
rankdir: 'TB',
nodesep: 100, // Increase horizontal spacing
ranksep: 100, // Increase vertical spacing
});
```
### Layout Not Updating
Ensure new array references:
```typescript
// Wrong - same reference
setNodes(layoutedNodes);
// Correct - new reference
setNodes([...layoutedNodes]);
```
### Nodes at Wrong Position
Check coordinate conversion:
```typescript
// Dagre returns center, React Flow needs top-left
position: {
x: pos.x - width / 2, // Not just pos.x
y: pos.y - height / 2, // Not just pos.y
}
```
### Performance with Large Graphs
- Layout in a Web Worker
- Debounce layout calls
- Use `useMemo` for layout function
- Only re-layout changed portions
## Configuration Reference
See [reference.md](reference.md) for complete dagre configuration options.
FILE:reference.md
# Dagre Configuration Reference
Complete configuration options for dagre layout algorithm.
## Graph-Level Options
Set via `g.setGraph(options)`:
| Option | Default | Description |
|--------|---------|-------------|
| `rankdir` | `'TB'` | Layout direction: `'TB'` (top-bottom), `'BT'` (bottom-top), `'LR'` (left-right), `'RL'` (right-left) |
| `align` | `undefined` | Node alignment within rank: `'UL'`, `'UR'`, `'DL'`, `'DR'`. `U`=up, `D`=down, `L`=left, `R`=right |
| `nodesep` | `50` | Horizontal spacing between nodes in same rank (pixels) |
| `edgesep` | `10` | Horizontal spacing between edges (pixels) |
| `ranksep` | `50` | Vertical spacing between ranks (pixels) |
| `marginx` | `0` | Horizontal margin around graph (pixels) |
| `marginy` | `0` | Vertical margin around graph (pixels) |
| `acyclicer` | `undefined` | Set to `'greedy'` for greedy cycle removal heuristic |
| `ranker` | `'network-simplex'` | Rank assignment algorithm: `'network-simplex'`, `'tight-tree'`, `'longest-path'` |
### Example
```typescript
g.setGraph({
rankdir: 'LR', // Horizontal layout
align: 'UL', // Align nodes to upper-left
nodesep: 80, // 80px horizontal spacing
ranksep: 100, // 100px between ranks
marginx: 20, // 20px horizontal margin
marginy: 20, // 20px vertical margin
ranker: 'tight-tree', // Faster ranking algorithm
});
```
### Ranker Algorithms
| Algorithm | Speed | Quality | Use Case |
|-----------|-------|---------|----------|
| `network-simplex` | Slower | Best | Default, optimal for most graphs |
| `tight-tree` | Fast | Good | Large graphs where speed matters |
| `longest-path` | Fastest | Acceptable | Very large graphs, quick preview |
## Node-Level Options
Set via `g.setNode(nodeId, options)`:
| Option | Default | Description |
|--------|---------|-------------|
| `width` | `0` | Node width in pixels (required for layout) |
| `height` | `0` | Node height in pixels (required for layout) |
### Output Properties
After `dagre.layout(g)`, each node gains:
| Property | Description |
|----------|-------------|
| `x` | Center x-coordinate |
| `y` | Center y-coordinate |
### Example
```typescript
// Setting node dimensions
g.setNode('node-1', { width: 200, height: 50 });
// After layout, reading position
dagre.layout(g);
const { x, y } = g.node('node-1'); // Center coordinates
```
## Edge-Level Options
Set via `g.setEdge(source, target, options)`:
| Option | Default | Description |
|--------|---------|-------------|
| `minlen` | `1` | Minimum number of ranks between source and target |
| `weight` | `1` | Edge weight for prioritization (higher = shorter path) |
| `width` | `0` | Edge label width in pixels |
| `height` | `0` | Edge label height in pixels |
| `labelpos` | `'r'` | Label position: `'l'` (left), `'c'` (center), `'r'` (right) |
| `labeloffset` | `10` | Pixels to offset label from edge |
### Output Properties
After `dagre.layout(g)`, each edge gains:
| Property | Description |
|----------|-------------|
| `points` | Array of `{x, y}` control points for edge path |
| `x` | Label center x-coordinate (if label dimensions set) |
| `y` | Label center y-coordinate (if label dimensions set) |
### Example
```typescript
// High priority edge (shorter path)
g.setEdge('a', 'b', { weight: 2 });
// Force separation of 3 ranks
g.setEdge('a', 'c', { minlen: 3 });
// Edge with label
g.setEdge('a', 'd', {
width: 50,
height: 20,
labelpos: 'c',
});
// After layout
dagre.layout(g);
const edge = g.edge('a', 'b');
console.log(edge.points); // [{x: 0, y: 0}, {x: 50, y: 50}, ...]
```
## Graph Methods
### Reading Graph State
```typescript
// Get all node IDs
const nodeIds = g.nodes(); // ['a', 'b', 'c']
// Get all edges
const edges = g.edges(); // [{v: 'a', w: 'b'}, ...]
// Check if node exists
g.hasNode('a'); // true/false
// Check if edge exists
g.hasEdge('a', 'b'); // true/false
// Get node data
g.node('a'); // { width: 100, height: 50, x: 200, y: 100 }
// Get edge data
g.edge('a', 'b'); // { points: [...], weight: 1 }
```
### Modifying Graph
```typescript
// Remove node (also removes connected edges)
g.removeNode('a');
// Remove edge
g.removeEdge('a', 'b');
// Get predecessors (nodes with edges TO this node)
g.predecessors('b'); // ['a']
// Get successors (nodes with edges FROM this node)
g.successors('a'); // ['b', 'c']
// Get all connected nodes (in + out)
g.neighbors('a'); // ['b', 'c', 'd']
```
## TypeScript Types
```typescript
import dagre from '@dagrejs/dagre';
interface GraphOptions {
rankdir?: 'TB' | 'BT' | 'LR' | 'RL';
align?: 'UL' | 'UR' | 'DL' | 'DR';
nodesep?: number;
edgesep?: number;
ranksep?: number;
marginx?: number;
marginy?: number;
acyclicer?: 'greedy';
ranker?: 'network-simplex' | 'tight-tree' | 'longest-path';
}
interface NodeOptions {
width: number;
height: number;
}
interface NodeOutput extends NodeOptions {
x: number;
y: number;
}
interface EdgeOptions {
minlen?: number;
weight?: number;
width?: number;
height?: number;
labelpos?: 'l' | 'c' | 'r';
labeloffset?: number;
}
interface EdgeOutput extends EdgeOptions {
points: Array<{ x: number; y: number }>;
x?: number; // If label dimensions set
y?: number; // If label dimensions set
}
```
## Performance Considerations
### Graph Size Guidelines
| Nodes | Performance | Recommendation |
|-------|-------------|----------------|
| < 100 | Fast | Use `network-simplex` |
| 100-500 | Moderate | Consider `tight-tree` |
| 500-1000 | Slow | Use `longest-path`, layout in worker |
| > 1000 | Very slow | Virtualize, paginate, or use WebGL renderer |
### Optimization Tips
1. **Reuse graph instance** when only positions change
2. **Layout in Web Worker** for graphs > 200 nodes
3. **Debounce layout calls** during rapid changes
4. **Cache layout results** for static portions
### Web Worker Example
```typescript
// layout.worker.ts
import dagre from '@dagrejs/dagre';
self.onmessage = (e) => {
const { nodes, edges, options } = e.data;
const g = new dagre.graphlib.Graph();
g.setGraph(options);
g.setDefaultEdgeLabel(() => ({}));
nodes.forEach((n) => g.setNode(n.id, { width: n.width, height: n.height }));
edges.forEach((e) => g.setEdge(e.source, e.target));
dagre.layout(g);
const positions = nodes.map((n) => ({
id: n.id,
x: g.node(n.id).x,
y: g.node(n.id).y,
}));
self.postMessage({ positions });
};
```
## Comparison with Alternatives
| Library | Best For | Bundle Size | Async |
|---------|----------|-------------|-------|
| **dagre** | Trees, hierarchies | ~30KB | No |
| **elkjs** | Complex constraints | ~150KB | Yes |
| **d3-hierarchy** | Pure trees only | ~10KB | No |
| **d3-force** | Organic layouts | ~15KB | Iterative |
Choose dagre when:
- Graph is hierarchical/tree-like
- Need simple, fast layouts
- Bundle size matters
- Don't need edge routing around nodes
## Animated Layout Transitions
Programmatic animation for smooth position changes:
```typescript
const animateLayout = (
currentNodes: Node[],
newNodes: Node[],
setNodes: (nodes: Node[]) => void,
duration = 300
) => {
const startPositions = new Map(
currentNodes.map((n) => [n.id, { ...n.position }])
);
const animate = (progress: number) => {
const interpolated = newNodes.map((node) => {
const start = startPositions.get(node.id);
if (!start) return node;
return {
...node,
position: {
x: start.x + (node.position.x - start.x) * progress,
y: start.y + (node.position.y - start.y) * progress,
},
};
});
setNodes(interpolated);
};
const startTime = Date.now();
const tick = () => {
const elapsed = Date.now() - startTime;
const progress = Math.min(elapsed / duration, 1);
// Ease-out curve
const eased = 1 - Math.pow(1 - progress, 3);
animate(eased);
if (progress < 1) requestAnimationFrame(tick);
};
tick();
};
// Usage
const onLayout = (direction: 'TB' | 'LR') => {
const { nodes: layouted } = getLayoutedElements(nodes, edges, { direction });
animateLayout(nodes, layouted, setNodes, 400);
};
```
Vercel AI Elements for workflow UI components. Use when building chat interfaces, displaying tool execution, showing reasoning/thinking, or creating job queu...
---
name: ai-elements
description: Vercel AI Elements for workflow UI components. Use when building chat interfaces, displaying tool execution, showing reasoning/thinking, or creating job queues. Triggers on ai-elements, Queue, Confirmation, Tool, Reasoning, Shimmer, Loader, Message, Conversation, PromptInput.
---
# AI Elements
AI Elements is a comprehensive React component library for building AI-powered user interfaces. The library provides 30+ components specifically designed for chat interfaces, tool execution visualization, reasoning displays, and workflow management.
## Installation
Install via shadcn registry:
```bash
npx shadcn@latest add https://ai-elements.vercel.app/r/[component-name]
```
**Import Pattern**: Components are imported from individual files, not a barrel export:
```tsx
// Correct - import from specific files
import { Conversation } from "@/components/ai-elements/conversation";
import { Message } from "@/components/ai-elements/message";
import { PromptInput } from "@/components/ai-elements/prompt-input";
// Incorrect - no barrel export
import { Conversation, Message } from "@/components/ai-elements";
```
## Gates (before relying on examples)
Use this sequence when adding or wiring AI Elements so setup is checkable, not assumed.
1. **Install each component** — Run `npx shadcn@latest add https://ai-elements.vercel.app/r/[component-name]` for every component you need.
- **Pass:** The command completes successfully and a file for that component exists under your project’s `components/ai-elements/` (or the directory `components.json` uses for those additions).
2. **Align import paths** — Every `import … from "@/components/ai-elements/..."` must match your repo’s actual alias and folder layout.
- **Pass:** Each import resolves to a file on disk (IDE navigation or build/`tsc` shows no “cannot find module” for those paths).
3. **Match `Tool` / `Confirmation` states to your AI SDK** — State strings (including approval-related states) depend on the installed `ai` package major/version.
- **Pass:** The states you pass to `ToolHeader`, `Tool`, or `Confirmation` are listed in the AI SDK version you have installed (docs or exported types), not copied from memory alone.
## Component Categories
### Conversation Components
Components for displaying chat-style interfaces with messages, attachments, and auto-scrolling behavior.
- **Conversation**: Container with auto-scroll capabilities
- **Message**: Individual message display with role-based styling
- **MessageAttachment**: File and image attachments
- **MessageBranch**: Alternative response navigation
See [references/conversation.md](references/conversation.md) for details.
### Prompt Input Components
Advanced text input with file attachments, drag-and-drop, speech input, and state management.
- **PromptInput**: Form container with file handling
- **PromptInputTextarea**: Auto-expanding textarea
- **PromptInputSubmit**: Status-aware submit button
- **PromptInputAttachments**: File attachment display
- **PromptInputProvider**: Global state management
See [references/prompt-input.md](references/prompt-input.md) for details.
### Workflow Components
Components for displaying job queues, tool execution, and approval workflows.
- **Queue**: Job queue container
- **QueueItem**: Individual queue items with status
- **Tool**: Tool execution display with collapsible states
- **Confirmation**: Approval workflow component
- **Reasoning**: Collapsible thinking/reasoning display
See [references/workflow.md](references/workflow.md) for details.
### Visualization Components
ReactFlow-based components for workflow visualization and custom node types.
- **Canvas**: ReactFlow wrapper with aviation-specific defaults
- **Node**: Custom node component with handles
- **Edge**: Temporary and Animated edge types
- **Controls, Panel, Toolbar**: Navigation and control elements
See [references/visualization.md](references/visualization.md) for details.
## Integration with shadcn/ui
AI Elements is built on top of shadcn/ui and integrates seamlessly with its theming system:
- Uses shadcn/ui's design tokens (colors, spacing, typography)
- Respects light/dark mode via CSS variables
- Compatible with shadcn/ui components (Button, Card, Collapsible, etc.)
- Follows shadcn/ui's component composition patterns
## Key Design Patterns
### Component Composition
AI Elements follows a composition-first approach where larger components are built from smaller primitives:
```tsx
<Tool>
<ToolHeader title="search" type="tool-call-search" state="output-available" />
<ToolContent>
<ToolInput input={{ query: "AI tools" }} />
<ToolOutput output={results} errorText={undefined} />
</ToolContent>
</Tool>
```
### Context-Based State
Many components use React Context for state management:
- `PromptInputProvider` for global input state
- `MessageBranch` for alternative response navigation
- `Confirmation` for approval workflow state
- `Reasoning` for collapsible thinking state
### Controlled vs Uncontrolled
Components support both controlled and uncontrolled patterns:
```tsx
// Uncontrolled (self-managed state)
<PromptInput onSubmit={handleSubmit} />
// Controlled (external state)
<PromptInputProvider initialInput="">
<PromptInput onSubmit={handleSubmit} />
</PromptInputProvider>
```
## Tool State Machine
The Tool component follows the Vercel AI SDK's state machine:
1. `input-streaming`: Parameters being received
2. `input-available`: Ready to execute
3. `approval-requested`: Awaiting user approval (SDK v6)
4. `approval-responded`: User responded (SDK v6)
5. `output-available`: Execution completed
6. `output-error`: Execution failed
7. `output-denied`: Approval denied
## Queue Patterns
Queue components support hierarchical organization:
```tsx
<Queue>
<QueueSection defaultOpen={true}>
<QueueSectionTrigger>
<QueueSectionLabel count={3} label="tasks" icon={<Icon />} />
</QueueSectionTrigger>
<QueueSectionContent>
<QueueList>
<QueueItem>
<QueueItemIndicator completed={false} />
<QueueItemContent>Task description</QueueItemContent>
</QueueItem>
</QueueList>
</QueueSectionContent>
</QueueSection>
</Queue>
```
## Auto-Scroll Behavior
The Conversation component uses the `use-stick-to-bottom` hook for intelligent auto-scrolling:
- Automatically scrolls to bottom when new messages arrive
- Pauses auto-scroll when user scrolls up
- Provides scroll-to-bottom button when not at bottom
- Supports smooth and instant scroll modes
## File Attachment Handling
PromptInput provides comprehensive file handling:
- Drag-and-drop support (local or global)
- Paste image/file support
- File type validation (accept prop)
- File size limits (maxFileSize prop)
- Maximum file count (maxFiles prop)
- Preview for images, icons for files
- Automatic blob URL to data URL conversion on submit
## Speech Input
The PromptInputSpeechButton uses the Web Speech API for voice input:
- Browser-based speech recognition
- Continuous recognition mode
- Interim results support
- Automatic text insertion into textarea
- Visual feedback during recording
## Reasoning Auto-Collapse
The Reasoning component provides auto-collapse behavior:
- Opens automatically when streaming starts
- Closes 1 second after streaming ends
- Tracks thinking duration in seconds
- Displays "Thinking..." with shimmer effect during streaming
- Shows "Thought for N seconds" when complete
## TypeScript Types
All components are fully typed with TypeScript:
```typescript
import type { ToolUIPart, FileUIPart, UIMessage } from "ai";
type ToolProps = ComponentProps<typeof Collapsible>;
type QueueItemProps = ComponentProps<"li">;
type MessageAttachmentProps = HTMLAttributes<HTMLDivElement> & {
data: FileUIPart;
onRemove?: () => void;
};
```
## Common Use Cases
### Chat Interface
Combine Conversation, Message, and PromptInput for a complete chat UI:
```tsx
import { Conversation, ConversationContent, ConversationScrollButton } from "@/components/ai-elements/conversation";
import { Message, MessageContent, MessageResponse } from "@/components/ai-elements/message";
import {
PromptInput,
PromptInputTextarea,
PromptInputFooter,
PromptInputTools,
PromptInputButton,
PromptInputSubmit
} from "@/components/ai-elements/prompt-input";
<div className="flex flex-col h-screen">
<Conversation>
<ConversationContent>
{messages.map(msg => (
<Message key={msg.id} from={msg.role}>
<MessageContent>
<MessageResponse>{msg.content}</MessageResponse>
</MessageContent>
</Message>
))}
</ConversationContent>
<ConversationScrollButton />
</Conversation>
<PromptInput onSubmit={handleSubmit}>
<PromptInputTextarea />
<PromptInputFooter>
<PromptInputTools>
<PromptInputButton onClick={() => attachments.openFileDialog()}>
<PaperclipIcon />
</PromptInputButton>
</PromptInputTools>
<PromptInputSubmit status={chatStatus} />
</PromptInputFooter>
</PromptInput>
</div>
```
### Tool Execution Display
Show tool execution with expandable details:
```tsx
import { Tool, ToolHeader, ToolContent, ToolInput, ToolOutput } from "@/components/ai-elements/tool";
{toolInvocations.map(tool => (
<Tool key={tool.id}>
<ToolHeader
title={tool.toolName}
type={`tool-call-tool.toolName`}
state={tool.state}
/>
<ToolContent>
<ToolInput input={tool.args} />
{tool.result && (
<ToolOutput output={tool.result} errorText={tool.error} />
)}
</ToolContent>
</Tool>
))}
```
### Approval Workflow
Request user confirmation before executing actions:
```tsx
import {
Confirmation,
ConfirmationTitle,
ConfirmationRequest,
ConfirmationActions,
ConfirmationAction,
ConfirmationAccepted,
ConfirmationRejected
} from "@/components/ai-elements/confirmation";
<Confirmation approval={tool.approval} state={tool.state}>
<ConfirmationTitle>
Approve deletion of {resource}?
</ConfirmationTitle>
<ConfirmationRequest>
<ConfirmationActions>
<ConfirmationAction onClick={approve} variant="default">
Approve
</ConfirmationAction>
<ConfirmationAction onClick={reject} variant="outline">
Reject
</ConfirmationAction>
</ConfirmationActions>
</ConfirmationRequest>
<ConfirmationAccepted>
Action approved and executed.
</ConfirmationAccepted>
<ConfirmationRejected>
Action rejected.
</ConfirmationRejected>
</Confirmation>
```
### Job Queue Management
Display task lists with completion status:
```tsx
import {
Queue,
QueueSection,
QueueSectionTrigger,
QueueSectionLabel,
QueueSectionContent,
QueueList,
QueueItem,
QueueItemIndicator,
QueueItemContent,
QueueItemDescription
} from "@/components/ai-elements/queue";
<Queue>
<QueueSection>
<QueueSectionTrigger>
<QueueSectionLabel count={todos.length} label="todos" />
</QueueSectionTrigger>
<QueueSectionContent>
<QueueList>
{todos.map(todo => (
<QueueItem key={todo.id}>
<QueueItemIndicator completed={todo.status === 'completed'} />
<QueueItemContent completed={todo.status === 'completed'}>
{todo.title}
</QueueItemContent>
{todo.description && (
<QueueItemDescription completed={todo.status === 'completed'}>
{todo.description}
</QueueItemDescription>
)}
</QueueItem>
))}
</QueueList>
</QueueSectionContent>
</QueueSection>
</Queue>
```
## Accessibility
Components include accessibility features:
- ARIA labels and roles
- Keyboard navigation support
- Screen reader announcements
- Focus management
- Semantic HTML elements
## Animation
Many components use Framer Motion for smooth animations:
- Shimmer effect for loading states
- Collapsible content transitions
- Edge animations in Canvas
- Loader spinner rotation
## References
- [Conversation Components](references/conversation.md)
- [Prompt Input Components](references/prompt-input.md)
- [Workflow Components](references/workflow.md)
- [Visualization Components](references/visualization.md)
FILE:references/conversation.md
# Conversation Components
Components for building chat-style interfaces with messages, attachments, and intelligent auto-scrolling.
## Core Components
### Conversation
Container component that wraps the entire conversation area with auto-scroll functionality.
```typescript
type ConversationProps = ComponentProps<typeof StickToBottom>;
```
**Props:**
- `className?: string` - Additional CSS classes
- `initial?: "smooth" | "auto"` - Initial scroll behavior (default: "smooth")
- `resize?: "smooth" | "auto"` - Scroll behavior on resize (default: "smooth")
**Usage:**
```tsx
<Conversation className="flex-1 overflow-y-hidden">
<ConversationContent>
{/* Messages go here */}
</ConversationContent>
<ConversationScrollButton />
</Conversation>
```
**Features:**
- Uses `use-stick-to-bottom` for intelligent scrolling
- Automatically scrolls to bottom when new messages arrive
- Pauses auto-scroll when user scrolls up manually
- Provides context for scroll state to child components
- Sets `role="log"` for accessibility
### ConversationContent
Content area for messages within the conversation.
```typescript
type ConversationContentProps = ComponentProps<typeof StickToBottom.Content>;
```
**Usage:**
```tsx
<ConversationContent className="flex flex-col gap-8 p-4">
{messages.map(message => (
<Message key={message.id} from={message.role}>
{/* Message content */}
</Message>
))}
</ConversationContent>
```
**Default Styling:**
- Flexbox column layout with gap
- Padding for content separation
### ConversationEmptyState
Placeholder shown when there are no messages.
```typescript
type ConversationEmptyStateProps = ComponentProps<"div"> & {
title?: string;
description?: string;
icon?: React.ReactNode;
};
```
**Props:**
- `title?: string` - Heading text (default: "No messages yet")
- `description?: string` - Descriptive text (default: "Start a conversation to see messages here")
- `icon?: React.ReactNode` - Icon to display above text
- `children?: React.ReactNode` - Custom content (overrides default)
**Usage:**
```tsx
{messages.length === 0 ? (
<ConversationEmptyState
title="Welcome!"
description="Ask me anything to get started"
icon={<MessageSquareIcon className="size-12" />}
/>
) : (
<ConversationContent>
{/* Messages */}
</ConversationContent>
)}
```
### ConversationScrollButton
Button that appears when user is not at the bottom of the conversation, allowing quick navigation to latest messages.
```typescript
type ConversationScrollButtonProps = ComponentProps<typeof Button>;
```
**Usage:**
```tsx
<Conversation>
<ConversationContent>
{/* Messages */}
</ConversationContent>
<ConversationScrollButton />
</Conversation>
```
**Behavior:**
- Only visible when `isAtBottom` is false
- Positioned at bottom center of conversation
- Calls `scrollToBottom()` on click
- Uses `ArrowDownIcon` by default
## Message Components
### Message
Container for an individual message with role-based styling.
```typescript
type MessageProps = HTMLAttributes<HTMLDivElement> & {
from: UIMessage["role"]; // "user" | "assistant"
};
```
**Props:**
- `from: "user" | "assistant"` - Message sender role
- `className?: string` - Additional CSS classes
- Standard HTML div attributes
**Usage:**
```tsx
<Message from="assistant">
<MessageContent>
<MessageResponse>{content}</MessageResponse>
</MessageContent>
<MessageActions>
<MessageAction tooltip="Copy" onClick={copyToClipboard}>
<CopyIcon />
</MessageAction>
</MessageActions>
</Message>
```
**Styling:**
- User messages: right-aligned, max-width 80%
- Assistant messages: left-aligned, max-width 80%
- Adds `is-user` or `is-assistant` class for context-specific styling
### MessageContent
Content area for message text and media.
```typescript
type MessageContentProps = HTMLAttributes<HTMLDivElement>;
```
**Usage:**
```tsx
<MessageContent>
<MessageResponse>{text}</MessageResponse>
</MessageContent>
```
**Styling:**
- User messages: rounded background with secondary color
- Assistant messages: plain text styling
- Flexbox column layout for multiple content types
### MessageResponse
Renders markdown/text content with streaming support.
```typescript
type MessageResponseProps = ComponentProps<typeof Streamdown>;
```
**Usage:**
```tsx
<MessageResponse>
{message.content}
</MessageResponse>
```
**Features:**
- Uses `Streamdown` for markdown rendering
- Memoized to prevent unnecessary re-renders
- Supports streaming text updates
- Removes default margin from first/last children
### MessageActions
Container for action buttons (copy, edit, regenerate, etc.).
```typescript
type MessageActionsProps = ComponentProps<"div">;
```
**Usage:**
```tsx
<MessageActions>
<MessageAction tooltip="Copy" onClick={handleCopy}>
<CopyIcon />
</MessageAction>
<MessageAction tooltip="Regenerate" onClick={handleRegenerate}>
<RefreshIcon />
</MessageAction>
</MessageActions>
```
### MessageAction
Individual action button with optional tooltip.
```typescript
type MessageActionProps = ComponentProps<typeof Button> & {
tooltip?: string;
label?: string;
};
```
**Props:**
- `tooltip?: string` - Tooltip text shown on hover
- `label?: string` - Accessible label (falls back to tooltip)
- All Button component props
**Usage:**
```tsx
<MessageAction
tooltip="Copy to clipboard"
onClick={handleCopy}
variant="ghost"
size="icon-sm"
>
<CopyIcon className="size-4" />
</MessageAction>
```
## Message Branching
### MessageBranch
Container for managing alternative message responses with navigation.
```typescript
type MessageBranchProps = HTMLAttributes<HTMLDivElement> & {
defaultBranch?: number;
onBranchChange?: (branchIndex: number) => void;
};
```
**Props:**
- `defaultBranch?: number` - Initial branch index (default: 0)
- `onBranchChange?: (index: number) => void` - Callback when branch changes
**Usage:**
```tsx
<MessageBranch defaultBranch={0} onBranchChange={handleBranchChange}>
<MessageBranchContent>
<MessageResponse key="1">{response1}</MessageResponse>
<MessageResponse key="2">{response2}</MessageResponse>
<MessageResponse key="3">{response3}</MessageResponse>
</MessageBranchContent>
<MessageBranchSelector from="assistant">
<MessageBranchPrevious />
<MessageBranchPage />
<MessageBranchNext />
</MessageBranchSelector>
</MessageBranch>
```
**Context:**
Provides context with:
- `currentBranch: number` - Current branch index
- `totalBranches: number` - Total number of branches
- `goToPrevious: () => void` - Navigate to previous branch
- `goToNext: () => void` - Navigate to next branch
### MessageBranchContent
Displays the current branch content, hiding others.
```typescript
type MessageBranchContentProps = HTMLAttributes<HTMLDivElement>;
```
**Behavior:**
- Automatically manages branch visibility
- Updates when children change
- Preserves all branches in DOM (display: none for hidden)
### MessageBranchSelector
Container for branch navigation controls.
```typescript
type MessageBranchSelectorProps = HTMLAttributes<HTMLDivElement> & {
from: UIMessage["role"];
};
```
**Behavior:**
- Only renders if `totalBranches > 1`
- Uses ButtonGroup for grouped appearance
### MessageBranchPrevious
Button to navigate to previous branch.
```typescript
type MessageBranchPreviousProps = ComponentProps<typeof Button>;
```
**Behavior:**
- Wraps around (last branch → first branch)
- Disabled if only one branch exists
- Default icon: `ChevronLeftIcon`
### MessageBranchNext
Button to navigate to next branch.
```typescript
type MessageBranchNextProps = ComponentProps<typeof Button>;
```
**Behavior:**
- Wraps around (first branch → last branch)
- Disabled if only one branch exists
- Default icon: `ChevronRightIcon`
### MessageBranchPage
Displays current branch number and total.
```typescript
type MessageBranchPageProps = HTMLAttributes<HTMLSpanElement>;
```
**Display:**
Shows "1 of 3", "2 of 3", etc.
## Attachment Components
### MessageAttachment
Displays a file or image attachment with optional remove button.
```typescript
type MessageAttachmentProps = HTMLAttributes<HTMLDivElement> & {
data: FileUIPart;
className?: string;
onRemove?: () => void;
};
```
**Props:**
- `data: FileUIPart` - Attachment data (url, filename, mediaType)
- `onRemove?: () => void` - Callback to remove attachment
**Usage:**
```tsx
<MessageAttachment
data={{
type: "file",
url: "blob:...",
filename: "document.pdf",
mediaType: "application/pdf"
}}
onRemove={() => removeAttachment(id)}
/>
```
**Behavior:**
- Images: Shows thumbnail preview
- Files: Shows paperclip icon
- Hover: Shows remove button (if onRemove provided)
- Tooltip: Displays filename on non-image files
### MessageAttachments
Container for multiple attachments.
```typescript
type MessageAttachmentsProps = ComponentProps<"div">;
```
**Usage:**
```tsx
<MessageAttachments>
{attachments.map(attachment => (
<MessageAttachment key={attachment.id} data={attachment} />
))}
</MessageAttachments>
```
**Styling:**
- Flexbox wrap layout
- Right-aligned (ml-auto)
- Gap between items
## MessageToolbar
Container for toolbar elements below message content.
```typescript
type MessageToolbarProps = ComponentProps<"div">;
```
**Usage:**
```tsx
<MessageToolbar>
<div className="flex items-center gap-2">
<span className="text-xs text-muted-foreground">{timestamp}</span>
</div>
<MessageActions>
{/* Action buttons */}
</MessageActions>
</MessageToolbar>
```
## Complete Example
```tsx
import {
Conversation,
ConversationContent,
ConversationEmptyState,
ConversationScrollButton,
} from "@/components/ai-elements/conversation";
import {
Message,
MessageContent,
MessageResponse,
MessageActions,
MessageAction,
MessageAttachments,
MessageAttachment,
MessageBranch,
MessageBranchContent,
MessageBranchSelector,
MessageBranchPrevious,
MessageBranchNext,
MessageBranchPage,
} from "@/components/ai-elements/message";
function ChatInterface({ messages }: { messages: UIMessage[] }) {
return (
<Conversation className="flex-1">
{messages.length === 0 ? (
<ConversationEmptyState
title="Start a conversation"
description="Ask me anything!"
/>
) : (
<ConversationContent className="p-4">
{messages.map(message => (
<Message key={message.id} from={message.role}>
{message.attachments && (
<MessageAttachments>
{message.attachments.map(att => (
<MessageAttachment key={att.url} data={att} />
))}
</MessageAttachments>
)}
<MessageContent>
{message.branches ? (
<MessageBranch>
<MessageBranchContent>
{message.branches.map((branch, idx) => (
<MessageResponse key={idx}>{branch}</MessageResponse>
))}
</MessageBranchContent>
<MessageBranchSelector from={message.role}>
<MessageBranchPrevious />
<MessageBranchPage />
<MessageBranchNext />
</MessageBranchSelector>
</MessageBranch>
) : (
<MessageResponse>{message.content}</MessageResponse>
)}
</MessageContent>
<MessageActions>
<MessageAction tooltip="Copy" onClick={() => copy(message)}>
<CopyIcon />
</MessageAction>
<MessageAction tooltip="Regenerate" onClick={() => regenerate(message)}>
<RefreshIcon />
</MessageAction>
</MessageActions>
</Message>
))}
</ConversationContent>
)}
<ConversationScrollButton />
</Conversation>
);
}
```
FILE:references/prompt-input.md
# Prompt Input Components
Advanced text input components with file attachments, drag-and-drop, speech input, and comprehensive state management.
## Table of Contents
- [Core Components](#core-components)
- [State Management](#state-management)
- [Attachment Handling](#attachment-handling)
- [Action Menus](#action-menus)
- [Submit Button](#submit-button)
- [Speech Input](#speech-input)
- [Advanced Features](#advanced-features)
- [Complete Example](#complete-example)
## Core Components
### PromptInput
Main form container that handles text input, file attachments, and submission.
```typescript
type PromptInputProps = Omit<HTMLAttributes<HTMLFormElement>, "onSubmit" | "onError"> & {
accept?: string;
multiple?: boolean;
globalDrop?: boolean;
syncHiddenInput?: boolean;
maxFiles?: number;
maxFileSize?: number;
onError?: (err: { code: "max_files" | "max_file_size" | "accept"; message: string }) => void;
onSubmit: (message: PromptInputMessage, event: FormEvent<HTMLFormElement>) => void | Promise<void>;
};
type PromptInputMessage = {
text: string;
files: FileUIPart[];
};
```
**Props:**
- `accept?: string` - File type filter (e.g., "image/*")
- `multiple?: boolean` - Allow multiple file selection
- `globalDrop?: boolean` - Accept drops anywhere on document (default: false)
- `syncHiddenInput?: boolean` - Keep hidden input in sync (default: false)
- `maxFiles?: number` - Maximum number of files
- `maxFileSize?: number` - Maximum file size in bytes
- `onError?: (err) => void` - Error handler for file validation
- `onSubmit: (message, event) => void | Promise<void>` - Submit handler (required)
**Usage:**
```tsx
<PromptInput
accept="image/*"
multiple
maxFiles={5}
maxFileSize={10 * 1024 * 1024} // 10MB
onError={(err) => toast.error(err.message)}
onSubmit={async (message) => {
await sendMessage(message.text, message.files);
}}
>
<PromptInputAttachments>
{(attachment) => <PromptInputAttachment data={attachment} />}
</PromptInputAttachments>
<PromptInputBody>
<PromptInputTextarea placeholder="Type a message..." />
</PromptInputBody>
<PromptInputFooter>
<PromptInputTools>
<PromptInputButton onClick={() => attachments.openFileDialog()}>
<PaperclipIcon />
</PromptInputButton>
</PromptInputTools>
<PromptInputSubmit status={chatStatus} />
</PromptInputFooter>
</PromptInput>
```
**Features:**
- Dual-mode operation (controlled/uncontrolled)
- Drag-and-drop file handling (local or global)
- Paste image/file support
- File validation (type, size, count)
- Automatic blob URL to data URL conversion
- Async/sync onSubmit support
- Auto-reset on successful submission
### PromptInputBody
Container for the main input area.
```typescript
type PromptInputBodyProps = HTMLAttributes<HTMLDivElement>;
```
**Usage:**
```tsx
<PromptInputBody>
<PromptInputTextarea />
</PromptInputBody>
```
### PromptInputTextarea
Auto-expanding textarea with keyboard shortcuts and paste handling.
```typescript
type PromptInputTextareaProps = ComponentProps<typeof InputGroupTextarea>;
```
**Props:**
- `placeholder?: string` - Placeholder text (default: "What would you like to know?")
- Standard textarea props
**Usage:**
```tsx
<PromptInputTextarea
placeholder="Ask me anything..."
className="max-h-48 min-h-16"
/>
```
**Keyboard Shortcuts:**
- `Enter` - Submit (without Shift)
- `Shift+Enter` - New line
- `Backspace` - Remove last attachment when textarea is empty
**Features:**
- Auto-expands with content (field-sizing-content)
- Paste image/file support
- Composition event handling (for IME)
- Respects submit button disabled state
- Controlled/uncontrolled dual-mode
### PromptInputHeader
Header section for additional controls above the textarea.
```typescript
type PromptInputHeaderProps = Omit<ComponentProps<typeof InputGroupAddon>, "align">;
```
**Usage:**
```tsx
<PromptInputHeader>
<PromptInputSelect>
<PromptInputSelectTrigger>
<PromptInputSelectValue placeholder="Select model" />
</PromptInputSelectTrigger>
<PromptInputSelectContent>
<PromptInputSelectItem value="gpt-4">GPT-4</PromptInputSelectItem>
<PromptInputSelectItem value="claude">Claude</PromptInputSelectItem>
</PromptInputSelectContent>
</PromptInputSelect>
</PromptInputHeader>
```
### PromptInputFooter
Footer section for tools and submit button.
```typescript
type PromptInputFooterProps = Omit<ComponentProps<typeof InputGroupAddon>, "align">;
```
**Usage:**
```tsx
<PromptInputFooter>
<PromptInputTools>
{/* Tool buttons */}
</PromptInputTools>
<PromptInputSubmit status="ready" />
</PromptInputFooter>
```
### PromptInputTools
Container for tool buttons in the footer.
```typescript
type PromptInputToolsProps = HTMLAttributes<HTMLDivElement>;
```
**Usage:**
```tsx
<PromptInputTools>
<PromptInputButton onClick={openFileDialog}>
<PaperclipIcon />
</PromptInputButton>
<PromptInputSpeechButton textareaRef={textareaRef} />
</PromptInputTools>
```
### PromptInputButton
Generic button for actions and tools.
```typescript
type PromptInputButtonProps = ComponentProps<typeof InputGroupButton>;
```
**Props:**
- `variant?: ButtonVariant` - Button style (default: "ghost")
- `size?: ButtonSize` - Button size (auto-determined based on children)
**Usage:**
```tsx
<PromptInputButton onClick={handleAction}>
<IconComponent />
</PromptInputButton>
<PromptInputButton onClick={handleAction}>
<IconComponent />
Label
</PromptInputButton>
```
## State Management
### PromptInputProvider
Optional global provider that lifts input and attachment state outside of PromptInput.
```typescript
type PromptInputProviderProps = PropsWithChildren<{
initialInput?: string;
}>;
```
**Usage:**
```tsx
<PromptInputProvider initialInput="">
{/* App content */}
<PromptInput onSubmit={handleSubmit}>
{/* Input content */}
</PromptInput>
{/* External components can access state */}
<ExternalComponent />
</PromptInputProvider>
```
**Provides:**
- `textInput: TextInputContext` - Text state and setters
- `attachments: AttachmentsContext` - File state and methods
- `__registerFileInput` - Internal registration method
### usePromptInputController
Hook to access the provider state.
```typescript
const usePromptInputController = () => {
const { textInput, attachments } = usePromptInputController();
return {
textInput: {
value: string;
setInput: (v: string) => void;
clear: () => void;
},
attachments: {
files: (FileUIPart & { id: string })[];
add: (files: File[] | FileList) => void;
remove: (id: string) => void;
clear: () => void;
openFileDialog: () => void;
fileInputRef: RefObject<HTMLInputElement | null>;
}
};
};
```
**Usage:**
```tsx
function ExternalComponent() {
const { textInput, attachments } = usePromptInputController();
return (
<div>
<p>Current input: {textInput.value}</p>
<p>Attachments: {attachments.files.length}</p>
<Button onClick={() => textInput.clear()}>Clear</Button>
</div>
);
}
```
### useProviderAttachments
Hook to access attachment state from provider.
```typescript
const useProviderAttachments = () => AttachmentsContext;
```
### usePromptInputAttachments
Hook to access attachment state (dual-mode: provider or local).
```typescript
const usePromptInputAttachments = () => AttachmentsContext;
```
## Attachment Handling
### PromptInputAttachments
Container for rendering attachments.
```typescript
type PromptInputAttachmentsProps = Omit<HTMLAttributes<HTMLDivElement>, "children"> & {
children: (attachment: FileUIPart & { id: string }) => ReactNode;
};
```
**Usage:**
```tsx
<PromptInputAttachments>
{(attachment) => (
<PromptInputAttachment data={attachment} />
)}
</PromptInputAttachments>
```
**Behavior:**
- Only renders if attachments exist
- Uses render prop pattern for flexibility
### PromptInputAttachment
Individual attachment display with preview and remove button.
```typescript
type PromptInputAttachmentProps = HTMLAttributes<HTMLDivElement> & {
data: FileUIPart & { id: string };
className?: string;
};
```
**Usage:**
```tsx
<PromptInputAttachment
data={{
id: "abc123",
type: "file",
url: "blob:...",
filename: "image.png",
mediaType: "image/png"
}}
/>
```
**Features:**
- Image preview for image/* media types
- Paperclip icon for other files
- Hover to reveal remove button
- Hover card with full preview
- Truncated filename display
### PromptInputHoverCard
Hover card for attachment preview.
```typescript
type PromptInputHoverCardProps = ComponentProps<typeof HoverCard>;
```
**Default Delays:**
- `openDelay: 0` - Instant open
- `closeDelay: 0` - Instant close
### PromptInputHoverCardContent
Content area for hover card preview.
```typescript
type PromptInputHoverCardContentProps = ComponentProps<typeof HoverCardContent>;
```
## Action Menus
### PromptInputActionMenu
Dropdown menu for additional actions.
```typescript
type PromptInputActionMenuProps = ComponentProps<typeof DropdownMenu>;
```
**Usage:**
```tsx
<PromptInputActionMenu>
<PromptInputActionMenuTrigger>
<PlusIcon />
</PromptInputActionMenuTrigger>
<PromptInputActionMenuContent>
<PromptInputActionAddAttachments label="Add files" />
<PromptInputActionMenuItem>
<SettingsIcon className="mr-2 size-4" />
Settings
</PromptInputActionMenuItem>
</PromptInputActionMenuContent>
</PromptInputActionMenu>
```
### PromptInputActionMenuTrigger
Trigger button for action menu.
```typescript
type PromptInputActionMenuTriggerProps = PromptInputButtonProps;
```
**Default Icon:** `PlusIcon`
### PromptInputActionMenuContent
Content area for menu items.
```typescript
type PromptInputActionMenuContentProps = ComponentProps<typeof DropdownMenuContent>;
```
**Default Alignment:** `align="start"`
### PromptInputActionMenuItem
Individual menu item.
```typescript
type PromptInputActionMenuItemProps = ComponentProps<typeof DropdownMenuItem>;
```
### PromptInputActionAddAttachments
Pre-built menu item for adding attachments.
```typescript
type PromptInputActionAddAttachmentsProps = ComponentProps<typeof DropdownMenuItem> & {
label?: string;
};
```
**Props:**
- `label?: string` - Button label (default: "Add photos or files")
**Usage:**
```tsx
<PromptInputActionMenuContent>
<PromptInputActionAddAttachments label="Upload images" />
</PromptInputActionMenuContent>
```
## Submit Button
### PromptInputSubmit
Status-aware submit button with dynamic icons.
```typescript
type PromptInputSubmitProps = ComponentProps<typeof InputGroupButton> & {
status?: ChatStatus; // "submitted" | "streaming" | "error" | "ready"
};
```
**Status Icons:**
- `undefined` / `"ready"` - `CornerDownLeftIcon` (enter key)
- `"submitted"` - `Loader2Icon` (spinning)
- `"streaming"` - `SquareIcon` (stop)
- `"error"` - `XIcon` (error)
**Usage:**
```tsx
<PromptInputSubmit status={chatStatus} />
```
**Behavior:**
- `type="submit"` - Triggers form submission
- `aria-label="Submit"` - Accessible label
## Speech Input
### PromptInputSpeechButton
Voice input button using Web Speech API.
```typescript
type PromptInputSpeechButtonProps = ComponentProps<typeof PromptInputButton> & {
textareaRef?: RefObject<HTMLTextAreaElement | null>;
onTranscriptionChange?: (text: string) => void;
};
```
**Props:**
- `textareaRef?: RefObject` - Reference to textarea for text insertion
- `onTranscriptionChange?: (text: string) => void` - Callback when text changes
**Usage:**
```tsx
const textareaRef = useRef<HTMLTextAreaElement>(null);
<PromptInputTextarea ref={textareaRef} />
<PromptInputSpeechButton
textareaRef={textareaRef}
onTranscriptionChange={(text) => console.log("Transcribed:", text)}
/>
```
**Features:**
- Browser-based speech recognition
- Continuous recording mode
- Interim results support
- Automatic text insertion
- Visual feedback (pulse animation when listening)
- Disabled if browser doesn't support Speech Recognition
- Error handling
## Advanced Features
### Select Components
For model selection or other dropdowns.
```tsx
<PromptInputSelect value={model} onValueChange={setModel}>
<PromptInputSelectTrigger>
<PromptInputSelectValue placeholder="Select model" />
</PromptInputSelectTrigger>
<PromptInputSelectContent>
<PromptInputSelectItem value="gpt-4">GPT-4</PromptInputSelectItem>
<PromptInputSelectItem value="claude">Claude</PromptInputSelectItem>
</PromptInputSelectContent>
</PromptInputSelect>
```
### Command Components
For slash commands or autocomplete.
```tsx
<PromptInputCommand>
<PromptInputCommandInput placeholder="Search commands..." />
<PromptInputCommandList>
<PromptInputCommandEmpty>No commands found</PromptInputCommandEmpty>
<PromptInputCommandGroup heading="Actions">
<PromptInputCommandItem value="summarize">
Summarize
</PromptInputCommandItem>
<PromptInputCommandItem value="translate">
Translate
</PromptInputCommandItem>
</PromptInputCommandGroup>
<PromptInputCommandSeparator />
<PromptInputCommandGroup heading="Settings">
<PromptInputCommandItem value="preferences">
Preferences
</PromptInputCommandItem>
</PromptInputCommandGroup>
</PromptInputCommandList>
</PromptInputCommand>
```
### Tab Components
For organizing input modes or templates.
```tsx
<PromptInputTabsList>
<PromptInputTab>
<PromptInputTabLabel>Templates</PromptInputTabLabel>
<PromptInputTabBody>
<PromptInputTabItem>Summarize article</PromptInputTabItem>
<PromptInputTabItem>Write email</PromptInputTabItem>
</PromptInputTabBody>
</PromptInputTab>
</PromptInputTabsList>
```
## Complete Example
```tsx
import { useRef, useState } from "react";
import {
PromptInput,
PromptInputProvider,
PromptInputAttachments,
PromptInputAttachment,
PromptInputBody,
PromptInputTextarea,
PromptInputFooter,
PromptInputTools,
PromptInputButton,
PromptInputSpeechButton,
PromptInputSubmit,
PromptInputActionMenu,
PromptInputActionMenuTrigger,
PromptInputActionMenuContent,
PromptInputActionAddAttachments,
usePromptInputAttachments,
} from "@/components/ai-elements/prompt-input";
import { PaperclipIcon, PlusIcon } from "lucide-react";
function ChatInput() {
const [status, setStatus] = useState<ChatStatus>("ready");
const textareaRef = useRef<HTMLTextAreaElement>(null);
const attachments = usePromptInputAttachments();
const handleSubmit = async (message: PromptInputMessage) => {
setStatus("submitted");
try {
await sendMessage(message.text, message.files);
setStatus("ready");
} catch (error) {
setStatus("error");
}
};
return (
<PromptInput
accept="image/*,.pdf,.doc,.docx"
multiple
maxFiles={10}
maxFileSize={10 * 1024 * 1024}
onError={(err) => toast.error(err.message)}
onSubmit={handleSubmit}
>
<PromptInputAttachments>
{(attachment) => <PromptInputAttachment data={attachment} />}
</PromptInputAttachments>
<PromptInputBody>
<PromptInputTextarea
ref={textareaRef}
placeholder="Type a message..."
/>
</PromptInputBody>
<PromptInputFooter>
<PromptInputTools>
<PromptInputButton onClick={() => attachments.openFileDialog()}>
<PaperclipIcon className="size-4" />
</PromptInputButton>
<PromptInputSpeechButton textareaRef={textareaRef} />
<PromptInputActionMenu>
<PromptInputActionMenuTrigger>
<PlusIcon className="size-4" />
</PromptInputActionMenuTrigger>
<PromptInputActionMenuContent>
<PromptInputActionAddAttachments />
<PromptInputActionMenuItem>
Insert template
</PromptInputActionMenuItem>
</PromptInputActionMenuContent>
</PromptInputActionMenu>
</PromptInputTools>
<PromptInputSubmit status={status} />
</PromptInputFooter>
</PromptInput>
);
}
// With global provider
function App() {
return (
<PromptInputProvider initialInput="">
<ChatInput />
</PromptInputProvider>
);
}
```
FILE:references/visualization.md
# Visualization Components
ReactFlow-based components for workflow visualization, custom nodes, and animated edges.
## Core Components
### Canvas
ReactFlow wrapper with aviation-specific defaults and background.
```typescript
type CanvasProps = ReactFlowProps & {
children?: ReactNode;
};
```
**Usage:**
```tsx
import { Canvas } from "@/components/ai-elements/canvas";
import { Background, Controls, Panel } from "@xyflow/react";
import "@xyflow/react/dist/style.css";
<Canvas
nodes={nodes}
edges={edges}
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
>
<Background bgColor="var(--sidebar)" />
<Controls />
<Panel position="top-left">
<h3>Workflow</h3>
</Panel>
</Canvas>
```
**Default Props:**
- `deleteKeyCode: ["Backspace", "Delete"]` - Keys to delete selected elements
- `fitView: true` - Automatically fit content in viewport
- `panOnDrag: false` - Disable pan on drag (use selection instead)
- `panOnScroll: true` - Enable pan on scroll
- `selectionOnDrag: true` - Enable selection box on drag
- `zoomOnDoubleClick: false` - Disable zoom on double click
**Features:**
- Includes Background component with sidebar color
- Accepts all ReactFlow props
- Optimized for workflow visualization
## Node Components
Custom node components built on shadcn/ui Card components.
### Node
Main node container with connection handles.
```typescript
type NodeProps = ComponentProps<typeof Card> & {
handles: {
target: boolean;
source: boolean;
};
};
```
**Props:**
- `handles: { target: boolean; source: boolean }` - Which handles to show (required)
- All Card component props
**Usage:**
```tsx
<Node handles={{ target: true, source: true }}>
<NodeHeader>
<NodeTitle>Process Data</NodeTitle>
<NodeDescription>Transform input data</NodeDescription>
</NodeHeader>
<NodeContent>
{/* Node content */}
</NodeContent>
<NodeFooter>
Status: Running
</NodeFooter>
</Node>
```
**Handle Positions:**
- Target: Left side (incoming connections)
- Source: Right side (outgoing connections)
**Default Styling:**
- Relative positioning for handles
- Auto height
- Fixed width (w-sm)
- Rounded corners
### NodeHeader
Header section with border bottom and secondary background.
```typescript
type NodeHeaderProps = ComponentProps<typeof CardHeader>;
```
**Usage:**
```tsx
<NodeHeader>
<NodeTitle>Step 1</NodeTitle>
<NodeDescription>Initial processing</NodeDescription>
<NodeAction onClick={handleEdit}>
<EditIcon />
</NodeAction>
</NodeHeader>
```
**Default Styling:**
- Rounded top corners
- Border bottom
- Secondary background
- Compact padding (p-3)
### NodeTitle
Title text for node header.
```typescript
type NodeTitleProps = ComponentProps<typeof CardTitle>;
```
**Usage:**
```tsx
<NodeTitle>Data Processing</NodeTitle>
```
### NodeDescription
Secondary description text.
```typescript
type NodeDescriptionProps = ComponentProps<typeof CardDescription>;
```
**Usage:**
```tsx
<NodeDescription>
Transform and validate input data
</NodeDescription>
```
### NodeAction
Action button in header (typically edit/delete).
```typescript
type NodeActionProps = ComponentProps<typeof CardAction>;
```
**Usage:**
```tsx
<NodeAction onClick={handleEdit}>
<EditIcon className="size-4" />
</NodeAction>
```
### NodeContent
Main content area of the node.
```typescript
type NodeContentProps = ComponentProps<typeof CardContent>;
```
**Usage:**
```tsx
<NodeContent>
<div className="space-y-2">
<Label>Input</Label>
<Input value={data.input} readOnly />
<Label>Output</Label>
<Input value={data.output} readOnly />
</div>
</NodeContent>
```
**Default Styling:**
- Compact padding (p-3)
### NodeFooter
Footer section with border top and secondary background.
```typescript
type NodeFooterProps = ComponentProps<typeof CardFooter>;
```
**Usage:**
```tsx
<NodeFooter>
<Badge variant="success">Completed</Badge>
<span className="text-xs text-muted-foreground">
Duration: 1.2s
</span>
</NodeFooter>
```
**Default Styling:**
- Rounded bottom corners
- Border top
- Secondary background
- Compact padding (p-3)
## Edge Components
Custom edge types for different connection styles.
### Edge.Temporary
Dashed edge for temporary or preview connections.
```typescript
type TemporaryEdgeProps = EdgeProps;
```
**Usage:**
```tsx
const edgeTypes = {
temporary: Edge.Temporary,
};
<Canvas edges={edges} edgeTypes={edgeTypes} />
```
**Features:**
- Simple Bezier curve
- Dashed stroke (5, 5)
- Ring color stroke
- Stroke width: 1
**Use Cases:**
- Drag preview connections
- Temporary workflow paths
- Suggested connections
### Edge.Animated
Animated edge with moving dot indicator.
```typescript
type AnimatedEdgeProps = EdgeProps;
```
**Usage:**
```tsx
const edgeTypes = {
animated: Edge.Animated,
};
<Canvas edges={edges} edgeTypes={edgeTypes} />
```
**Features:**
- Bezier curve path
- Animated circle following edge path
- 2-second animation duration
- Infinite repeat
- Primary color dot (4px radius)
**Use Cases:**
- Active data flow
- Processing pipelines
- Real-time connections
**Implementation Details:**
- Uses `useInternalNode` to get node positions
- Calculates handle coordinates based on position
- Supports Left (target) and Right (source) handle positions
- Uses `getBezierPath` for smooth curves
## ReactFlow Integration
### Controls
Standard ReactFlow controls (zoom, fit view, etc.).
```typescript
import { Controls } from "@xyflow/react";
```
**Usage:**
```tsx
<Canvas>
<Controls />
</Canvas>
```
### Panel
Panel for custom UI overlays.
```typescript
import { Panel } from "@xyflow/react";
```
**Usage:**
```tsx
<Canvas>
<Panel position="top-left">
<h3>Workflow Name</h3>
<p>Status: Running</p>
</Panel>
<Panel position="bottom-right">
<Button onClick={handleSave}>Save</Button>
</Panel>
</Canvas>
```
**Positions:**
- `top-left`, `top-center`, `top-right`
- `bottom-left`, `bottom-center`, `bottom-right`
### Background
Background pattern for the canvas.
```typescript
import { Background } from "@xyflow/react";
```
**Usage:**
```tsx
<Canvas>
<Background bgColor="var(--sidebar)" />
</Canvas>
```
## Custom Node Types
Example of creating custom node types with aviation-specific styling.
```tsx
import { Node, NodeHeader, NodeTitle, NodeContent, NodeFooter } from "@/components/ai-elements/node";
import type { NodeProps } from "@xyflow/react";
type ProcessNodeData = {
label: string;
status: "pending" | "running" | "completed" | "failed";
input?: string;
output?: string;
};
function ProcessNode({ data }: NodeProps<ProcessNodeData>) {
const statusColors = {
pending: "bg-gray-500",
running: "bg-blue-500 animate-pulse",
completed: "bg-green-500",
failed: "bg-red-500",
};
return (
<Node handles={{ target: true, source: true }}>
<NodeHeader>
<NodeTitle>{data.label}</NodeTitle>
</NodeHeader>
<NodeContent>
{data.input && (
<div className="space-y-1">
<Label className="text-xs">Input</Label>
<p className="text-xs text-muted-foreground">{data.input}</p>
</div>
)}
{data.output && (
<div className="space-y-1">
<Label className="text-xs">Output</Label>
<p className="text-xs text-muted-foreground">{data.output}</p>
</div>
)}
</NodeContent>
<NodeFooter>
<div className="flex items-center gap-2">
<div className={cn("size-2 rounded-full", statusColors[data.status])} />
<span className="text-xs capitalize">{data.status}</span>
</div>
</NodeFooter>
</Node>
);
}
// Register custom node type
const nodeTypes = {
process: ProcessNode,
};
<Canvas nodeTypes={nodeTypes} />
```
## Complete Example
```tsx
import { useState } from "react";
import { Canvas } from "@/components/ai-elements/canvas";
import {
Node,
NodeHeader,
NodeTitle,
NodeDescription,
NodeContent,
NodeFooter,
} from "@/components/ai-elements/node";
import { Edge } from "@/components/ai-elements/edge";
import {
Background,
Controls,
Panel,
useNodesState,
useEdgesState,
addEdge,
type Connection,
} from "@xyflow/react";
import "@xyflow/react/dist/style.css";
const initialNodes = [
{
id: "1",
type: "custom",
position: { x: 0, y: 0 },
data: { label: "Start", status: "completed" },
},
{
id: "2",
type: "custom",
position: { x: 250, y: 0 },
data: { label: "Process", status: "running" },
},
{
id: "3",
type: "custom",
position: { x: 500, y: 0 },
data: { label: "End", status: "pending" },
},
];
const initialEdges = [
{
id: "e1-2",
source: "1",
target: "2",
type: "animated",
},
{
id: "e2-3",
source: "2",
target: "3",
type: "temporary",
},
];
function CustomNode({ data }) {
return (
<Node handles={{ target: true, source: true }}>
<NodeHeader>
<NodeTitle>{data.label}</NodeTitle>
<NodeDescription>Step in workflow</NodeDescription>
</NodeHeader>
<NodeContent>
<p className="text-sm">Status: {data.status}</p>
</NodeContent>
<NodeFooter>
<Badge variant={data.status === "completed" ? "success" : "default"}>
{data.status}
</Badge>
</NodeFooter>
</Node>
);
}
const nodeTypes = {
custom: CustomNode,
};
const edgeTypes = {
temporary: Edge.Temporary,
animated: Edge.Animated,
};
function WorkflowCanvas() {
const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes);
const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges);
const onConnect = (connection: Connection) => {
setEdges((eds) => addEdge({ ...connection, type: "animated" }, eds));
};
return (
<div className="h-screen">
<Canvas
nodes={nodes}
edges={edges}
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnect={onConnect}
>
<Background bgColor="var(--sidebar)" />
<Controls />
<Panel position="top-left">
<div className="rounded-lg bg-background p-4 shadow-lg">
<h3 className="font-semibold">Workflow</h3>
<p className="text-sm text-muted-foreground">
{nodes.length} nodes, {edges.length} edges
</p>
</div>
</Panel>
</Canvas>
</div>
);
}
```
## Handle Positioning
The edge components use custom handle coordinate calculation:
```typescript
const getHandleCoordsByPosition = (node, handlePosition) => {
const handleType = handlePosition === Position.Left ? "target" : "source";
const handle = node.internals.handleBounds?.[handleType]?.find(
(h) => h.position === handlePosition
);
// Calculate absolute coordinates
const offsetX = handlePosition === Position.Right ? handle.width : 0;
const offsetY = handlePosition === Position.Bottom ? handle.height : 0;
const x = node.internals.positionAbsolute.x + handle.x + offsetX;
const y = node.internals.positionAbsolute.y + handle.y + offsetY;
return [x, y];
};
```
This ensures edges connect properly to handle centers.
## Styling Tips
### Node Widths
Control node width with Tailwind classes:
```tsx
<Node handles={handles} className="w-64">
{/* wider node */}
</Node>
<Node handles={handles} className="w-96">
{/* extra wide node */}
</Node>
```
### Edge Colors
Customize edge colors via className:
```tsx
// In custom edge component
<BaseEdge
className="stroke-2 stroke-primary"
id={id}
path={edgePath}
/>
```
### Node Status Indicators
Add visual status indicators:
```tsx
<NodeFooter>
<div className="flex items-center gap-2">
<div className={cn(
"size-2 rounded-full",
status === "running" && "bg-blue-500 animate-pulse",
status === "completed" && "bg-green-500",
status === "failed" && "bg-red-500"
)} />
<span className="text-xs">{status}</span>
</div>
</NodeFooter>
```
## Integration with Aviation Nodes
AI Elements Canvas integrates seamlessly with custom aviation-specific node types. The Node component provides a flexible base for domain-specific extensions:
```tsx
// Aviation-specific node
function FlightPlanNode({ data }) {
return (
<Node handles={{ target: true, source: true }}>
<NodeHeader>
<NodeTitle>{data.flightNumber}</NodeTitle>
<NodeDescription>{data.route}</NodeDescription>
</NodeHeader>
<NodeContent>
<div className="space-y-2 text-xs">
<div>Departure: {data.departure}</div>
<div>Arrival: {data.arrival}</div>
<div>Aircraft: {data.aircraft}</div>
</div>
</NodeContent>
<NodeFooter>
<Badge>{data.status}</Badge>
</NodeFooter>
</Node>
);
}
```
FILE:references/workflow.md
# Workflow Components
Components for displaying job queues, tool execution, approval workflows, and reasoning displays.
## Table of Contents
- [Queue Components](#queue-components)
- [Tool Components](#tool-components)
- [Confirmation Components](#confirmation-components)
- [Reasoning Components](#reasoning-components)
- [Loading Components](#loading-components)
## Queue Components
Components for displaying task lists, job queues, and progress tracking.
### Queue
Main container for queue items and sections.
```typescript
type QueueProps = ComponentProps<"div">;
```
**Usage:**
```tsx
<Queue>
<QueueSection>
{/* Queue items */}
</QueueSection>
</Queue>
```
**Default Styling:**
- Bordered container with rounded corners
- Background with shadow
- Flexbox column layout with gap
### QueueSection
Collapsible section for organizing queue items.
```typescript
type QueueSectionProps = ComponentProps<typeof Collapsible>;
```
**Props:**
- `defaultOpen?: boolean` - Initial open state (default: true)
**Usage:**
```tsx
<QueueSection defaultOpen={true}>
<QueueSectionTrigger>
<QueueSectionLabel count={5} label="pending tasks" />
</QueueSectionTrigger>
<QueueSectionContent>
<QueueList>
{/* Queue items */}
</QueueList>
</QueueSectionContent>
</QueueSection>
```
### QueueSectionTrigger
Clickable header to toggle section visibility.
```typescript
type QueueSectionTriggerProps = ComponentProps<"button">;
```
**Default Styling:**
- Full width button
- Muted background with hover effect
- Flexbox layout for content alignment
### QueueSectionLabel
Label with icon, text, and count display.
```typescript
type QueueSectionLabelProps = ComponentProps<"span"> & {
count?: number;
label: string;
icon?: React.ReactNode;
};
```
**Props:**
- `count?: number` - Item count to display
- `label: string` - Section label (required)
- `icon?: React.ReactNode` - Optional icon
**Usage:**
```tsx
<QueueSectionLabel
count={todos.length}
label="todos"
icon={<CheckSquareIcon className="size-4" />}
/>
```
**Display:**
Shows "{count} {label}" with chevron icon that rotates when expanded.
### QueueSectionContent
Collapsible content area for queue items.
```typescript
type QueueSectionContentProps = ComponentProps<typeof CollapsibleContent>;
```
### QueueList
Scrollable container for queue items.
```typescript
type QueueListProps = ComponentProps<typeof ScrollArea>;
```
**Default Styling:**
- Max height of 40 (10rem)
- Scrollable when content overflows
- Padding for scroll area
**Usage:**
```tsx
<QueueList>
{items.map(item => (
<QueueItem key={item.id}>
{/* Item content */}
</QueueItem>
))}
</QueueList>
```
### QueueItem
Individual queue item container.
```typescript
type QueueItemProps = ComponentProps<"li">;
```
**Usage:**
```tsx
<QueueItem>
<QueueItemIndicator completed={todo.status === 'completed'} />
<QueueItemContent completed={todo.status === 'completed'}>
{todo.title}
</QueueItemContent>
{todo.description && (
<QueueItemDescription completed={todo.status === 'completed'}>
{todo.description}
</QueueItemDescription>
)}
<QueueItemActions>
<QueueItemAction onClick={handleEdit}>
<EditIcon />
</QueueItemAction>
</QueueItemActions>
</QueueItem>
```
**Default Styling:**
- Flexbox column layout
- Hover background effect
- Grouped item styling
### QueueItemIndicator
Status indicator dot.
```typescript
type QueueItemIndicatorProps = ComponentProps<"span"> & {
completed?: boolean;
};
```
**Props:**
- `completed?: boolean` - Whether item is completed (default: false)
**Styling:**
- Pending: Solid border
- Completed: Muted border and background
### QueueItemContent
Main content text for queue item.
```typescript
type QueueItemContentProps = ComponentProps<"span"> & {
completed?: boolean;
};
```
**Props:**
- `completed?: boolean` - Whether item is completed (default: false)
**Styling:**
- Pending: Normal text
- Completed: Muted text with line-through
### QueueItemDescription
Secondary description text.
```typescript
type QueueItemDescriptionProps = ComponentProps<"div"> & {
completed?: boolean;
};
```
**Props:**
- `completed?: boolean` - Whether item is completed (default: false)
**Styling:**
- Smaller text size
- Indented under main content
- Muted color (more muted if completed)
### QueueItemActions
Container for action buttons.
```typescript
type QueueItemActionsProps = ComponentProps<"div">;
```
**Usage:**
```tsx
<QueueItemActions>
<QueueItemAction onClick={handleEdit}>
<EditIcon />
</QueueItemAction>
<QueueItemAction onClick={handleDelete}>
<TrashIcon />
</QueueItemAction>
</QueueItemActions>
```
### QueueItemAction
Individual action button.
```typescript
type QueueItemActionProps = Omit<ComponentProps<typeof Button>, "variant" | "size">;
```
**Behavior:**
- Hidden by default
- Visible on item hover
- Ghost variant
- Icon size
### QueueItemAttachment
Container for attached files/images.
```typescript
type QueueItemAttachmentProps = ComponentProps<"div">;
```
### QueueItemImage
Image preview thumbnail.
```typescript
type QueueItemImageProps = ComponentProps<"img">;
```
**Default Size:** 32x32px
### QueueItemFile
File attachment display with icon.
```typescript
type QueueItemFileProps = ComponentProps<"span">;
```
**Features:**
- Paperclip icon
- Truncated filename (max 100px)
- Border and background styling
## Tool Components
Components for displaying tool execution with states, parameters, and results.
### Tool
Main container for tool execution display.
```typescript
type ToolProps = ComponentProps<typeof Collapsible>;
```
**Usage:**
```tsx
<Tool>
<ToolHeader
title="search"
type="tool-call-search"
state="output-available"
/>
<ToolContent>
<ToolInput input={{ query: "AI tools" }} />
<ToolOutput output={results} errorText={undefined} />
</ToolContent>
</Tool>
```
**Default Styling:**
- Bordered container
- Rounded corners
- Not prose (for content formatting)
### ToolHeader
Collapsible trigger showing tool name and status.
```typescript
type ToolHeaderProps = {
title?: string;
type: ToolUIPart["type"];
state: ToolUIPart["state"];
className?: string;
};
```
**Props:**
- `title?: string` - Display name (defaults to type without "tool-call-" prefix)
- `type: string` - Tool type identifier
- `state: ToolState` - Current execution state (required)
**Tool States:**
- `input-streaming` - Parameters being received (Pending badge)
- `input-available` - Ready to execute (Running badge, pulsing)
- `approval-requested` - Awaiting approval (Awaiting Approval badge, yellow)
- `approval-responded` - User responded (Responded badge, blue)
- `output-available` - Completed (Completed badge, green)
- `output-error` - Failed (Error badge, red)
- `output-denied` - Approval denied (Denied badge, orange)
**Features:**
- Wrench icon
- Color-coded status badge
- Chevron that rotates when expanded
### ToolContent
Collapsible content area for parameters and results.
```typescript
type ToolContentProps = ComponentProps<typeof CollapsibleContent>;
```
**Animation:**
- Slide in/out from top
- Fade transition
### ToolInput
Displays tool parameters/arguments.
```typescript
type ToolInputProps = ComponentProps<"div"> & {
input: ToolUIPart["input"];
};
```
**Props:**
- `input: unknown` - Tool parameters (any JSON-serializable value)
**Usage:**
```tsx
<ToolInput
input={{
query: "AI tools",
limit: 10,
filters: ["type:library"]
}}
/>
```
**Features:**
- "PARAMETERS" heading
- JSON syntax highlighting via CodeBlock
- Automatic JSON.stringify with formatting
### ToolOutput
Displays tool results or errors.
```typescript
type ToolOutputProps = ComponentProps<"div"> & {
output: ToolUIPart["output"];
errorText: ToolUIPart["errorText"];
};
```
**Props:**
- `output: unknown` - Tool result (any value, React element, or JSON)
- `errorText?: string` - Error message if execution failed
**Usage:**
```tsx
<ToolOutput
output={{
results: [...],
count: 42
}}
errorText={undefined}
/>
```
**Behavior:**
- Shows "RESULT" heading for success
- Shows "ERROR" heading for errors
- Renders React elements directly
- JSON stringifies objects
- CodeBlock for strings
- Destructive styling for errors
## Confirmation Components
Components for approval workflows requiring user confirmation.
### Confirmation
Container for confirmation UI with conditional rendering based on state.
```typescript
type ConfirmationProps = ComponentProps<typeof Alert> & {
approval?: ToolUIPartApproval;
state: ToolUIPart["state"];
};
type ToolUIPartApproval =
| { id: string; approved?: never; reason?: never }
| { id: string; approved: boolean; reason?: string }
| undefined;
```
**Props:**
- `approval?: ToolUIPartApproval` - Approval data
- `state: ToolState` - Current state
**Usage:**
```tsx
<Confirmation approval={tool.approval} state={tool.state}>
<ConfirmationTitle>
Delete {count} files?
</ConfirmationTitle>
<ConfirmationRequest>
<ConfirmationActions>
<ConfirmationAction onClick={handleApprove} variant="default">
Approve
</ConfirmationAction>
<ConfirmationAction onClick={handleReject} variant="outline">
Reject
</ConfirmationAction>
</ConfirmationActions>
</ConfirmationRequest>
<ConfirmationAccepted>
Action approved and executed.
</ConfirmationAccepted>
<ConfirmationRejected>
Action rejected.
</ConfirmationRejected>
</Confirmation>
```
**Context:**
Provides `{ approval, state }` to child components.
### ConfirmationTitle
Title/description of the confirmation request.
```typescript
type ConfirmationTitleProps = ComponentProps<typeof AlertDescription>;
```
### ConfirmationRequest
Container shown during approval-requested state.
```typescript
type ConfirmationRequestProps = { children?: ReactNode };
```
**Visibility:**
Only shown when `state === "approval-requested"`
### ConfirmationActions
Container for approve/reject buttons.
```typescript
type ConfirmationActionsProps = ComponentProps<"div">;
```
**Visibility:**
Only shown when `state === "approval-requested"`
### ConfirmationAction
Individual action button.
```typescript
type ConfirmationActionProps = ComponentProps<typeof Button>;
```
**Default Styling:**
- Small height (h-8)
- Compact padding
### ConfirmationAccepted
Content shown when approval is accepted.
```typescript
type ConfirmationAcceptedProps = { children?: ReactNode };
```
**Visibility:**
Only shown when `approval.approved === true` and state is response/output state.
### ConfirmationRejected
Content shown when approval is rejected.
```typescript
type ConfirmationRejectedProps = { children?: ReactNode };
```
**Visibility:**
Only shown when `approval.approved === false` and state is response/output state.
### useConfirmation
Hook to access confirmation context.
```typescript
const useConfirmation = () => {
const { approval, state } = useConfirmation();
// ...
};
```
## Reasoning Components
Components for displaying AI thinking/reasoning with auto-collapse behavior.
### Reasoning
Collapsible container for reasoning content with auto-collapse.
```typescript
type ReasoningProps = ComponentProps<typeof Collapsible> & {
isStreaming?: boolean;
open?: boolean;
defaultOpen?: boolean;
onOpenChange?: (open: boolean) => void;
duration?: number;
};
```
**Props:**
- `isStreaming?: boolean` - Whether content is actively streaming (default: false)
- `open?: boolean` - Controlled open state
- `defaultOpen?: boolean` - Initial open state (default: true)
- `onOpenChange?: (open: boolean) => void` - Callback when open state changes
- `duration?: number` - Thinking duration in seconds
**Usage:**
```tsx
<Reasoning isStreaming={isStreaming} defaultOpen={true}>
<ReasoningTrigger />
<ReasoningContent>
{reasoningText}
</ReasoningContent>
</Reasoning>
```
**Auto-Collapse Behavior:**
1. Opens automatically when streaming starts
2. Tracks duration from start to end of streaming
3. Closes 1 second after streaming ends
4. Auto-close only happens once per component lifecycle
**Context:**
Provides `{ isStreaming, isOpen, setIsOpen, duration }` to children.
### ReasoningTrigger
Trigger button with status message and icon.
```typescript
type ReasoningTriggerProps = ComponentProps<typeof CollapsibleTrigger> & {
getThinkingMessage?: (isStreaming: boolean, duration?: number) => ReactNode;
};
```
**Props:**
- `getThinkingMessage?: (isStreaming, duration) => ReactNode` - Custom message generator
**Default Messages:**
- Streaming: "Thinking..." with shimmer effect
- Duration 0 or undefined: "Thought for a few seconds"
- Duration N: "Thought for N seconds"
**Usage:**
```tsx
<ReasoningTrigger />
// Custom message
<ReasoningTrigger
getThinkingMessage={(streaming, duration) =>
streaming ? <Shimmer>Processing...</Shimmer> : `Done in durations`
}
/>
```
**Icons:**
- Brain icon
- Rotating chevron (down when closed, up when open)
### ReasoningContent
Collapsible content area with markdown rendering.
```typescript
type ReasoningContentProps = ComponentProps<typeof CollapsibleContent> & {
children: string;
};
```
**Props:**
- `children: string` - Reasoning text (required, string only)
**Features:**
- Uses Streamdown for markdown rendering
- Slide and fade animations
- Muted text color
### useReasoning
Hook to access reasoning context.
```typescript
const useReasoning = () => {
const { isStreaming, isOpen, setIsOpen, duration } = useReasoning();
// ...
};
```
## Loading Components
### Shimmer
Animated shimmer effect for loading text.
```typescript
type TextShimmerProps = {
children: string;
as?: ElementType;
className?: string;
duration?: number;
spread?: number;
};
```
**Props:**
- `children: string` - Text to shimmer (required)
- `as?: ElementType` - HTML element type (default: "p")
- `className?: string` - Additional classes
- `duration?: number` - Animation duration in seconds (default: 2)
- `spread?: number` - Shimmer spread multiplier (default: 2)
**Usage:**
```tsx
<Shimmer duration={1.5}>Thinking...</Shimmer>
<Shimmer as="span" spread={3}>Loading data...</Shimmer>
```
**Features:**
- Framer Motion animation
- Gradient sweep effect
- Dynamic spread based on text length
- Infinite loop
### Loader
Spinning loader icon.
```typescript
type LoaderProps = HTMLAttributes<HTMLDivElement> & {
size?: number;
};
```
**Props:**
- `size?: number` - Icon size in pixels (default: 16)
**Usage:**
```tsx
<Loader size={24} />
<Loader className="text-primary" />
```
**Features:**
- SVG-based spinner
- CSS animation (spin)
- Respects current color
## Complete Example
```tsx
import {
Queue,
QueueSection,
QueueSectionTrigger,
QueueSectionLabel,
QueueSectionContent,
QueueList,
QueueItem,
QueueItemIndicator,
QueueItemContent,
QueueItemDescription,
} from "@/components/ai-elements/queue";
import {
Tool,
ToolHeader,
ToolContent,
ToolInput,
ToolOutput,
} from "@/components/ai-elements/tool";
import {
Confirmation,
ConfirmationTitle,
ConfirmationRequest,
ConfirmationActions,
ConfirmationAction,
ConfirmationAccepted,
ConfirmationRejected,
} from "@/components/ai-elements/confirmation";
import {
Reasoning,
ReasoningTrigger,
ReasoningContent,
} from "@/components/ai-elements/reasoning";
function WorkflowDisplay({ todos, tools, reasoning }) {
return (
<div className="space-y-4">
{/* Queue */}
<Queue>
<QueueSection>
<QueueSectionTrigger>
<QueueSectionLabel count={todos.length} label="tasks" />
</QueueSectionTrigger>
<QueueSectionContent>
<QueueList>
{todos.map(todo => (
<QueueItem key={todo.id}>
<QueueItemIndicator completed={todo.completed} />
<QueueItemContent completed={todo.completed}>
{todo.title}
</QueueItemContent>
{todo.description && (
<QueueItemDescription completed={todo.completed}>
{todo.description}
</QueueItemDescription>
)}
</QueueItem>
))}
</QueueList>
</QueueSectionContent>
</QueueSection>
</Queue>
{/* Reasoning */}
{reasoning && (
<Reasoning isStreaming={reasoning.isStreaming}>
<ReasoningTrigger />
<ReasoningContent>{reasoning.content}</ReasoningContent>
</Reasoning>
)}
{/* Tools */}
{tools.map(tool => (
<div key={tool.id}>
<Tool>
<ToolHeader
title={tool.name}
type={tool.type}
state={tool.state}
/>
<ToolContent>
<ToolInput input={tool.args} />
<ToolOutput output={tool.result} errorText={tool.error} />
</ToolContent>
</Tool>
{tool.requiresApproval && (
<Confirmation approval={tool.approval} state={tool.state}>
<ConfirmationTitle>
Approve {tool.name}?
</ConfirmationTitle>
<ConfirmationRequest>
<ConfirmationActions>
<ConfirmationAction
onClick={() => approveTool(tool.id)}
variant="default"
>
Approve
</ConfirmationAction>
<ConfirmationAction
onClick={() => rejectTool(tool.id)}
variant="outline"
>
Reject
</ConfirmationAction>
</ConfirmationActions>
</ConfirmationRequest>
<ConfirmationAccepted>
Tool approved and executed.
</ConfirmationAccepted>
<ConfirmationRejected>
Tool execution rejected.
</ConfirmationRejected>
</Confirmation>
)}
</div>
))}
</div>
);
}
```
Reviews SQLAlchemy code for session management, relationships, N+1 queries, and migration patterns. Use when reviewing SQLAlchemy 2.0 code, checking session...
---
name: sqlalchemy-code-review
description: Reviews SQLAlchemy code for session management, relationships, N+1 queries, and migration patterns. Use when reviewing SQLAlchemy 2.0 code, checking session lifecycle, relationship() usage, or Alembic migrations.
---
# SQLAlchemy Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| Session lifecycle, context managers, async sessions | [references/sessions.md](references/sessions.md) |
| relationship(), lazy loading, N+1, joinedload | [references/relationships.md](references/relationships.md) |
| select() vs query(), ORM overhead, bulk ops | [references/queries.md](references/queries.md) |
| Alembic patterns, reversible migrations, data migrations | [references/migrations.md](references/migrations.md) |
## Review Checklist
- [ ] Sessions use context managers (`with`, `async with`)
- [ ] No session sharing across requests or threads
- [ ] Sessions closed/cleaned up properly
- [ ] `relationship()` uses appropriate `lazy` strategy
- [ ] Explicit `joinedload`/`selectinload` to avoid N+1
- [ ] No lazy loading in loops (N+1 queries)
- [ ] Using SQLAlchemy 2.0 `select()` syntax, not legacy `query()`
- [ ] Bulk operations use bulk_insert/bulk_update, not ORM loops
- [ ] Async sessions use proper async context managers
- [ ] Migrations are reversible with `downgrade()`
- [ ] Data migrations use `op.execute()` not ORM models
- [ ] Migration dependencies properly ordered
## Gates (SQLAlchemy-specific)
Run **once per SQLAlchemy-related finding**, after you can anchor **`file:line`** (see [review-verification-protocol](../review-verification-protocol/SKILL.md)) and **before** the finding ships. If a step’s pass condition is not met, **do not** assert the finding as written—gather evidence, withdraw, downgrade severity, or rephrase as a question.
### Gate 1 — Session scope and lifecycle
| Step | Action | **Pass condition** |
|------|--------|---------------------|
| 1a | Open the module where the session is created or injected (not from memory). | **`file:line`** for `Session`, `sessionmaker`, `async_session`, or the factory/`Depends()` that yields a session. |
| 1b | If claiming leak, cross-request sharing, or missing cleanup: trace the session’s scope (context manager, `try`/`finally`, middleware). | **Scoped region** cited with a **`file:line` range**, or withdraw if scope is correct after the read. |
### Gate 2 — N+1, lazy loading, eager loads
| Step | Action | **Pass condition** |
|------|--------|---------------------|
| 2a | Identify the loop or repeated call site (ORM attribute access, `execute` in a loop). | **`file:line`** for the loop or hot path. |
| 2b | If claiming N+1: name the relationship or query pattern emitted per iteration. | **Relationship or per-iteration SQL pattern** with **`file:line`**, or rephrase as a question if unclear. |
### Gate 3 — Migrations (Alembic)
| Step | Action | **Pass condition** |
|------|--------|---------------------|
| 3a | Open the revision file (e.g. under `versions/`, or the project’s Alembic layout). | **Repo-relative path** + **`file:line`** for `revision` / `upgrade` / `downgrade`. |
| 3b | If claiming broken `downgrade()` or risky data migration: point at the `op.*` / `op.execute()` involved. | **Snippet or line range** in that file for each claimed op, or withdraw. |
## When to Load References
- Reviewing session creation/cleanup → sessions.md
- Reviewing model relationships → relationships.md
- Reviewing database queries → queries.md
- Reviewing Alembic migration files → migrations.md
## Review Questions
1. Are all sessions properly managed with context managers?
2. Are relationships configured to avoid N+1 queries?
3. Are queries using SQLAlchemy 2.0 `select()` syntax?
4. Are all migrations reversible and properly tested?
FILE:references/migrations.md
# Migrations
## Critical Anti-Patterns
### 1. Non-Reversible Migrations
**Problem**: Can't rollback, stuck on failed deploys.
```python
# BAD - no downgrade
"""Add user_role column
Revision ID: abc123
"""
def upgrade():
op.add_column('users', sa.Column('role', sa.String(50)))
def downgrade():
pass # Can't rollback!
# GOOD - reversible migration
def upgrade():
op.add_column('users', sa.Column('role', sa.String(50), nullable=True))
def downgrade():
op.drop_column('users', 'role')
```
### 2. Not Making New Columns Nullable First
**Problem**: Migration fails on existing data.
```python
# BAD - adding non-nullable column to existing table
def upgrade():
# Fails if table has existing rows!
op.add_column('users', sa.Column('email', sa.String(255), nullable=False))
# GOOD - two-step migration
def upgrade():
# Step 1: Add nullable column
op.add_column('users', sa.Column('email', sa.String(255), nullable=True))
# Then in a separate migration after backfilling data:
def upgrade():
# Step 2: Make it non-nullable
op.alter_column('users', 'email', nullable=False)
def downgrade():
op.alter_column('users', 'email', nullable=True)
# BETTER - add with server_default
def upgrade():
op.add_column(
'users',
sa.Column('email', sa.String(255), nullable=False, server_default='')
)
# Remove server_default in next migration after cleanup
```
### 3. Using ORM Models in Migrations
**Problem**: Model changes break old migrations.
```python
# BAD - using ORM models directly
from app.models import User # DON'T!
def upgrade():
session = Session()
users = session.query(User).all() # Model might change!
for user in users:
user.email = f"{user.username}@example.com"
session.commit()
# GOOD - use op.execute with raw SQL
def upgrade():
op.execute(
"""
UPDATE users
SET email = username || '@example.com'
WHERE email IS NULL
"""
)
# BETTER - use Core Table for complex operations
from sqlalchemy import table, column, String, Integer
def upgrade():
users_table = table(
'users',
column('id', Integer),
column('username', String),
column('email', String)
)
connection = op.get_bind()
users = connection.execute(
select(users_table.c.id, users_table.c.username)
.where(users_table.c.email.is_(None))
).fetchall()
for user in users:
connection.execute(
update(users_table)
.where(users_table.c.id == user.id)
.values(email=f"{user.username}@example.com")
)
```
### 4. Not Handling Concurrent Migrations
**Problem**: Multiple developers create conflicting migrations.
```python
# BAD - no dependency management
"""Add status column
Revision ID: abc123
Revises: xyz789
"""
# Developer B also based on xyz789 - conflict!
"""Add priority column
Revision ID: def456
Revises: xyz789 # Same parent!
"""
# GOOD - use down_revision properly
# Developer A
"""Add status column
Revision ID: abc123
Revises: xyz789
"""
# Developer B rebases
"""Add priority column
Revision ID: def456
Revises: abc123 # Updated after merge
"""
# BETTER - use alembic branches for long-running features
$ alembic revision -m "feature branch" --branch-label feature_x --depends-on abc123
```
### 5. Dangerous DDL Without Transactions
**Problem**: Partial migrations leave database in broken state.
```python
# BAD - multiple DDL operations without transaction
def upgrade():
op.create_table('temp_users', ...)
op.execute("INSERT INTO temp_users SELECT * FROM users")
op.drop_table('users') # If this fails, temp_users exists but users is gone!
op.rename_table('temp_users', 'users')
# GOOD - use batch operations for SQLite
def upgrade():
with op.batch_alter_table('users') as batch_op:
batch_op.add_column(sa.Column('new_col', sa.String(50)))
batch_op.drop_column('old_col')
# PostgreSQL supports transactional DDL
def upgrade():
# These all happen in a transaction by default
op.add_column('users', sa.Column('new_col', sa.String(50)))
op.drop_column('users', 'old_col')
# For operations that can't be in a transaction
def upgrade():
op.execute("CREATE INDEX CONCURRENTLY idx_users_email ON users(email)")
def downgrade():
op.execute("DROP INDEX CONCURRENTLY idx_users_email")
```
### 6. Not Testing Migrations
**Problem**: Migrations fail in production.
```python
# BAD - no testing
def upgrade():
# Hope this works in production!
op.add_column('users', sa.Column('role', sa.String(50)))
# GOOD - test migrations in CI
# tests/test_migrations.py
import pytest
from alembic import command
from alembic.config import Config
def test_migration_upgrade_downgrade():
config = Config("alembic.ini")
# Test upgrade
command.upgrade(config, "head")
# Test downgrade
command.downgrade(config, "-1")
# Test re-upgrade
command.upgrade(config, "head")
# BETTER - test with actual data
def test_migration_preserves_data():
config = Config("alembic.ini")
# Setup test data
connection = engine.connect()
connection.execute(
"INSERT INTO users (username, email) VALUES ('test', '[email protected]')"
)
# Run migration
command.upgrade(config, "head")
# Verify data preserved
result = connection.execute("SELECT * FROM users WHERE username = 'test'")
assert result.rowcount == 1
```
### 7. Not Using Batch Operations for SQLite
**Problem**: SQLite doesn't support many ALTER TABLE operations.
```python
# BAD - doesn't work on SQLite
def upgrade():
op.alter_column('users', 'email', type_=sa.String(512)) # Fails on SQLite!
# GOOD - use batch operations
def upgrade():
with op.batch_alter_table('users', schema=None) as batch_op:
batch_op.alter_column('email', type_=sa.String(512))
def downgrade():
with op.batch_alter_table('users', schema=None) as batch_op:
batch_op.alter_column('email', type_=sa.String(255))
```
### 8. Not Handling Large Data Migrations
**Problem**: Migration times out or locks table.
```python
# BAD - single UPDATE locks entire table
def upgrade():
op.execute(
"UPDATE users SET normalized_email = LOWER(email)"
) # Locks millions of rows!
# GOOD - batch updates
def upgrade():
connection = op.get_bind()
batch_size = 1000
offset = 0
while True:
result = connection.execute(
f"""
UPDATE users
SET normalized_email = LOWER(email)
WHERE id IN (
SELECT id FROM users
WHERE normalized_email IS NULL
ORDER BY id
LIMIT {batch_size} OFFSET {offset}
)
"""
)
if result.rowcount == 0:
break
offset += batch_size
# Sleep to avoid overwhelming the database
import time
time.sleep(0.1)
# BETTER - use queue/background job for very large tables
def upgrade():
# Add column
op.add_column('users', sa.Column('normalized_email', sa.String(255)))
# Create background job to populate
# (Actual backfill happens outside migration)
pass
```
### 9. Not Managing Indexes Properly
**Problem**: Slow queries after migration, or failed migrations.
```python
# BAD - adding index inline blocks table
def upgrade():
op.add_column('users', sa.Column('email', sa.String(255)))
op.create_index('idx_users_email', 'users', ['email']) # Locks table!
# GOOD - create index concurrently (PostgreSQL)
def upgrade():
op.add_column('users', sa.Column('email', sa.String(255)))
# Separate connection for concurrent index
op.execute("COMMIT") # End transaction
op.execute("CREATE INDEX CONCURRENTLY idx_users_email ON users(email)")
def downgrade():
op.execute("DROP INDEX CONCURRENTLY idx_users_email")
op.drop_column('users', 'email')
# BETTER - track index creation separately
def upgrade():
op.add_column('users', sa.Column('email', sa.String(255)))
# Create index in a separate migration
```
### 10. Not Documenting Complex Migrations
**Problem**: Team doesn't understand migration purpose or impact.
```python
# BAD - no documentation
"""revision abc123
"""
def upgrade():
op.execute("complex SQL here...")
# GOOD - clear documentation
"""Add normalized_email column for case-insensitive lookups
This migration:
1. Adds a new normalized_email column (nullable initially)
2. Backfills it with lowercase email values
3. Creates a unique index on normalized_email
4. Does NOT make it non-nullable yet (requires follow-up migration)
Expected duration: ~2 minutes for 1M users
Locks: Brief lock during index creation
Rollback safe: Yes
Revision ID: abc123
Revises: xyz789
Create Date: 2024-01-15 10:30:00
"""
def upgrade():
# Step 1: Add column
op.add_column(
'users',
sa.Column('normalized_email', sa.String(255), nullable=True)
)
# Step 2: Backfill in batches
connection = op.get_bind()
batch_size = 1000
# ... batched update logic ...
# Step 3: Create index
op.create_index(
'idx_users_normalized_email',
'users',
['normalized_email'],
unique=True
)
def downgrade():
op.drop_index('idx_users_normalized_email', table_name='users')
op.drop_column('users', 'normalized_email')
```
### 11. Not Using Check Constraints
**Problem**: Invalid data gets inserted.
```python
# BAD - no constraints, rely on application validation
def upgrade():
op.add_column('users', sa.Column('age', sa.Integer))
# GOOD - add check constraints
def upgrade():
op.add_column('users', sa.Column('age', sa.Integer))
op.create_check_constraint(
'ck_users_age_positive',
'users',
'age >= 0 AND age <= 150'
)
def downgrade():
op.drop_constraint('ck_users_age_positive', 'users')
op.drop_column('users', 'age')
# BETTER - use enum for limited values
from sqlalchemy import Enum
def upgrade():
role_enum = sa.Enum('user', 'admin', 'moderator', name='user_role')
role_enum.create(op.get_bind())
op.add_column(
'users',
sa.Column('role', role_enum, nullable=False, server_default='user')
)
def downgrade():
op.drop_column('users', 'role')
sa.Enum(name='user_role').drop(op.get_bind())
```
## Review Questions
1. Does every migration have a working `downgrade()` function?
2. Are new non-nullable columns added in two steps (nullable first, then constrain)?
3. Are data migrations using `op.execute()` not ORM models?
4. Are large data updates batched to avoid timeouts?
5. Are indexes created with CONCURRENTLY on PostgreSQL?
6. Are complex migrations documented with expected duration and impact?
7. Are constraints (CHECK, UNIQUE, FK) properly created and dropped?
FILE:references/queries.md
# Queries
## Critical Anti-Patterns
### 1. Using Legacy query() Instead of select()
**Problem**: Legacy API, deprecated in SQLAlchemy 2.0.
```python
# BAD - legacy query() API (deprecated)
def get_active_users():
with Session() as session:
users = session.query(User).filter(User.active == True).all()
return users
# GOOD - SQLAlchemy 2.0 select() syntax
from sqlalchemy import select
def get_active_users():
with Session() as session:
result = session.execute(
select(User).where(User.active == True)
)
return result.scalars().all()
# ASYNC version
async def get_active_users():
async with AsyncSession() as session:
result = await session.execute(
select(User).where(User.active == True)
)
return result.scalars().all()
```
### 2. Loading Full Objects When Only Columns Needed
**Problem**: ORM overhead, unnecessary data transfer.
```python
# BAD - loading full ORM objects just for one column
def get_user_emails():
with Session() as session:
users = session.execute(select(User)).scalars().all()
return [user.email for user in users] # Loaded entire object!
# GOOD - select only needed columns
def get_user_emails():
with Session() as session:
result = session.execute(
select(User.email)
)
return result.scalars().all()
# BETTER - multiple columns as tuples
def get_user_info():
with Session() as session:
result = session.execute(
select(User.id, User.name, User.email)
)
return result.all() # Returns list of tuples
```
### 3. Using all() When Only One Result Expected
**Problem**: Confusing API, loads unnecessary data.
```python
# BAD - using all() when expecting one result
def get_user_by_email(email: str):
with Session() as session:
users = session.execute(
select(User).where(User.email == email)
).scalars().all()
return users[0] if users else None # Awkward!
# GOOD - use scalar_one_or_none()
def get_user_by_email(email: str) -> User | None:
with Session() as session:
return session.execute(
select(User).where(User.email == email)
).scalar_one_or_none()
# Use scalar_one() if must exist (raises if not found)
def get_user_by_id(user_id: int) -> User:
with Session() as session:
return session.execute(
select(User).where(User.id == user_id)
).scalar_one() # Raises NoResultFound or MultipleResultsFound
```
### 4. Not Using Bulk Operations
**Problem**: ORM overhead per object, slow inserts/updates.
```python
# BAD - ORM insert in loop
def create_users(user_data: list[dict]):
with Session() as session:
for data in user_data:
user = User(**data)
session.add(user) # Individual ORM overhead per user
session.commit()
# GOOD - bulk insert
def create_users(user_data: list[dict]):
with Session() as session:
session.bulk_insert_mappings(User, user_data)
session.commit()
# BETTER - Core insert for maximum performance
from sqlalchemy import insert
def create_users(user_data: list[dict]):
with Session() as session:
session.execute(
insert(User),
user_data
)
session.commit()
# ASYNC bulk insert
async def create_users(user_data: list[dict]):
async with AsyncSession() as session:
await session.execute(
insert(User),
user_data
)
await session.commit()
```
### 5. Not Using Bulk Updates
**Problem**: ORM overhead, multiple UPDATE statements.
```python
# BAD - update in loop
def deactivate_old_users(cutoff_date):
with Session() as session:
users = session.execute(
select(User).where(User.last_login < cutoff_date)
).scalars().all()
for user in users:
user.active = False # Individual UPDATE per user
session.commit()
# GOOD - single UPDATE statement
from sqlalchemy import update
def deactivate_old_users(cutoff_date):
with Session() as session:
session.execute(
update(User)
.where(User.last_login < cutoff_date)
.values(active=False)
)
session.commit()
# ASYNC version
async def deactivate_old_users(cutoff_date):
async with AsyncSession() as session:
await session.execute(
update(User)
.where(User.last_login < cutoff_date)
.values(active=False)
)
await session.commit()
```
### 6. Not Using exists() for Existence Checks
**Problem**: Loads unnecessary data just to check existence.
```python
# BAD - loading data just to check existence
def user_exists(email: str) -> bool:
with Session() as session:
user = session.execute(
select(User).where(User.email == email)
).scalar_one_or_none()
return user is not None # Loaded entire object!
# GOOD - use exists()
from sqlalchemy import exists, select
def user_exists(email: str) -> bool:
with Session() as session:
return session.execute(
select(exists().where(User.email == email))
).scalar()
# Alternative with count (less efficient but sometimes clearer)
from sqlalchemy import func
def user_exists(email: str) -> bool:
with Session() as session:
count = session.execute(
select(func.count()).select_from(User).where(User.email == email)
).scalar()
return count > 0
```
### 7. Not Using Pagination
**Problem**: Memory exhaustion on large result sets.
```python
# BAD - loading all results into memory
def get_all_users():
with Session() as session:
users = session.execute(select(User)).scalars().all() # OOM on millions!
return users
# GOOD - use limit/offset for pagination
def get_users_page(page: int = 1, page_size: int = 100):
with Session() as session:
offset = (page - 1) * page_size
users = session.execute(
select(User)
.offset(offset)
.limit(page_size)
).scalars().all()
return users
# BETTER - use keyset pagination for large datasets
def get_users_after(last_id: int | None = None, page_size: int = 100):
with Session() as session:
query = select(User).order_by(User.id)
if last_id:
query = query.where(User.id > last_id)
users = session.execute(
query.limit(page_size)
).scalars().all()
return users
# BEST - stream results for very large datasets
def stream_all_users():
with Session() as session:
result = session.execute(select(User))
for user in result.scalars(): # Streams, doesn't load all
yield user
```
### 8. Not Using with_for_update for Row Locking
**Problem**: Race conditions in concurrent updates.
```python
# BAD - race condition in concurrent requests
def decrement_stock(product_id: int, quantity: int):
with Session() as session:
product = session.execute(
select(Product).where(Product.id == product_id)
).scalar_one()
# Another request could modify stock here!
if product.stock >= quantity:
product.stock -= quantity
session.commit()
else:
raise ValueError("Insufficient stock")
# GOOD - use SELECT FOR UPDATE
def decrement_stock(product_id: int, quantity: int):
with Session() as session:
with session.begin():
product = session.execute(
select(Product)
.where(Product.id == product_id)
.with_for_update() # Row locked until commit
).scalar_one()
if product.stock >= quantity:
product.stock -= quantity
else:
raise ValueError("Insufficient stock")
# ASYNC version
async def decrement_stock(product_id: int, quantity: int):
async with AsyncSession() as session:
async with session.begin():
result = await session.execute(
select(Product)
.where(Product.id == product_id)
.with_for_update()
)
product = result.scalar_one()
if product.stock >= quantity:
product.stock -= quantity
else:
raise ValueError("Insufficient stock")
```
### 9. Using String-Based Filters Instead of Column Objects
**Problem**: No IDE support, error-prone, SQL injection risk.
```python
# BAD - string-based filters
def search_users(name: str):
with Session() as session:
users = session.execute(
select(User).filter_by(name=name) # String-based
).scalars().all()
return users
# WORSE - string SQL (SQL injection risk!)
def search_users(name: str):
with Session() as session:
users = session.execute(
f"SELECT * FROM users WHERE name = '{name}'" # NEVER DO THIS!
).all()
# GOOD - column object filters
def search_users(name: str):
with Session() as session:
users = session.execute(
select(User).where(User.name == name) # Type-safe
).scalars().all()
return users
# BETTER - parameterized for complex filters
from sqlalchemy import text
def search_users_complex(filters: dict):
with Session() as session:
query = select(User)
if "name" in filters:
query = query.where(User.name.contains(filters["name"]))
if "active" in filters:
query = query.where(User.active == filters["active"])
users = session.execute(query).scalars().all()
return users
```
### 10. Not Using Subqueries Efficiently
**Problem**: Multiple queries instead of single subquery.
```python
# BAD - multiple queries
def get_users_with_recent_posts():
with Session() as session:
# First query
recent_post_user_ids = session.execute(
select(Post.user_id)
.where(Post.created_at > datetime.now() - timedelta(days=7))
.distinct()
).scalars().all()
# Second query
users = session.execute(
select(User).where(User.id.in_(recent_post_user_ids))
).scalars().all()
return users
# GOOD - single query with subquery
def get_users_with_recent_posts():
with Session() as session:
recent_posts_subq = (
select(Post.user_id)
.where(Post.created_at > datetime.now() - timedelta(days=7))
.distinct()
.subquery()
)
users = session.execute(
select(User).where(User.id.in_(select(recent_posts_subq.c.user_id)))
).scalars().all()
return users
# BETTER - use join
def get_users_with_recent_posts():
with Session() as session:
users = session.execute(
select(User)
.join(Post)
.where(Post.created_at > datetime.now() - timedelta(days=7))
.distinct()
).scalars().all()
return users
```
### 11. Not Using union/union_all
**Problem**: Multiple queries when one combined query would work.
```python
# BAD - multiple queries
def get_all_content():
with Session() as session:
posts = session.execute(select(Post)).scalars().all()
pages = session.execute(select(Page)).scalars().all()
return {"posts": posts, "pages": pages}
# GOOD - union query (if columns match)
from sqlalchemy import union_all
def get_all_content_items():
with Session() as session:
posts_query = select(
Post.id,
Post.title,
Post.created_at,
literal("post").label("type")
)
pages_query = select(
Page.id,
Page.title,
Page.created_at,
literal("page").label("type")
)
combined = union_all(posts_query, pages_query)
result = session.execute(combined).all()
return result
```
## Review Questions
1. Are all queries using SQLAlchemy 2.0 `select()` syntax not legacy `query()`?
2. Are bulk operations used for batch inserts/updates?
3. Are only required columns selected when full objects aren't needed?
4. Is `exists()` used instead of loading objects for existence checks?
5. Is pagination implemented for large result sets?
6. Is `with_for_update()` used for concurrent updates?
7. Are column objects used instead of string-based filters?
FILE:references/relationships.md
# Relationships
## Critical Anti-Patterns
### 1. N+1 Query Problem
**Problem**: One query per related object, severe performance degradation.
```python
# BAD - N+1 queries
def get_users_with_posts():
with Session() as session:
users = session.execute(select(User)).scalars().all()
result = []
for user in users:
# Each access triggers a separate query!
posts = user.posts # SELECT * FROM posts WHERE user_id = ?
result.append({"user": user, "posts": posts})
return result
# GOOD - eager load with joinedload
from sqlalchemy.orm import joinedload
def get_users_with_posts():
with Session() as session:
users = session.execute(
select(User).options(joinedload(User.posts))
).unique().scalars().all()
return users
# ASYNC version
async def get_users_with_posts():
async with AsyncSession() as session:
result = await session.execute(
select(User).options(joinedload(User.posts))
)
return result.unique().scalars().all()
```
### 2. Wrong Lazy Loading Strategy
**Problem**: Default lazy loading causes N+1 in most real-world scenarios.
```python
# BAD - default lazy='select' causes N+1
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
posts = relationship("Post", back_populates="user") # lazy='select' by default
# GOOD - choose appropriate lazy strategy
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
# Option 1: lazy='joined' - always join
posts = relationship("Post", back_populates="user", lazy="joined")
# Option 2: lazy='selectin' - single extra query
posts = relationship("Post", back_populates="user", lazy="selectin")
# Option 3: lazy='raise' - force explicit loading
posts = relationship("Post", back_populates="user", lazy="raise")
# BEST - use lazy='raise' and explicit loading at query time
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
posts = relationship("Post", back_populates="user", lazy="raise")
# Then explicitly load when needed
def get_user_with_posts(user_id: int):
with Session() as session:
user = session.execute(
select(User)
.options(selectinload(User.posts))
.where(User.id == user_id)
).scalar_one()
return user
```
### 3. Missing back_populates
**Problem**: One-way relationship, inconsistent state, bugs.
```python
# BAD - missing back_populates
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
posts = relationship("Post")
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
# No relationship back to User!
# GOOD - bidirectional with back_populates
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
posts = relationship("Post", back_populates="user")
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
user = relationship("User", back_populates="posts")
```
### 4. Cascade Not Set Properly
**Problem**: Orphaned records, foreign key violations.
```python
# BAD - no cascade, orphaned posts when user deleted
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
posts = relationship("Post", back_populates="user")
# Deleting user leaves orphaned posts or fails with FK constraint
# GOOD - proper cascade for composition
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
posts = relationship(
"Post",
back_populates="user",
cascade="all, delete-orphan" # Delete posts when user deleted
)
# For many-to-many, different cascade
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
groups = relationship(
"Group",
secondary="user_groups",
back_populates="users",
cascade="save-update, merge" # Don't delete groups
)
```
### 5. Using joinedload with Many-to-Many
**Problem**: Cartesian product explosion, duplicate rows.
```python
# BAD - joinedload with many-to-many causes duplicates
def get_users_with_groups_and_posts():
with Session() as session:
users = session.execute(
select(User)
.options(joinedload(User.groups))
.options(joinedload(User.posts))
).scalars().all() # Cartesian product: users × groups × posts!
# GOOD - use selectinload for collections
from sqlalchemy.orm import selectinload
def get_users_with_groups_and_posts():
with Session() as session:
users = session.execute(
select(User)
.options(selectinload(User.groups))
.options(selectinload(User.posts))
).scalars().all() # Two separate IN queries, no cartesian product
```
### 6. Not Using contains_eager for Filtered Joins
**Problem**: Inefficient loading when filtering related objects.
```python
# BAD - loads all posts, then filters in Python
def get_users_with_published_posts():
with Session() as session:
users = session.execute(
select(User).options(selectinload(User.posts))
).scalars().all()
# Filters in Python, wasteful
return [
{
"user": user,
"posts": [p for p in user.posts if p.published]
}
for user in users
]
# GOOD - use contains_eager with join filter
from sqlalchemy.orm import contains_eager
def get_users_with_published_posts():
with Session() as session:
users = session.execute(
select(User)
.join(User.posts)
.where(Post.published == True)
.options(contains_eager(User.posts))
).unique().scalars().all()
return users
```
### 7. Circular Eager Loading
**Problem**: Infinite recursion with bidirectional relationships.
```python
# BAD - circular eager loading
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
posts = relationship("Post", back_populates="user", lazy="joined")
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
user = relationship("User", back_populates="posts", lazy="joined")
# Querying User loads Posts which loads User which loads Posts...
# GOOD - one side lazy, or explicit loading
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
posts = relationship("Post", back_populates="user", lazy="raise")
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
user = relationship("User", back_populates="posts", lazy="raise")
# Explicitly load what you need
def get_user_with_posts(user_id: int):
with Session() as session:
return session.execute(
select(User)
.options(selectinload(User.posts))
.where(User.id == user_id)
).scalar_one()
```
### 8. Not Using Association Object for Rich M2M
**Problem**: Can't store additional attributes on join table.
```python
# BAD - simple secondary table, can't add attributes
user_groups = Table(
"user_groups",
Base.metadata,
Column("user_id", ForeignKey("users.id")),
Column("group_id", ForeignKey("groups.id"))
)
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
groups = relationship("Group", secondary=user_groups)
# Can't store "joined_at" or "role" on the relationship!
# GOOD - association object pattern
class UserGroup(Base):
__tablename__ = "user_groups"
user_id = Column(Integer, ForeignKey("users.id"), primary_key=True)
group_id = Column(Integer, ForeignKey("groups.id"), primary_key=True)
joined_at = Column(DateTime, default=datetime.utcnow)
role = Column(String) # "admin", "member", etc.
user = relationship("User", back_populates="group_associations")
group = relationship("Group", back_populates="user_associations")
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
group_associations = relationship("UserGroup", back_populates="user")
# Convenience property
@property
def groups(self):
return [assoc.group for assoc in self.group_associations]
class Group(Base):
__tablename__ = "groups"
id = Column(Integer, primary_key=True)
user_associations = relationship("UserGroup", back_populates="group")
```
### 9. Not Using raiseload for Debugging
**Problem**: N+1 queries slip into production unnoticed.
```python
# BAD - lazy loading hidden issues in production
from sqlalchemy.orm import Session
def get_users():
with Session() as session:
users = session.execute(select(User)).scalars().all()
# Accessing posts triggers lazy load - silent N+1 in production
for user in users:
print(user.posts)
# GOOD - use raiseload in development to catch issues
from sqlalchemy.orm import raiseload
def get_users():
with Session() as session:
users = session.execute(
select(User).options(raiseload("*")) # Raise on any lazy load
).scalars().all()
# This will raise immediately, forcing us to fix it
for user in users:
print(user.posts) # InvalidRequestError!
# FIX - explicit loading
def get_users():
with Session() as session:
users = session.execute(
select(User).options(selectinload(User.posts))
).scalars().all()
for user in users:
print(user.posts) # No lazy load, efficient!
```
## Review Questions
1. Are all relationship queries using explicit eager loading (joinedload, selectinload)?
2. Is `lazy='raise'` used to prevent accidental lazy loading?
3. Do all relationships have proper `back_populates`?
4. Are cascade options set appropriately for composition vs association?
5. Is `selectinload` used instead of `joinedload` for collections?
6. Are association objects used for many-to-many with attributes?
FILE:references/sessions.md
# Sessions
## Critical Anti-Patterns
### 1. Session Not Closed
**Problem**: Connection pool exhaustion, memory leaks.
```python
# BAD - session never closed
def get_user(user_id: int):
session = Session()
user = session.get(User, user_id)
return user # Session leaked!
# GOOD - using context manager
def get_user(user_id: int) -> User | None:
with Session() as session:
user = session.get(User, user_id)
return user
```
### 2. Session Shared Across Requests
**Problem**: Concurrent modifications, race conditions, data corruption.
```python
# BAD - global session shared across requests
session = Session() # Module-level!
@app.get("/users/{user_id}")
async def get_user(user_id: int):
user = session.get(User, user_id) # Multiple requests share session!
return user
# GOOD - request-scoped session
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_db_session():
async with AsyncSession() as session:
try:
yield session
finally:
await session.close()
@app.get("/users/{user_id}")
async def get_user(user_id: int, session = Depends(get_db_session)):
user = await session.get(User, user_id)
return user
```
### 3. Manual Commit Without Rollback Handling
**Problem**: Partial commits, inconsistent state on errors.
```python
# BAD - no rollback on error
def create_user(name: str, email: str):
session = Session()
user = User(name=name, email=email)
session.add(user)
session.commit() # If this fails, session corrupted
session.close()
return user
# GOOD - proper error handling
def create_user(name: str, email: str) -> User:
with Session() as session:
try:
user = User(name=name, email=email)
session.add(user)
session.commit()
return user
except Exception:
session.rollback()
raise
```
### 4. Using Sync Session in Async Context
**Problem**: Blocks event loop, poor performance.
```python
# BAD - blocking sync session in async
from sqlalchemy.orm import Session
async def get_user(user_id: int):
with Session() as session: # Blocks event loop!
user = session.get(User, user_id)
return user
# GOOD - async session
from sqlalchemy.ext.asyncio import AsyncSession
async def get_user(user_id: int) -> User | None:
async with AsyncSession() as session:
result = await session.execute(
select(User).where(User.id == user_id)
)
return result.scalar_one_or_none()
```
### 5. Session Used After Commit
**Problem**: DetachedInstanceError, expired objects.
```python
# BAD - accessing object after session closed
def get_user_data(user_id: int):
with Session() as session:
user = session.get(User, user_id)
return user.email # DetachedInstanceError! Session closed
# GOOD - access data before session closes
def get_user_data(user_id: int) -> str | None:
with Session() as session:
user = session.get(User, user_id)
if user:
return user.email
return None
# BETTER - use expunge or eager loading
from sqlalchemy.orm import joinedload
def get_user_with_posts(user_id: int) -> User | None:
with Session() as session:
user = session.execute(
select(User)
.options(joinedload(User.posts))
.where(User.id == user_id)
).scalar_one_or_none()
if user:
session.expunge(user) # Detach from session
return user
```
### 6. Not Using Session.begin() for Transactions
**Problem**: AutoCommit confusion, no explicit transaction boundaries.
```python
# BAD - implicit transaction boundaries
def transfer_money(from_id: int, to_id: int, amount: float):
with Session() as session:
from_account = session.get(Account, from_id)
to_account = session.get(Account, to_id)
from_account.balance -= amount
session.commit() # First commit
to_account.balance += amount
session.commit() # Second commit - money lost if this fails!
# GOOD - explicit transaction with begin()
def transfer_money(from_id: int, to_id: int, amount: float):
with Session() as session:
with session.begin():
from_account = session.get(Account, from_id)
to_account = session.get(Account, to_id)
if from_account.balance < amount:
raise ValueError("Insufficient funds")
from_account.balance -= amount
to_account.balance += amount
# Both committed together or rolled back together
# ASYNC version
async def transfer_money(from_id: int, to_id: int, amount: float):
async with AsyncSession() as session:
async with session.begin():
result = await session.execute(
select(Account).where(Account.id.in_([from_id, to_id]))
)
accounts = {acc.id: acc for acc in result.scalars()}
from_account = accounts[from_id]
to_account = accounts[to_id]
if from_account.balance < amount:
raise ValueError("Insufficient funds")
from_account.balance -= amount
to_account.balance += amount
```
### 7. Session Factory Not Configured Properly
**Problem**: Inconsistent session behavior, connection issues.
```python
# BAD - new engine every time
def get_session():
engine = create_engine("postgresql://...") # New engine each call!
Session = sessionmaker(bind=engine)
return Session()
# GOOD - reuse engine and session factory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Module level - create once
engine = create_engine(
"postgresql://...",
pool_pre_ping=True, # Verify connections
pool_size=10,
max_overflow=20
)
SessionLocal = sessionmaker(
bind=engine,
expire_on_commit=False, # Don't expire objects on commit
autocommit=False,
autoflush=False
)
def get_session():
return SessionLocal()
# ASYNC version
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
async_engine = create_async_engine(
"postgresql+asyncpg://...",
pool_pre_ping=True,
pool_size=10
)
AsyncSessionLocal = sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False
)
async def get_async_session():
async with AsyncSessionLocal() as session:
yield session
```
### 8. Missing Session Refresh After Background Operations
**Problem**: Stale data when session persists across long operations.
```python
# BAD - using stale session data
async def process_order(order_id: int):
async with AsyncSession() as session:
order = await session.get(Order, order_id)
# Long running background task
await process_payment(order.id) # Another process might update order
# order.status might be stale here!
if order.status == "pending":
order.status = "completed"
await session.commit()
# GOOD - refresh after external operations
async def process_order(order_id: int):
async with AsyncSession() as session:
order = await session.get(Order, order_id)
await process_payment(order.id)
# Refresh to get latest state
await session.refresh(order)
if order.status == "pending":
order.status = "completed"
await session.commit()
```
## Review Questions
1. Are all sessions using context managers (`with` or `async with`)?
2. Is each request/thread getting its own session instance?
3. Are transactions using explicit `session.begin()`?
4. Are async contexts using `AsyncSession` not sync `Session`?
5. Are objects accessed before the session closes?
6. Is the session factory configured once and reused?
Reviews Python code for type safety, async patterns, error handling, and common mistakes. Use when reviewing .py files, checking type hints, async/await usag...
---
name: python-code-review
description: Reviews Python code for type safety, async patterns, error handling, and common mistakes. Use when reviewing .py files, checking type hints, async/await usage, or exception handling.
---
# Python Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| Indentation, line length, whitespace, naming | [references/pep8-style.md](references/pep8-style.md) |
| Missing/wrong type hints, Any usage | [references/type-safety.md](references/type-safety.md) |
| Blocking calls in async, missing await | [references/async-patterns.md](references/async-patterns.md) |
| Bare except, missing context, logging | [references/error-handling.md](references/error-handling.md) |
| Mutable defaults, print statements | [references/common-mistakes.md](references/common-mistakes.md) |
## Review Checklist
### PEP8 Style
- [ ] 4-space indentation (no tabs)
- [ ] Line length ≤79 characters (≤72 for docstrings/comments)
- [ ] Two blank lines around top-level definitions, one within classes
- [ ] Imports grouped: stdlib → third-party → local (blank line between groups)
- [ ] No whitespace inside brackets or before colons/commas
- [ ] Naming: `snake_case` for functions/variables, `CamelCase` for classes, `UPPER_CASE` for constants
- [ ] Inline comments separated by at least two spaces
### Type Safety
- [ ] Type hints on all function parameters and return types
- [ ] No `Any` unless necessary (with comment explaining why)
- [ ] Proper `T | None` syntax (Python 3.10+)
### Async Patterns
- [ ] No blocking calls (`time.sleep`, `requests`) in async functions
- [ ] Proper `await` on all coroutines
### Error Handling
- [ ] No bare `except:` clauses
- [ ] Specific exception types with context
- [ ] `raise ... from` to preserve stack traces
### Common Mistakes
- [ ] No mutable default arguments
- [ ] Using `logger` not `print()` for output
- [ ] f-strings preferred over `.format()` or `%`
## Valid Patterns (Do NOT Flag)
These patterns are intentional and correct - do not report as issues:
- **Type annotation vs type assertion** - Annotations declare types but are not runtime assertions; don't confuse with missing validation
- **Using `Any` when interacting with untyped libraries** - Required when external libraries lack type stubs
- **Empty `__init__.py` files** - Valid for package structure, no code required
- **`noqa` comments** - Valid when linter rule doesn't apply to specific case
- **Using `cast()` after runtime type check** - Correct pattern to inform type checker of narrowed type
## Context-Sensitive Rules
Only flag these issues when the specific conditions apply:
| Issue | Flag ONLY IF |
|-------|--------------|
| Generic exception handling | Specific exception types are available and meaningful |
| Unused variables | Variable lacks `_` prefix AND isn't used in f-strings, logging, or debugging |
## Gates (reporting workflow)
Complete **in order**. Do not advance until each **pass condition** is met.
1. **Scope** — **Pass:** You list every `.py` path (or explicit glob) you inspected this run.
2. **False-positive screen** — **Pass:** For each issue you plan to report, you checked **Valid Patterns** and **Context-Sensitive Rules** above; you drop or narrow the finding if those sections say not to flag it.
3. **Evidence** — **Pass:** Each remaining finding includes **`[FILE:LINE]`** (or a bounded line range). Symbols or short verbatim snippets may supplement the location anchor but do not replace it.
4. **Verification protocol** — **Pass:** You load [review-verification-protocol](../review-verification-protocol/SKILL.md) and complete its mandatory steps **for each reported issue** before the user-facing write-up.
5. **Ship** — **Pass:** The user-visible output matches whatever structure that protocol requires (no issues-only dump that skips its checks).
## When to Load References
- Reviewing code formatting/style → pep8-style.md
- Reviewing function signatures → type-safety.md
- Reviewing `async def` functions → async-patterns.md
- Reviewing try/except blocks → error-handling.md
- General Python review → common-mistakes.md
## Review Questions
1. Does the code follow PEP8 formatting (indentation, line length, whitespace)?
2. Are imports properly grouped (stdlib → third-party → local)?
3. Do names follow conventions (snake_case, CamelCase, UPPER_CASE)?
4. Are all function signatures fully typed?
5. Are async functions truly non-blocking?
6. Do exceptions include meaningful context?
7. Are there any mutable default arguments?
Before reporting: complete **Gates (reporting workflow)** above (especially gate 4).
FILE:references/async-patterns.md
# Async Patterns
## Critical Anti-Patterns
### 1. Blocking Calls in Async Functions
**Problem**: Blocks the event loop, defeats async benefits.
```python
# BAD - blocks event loop
async def fetch_data():
response = requests.get(url) # BLOCKING!
time.sleep(1) # BLOCKING!
return response.json()
# GOOD - non-blocking
async def fetch_data():
async with httpx.AsyncClient() as client:
response = await client.get(url)
await asyncio.sleep(1)
return response.json()
```
### 2. Missing await on Coroutines
**Problem**: Coroutine never executes.
```python
# BAD - coroutine created but never awaited
async def process():
fetch_data() # Returns coroutine, doesn't execute!
# GOOD
async def process():
await fetch_data()
```
### 3. Sequential Instead of Concurrent
**Problem**: Misses parallelization opportunity.
```python
# BAD - sequential (slow)
async def get_all():
user = await get_user()
posts = await get_posts()
comments = await get_comments()
return user, posts, comments
# GOOD - concurrent (fast)
async def get_all():
user, posts, comments = await asyncio.gather(
get_user(),
get_posts(),
get_comments()
)
return user, posts, comments
```
### 4. Missing async with for Async Context Managers
**Problem**: Resource not properly managed.
```python
# BAD
async def query():
session = aiosqlite.connect(db) # Not entered!
return await session.execute(sql)
# GOOD
async def query():
async with aiosqlite.connect(db) as session:
return await session.execute(sql)
```
### 5. Sync File I/O in Async Context
**Problem**: File operations block event loop.
```python
# BAD - blocks event loop
async def read_config():
with open("config.json") as f:
return json.load(f)
# GOOD - use aiofiles
import aiofiles
async def read_config():
async with aiofiles.open("config.json") as f:
content = await f.read()
return json.loads(content)
# ACCEPTABLE - for small files, run in executor
async def read_config():
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, load_config_sync)
```
## Review Questions
1. Are there any `requests`, `time.sleep`, or `open()` calls in async functions?
2. Is every coroutine call awaited?
3. Are independent async calls parallelized with `gather()`?
4. Are async context managers used with `async with`?
FILE:references/common-mistakes.md
# Common Mistakes
## Critical Anti-Patterns
### 1. Mutable Default Arguments
**Problem**: Default value is shared across all calls.
```python
# BAD - same list reused!
def add_item(item, items=[]):
items.append(item)
return items
add_item("a") # ["a"]
add_item("b") # ["a", "b"] - unexpected!
# GOOD
def add_item(item, items=None):
if items is None:
items = []
items.append(item)
return items
# BETTER - using dataclass
from dataclasses import dataclass, field
@dataclass
class Container:
items: list = field(default_factory=list)
```
### 2. Using print() for Logging
**Problem**: No log levels, no timestamps, hard to filter.
```python
# BAD
print(f"Processing {item}")
print(f"Error: {e}")
# GOOD
from loguru import logger
logger.info(f"Processing {item}")
logger.error(f"Error: {e}")
```
### 3. String Formatting Inconsistency
**Problem**: Mixing formats reduces readability.
```python
# BAD - mixed formats
msg = "Hello %s" % name
msg = "Hello {}".format(name)
msg = f"Hello {name}"
# GOOD - f-strings consistently
msg = f"Hello {name}"
total = f"Count: {count:,}" # with formatting
path = f"{base}/{sub}/{file}"
```
### 4. Unused Variables
**Problem**: Dead code, confusing to readers.
```python
# BAD
result = process() # never used
# GOOD - use underscore for intentionally ignored
_, second, _ = get_triple()
# Or just don't assign
process() # if result not needed
```
### 5. Import Order
**Problem**: Hard to scan, may cause issues.
```python
# BAD - random order
from myapp.utils import helper
import os
from typing import Optional
import sys
from myapp.models import User
# GOOD - standard order
import os
import sys
from typing import Optional
from myapp.models import User
from myapp.utils import helper
```
### 6. Magic Numbers
**Problem**: Unclear intent, hard to maintain.
```python
# BAD
if len(items) > 100:
paginate()
time.sleep(3600)
# GOOD
MAX_PAGE_SIZE = 100
CACHE_TTL_SECONDS = 3600
if len(items) > MAX_PAGE_SIZE:
paginate()
time.sleep(CACHE_TTL_SECONDS)
```
### 7. Nested Conditionals
**Problem**: Hard to read and maintain.
```python
# BAD
def process(user):
if user:
if user.active:
if user.verified:
return do_work(user)
return None
# GOOD - early returns
def process(user):
if not user:
return None
if not user.active:
return None
if not user.verified:
return None
return do_work(user)
```
## Review Questions
1. Are there any mutable default arguments (list, dict, set)?
2. Is `print()` used instead of `logger`?
3. Are f-strings used consistently?
4. Are there magic numbers that should be constants?
5. Are deeply nested conditionals flattened with early returns?
FILE:references/error-handling.md
# Error Handling
## Critical Anti-Patterns
### 1. Bare Except Clause
**Problem**: Catches everything including KeyboardInterrupt, SystemExit.
```python
# BAD
try:
process()
except:
pass
# GOOD - specific exception
try:
process()
except ValueError as e:
logger.error(f"Invalid value: {e}")
raise
# ACCEPTABLE - if you must catch all
try:
process()
except Exception as e: # Still allows KeyboardInterrupt
logger.error(f"Unexpected error: {e}")
raise
```
### 2. Swallowing Exceptions
**Problem**: Hides errors, makes debugging impossible.
```python
# BAD
try:
result = risky_operation()
except Exception:
pass # Error silently ignored!
# GOOD - log and handle
try:
result = risky_operation()
except OperationError as e:
logger.warning(f"Operation failed: {e}")
result = default_value
```
### 3. Losing Exception Context
**Problem**: Original stack trace lost.
```python
# BAD - loses original traceback
try:
parse_config()
except ValueError:
raise ConfigError("Invalid config")
# GOOD - preserves chain
try:
parse_config()
except ValueError as e:
raise ConfigError("Invalid config") from e
```
### 4. Missing Context in Error Messages
**Problem**: Can't diagnose issue from logs.
```python
# BAD
except KeyError:
raise ValueError("Missing key")
# GOOD - include context
except KeyError as e:
raise ValueError(f"Missing required key: {e.args[0]}") from e
```
### 5. Not Logging Before Re-raising
**Problem**: Exception might be caught elsewhere without logging.
```python
# BAD - no record if caught upstream
try:
process(item)
except ProcessError:
raise
# GOOD - log before re-raising
try:
process(item)
except ProcessError as e:
logger.error(f"Failed to process item {item.id}: {e}")
raise
```
## Logging Best Practices
```python
from loguru import logger
# BAD
print(f"Processing {item}")
print(f"Error: {e}")
# GOOD
logger.debug(f"Processing item {item.id}")
logger.info(f"Completed batch of {count} items")
logger.warning(f"Retry {attempt}/3 for {operation}")
logger.error(f"Failed to process {item.id}: {e}")
# With exception info
logger.exception(f"Unexpected error processing {item.id}")
```
## Review Questions
1. Are there any bare `except:` clauses?
2. Is exception context preserved with `raise ... from e`?
3. Do error messages include enough context to diagnose?
4. Is logging used instead of print statements?
FILE:references/pep8-style.md
# PEP8 Style Guide
## Indentation
**Rule**: Use 4 spaces per indentation level. Never use tabs.
```python
# BAD - 2 spaces
def foo():
return bar
# BAD - tabs (shown with → for visibility)
def foo():
→ return bar # ← actual code would have tab here
# GOOD - 4 spaces
def foo():
return bar
```
### Continuation Lines
```python
# GOOD - aligned with opening delimiter
result = function_name(arg_one, arg_two,
arg_three, arg_four)
# GOOD - hanging indent
result = function_name(
arg_one, arg_two,
arg_three, arg_four,
)
# BAD - no alignment
result = function_name(arg_one, arg_two,
arg_three, arg_four)
```
## Line Length
**Rule**: Maximum 79 characters for code, 72 for docstrings/comments.
```python
# BAD - too long
result = some_function(argument_one, argument_two, argument_three, argument_four, argument_five)
# GOOD - broken across lines
result = some_function(
argument_one,
argument_two,
argument_three,
argument_four,
argument_five,
)
```
## Blank Lines
**Rule**: Two blank lines around top-level definitions, one blank line around methods.
```python
# GOOD
import os
class MyClass:
"""Docstring."""
def method_one(self):
pass
def method_two(self):
pass
def top_level_function():
pass
def another_function():
pass
```
## Imports
**Rule**: Group imports in order: stdlib → third-party → local. One blank line between groups.
```python
# GOOD
import os
import sys
from pathlib import Path
import requests
from pydantic import BaseModel
from myapp.models import User
from myapp.utils import helper
```
**Rule**: Avoid wildcard imports.
```python
# BAD
from module import *
# GOOD
from module import specific_function, SpecificClass
```
## Whitespace
### Inside Brackets
```python
# BAD
spam( ham[ 1 ], { eggs: 2 } )
# GOOD
spam(ham[1], {eggs: 2})
```
### Before Colons and Commas
```python
# BAD
if x == 4 :
print(x , y)
# GOOD
if x == 4:
print(x, y)
```
### Around Operators
```python
# BAD
x=1
y = x+1
z = x +1
# GOOD
x = 1
y = x + 1
# Exception: indicate precedence
result = x*2 + y*3
```
### Function Arguments
```python
# BAD
def function(arg1 = None, arg2 = 0):
pass
# GOOD
def function(arg1=None, arg2=0):
pass
```
## Naming Conventions
| Type | Convention | Example |
|------|------------|---------|
| Functions/variables | `snake_case` | `my_function`, `user_count` |
| Classes | `CamelCase` | `MyClass`, `HttpClient` |
| Constants | `UPPER_CASE` | `MAX_SIZE`, `DEFAULT_TIMEOUT` |
| Private | Leading underscore | `_internal_method` |
| "Protected" | Double underscore | `__name_mangled` |
```python
# BAD
def MyFunction(): # should be snake_case
pass
class my_class: # should be CamelCase
maxSize = 100 # should be MAX_SIZE if constant
# GOOD
def my_function():
pass
class MyClass:
MAX_SIZE = 100
```
## Comments
### Inline Comments
**Rule**: Separate by at least two spaces. Use sparingly.
```python
# BAD
x = x + 1# increment
# GOOD
x = x + 1 # compensate for border
```
### Block Comments
```python
# GOOD - aligned with code, complete sentences
# This is a block comment explaining the
# following code section. Each sentence
# ends with a period.
result = complex_operation()
```
### Docstrings
```python
# GOOD
def fetch_users(limit: int = 100) -> list[User]:
"""Fetch users from the database.
Args:
limit: Maximum number of users to return.
Returns:
List of User objects.
Raises:
DatabaseError: If connection fails.
"""
pass
```
## Review Questions
1. Is indentation consistently 4 spaces (no tabs)?
2. Are lines ≤79 characters (≤72 for docstrings)?
3. Are there two blank lines around top-level definitions?
4. Are imports grouped correctly with blank lines between groups?
5. Is there extraneous whitespace inside brackets or around operators?
6. Do names follow conventions (snake_case, CamelCase, UPPER_CASE)?
7. Are inline comments separated by at least two spaces?
FILE:references/type-safety.md
# Type Safety
## Critical Anti-Patterns
### 1. Missing Return Type
**Problem**: Callers don't know what to expect.
```python
# BAD
def get_user(id: int):
return User.query.get(id)
# GOOD
def get_user(id: int) -> User | None:
return User.query.get(id)
```
### 2. Using Any Without Justification
**Problem**: Defeats the purpose of type checking.
```python
# BAD
def process(data: Any) -> Any:
return data
# GOOD - with justification
def process(data: Any) -> dict: # Any: accepts JSON from external API
return json.loads(data)
# BETTER - use proper types
def process(data: str | bytes) -> dict:
return json.loads(data)
```
### 3. Optional vs Union Syntax
**Problem**: Inconsistent syntax, less readable.
```python
# OLD (pre-3.10)
from typing import Optional, Union
def find(id: int) -> Optional[User]: ...
def parse(val: Union[str, int]) -> str: ...
# GOOD (3.10+)
def find(id: int) -> User | None: ...
def parse(val: str | int) -> str: ...
```
### 4. Missing Generic Types
**Problem**: Loses type information in collections.
```python
# BAD
def get_items() -> list:
return [Item(...)]
# GOOD
def get_items() -> list[Item]:
return [Item(...)]
# BAD
def get_config() -> dict:
return {"key": "value"}
# GOOD
def get_config() -> dict[str, str]:
return {"key": "value"}
```
### 5. TypedDict for Structured Dicts
**Problem**: Plain dict loses key/value type information.
```python
# BAD
def get_user_data() -> dict:
return {"name": "Alice", "age": 30}
# GOOD
from typing import TypedDict
class UserData(TypedDict):
name: str
age: int
def get_user_data() -> UserData:
return {"name": "Alice", "age": 30}
```
## Review Questions
1. Are all function parameters typed?
2. Are all return types specified?
3. Is `Any` used only when necessary with a comment?
4. Are collection types generic (`list[T]`, `dict[K, V]`)?
5. Is `T | None` used instead of `Optional[T]`?
Reviews pytest test code for async patterns, fixtures, parametrize, and mocking. Use when reviewing test_*.py files, checking async test functions, fixture u...
---
name: pytest-code-review
description: Reviews pytest test code for async patterns, fixtures, parametrize, and mocking. Use when reviewing test_*.py files, checking async test functions, fixture usage, or mock patterns.
---
# Pytest Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| async def test_*, AsyncMock, await patterns | [references/async-testing.md](references/async-testing.md) |
| conftest.py, factory fixtures, scope, cleanup | [references/fixtures.md](references/fixtures.md) |
| @pytest.mark.parametrize, DRY patterns | [references/parametrize.md](references/parametrize.md) |
| AsyncMock tracking, patch patterns, when to mock | [references/mocking.md](references/mocking.md) |
## Review gates
Work in order. Do not assert pytest-specific problems until each applicable gate passes.
1. **Scoped files** — **Pass when:** You list every `test_*.py` and any `conftest.py` you will cite; no findings for files outside that list.
2. **Async vs sync** — **Pass when:** Per scoped file, you note whether it uses `async def test_*` / `await`; if yes, open [references/async-testing.md](references/async-testing.md) before criticizing async usage.
3. **Fixtures** — **Pass when:** If shared setup matters, you name the `conftest.py` path(s) or state none; for yield fixtures, confirm cleanup exists before claiming resource leaks.
4. **patch / mocks** — **Pass when:** For any `patch` or mock critique, you give the import path where the symbol is **used** (call site), or mark N/A; open [references/mocking.md](references/mocking.md) when mocking is central to the review.
5. **Findings** — **Pass when:** Each finding includes a file path and line(s) or test node id, not a generic rule restatement.
## Review Checklist
- [ ] Test functions are `async def test_*` for async code under test
- [ ] AsyncMock used for async dependencies, not Mock
- [ ] All async mocks and coroutines are awaited
- [ ] Fixtures in conftest.py for shared setup
- [ ] Fixture scope appropriate (function, class, module, session)
- [ ] Yield fixtures have proper cleanup in finally block
- [ ] @pytest.mark.parametrize for similar test cases
- [ ] No duplicated test logic across multiple test functions
- [ ] Mocks track calls properly (assert_called_once_with)
- [ ] patch() targets correct location (where used, not defined)
- [ ] No mocking of internals that should be tested
- [ ] Test isolation (no shared mutable state between tests)
## When to Load References
- Reviewing async test functions → async-testing.md
- Reviewing fixtures or conftest.py → fixtures.md
- Reviewing similar test cases → parametrize.md
- Reviewing mocks and patches → mocking.md
## Review Questions
1. Are all async functions tested with async def test_*?
2. Are fixtures properly scoped with appropriate cleanup?
3. Can similar test cases be parametrized to reduce duplication?
4. Are mocks tracking calls and used at the right locations?
FILE:references/async-testing.md
# Async Testing
## Critical Anti-Patterns
### 1. Using Mock Instead of AsyncMock
**Problem**: Mock returns a regular Mock object, not a coroutine. Tests pass but don't actually test async behavior.
```python
# BAD - Mock doesn't work with async
from unittest.mock import Mock
@pytest.mark.asyncio
async def test_fetch_data():
mock_client = Mock()
mock_client.get.return_value = {"data": "test"}
# This won't work! mock_client.get() is not awaitable
result = await fetch_data(mock_client) # TypeError!
# GOOD - AsyncMock for async functions
from unittest.mock import AsyncMock
@pytest.mark.asyncio
async def test_fetch_data():
mock_client = AsyncMock()
mock_client.get.return_value = {"data": "test"}
result = await fetch_data(mock_client)
assert result == {"data": "test"}
```
### 2. Forgetting @pytest.mark.asyncio
**Problem**: Test function is not run as coroutine, async code never executes.
```python
# BAD - missing decorator
async def test_process_data():
result = await process_data() # Never actually awaited!
assert result == expected
# GOOD - proper async test
@pytest.mark.asyncio
async def test_process_data():
result = await process_data()
assert result == expected
```
### 3. Not Awaiting Async Mocks
**Problem**: Mock returns coroutine object instead of actual value.
```python
# BAD - not awaiting AsyncMock
@pytest.mark.asyncio
async def test_service():
mock_db = AsyncMock()
mock_db.query.return_value = [{"id": 1}]
service = UserService(mock_db)
result = service.get_users() # Returns coroutine, not list!
assert len(result) == 1 # TypeError!
# GOOD - await AsyncMock
@pytest.mark.asyncio
async def test_service():
mock_db = AsyncMock()
mock_db.query.return_value = [{"id": 1}]
service = UserService(mock_db)
result = await service.get_users()
assert len(result) == 1
```
### 4. Mixing Sync and Async in Tests
**Problem**: Calling sync blocking code in async test defeats purpose.
```python
# BAD - mixing sync and async
@pytest.mark.asyncio
async def test_user_flow():
user = create_user_sync() # Blocking call!
time.sleep(1) # Blocks event loop!
result = await process_user(user)
assert result.processed
# GOOD - fully async
@pytest.mark.asyncio
async def test_user_flow():
user = await create_user_async()
await asyncio.sleep(1)
result = await process_user(user)
assert result.processed
```
### 5. Not Cleaning Up Async Resources
**Problem**: Background tasks, connections, or coroutines left running after test.
```python
# BAD - no cleanup
@pytest.mark.asyncio
async def test_background_task():
task = asyncio.create_task(long_running_operation())
# Task still running after test ends!
result = await some_other_operation()
assert result
# GOOD - proper cleanup
@pytest.mark.asyncio
async def test_background_task():
task = asyncio.create_task(long_running_operation())
try:
result = await some_other_operation()
assert result
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
```
### 6. Not Testing Concurrent Behavior
**Problem**: Tests run sequentially, missing race conditions and timing issues.
```python
# BAD - sequential testing of concurrent code
@pytest.mark.asyncio
async def test_concurrent_updates():
await update_counter()
await update_counter()
# Doesn't test actual concurrent access!
assert get_counter() == 2
# GOOD - actually test concurrency
@pytest.mark.asyncio
async def test_concurrent_updates():
results = await asyncio.gather(
update_counter(),
update_counter(),
update_counter()
)
# Tests actual concurrent behavior
assert get_counter() == 3
```
### 7. Using pytest-asyncio Without Configuration
**Problem**: Tests may not run in correct mode or fail silently.
```python
# BAD - no configuration, ambiguous mode
# test_something.py
async def test_feature(): # Might not run as async!
result = await process()
assert result
# GOOD - explicit configuration
# pyproject.toml or pytest.ini
[tool.pytest.ini_options]
asyncio_mode = "auto"
# test_something.py
import pytest
@pytest.mark.asyncio
async def test_feature():
result = await process()
assert result
```
### 8. Not Testing Exception Paths in Async Code
**Problem**: Async exceptions behave differently, need explicit testing.
```python
# BAD - not testing async exceptions
@pytest.mark.asyncio
async def test_error_handling():
# Doesn't verify exception is properly raised
result = await fetch_data_with_retry()
assert result
# GOOD - test async exception handling
@pytest.mark.asyncio
async def test_error_handling():
mock_client = AsyncMock()
mock_client.get.side_effect = asyncio.TimeoutError()
with pytest.raises(asyncio.TimeoutError):
await fetch_data(mock_client)
```
## Review Questions
1. Are all async test functions marked with `@pytest.mark.asyncio`?
2. Are `AsyncMock` objects used instead of `Mock` for async dependencies?
3. Are all coroutines and async mocks properly awaited?
4. Are async resources (tasks, connections) cleaned up after tests?
5. Do concurrent code tests actually run operations concurrently?
6. Is pytest-asyncio configured correctly in pyproject.toml or pytest.ini?
FILE:references/fixtures.md
# Fixtures
## Critical Anti-Patterns
### 1. Duplicated Setup Across Tests
**Problem**: Same setup repeated in every test function instead of using fixtures.
```python
# BAD - duplicated setup
@pytest.mark.asyncio
async def test_create_user():
db = await create_db_connection()
await db.setup_schema()
result = await create_user(db, "Alice")
await db.close()
assert result.name == "Alice"
@pytest.mark.asyncio
async def test_delete_user():
db = await create_db_connection()
await db.setup_schema()
result = await delete_user(db, 1)
await db.close()
assert result.success
# GOOD - fixture handles setup/teardown
@pytest.fixture
async def db():
connection = await create_db_connection()
await connection.setup_schema()
yield connection
await connection.close()
@pytest.mark.asyncio
async def test_create_user(db):
result = await create_user(db, "Alice")
assert result.name == "Alice"
@pytest.mark.asyncio
async def test_delete_user(db):
result = await delete_user(db, 1)
assert result.success
```
### 2. Missing Cleanup in Fixtures
**Problem**: Resources leak when tests fail or fixtures don't clean up.
```python
# BAD - no cleanup
@pytest.fixture
async def temp_file():
file_path = "/tmp/test_file.txt"
async with aiofiles.open(file_path, "w") as f:
await f.write("test data")
return file_path
# File never deleted!
# GOOD - cleanup with yield
@pytest.fixture
async def temp_file():
file_path = "/tmp/test_file.txt"
async with aiofiles.open(file_path, "w") as f:
await f.write("test data")
yield file_path
# Cleanup always runs
if os.path.exists(file_path):
os.remove(file_path)
```
### 3. Wrong Fixture Scope
**Problem**: Expensive setup repeated unnecessarily or shared state causes test coupling.
```python
# BAD - function scope for expensive operation
@pytest.fixture(scope="function") # Runs for EVERY test!
def database_with_seed_data():
db = create_database()
seed_large_dataset(db) # Takes 10 seconds!
return db
# GOOD - module scope for expensive, read-only setup
@pytest.fixture(scope="module")
def database_with_seed_data():
db = create_database()
seed_large_dataset(db)
yield db
db.cleanup()
# BAD - session scope for mutable state
@pytest.fixture(scope="session")
def user_cache():
return {} # Shared across ALL tests - race conditions!
# GOOD - function scope for mutable state
@pytest.fixture(scope="function")
def user_cache():
return {} # Fresh cache per test
```
### 4. Not Using conftest.py
**Problem**: Fixtures duplicated across test files.
```python
# BAD - fixture in test_users.py
@pytest.fixture
def db_session():
session = create_session()
yield session
session.close()
# BAD - same fixture duplicated in test_posts.py
@pytest.fixture
def db_session():
session = create_session()
yield session
session.close()
# GOOD - shared fixture in conftest.py
# conftest.py
@pytest.fixture
def db_session():
session = create_session()
yield session
session.close()
# test_users.py and test_posts.py can both use db_session
```
### 5. Factory Fixtures Not Used for Variations
**Problem**: Creating multiple similar fixtures instead of one factory.
```python
# BAD - separate fixture for each variation
@pytest.fixture
def user_alice():
return User(name="Alice", role="admin")
@pytest.fixture
def user_bob():
return User(name="Bob", role="user")
@pytest.fixture
def user_charlie():
return User(name="Charlie", role="guest")
# GOOD - factory fixture
@pytest.fixture
def make_user():
def _make_user(name: str, role: str = "user"):
return User(name=name, role=role)
return _make_user
def test_admin_access(make_user):
admin = make_user("Alice", role="admin")
assert admin.can_delete()
def test_user_access(make_user):
user = make_user("Bob")
assert not user.can_delete()
```
### 6. Fixture Dependencies Not Leveraged
**Problem**: Manually composing dependencies instead of using fixture chaining.
```python
# BAD - manual composition
@pytest.fixture
def authenticated_client():
app = create_app()
client = TestClient(app)
user = create_user()
token = generate_token(user)
client.headers["Authorization"] = f"Bearer {token}"
return client
# GOOD - fixture chaining
@pytest.fixture
def app():
return create_app()
@pytest.fixture
def client(app):
return TestClient(app)
@pytest.fixture
def user():
return create_user()
@pytest.fixture
def auth_token(user):
return generate_token(user)
@pytest.fixture
def authenticated_client(client, auth_token):
client.headers["Authorization"] = f"Bearer {auth_token}"
return client
```
### 7. Autouse Fixtures Overused
**Problem**: Autouse fixtures run even when not needed, slowing tests.
```python
# BAD - autouse for specific setup
@pytest.fixture(autouse=True)
def setup_database():
# Runs for EVERY test, even ones that don't use database
db = setup_test_db()
yield
db.teardown()
# GOOD - explicit fixture dependency
@pytest.fixture
def database():
db = setup_test_db()
yield db
db.teardown()
def test_with_db(database):
# Only runs when explicitly requested
assert database.is_connected()
def test_without_db():
# Doesn't pay database setup cost
assert 1 + 1 == 2
```
### 8. Async Fixtures Without Proper Cleanup
**Problem**: Async cleanup not wrapped in try/finally.
```python
# BAD - no try/finally in async fixture
@pytest.fixture
async def api_client():
client = AsyncClient(base_url="http://test")
yield client
await client.close() # Skipped if test fails!
# GOOD - try/finally ensures cleanup
@pytest.fixture
async def api_client():
client = AsyncClient(base_url="http://test")
try:
yield client
finally:
await client.close()
```
### 9. Using Fixtures as Data Instead of Setup
**Problem**: Fixtures return data instead of managing resources.
```python
# BAD - fixture just returns data
@pytest.fixture
def sample_users():
return [
{"name": "Alice", "role": "admin"},
{"name": "Bob", "role": "user"}
]
# GOOD - use module-level constant
SAMPLE_USERS = [
{"name": "Alice", "role": "admin"},
{"name": "Bob", "role": "user"}
]
# ACCEPTABLE - fixture when setup/teardown needed
@pytest.fixture
async def sample_users(db):
users = await db.create_users([
{"name": "Alice", "role": "admin"},
{"name": "Bob", "role": "user"}
])
yield users
await db.delete_users([u.id for u in users])
```
## Review Questions
1. Are fixtures in conftest.py for cross-file reuse?
2. Do all fixtures with resources have proper yield + cleanup?
3. Is fixture scope appropriate (function/module/session)?
4. Are factory fixtures used for creating test data variations?
5. Are fixture dependencies chained instead of manually composed?
6. Are autouse fixtures limited to truly universal setup?
7. Do async fixtures wrap cleanup in try/finally blocks?
FILE:references/mocking.md
# Mocking
## Critical Anti-Patterns
### 1. Patching Where Defined Instead of Where Used
**Problem**: Mock doesn't affect the code under test because patch location is wrong.
```python
# module_a.py
def external_api_call():
return "real data"
# module_b.py
from module_a import external_api_call
def process_data():
return external_api_call()
# BAD - patching where defined
from unittest.mock import patch
@patch("module_a.external_api_call") # Wrong location!
def test_process_data(mock_api):
mock_api.return_value = "mocked data"
result = process_data()
assert result == "mocked data" # FAILS! Uses real function
# GOOD - patch where used
@patch("module_b.external_api_call") # Patch in module_b namespace
def test_process_data(mock_api):
mock_api.return_value = "mocked data"
result = process_data()
assert result == "mocked data" # Works!
```
### 2. Not Verifying Mock Calls
**Problem**: Mock used but never verified, test doesn't validate behavior.
```python
# BAD - mock not verified
@pytest.mark.asyncio
async def test_user_creation(mocker):
mock_db = mocker.AsyncMock()
mock_db.insert.return_value = {"id": 1}
await create_user(mock_db, "Alice")
# No verification! Did it call insert? With what args?
# GOOD - verify mock calls
@pytest.mark.asyncio
async def test_user_creation(mocker):
mock_db = mocker.AsyncMock()
mock_db.insert.return_value = {"id": 1}
result = await create_user(mock_db, "Alice")
mock_db.insert.assert_called_once_with({"name": "Alice"})
assert result["id"] == 1
```
### 3. Over-Mocking Internal Implementation
**Problem**: Mocking internal details that should be tested, not mocked.
```python
# BAD - mocking internal helper that should be tested
class UserService:
def _validate_email(self, email):
return "@" in email
def create_user(self, email):
if self._validate_email(email):
return User(email=email)
raise ValueError("Invalid email")
@patch.object(UserService, "_validate_email")
def test_create_user(mock_validate):
mock_validate.return_value = True
service = UserService()
user = service.create_user("invalid") # Should fail but doesn't!
assert user.email == "invalid"
# GOOD - test the actual behavior
def test_create_user_valid():
service = UserService()
user = service.create_user("[email protected]")
assert user.email == "[email protected]"
def test_create_user_invalid():
service = UserService()
with pytest.raises(ValueError, match="Invalid email"):
service.create_user("invalid")
```
### 4. Using Mock Instead of AsyncMock for Async
**Problem**: Regular Mock doesn't work properly with async code.
```python
# BAD - Mock for async function
from unittest.mock import Mock
@pytest.mark.asyncio
async def test_fetch_data():
mock_client = Mock()
mock_client.get = Mock(return_value={"data": "test"})
result = await fetch_data(mock_client) # TypeError: object Mock can't be used in 'await'
# GOOD - AsyncMock for async functions
from unittest.mock import AsyncMock
@pytest.mark.asyncio
async def test_fetch_data():
mock_client = AsyncMock()
mock_client.get.return_value = {"data": "test"}
result = await fetch_data(mock_client)
assert result == {"data": "test"}
```
### 5. Not Resetting Mocks Between Tests
**Problem**: Mock state leaks between tests, causing flaky failures.
```python
# BAD - shared mock across tests
mock_api = Mock()
def test_first_call():
mock_api.fetch.return_value = "data1"
result = process(mock_api)
assert result == "data1"
def test_second_call():
# Mock still has state from test_first_call!
mock_api.fetch.return_value = "data2"
assert mock_api.fetch.call_count == 0 # FAILS! call_count is 1
# GOOD - fresh mock per test with fixture
@pytest.fixture
def mock_api():
return Mock()
def test_first_call(mock_api):
mock_api.fetch.return_value = "data1"
result = process(mock_api)
assert result == "data1"
def test_second_call(mock_api):
mock_api.fetch.return_value = "data2"
result = process(mock_api)
assert mock_api.fetch.call_count == 1 # Works!
```
### 6. Not Using side_effect for Complex Behavior
**Problem**: Using return_value when mock needs to raise exceptions or vary responses.
```python
# BAD - can't test retry logic with simple return_value
@pytest.mark.asyncio
async def test_retry_on_failure():
mock_client = AsyncMock()
mock_client.get.return_value = {"data": "test"} # Always succeeds!
result = await fetch_with_retry(mock_client)
# Can't test retry behavior!
# GOOD - use side_effect for sequence of responses
@pytest.mark.asyncio
async def test_retry_on_failure():
mock_client = AsyncMock()
mock_client.get.side_effect = [
asyncio.TimeoutError(), # First call fails
asyncio.TimeoutError(), # Second call fails
{"data": "test"} # Third call succeeds
]
result = await fetch_with_retry(mock_client, max_retries=3)
assert result == {"data": "test"}
assert mock_client.get.call_count == 3
# GOOD - use side_effect function for dynamic behavior
def test_dynamic_behavior():
def side_effect_fn(user_id):
if user_id == 1:
return {"name": "Alice"}
raise ValueError("User not found")
mock_db = Mock()
mock_db.get_user.side_effect = side_effect_fn
assert mock_db.get_user(1) == {"name": "Alice"}
with pytest.raises(ValueError):
mock_db.get_user(2)
```
### 7. Not Using spec or spec_set
**Problem**: Mock accepts any attribute, allowing tests that pass but code that fails.
```python
# BAD - mock without spec
def test_user_service():
mock_db = Mock()
service = UserService(mock_db)
service.process()
# Typo! Should be execute(), not exectue()
mock_db.exectue.assert_called_once() # Test passes! But code would fail
# GOOD - use spec to catch attribute errors
def test_user_service():
mock_db = Mock(spec=Database)
service = UserService(mock_db)
service.process()
# AttributeError: Mock object has no attribute 'exectue'
mock_db.execute.assert_called_once() # Forces correct spelling
```
### 8. Patching with Context Manager but Not Using It
**Problem**: Using patch as decorator when context manager is clearer for partial mocking.
```python
# BAD - decorator for partial test mocking
@patch("module.external_call")
def test_process(mock_call):
mock_call.return_value = "mocked"
# First part of test uses mock
result1 = process_with_external()
assert result1 == "mocked"
# Second part wants real call, but can't!
result2 = process_with_external() # Still mocked!
# GOOD - context manager for scoped mocking
def test_process():
# First part uses mock
with patch("module.external_call") as mock_call:
mock_call.return_value = "mocked"
result1 = process_with_external()
assert result1 == "mocked"
# Second part uses real function
result2 = process_with_external()
assert result2 != "mocked"
```
### 9. Not Checking Call Arguments Precisely
**Problem**: Using assert_called() instead of assert_called_with().
```python
# BAD - only checks if called, not what arguments
@pytest.mark.asyncio
async def test_create_user():
mock_db = AsyncMock()
await create_user(mock_db, name="Alice", email="[email protected]")
mock_db.insert.assert_called() # Called, but with what args?
# GOOD - verify exact arguments
@pytest.mark.asyncio
async def test_create_user():
mock_db = AsyncMock()
await create_user(mock_db, name="Alice", email="[email protected]")
mock_db.insert.assert_called_once_with(
name="Alice",
email="[email protected]"
)
# ALSO GOOD - use call_args for partial matching
@pytest.mark.asyncio
async def test_create_user():
mock_db = AsyncMock()
await create_user(mock_db, name="Alice", email="[email protected]")
call_args = mock_db.insert.call_args
assert call_args.kwargs["name"] == "Alice"
assert "email" in call_args.kwargs
```
### 10. Mocking Entire Objects Instead of Interfaces
**Problem**: Mocking concrete class when interface would be more accurate.
```python
# BAD - mocking concrete class
from unittest.mock import Mock
class PostgresDatabase:
def query(self, sql): ...
def execute(self, sql): ...
def internal_connection_pool(self): ...
def test_service():
mock_db = Mock(spec=PostgresDatabase)
# Test knows about PostgreSQL specifics!
service = UserService(mock_db)
# GOOD - mock interface/protocol
from typing import Protocol
class Database(Protocol):
async def query(self, sql: str) -> list: ...
async def execute(self, sql: str) -> None: ...
@pytest.fixture
def mock_db():
mock = AsyncMock(spec=Database)
return mock
def test_service(mock_db):
# Test only depends on interface
service = UserService(mock_db)
```
### 11. Not Using pytest-mock Plugin
**Problem**: Using unittest.mock directly when pytest-mock provides better integration.
```python
# BAD - manual patch cleanup
from unittest.mock import patch
def test_feature():
patcher = patch("module.function")
mock_fn = patcher.start()
mock_fn.return_value = "test"
result = use_function()
patcher.stop() # Easy to forget!
assert result == "test"
# GOOD - using pytest-mock mocker fixture
def test_feature(mocker):
mock_fn = mocker.patch("module.function")
mock_fn.return_value = "test"
result = use_function()
# Automatic cleanup!
assert result == "test"
```
## Review Questions
1. Are patches applied where the function is used, not where it's defined?
2. Are mock calls verified with assert_called_once_with() or similar?
3. Are internal implementation details tested rather than mocked?
4. Is AsyncMock used for all async functions and methods?
5. Are mocks fresh for each test (via fixtures or setUp)?
6. Is side_effect used for exceptions or varying responses?
7. Do mocks use spec or spec_set to catch attribute errors?
8. Are call arguments verified precisely, not just call presence?
9. Is pytest-mock used instead of raw unittest.mock?
FILE:references/parametrize.md
# Parametrize
## Critical Anti-Patterns
### 1. Duplicated Test Functions for Similar Cases
**Problem**: Copy-pasted test functions that differ only in input values.
```python
# BAD - duplicated test logic
@pytest.mark.asyncio
async def test_validate_email_valid():
result = validate_email("[email protected]")
assert result.is_valid is True
@pytest.mark.asyncio
async def test_validate_email_valid_subdomain():
result = validate_email("[email protected]")
assert result.is_valid is True
@pytest.mark.asyncio
async def test_validate_email_invalid_no_at():
result = validate_email("userexample.com")
assert result.is_valid is False
@pytest.mark.asyncio
async def test_validate_email_invalid_no_domain():
result = validate_email("user@")
assert result.is_valid is False
# GOOD - parametrized test
@pytest.mark.asyncio
@pytest.mark.parametrize("email,expected", [
("[email protected]", True),
("[email protected]", True),
("userexample.com", False),
("user@", False),
])
async def test_validate_email(email, expected):
result = validate_email(email)
assert result.is_valid is expected
```
### 2. Unclear Parametrize Names
**Problem**: Using generic names like "input" and "output" instead of descriptive names.
```python
# BAD - unclear parameter names
@pytest.mark.parametrize("input,output", [
(10, 100),
(5, 25),
(0, 0),
])
def test_calculation(input, output):
assert calculate(input) == output
# GOOD - descriptive parameter names
@pytest.mark.parametrize("radius,expected_area", [
(10, 314.159),
(5, 78.539),
(0, 0),
])
def test_circle_area(radius, expected_area):
assert calculate_area(radius) == pytest.approx(expected_area, rel=1e-3)
```
### 3. Not Using pytest.param for IDs
**Problem**: Test output shows cryptic parameter values instead of meaningful descriptions.
```python
# BAD - unclear test IDs in output
@pytest.mark.parametrize("user_role,can_access", [
("admin", True),
("user", False),
("guest", False),
])
def test_access_control(user_role, can_access):
assert check_access(user_role) == can_access
# Output: test_access_control[admin-True], test_access_control[user-False]
# GOOD - descriptive test IDs
@pytest.mark.parametrize("user_role,can_access", [
pytest.param("admin", True, id="admin_has_access"),
pytest.param("user", False, id="user_denied"),
pytest.param("guest", False, id="guest_denied"),
])
def test_access_control(user_role, can_access):
assert check_access(user_role) == can_access
# Output: test_access_control[admin_has_access], test_access_control[user_denied]
```
### 4. Not Combining Multiple Parametrize Decorators
**Problem**: Creating cartesian product manually instead of stacking decorators.
```python
# BAD - manual combinations
@pytest.mark.parametrize("method,status,role", [
("GET", 200, "admin"),
("GET", 200, "user"),
("POST", 200, "admin"),
("POST", 403, "user"),
("DELETE", 200, "admin"),
("DELETE", 403, "user"),
])
def test_api_access(method, status, role):
assert api_call(method, role).status_code == status
# GOOD - stacked parametrize for cartesian product
@pytest.mark.parametrize("method", ["GET", "POST", "DELETE"])
@pytest.mark.parametrize("role,expected_statuses", [
("admin", {"GET": 200, "POST": 200, "DELETE": 200}),
("user", {"GET": 200, "POST": 403, "DELETE": 403}),
])
def test_api_access(method, role, expected_statuses):
assert api_call(method, role).status_code == expected_statuses[method]
# ALTERNATIVE - if all admins succeed and users fail writes
@pytest.mark.parametrize("method", ["GET", "POST", "DELETE"])
@pytest.mark.parametrize("role,can_write", [
("admin", True),
("user", False),
])
def test_api_access(method, role, can_write):
response = api_call(method, role)
if method in ["POST", "DELETE"] and not can_write:
assert response.status_code == 403
else:
assert response.status_code == 200
```
### 5. Parametrizing Fixtures Instead of Tests
**Problem**: Complex parametrized fixtures when parametrized tests would be clearer.
```python
# BAD - parametrized fixture is hard to read
@pytest.fixture(params=[
{"name": "Alice", "role": "admin", "can_delete": True},
{"name": "Bob", "role": "user", "can_delete": False},
])
def user(request):
return User(**request.param)
def test_user_permissions(user):
assert user.can_delete() == user.expected_can_delete
# GOOD - parametrize the test
@pytest.fixture
def make_user():
def _make(name: str, role: str):
return User(name=name, role=role)
return _make
@pytest.mark.parametrize("name,role,can_delete", [
("Alice", "admin", True),
("Bob", "user", False),
])
def test_user_permissions(make_user, name, role, can_delete):
user = make_user(name, role)
assert user.can_delete() == can_delete
```
### 6. Not Marking Expected Failures
**Problem**: Including known failing cases without marking them.
```python
# BAD - test fails on known edge case
@pytest.mark.parametrize("input,expected", [
("valid", True),
("also_valid", True),
("edge_case", True), # This actually fails but is being worked on
])
def test_validator(input, expected):
assert validate(input) == expected
# GOOD - mark expected failures
@pytest.mark.parametrize("input,expected", [
("valid", True),
("also_valid", True),
pytest.param("edge_case", True, marks=pytest.mark.xfail(reason="Issue #123")),
])
def test_validator(input, expected):
assert validate(input) == expected
```
### 7. Large Parametrize Tables in Test File
**Problem**: Test file cluttered with large data tables.
```python
# BAD - 100 lines of test data inline
@pytest.mark.parametrize("input,expected", [
("case1", "result1"),
("case2", "result2"),
# ... 100 more lines ...
])
def test_parser(input, expected):
assert parse(input) == expected
# GOOD - externalize large datasets
# test_data/parser_cases.json
[
{"input": "case1", "expected": "result1"},
{"input": "case2", "expected": "result2"}
]
# test_parser.py
import json
from pathlib import Path
def load_test_cases():
path = Path(__file__).parent / "test_data" / "parser_cases.json"
with open(path) as f:
cases = json.load(f)
return [(c["input"], c["expected"]) for c in cases]
@pytest.mark.parametrize("input,expected", load_test_cases())
def test_parser(input, expected):
assert parse(input) == expected
```
### 8. Not Using Indirect Parametrization
**Problem**: Creating expensive test data for every parameter combination.
```python
# BAD - creating full database for each test
@pytest.mark.parametrize("user_id,expected_name", [
(1, "Alice"),
(2, "Bob"),
(3, "Charlie"),
])
async def test_get_user(user_id, expected_name):
db = await create_full_database() # Expensive! Runs 3 times!
user = await db.get_user(user_id)
assert user.name == expected_name
# GOOD - indirect parametrization with fixture
@pytest.fixture
async def db_with_users():
db = await create_full_database()
yield db
await db.cleanup()
@pytest.mark.parametrize("user_id,expected_name", [
(1, "Alice"),
(2, "Bob"),
(3, "Charlie"),
])
async def test_get_user(db_with_users, user_id, expected_name):
user = await db_with_users.get_user(user_id)
assert user.name == expected_name
# EVEN BETTER - indirect parametrization
@pytest.fixture
async def user(request, db):
user_id = request.param
return await db.get_user(user_id)
@pytest.mark.parametrize("user,expected_name", [
(1, "Alice"),
(2, "Bob"),
(3, "Charlie"),
], indirect=["user"])
async def test_user_name(user, expected_name):
assert user.name == expected_name
```
### 9. Testing Multiple Assertions Instead of Separating
**Problem**: Multiple unrelated assertions in one parametrized test.
```python
# BAD - multiple unrelated assertions
@pytest.mark.parametrize("user_data", [
{"name": "Alice", "age": 30, "role": "admin"},
{"name": "Bob", "age": 25, "role": "user"},
])
def test_user(user_data):
user = User(**user_data)
assert user.name == user_data["name"]
assert user.age == user_data["age"]
assert user.role == user_data["role"]
assert user.is_valid() # Unrelated to data validation
assert len(user.name) > 0 # Different concern
# GOOD - separate tests for different concerns
@pytest.mark.parametrize("name,age,role", [
("Alice", 30, "admin"),
("Bob", 25, "user"),
])
def test_user_creation(name, age, role):
user = User(name=name, age=age, role=role)
assert user.name == name
assert user.age == age
assert user.role == role
@pytest.mark.parametrize("name", ["Alice", "Bob", ""])
def test_user_name_validation(name):
if name:
user = User(name=name, age=30, role="user")
assert user.is_valid()
else:
with pytest.raises(ValueError):
User(name=name, age=30, role="user")
```
## Review Questions
1. Can duplicated test functions be combined with parametrize?
2. Do parametrized tests use descriptive parameter names?
3. Are test IDs meaningful using pytest.param(id="...")?
4. Should multiple parametrize decorators be stacked for combinations?
5. Are large test datasets externalized to separate files?
6. Is indirect parametrization used for expensive fixture setup?
Reviews PostgreSQL code for indexing strategies, JSONB operations, connection pooling, and transaction safety. Use when reviewing SQL queries, database schem...
---
name: postgres-code-review
description: Reviews PostgreSQL code for indexing strategies, JSONB operations, connection pooling, and transaction safety. Use when reviewing SQL queries, database schemas, JSONB usage, or connection management.
---
# PostgreSQL Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| Missing indexes, wrong index type, query performance | [references/indexes.md](references/indexes.md) |
| JSONB queries, operators, GIN indexes | [references/jsonb.md](references/jsonb.md) |
| Connection leaks, pool configuration, timeouts | [references/connections.md](references/connections.md) |
| Isolation levels, deadlocks, advisory locks | [references/transactions.md](references/transactions.md) |
## Review Checklist
- [ ] WHERE/JOIN columns have appropriate indexes
- [ ] Composite indexes match query patterns (column order matters)
- [ ] JSONB columns use GIN indexes when queried
- [ ] Using proper JSONB operators (`->`, `->>`, `@>`, `?`)
- [ ] Connection pool configured with appropriate limits
- [ ] Connections properly released (context managers, try/finally)
- [ ] Appropriate transaction isolation level for use case
- [ ] No long-running transactions holding locks
- [ ] Advisory locks used for application-level coordination
- [ ] Queries use parameterized statements (no SQL injection)
## Gates (before reporting findings)
Use this sequence so conclusions stay evidence-bound (not “I checked mentally”):
1. **Scope** — Record the concrete paths (and line ranges or symbols if helpful) for the SQL, DDL/migrations, and connection code under review. **Pass:** every subsystem you critique (queries, JSONB, pool, transactions) has at least one cited path.
2. **SQL/DDL citation for performance claims** — Index, sequential-scan, JSONB-operator, and plan-related findings must point to the exact statement or schema (quoted excerpt or `file:line`). **Pass:** each such finding includes that citation.
3. **Binding check before injection flags** — Only assert SQL-injection risk after locating how SQL and values are combined (bound parameters vs string concat/format/f-strings). **Pass:** you name the mechanism you saw in code for each flagged callsite.
Then load the relevant reference doc from [Quick Reference](#quick-reference) and walk the [Review Checklist](#review-checklist).
## When to Load References
- Reviewing SELECT queries with WHERE/JOIN → indexes.md
- Reviewing JSONB columns or JSON operations → jsonb.md
- Reviewing database connection code → connections.md
- Reviewing BEGIN/COMMIT or concurrent updates → transactions.md
## Review Questions
1. Will this query use an index or perform a sequential scan?
2. Are JSONB operations using appropriate operators and indexes?
3. Are database connections properly managed and released?
4. Is the transaction isolation level appropriate for this operation?
5. Could this cause deadlocks or long-running locks?
FILE:references/connections.md
# Connections
## Critical Anti-Patterns
### 1. Not Using Connection Pooling
**Problem**: Creating new connection per request is slow and exhausts database connections.
```python
# BAD: New connection every time
def get_user(user_id: int):
conn = psycopg2.connect(
host='localhost',
database='mydb',
user='user',
password='password'
)
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
result = cursor.fetchone()
conn.close()
return result
# GOOD: Use connection pool
from psycopg2.pool import ThreadedConnectionPool
pool = ThreadedConnectionPool(
minconn=5,
maxconn=20,
host='localhost',
database='mydb',
user='user',
password='password'
)
def get_user(user_id: int):
conn = pool.getconn()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
return cursor.fetchone()
finally:
pool.putconn(conn)
```
### 2. Connection Leaks
**Problem**: Not releasing connections back to pool causes starvation.
```python
# BAD: Connection leaked on error
def get_user(user_id: int):
conn = pool.getconn()
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
result = cursor.fetchone()
pool.putconn(conn) # Not called if error occurs!
return result
# GOOD: Always release in finally block
def get_user(user_id: int):
conn = pool.getconn()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
return cursor.fetchone()
finally:
pool.putconn(conn)
# BETTER: Use context manager
from contextlib import contextmanager
@contextmanager
def get_db_connection():
conn = pool.getconn()
try:
yield conn
finally:
pool.putconn(conn)
def get_user(user_id: int):
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
return cursor.fetchone()
```
### 3. Wrong Pool Size
**Problem**: Pool too small (connection starvation) or too large (resource waste).
```python
# BAD: Pool size not based on workload
pool = ThreadedConnectionPool(minconn=1, maxconn=100)
# GOOD: Size based on concurrent requests and database limits
# Rule of thumb: (num_cores * 2) + effective_spindle_count
# For web server: match number of worker threads/processes
pool = ThreadedConnectionPool(
minconn=5, # Keep warm connections
maxconn=20, # Max concurrent requests
host='localhost',
database='mydb'
)
# Check PostgreSQL connection limit
# SHOW max_connections; -- Default 100
# Ensure pool max < max_connections across all app instances
```
### 4. No Connection Timeout
**Problem**: Application hangs waiting for connections.
```python
# BAD: No timeout, hangs indefinitely
pool = ThreadedConnectionPool(minconn=5, maxconn=20)
conn = pool.getconn() # Blocks forever if pool exhausted
# GOOD: Use timeout
pool = ThreadedConnectionPool(minconn=5, maxconn=20)
conn = pool.getconn(timeout=5) # Raises error after 5 seconds
if conn is None:
raise Exception("Could not get database connection")
# BETTER: Use asyncpg with async/await
import asyncpg
pool = await asyncpg.create_pool(
host='localhost',
database='mydb',
min_size=5,
max_size=20,
timeout=5,
command_timeout=30 # Query timeout
)
async def get_user(user_id: int):
async with pool.acquire() as conn:
return await conn.fetchrow(
"SELECT * FROM users WHERE id = $1",
user_id
)
```
### 5. Not Setting Statement Timeout
**Problem**: Long-running queries hold connections and locks.
```python
# BAD: No query timeout
async def expensive_query():
async with pool.acquire() as conn:
# Could run for hours, holding connection
return await conn.fetch("SELECT * FROM huge_table")
# GOOD: Set statement timeout
async def expensive_query():
async with pool.acquire() as conn:
await conn.execute("SET statement_timeout = '30s'")
try:
return await conn.fetch("SELECT * FROM huge_table")
except asyncpg.QueryCanceledError:
raise TimeoutError("Query took too long")
# BETTER: Set at connection level
pool = await asyncpg.create_pool(
host='localhost',
database='mydb',
command_timeout=30, # 30 second timeout for all queries
server_settings={'statement_timeout': '30000'} # milliseconds
)
```
### 6. Not Using PgBouncer
**Problem**: Application connection pool doesn't reduce database connections.
```yaml
# BAD: Each app instance has its own pool
# 3 app servers * 20 connections = 60 database connections
# GOOD: Use PgBouncer for connection pooling
# pgbouncer.ini
[databases]
mydb = host=localhost port=5432 dbname=mydb
[pgbouncer]
listen_addr = *
listen_port = 6432
auth_type = md5
auth_file = /etc/pgbouncer/userlist.txt
pool_mode = transaction # or session
max_client_conn = 1000 # Application connections
default_pool_size = 20 # Database connections
reserve_pool_size = 5
```
```python
# Application connects to PgBouncer instead of PostgreSQL
pool = await asyncpg.create_pool(
host='localhost',
port=6432, # PgBouncer port
database='mydb'
)
# Now 3 app servers * 20 connections = 60 app connections
# But only 20 database connections via PgBouncer
```
### 7. Holding Connections During I/O
**Problem**: Holding database connection while doing network/file I/O.
```python
# BAD: Holding connection during API call
async def process_user(user_id: int):
async with pool.acquire() as conn:
user = await conn.fetchrow("SELECT * FROM users WHERE id = $1", user_id)
# Holding connection during external API call!
async with httpx.AsyncClient() as client:
response = await client.get(f"https://api.example.com/users/{user_id}")
await conn.execute(
"UPDATE users SET api_data = $1 WHERE id = $2",
response.json(), user_id
)
# GOOD: Release connection during I/O
async def process_user(user_id: int):
async with pool.acquire() as conn:
user = await conn.fetchrow("SELECT * FROM users WHERE id = $1", user_id)
# Connection released during API call
async with httpx.AsyncClient() as client:
response = await client.get(f"https://api.example.com/users/{user_id}")
async with pool.acquire() as conn:
await conn.execute(
"UPDATE users SET api_data = $1 WHERE id = $2",
response.json(), user_id
)
```
### 8. Not Monitoring Connection Pool
**Problem**: Can't diagnose connection starvation or leaks.
```python
# BAD: No visibility into pool state
pool = ThreadedConnectionPool(minconn=5, maxconn=20)
# GOOD: Monitor pool metrics
import logging
logger = logging.getLogger(__name__)
@contextmanager
def get_db_connection():
logger.info(f"Pool: {pool._used}/{pool._maxconn} connections used")
conn = pool.getconn()
try:
yield conn
finally:
pool.putconn(conn)
# BETTER: Use metrics library
from prometheus_client import Gauge
db_connections_used = Gauge('db_connections_used', 'Database connections in use')
db_connections_max = Gauge('db_connections_max', 'Max database connections')
@contextmanager
def get_db_connection():
conn = pool.getconn()
db_connections_used.inc()
try:
yield conn
finally:
pool.putconn(conn)
db_connections_used.dec()
db_connections_max.set(pool._maxconn)
```
## Review Questions
1. Is connection pooling used?
2. Are connections always released (try/finally or context manager)?
3. Is pool size appropriate for workload?
4. Are connection and statement timeouts configured?
5. Would PgBouncer help reduce database connections?
6. Are connections released during I/O operations?
7. Is connection pool health monitored?
8. Are connection errors handled and logged?
FILE:references/indexes.md
# Indexes
## Critical Anti-Patterns
### 1. Missing Index on WHERE Clause
**Problem**: Sequential scan on large tables causes slow queries.
```sql
-- BAD: No index on email
SELECT * FROM users WHERE email = '[email protected]';
-- GOOD: Create index
CREATE INDEX idx_users_email ON users(email);
SELECT * FROM users WHERE email = '[email protected]';
```
```python
# Check query plan
EXPLAIN ANALYZE SELECT * FROM users WHERE email = '[email protected]';
# Look for "Seq Scan" (bad) vs "Index Scan" (good)
```
### 2. Wrong Column Order in Composite Index
**Problem**: Index not used if query doesn't match leftmost columns.
```sql
-- BAD: Index doesn't match query pattern
CREATE INDEX idx_orders_wrong ON orders(status, user_id);
SELECT * FROM orders WHERE user_id = 123; -- Won't use index!
-- GOOD: Match query pattern
CREATE INDEX idx_orders_user_status ON orders(user_id, status);
SELECT * FROM orders WHERE user_id = 123; -- Uses index
SELECT * FROM orders WHERE user_id = 123 AND status = 'pending'; -- Uses index
```
**Rule**: Put high-selectivity columns first, match WHERE clause order.
### 3. Not Using Partial Indexes
**Problem**: Indexing entire table when only subset is queried.
```sql
-- BAD: Index includes all rows
CREATE INDEX idx_orders_status ON orders(status);
-- GOOD: Only index active orders
CREATE INDEX idx_orders_active ON orders(user_id, created_at)
WHERE status = 'active';
SELECT * FROM orders
WHERE status = 'active' AND user_id = 123
ORDER BY created_at DESC; -- Uses partial index
```
### 4. Missing Index on Foreign Keys
**Problem**: Slow JOINs and cascading deletes.
```sql
-- BAD: No index on foreign key
CREATE TABLE order_items (
id SERIAL PRIMARY KEY,
order_id INTEGER REFERENCES orders(id),
product_id INTEGER REFERENCES products(id)
);
-- GOOD: Index foreign keys
CREATE TABLE order_items (
id SERIAL PRIMARY KEY,
order_id INTEGER REFERENCES orders(id),
product_id INTEGER REFERENCES products(id)
);
CREATE INDEX idx_order_items_order_id ON order_items(order_id);
CREATE INDEX idx_order_items_product_id ON order_items(product_id);
```
### 5. Not Using EXPLAIN ANALYZE
**Problem**: Guessing instead of measuring query performance.
```python
# BAD: Assuming query is fast
cursor.execute("SELECT * FROM orders WHERE user_id = %s", (user_id,))
# GOOD: Verify with EXPLAIN
cursor.execute("""
EXPLAIN ANALYZE
SELECT * FROM orders WHERE user_id = %s
""", (user_id,))
print(cursor.fetchall())
# Check: Index Scan vs Seq Scan, actual time, rows
# Then run actual query
cursor.execute("SELECT * FROM orders WHERE user_id = %s", (user_id,))
```
### 6. Over-Indexing
**Problem**: Slows down writes, wastes space.
```sql
-- BAD: Too many indexes on rarely-queried columns
CREATE INDEX idx_users_created_at ON users(created_at);
CREATE INDEX idx_users_updated_at ON users(updated_at);
CREATE INDEX idx_users_last_login ON users(last_login);
CREATE INDEX idx_users_email ON users(email);
CREATE INDEX idx_users_username ON users(username);
-- GOOD: Only index frequently-queried columns
CREATE INDEX idx_users_email ON users(email); -- Used for login
CREATE INDEX idx_users_username ON users(username); -- Used for lookup
-- Skip indexes on created_at, updated_at, last_login if rarely queried
```
### 7. Not Using Covering Indexes
**Problem**: Index scan followed by table lookup (heap fetch).
```sql
-- BAD: Index on id only, must fetch name from table
CREATE INDEX idx_users_email ON users(email);
SELECT id, name FROM users WHERE email = '[email protected]';
-- GOOD: Include name in index (covering index)
CREATE INDEX idx_users_email_covering ON users(email) INCLUDE (name);
SELECT id, name FROM users WHERE email = '[email protected]';
-- "Index Only Scan" - no heap fetch needed
```
### 8. String Pattern Matching Without Index
**Problem**: LIKE with leading wildcard can't use B-tree index.
```sql
-- BAD: Can't use standard index
CREATE INDEX idx_users_email ON users(email);
SELECT * FROM users WHERE email LIKE '%@example.com'; -- Seq Scan
-- GOOD: Use trigram index for pattern matching
CREATE EXTENSION IF NOT EXISTS pg_trgm;
CREATE INDEX idx_users_email_trgm ON users USING gin(email gin_trgm_ops);
SELECT * FROM users WHERE email LIKE '%@example.com'; -- Uses GIN index
```
## Review Questions
1. Do all WHERE and JOIN columns have indexes?
2. Are composite index column orders optimized for queries?
3. Would partial indexes reduce index size and improve performance?
4. Are foreign keys indexed?
5. Has EXPLAIN ANALYZE been used to verify query plans?
6. Are there redundant or unused indexes?
7. Would covering indexes eliminate heap fetches?
8. Are pattern matching queries using appropriate index types (GIN, trigram)?
FILE:references/jsonb.md
# JSONB
## Critical Anti-Patterns
### 1. Using JSON Instead of JSONB
**Problem**: JSON is stored as text, slower to query, no indexing.
```sql
-- BAD: Using JSON type
CREATE TABLE events (
id SERIAL PRIMARY KEY,
metadata JSON
);
-- GOOD: Use JSONB for querying and indexing
CREATE TABLE events (
id SERIAL PRIMARY KEY,
metadata JSONB
);
```
**Rule**: Always use JSONB unless you need to preserve exact formatting/whitespace.
### 2. Wrong JSONB Operator
**Problem**: `->` returns JSONB, `->>` returns text. Using wrong one breaks queries.
```sql
-- BAD: Comparing JSONB to text
SELECT * FROM users WHERE metadata->'age' = '25'; -- Won't work
-- GOOD: Use ->> for text comparison
SELECT * FROM users WHERE metadata->>'age' = '25';
-- GOOD: Use -> for JSONB comparison
SELECT * FROM users WHERE metadata->'age' = '25'::jsonb;
-- GOOD: Cast to integer for numeric comparison
SELECT * FROM users WHERE (metadata->>'age')::int = 25;
```
**Operators**:
- `->` extracts as JSONB: `metadata->'address'` → `{"city": "NYC"}`
- `->>` extracts as text: `metadata->>'name'` → `"Alice"`
- `@>` contains: `metadata @> '{"role": "admin"}'`
- `?` key exists: `metadata ? 'email'`
### 3. Missing GIN Index on JSONB
**Problem**: JSONB queries without indexes perform sequential scans.
```sql
-- BAD: Querying JSONB without index
SELECT * FROM users WHERE metadata @> '{"role": "admin"}'; -- Seq Scan
-- GOOD: Create GIN index
CREATE INDEX idx_users_metadata ON users USING gin(metadata);
SELECT * FROM users WHERE metadata @> '{"role": "admin"}'; -- Uses index
-- GOOD: GIN index on specific path
CREATE INDEX idx_users_metadata_role ON users USING gin((metadata->'role'));
```
### 4. Not Using Containment Operator
**Problem**: Extracting and comparing is slower than using `@>`.
```sql
-- BAD: Extracting then comparing
SELECT * FROM events
WHERE metadata->>'type' = 'click' AND metadata->>'source' = 'mobile';
-- GOOD: Use containment operator
SELECT * FROM events
WHERE metadata @> '{"type": "click", "source": "mobile"}';
-- Much faster with GIN index
```
### 5. Storing Arrays as JSON Strings
**Problem**: Can't use array operators, must parse JSON every time.
```python
# BAD: Storing array as JSON string
cursor.execute("""
INSERT INTO users (tags) VALUES (%s)
""", (json.dumps(['python', 'postgres']),))
cursor.execute("""
SELECT * FROM users WHERE tags::jsonb @> '"python"'
""")
# GOOD: Use PostgreSQL array type for simple arrays
cursor.execute("""
INSERT INTO users (tags) VALUES (%s)
""", (['python', 'postgres'],))
cursor.execute("""
SELECT * FROM users WHERE 'python' = ANY(tags)
""")
# Use JSONB only for complex nested structures
```
### 6. Deep Nesting Without Indexes
**Problem**: Querying deep paths is slow without expression indexes.
```sql
-- BAD: Querying deep path without index
SELECT * FROM events
WHERE metadata->'user'->'profile'->>'country' = 'US';
-- GOOD: Create expression index
CREATE INDEX idx_events_country ON events(
(metadata->'user'->'profile'->>'country')
);
SELECT * FROM events
WHERE metadata->'user'->'profile'->>'country' = 'US';
```
### 7. Not Validating JSONB Structure
**Problem**: No schema validation leads to inconsistent data.
```python
# BAD: No validation
cursor.execute("""
INSERT INTO users (metadata) VALUES (%s)
""", (json.dumps({'age': 'twenty-five'}),)) # Should be integer!
# GOOD: Validate before insert
def validate_user_metadata(metadata: dict) -> dict:
assert isinstance(metadata.get('age'), int), "age must be integer"
assert isinstance(metadata.get('email'), str), "email must be string"
return metadata
metadata = validate_user_metadata({'age': 25, 'email': '[email protected]'})
cursor.execute("""
INSERT INTO users (metadata) VALUES (%s)
""", (json.dumps(metadata),))
# BETTER: Use CHECK constraint
CREATE TABLE users (
id SERIAL PRIMARY KEY,
metadata JSONB,
CHECK (jsonb_typeof(metadata->'age') = 'number'),
CHECK (jsonb_typeof(metadata->'email') = 'string')
);
```
### 8. JSONB for Relational Data
**Problem**: Using JSONB when proper columns/foreign keys are better.
```sql
-- BAD: Storing relational data in JSONB
CREATE TABLE orders (
id SERIAL PRIMARY KEY,
data JSONB -- Contains user_id, product_id, quantity
);
-- GOOD: Use proper columns and foreign keys
CREATE TABLE orders (
id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(id),
product_id INTEGER REFERENCES products(id),
quantity INTEGER,
metadata JSONB -- Only for truly unstructured data
);
```
**Rule**: Use JSONB for truly dynamic/unstructured data, not for avoiding schema design.
### 9. Inefficient JSONB Aggregation
**Problem**: Not using jsonb_agg or jsonb_object_agg.
```python
# BAD: Fetching and building JSON in application code
cursor.execute("SELECT id, name FROM products WHERE category_id = %s", (cat_id,))
products = [{'id': row[0], 'name': row[1]} for row in cursor.fetchall()]
result = {'products': products}
# GOOD: Build JSON in database
cursor.execute("""
SELECT jsonb_build_object(
'products', jsonb_agg(jsonb_build_object('id', id, 'name', name))
)
FROM products
WHERE category_id = %s
""", (cat_id,))
result = cursor.fetchone()[0]
```
## Review Questions
1. Is JSONB used instead of JSON?
2. Are the correct operators used (`->` vs `->>`, `@>` for containment)?
3. Do JSONB columns have GIN indexes?
4. Are containment operators (`@>`) used instead of extracting and comparing?
5. Is JSONB used appropriately (not for relational data)?
6. Are deep paths indexed with expression indexes?
7. Is JSONB structure validated?
8. Are JSONB aggregation functions used instead of application-side building?
FILE:references/transactions.md
# Transactions
## Critical Anti-Patterns
### 1. Wrong Isolation Level
**Problem**: Using default isolation level when stronger guarantees needed.
```python
# BAD: Default READ COMMITTED allows non-repeatable reads
async def transfer_money(from_id: int, to_id: int, amount: int):
async with pool.acquire() as conn:
async with conn.transaction():
balance = await conn.fetchval(
"SELECT balance FROM accounts WHERE id = $1", from_id
)
if balance < amount:
raise ValueError("Insufficient funds")
# Another transaction could modify balance here!
await conn.execute(
"UPDATE accounts SET balance = balance - $1 WHERE id = $2",
amount, from_id
)
await conn.execute(
"UPDATE accounts SET balance = balance + $1 WHERE id = $2",
amount, to_id
)
# GOOD: Use SERIALIZABLE for critical operations
async def transfer_money(from_id: int, to_id: int, amount: int):
async with pool.acquire() as conn:
async with conn.transaction(isolation='serializable'):
balance = await conn.fetchval(
"SELECT balance FROM accounts WHERE id = $1", from_id
)
if balance < amount:
raise ValueError("Insufficient funds")
await conn.execute(
"UPDATE accounts SET balance = balance - $1 WHERE id = $2",
amount, from_id
)
await conn.execute(
"UPDATE accounts SET balance = balance + $1 WHERE id = $2",
amount, to_id
)
# BETTER: Use SELECT FOR UPDATE to lock row
async def transfer_money(from_id: int, to_id: int, amount: int):
async with pool.acquire() as conn:
async with conn.transaction():
balance = await conn.fetchval(
"SELECT balance FROM accounts WHERE id = $1 FOR UPDATE",
from_id
)
if balance < amount:
raise ValueError("Insufficient funds")
await conn.execute(
"UPDATE accounts SET balance = balance - $1 WHERE id = $2",
amount, from_id
)
await conn.execute(
"UPDATE accounts SET balance = balance + $1 WHERE id = $2",
amount, to_id
)
```
**Isolation Levels**:
- `READ COMMITTED` (default): Prevents dirty reads, allows non-repeatable reads
- `REPEATABLE READ`: Prevents dirty and non-repeatable reads
- `SERIALIZABLE`: Full isolation, prevents all anomalies
### 2. Long-Running Transactions
**Problem**: Holds locks, blocks other queries, bloats WAL.
```python
# BAD: Long transaction holding locks
async def process_orders():
async with pool.acquire() as conn:
async with conn.transaction():
orders = await conn.fetch("SELECT * FROM orders WHERE status = 'pending'")
for order in orders:
# External API call inside transaction!
result = await external_api.process(order)
await conn.execute(
"UPDATE orders SET status = $1, result = $2 WHERE id = $3",
'processed', result, order['id']
)
# GOOD: Keep transactions short
async def process_orders():
async with pool.acquire() as conn:
orders = await conn.fetch("SELECT * FROM orders WHERE status = 'pending'")
# Process outside transaction
for order in orders:
result = await external_api.process(order)
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
"UPDATE orders SET status = $1, result = $2 WHERE id = $3",
'processed', result, order['id']
)
```
### 3. Deadlocks from Lock Order
**Problem**: Different transactions acquire locks in different orders.
```python
# BAD: Different lock order causes deadlocks
# Transaction 1
async with conn.transaction():
await conn.execute("UPDATE accounts SET balance = balance - 100 WHERE id = 1")
await conn.execute("UPDATE accounts SET balance = balance + 100 WHERE id = 2")
# Transaction 2 (at same time)
async with conn.transaction():
await conn.execute("UPDATE accounts SET balance = balance - 50 WHERE id = 2")
await conn.execute("UPDATE accounts SET balance = balance + 50 WHERE id = 1")
# DEADLOCK: T1 locks account 1, T2 locks account 2, both wait for each other
# GOOD: Always acquire locks in same order
async def transfer(from_id: int, to_id: int, amount: int):
# Always lock lower ID first
first_id, second_id = sorted([from_id, to_id])
async with conn.transaction():
# Lock in consistent order
await conn.execute(
"SELECT id FROM accounts WHERE id IN ($1, $2) ORDER BY id FOR UPDATE",
first_id, second_id
)
if from_id < to_id:
await conn.execute(
"UPDATE accounts SET balance = balance - $1 WHERE id = $2",
amount, from_id
)
await conn.execute(
"UPDATE accounts SET balance = balance + $1 WHERE id = $2",
amount, to_id
)
else:
await conn.execute(
"UPDATE accounts SET balance = balance + $1 WHERE id = $2",
amount, to_id
)
await conn.execute(
"UPDATE accounts SET balance = balance - $1 WHERE id = $2",
amount, from_id
)
```
### 4. Not Using Advisory Locks
**Problem**: Application-level coordination requires database support.
```python
# BAD: Race condition on external resource
async def process_unique_job(job_id: int):
async with pool.acquire() as conn:
job = await conn.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
if job['status'] == 'pending':
# Another process could process same job!
result = await expensive_operation(job)
await conn.execute(
"UPDATE jobs SET status = 'complete', result = $1 WHERE id = $2",
result, job_id
)
# GOOD: Use advisory lock
async def process_unique_job(job_id: int):
async with pool.acquire() as conn:
# Try to acquire advisory lock (non-blocking)
locked = await conn.fetchval(
"SELECT pg_try_advisory_lock($1)",
job_id
)
if not locked:
return # Another process is handling this job
try:
job = await conn.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
if job['status'] == 'pending':
result = await expensive_operation(job)
await conn.execute(
"UPDATE jobs SET status = 'complete', result = $1 WHERE id = $2",
result, job_id
)
finally:
# Release advisory lock
await conn.execute("SELECT pg_advisory_unlock($1)", job_id)
```
**Advisory Lock Functions**:
- `pg_advisory_lock(key)`: Blocking lock
- `pg_try_advisory_lock(key)`: Non-blocking, returns true/false
- `pg_advisory_unlock(key)`: Release lock
- `pg_advisory_xact_lock(key)`: Auto-released at transaction end
### 5. Not Handling Serialization Failures
**Problem**: SERIALIZABLE transactions can fail and need retry.
```python
# BAD: No retry on serialization failure
async def increment_counter(counter_id: int):
async with pool.acquire() as conn:
async with conn.transaction(isolation='serializable'):
count = await conn.fetchval(
"SELECT count FROM counters WHERE id = $1", counter_id
)
await conn.execute(
"UPDATE counters SET count = $1 WHERE id = $2",
count + 1, counter_id
)
# Raises SerializationError under contention
# GOOD: Retry on serialization failure
import asyncpg
async def increment_counter(counter_id: int, max_retries: int = 3):
for attempt in range(max_retries):
try:
async with pool.acquire() as conn:
async with conn.transaction(isolation='serializable'):
count = await conn.fetchval(
"SELECT count FROM counters WHERE id = $1", counter_id
)
await conn.execute(
"UPDATE counters SET count = $1 WHERE id = $2",
count + 1, counter_id
)
return # Success
except asyncpg.SerializationError:
if attempt == max_retries - 1:
raise
# Retry with exponential backoff
await asyncio.sleep(0.1 * (2 ** attempt))
```
### 6. Missing ROLLBACK on Error
**Problem**: Transaction left open on error, holds locks.
```python
# BAD: Transaction not rolled back on error
conn = pool.getconn()
cursor = conn.cursor()
cursor.execute("BEGIN")
try:
cursor.execute("UPDATE accounts SET balance = balance - 100 WHERE id = 1")
# Error here leaves transaction open!
cursor.execute("UPDATE accounts SET balance = balance + 100 WHERE id = 2")
conn.commit()
finally:
pool.putconn(conn)
# GOOD: Use context manager (auto rollback)
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute("UPDATE accounts SET balance = balance - 100 WHERE id = 1")
await conn.execute("UPDATE accounts SET balance = balance + 100 WHERE id = 2")
# Automatically rolls back on exception
# GOOD: Explicit rollback
conn = pool.getconn()
try:
cursor = conn.cursor()
cursor.execute("BEGIN")
try:
cursor.execute("UPDATE accounts SET balance = balance - 100 WHERE id = 1")
cursor.execute("UPDATE accounts SET balance = balance + 100 WHERE id = 2")
conn.commit()
except Exception:
conn.rollback()
raise
finally:
pool.putconn(conn)
```
### 7. Nested Transactions Without Savepoints
**Problem**: Inner "transaction" doesn't actually create nested transaction.
```python
# BAD: Nested transaction blocks don't work as expected
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute("INSERT INTO logs (message) VALUES ('start')")
try:
async with conn.transaction(): # This doesn't create nested transaction!
await conn.execute("INSERT INTO data (value) VALUES (123)")
raise ValueError("Error")
except ValueError:
pass # Expect outer transaction to continue
await conn.execute("INSERT INTO logs (message) VALUES ('end')")
# Entire transaction is rolled back, including 'start' log
# GOOD: Use savepoints for nested transactions
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute("INSERT INTO logs (message) VALUES ('start')")
try:
# Create savepoint
await conn.execute("SAVEPOINT inner")
await conn.execute("INSERT INTO data (value) VALUES (123)")
raise ValueError("Error")
except ValueError:
# Rollback to savepoint
await conn.execute("ROLLBACK TO SAVEPOINT inner")
await conn.execute("INSERT INTO logs (message) VALUES ('end')")
# Both logs are committed, data insert is rolled back
```
### 8. Not Using FOR UPDATE SKIP LOCKED
**Problem**: Queue processing blocked by locked rows.
```python
# BAD: Workers block on locked rows
async def process_next_job():
async with pool.acquire() as conn:
async with conn.transaction():
# Blocks if another worker locked this row
job = await conn.fetchrow("""
SELECT * FROM jobs
WHERE status = 'pending'
ORDER BY created_at
LIMIT 1
FOR UPDATE
""")
if job:
await process_job(job)
await conn.execute(
"UPDATE jobs SET status = 'complete' WHERE id = $1",
job['id']
)
# GOOD: Skip locked rows
async def process_next_job():
async with pool.acquire() as conn:
async with conn.transaction():
# Skip rows locked by other workers
job = await conn.fetchrow("""
SELECT * FROM jobs
WHERE status = 'pending'
ORDER BY created_at
LIMIT 1
FOR UPDATE SKIP LOCKED
""")
if job:
await process_job(job)
await conn.execute(
"UPDATE jobs SET status = 'complete' WHERE id = $1",
job['id']
)
```
## Review Questions
1. Is the isolation level appropriate for the operation?
2. Are transactions kept short (no I/O inside)?
3. Are locks always acquired in consistent order to prevent deadlocks?
4. Would advisory locks help with application-level coordination?
5. Are serialization failures caught and retried?
6. Are transactions properly rolled back on error (context managers)?
7. Are savepoints used for nested transaction semantics?
8. Is `FOR UPDATE SKIP LOCKED` used for queue processing?
Reviews FastAPI code for routing patterns, dependency injection, validation, and async handlers. Use when reviewing FastAPI apps, checking APIRouter setup, D...
---
name: fastapi-code-review
description: Reviews FastAPI code for routing patterns, dependency injection, validation, and async handlers. Use when reviewing FastAPI apps, checking APIRouter setup, Depends() usage, or response models.
---
# FastAPI Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| APIRouter setup, response_model, status codes | [references/routes.md](references/routes.md) |
| Depends(), yield deps, cleanup, shared deps | [references/dependencies.md](references/dependencies.md) |
| Pydantic models, HTTPException, 422 handling | [references/validation.md](references/validation.md) |
| Async handlers, blocking I/O, background tasks | [references/async.md](references/async.md) |
## Review Checklist
- [ ] APIRouter with proper prefix and tags
- [ ] All routes specify `response_model` for type safety
- [ ] Correct HTTP methods (GET, POST, PUT, DELETE, PATCH)
- [ ] Proper status codes (200, 201, 204, 404, etc.)
- [ ] Dependencies use `Depends()` not manual calls
- [ ] Yield dependencies have proper cleanup
- [ ] Request/Response models use Pydantic
- [ ] HTTPException with status code and detail
- [ ] All route handlers are `async def`
- [ ] No blocking I/O (`requests`, `time.sleep`, `open()`)
- [ ] Background tasks for non-blocking operations
- [ ] No bare `except` in route handlers
## Valid Patterns (Do NOT Flag)
These are idiomatic FastAPI patterns that may appear problematic but are correct:
- **Pydantic validates request body automatically** - No manual validation needed when using typed Pydantic models as parameters
- **Dependency injection for database sessions** - Sessions come from `Depends()`, not passed as function arguments
- **HTTPException for all HTTP errors** - FastAPI handles conversion to proper HTTP responses
- **Async def endpoint without await** - May be using sync dependencies or simple operations; FastAPI handles this
- **Type annotation on Depends()** - This is documentation/IDE support, not a type assertion
- **Query/Path/Body defaults** - FastAPI processes these at runtime, not traditional Python defaults
- **Returning dict from endpoint** - Pydantic converts automatically if `response_model` is set
## Context-Sensitive Rules
Only flag issues when the context warrants it:
- **Flag missing validation** ONLY IF the field isn't already in a Pydantic model with validators
- **Flag missing auth** ONLY IF the endpoint isn't using `Depends()` with an auth dependency
- **Flag missing error handling** ONLY IF HTTPException isn't raised appropriately for error cases
- **Flag sync in async** ONLY IF the operation is actually blocking (file I/O, network calls, CPU-bound), not just non-async
## Gates (FastAPI-specific)
Run **once per FastAPI-related finding**, after you can anchor **`file:line`** for the handler (see [review-verification-protocol](../review-verification-protocol/SKILL.md)) and **before** the finding text ships. If a step’s pass condition is not met, **do not** assert the finding as written—gather evidence, withdraw, downgrade severity, or rephrase as a question.
### Gate 1 — Route decorator and response surface
| Step | Action | **Pass condition** |
|------|--------|---------------------|
| 1a | Open the handler’s route decorator in the repo (not from memory). | **`file:line`** for `@router.*` / `@app.*` (or the site that registers this handler). |
| 1b | Record HTTP method, `response_model=`, and `status_code=` on that decorator (or note they are absent). | **Snippet from that line** or **explicit absent** with the same **`file:line`**. |
### Gate 2 — Blocking or “should be async”
| Step | Action | **Pass condition** |
|------|--------|---------------------|
| 2a | Read the full handler body. | **`file:line` range** covering the body. |
| 2b | If claiming blocking I/O: name each blocking call (e.g. `requests.`, `open(`, `time.sleep`, sync DB/ORM). | **Each** call has **`file:line`**, or withdraw the finding if **none** after the read. |
### Gate 3 — Depends, validation, auth
| Step | Action | **Pass condition** |
|------|--------|---------------------|
| 3a | List parameters: `Depends` / `Annotated[..., Depends]`, Pydantic models, `Body`/`Query`/`Path`, `Request`/`Response`. | **Names + mechanism** tied to **`file:line`** on the signature. |
| 3b | If claiming missing auth: search the handler file (and its `APIRouter` module if separate) for `Depends`, `Security`, `HTTPBearer`, or project auth dependencies. | **Citation** to an existing hook, or **search result**: paths searched + **N matches** (zero is allowed). |
| 3c | If claiming missing validation: confirm the argument is not already a Pydantic model or constrained `Query`/`Path`/`Body`. | **Type/source** with **`file:line`**, or withdraw if validation already applies. |
## FastAPI Framework Behaviors
FastAPI + Pydantic handle many concerns automatically:
- Request validation via Pydantic models
- Response serialization via response_model
- Dependency injection for cross-cutting concerns
- Exception handling via exception handlers
Before flagging "missing" functionality, verify FastAPI isn't handling it.
## When to Load References
- Reviewing route definitions → routes.md
- Reviewing dependency injection → dependencies.md
- Reviewing Pydantic models/validation → validation.md
- Reviewing async route handlers → async.md
## Review Questions
1. Do all routes have explicit response models and status codes?
2. Are dependencies injected via Depends() with proper cleanup?
3. Do all Pydantic models validate inputs correctly?
4. Are all route handlers async and non-blocking?
## Before Submitting Findings
1. For each FastAPI-related finding, complete **Gates (FastAPI-specific)** above.
2. Load and follow [review-verification-protocol](../review-verification-protocol/SKILL.md) (Pre-Report checklist and **Verification by Issue Type**) before reporting any issue.
FILE:references/async.md
# Async
## Critical Anti-Patterns
### 1. Blocking I/O in Async Handlers
**Problem**: Blocks the event loop, prevents concurrent request handling.
```python
# BAD - blocking HTTP client
import requests
@router.get("/external")
async def fetch_external():
response = requests.get("https://api.example.com") # BLOCKS!
return response.json()
# GOOD - async HTTP client
import httpx
@router.get("/external")
async def fetch_external():
async with httpx.AsyncClient() as client:
response = await client.get("https://api.example.com")
return response.json()
```
### 2. Blocking Database Calls
**Problem**: Synchronous DB driver blocks event loop.
```python
# BAD - sync SQLAlchemy
from sqlalchemy.orm import Session
@router.get("/users", response_model=list[UserResponse])
async def list_users(db: Session = Depends(get_db)):
users = db.query(User).all() # BLOCKS!
return users
# GOOD - async SQLAlchemy
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
@router.get("/users", response_model=list[UserResponse])
async def list_users(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User))
users = result.scalars().all()
return users
```
### 3. Using time.sleep Instead of asyncio.sleep
**Problem**: Blocks event loop during sleep.
```python
# BAD - blocking sleep
import time
@router.post("/jobs")
async def create_job():
time.sleep(5) # BLOCKS for 5 seconds!
return {"status": "done"}
# GOOD - async sleep
import asyncio
@router.post("/jobs")
async def create_job():
await asyncio.sleep(5) # Yields control
return {"status": "done"}
# BETTER - use background tasks for long operations
from fastapi import BackgroundTasks
async def process_job():
await asyncio.sleep(5)
# Do actual work
@router.post("/jobs")
async def create_job(background_tasks: BackgroundTasks):
background_tasks.add_task(process_job)
return {"status": "processing"}
```
### 4. Sync File I/O in Async Handlers
**Problem**: File operations block event loop.
```python
# BAD - blocking file I/O
@router.get("/config")
async def get_config():
with open("config.json") as f: # BLOCKS!
return json.load(f)
# GOOD - async file I/O
import aiofiles
@router.get("/config")
async def get_config():
async with aiofiles.open("config.json") as f:
content = await f.read()
return json.loads(content)
# ACCEPTABLE - small files in executor
import asyncio
def read_config_sync():
with open("config.json") as f:
return json.load(f)
@router.get("/config")
async def get_config():
loop = asyncio.get_event_loop()
config = await loop.run_in_executor(None, read_config_sync)
return config
```
### 5. Not Using Background Tasks
**Problem**: Long operations block response, timeout issues.
```python
# BAD - blocks response
@router.post("/emails")
async def send_email(email: EmailCreate):
await send_email_via_smtp(email) # Takes 5 seconds!
await log_email_sent(email) # Takes 1 second!
return {"status": "sent"}
# GOOD - use background tasks
from fastapi import BackgroundTasks
async def send_email_background(email: EmailCreate):
await send_email_via_smtp(email)
await log_email_sent(email)
@router.post("/emails", status_code=202)
async def send_email(
email: EmailCreate,
background_tasks: BackgroundTasks
):
background_tasks.add_task(send_email_background, email)
return {"status": "queued"}
```
### 6. Sequential Instead of Concurrent Calls
**Problem**: Misses parallelization opportunity.
```python
# BAD - sequential (slow)
@router.get("/dashboard")
async def get_dashboard(user_id: int):
user = await get_user(user_id)
posts = await get_user_posts(user_id)
stats = await get_user_stats(user_id)
return {"user": user, "posts": posts, "stats": stats}
# GOOD - concurrent (fast)
import asyncio
@router.get("/dashboard")
async def get_dashboard(user_id: int):
user, posts, stats = await asyncio.gather(
get_user(user_id),
get_user_posts(user_id),
get_user_stats(user_id)
)
return {"user": user, "posts": posts, "stats": stats}
```
### 7. Mixing Sync and Async Route Handlers
**Problem**: Inconsistent patterns, sync handlers block thread pool.
```python
# BAD - mixing sync and async
@router.get("/sync-route")
def sync_handler(): # Blocks thread pool
return db.query(User).all()
@router.get("/async-route")
async def async_handler():
return await db.query_async(User)
# GOOD - all async
@router.get("/route1")
async def handler1():
result = await db.execute(select(User))
return result.scalars().all()
@router.get("/route2")
async def handler2():
result = await db.execute(select(Post))
return result.scalars().all()
```
### 8. Not Awaiting Coroutines
**Problem**: Coroutine never executes, silent failures.
```python
# BAD - missing await
@router.post("/users")
async def create_user(user: UserCreate):
db.create_user(user) # Returns coroutine, doesn't execute!
return {"status": "created"} # User not actually created!
# GOOD - await coroutines
@router.post("/users", response_model=UserResponse, status_code=201)
async def create_user(user: UserCreate):
created_user = await db.create_user(user)
return created_user
```
### 9. Blocking External API Calls
**Problem**: Synchronous requests library blocks event loop.
```python
# BAD - requests blocks
import requests
@router.get("/weather")
async def get_weather(city: str):
response = requests.get(f"https://api.weather.com/{city}") # BLOCKS!
return response.json()
# GOOD - httpx async
import httpx
@router.get("/weather")
async def get_weather(city: str):
async with httpx.AsyncClient() as client:
response = await client.get(f"https://api.weather.com/{city}")
return response.json()
# GOOD - with timeout
@router.get("/weather")
async def get_weather(city: str):
async with httpx.AsyncClient(timeout=5.0) as client:
try:
response = await client.get(f"https://api.weather.com/{city}")
return response.json()
except httpx.TimeoutException:
raise HTTPException(504, detail="Weather API timeout")
```
## Review Questions
1. Are all route handlers `async def`?
2. Are there any `requests`, `time.sleep`, or `open()` calls?
3. Is the database driver async (AsyncSession, asyncpg, etc.)?
4. Are background tasks used for long operations?
5. Are independent async calls parallelized with `gather()`?
6. Are all coroutines properly awaited?
7. Are external API calls using async HTTP clients?
FILE:references/dependencies.md
# Dependencies
## Critical Anti-Patterns
### 1. Manual Dependency Calls
**Problem**: Bypasses FastAPI's injection system, no automatic cleanup.
```python
# BAD - manually calling dependency
async def get_db_session():
session = SessionLocal()
return session
@router.get("/users")
async def list_users():
db = await get_db_session() # Manual call!
users = await db.query(User).all()
return users
# GOOD - using Depends()
from fastapi import Depends
async def get_db_session():
session = SessionLocal()
try:
yield session
finally:
await session.close()
@router.get("/users", response_model=list[UserResponse])
async def list_users(db: Session = Depends(get_db_session)):
users = await db.query(User).all()
return users
```
### 2. Missing Cleanup in Yield Dependencies
**Problem**: Resources leak, connections not closed.
```python
# BAD - no cleanup
async def get_db():
db = DatabaseConnection()
yield db
# Connection never closed!
# GOOD - proper cleanup
async def get_db():
db = DatabaseConnection()
try:
yield db
finally:
await db.close()
```
### 3. Shared State Without Proper Scope
**Problem**: Dependencies create shared mutable state across requests.
```python
# BAD - shared mutable state
cache = {} # Shared across all requests!
async def get_cache():
return cache
@router.get("/items/{id}")
async def get_item(id: int, cache: dict = Depends(get_cache)):
# Multiple requests share same dict - race conditions!
if id not in cache:
cache[id] = await fetch_item(id)
return cache[id]
# GOOD - request-scoped state
from contextvars import ContextVar
request_cache: ContextVar[dict] = ContextVar('request_cache')
async def get_cache():
cache = {}
request_cache.set(cache)
return cache
# BETTER - use proper caching library
from functools import lru_cache
@lru_cache(maxsize=128)
async def get_item_cached(id: int):
return await fetch_item(id)
```
### 4. Nested Depends Not Utilized
**Problem**: Duplicate code, no composition of dependencies.
```python
# BAD - duplicated logic
async def get_current_user(token: str):
# Verify token, decode, fetch user
return user
async def get_admin_user(token: str):
# Same verification, then check admin
user = await verify_and_decode(token)
if not user.is_admin:
raise HTTPException(403)
return user
# GOOD - compose dependencies
async def get_current_user(token: str = Depends(oauth2_scheme)):
user = await verify_token(token)
if not user:
raise HTTPException(401, detail="Invalid token")
return user
async def get_admin_user(user: User = Depends(get_current_user)):
if not user.is_admin:
raise HTTPException(403, detail="Admin required")
return user
```
### 5. Dependencies with Side Effects
**Problem**: Dependencies modify state instead of providing resources.
```python
# BAD - dependency has side effects
async def log_request(request: Request):
# Side effect: writes to database
await db.log_request(request)
return None
@router.get("/users")
async def list_users(_: None = Depends(log_request)):
return users
# GOOD - use middleware for cross-cutting concerns
from starlette.middleware.base import BaseHTTPMiddleware
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await db.log_request(request)
response = await call_next(request)
return response
app.add_middleware(LoggingMiddleware)
# OR - dependency returns resource
async def get_logger(request: Request):
logger = RequestLogger(request)
return logger
@router.get("/users")
async def list_users(logger: RequestLogger = Depends(get_logger)):
logger.info("Listing users")
return users
```
### 6. Class-Based Dependencies Without Caching
**Problem**: New instance created unnecessarily.
```python
# BAD - new instance every time
class DatabaseService:
def __init__(self):
self.connection_pool = create_pool() # Expensive!
@router.get("/users")
async def list_users(db: DatabaseService = Depends(DatabaseService)):
return await db.query_users()
# GOOD - use singleton or app state
class DatabaseService:
def __init__(self, pool):
self.pool = pool
async def get_db_service(
pool = Depends(lambda: app.state.db_pool)
) -> DatabaseService:
return DatabaseService(pool)
# OR - use dependency with cache
async def get_db_service() -> DatabaseService:
return app.state.db_service
@router.get("/users")
async def list_users(db: DatabaseService = Depends(get_db_service)):
return await db.query_users()
```
### 7. Security Dependencies Not Applied Globally
**Problem**: Easy to forget security on new routes.
```python
# BAD - must remember to add auth to every route
@router.get("/users", dependencies=[Depends(verify_token)])
async def list_users(): ...
@router.get("/posts") # Forgot auth!
async def list_posts(): ...
# GOOD - apply at router level
router = APIRouter(
prefix="/api/v1",
dependencies=[Depends(verify_token)]
)
@router.get("/users")
async def list_users(): ...
@router.get("/posts")
async def list_posts(): ...
```
## Review Questions
1. Are all dependencies injected via `Depends()` not manually called?
2. Do yield dependencies have proper `try/finally` cleanup?
3. Is there any shared mutable state across requests?
4. Are nested dependencies used to compose common patterns?
5. Do dependencies provide resources, not perform side effects?
6. Are security dependencies applied at router or app level?
FILE:references/routes.md
# Routes
## Critical Anti-Patterns
### 1. Missing response_model
**Problem**: No type safety, documentation unclear, response not validated.
```python
# BAD
@router.get("/users/{user_id}")
async def get_user(user_id: int):
return {"id": user_id, "name": "Alice"}
# GOOD
@router.get("/users/{user_id}", response_model=UserResponse)
async def get_user(user_id: int):
return {"id": user_id, "name": "Alice"}
```
### 2. No APIRouter Prefix/Tags
**Problem**: Routes not organized, duplicated path prefixes, unclear docs.
```python
# BAD
@app.get("/api/v1/users")
async def list_users(): ...
@app.get("/api/v1/users/{id}")
async def get_user(id: int): ...
# GOOD
router = APIRouter(prefix="/api/v1/users", tags=["users"])
@router.get("")
async def list_users(): ...
@router.get("/{id}")
async def get_user(id: int): ...
app.include_router(router)
```
### 3. Wrong HTTP Methods
**Problem**: Violates REST conventions, confusing semantics.
```python
# BAD - using GET for mutations
@router.get("/users/{id}/delete")
async def delete_user(id: int): ...
# BAD - using POST for retrieval
@router.post("/users/{id}")
async def get_user(id: int): ...
# GOOD
@router.delete("/users/{id}", status_code=204)
async def delete_user(id: int): ...
@router.get("/users/{id}", response_model=UserResponse)
async def get_user(id: int): ...
```
### 4. Missing Status Codes
**Problem**: Always returns 200, even for creates/deletes.
```python
# BAD - creates should return 201
@router.post("/users")
async def create_user(user: UserCreate):
return created_user
# BAD - deletes should return 204
@router.delete("/users/{id}")
async def delete_user(id: int):
return {"message": "deleted"}
# GOOD
@router.post("/users", response_model=UserResponse, status_code=201)
async def create_user(user: UserCreate):
return created_user
@router.delete("/users/{id}", status_code=204)
async def delete_user(id: int):
# 204 returns no content
return None
```
### 5. Direct Exception Raising
**Problem**: Returns generic 500 errors instead of proper HTTP status codes.
```python
# BAD
@router.get("/users/{id}")
async def get_user(id: int):
user = await db.get_user(id)
if not user:
raise ValueError("User not found")
return user
# GOOD
from fastapi import HTTPException
@router.get("/users/{id}", response_model=UserResponse)
async def get_user(id: int):
user = await db.get_user(id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
```
### 6. Multiple Response Models
**Problem**: Same endpoint returns different schemas.
```python
# BAD
@router.get("/users/{id}")
async def get_user(id: int, full: bool = False):
if full:
return UserDetailResponse(...)
return UserSummaryResponse(...)
# GOOD - use separate endpoints
@router.get("/users/{id}", response_model=UserSummaryResponse)
async def get_user(id: int):
return UserSummaryResponse(...)
@router.get("/users/{id}/full", response_model=UserDetailResponse)
async def get_user_full(id: int):
return UserDetailResponse(...)
# ALTERNATIVE - use response_model with Union
from typing import Union
@router.get("/users/{id}", response_model=Union[UserSummaryResponse, UserDetailResponse])
async def get_user(id: int, full: bool = False):
if full:
return UserDetailResponse(...)
return UserSummaryResponse(...)
```
### 7. Path Parameter Validation
**Problem**: No validation on path parameters.
```python
# BAD
@router.get("/users/{user_id}")
async def get_user(user_id: int):
# What if user_id is negative or zero?
return await db.get_user(user_id)
# GOOD
from fastapi import Path
@router.get("/users/{user_id}", response_model=UserResponse)
async def get_user(user_id: int = Path(..., gt=0)):
return await db.get_user(user_id)
```
## Review Questions
1. Does every route have an explicit `response_model`?
2. Are routes organized with APIRouter using prefix and tags?
3. Are HTTP methods semantically correct (GET for read, POST for create, etc.)?
4. Do create operations return 201? Do deletes return 204?
5. Are HTTPExceptions used instead of generic exceptions?
6. Are path parameters validated?
FILE:references/validation.md
# Validation
## Critical Anti-Patterns
### 1. Manual Validation Instead of Pydantic
**Problem**: Duplicate validation logic, inconsistent errors.
```python
# BAD - manual validation
@router.post("/users")
async def create_user(request: Request):
data = await request.json()
if "email" not in data:
raise HTTPException(400, "Email required")
if "@" not in data["email"]:
raise HTTPException(400, "Invalid email")
return await db.create_user(data)
# GOOD - Pydantic validation
from pydantic import BaseModel, EmailStr
class UserCreate(BaseModel):
email: EmailStr
name: str
age: int | None = None
@router.post("/users", response_model=UserResponse, status_code=201)
async def create_user(user: UserCreate):
return await db.create_user(user)
```
### 2. Missing Field Validators
**Problem**: Invalid data passes through.
```python
# BAD - no validation on age
class UserCreate(BaseModel):
name: str
age: int # Can be negative!
# GOOD - field validation
from pydantic import BaseModel, Field
class UserCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
age: int = Field(..., ge=0, le=150)
email: EmailStr
```
### 3. Generic HTTPException Messages
**Problem**: Users don't know what's wrong.
```python
# BAD - vague error
@router.get("/users/{user_id}")
async def get_user(user_id: int):
user = await db.get_user(user_id)
if not user:
raise HTTPException(404) # No detail!
return user
# GOOD - specific error
@router.get("/users/{user_id}", response_model=UserResponse)
async def get_user(user_id: int):
user = await db.get_user(user_id)
if not user:
raise HTTPException(
status_code=404,
detail=f"User {user_id} not found"
)
return user
```
### 4. Not Using Pydantic Config
**Problem**: Models accept extra fields, expose internal fields.
```python
# BAD - accepts any extra fields
class UserCreate(BaseModel):
name: str
email: str
# {"name": "Alice", "email": "[email protected]", "is_admin": true} accepted!
# GOOD - strict validation
class UserCreate(BaseModel):
name: str
email: EmailStr
class Config:
extra = "forbid" # Reject unknown fields
# GOOD - control ORM exposure
class UserResponse(BaseModel):
id: int
name: str
email: str
# Don't expose password_hash, created_at, etc.
class Config:
from_attributes = True # Formerly orm_mode
```
### 5. Missing Custom Validators
**Problem**: Business rules not enforced.
```python
# BAD - no validation
class PasswordReset(BaseModel):
password: str
confirm_password: str
# Passwords might not match!
# GOOD - custom validator
from pydantic import BaseModel, model_validator
class PasswordReset(BaseModel):
password: str = Field(..., min_length=8)
confirm_password: str
@model_validator(mode='after')
def passwords_match(self):
if self.password != self.confirm_password:
raise ValueError('Passwords do not match')
return self
```
### 6. Not Handling 422 Validation Errors
**Problem**: Default 422 responses unclear to clients.
```python
# BAD - default 422 response is verbose and unclear
# (No custom handler)
# GOOD - custom 422 handler
from fastapi import Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request,
exc: RequestValidationError
):
errors = []
for error in exc.errors():
errors.append({
"field": ".".join(str(x) for x in error["loc"][1:]),
"message": error["msg"],
"type": error["type"]
})
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": errors}
)
```
### 7. Using Dict Instead of Models
**Problem**: No validation, no type safety, unclear API.
```python
# BAD - dict responses
@router.get("/users/{id}")
async def get_user(id: int) -> dict:
return {
"id": id,
"name": "Alice",
"extra_field": "oops" # Inconsistent!
}
# GOOD - Pydantic response model
class UserResponse(BaseModel):
id: int
name: str
email: str
@router.get("/users/{id}", response_model=UserResponse)
async def get_user(id: int):
user = await db.get_user(id)
if not user:
raise HTTPException(404, detail="User not found")
return user # Auto-validates and filters fields
```
### 8. Missing Query Parameter Validation
**Problem**: Invalid query parameters not validated.
```python
# BAD - no validation
@router.get("/users")
async def list_users(page: int = 1, size: int = 10):
# What if page is 0 or negative?
# What if size is 10000?
return await db.get_users(page, size)
# GOOD - validated query params
from fastapi import Query
@router.get("/users", response_model=list[UserResponse])
async def list_users(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100)
):
return await db.get_users(page, size)
```
## Review Questions
1. Are all request bodies defined as Pydantic models?
2. Do fields have proper validators (min_length, ge, EmailStr, etc.)?
3. Do HTTPExceptions include detailed error messages?
4. Are models configured with `extra = "forbid"` to reject unknown fields?
5. Are custom validators used for business rules?
6. Are query parameters validated with `Query()`?
7. Are response models used instead of plain dicts?
Reviews Wish SSH server code for proper middleware, session handling, and security patterns. Use when reviewing SSH server code using charmbracelet/wish.
---
name: wish-ssh-code-review
description: Reviews Wish SSH server code for proper middleware, session handling, and security patterns. Use when reviewing SSH server code using charmbracelet/wish.
---
# Wish SSH Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| Server setup, middleware | [references/server.md](references/server.md) |
| Session handling, security | [references/sessions.md](references/sessions.md) |
## Review gates
Run these **in order** when producing a written review. Do not claim a defect in a later step until the **Pass when** for the current step is satisfied for the code under review.
1. **Locate Wish entry points** — **Pass when:** you have at least one repo path per server surface that calls `wish.NewServer`, `wish.WithMiddleware`, registers `bubbletea.Middleware`, or defines the top-level `ssh.Handler` chain (list the paths explicitly).
2. **Capture server-setup evidence** — **Pass when:** for each path from step 1, you have the actual `wish.WithHostKey*` / host-key configuration and the **full middleware list in source order** as written (not recalled from memory). If graceful shutdown exists, note the file(s) where `ListenAndServe` and `Shutdown` run.
3. **Capture session / TUI evidence** — **Pass when:** for each `teaHandler` (or equivalent), you have noted from source whether `s.Pty()` is checked before using window size, and whether per-session renderers (`bubbletea.MakeRenderer`) are used where Lipgloss styles apply.
4. **Write findings** — **Pass when:** each finding uses `[FILE:LINE] ISSUE_TITLE` (line range allowed where needed) and points to the relevant row in **Quick Reference** (or the matching section in `references/`).
## Review Checklist
Use alongside **Review gates**; for a written review, complete the gates first so each item below can be tied to cited source.
- [ ] Host keys are loaded from file or generated securely
- [ ] Middleware order is correct (logging first, auth early)
- [ ] Session context is used for per-connection state
- [ ] Graceful shutdown handles active sessions
- [ ] PTY requests are handled for terminal apps
- [ ] Connection limits prevent resource exhaustion
- [ ] Timeout middleware prevents hung connections
- [ ] BubbleTea middleware correctly configured
## Critical Patterns
### Server Setup
```go
// GOOD - complete server setup
s, err := wish.NewServer(
wish.WithAddress(fmt.Sprintf("%s:%d", host, port)),
wish.WithHostKeyPath(".ssh/id_ed25519"),
wish.WithMiddleware(
logging.Middleware(), // first: log all connections
activeterm.Middleware(), // handle terminal sizing
bubbletea.Middleware(teaHandler),
),
)
if err != nil {
return fmt.Errorf("creating server: %w", err)
}
```
### Graceful Shutdown
```go
// BAD - abrupt shutdown
log.Fatal(s.ListenAndServe())
// GOOD - graceful shutdown
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGTERM)
go func() {
if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Error("server error", "error", err)
}
}()
<-done
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.Shutdown(ctx); err != nil {
log.Error("shutdown error", "error", err)
}
```
### BubbleTea Handler
```go
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
pty, _, _ := s.Pty()
model := NewModel(pty.Window.Width, pty.Window.Height)
return model, []tea.ProgramOption{
tea.WithAltScreen(),
tea.WithMouseCellMotion(),
}
}
```
## When to Load References
- Reviewing server initialization → server.md
- Reviewing authentication, session state → sessions.md
## Review Questions
1. Are host keys handled securely?
2. Is middleware order correct?
3. Is graceful shutdown implemented?
4. Are PTY window sizes passed to the TUI?
5. Are connection timeouts configured?
FILE:references/server.md
# Server Setup
## Host Key Management
### 1. Use Persistent Keys
```go
// BAD - generates new key each start (fingerprint changes)
s, err := wish.NewServer(
wish.WithAddress(":22"),
// no host key specified - generates random
)
// GOOD - load from file
s, err := wish.NewServer(
wish.WithAddress(":22"),
wish.WithHostKeyPath("/data/ssh_host_ed25519_key"),
)
// GOOD - generate if missing, persist for reuse
func ensureHostKey(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return err
}
// save to file...
}
return nil
}
```
### 2. Support Multiple Key Types
```go
s, err := wish.NewServer(
wish.WithAddress(":22"),
wish.WithHostKeyPath("/data/ssh_host_ed25519_key"),
wish.WithHostKeyPEM(rsaKeyBytes), // additional key type
)
```
## Middleware Configuration
### 1. Correct Middleware Order
```go
// Middleware executes in order - first added runs first
wish.WithMiddleware(
// 1. Logging - see all connections
logging.Middleware(),
// 2. Timeout - prevent hung connections
wish.WithIdleTimeout(10*time.Minute),
wish.WithMaxTimeout(30*time.Minute),
// 3. Active terminal - handle PTY/window sizing
activeterm.Middleware(),
// 4. Your app handler - BubbleTea or custom
bubbletea.Middleware(teaHandler),
)
```
### 2. Custom Middleware
```go
func customMiddleware() wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
// Before handling
log.Info("connection", "user", s.User(), "remote", s.RemoteAddr())
// Call next handler
next(s)
// After handling (session ended)
log.Info("disconnected", "user", s.User())
}
}
}
```
### 3. Metrics Middleware
```go
func metricsMiddleware(metrics *Metrics) wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
metrics.ActiveConnections.Inc()
start := time.Now()
defer func() {
metrics.ActiveConnections.Dec()
metrics.SessionDuration.Observe(time.Since(start).Seconds())
}()
next(s)
}
}
}
```
## Server Lifecycle
### 1. Graceful Shutdown
```go
func run() error {
s, err := wish.NewServer(...)
if err != nil {
return err
}
// Start server in goroutine
errCh := make(chan error, 1)
go func() {
errCh <- s.ListenAndServe()
}()
// Wait for shutdown signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-errCh:
return err
case <-quit:
}
// Graceful shutdown with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return s.Shutdown(ctx)
}
```
### 2. Health Checks
```go
// Run HTTP health endpoint alongside SSH
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
})
go http.ListenAndServe(":8080", nil)
```
## Connection Handling
### 1. Connection Limits
```go
// Limit concurrent connections
var connLimiter = make(chan struct{}, 100)
func connectionLimitMiddleware() wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
select {
case connLimiter <- struct{}{}:
defer func() { <-connLimiter }()
next(s)
default:
s.Exit(1)
}
}
}
}
```
### 2. Rate Limiting
```go
import "golang.org/x/time/rate"
var limiter = rate.NewLimiter(rate.Every(time.Second), 10) // 10/sec
func rateLimitMiddleware() wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
if !limiter.Allow() {
io.WriteString(s, "Too many connections, try again later\n")
s.Exit(1)
return
}
next(s)
}
}
}
```
## Anti-Patterns
### 1. No Error Handling on ListenAndServe
```go
// BAD
go s.ListenAndServe()
// GOOD
go func() {
if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Fatal("server error", "error", err)
}
}()
```
### 2. Ignoring Context in Shutdown
```go
// BAD - no timeout
s.Shutdown(context.Background()) // could hang forever
// GOOD
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
s.Shutdown(ctx)
```
## Review Questions
1. Are host keys persisted (not regenerated on restart)?
2. Is middleware order correct (logging first)?
3. Is graceful shutdown implemented with timeout?
4. Are connection/rate limits in place?
5. Is there a health check endpoint?
FILE:references/sessions.md
# Sessions & Security
## Session Handling
### 1. Access Session Info
```go
func handler(s ssh.Session) {
// User info
user := s.User()
remoteAddr := s.RemoteAddr()
// Public key (if key auth)
key := s.PublicKey()
// Environment variables
env := s.Environ()
// Command (if not interactive)
cmd := s.Command()
// PTY info (if allocated)
pty, winCh, isPty := s.Pty()
if isPty {
width := pty.Window.Width
height := pty.Window.Height
}
}
```
### 2. Handle Window Resize
```go
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
pty, winCh, _ := s.Pty()
model := NewModel(pty.Window.Width, pty.Window.Height)
// Window change channel is passed via activeterm middleware
// BubbleTea handles this automatically when using bubbletea.Middleware
return model, []tea.ProgramOption{tea.WithAltScreen()}
}
```
### 3. Session Context for State
```go
// Store per-session state using context
type contextKey string
const sessionDataKey contextKey = "sessionData"
type SessionData struct {
User string
ConnectAt time.Time
PageViews int
}
func sessionMiddleware() wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
data := &SessionData{
User: s.User(),
ConnectAt: time.Now(),
}
ctx := context.WithValue(s.Context(), sessionDataKey, data)
// Note: wish.Session doesn't expose SetContext
// Store in sync.Map keyed by session ID instead
next(s)
}
}
}
```
## Security
### 1. Authentication
```go
// Public key authentication
wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
// Check against authorized keys
authorized := loadAuthorizedKeys()
for _, authKey := range authorized {
if ssh.KeysEqual(key, authKey) {
return true
}
}
return false
}),
// Password authentication (not recommended for production)
wish.WithPasswordAuth(func(ctx ssh.Context, password string) bool {
// Never do this - use public key auth
return password == os.Getenv("SSH_PASSWORD")
}),
```
### 2. Authorization
```go
func authorizationMiddleware(allowedUsers map[string]bool) wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
if !allowedUsers[s.User()] {
io.WriteString(s, "Access denied\n")
s.Exit(1)
return
}
next(s)
}
}
}
```
### 3. Secure Defaults
```go
s, err := wish.NewServer(
wish.WithAddress(":22"),
wish.WithHostKeyPath("./host_key"),
// Timeouts prevent hung connections
wish.WithIdleTimeout(10*time.Minute),
wish.WithMaxTimeout(60*time.Minute),
// Require public key auth
wish.WithPublicKeyAuth(authHandler),
wish.WithMiddleware(
logging.Middleware(), // audit trail
activeterm.Middleware(),
bubbletea.Middleware(teaHandler),
),
)
```
## BubbleTea Integration
### 1. Basic Handler
```go
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
pty, _, _ := s.Pty()
renderer := bubbletea.MakeRenderer(s)
model := NewModel(renderer, pty.Window.Width, pty.Window.Height)
return model, []tea.ProgramOption{
tea.WithAltScreen(),
}
}
```
### 2. Passing Session to Model
```go
type Model struct {
renderer *lipgloss.Renderer
user string
width int
height int
}
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
pty, _, _ := s.Pty()
renderer := bubbletea.MakeRenderer(s)
model := Model{
renderer: renderer,
user: s.User(),
width: pty.Window.Width,
height: pty.Window.Height,
}
return model, []tea.ProgramOption{tea.WithAltScreen()}
}
```
### 3. Per-Session Styles
```go
// Each session needs its own renderer for correct color detection
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
renderer := bubbletea.MakeRenderer(s)
// Create styles with session's renderer
styles := NewStyles(renderer)
model := Model{
styles: styles,
}
return model, nil
}
type Styles struct {
Title lipgloss.Style
Item lipgloss.Style
}
func NewStyles(r *lipgloss.Renderer) Styles {
return Styles{
Title: r.NewStyle().Bold(true).Foreground(lipgloss.Color("205")),
Item: r.NewStyle().PaddingLeft(2),
}
}
```
## Anti-Patterns
### 1. Ignoring PTY
```go
// BAD - assumes PTY always exists
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
pty, _, _ := s.Pty() // may be nil!
model := NewModel(pty.Window.Width, pty.Window.Height) // panic!
}
// GOOD - handle non-PTY connections
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
pty, _, hasPty := s.Pty()
width, height := 80, 24 // sensible defaults
if hasPty {
width = pty.Window.Width
height = pty.Window.Height
}
model := NewModel(width, height)
return model, nil
}
```
### 2. Global Lipgloss Styles
```go
// BAD - global styles don't detect terminal capabilities per-session
var titleStyle = lipgloss.NewStyle().Bold(true)
// GOOD - per-session renderer
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
renderer := bubbletea.MakeRenderer(s)
titleStyle := renderer.NewStyle().Bold(true)
// ...
}
```
## Review Questions
1. Is PTY presence checked before accessing window size?
2. Are per-session renderers used for Lipgloss?
3. Is authentication configured (public key preferred)?
4. Are session timeouts set?
5. Is logging middleware capturing connection info?
Reviews Prometheus instrumentation in Go code for proper metric types, labels, and patterns. Use when reviewing code with prometheus/client_golang metrics.
---
name: prometheus-go-code-review
description: Reviews Prometheus instrumentation in Go code for proper metric types, labels, and patterns. Use when reviewing code with prometheus/client_golang metrics.
---
# Prometheus Go Code Review
## Review Checklist
- [ ] Metric types match measurement semantics (Counter/Gauge/Histogram)
- [ ] Labels have low cardinality (no user IDs, timestamps, paths)
- [ ] Metric names follow conventions (snake_case, unit suffix)
- [ ] Histograms use appropriate bucket boundaries
- [ ] Metrics registered once, not per-request
- [ ] Collectors don't panic on race conditions
- [ ] /metrics endpoint exposed and accessible
## Hard gates (sequenced)
Complete in order before recording a **finding**. Skip gates that clearly do not apply to the diff.
1. **Evidence scope** — Enumerate the files you are reviewing that touch Prometheus (`prometheus/client_golang`, `promauto`, `promhttp`, or `MustRegister`). **Pass:** you have a concrete path list (from the diff or an explicit file set); no repo-wide claim without at least one path.
2. **Label cardinality** — For each `*Vec` or labeled metric in scope, list label names and where values come from (constants, bounded codes, vs request-derived strings). **Pass:** no label uses unbounded values (e.g. raw `user_id`, full URL path, timestamps) unless the code uses a bounded mapping and you cite it.
3. **Registration lifecycle** — For metric definitions in scope, confirm constructors run once (package-level `var`, `init`, or `sync.Once`), not inside per-request handlers. **Pass:** no pattern that allocates/registers a new `Counter`/`Histogram`/`*Vec` on every request for the same logical metric.
4. **Finding shape** — Each finding names a file (and line or symbol where possible), states which gate (2 or 3) would fail if the issue is real, and ties to observed code. **Pass:** no standalone style nit when gates 2–3 are satisfied for that code.
## Metric Type Selection
| Measurement | Type | Example |
|-------------|------|---------|
| Requests processed | Counter | `requests_total` |
| Items in queue | Gauge | `queue_length` |
| Request duration | Histogram | `request_duration_seconds` |
| Concurrent connections | Gauge | `active_connections` |
| Errors since start | Counter | `errors_total` |
| Memory usage | Gauge | `memory_bytes` |
## Critical Anti-Patterns
### 1. High Cardinality Labels
```go
// BAD - unique per user/request
counter := promauto.NewCounterVec(
prometheus.CounterOpts{Name: "requests_total"},
[]string{"user_id", "path"}, // millions of series!
)
counter.WithLabelValues(userID, request.URL.Path).Inc()
// GOOD - bounded label values
counter := promauto.NewCounterVec(
prometheus.CounterOpts{Name: "requests_total"},
[]string{"method", "status_code"}, // <100 series
)
counter.WithLabelValues(r.Method, statusCode).Inc()
```
### 2. Wrong Metric Type
```go
// BAD - using gauge for monotonic value
requestCount := promauto.NewGauge(prometheus.GaugeOpts{
Name: "http_requests",
})
requestCount.Inc() // should be Counter!
// GOOD
requestCount := promauto.NewCounter(prometheus.CounterOpts{
Name: "http_requests_total",
})
requestCount.Inc()
```
### 3. Registering Per-Request
```go
// BAD - new metric per request
func handler(w http.ResponseWriter, r *http.Request) {
counter := prometheus.NewCounter(...) // creates new each time!
prometheus.MustRegister(counter) // panics on duplicate!
}
// GOOD - register once
var requestCounter = promauto.NewCounter(prometheus.CounterOpts{
Name: "http_requests_total",
})
func handler(w http.ResponseWriter, r *http.Request) {
requestCounter.Inc()
}
```
### 4. Missing Unit Suffix
```go
// BAD
duration := promauto.NewHistogram(prometheus.HistogramOpts{
Name: "request_duration", // no unit!
})
// GOOD
duration := promauto.NewHistogram(prometheus.HistogramOpts{
Name: "request_duration_seconds", // unit in name
})
```
## Good Patterns
### Metric Definition
```go
var (
httpRequests = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "myapp",
Subsystem: "http",
Name: "requests_total",
Help: "Total HTTP requests processed",
},
[]string{"method", "status"},
)
httpDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "myapp",
Subsystem: "http",
Name: "request_duration_seconds",
Help: "HTTP request latencies",
Buckets: []float64{.005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10},
},
[]string{"method"},
)
)
```
### Middleware Pattern
```go
func metricsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := prometheus.NewTimer(httpDuration.WithLabelValues(r.Method))
defer timer.ObserveDuration()
wrapped := &responseWriter{ResponseWriter: w, status: 200}
next.ServeHTTP(wrapped, r)
httpRequests.WithLabelValues(r.Method, strconv.Itoa(wrapped.status)).Inc()
})
}
```
### Exposing Metrics
```go
import "github.com/prometheus/client_golang/prometheus/promhttp"
func main() {
http.Handle("/metrics", promhttp.Handler())
http.ListenAndServe(":9090", nil)
}
```
## Review Questions
1. Are metric types correct (Counter vs Gauge vs Histogram)?
2. Are label values bounded (no UUIDs, timestamps, paths)?
3. Do metric names include units (_seconds, _bytes)?
4. Are metrics registered once (not per-request)?
5. Is /metrics endpoint properly exposed?
Comprehensive Go web development persona enforcing zero global state, explicit error handling, input validation, testability, and documentation conventions....
---
name: go-web-expert
description: Comprehensive Go web development persona enforcing zero global state, explicit error handling, input validation, testability, and documentation conventions. Use when building Go web applications to ensure production-quality code from the start.
---
# Go Web Expert System
Five non-negotiable rules for production-quality Go web applications. Every handler, every service, every line of code must satisfy all five.
## Quick Reference
| Topic | Reference |
|-------|-----------|
| Validation tags, custom validators, nested structs, error formatting | [references/validation.md](references/validation.md) |
| httptest patterns, middleware testing, integration tests, fixtures | [references/testing-handlers.md](references/testing-handlers.md) |
## Rules of Engagement
| # | Rule | One-Liner |
|---|------|-----------|
| 1 | Zero Global State | All handlers are methods on a struct; no package-level `var` for mutable state |
| 2 | Explicit Error Handling | Every error is checked, wrapped with `fmt.Errorf("doing X: %w", err)` |
| 3 | Validation First | All incoming JSON validated with `go-playground/validator` at the boundary |
| 4 | Testability | Every handler has a `_test.go` using `httptest` with table-driven tests |
| 5 | Documentation | Every exported symbol has a Go doc comment starting with its name |
### Hard gates (new HTTP handler)
Apply **in order**. Do not treat the next step as done until the **Pass when** for the current step is satisfied (objective evidence on disk or in test output—not “I checked mentally”).
1. **Dependencies (Rule 1)** — **Pass when:** the handler is a method on a struct that holds every mutable dependency (`db`, logger, HTTP clients, caches); any new package-level `var` is only in the allowlist under [What Is Allowed at Package Level](#what-is-allowed-at-package-level). *Evidence:* constructor wires deps; no new forbidden globals from that list.
2. **Boundary (Rule 3)** — **Pass before** calling service/domain code: **Pass when:** the request decodes into a tagged struct and `validate.Struct` (or equivalent) runs; invalid JSON and validation failures have defined HTTP status bodies (e.g. 400/422). *Evidence:* decode + `validate.Struct` appear in the handler; tests or manual run show 422/400 for bad input.
3. **Errors (Rule 2)** — **Pass when:** no `_` discards on the handler path; `json.NewEncoder(w).Encode` errors are handled; errors passed up or logged use wrapping (`%w`) or mapped `AppError` as this skill prescribes. *Evidence:* review the diff for ignored errors and bare `return err` without context where wrapping is required.
4. **Tests (Rule 4)** — **Pass when:** a `_test.go` exists for the handler package and calls `ServeHTTP` with `httptest`, including at least one success case and one non-2xx case (validation, not found, or domain error). *Evidence:* test file path exists; `go test` for that package passes.
5. **Documentation (Rule 5)** — **Pass when:** every **new or changed** exported identifier in the change has a doc comment whose first line starts with that identifier’s name. *Evidence:* `go doc <pkg>` or the IDE/doc preview shows summaries for new exports.
---
## Rule 1: Zero Global State
All handlers must be methods on a server struct. No package-level `var` for databases, loggers, clients, or any mutable state.
```go
// FORBIDDEN
var db *sql.DB
var logger *slog.Logger
func handleGetUser(w http.ResponseWriter, r *http.Request) {
user, err := db.QueryRow(...) // global state -- untestable, unsafe
}
// REQUIRED
type Server struct {
db *sql.DB
logger *slog.Logger
router *http.ServeMux
}
func (s *Server) handleGetUser(w http.ResponseWriter, r *http.Request) {
user, err := s.db.QueryRow(...) // explicit dependency
}
```
### What Is Allowed at Package Level
- **Constants** -- `const maxPageSize = 100`
- **Pure functions** -- functions with no side effects that depend only on their arguments
- **Sentinel errors** -- `var ErrNotFound = errors.New("not found")`
- **Validator instance** -- `var validate = validator.New()` (stateless after init)
### What Is Forbidden at Package Level
- Database connections (`*sql.DB`, `*pgxpool.Pool`)
- Loggers (`*slog.Logger`)
- HTTP clients configured with timeouts or transport
- Configuration structs read from environment
- Caches, rate limiters, or any mutable shared resource
### Constructor Pattern
```go
func NewServer(db *sql.DB, logger *slog.Logger) *Server {
s := &Server{
db: db,
logger: logger,
router: http.NewServeMux(),
}
s.routes()
return s
}
func (s *Server) routes() {
s.router.HandleFunc("GET /api/users/{id}", s.handleGetUser)
s.router.HandleFunc("POST /api/users", s.handleCreateUser)
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
```
---
## Rule 2: Explicit Error Handling
Never ignore errors. Every error must be wrapped with context describing what was being attempted when the error occurred.
```go
// FORBIDDEN
result, _ := doSomething()
json.NewEncoder(w).Encode(data) // error ignored
// REQUIRED
result, err := doSomething()
if err != nil {
return fmt.Errorf("doing something for user %s: %w", userID, err)
}
if err := json.NewEncoder(w).Encode(data); err != nil {
s.logger.Error("encoding response", "err", err, "request_id", reqID)
}
```
### Error Wrapping Convention
Format: `"<verb>ing <noun>: %w"` -- lowercase, no period, provides call-chain context.
```go
// Good wrapping -- each layer adds context
return fmt.Errorf("creating user: %w", err)
return fmt.Errorf("inserting user into database: %w", err)
return fmt.Errorf("hashing password for user %s: %w", email, err)
// Bad wrapping
return fmt.Errorf("error: %w", err) // no context
return fmt.Errorf("Failed to create user: %w", err) // uppercase, verbose
return err // no wrapping at all
```
### Structured Error Type for HTTP APIs
```go
type AppError struct {
Code int `json:"-"`
Message string `json:"error"`
Detail string `json:"detail,omitempty"`
}
func (e *AppError) Error() string {
return fmt.Sprintf("%d: %s", e.Code, e.Message)
}
// Map domain errors to HTTP errors in one place
func handleError(w http.ResponseWriter, r *http.Request, err error) {
var appErr *AppError
if errors.As(err, &appErr) {
writeJSON(w, appErr.Code, appErr)
return
}
slog.Error("unhandled error",
"err", err,
"path", r.URL.Path,
)
writeJSON(w, 500, map[string]string{"error": "internal server error"})
}
```
### Common Mistakes
```go
// MISTAKE: not checking Close errors on writers
defer f.Close() // at minimum, log Close errors for writable resources
// BETTER for writable resources:
defer func() {
if err := f.Close(); err != nil {
s.logger.Error("closing file", "err", err)
}
}()
// OK for read-only resources where Close rarely fails:
defer resp.Body.Close()
```
---
## Rule 3: Validation First
Use `go-playground/validator` for all incoming JSON. Validate at the boundary, trust internal data.
```go
import "github.com/go-playground/validator/v10"
var validate = validator.New()
type CreateUserRequest struct {
Name string `json:"name" validate:"required,min=1,max=100"`
Email string `json:"email" validate:"required,email"`
Age int `json:"age" validate:"omitempty,gte=0,lte=150"`
}
func (s *Server) handleCreateUser(w http.ResponseWriter, r *http.Request) error {
var req CreateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return &AppError{Code: 400, Message: "invalid JSON", Detail: err.Error()}
}
if err := validate.Struct(req); err != nil {
return &AppError{Code: 422, Message: "validation failed", Detail: formatValidationErrors(err)}
}
// From here, req is trusted
user, err := s.userService.Create(r.Context(), req.Name, req.Email)
if err != nil {
return fmt.Errorf("creating user: %w", err)
}
writeJSON(w, http.StatusCreated, user)
return nil
}
```
### Validation Error Formatting
```go
func formatValidationErrors(err error) string {
var msgs []string
for _, e := range err.(validator.ValidationErrors) {
msgs = append(msgs, fmt.Sprintf("field '%s' failed on '%s'", e.Field(), e.Tag()))
}
return strings.Join(msgs, "; ")
}
```
### Validation Boundary Rule
- **Validate at the edge** -- HTTP handlers, message consumers, CLI input
- **Trust internal data** -- service layer receives already-validated types
- **Never validate twice** -- if the handler validated, the service does not re-validate the same fields
See [references/validation.md](references/validation.md) for custom validators, nested struct validation, slice validation, and cross-field validation.
---
## Rule 4: Testability
Every handler must have a corresponding `_test.go` file using `httptest`. Test through the HTTP layer, not by calling handler methods directly.
```go
func TestServer_handleGetUser(t *testing.T) {
mockStore := &MockUserStore{
GetUserFunc: func(ctx context.Context, id string) (*User, error) {
if id == "123" {
return &User{ID: "123", Name: "Alice"}, nil
}
return nil, ErrNotFound
},
}
srv := NewServer(mockStore, slog.Default())
tests := []struct {
name string
path string
wantStatus int
wantBody string
}{
{
name: "existing user",
path: "/api/users/123",
wantStatus: http.StatusOK,
wantBody: `"name":"Alice"`,
},
{
name: "not found",
path: "/api/users/999",
wantStatus: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.path, nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != tt.wantStatus {
t.Errorf("status = %d, want %d", w.Code, tt.wantStatus)
}
if tt.wantBody != "" && !strings.Contains(w.Body.String(), tt.wantBody) {
t.Errorf("body = %q, want to contain %q", w.Body.String(), tt.wantBody)
}
})
}
}
```
### Key Testing Principles
- **Test through HTTP** -- use `httptest.NewRequest` and `httptest.NewRecorder`, call `srv.ServeHTTP`
- **Interface-based mocks** -- define narrow interfaces at the consumer, create mock implementations for tests
- **Table-driven tests** -- one `[]struct` with test cases, one `t.Run` loop
- **Error paths matter** -- test 400s, 404s, 422s, and 500s, not just 200s
- **No global test state** -- each test creates its own server with its own mocks
See [references/testing-handlers.md](references/testing-handlers.md) for middleware testing, integration tests with real databases, file upload testing, and streaming response testing.
---
## Rule 5: Documentation
Every exported function, type, method, and constant must have a Go doc comment following standard conventions.
```go
// CreateUser creates a new user with the given name and email.
// It returns ErrDuplicateEmail if a user with the same email already exists.
func (s *UserService) CreateUser(ctx context.Context, name, email string) (*User, error) {
// ...
}
// Server handles HTTP requests for the user API.
type Server struct {
// ...
}
// NewServer creates a Server with the given dependencies.
// The logger must not be nil.
func NewServer(store UserStore, logger *slog.Logger) *Server {
// ...
}
// ErrNotFound is returned when a requested resource does not exist.
var ErrNotFound = errors.New("not found")
```
### Doc Comment Conventions
- **Start with the name** -- `// CreateUser creates...` not `// This function creates...`
- **First sentence is the summary** -- shown in `go doc` listings and IDE tooltips
- **Mention important error returns** -- callers need to know which errors to check
- **Don't document the obvious** -- `// SetName sets the name` adds no value
- **Document why, not what** -- when behavior is non-obvious, explain the reasoning
### Package Documentation
```go
// Package user provides user management for the application.
// It handles creation, retrieval, and deletion of user accounts,
// with email uniqueness enforced at the database level.
package user
```
---
## Cross-Cutting Concerns
The five rules reinforce each other. Here is how they interact.
### Zero Global State Enables Testability
Because all dependencies are on the struct, tests can inject mocks:
```go
// Production
srv := NewServer(realDB, prodLogger)
// Test
srv := NewServer(mockStore, slog.Default())
```
If `db` were a global `var`, tests would need to mutate package state, causing race conditions in parallel tests.
### Validation First Simplifies Error Handling
When handlers validate at the boundary, the service layer can assume valid input. This means service-layer errors are always unexpected (database failures, network issues), and error handling becomes simpler:
```go
func (s *UserService) Create(ctx context.Context, name, email string) (*User, error) {
// No need to check if name is empty -- handler already validated
user := &User{Name: name, Email: email}
if err := s.store.Insert(ctx, user); err != nil {
return nil, fmt.Errorf("inserting user: %w", err)
}
return user, nil
}
```
### Documentation Makes Error Handling Discoverable
Doc comments that mention error returns tell callers what to handle:
```go
// Delete removes a user by ID.
// It returns ErrNotFound if the user does not exist.
// It returns ErrHasActiveOrders if the user has unfinished orders.
func (s *UserService) Delete(ctx context.Context, id string) error {
```
---
## Self-Review Checklist
Before considering any handler or service complete, verify all five rules:
### Zero Global State
- [ ] No package-level `var` for mutable state (db, logger, clients)
- [ ] All handlers are methods on a struct
- [ ] Dependencies injected through constructor
### Explicit Error Handling
- [ ] No `_` ignoring returned errors
- [ ] All errors wrapped with `fmt.Errorf("doing X: %w", err)`
- [ ] `json.NewEncoder(w).Encode(...)` error checked or logged
- [ ] Structured `AppError` used for HTTP error responses
### Validation First
- [ ] All request structs have `validate` tags
- [ ] `validate.Struct(req)` called before any business logic
- [ ] Validation errors return 422 with field-level detail
- [ ] Service layer does not re-validate handler-validated data
### Testability
- [ ] `_test.go` file exists for every handler file
- [ ] Tests use `httptest.NewRequest` and `httptest.NewRecorder`
- [ ] Table-driven tests cover happy path and error paths
- [ ] Mocks implement narrow interfaces, not concrete types
### Documentation
- [ ] Every exported function has a doc comment starting with its name
- [ ] Error return values are documented
- [ ] Package has a doc comment
## When to Load References
Load **validation.md** when:
- Adding new request types with validation tags
- Creating custom validators
- Validating nested structs, slices, or maps
- Formatting validation errors for API responses
Load **testing-handlers.md** when:
- Writing handler tests for the first time in a project
- Testing middleware chains or authentication
- Setting up integration tests with a real database
- Testing file uploads or streaming responses
FILE:references/testing-handlers.md
# Testing Go HTTP Handlers
## httptest Fundamentals
Every handler test follows the same three-step pattern: build a request, record the response, assert on the result.
### Basic Pattern
```go
func TestServer_handleHealth(t *testing.T) {
srv := NewServer(nil, slog.Default())
req := httptest.NewRequest("GET", "/healthz", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
}
```
### httptest.NewRequest vs http.NewRequest
```go
// httptest.NewRequest -- panics on error, never returns one.
// Use in tests where a bad URL is a programming error.
req := httptest.NewRequest("GET", "/api/users/123", nil)
// http.NewRequest -- returns an error. Use when constructing
// from dynamic test data that could be invalid.
req, err := http.NewRequest("POST", "/api/users", body)
if err != nil {
t.Fatal(err)
}
```
### httptest.NewRecorder
`httptest.NewRecorder` returns a `*httptest.ResponseRecorder` that implements `http.ResponseWriter` and captures everything the handler writes.
```go
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
w.Code // status code (int)
w.Body.String() // response body as string
w.Body.Bytes() // response body as []byte
w.Header().Get("Content-Type") // response headers
w.Result() // *http.Response for more detailed inspection
```
---
## Testing with Real JSON Payloads
### POST with JSON Body
```go
func TestServer_handleCreateUser(t *testing.T) {
mockStore := &MockUserStore{
CreateFunc: func(ctx context.Context, u *User) error {
u.ID = "generated-id"
return nil
},
}
srv := NewServer(mockStore, slog.Default())
body := strings.NewReader(`{"name":"Alice","email":"[email protected]"}`)
req := httptest.NewRequest("POST", "/api/users", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusCreated, w.Body.String())
}
var resp User
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if resp.ID == "" {
t.Error("expected non-empty user ID")
}
if resp.Name != "Alice" {
t.Errorf("name = %q, want %q", resp.Name, "Alice")
}
}
```
### Testing Validation Errors
```go
func TestServer_handleCreateUser_validation(t *testing.T) {
srv := NewServer(&MockUserStore{}, slog.Default())
tests := []struct {
name string
body string
wantStatus int
wantErr string
}{
{
name: "missing name",
body: `{"email":"[email protected]"}`,
wantStatus: 422,
wantErr: "name",
},
{
name: "invalid email",
body: `{"name":"Alice","email":"not-an-email"}`,
wantStatus: 422,
wantErr: "email",
},
{
name: "malformed JSON",
body: `{bad json`,
wantStatus: 400,
wantErr: "invalid JSON",
},
{
name: "empty body",
body: ``,
wantStatus: 400,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/users", strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != tt.wantStatus {
t.Errorf("status = %d, want %d; body = %s", w.Code, tt.wantStatus, w.Body.String())
}
if tt.wantErr != "" && !strings.Contains(w.Body.String(), tt.wantErr) {
t.Errorf("body = %q, want to contain %q", w.Body.String(), tt.wantErr)
}
})
}
}
```
### Decoding JSON Responses with a Helper
```go
func decodeJSON[T any](t *testing.T, w *httptest.ResponseRecorder) T {
t.Helper()
var result T
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("decoding response body: %v", err)
}
return result
}
// Usage
resp := decodeJSON[User](t, w)
if resp.Name != "Alice" {
t.Errorf("name = %q, want %q", resp.Name, "Alice")
}
```
---
## Testing Middleware Chains
### Testing a Single Middleware
Test middleware in isolation by wrapping a known inner handler:
```go
func TestRequestIDMiddleware(t *testing.T) {
// Inner handler that captures the request ID from context
var gotID string
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := RequestID(inner)
t.Run("generates ID when missing", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if gotID == "" {
t.Error("expected non-empty request ID in context")
}
if w.Header().Get("X-Request-ID") == "" {
t.Error("expected X-Request-ID response header")
}
})
t.Run("preserves existing ID", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Request-ID", "test-id-123")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if gotID != "test-id-123" {
t.Errorf("request ID = %q, want %q", gotID, "test-id-123")
}
})
}
```
### Testing the Full Middleware Stack
Test through the complete stack to verify middleware ordering and interaction:
```go
func TestMiddlewareChain(t *testing.T) {
mockStore := &MockUserStore{
GetUserFunc: func(ctx context.Context, id string) (*User, error) {
return &User{ID: id, Name: "Alice"}, nil
},
}
srv := NewServer(mockStore, slog.Default())
// Apply the same middleware stack as production
handler := Chain(srv, Recovery, RequestID, Logger(slog.Default()))
req := httptest.NewRequest("GET", "/api/users/123", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
if w.Header().Get("X-Request-ID") == "" {
t.Error("middleware chain did not set X-Request-ID")
}
}
```
### Testing Recovery Middleware
```go
func TestRecoveryMiddleware(t *testing.T) {
panicking := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("something went wrong")
})
handler := Recovery(panicking)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// Should not panic
handler.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError)
}
}
```
---
## Testing Authentication and Authorization
### Testing Auth Middleware
```go
func TestAuthMiddleware(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := UserFromContext(r.Context())
if !ok {
t.Fatal("expected user in context")
}
fmt.Fprintf(w, "hello %s", user.Name)
})
tokenValidator := &MockTokenValidator{
ValidateFunc: func(token string) (*User, error) {
if token == "Bearer valid-token" {
return &User{ID: "1", Name: "Alice", Roles: []string{"admin"}}, nil
}
return nil, errors.New("invalid token")
},
}
handler := AuthMiddleware(tokenValidator)(inner)
tests := []struct {
name string
authHeader string
wantStatus int
wantBody string
}{
{
name: "valid token",
authHeader: "Bearer valid-token",
wantStatus: http.StatusOK,
wantBody: "hello Alice",
},
{
name: "invalid token",
authHeader: "Bearer bad-token",
wantStatus: http.StatusUnauthorized,
},
{
name: "missing header",
authHeader: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tt.wantStatus {
t.Errorf("status = %d, want %d", w.Code, tt.wantStatus)
}
if tt.wantBody != "" && !strings.Contains(w.Body.String(), tt.wantBody) {
t.Errorf("body = %q, want to contain %q", w.Body.String(), tt.wantBody)
}
})
}
}
```
### Testing Role-Based Authorization
```go
func TestRequireRole(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := RequireRole("admin")(inner)
tests := []struct {
name string
user *User
wantStatus int
}{
{
name: "admin user",
user: &User{Roles: []string{"admin"}},
wantStatus: http.StatusOK,
},
{
name: "regular user",
user: &User{Roles: []string{"user"}},
wantStatus: http.StatusForbidden,
},
{
name: "no user in context",
user: nil,
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/admin", nil)
if tt.user != nil {
ctx := context.WithValue(req.Context(), userKey, tt.user)
req = req.WithContext(ctx)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tt.wantStatus {
t.Errorf("status = %d, want %d", w.Code, tt.wantStatus)
}
})
}
}
```
---
## Integration Tests with Real Database
### Pattern: Test Database with t.Cleanup
```go
func setupTestDB(t *testing.T) *sql.DB {
t.Helper()
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
t.Skip("TEST_DATABASE_URL not set")
}
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatalf("opening test database: %v", err)
}
t.Cleanup(func() {
db.Close()
})
return db
}
```
### Transaction Rollback for Test Isolation
Each test runs in a transaction that is rolled back, leaving the database unchanged:
```go
func setupTestTx(t *testing.T, db *sql.DB) *sql.Tx {
t.Helper()
tx, err := db.Begin()
if err != nil {
t.Fatalf("beginning transaction: %v", err)
}
t.Cleanup(func() {
tx.Rollback() // always rollback -- test data never persists
})
return tx
}
```
### Full Integration Test
```go
func TestUserStore_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
db := setupTestDB(t)
tx := setupTestTx(t, db)
store := NewUserStore(tx)
t.Run("create and retrieve", func(t *testing.T) {
user := &User{Name: "Alice", Email: "[email protected]"}
err := store.Create(context.Background(), user)
if err != nil {
t.Fatalf("creating user: %v", err)
}
if user.ID == "" {
t.Fatal("expected non-empty ID after create")
}
got, err := store.GetByID(context.Background(), user.ID)
if err != nil {
t.Fatalf("getting user: %v", err)
}
if got.Name != "Alice" {
t.Errorf("name = %q, want %q", got.Name, "Alice")
}
})
t.Run("duplicate email", func(t *testing.T) {
user1 := &User{Name: "Bob", Email: "[email protected]"}
if err := store.Create(context.Background(), user1); err != nil {
t.Fatalf("creating first user: %v", err)
}
user2 := &User{Name: "Bob2", Email: "[email protected]"}
err := store.Create(context.Background(), user2)
if !errors.Is(err, ErrDuplicateEmail) {
t.Errorf("err = %v, want ErrDuplicateEmail", err)
}
})
}
```
### HTTP Integration Test
Test the full HTTP stack against a real database:
```go
func TestServer_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
db := setupTestDB(t)
tx := setupTestTx(t, db)
store := NewUserStore(tx)
srv := NewServer(store, slog.Default())
// Create
body := strings.NewReader(`{"name":"Alice","email":"[email protected]"}`)
req := httptest.NewRequest("POST", "/api/users", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("create: status = %d, want %d; body = %s", w.Code, http.StatusCreated, w.Body.String())
}
var created User
json.NewDecoder(w.Body).Decode(&created)
// Retrieve
req = httptest.NewRequest("GET", "/api/users/"+created.ID, nil)
w = httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("get: status = %d, want %d", w.Code, http.StatusOK)
}
var fetched User
json.NewDecoder(w.Body).Decode(&fetched)
if fetched.Name != "Alice" {
t.Errorf("name = %q, want %q", fetched.Name, "Alice")
}
}
```
---
## Testing File Uploads
### Multipart Form Data
```go
func TestServer_handleUpload(t *testing.T) {
srv := NewServer(&MockFileStore{}, slog.Default())
// Build multipart body
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("file", "test.txt")
if err != nil {
t.Fatal(err)
}
part.Write([]byte("hello world"))
// Add a form field alongside the file
writer.WriteField("description", "test file upload")
writer.Close()
req := httptest.NewRequest("POST", "/api/upload", &buf)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
}
}
```
### Testing File Size Limits
```go
func TestServer_handleUpload_tooLarge(t *testing.T) {
srv := NewServer(&MockFileStore{}, slog.Default())
// Create a file that exceeds the size limit
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, _ := writer.CreateFormFile("file", "large.bin")
part.Write(make([]byte, 11<<20)) // 11MB, exceeding a 10MB limit
writer.Close()
req := httptest.NewRequest("POST", "/api/upload", &buf)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("status = %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
}
}
```
---
## Testing Streaming Responses
### Server-Sent Events
```go
func TestServer_handleSSE(t *testing.T) {
events := make(chan string, 3)
events <- "event 1"
events <- "event 2"
events <- "event 3"
close(events)
srv := NewServer(&MockEventSource{Events: events}, slog.Default())
req := httptest.NewRequest("GET", "/api/events", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
contentType := w.Header().Get("Content-Type")
if contentType != "text/event-stream" {
t.Errorf("Content-Type = %q, want %q", contentType, "text/event-stream")
}
body := w.Body.String()
for _, want := range []string{"event 1", "event 2", "event 3"} {
if !strings.Contains(body, want) {
t.Errorf("body missing %q", want)
}
}
}
```
### Testing with httptest.Server for Long-Lived Connections
For testing streaming with actual HTTP connections (e.g., when `httptest.NewRecorder` is insufficient because the handler flushes):
```go
func TestServer_handleSSE_live(t *testing.T) {
events := make(chan string, 3)
events <- "event 1"
events <- "event 2"
events <- "event 3"
close(events)
srv := NewServer(&MockEventSource{Events: events}, slog.Default())
ts := httptest.NewServer(srv)
defer ts.Close()
resp, err := http.Get(ts.URL + "/api/events")
if err != nil {
t.Fatalf("GET /api/events: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
for _, want := range []string{"event 1", "event 2", "event 3"} {
if !strings.Contains(string(body), want) {
t.Errorf("body missing %q", want)
}
}
}
```
---
## Test Fixtures in testdata/
Go's testing toolchain ignores directories named `testdata`. Use it to store JSON fixtures, SQL seed files, and golden files.
### Directory Structure
```
mypackage/
handler.go
handler_test.go
testdata/
create_user_valid.json
create_user_invalid.json
golden/
user_response.json
sql/
seed_users.sql
```
### Loading Fixtures
```go
func loadFixture(t *testing.T, path string) []byte {
t.Helper()
data, err := os.ReadFile(filepath.Join("testdata", path))
if err != nil {
t.Fatalf("loading fixture %s: %v", path, err)
}
return data
}
func TestServer_handleCreateUser_fromFixture(t *testing.T) {
srv := NewServer(&MockUserStore{
CreateFunc: func(ctx context.Context, u *User) error {
u.ID = "test-id"
return nil
},
}, slog.Default())
body := loadFixture(t, "create_user_valid.json")
req := httptest.NewRequest("POST", "/api/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusCreated, w.Body.String())
}
}
```
### Golden File Testing
Compare handler output against a stored golden file. Update golden files with `-update` flag.
```go
var update = flag.Bool("update", false, "update golden files")
func TestServer_handleGetUser_golden(t *testing.T) {
srv := NewServer(&MockUserStore{
GetUserFunc: func(ctx context.Context, id string) (*User, error) {
return &User{ID: "123", Name: "Alice", Email: "[email protected]"}, nil
},
}, slog.Default())
req := httptest.NewRequest("GET", "/api/users/123", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
goldenPath := filepath.Join("testdata", "golden", "user_response.json")
if *update {
// Pretty-print for readable golden files
var pretty bytes.Buffer
json.Indent(&pretty, w.Body.Bytes(), "", " ")
if err := os.WriteFile(goldenPath, pretty.Bytes(), 0644); err != nil {
t.Fatalf("writing golden file: %v", err)
}
return
}
want, err := os.ReadFile(goldenPath)
if err != nil {
t.Fatalf("reading golden file: %v (run with -update to create)", err)
}
// Normalize both for comparison
var gotPretty, wantPretty bytes.Buffer
json.Indent(&gotPretty, w.Body.Bytes(), "", " ")
json.Indent(&wantPretty, want, "", " ")
if gotPretty.String() != wantPretty.String() {
t.Errorf("response does not match golden file.\ngot:\n%s\nwant:\n%s",
gotPretty.String(), wantPretty.String())
}
}
```
---
## Interface-Based Mocking Pattern
Define narrow interfaces at the consumer and create mock implementations for tests.
### Define the Interface
```go
// In the handler/server package -- not in the store package
type UserStore interface {
GetUser(ctx context.Context, id string) (*User, error)
CreateUser(ctx context.Context, u *User) error
ListUsers(ctx context.Context, limit, offset int) ([]*User, error)
}
```
### Create the Mock
```go
type MockUserStore struct {
GetUserFunc func(ctx context.Context, id string) (*User, error)
CreateUserFunc func(ctx context.Context, u *User) error
ListUsersFunc func(ctx context.Context, limit, offset int) ([]*User, error)
}
func (m *MockUserStore) GetUser(ctx context.Context, id string) (*User, error) {
return m.GetUserFunc(ctx, id)
}
func (m *MockUserStore) CreateUser(ctx context.Context, u *User) error {
return m.CreateUserFunc(ctx, u)
}
func (m *MockUserStore) ListUsers(ctx context.Context, limit, offset int) ([]*User, error) {
return m.ListUsersFunc(ctx, limit, offset)
}
```
### Use in Tests
```go
store := &MockUserStore{
GetUserFunc: func(ctx context.Context, id string) (*User, error) {
if id == "123" {
return &User{ID: "123", Name: "Alice"}, nil
}
return nil, ErrNotFound
},
// CreateUserFunc and ListUsersFunc will panic if called --
// this is intentional. If a test triggers an unexpected call,
// you want to know.
}
srv := NewServer(store, slog.Default())
```
---
## Testing Anti-Patterns
### Calling handler methods directly
```go
// BAD: bypasses routing, middleware, and ServeHTTP
srv.handleGetUser(w, req)
// GOOD: test through the full HTTP stack
srv.ServeHTTP(w, req)
```
### Shared mutable test state
```go
// BAD: tests interfere with each other
var testDB *sql.DB
func TestA(t *testing.T) { /* uses testDB */ }
func TestB(t *testing.T) { /* uses testDB, fails if TestA runs first */ }
// GOOD: each test creates its own dependencies
func TestA(t *testing.T) {
store := &MockUserStore{...}
srv := NewServer(store, slog.Default())
// ...
}
```
### Not testing error responses
```go
// BAD: only tests the happy path
func TestCreateUser(t *testing.T) {
// ... only tests 201 Created
}
// GOOD: tests all outcomes
func TestCreateUser(t *testing.T) {
tests := []struct{...}{
{"valid", ..., 201, ""},
{"missing name", ..., 422, "name"},
{"duplicate email", ..., 409, "email already exists"},
{"store error", ..., 500, "internal server error"},
}
}
```
### Asserting on exact JSON strings
```go
// BAD: brittle -- breaks if field order changes or whitespace differs
if w.Body.String() != `{"id":"123","name":"Alice"}` {
// GOOD: decode and compare structs or check individual fields
var resp User
json.NewDecoder(w.Body).Decode(&resp)
if resp.Name != "Alice" {
```
FILE:references/validation.md
# Input Validation with go-playground/validator
## Common Validation Tags
The `go-playground/validator` package uses struct tags to declare constraints. Here are the most frequently used tags for web APIs.
### String Constraints
```go
type CreatePostRequest struct {
Title string `json:"title" validate:"required,min=1,max=200"`
Slug string `json:"slug" validate:"required,alphanum"`
Body string `json:"body" validate:"required,min=10,max=50000"`
Status string `json:"status" validate:"required,oneof=draft published archived"`
Website string `json:"website" validate:"omitempty,url"`
}
```
| Tag | Description |
|-----|-------------|
| `required` | Field must be present and non-zero |
| `omitempty` | Skip validation if field is zero value |
| `min=N` | Minimum length (string) or value (number) |
| `max=N` | Maximum length (string) or value (number) |
| `len=N` | Exact length |
| `oneof=a b c` | Value must be one of the listed options (space-separated) |
| `alpha` | Letters only |
| `alphanum` | Letters and numbers only |
| `ascii` | ASCII characters only |
### Format Validators
```go
type ContactRequest struct {
Email string `json:"email" validate:"required,email"`
Phone string `json:"phone" validate:"omitempty,e164"`
Website string `json:"website" validate:"omitempty,url"`
IP string `json:"ip" validate:"omitempty,ip"`
}
```
| Tag | Description |
|-----|-------------|
| `email` | Valid email address |
| `url` | Valid URL |
| `uri` | Valid URI |
| `uuid` | Valid UUID (any version) |
| `uuid4` | Valid UUID v4 |
| `ip` | Valid IPv4 or IPv6 address |
| `ipv4` | Valid IPv4 address |
| `e164` | Valid E.164 phone number |
| `json` | Valid JSON string |
### Numeric Constraints
```go
type PaginationRequest struct {
Page int `json:"page" validate:"required,gte=1"`
PageSize int `json:"page_size" validate:"required,gte=1,lte=100"`
}
type ProductRequest struct {
Price float64 `json:"price" validate:"required,gt=0"`
Quantity int `json:"quantity" validate:"required,gte=0,lte=10000"`
Weight float64 `json:"weight" validate:"omitempty,gte=0"`
}
```
| Tag | Description |
|-----|-------------|
| `gt=N` | Greater than N |
| `gte=N` | Greater than or equal to N |
| `lt=N` | Less than N |
| `lte=N` | Less than or equal to N |
| `ne=N` | Not equal to N |
---
## Custom Validators
Register custom validation functions for domain-specific rules.
### Simple Custom Validator
```go
func setupValidator() *validator.Validate {
v := validator.New()
// Register a custom "slug" validator
v.RegisterValidation("slug", func(fl validator.FieldLevel) bool {
val := fl.Field().String()
matched, _ := regexp.MatchString(`^[a-z0-9]+(-[a-z0-9]+)*$`, val)
return matched
})
// Register a custom "strong_password" validator
v.RegisterValidation("strong_password", func(fl validator.FieldLevel) bool {
val := fl.Field().String()
if len(val) < 8 {
return false
}
hasUpper := regexp.MustCompile(`[A-Z]`).MatchString(val)
hasLower := regexp.MustCompile(`[a-z]`).MatchString(val)
hasDigit := regexp.MustCompile(`[0-9]`).MatchString(val)
return hasUpper && hasLower && hasDigit
})
return v
}
```
Usage:
```go
type CreatePostRequest struct {
Slug string `json:"slug" validate:"required,slug"`
}
type RegisterRequest struct {
Password string `json:"password" validate:"required,strong_password"`
}
```
### Custom Validator with Parameters
```go
// Usage: validate:"not_reserved=admin root system"
v.RegisterValidation("not_reserved", func(fl validator.FieldLevel) bool {
val := fl.Field().String()
param := fl.Param() // "admin root system"
reserved := strings.Fields(param)
for _, r := range reserved {
if strings.EqualFold(val, r) {
return false
}
}
return true
})
```
### Using JSON Tag Names in Error Messages
By default, validator uses Go struct field names in errors. Register the JSON tag name function to get API-friendly field names:
```go
v := validator.New()
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
```
Now `e.Field()` returns `"email"` instead of `"Email"` in validation errors.
---
## Nested Struct Validation
Validator automatically descends into nested structs when `validate:"required"` or `validate:"dive"` is used.
### Required Nested Struct
```go
type CreateOrderRequest struct {
Items []OrderItem `json:"items" validate:"required,min=1,dive"`
Address ShippingAddress `json:"address" validate:"required"`
}
type OrderItem struct {
ProductID string `json:"product_id" validate:"required,uuid"`
Quantity int `json:"quantity" validate:"required,gte=1,lte=100"`
}
type ShippingAddress struct {
Street string `json:"street" validate:"required,min=1,max=200"`
City string `json:"city" validate:"required,min=1,max=100"`
State string `json:"state" validate:"required,len=2"`
ZipCode string `json:"zip" validate:"required,numeric,len=5"`
Country string `json:"country" validate:"required,iso3166_1_alpha2"`
}
```
Key points:
- `dive` tells the validator to validate each element inside a slice
- Without `dive`, only the slice itself is checked (length, required)
- Nested structs with `validate:"required"` are validated recursively
### Optional Nested Struct
Use a pointer for optional nested structs:
```go
type UpdateProfileRequest struct {
Name string `json:"name" validate:"omitempty,min=1,max=100"`
Address *ShippingAddress `json:"address" validate:"omitempty"`
}
```
When `Address` is `nil`, validation is skipped. When present, all its field rules apply.
---
## Slice and Map Validation
### Slice Validation
```go
type BulkCreateRequest struct {
// Validate the slice itself (1-50 items) AND each element
Users []CreateUserRequest `json:"users" validate:"required,min=1,max=50,dive"`
}
```
The `dive` tag means: after validating the slice-level constraints (`min=1,max=50`), validate each element according to its own struct tags.
### Slice of Primitives
```go
type TagRequest struct {
Tags []string `json:"tags" validate:"required,min=1,max=10,dive,required,min=1,max=50"`
}
```
Reading left to right:
1. `required` -- slice must be present
2. `min=1,max=10` -- slice must have 1-10 elements
3. `dive` -- now validate each element
4. `required,min=1,max=50` -- each string must be non-empty and max 50 chars
### Map Validation
```go
type MetadataRequest struct {
// Validate keys and values separately
Metadata map[string]string `json:"metadata" validate:"required,max=20,dive,keys,min=1,max=50,endkeys,required,max=500"`
}
```
Reading left to right:
1. `required,max=20` -- map is required, max 20 entries
2. `dive` -- enter the map
3. `keys,min=1,max=50,endkeys` -- each key must be 1-50 chars
4. `required,max=500` -- each value must be non-empty and max 500 chars
---
## Cross-Field Validation
Validate fields relative to each other using `eqfield`, `nefield`, `gtfield`, etc.
### Password Confirmation
```go
type RegisterRequest struct {
Email string `json:"email" validate:"required,email"`
Password string `json:"password" validate:"required,min=8,max=72"`
ConfirmPassword string `json:"confirm_password" validate:"required,eqfield=Password"`
}
```
### Date Range Validation
```go
type DateRangeRequest struct {
StartDate time.Time `json:"start_date" validate:"required"`
EndDate time.Time `json:"end_date" validate:"required,gtfield=StartDate"`
}
```
### Cross-Field Tags
| Tag | Description |
|-----|-------------|
| `eqfield=Other` | Must equal the value of `Other` |
| `nefield=Other` | Must not equal the value of `Other` |
| `gtfield=Other` | Must be greater than `Other` |
| `gtefield=Other` | Must be greater than or equal to `Other` |
| `ltfield=Other` | Must be less than `Other` |
| `ltefield=Other` | Must be less than or equal to `Other` |
### Struct-Level Validation
For complex cross-field rules that cannot be expressed with tags, use struct-level validation:
```go
v.RegisterStructValidation(func(sl validator.StructLevel) {
req := sl.Current().Interface().(CreateEventRequest)
if req.EndDate.Before(req.StartDate) {
sl.ReportError(req.EndDate, "end_date", "EndDate", "after_start", "")
}
if req.MaxAttendees > 0 && req.MinAttendees > req.MaxAttendees {
sl.ReportError(req.MinAttendees, "min_attendees", "MinAttendees", "lte_max", "")
}
}, CreateEventRequest{})
```
---
## Error Message Formatting for API Responses
### Basic Formatting
```go
func formatValidationErrors(err error) string {
var msgs []string
for _, e := range err.(validator.ValidationErrors) {
msgs = append(msgs, fmt.Sprintf("field '%s' failed on '%s'", e.Field(), e.Tag()))
}
return strings.Join(msgs, "; ")
}
```
### Structured JSON Error Response
For richer API responses, return field-level errors as a map:
```go
type ValidationError struct {
Field string `json:"field"`
Message string `json:"message"`
}
func formatValidationErrorsJSON(err error) []ValidationError {
var errs []ValidationError
for _, e := range err.(validator.ValidationErrors) {
errs = append(errs, ValidationError{
Field: e.Field(),
Message: msgForTag(e),
})
}
return errs
}
func msgForTag(e validator.FieldError) string {
switch e.Tag() {
case "required":
return "this field is required"
case "email":
return "must be a valid email address"
case "min":
return fmt.Sprintf("must be at least %s characters", e.Param())
case "max":
return fmt.Sprintf("must be at most %s characters", e.Param())
case "oneof":
return fmt.Sprintf("must be one of: %s", e.Param())
case "uuid":
return "must be a valid UUID"
case "gte":
return fmt.Sprintf("must be at least %s", e.Param())
case "lte":
return fmt.Sprintf("must be at most %s", e.Param())
case "eqfield":
return fmt.Sprintf("must match %s", e.Param())
default:
return fmt.Sprintf("failed validation: %s", e.Tag())
}
}
```
### Usage in Handler
```go
func (s *Server) handleCreateUser(w http.ResponseWriter, r *http.Request) error {
var req CreateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return &AppError{Code: 400, Message: "invalid JSON"}
}
if err := validate.Struct(req); err != nil {
errs := formatValidationErrorsJSON(err)
writeJSON(w, 422, map[string]any{
"error": "validation failed",
"fields": errs,
})
return nil
}
// proceed with validated request...
return nil
}
```
Example API response:
```json
{
"error": "validation failed",
"fields": [
{"field": "email", "message": "must be a valid email address"},
{"field": "name", "message": "this field is required"}
]
}
```
---
## Validation Anti-Patterns
### Validating in the service layer
```go
// BAD: validation scattered across layers
func (s *UserService) Create(ctx context.Context, name, email string) (*User, error) {
if name == "" {
return nil, errors.New("name required") // should be caught at handler
}
}
```
Validate once at the boundary. Services receive trusted data.
### Using validate tags without dive on slices
```go
// BAD: only checks slice length, not element contents
Items []OrderItem `json:"items" validate:"required,min=1"`
// GOOD: dive validates each element
Items []OrderItem `json:"items" validate:"required,min=1,dive"`
```
### Ignoring the difference between required and omitempty
```go
// Required: field must be present and non-zero
Name string `validate:"required"` // "" is invalid
// Omitempty: skip validation if zero, validate if present
Bio string `validate:"omitempty,min=10,max=1000"` // "" is valid, "short" is invalid
```
### Not limiting request body size
```go
// BAD: attacker can send gigabytes
json.NewDecoder(r.Body).Decode(&req)
// GOOD: limit body size
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1MB limit
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// MaxBytesError is returned if the limit is exceeded
return &AppError{Code: 413, Message: "request body too large"}
}
```
Reviews Go test code for proper table-driven tests, assertions, and coverage patterns. Use when reviewing *_test.go files.
---
name: go-testing-code-review
description: Reviews Go test code for proper table-driven tests, assertions, and coverage patterns. Use when reviewing *_test.go files.
---
# Go Testing Code Review
## Review Workflow
Follow this sequence **in order**. Do not emit findings until every **Pass** below is satisfied.
1. **Baseline `go.mod`** — Open `go.mod` for the module under review and read the `go` directive.
**Pass:** You can state the exact `go X.YY` value (in the review preamble or working notes). Apply version-gated advice only when it matches this baseline (e.g. fuzz tests Go 1.18+, loop-variable capture pre-Go 1.22).
2. **Read surrounding tests** — For each `*_test.go` (or benchmark/fuzz file) in scope, read full test functions and any table `struct{...}` / helpers they use, not only the diff hunk.
**Pass:** At least one full `func Test...` / `func Benchmark...` / `func Fuzz...` (or helper it calls) containing the change was read per in-scope file.
3. **Scope the checklist** — Decide which [Review Checklist](#review-checklist) rows apply (table-driven structure, parallelism, HTTP, golden files, mocks). Open [references/structure.md](references/structure.md) and/or [references/mocking.md](references/mocking.md) for those topics; skip rows N/A to the diff with a one-line reason (e.g. “no `t.Parallel` in change”).
**Pass:** The review (or working notes) lists which checklist themes you applied, or marks themes N/A with a diff-tied reason.
4. **Pre-report verification** — Load and follow [review-verification-protocol](../review-verification-protocol/SKILL.md).
**Pass:** The protocol’s **Pre-Report Verification Checklist** is satisfied for each finding you will report (actual test code read, surrounding context checked, “wrong” vs “different style” distinguished, etc.).
## Hard gates (same sequence, shorter)
| Step | Objective pass condition |
| --- | --- |
| 1 | `go X.YY` from `go.mod` is recorded before version-specific test advice. |
| 2 | Full enclosing test (or helper it uses) read per in-scope test file, not diff-only. |
| 3 | In-scope checklist themes listed or N/A with diff-tied reason; references opened as needed. |
| 4 | `review-verification-protocol` completed for every reported issue. |
## Output Format
Report findings as:
```text
[FILE:LINE] ISSUE_TITLE
Severity: Critical | Major | Minor | Informational
Description of the issue and why it matters.
```
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| Test structure, naming | [references/structure.md](references/structure.md) |
| Mocking, interfaces | [references/mocking.md](references/mocking.md) |
## Review Checklist
- [ ] Tests are table-driven with clear case names
- [ ] Subtests use t.Run for parallel execution
- [ ] Test names describe behavior, not implementation
- [ ] Errors include got/want with descriptive message
- [ ] Cleanup registered with t.Cleanup
- [ ] Parallel tests don't share mutable state
- [ ] Mocks use interfaces defined in test file
- [ ] Coverage includes edge cases and error paths
- [ ] Performance-critical functions have `Benchmark*` tests
- [ ] Input parsers/validators have `Fuzz*` tests (Go 1.18+)
- [ ] HTTP handlers tested with `httptest.NewRequest`/`httptest.NewRecorder`
- [ ] Golden file tests use `testdata/*.golden` pattern with `-update` flag
## Critical Patterns
### Table-Driven Tests
```go
// BAD - repetitive
func TestAdd(t *testing.T) {
if Add(1, 2) != 3 {
t.Error("wrong")
}
if Add(0, 0) != 0 {
t.Error("wrong")
}
}
// GOOD
func TestAdd(t *testing.T) {
tests := []struct {
name string
a, b int
want int
}{
{"positive numbers", 1, 2, 3},
{"zeros", 0, 0, 0},
{"negative", -1, 1, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Add(tt.a, tt.b)
if got != tt.want {
t.Errorf("Add(%d, %d) = %d, want %d", tt.a, tt.b, got, tt.want)
}
})
}
}
```
### Error Messages
```go
// BAD
if got != want {
t.Error("wrong result")
}
// GOOD
if got != want {
t.Errorf("GetUser(%d) = %v, want %v", id, got, want)
}
// For complex types
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("GetUser() mismatch (-want +got):\n%s", diff)
}
```
### Parallel Tests
```go
func TestFoo(t *testing.T) {
tests := []struct{...}
for _, tt := range tests {
tt := tt // capture (not needed Go 1.22+)
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// test code
})
}
}
```
### Cleanup
```go
// BAD - manual cleanup, skipped on failure
func TestWithTempFile(t *testing.T) {
f, _ := os.CreateTemp("", "test")
defer os.Remove(f.Name()) // skipped if test panics
}
// GOOD
func TestWithTempFile(t *testing.T) {
f, _ := os.CreateTemp("", "test")
t.Cleanup(func() {
os.Remove(f.Name())
})
}
```
## Additional Patterns
### Benchmarks
```go
func BenchmarkProcess(b *testing.B) {
data := generateTestData(1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
Process(data)
}
}
// Run: go test -bench=BenchmarkProcess -benchmem
```
### Fuzz Tests (Go 1.18+)
```go
func FuzzParseInput(f *testing.F) {
// Seed corpus
f.Add(`{"name": "test"}`)
f.Add(``)
f.Add(`{invalid}`)
f.Fuzz(func(t *testing.T, input string) {
result, err := ParseInput(input)
if err != nil {
return // invalid input is expected
}
// If parsing succeeded, re-encoding should work
if _, err := json.Marshal(result); err != nil {
t.Errorf("Marshal after Parse: %v", err)
}
})
}
// Run: go test -fuzz=FuzzParseInput -fuzztime=30s
```
### HTTP Handler Tests
```go
func TestHandler(t *testing.T) {
srv := NewServer(mockDeps)
req := httptest.NewRequest("GET", "/api/users/123", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
}
```
### Golden Files
```go
var update = flag.Bool("update", false, "update golden files")
func TestRender(t *testing.T) {
got := Render(input)
golden := filepath.Join("testdata", t.Name()+".golden")
if *update {
if err := os.WriteFile(golden, got, 0644); err != nil {
t.Fatalf("writing golden file: %v", err)
}
}
want, err := os.ReadFile(golden)
if err != nil {
t.Fatalf("reading golden file: %v (run with -update to create)", err)
}
if !bytes.Equal(got, want) {
t.Errorf("output mismatch:\ngot:\n%s\nwant:\n%s", got, want)
}
}
```
## Anti-Patterns
### 1. Testing Internal Implementation
```go
// BAD - tests private state
func TestUser(t *testing.T) {
u := NewUser("alice")
if u.id != 1 { // testing internal field
t.Error("wrong id")
}
}
// GOOD - tests behavior
func TestUser(t *testing.T) {
u := NewUser("alice")
if u.ID() != 1 {
t.Error("wrong ID")
}
}
```
### 2. Shared Mutable State
```go
// BAD - tests interfere with each other
var testDB = setupDB()
func TestA(t *testing.T) {
t.Parallel()
testDB.Insert(...) // race!
}
// GOOD - isolated per test
func TestA(t *testing.T) {
db := setupTestDB(t)
t.Cleanup(func() { db.Close() })
db.Insert(...)
}
```
### 3. Assertions Without Context
```go
// BAD
assert.Equal(t, want, got) // "expected X got Y" - which test?
// GOOD
assert.Equal(t, want, got, "user name after update")
```
## When to Load References
- Reviewing test file structure → structure.md
- Reviewing mock implementations → mocking.md
## Review Questions
1. Are tests table-driven with named cases?
2. Do error messages include input, got, and want?
3. Are parallel tests isolated (no shared state)?
4. Is cleanup done via t.Cleanup?
5. Do tests verify behavior, not implementation?
FILE:references/mocking.md
# Mocking
## Interface-Based Mocking
### 1. Define Interface in Consumer
```go
// service.go
type UserStore interface {
Get(id int) (*User, error)
}
type UserService struct {
store UserStore
}
func (s *UserService) GetUser(id int) (*User, error) {
return s.store.Get(id)
}
```
### 2. Create Mock in Test File
```go
// service_test.go
type mockUserStore struct {
users map[int]*User
err error
}
func (m *mockUserStore) Get(id int) (*User, error) {
if m.err != nil {
return nil, m.err
}
user, ok := m.users[id]
if !ok {
return nil, ErrNotFound
}
return user, nil
}
func TestGetUser(t *testing.T) {
mock := &mockUserStore{
users: map[int]*User{
1: {ID: 1, Name: "Alice"},
},
}
svc := &UserService{store: mock}
user, err := svc.GetUser(1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if user.Name != "Alice" {
t.Errorf("name = %s, want Alice", user.Name)
}
}
```
### 3. Functional Mock Pattern
```go
// More flexible for varying behavior per test
type mockUserStore struct {
getFn func(id int) (*User, error)
}
func (m *mockUserStore) Get(id int) (*User, error) {
return m.getFn(id)
}
func TestGetUser_Error(t *testing.T) {
mock := &mockUserStore{
getFn: func(id int) (*User, error) {
return nil, errors.New("db error")
},
}
svc := &UserService{store: mock}
_, err := svc.GetUser(1)
if err == nil {
t.Error("expected error, got nil")
}
}
```
## Testing HTTP Clients
### 1. httptest Server
```go
func TestFetchUser(t *testing.T) {
// Create test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/users/1" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"id": 1, "name": "Alice"}`))
}))
defer ts.Close()
// Use test server URL
client := NewClient(ts.URL)
user, err := client.FetchUser(1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if user.Name != "Alice" {
t.Errorf("name = %s, want Alice", user.Name)
}
}
```
### 2. RoundTripper Mock
```go
type mockTransport struct {
response *http.Response
err error
}
func (m *mockTransport) RoundTrip(*http.Request) (*http.Response, error) {
return m.response, m.err
}
func TestClient_Error(t *testing.T) {
client := &http.Client{
Transport: &mockTransport{
err: errors.New("network error"),
},
}
_, err := FetchData(client, "http://example.com")
if err == nil {
t.Error("expected error")
}
}
```
## Testing Time
### 1. Inject Time Function
```go
// Code
type Service struct {
now func() time.Time
}
func (s *Service) IsExpired(expiry time.Time) bool {
return s.now().After(expiry)
}
// Test
func TestIsExpired(t *testing.T) {
fixedTime := time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC)
svc := &Service{
now: func() time.Time { return fixedTime },
}
tests := []struct {
name string
expiry time.Time
want bool
}{
{"past", fixedTime.Add(-time.Hour), true},
{"future", fixedTime.Add(time.Hour), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.IsExpired(tt.expiry)
if got != tt.want {
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
}
})
}
}
```
## Testing Filesystem
### 1. fstest.MapFS
```go
import "testing/fstest"
func TestReadConfig(t *testing.T) {
fs := fstest.MapFS{
"config.json": &fstest.MapFile{
Data: []byte(`{"key": "value"}`),
},
}
cfg, err := ReadConfig(fs, "config.json")
if err != nil {
t.Fatal(err)
}
if cfg.Key != "value" {
t.Errorf("key = %s, want value", cfg.Key)
}
}
```
### 2. T.TempDir
```go
func TestWriteFile(t *testing.T) {
dir := t.TempDir() // automatically cleaned up
path := filepath.Join(dir, "test.txt")
err := WriteFile(path, "content")
if err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(path)
if string(data) != "content" {
t.Errorf("got %q, want content", data)
}
}
```
## Verifying Calls
### 1. Call Recording
```go
type mockStore struct {
getCalls []int
}
func (m *mockStore) Get(id int) (*User, error) {
m.getCalls = append(m.getCalls, id)
return &User{ID: id}, nil
}
func TestBatchGet(t *testing.T) {
mock := &mockStore{}
svc := &Service{store: mock}
svc.BatchGet([]int{1, 2, 3})
if len(mock.getCalls) != 3 {
t.Errorf("Get called %d times, want 3", len(mock.getCalls))
}
if !slices.Equal(mock.getCalls, []int{1, 2, 3}) {
t.Errorf("Get called with %v, want [1,2,3]", mock.getCalls)
}
}
```
## Anti-Patterns
### 1. Over-Mocking
```go
// BAD - mocking everything
func TestAdd(t *testing.T) {
mockCalc := &mockCalculator{}
// just test the actual function!
}
// GOOD - only mock external dependencies
func TestService(t *testing.T) {
mockDB := &mockDB{} // external dependency
svc := NewService(mockDB)
// test service logic
}
```
### 2. Mocking Concrete Types
```go
// BAD - can't inject mock
type Service struct {
store *PostgresStore
}
// GOOD - interface allows mocking
type Service struct {
store Store // interface
}
```
## Review Questions
1. Are interfaces defined by consumers, not producers?
2. Are mocks minimal (only implement what's tested)?
3. Are test servers used for HTTP testing?
4. Is time injected for time-dependent tests?
5. Are call recordings used to verify interactions?
FILE:references/structure.md
# Test Structure
## File Organization
### 1. Test File Location
```
package/
├── user.go
├── user_test.go # same package tests
├── user_internal_test.go # internal tests if needed
└── testdata/ # test fixtures
└── users.json
```
### 2. Test Naming Convention
```go
// Function test
func TestFunctionName(t *testing.T) {}
// Method test
func TestTypeName_MethodName(t *testing.T) {}
// Scenario test
func TestGetUser_WhenNotFound_ReturnsError(t *testing.T) {}
```
## Test Patterns
### 1. Setup and Teardown
```go
func TestMain(m *testing.M) {
// Global setup
setup()
code := m.Run()
// Global teardown
teardown()
os.Exit(code)
}
// Per-test setup
func TestFoo(t *testing.T) {
db := setupTestDB(t)
t.Cleanup(func() {
db.Close()
})
}
```
### 2. Helper Functions
```go
// Mark as helper for better stack traces
func assertNoError(t *testing.T, err error) {
t.Helper() // marks this as helper
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func createTestUser(t *testing.T, name string) *User {
t.Helper()
u, err := NewUser(name)
if err != nil {
t.Fatalf("creating test user: %v", err)
}
return u
}
```
### 3. Testdata Directory
```go
func TestParseConfig(t *testing.T) {
// Load from testdata directory
data, err := os.ReadFile("testdata/config.json")
if err != nil {
t.Fatal(err)
}
cfg, err := ParseConfig(data)
// ...
}
```
## Table-Driven Tests
### 1. Basic Structure
```go
func TestParse(t *testing.T) {
tests := []struct {
name string
input string
want int
wantErr bool
}{
{
name: "valid number",
input: "42",
want: 42,
},
{
name: "invalid input",
input: "abc",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Parse(tt.input)
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.want {
t.Errorf("Parse(%q) = %d, want %d", tt.input, got, tt.want)
}
})
}
}
```
### 2. With Setup Function
```go
func TestHandler(t *testing.T) {
tests := []struct {
name string
setup func() *Handler
input Request
wantStatus int
}{
{
name: "authorized user",
setup: func() *Handler {
return NewHandler(WithAuth(true))
},
input: Request{UserID: 1},
wantStatus: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := tt.setup()
resp := h.Handle(tt.input)
if resp.Status != tt.wantStatus {
t.Errorf("status = %d, want %d", resp.Status, tt.wantStatus)
}
})
}
}
```
### 3. With Assertions
```go
func TestProcess(t *testing.T) {
tests := []struct {
name string
input []int
check func(t *testing.T, result []int)
}{
{
name: "preserves order",
input: []int{3, 1, 2},
check: func(t *testing.T, result []int) {
if !slices.Equal(result, []int{1, 2, 3}) {
t.Errorf("got %v, want sorted", result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := Process(tt.input)
tt.check(t, result)
})
}
}
```
## Parallel Testing
### 1. Top-Level Parallel
```go
func TestFoo(t *testing.T) {
t.Parallel() // this test runs in parallel with others
// test code
}
```
### 2. Subtests Parallel
```go
func TestAll(t *testing.T) {
tests := []struct{...}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel() // subtests run in parallel
// test code using tt
})
}
}
```
### 3. Avoiding Race Conditions
```go
// Before Go 1.22, capture loop variable
for _, tt := range tests {
tt := tt // capture!
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// use tt safely
})
}
// Go 1.22+: not needed, loop variable is per-iteration
```
## Error Assertions
### 1. Using errors.Is
```go
func TestGetUser_NotFound(t *testing.T) {
_, err := GetUser(999)
if !errors.Is(err, ErrNotFound) {
t.Errorf("got %v, want ErrNotFound", err)
}
}
```
### 2. Using errors.As
```go
func TestValidate(t *testing.T) {
err := Validate(invalidInput)
var validErr *ValidationError
if !errors.As(err, &validErr) {
t.Fatalf("expected ValidationError, got %T", err)
}
if validErr.Field != "email" {
t.Errorf("field = %s, want email", validErr.Field)
}
}
```
## Benchmarks and Fuzzing
### Benchmark File Organization
Benchmarks can live in the same `*_test.go` file as unit tests, or in a dedicated `*_bench_test.go` file for large suites:
```text
package/
├── parser.go
├── parser_test.go # unit tests
├── parser_bench_test.go # benchmarks (optional, for large suites)
└── testdata/
└── corpus/ # fuzz seed corpus
```
### Benchmark Naming
```go
// Function benchmark
func BenchmarkFunctionName(b *testing.B) {}
// Method benchmark
func BenchmarkTypeName_Method(b *testing.B) {}
```
### Sub-Benchmarks for Input Sizes
```go
func BenchmarkProcess(b *testing.B) {
sizes := []int{10, 100, 1000, 10000}
for _, size := range sizes {
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) {
data := generateTestData(size)
b.ResetTimer()
for i := 0; i < b.N; i++ {
Process(data)
}
})
}
}
```
### Fuzz Test Seed Corpus
Place seed corpus files in `testdata/fuzz/<FuzzTestName>/`:
```text
package/
└── testdata/
└── fuzz/
└── FuzzParseInput/
├── seed1 # each file contains one corpus entry
└── seed2
```
Go will also auto-generate corpus entries in `$GOCACHE/fuzz/` during fuzzing runs.
### Running Benchmarks in CI
```bash
# Run all benchmarks with memory stats
go test -bench=. -benchmem ./...
# Compare benchmarks across commits (using benchstat)
go test -bench=. -benchmem -count=5 ./... > old.txt
# make changes
go test -bench=. -benchmem -count=5 ./... > new.txt
benchstat old.txt new.txt
```
## Golden Files
### testdata Directory for Golden Files
Store expected outputs as golden files in the `testdata/` directory:
```text
package/
├── render.go
├── render_test.go
└── testdata/
├── TestRender/simple.golden
├── TestRender/complex.golden
└── TestRender/empty.golden
```
### The `-update` Flag Pattern
```go
var update = flag.Bool("update", false, "update golden files")
func TestRender(t *testing.T) {
got := Render(input)
golden := filepath.Join("testdata", t.Name()+".golden")
if *update {
if err := os.MkdirAll(filepath.Dir(golden), 0755); err != nil {
t.Fatalf("creating golden dir: %v", err)
}
if err := os.WriteFile(golden, got, 0644); err != nil {
t.Fatalf("writing golden file: %v", err)
}
}
want, err := os.ReadFile(golden)
if err != nil {
t.Fatalf("reading golden file: %v (run with -update to create)", err)
}
if !bytes.Equal(got, want) {
t.Errorf("output mismatch:\ngot:\n%s\nwant:\n%s", got, want)
}
}
```
Run `go test -update ./...` to regenerate golden files after intentional changes.
### When to Use Golden Files
- **Complex output**: Rendered templates, formatted text, serialized data
- **Serialization formats**: JSON, YAML, protobuf text format
- **Code generation**: Generated source files, SQL migrations
- **Snapshot testing**: CLI output, error messages, log formatting
Golden files are preferable to inline expected values when output is large, multi-line, or changes infrequently.
## Review Questions
1. Are test files colocated with source files?
2. Do test names describe the scenario?
3. Are helper functions marked with t.Helper()?
4. Are parallel tests properly isolated?
5. Are fixtures in testdata directory?
Idiomatic Go HTTP middleware patterns with context propagation, structured logging via slog, centralized error handling, and panic recovery. Use when writing...
---
name: go-middleware
description: Idiomatic Go HTTP middleware patterns with context propagation, structured logging via slog, centralized error handling, and panic recovery. Use when writing middleware, adding request tracing, or implementing cross-cutting concerns.
---
# Go HTTP Middleware
## Quick Reference
| Topic | Reference |
|-------|-----------|
| Context keys, request IDs, user metadata | [references/context-propagation.md](references/context-propagation.md) |
| slog setup, logging middleware, child loggers | [references/structured-logging.md](references/structured-logging.md) |
| AppHandler pattern, domain errors, recovery | [references/error-handling-middleware.md](references/error-handling-middleware.md) |
## Middleware Signature
All middleware follows the standard `func(http.Handler) http.Handler` pattern. This is the composable building block for cross-cutting concerns in Go HTTP servers.
```go
// Standard middleware signature
func RequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-ID")
if id == "" {
id = uuid.New().String()
}
ctx := context.WithValue(r.Context(), requestIDKey, id)
w.Header().Set("X-Request-ID", id)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Type-safe context keys
type contextKey string
const requestIDKey contextKey = "request_id"
func RequestIDFromContext(ctx context.Context) string {
id, _ := ctx.Value(requestIDKey).(string)
return id
}
```
Key points:
- Accept `http.Handler`, return `http.Handler` -- always
- Call `next.ServeHTTP(w, r)` to pass control to the next handler
- Work before the call (pre-processing) or after (post-processing) or both
- Use `r.WithContext(ctx)` to propagate new context values downstream
## Context Propagation
Use `context.WithValue` for request-scoped data that crosses API boundaries (request IDs, authenticated users, tenant IDs). Always use typed keys to avoid collisions.
```go
type contextKey string
const (
requestIDKey contextKey = "request_id"
userKey contextKey = "user"
)
```
Provide typed helper functions for extraction:
```go
func RequestIDFromContext(ctx context.Context) string {
id, _ := ctx.Value(requestIDKey).(string)
return id
}
```
See [references/context-propagation.md](references/context-propagation.md) for user metadata patterns, downstream propagation, and timeouts.
## Structured Logging
Use `slog` (standard library, Go 1.21+) for structured logging in middleware. Wrap `http.ResponseWriter` to capture the status code.
```go
func Logger(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrapped := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(wrapped, r)
logger.Info("request completed",
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.status,
"duration_ms", time.Since(start).Milliseconds(),
"request_id", RequestIDFromContext(r.Context()),
)
})
}
}
```
See [references/structured-logging.md](references/structured-logging.md) for JSON/text handler setup, log levels, and child loggers.
## Centralized Error Handling
Define a custom handler type that returns `error` so handlers don't need to write error responses themselves:
```go
type AppHandler func(w http.ResponseWriter, r *http.Request) error
func (fn AppHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil {
handleError(w, r, err)
}
}
```
Map domain errors to HTTP status codes in a single `handleError` function. Never leak internal error details to clients.
See [references/error-handling-middleware.md](references/error-handling-middleware.md) for the full pattern with `AppError`, `errors.As`, and JSON responses.
## Recovery Middleware
Catch panics to prevent a single bad request from crashing the server:
```go
func Recovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
slog.Error("panic recovered",
"panic", rec,
"stack", string(debug.Stack()),
"request_id", RequestIDFromContext(r.Context()),
)
writeJSON(w, 500, map[string]string{"error": "internal server error"})
}
}()
next.ServeHTTP(w, r)
})
}
```
Recovery must be the **outermost** middleware so it catches panics from all inner middleware and handlers. See [references/error-handling-middleware.md](references/error-handling-middleware.md) for details.
## Middleware Chain Ordering
Apply middleware outermost-first. The first middleware in the chain wraps all others.
```go
// Nested style (outermost first)
handler := Recovery(
RequestID(
Logger(
Auth(
router,
),
),
),
)
// Or with a chain helper
func Chain(h http.Handler, middleware ...func(http.Handler) http.Handler) http.Handler {
for i := len(middleware) - 1; i >= 0; i-- {
h = middleware[i](h)
}
return h
}
handler := Chain(router, Recovery, RequestID, Logger(slog.Default()), Auth)
```
### Recommended Order
1. **Recovery** -- outermost; catches panics from all inner middleware
2. **RequestID** -- assign early so all subsequent middleware can reference it
3. **Logger** -- logs the completed request with ID and status
4. **Auth** -- after logging so failed auth attempts are recorded
5. **Application-specific middleware** -- rate limiting, CORS, etc.
## Gates (check before merge or review)
Use these **sequenced** checks for objective pass/fail; do not replace them with “I verified mentally.”
1. **Recovery position**
- Locate where the server builds the middleware chain (e.g. `main`, router `Use`, or a `Chain` helper).
- **Pass:** Recovery wraps all other middleware and the final handler per [Middleware Chain Ordering](#middleware-chain-ordering) (outermost in nested style, or correct `Chain` argument order for your helper). Cite file path and the full chain snippet.
2. **Status-aware middleware uses a wrapped `ResponseWriter`**
- If middleware logs or records HTTP status after the handler runs, it must pass a wrapper into `next.ServeHTTP`, not the original writer alone.
- **Pass:** snippet shows `next.ServeHTTP(wrapped, r)` (or equivalent) when status is observed after `next` returns.
3. **Every forward path calls `next`**
- Scan each middleware’s control flow.
- **Pass:** no branch drops the request without calling `next.ServeHTTP` unless that branch intentionally sends a response (e.g. auth failure); those short-circuits are obvious in code review.
## Anti-patterns
### Using string or int context keys
```go
// BAD: collisions with other packages
ctx = context.WithValue(ctx, "user", user)
// GOOD: unexported typed key
type contextKey string
const userKey contextKey = "user"
ctx = context.WithValue(ctx, userKey, user)
```
### Writing response before calling next
```go
// BAD: writes response then continues chain
func Bad(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) // too early!
next.ServeHTTP(w, r)
})
}
```
### Forgetting to call next.ServeHTTP
```go
// BAD: swallows the request
func Bad(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("got request")
// forgot next.ServeHTTP(w, r)
})
}
```
### Storing large objects in context
Context values should be small, request-scoped metadata (IDs, tokens, user structs). Never store database connections, file handles, or large payloads.
### Using context.WithValue for function parameters
If a function needs a value to do its job, pass it as an explicit parameter. Context is for cross-cutting metadata that passes through APIs, not for avoiding function signatures.
### Recovery middleware in the wrong position
If recovery is not the outermost middleware, panics in outer middleware will crash the server. Always apply recovery first.
FILE:references/context-propagation.md
# Context Propagation in Go Middleware
## Type-Safe Context Keys
Never use plain `string` or `int` as context keys. Define an unexported type so keys from different packages cannot collide.
```go
// Define in your middleware package
type contextKey string
const (
requestIDKey contextKey = "request_id"
userKey contextKey = "user"
tenantIDKey contextKey = "tenant_id"
)
```
Why this matters:
- `context.WithValue` uses interface equality for key comparison
- Two packages using `"user"` as a string key would overwrite each other
- An unexported `contextKey` type is unique to your package
## Request ID Propagation
Assign a request ID early in the middleware chain. Propagate it through context so every layer can include it in logs, error reports, and outgoing requests.
```go
func RequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-ID")
if id == "" {
id = uuid.New().String()
}
ctx := context.WithValue(r.Context(), requestIDKey, id)
w.Header().Set("X-Request-ID", id)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func RequestIDFromContext(ctx context.Context) string {
id, _ := ctx.Value(requestIDKey).(string)
return id
}
```
Usage in downstream code:
```go
func handleOrder(w http.ResponseWriter, r *http.Request) {
reqID := RequestIDFromContext(r.Context())
slog.Info("processing order", "request_id", reqID)
// Pass to outgoing HTTP calls
outReq, _ := http.NewRequestWithContext(r.Context(), "GET", url, nil)
outReq.Header.Set("X-Request-ID", reqID)
}
```
## User Metadata
Store authenticated user information in context after validation in auth middleware.
```go
type User struct {
ID string
Email string
Roles []string
}
const userKey contextKey = "user"
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
user, err := validateToken(token)
if err != nil {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), userKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func UserFromContext(ctx context.Context) (*User, bool) {
u, ok := ctx.Value(userKey).(*User)
return u, ok
}
```
### Multi-Tenant Context
For multi-tenant applications, propagate the tenant ID alongside the user:
```go
const tenantIDKey contextKey = "tenant_id"
func TenantMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := UserFromContext(r.Context())
if !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
tenantID := extractTenantID(user)
ctx := context.WithValue(r.Context(), tenantIDKey, tenantID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func TenantIDFromContext(ctx context.Context) string {
id, _ := ctx.Value(tenantIDKey).(string)
return id
}
```
## Typed Helper Functions
Always provide exported helper functions for extracting context values. This encapsulates the key and type assertion in one place.
```go
// Good: callers use typed helpers
user, ok := UserFromContext(ctx)
reqID := RequestIDFromContext(ctx)
tenantID := TenantIDFromContext(ctx)
// Bad: callers reach into context directly
user := ctx.Value("user").(*User) // unsafe, untyped key
```
Always check the `ok` return from type assertions:
```go
func UserFromContext(ctx context.Context) (*User, bool) {
u, ok := ctx.Value(userKey).(*User)
return u, ok
}
// In handlers
user, ok := UserFromContext(r.Context())
if !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
```
## Passing Context to Downstream Services
### Database Queries
Pass `r.Context()` to database calls so they respect request cancellation:
```go
func getUser(ctx context.Context, db *sql.DB, id string) (*User, error) {
row := db.QueryRowContext(ctx, "SELECT id, email FROM users WHERE id = $1", id)
var u User
if err := row.Scan(&u.ID, &u.Email); err != nil {
return nil, fmt.Errorf("querying user %s: %w", id, err)
}
return &u, nil
}
```
### Outgoing HTTP Requests
Use `http.NewRequestWithContext` to propagate cancellation and pass along tracing headers:
```go
func callDownstream(ctx context.Context, url string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("X-Request-ID", RequestIDFromContext(ctx))
return http.DefaultClient.Do(req)
}
```
## Context Timeout and Cancellation
For long operations, derive a context with a timeout to prevent requests from hanging:
```go
func slowHandler(w http.ResponseWriter, r *http.Request) {
// Give the operation 5 seconds max
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
result, err := longRunningQuery(ctx)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
http.Error(w, "request timed out", http.StatusGatewayTimeout)
return
}
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(result)
}
```
### Timeout Middleware
Apply a blanket timeout to all requests:
```go
func Timeout(duration time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), duration)
defer cancel()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
```
Note: this cancels the context but does not stop the handler goroutine. Handlers must check `ctx.Done()` or use context-aware I/O to actually stop work.
## Anti-patterns
### Using context.WithValue for function parameters
```go
// BAD: hiding dependencies in context
ctx = context.WithValue(ctx, "db", db)
// ...later...
db := ctx.Value("db").(*sql.DB)
// GOOD: explicit parameter
func handleOrder(ctx context.Context, db *sql.DB, orderID string) error {
// ...
}
```
Context is for request-scoped metadata that crosses API boundaries, not for dependency injection.
### Storing large objects in context
```go
// BAD: large payload in context
ctx = context.WithValue(ctx, "body", largeRequestBody)
// GOOD: pass as parameter or store a reference/ID
ctx = context.WithValue(ctx, requestIDKey, reqID)
```
### Not checking ok from type assertion
```go
// BAD: panics if value is nil or wrong type
user := ctx.Value(userKey).(*User)
// GOOD: always check
user, ok := ctx.Value(userKey).(*User)
if !ok {
return ErrUnauthorized
}
```
FILE:references/error-handling-middleware.md
# Centralized Error Handling and Recovery Middleware
## The AppHandler Pattern
Standard `http.HandlerFunc` has no return value, forcing each handler to write its own error responses. The `AppHandler` pattern lets handlers return errors, with a single centralized function mapping errors to HTTP responses.
### Custom Handler Type
```go
// AppHandler is an http.HandlerFunc that returns an error
type AppHandler func(w http.ResponseWriter, r *http.Request) error
// ServeHTTP implements http.Handler, calling the function and handling errors
func (fn AppHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil {
handleError(w, r, err)
}
}
```
Usage with a router:
```go
mux := http.NewServeMux()
mux.Handle("GET /users/{id}", AppHandler(getUser))
mux.Handle("POST /users", AppHandler(createUser))
func getUser(w http.ResponseWriter, r *http.Request) error {
id := r.PathValue("id")
user, err := db.FindUser(r.Context(), id)
if err != nil {
return fmt.Errorf("finding user %s: %w", id, err)
}
if user == nil {
return ErrNotFound
}
return writeJSON(w, http.StatusOK, user)
}
```
Handlers focus on the happy path and return errors. The centralized `handleError` function takes care of logging and response formatting.
## Domain Errors
Define typed errors that map to HTTP status codes:
```go
type AppError struct {
Code int `json:"-"`
Message string `json:"error"`
Detail string `json:"detail,omitempty"`
}
func (e *AppError) Error() string { return e.Message }
var (
ErrNotFound = &AppError{Code: 404, Message: "resource not found"}
ErrUnauthorized = &AppError{Code: 401, Message: "unauthorized"}
ErrForbidden = &AppError{Code: 403, Message: "forbidden"}
ErrBadRequest = &AppError{Code: 400, Message: "bad request"}
ErrConflict = &AppError{Code: 409, Message: "conflict"}
)
```
### Creating Errors with Detail
```go
func NewBadRequest(detail string) *AppError {
return &AppError{
Code: 400,
Message: "bad request",
Detail: detail,
}
}
// In a handler
func createUser(w http.ResponseWriter, r *http.Request) error {
var input CreateUserInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
return NewBadRequest("invalid JSON body")
}
if input.Email == "" {
return NewBadRequest("email is required")
}
// ...
}
```
### Wrapping Domain Errors
Use `fmt.Errorf` with `%w` to add context while preserving the original error for `errors.As`:
```go
func getOrder(w http.ResponseWriter, r *http.Request) error {
id := r.PathValue("id")
order, err := db.FindOrder(r.Context(), id)
if err != nil {
return fmt.Errorf("finding order %s: %w", id, err)
}
if order == nil {
return fmt.Errorf("order %s: %w", id, ErrNotFound)
}
return writeJSON(w, http.StatusOK, order)
}
```
## Centralized Error Handler
The `handleError` function maps errors to HTTP responses. Known `AppError` types get their specific status code; everything else is a 500.
```go
func handleError(w http.ResponseWriter, r *http.Request, err error) {
logger := slog.Default()
reqID := RequestIDFromContext(r.Context())
var appErr *AppError
if errors.As(err, &appErr) {
logger.Warn("handled error",
"error", appErr.Message,
"detail", appErr.Detail,
"status", appErr.Code,
"request_id", reqID,
"method", r.Method,
"path", r.URL.Path,
)
writeJSON(w, appErr.Code, appErr)
return
}
// Unexpected error -- do not leak internals
logger.Error("unhandled error",
"error", err.Error(),
"request_id", reqID,
"method", r.Method,
"path", r.URL.Path,
)
writeJSON(w, 500, map[string]string{"error": "internal server error"})
}
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(v)
}
```
Key principles:
- Known errors (AppError) are logged at Warn level with their detail
- Unknown errors are logged at Error level with the full message
- Clients never see internal error messages for unknown errors
- Every error log includes the request ID for correlation
## JSON Error Response Format
All error responses follow a consistent structure:
```json
{
"error": "resource not found",
"detail": "order abc-123"
}
```
The `detail` field is optional and omitted when empty. This consistency makes it easy for API clients to parse errors.
## Recovery Middleware
Panics in Go HTTP handlers crash the server (when not using `net/http`'s default recovery, which only logs and closes the connection). Recovery middleware catches panics and returns a proper error response.
```go
func Recovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
slog.Error("panic recovered",
"panic", rec,
"stack", string(debug.Stack()),
"request_id", RequestIDFromContext(r.Context()),
"method", r.Method,
"path", r.URL.Path,
)
writeJSON(w, 500, map[string]string{"error": "internal server error"})
}
}()
next.ServeHTTP(w, r)
})
}
```
### Why Recovery Must Be Outermost
Recovery catches panics by wrapping the call to `next.ServeHTTP` in a deferred `recover()`. If any middleware outside of recovery panics, it won't be caught:
```go
// CORRECT: recovery wraps everything
handler := Recovery(RequestID(Logger(router)))
// WRONG: panics in RequestID or Logger are not caught
handler := RequestID(Logger(Recovery(router)))
```
### Stack Trace Logging
`runtime/debug.Stack()` returns the goroutine's stack trace at the point of the panic. Log this at Error level for debugging, but never include it in the HTTP response.
```go
import "runtime/debug"
slog.Error("panic recovered",
"panic", rec,
"stack", string(debug.Stack()),
)
```
### Never Expose Panic Details to Clients
The panic value (`rec`) often contains internal information -- file paths, memory addresses, or internal state. Always return a generic error message:
```go
// GOOD: generic message to client
writeJSON(w, 500, map[string]string{"error": "internal server error"})
// BAD: leaking panic info
writeJSON(w, 500, map[string]string{"error": fmt.Sprintf("%v", rec)})
```
## Combining AppHandler with Recovery
The `AppHandler` pattern handles returned errors; recovery handles panics. Together they cover all failure modes:
```go
// AppHandler catches returned errors
func getUser(w http.ResponseWriter, r *http.Request) error {
user, err := db.FindUser(r.Context(), r.PathValue("id"))
if err != nil {
return fmt.Errorf("finding user: %w", err) // caught by AppHandler
}
return writeJSON(w, 200, user)
}
// Recovery catches panics (e.g., nil pointer dereference)
// Applied as outermost middleware
handler := Recovery(
RequestID(
Logger(router),
),
)
```
Handlers should return errors, not panic. Recovery is a safety net for unexpected situations (nil pointer dereference, index out of range, third-party library panics).
## Testing Error Handling
```go
func TestHandleError_AppError(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/test", nil)
handleError(w, r, ErrNotFound)
if w.Code != 404 {
t.Errorf("expected 404, got %d", w.Code)
}
var body map[string]string
json.NewDecoder(w.Body).Decode(&body)
if body["error"] != "resource not found" {
t.Errorf("expected 'resource not found', got %q", body["error"])
}
}
func TestHandleError_UnknownError(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/test", nil)
handleError(w, r, fmt.Errorf("database connection refused"))
if w.Code != 500 {
t.Errorf("expected 500, got %d", w.Code)
}
var body map[string]string
json.NewDecoder(w.Body).Decode(&body)
if body["error"] != "internal server error" {
t.Errorf("expected 'internal server error', got %q", body["error"])
}
}
func TestRecoveryMiddleware(t *testing.T) {
panicking := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("something went wrong")
})
handler := Recovery(panicking)
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/test", nil)
handler.ServeHTTP(w, r)
if w.Code != 500 {
t.Errorf("expected 500, got %d", w.Code)
}
}
```
FILE:references/structured-logging.md
# Structured Logging with slog
`log/slog` is the standard library structured logging package (Go 1.21+). It replaces the older `log` package for production services.
## Setting Up slog
### Production: JSON Handler
```go
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
slog.SetDefault(logger)
```
Output:
```json
{"time":"2024-01-15T10:30:00Z","level":"INFO","msg":"request completed","method":"GET","path":"/api/users","status":200,"duration_ms":42}
```
### Development: Text Handler
```go
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
AddSource: true,
}))
slog.SetDefault(logger)
```
Output:
```
time=2024-01-15T10:30:00Z level=INFO source=main.go:42 msg="request completed" method=GET path=/api/users status=200 duration_ms=42
```
### Choosing Based on Environment
```go
func setupLogger(env string) *slog.Logger {
var handler slog.Handler
switch env {
case "production":
handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelInfo,
})
default:
handler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
AddSource: true,
})
}
return slog.New(handler)
}
```
## Log Levels
slog provides four levels:
| Level | Value | Use for |
|-------|-------|---------|
| `slog.LevelDebug` | -4 | Verbose diagnostic info, disabled in production |
| `slog.LevelInfo` | 0 | Normal operations (request completed, job started) |
| `slog.LevelWarn` | 4 | Handled errors, degraded operation, approaching limits |
| `slog.LevelError` | 8 | Unhandled errors, panics, failed critical operations |
```go
slog.Debug("cache miss", "key", cacheKey)
slog.Info("request completed", "method", r.Method, "status", 200)
slog.Warn("rate limit approaching", "current", count, "limit", max)
slog.Error("database connection failed", "error", err)
```
## Logging Middleware
Capture HTTP method, path, response status, and request duration for every request.
```go
func Logger(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrap ResponseWriter to capture status code
wrapped := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(wrapped, r)
logger.Info("request completed",
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.status,
"duration_ms", time.Since(start).Milliseconds(),
"request_id", RequestIDFromContext(r.Context()),
)
})
}
}
type statusWriter struct {
http.ResponseWriter
status int
}
func (w *statusWriter) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}
```
### Logging Errors vs Success at Different Levels
```go
func Logger(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrapped := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(wrapped, r)
attrs := []any{
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.status,
"duration_ms", time.Since(start).Milliseconds(),
"request_id", RequestIDFromContext(r.Context()),
}
switch {
case wrapped.status >= 500:
logger.Error("server error", attrs...)
case wrapped.status >= 400:
logger.Warn("client error", attrs...)
default:
logger.Info("request completed", attrs...)
}
})
}
}
```
## Adding Request ID to All Log Entries
Use `slog.With` to create a child logger that includes the request ID in every log call within that request's scope:
```go
func LoggerWithContext(baseLogger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqID := RequestIDFromContext(r.Context())
// Create a child logger with request_id baked in
logger := baseLogger.With("request_id", reqID)
// Store logger in context for use in handlers
ctx := context.WithValue(r.Context(), loggerKey, logger)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
type contextKey string
const loggerKey contextKey = "logger"
func LoggerFromContext(ctx context.Context) *slog.Logger {
if logger, ok := ctx.Value(loggerKey).(*slog.Logger); ok {
return logger
}
return slog.Default()
}
```
Usage in handlers:
```go
func handleOrder(w http.ResponseWriter, r *http.Request) {
logger := LoggerFromContext(r.Context())
logger.Info("processing order", "order_id", orderID)
// Output includes request_id automatically
}
```
## Child Loggers with Additional Context
Build up context as you go deeper into the call stack:
```go
func processOrder(ctx context.Context, order *Order) error {
logger := LoggerFromContext(ctx).With(
"order_id", order.ID,
"customer_id", order.CustomerID,
)
logger.Info("validating order")
if err := validate(order); err != nil {
logger.Warn("validation failed", "error", err)
return fmt.Errorf("validating order: %w", err)
}
logger.Info("charging payment")
// ...
return nil
}
```
## Structured Logging Best Practices
### Use consistent key names
```go
// Good: consistent naming across the codebase
slog.Info("query executed", "duration_ms", dur, "row_count", count)
slog.Info("request completed", "duration_ms", dur, "status", code)
// Bad: inconsistent naming
slog.Info("query executed", "elapsed", dur, "rows", count)
slog.Info("request completed", "time_ms", dur, "statusCode", code)
```
### Use slog.Group for namespaced attributes
```go
slog.Info("request",
slog.Group("http",
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int("status", status),
),
slog.Group("timing",
slog.Int64("duration_ms", dur),
),
)
// JSON: {"msg":"request","http":{"method":"GET","path":"/api","status":200},"timing":{"duration_ms":42}}
```
### Never log sensitive data
```go
// BAD
slog.Info("user login", "password", password, "token", authToken)
// GOOD
slog.Info("user login", "user_id", userID)
```
### Log errors with the "error" key
```go
// Consistent error key makes searching/filtering easy
slog.Error("database query failed", "error", err, "query", queryName)
slog.Warn("cache miss", "error", err, "key", cacheKey)
```
## StatusWriter Considerations
The basic `statusWriter` does not implement optional `http.ResponseWriter` interfaces. If you need `http.Flusher`, `http.Hijacker`, or `http.Pusher` support, implement them explicitly:
```go
type statusWriter struct {
http.ResponseWriter
status int
wroteHeader bool
}
func (w *statusWriter) WriteHeader(code int) {
if !w.wroteHeader {
w.status = code
w.wroteHeader = true
}
w.ResponseWriter.WriteHeader(code)
}
func (w *statusWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
func (w *statusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("hijack not supported")
}
```
Data persistence patterns in Go covering raw SQL with sqlx/pgx, ORMs like Ent and GORM, connection pooling, migrations with golang-migrate, and transaction m...
---
name: go-data-persistence
description: Data persistence patterns in Go covering raw SQL with sqlx/pgx, ORMs like Ent and GORM, connection pooling, migrations with golang-migrate, and transaction management. Use when implementing database access, designing repositories, or managing schema migrations.
---
# Data Persistence in Go
## Quick Reference
| Topic | Reference |
|-------|-----------|
| Connection pool internals, sizing, pgx pools, monitoring | [references/connection-pooling.md](references/connection-pooling.md) |
| golang-migrate setup, file conventions, CI/CD integration | [references/migrations.md](references/migrations.md) |
| Transaction helpers, service-layer transactions, isolation levels | [references/transactions.md](references/transactions.md) |
## Choosing Your Approach
Pick the right tool based on your project's needs:
| Factor | Raw SQL (sqlx/pgx) | ORM (Ent/GORM) |
|--------|-------------------|-----------------|
| Complex queries | Preferred | Awkward |
| Type safety | Manual | Auto-generated |
| Performance control | Full | Limited |
| Rapid prototyping | Slower | Faster |
| Schema migrations | golang-migrate | Built-in (Ent) |
| Learning curve | SQL knowledge | ORM API |
### When to Use Raw SQL (sqlx/pgx)
- You need full control over query performance and execution plans
- Your domain has complex joins, CTEs, window functions, or recursive queries
- You want zero abstraction overhead and direct access to PostgreSQL features
- Your team is comfortable writing and maintaining SQL
- You need advanced PostgreSQL features like `LISTEN/NOTIFY`, advisory locks, or `COPY`
**pgx** is the recommended PostgreSQL driver for Go. It provides native PostgreSQL protocol support, better performance than `database/sql`, and access to PostgreSQL-specific features. Use **sqlx** when you need `database/sql` compatibility or work with multiple database backends.
### When to Use an ORM (Ent/GORM)
- You want type-safe, generated query builders and avoid writing SQL
- Your schema is mostly CRUD with straightforward relationships
- You value generated code, schema-as-code, and automatic migrations (Ent)
- You are prototyping quickly and want to iterate on the schema fast
**Ent** is preferred over GORM for new projects. It uses code generation for type safety, has a declarative schema DSL, built-in migration support, and integrates with GraphQL. GORM is suitable if the team already knows it or if the project is small.
## Connection Setup
Every Go application connecting to a database needs a properly configured connection pool. The `database/sql` package manages pooling automatically, but the defaults are not suitable for production.
```go
db, err := sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("opening db: %w", err)
}
// Connection pool configuration
db.SetMaxOpenConns(25) // Max simultaneous connections
db.SetMaxIdleConns(10) // Connections kept alive when idle
db.SetConnMaxLifetime(5 * time.Minute) // Recycle connections
db.SetConnMaxIdleTime(1 * time.Minute) // Close idle connections
// Verify connection
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("pinging db: %w", err)
}
```
### Pool Settings Explained
**MaxOpenConns** -- The maximum number of open connections to the database. This prevents your application from overwhelming the database with too many concurrent connections. Set to approximately 25 for typical web apps. To calculate: divide your database's `max_connections` (minus a reserve for admin and replication) by the number of application instances. If your DB allows 100 connections, you have 3 app instances, and you reserve 10 for admin, set this to `(100 - 10) / 3 = 30`.
**MaxIdleConns** -- The number of connections kept alive in the pool when not in use. These warm connections avoid the latency of establishing new connections for each request. Set to approximately 10, or roughly 40% of `MaxOpenConns`. Setting this too high wastes database connections; setting it too low causes frequent reconnections.
**ConnMaxLifetime** -- The maximum amount of time a connection can be reused. After this duration, the connection is closed and a new one is created on the next request. This helps pick up DNS changes (important for cloud databases that failover to new IPs), rebalance load across read replicas, and prevent connections from becoming stale. A value of 5 minutes is typical. Set shorter (1-2 min) if your infrastructure uses DNS-based failover.
**ConnMaxIdleTime** -- The maximum amount of time a connection can sit idle before it is closed. This releases connections back to the database during low-traffic periods, freeing resources. A value of 1 minute is typical. This should be shorter than `ConnMaxLifetime`.
For pgx-specific pooling with native PostgreSQL support, see [references/connection-pooling.md](references/connection-pooling.md).
## Repository Pattern
Define a store interface at the consumer for testability. Implement against a concrete database driver. This pattern keeps your domain logic decoupled from the database.
```go
// Store interface for testability
type UserStore interface {
GetUser(ctx context.Context, id string) (*User, error)
ListUsers(ctx context.Context, limit, offset int) ([]*User, error)
CreateUser(ctx context.Context, u *User) error
}
// sqlx implementation
type PostgresUserStore struct {
db *sqlx.DB
}
func NewPostgresUserStore(db *sqlx.DB) *PostgresUserStore {
return &PostgresUserStore{db: db}
}
func (s *PostgresUserStore) GetUser(ctx context.Context, id string) (*User, error) {
var u User
err := s.db.GetContext(ctx, &u, "SELECT * FROM users WHERE id = $1", id)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return &u, err
}
func (s *PostgresUserStore) ListUsers(ctx context.Context, limit, offset int) ([]*User, error) {
var users []*User
err := s.db.SelectContext(ctx, &users,
"SELECT * FROM users ORDER BY created_at DESC LIMIT $1 OFFSET $2",
limit, offset,
)
return users, err
}
func (s *PostgresUserStore) CreateUser(ctx context.Context, u *User) error {
_, err := s.db.NamedExecContext(ctx,
`INSERT INTO users (id, email, name, created_at, updated_at)
VALUES (:id, :email, :name, :created_at, :updated_at)`, u)
return err
}
```
### Model Struct Tags
Use `db` tags for sqlx column mapping and keep models close to the store:
```go
type User struct {
ID string `db:"id"`
Email string `db:"email"`
Name string `db:"name"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}
```
### Sentinel Errors
Define domain-specific errors that callers can check without importing database packages:
```go
var (
ErrNotFound = errors.New("not found")
ErrConflict = errors.New("conflict")
)
```
Map database errors to domain errors in the store layer:
```go
func (s *PostgresUserStore) CreateUser(ctx context.Context, u *User) error {
_, err := s.db.NamedExecContext(ctx, query, u)
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return ErrConflict
}
return fmt.Errorf("inserting user: %w", err)
}
return nil
}
```
## Migrations
Use **golang-migrate** for managing schema changes. Migrations are pairs of SQL files: one for applying changes (`up`) and one for reverting them (`down`).
```
migrations/
├── 000001_create_users.up.sql
├── 000001_create_users.down.sql
├── 000002_add_user_roles.up.sql
└── 000002_add_user_roles.down.sql
```
Run migrations at application startup:
```go
import "github.com/golang-migrate/migrate/v4"
func runMigrations(dbURL string) error {
m, err := migrate.New("file://migrations", dbURL)
if err != nil {
return fmt.Errorf("creating migrator: %w", err)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("running migrations: %w", err)
}
return nil
}
```
Key rules: always write both up and down migrations, use `IF NOT EXISTS` / `IF EXISTS` for idempotency, never modify a migration that has been applied in production. For full migration patterns, CI/CD integration, and safe migration strategies, see [references/migrations.md](references/migrations.md).
## Transactions
Use a transaction helper to ensure consistent commit/rollback handling. Transactions should be managed at the **service layer**, not the store layer, so that multiple store operations can be composed into a single atomic unit.
```go
func WithTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("beginning transaction: %w", err)
}
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("rollback failed: %v (original error: %w)", rbErr, err)
}
return err
}
return tx.Commit()
}
```
Store methods accept a `*sql.Tx` parameter so they can participate in a caller-controlled transaction:
```go
func (s *OrderService) PlaceOrder(ctx context.Context, order *Order) error {
return WithTx(ctx, s.db, func(tx *sql.Tx) error {
if err := s.orderStore.CreateWithTx(ctx, tx, order); err != nil {
return fmt.Errorf("creating order: %w", err)
}
if err := s.inventoryStore.DecrementWithTx(ctx, tx, order.Items); err != nil {
return fmt.Errorf("updating inventory: %w", err)
}
return nil
})
}
```
For isolation levels, deadlock prevention, context propagation, and testing strategies, see [references/transactions.md](references/transactions.md).
## When to Load References
Load **connection-pooling.md** when:
- Configuring pgx native pools (`pgxpool.Pool`)
- Sizing connection pools for production workloads
- Working with cloud databases, PgBouncer, or connection limits
- Monitoring pool health and metrics
Load **migrations.md** when:
- Setting up golang-migrate for the first time
- Writing new migration files
- Integrating migrations into CI/CD pipelines
- Dealing with migration failures or rollbacks
Load **transactions.md** when:
- Implementing multi-step operations that must be atomic
- Designing service-layer transaction boundaries
- Choosing transaction isolation levels
- Debugging deadlocks or long-running transactions
## Gates (objective checks before merge)
Run these in order; do not rationalize past a failed step.
1. **Migrations**
1. List the migration file paths you are adding or relying on (both `.up.sql` and `.down.sql` per version).
2. **Pass:** Each new version has a matching pair on disk with consistent naming (see [Migrations](#migrations)).
3. **Pass:** You did not rewrite migration content that is already applied anywhere you care about (production or shared dev); you added a new version instead.
2. **Query safety**
1. Scan the diff for dynamic SQL built with `fmt.Sprintf`, `+`, or string concatenation involving request fields, JSON, or other external input.
2. **Pass:** Every such query uses bind parameters (`$1`, `:name`) or an ORM/query builder that emits parameterized statements; identifiers (table/column names) that must be dynamic use an explicit allowlist, not raw strings from users.
3. **Pool and context**
1. Confirm database pool construction (`sql.Open`, `pgxpool.New`, etc.) runs once at process startup and is shared, not inside per-request handlers.
2. **Pass:** Code paths that should respect cancellation/timeouts use `QueryContext`, `ExecContext`, `GetContext`, or equivalent—not `Query`/`Exec` without context—for work tied to `context.Context`.
## Anti-Patterns
### Using string concatenation for queries
```go
// BAD -- SQL injection vulnerability
query := "SELECT * FROM users WHERE name = '" + name + "'"
```
Always use parameterized queries (`$1`, `$2`, etc.) or named parameters (`:name`).
### Leaking database types into handlers
```go
// BAD -- handler depends on sql.ErrNoRows
func (s *Server) handleGetUser(w http.ResponseWriter, r *http.Request) {
user, err := s.store.GetUser(ctx, id)
if errors.Is(err, sql.ErrNoRows) { // handler knows about sql package
http.NotFound(w, r)
return
}
}
```
Return domain errors (`ErrNotFound`) from the store and check those in handlers instead.
### Opening a new connection per request
```go
// BAD -- bypasses connection pooling entirely
func (s *Server) handleGetUser(w http.ResponseWriter, r *http.Request) {
db, _ := sql.Open("postgres", connStr) // new pool per request!
defer db.Close()
}
```
Open the database connection once at startup and share the pool across the application.
### SELECT * in production code
```go
// BAD -- fragile, breaks when columns change
err := db.GetContext(ctx, &u, "SELECT * FROM users WHERE id = $1", id)
```
Explicitly list the columns you need. This makes the query resilient to schema changes and avoids fetching unnecessary data.
### Not handling context cancellation
```go
// BAD -- ignores context, query runs even if client disconnects
rows, err := db.Query("SELECT * FROM large_table")
```
Always use the `Context` variants (`QueryContext`, `ExecContext`, `GetContext`) and pass the request context so that queries are cancelled when the caller gives up.
### Transactions in store methods
```go
// BAD -- store controls transaction, caller cannot compose
func (s *UserStore) CreateUser(ctx context.Context, u *User) error {
tx, _ := s.db.BeginTx(ctx, nil)
// ... insert user ...
return tx.Commit()
}
```
Let the service layer manage transactions and pass `*sql.Tx` into store methods. See [references/transactions.md](references/transactions.md) for the correct pattern.
FILE:references/connection-pooling.md
# Connection Pooling in Go
## How database/sql Pooling Works
Go's `database/sql` package manages a pool of connections internally. When you call `db.QueryContext()` or `db.ExecContext()`, the pool:
1. Checks for an available idle connection
2. If none available and under `MaxOpenConns`, creates a new connection
3. If at `MaxOpenConns`, blocks until a connection is returned to the pool
4. After the query completes, returns the connection to the idle pool
5. If the idle pool is full (`MaxIdleConns`), closes the connection instead
This means `sql.Open()` does not actually open a connection -- it only validates the DSN and prepares the pool. The first real connection happens on the first query or `Ping()`.
### Pool Lifecycle
```
Request arrives
|
v
Pool has idle conn? --yes--> Use it --> Return to idle pool
| |
no Idle pool full?
| / \
v yes no
Under MaxOpenConns? Close conn Keep idle
| |
yes no
| |
v v
Open new Block until
connection one is returned
```
## Sizing Guidelines
### Formula
```
MaxOpenConns = (DB max_connections - reserved_connections) / app_instances
```
Where:
- `max_connections` is the database server's maximum connection limit (check with `SHOW max_connections` in PostgreSQL)
- `reserved_connections` are connections reserved for superuser access, replication, monitoring, and migrations (typically 10-20)
- `app_instances` is the number of running application replicas
### Workload-Based Sizing
| Workload Type | MaxOpenConns | MaxIdleConns | Notes |
|---------------|-------------|-------------|-------|
| Low-traffic API (< 100 rps) | 10 | 5 | Minimal resources |
| Typical web app (100-1000 rps) | 25 | 10 | Good default |
| High-traffic service (1000+ rps) | 50-100 | 20-40 | Monitor DB CPU |
| Background worker | 5-10 | 2-5 | Few concurrent queries |
| Batch processing | 10-25 | 5-10 | Depends on parallelism |
### Important Considerations
- More connections does not mean more throughput. PostgreSQL performance degrades significantly above ~100 active connections due to lock contention and context switching.
- If your app needs more than ~50 connections per instance, consider using PgBouncer or a similar connection pooler between your app and the database.
- Monitor actual connection usage before tuning. Use `db.Stats()` to check pool utilization.
## pgx Native Pooling
For PostgreSQL-only applications, use `pgxpool.Pool` instead of `database/sql`. It provides better performance, native PostgreSQL protocol support, and additional features like health checks.
```go
import (
"context"
"fmt"
"os"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
func NewPool(ctx context.Context) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL"))
if err != nil {
return nil, fmt.Errorf("parsing db config: %w", err)
}
config.MaxConns = 25
config.MinConns = 5
config.MaxConnLifetime = 5 * time.Minute
config.MaxConnIdleTime = 1 * time.Minute
config.HealthCheckPeriod = 30 * time.Second
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("creating pool: %w", err)
}
return pool, nil
}
```
### pgxpool vs database/sql
| Feature | pgxpool.Pool | database/sql |
|---------|-------------|-------------|
| Protocol | Native PostgreSQL | Generic driver interface |
| Performance | Faster (no interface overhead) | Slightly slower |
| MinConns | Supported | Not available |
| Health checks | Built-in periodic | Manual via Ping |
| COPY protocol | Native support | Not available |
| LISTEN/NOTIFY | Native support | Driver-dependent |
| Multi-database | PostgreSQL only | Any database |
| Ecosystem | pgx-specific | Universal Go packages |
### pgx with database/sql Compatibility
If you need `database/sql` compatibility (for libraries that require it) but still want pgx as the driver:
```go
import (
"database/sql"
_ "github.com/jackc/pgx/v5/stdlib"
)
func NewDB(connStr string) (*sql.DB, error) {
db, err := sql.Open("pgx", connStr)
if err != nil {
return nil, fmt.Errorf("opening db: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(10)
db.SetConnMaxLifetime(5 * time.Minute)
db.SetConnMaxIdleTime(1 * time.Minute)
return db, nil
}
```
## Health Checks and Connection Validation
### pgxpool Health Checks
`pgxpool` performs automatic health checks on idle connections at the interval set by `HealthCheckPeriod`. This detects broken connections (network failures, database restarts) before they are used for a real query.
```go
config.HealthCheckPeriod = 30 * time.Second
```
If a health check fails, the connection is removed from the pool and a new one is created on demand.
### Manual Health Check Endpoint
Expose a health check endpoint that verifies database connectivity:
```go
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second)
defer cancel()
if err := s.db.PingContext(ctx); err != nil {
http.Error(w, "database unreachable", http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}
```
Use a short timeout (1-2 seconds) for health check pings. If the database does not respond within that window, the instance should be marked unhealthy.
## Monitoring Pool Metrics
### database/sql Stats
```go
func (s *Server) handleDBStats(w http.ResponseWriter, r *http.Request) {
stats := s.db.Stats()
fmt.Fprintf(w, "Open connections: %d\n", stats.OpenConnections)
fmt.Fprintf(w, "In use: %d\n", stats.InUse)
fmt.Fprintf(w, "Idle: %d\n", stats.Idle)
fmt.Fprintf(w, "Wait count: %d\n", stats.WaitCount)
fmt.Fprintf(w, "Wait duration: %s\n", stats.WaitDuration)
fmt.Fprintf(w, "Max idle closed: %d\n", stats.MaxIdleClosed)
fmt.Fprintf(w, "Max lifetime closed: %d\n", stats.MaxLifetimeClosed)
}
```
### Key Metrics to Watch
| Metric | Healthy | Warning |
|--------|---------|---------|
| `WaitCount` | Low/zero | Increasing over time |
| `WaitDuration` | < 10ms avg | > 100ms avg |
| `InUse` | < 80% of MaxOpenConns | Consistently near max |
| `MaxIdleClosed` | Low | Very high (raise MaxIdleConns) |
| `MaxLifetimeClosed` | Proportional to traffic | Unexpectedly high |
If `WaitCount` is steadily increasing, your application is running out of connections. Either increase `MaxOpenConns` (if the database can handle it) or reduce query duration.
### Prometheus Integration
```go
import "github.com/prometheus/client_golang/prometheus"
func registerDBMetrics(db *sql.DB) {
prometheus.MustRegister(prometheus.NewGaugeFunc(
prometheus.GaugeOpts{
Name: "db_open_connections",
Help: "Number of open database connections",
},
func() float64 { return float64(db.Stats().OpenConnections) },
))
prometheus.MustRegister(prometheus.NewGaugeFunc(
prometheus.GaugeOpts{
Name: "db_in_use_connections",
Help: "Number of in-use database connections",
},
func() float64 { return float64(db.Stats().InUse) },
))
prometheus.MustRegister(prometheus.NewGaugeFunc(
prometheus.GaugeOpts{
Name: "db_idle_connections",
Help: "Number of idle database connections",
},
func() float64 { return float64(db.Stats().Idle) },
))
prometheus.MustRegister(prometheus.NewCounterFunc(
prometheus.CounterOpts{
Name: "db_wait_count_total",
Help: "Total number of connections waited for",
},
func() float64 { return float64(db.Stats().WaitCount) },
))
}
```
### pgxpool Stats
```go
func logPoolStats(pool *pgxpool.Pool) {
stat := pool.Stat()
slog.Info("pool stats",
"total_conns", stat.TotalConns(),
"acquired_conns", stat.AcquiredConns(),
"idle_conns", stat.IdleConns(),
"constructing_conns", stat.ConstructingConns(),
"max_conns", stat.MaxConns(),
"new_conns_count", stat.NewConnsCount(),
"max_lifetime_destroy_count", stat.MaxLifetimeDestroyCount(),
"max_idle_destroy_count", stat.MaxIdleDestroyCount(),
)
}
```
## Cloud Database Considerations
### Connection Limits by Provider
| Provider | Free/Dev Tier | Standard | Notes |
|----------|-------------|----------|-------|
| AWS RDS (db.t3.micro) | 87 | Scales with instance | Based on instance memory |
| Google Cloud SQL | 25 (basic) | Up to 4000 | Depends on tier |
| Supabase | 60 (free) | 200-500 | Uses PgBouncer |
| Neon | 100 (free) | 300-500 | Serverless, auto-scales |
| Railway | Varies | Varies | Shared resources on free |
### PgBouncer
When using PgBouncer (common in managed PostgreSQL services like Supabase), adjust your application settings:
```go
// With PgBouncer in transaction mode
db.SetMaxOpenConns(50) // Can be higher -- PgBouncer multiplexes
db.SetMaxIdleConns(5) // Keep low -- PgBouncer handles idle
db.SetConnMaxLifetime(0) // Disable -- PgBouncer manages lifetime
```
Important PgBouncer considerations:
- **Transaction pooling mode** (most common): connections are assigned per transaction, not per session. Prepared statements do not work across transactions.
- **Session pooling mode**: connections are assigned per session. Prepared statements work normally but you get less multiplexing benefit.
- If using pgx with PgBouncer in transaction mode, disable prepared statements:
```go
config, _ := pgxpool.ParseConfig(connStr)
config.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
```
### DNS-Based Failover
Cloud databases often use DNS to point to the current primary. Set `ConnMaxLifetime` to a short value so your application picks up DNS changes after failover:
```go
db.SetConnMaxLifetime(1 * time.Minute) // Short lifetime for fast failover
```
Without this, long-lived connections may keep pointing to the old primary after a failover event, causing errors.
FILE:references/migrations.md
# Database Migrations with golang-migrate
## Overview
**golang-migrate** is the standard migration tool for Go applications using raw SQL. It manages versioned migration files, tracks which migrations have been applied, and supports both programmatic and CLI usage.
Install the CLI:
```bash
go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest
```
## File Naming Convention
Migrations are pairs of SQL files stored in a `migrations/` directory:
```
migrations/
├── 000001_create_users.up.sql
├── 000001_create_users.down.sql
├── 000002_add_user_roles.up.sql
├── 000002_add_user_roles.down.sql
├── 000003_create_orders.up.sql
└── 000003_create_orders.down.sql
```
Format: `{version}_{description}.{direction}.sql`
- **version**: zero-padded sequential number (6 digits recommended for sorting)
- **description**: snake_case description of the change
- **direction**: `up` (apply) or `down` (revert)
Generate a new migration pair with the CLI:
```bash
migrate create -ext sql -dir migrations -seq add_user_roles
```
This creates both `up.sql` and `down.sql` files with the next sequential version number.
## Example Migrations
### Creating a table
```sql
-- 000001_create_users.up.sql
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
email TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX idx_users_email ON users(email);
-- 000001_create_users.down.sql
DROP TABLE IF EXISTS users;
```
### Adding columns
```sql
-- 000002_add_user_roles.up.sql
ALTER TABLE users ADD COLUMN IF NOT EXISTS role TEXT NOT NULL DEFAULT 'user';
CREATE INDEX idx_users_role ON users(role);
-- 000002_add_user_roles.down.sql
DROP INDEX IF EXISTS idx_users_role;
ALTER TABLE users DROP COLUMN IF EXISTS role;
```
### Creating a related table
```sql
-- 000003_create_orders.up.sql
CREATE TABLE IF NOT EXISTS orders (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
total_cents BIGINT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX idx_orders_user_id ON orders(user_id);
CREATE INDEX idx_orders_status ON orders(status);
-- 000003_create_orders.down.sql
DROP TABLE IF EXISTS orders;
```
## Running Migrations in Code
Embed migrations in your binary and run them at application startup:
```go
import (
"embed"
"fmt"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
func runMigrations(dbURL string) error {
source, err := iofs.New(migrationsFS, "migrations")
if err != nil {
return fmt.Errorf("creating migration source: %w", err)
}
m, err := migrate.NewWithSourceInstance("iofs", source, dbURL)
if err != nil {
return fmt.Errorf("creating migrator: %w", err)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("running migrations: %w", err)
}
version, dirty, _ := m.Version()
slog.Info("migrations complete", "version", version, "dirty", dirty)
return nil
}
```
### File-based migrations (without embedding)
```go
func runMigrations(dbURL string) error {
m, err := migrate.New("file://migrations", dbURL)
if err != nil {
return fmt.Errorf("creating migrator: %w", err)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("running migrations: %w", err)
}
return nil
}
```
### Integration in main()
```go
func run(ctx context.Context) error {
dbURL := os.Getenv("DATABASE_URL")
// Run migrations before opening the connection pool
if err := runMigrations(dbURL); err != nil {
return fmt.Errorf("running migrations: %w", err)
}
db, err := sql.Open("postgres", dbURL)
if err != nil {
return fmt.Errorf("opening db: %w", err)
}
defer db.Close()
// ... rest of app setup ...
}
```
## CLI Usage
```bash
# Apply all pending migrations
migrate -database "$DATABASE_URL" -path migrations up
# Apply the next N migrations
migrate -database "$DATABASE_URL" -path migrations up 2
# Rollback the last migration
migrate -database "$DATABASE_URL" -path migrations down 1
# Rollback all migrations
migrate -database "$DATABASE_URL" -path migrations down
# Go to a specific version
migrate -database "$DATABASE_URL" -path migrations goto 3
# Show current migration version
migrate -database "$DATABASE_URL" -path migrations version
# Force a version (useful for fixing dirty state)
migrate -database "$DATABASE_URL" -path migrations force 3
```
## Writing Safe Migrations
### Idempotency
Always use `IF NOT EXISTS` and `IF EXISTS` so that migrations can be retried safely after partial failures:
```sql
-- Good: idempotent
CREATE TABLE IF NOT EXISTS users (...);
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
ALTER TABLE users ADD COLUMN IF NOT EXISTS role TEXT DEFAULT 'user';
-- Bad: fails on re-run
CREATE TABLE users (...);
CREATE INDEX idx_users_email ON users(email);
```
### Use Transactions
Wrap DDL statements in transactions when the database supports transactional DDL (PostgreSQL does):
```sql
-- 000004_add_audit_fields.up.sql
BEGIN;
ALTER TABLE users ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ;
ALTER TABLE users ADD COLUMN IF NOT EXISTS login_count INTEGER NOT NULL DEFAULT 0;
COMMIT;
```
If any statement within the transaction fails, all changes are rolled back, leaving the schema in a consistent state.
### Large Table Migrations
For tables with millions of rows, certain operations lock the table and block reads/writes. Use these strategies:
```sql
-- Bad: locks the entire table while building the index
CREATE INDEX idx_orders_created_at ON orders(created_at);
-- Good: builds the index without locking
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_orders_created_at ON orders(created_at);
```
Note: `CREATE INDEX CONCURRENTLY` cannot run inside a transaction. For migrations that include concurrent index creation, do not wrap them in `BEGIN/COMMIT`.
For adding columns with defaults on large tables (PostgreSQL 11+), `ALTER TABLE ADD COLUMN ... DEFAULT` is safe and fast because PostgreSQL stores the default value in the catalog rather than rewriting the table.
### Separate Data and Schema Migrations
Keep data transformations in separate migration files from schema changes:
```
migrations/
├── 000005_add_full_name_column.up.sql # Schema: add column
├── 000005_add_full_name_column.down.sql
├── 000006_populate_full_name.up.sql # Data: backfill
├── 000006_populate_full_name.down.sql
├── 000007_drop_first_last_name.up.sql # Schema: remove old columns
└── 000007_drop_first_last_name.down.sql
```
This three-step approach (add new, backfill, remove old) allows zero-downtime deployments because the application can read from either the old or new columns during the transition.
## Rolling Back Migrations
### When to Roll Back
- A migration introduced a bug that affects production
- A deployment failed partway through and the database is in a dirty state
- You need to revert a schema change before deploying a fix
### How to Roll Back
```bash
# Roll back the last applied migration
migrate -database "$DATABASE_URL" -path migrations down 1
```
### Handling Dirty State
If a migration fails partway through, golang-migrate marks the migration version as "dirty." You cannot apply further migrations until the dirty flag is cleared.
```bash
# Check current state
migrate -database "$DATABASE_URL" -path migrations version
# Output: 5 (dirty)
# Option 1: Fix the issue and force the version
migrate -database "$DATABASE_URL" -path migrations force 4 # Revert to last clean version
# Option 2: Manually fix the database, then force to the current version
migrate -database "$DATABASE_URL" -path migrations force 5 # Mark as clean
```
### Writing Reversible Down Migrations
Not all migrations are easily reversible. For destructive operations, the down migration should be a best-effort approximation:
```sql
-- 000005_drop_legacy_column.up.sql
ALTER TABLE users DROP COLUMN IF EXISTS legacy_field;
-- 000005_drop_legacy_column.down.sql
-- Cannot restore data, but can restore the column
ALTER TABLE users ADD COLUMN IF NOT EXISTS legacy_field TEXT;
```
Document in comments when a down migration cannot fully restore the previous state.
## Migration Rules
1. **Never modify a migration that has been applied in production.** Create a new migration to make corrections.
2. **Always write both up and down migrations.** Even if the down migration is imperfect, it provides a rollback path.
3. **Use `IF NOT EXISTS` / `IF EXISTS`** for idempotent, retriable migrations.
4. **Use transactions** for multi-statement migrations (except when using `CONCURRENTLY`).
5. **Separate data migrations from schema migrations.** This keeps each migration focused and allows staged rollouts.
6. **Add indexes concurrently** on large tables to avoid blocking reads and writes.
7. **Test migrations** against a copy of production data before deploying. Schema changes that work on an empty table may lock or fail on a table with millions of rows.
8. **Version control your migrations.** They are part of the codebase and should be reviewed in pull requests.
## Migrations in CI/CD
### CI Pipeline
Run migrations as part of your test pipeline against a test database:
```yaml
# GitHub Actions example
jobs:
test:
services:
postgres:
image: postgres:16
env:
POSTGRES_USER: test
POSTGRES_PASSWORD: test
POSTGRES_DB: testdb
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- uses: actions/checkout@v4
- name: Run migrations
run: |
migrate -database "postgres://test:test@localhost:5432/testdb?sslmode=disable" \
-path migrations up
- name: Run tests
run: go test ./...
env:
DATABASE_URL: postgres://test:test@localhost:5432/testdb?sslmode=disable
```
### CD Pipeline
For production deployments, run migrations before deploying the new application version:
```bash
# 1. Run migrations against production database
migrate -database "$PROD_DATABASE_URL" -path migrations up
# 2. Deploy new application version
# (only after migrations succeed)
```
If migrations fail, do not deploy the new application version. Fix the migration issue first, then retry.
### Multi-Instance Deployments
golang-migrate uses an advisory lock in PostgreSQL to prevent concurrent migration runs. This means it is safe to run migrations from multiple instances simultaneously -- only one will execute, the others will wait or skip.
However, it is cleaner to run migrations as a separate step (e.g., a Kubernetes Job or an init container) rather than having every application instance attempt to migrate on startup.
FILE:references/transactions.md
# Transaction Management in Go
## Why Service-Layer Transactions
Transactions should be managed at the **service layer**, not the store (repository) layer. The service layer knows which operations must be atomic. Individual store methods should not start their own transactions because:
- The caller cannot compose multiple store operations into a single transaction
- Each store method would commit independently, breaking atomicity
- Error handling becomes inconsistent -- some operations commit, others roll back
The pattern: the service begins a transaction, passes it to store methods, and commits or rolls back based on the outcome of all operations.
## Basic Transaction Pattern
```go
func transferFunds(ctx context.Context, db *sql.DB, from, to string, amount int64) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("beginning transaction: %w", err)
}
// Debit source account
_, err = tx.ExecContext(ctx,
"UPDATE accounts SET balance = balance - $1 WHERE id = $2 AND balance >= $1",
amount, from,
)
if err != nil {
tx.Rollback()
return fmt.Errorf("debiting account: %w", err)
}
// Credit destination account
_, err = tx.ExecContext(ctx,
"UPDATE accounts SET balance = balance + $1 WHERE id = $2",
amount, to,
)
if err != nil {
tx.Rollback()
return fmt.Errorf("crediting account: %w", err)
}
return tx.Commit()
}
```
This works but has problems: repetitive rollback handling, and forgetting `tx.Rollback()` on any error path causes a connection leak.
## Transaction Helper Function
Encapsulate the begin/commit/rollback lifecycle in a helper:
```go
func WithTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("beginning transaction: %w", err)
}
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("rollback failed: %v (original error: %w)", rbErr, err)
}
return err
}
return tx.Commit()
}
```
### With Custom Isolation Level
```go
func WithTxOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, fn func(tx *sql.Tx) error) error {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("beginning transaction: %w", err)
}
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("rollback failed: %v (original error: %w)", rbErr, err)
}
return err
}
return tx.Commit()
}
// Usage with serializable isolation
err := WithTxOptions(ctx, db, &sql.TxOptions{
Isolation: sql.LevelSerializable,
}, func(tx *sql.Tx) error {
// operations that require serializable isolation
return nil
})
```
## Service-Layer Transactions
The service layer coordinates multiple store operations within a single transaction:
```go
type OrderService struct {
db *sql.DB
orderStore *OrderStore
inventoryStore *InventoryStore
paymentStore *PaymentStore
}
func (s *OrderService) PlaceOrder(ctx context.Context, order *Order) error {
return WithTx(ctx, s.db, func(tx *sql.Tx) error {
// All operations share the same transaction
if err := s.orderStore.CreateWithTx(ctx, tx, order); err != nil {
return fmt.Errorf("creating order: %w", err)
}
if err := s.inventoryStore.DecrementWithTx(ctx, tx, order.Items); err != nil {
return fmt.Errorf("updating inventory: %w", err)
}
if err := s.paymentStore.ChargeWithTx(ctx, tx, order.Payment); err != nil {
return fmt.Errorf("charging payment: %w", err)
}
return nil
})
}
```
## Store Pattern Accepting Transactions
Store methods should accept a transaction parameter so they can participate in caller-controlled transactions:
```go
type OrderStore struct{}
func (s *OrderStore) CreateWithTx(ctx context.Context, tx *sql.Tx, order *Order) error {
_, err := tx.ExecContext(ctx,
"INSERT INTO orders (id, user_id, total) VALUES ($1, $2, $3)",
order.ID, order.UserID, order.Total,
)
return err
}
```
### Dual Interface Pattern
Some store methods need to work both with and without an explicit transaction. Use an interface that both `*sql.DB` and `*sql.Tx` satisfy:
```go
// DBTX is satisfied by both *sql.DB and *sql.Tx
type DBTX interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type UserStore struct {
db DBTX
}
func NewUserStore(db DBTX) *UserStore {
return &UserStore{db: db}
}
func (s *UserStore) GetUser(ctx context.Context, id string) (*User, error) {
var u User
err := s.db.QueryRowContext(ctx,
"SELECT id, email, name FROM users WHERE id = $1", id,
).Scan(&u.ID, &u.Email, &u.Name)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return &u, err
}
// Usage without transaction
store := NewUserStore(db)
user, err := store.GetUser(ctx, "123")
// Usage within transaction
WithTx(ctx, db, func(tx *sql.Tx) error {
store := NewUserStore(tx)
user, err := store.GetUser(ctx, "123")
// ...
return nil
})
```
## Context-Based Transaction Propagation
For deeply nested call chains, propagate the transaction through context:
```go
type ctxKey struct{}
// TxFromContext retrieves a transaction from context, if present.
func TxFromContext(ctx context.Context) *sql.Tx {
tx, _ := ctx.Value(ctxKey{}).(*sql.Tx)
return tx
}
// ContextWithTx stores a transaction in the context.
func ContextWithTx(ctx context.Context, tx *sql.Tx) context.Context {
return context.WithValue(ctx, ctxKey{}, tx)
}
// Store uses transaction from context if available, otherwise uses db.
func (s *UserStore) GetUser(ctx context.Context, id string) (*User, error) {
var querier DBTX = s.db
if tx := TxFromContext(ctx); tx != nil {
querier = tx
}
var u User
err := querier.QueryRowContext(ctx,
"SELECT id, email, name FROM users WHERE id = $1", id,
).Scan(&u.ID, &u.Email, &u.Name)
return &u, err
}
```
Use this pattern sparingly. It makes the transaction boundary less visible in the code. Prefer explicit `*sql.Tx` parameters when the call chain is shallow.
## Isolation Levels
PostgreSQL supports four isolation levels. Choose based on your consistency requirements:
| Level | Dirty Reads | Non-Repeatable Reads | Phantom Reads | Use Case |
|-------|-------------|---------------------|---------------|----------|
| Read Uncommitted | Prevented* | Possible | Possible | Rarely used in PostgreSQL |
| Read Committed (default) | Prevented | Possible | Possible | Most CRUD operations |
| Repeatable Read | Prevented | Prevented | Prevented** | Reports, aggregations |
| Serializable | Prevented | Prevented | Prevented | Financial transactions |
*PostgreSQL treats Read Uncommitted as Read Committed.
**PostgreSQL's Repeatable Read also prevents phantom reads (unlike the SQL standard minimum).
### When to Change Isolation Level
**Read Committed** (default): Suitable for most web application queries. Each statement sees the latest committed data. Use this unless you have a specific reason not to.
**Repeatable Read**: Use when a transaction reads the same data multiple times and needs consistent results (e.g., generating a report where totals must be consistent across queries).
```go
tx, err := db.BeginTx(ctx, &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
})
```
**Serializable**: Use for operations where concurrent transactions could produce inconsistent results (e.g., checking inventory and placing an order). Serializable transactions may fail with serialization errors and must be retried.
```go
func PlaceOrderSerializable(ctx context.Context, db *sql.DB, order *Order) error {
for retries := 0; retries < 3; retries++ {
err := WithTxOptions(ctx, db, &sql.TxOptions{
Isolation: sql.LevelSerializable,
}, func(tx *sql.Tx) error {
// Check inventory, place order, etc.
return nil
})
if err == nil {
return nil
}
// Check for serialization failure (PostgreSQL error code 40001)
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "40001" {
continue // Retry
}
return err // Non-retryable error
}
return fmt.Errorf("transaction failed after 3 retries")
}
```
## Deadlock Prevention
Deadlocks occur when two transactions wait for each other to release locks. PostgreSQL detects deadlocks and aborts one of the transactions.
### Consistent Lock Ordering
The primary strategy for preventing deadlocks is to always acquire locks in the same order:
```go
// BAD -- Transaction A locks user then order, Transaction B locks order then user
// This can deadlock
// GOOD -- Always lock in the same order (e.g., alphabetical by table, ascending by ID)
func (s *OrderService) PlaceOrder(ctx context.Context, order *Order) error {
return WithTx(ctx, s.db, func(tx *sql.Tx) error {
// Sort items by ID to ensure consistent lock ordering
sort.Slice(order.Items, func(i, j int) bool {
return order.Items[i].ProductID < order.Items[j].ProductID
})
for _, item := range order.Items {
_, err := tx.ExecContext(ctx,
"UPDATE inventory SET quantity = quantity - $1 WHERE product_id = $2",
item.Quantity, item.ProductID,
)
if err != nil {
return err
}
}
return nil
})
}
```
### Advisory Locks
For application-level locking (e.g., ensuring only one instance processes a job):
```go
func withAdvisoryLock(ctx context.Context, tx *sql.Tx, lockID int64, fn func() error) error {
// Acquire lock (released when transaction ends)
_, err := tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", lockID)
if err != nil {
return fmt.Errorf("acquiring advisory lock: %w", err)
}
return fn()
}
```
## Long-Running Transactions
Long-running transactions hold connections from the pool and can cause problems:
- **Connection starvation**: Other requests wait for a connection while the transaction holds one
- **Lock contention**: Rows locked by the transaction block other writes
- **WAL bloat**: PostgreSQL retains WAL segments until long transactions complete
- **Vacuum blocking**: `VACUUM` cannot clean up rows visible to the long transaction
### Mitigation
1. **Set a statement timeout** to prevent runaway queries within a transaction:
```go
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
// Set a 30-second timeout for this transaction
_, err = tx.ExecContext(ctx, "SET LOCAL statement_timeout = '30s'")
if err != nil {
tx.Rollback()
return err
}
```
2. **Use context with timeout** so the entire transaction is bounded:
```go
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
err := WithTx(ctx, db, func(tx *sql.Tx) error {
// If context expires, the transaction is automatically rolled back
return nil
})
```
3. **Break large operations into batches** instead of processing everything in one transaction:
```go
// BAD -- single transaction updating millions of rows
WithTx(ctx, db, func(tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "UPDATE users SET status = 'active'")
return err
})
// GOOD -- batch processing
for {
result, err := db.ExecContext(ctx,
"UPDATE users SET status = 'active' WHERE status = 'pending' LIMIT 1000",
)
if err != nil {
return err
}
rows, _ := result.RowsAffected()
if rows == 0 {
break
}
}
```
4. **Never call external APIs** inside a transaction. If you need to coordinate with an external service, use the saga pattern or outbox pattern instead.
## Testing Transactions
### Test with Real Database
Use a test database and roll back after each test:
```go
func TestPlaceOrder(t *testing.T) {
db := setupTestDB(t)
// Start a transaction for the test
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
tx.Rollback() // Undo all changes after the test
})
// Create store using the test transaction
store := NewOrderStore(tx)
// ... run test assertions ...
}
```
### Test Helper with Savepoints
For tests that need to verify transaction behavior:
```go
func setupTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("postgres", os.Getenv("TEST_DATABASE_URL"))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { db.Close() })
return db
}
func withTestTx(t *testing.T, db *sql.DB, fn func(tx *sql.Tx)) {
t.Helper()
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
fn(tx)
// Transaction is always rolled back -- test data is never committed
}
```
### Testing Transaction Rollback
```go
func TestPlaceOrder_RollsBackOnPaymentFailure(t *testing.T) {
db := setupTestDB(t)
withTestTx(t, db, func(tx *sql.Tx) {
// Setup: create a user and product
_, err := tx.ExecContext(context.Background(),
"INSERT INTO users (id, email, name) VALUES ($1, $2, $3)",
"user-1", "[email protected]", "Test",
)
if err != nil {
t.Fatal(err)
}
// Create a service that will fail on payment
svc := &OrderService{
db: db,
orderStore: NewOrderStore(),
paymentStore: &FailingPaymentStore{}, // Always returns error
}
err = svc.PlaceOrder(context.Background(), &Order{
UserID: "user-1",
Total: 1000,
})
// Verify the error
if err == nil {
t.Fatal("expected error from failing payment")
}
// Verify the order was NOT created (transaction rolled back)
var count int
tx.QueryRowContext(context.Background(),
"SELECT COUNT(*) FROM orders WHERE user_id = $1", "user-1",
).Scan(&count)
if count != 0 {
t.Errorf("expected 0 orders after rollback, got %d", count)
}
})
}
```
## Anti-Patterns
### Starting transactions in store methods
```go
// BAD -- store controls transaction, caller cannot compose
func (s *UserStore) CreateUser(ctx context.Context, u *User) error {
tx, _ := s.db.BeginTx(ctx, nil)
_, err := tx.ExecContext(ctx, "INSERT INTO users ...", ...)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
```
The caller cannot add this operation to a larger transaction. Move transaction management to the service layer.
### Forgetting to handle rollback errors
```go
// BAD -- rollback error is silently ignored
if err := fn(tx); err != nil {
tx.Rollback() // What if this fails?
return err
}
```
Log or wrap the rollback error so you know if cleanup failed.
### Holding transactions open during external API calls
```go
// BAD -- holds a connection and locks while waiting for HTTP response
WithTx(ctx, db, func(tx *sql.Tx) error {
order, _ := orderStore.CreateWithTx(ctx, tx, order)
// This HTTP call might take seconds or time out
paymentResult, err := paymentAPI.Charge(order.Total)
if err != nil {
return err // Transaction held open the entire time
}
return orderStore.UpdateStatusWithTx(ctx, tx, order.ID, "paid")
})
```
Make external calls outside the transaction. Use an outbox pattern if you need to coordinate:
```go
// GOOD -- transaction is short, external call is outside
var order *Order
err := WithTx(ctx, db, func(tx *sql.Tx) error {
var err error
order, err = orderStore.CreateWithTx(ctx, tx, newOrder)
return err
})
if err != nil {
return err
}
// External call outside the transaction
paymentResult, err := paymentAPI.Charge(order.Total)
if err != nil {
// Mark order as failed in a separate transaction
return orderStore.UpdateStatus(ctx, order.ID, "payment_failed")
}
return orderStore.UpdateStatus(ctx, order.ID, "paid")
```
### Not passing context to transaction operations
```go
// BAD -- query is not cancellable
_, err := tx.Exec("SELECT * FROM large_table")
// GOOD -- respects context cancellation
_, err := tx.ExecContext(ctx, "SELECT * FROM large_table")
```
Always use the `Context` variants so queries are cancelled when the request context is done.
Go concurrency patterns for high-throughput web applications including worker pools, rate limiting, race detection, and safe shared state management. Use whe...
---
name: go-concurrency-web
description: Go concurrency patterns for high-throughput web applications including worker pools, rate limiting, race detection, and safe shared state management. Use when implementing background task processing, rate limiters, or concurrent request handling.
---
# Go Concurrency for Web Applications
## Quick Reference
| Topic | Reference |
|-------|-----------|
| Worker Pools & errgroup | [references/worker-pools.md](references/worker-pools.md) |
| Rate Limiting | [references/rate-limiting.md](references/rate-limiting.md) |
| Race Detection & Fixes | [references/race-detection.md](references/race-detection.md) |
## Core Rules
1. **Goroutines are cheap but not free** — each goroutine consumes ~2-8 KB of stack. Unbounded spawning under load leads to OOM.
2. **Always have a shutdown path** — every goroutine you start must have a way to exit. Use `context.Context`, channel closing, or `sync.WaitGroup`.
3. **Prefer channels for communication** — use channels to coordinate work between goroutines and signal completion.
4. **Use mutexes for state protection** — when goroutines share mutable state, protect it with `sync.Mutex`, `sync.RWMutex`, or `sync/atomic`.
5. **Never spawn raw goroutines in HTTP handlers** — use worker pools, `errgroup`, or other bounded concurrency primitives.
## Gates (check before merge or review)
Use these **sequenced** checks for objective pass/fail; do not replace them with “I verified mentally.”
1. **Race detector**
- Run `go test -race ./...` on packages that changed concurrent code, or `go build -race` for binaries under test.
- **Pass:** exit code `0`. If you report “no races,” attach or cite CI output / saved terminal transcript—do not assert cleanliness without that artifact.
2. **Bounded background work from HTTP**
- Inspect handlers and middleware that start work beyond the request goroutine.
- **Pass:** every such path uses a bounded primitive (worker pool, buffered channel with documented capacity, `errgroup` with an explicit concurrency cap)—not unbounded `go` per incoming request.
3. **Graceful teardown**
- For processes that start long-lived goroutines, trace from shutdown signal (or test `defer`) to `Wait()` / channel close / `context` cancel for each goroutine family.
- **Pass:** you can point to the call chain or a test that proves shutdown completes without hang (no orphan goroutines).
## Worker Pool Pattern
Use worker pools for background tasks dispatched from HTTP handlers. This bounds concurrency and provides graceful shutdown.
```go
// Worker pool for background tasks (e.g., sending emails)
type WorkerPool struct {
jobs chan Job
wg sync.WaitGroup
logger *slog.Logger
}
type Job struct {
ID string
Execute func(ctx context.Context) error
}
func NewWorkerPool(numWorkers int, queueSize int, logger *slog.Logger) *WorkerPool {
wp := &WorkerPool{
jobs: make(chan Job, queueSize),
logger: logger,
}
for i := 0; i < numWorkers; i++ {
wp.wg.Add(1)
go wp.worker(i)
}
return wp
}
func (wp *WorkerPool) worker(id int) {
defer wp.wg.Done()
for job := range wp.jobs {
wp.logger.Info("processing job", "worker", id, "job_id", job.ID)
if err := job.Execute(context.Background()); err != nil {
wp.logger.Error("job failed", "worker", id, "job_id", job.ID, "err", err)
}
}
}
func (wp *WorkerPool) Submit(job Job) {
wp.jobs <- job
}
func (wp *WorkerPool) Shutdown() {
close(wp.jobs)
wp.wg.Wait()
}
```
### Usage in HTTP Handler
```go
func (s *Server) handleCreateUser(w http.ResponseWriter, r *http.Request) {
user, err := s.userService.Create(r.Context(), decodeUser(r))
if err != nil {
handleError(w, r, err)
return
}
// Dispatch background task — never spawn raw goroutines in handlers
s.workers.Submit(Job{
ID: "welcome-email-" + user.ID,
Execute: func(ctx context.Context) error {
return s.emailService.SendWelcome(ctx, user)
},
})
writeJSON(w, http.StatusCreated, user)
}
```
See [references/worker-pools.md](references/worker-pools.md) for sizing guidance, backpressure, error handling, retry patterns, and `errgroup` as a simpler alternative.
## Rate Limiting
Use `golang.org/x/time/rate` for token bucket rate limiting. Apply as middleware for global limits or per-IP/per-user limits.
Key points:
- Global rate limiting protects overall service capacity
- Per-IP rate limiting prevents individual clients from monopolizing resources
- Always return `429 Too Many Requests` with a `Retry-After` header
See [references/rate-limiting.md](references/rate-limiting.md) for middleware implementation, per-IP limiting, stale limiter cleanup, and API key-based limiting.
## Race Detection
Run the race detector in development and CI:
```bash
go test -race ./...
go build -race -o myserver ./cmd/server
```
The race detector catches concurrent reads and writes to shared memory. It does not catch logical races (e.g., TOCTOU bugs) or deadlocks.
See [references/race-detection.md](references/race-detection.md) for common web handler races, fixing strategies, and CI integration.
## Handler Safety
Every incoming HTTP request runs in its own goroutine. Any shared mutable state on the server struct is a potential data race.
```go
// BAD — shared state without protection
type Server struct {
requestCount int // data race!
}
func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
s.requestCount++ // concurrent writes = race condition
}
// GOOD — use atomic or mutex
type Server struct {
requestCount atomic.Int64
}
func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
s.requestCount.Add(1)
}
// GOOD — use mutex for complex state
type Server struct {
mu sync.RWMutex
cache map[string]*CachedItem
}
func (s *Server) handleGetCached(w http.ResponseWriter, r *http.Request) {
s.mu.RLock()
item, ok := s.cache[r.PathValue("key")]
s.mu.RUnlock()
// ...
}
```
### Rules for Handler Safety
- **Request-scoped data is safe** — `r.Context()`, request body, URL params are isolated per request.
- **Server struct fields are shared** — any field on `*Server` accessed by handlers needs synchronization.
- **Database connections are safe** — `*sql.DB` manages its own connection pool with internal locking.
- **Maps are not safe** — use `sync.Map` or protect with a mutex.
- **Slices are not safe** — concurrent append or read/write requires a mutex.
## Anti-Patterns
### Unbounded goroutine spawning
```go
// BAD — no limit on concurrent goroutines
func (s *Server) handleWebhook(w http.ResponseWriter, r *http.Request) {
go func() {
// What if 10,000 requests arrive at once?
s.processWebhook(r.Context(), decodeWebhook(r))
}()
w.WriteHeader(http.StatusAccepted)
}
// GOOD — use a worker pool
func (s *Server) handleWebhook(w http.ResponseWriter, r *http.Request) {
webhook := decodeWebhook(r)
s.workers.Submit(Job{
ID: "webhook-" + webhook.ID,
Execute: func(ctx context.Context) error {
return s.processWebhook(ctx, webhook)
},
})
w.WriteHeader(http.StatusAccepted)
}
```
### Forgetting to propagate context
```go
// BAD — loses cancellation signal
func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) {
results, err := s.search(context.Background(), r.URL.Query().Get("q"))
// ...
}
// GOOD — use request context
func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) {
results, err := s.search(r.Context(), r.URL.Query().Get("q"))
// ...
}
```
### Goroutine leak from missing channel receiver
```go
// BAD — goroutine blocks forever if nobody reads the channel
func fetchWithTimeout(ctx context.Context, url string) (*Response, error) {
ch := make(chan *Response)
go func() {
resp, _ := http.Get(url) // blocks forever if ctx cancels
ch <- resp // stuck here if nobody reads
}()
select {
case resp := <-ch:
return resp, nil
case <-ctx.Done():
return nil, ctx.Err() // goroutine leaked!
}
}
// GOOD — use buffered channel so goroutine can exit
func fetchWithTimeout(ctx context.Context, url string) (*Response, error) {
ch := make(chan *Response, 1) // buffered — goroutine can always send
go func() {
resp, _ := http.Get(url)
ch <- resp
}()
select {
case resp := <-ch:
return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
```
### Using `time.Sleep` for coordination
```go
// BAD — sleeping to wait for goroutines
go doWork()
time.Sleep(5 * time.Second) // hoping it finishes
// GOOD — use sync primitives
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
doWork()
}()
wg.Wait()
```
FILE:references/race-detection.md
# Race Detection
## Running the Race Detector
The Go race detector instruments memory accesses at compile time and detects concurrent unsynchronized access at runtime.
```bash
# Run tests with race detection
go test -race ./...
# Build a binary with race detection (for integration testing)
go build -race -o myserver ./cmd/server
# Run a specific test with race detection
go test -race -run TestHandlerConcurrency ./internal/server/
```
**Always run `-race` in CI.** Race conditions are intermittent; the race detector increases the chance of catching them.
## What the Race Detector Catches
The race detector detects **data races**: two goroutines access the same memory location concurrently, and at least one access is a write.
It catches:
- Concurrent map reads and writes
- Concurrent struct field modifications
- Concurrent slice access
- Concurrent variable increments
It does **not** catch:
- **Logical races (TOCTOU)** — checking a condition and acting on it non-atomically
- **Deadlocks** — goroutines waiting on each other forever
- **Starvation** — a goroutine never gets scheduled
- **Race conditions that don't execute** — it only detects races that actually occur during the test run
## Common Race Conditions in Web Handlers
### Race 1: Shared Map Without Lock
```go
// BAD — concurrent map write causes panic
var cache = map[string]string{}
func handler(w http.ResponseWriter, r *http.Request) {
cache[r.URL.Path] = "value" // concurrent map write!
}
// Fix: use sync.Map or mutex
var cache sync.Map
func handler(w http.ResponseWriter, r *http.Request) {
cache.Store(r.URL.Path, "value") // safe
}
```
Note: Concurrent map writes in Go cause a runtime panic, not just incorrect data.
### Race 2: Incrementing a Counter
```go
// BAD — data race on counter
var count int
func handler(w http.ResponseWriter, r *http.Request) {
count++ // data race!
}
// Fix: use atomic
var count atomic.Int64
func handler(w http.ResponseWriter, r *http.Request) {
count.Add(1) // safe
}
```
### Race 3: Slice Append
```go
// BAD — concurrent append is not safe
type Server struct {
events []Event
}
func (s *Server) handleEvent(w http.ResponseWriter, r *http.Request) {
event := decodeEvent(r)
s.events = append(s.events, event) // data race!
}
// Fix: protect with mutex
type Server struct {
mu sync.Mutex
events []Event
}
func (s *Server) handleEvent(w http.ResponseWriter, r *http.Request) {
event := decodeEvent(r)
s.mu.Lock()
s.events = append(s.events, event)
s.mu.Unlock()
}
```
### Race 4: Lazy Initialization
```go
// BAD — multiple goroutines may initialize simultaneously
type Server struct {
client *http.Client
}
func (s *Server) getClient() *http.Client {
if s.client == nil { // race: read
s.client = &http.Client{ // race: write
Timeout: 10 * time.Second,
}
}
return s.client
}
// Fix: use sync.Once
type Server struct {
clientOnce sync.Once
client *http.Client
}
func (s *Server) getClient() *http.Client {
s.clientOnce.Do(func() {
s.client = &http.Client{
Timeout: 10 * time.Second,
}
})
return s.client
}
```
### Race 5: Read-Modify-Write on Struct Field
```go
// BAD — read and write are separate operations
type Server struct {
healthy bool
}
func (s *Server) healthCheck(w http.ResponseWriter, r *http.Request) {
if s.healthy { // race: read
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
}
func (s *Server) setHealthy(healthy bool) {
s.healthy = healthy // race: write
}
// Fix: use atomic.Bool
type Server struct {
healthy atomic.Bool
}
func (s *Server) healthCheck(w http.ResponseWriter, r *http.Request) {
if s.healthy.Load() {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
}
func (s *Server) setHealthy(healthy bool) {
s.healthy.Store(healthy)
}
```
## Fixing Races: When to Use What
| Scenario | Use | Why |
|----------|-----|-----|
| Simple counter | `sync/atomic` (`atomic.Int64`) | Lock-free, minimal overhead |
| Boolean flag | `sync/atomic` (`atomic.Bool`) | Lock-free, minimal overhead |
| Read-heavy cache | `sync.RWMutex` | Multiple concurrent readers, exclusive writers |
| Write-heavy map | `sync.Mutex` | Simple exclusive access |
| Cross-goroutine communication | Channels | Idiomatic Go, naturally synchronizes |
| One-time initialization | `sync.Once` | Guaranteed single execution |
| Concurrent-safe map (simple keys) | `sync.Map` | Built-in safety, good for append-only or key-stable maps |
### sync.Mutex vs sync.RWMutex
Use `sync.RWMutex` when reads significantly outnumber writes:
```go
type Cache struct {
mu sync.RWMutex
data map[string]*Entry
}
// Multiple goroutines can read concurrently
func (c *Cache) Get(key string) (*Entry, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, ok := c.data[key]
return entry, ok
}
// Only one goroutine can write at a time
func (c *Cache) Set(key string, entry *Entry) {
c.mu.Lock()
defer c.mu.Unlock()
c.data[key] = entry
}
```
If reads and writes are roughly equal, use `sync.Mutex` (simpler, less overhead from lock upgrades).
### sync.Map vs Mutex-Protected Map
`sync.Map` is optimized for two patterns:
1. Write-once, read-many (append-only maps)
2. Multiple goroutines read/write disjoint key sets
For everything else, a mutex-protected `map` is usually faster and provides type safety.
## Testing for Race Conditions
### Write Concurrent Tests
```go
func TestHandlerConcurrency(t *testing.T) {
srv := NewServer()
ts := httptest.NewServer(srv)
defer ts.Close()
// Hammer the endpoint from multiple goroutines
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := http.Get(ts.URL + "/api/data")
if err != nil {
t.Errorf("request failed: %v", err)
return
}
resp.Body.Close()
}()
}
wg.Wait()
}
```
### Test Specific Race Scenarios
```go
func TestCacheRace(t *testing.T) {
cache := NewCache()
var wg sync.WaitGroup
// Concurrent writes
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
cache.Set(fmt.Sprintf("key-%d", i), &Entry{Value: i})
}(i)
}
// Concurrent reads
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
cache.Get(fmt.Sprintf("key-%d", i))
}(i)
}
wg.Wait()
}
```
## CI Integration
### GitHub Actions
```yaml
name: Test
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.22'
- name: Test with race detector
run: go test -race -count=1 ./...
- name: Build with race detector (integration)
run: go build -race -o ./bin/server ./cmd/server
```
### Key Flags
- `-race` enables the race detector
- `-count=1` disables test caching (ensures fresh run for race detection)
- The race detector adds ~2-10x runtime overhead and ~5-10x memory overhead
- Acceptable for CI and testing; do **not** ship race-enabled binaries to production
### Race Detector Environment Variables
```bash
# Customize race detector behavior
GORACE="log_path=/tmp/race.log" go test -race ./...
# Halt on first race (useful in CI)
GORACE="halt_on_error=1" go test -race ./...
# Increase history size for complex programs
GORACE="history_size=7" go test -race ./...
```
FILE:references/rate-limiting.md
# Rate Limiting
## Token Bucket Algorithm
`golang.org/x/time/rate` implements a token bucket rate limiter:
- A bucket holds up to **burst** tokens
- Tokens are added at a rate of **rps** (requests per second)
- Each request consumes one token
- If no tokens are available, the request is rejected (or waits)
Example: `rate.NewLimiter(10, 20)` allows 10 requests/second sustained with bursts up to 20.
## Global Rate Limiting
Protect the entire service from being overwhelmed:
```go
// Rate limit middleware using x/time/rate
func RateLimit(rps float64, burst int) func(http.Handler) http.Handler {
limiter := rate.NewLimiter(rate.Limit(rps), burst)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.Allow() {
w.Header().Set("Retry-After", "1")
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// Usage
mux := http.NewServeMux()
mux.HandleFunc("GET /api/data", s.handleGetData)
handler := RateLimit(100, 200)(mux) // 100 rps, burst of 200
```
## Per-IP Rate Limiting
Prevent individual clients from monopolizing resources:
```go
// Per-IP rate limiting
type IPRateLimiter struct {
mu sync.RWMutex
limiters map[string]*rate.Limiter
rps float64
burst int
}
func NewIPRateLimiter(rps float64, burst int) *IPRateLimiter {
return &IPRateLimiter{
limiters: make(map[string]*rate.Limiter),
rps: rps,
burst: burst,
}
}
func (l *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
l.mu.RLock()
limiter, exists := l.limiters[ip]
l.mu.RUnlock()
if exists {
return limiter
}
l.mu.Lock()
defer l.mu.Unlock()
// Double-check after acquiring write lock
if limiter, exists = l.limiters[ip]; exists {
return limiter
}
limiter = rate.NewLimiter(rate.Limit(l.rps), l.burst)
l.limiters[ip] = limiter
return limiter
}
```
### Per-IP Middleware
```go
func PerIPRateLimit(rps float64, burst int, trustProxy bool) func(http.Handler) http.Handler {
limiter := NewIPRateLimiter(rps, burst)
// Start cleanup goroutine
go limiter.cleanup(5 * time.Minute)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := extractIP(r, trustProxy)
if !limiter.GetLimiter(ip).Allow() {
w.Header().Set("Retry-After", "1")
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// extractIP returns the client IP address from the request.
// WARNING: X-Forwarded-For and X-Real-IP headers can be spoofed by clients.
// Only trust these headers when behind a known reverse proxy that strips/overwrites them.
func extractIP(r *http.Request, trustProxy bool) string {
if trustProxy {
// Check X-Forwarded-For for proxied requests
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP (client IP)
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
}
// Fall back to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
```
## Cleaning Up Stale Limiters
Without cleanup, the limiter map grows indefinitely. Remove entries that haven't been used recently:
```go
type trackedLimiter struct {
limiter *rate.Limiter
lastSeen time.Time
}
type IPRateLimiter struct {
mu sync.RWMutex
limiters map[string]*trackedLimiter
rps float64
burst int
}
func (l *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
l.mu.RLock()
tracked, exists := l.limiters[ip]
l.mu.RUnlock()
if exists {
// Update last seen time under write lock
l.mu.Lock()
tracked.lastSeen = time.Now()
l.mu.Unlock()
return tracked.limiter
}
l.mu.Lock()
defer l.mu.Unlock()
// Double-check
if tracked, exists = l.limiters[ip]; exists {
tracked.lastSeen = time.Now()
return tracked.limiter
}
limiter := rate.NewLimiter(rate.Limit(l.rps), l.burst)
l.limiters[ip] = &trackedLimiter{
limiter: limiter,
lastSeen: time.Now(),
}
return limiter
}
func (l *IPRateLimiter) cleanup(maxAge time.Duration) {
ticker := time.NewTicker(maxAge)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
cutoff := time.Now().Add(-maxAge)
for ip, tracked := range l.limiters {
if tracked.lastSeen.Before(cutoff) {
delete(l.limiters, ip)
}
}
l.mu.Unlock()
}
}
```
## Rate Limiting by API Key or User
For authenticated endpoints, rate limit by user identity instead of IP:
```go
type KeyRateLimiter struct {
mu sync.RWMutex
limiters map[string]*trackedLimiter
tiers map[string]Tier // API key -> tier
}
type Tier struct {
RPS float64
Burst int
}
var defaultTiers = map[string]Tier{
"free": {RPS: 10, Burst: 20},
"pro": {RPS: 100, Burst: 200},
"enterprise": {RPS: 1000, Burst: 2000},
}
func (l *KeyRateLimiter) GetLimiter(apiKey string) *rate.Limiter {
l.mu.RLock()
tracked, exists := l.limiters[apiKey]
l.mu.RUnlock()
if exists {
l.mu.Lock()
tracked.lastSeen = time.Now()
l.mu.Unlock()
return tracked.limiter
}
l.mu.Lock()
defer l.mu.Unlock()
if tracked, exists = l.limiters[apiKey]; exists {
tracked.lastSeen = time.Now()
return tracked.limiter
}
tier, ok := l.tiers[apiKey]
if !ok {
tier = defaultTiers["free"]
}
limiter := rate.NewLimiter(rate.Limit(tier.RPS), tier.Burst)
l.limiters[apiKey] = &trackedLimiter{
limiter: limiter,
lastSeen: time.Now(),
}
return limiter
}
```
### API Key Middleware
```go
func APIKeyRateLimit(keyLimiter *KeyRateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
http.Error(w, "missing API key", http.StatusUnauthorized)
return
}
limiter := keyLimiter.GetLimiter(apiKey)
if !limiter.Allow() {
w.Header().Set("Retry-After", "1")
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%.0f", float64(limiter.Limit())))
w.Header().Set("X-RateLimit-Remaining", "0")
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
```
## Returning Proper 429 Responses
Always include informative headers when rejecting rate-limited requests:
```go
func rateLimitResponse(w http.ResponseWriter, limiter *rate.Limiter) {
reservation := limiter.Reserve()
if !reservation.OK() {
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
delay := reservation.Delay()
reservation.Cancel() // We're rejecting, not waiting
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()+1))
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%.0f", float64(limiter.Limit())))
w.Header().Set("X-RateLimit-Remaining", "0")
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]any{
"error": "rate limit exceeded",
"retry_after": delay.Seconds(),
})
}
```
## Combining Rate Limiters
Layer global and per-IP limits for defense in depth:
```go
mux := http.NewServeMux()
mux.HandleFunc("GET /api/data", s.handleGetData)
// Apply per-IP first (inner), then global (outer)
handler := RateLimit(1000, 2000)( // global: 1000 rps
PerIPRateLimit(10, 20)( // per-IP: 10 rps
mux,
),
)
```
This ensures no single IP can use more than 10 rps, while the service overall caps at 1000 rps.
FILE:references/worker-pools.md
# Worker Pools
## Why Worker Pools
Worker pools prevent goroutine leaks and OOM by bounding concurrency. Without a pool, every incoming request that spawns a goroutine can create unbounded parallelism:
- 10,000 requests/second = 10,000 goroutines = ~20 MB minimum (stacks start at ~2 KB, grow as needed)
- Each goroutine may hold open database connections, file descriptors, or network sockets
- The Go scheduler slows down with millions of goroutines
A worker pool with N workers guarantees at most N concurrent background tasks, regardless of request volume.
## Sizing Workers
### CPU-Bound Tasks
For tasks that primarily consume CPU (compression, hashing, image processing):
- **Workers = `runtime.NumCPU()`** or slightly more
- More workers than CPUs adds context-switching overhead with no throughput gain
### I/O-Bound Tasks
For tasks that wait on external services (HTTP calls, database queries, email sending):
- **Workers = 10x to 100x the number of CPUs** is common
- The bottleneck is the external service, not CPU
- Tune based on the external service's capacity and latency
### General Guidance
```go
// CPU-bound: match CPU count
pool := NewWorkerPool(runtime.NumCPU(), 1000, logger)
// I/O-bound: more workers, they spend most time waiting
pool := NewWorkerPool(50, 5000, logger)
```
## Queue Sizing and Backpressure
The buffered channel acts as a queue. Its size determines backpressure behavior:
```go
// Small queue = fast backpressure signal
jobs: make(chan Job, 10)
// Large queue = absorbs bursts but uses more memory
jobs: make(chan Job, 10000)
```
### Handling a Full Queue
When the queue is full, `Submit` blocks. For HTTP handlers, blocking is usually unacceptable. Use a non-blocking submit:
```go
func (wp *WorkerPool) TrySubmit(job Job) bool {
select {
case wp.jobs <- job:
return true
default:
wp.logger.Warn("worker pool full, dropping job", "job_id", job.ID)
return false
}
}
// In handler
func (s *Server) handleWebhook(w http.ResponseWriter, r *http.Request) {
webhook := decodeWebhook(r)
if !s.workers.TrySubmit(Job{
ID: "webhook-" + webhook.ID,
Execute: func(ctx context.Context) error {
return s.processWebhook(ctx, webhook)
},
}) {
http.Error(w, "server busy, try again later", http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusAccepted)
}
```
## Graceful Shutdown Integration
Worker pools must integrate with your server's shutdown sequence. Process in-flight jobs before exiting:
```go
func main() {
logger := slog.Default()
pool := NewWorkerPool(10, 1000, logger)
srv := &http.Server{
Addr: ":8080",
Handler: newRouter(pool),
}
// Start server
go func() {
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
logger.Error("server error", "err", err)
}
}()
// Wait for interrupt
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info("shutting down server...")
// 1. Stop accepting new requests
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
srv.Shutdown(ctx)
// 2. Drain the worker pool (finish in-flight jobs)
logger.Info("draining worker pool...")
pool.Shutdown()
logger.Info("shutdown complete")
}
```
## Error Handling and Retry Patterns
### Logging Errors
At minimum, log all job failures:
```go
func (wp *WorkerPool) worker(id int) {
defer wp.wg.Done()
for job := range wp.jobs {
start := time.Now()
if err := job.Execute(context.Background()); err != nil {
wp.logger.Error("job failed",
"worker", id,
"job_id", job.ID,
"duration", time.Since(start),
"err", err,
)
} else {
wp.logger.Info("job completed",
"worker", id,
"job_id", job.ID,
"duration", time.Since(start),
)
}
}
}
```
### Retry with Backoff
For transient failures, wrap jobs with retry logic:
```go
type RetryJob struct {
Job
MaxRetries int
Backoff time.Duration
}
func (wp *WorkerPool) workerWithRetry(id int) {
defer wp.wg.Done()
for job := range wp.jobs {
retryJob, hasRetry := job.(RetryJob) // type assertion if using interface
maxRetries := 1
backoff := time.Second
if hasRetry {
maxRetries = retryJob.MaxRetries
backoff = retryJob.Backoff
}
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
time.Sleep(backoff * time.Duration(attempt))
}
if err := job.Execute(context.Background()); err != nil {
lastErr = err
wp.logger.Warn("job attempt failed",
"worker", id,
"job_id", job.ID,
"attempt", attempt+1,
"err", err,
)
continue
}
lastErr = nil
break
}
if lastErr != nil {
wp.logger.Error("job exhausted retries",
"worker", id,
"job_id", job.ID,
"err", lastErr,
)
}
}
}
```
## Context Propagation
Workers should use their own context, not the request context. The request context is cancelled when the HTTP response is sent, which happens before the background job runs:
```go
// BAD — request context cancels when response is written
s.workers.Submit(Job{
ID: "send-email",
Execute: func(ctx context.Context) error {
return s.email.Send(r.Context(), user) // r.Context() is already cancelled!
},
})
// GOOD — use background context with timeout
s.workers.Submit(Job{
ID: "send-email",
Execute: func(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
return s.email.Send(ctx, user)
},
})
```
If you need to pass values from the request context (like trace IDs), extract them before submitting:
```go
traceID := traceIDFromContext(r.Context())
userID := userIDFromContext(r.Context())
s.workers.Submit(Job{
ID: "audit-log-" + userID,
Execute: func(ctx context.Context) error {
ctx = withTraceID(ctx, traceID)
return s.audit.Log(ctx, userID, action)
},
})
```
## errgroup as a Simpler Alternative
For fan-out/fan-in within a single request (parallel API calls, batch processing), `golang.org/x/sync/errgroup` is simpler than a worker pool:
```go
// errgroup for bounded parallel work
func (s *Server) handleBatchProcess(w http.ResponseWriter, r *http.Request) {
items := decodeItems(r)
g, ctx := errgroup.WithContext(r.Context())
g.SetLimit(10) // max 10 concurrent goroutines
results := make([]*Result, len(items))
for i, item := range items {
i, item := i, item
g.Go(func() error {
result, err := s.processItem(ctx, item)
if err != nil {
return fmt.Errorf("processing item %d: %w", i, err)
}
results[i] = result
return nil
})
}
if err := g.Wait(); err != nil {
handleError(w, r, err)
return
}
writeJSON(w, http.StatusOK, results)
}
```
### When to Use errgroup vs Worker Pool
| Scenario | Use |
|----------|-----|
| Parallel work within a single request | `errgroup` |
| Background tasks that outlive the request | Worker pool |
| Fan-out to multiple APIs then combine results | `errgroup` |
| Fire-and-forget tasks (emails, webhooks) | Worker pool |
| Batch processing an upload | `errgroup` |
| Long-running async processing | Worker pool |
### errgroup Key Points
- `g.SetLimit(n)` bounds concurrency (available since Go 1.20)
- Context cancellation propagates automatically — if one goroutine returns an error, the context is cancelled for all others
- `g.Wait()` blocks until all goroutines complete and returns the first error
- Safe to write to `results[i]` from goroutine `i` without a mutex because each goroutine writes to a distinct index
Reviews Go code for idiomatic patterns, error handling, concurrency safety, and common mistakes. Use when reviewing .go files, checking error handling, gorou...
---
name: go-code-review
description: Reviews Go code for idiomatic patterns, error handling, concurrency safety, and common mistakes. Use when reviewing .go files, checking error handling, goroutine usage, or interface design. Covers generics (Go 1.18+), errors.Join and slog (Go 1.21+), and Go 1.22 loop variable semantics.
---
# Go Code Review
## Review Workflow
Follow this sequence **in order**. Do not emit findings until every **Pass** below is satisfied.
1. **Baseline `go.mod`** — Open `go.mod` and read the `go` directive.
**Pass:** You can state the exact `go X.YY` value (in the review preamble or working notes). Apply version-gated advice only when it matches this baseline (loop capture pre-1.22, `slog`/structured logging from 1.21, `errors.Join` from 1.20).
2. **Read surrounding code** — For each changed `.go` file, read full functions or logical units that contain the edits, not only the diff hunk.
**Pass:** At least one full enclosing function (or package-level `init`/var block) containing the change was read per changed file.
3. **Scope the checklist** — Decide which [Review Checklist](#review-checklist) blocks apply (error handling, concurrency, interfaces/types, resources, naming). Load [references](#quick-reference) for those blocks; skip blocks that are irrelevant to the diff.
**Pass:** The review (or working notes) lists which checklist blocks you applied, or marks blocks N/A with a one-line reason tied to the diff (e.g. “no concurrency in change”).
4. **Pre-report verification** — Load and follow [review-verification-protocol](../review-verification-protocol/SKILL.md).
**Pass:** The protocol’s **Pre-Report Verification Checklist** is satisfied for each finding you will report (actual code read, surrounding context checked, “wrong” vs “different style” distinguished, etc.).
## Hard gates (same sequence, shorter)
| Step | Objective pass condition |
| --- | --- |
| 1 | `go X.YY` from `go.mod` is recorded before version-specific advice. |
| 2 | Full enclosing context read per changed file, not diff-only. |
| 3 | In-scope checklist blocks listed or N/A with diff-tied reason; references opened as needed. |
| 4 | `review-verification-protocol` completed for every reported issue. |
## Output Format
Report findings as:
```text
[FILE:LINE] ISSUE_TITLE
Severity: Critical | Major | Minor | Informational
Description of the issue and why it matters.
```
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| Missing error checks, wrapping, errors.Join | [references/error-handling.md](references/error-handling.md) |
| Race conditions, channel misuse, goroutine lifecycle | [references/concurrency.md](references/concurrency.md) |
| Interface pollution, naming, generics | [references/interfaces.md](references/interfaces.md) |
| Resource leaks, defer misuse, slog, naming | [references/common-mistakes.md](references/common-mistakes.md) |
## Review Checklist
### Error Handling
- [ ] All errors checked (no `_ = err` without justifying comment)
- [ ] Errors wrapped with context (`fmt.Errorf("...: %w", err)`)
- [ ] `errors.Is`/`errors.As` used instead of string matching
- [ ] `errors.Join` used for aggregating multiple errors (Go 1.20+)
- [ ] Zero values returned alongside errors
### Concurrency
- [ ] No goroutine leaks (context cancellation or shutdown signal exists)
- [ ] Channels closed by sender only, exactly once
- [ ] Shared state protected by mutex or sync types
- [ ] WaitGroups used to wait for goroutine completion
- [ ] Context propagated through call chain
- [ ] Loop variable capture handled (pre-Go 1.22 codebases only)
### Interfaces and Types
- [ ] Interfaces defined by consumers, not producers
- [ ] Interface names follow `-er` convention
- [ ] Interfaces minimal (1-3 methods)
- [ ] Concrete types returned from constructors
- [ ] `any` preferred over `interface{}` (Go 1.18+)
- [ ] Generics used where appropriate instead of `any` or code generation
### Resources and Lifecycle
- [ ] Resources closed with `defer` immediately after creation
- [ ] HTTP response bodies always closed
- [ ] No `defer` in loops without closure wrapping
- [ ] `init()` functions avoided in favor of explicit initialization
### Naming and Style
- [ ] Exported names have doc comments
- [ ] No stuttering names (`user.UserService` → `user.Service`)
- [ ] No naked returns in functions > 5 lines
- [ ] Context passed as first parameter
- [ ] `slog` used over `log` for structured logging (Go 1.21+)
## Severity Calibration
### Critical (Block Merge)
- Unchecked errors on I/O, network, or database operations
- Goroutine leaks (no shutdown path)
- Race conditions on shared state (concurrent map access without sync)
- Unbounded resource accumulation (defer in loop, unclosed connections)
### Major (Should Fix)
- Errors returned without context (bare `return err`)
- Missing WaitGroup for spawned goroutines
- `panic` for recoverable errors
- Context not propagated to downstream calls
### Minor (Consider Fixing)
- `interface{}` instead of `any` in Go 1.18+ codebases
- Missing doc comments on exports
- Stuttering names
- Slice not preallocated when size is known
### Informational (Note Only)
- Suggestions to add generics where code generation exists
- Refactoring ideas for interface design
- Performance optimizations without measured impact
## When to Load References
- Reviewing error return patterns → error-handling.md
- Reviewing goroutines, channels, or sync types → concurrency.md
- Reviewing type definitions, interfaces, or generics → interfaces.md
- General review (resources, naming, init, performance) → common-mistakes.md
## Valid Patterns (Do NOT Flag)
These are acceptable Go patterns — reporting them wastes developer time:
- **`_ = err` with reason comment** — Intentionally ignored errors with explanation
- **Empty interface / `any`** — For truly generic code or interop with untyped APIs
- **Naked returns in short functions** — Acceptable in functions < 5 lines with named returns
- **Channel without close** — When consumer stops via context cancellation, not channel close
- **Mutex protecting struct fields** — Even if accessed only via methods, this is correct encapsulation
- **`//nolint` directives with reason** — Acceptable when accompanied by explanation
- **Defer in loop** — When function scope cleanup is intentional (e.g., processing files in batches)
- **Functional options pattern** — `type Option func(*T)` with `With*` constructors is idiomatic
- **`sync.Pool` for hot paths** — Acceptable for reducing allocation pressure in performance-critical code
- **`context.Background()` in main/tests** — Valid root context for top-level calls
- **`select` with `default`** — Non-blocking channel operation, intentional pattern
- **Short variable names in small scope** — `i`, `err`, `ctx`, `ok` are idiomatic Go
## Context-Sensitive Rules
Only flag these issues when the specific conditions apply:
| Issue | Flag ONLY IF |
|-------|--------------|
| Missing error check | Error return is actionable (can retry, log, or propagate) |
| Goroutine leak | No context cancellation path exists for the goroutine |
| Missing defer | Resource isn't explicitly closed before next acquisition or return |
| Interface pollution | Interface has > 1 method AND only one consumer exists |
| Loop variable capture | `go.mod` specifies Go < 1.22 |
| Missing slog | `go.mod` specifies Go >= 1.21 AND code uses `log` package for structured output |
## Before Submitting Findings
Satisfy **step 4** in [Review Workflow](#review-workflow): load [review-verification-protocol](../review-verification-protocol/SKILL.md) and complete its pre-report checks for each issue.
FILE:references/common-mistakes.md
# Common Mistakes
## Resource Leaks
### 1. Missing defer for Close
Resources leaked on early return. The `defer` should come immediately after the error check for the open/create call.
```go
// BAD
func readFile(path string) ([]byte, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
data, err := io.ReadAll(f)
if err != nil {
return nil, err // file never closed!
}
f.Close()
return data, nil
}
// GOOD - defer immediately
func readFile(path string) ([]byte, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
```
### 2. Defer in Loop
`defer` runs at function exit, not loop iteration exit. In a loop, resources accumulate until the function returns.
```go
// BAD - files stay open until function returns
for _, path := range paths {
f, _ := os.Open(path)
defer f.Close()
process(f)
}
// GOOD - wrap in closure for per-iteration cleanup
for _, path := range paths {
func() {
f, _ := os.Open(path)
defer f.Close()
process(f)
}()
}
```
### 3. HTTP Response Body Not Closed
Every `http.Client` call that returns a non-nil response has a body that must be closed, even if you don't read it. Failing to close it leaks the underlying TCP connection.
```go
// BAD
resp, err := http.Get(url)
if err != nil {
return err
}
data, _ := io.ReadAll(resp.Body)
// GOOD
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
```
## Naming and Style
### 4. Stuttering Names
Package names are part of the identifier at the call site. Repeating the package name in the type or function name creates redundancy.
```go
// BAD
package user
type UserService struct { ... } // user.UserService
// GOOD
package user
type Service struct { ... } // user.Service
```
### 5. Missing Doc Comments on Exports
Exported names without doc comments can't be documented by `godoc`/`pkgsite`. The comment should start with the name being documented.
```go
// BAD
func NewServer(addr string) *Server { ... }
// GOOD
// NewServer creates a new HTTP server listening on addr.
func NewServer(addr string) *Server { ... }
```
### 6. Naked Returns in Long Functions
Named returns are convenient in short functions, but in longer functions they obscure what's being returned. The threshold is roughly 5 lines — beyond that, be explicit.
```go
// BAD
func process(data []byte) (result string, err error) {
// 50 lines of code...
return // what's being returned?
}
// GOOD - explicit returns
func process(data []byte) (string, error) {
// 50 lines of code...
return processedString, nil
}
```
## Initialization
### 7. Init Function Overuse
`init()` functions run before `main()`, create hidden dependencies, make testing harder, and can cause subtle ordering issues when multiple packages have init functions.
```go
// BAD - global state via init
var db *sql.DB
func init() {
var err error
db, err = sql.Open("postgres", os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatal(err)
}
}
// GOOD - explicit initialization
type App struct {
db *sql.DB
}
func NewApp(dbURL string) (*App, error) {
db, err := sql.Open("postgres", dbURL)
if err != nil {
return nil, fmt.Errorf("opening db: %w", err)
}
return &App{db: db}, nil
}
```
### 8. Global Mutable State
Package-level mutable variables create race conditions in concurrent code and make testing unreliable because tests share state.
```go
// BAD
var config Config
func GetConfig() Config {
return config
}
// GOOD - dependency injection
type Server struct {
config Config
}
func NewServer(cfg Config) *Server {
return &Server{config: cfg}
}
```
## Structured Logging (Go 1.21+)
### 9. Using `log` Instead of `slog`
The `log/slog` package (Go 1.21+) provides structured, leveled logging that's far more useful in production than unstructured `log.Println` output.
```go
// OLD - unstructured, hard to parse
log.Printf("failed to load user %d: %v", userID, err)
// MODERN - structured, machine-parseable
slog.Error("failed to load user",
"user_id", userID,
"error", err,
)
// With logger groups and attributes
logger := slog.With("service", "auth")
logger.Info("user logged in",
"user_id", userID,
"ip", req.RemoteAddr,
)
```
Key `slog` patterns:
- Use `slog.With()` to add common attributes to a logger
- Pass `*slog.Logger` as a dependency, don't use the global default in libraries
- Implement `slog.LogValuer` for custom types that appear frequently in logs
- Use `slog.Group()` to namespace related attributes
## Performance
### 10. String Concatenation in Loop
String concatenation with `+` in a loop creates a new string allocation on every iteration, resulting in O(n^2) memory usage.
```go
// BAD
var result string
for _, s := range items {
result += s + ", "
}
// GOOD
var b strings.Builder
for _, s := range items {
b.WriteString(s)
b.WriteString(", ")
}
result := b.String()
```
### 11. Slice Preallocation
When you know the final size, preallocate to avoid repeated backing array copies as the slice grows.
```go
// BAD - grows dynamically
var results []Result
for _, item := range items {
results = append(results, process(item))
}
// GOOD - preallocate known size
results := make([]Result, 0, len(items))
for _, item := range items {
results = append(results, process(item))
}
```
### 12. Range Over Integer (Go 1.22+)
Go 1.22 added `range` over integers, replacing the classic C-style for loop for simple counting:
```go
// OLD
for i := 0; i < n; i++ {
process(i)
}
// MODERN (Go 1.22+)
for i := range n {
process(i)
}
```
## Sync and Performance
### 13. sync.Pool Misuse
Objects returned to a `sync.Pool` must be reset first, otherwise the next consumer gets stale data.
```go
// BAD - not resetting before Put
buf := bufPool.Get().(*bytes.Buffer)
buf.WriteString("data")
bufPool.Put(buf) // still has "data"!
// GOOD - reset before returning to pool
buf := bufPool.Get().(*bytes.Buffer)
defer func() {
buf.Reset()
bufPool.Put(buf)
}()
buf.WriteString("data")
```
### 14. Functional Options
Constructors with many parameters are hard to read and painful to extend. The functional options pattern provides a clean API with sensible defaults.
```go
// BAD - parameter bloat
func NewServer(addr string, timeout time.Duration, logger *slog.Logger, maxConns int) *Server
// GOOD - functional options
type Option func(*Server)
func WithTimeout(d time.Duration) Option {
return func(s *Server) { s.timeout = d }
}
func NewServer(addr string, opts ...Option) *Server {
s := &Server{addr: addr, timeout: 30 * time.Second}
for _, opt := range opts {
opt(s)
}
return s
}
```
## Testing
### 15. Table-Driven Tests Missing
Table-driven tests reduce repetition and make it easy to add new cases.
```go
// BAD
func TestAdd(t *testing.T) {
if Add(1, 2) != 3 {
t.Error("1+2 should be 3")
}
if Add(0, 0) != 0 {
t.Error("0+0 should be 0")
}
}
// GOOD
func TestAdd(t *testing.T) {
tests := []struct {
a, b, want int
}{
{1, 2, 3},
{0, 0, 0},
{-1, 1, 0},
}
for _, tt := range tests {
got := Add(tt.a, tt.b)
if got != tt.want {
t.Errorf("Add(%d, %d) = %d, want %d", tt.a, tt.b, got, tt.want)
}
}
}
```
## Review Questions
1. Is `defer Close()` called immediately after opening resources?
2. Are HTTP response bodies always closed?
3. Are package-level names not stuttering with package name?
4. Do exported symbols have doc comments?
5. Is mutable global state avoided?
6. Are slices preallocated when size is known?
7. Is `slog` used instead of `log` for structured output (Go 1.21+)?
FILE:references/concurrency.md
# Concurrency
## Critical Anti-Patterns
### 1. Goroutine Leak
Goroutines that block forever consume memory and can accumulate over time, eventually exhausting resources.
```go
// BAD - no way to stop the goroutine
func startWorker() {
go func() {
for {
doWork()
}
}()
}
// GOOD - context cancellation
func startWorker(ctx context.Context) {
go func() {
for {
select {
case <-ctx.Done():
return
default:
doWork()
}
}
}()
}
```
### 2. Unbounded Channel Send
If the receiver dies or falls behind, the sender blocks forever. Always provide an escape hatch via context.
```go
// BAD - blocks if nobody reads
ch <- result
// GOOD - respect context
select {
case ch <- result:
case <-ctx.Done():
return ctx.Err()
}
```
### 3. Closing Channel Multiple Times
Closing a closed channel panics at runtime. The rule: only the sender closes the channel, and only once.
```go
// BAD - potential double close
close(ch)
close(ch) // panic!
// GOOD - only sender closes, once
func produce(ch chan<- int) {
defer close(ch)
for i := 0; i < 10; i++ {
ch <- i
}
}
```
### 4. Race Condition on Shared State
Concurrent reads and writes to maps, slices, or structs without synchronization cause data corruption and crashes.
```go
// BAD - concurrent map access
var cache = make(map[string]int)
func Get(key string) int {
return cache[key] // race!
}
func Set(key string, val int) {
cache[key] = val // race!
}
// GOOD - mutex protection
var (
cache = make(map[string]int)
cacheMu sync.RWMutex
)
func Get(key string) int {
cacheMu.RLock()
defer cacheMu.RUnlock()
return cache[key]
}
func Set(key string, val int) {
cacheMu.Lock()
defer cacheMu.Unlock()
cache[key] = val
}
// ALTERNATIVE - sync.Map for simple concurrent access patterns
var cache sync.Map
func Get(key string) (int, bool) {
v, ok := cache.Load(key)
if !ok {
return 0, false
}
return v.(int), true
}
```
### 5. Missing WaitGroup
Without synchronization, the calling function may return before spawned goroutines finish their work.
```go
// BAD - may exit before done
for _, item := range items {
go process(item)
}
return // goroutines may not finish
// GOOD
var wg sync.WaitGroup
for _, item := range items {
wg.Add(1)
go func(item Item) {
defer wg.Done()
process(item)
}(item)
}
wg.Wait()
```
### 6. Loop Variable Capture (Pre-Go 1.22)
**Go 1.22+ fixed this** — each iteration gets its own variable. Only flag in codebases with `go.mod` specifying Go < 1.22.
```go
// ISSUE in Go < 1.22 - all goroutines see the last item
for _, item := range items {
go func() {
process(item) // captures loop variable
}()
}
// FIX for Go < 1.22 - capture in closure parameter
for _, item := range items {
go func(item Item) {
process(item)
}(item)
}
// Go 1.22+ - this is fine, each iteration has its own variable
for _, item := range items {
go func() {
process(item) // safe
}()
}
```
### 7. Context Not Propagated
When context isn't passed to downstream calls, cancellation signals don't reach them. This means timeouts and cancellation from the caller have no effect.
```go
// BAD
func Handler(ctx context.Context) error {
result := doWork() // ignores ctx
return nil
}
// GOOD
func Handler(ctx context.Context) error {
result, err := doWork(ctx)
if err != nil {
return err
}
return nil
}
```
## sync.OnceValue and sync.OnceFunc (Go 1.21+)
These replace the common `sync.Once` + package-level variable pattern with a cleaner API:
```go
// OLD PATTERN
var (
dbOnce sync.Once
db *sql.DB
)
func getDB() *sql.DB {
dbOnce.Do(func() {
db, _ = sql.Open("postgres", os.Getenv("DATABASE_URL"))
})
return db
}
// NEW PATTERN (Go 1.21+) - type-safe, no package variable
var getDB = sync.OnceValue(func() *sql.DB {
db, _ := sql.Open("postgres", os.Getenv("DATABASE_URL"))
return db
})
// With error handling
var getDB = sync.OnceValues(func() (*sql.DB, error) {
return sql.Open("postgres", os.Getenv("DATABASE_URL"))
})
```
## Worker Pool Pattern
```go
func processItems(ctx context.Context, items []Item) error {
const workers = 5
jobs := make(chan Item)
errs := make(chan error, 1)
var wg sync.WaitGroup
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for item := range jobs {
if err := process(ctx, item); err != nil {
select {
case errs <- err:
default:
}
return
}
}
}()
}
go func() {
wg.Wait()
close(errs)
}()
for _, item := range items {
select {
case jobs <- item:
case err := <-errs:
return err
case <-ctx.Done():
return ctx.Err()
}
}
close(jobs)
return <-errs
}
```
## errgroup Pattern
The `golang.org/x/sync/errgroup` package simplifies the worker pool pattern with built-in context cancellation:
```go
func processItems(ctx context.Context, items []Item) error {
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(5)
for _, item := range items {
g.Go(func() error {
return process(ctx, item)
})
}
return g.Wait()
}
```
## Review Questions
1. Are all goroutines stoppable via context?
2. Are channels always closed by the sender?
3. Is shared state protected by mutex or sync types?
4. Are WaitGroups used to wait for goroutine completion?
5. Is context passed through the call chain?
6. Is loop variable capture handled correctly for the target Go version?
7. Are `sync.OnceValue`/`sync.OnceFunc` used instead of `sync.Once` + variable (Go 1.21+)?
FILE:references/error-handling.md
# Error Handling
## Critical Anti-Patterns
### 1. Ignoring Errors
Silent failures are impossible to debug.
```go
// BAD
file, _ := os.Open("config.json")
data, _ := io.ReadAll(file)
// GOOD
file, err := os.Open("config.json")
if err != nil {
return fmt.Errorf("opening config: %w", err)
}
defer file.Close()
```
### 2. Unwrapped Errors
Loses context for debugging. When an error bubbles up through multiple layers, each layer should add context about what it was trying to do.
```go
// BAD - raw error
if err != nil {
return err
}
// GOOD - wrapped with context
if err != nil {
return fmt.Errorf("loading user %d: %w", userID, err)
}
```
### 3. String Errors Instead of Wrapping
Using `%s` or `.Error()` breaks the error chain — callers can no longer use `errors.Is` or `errors.As` to inspect the underlying cause.
```go
// BAD - breaks error inspection
return fmt.Errorf("failed: %s", err.Error())
return fmt.Errorf("failed: %v", err)
// GOOD - preserves error chain
return fmt.Errorf("failed: %w", err)
```
### 4. Panic for Recoverable Errors
Panics crash the program and bypass normal error handling. Reserve them for truly unrecoverable situations (programmer bugs, violated invariants), not for expected failures like I/O errors.
```go
// BAD
func GetConfig(path string) Config {
data, err := os.ReadFile(path)
if err != nil {
panic(err)
}
...
}
// GOOD
func GetConfig(path string) (Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return Config{}, fmt.Errorf("reading config: %w", err)
}
...
}
```
### 5. Checking Error String Instead of Type
Error messages can change between releases. Type-based checking is stable.
```go
// BAD
if err.Error() == "file not found" {
...
}
// GOOD
if errors.Is(err, os.ErrNotExist) {
...
}
// For custom errors
var ErrNotFound = errors.New("not found")
if errors.Is(err, ErrNotFound) {
...
}
```
### 6. Returning Error and Valid Value
Callers expect zero values when errors are returned. Returning a meaningful value alongside an error creates ambiguity about whether the value is usable.
```go
// BAD - -1 is a valid integer, confuses callers
func Parse(s string) (int, error) {
if s == "" {
return -1, errors.New("empty string")
}
...
}
// GOOD - zero value on error
func Parse(s string) (int, error) {
if s == "" {
return 0, errors.New("empty string")
}
...
}
```
## Multi-Error Aggregation (Go 1.20+)
When a function encounters multiple independent errors (cleanup, batch processing, parallel operations), combine them with `errors.Join` instead of dropping all but one.
```go
// BAD - loses the first error
func cleanup(db *sql.DB, f *os.File) error {
err := db.Close()
err = f.Close() // overwrites db error
return err
}
// GOOD - preserves both errors
func cleanup(db *sql.DB, f *os.File) error {
return errors.Join(db.Close(), f.Close())
}
```
`errors.Join` returns `nil` when all errors are `nil`, and the joined error supports `errors.Is`/`errors.As` for each constituent error:
```go
err := errors.Join(ErrNotFound, ErrTimeout)
errors.Is(err, ErrNotFound) // true
errors.Is(err, ErrTimeout) // true
```
This is especially useful in defer chains:
```go
func processFile(path string) (retErr error) {
f, err := os.Open(path)
if err != nil {
return fmt.Errorf("opening %s: %w", path, err)
}
defer func() {
retErr = errors.Join(retErr, f.Close())
}()
// ... process file
}
```
## Sentinel Errors Pattern
```go
// Define at package level
var (
ErrNotFound = errors.New("not found")
ErrUnauthorized = errors.New("unauthorized")
)
// Usage
func GetUser(id int) (*User, error) {
user := db.Find(id)
if user == nil {
return nil, ErrNotFound
}
return user, nil
}
// Caller checks
if errors.Is(err, ErrNotFound) {
http.Error(w, "User not found", 404)
}
```
## Custom Error Types
When you need to carry structured data with an error, implement the `error` interface:
```go
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return fmt.Sprintf("validation failed on %s: %s", e.Field, e.Message)
}
// Caller extracts structured data
var ve *ValidationError
if errors.As(err, &ve) {
log.Printf("field %s: %s", ve.Field, ve.Message)
}
```
## Review Questions
1. Are all error returns checked (no `_`)?
2. Are errors wrapped with context using `%w`?
3. Are sentinel errors used for expected error conditions?
4. Does the code use `errors.Is/As` instead of string matching?
5. Does it return zero values alongside errors?
6. Are multiple independent errors aggregated with `errors.Join`?
FILE:references/interfaces.md
# Interfaces and Types
## Critical Anti-Patterns
### 1. Premature Interface Definition
Interfaces should be defined where they're consumed, not where the implementation lives. Defining them in the producer package couples the abstraction to a specific implementation.
```go
// BAD - interface in producer package
package storage
type UserRepository interface {
Get(id int) (*User, error)
Save(user *User) error
}
type PostgresUserRepository struct { ... }
// GOOD - interface in consumer package
package service
type UserGetter interface {
Get(id int) (*User, error)
}
func NewUserService(users UserGetter) *UserService {
return &UserService{users: users}
}
```
### 2. Interface Pollution (Too Many Methods)
Fat interfaces are hard to implement, hard to mock, and force consumers to depend on methods they don't use.
```go
// BAD - fat interface
type UserStore interface {
Get(id int) (*User, error)
GetAll() ([]*User, error)
Save(user *User) error
Delete(id int) error
Search(query string) ([]*User, error)
Count() (int, error)
}
// GOOD - focused interfaces composed as needed
type UserGetter interface {
Get(id int) (*User, error)
}
type UserSaver interface {
Save(user *User) error
}
type UserStore interface {
UserGetter
UserSaver
}
```
### 3. Wrong Interface Names
Go convention: single-method interfaces are named after the method with an `-er` suffix.
```go
// BAD
type IUserService interface { ... } // Java-style prefix
type UserServiceInterface { ... } // redundant suffix
type UserManager interface { ... } // vague noun
// GOOD - verb forms ending in -er
type UserReader interface {
ReadUser(id int) (*User, error)
}
type UserWriter interface {
WriteUser(user *User) error
}
```
### 4. Returning Interface Instead of Concrete Type
Returning interfaces from constructors hides information from callers and prevents them from accessing implementation-specific methods. Accept interfaces, return structs.
```go
// BAD - returns interface
func NewServer(addr string) Server {
return &httpServer{addr: addr}
}
// GOOD - returns concrete type
func NewServer(addr string) *HTTPServer {
return &HTTPServer{addr: addr}
}
```
### 5. Interface for Single Implementation
An interface with only one implementation adds indirection without benefit. Introduce interfaces when you actually need them (testing, multiple implementations, package boundary decoupling).
```go
// BAD - interface with only one implementation and no tests mocking it
type ConfigLoader interface {
Load() (*Config, error)
}
type fileConfigLoader struct { ... }
// GOOD - just use the concrete type until you need the abstraction
type ConfigLoader struct { ... }
func (c *ConfigLoader) Load() (*Config, error) { ... }
```
## Generics (Go 1.18+)
### Prefer `any` over `interface{}`
The `any` keyword is an alias for `interface{}` introduced in Go 1.18. It's clearer and more idiomatic in modern Go code.
```go
// OLD
func Process(data interface{}) interface{} { ... }
// MODERN
func Process(data any) any { ... }
```
### Use Type Constraints Instead of `any`
When you know the set of types you need, use constraints to preserve type safety. `any` in a generic function means you've given up type checking.
```go
// BAD - any constraint means no useful operations
func Max[T any](a, b T) T {
// Can't compare a and b!
}
// GOOD - constrained to comparable and ordered types
func Max[T cmp.Ordered](a, b T) T {
if a > b {
return a
}
return b
}
```
### Common Generic Anti-Patterns
```go
// BAD - generic function that only works with one type
func ParseUserID[T ~string](s T) (int, error) {
return strconv.Atoi(string(s))
}
// Just use string directly
// BAD - over-genericized struct
type Cache[K comparable, V any] struct { ... }
// Only used as Cache[string, *User] throughout the codebase
// Generics add value when there are multiple instantiations
// GOOD - generics for truly reusable code
func Map[T, U any](slice []T, fn func(T) U) []U {
result := make([]U, len(slice))
for i, v := range slice {
result[i] = fn(v)
}
return result
}
```
### Type Constraints with `~` (Underlying Types)
The `~` prefix matches types with the same underlying type, which is important for custom types:
```go
type UserID int64
// Without ~: only accepts int64, not UserID
func Format[T int64](id T) string { ... }
// With ~: accepts int64 AND UserID
func Format[T ~int64](id T) string { ... }
```
## Accept Interfaces, Return Structs
```go
// Function accepts interface (flexible)
func WriteData(w io.Writer, data []byte) error {
_, err := w.Write(data)
return err
}
// Function returns concrete type (explicit)
func NewBuffer() *bytes.Buffer {
return &bytes.Buffer{}
}
// Usage
buf := NewBuffer()
WriteData(buf, []byte("hello")) // Buffer implements io.Writer
```
## Standard Library Interfaces to Use
Prefer these over custom interfaces when your use case matches:
| Interface | Package | Use When |
|-----------|---------|----------|
| `io.Reader` | io | Anything that provides bytes |
| `io.Writer` | io | Anything that accepts bytes |
| `io.Closer` | io | Anything that releases resources |
| `fmt.Stringer` | fmt | Custom string representation |
| `error` | builtin | Any error condition |
| `sort.Interface` | sort | Custom sort ordering (pre-generics; prefer `slices.SortFunc` in Go 1.21+) |
| `encoding.TextMarshaler` | encoding | Custom text serialization |
| `slog.LogValuer` | log/slog | Custom structured log values (Go 1.21+) |
## Review Questions
1. Are interfaces defined where they're used (consumer side)?
2. Are interfaces minimal (1-3 methods)?
3. Do interface names end in `-er`?
4. Are concrete types returned from constructors?
5. Is `any` used instead of `interface{}` (Go 1.18+)?
6. Are generics used where they add real value (multiple instantiations)?
7. Are type constraints specific enough (not just `any`)?
Go application architecture with net/http 1.22+ routing, project structure patterns, graceful shutdown, and dependency injection. Use when building Go web se...
---
name: go-architect
description: Go application architecture with net/http 1.22+ routing, project structure patterns, graceful shutdown, and dependency injection. Use when building Go web servers, designing project layout, or structuring application dependencies.
---
# Lead Go Architect
## Quick Reference
| Topic | Reference |
|-------|-----------|
| Flat vs modular project layout, migration signals | [references/project-structure.md](references/project-structure.md) |
| Graceful shutdown with signal handling | [references/graceful-shutdown.md](references/graceful-shutdown.md) |
| Dependency injection patterns, testing seams | [references/dependency-injection.md](references/dependency-injection.md) |
## Core Principles
1. **Standard library first** -- Use `net/http` and the Go 1.22+ enhanced `ServeMux` for routing. Only reach for a framework (chi, echo, gin) when you have a concrete need the stdlib cannot satisfy (e.g., complex middleware chains, regex routes).
2. **Dependency injection over globals** -- Pass databases, loggers, and services through struct fields and constructors, never package-level `var`.
3. **Explicit over magic** -- No `init()` side effects, no framework auto-wiring. `main.go` is the composition root where everything is assembled visibly.
4. **Small interfaces, big structs** -- Define interfaces at the consumer, keep them narrow (1-3 methods). Concrete types carry the implementation.
## Hard gates
Use this sequence when implementing or reviewing work that claims to follow this skill. Do not skip ahead; each step has a **pass condition** you can answer with tooling or a concrete file path.
1. **Toolchain vs APIs** — If the code uses Go 1.22+ `ServeMux` features (method+path patterns like `"GET /x/{id}"`, `r.PathValue`, or `{path...}`): run `go version` and **pass** only if the reported toolchain is **go1.22+**. If the project must stay on an older Go, **pass** only by not using those APIs (use a compatible router or older patterns) and say so in the review or PR.
2. **Composition root** — **Pass** when `main.go` or `cmd/.../main.go` visibly constructs the server and injects shared dependencies (DB, logger, config). **Fail** if shared dependencies are wired in `init()` or package-level `var` instead of explicit construction in `main` (or a `run()` called from `main`).
3. **Production HTTP shutdown** — For a long-lived HTTP service, **pass** only if shutdown uses `http.Server.Shutdown` with a **bounded** context (e.g. `context.WithTimeout`) after waiting on `signal.NotifyContext` (or equivalent). Cite the file path when reporting; see [references/graceful-shutdown.md](references/graceful-shutdown.md) for the full pattern.
4. **No env/globals in handlers** — **Pass** when handlers and domain code take dependencies via structs/arguments. **Fail** if handlers read `os.Getenv` for secrets or use package-level `var` for DB/clients (loading env in `main` or a dedicated config package is fine).
## Go 1.22+ Enhanced Routing
Go 1.22 upgraded `http.ServeMux` with method-based routing and path parameters, eliminating the most common reason for third-party routers.
### Method-Based Routing and Path Parameters
```go
mux := http.NewServeMux()
mux.HandleFunc("GET /api/users", s.handleListUsers)
mux.HandleFunc("GET /api/users/{id}", s.handleGetUser)
mux.HandleFunc("POST /api/users", s.handleCreateUser)
mux.HandleFunc("DELETE /api/users/{id}", s.handleDeleteUser)
```
### Extracting Path Parameters
```go
func (s *Server) handleGetUser(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "" {
http.Error(w, "missing id", http.StatusBadRequest)
return
}
user, err := s.users.GetUser(r.Context(), id)
if err != nil {
s.logger.Error("getting user", "err", err, "id", id)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(user)
}
```
### Wildcard and Exact Match
```go
// Exact match on trailing slash -- serves /api/files/ only
mux.HandleFunc("GET /api/files/", s.handleListFiles)
// Wildcard to end of path -- /api/files/path/to/doc.txt
mux.HandleFunc("GET /api/files/{path...}", s.handleGetFile)
```
### Routing Precedence
The new `ServeMux` uses most-specific-wins precedence:
- `GET /api/users/{id}` is more specific than `GET /api/users/`
- `GET /api/users/me` is more specific than `GET /api/users/{id}`
- Method routes take precedence over method-less routes
## Server Struct Pattern
The Server struct is the central dependency container for your application. It holds all shared dependencies and implements `http.Handler`.
```go
type Server struct {
db *sql.DB
logger *slog.Logger
router *http.ServeMux
}
func NewServer(db *sql.DB, logger *slog.Logger) *Server {
s := &Server{
db: db,
logger: logger,
router: http.NewServeMux(),
}
s.routes()
return s
}
func (s *Server) routes() {
s.router.HandleFunc("GET /api/users/{id}", s.handleGetUser)
s.router.HandleFunc("POST /api/users", s.handleCreateUser)
s.router.HandleFunc("GET /healthz", s.handleHealth)
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
```
### Middleware Wrapping
Apply middleware at the `http.Server` level or per-route:
```go
// Wrap entire server
httpServer := &http.Server{
Addr: ":8080",
Handler: requestLogger(s),
}
// Or per-route
s.router.Handle("GET /api/admin/", adminOnly(http.HandlerFunc(s.handleAdmin)))
```
### Middleware Signature
```go
func requestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
slog.Info("request", "method", r.Method, "path", r.URL.Path, "dur", time.Since(start))
})
}
```
## Project Structure
Choose based on project size:
- **Flat structure** -- single package, all files in root. Best for CLIs, small services, < ~10 handlers. See [references/project-structure.md](references/project-structure.md).
- **Modular/domain-driven** -- `cmd/`, `internal/` with domain packages. For larger apps with multiple bounded contexts. See [references/project-structure.md](references/project-structure.md).
Start flat. Migrate when you see the signs described in the reference.
## Graceful Shutdown
Every production Go server needs graceful shutdown. The pattern uses `signal.NotifyContext` to listen for OS signals and `http.Server.Shutdown` to drain connections.
```go
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer cancel()
// ... start server in goroutine ...
<-ctx.Done()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
httpServer.Shutdown(shutdownCtx)
```
Full pattern with cleanup ordering in [references/graceful-shutdown.md](references/graceful-shutdown.md).
## When to Load References
Load **project-structure.md** when:
- Scaffolding a new Go project
- Discussing package layout or directory organization
- The project is growing and needs restructuring
Load **graceful-shutdown.md** when:
- Setting up a production HTTP server
- Implementing signal handling or clean shutdown
- Discussing deployment or container readiness
Load **dependency-injection.md** when:
- Designing how services, stores, and handlers connect
- Making code testable with interfaces
- Reviewing constructor functions or wiring logic
## Anti-Patterns
### Global database variables
```go
// BAD -- untestable, hidden dependency
var db *sql.DB
func handleGetUser(w http.ResponseWriter, r *http.Request) {
db.QueryRow(...)
}
```
Pass `db` through a Server or Service struct instead.
### Framework-first thinking
Do not start with `gin.Default()` or `echo.New()`. Start with `http.NewServeMux()`. Only introduce a framework if you hit a real limitation of the stdlib that justifies the dependency.
### God packages
A single `handlers` package with 50 files is not organization. Group by domain (`user`, `order`, `billing`), not by technical layer.
### Using init() for setup
```go
// BAD -- invisible side effects, untestable
func init() {
db, _ = sql.Open("postgres", os.Getenv("DATABASE_URL"))
}
```
All initialization belongs in `main()` or a `run()` function so it can be tested and errors can be handled.
### Reading config in business logic
```go
// BAD -- couples handler to environment
func (s *Server) handleSendEmail(w http.ResponseWriter, r *http.Request) {
apiKey := os.Getenv("SENDGRID_API_KEY") // don't do this
}
```
Inject configuration values or clients through constructors.
FILE:references/dependency-injection.md
# Dependency Injection in Go
Go does not need a DI framework. The language's interfaces, structs, and constructor functions provide everything necessary for clean dependency injection.
## Server Struct as Dependency Container
The Server struct holds all shared dependencies and exposes HTTP handlers as methods. This is the simplest form of DI in Go.
```go
type Server struct {
users *user.Service
orders *order.Service
logger *slog.Logger
router *http.ServeMux
}
func NewServer(users *user.Service, orders *order.Service, logger *slog.Logger) *Server {
s := &Server{
users: users,
orders: orders,
logger: logger,
router: http.NewServeMux(),
}
s.routes()
return s
}
```
Dependencies are explicit: you can see exactly what the server needs by looking at its struct fields and constructor signature.
## Constructor Functions
Every component provides a `New*` constructor that accepts its dependencies and returns a ready-to-use instance.
```go
// user/store.go
func NewPostgresStore(db *sql.DB) *PostgresStore {
return &PostgresStore{db: db}
}
// user/service.go
func NewService(store Store) *Service {
return &Service{store: store}
}
```
Constructors should:
- Accept only what the component actually uses
- Return a concrete type (not an interface)
- Not perform I/O (no database pings, no HTTP calls)
- Panic only if the dependency is nil and the component cannot function without it
```go
func NewService(store Store) *Service {
if store == nil {
panic("user: store is required")
}
return &Service{store: store}
}
```
## Layered Dependency Injection
`main.go` (or the `run()` function) is the composition root. It creates all dependencies in order and wires them together. No other part of the application creates its own dependencies.
```go
// cmd/server/main.go
func run(ctx context.Context) error {
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("loading config: %w", err)
}
// Layer 1: Infrastructure
db, err := sql.Open("postgres", cfg.DatabaseURL)
if err != nil {
return fmt.Errorf("opening db: %w", err)
}
defer db.Close()
// Layer 2: Stores (depend on infrastructure)
userStore := user.NewPostgresStore(db)
orderStore := order.NewPostgresStore(db)
// Layer 3: Services (depend on stores)
userService := user.NewService(userStore)
orderService := order.NewService(orderStore, userService)
// Layer 4: HTTP server (depends on services)
srv := NewServer(userService, orderService, slog.Default())
// ... start server ...
return nil
}
```
The dependency graph is a tree built from bottom (infrastructure) to top (HTTP layer). Each layer only knows about the layer directly below it.
## Interface-Based Dependencies for Testability
Define interfaces at the consumer, not the producer. Keep them small.
```go
// user/service.go
package user
// Store is defined where it is used, not where it is implemented
type Store interface {
GetByID(ctx context.Context, id string) (*User, error)
Create(ctx context.Context, u *User) error
List(ctx context.Context, limit, offset int) ([]User, error)
}
type Service struct {
store Store
}
func NewService(store Store) *Service {
return &Service{store: store}
}
```
The concrete implementation lives in a separate file or package:
```go
// user/postgres_store.go
package user
type PostgresStore struct {
db *sql.DB
}
func NewPostgresStore(db *sql.DB) *PostgresStore {
return &PostgresStore{db: db}
}
func (s *PostgresStore) GetByID(ctx context.Context, id string) (*User, error) {
row := s.db.QueryRowContext(ctx, "SELECT id, name, email FROM users WHERE id = $1", id)
var u User
if err := row.Scan(&u.ID, &u.Name, &u.Email); err != nil {
return nil, fmt.Errorf("querying user %s: %w", id, err)
}
return &u, nil
}
// ... other Store methods ...
```
### Testing with Mock Implementations
```go
// user/service_test.go
package user
type mockStore struct {
users map[string]*User
}
func (m *mockStore) GetByID(ctx context.Context, id string) (*User, error) {
u, ok := m.users[id]
if !ok {
return nil, fmt.Errorf("user not found: %s", id)
}
return u, nil
}
func (m *mockStore) Create(ctx context.Context, u *User) error {
m.users[u.ID] = u
return nil
}
func (m *mockStore) List(ctx context.Context, limit, offset int) ([]User, error) {
var result []User
for _, u := range m.users {
result = append(result, *u)
}
return result, nil
}
func TestGetUser(t *testing.T) {
store := &mockStore{
users: map[string]*User{
"1": {ID: "1", Name: "Alice", Email: "[email protected]"},
},
}
svc := NewService(store)
u, err := svc.GetUser(context.Background(), "1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if u.Name != "Alice" {
t.Errorf("got name %q, want %q", u.Name, "Alice")
}
}
```
No mocking library needed. Go interfaces make manual test doubles straightforward.
## Configuration as a Dependency
Business logic should never read environment variables or config files directly. Configuration is loaded once in `main.go` and passed as explicit values to constructors.
```go
// config/config.go
package config
type Config struct {
DatabaseURL string `env:"DATABASE_URL,required"`
Addr string `env:"ADDR" default:":8080"`
ShutdownTimeout time.Duration `env:"SHUTDOWN_TIMEOUT" default:"10s"`
SendGrid SendGridConfig
}
type SendGridConfig struct {
APIKey string `env:"SENDGRID_API_KEY,required"`
FromAddr string `env:"SENDGRID_FROM" default:"[email protected]"`
}
```
Pass only what each component needs, not the entire config:
```go
// GOOD -- emailer gets only its own config
emailer := email.NewSendGridEmailer(cfg.SendGrid.APIKey, cfg.SendGrid.FromAddr)
// BAD -- emailer receives entire application config
emailer := email.NewSendGridEmailer(cfg)
```
This keeps components decoupled from the config structure and makes their requirements visible.
## Functional Options for Optional Dependencies
When a constructor has many optional parameters, use the functional options pattern:
```go
type Server struct {
db *sql.DB
logger *slog.Logger
cache Cache
metrics MetricsRecorder
}
type Option func(*Server)
func WithCache(c Cache) Option {
return func(s *Server) {
s.cache = c
}
}
func WithMetrics(m MetricsRecorder) Option {
return func(s *Server) {
s.metrics = m
}
}
func NewServer(db *sql.DB, logger *slog.Logger, opts ...Option) *Server {
s := &Server{
db: db,
logger: logger,
cache: noopCache{}, // sensible default
metrics: noopMetrics{}, // sensible default
}
for _, opt := range opts {
opt(s)
}
s.routes()
return s
}
```
Usage:
```go
// Minimal -- uses defaults for cache and metrics
srv := NewServer(db, logger)
// With optional dependencies
srv := NewServer(db, logger,
WithCache(redisCache),
WithMetrics(promMetrics),
)
```
Use functional options when:
- There are more than 3-4 optional parameters
- You want sensible defaults that can be overridden
- The constructor signature is growing unwieldy
Do not use functional options for required dependencies. Those belong as regular constructor parameters.
## Cross-Domain Dependencies
When one domain needs data from another, define an interface in the consuming package:
```go
// internal/order/service.go
package order
// UserLookup is what the order domain needs from the user domain
type UserLookup interface {
GetByID(ctx context.Context, id string) (*UserInfo, error)
}
// UserInfo contains only what orders need -- not the full user model
type UserInfo struct {
ID string
Name string
Email string
}
type Service struct {
store Store
userLookup UserLookup
}
func NewService(store Store, userLookup UserLookup) *Service {
return &Service{store: store, userLookup: userLookup}
}
```
The user service satisfies this interface without knowing about it:
```go
// cmd/server/main.go
orderService := order.NewService(orderStore, userService) // userService satisfies order.UserLookup
```
This keeps domains decoupled. The order package never imports the user package.
## Anti-Patterns
### Global Database Variable
```go
// BAD
package db
var DB *sql.DB
func init() {
var err error
DB, err = sql.Open("postgres", os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatal(err)
}
}
```
Problems:
- Impossible to test with a different database
- Hidden dependency -- callers don't declare they need a DB
- `init()` runs at import time, before `main()`, making startup order unpredictable
- `log.Fatal` in `init()` prevents graceful error handling
Fix: Pass `*sql.DB` through constructors.
### Reading Environment Variables in Handlers
```go
// BAD
func (s *Server) handleSendEmail(w http.ResponseWriter, r *http.Request) {
apiKey := os.Getenv("SENDGRID_API_KEY")
client := sendgrid.NewClient(apiKey)
// ...
}
```
Problems:
- Creates a new client on every request
- Cannot test without setting env vars
- Handler does infrastructure work
Fix: Inject a pre-configured email client through the Server struct.
```go
// GOOD
type Server struct {
emailer EmailSender
}
func (s *Server) handleSendEmail(w http.ResponseWriter, r *http.Request) {
err := s.emailer.Send(r.Context(), to, subject, body)
// ...
}
```
### Passing Entire Config Struct
```go
// BAD -- emailer knows about database config, server port, etc.
func NewEmailer(cfg *config.Config) *Emailer {
return &Emailer{apiKey: cfg.SendGrid.APIKey}
}
```
Problems:
- Component knows about the entire configuration shape
- Cannot tell what the emailer actually needs without reading its code
- Refactoring config structure breaks unrelated components
Fix: Pass individual values or a small, focused config struct.
```go
// GOOD
func NewEmailer(apiKey string, fromAddr string) *Emailer {
return &Emailer{apiKey: apiKey, fromAddr: fromAddr}
}
```
### Using init() for Dependency Setup
```go
// BAD
var userService *UserService
func init() {
db := connectDB()
store := NewPostgresStore(db)
userService = NewService(store)
}
```
Problems:
- Runs before `main()`, no error handling
- Global state, untestable
- Invisible side effects at import time
- Order of `init()` across packages is hard to reason about
Fix: Build the dependency graph explicitly in `main.go` or `run()`.
FILE:references/graceful-shutdown.md
# Graceful Shutdown
Every production Go HTTP server must handle shutdown gracefully: finish in-flight requests, close database connections, and flush buffers before exiting. An abrupt `os.Exit` or unhandled signal drops active requests and can corrupt data.
## Full Pattern
```go
package main
import (
"context"
"database/sql"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
func run(ctx context.Context) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer cancel()
db, err := sql.Open("postgres", os.Getenv("DATABASE_URL"))
if err != nil {
return fmt.Errorf("opening db: %w", err)
}
defer db.Close()
srv := NewServer(db, slog.Default())
httpServer := &http.Server{
Addr: ":8080",
Handler: srv,
}
errCh := make(chan error, 1)
go func() {
slog.Info("server starting", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errCh <- err
}
}()
// Wait for interrupt signal or server error
select {
case <-ctx.Done():
case err := <-errCh:
return fmt.Errorf("server listen: %w", err)
}
slog.Info("shutting down gracefully")
// Give outstanding requests time to complete
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("server shutdown: %w", err)
}
return nil
}
func main() {
if err := run(context.Background()); err != nil {
slog.Error("application error", "err", err)
os.Exit(1)
}
}
```
## Why `run()` Returns an Error
Separating `run()` from `main()` provides several benefits:
1. **Testability** -- You can call `run()` in tests with a cancelable context and verify behavior without starting a real process.
2. **Clean error handling** -- `run()` uses normal Go error returns instead of `log.Fatal()`, which calls `os.Exit(1)` and skips deferred cleanup.
3. **Deferred cleanup runs** -- Since `run()` returns instead of exiting, all `defer` statements (db.Close(), cancel(), etc.) execute in order.
4. **Single exit point** -- `main()` is the only place that calls `os.Exit`, making the exit path predictable.
```go
// BAD -- defers never run, no cleanup
func main() {
db, err := sql.Open(...)
if err != nil {
log.Fatal(err) // calls os.Exit(1), skips defer db.Close()
}
defer db.Close()
// ...
}
// GOOD -- all defers run, clean exit
func main() {
if err := run(context.Background()); err != nil {
slog.Error("application error", "err", err)
os.Exit(1)
}
}
```
## signal.NotifyContext vs signal.Notify
### signal.NotifyContext (preferred)
```go
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer cancel()
<-ctx.Done() // blocks until signal received
```
Benefits:
- Returns a standard `context.Context` that integrates with the rest of the application
- Cancelation propagates to all child contexts automatically
- `defer cancel()` cleans up signal registration
- Idiomatic for modern Go code
### signal.Notify (older pattern)
```go
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
<-quit // blocks until signal received
```
Use `signal.Notify` only when you need to handle the same signal multiple times or perform special signal-specific logic. For typical graceful shutdown, `signal.NotifyContext` is cleaner.
## Shutdown Timeout Configuration
The shutdown timeout controls how long the server waits for in-flight requests to complete before forcefully closing connections.
```go
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
```
### Choosing a Timeout Value
| Scenario | Recommended Timeout |
|----------|-------------------|
| API with fast queries | 5-10 seconds |
| Long-polling / SSE | 30 seconds |
| File uploads | 60 seconds |
| WebSocket connections | 30-60 seconds |
Make the timeout configurable:
```go
type Config struct {
ShutdownTimeout time.Duration `env:"SHUTDOWN_TIMEOUT" default:"10s"`
}
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout)
```
### What Happens When the Timeout Expires
If `httpServer.Shutdown(shutdownCtx)` exceeds the timeout, it returns `context.DeadlineExceeded`. At that point:
- Any remaining connections are forcefully closed
- The server stops accepting new connections (this happens immediately on Shutdown call)
- Clients with active requests receive connection-reset errors
## Cleanup Ordering
Resources must be cleaned up in reverse order of creation. The server should stop accepting new requests before closing the resources those requests depend on.
```go
func run(ctx context.Context) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer cancel()
// 1. Open database (first resource created)
db, err := sql.Open("postgres", os.Getenv("DATABASE_URL"))
if err != nil {
return fmt.Errorf("opening db: %w", err)
}
defer db.Close() // 4. Close database LAST (after server is done)
// 2. Create cache client
cache := redis.NewClient(...)
defer cache.Close() // 3. Close cache after server, before database
srv := NewServer(db, cache, slog.Default())
httpServer := &http.Server{
Addr: ":8080",
Handler: srv,
}
go func() {
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
slog.Error("server error", "err", err)
}
}()
<-ctx.Done()
// Shutdown server FIRST -- drains in-flight requests that use db and cache
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("server shutdown: %w", err)
}
// Then defers run in LIFO order: cache.Close(), then db.Close()
return nil
}
```
The ordering is:
1. Stop accepting new connections (`Shutdown` called)
2. Wait for in-flight requests to finish (up to timeout)
3. Close cache (deferred, LIFO)
4. Close database (deferred, LIFO)
### With Background Workers
If you have background goroutines (job processors, consumers), shut them down after the HTTP server but before closing shared resources:
```go
<-ctx.Done()
// 1. Stop HTTP server
httpServer.Shutdown(shutdownCtx)
// 2. Stop background workers (they may still use db)
workerCancel()
workerWg.Wait()
// 3. Defers close db, cache, etc.
return nil
```
## Health Check Endpoint
In container orchestrators (Kubernetes, ECS), the health check should start failing before the server shuts down. This tells the load balancer to stop sending new traffic.
```go
type Server struct {
db *sql.DB
logger *slog.Logger
router *http.ServeMux
healthy atomic.Bool
}
func NewServer(db *sql.DB, logger *slog.Logger) *Server {
s := &Server{
db: db,
logger: logger,
router: http.NewServeMux(),
}
s.healthy.Store(true)
s.routes()
return s
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
if !s.healthy.Load() {
w.WriteHeader(http.StatusServiceUnavailable)
fmt.Fprintln(w, "shutting down")
return
}
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "ok")
}
// Call before starting Shutdown
func (s *Server) SetUnhealthy() {
s.healthy.Store(false)
}
```
### Shutdown Sequence with Health Check
```go
<-ctx.Done()
slog.Info("shutting down gracefully")
// 1. Mark unhealthy -- load balancer stops sending new traffic
srv.SetUnhealthy()
// 2. Wait for load balancer to detect unhealthy status
// This depends on your health check interval (typically 5-10s)
time.Sleep(5 * time.Second)
// 3. Shut down server -- drain remaining in-flight requests
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("server shutdown: %w", err)
}
```
The sleep between marking unhealthy and calling Shutdown gives the load balancer time to route traffic elsewhere. Without this, new requests may arrive at a server that is already draining.
## Complete Production Template
```go
func run(ctx context.Context) error {
cfg, err := loadConfig()
if err != nil {
return fmt.Errorf("loading config: %w", err)
}
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer cancel()
db, err := sql.Open("postgres", cfg.DatabaseURL)
if err != nil {
return fmt.Errorf("opening db: %w", err)
}
defer db.Close()
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("pinging db: %w", err)
}
srv := NewServer(db, slog.Default())
httpServer := &http.Server{
Addr: cfg.Addr,
Handler: srv,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 60 * time.Second,
}
errCh := make(chan error, 1)
go func() {
slog.Info("server starting", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errCh <- fmt.Errorf("server listen: %w", err)
}
}()
// Wait for signal or server error
select {
case err := <-errCh:
return err
case <-ctx.Done():
}
slog.Info("shutting down gracefully")
srv.SetUnhealthy()
time.Sleep(cfg.HealthDrainDelay)
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("server shutdown: %w", err)
}
slog.Info("server stopped")
return nil
}
```
FILE:references/project-structure.md
# Go Project Structure
## Flat Structure
Best for small applications, CLIs, microservices with fewer than ~10 handlers, and projects where a single developer or small team owns the entire codebase.
```text
myapp/
├── main.go
├── server.go
├── handlers.go
├── middleware.go
├── models.go
├── store.go
├── server_test.go
├── handlers_test.go
└── go.mod
```
### When to Use
- CLI tools and small utilities
- Single-purpose microservices (one bounded context)
- Prototypes and proofs of concept
- Fewer than ~10 HTTP handlers
- One or two developers working on the codebase
### Benefits
- Zero cognitive overhead for navigation -- everything is in one place
- No circular dependency issues (single package)
- Easy to refactor -- just move functions between files
- `go test ./...` covers everything in one pass
- New contributors can understand the layout immediately
### File Responsibilities
| File | Contains |
|------|----------|
| `main.go` | `func main()`, wiring, configuration loading, `run()` function |
| `server.go` | `Server` struct, `NewServer()`, `routes()`, `ServeHTTP()` |
| `handlers.go` | All HTTP handler methods on `Server` |
| `middleware.go` | Middleware functions (`requestLogger`, `authenticate`, etc.) |
| `models.go` | Domain types, request/response structs |
| `store.go` | Database access layer (queries, store struct) |
For very small apps, `server.go` and `handlers.go` can be the same file.
### Example: Flat Server
```go
// main.go
package main
import (
"context"
"database/sql"
"log/slog"
"os"
)
func main() {
if err := run(context.Background()); err != nil {
slog.Error("application error", "err", err)
os.Exit(1)
}
}
func run(ctx context.Context) error {
db, err := sql.Open("postgres", os.Getenv("DATABASE_URL"))
if err != nil {
return fmt.Errorf("opening db: %w", err)
}
defer db.Close()
srv := NewServer(db, slog.Default())
// ... start and graceful shutdown ...
return nil
}
```
```go
// server.go
package main
type Server struct {
db *sql.DB
logger *slog.Logger
router *http.ServeMux
}
func NewServer(db *sql.DB, logger *slog.Logger) *Server {
s := &Server{db: db, logger: logger, router: http.NewServeMux()}
s.routes()
return s
}
func (s *Server) routes() {
s.router.HandleFunc("GET /api/items", s.handleListItems)
s.router.HandleFunc("GET /api/items/{id}", s.handleGetItem)
s.router.HandleFunc("POST /api/items", s.handleCreateItem)
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
```
---
## Modular / Domain-Driven Structure
For larger applications with multiple bounded contexts, multiple teams, or significant growth expected.
```text
myapp/
├── cmd/
│ └── server/
│ └── main.go
├── internal/
│ ├── user/
│ │ ├── handler.go
│ │ ├── service.go
│ │ ├── store.go
│ │ ├── model.go
│ │ └── handler_test.go
│ ├── order/
│ │ ├── handler.go
│ │ ├── service.go
│ │ ├── store.go
│ │ ├── model.go
│ │ └── handler_test.go
│ └── platform/
│ ├── middleware/
│ │ ├── auth.go
│ │ └── logging.go
│ ├── database/
│ │ └── postgres.go
│ └── config/
│ └── config.go
├── migrations/
│ ├── 001_create_users.up.sql
│ └── 001_create_users.down.sql
├── go.mod
└── go.sum
```
### Directory Conventions
#### `cmd/`
Entry points for the application. Each subdirectory produces one binary.
```text
cmd/
├── server/
│ └── main.go # HTTP server
├── worker/
│ └── main.go # Background job processor
└── migrate/
└── main.go # Database migration tool
```
Each `main.go` is the composition root: it reads config, creates dependencies, wires them together, and starts the program. Keep `main.go` small -- delegate to a `run()` function that returns an error.
#### `internal/`
The `internal/` directory is enforced by the Go toolchain. Code inside `internal/` cannot be imported by external modules. Use it for all application-specific code.
```go
// This import is only allowed from within the same module:
import "myapp/internal/user"
```
This gives you the freedom to refactor internal packages without worrying about breaking external consumers.
#### Domain Packages (`internal/user/`, `internal/order/`)
Each domain package owns its:
- **Models** -- domain types and validation
- **Store** -- database queries, implements a store interface
- **Service** -- business logic, orchestrates store calls
- **Handler** -- HTTP handlers, request parsing, response writing
```go
// internal/user/service.go
package user
type Service struct {
store Store
}
type Store interface {
GetByID(ctx context.Context, id string) (*User, error)
Create(ctx context.Context, u *User) error
List(ctx context.Context, limit, offset int) ([]User, error)
}
func NewService(store Store) *Service {
return &Service{store: store}
}
func (s *Service) GetUser(ctx context.Context, id string) (*User, error) {
if id == "" {
return nil, fmt.Errorf("user id is required")
}
return s.store.GetByID(ctx, id)
}
```
#### `internal/platform/`
Shared infrastructure code that is not domain-specific:
- `middleware/` -- HTTP middleware (logging, auth, CORS)
- `database/` -- connection helpers, migration runners
- `config/` -- configuration loading and validation
Platform packages are imported by domain packages and `cmd/`, but never import domain packages.
### Package Design Principles
**Dependencies flow inward.** Domain packages should not import other domain packages. If `order` needs user data, it defines its own interface:
```go
// internal/order/service.go
package order
type UserLookup interface {
GetByID(ctx context.Context, id string) (*User, error)
}
type Service struct {
store Store
userLookup UserLookup
}
```
The `cmd/server/main.go` wires the `user.Service` (which satisfies `order.UserLookup`) into the order service.
**Avoid circular dependencies.** If package A imports package B, package B cannot import package A. This is a compile error in Go. Solutions:
1. Extract shared types into a separate package (e.g., `internal/domain`)
2. Use interfaces at the consumer side
3. Merge the packages if they are tightly coupled
**Keep packages focused.** A package named `utils` or `helpers` is a code smell. If a function doesn't belong to a domain, it belongs in `platform/` with a descriptive package name.
**Export only what is needed.** Start with unexported types and functions. Export them only when another package needs access.
### Wiring in main.go
```go
// cmd/server/main.go
package main
import (
"myapp/internal/order"
"myapp/internal/platform/config"
"myapp/internal/platform/database"
"myapp/internal/user"
)
func run(ctx context.Context) error {
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("loading config: %w", err)
}
db, err := database.Open(cfg.DatabaseURL)
if err != nil {
return fmt.Errorf("opening db: %w", err)
}
defer db.Close()
// Build dependency graph
userStore := user.NewPostgresStore(db)
userService := user.NewService(userStore)
orderStore := order.NewPostgresStore(db)
orderService := order.NewService(orderStore, userService)
// Build server
mux := http.NewServeMux()
user.RegisterRoutes(mux, userService)
order.RegisterRoutes(mux, orderService)
// ... start HTTP server with graceful shutdown ...
return nil
}
```
### Route Registration in Domain Packages
Each domain package provides a `RegisterRoutes` function:
```go
// internal/user/handler.go
package user
func RegisterRoutes(mux *http.ServeMux, svc *Service) {
h := &handler{svc: svc}
mux.HandleFunc("GET /api/users", h.list)
mux.HandleFunc("GET /api/users/{id}", h.get)
mux.HandleFunc("POST /api/users", h.create)
}
type handler struct {
svc *Service
}
func (h *handler) get(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
u, err := h.svc.GetUser(r.Context(), id)
// ...
}
```
---
## Migration Signals: Flat to Modular
Move from flat to modular when you notice:
1. **File length** -- `handlers.go` exceeds ~500 lines or contains unrelated handlers
2. **Naming collisions** -- You prefix functions like `userGetHandler`, `orderGetHandler` to avoid confusion
3. **Multiple developers** -- Merge conflicts in shared files become frequent
4. **Distinct domains** -- The application clearly has separate bounded contexts (users, orders, billing)
5. **Separate deployment needs** -- You want a CLI tool and an HTTP server from the same codebase (`cmd/server/`, `cmd/cli/`)
6. **Test isolation** -- You want to test one domain without loading all the others
### How to Migrate
1. Create `cmd/server/main.go` and move wiring code there
2. Create `internal/` and make one domain package for the most independent domain
3. Move its models, handlers, store, and tests into the new package
4. Update imports in `main.go`
5. Repeat for each domain
6. Extract shared infrastructure into `internal/platform/`
Migrate incrementally. Do not restructure everything in one commit.
Reviews BubbleTea TUI code for proper Elm architecture, model/update/view patterns, and Lipgloss styling. Use when reviewing terminal UI code using charmbrac...
---
name: bubbletea-code-review
description: Reviews BubbleTea TUI code for proper Elm architecture, model/update/view patterns, and Lipgloss styling. Use when reviewing terminal UI code using charmbracelet/bubbletea.
---
# BubbleTea Code Review
## Hard gates (sequence)
Advance only when each **pass condition** is objectively true (reduces false positives on `tea.Cmd` and unsubstantiated blocking claims):
| Gate | Pass condition |
|------|----------------|
| **G1 — Anti–false-positive** | You skimmed **NOT Issues** below **or** read [references/elm-architecture.md](references/elm-architecture.md) **before** recording a finding about `tea.Cmd` returns, value receivers on `Update`, or nested child `Update`. |
| **G2 — Evidence for blocking / suspicious I/O** | Each Critical/Major finding names **file path + line** (or a short quoted snippet) showing the blocking call, `huh.Form.Run` in the wrong place, or other asserted anti-pattern—not a hypothetical. |
| **G3 — Verification** | Before publishing review output, you applied **beagle-go:review-verification-protocol** to each proposed finding. |
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| Elm architecture, tea.Cmd as data | [references/elm-architecture.md](references/elm-architecture.md) |
| Model state, message handling | [references/model-update.md](references/model-update.md) |
| View rendering, Lipgloss styling | [references/view-styling.md](references/view-styling.md) |
| Component composition, Huh forms | [references/composition.md](references/composition.md) |
| Bubbles components (list, table, etc.) | [references/bubbles-components.md](references/bubbles-components.md) |
## CRITICAL: Avoid False Positives
**Read [elm-architecture.md](references/elm-architecture.md) first!** The most common review mistake is flagging correct patterns as bugs.
### NOT Issues (Do NOT Flag These)
| Pattern | Why It's Correct |
|---------|------------------|
| `return m, m.loadData()` | `tea.Cmd` is returned immediately; runtime executes async |
| Value receiver on `Update()` | Standard BubbleTea pattern; model returned by value |
| Nested `m.child, cmd = m.child.Update(msg)` | Normal component composition |
| Helper functions returning `tea.Cmd` | Creates command descriptor, no I/O in Update |
| `tea.Batch(cmd1, cmd2)` | Commands execute concurrently by runtime |
### ACTUAL Issues (DO Flag These)
| Pattern | Why It's Wrong |
|---------|----------------|
| `os.ReadFile()` in Update | Blocks UI thread |
| `http.Get()` in Update | Network I/O blocks |
| `time.Sleep()` in Update | Freezes UI |
| `<-channel` in Update (blocking) | May block indefinitely |
| `huh.Form.Run()` in Update | Blocking call |
## Review Checklist
### Architecture
- [ ] **No blocking I/O in Update()** (file, network, sleep)
- [ ] Helper functions returning `tea.Cmd` are NOT flagged as blocking
- [ ] Commands used for all async operations
### Model & Update
- [ ] Model is immutable (Update returns new model, not mutates)
- [ ] Init returns proper initial command (or nil)
- [ ] Update handles all expected message types
- [ ] WindowSizeMsg handled for responsive layout
- [ ] tea.Batch used for multiple commands
- [ ] tea.Quit used correctly for exit
### View & Styling
- [ ] View is a pure function (no side effects)
- [ ] Lipgloss styles defined once, not in View
- [ ] Key bindings use key.Matches with help.KeyMap
### Components
- [ ] Sub-component updates propagated correctly
- [ ] Bubbles components initialized with dimensions
- [ ] Huh forms embedded via Update loop (not Run())
## Critical Patterns
### Model Must Be Immutable
```go
// BAD - mutates model
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.items = append(m.items, newItem) // mutation!
return m, nil
}
// GOOD - returns new model
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
newItems := make([]Item, len(m.items)+1)
copy(newItems, m.items)
newItems[len(m.items)] = newItem
m.items = newItems
return m, nil
}
```
### Commands for Async/IO
```go
// BAD - blocking in Update
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
data, _ := os.ReadFile("config.json") // blocks UI!
m.config = parse(data)
return m, nil
}
// GOOD - use commands
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, loadConfigCmd()
}
func loadConfigCmd() tea.Cmd {
return func() tea.Msg {
data, err := os.ReadFile("config.json")
if err != nil {
return errMsg{err}
}
return configLoadedMsg{parse(data)}
}
}
```
### Styles Defined Once
```go
// BAD - creates new style each render
func (m Model) View() string {
style := lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("205"))
return style.Render("Hello")
}
// GOOD - define styles at package level or in model
var titleStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("205"))
func (m Model) View() string {
return titleStyle.Render("Hello")
}
```
## When to Load References
- **First time reviewing BubbleTea** → [elm-architecture.md](references/elm-architecture.md) (prevents false positives)
- Reviewing Update function logic → [model-update.md](references/model-update.md)
- Reviewing View function, styling → [view-styling.md](references/view-styling.md)
- Reviewing component hierarchy → [composition.md](references/composition.md)
- Using Bubbles components → [bubbles-components.md](references/bubbles-components.md)
## Review Questions
1. Is Update() free of blocking I/O? (NOT: "is the cmd helper blocking?")
2. Is the model immutable in Update?
3. Are Lipgloss styles defined once, not in View?
4. Is WindowSizeMsg handled for resizing?
5. Are key bindings documented with help.KeyMap?
6. Are Bubbles components sized correctly?
FILE:references/bubbles-components.md
# Bubbles Component Reference
Complete reference for all charmbracelet/bubbles components.
## Component Overview
| Component | Package | Purpose |
|-----------|---------|---------|
| list | `bubbles/list` | Scrollable list with filtering |
| table | `bubbles/table` | Tabular data display |
| viewport | `bubbles/viewport` | Scrollable content area |
| textinput | `bubbles/textinput` | Single-line text input |
| textarea | `bubbles/textarea` | Multi-line text input |
| spinner | `bubbles/spinner` | Loading indicator |
| progress | `bubbles/progress` | Progress bar |
| paginator | `bubbles/paginator` | Page navigation |
| filepicker | `bubbles/filepicker` | File/directory selection |
| timer | `bubbles/timer` | Countdown timer |
| stopwatch | `bubbles/stopwatch` | Elapsed time counter |
| help | `bubbles/help` | Key binding help display |
| key | `bubbles/key` | Key binding definitions |
| cursor | `bubbles/cursor` | Text cursor management |
---
## List
Full-featured list with filtering, pagination, and custom delegates.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/list"
// Items must implement list.Item
type item struct {
title, desc string
}
func (i item) Title() string { return i.title }
func (i item) Description() string { return i.desc }
func (i item) FilterValue() string { return i.title }
// Create list
items := []list.Item{
item{title: "Raspberry Pi", desc: "A small computer"},
item{title: "Arduino", desc: "A microcontroller"},
}
l := list.New(items, list.NewDefaultDelegate(), 0, 0)
l.Title = "My List"
```
### Common Patterns
```go
// Update list size on window resize
case tea.WindowSizeMsg:
h, v := docStyle.GetFrameSize()
m.list.SetSize(msg.Width-h, msg.Height-v)
// Get selected item
if i, ok := m.list.SelectedItem().(item); ok {
return i.title
}
// Set items dynamically
m.list.SetItems(newItems)
// Custom delegate for styling
delegate := list.NewDefaultDelegate()
delegate.Styles.SelectedTitle = selectedTitleStyle
delegate.Styles.SelectedDesc = selectedDescStyle
```
### Anti-Patterns
```go
// ❌ BAD - reaching into internals
selected := m.list.Items()[m.list.Index()]
// ✅ GOOD - use provided methods
selected := m.list.SelectedItem()
```
---
## Table
Tabular data with column definitions and row selection.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/table"
columns := []table.Column{
{Title: "Name", Width: 20},
{Title: "Email", Width: 30},
{Title: "Role", Width: 15},
}
rows := []table.Row{
{"Alice", "[email protected]", "Admin"},
{"Bob", "[email protected]", "User"},
}
t := table.New(
table.WithColumns(columns),
table.WithRows(rows),
table.WithFocused(true),
table.WithHeight(10),
)
// Apply styles
s := table.DefaultStyles()
s.Header = s.Header.BorderStyle(lipgloss.NormalBorder())
s.Selected = s.Selected.Foreground(lipgloss.Color("229"))
t.SetStyles(s)
```
### Common Patterns
```go
// Get selected row
selectedRow := m.table.SelectedRow()
// Update rows
m.table.SetRows(newRows)
// Handle selection
case tea.KeyMsg:
switch msg.String() {
case "enter":
row := m.table.SelectedRow()
return m, selectRowCmd(row)
}
```
---
## Viewport
Scrollable content area for large text or rendered content.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/viewport"
vp := viewport.New(80, 20)
vp.SetContent(longContent)
// In Update
case tea.WindowSizeMsg:
vp.Width = msg.Width
vp.Height = msg.Height - headerHeight - footerHeight
```
### Common Patterns
```go
// Track scroll position
func (m Model) footerView() string {
return fmt.Sprintf("%3.f%%", m.viewport.ScrollPercent()*100)
}
// Programmatic scrolling
m.viewport.GotoTop()
m.viewport.GotoBottom()
m.viewport.LineDown(5)
m.viewport.LineUp(5)
// Update content
m.viewport.SetContent(newContent)
```
### Anti-Patterns
```go
// ❌ BAD - setting content in View
func (m Model) View() string {
m.viewport.SetContent(m.renderContent()) // Side effect!
return m.viewport.View()
}
// ✅ GOOD - set content in Update
case contentLoadedMsg:
m.viewport.SetContent(msg.content)
return m, nil
```
---
## TextInput
Single-line text input with placeholder and validation.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/textinput"
ti := textinput.New()
ti.Placeholder = "Enter username"
ti.CharLimit = 32
ti.Width = 20
ti.Focus()
```
### Common Patterns
```go
// Password input
ti.EchoMode = textinput.EchoPassword
ti.EchoCharacter = '*'
// Validation styling
ti.Validate = func(s string) error {
if len(s) < 3 {
return errors.New("too short")
}
return nil
}
// Get value
value := m.textinput.Value()
// Clear input
m.textinput.Reset()
// Focus management
m.textinput.Focus()
m.textinput.Blur()
```
### Multiple Inputs
```go
type Model struct {
inputs []textinput.Model
focused int
}
func (m *Model) nextInput() {
m.inputs[m.focused].Blur()
m.focused = (m.focused + 1) % len(m.inputs)
m.inputs[m.focused].Focus()
}
```
---
## TextArea
Multi-line text input with line wrapping.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/textarea"
ta := textarea.New()
ta.Placeholder = "Type your message..."
ta.SetWidth(60)
ta.SetHeight(10)
ta.Focus()
```
### Common Patterns
```go
// Get/set value
content := m.textarea.Value()
m.textarea.SetValue("Initial content")
// Line count
lines := m.textarea.LineCount()
// Cursor position
row, col := m.textarea.Cursor()
// Resize
case tea.WindowSizeMsg:
m.textarea.SetWidth(msg.Width - 4)
m.textarea.SetHeight(msg.Height - 6)
```
---
## Spinner
Loading indicator with multiple styles.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/spinner"
s := spinner.New()
s.Spinner = spinner.Dot // or Line, MiniDot, Jump, Pulse, Points, Globe, Moon, Monkey, Meter, Hamburger
// In Init
return s.Tick
// In Update
case spinner.TickMsg:
m.spinner, cmd = m.spinner.Update(msg)
return m, cmd
```
### Spinner Styles
```go
// Available spinners
spinner.Line // |/-\
spinner.Dot // ⣾⣽⣻⢿⡿⣟⣯⣷
spinner.MiniDot // ⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏
spinner.Jump // ⢄⢂⢁⡁⡈⡐⡠
spinner.Pulse // █▓▒░
spinner.Points // ∙∙∙
spinner.Globe // 🌍🌎🌏
spinner.Moon // 🌑🌒🌓🌔🌕🌖🌗🌘
spinner.Monkey // 🙈🙉🙊
spinner.Meter // ▱▰▰▰▰▰▰
spinner.Hamburger // ☰☲☴
```
---
## Progress
Progress bar with percentage and custom styling.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/progress"
p := progress.New(progress.WithDefaultGradient())
// or
p := progress.New(progress.WithScaledGradient("#FF7CCB", "#FDFF8C"))
// In View
return p.ViewAs(0.5) // 50%
// Animated progress
return p.View() // uses internal percentage
```
### Common Patterns
```go
// Update progress
m.progress.SetPercent(0.75)
// Width adjustment
case tea.WindowSizeMsg:
m.progress.Width = msg.Width - padding
// Animated increment
case progressMsg:
cmd := m.progress.SetPercent(msg.percent)
return m, cmd
```
---
## Paginator
Page navigation for paginated content.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/paginator"
p := paginator.New()
p.Type = paginator.Dots // or Arabic (1/10)
p.SetTotalPages(10)
p.PerPage = 5
```
### Common Patterns
```go
// Get current page items
start, end := m.paginator.GetSliceBounds(len(items))
pageItems := items[start:end]
// Navigation
if m.paginator.OnLastPage() {
// handle end
}
// In Update - paginator handles arrow keys
m.paginator, cmd = m.paginator.Update(msg)
```
---
## FilePicker
File and directory selection.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/filepicker"
fp := filepicker.New()
fp.CurrentDirectory, _ = os.UserHomeDir()
fp.AllowedTypes = []string{".go", ".md", ".txt"}
fp.ShowHidden = false
```
### Common Patterns
```go
// Check for selection
case tea.KeyMsg:
m.filepicker, cmd = m.filepicker.Update(msg)
if didSelect, path := m.filepicker.DidSelectFile(msg); didSelect {
m.selectedFile = path
return m, fileSelectedCmd(path)
}
if didSelect, path := m.filepicker.DidSelectDisabledFile(msg); didSelect {
m.err = errors.New("file type not allowed")
}
```
---
## Timer
Countdown timer with start/stop/reset.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/timer"
t := timer.NewWithInterval(5*time.Minute, time.Second)
// In Init
return t.Init()
// In Update
case timer.TickMsg:
m.timer, cmd = m.timer.Update(msg)
return m, cmd
case timer.TimeoutMsg:
// Timer finished
return m, nil
```
### Control
```go
// Toggle
cmd := m.timer.Toggle()
// Stop
cmd := m.timer.Stop()
// Start
cmd := m.timer.Start()
```
---
## Stopwatch
Elapsed time counter.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/stopwatch"
sw := stopwatch.NewWithInterval(time.Millisecond * 100)
// In Init
return sw.Init()
// In Update
case stopwatch.TickMsg:
m.stopwatch, cmd = m.stopwatch.Update(msg)
return m, cmd
```
---
## Help
Display key bindings to users.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/help"
import "github.com/charmbracelet/bubbles/key"
type keyMap struct {
Up key.Binding
Down key.Binding
Quit key.Binding
}
func (k keyMap) ShortHelp() []key.Binding {
return []key.Binding{k.Up, k.Down, k.Quit}
}
func (k keyMap) FullHelp() [][]key.Binding {
return [][]key.Binding{
{k.Up, k.Down},
{k.Quit},
}
}
var keys = keyMap{
Up: key.NewBinding(
key.WithKeys("up", "k"),
key.WithHelp("↑/k", "up"),
),
Down: key.NewBinding(
key.WithKeys("down", "j"),
key.WithHelp("↓/j", "down"),
),
Quit: key.NewBinding(
key.WithKeys("q", "ctrl+c"),
key.WithHelp("q", "quit"),
),
}
h := help.New()
// In View
return h.View(keys)
```
### Expand/Collapse
```go
// Toggle full help
case tea.KeyMsg:
if msg.String() == "?" {
m.help.ShowAll = !m.help.ShowAll
}
```
---
## Key
Key binding definitions for consistent input handling.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/key"
var quitKey = key.NewBinding(
key.WithKeys("q", "ctrl+c", "esc"),
key.WithHelp("q", "quit"),
)
// In Update
case tea.KeyMsg:
if key.Matches(msg, quitKey) {
return m, tea.Quit
}
```
### Enable/Disable Bindings
```go
// Disable a binding
quitKey.SetEnabled(false)
// Check if enabled
if quitKey.Enabled() {
// ...
}
```
---
## Cursor
Text cursor management for custom text inputs.
### Basic Setup
```go
import "github.com/charmbracelet/bubbles/cursor"
c := cursor.New()
c.SetMode(cursor.CursorBlink)
// Modes
cursor.CursorBlink
cursor.CursorStatic
cursor.CursorHide
```
---
## Integration Patterns
### Multiple Components
```go
type Model struct {
list list.Model
spinner spinner.Model
help help.Model
keys keyMap
loading bool
}
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
var cmd tea.Cmd
// Always update spinner when loading
if m.loading {
m.spinner, cmd = m.spinner.Update(msg)
cmds = append(cmds, cmd)
}
// Update list when not loading
if !m.loading {
m.list, cmd = m.list.Update(msg)
cmds = append(cmds, cmd)
}
return m, tea.Batch(cmds...)
}
```
### Component Communication
```go
// Custom message for cross-component communication
type itemSelectedMsg struct {
item Item
}
// Child component emits message
case tea.KeyMsg:
if msg.String() == "enter" {
return m, func() tea.Msg {
return itemSelectedMsg{m.list.SelectedItem().(Item)}
}
}
// Parent handles message
case itemSelectedMsg:
m.selectedItem = msg.item
m.state = viewDetail
```
## Review Questions
1. Are components initialized with proper dimensions?
2. Are components updated on WindowSizeMsg?
3. Is focus managed correctly between components?
4. Are component methods used instead of reaching into internals?
5. Are tick messages handled for animated components (spinner, timer)?
FILE:references/composition.md
# Component Composition
## Bubbles Integration
### 1. Using Standard Bubbles
```go
import (
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
"github.com/charmbracelet/bubbles/spinner"
)
type Model struct {
list list.Model
input textinput.Model
viewport viewport.Model
spinner spinner.Model
}
```
### 2. Initialize Sub-Components
```go
func NewModel() Model {
// List
items := []list.Item{...}
l := list.New(items, list.NewDefaultDelegate(), 0, 0)
l.Title = "My List"
// Text input
ti := textinput.New()
ti.Placeholder = "Type here..."
ti.Focus()
// Spinner
s := spinner.New()
s.Spinner = spinner.Dot
return Model{
list: l,
input: ti,
spinner: s,
}
}
```
### 3. Update Sub-Components
```go
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
var cmd tea.Cmd
// Always update active sub-components
switch m.state {
case stateList:
m.list, cmd = m.list.Update(msg)
cmds = append(cmds, cmd)
case stateInput:
m.input, cmd = m.input.Update(msg)
cmds = append(cmds, cmd)
}
// Handle window size for all components
if msg, ok := msg.(tea.WindowSizeMsg); ok {
m.list.SetSize(msg.Width, msg.Height-4)
m.viewport.Width = msg.Width
m.viewport.Height = msg.Height - 4
}
return m, tea.Batch(cmds...)
}
```
## Custom Components
### 1. Component Interface Pattern
```go
// Component interface for consistent sub-components
type Component interface {
Init() tea.Cmd
Update(tea.Msg) (Component, tea.Cmd)
View() string
SetSize(width, height int)
}
```
### 2. Self-Contained Component
```go
// menu/menu.go
package menu
type Model struct {
items []Item
cursor int
width int
height int
}
func New(items []Item) Model {
return Model{items: items}
}
func (m Model) Init() tea.Cmd {
return nil
}
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "up", "k":
if m.cursor > 0 {
m.cursor--
}
case "down", "j":
if m.cursor < len(m.items)-1 {
m.cursor++
}
}
}
return m, nil
}
func (m Model) View() string {
var b strings.Builder
for i, item := range m.items {
cursor := " "
if i == m.cursor {
cursor = "> "
}
b.WriteString(cursor + item.Title + "\n")
}
return b.String()
}
func (m *Model) SetSize(w, h int) {
m.width = w
m.height = h
}
func (m Model) Selected() Item {
return m.items[m.cursor]
}
```
### 3. Using Custom Component
```go
import "myapp/menu"
type Model struct {
menu menu.Model
}
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.menu, cmd = m.menu.Update(msg)
// React to menu selection
if key, ok := msg.(tea.KeyMsg); ok && key.String() == "enter" {
selected := m.menu.Selected()
// handle selection
}
return m, cmd
}
```
## State Machine Pattern
### 1. View States
```go
type viewState int
const (
viewLoading viewState = iota
viewList
viewDetail
viewEdit
)
type Model struct {
state viewState
// sub-components for each state
list list.Model
detail detailModel
edit editModel
}
```
### 2. State Transitions
```go
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Global key handling
if key, ok := msg.(tea.KeyMsg); ok {
switch key.String() {
case "esc":
// Go back based on current state
switch m.state {
case viewDetail:
m.state = viewList
return m, nil
case viewEdit:
m.state = viewDetail
return m, nil
}
}
}
// Delegate to current state's component
var cmd tea.Cmd
switch m.state {
case viewList:
m.list, cmd = m.list.Update(msg)
// Check for selection
if key, ok := msg.(tea.KeyMsg); ok && key.String() == "enter" {
m.state = viewDetail
m.detail = newDetailModel(m.list.SelectedItem())
}
case viewDetail:
m.detail, cmd = m.detail.Update(msg)
case viewEdit:
m.edit, cmd = m.edit.Update(msg)
}
return m, cmd
}
```
### 3. View Routing
```go
func (m Model) View() string {
switch m.state {
case viewLoading:
return m.spinner.View() + " Loading..."
case viewList:
return m.list.View()
case viewDetail:
return m.detail.View()
case viewEdit:
return m.edit.View()
default:
return "Unknown state"
}
}
```
## Focus Management
### 1. Track Focus
```go
type focusState int
const (
focusList focusState = iota
focusInput
focusButtons
)
type Model struct {
focus focusState
list list.Model
input textinput.Model
}
func (m *Model) nextFocus() {
m.focus = (m.focus + 1) % 3
m.updateFocus()
}
func (m *Model) updateFocus() {
switch m.focus {
case focusInput:
m.input.Focus()
default:
m.input.Blur()
}
}
```
### 2. Tab Navigation
```go
case tea.KeyMsg:
switch key.String() {
case "tab":
m.nextFocus()
return m, nil
case "shift+tab":
m.prevFocus()
return m, nil
}
// Only handle keys for focused component
switch m.focus {
case focusList:
m.list, cmd = m.list.Update(msg)
case focusInput:
m.input, cmd = m.input.Update(msg)
}
```
## Anti-Patterns
### 1. Not Propagating Updates
```go
// BAD - sub-component never updates
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
// only handles own keys, ignores sub-component
}
return m, nil
}
// GOOD - always update sub-components
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.list, cmd = m.list.Update(msg) // always propagate
return m, cmd
}
```
### 2. Nested Component Access
```go
// BAD - reaches into component internals
func (m Model) View() string {
return m.list.items[m.list.cursor].Title // breaks encapsulation
}
// GOOD - use component methods
func (m Model) View() string {
return m.list.SelectedItem().(Item).Title
}
```
## Huh Forms Integration
[Huh](https://github.com/charmbracelet/huh) is a form library built on BubbleTea.
### Basic Form
```go
import "github.com/charmbracelet/huh"
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().
Key("name").
Title("What's your name?").
Validate(func(s string) error {
if s == "" {
return errors.New("name required")
}
return nil
}),
huh.NewSelect[string]().
Key("role").
Title("Select role").
Options(
huh.NewOption("Admin", "admin"),
huh.NewOption("User", "user"),
),
huh.NewConfirm().
Key("confirm").
Title("Continue?"),
),
)
// Run standalone (blocking)
err := form.Run()
// Get values
name := form.GetString("name")
role := form.GetString("role")
confirmed := form.GetBool("confirm")
```
### Embedding in BubbleTea
```go
type Model struct {
form *huh.Form
done bool
}
func NewModel() Model {
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().Key("name").Title("Name"),
),
).WithTheme(huh.ThemeDracula())
return Model{form: form}
}
func (m Model) Init() tea.Cmd {
return m.form.Init()
}
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Check completion first
if m.form.State == huh.StateCompleted {
m.done = true
return m, nil
}
// Update form
form, cmd := m.form.Update(msg)
if f, ok := form.(*huh.Form); ok {
m.form = f
}
return m, cmd
}
func (m Model) View() string {
if m.done {
return fmt.Sprintf("Hello, %s!", m.form.GetString("name"))
}
return m.form.View()
}
```
### Field Types
```go
// Text input
huh.NewInput().Key("name").Title("Name").Placeholder("Enter name")
// Multi-line text
huh.NewText().Key("bio").Title("Bio").Lines(5)
// Single select
huh.NewSelect[string]().Key("color").Title("Color").
Options(
huh.NewOption("Red", "red"),
huh.NewOption("Blue", "blue"),
)
// Multi select
huh.NewMultiSelect[string]().Key("tags").Title("Tags").
Options(
huh.NewOption("Go", "go"),
huh.NewOption("Rust", "rust"),
)
// Confirmation
huh.NewConfirm().Key("agree").Title("Agree?")
// File picker
huh.NewFilePicker().Key("file").Title("Select file")
```
### Theming
```go
form := huh.NewForm(...).
WithTheme(huh.ThemeDracula()). // Built-in themes
WithWidth(60).
WithShowHelp(true).
WithShowErrors(true)
// Built-in themes
huh.ThemeBase()
huh.ThemeCharm()
huh.ThemeDracula()
huh.ThemeCatppuccin()
huh.ThemeBase16()
```
### Multi-Page Forms
```go
form := huh.NewForm(
// Page 1
huh.NewGroup(
huh.NewInput().Key("name").Title("Name"),
huh.NewInput().Key("email").Title("Email"),
).Title("Personal Info"),
// Page 2
huh.NewGroup(
huh.NewSelect[string]().Key("plan").Title("Plan").
Options(
huh.NewOption("Free", "free"),
huh.NewOption("Pro", "pro"),
),
).Title("Subscription"),
)
```
### Anti-Patterns
```go
// ❌ BAD - calling Run() inside BubbleTea (blocks)
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.form.Run() // BLOCKS THE UI!
return m, nil
}
// ✅ GOOD - use Update loop
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
form, cmd := m.form.Update(msg)
m.form = form.(*huh.Form)
return m, cmd
}
```
## Review Questions
1. Are sub-components properly initialized?
2. Are sub-component updates propagated?
3. Is WindowSizeMsg passed to all components needing resize?
4. Is there a clear state machine for view transitions?
5. Is focus tracked and components blurred/focused correctly?
6. Are Huh forms embedded correctly (not using blocking Run())?
FILE:references/elm-architecture.md
# Understanding the Elm Architecture
## The Core Principle: Commands Are Data
The most important concept in BubbleTea (and Elm) is that **commands describe effects, they don't execute them**.
```go
// tea.Cmd is just a function signature
type Cmd func() Msg
```
When you return a `tea.Cmd` from `Update()`, you're returning a *description* of work to do. The BubbleTea runtime executes it asynchronously after `Update()` returns.
## Common False Positive: "Synchronous Execution"
**This is NOT blocking:**
```go
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case NavigateMsg:
return m, m.loadData() // ← NOT synchronous execution!
}
return m, nil
}
func (m *Model) loadData() tea.Cmd {
return func() tea.Msg {
// This closure is NOT executed during Update()
// The runtime schedules it for async execution
data, _ := http.Get("https://api.example.com/data")
return DataLoadedMsg{data}
}
}
```
**Why this is correct:**
1. `m.loadData()` is called synchronously, but it only *creates* the command
2. The `http.Get` inside the closure does NOT run during `Update()`
3. `Update()` returns immediately with the command
4. BubbleTea's runtime executes the command in a separate goroutine
5. When complete, the runtime sends `DataLoadedMsg` back to `Update()`
## The Execution Model
```
┌─────────────────────────────────────────────────────────────────┐
│ BubbleTea Runtime │
├─────────────────────────────────────────────────────────────────┤
│ │
│ User Input ──┐ │
│ ▼ │
│ ┌──────────┐ returns ┌──────────────┐ │
│ Msg → │ Update │ ───────────────→ │ Model, Cmd │ │
│ └──────────┘ immediately └──────┬───────┘ │
│ │ │
│ ┌───────────────────────────────┘ │
│ ▼ │
│ ┌──────────┐ │
│ │ Runtime │ executes Cmd │
│ │ executes │ in background │
│ │ Cmd │ goroutine │
│ └────┬─────┘ │
│ │ │
│ ▼ sends Msg │
│ ┌──────────┐ │
│ Msg → │ Update │ ← cycle continues │
│ └──────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
```
## NOT Issues (Avoid These False Positives)
### 1. Helper Functions Returning tea.Cmd
```go
// ✅ CORRECT - this is NOT blocking
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, m.fetchItems()
}
func (m *Model) fetchItems() tea.Cmd {
return func() tea.Msg {
items, _ := api.GetItems() // Runs LATER, by runtime
return ItemsMsg{items}
}
}
```
**Why OK:** The helper creates and returns a command descriptor. No I/O happens in Update().
### 2. Value Receivers on Update
```go
// ✅ CORRECT - standard BubbleTea pattern
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.counter++
return m, nil
}
```
**Why OK:** BubbleTea returns the model by value. The caller receives the modified copy.
### 3. Nested Model Updates
```go
// ✅ CORRECT - normal component composition
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.child, cmd = m.child.Update(msg) // Updates child synchronously
return m, cmd
}
```
**Why OK:** Child's Update() is also non-blocking. Commands bubble up.
### 4. Batch Commands
```go
// ✅ CORRECT - commands execute concurrently
return m, tea.Batch(
m.loadUser(),
m.loadPosts(),
m.loadSettings(),
)
```
**Why OK:** All three commands run concurrently by the runtime.
### 5. Immediate Message Return
```go
// ✅ CORRECT - synchronous state transition
func (m *Model) navigateToMenu() tea.Cmd {
return func() tea.Msg {
return ShowMenuMsg{} // No I/O, just returns a message
}
}
```
**Why OK:** Even though this returns immediately, it's still async from Update()'s perspective.
## ACTUAL Issues to Flag
### 1. Blocking I/O Directly in Update
```go
// ❌ BAD - blocks the UI
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
data, _ := os.ReadFile("config.json") // BLOCKS!
m.config = parse(data)
return m, nil
}
```
**Fix:** Move to a command:
```go
return m, loadConfigCmd()
```
### 2. Sleep in Update
```go
// ❌ BAD - freezes UI for 2 seconds
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
time.Sleep(2 * time.Second)
return m, nil
}
```
**Fix:** Use tea.Tick:
```go
return m, tea.Tick(2*time.Second, func(t time.Time) tea.Msg {
return DelayCompleteMsg{}
})
```
### 3. HTTP Calls in Update
```go
// ❌ BAD - network I/O in Update
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
resp, _ := http.Get("https://api.example.com")
// ...
}
```
**Fix:** Wrap in a command function.
### 4. Channel Operations That Block
```go
// ❌ BAD - may block indefinitely
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
data := <-m.dataChan // Could block!
return m, nil
}
```
**Fix:** Use non-blocking select or move to command.
## Quick Reference: Is It Blocking?
| Code Pattern | Blocking? | Why |
|--------------|-----------|-----|
| `return m, m.loadData()` | No | Returns cmd descriptor |
| `data := fetchData()` (in Update) | **Yes** | Direct I/O call |
| `return m, func() tea.Msg { ... }` | No | Closure runs later |
| `time.Sleep(d)` (in Update) | **Yes** | Blocks goroutine |
| `<-channel` (in Update) | **Maybe** | Blocks if empty |
| `return m, tea.Tick(d, ...)` | No | Runtime handles delay |
## Review Guidance
When reviewing BubbleTea code:
1. **Look for I/O in Update()** - file, network, database calls directly in Update are bugs
2. **Ignore cmd helper patterns** - `return m, m.someHelper()` where helper returns `tea.Cmd` is correct
3. **Check what's INSIDE commands** - the closure body is where blocking ops belong
4. **Value receivers are fine** - BubbleTea's design expects this
The rule is simple: **Update() must return quickly. Commands do the slow work.**
FILE:references/model-update.md
# Model & Update
## Model Design
### 1. Model Must Implement tea.Model
```go
type Model struct {
// State
items []Item
cursor int
selected map[int]struct{}
// Dimensions (for responsive layout)
width int
height int
// Sub-components
list list.Model
viewport viewport.Model
// Error state
err error
}
// Verify interface implementation
var _ tea.Model = (*Model)(nil)
```
### 2. Init Returns Initial Command
```go
// BAD - blocking operation
func (m Model) Init() tea.Cmd {
data := loadData() // blocks!
return nil
}
// GOOD - async via command
func (m Model) Init() tea.Cmd {
return tea.Batch(
loadDataCmd(),
tea.EnterAltScreen,
)
}
```
## Update Patterns
### 1. Switch on Message Type
```go
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
return m.handleKey(msg)
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
return m, nil
case dataLoadedMsg:
m.items = msg.items
return m, nil
case errMsg:
m.err = msg.err
return m, nil
}
return m, nil
}
```
### 2. Always Handle WindowSizeMsg
```go
// BAD - ignores window size
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// no WindowSizeMsg handling
}
// GOOD
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
// Update sub-components
m.viewport.Width = msg.Width
m.viewport.Height = msg.Height - 4 // reserve for header/footer
return m, nil
```
### 3. Key Handling with key.Matches
```go
// BAD - string comparison
case tea.KeyMsg:
if msg.String() == "q" {
return m, tea.Quit
}
// GOOD - use key bindings
type keyMap struct {
Quit key.Binding
Up key.Binding
Down key.Binding
}
var keys = keyMap{
Quit: key.NewBinding(
key.WithKeys("q", "ctrl+c"),
key.WithHelp("q", "quit"),
),
Up: key.NewBinding(
key.WithKeys("up", "k"),
key.WithHelp("↑/k", "up"),
),
}
case tea.KeyMsg:
switch {
case key.Matches(msg, keys.Quit):
return m, tea.Quit
case key.Matches(msg, keys.Up):
m.cursor--
}
```
### 4. Sub-Component Updates
```go
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
// Update sub-components
var cmd tea.Cmd
m.list, cmd = m.list.Update(msg)
cmds = append(cmds, cmd)
m.viewport, cmd = m.viewport.Update(msg)
cmds = append(cmds, cmd)
// Handle our own messages
switch msg := msg.(type) {
case tea.KeyMsg:
// ...
}
return m, tea.Batch(cmds...)
}
```
## Commands
### 1. Commands Return Messages
```go
// Command that performs I/O
func fetchItemsCmd(url string) tea.Cmd {
return func() tea.Msg {
resp, err := http.Get(url)
if err != nil {
return errMsg{err}
}
defer resp.Body.Close()
var items []Item
json.NewDecoder(resp.Body).Decode(&items)
return itemsFetchedMsg{items}
}
}
```
### 2. Tick Commands for Animation
```go
type tickMsg time.Time
func tickCmd() tea.Cmd {
return tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
case tickMsg:
m.frame++
return m, tickCmd() // schedule next tick
```
### 3. Batch Multiple Commands
```go
// BAD - returns only last command
func (m Model) Init() tea.Cmd {
loadConfig()
return loadData() // loadConfig result lost!
}
// GOOD - batch them
func (m Model) Init() tea.Cmd {
return tea.Batch(
loadConfigCmd(),
loadDataCmd(),
startSpinnerCmd(),
)
}
```
## Anti-Patterns
### 1. Side Effects in View
```go
// BAD
func (m Model) View() string {
log.Printf("rendering") // side effect!
m.renderCount++ // mutation!
return "..."
}
// GOOD - View is pure
func (m Model) View() string {
return "..."
}
```
### 2. Blocking in Update
```go
// BAD
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
time.Sleep(2 * time.Second) // freezes UI!
return m, nil
}
// GOOD - use commands for delays
return m, tea.Tick(2*time.Second, func(t time.Time) tea.Msg {
return delayCompleteMsg{}
})
```
## Review Questions
1. Does Init return a command for initial I/O?
2. Does Update handle all relevant message types?
3. Is WindowSizeMsg handled for responsive layout?
4. Are key bindings using key.Matches?
5. Are sub-component updates propagated correctly?
6. Are commands used for all async/I/O operations?
FILE:references/view-styling.md
# View & Styling
## View Function
### 1. View Must Be Pure
```go
// BAD - side effects
func (m Model) View() string {
m.lastRender = time.Now() // mutation!
log.Println("rendering") // I/O!
return "..."
}
// GOOD - pure function
func (m Model) View() string {
if m.loading {
return m.spinner.View() + " Loading..."
}
return m.renderContent()
}
```
### 2. Handle Loading/Error States
```go
func (m Model) View() string {
if m.err != nil {
return errorStyle.Render(fmt.Sprintf("Error: %v", m.err))
}
if m.loading {
return m.spinner.View() + " Loading..."
}
return m.renderContent()
}
```
### 3. Compose Views Cleanly
```go
func (m Model) View() string {
var b strings.Builder
b.WriteString(m.renderHeader())
b.WriteString("\n")
b.WriteString(m.renderContent())
b.WriteString("\n")
b.WriteString(m.renderFooter())
return b.String()
}
```
## Lipgloss Styling
### 1. Define Styles at Package Level
```go
// BAD - created every render
func (m Model) View() string {
style := lipgloss.NewStyle().Bold(true)
return style.Render("Hello")
}
// GOOD - defined once
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("205"))
itemStyle = lipgloss.NewStyle().
PaddingLeft(2)
)
func (m Model) View() string {
return titleStyle.Render("Hello")
}
```
### 2. Use Color Palette
```go
// Define a consistent color palette
var (
colorPrimary = lipgloss.Color("205") // magenta
colorSecondary = lipgloss.Color("241") // gray
colorSuccess = lipgloss.Color("78") // green
colorError = lipgloss.Color("196") // red
)
var (
titleStyle = lipgloss.NewStyle().Foreground(colorPrimary)
errorStyle = lipgloss.NewStyle().Foreground(colorError)
)
```
### 3. Adaptive Colors for Themes
```go
var (
// Adaptive colors work with light and dark terminals
subtle = lipgloss.AdaptiveColor{Light: "#D9DCCF", Dark: "#383838"}
highlight = lipgloss.AdaptiveColor{Light: "#874BFD", Dark: "#7D56F4"}
)
var titleStyle = lipgloss.NewStyle().
Foreground(highlight).
Background(subtle)
```
### 4. Responsive Width
```go
func (m Model) View() string {
// Adjust style based on window width
doc := lipgloss.NewStyle().
Width(m.width).
MaxWidth(m.width)
return doc.Render(m.content)
}
```
### 5. Layout with Place and Join
```go
func (m Model) View() string {
// Horizontal join
row := lipgloss.JoinHorizontal(
lipgloss.Top,
leftPanel.Render(m.menu),
rightPanel.Render(m.content),
)
// Vertical join
return lipgloss.JoinVertical(
lipgloss.Left,
m.header(),
row,
m.footer(),
)
}
// Center content
func (m Model) View() string {
return lipgloss.Place(
m.width, m.height,
lipgloss.Center, lipgloss.Center,
m.content,
)
}
```
### 6. Borders and Padding
```go
var boxStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("63")).
Padding(1, 2).
Margin(1)
var selectedStyle = lipgloss.NewStyle().
Border(lipgloss.DoubleBorder()).
BorderForeground(lipgloss.Color("205"))
```
## Common Patterns
### Selected Item Highlighting
```go
func (m Model) renderItems() string {
var b strings.Builder
for i, item := range m.items {
cursor := " "
if i == m.cursor {
cursor = "▸ "
}
style := itemStyle
if i == m.cursor {
style = selectedStyle
}
b.WriteString(style.Render(cursor + item.Title))
b.WriteString("\n")
}
return b.String()
}
```
### Help Footer
```go
func (m Model) helpView() string {
return helpStyle.Render("↑/↓: navigate • enter: select • q: quit")
}
// Or use the help bubble
import "github.com/charmbracelet/bubbles/help"
func (m Model) View() string {
return m.content + "\n" + m.help.View(m.keys)
}
```
### Status Bar
```go
var statusStyle = lipgloss.NewStyle().
Background(lipgloss.Color("235")).
Foreground(lipgloss.Color("255")).
Padding(0, 1)
func (m Model) statusBar() string {
status := fmt.Sprintf("Items: %d | Selected: %d", len(m.items), len(m.selected))
return statusStyle.Width(m.width).Render(status)
}
```
## Anti-Patterns
### 1. ANSI Codes Instead of Lipgloss
```go
// BAD - raw ANSI
func (m Model) View() string {
return "\033[1;31mError\033[0m"
}
// GOOD - Lipgloss
var errorStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("196"))
func (m Model) View() string {
return errorStyle.Render("Error")
}
```
### 2. Hardcoded Dimensions
```go
// BAD - ignores terminal size
var boxStyle = lipgloss.NewStyle().Width(80)
// GOOD - responsive
func (m Model) renderBox() string {
return boxStyle.Width(m.width - 4).Render(m.content)
}
```
## Review Questions
1. Is View a pure function with no side effects?
2. Are styles defined once, not in View?
3. Are colors using AdaptiveColor for light/dark themes?
4. Is layout responsive to WindowSizeMsg?
5. Are lipgloss.Join/Place used for layout composition?
Reviews Phoenix code for controller patterns, context boundaries, routing, and plugs. Use when reviewing Phoenix apps, checking controllers, routers, or cont...
---
name: phoenix-code-review
description: Reviews Phoenix code for controller patterns, context boundaries, routing, and plugs. Use when reviewing Phoenix apps, checking controllers, routers, or context modules.
---
# Phoenix Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| Bounded contexts, Ecto integration | [references/contexts.md](references/contexts.md) |
| Actions, params, error handling | [references/controllers.md](references/controllers.md) |
| Pipelines, scopes, verified routes | [references/routing.md](references/routing.md) |
| Custom plugs, authentication | [references/plugs.md](references/plugs.md) |
## Review Checklist
### Controllers
- [ ] Business logic in contexts, not controllers
- [ ] Controllers return proper HTTP status codes
- [ ] Action clauses handle all expected patterns
- [ ] Fallback controllers handle errors consistently
### Contexts
- [ ] Contexts are bounded by domain, not technical layer
- [ ] Public functions have clear, domain-focused names
- [ ] Changesets validate all user input
- [ ] No Ecto queries in controllers
### Routing
- [ ] Verified routes (~p sigil) used, not string paths
- [ ] Pipelines group related plugs
- [ ] Resources use only needed actions
- [ ] Scopes group related routes
### Plugs
- [ ] Authentication/authorization via plugs
- [ ] Plugs are composable and single-purpose
- [ ] Halt called after sending response in plugs
### JSON APIs
- [ ] Proper content negotiation
- [ ] Consistent error response format
- [ ] Pagination for list endpoints
## Valid Patterns (Do NOT Flag)
- **Controller calling multiple contexts** - Valid for orchestration
- **Inline Ecto query in context** - Context owns its data access
- **Using `action_fallback`** - Centralized error handling pattern
- **Multiple pipelines per route** - Composition is intentional
- **`Plug.Conn.halt/1` without send** - May be handled by fallback
## Context-Sensitive Rules
| Issue | Flag ONLY IF |
|-------|--------------|
| Missing changeset validation | Field accepts user input AND no validation exists |
| Controller too large | More than 7 actions OR actions > 20 lines |
| Missing authorization | Route is not public AND no auth plug in pipeline |
## Gates (run in order; each step has a pass condition)
1. **Anchored evidence** — For every planned finding, open the source and note **file path + line number** from that read (not from memory or diff snippets alone). **Pass:** each finding cites `path:line` that you opened.
2. **“Handled elsewhere” sweep** — Before reporting “missing validation,” “missing auth,” or “wrong status,” search the router (pipelines/scopes), controller (`action_fallback`, `plug`), and relevant context for existing checks. **Pass:** you recorded whether handling exists elsewhere (yes + where, or no after search).
3. **Verification protocol** — Load and apply [review-verification-protocol](../review-verification-protocol/SKILL.md) for the issue type. **Pass:** that skill’s pre-report checks for that finding class are satisfied before you write the finding.
4. **Finding shape** — Emit each issue as `[FILE:LINE] ISSUE_TITLE` with a one-line rationale tied to the cited code. **Pass:** every line matches that pattern.
## Before Submitting Findings
Do not report until **Gates** above pass. For full anti-false-positive steps, follow [review-verification-protocol](../review-verification-protocol/SKILL.md).
FILE:references/contexts.md
# Phoenix Contexts
## Bounded Contexts
### Domain Boundaries
```elixir
# GOOD - contexts bounded by domain
lib/my_app/
├── accounts/ # User identity & auth
│ ├── user.ex
│ └── accounts.ex
├── catalog/ # Product information
│ ├── product.ex
│ └── catalog.ex
└── orders/ # Purchase workflow
├── order.ex
└── orders.ex
# BAD - contexts bounded by technical layer
lib/my_app/
├── models/
├── queries/
└── services/
```
### Public API Design
```elixir
# GOOD - domain-focused function names
defmodule MyApp.Accounts do
def register_user(attrs)
def authenticate_user(email, password)
def reset_password(user, new_password)
end
# BAD - CRUD-focused names
defmodule MyApp.Accounts do
def create_user(attrs)
def get_user(id)
def update_user(user, attrs)
end
```
## Ecto Integration
### Changesets in Contexts
```elixir
defmodule MyApp.Accounts do
alias MyApp.Accounts.User
def create_user(attrs) do
%User{}
|> User.registration_changeset(attrs)
|> Repo.insert()
end
def update_user(%User{} = user, attrs) do
user
|> User.update_changeset(attrs)
|> Repo.update()
end
end
```
### Schema Definitions
```elixir
defmodule MyApp.Accounts.User do
use Ecto.Schema
import Ecto.Changeset
schema "users" do
field :email, :string
field :password_hash, :string
field :password, :string, virtual: true
timestamps()
end
def registration_changeset(user, attrs) do
user
|> cast(attrs, [:email, :password])
|> validate_required([:email, :password])
|> validate_format(:email, ~r/@/)
|> validate_length(:password, min: 8)
|> unique_constraint(:email)
|> hash_password()
end
end
```
## Cross-Context Communication
```elixir
# GOOD - contexts communicate through public APIs
defmodule MyApp.Orders do
alias MyApp.Accounts
def create_order(user_id, items) do
with {:ok, user} <- Accounts.get_user(user_id),
:ok <- Accounts.verify_can_purchase(user) do
# Create order
end
end
end
# BAD - reaching into another context's internals
defmodule MyApp.Orders do
alias MyApp.Accounts.User
alias MyApp.Repo
def create_order(user_id, items) do
user = Repo.get!(User, user_id) # Bypasses Accounts context!
# ...
end
end
```
## Review Questions
1. Are contexts bounded by business domain, not technical layer?
2. Do public functions have domain-focused names?
3. Are changesets used for all data validation?
4. Do contexts communicate through public APIs only?
FILE:references/controllers.md
# Phoenix Controllers
## Action Structure
### Keep Controllers Thin
```elixir
# GOOD - delegates to context
defmodule MyAppWeb.UserController do
use MyAppWeb, :controller
alias MyApp.Accounts
def create(conn, %{"user" => user_params}) do
case Accounts.register_user(user_params) do
{:ok, user} ->
conn
|> put_status(:created)
|> render(:show, user: user)
{:error, changeset} ->
conn
|> put_status(:unprocessable_entity)
|> render(:error, changeset: changeset)
end
end
end
# BAD - business logic in controller
defmodule MyAppWeb.UserController do
def create(conn, %{"user" => params}) do
changeset = User.changeset(%User{}, params)
if changeset.valid? do
# Validation logic here...
# Email verification logic here...
# Password hashing here...
end
end
end
```
### Action Fallback
```elixir
defmodule MyAppWeb.UserController do
use MyAppWeb, :controller
action_fallback MyAppWeb.FallbackController
def show(conn, %{"id" => id}) do
with {:ok, user} <- Accounts.get_user(id) do
render(conn, :show, user: user)
end
end
end
defmodule MyAppWeb.FallbackController do
use MyAppWeb, :controller
def call(conn, {:error, :not_found}) do
conn
|> put_status(:not_found)
|> put_view(MyAppWeb.ErrorJSON)
|> render(:"404")
end
def call(conn, {:error, %Ecto.Changeset{} = changeset}) do
conn
|> put_status(:unprocessable_entity)
|> put_view(MyAppWeb.ChangesetJSON)
|> render(:error, changeset: changeset)
end
end
```
## Parameter Handling
### Pattern Match in Function Head
```elixir
# GOOD - pattern match expected params
def update(conn, %{"id" => id, "user" => user_params}) do
# ...
end
# GOOD - handle missing params explicitly
def update(conn, %{"id" => id}) do
conn
|> put_status(:bad_request)
|> json(%{error: "Missing user params"})
end
```
### Strong Parameters via Changesets
```elixir
# Changeset controls which fields are accepted
def registration_changeset(user, attrs) do
user
|> cast(attrs, [:email, :password, :name]) # Only these fields
|> validate_required([:email, :password])
end
```
## HTTP Status Codes
| Action | Success | Common Errors |
|--------|---------|---------------|
| create | 201 Created | 422 Unprocessable |
| show | 200 OK | 404 Not Found |
| update | 200 OK | 404, 422 |
| delete | 204 No Content | 404 |
| index | 200 OK | - |
## Review Questions
1. Is business logic delegated to contexts?
2. Do actions use appropriate HTTP status codes?
3. Is action_fallback used for consistent error handling?
4. Are parameters validated via changesets?
FILE:references/plugs.md
# Phoenix Plugs
## Custom Plugs
### Module Plug Structure
```elixir
defmodule MyAppWeb.Plugs.RequireAuth do
import Plug.Conn
import Phoenix.Controller
def init(opts), do: opts
def call(conn, _opts) do
if conn.assigns[:current_user] do
conn
else
conn
|> put_status(:unauthorized)
|> put_view(MyAppWeb.ErrorJSON)
|> render(:"401")
|> halt() # IMPORTANT: halt after sending response
end
end
end
```
### Function Plug
```elixir
defmodule MyAppWeb.UserController do
plug :load_user when action in [:show, :edit, :update]
defp load_user(conn, _opts) do
case Accounts.get_user(conn.params["id"]) do
{:ok, user} -> assign(conn, :user, user)
{:error, :not_found} ->
conn
|> put_status(:not_found)
|> render(:not_found)
|> halt()
end
end
end
```
## Authentication Pattern
```elixir
defmodule MyAppWeb.Plugs.LoadCurrentUser do
import Plug.Conn
def init(opts), do: opts
def call(conn, _opts) do
user_id = get_session(conn, :user_id)
cond do
conn.assigns[:current_user] ->
conn # Already loaded
user_id && user = Accounts.get_user!(user_id) ->
assign(conn, :current_user, user)
true ->
assign(conn, :current_user, nil)
end
end
end
```
## Authorization Pattern
```elixir
defmodule MyAppWeb.Plugs.RequireAdmin do
import Plug.Conn
import Phoenix.Controller
def init(opts), do: opts
def call(conn, _opts) do
user = conn.assigns[:current_user]
if user && user.admin do
conn
else
conn
|> put_status(:forbidden)
|> put_view(MyAppWeb.ErrorJSON)
|> render(:"403")
|> halt()
end
end
end
```
## Plug Composition
```elixir
# In router
pipeline :authenticated do
plug MyAppWeb.Plugs.LoadCurrentUser
plug MyAppWeb.Plugs.RequireAuth
end
pipeline :admin do
plug MyAppWeb.Plugs.RequireAdmin
end
scope "/admin", MyAppWeb.Admin do
pipe_through [:browser, :authenticated, :admin]
# ...
end
```
## Common Mistakes
### Forgetting to Halt
```elixir
# BAD - continues to controller after sending response
def call(conn, _opts) do
if unauthorized?(conn) do
conn
|> send_resp(401, "Unauthorized")
# Missing halt()! Controller still runs
else
conn
end
end
# GOOD
def call(conn, _opts) do
if unauthorized?(conn) do
conn
|> send_resp(401, "Unauthorized")
|> halt()
else
conn
end
end
```
### Modifying Halted Conn
```elixir
# BAD - checking after halt
def call(conn, _opts) do
conn = maybe_halt(conn)
assign(conn, :data, load_data()) # Runs even if halted!
end
# GOOD - check halted status
def call(conn, _opts) do
conn = maybe_halt(conn)
if conn.halted do
conn
else
assign(conn, :data, load_data())
end
end
```
## Review Questions
1. Do plugs call halt() after sending a response?
2. Is authentication handled via plugs, not controller logic?
3. Are plugs composable and single-purpose?
4. Is halted status checked before further processing?
FILE:references/routing.md
# Phoenix Routing
## Verified Routes
### Use ~p Sigil
```elixir
# GOOD - verified at compile time
~p"/users/#{user.id}"
~p"/users/#{user}/edit"
# BAD - string interpolation (no compile-time check)
"/users/#{user.id}"
"/users/#{user.id}/edit"
```
### In Templates
```heex
<%# GOOD %>
<.link navigate={~p"/users/#{@user}"}>Profile</.link>
<%# BAD %>
<.link navigate={"/users/#{@user.id}"}>Profile</.link>
```
## Pipelines
### Group Related Plugs
```elixir
pipeline :browser do
plug :accepts, ["html"]
plug :fetch_session
plug :fetch_live_flash
plug :put_root_layout, html: {MyAppWeb.Layouts, :root}
plug :protect_from_forgery
plug :put_secure_browser_headers
end
pipeline :api do
plug :accepts, ["json"]
end
pipeline :authenticated do
plug MyAppWeb.Plugs.RequireAuth
plug MyAppWeb.Plugs.LoadCurrentUser
end
```
### Compose Pipelines
```elixir
scope "/", MyAppWeb do
pipe_through [:browser, :authenticated]
resources "/settings", SettingsController, only: [:edit, :update]
end
scope "/admin", MyAppWeb.Admin do
pipe_through [:browser, :authenticated, :require_admin]
resources "/users", UserController
end
```
## Resources
### Limit Actions
```elixir
# GOOD - only needed actions
resources "/users", UserController, only: [:index, :show, :create]
resources "/sessions", SessionController, only: [:new, :create, :delete]
# BAD - all actions when not needed
resources "/users", UserController # Generates 7 routes
```
### Nested Resources
```elixir
# GOOD - shallow nesting
resources "/posts", PostController do
resources "/comments", CommentController, only: [:create]
end
resources "/comments", CommentController, only: [:show, :update, :delete]
# BAD - deep nesting
resources "/users", UserController do
resources "/posts", PostController do
resources "/comments", CommentController # Too deep!
end
end
```
## Scopes
```elixir
# API versioning
scope "/api", MyAppWeb.API do
scope "/v1", V1 do
pipe_through :api
resources "/users", UserController
end
scope "/v2", V2 do
pipe_through [:api, :v2_transforms]
resources "/users", UserController
end
end
```
## Review Questions
1. Are verified routes (~p) used instead of string paths?
2. Are pipelines composed for authentication/authorization?
3. Do resources specify only needed actions?
4. Is nesting kept shallow (max 1 level)?
Reviews WidgetKit code for timeline management, view composition, configurable intents, and performance. Use when reviewing code with import WidgetKit, Timel...
---
name: widgetkit-code-review
description: Reviews WidgetKit code for timeline management, view composition, configurable intents, and performance. Use when reviewing code with import WidgetKit, TimelineProvider, Widget protocol, or @main struct Widget.
---
# WidgetKit Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| TimelineProvider, entries, reload policies | [references/timeline.md](references/timeline.md) |
| Widget families, containerBackground, deep linking | [references/views.md](references/views.md) |
| AppIntentConfiguration, EntityQuery, @Parameter | [references/intents.md](references/intents.md) |
| Refresh budget, memory limits, caching | [references/performance.md](references/performance.md) |
## Review Checklist
- [ ] `placeholder(in:)` returns immediately without async work
- [ ] Timeline entries spaced at least 5 minutes apart
- [ ] `getSnapshot` checks `context.isPreview` for gallery previews
- [ ] `containerBackground(for:)` used for iOS 17+ compatibility
- [ ] `widgetURL` used for systemSmall (not Link)
- [ ] No Button views (use Link or widgetURL)
- [ ] No AsyncImage or UIViewRepresentable in widget views
- [ ] Images downsampled to widget display size (~30MB limit)
- [ ] App Groups configured for data sharing between app and widget
- [ ] EntityQuery implements `defaultResult()` for non-optional parameters
- [ ] New intent parameters handle nil for existing widgets after updates
- [ ] `reloadTimelines` called strategically (not on every data change)
## When to Load References
- TimelineProvider implementation or refresh issues -> timeline.md
- Widget sizes, Lock Screen, containerBackground -> views.md
- Configurable widgets, AppIntent migration -> intents.md
- Memory issues, caching, budget management -> performance.md
## Review Questions
1. Does the widget provide fallback entries for when system delays refresh?
2. Are Lock Screen families (accessoryCircular/Rectangular/Inline) handled appropriately?
3. Would migrating from IntentConfiguration break existing user widgets?
4. Is timeline populated with future entries or does it rely on frequent refreshes?
5. Is data cached via App Groups for widget access?
## Hard gates (before reporting)
Complete **in order** for each finding you intend to report. Do not advance until the pass condition is satisfied.
1. **Location artifact** — The finding includes `[FILE:LINE]` (or a line range) copied from the current file contents; the path resolves in this repo.
2. **Scope read** — You read the full surrounding implementation: the `TimelineProvider` (including `placeholder`, `getSnapshot`, and `getTimeline` when relevant), the `@main` `Widget` / widget bundle, or the configurable widget’s `AppIntentConfiguration` / intent types—not only a diff hunk or snippet.
3. **Platform or system claim** (only if the finding depends on refresh budget, ~30MB memory guidance, Lock Screen accessory families, iOS 17+ `containerBackground`, App Groups data sharing, or migration from `IntentConfiguration` to `AppIntentConfiguration`) — You name one concrete artifact you inspected (for example `.entitlements` / App Group id in project, `WidgetFamily` handling in source, `IPHONEOS_DEPLOYMENT_TARGET`, or the exact reference subsection you used) **or** you drop or downgrade the finding to an open question.
4. **Protocol** — Pre-report steps in [review-verification-protocol](../review-verification-protocol/SKILL.md) are satisfied for this item (no finding if they are not).
Use the issue format `[FILE:LINE] ISSUE_TITLE` for each reported finding. Hard gate 4 is the full pre-report checklist for this skill’s review type.
FILE:references/intents.md
# Configurable Widgets
## Configuration Approaches
| Approach | iOS | Status |
|----------|-----|--------|
| `StaticConfiguration` | 14+ | Non-configurable widgets |
| `IntentConfiguration` | 14+ | Legacy SiriKit intents |
| `AppIntentConfiguration` | 17+ | Modern App Intents |
**Migration warning**: Changing from `IntentConfiguration` to `AppIntentConfiguration` can cause existing user widgets to disappear or freeze.
## AppIntentTimelineProvider
```swift
struct ConfigurableProvider: AppIntentTimelineProvider {
func placeholder(in context: Context) -> MyEntry { .placeholder } // Sync, instant
func snapshot(for config: MyIntent, in context: Context) async -> MyEntry {
MyEntry(date: .now, item: config.selectedItem)
}
func timeline(for config: MyIntent, in context: Context) async -> Timeline<MyEntry> {
let entry = MyEntry(date: .now, item: config.selectedItem)
return Timeline(entries: [entry], policy: .after(.now.addingTimeInterval(900)))
}
}
```
## Widget Configuration
```swift
struct MyWidgetIntent: WidgetConfigurationIntent {
static var title: LocalizedStringResource = "Configure Widget"
@Parameter(title: "Name", default: "Default") var name: String
@Parameter(title: "Style") var style: DisplayStyle // AppEnum
@Parameter(title: "Item") var selectedItem: ItemEntity? // AppEntity
}
struct MyWidget: Widget {
var body: some WidgetConfiguration {
AppIntentConfiguration(kind: "com.app.widget", intent: MyWidgetIntent.self,
provider: ConfigurableProvider()) { entry in
MyWidgetView(entry: entry)
}
}
}
```
## Dynamic Options
### EntityStringQuery for Custom Types
```swift
struct ItemEntity: AppEntity {
static var defaultQuery = ItemQuery()
var id: String
var name: String
var displayRepresentation: DisplayRepresentation { DisplayRepresentation(title: "\(name)") }
}
struct ItemQuery: EntityStringQuery {
func entities(for identifiers: [String]) async throws -> [ItemEntity] { /* fetch by IDs */ }
func entities(matching string: String) async throws -> [ItemEntity] { /* search */ }
func suggestedEntities() async throws -> [ItemEntity] { /* default list */ }
func defaultResult() async -> ItemEntity? { /* REQUIRED for non-optional params */ }
}
```
### DynamicOptionsProvider for Simple Types
```swift
@Parameter(title: "Hour", optionsProvider: HourOptionsProvider()) var hour: Int
struct HourOptionsProvider: DynamicOptionsProvider {
func results() async throws -> [Int] { Array(0..<24) }
func defaultResult() async -> Int? { 12 }
}
```
## Critical Anti-Patterns
### Missing defaultResult()
```swift
// BAD: Widget shows "Select" instead of value
struct ItemQuery: EntityStringQuery { /* no defaultResult() */ }
// GOOD: Always implement for non-optional entity parameters
func defaultResult() async -> ItemEntity? { items.first }
```
### Ignoring Nil After App Updates
```swift
// BAD: Parameters added in updates are nil for existing widgets
let name = config.newParameter.name // Crash!
// GOOD: Handle optional parameters
let name = config.newParameter?.name ?? "Default"
```
### Heavy Work in Placeholder
```swift
// BAD: Blocks UI
func placeholder(in context: Context) -> Entry { Entry(data: fetchSync()) }
// GOOD: Return static data instantly
func placeholder(in context: Context) -> Entry { .placeholder }
```
### Breaking Migration
```swift
// BAD: Same kind causes widget disappearance
AppIntentConfiguration(kind: "widget", ...) // Was IntentConfiguration
// GOOD: Use new kind for new configuration type
AppIntentConfiguration(kind: "widget.v2", ...)
```
## Review Questions
1. **Does EntityQuery implement `defaultResult()`?** Missing causes "Select" UI instead of default.
2. **Are new parameters optional-safe?** Parameters added in updates are nil for existing widgets.
3. **Is placeholder instant?** Must be synchronous with static data only.
4. **Does migration use new kind?** Same kind string breaks existing widgets.
5. **Is configuration stored in timeline entry?** Entry must hold intent for view access.
6. **Are AppEntity types Codable?** Required for WidgetKit to persist configuration.
FILE:references/performance.md
# Widget Performance
## Budget System
Widgets operate under strict refresh budgets to conserve battery:
- **Daily budget**: 40-70 refreshes for frequently viewed widgets
- **Refresh interval**: Every 15-60 minutes in production
- **Debug mode**: No limits during development
### Timeline Policies
```swift
Timeline(entries: entries, policy: .atEnd) // Refresh when timeline exhausted
Timeline(entries: entries, policy: .after(date)) // Refresh after specific date
Timeline(entries: entries, policy: .never) // Manual refresh via reloadTimelines()
```
Populate timelines with as many future entries as possible. Keep entries at least 5 minutes apart.
## Memory Limits
Widgets are constrained to approximately **30MB** - this applies collectively across all timeline entries.
```swift
// BAD: Loading full-resolution images
let image = UIImage(contentsOfFile: path)
// GOOD: Downsample to widget display size
func downsample(imageAt url: URL, to size: CGSize, scale: CGFloat) -> UIImage? {
let options = [kCGImageSourceShouldCache: false] as CFDictionary
guard let source = CGImageSourceCreateWithURL(url as CFURL, options) else { return nil }
let maxDim = max(size.width, size.height) * scale
let downsampleOptions = [
kCGImageSourceCreateThumbnailFromImageAlways: true,
kCGImageSourceThumbnailMaxPixelSize: maxDim
] as CFDictionary
guard let cg = CGImageSourceCreateThumbnailAtIndex(source, 0, downsampleOptions) else { return nil }
return UIImage(cgImage: cg)
}
```
## Data Fetching
Network calls must complete within timeline generation. Never call APIs in `getSnapshot()`:
```swift
func getTimeline(in context: Context, completion: @escaping (Timeline<Entry>) -> Void) {
Task {
guard let data = try? await fetchData() else {
completion(Timeline(entries: [Entry(date: Date(), data: cachedData)], policy: .after(Date().addingTimeInterval(900))))
return
}
completion(Timeline(entries: [Entry(date: Date(), data: data)], policy: .after(Date().addingTimeInterval(3600))))
}
}
func getSnapshot(in context: Context, completion: @escaping (Entry) -> Void) {
completion(context.isPreview ? .sample : Entry(date: Date(), data: cachedData ?? .sample))
}
```
For background downloads, use `onBackgroundURLSessionEvents` modifier on the widget configuration.
## Caching Strategies
### App Groups for Shared Data
```swift
let sharedDefaults = UserDefaults(suiteName: "group.com.yourapp.widgets")
// Main app: save and notify widget
func saveWidgetData(_ data: WidgetData) {
if let encoded = try? JSONEncoder().encode(data) {
sharedDefaults?.set(encoded, forKey: "widgetData")
WidgetCenter.shared.reloadAllTimelines()
}
}
// Widget: read cached data
func loadWidgetData() -> WidgetData? {
guard let data = sharedDefaults?.data(forKey: "widgetData") else { return nil }
return try? JSONDecoder().decode(WidgetData.self, from: data)
}
```
Both app and widget extension must have the same App Group in Signing & Capabilities.
## Critical Anti-Patterns
### AsyncImage Not Supported
```swift
// BAD: Widgets render synchronously
AsyncImage(url: imageURL)
// GOOD: Pre-fetch in timeline provider
Image(uiImage: cachedImage)
```
### Excessive Reloads
```swift
// BAD: Burns budget quickly
WidgetCenter.shared.reloadAllTimelines()
// GOOD: Reload specific widget strategically
WidgetCenter.shared.reloadTimelines(ofKind: "specificWidget")
```
### UIKit Components
```swift
// BAD: UIViewRepresentable not supported
MapViewRepresentable()
// GOOD: Use MKMapSnapshotter for map images
Image(uiImage: mapSnapshot)
```
### Keychain Access
Keychain can fail with `errSecInteractionNotAllowed` after extended periods. Use App Groups instead.
### Sparse Timelines
```swift
// BAD: Forces frequent refreshes
Timeline(entries: [entry], policy: .after(Date().addingTimeInterval(60)))
// GOOD: Pre-computed entries
let entries = (0..<24).map { Entry(date: Date().addingTimeInterval(Double($0) * 3600), data: data) }
Timeline(entries: entries, policy: .atEnd)
```
## Review Questions
1. Does the widget downsample images to display size, or load full-resolution assets?
2. Are timeline entries pre-computed for future dates to minimize refresh frequency?
3. Does `getSnapshot()` avoid network calls and use cached/sample data?
4. Is App Groups configured correctly for both app and widget extension targets?
5. Are `reloadTimelines()` calls strategic, or does every data update trigger a reload?
6. Does the widget view avoid AsyncImage and other async loading patterns?
FILE:references/timeline.md
# Timeline Management
## Core Concepts
WidgetKit renders widgets as static snapshots at predetermined times. The system controls refresh timing to optimize battery life, allowing 40-70 refreshes per day (every 15-60 minutes). Timeline entries should be at least 5 minutes apart. The system may delay refreshes significantly beyond requested times.
## TimelineProvider Protocol
Three required methods with distinct purposes:
| Method | Sync/Async | Purpose |
|--------|-----------|---------|
| `placeholder(in:)` | Synchronous | Redacted loading state; return immediately |
| `getSnapshot(in:completion:)` | Async | Widget gallery preview; check `context.isPreview` |
| `getTimeline(in:completion:)` | Async | Primary content; returns entries array + reload policy |
```swift
struct Provider: TimelineProvider {
func placeholder(in context: Context) -> Entry {
Entry(date: .now, data: .placeholder) // Must be instant
}
func getSnapshot(in context: Context, completion: @escaping (Entry) -> ()) {
completion(Entry(date: .now, data: context.isPreview ? .sample : .current))
}
func getTimeline(in context: Context, completion: @escaping (Timeline<Entry>) -> ()) {
let entries = (0..<12).map { hour in
Entry(date: Calendar.current.date(byAdding: .hour, value: hour, to: .now)!, data: .forHour(hour))
}
completion(Timeline(entries: entries, policy: .atEnd))
}
}
```
## TimelineEntry
Requires only `date` property. Add custom properties for widget data:
```swift
struct MyEntry: TimelineEntry {
let date: Date // Required: when to display
let relevance: TimelineEntryRelevance? // Optional: Smart Stack ranking
let title: String // Custom data
}
```
## Reload Policies
| Policy | Behavior | Use Case |
|--------|----------|----------|
| `.atEnd` | Request new timeline after last entry expires | Regularly changing content |
| `.after(Date)` | Wait until specified date | Known future update time |
| `.never` | No auto-refresh; requires `reloadTimelines` call | App-driven updates only |
All policies are suggestions. System decides actual timing based on budget and battery.
## App-Driven Reloads
```swift
WidgetCenter.shared.reloadTimelines(ofKind: "MyWidget") // Specific widget
WidgetCenter.shared.reloadAllTimelines() // All widgets
```
Limitations: Not immediate; may only update when app backgrounds; subject to daily budget.
## Critical Anti-Patterns
```swift
// BAD: Entries too close together
for minute in 0..<60 {
let date = Calendar.current.date(byAdding: .minute, value: minute, to: now)!
entries.append(Entry(date: date))
}
// GOOD: Reasonable intervals (5+ minutes minimum)
for hour in 0..<24 {
let date = Calendar.current.date(byAdding: .hour, value: hour, to: now)!
entries.append(Entry(date: date))
}
```
```swift
// BAD: Heavy work in synchronous placeholder
func placeholder(in context: Context) -> Entry {
let data = fetchDataSync() // Blocks UI, may timeout
return Entry(date: .now, data: data)
}
```
```swift
// BAD: Ignoring isPreview
func getSnapshot(in context: Context, completion: @escaping (Entry) -> ()) {
fetchRealData { completion(Entry(date: .now, data: $0)) } // Slow for gallery
}
// GOOD: Sample data for previews, real data otherwise
func getSnapshot(in context: Context, completion: @escaping (Entry) -> ()) {
if context.isPreview {
completion(Entry(date: .now, data: .sample))
} else {
completion(Entry(date: .now, data: .current))
}
}
```
```swift
// BAD: Expecting exact refresh timing
Timeline(entries: entries, policy: .after(exactDeadline)) // May refresh hours late
// GOOD: Include fallback entries past critical times
```
## Review Questions
1. **Does `placeholder(in:)` return immediately without async work?**
2. **Are timeline entries spaced at least 5 minutes apart?**
3. **Does `getSnapshot` check `context.isPreview` for gallery previews?**
4. **Is the reload policy appropriate?** Static: `.never`; Dynamic: `.atEnd`/`.after`
5. **Are there fallback entries past critical times?** System may delay refreshes.
6. **Is `reloadTimelines` called only when necessary?** Each call consumes budget.
FILE:references/views.md
# Widget Views
## Widget Families
**Home Screen:** `systemSmall`, `systemMedium`, `systemLarge`, `systemExtraLarge` (iPad only)
**Lock Screen (iOS 16+):** `accessoryCircular`, `accessoryRectangular`, `accessoryInline`
```swift
.supportedFamilies([.systemSmall, .systemMedium, .accessoryCircular, .accessoryRectangular])
```
## View Composition
Use `@Environment(\.widgetFamily)` for adaptive layouts:
```swift
@Environment(\.widgetFamily) var widgetFamily
var body: some View {
switch widgetFamily {
case .systemSmall: CompactView()
case .accessoryCircular: CircularWidgetView()
case .accessoryInline: Text(entry.summary)
default: DetailedView()
}
}
```
- Use `@Environment(\.widgetRenderingMode)` to detect Lock Screen vibrant mode
- `AccessoryWidgetBackground()` works for `accessoryCircular`/`accessoryRectangular` only
- Use `ViewThatFits` for content that may truncate
## containerBackground
**Required for iOS 17+.** Widgets show error without this modifier.
```swift
Text("Content")
.containerBackground(for: .widget) { Color.blue }
```
**Backwards compatibility:**
```swift
extension View {
func widgetBackground(_ bg: some View) -> some View {
if #available(iOSApplicationExtension 17.0, *) {
return containerBackground(for: .widget) { bg }
} else { return background(bg) }
}
}
```
**Configuration modifiers:**
- `.containerBackgroundRemovable(false)` - Prevent removal in StandBy
- `.contentMarginsDisabled()` - Opt out of automatic margins
## Deep Linking
| Size | Method | Notes |
|------|--------|-------|
| `systemSmall` | `widgetURL()` | Entire widget is one tap target |
| `systemMedium`/`Large` | `Link` or `widgetURL()` | Multiple tappable regions |
```swift
// Small widgets: entire widget taps
.widgetURL(URL(string: "myapp://item/\(entry.id)")!)
// Medium/Large: multiple targets
Link(destination: URL(string: "myapp://section1")!) { Text("Section 1") }
```
Handle in app with `.onOpenURL { url in handleDeepLink(url) }`
## Critical Anti-Patterns
| Issue | Problem |
|-------|---------|
| Missing `containerBackground` | iOS 17 shows error instead of widget |
| `Link` in `systemSmall` | Silently fails, only `widgetURL` works |
| `Button` in widgets | Never works, use `Link` or `widgetURL` |
| Same view for all families | Content truncated or wasted space |
| `AccessoryWidgetBackground` in `accessoryInline` | Renders empty view |
| No URL validation in `onOpenURL` | Security risk from malformed deep links |
## Review Questions
1. Does the widget use `containerBackground(for:)` for iOS 17+ compatibility?
2. Are Lock Screen families handled with appropriate compact layouts?
3. Is `widgetURL` used for `systemSmall` instead of `Link`?
4. Does the code avoid `Button` views (never work in widgets)?
5. Is `AccessoryWidgetBackground` excluded from `accessoryInline` contexts?
6. Are deep link URLs validated before navigation?
Reviews watchOS code for app lifecycle, complications (ClockKit/WidgetKit), WatchConnectivity, and performance constraints. Use when reviewing code with impo...
---
name: watchos-code-review
description: Reviews watchOS code for app lifecycle, complications (ClockKit/WidgetKit), WatchConnectivity, and performance constraints. Use when reviewing code with import WatchKit, WKExtension, WKApplicationDelegate, WCSession, or watchOS-specific patterns.
---
# watchOS Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| App lifecycle, scenes, background modes, extended runtime | [references/lifecycle.md](references/lifecycle.md) |
| ClockKit, WidgetKit, timeline providers, Smart Stack | [references/complications.md](references/complications.md) |
| WCSession, message passing, file transfer, reachability | [references/connectivity.md](references/connectivity.md) |
| Memory limits, background refresh, battery optimization | [references/performance.md](references/performance.md) |
## Review Checklist
- [ ] SwiftUI App protocol used with `@WKApplicationDelegateAdaptor` for lifecycle events
- [ ] `scenePhase` read from root view (not sheets/modals where it's always `.active`)
- [ ] `WKExtendedRuntimeSession` started only while app is active (not from background)
- [ ] Workout sessions recovered in `applicationDidFinishLaunching` (not just delegate)
- [ ] Background tasks scheduled at least 5 minutes apart; next scheduled before completing current
- [ ] `URLSessionDownloadTask` (not `DataTask`) used for background network requests
- [ ] WidgetKit used instead of ClockKit for watchOS 9+ complications
- [ ] Timeline includes future entries (not just current state); gaps avoided
- [ ] `TimelineEntryRelevance` implemented for Smart Stack prioritization
- [ ] WCSession delegate set before `activate()`; singleton pattern used
- [ ] `isReachable` checked before `sendMessage`; `transferUserInfo` for critical data
- [ ] Received files moved synchronously before delegate callback returns
## When to Load References
- Reviewing app lifecycle, background modes, or extended sessions -> lifecycle.md
- Reviewing complications, widgets, or timeline providers -> complications.md
- Reviewing WCSession, iPhone-Watch communication -> connectivity.md
- Reviewing memory, battery, or performance issues -> performance.md
## Output Format
Report issues using: `[FILE:LINE] ISSUE_TITLE`
Examples:
- `[WatchApp.swift:18] WKExtendedRuntimeSession started while app not active`
- `[ConnectivityManager.swift:42] WCSession.activate() before delegate assignment`
- `[ComplicationTimeline.swift:67] Timeline has no future entries`
## Hard gates (before reporting)
Complete **in order** for each finding you intend to report. Do not advance until the pass condition is satisfied.
1. **Location artifact** — The finding includes `[FILE:LINE]` (or a line range) copied from the current file contents; the path resolves in this repo.
2. **Scope read** — You read the full surrounding unit: the `View` body, `WKApplicationDelegate` / scene method, `TimelineProvider` implementation, `WCSessionDelegate` callback, or workout/background task handler that owns the behavior—not only a diff hunk.
3. **watchOS or pairing claim** (only if the finding depends on background modes, complication/timeline contracts, `WCSession` reachability or transfer semantics, workout or extended runtime rules, or device-specific limits) — You name one concrete artifact you inspected (for example `Info.plist` / target capabilities for background modes, the `WK*` / `WCSession` call order in source, entitlements, or a subsection you read in the matching doc from [Quick Reference](#quick-reference)) **or** you downgrade the item to an open question in [Review Questions](#review-questions).
4. **Protocol** — Pre-report steps in [review-verification-protocol](../review-verification-protocol/SKILL.md) are satisfied for this item (no finding if they are not).
Use the issue format `[FILE:LINE] ISSUE_TITLE` for each reported finding. Hard gate 4 is the full pre-report checklist for this skill’s review type.
## Review Questions
1. Is the app using modern SwiftUI lifecycle with delegate adaptor?
2. Are background tasks completing properly (calling `setTaskCompletedWithSnapshot`)?
3. Is UI update frequency reduced when `isLuminanceReduced` is true?
4. Are WatchConnectivity delegate callbacks dispatching to main thread?
5. Is `TabView` nested within another `TabView`? (Memory leak on watchOS)
FILE:references/complications.md
# watchOS Complications
## Evolution
- **ClockKit**: Deprecated framework (watchOS 2-8)
- **WidgetKit**: Modern replacement (watchOS 9+)
- **watchOS 10**: Smart Stack with relevance-based prioritization
- **watchOS 11**: `RelevantContext` API for context-aware widgets
## Widget Families (WidgetKit)
| Family | Use Case | ClockKit Equivalent |
|--------|----------|---------------------|
| `accessoryRectangular` | Multiple lines, graphs | `graphicRectangular` |
| `accessoryCircular` | Gauges, progress | `graphicCircular` variants |
| `accessoryInline` | Single text line | `utilitarianSmallFlat` |
| `accessoryCorner` | Icon + curved label (watchOS only) | `utilitarianSmall` |
## Timeline Provider Types
### Static Widget
```swift
struct Provider: TimelineProvider {
func placeholder(in context: Context) -> SimpleEntry
func getSnapshot(in context: Context, completion: @escaping (SimpleEntry) -> ())
func getTimeline(in context: Context, completion: @escaping (Timeline<SimpleEntry>) -> ())
}
```
### Configurable Widget (AppIntents)
```swift
struct Provider: AppIntentTimelineProvider {
func placeholder(in context: Context) -> SimpleEntry
func snapshot(for configuration: ConfigIntent, in context: Context) async -> SimpleEntry
func timeline(for configuration: ConfigIntent, in context: Context) async -> Timeline<SimpleEntry>
func recommendations() -> [AppIntentRecommendation<ConfigIntent>]
}
```
## Smart Stack Relevance
```swift
struct SimpleEntry: TimelineEntry {
var date: Date
var event: Event?
var relevance: TimelineEntryRelevance? {
guard let event = event else {
return TimelineEntryRelevance(score: 0)
}
return TimelineEntryRelevance(
score: 10,
duration: event.endDate.timeIntervalSince(date)
)
}
}
```
## Critical Anti-Patterns
### 1. Exceeding Refresh Budget
```swift
// BAD: Called on every data change
func dataDidUpdate() {
WidgetCenter.shared.reloadTimelines(ofKind: "MyWidget")
}
// GOOD: Throttle reloads, use timeline entries
func getTimeline(...) {
var entries: [Entry] = []
for hourOffset in 0..<24 {
let date = Calendar.current.date(byAdding: .hour, value: hourOffset, to: Date())!
entries.append(Entry(date: date, data: predictedData(for: date)))
}
completion(Timeline(entries: entries, policy: .atEnd))
}
```
**Budget**: ~40-70 refreshes/day (~every 15-60 minutes)
### 2. Gaps in Timeline
```swift
// BAD: Only entries for events
func getTimeline(...) {
for event in events {
entries.append(Entry(date: event.startDate, event: event))
}
}
// GOOD: Entries for state changes
func getTimeline(...) {
entries.append(Entry(date: Date(), event: currentEvent))
for event in upcomingEvents {
entries.append(Entry(date: event.startDate, event: event))
entries.append(Entry(date: event.endDate, event: nil)) // End state
}
}
```
### 3. Expensive Operations in Placeholder
```swift
// BAD: Blocks UI
func placeholder(in context: Context) -> Entry {
let data = fetchLatestData() // Network call!
return Entry(date: Date(), data: data)
}
// GOOD: Return static data immediately
func placeholder(in context: Context) -> Entry {
return Entry(date: Date(), data: .placeholder)
}
```
### 4. AsyncImage in Widget
```swift
// BAD: Won't work
var body: some View {
AsyncImage(url: imageURL) // Widgets can't do async in view
}
// GOOD: Fetch in timeline provider
func getTimeline(...) {
let imageData = try? Data(contentsOf: imageURL)
let entry = Entry(date: Date(), imageData: imageData)
completion(Timeline(entries: [entry], policy: .atEnd))
}
```
### 5. Not Implementing Migration
```swift
// BAD: User complications become blank
class ComplicationController: NSObject, CLKComplicationDataSource {
// Missing: var widgetMigrator: CLKComplicationWidgetMigrator
}
// GOOD: Implement migration
extension ComplicationController: CLKComplicationWidgetMigrator {
func widgetConfiguration(
from descriptor: CLKComplicationDescriptor
) async -> CLKComplicationWidgetMigrationConfiguration? {
return CLKComplicationStaticWidgetMigrationConfiguration(
kind: "MyWidget",
extensionBundleIdentifier: "com.myapp.widget"
)
}
}
```
## Key Modifiers
| Modifier | Purpose |
|----------|---------|
| `.widgetAccentable()` | Mark for accent coloring |
| `.widgetLabel { }` | Curved text for corner/circular |
| `.containerBackground(for: .widget)` | Smart Stack background |
| `.privacySensitive()` | Redact in Always-On |
| `AccessoryWidgetBackground()` | Consistent backdrop |
## Always-On Display
```swift
var body: some View {
VStack {
Image(systemName: "heart.fill")
.widgetAccentable()
if isLuminanceReduced {
Text("\(value)")
.redacted(reason: .placeholder) // Hide sensitive
} else {
Text("\(value) BPM")
.privacySensitive()
}
}
}
```
## Review Questions
1. Is WidgetKit used instead of ClockKit (watchOS 9+)?
2. Does `placeholder()` return immediately without async work?
3. Does the timeline include future entries (not just current)?
4. Is `TimelineEntryRelevance` implemented for Smart Stack?
5. Is `.privacySensitive()` applied to sensitive content?
6. Is `@Environment(\.isLuminanceReduced)` checked for Always-On?
7. Are images pre-fetched (not using AsyncImage)?
8. Is ClockKit migration implemented if updating from older app?
FILE:references/connectivity.md
# WatchConnectivity
## Communication Methods
| Method | Use Case | Guaranteed | Queuing |
|--------|----------|------------|---------|
| `sendMessage(_:)` | Real-time, immediate | No | None |
| `transferUserInfo(_:)` | Critical data | Yes | FIFO |
| `updateApplicationContext(_:)` | State sync, latest only | Yes (latest) | Overwrites |
| `transferFile(_:)` | Large files | Yes | FIFO |
| `transferCurrentComplicationUserInfo(_:)` | Complication data | Yes | Budget limited |
## Session Setup
```swift
final class WatchConnectivityService: NSObject, WCSessionDelegate {
static let shared = WatchConnectivityService()
override private init() {
super.init()
#if !os(watchOS)
guard WCSession.isSupported() else { return }
#endif
WCSession.default.delegate = self
WCSession.default.activate()
}
}
```
## Required Delegate Methods
**iOS (all three required):**
- `session(_:activationDidCompleteWith:error:)`
- `sessionDidBecomeInactive(_:)`
- `sessionDidDeactivate(_:)`
**watchOS (one required):**
- `session(_:activationDidCompleteWith:error:)`
## Pre-Send Validation
```swift
private func canSendToPeer() -> Bool {
guard WCSession.default.activationState == .activated else { return false }
#if os(watchOS)
guard WCSession.default.isCompanionAppInstalled else { return false }
#else
guard WCSession.default.isWatchAppInstalled else { return false }
#endif
return true
}
// For sendMessage only
if WCSession.default.isReachable {
WCSession.default.sendMessage(message, replyHandler: nil, errorHandler: nil)
}
```
## Critical Anti-Patterns
### 1. Setup in View Controller
```swift
// BAD: Won't be called during background launches
class MyViewController: UIViewController {
override func viewDidLoad() {
WCSession.default.delegate = self
WCSession.default.activate()
}
}
// GOOD: Singleton in early lifecycle
// In AppDelegate
func application(...) -> Bool {
_ = WatchConnectivityService.shared
return true
}
```
### 2. Using sendMessage for Critical Data
```swift
// BAD: Lost when counterpart not reachable
func sendToWatch(_ data: [String: Any]) {
WCSession.default.sendMessage(data, replyHandler: nil, errorHandler: nil)
}
// GOOD: Use appropriate method based on criticality
func sendToWatch(_ data: [String: Any], critical: Bool) {
guard canSendToPeer() else { return }
if critical {
WCSession.default.transferUserInfo(data)
} else if WCSession.default.isReachable {
WCSession.default.sendMessage(data, replyHandler: nil, errorHandler: nil)
}
}
```
### 3. UI Updates on Background Thread
```swift
// BAD: Delegate runs on background thread
func session(_ session: WCSession, didReceiveMessage message: [String: Any]) {
self.label.text = message["text"] as? String // Crash!
}
// GOOD: Dispatch to main
func session(_ session: WCSession, didReceiveMessage message: [String: Any]) {
DispatchQueue.main.async {
self.label.text = message["text"] as? String
}
}
```
### 4. Async File Handling
```swift
// BAD: File deleted before async completes
func session(_ session: WCSession, didReceive file: WCSessionFile) {
DispatchQueue.global().async {
try? FileManager.default.moveItem(at: file.fileURL, to: destination)
}
}
// GOOD: Synchronous move first
func session(_ session: WCSession, didReceive file: WCSessionFile) {
do {
try FileManager.default.moveItem(at: file.fileURL, to: destination)
DispatchQueue.main.async {
self.processFile(at: destination)
}
} catch {
print("Failed: \(error)")
}
}
```
### 5. Not Reactivating After Deactivation
```swift
// BAD: Session unusable after watch swap
func sessionDidDeactivate(_ session: WCSession) {
// Nothing
}
// GOOD: Reactivate for watch swaps
func sessionDidDeactivate(_ session: WCSession) {
WCSession.default.activate()
}
```
### 6. Reply Handler When Not Expecting Reply
```swift
// BAD: OS generates errors
WCSession.default.sendMessage(data, replyHandler: { _ in }, errorHandler: nil)
// GOOD: nil when no reply expected
WCSession.default.sendMessage(data, replyHandler: nil, errorHandler: { error in
print("Error: \(error)")
})
```
## Data Type Requirements
Only Plist-encodable types allowed:
- String, Int, Double, Bool
- Data
- Array, Dictionary (of above types)
```swift
// BAD: Custom types
WCSession.default.sendMessage(["user": myUser], ...)
// GOOD: Encode first
let data = try JSONEncoder().encode(myUser)
WCSession.default.sendMessage(["userData": data], ...)
```
## Review Questions
1. Is `WCSession.isSupported()` checked on iOS before setup?
2. Is delegate set before `activate()` (use singleton)?
3. Is `activationState == .activated` checked before sending?
4. Is `isReachable` checked for `sendMessage` calls?
5. Is `transferUserInfo` used for data that must be delivered?
6. Are delegate callbacks dispatching UI updates to main thread?
7. Are received files moved synchronously before delegate returns?
8. Is `sessionDidDeactivate` reactivating the session on iOS?
9. Are only Plist-encodable types being sent?
FILE:references/lifecycle.md
# WatchKit App Lifecycle
## Lifecycle Architecture
watchOS uses two lifecycle models:
### SwiftUI App Protocol (Modern)
```swift
@main
struct MyWatchApp: App {
@WKApplicationDelegateAdaptor var appDelegate: MyAppDelegate
var body: some Scene {
WindowGroup {
ContentView()
}
}
}
```
### WKApplicationDelegate
Use for lifecycle events not covered by SwiftUI's `scenePhase`. Note: `WKExtensionDelegate` was renamed to `WKApplicationDelegate` in Xcode 14.
### Scene Phase States
| State | Description |
|-------|-------------|
| `.active` | App in foreground, user can interact |
| `.inactive` | Visible but no interaction (wrist lowered, screen on) |
| `.background` | Not visible, may be terminated |
**Note**: On watchOS, `.inactive` does NOT mean the app isn't running.
## Background Execution Modes
| Mode | Use Case | Constraints |
|------|----------|-------------|
| `BGAppRefreshTask` | Data updates | 4 per hour; 4s CPU, 15s total |
| `HKWorkoutSession` | Workout tracking | Continuous; use for workouts only |
| `WKExtendedRuntimeSession` | Self-care, mindfulness | Start while active only |
| Background URLSession | Downloads | Requires complication or dock |
### WKExtendedRuntimeSession Types
| Type | Duration | Notes |
|------|----------|-------|
| Self Care | 10 minutes | |
| Mindfulness | 1 hour | |
| Physical Therapy | 1 hour | Allows background multitasking |
| Health Monitoring | Variable | Requires entitlement |
| Alarm | 30 minutes | Use `startAtDate()` to schedule |
## Critical Anti-Patterns
### 1. Heavy Work in Lifecycle Methods
```swift
// BAD: Slows resume time
func applicationDidBecomeActive() {
loadAllDataFromDisk()
syncWithServer()
}
// GOOD: Defer to background
func applicationDidBecomeActive() {
Task.detached(priority: .background) {
await self.prefetchData()
}
}
```
### 2. Reading scenePhase in Sheets
```swift
// BAD: Always returns .active in sheets
struct SettingsSheet: View {
@Environment(\.scenePhase) var scenePhase // Broken!
}
// GOOD: Pass from root view
struct ContentView: View {
@Environment(\.scenePhase) var scenePhase
var body: some View {
Button("Settings") { showSettings = true }
.sheet(isPresented: $showSettings) {
SettingsSheet(scenePhase: scenePhase)
}
}
}
```
### 3. Starting Extended Sessions from Background
```swift
// BAD: Cannot start from background
func applicationDidEnterBackground() {
let session = WKExtendedRuntimeSession()
session.start() // Error!
}
// GOOD: Start while active
func startMindfulnessSession() {
guard WKApplication.shared().applicationState == .active else { return }
extendedSession = WKExtendedRuntimeSession()
extendedSession?.start()
}
```
### 4. Not Recovering Workout Sessions
```swift
// BAD: handleActiveWorkoutRecovery NOT called on reboot
class AppDelegate: NSObject, WKApplicationDelegate {
func handleActiveWorkoutRecovery() {
recoverWorkout()
}
}
// GOOD: Check in applicationDidFinishLaunching
func applicationDidFinishLaunching() {
Task {
do {
let (session, builder) = try await HKHealthStore().recoverActiveWorkoutSession()
workoutManager.resume(session: session, builder: builder)
} catch {
// No session to recover
}
}
}
```
### 5. Network Calls During Background Transition
```swift
// BAD: Not enough time
func applicationWillResignActive() {
URLSession.shared.dataTask(with: url) { ... } // Won't complete
}
// GOOD: Use expiring activity
func applicationWillResignActive() {
ProcessInfo.processInfo.performExpiringActivity(withReason: "Sync") { expired in
guard !expired else { return }
self.quickSync()
}
}
```
## Background App Refresh
### Correct Pattern
```swift
func handle(_ backgroundTasks: Set<WKRefreshBackgroundTask>) {
for task in backgroundTasks {
if let refreshTask = task as? WKApplicationRefreshBackgroundTask {
// 1. Schedule next FIRST
scheduleNextRefresh()
// 2. Use download task (not data task)
let config = URLSessionConfiguration.background(withIdentifier: "com.app.refresh")
let session = URLSession(configuration: config, delegate: self, delegateQueue: nil)
session.downloadTask(with: url).resume()
// 3. Complete this task
refreshTask.setTaskCompletedWithSnapshot(false)
}
}
}
func scheduleNextRefresh() {
// At least 5 minutes in future
let preferredDate = Date().addingTimeInterval(5 * 60)
WKApplication.shared().scheduleBackgroundRefresh(
withPreferredDate: preferredDate,
userInfo: nil
) { _ in }
}
```
## Review Questions
1. Is SwiftUI App protocol used with `@WKApplicationDelegateAdaptor` for lifecycle events?
2. Is `scenePhase` read from root view (not sheets/modals)?
3. Are extended runtime sessions started only while app is active?
4. Is `HKHealthStore().recoverActiveWorkoutSession()` called in `applicationDidFinishLaunching`?
5. Are background tasks scheduled at least 5 minutes apart?
6. Is `URLSessionDownloadTask` (not `DataTask`) used for background network?
7. Is next refresh scheduled BEFORE completing current task?
FILE:references/performance.md
# watchOS Performance
## Constraints
### Memory
| Constraint | Limit |
|------------|-------|
| Device RAM | ~1 GB (Series 9/10/Ultra 2) |
| App bundle | ~50 MB |
| Widget/Complication images | ~30 MB |
| Background task memory | Limited |
### CPU and Battery
| Constraint | Limit |
|------------|-------|
| CPU usage threshold | <80% sustained |
| Background task duration | ~3 min when backgrounding; ~30s when resumed |
| Background refresh | 4 per hour with complication; 15+ min apart |
| Extended runtime | Battery-intensive; end promptly |
### Network
| Consideration | Details |
|---------------|---------|
| Connection | URLSession abstracts Bluetooth/Wi-Fi/cellular |
| WebSocket/Stream | Not supported |
| Background minimum interval | 10+ minutes recommended |
## Critical Anti-Patterns
### 1. Nested TabViews (Memory Leak)
```swift
// BAD: Causes memory leaks
NavigationStack {
TabView {
TabView { // DON'T NEST!
ContentView()
}
}
}
// GOOD: Single level
NavigationStack {
TabView {
ContentView()
}
}
```
### 2. Not Completing Background Tasks
```swift
// BAD: Missing completion handler
func handle(_ backgroundTasks: Set<WKRefreshBackgroundTask>) {
for task in backgroundTasks {
if let refreshTask = task as? WKApplicationRefreshBackgroundTask {
doWork()
// MISSING: setTaskCompletedWithSnapshot!
}
}
}
// GOOD: Always complete with defer
func handle(_ backgroundTasks: Set<WKRefreshBackgroundTask>) {
for task in backgroundTasks {
if let refreshTask = task as? WKApplicationRefreshBackgroundTask {
defer { refreshTask.setTaskCompletedWithSnapshot(false) }
doWork()
}
}
}
```
### 3. Protected File Access in Background
```swift
// BAD: Fails when screen locked
func backgroundHandler() {
let data = try? Data(contentsOf: protectedFileURL) // Fails!
}
// GOOD: Use no file protection for background data
try data.write(to: url, options: .noFileProtection)
```
### 4. WKInterface Property Updates
```swift
// BAD: Each property = ~200ms message
func updateUI() {
label.setText(newText) // Update 1
label.setTextColor(.red) // Update 2
image.setImage(newImage) // Update 3
}
// GOOD: Only set when values change
func updateUI() {
if textChanged {
label.setText(newText)
}
if colorChanged {
label.setTextColor(.red)
}
}
```
### 5. Constant UI Updates During Workout
```swift
// BAD: Updates even when dimmed
struct WorkoutView: View {
let timer = Timer.publish(every: 1, on: .main, in: .common).autoconnect()
var body: some View {
Text("\(heartRate)")
.onReceive(timer) { _ in updateUI() }
}
}
// GOOD: Adaptive update rate
struct WorkoutView: View {
@Environment(\.isLuminanceReduced) var isLuminanceReduced
var body: some View {
TimelineView(.periodic(from: .now, by: updateInterval)) { _ in
Text("\(heartRate)")
}
}
var updateInterval: TimeInterval {
isLuminanceReduced ? 10.0 : 1.0 // Slower when dimmed
}
}
```
### 6. Large WKInterfaceTable
```swift
// BAD: All cells load upfront (no reuse)
func loadTable(items: [Item]) {
table.setNumberOfRows(items.count, withRowType: "Row") // 100+ rows = bad
}
// GOOD: Keep under 20 rows, use incremental updates
func loadTable(items: [Item]) {
let limitedItems = Array(items.prefix(20))
table.setNumberOfRows(limitedItems.count, withRowType: "Row")
}
func addRows(at indexes: IndexSet) {
table.insertRows(at: indexes, withRowType: "Row") // Incremental
}
```
### 7. Loading All Data
```swift
// BAD: Load everything
func loadRecords() async -> [Record] {
return await database.fetchAll()
}
// GOOD: Load what's displayed
func loadRecords(limit: Int = 10) async -> [Record] {
return await database.fetch(limit: limit)
}
```
## Battery Optimization
### Extended Runtime Sessions
```swift
// Always end when activity completes
class MindfulnessManager {
var session: WKExtendedRuntimeSession?
func startSession(duration: TimeInterval) {
session = WKExtendedRuntimeSession()
session?.start()
DispatchQueue.main.asyncAfter(deadline: .now() + duration) { [weak self] in
self?.session?.invalidate()
self?.session = nil
}
}
}
```
### Image Optimization
```swift
// Downsample to display size
func displayImage(_ image: UIImage, targetSize: CGSize) {
let renderer = UIGraphicsImageRenderer(size: targetSize)
let downsampledImage = renderer.image { _ in
image.draw(in: CGRect(origin: .zero, size: targetSize))
}
imageView.setImage(downsampledImage)
}
```
### HealthKit Queries
```swift
// Store and stop long-running queries
class HealthManager {
var observerQuery: HKObserverQuery?
deinit {
if let query = observerQuery {
healthStore.stop(query)
}
}
}
```
## Review Questions
1. Is `TabView` nested within another `TabView`? (Memory leak)
2. Are all `WKRefreshBackgroundTask` completion handlers called?
3. Are files using `.noFileProtection` if accessed in background?
4. Is UI update frequency reduced when `isLuminanceReduced` is true?
5. Is `WKExtendedRuntimeSession` invalidated when activity completes?
6. Are WKInterface properties only set when values change?
7. Are WKInterfaceTables kept under 20 rows?
8. Are images downsampled to display size?
9. Are long-running queries stored and stopped in `deinit`?
Reviews URLSession networking code for iOS/macOS. Covers async/await patterns, request building, error handling, caching, and background sessions.
---
name: urlsession-code-review
description: Reviews URLSession networking code for iOS/macOS. Covers async/await patterns, request building, error handling, caching, and background sessions.
triggers:
- URLSession
- URLRequest
- URLCache
- URLError
- iOS networking
---
# URLSession Code Review
## Quick Reference
| Topic | Reference |
|-------|-----------|
| Async/Await | [async-networking.md](references/async-networking.md) |
| Requests | [request-building.md](references/request-building.md) |
| Errors | [error-handling.md](references/error-handling.md) |
| Caching | [caching.md](references/caching.md) |
## Review Checklist
### Response Validation
- [ ] HTTP status codes validated - URLSession does NOT throw on 404/500
- [ ] Response cast to HTTPURLResponse before checking status
- [ ] Both transport errors (URLError) and HTTP errors handled
### Memory & Resources
- [ ] Downloaded files moved/deleted (async API doesn't auto-delete)
- [ ] Sessions with delegates call `finishTasksAndInvalidate()`
- [ ] Long-running tasks use `[weak self]`
- [ ] Stored Task references cancelled when appropriate
### Configuration
- [ ] `timeoutIntervalForResource` set (default is 7 days!)
- [ ] URLCache sized adequately (default 512KB too small)
- [ ] Sessions reused for connection pooling
### Background Sessions
- [ ] Unique identifier (especially with app extensions)
- [ ] File-based uploads (not data-based)
- [ ] Delegate methods used (not completion handlers)
### Security
- [ ] No hardcoded secrets (use Keychain)
- [ ] Header values sanitized for CRLF injection
- [ ] Query params via URLComponents (not string concat)
## Hard gates (before reporting findings)
Complete in order. Do not advance while a prior gate is open.
1. **Scope** — **Pass:** You name at least one file under review where `URLSession`, `URLRequest`, `HTTPURLResponse` / `URLResponse`, `URLCache`, or `URLError` appears on a networking path. If none apply, stop with “out of scope.”
2. **HTTP vs transport** — **Pass:** Before claiming missing HTTP status handling or “404 treated as success,” you cite `file:line` for the completion/async/`for await` path that receives `response` and state whether `HTTPURLResponse` is cast and `statusCode` is checked (or cite the helper that does). If you cannot see the handler, say **unknown** and ask for it—do not assume.
3. **Session lifecycle** — **Pass:** For a custom `URLSession` with a delegate, you cite `finishTasksAndInvalidate()` or the documented long-lived/singleton pattern you rely on; for `.shared`, say so if the finding depends on configuration. Skip if only ad hoc `URLSession.shared` one-shots with no delegate issues.
4. **Background or file transfer (if applicable)** — **Pass:** If `URLSessionConfiguration.background`, `downloadTask`, or app-extension–scoped sessions appear, findings cite identifier uniqueness, delegate vs completion-handler usage, or file URLs as required. If none of those APIs appear, mark **N/A** and continue.
5. **Severity and checklist** — **Pass:** Every **Critical** item includes `file:line` and names which **Review Checklist** subsection it violates (e.g. Response Validation, Background Sessions). Lower-severity items still name the file(s) they are drawn from.
## Output Format
```markdown
### Critical
1. [FILE:LINE] Missing HTTP status validation
- Issue: 404/500 responses not treated as errors
- Fix: Check `httpResponse.statusCode` is 200-299
```
FILE:references/async-networking.md
# URLSession Async/Await Reference
> Minimum deployment: iOS 15+, macOS 12+
## Quick Reference
### Core Async Methods
| Method | Returns | Use Case |
|--------|---------|----------|
| `data(from: URL)` | `(Data, URLResponse)` | Simple GET requests |
| `data(for: URLRequest)` | `(Data, URLResponse)` | Configured requests (POST, headers) |
| `download(from: URL)` | `(URL, URLResponse)` | Large files to disk |
| `download(for: URLRequest)` | `(URL, URLResponse)` | Large files with custom request |
| `upload(for: URLRequest, from: Data)` | `(Data, URLResponse)` | Upload data in memory |
| `upload(for: URLRequest, fromFile: URL)` | `(Data, URLResponse)` | Upload file from disk |
| `bytes(from: URL)` | `(AsyncBytes, URLResponse)` | Streaming response body |
## Data Tasks
```swift
// Basic GET
func fetchData(from url: URL) async throws -> Data {
let (data, response) = try await URLSession.shared.data(from: url)
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode) else {
throw NetworkError.invalidResponse
}
return data
}
// POST with URLRequest
func postData<T: Encodable>(_ body: T, to url: URL) async throws -> Data {
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.httpBody = try JSONEncoder().encode(body)
let (data, response) = try await URLSession.shared.data(for: request)
// Validate response...
return data
}
```
## Download Tasks
```swift
func downloadFile(from url: URL, to destination: URL) async throws {
let (tempURL, response) = try await URLSession.shared.download(from: url)
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode) else {
throw NetworkError.invalidResponse
}
// CRITICAL: Move or delete the file - it is NOT auto-deleted
try FileManager.default.moveItem(at: tempURL, to: destination)
}
```
**Key difference**: The async/await download API does NOT automatically delete temporary files.
## Streaming with AsyncBytes
```swift
// Line-by-line processing
func streamLines(from url: URL) async throws {
let (bytes, _) = try await URLSession.shared.bytes(from: url)
for try await line in bytes.lines {
processLine(line)
}
}
// Server-Sent Events
func subscribeToEvents(url: URL) async throws {
let (bytes, _) = try await URLSession.shared.bytes(from: url)
for try await line in bytes.lines {
if line.hasPrefix("data: ") {
let jsonString = String(line.dropFirst(6))
// Parse and handle event
}
}
}
```
## Task Cancellation
Task cancellation **automatically propagates** to URLSession requests.
```swift
class DataLoader {
private var loadTask: Task<Data, Error>?
func load(from url: URL) {
loadTask?.cancel() // Cancel previous request
loadTask = Task {
try await URLSession.shared.data(from: url).0
}
}
}
```
SwiftUI's `.task` modifier automatically cancels when the view disappears.
## Memory Management
Tasks implicitly capture `self` strongly. Use `[weak self]` for long-running tasks:
```swift
downloadTask = Task { [weak self] in
guard let url = self?.downloadURL else { return }
let (data, _) = try await URLSession.shared.data(from: url)
self?.processData(data)
}
```
## Critical Anti-Patterns
### 1. Not Checking HTTP Status Codes
```swift
// BAD: 404 does not throw an error
let (data, _) = try await URLSession.shared.data(from: url)
let decoded = try JSONDecoder().decode(Model.self, from: data) // Crashes on error HTML
// GOOD: Validate response
let (data, response) = try await URLSession.shared.data(from: url)
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode) else {
throw NetworkError.serverError
}
```
### 2. Forgetting to Delete Downloaded Files
```swift
// BAD: Temporary file wastes storage
let (tempURL, _) = try await URLSession.shared.download(from: url)
// File is never moved or deleted
// GOOD: Always handle temporary file
let (tempURL, _) = try await URLSession.shared.download(from: url)
defer { try? FileManager.default.removeItem(at: tempURL) }
let data = try Data(contentsOf: tempURL)
```
### 3. Upload Without HTTP Method
```swift
// BAD: GET cannot have a body
var request = URLRequest(url: url)
try await URLSession.shared.upload(for: request, from: data) // Fails
// GOOD: Set HTTP method
request.httpMethod = "POST"
try await URLSession.shared.upload(for: request, from: data)
```
### 4. Storing Tasks Without Cancellation
```swift
// BAD: Tasks accumulate
func search(query: String) {
Task { let results = try await performSearch(query) }
}
// GOOD: Cancel previous task
private var searchTask: Task<Void, Never>?
func search(query: String) {
searchTask?.cancel()
searchTask = Task {
guard !Task.isCancelled else { return }
let results = try? await performSearch(query)
}
}
```
### 5. Strong Self in Infinite Loops
```swift
// BAD: Permanent memory leak
listenerTask = Task {
for try await line in bytes.lines {
self.handleEvent(line) // Never deallocates
}
}
// GOOD: Weak self
listenerTask = Task { [weak self] in
for try await line in bytes.lines {
guard let self else { return }
self.handleEvent(line)
}
}
```
## Review Questions
- [ ] Are HTTP status codes validated (not just assuming success)?
- [ ] Are downloaded files moved/deleted after use?
- [ ] Are upload requests setting HTTP method (POST/PUT)?
- [ ] Are long-running tasks using `[weak self]`?
- [ ] Are stored Task references cancelled when appropriate?
- [ ] Is cancellation handled in `viewWillDisappear`?
- [ ] Is SwiftUI's `.task` modifier used instead of manual Task management?
- [ ] For streaming, is response status checked before iterating?
FILE:references/caching.md
# URLSession Caching and Configuration Reference
## Quick Reference
### URLSessionConfiguration Types
| Type | Persistence | Use Case |
|------|-------------|----------|
| `.default` | Disk cache, cookies | Normal networking |
| `.ephemeral` | Memory only | Privacy-sensitive |
| `.background(withIdentifier:)` | System-managed | Large transfers |
### Cache Policies
| Policy | Behavior | When to Use |
|--------|----------|-------------|
| `.useProtocolCachePolicy` | Follows HTTP headers | Default |
| `.reloadIgnoringLocalCacheData` | Always fetch fresh | Fresh data required |
| `.returnCacheDataElseLoad` | Cache first | Offline-first |
| `.returnCacheDataDontLoad` | Cache only | Strict offline |
### Timeout Defaults
| Property | Default | Typical Setting |
|----------|---------|-----------------|
| `timeoutIntervalForRequest` | 60s | 30-60s |
| `timeoutIntervalForResource` | **7 days** | 2-5 minutes |
### URLCache Sizing
| Type | Default | Recommended |
|------|---------|-------------|
| Memory | 512 KB | 20 MB |
| Disk | 10 MB | 100 MB |
## URLCache Configuration
```swift
// Default cache is too small - configure early in app lifecycle
func application(_ application: UIApplication,
didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?) -> Bool {
URLCache.shared = URLCache(
memoryCapacity: 20 * 1024 * 1024, // 20 MB
diskCapacity: 100 * 1024 * 1024, // 100 MB
directory: nil
)
return true
}
```
**Cache rules**: Response must be <= 5% of disk cache size to be cached. ([Apple Developer Documentation](https://developer.apple.com/documentation/foundation/urlsessiondatadelegate/urlsession(_:datatask:willcacheresponse:completionhandler:)))
## Session Configuration
### Default Configuration
```swift
let config = URLSessionConfiguration.default
config.timeoutIntervalForRequest = 30.0
config.timeoutIntervalForResource = 300.0 // Not 7 days!
config.waitsForConnectivity = true
let session = URLSession(configuration: config)
```
### Ephemeral (Privacy Mode)
```swift
// No disk persistence - RAM only
let session = URLSession(configuration: .ephemeral)
```
### Background Configuration
```swift
let config = URLSessionConfiguration.background(
withIdentifier: "com.yourapp.backgroundSession"
)
config.isDiscretionary = false // Start immediately
config.sessionSendsLaunchEvents = true
let session = URLSession(configuration: config, delegate: self, delegateQueue: nil)
```
## Background Session Implementation
### AppDelegate Handler (Required)
```swift
var backgroundCompletionHandler: (() -> Void)?
func application(_ application: UIApplication,
handleEventsForBackgroundURLSession identifier: String,
completionHandler: @escaping () -> Void) {
backgroundCompletionHandler = completionHandler
}
```
### Delegate Methods
```swift
func urlSession(_ session: URLSession,
downloadTask: URLSessionDownloadTask,
didFinishDownloadingTo location: URL) {
// Move file immediately - location deleted after method returns
try? FileManager.default.moveItem(at: location, to: permanentURL)
}
func urlSessionDidFinishEvents(forBackgroundURLSession session: URLSession) {
DispatchQueue.main.async {
self.backgroundCompletionHandler?()
self.backgroundCompletionHandler = nil
}
}
```
## Connection Pooling
```swift
// CORRECT: Reuse sessions for HTTP/2 multiplexing
class NetworkService {
static let shared = NetworkService()
private let session = URLSession(configuration: .default)
}
// ANTI-PATTERN: New session per request - loses pooling
func badFetch(_ url: URL) async throws -> Data {
let session = URLSession(configuration: .default) // Bad!
return try await session.data(from: url).0
}
```
## Critical Anti-Patterns
### 1. Memory Leak from Strong Delegate
```swift
// BUG: URLSession retains delegate forever
class LeakyManager {
var session: URLSession!
init() {
session = URLSession(configuration: .default, delegate: self, delegateQueue: nil)
}
// deinit never called - memory leak
}
// CORRECT: Invalidate session
class CorrectManager {
var session: URLSession!
init() {
session = URLSession(configuration: .default, delegate: self, delegateQueue: nil)
}
deinit {
session.finishTasksAndInvalidate()
}
}
```
### 2. Background Session Identifier Conflicts
```swift
// BUG: Same identifier in app and extension
// Main app
let config = URLSessionConfiguration.background(withIdentifier: "downloads")
// Extension (CONFLICT!)
let config = URLSessionConfiguration.background(withIdentifier: "downloads")
// CORRECT: Unique per process
"com.yourapp.main.downloads"
"com.yourapp.extension.downloads"
```
### 3. Data-Based Background Uploads
```swift
// BUG: Data uploads don't persist in background
backgroundSession.uploadTask(with: request, from: data) // Fails!
// CORRECT: File-based uploads
let fileURL = saveDataToFile(data)
backgroundSession.uploadTask(with: request, fromFile: fileURL)
```
### 4. Completion Handlers in Background Sessions
```swift
// BUG: Completion handlers not called
backgroundSession.dataTask(with: url) { data, _, _ in
// Never executed!
}
// CORRECT: Use delegate methods only
```
### 5. Inadequate Cache Size
```swift
// BUG: Default 512KB memory, 10MB disk - too small
let session = URLSession.shared
// CORRECT: Configure adequate cache
URLCache.shared = URLCache(
memoryCapacity: 20 * 1024 * 1024,
diskCapacity: 100 * 1024 * 1024,
directory: nil
)
```
### 6. Battery Drain from Immediate Transfers
```swift
// ANTI-PATTERN: Non-urgent but immediate
config.isDiscretionary = false // Immediate regardless of conditions
// CORRECT: Let system optimize
config.isDiscretionary = true // Waits for WiFi, charging
```
### 7. Missing Background Event Handler
```swift
// BUG: No handleEventsForBackgroundURLSession
class IncompleteAppDelegate: UIResponder, UIApplicationDelegate {
// App never notified of completion
}
```
### 8. Unresumed Tasks
```swift
// BUG: Task created but never resumed
task = session.dataTask(with: url) { ... }
// MISSING: task.resume()
// Completion handler retained indefinitely
// CORRECT
task.resume() // Always call
```
## Review Questions
### Cache
- [ ] Is URLCache configured with adequate capacity?
- [ ] Is cache configured before network calls?
- [ ] Is ephemeral config used for sensitive data?
### Session Management
- [ ] Are sessions reused (not created per request)?
- [ ] Is session invalidated when done?
- [ ] Are timeouts configured (not 7-day default)?
### Background Sessions
- [ ] Is identifier unique (especially with extensions)?
- [ ] Is `handleEventsForBackgroundURLSession` implemented?
- [ ] Is `urlSessionDidFinishEvents` calling completion handler?
- [ ] Are uploads file-based (not data-based)?
- [ ] Are delegate methods used (not completion handlers)?
- [ ] Is `isDiscretionary` set for non-urgent transfers?
- [ ] Is background session at app level (not ViewController)?
### Memory
- [ ] Is session delegate invalidated to break retain cycle?
- [ ] Are tasks always resumed after creation?
FILE:references/error-handling.md
# URLSession Error Handling Reference
## Quick Reference
### URLError Codes
| Code | Name | Retryable | User Message |
|------|------|-----------|--------------|
| -1009 | `notConnectedToInternet` | No* | "You're offline" |
| -1001 | `timedOut` | Yes | "Request timed out" |
| -999 | `cancelled` | No | (Silent) |
| -1003 | `cannotFindHost` | Yes | "Unable to reach server" |
| -1004 | `cannotConnectToHost` | Yes | "Unable to connect" |
| -1005 | `networkConnectionLost` | Yes | "Connection lost" |
| -1200 | `secureConnectionFailed` | No | "Security error" |
*Wait for network to reconnect
### HTTP Status Codes
| Range | Category | Retryable | Handling |
|-------|----------|-----------|----------|
| 200-299 | Success | N/A | Process response |
| 400 | Bad Request | No | Show validation error |
| 401 | Unauthorized | No | Re-authenticate |
| 404 | Not Found | No | Show not found |
| 429 | Too Many Requests | Yes | Respect Retry-After |
| 500-599 | Server Error | Yes | Retry with backoff |
## Transport vs HTTP Errors
**Critical**: URLSession does NOT treat non-2xx status codes as errors automatically.
```swift
// Transport errors via error parameter
if let error = error as? URLError {
switch error.code {
case .notConnectedToInternet: // Device offline
case .timedOut: // Request timed out
case .cancelled: // User cancelled
default: break
}
}
// HTTP errors via status code (MUST check manually)
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode) else {
// Server returned error - data may contain error body
}
```
## Response Validation
```swift
func validateResponse(_ data: Data?, _ response: URLResponse?, _ error: Error?) throws -> Data {
// 1. Transport errors first
if let error = error {
throw NetworkError.transport(error)
}
// 2. Validate response type
guard let httpResponse = response as? HTTPURLResponse else {
throw NetworkError.invalidResponse
}
// 3. Check status code
guard (200...299).contains(httpResponse.statusCode) else {
throw NetworkError.httpError(httpResponse.statusCode, data)
}
// 4. Validate data
guard let data = data, !data.isEmpty else {
throw NetworkError.noData
}
return data
}
```
## Retry Strategy
### Determining Retryability
```swift
extension URLError.Code {
var isRetryable: Bool {
switch self {
case .timedOut, .cannotFindHost, .cannotConnectToHost,
.networkConnectionLost, .dnsLookupFailed:
return true
case .notConnectedToInternet, .cancelled,
.secureConnectionFailed, .userAuthenticationRequired:
return false
default:
return false
}
}
}
extension Int {
var isRetryableStatusCode: Bool {
[408, 429, 500, 502, 503, 504].contains(self)
}
}
```
### Exponential Backoff with Jitter
```swift
struct RetryConfiguration {
let maxRetries: Int = 3
let baseDelay: TimeInterval = 1.0
let maxDelay: TimeInterval = 30.0
func delay(for attempt: Int) -> TimeInterval {
let exponential = baseDelay * pow(2.0, Double(attempt))
let clamped = min(exponential, maxDelay)
let jitter = Double.random(in: 0...(0.1 * clamped))
return clamped + jitter
}
}
```
### Retry-After Header
```swift
func retryDelay(from response: HTTPURLResponse, fallback: TimeInterval) -> TimeInterval {
if let retryAfter = response.value(forHTTPHeaderField: "Retry-After"),
let seconds = Double(retryAfter) {
return seconds
}
return fallback
}
```
## Network Conditions
### waitsForConnectivity (Recommended)
```swift
let config = URLSessionConfiguration.default
config.waitsForConnectivity = true // Wait instead of failing
config.timeoutIntervalForResource = 300 // Don't use 7-day default
// Delegate for UI feedback
func urlSession(_ session: URLSession,
taskIsWaitingForConnectivity task: URLSessionTask) {
// Show "waiting for network" UI
}
```
**Important**: Don't pre-check network before requests - race condition.
## Critical Anti-Patterns
### 1. Silent Error Swallowing
```swift
// DANGEROUS
URLSession.shared.dataTask(with: request) { data, _, error in
guard let data = data else { return } // Error ignored!
}
// CORRECT
URLSession.shared.dataTask(with: request) { data, response, error in
if let error = error { handleError(error); return }
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode) else {
handleHTTPError(response); return
}
guard let data = data else { handleNoData(); return }
}
```
### 2. Missing Status Code Validation
```swift
// DANGEROUS: Assumes nil error means success
let (data, _) = try await URLSession.shared.data(for: request)
return data // Could be 404 error page!
// CORRECT
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode) else {
throw NetworkError.httpError
}
```
### 3. Retrying Non-Retryable Errors
```swift
// DANGEROUS: Retrying 401 won't help
for _ in 0..<3 {
do { return try await fetch() }
catch { continue } // Retries ALL errors
}
// CORRECT
catch let error as URLError where error.code.isRetryable {
continue // Only retry network issues
}
```
### 4. Blocking Retry Without Backoff
```swift
// DANGEROUS: Hammers server
while true {
do { return try await fetch() }
catch { continue } // Immediate retry
}
// CORRECT: Exponential backoff
for attempt in 0..<maxRetries {
do { return try await fetch() }
catch {
let delay = baseDelay * pow(2.0, Double(attempt))
try await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000))
}
}
```
### 5. Technical Errors to Users
```swift
// DANGEROUS
showAlert(error.localizedDescription)
// "Error Domain=NSURLErrorDomain Code=-1004..."
// CORRECT
showAlert(userFriendlyMessage(for: error))
// "Unable to connect. Please check your internet."
```
### 6. Ignoring Cancellation
```swift
// DANGEROUS: Shows error for user cancel
catch { showError(error) }
// CORRECT
catch let error as URLError where error.code == .cancelled {
return // Silent - user initiated
}
catch { showError(error) }
```
## Review Questions
### Error Handling
- [ ] Are both transport errors and HTTP status codes handled?
- [ ] Is there a centralized error handling strategy?
- [ ] Are error types mapped to user-friendly messages?
### Response Validation
- [ ] Is response cast to HTTPURLResponse?
- [ ] Are non-2xx status codes treated as errors?
- [ ] Are error response bodies parsed for messages?
### Retry Logic
- [ ] Are only appropriate errors retried (not 4xx)?
- [ ] Is exponential backoff with jitter implemented?
- [ ] Is there a maximum retry count?
- [ ] Is Retry-After header respected for 429/503?
### User Experience
- [ ] Are cancellation errors handled silently?
- [ ] Is there a retry option for recoverable errors?
- [ ] Are authentication errors handled separately?
FILE:references/request-building.md
# URLRequest Building Reference
## Quick Reference
### URLRequest Configuration
| Property | Type | Default | Description |
|----------|------|---------|-------------|
| `url` | `URL?` | nil | Request URL |
| `httpMethod` | `String?` | "GET" | HTTP method |
| `httpBody` | `Data?` | nil | Request body |
| `timeoutInterval` | `TimeInterval` | 60.0 | Timeout in seconds |
| `cachePolicy` | `CachePolicy` | `.useProtocolCachePolicy` | Cache behavior |
### Cache Policies
| Policy | Use Case |
|--------|----------|
| `.useProtocolCachePolicy` | Default; respects server headers |
| `.reloadIgnoringLocalCacheData` | Always fetch fresh |
| `.returnCacheDataElseLoad` | Offline-first apps |
| `.returnCacheDataDontLoad` | Strictly offline |
### Content-Types
| Content Type | Use Case |
|--------------|----------|
| `application/json` | JSON body |
| `application/x-www-form-urlencoded` | Form data |
| `multipart/form-data; boundary=xxx` | File uploads |
## HTTP Headers
```swift
// Set headers
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
// Set all at once
request.allHTTPHeaderFields = [
"Content-Type": "application/json",
"Accept": "application/json"
]
```
## Body Encoding
### JSON (Recommended)
```swift
struct CreateUserRequest: Encodable {
let name: String
let email: String
}
let body = CreateUserRequest(name: "John", email: "[email protected]")
request.httpBody = try JSONEncoder().encode(body)
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
```
### Form URL Encoded
```swift
// CORRECT: Use URLComponents for proper encoding
var components = URLComponents()
components.queryItems = [
URLQueryItem(name: "username", value: "john"),
URLQueryItem(name: "password", value: "secret")
]
// percentEncodedQuery encodes spaces as + and handles reserved characters
request.httpBody = components.percentEncodedQuery?.data(using: .utf8)
request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type")
```
> **Warning**: Don't use `.urlQueryAllowed` for form-encoded values. It includes reserved characters (`&`, `=`, `+`, `/`, `?`) that must be escaped in parameter values. Use `URLComponents` or a custom charset with only RFC 3986 unreserved characters.
### Multipart Form Data
```swift
let boundary = UUID().uuidString
request.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
var body = Data()
body.append("--\(boundary)\r\n".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"file\"; filename=\"image.jpg\"\r\n".data(using: .utf8)!)
body.append("Content-Type: image/jpeg\r\n\r\n".data(using: .utf8)!)
body.append(imageData)
body.append("\r\n--\(boundary)--\r\n".data(using: .utf8)!)
request.httpBody = body
```
## URL Query Parameters
```swift
// CORRECT: Use URLComponents
var components = URLComponents(string: "https://api.example.com/search")!
components.queryItems = [
URLQueryItem(name: "query", value: "swift programming"),
URLQueryItem(name: "page", value: "1")
]
let request = URLRequest(url: components.url!)
// Handle plus signs (not encoded by default)
let encodedValue = value?.replacingOccurrences(of: "+", with: "%2B")
```
## Timeout Configuration
```swift
// Request-level
var request = URLRequest(url: url)
request.timeoutInterval = 30.0
// Session-level
let config = URLSessionConfiguration.default
config.timeoutIntervalForRequest = 30.0 // Resets on each packet
config.timeoutIntervalForResource = 300.0 // Total time (default: 7 days!)
```
> **Note**: Per-request `timeoutInterval` only takes effect if it's not more restrictive than the session's `timeoutIntervalForRequest`. If the session enforces a stricter limit, that limit applies instead.
## Critical Anti-Patterns
### 1. CRLF Injection (CVE-2022-3918)
> **Note**: This vulnerability affects swift-corelibs-foundation versions before 5.7.3. In 5.7.3+, URLRequest rejects CR/LF in header values at the framework level. Manual sanitization is only needed for projects that cannot upgrade.
```swift
// DANGEROUS: User input in headers (affects swift-corelibs-foundation < 5.7.3)
let userInput = "value\r\nEvil-Header: injected"
request.setValue(userInput, forHTTPHeaderField: "X-Custom")
// SAFE: Sanitize header values (for pre-5.7.3 or as defense-in-depth)
let sanitized = userInput.replacingOccurrences(of: "\r", with: "")
.replacingOccurrences(of: "\n", with: "")
request.setValue(sanitized, forHTTPHeaderField: "X-Custom")
```
### 2. Hardcoded Secrets
```swift
// DANGEROUS
request.setValue("sk_live_abc123xyz", forHTTPHeaderField: "Authorization")
// SAFE: From Keychain
let token = KeychainService.shared.getAPIToken()
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
```
### 3. Content-Type/Body Mismatch
```swift
// BUG: JSON body but wrong Content-Type
request.httpBody = try JSONEncoder().encode(user)
request.setValue("text/plain", forHTTPHeaderField: "Content-Type")
// CORRECT
request.httpBody = try JSONEncoder().encode(user)
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
```
### 4. Manual URL Concatenation
```swift
// DANGEROUS: Injection risk
let url = URL(string: "https://api.com/search?q=\(userQuery)")!
// SAFE: URLComponents
var components = URLComponents(string: "https://api.com/search")!
components.queryItems = [URLQueryItem(name: "q", value: userQuery)]
```
### 5. Memory Issues with Large Files
```swift
// DANGEROUS: Loads entire file into memory
let largeFileData = try Data(contentsOf: largeFileURL)
request.httpBody = largeFileData
// SAFE: Use file-based upload
session.uploadTask(with: request, fromFile: largeFileURL)
```
### 6. Creating Sessions Per Request
```swift
// INEFFICIENT
func makeRequest() {
let session = URLSession(configuration: .default) // New each time!
session.dataTask(with: request).resume()
}
// EFFICIENT: Reuse session
class NetworkManager {
private let session = URLSession(configuration: .default)
}
```
## Review Questions
### Security
- [ ] Are header values sanitized for CRLF characters?
- [ ] Are secrets from Keychain, not hardcoded?
- [ ] Is SSL/TLS validation proper (no blanket trust)?
- [ ] Are credentials excluded from URLs and logs?
### Correctness
- [ ] Does Content-Type match body encoding?
- [ ] Is HTTP method appropriate (no body on GET)?
- [ ] Are query parameters built with URLComponents?
- [ ] Are special characters (+, ;, ,) encoded correctly?
### Performance
- [ ] Is URLSession reused across requests?
- [ ] Are timeouts configured appropriately?
- [ ] Are large uploads using file-based API?
- [ ] Is `timeoutIntervalForResource` set (not 7-day default)?
Reviews SwiftUI code for view composition, state management, performance, and accessibility. Use when reviewing .swift files containing SwiftUI views, proper...
---
name: swiftui-code-review
description: Reviews SwiftUI code for view composition, state management, performance, and accessibility. Use when reviewing .swift files containing SwiftUI views, property wrappers (@State, @Binding, @Observable), or UI code.
---
# SwiftUI Code Review
## Quick Reference
| Issue Type | Reference |
|------------|-----------|
| View extraction, modifiers, body complexity | [references/view-composition.md](references/view-composition.md) |
| @State, @Binding, @Observable, @Bindable | [references/state-management.md](references/state-management.md) |
| LazyStacks, AnyView, ForEach, identity | [references/performance.md](references/performance.md) |
| VoiceOver, Dynamic Type, labels, traits | [references/accessibility.md](references/accessibility.md) |
## Gates (review workflow)
Complete in order; do not skip ahead.
1. **Anchor scope** — Pass when: every reviewed file is listed as a repo-relative `.swift` path (or the review explicitly states “none opened / N/A” with reason).
2. **Reference before critique** — Pass when: for any non-trivial body, modifier chain, or wrapper-ownership question, you have opened the matching `references/*.md` row from the table above *or* you state “not needed” with one line why.
3. **Evidence-bound findings** — Pass when: each substantive issue includes **`[FILE:LINE]`** (or a bounded line range) before recommendations; symbols/snippets may supplement but not replace the location anchor; no finding that rests only on “typical SwiftUI” without pointing at this code.
## Review Checklist
- [ ] View body under 10 composed elements (extract subviews)
- [ ] Modifiers in correct order (padding before background)
- [ ] @StateObject for view-owned objects, @ObservedObject for passed objects
- [ ] @Bindable used for two-way bindings to @Observable (iOS 17+)
- [ ] LazyVStack/LazyHStack for scrolling lists with 50+ items
- [ ] No AnyView (use @ViewBuilder or generics instead)
- [ ] ForEach uses stable Identifiable IDs (not array indices)
- [ ] All images/icons have accessibilityLabel
- [ ] Custom controls have accessibilityAddTraits(.isButton)
- [ ] Dynamic Type supported (no fixed font sizes)
- [ ] .task modifier for async work (not onAppear + Task)
## When to Load References
- Complex view bodies or modifier chains -> view-composition.md
- Property wrapper usage (@State, @Observable) -> state-management.md
- List performance or view identity issues -> performance.md
- VoiceOver or accessibility implementation -> accessibility.md
## Review Questions
1. Could this large view body be split into smaller, reusable Views?
2. Is modifier order intentional? (padding -> background -> frame)
3. Is @StateObject/@ObservedObject usage correct for ownership?
4. Could LazyVStack improve this ScrollView's performance?
5. Would VoiceOver users understand this interface?
FILE:references/accessibility.md
# Accessibility
## Accessibility Labels
All interactive elements need descriptive labels.
```swift
// BAD - VoiceOver says "heart"
Button(action: { addToFavorites() }) {
Image(systemName: "heart")
}
// GOOD - VoiceOver says "Add to favorites"
Button(action: { addToFavorites() }) {
Image(systemName: "heart")
}
.accessibilityLabel("Add to favorites")
.accessibilityHint("Double tap to add to your favorites")
```
## Decorative Images
Hide non-informative images from VoiceOver.
```swift
// BAD - VoiceOver reads "star fill"
HStack {
Image(systemName: "star.fill")
Text("Premium Feature")
}
// GOOD - decorative image hidden
HStack {
Image(decorative: "star.fill")
Text("Premium Feature")
}
// Alternative
Image(systemName: "star.fill")
.accessibilityHidden(true)
```
## Custom Control Traits
Custom interactive views need accessibility traits.
```swift
// BAD - VoiceOver doesn't know it's tappable
struct CustomCheckbox: View {
@Binding var isChecked: Bool
var body: some View {
Image(systemName: isChecked ? "checkmark.square" : "square")
.onTapGesture { isChecked.toggle() }
}
}
// GOOD - proper accessibility
struct CustomCheckbox: View {
@Binding var isChecked: Bool
var body: some View {
Image(systemName: isChecked ? "checkmark.square" : "square")
.onTapGesture { isChecked.toggle() }
.accessibilityLabel("Agreement")
.accessibilityAddTraits(.isButton)
.accessibilityAddTraits(isChecked ? .isSelected : [])
.accessibilityValue(isChecked ? "Checked" : "Unchecked")
}
}
```
## Common Traits
| Trait | Use Case |
|-------|----------|
| `.isButton` | Custom tappable views |
| `.isHeader` | Section headers |
| `.isSelected` | Current selection state |
| `.isLink` | External navigation |
## Grouping Elements
Combine related elements for VoiceOver.
```swift
// BAD - read separately
VStack {
Text("John Doe")
Text("Senior Developer")
Text("San Francisco")
}
// GOOD - single announcement
VStack {
Text("John Doe")
Text("Senior Developer")
Text("San Francisco")
}
.accessibilityElement(children: .combine)
```
## Dynamic Type
Use semantic fonts, not fixed sizes.
```swift
// BAD - ignores user preference
Text("Settings")
.font(.system(size: 17))
// GOOD - scales with preference
Text("Settings")
.font(.body)
```
## Environment Properties
Check accessibility settings.
```swift
@Environment(\.accessibilityReduceMotion) var reduceMotion
@Environment(\.accessibilityDifferentiateWithoutColor) var noColor
@Environment(\.dynamicTypeSize) var typeSize
```
## Critical Anti-Patterns
| Pattern | Issue |
|---------|-------|
| Interactive element without label | VoiceOver can't describe it |
| Decorative image not hidden | Clutters VoiceOver reading |
| Custom control without traits | VoiceOver doesn't indicate interactivity |
| Fixed font sizes | Ignores Dynamic Type |
| Color-only information | Excludes color blind users |
| Touch target < 44pt | Hard to tap |
## Review Questions
1. Do all interactive elements have accessibilityLabel?
2. Are decorative images hidden from VoiceOver?
3. Do custom controls have proper accessibility traits?
4. Are related UI elements grouped for VoiceOver?
5. Does the UI support Dynamic Type (no fixed font sizes)?
6. Is color paired with other indicators (shape, text, icon)?
FILE:references/performance.md
# Performance
## Lazy Stacks
Use LazyVStack/LazyHStack for scrolling content with many items.
```swift
// BAD - all 1000 items rendered immediately
ScrollView {
VStack {
ForEach(items) { ItemView(item: $0) }
}
}
// GOOD - only visible items rendered
ScrollView {
LazyVStack {
ForEach(items) { ItemView(item: $0) }
}
}
```
## AnyView Avoidance
AnyView defeats SwiftUI's type-based diffing.
```swift
// BAD - SwiftUI can't diff
func makeView(type: ViewType) -> some View {
switch type {
case .a: return AnyView(ViewA())
case .b: return AnyView(ViewB())
}
}
// GOOD - preserves type information
@ViewBuilder
func makeView(type: ViewType) -> some View {
switch type {
case .a: ViewA()
case .b: ViewB()
}
}
```
## ForEach Identity
Use stable Identifiable IDs, never array indices.
```swift
// BAD - index changes when array changes
ForEach(items.indices, id: \.self) { index in
ItemView(item: items[index])
}
// GOOD - stable ID from Identifiable
ForEach(items) { item in
ItemView(item: item)
}
// BAD - dynamic range without id
@State var count = 5
List(0..<count) { i in Text("Row \(i)") } // Crashes!
// GOOD - dynamic range with id
List(0..<count, id: \.self) { i in Text("Row \(i)") }
```
## Equatable Views
For complex views, implement Equatable to optimize diffing.
```swift
struct CalendarView: View, Equatable {
let events: [Event]
let selectedDate: Date
static func == (lhs: Self, rhs: Self) -> Bool {
lhs.events.count == rhs.events.count &&
lhs.selectedDate == rhs.selectedDate
}
var body: some View { /* expensive */ }
}
// Usage
ParentView()
.equatable()
```
## View Body Efficiency
Avoid expensive operations in view body.
```swift
// BAD - formatter created on every rebuild
var body: some View {
let formatter = DateFormatter()
formatter.dateStyle = .long
return Text(formatter.string(from: date))
}
// GOOD - cached formatter
private static let formatter: DateFormatter = {
let f = DateFormatter()
f.dateStyle = .long
return f
}()
```
## Expensive Visual Effects
Blur, shadow, and mask cause offscreen rendering.
```swift
// CAUTION - expensive in lists
ForEach(items) { item in
ItemView(item: item)
.blur(radius: 5) // Expensive
.shadow(radius: 10) // Expensive
}
```
## Critical Anti-Patterns
| Pattern | Issue |
|---------|-------|
| VStack in ScrollView with 100+ items | All items rendered at once |
| AnyView | Defeats type-based diffing |
| ForEach with array index as id | View recreation on array change |
| .id() modifier inside List | Prevents List optimization |
| DateFormatter in view body | Recreated on every rebuild |
## Review Questions
1. Should this VStack/HStack be Lazy for scrolling performance?
2. Is AnyView used? Can @ViewBuilder replace it?
3. Does ForEach use stable Identifiable IDs?
4. Are expensive computations cached outside view body?
5. Are visual effects (blur, shadow) used sparingly in lists?
FILE:references/state-management.md
# State Management
## Property Wrapper Quick Reference
| iOS Version | View Creates Object | View Receives Object | Two-Way Binding |
|-------------|---------------------|---------------------|-----------------|
| < iOS 17 | `@StateObject` | `@ObservedObject` | `@Binding` |
| iOS 17+ | `@State` (for @Observable) | Plain property | `@Bindable` |
## @StateObject vs @ObservedObject
Use @StateObject when the view creates and owns the object. Use @ObservedObject when passed from a parent.
```swift
// BAD - recreated on every view rebuild
struct ContentView: View {
@ObservedObject var viewModel = ViewModel() // Wrong!
}
// GOOD - survives view rebuilds
struct ContentView: View {
@StateObject var viewModel = ViewModel()
}
// GOOD - received from parent
struct ChildView: View {
@ObservedObject var viewModel: ViewModel // Not creating it
}
```
## iOS 17+ @Observable Pattern
With @Observable, use @State at app level, plain properties in children.
```swift
@Observable class AppStore { /* ... */ }
@main
struct MyApp: App {
@State private var store = AppStore()
var body: some Scene {
WindowGroup { ContentView(store: store) }
}
}
// Child receives without wrapper
struct ChildView: View {
var store: AppStore // Read-only access
}
// For two-way bindings, use @Bindable
struct EditView: View {
@Bindable var user: User
var body: some View {
TextField("Name", text: $user.name)
}
}
```
## @State Mistakes
```swift
// BAD - @State ignores external updates
struct ChildView: View {
@State var user: User // Updates from parent ignored!
}
// BAD - property observers don't work
@State var count = 0 {
didSet { print("Changed") } // Never called!
}
// GOOD - use .onChange modifier
.onChange(of: count) { oldValue, newValue in
print("Changed from \(oldValue) to \(newValue)")
}
```
## Environment Usage (iOS 17+)
```swift
// Old syntax
@Environment(\.modelContext) private var context
// New syntax for custom types
@Environment(AuthService.self) private var authService
```
## Critical Anti-Patterns
| Pattern | Issue |
|---------|-------|
| `@ObservedObject var vm = ViewModel()` | Creates new object on rebuild |
| `@State var model: SomeClass` | Reference types need @StateObject |
| `@State` in child for passed data | Ignores parent updates |
| didSet on @State | Property observers don't fire |
| Missing @Bindable for bindings | Can't bind to @Observable properties |
## Review Questions
1. Is @StateObject used when the view creates the object?
2. Is @ObservedObject only used for objects passed from parent?
3. For iOS 17+, is @State with @Observable in a stable parent view?
4. Is @Bindable used where two-way bindings are needed?
5. Is .onChange used instead of didSet on @State?
FILE:references/view-composition.md
# View Composition
## View Body Complexity
Keep view bodies under 10 composed elements. Extract subviews when bodies grow large.
```swift
// BAD - massive body
struct ProductView: View {
let product: Product
var body: some View {
VStack {
// 50+ lines of inline views
}
}
}
// GOOD - extracted subviews
struct ProductView: View {
let product: Product
var body: some View {
VStack {
PriceSection(price: product.price)
DetailsSection(description: product.description)
}
}
}
```
## Computed Property Views
Avoid computed property views - they prevent SwiftUI diffing.
```swift
// BAD - can't be diffed
var priceSection: some View {
VStack { /* content */ }
}
// GOOD - proper View struct enables diffing
struct PriceSection: View {
let price: Decimal
var body: some View { /* content */ }
}
```
## Modifier Ordering
Modifiers apply in order. Common patterns:
```swift
// BAD - background only covers text
Text("Hello")
.background(.blue)
.padding()
// GOOD - background covers padded area
Text("Hello")
.padding()
.background(.blue)
// BAD - shadow clipped away
RoundedRectangle()
.shadow(radius: 5)
.clipShape(Circle())
// GOOD - shadow visible
RoundedRectangle()
.clipShape(Circle())
.shadow(radius: 5)
```
## Subview Parameters
Pass simple parameters, not domain models, for reusability.
```swift
// BAD - coupled to model
struct PriceLabel: View {
let product: Product // Tightly coupled
}
// GOOD - model-agnostic
struct PriceLabel: View {
let price: Decimal
let discount: Decimal?
}
```
## Critical Anti-Patterns
| Pattern | Issue |
|---------|-------|
| `var section: some View {}` | Computed views prevent diffing |
| Body over 10 composed elements | Hard to maintain and test |
| `.background().padding()` | Wrong modifier order |
| Inline closures with complex logic | Extract to methods or views |
## Review Questions
1. Could this view body be split into smaller View structs?
2. Are there computed property views that should be View structs?
3. Is the modifier order intentional and correct?
4. Are subviews model-agnostic and reusable?