#include #include #include #include #include /* network runs for 3 ticks */ #define TIME 3 int main() { Net *net; Group *input,*hidden,*output,*bias; Connections *c1,*c2,*c3,*c4; float error; int i,j,count; Example *ex; ExampleSet *examples; /* pick a random number seed */ mikenet_set_seed(5); /* set some defaults */ default_errorComputation=CROSS_ENTROPY_ERROR; default_epsilon=0.5; net=create_net(TIME); /* create some groups */ input=init_group("Input",2,TIME); hidden=init_group("hidden",2,TIME); output=init_group("Output",1,TIME); bias=init_bias(1.0,TIME); /* now add our groups to the network object */ bind_group_to_net(net,input); bind_group_to_net(net,hidden); bind_group_to_net(net,output); bind_group_to_net(net,bias); /* now connect our groups, instantiating */ /* connection objects c1 through c4 */ c1=connect_groups(input,hidden); c2=connect_groups(hidden,output); c3=connect_groups(bias,hidden); c4=connect_groups(bias,output); /* add connections to our network */ bind_connection_to_net(net,c1); bind_connection_to_net(net,c2); bind_connection_to_net(net,c3); bind_connection_to_net(net,c4); /* randomize the weights in the connection objects. 2nd argument is weight range. */ randomize_connections(c1,0.5); randomize_connections(c2,0.5); randomize_connections(c3,0.5); randomize_connections(c4,0.5); /* load in our example set */ examples=load_examples("xor.ex",TIME); error=0.0; count=0; /* loop for up to 10000 times */ for(i=0;i<10000;i++) { /* loop over all examples */ for(j=0;jnumExamples;j++) { /* get j'th example from exampleset */ ex=&examples->examples[j]; /* do forward propagation */ bptt_forward(net,ex); /* backward pass: compute gradients */ bptt_compute_gradients(net,ex); /* sum up error for this example */ error+=compute_error(net,ex); } bptt_apply_deltas(net); /* is it time to write status? */ if (count==100) { /* average error over last 'count' iterations */ error = error/(float)count; count=0; /* print a message about average error so far */ printf("%d\t%f\n",i,error); /* are we done? */ if (error < 0.05) { printf("quitting... error low enough\n"); /* pop out of loop */ break; } /* zero error; start counting again */ error=0.0; } count++; } /* finished. write out results for each example */ for(i=0;inumExamples;i++) { ex=&examples->examples[i]; bptt_forward(net,ex); printf("example %d\tinputs %f\t%f\toutput %f\ttarget %f\n", i, get_value(ex->inputs,input->index,TIME-1,0), get_value(ex->inputs,input->index,TIME-1,1), output->outputs[TIME-1][0], get_value(ex->targets,output->index,TIME-1,0)); } return 0; }