欢迎大家提Issue反馈问题或建议,建设推理王国!
Skip to content

Chapter 12: Chains of Memory (LSTM and RNN)

Mr. Pallas's Cat's Warm Welcome
In the previous chapter, we explored the wisdom of backpropagation — how networks learn from error. But learning is not only correcting mistakes; it also requires remembering the past. Today, we answer a crucial question: how do we give neural networks memory? When information flows through time, how does a network maintain state, connecting past and present? Let's take it slow and explore the chains of memory.


Core Question: Patterns in Time

"Professor," Piglet pointed at a passage of text on the screen, "the network I trained can recognize individual words, but it always seems to 'forget' what came before. Like 'I like eating apples because they are very...' — it should fill in 'sweet,' but sometimes it fills in 'red' or 'round.'"

It was a winter morning in Kangle Garden at Sun Yat-sen University. Morning mist shrouded the red-brick buildings. In the Black Stone House study, the radiator hummed softly and water droplets condensed on the windowpanes. Outside, the Pearl River lay still as a mirror; the occasional student jogged past, their breath forming white puffs in the cold air.

By the window, Little Seal looked up from A Brief History of Cognitive Science. "That's a profound problem. Historically, the study of memory dates back to ancient Greece. Aristotle divided memory into 'sensory memory,' 'short-term memory,' and 'long-term memory,' but modern cognitive science tells us that memory is the dynamic retention and retrieval of information."

Mr. Pallas's Cat gently set down his teacup and smiled. "You've raised the core challenge of sequence learning. Traditional neural networks are like 'goldfish' — each time they process a new input, they 'forget' the previous context. Today, we're going to give our network chains of memory."

The Dilemma of Time: The Amnesia of Traditional Networks

Piglet walked to the whiteboard and drew a traditional feedforward neural network.

"Professor, feedforward networks process one input at a time, output a result, and then 'forget.' But many problems require context — language understanding, stock prediction, music generation..."

Little Seal added: "Historically, in the 1980s, researchers began thinking about how to make neural networks process sequential data. One naive idea was: unfold time into space."

Mr. Pallas's Cat nodded: "Yes, this is the idea of unfolding through time. But this approach has fundamental flaws: the number of parameters explodes with sequence length, and it can't handle variable-length sequences."

He drew a diagram of unfolding on the whiteboard:

Time step 1: x₁ → network → h₁
Time step 2: x₂ → network → h₂  
Time step 3: x₂ → network → h₃
...

"Worse still," Mr. Pallas's Cat said, "this unfolded network needs to store intermediate states for all time steps during training — enormous memory requirements. We need a more elegant solution."

Piglet thought: "What if we make the network 'recurrent'? Feed the previous moment's output as the next moment's input?"

"That is the core idea of Recurrent Neural Networks (RNNs)," Mr. Pallas's Cat smiled. "You've found the key."


The Wisdom of Recurrence: The Basic Structure of RNNs

Outside, the sunlight grew stronger, casting dappled patterns through the condensation-covered glass onto the red-brick floor.

"Professor," Piglet asked, "how exactly does an RNN 'recur'?"

Mr. Pallas's Cat wrote the RNN computation formula on the whiteboard:

ht=f(Wxhxt+Whhht1+bh)

"Look at this formula," he said. "ht is the hidden state at the current moment, determined by two parts: the current input xt and the previous hidden state ht1."

Little Seal studied the formula carefully: "So the hidden state h is like a 'memory container'? It carries information from all past moments?"

"Yes," Mr. Pallas's Cat nodded. "In theory, ht contains all historical information from the start of the sequence up to time t. This is achieved through the recurrent connection Whh."

Piglet grasped the design: "Whh determines 'how much of the past to remember.' If Whh is large, it values history; if small, it focuses more on the present?"

"Good intuition," Mr. Pallas's Cat said approvingly. "But in practice, simple RNNs face a fundamental problem: vanishing and exploding gradients."

Vanishing Gradients: The Decay of Memory

Mr. Pallas's Cat drew an unrolled diagram of a long sequence on the whiteboard:

h₀ → h₁ → h₂ → ... → h₅₀

"When error propagates backward from time step 50 to time step 1," he explained, "it must be multiplied by Whh fifty times in succession. If the eigenvalues of Whh are less than 1, the gradient decays exponentially to near zero — this is vanishing gradients."

Little Seal mused: "Vanishing gradients mean... early time steps get almost no learning signal? The network 'forgets' how to adjust early parameters?"

"Exactly," said Mr. Pallas's Cat. "This makes it hard for simple RNNs to learn long-range dependencies. If the eigenvalues of Whh are greater than 1, gradients explode exponentially, making training unstable."

Piglet thought: "So we need a mechanism that can selectively remember — remember what's important, forget what's not?"

"A concise summary," Mr. Pallas's Cat smiled. "This is the philosophy of Long Short-Term Memory (LSTM)."


Long Short-Term Memory: The Gating Philosophy of LSTM

Mr. Pallas's Cat turned on the projector, and a neat computation graph appeared on the screen.

LSTM Orthogonal Computation Graph

"This is the orthogonal computation graph for LSTM," Mr. Pallas's Cat said, pointing. "LSTM achieves selective memory through three gating mechanisms: the forget gate, the input gate, and the output gate."

Little Seal studied the gating symbols in the diagram carefully: "These gates... are like 'valves' for information? Controlling what information can enter, stay, or leave?"

"Yes," Mr. Pallas's Cat explained. "Each gate is a sigmoid unit, outputting a value between 0 and 1, representing the 'pass-through ratio.'"

He wrote the core LSTM equations on the whiteboard:

Forget gate: ft=σ(Wf[ht1,xt]+bf)
Input gate: it=σ(Wi[ht1,xt]+bi)
Candidate memory: C~t=tanh(WC[ht1,xt]+bC)
Memory update: Ct=ftCt1+itC~t
Output gate: ot=σ(Wo[ht1,xt]+bo)
Hidden state: ht=ottanh(Ct)

Piglet studied the formulas seriously: " is element-wise multiplication? So the forget gate ft controls how much old memory is retained, and the input gate it controls how much new memory is added?"

"Exactly," said Mr. Pallas's Cat. "Ct is the cell state — LSTM's long-term memory container. It's like a conveyor belt, carrying information through the entire sequence, subject only to mild adjustments by the gates."

Little Seal thought: "This design resembles human memory. We don't remember every detail; instead, we selectively strengthen important memories and let irrelevant information fade."

"Good connection," said Mr. Pallas's Cat. "LSTM realizes the working memory model from cognitive science: limited capacity, dynamic updating, attention-based selection."

The Wisdom of the Forget Gate

Mr. Pallas's Cat highlighted the forget gate equation prominently on the whiteboard.

"The forget gate ft is LSTM's most profound innovation," he said. "It allows the network to actively forget. Useless information is allowed to decay; important information stays close to 1."

Piglet understood: "So LSTM doesn't passively let gradients vanish — it actively controls forgetting? This solves the vanishing gradient problem?"

"Partially," Mr. Pallas's Cat explained. "Through the gating mechanism, LSTM creates a constant error carousel. The gradient of the cell state Ct can propagate relatively smoothly, without exponential decay."

Little Seal added: "Historically, LSTM was proposed by Sepp Hochreiter and Jürgen Schmidhuber in 1997. Their key insight was: use multiplicative gating to control information flow, rather than letting gradients passively decay."

"Yes," said Mr. Pallas's Cat. "This design was so elegant that LSTM dominated sequence modeling for nearly 20 years — until the emergence of attention mechanisms."


Orthogonal Computation Graphs: Seeing Memory Flow

Mr. Pallas's Cat zoomed in on the details of the LSTM computation graph.

"In the orthogonal computation graph," he said, "we can clearly see three information pathways: the horizontally flowing cell state (long-term memory), the vertically flowing hidden state (short-term memory), and the gating signal flow (control mechanism)."

Piglet studied the layout of the diagram carefully: "The cell state flows horizontally from left to right, like a conveyor belt. The hidden state updates at each time step, and the gating signals are like 'traffic officers' directing the flow?"

"A very vivid analogy," Mr. Pallas's Cat smiled. "The orthogonal computation graph helps us understand LSTM's triple structure: memory storage, memory updating, and memory retrieval."

Little Seal mused: "This visualization reveals LSTM's modular design. The computation at each time step can be decomposed into independent sub-computations, combined through orthogonal connections."

"Yes," said Mr. Pallas's Cat. "This modularity makes LSTM easy to understand and implement. In code, we can clearly see the computational flow of the three gates."


Mental Model: Memory as Selective Retention

Little Seal took a psychology book from the shelf. "Professor, this reminds me of the 'decay and interference' theory of memory."

"Good connection," said Mr. Pallas's Cat. "LSTM implements a model of active memory management."

He wrote the mental models on the whiteboard:

Mental Model: The Threefold Management of Memory

  1. Selective encoding (input gate): decides what new information is worth remembering
  2. Selective retention (forget gate): decides what old information is worth keeping
  3. Selective retrieval (output gate): decides what memory is useful at the current moment

"These three mechanisms," Mr. Pallas's Cat explained, "correspond to the active management processes of human memory. We don't passively receive and store all information; we actively encode, retain, and retrieve."

Piglet thought: "So LSTM is not only a technical solution, but also a cognitive model? It shows how 'memory' can be realized through computation?"

"Exactly," Mr. Pallas's Cat answered. "LSTM tells us: memory is not static storage, but a dynamic process. It requires constant updating, filtering, and reorganizing."

Memory and Attention: A Comparison

Mr. Pallas's Cat drew diagrams of LSTM and attention mechanisms side by side on the whiteboard.

"LSTM and attention are two philosophies of sequence processing," he said. "LSTM is incremental memory: information gradually accumulates and updates over time. Attention is selective focus: at each moment, directly access all historical information."

Little Seal compared the two structures: "LSTM is like a diary that is constantly updated; attention is like a carefully indexed archive?"

"A brilliant analogy," Mr. Pallas's Cat smiled. "We will compare these two philosophies in depth in the next chapter. But today, let us first master the wisdom of LSTM."


Key Takeaways

Mr. Pallas's Cat's Summary: The Wisdom of Memory

  1. The nature of sequence learning: many real-world problems have a temporal dimension, requiring models to maintain historical context — understanding "now" depends on "then"
  2. RNN's basic architecture: introduces the time dimension through recurrent connections; the hidden state serves as a memory container carrying historical information
  3. The vanishing gradient challenge: simple RNNs struggle to learn long-range dependencies; gradients decay or explode exponentially through time
  4. LSTM's gating philosophy: achieves selective memory through forget, input, and output gates, creating a constant error carousel
  5. Active memory management: LSTM not only solves a technical problem but also provides a computational model of memory — a dynamic balance of encoding, retention, and retrieval

Code Practice: Implementing LSTM in Python

"Let's use Python to practice LSTM," said Mr. Pallas's Cat. "Code not only helps us understand the abstract gating equations but also lets us 'run' this memory system."

Implementing a Simple LSTM Cell

python
import numpy as np
import matplotlib.pyplot as plt

class LSTMCell:
    """Basic implementation of an LSTM cell"""
    
    def __init__(self, input_size, hidden_size):
        """Initialize LSTM cell
        
        Args:
            input_size: input feature dimension
            hidden_size: hidden state dimension
        """
        # Parameter initialization
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Combined weight matrix: [h_{t-1}, x_t] → four gates
        # Order: forget gate, input gate, candidate memory, output gate
        self.W = np.random.randn(hidden_size + input_size, 4 * hidden_size) * 0.01
        self.b = np.zeros((1, 4 * hidden_size))
        
    def sigmoid(self, x):
        """Sigmoid activation function"""
        return 1 / (1 + np.exp(-x))
    
    def forward(self, x, h_prev, c_prev):
        """Forward pass for one time step
        
        Args:
            x: current input (batch_size, input_size)
            h_prev: previous hidden state (batch_size, hidden_size)
            c_prev: previous cell state (batch_size, hidden_size)
            
        Returns:
            h_next: next hidden state
            c_next: next cell state
            cache: cached intermediate results for backpropagation
        """
        batch_size = x.shape[0]
        
        # Concatenate input and previous hidden state
        combined = np.hstack([h_prev, x])  # (batch_size, hidden_size+input_size)
        
        # Compute all gates and candidate memory
        gates = np.dot(combined, self.W) + self.b
        
        # Split into four parts
        f_gate = self.sigmoid(gates[:, :self.hidden_size])                     # forget gate
        i_gate = self.sigmoid(gates[:, self.hidden_size:2*self.hidden_size])   # input gate
        c_candidate = np.tanh(gates[:, 2*self.hidden_size:3*self.hidden_size]) # candidate memory
        o_gate = self.sigmoid(gates[:, 3*self.hidden_size:])                   # output gate
        
        # Update cell state
        c_next = f_gate * c_prev + i_gate * c_candidate
        
        # Compute hidden state
        h_next = o_gate * np.tanh(c_next)
        
        # Cache intermediate results for backpropagation
        cache = {
            'x': x, 'h_prev': h_prev, 'c_prev': c_prev,
            'f_gate': f_gate, 'i_gate': i_gate,
            'c_candidate': c_candidate, 'o_gate': o_gate,
            'c_next': c_next, 'combined': combined
        }
        
        return h_next, c_next, cache
    
    def describe(self):
        """Describe the parameters of the LSTM cell"""
        print(f"LSTM Cell:")
        print(f"  Input dimension: {self.input_size}")
        print(f"  Hidden dimension: {self.hidden_size}")
        print(f"  Weight shape: {self.W.shape}")
        print(f"  Bias shape: {self.b.shape}")
        print(f"  Total parameters: {self.W.size + self.b.size}")

# Create and test LSTM cell
print("LSTM Cell Test:")
print("=" * 60)

# Create LSTM cell
lstm_cell = LSTMCell(input_size=3, hidden_size=5)
lstm_cell.describe()

# Test data
batch_size = 2
x_t = np.random.randn(batch_size, 3)       # current input
h_prev = np.zeros((batch_size, 5))         # initial hidden state
c_prev = np.zeros((batch_size, 5))         # initial cell state

# Forward pass
h_next, c_next, cache = lstm_cell.forward(x_t, h_prev, c_prev)

print(f"\nInput shape: {x_t.shape}")
print(f"Previous hidden state shape: {h_prev.shape}")
print(f"Previous cell state shape: {c_prev.shape}")
print(f"Next hidden state shape: {h_next.shape}")
print(f"Next cell state shape: {c_next.shape}")

# View gate values
print(f"\nGate values example (first sample):")
print(f"  Forget gate: {cache['f_gate'][0].round(3)}")
print(f"  Input gate: {cache['i_gate'][0].round(3)}")
print(f"  Output gate: {cache['o_gate'][0].round(3)}")
print(f"  Candidate memory: {cache['c_candidate'][0].round(3)}")
print(f"  Cell state change: {(c_next[0] - c_prev[0]).round(3)}")

LSTM Layer for Sequence Processing

python
class LSTMLayer:
    """LSTM layer: processes entire sequences"""
    
    def __init__(self, input_size, hidden_size):
        """Initialize LSTM layer
        
        Args:
            input_size: input feature dimension
            hidden_size: hidden state dimension
        """
        self.cell = LSTMCell(input_size, hidden_size)
        self.hidden_size = hidden_size
        
    def forward_sequence(self, X):
        """Process an entire sequence
        
        Args:
            X: input sequence (seq_length, batch_size, input_size)
            
        Returns:
            H: hidden states at all time steps (seq_length, batch_size, hidden_size)
            C: cell states at all time steps (seq_length, batch_size, hidden_size)
            caches: cache for each time step
        """
        seq_length, batch_size, input_size = X.shape
        
        # Initialize hidden state and cell state
        h = np.zeros((batch_size, self.hidden_size))
        c = np.zeros((batch_size, self.hidden_size))
        
        # Store outputs at all time steps
        H = np.zeros((seq_length, batch_size, self.hidden_size))
        C = np.zeros((seq_length, batch_size, self.hidden_size))
        caches = []
        
        # Loop through each time step
        for t in range(seq_length):
            h, c, cache = self.cell.forward(X[t], h, c)
            H[t] = h
            C[t] = c
            caches.append(cache)
        
        return H, C, caches
    
    def describe_sequence_processing(self, seq_length):
        """Describe the dimensional changes of sequence processing"""
        print(f"LSTM Sequence Processing:")
        print(f"  Input sequence shape: ({seq_length}, batch_size, input_size)")
        print(f"  Output sequence shape: ({seq_length}, batch_size, {self.hidden_size})")
        print(f"  Cell state shape: ({seq_length}, batch_size, {self.hidden_size})")

# Sequence processing demo
print("\nLSTM Sequence Processing Demo:")
print("=" * 60)

# Create LSTM layer
lstm_layer = LSTMLayer(input_size=4, hidden_size=6)
lstm_layer.describe_sequence_processing(seq_length=8)

# Generate test sequence
seq_length = 8
batch_size = 3
input_size = 4

X_seq = np.random.randn(seq_length, batch_size, input_size)
print(f"\nInput sequence shape: {X_seq.shape}")

# Process the entire sequence
H_seq, C_seq, caches = lstm_layer.forward_sequence(X_seq)
print(f"Output hidden state shape: {H_seq.shape}")
print(f"Output cell state shape: {C_seq.shape}")

# Visualize gate values over time
def visualize_gates_over_time(caches, sample_idx=0):
    """Visualize how gate values change over time"""
    seq_length = len(caches)
    hidden_size = caches[0]['f_gate'].shape[1]
    
    # Extract gate values
    f_gates = np.zeros((seq_length, hidden_size))
    i_gates = np.zeros((seq_length, hidden_size))
    o_gates = np.zeros((seq_length, hidden_size))
    
    for t in range(seq_length):
        f_gates[t] = caches[t]['f_gate'][sample_idx]
        i_gates[t] = caches[t]['i_gate'][sample_idx]
        o_gates[t] = caches[t]['o_gate'][sample_idx]
    
    # Visualize
    fig, axes = plt.subplots(3, 1, figsize=(12, 10))
    
    # Forget gate
    im1 = axes[0].imshow(f_gates.T, aspect='auto', cmap='Reds', vmin=0, vmax=1)
    axes[0].set_xlabel('Time Step')
    axes[0].set_ylabel('Hidden Unit')
    axes[0].set_title('Forget Gate Values (f_t) — Controls Retention of Old Memory')
    plt.colorbar(im1, ax=axes[0])
    
    # Input gate
    im2 = axes[1].imshow(i_gates.T, aspect='auto', cmap='Blues', vmin=0, vmax=1)
    axes[1].set_xlabel('Time Step')
    axes[1].set_ylabel('Hidden Unit')
    axes[1].set_title('Input Gate Values (i_t) — Controls Incorporation of New Memory')
    plt.colorbar(im2, ax=axes[1])
    
    # Output gate
    im3 = axes[2].imshow(o_gates.T, aspect='auto', cmap='Greens', vmin=0, vmax=1)
    axes[2].set_xlabel('Time Step')
    axes[2].set_ylabel('Hidden Unit')
    axes[2].set_title('Output Gate Values (o_t) — Controls Memory Retrieval')
    plt.colorbar(im3, ax=axes[2])
    
    plt.tight_layout()
    plt.savefig('/tmp/lstm_gates_over_time.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Gate value visualization saved to /tmp/lstm_gates_over_time.png")

# Run visualization
visualize_gates_over_time(caches, sample_idx=0)

# Analyze cell state propagation
print("\nCell State Propagation Analysis:")
print("=" * 60)

# Compute cumulative changes in cell state
cell_state_changes = np.zeros(seq_length)
for t in range(seq_length):
    if t == 0:
        cell_state_changes[t] = 0
    else:
        # Compute L2 change in cell state
        change = np.linalg.norm(C_seq[t, 0] - C_seq[t-1, 0])
        cell_state_changes[t] = change

print(f"Cell state changes (L2 norm):")
for t in range(seq_length):
    print(f"  Time step {t}: {cell_state_changes[t]:.4f}")

# Visualize cell state changes
plt.figure(figsize=(10, 5))
plt.plot(range(seq_length), cell_state_changes, 'b-o', linewidth=2, markersize=8)
plt.xlabel('Time Step')
plt.ylabel('Cell State Change (L2 Norm)')
plt.title('LSTM Cell State Changes Over Time')
plt.grid(True, alpha=0.3)
plt.savefig('/tmp/lstm_cell_state_changes.png', dpi=150, bbox_inches='tight')
plt.close()

print(f"Cell state change plot saved to /tmp/lstm_cell_state_changes.png")

Text Generation Demo: LSTM's Memory Capacity

python
class TextGenerationLSTM:
    """Simple text generation LSTM demo"""
    
    def __init__(self, vocab_size, hidden_size=128):
        """Initialize text generation model
        
        Args:
            vocab_size: vocabulary size
            hidden_size: hidden state dimension
        """
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        # Embedding layer (converts word indices to vectors)
        self.embedding = np.random.randn(vocab_size, hidden_size) * 0.01
        
        # LSTM layer
        self.lstm = LSTMLayer(input_size=hidden_size, hidden_size=hidden_size)
        
        # Output layer
        self.W_out = np.random.randn(hidden_size, vocab_size) * 0.01
        self.b_out = np.zeros((1, vocab_size))
    
    def forward(self, input_indices, h_prev=None, c_prev=None):
        """Forward pass
        
        Args:
            input_indices: list of input word indices
            h_prev: initial hidden state (optional)
            c_prev: initial cell state (optional)
            
        Returns:
            logits: unnormalized scores at each time step
            h_next: final hidden state
            c_next: final cell state
        """
        seq_length = len(input_indices)
        batch_size = 1  # single-sample generation
        
        # Initialize states
        if h_prev is None:
            h_prev = np.zeros((batch_size, self.hidden_size))
        if c_prev is None:
            c_prev = np.zeros((batch_size, self.hidden_size))
        
        # Word embeddings
        embedded = np.zeros((seq_length, batch_size, self.hidden_size))
        for t, idx in enumerate(input_indices):
            embedded[t, 0] = self.embedding[idx]
        
        # LSTM forward pass
        H, C, _ = self.lstm.forward_sequence(embedded)
        
        # Output layer
        logits = np.zeros((seq_length, self.vocab_size))
        for t in range(seq_length):
            logits[t] = np.dot(H[t, 0], self.W_out) + self.b_out
        
        return logits, H[-1, 0], C[-1, 0]
    
    def generate_text(self, seed_text, char_to_idx, idx_to_char, length=50, temperature=1.0):
        """Generate text
        
        Args:
            seed_text: seed text
            char_to_idx: character-to-index mapping
            idx_to_char: index-to-character mapping
            length: generated text length
            temperature: temperature parameter (controls randomness)
        """
        # Initialize states
        h = np.zeros((1, self.hidden_size))
        c = np.zeros((1, self.hidden_size))
        
        # Convert seed text to indices
        indices = [char_to_idx[ch] for ch in seed_text]
        generated_text = seed_text
        
        # Generation loop
        for _ in range(length):
            # Forward pass (process only the last character)
            logits, h, c = self.forward([indices[-1]], h, c)
            
            # Apply temperature sampling
            logits = logits[0] / temperature
            probs = np.exp(logits) / np.sum(np.exp(logits))
            
            # Sample next character
            next_idx = np.random.choice(len(probs), p=probs)
            next_char = idx_to_char[next_idx]
            
            # Append to generated text
            generated_text += next_char
            indices.append(next_idx)
        
        return generated_text

# Text generation demo
print("\nLSTM Text Generation Demo:")
print("=" * 60)

# Create a simple character-level vocabulary
text_corpus = "hello world from lstm network for sequence modeling "
chars = sorted(list(set(text_corpus)))
vocab_size = len(chars)

char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

print(f"Vocabulary: {chars}")
print(f"Vocabulary size: {vocab_size}")

# Create model (note: this is an untrained demo model)
text_lstm = TextGenerationLSTM(vocab_size=vocab_size, hidden_size=32)

# Generate text example
seed_text = "hello"
generated = text_lstm.generate_text(
    seed_text=seed_text,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    length=30,
    temperature=0.8
)

print(f"\nSeed text: '{seed_text}'")
print(f"Generated text: '{generated}'")
print(f"\nNote: This is an untrained model, so the generated result is random.")
print(f"      In practice, LSTM needs to be trained on large amounts of text to generate meaningful sequences.")

# Demo: LSTM's memory capability
print("\nLSTM Memory Capability Demo:")
print("=" * 60)

# Create a simple memory task
def simple_memory_task():
    """Simple memory task: the network needs to remember characters from the beginning of a sequence"""
    # Create a smaller model
    memory_lstm = TextGenerationLSTM(vocab_size=vocab_size, hidden_size=16)
    
    # Construct a task that requires memory: repeat the first character
    test_sequence = [char_to_idx['h'], char_to_idx['e'], char_to_idx['l'], char_to_idx['l'], char_to_idx['o']]
    
    # Forward pass
    logits, h_final, c_final = memory_lstm.forward(test_sequence)
    
    print(f"Test sequence: 'hello'")
    print(f"Sequence length: {len(test_sequence)}")
    print(f"Final hidden state shape: {h_final.shape}")
    print(f"Final cell state shape: {c_final.shape}")
    
    # Analyze information retention in cell state
    print(f"\nCell state analysis:")
    print(f"  Cell state norm: {np.linalg.norm(c_final):.4f}")
    print(f"  Cell state mean absolute value: {np.mean(np.abs(c_final)):.4f}")
    
    # Simulation: if the cell state indeed retains information, changing the input should affect the output
    altered_sequence = [char_to_idx['w'], char_to_idx['o'], char_to_idx['r'], char_to_idx['l'], char_to_idx['d']]
    logits2, h_final2, c_final2 = memory_lstm.forward(altered_sequence)
    
    # Compare the difference between two cell states
    state_diff = np.linalg.norm(c_final - c_final2)
    print(f"  Cell state difference between different sequences: {state_diff:.4f}")
    print(f"  (The larger the difference, the more sequence-specific information the cell state retains)")

simple_memory_task()

"Remember," Mr. Pallas's Cat summarized, "LSTM is a computational model of memory — it shows how gating mechanisms can achieve selective memory. Through the forget gate, input gate, and output gate, LSTM actively manages information flow, balancing old and new memories. Most importantly, LSTM not only solves a technical problem (vanishing gradients) but also provides a computational framework for understanding memory itself. It reminds us: memory is not passive storage but an active process; forgetting is not a defect but a function; learning is not only accumulation but also selection."


Mr. Pallas's Cat's Reflection Questions

Hands-On Exploration (for Piglet)

  1. Gating experiment: modify the LSTM code and try fixing certain gate values (e.g., set the forget gate always to 1). Observe the effect on sequence modeling.
  2. Gradient propagation: implement backpropagation for LSTM. Observe how gradients propagate through the time dimension and compare with simple RNNs.
  3. Architecture variant: implement GRU (Gated Recurrent Unit) and compare it with LSTM. GRU has only two gates and fewer parameters — how does it perform?

Historical Investigation (for Little Seal)

  1. The birth of LSTM: study Hochreiter and Schmidhuber's original 1997 paper. What inspired them? (Neuroscience, control theory, computer science.)
  2. Memory model evolution: trace the development from simple RNNs to LSTM to attention mechanisms. What problem did each architecture solve?
  3. Cross-disciplinary connection: compare LSTM with Baddeley and Hitch's working memory model in psychology. What are the similarities and differences?

Integrated Reflection

  1. Philosophical reflection: LSTM's "forget gate" embodies the philosophy of "active forgetting." What is the value of active forgetting in human cognition? In machine learning?
  2. Ethical challenge: when LSTM is used for text generation, it may "remember" and reproduce biases from training data. How can we detect and mitigate this memory bias?
  3. Creative exercise: design a "meta-memory" LSTM that learns how to adjust its own gating strategy. How would you design this architecture?
  4. Limit challenge: prove that LSTM can simulate any Turing machine (theoretically). What conditions are required? What does this say about LSTM's capabilities?

Coming Up Next

The fragrance of tea filled the Black Stone House; night had grown deep.

"Today we explored chains of memory," said Mr. Pallas's Cat. "LSTM achieves selective memory through gating, but it remains a form of incremental memory — information accumulates gradually over time. Is there a more direct approach?"

Piglet asked curiously: "More direct? Like... directly accessing all historical information?"

"Yes," Mr. Pallas's Cat explained. "In the next chapter, we'll explore the contest between forgetting and causality — comparing LSTM's incremental memory with the selective attention of attention mechanisms. Which approach is better suited for processing long sequences?"

Little Seal flipped through his notebook. "This leads to two philosophies of sequence modeling. Historically, how did attention mechanisms challenge LSTM's dominance?"

Mr. Pallas's Cat smiled: "We'll take it slow. See you in the next chapter."


Piglet's note: I implemented an LSTM cell and visualized how gate values change over time. I found that different hidden units indeed have different "memory strategies": some keep the forget gate high throughout (long-term memory), while others have dynamically changing input gates (short-term buffer). Most interestingly, by adjusting the temperature parameter, you can control the "creativity" of text generation — high temperature is more random, low temperature more conservative.

Little Seal's note: I researched the history of LSTM and was amazed it was born in 1997 — 20 years before the Transformer. The original paper was only 8 pages, but the ideas were profound. Most fascinatingly, LSTM was originally designed to solve the "long-term dependency problem," but its gating mechanism unexpectedly provided a cognitive model of memory. Scientific discoveries often have unexpected byproducts.

Mr. Pallas's Cat's closing words: LSTM teaches a profound lesson about memory: memory requires choice, forgetting requires wisdom, retrieval requires timing. In this architecture, we see the perfect union of engineering necessity and cognitive insight. Most importantly, it reminds us: good design often arises from deep understanding of the problem's nature, not from mere technical accumulation. On this path, insight matters more than technique, understanding more than implementation. We'll take it slow — understanding is what matters most.