6
\$\begingroup\$

I’ve implemented a Connect 4 AI in JavaScript using a Negamax search. My goal is to solve the game completely from the starting position (depth 42). I’m trying to make it as fast as possible and would like advice on improving performance.

Board:

  • Bitboard representation using two 32-bit integers (BigInt was slower in JS)
  • Faster win check by only examining from the last move played
  • Additional function to check if a move would cause a win without actually playing it

AI:

  • Alpha/beta pruning
  • Transposition table with variable size depending on depth
  • Move ordering: winning moves first, then center priority (huge impact)
  • Board symmetry: if the board is mirrored, only search half and mirror results

Here is the result I have from the start position:

Depth Time (hh:mm:ss) Score Nodes Nodes/s
34 00:01:44 0 303,352,701 2,914,331
36 00:03:20 0 573,822,574 2,864,630
38 00:07:53 0 1,265,387,699 2,671,409
40 00:17:52 0 3,066,005,415 2,859,962
42 00:46:13 1 8,571,378,203 3,090,289

Things I tried that only slowed things down

  • BigInt board representation
  • Killer & history move ordering
  • Different transposition table sizes (I think the current size is optimal)

What I am looking for

  1. Ways to increase nodes/sec using JavaScript tricks or better data structures
  2. Techniques to reduce node count while maintaining speed

I could implement some opening book but I feel that is a bit "cheating" since I assume an opening is good. I want to solve it myself.

Code Here is the current version of the bitboard logic and ai.

board.js:

import { COLS, ROWS } from "../utils/constants.js";
import { printBoardHelper } from "./boardHelpers.js";

const halfBoard = [2, 1, 0];

// Create a simple PRNG for Zobrist hashing to always get the same result when running tests
function makePRNG(seed) {
    let state = seed >>> 0;
    return () => {
        state ^= state << 13;
        state ^= state >>> 17;
        state ^= state << 5;
        return state >>> 0;
    };
}

export class Board {
    constructor(seed = 123456789) {        
        this.bitboards = { 1: [0, 0], 2: [0, 0] }; // 64-bit integers simulated with two 32-bit ints, bottom and top bits
        this.currentPlayer = 1;
        this.moveHistory = new Uint8Array(COLS * ROWS);
        this.colHeights = new Uint8Array(COLS);
        this.lastMove = null;
        this.moveCount = 0;

        // Zobrist (32-bit per entry), incremental mirrored hash kept too
        this.zobristHash = 0;
        this.zobristHashMirror = 0;

        const rand = makePRNG(seed);
        this.zobrist = Array.from({ length: ROWS * COLS }, () => [rand(), rand()]);

        // Per-column bit masks for incremental symmetry
        this.cols = [
            new Uint8Array(COLS), // unused to make indexing easier and faster
            new Uint8Array(COLS), // player 1
            new Uint8Array(COLS)  // player 2
        ];

        // If isMirrored is true, the board is symmetric around the center column
        this.isMirrored = true;
    }

    makeMove(col) {
        const row = this.colHeights[col];
        const index = row * COLS + col;
        const player = this.currentPlayer;

        // Update column mask
        this.cols[player][col] |= 1 << row;

        // Recompute symmetry
        this.isMirrored = true;
        for (let c of halfBoard) {
            const r = 6 - c;
            if (this.cols[1][c] !== this.cols[1][r] || this.cols[2][c] !== this.cols[2][r]) {
                this.isMirrored = false;
                break;
            }
        }

        // Bitboards
        this.bitboards[player][index < 32 ? 0 : 1] |= 1 << (index % 32);

        // History & heights
        this.lastMove = col;
        this.moveHistory[this.moveCount] = col;
        this.colHeights[col]++;
        this.moveCount++;

        // Zobrist
        this.zobristHash ^= this.zobrist[index][player - 1];
        this.zobristHashMirror ^= this.zobrist[row * COLS + (6 - col)][player - 1];

        // Switch player
        this.currentPlayer = 3 - player;
    }

    unmakeMove() {
        const col = this.moveHistory[this.moveCount - 1];
        const row = this.colHeights[col] - 1;
        const index = row * COLS + col;
        const player = 3 - this.currentPlayer;

        // Clear column mask
        this.cols[player][col] &= ~(1 << row);

        // Recompute symmetry
        this.isMirrored = true;
        for (let c of halfBoard) {
            const r = 6 - c;
            if (this.cols[1][c] !== this.cols[1][r] || this.cols[2][c] !== this.cols[2][r]) {
                this.isMirrored = false;
                break;
            }
        }

        // Clear bitboards
        this.bitboards[player][index < 32 ? 0 : 1] &= ~(1 << (index % 32));

        // Undo history/heights
        this.colHeights[col]--;
        this.moveHistory[this.moveCount - 1] = 0;
        this.moveCount--;
        this.lastMove = this.moveHistory[this.moveCount - 1] ?? null;

        // Zobrist
        this.zobristHash ^= this.zobrist[index][player - 1];
        this.zobristHashMirror ^= this.zobrist[row * COLS + (6 - col)][player - 1];

        this.currentPlayer = player;
    }

    checkWin() {
        // Check using lastMove only
        const player = 3 - this.currentPlayer;
        const col = this.lastMove;
        const row = this.colHeights[col] - 1;

        const bbLo = this.bitboards[player][0];
        const bbHi = this.bitboards[player][1];

        const rowCols = row * COLS;

        // Horizontal
        let count = 1;
        for (let c = col + 1; c < COLS && c <= col + 3; c++) {
            const idx = rowCols + c;
            if (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0) {
                if (++count >= 4) return true;
            } else break;
        }
        for (let c = col - 1; c >= 0 && c >= col - 3; c--) {
            const idx = rowCols + c;
            if (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0) {
                if (++count >= 4) return true;
            } else break;
        }
        if (count >= 4) return true;

        // Vertical
        count = 1;
        for (let r = row + 1; r < ROWS && r <= row + 3; r++) {
            const idx = r * COLS + col;
            if (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0) {
                if (++count >= 4) return true;
            } else break;
        }
        for (let r = row - 1; r >= 0 && r >= row - 3; r--) {
            const idx = r * COLS + col;
            if (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0) {
                if (++count >= 4) return true;
            } else break;
        }
        if (count >= 4) return true;

        // Diagonal \
        count = 1;
        for (let r = row + 1, c = col + 1; r < ROWS && c < COLS && r <= row + 3 && c <= col + 3; r++, c++) {
            const idx = r * COLS + c;
            if (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0) {
                if (++count >= 4) return true;
            } else break;
        }
        for (let r = row - 1, c = col - 1; r >= 0 && c >= 0 && r >= row - 3 && c >= col - 3; r--, c--) {
            const idx = r * COLS + c;
            if (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0) {
                if (++count >= 4) return true;
            } else break;
        }
        if (count >= 4) return true;

        // Diagonal /
        count = 1;
        for (let r = row + 1, c = col - 1; r < ROWS && c >= 0 && r <= row + 3 && c >= col - 3; r++, c--) {
            const idx = r * COLS + c;
            if (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0) {
                if (++count >= 4) return true;
            } else break;
        }
        for (let r = row - 1, c = col + 1; r >= 0 && c < COLS && r >= row - 3 && c <= col + 3; r--, c++) {
            const idx = r * COLS + c;
            if (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0) {
                count++;
                if (count >= 4) return true;
            } else break;
        }

        return false;
    }

    winForColumn(col) {
        const player = this.currentPlayer;
        const row = this.colHeights[col];

        const bbLo = this.bitboards[player][0];
        const bbHi = this.bitboards[player][1];
        const idxThisMove = row * COLS + col;
        const rowCols = row * COLS;

        // Horizontal
        let count = 1;
        for (let c = col + 1; c < COLS && c <= col + 3; c++) {
            const idx = rowCols + c;
            if (idx === idxThisMove || (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0)) {
                if (++count >= 4) return true;
            } else break;
        }
        for (let c = col - 1; c >= 0 && c >= col - 3; c--) {
            const idx = rowCols + c;
            if (idx === idxThisMove || (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0)) {
                if (++count >= 4) return true;
            } else break;
        }
        if (count >= 4) return true;

        // Vertical
        count = 1;
        for (let r = row + 1; r < ROWS && r <= row + 3; r++) {
            const idx = r * COLS + col;
            if (idx === idxThisMove || (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0)) {
                if (++count >= 4) return true;
            } else break;
        }
        for (let r = row - 1; r >= 0 && r >= row - 3; r--) {
            const idx = r * COLS + col;
            if (idx === idxThisMove || (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0)) {
                if (++count >= 4) return true;
            } else break;
        }
        if (count >= 4) return true;

        // Diagonals (same pattern as checkWin)
        count = 1;
        for (let r = row + 1, c = col + 1; r < ROWS && c < COLS && r <= row + 3 && c <= col + 3; r++, c++) {
            const idx = r * COLS + c;
            if (idx === idxThisMove || (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0)) {
                if (++count >= 4) return true;
            } else break;
        }
        for (let r = row - 1, c = col - 1; r >= 0 && c >= 0 && r >= row - 3 && c >= col - 3; r--, c--) {
            const idx = r * COLS + c;
            if (idx === idxThisMove || (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0)) {
                if (++count >= 4) return true;
            } else break;
        }
        if (count >= 4) return true;

        count = 1;
        for (let r = row + 1, c = col - 1; r < ROWS && c >= 0 && r <= row + 3 && c >= col - 3; r++, c--) {
            const idx = r * COLS + c;
            if (idx === idxThisMove || (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0)) {
                if (++count >= 4) return true;
            } else break;
        }
        for (let r = row - 1, c = col + 1; r >= 0 && c < COLS && r >= row - 3 && c <= col + 3; r--, c++) {
            const idx = r * COLS + c;
            if (idx === idxThisMove || (idx < 32 ? (bbLo & (1 << idx)) !== 0 : (bbHi & (1 << (idx - 32))) !== 0)) {
                count++;
                if (count >= 4) return true;
            } else break;
        }

        return false;
    }

    printBoard() {
        printBoardHelper(this);
    }
}

ai.js:

import { COLS, ROWS } from '../utils/constants.js';

let tt;
const boardSize = COLS * ROWS;
const CENTER_ORDER = [3, 2, 4, 1, 5, 0, 6];
const CENTER_ORDER_MIRROR = [3, 2, 1, 0];

// Variable size of TT depending on depth
function getTTSizeForDepth(depth) {
    if (depth >= 38) return 1 << 28;
    if (depth >= 36) return 1 << 26;
    if (depth >= 18) return 1 << 23;
    if (depth >= 10) return 1 << 18;
    return 1 << 16;
}

class TranspositionTable {
    constructor(size = 1 << 22) {
        this.size = size;
        this.keys = new Uint32Array(size);
        this.scores = new Int16Array(size);
        this.depths = new Int8Array(size);
        this.flags = new Uint8Array(size);
    }
    put(hash, score, depth, flag) {
        const idx = hash & (this.size - 1);
        this.keys[idx] = hash;
        this.scores[idx] = score;
        this.depths[idx] = depth;
        this.flags[idx] = flag;
    }
    getScore(hash, depth, alpha, beta) {
        const idx = hash & (this.size - 1);
        if (this.keys[idx] === hash && this.depths[idx] >= depth) {
            const score = this.scores[idx];
            const flag = this.flags[idx];
            if (flag === 1) return score;
            if (flag === 2 && score >= beta) return score;
            if (flag === 3 && score <= alpha) return score;
        }
        return null;
    }
}

export function negamax(board, depth, alpha, beta) {
    let nodes = 1;
    const originalAlpha = alpha;
    const moveCount = board.moveCount;

    // Check for symmetry and get appropriate hash
    const hash = board.isMirrored ? board.zobristHashMirror : board.zobristHash;

    // Check for cached result
    const cached = tt.getScore(hash, depth, alpha, beta);
    if (cached !== null) return { score: cached, nodes: 1, move: null };

    if (board.checkWin()) return { score: ((moveCount + 1) >> 1) - 22, nodes, move: null };
    if (moveCount >= boardSize || depth === 0) return { score: 0, nodes, move: null };

    let bestScore = -100;
    let bestMove = null;
    let flag = 1;
    const colHeights = board.colHeights;

    // Immediate win: use symmetric ordering if mirror
    const colOrder = board.isMirrored ? CENTER_ORDER_MIRROR : CENTER_ORDER;
    for (const col of colOrder) {
        if (colHeights[col] >= ROWS) continue;
        if (board.winForColumn(col)) {
            return { score: 22 - ((moveCount + 2) >> 1), nodes: 1, move: col };
        }
    }

    // Recursive search
    for (const col of colOrder) {
        if (colHeights[col] >= ROWS) continue;

        board.makeMove(col);
        const child = negamax(board, depth - 1, -beta, -alpha);
        board.unmakeMove();

        nodes += child.nodes;
        const score = -child.score;

        if (score >= beta) {
            bestScore = score;
            bestMove = col;
            flag = 2;
            break;
        }
        if (score > bestScore) {
            bestScore = score;
            bestMove = col;
        }
        if (score > alpha) alpha = score;
    }

    if (bestScore <= originalAlpha) flag = 3;
    else if (bestScore >= beta) flag = 2;

    tt.put(hash, bestScore, depth, flag);

    // Mirror move if board was symmetric
    if (board.isMirrored && bestMove < 3) bestMove = 6 - bestMove;

    return { score: bestScore, nodes, move: bestMove };
}

export function findBestMove(board, depth) {
    tt = new TranspositionTable(getTTSizeForDepth(depth));
    const result = negamax(board, depth, -100, 100);
    return { move: result.move, score: result.score, nodes: result.nodes };
}

EDIT: I made these two optimizations since this question:

  1. Just do checkWin in the negamax function when moveCount is larger than 6 since before that no win is possible
  2. In checkWin I changed so that I only check for vertical wins if row is larger than 4, and only check downwards in that direction.
\$\endgroup\$
3
  • \$\begingroup\$ "EDIT: I made these two optimizations since this question:" please see rules on changing the question after you receive an answer: "Do not change the code in the question after receiving an answer"... but it looks like new code was added just before an answer was posted. If the optimizations aren't part of the question, I'd refrain from mentioning them since it's a bit confusing to know if they're in there or not. Can you roll back that addition? \$\endgroup\$ Commented Sep 15 at 22:47
  • \$\begingroup\$ @ggorlen They are not in the code, I just posted it so people know what has been implemented. \$\endgroup\$ Commented Sep 16 at 4:08
  • \$\begingroup\$ Generally: when I read "ai" I think of something that at least has some sort of learning functionality etc.. I'd call it a bot so everyone immediately knows what to expect from the file/project \$\endgroup\$ Commented Sep 16 at 12:40

1 Answer 1

4
\$\begingroup\$

The checks whether there are 4 in a row don't really make use of the bitboard representation, but they could do. Looping over bitboards and checking them bit-by-bit (I realize you're only checking a region, which is nicer than checking the whole thing, but it still works out to a bunch of tests) is a waste of the representation, unless you're in a situation where that's the only thing you can do, but you do have opportunities here.

For example, you can detect 4 in row horizontally by ANDing together some shifted copies of the board (well 3 of them shifted, 1 of them the original) such that any set bit in the result indicates the position at which a row of 4 ends (or starts depending on how you see it), which naturally also indicates a win. A similar thing can be done for the vertical test (but it involves shifts across the low and high halves), and diagonal is tricky but you can still do it.

Be careful not to cause false positives due to bits wrapping around the edge of the board in shifts. For example, you could AND together 5 things: the original board, 3 shifted versions of the board, and one mask that has ones everywhere except the places where a set bit can only be a false positive. Masking off every shifted copy of the board (so that pieces never wrap around the edges of the board) would work but it would be overkill here.

\$\endgroup\$
3
  • \$\begingroup\$ I tried this but it somehow involved BigInt which was slower, or I didn't make it to work properly. But thank you for the tips. I put some things I did for improvement in the EDIT in the original quesiton. \$\endgroup\$ Commented Sep 15 at 8:02
  • \$\begingroup\$ @eligolf it is possible to do this without BigInt, only a bit annoying to shift the bits from one half into the other half (necessary for the vertical and diagonal tests, not the horizontal test) which the BigInt would do automatically \$\endgroup\$ Commented Sep 15 at 8:15
  • \$\begingroup\$ Yes I didn't get that to work, it was a bit tricky to get it right :) \$\endgroup\$ Commented Sep 15 at 8:26

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.