summaryrefslogtreecommitdiff
path: root/scripts/train/getth.py
blob: b0214d102e2711a27156408ba70c7016697b18a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env pygnubg 
"""getth [-np PLY] [-th VAL] [-1] dat-file
Find positions whose evaluation at ply PLY is different by more than TH from
values in dat-file"""

import sys, getopt

from bgutil import *

th = 0.2
verbose = 0
np = 0
do01Dif = 0

optlist, args = getopt.getopt(sys.argv[1:], "v:1", [ 'np=', 'th=' ] )

for o, a in optlist:
  if o == '-v':
    verbose = int(a)
  elif o == '--np' :
    np = int(a)
  elif o == '--th':
    th = float(a)
  elif o == '-1':
    do01Dif = 1

if len(args) != 1 :
  print >> sys.stderr, "Usage:",sys.argv[0]," [flags] dat-file"
  sys.exit(1)
  
f = file(args[0])
for line in f :
  
  if line[0] == '#' or line.isspace() :
    continue

  l = line.split()

  pos = gnubg.boardfromkey(l[0])
  probs = [float(p) for p in l[1:]]

  if len(probs) != 5 :
    print line
    print probs
    
  if do01Dif :
    p0 = gnubg.probs(pos, 0)
    p1 = gnubg.probs(pos, 1)
    err = abs(eq(p0) - eq(p1))
    if err >= th :
      print "# err %.5lf, pos = {%s}" % (err, listToString(pos))
      print "# probs(0) =", formatedp(p0)
      print "# probs(1) =", formatedp(p1)
      print line
  else :
    p = gnubg.probs(pos, np)

    err = eqError(p, probs)

    if err >= th :
      print "# err %.5lf, pos = {%s}" % (err, listToString(pos))
      print "# probs =",formatedp(probs)
      print "# ply" + str(np) + "  =",formatedp(p)
      print line

f.close()