"""
Learned signalling with different weight-update rules

learn takes a signalling system, a meaning, a signal, and a learning rule of
the form [alpha, beta, gamma, delta] and adjusts the weights in that system
appropriately.

NB. alpha in the weight-update rule refers to the change made to connection
weights between active meaning and signal nodes, beta is the update when the
meaning node is active and the signal node is inactive, gamma applies
when the meaning node is inactive and the signal node is active, and delta
applies when neither noce is active.


pop_learn uses a list of meaning-signal pairs to train a whole population of
learners. pop_produce uses a population to produce a list of meaning-signal
pairs

ca_monte_pop lets us test the communicative accuracy of the whole population

Usage example:

population = [[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
              [[0, 0, 0], [0, 0, 0], [0, 0, 0]]]
data = pop_produce(population, 100)
pop_learn(population, data, 100, [1, 0, 0, 0])
ca_monte_pop(population,10000)

Takes a population of two agents, has them generate 100 meaning-signal pairs,
the trains the same population on that data 100 times, using the
frequency-counting learning rule from learning1.py, and then tests the
communicative accuracy of the population after learning.
"""

import random

def reception_weights(system, signal):
    weights = []
    for row in system:
        weights.append(row[signal])
    return weights

def production_weights(system, meaning):
    return system[meaning]

def wta(items):
    maxweight = max(items)
    candidates = []
    for i in range(len(items)):
        if items[i] == maxweight:
            candidates.append(i)
    return random.choice(candidates)


def communicate(speaker_system, hearer_system, meaning):
    speaker_signal = wta(production_weights(speaker_system,meaning))
    hearer_meaning = wta(reception_weights(hearer_system,speaker_signal))
    if meaning == hearer_meaning: 
        return 1
    else: 
        return 0


# ----- new code below -----

def learn(system, meaning, signal, rule):
    for m in range(len(system)):
        for s in range(len(system[m])):
            if m == meaning and s == signal: system[m][s] += rule[0]
            if m == meaning and s != signal: system[m][s] += rule[1]
            if m != meaning and s == signal: system[m][s] += rule[2]
            if m != meaning and s != signal: system[m][s] += rule[3]

def pop_learn(population, data, no_learning_episodes, rule):
    for n in range(no_learning_episodes):
        ms_pair = random.choice(data)
        learn(random.choice(population), ms_pair[0], ms_pair[1], rule)

def pop_produce(population, no_productions):
    ms_pairs = []
    for n in range(no_productions):
        speaker = random.choice(population)
        meaning = random.randrange(len(speaker))
        signal = wta(production_weights(speaker, meaning))
        ms_pairs.append([meaning,signal])
    return ms_pairs

def ca_monte_pop(population, trials):
    total = 0.
    for n in range(trials):
        speaker = random.choice(population)
        hearer = random.choice(population)
        total += communicate(speaker, hearer, random.randrange(len(speaker)))
    return total / trials
