;; --------------------------------------------------------
;; Author: Charles Young, Solidsoft - August 2007
;;
;; Jess/CLIPS rule set to demonstrate the use of Bayes'
;; theorem in handling dependencies between uncertain
;; beliefs.   This rule set uses a very simple scenario in
;; which a number of jars hold combinations of black and
;; white jelly beans.   As the test progresses, jars are
;; randomly selected, and beans are removed.   The rule
;; set processes a sequence of these bean draw events.
;; Before each draw, the system uses Bayes' theorem to
;; determine the probability of a black or a white bean
;; being drawn from specific jars.  The program answers
;; questions in the form of "if the next bean drawn is
;; black, what is the probability that it will be drawn
;; from jar 3"?
;;
;; This is example code only.   Please feel free to use
;; as you wish.
;;
;; THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
;; ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
;; TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
;; PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
;; SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
;; CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
;; OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
;; IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
;; DEALINGS IN THE SOFTWARE.
;; --------------------------------------------------------

;; --------------------------------------------------------
;; Define the fact types (templates)
;; --------------------------------------------------------

(deftemplate jar            (slot id)
                            (slot marginal)
                            (slot blackBeanCount)
                            (slot whiteBeanCount)
                            (slot blackBeanMarginal)
                            (slot whiteBeanMarginal)
                            (slot blackBeanPosterior)
                            (slot whiteBeanPosterior)
                            (slot counted)
                            (slot marginalSet)
                            (slot posteriorsSet)
                            (slot onDrawBlack)
                            (slot onDrawWhite)
                            (slot probabilitiesSet)
)
(deftemplate bean           (slot id)
                            (slot colour)
                            (slot jarId)
                            (slot counted)
                            (slot jarCounted)
)
(deftemplate drawEvent      (slot id)
                            (slot jarId)
                            (slot colour)
)

(deftemplate blackBeans     (slot count)
                            (slot marginal)
)
(deftemplate whiteBeans     (slot count)
                            (slot marginal)
)

(deftemplate context        (slot jarCount)
                            (slot state)
                            (slot lastEventId)
)

;; --------------------------------------------------------
;; RULE SET
;; --------------------------------------------------------

;; --------------------------------------------------------
;; Assert and initialise the main context.
;; --------------------------------------------------------

(defrule assert_context
  (not (context))
  =>
  (assert (context (jarCount 0)
                   (state created)
                   (lastEventId 0)))
)

(defrule set_jar_count
  ?ctxt <- (context (jarCount ?jc) (state created))
  ?jar <- (jar (counted ~yes))
  =>
  (modify ?jar (counted yes))
  (modify ?ctxt (jarCount (+ ?jc 1)))
)

(defrule set_state_start (declare (salience -10))
  ?ctxt <- (context (state created))
  =>
  (modify ?ctxt (state start))
)

;; --------------------------------------------------------
;; Assert and initialise blackBean and whiteBean contexts.
;; --------------------------------------------------------

(defrule assert_beans_context (declare (salience 1))
  (context (state start))
  (not (blackBeans))
  (not (whiteBeans))
  =>
  (assert (blackBeans (count 0) (marginal 0.0)))
  (assert (whiteBeans (count 0) (marginal 0.0)))
)

(defrule set_blackbeans_count
  (context (state start))
  ?bean <- (bean (colour black) (counted ~yes))
  ?blackBeans <- (blackBeans (count ?bbc))
  =>
  (modify ?bean (counted yes))
  (modify ?blackBeans  (count (+ ?bbc 1)))
)

(defrule set_whitebeans_count
  (context (state start))
  ?bean <- (bean (colour white) (counted ~yes))
  ?whiteBeans <- (whiteBeans (count ?wbc))
  =>
  (modify ?bean (counted yes))
  (modify ?whiteBeans  (count (+ ?wbc 1)))
)

;; --------------------------------------------------------
;; Initialise each jar and set the beans counts in each
;; jar to 0.
;; --------------------------------------------------------

(defrule initialise_jar
  (context (state start))
  ?jar <- (jar (id ?id)
               (blackBeanCount ~0)
               (whiteBeanCount ~0)
               (blackBeanMarginal ~0.0)
               (whiteBeanMarginal ~0.0)
               (onDrawBlack ~0.0)
               (onDrawWhite ~0.0))
  =>
  (modify ?jar (blackBeanCount 0)
               (whiteBeanCount 0)
               (blackBeanMarginal 0.0)
               (whiteBeanMarginal 0.0)
               (onDrawBlack 0.0)
               (onDrawWhite 0.0))
)

;; --------------------------------------------------------
;; Calculate the marginal probability of selecting each
;; jar.
;; --------------------------------------------------------

(defrule initialise_hyp_marginals
  (context (jarCount ?jc&:(> ?jc 0)) (state start))
  ?jar <- (jar (marginalSet ~yes))
  =>
  (modify ?jar (marginal (/ 1 ?jc)) (marginalSet yes))
)

(defrule set_state_initialised (declare (salience -10))
  ?ctxt <- (context (state start))
  =>
  (modify ?ctxt (state initialised))
)

;; --------------------------------------------------------
;; Initialise the white and black bean count for each jar.
;; --------------------------------------------------------

(defrule initialise_jar_blackbean_count
  (context (state initialised))
  ?jar <- (jar (id ?id) (blackBeanCount ?bbc))
  ?bean <- (bean (jarId ?id) (colour black) (jarCounted ~yes))
  =>
  (modify ?jar (blackBeanCount (+ ?bbc 1)))
  (modify ?bean (jarCounted yes))
)

(defrule initialise_jar_whitebean_count
  (context (state initialised) )
  ?jar <- (jar (id ?id) (whiteBeanCount ?wbc))
  ?bean <- (bean (jarId ?id) (colour white) (jarCounted ~yes))
  =>
  (modify ?jar (whiteBeanCount (+ ?wbc 1)))
  (modify ?bean (jarCounted yes))
)

(defrule set_state_jar_bean_counts (declare (salience -10))
  ?ctxt <- (context (state initialised))
  =>
  (modify ?ctxt (state jar_bean_counts_set))
)

;; --------------------------------------------------------
;; Calculate the marginal probabilities for the evidence
;; (probability of selecting a bean of a given colour) and
;; set the context state to 'marginals initialised'   This
;; is calculated by summing, for each jar, the product of
;; the probability of selecting a particular coloured
;; bean from that jar and the probability of selecting the
;; given jar.  The summation is done in the next section.
;; --------------------------------------------------------

(defrule initialise_ev_jar_bean_marginals
  (context (state jar_bean_counts_set))
  ?jar <- (jar (marginal ?m)
               (blackBeanCount ?bbc)
               (whiteBeanCount ?wbc)
               (blackBeanMarginal 0.0)
               (whiteBeanMarginal 0.0))
  =>
  (modify ?jar (blackBeanMarginal (* ?m (/ ?bbc (+ ?bbc ?wbc))))
               (whiteBeanMarginal (* ?m (/ ?wbc (+ ?bbc ?wbc)))))
)

(defrule set_state_jar_bean_marginals (declare (salience -10))
  ?ctxt <- (context (state jar_bean_counts_set))
  =>
  (modify ?ctxt (state jar_bean_marginals_set))
)

;; --------------------------------------------------------
;; Sum the partial marginals calculated in the section
;; above.
;; --------------------------------------------------------

(defrule sum_ev_blackbean_marginals
  (context (state jar_bean_marginals_set))
  ?jar <- (jar (blackBeanMarginal ?m&:(neq ?m nil)&:(> ?m 0.0)))
  ?blackBeans <- (blackBeans (marginal ?bbm))
  =>
  (modify ?blackBeans (marginal (+ ?bbm ?m)))
  (modify ?jar (blackBeanMarginal 0.0))
)

(defrule sum_ev_whitebean_marginals
  (context (state jar_bean_marginals_set))
  ?jar <- (jar (whiteBeanMarginal ?m&:(neq ?m nil)&:(> ?m 0.0)))
  ?whiteBeans <- (whiteBeans (marginal ?wbm))
  =>
  (modify ?whiteBeans (marginal (+ ?wbm ?m)))
  (modify ?jar (whiteBeanMarginal 0.0))
)

(defrule set_state_marginals_initialised (declare (salience -10))
  ?ctxt <- (context (state jar_bean_marginals_set))
  =>
  (modify ?ctxt (state evidence_marginals_calculated))
)

;; --------------------------------------------------------
;; Set the posterior probability (likelihood) on each jar.
;; --------------------------------------------------------

(defrule initialise_posterior
  ?ctxt <- (context (state evidence_marginals_calculated))
  ?jar <- (jar (blackBeanCount ?blackBeanCount) (id ?id)
               (whiteBeanCount ?whiteBeanCount)
               (posteriorsSet ~yes))
  (test (or(neq ?blackBeanCount 0) (neq ?whiteBeanCount 0)))
  =>
  (modify ?jar (blackBeanPosterior (/ ?blackBeanCount (+ ?blackBeanCount ?whiteBeanCount)))
               (whiteBeanPosterior (/ ?whiteBeanCount (+ ?blackBeanCount ?whiteBeanCount)))
               (posteriorsSet yes))
)

(defrule initialise_posterior_for_zeros
  ?ctxt <- (context (state evidence_marginals_calculated))
  ?jar <- (jar (blackBeanCount ?blackBeanCount) (id ?id)
               (whiteBeanCount ?whiteBeanCount)
               (posteriorsSet ~yes))
  (test (and(= ?blackBeanCount 0) (= ?whiteBeanCount 0)))
  =>
  (modify ?jar (blackBeanPosterior 0)
               (whiteBeanPosterior 0)
               (posteriorsSet yes))
)

(defrule reset_jar_probabilities
  ?ctxt <- (context (state evidence_marginals_calculated))
  ?jar <- (jar (probabilitiesSet yes))
  =>
  (modify ?jar (probabilitiesSet no))
)

(defrule set_state_ready_to_calculate_pobabilities (declare (salience -10))
  ?ctxt <- (context (state evidence_marginals_calculated))
  =>
  (modify ?ctxt (state ready_to_calculate_probabilities))
)

;; --------------------------------------------------------
;; Calculate the probability of the hypotheses using Bayes'
;; theorem.
;; --------------------------------------------------------

(defrule calc_baysian_probability
  ?ctxt <- (context (state ready_to_calculate_probabilities))
  ?jar <-(jar (id ?id)
              (marginal ?jarMarginal)
              (blackBeanPosterior ?blackBeanPosterior)
              (whiteBeanPosterior ?whiteBeanPosterior)
              (probabilitiesSet ~yes))
  (blackBeans (marginal ?bbMarginal))
  (whiteBeans (marginal ?wbMarginal))
  =>
  (if (= ?bbMarginal 0.0) then (modify ?jar (onDrawBlack 0.0))
                           else (modify ?jar (onDrawBlack
                                  (/ (* ?blackBeanPosterior ?jarMarginal) ?bbMarginal))))
  (if (= ?wbMarginal 0.0) then (modify ?jar (onDrawWhite 0.0))
                           else (modify ?jar (onDrawWhite
                                  (/ (* ?whiteBeanPosterior ?jarMarginal) ?wbMarginal))))
  (modify ?jar (probabilitiesSet yes))
)

(defrule set_state_probabilities_calculated (declare (salience -10))
  ?ctxt <- (context (state ready_to_calculate_probabilities))
  =>
  (modify ?ctxt (state probabilities_calculated))
)

;; --------------------------------------------------------
;; Use StdIO to output the probabilities of drawing black
;; and white beans from each jar for next draw event.
;; --------------------------------------------------------

(defrule output_probabilities
  ?ctxt <- (context (state probabilities_calculated))
  (jar (id ?id)
       (onDrawBlack ?onDrawBlack)
       (onDrawWhite ?onDrawWhite))
  =>
  (printout t "The probabilities of the next bean being drawn from jar " ?id " are:" crlf)
  (printout t "   if next bean is black, " ?onDrawBlack crlf)
  (printout t "   if next bean is white, " ?onDrawWhite crlf)
)

(defrule set_state_ready_to_draw_beans (declare (salience -10))
  ?ctxt <- (context (state probabilities_calculated))
  =>
  (modify ?ctxt (state ready_to_draw_bean))
)

;; --------------------------------------------------------
;; Draw the next bean from the jars and amend the bean
;; counts.   If an invalid draw event is detected, show
;; a message and halt the engine.
;; --------------------------------------------------------

(defrule draw_next_black_bean
  ?ctx <- (context (state ready_to_draw_bean)
                   (lastEventId ?leid))
  ?blackBeans <- (blackBeans (count ?bbc))
  ?drawEvent <- (drawEvent (id ?deid&:(= ?deid (+ ?leid 1)))
                           (jarId ?id)
                           (colour black))
  ?bean <- (bean (jarId ?id)
                 (colour black))
  ?jar <- (jar (id ?id)
               (blackBeanCount ?jbbc))
  =>
  (modify ?blackBeans (count (- ?bbc 1)) (marginal 0.0))
  (modify ?jar (blackBeanCount (- ?jbbc 1)) (posteriorsSet no))
  (retract ?bean)
  (retract ?drawEvent)
  (printout t crlf "Drawing a black bean from jar " ?id crlf crlf)
)

(defrule draw_next_white_bean
  ?ctx <- (context (state ready_to_draw_bean) 
                   (lastEventId ?leid))
  ?whiteBeans <- (whiteBeans (count ?wbc))
  ?drawEvent <- (drawEvent (id ?deid&:(= ?deid (+ ?leid 1)))
                           (jarId ?id)
                           (colour white))
  ?bean <- (bean (jarId ?id)
                 (colour white))
  ?jar <- (jar (id ?id)
               (whiteBeanCount ?jwbc))
  =>
  (modify  ?whiteBeans (count (- ?wbc 1)) (marginal 0.0))
  (modify  ?jar (whiteBeanCount (- ?jwbc 1)) (posteriorsSet no))
  (retract ?bean)
  (retract ?drawEvent)
  (printout t crlf "Drawing a white bean from jar " ?id crlf crlf)
)

(defrule report_bad_event
  (context (state ready_to_draw_bean)
                   (lastEventId ?leid))
  (drawEvent (id ?deid&:(= ?deid (+ ?leid 1)))
             (jarId ?id)
             (colour ?colour))
  (not (bean (jarId ?id)
             (colour ?colour)) )
  =>
  (printout t crlf "EVENT ERROR: An invalid draw event has been detected." crlf)
  (printout t "             Event id " ?deid " cannot draw a " ?colour " bean" crlf)
  (printout t "             from jar " ?id "." crlf)
  (printout t crlf "             Test has been aborted." crlf)
  (halt)
)

(defrule reset_bean_context
  (context (state ready_to_draw_bean))
  ?blackBeans <- (blackBeans)
  ?whiteBeans <- (whiteBeans)
  =>
  (modify  ?blackBeans (marginal 0.0))
  (modify  ?whiteBeans (marginal 0.0))
)

(defrule set_state_jar_bean_counts_on_draw (declare (salience -10))

  ?ctxt <- (context (state ready_to_draw_bean) (lastEventId ?leid))
  =>
  (modify ?ctxt (state jar_bean_counts_set) (lastEventId (+ ?leid 1)))
)

;; --------------------------------------------------------
;; Halt processing if all events have been processed.
;; --------------------------------------------------------

(defrule halt_Processing (declare (salience 1))
  (context (state jar_bean_counts_set) (lastEventId ?leid&:(> ?leid 0)))
  (not (drawEvent (id ?deid&:(> ?deid ?leid))))
  =>
  (halt)
) 
